import atexit import os import pickle import threading import time from collections import deque from concurrent import futures from itertools import zip_longest import grpc import numpy as np import psycopg2 import requests import schedule import yaml from fake_useragent import UserAgent import id_service_pb2 import id_service_pb2_grpc DIR = os.path.dirname(os.path.realpath(__file__)) BATCHER_STATE_FILE = os.path.join(DIR, "batcher_state.pkl") SERVICE_STATE_FILE = os.path.join(DIR, "id_service_state.pkl") with open(os.path.join(DIR, "batcher_config.yaml")) as f: config = yaml.load(f, yaml.FullLoader) class IDBatcher: """ Periodically gets the newest comment/post ID from reddit, compares it to the last fetched one, and adds all intermediate values into a queue. """ def __init__(self): self.lock = threading.Lock() self.ids_queue = deque() self.ua = UserAgent() # Load state from file, if exists. if os.path.exists(BATCHER_STATE_FILE): with open(BATCHER_STATE_FILE, "rb") as f: self.last_post_id, self.last_comment_id, self.ids_queue = pickle.load(f) else: self.last_post_id = self.fetch_newest_id("/r/all/new.json") self.last_comment_id = self.fetch_newest_id("/r/all/comments.json") self.save_state() self.fetch_new_ids() schedule.every(30).seconds.do(self.fetch_new_ids) thread = threading.Thread(target=self.start_scheduler) thread.start() def start_scheduler(self): while True: schedule.run_pending() time.sleep(1) def fetch_newest_id(self, endpoint): reddit_id = requests.get( f"https://www.reddit.com{endpoint}", headers={"User-Agent": self.ua.random}, ).json()["data"]["children"][0]["data"]["name"] return self.reddit_id_to_int(reddit_id) def fetch_new_ids(self): with self.lock: new_post_id = self.fetch_newest_id("/r/all/new.json") new_comment_id = self.fetch_newest_id("/r/all/comments.json") new_post_ids = [ self.int_to_reddit_id("t3", i) for i in range(self.last_post_id + 1, new_post_id + 1) ] new_comment_ids = [ self.int_to_reddit_id("t1", i) for i in range(self.last_comment_id + 1, new_comment_id + 1) ] self.ids_queue.extend(self.interleave(new_post_ids, new_comment_ids)) self.last_post_id = new_post_id self.last_comment_id = new_comment_id print("Queue size:", len(self.ids_queue)) self.save_state() def get_batch(self): with self.lock: batch = [ self.ids_queue.popleft() for _ in range(min(100, len(self.ids_queue))) ] return batch def save_state(self): with open(BATCHER_STATE_FILE, "wb") as f: pickle.dump((self.last_post_id, self.last_comment_id, self.ids_queue), f) @staticmethod def reddit_id_to_int(reddit_id): _prefix, base36 = reddit_id.split("_") return int(base36, 36) @staticmethod def int_to_reddit_id(id_type, i): id_b36 = np.base_repr(i, 36) return f"{id_type}_{id_b36}".lower() @staticmethod def interleave(list1, list2): return [ item for pair in zip_longest(list1, list2) for item in pair if item is not None ] class IDService(id_service_pb2_grpc.IDServiceServicer): def __init__(self, batcher): self.batcher = batcher self.active_batches = {} self.timers = {} self.lock = threading.Lock() self.conn = psycopg2.connect( f"dbname={config['db']['name']} " f"user={config['db']['user']} " f"password={config['db']['password']}" ) self.cur = self.conn.cursor() self.cur.execute( """ CREATE TABLE IF NOT EXISTS reddit_data( id SERIAL PRIMARY KEY, data JSONB NOT NULL ) """ ) self.conn.commit() if os.path.exists(SERVICE_STATE_FILE): with open(SERVICE_STATE_FILE, "rb") as f: self.active_batches = pickle.load(f) def GetBatch(self, request, context): client_id = request.client_id with self.lock: # Check if there is an unconfirmed batch for the client. if client_id in self.active_batches: batch = self.active_batches[client_id] else: batch = self.batcher.get_batch() self.active_batches[client_id] = batch # Cancel any existing timer for this client. if client_id in self.timers: self.timers[client_id].cancel() # Start a new timer for the client. self.timers[client_id] = threading.Timer( 120, self.return_batch, [client_id] ) self.timers[client_id].start() return id_service_pb2.BatchResponse(ids=batch) def SubmitBatch(self, request, context): client_id = request.client_id with self.lock: if client_id in self.active_batches: # Process the submitted data. self.write_json_to_postgres(request.data) # Remove batch from active batches for the client. del self.active_batches[client_id] # Cancel the timer for this client. if client_id in self.timers: self.timers[client_id].cancel() del self.timers[client_id] return id_service_pb2.SubmitResponse(success=True) def return_batch(self, client_id): with self.lock: if client_id in self.active_batches: batch = self.active_batches[client_id] self.batcher.ids_queue.extendleft(batch) del self.active_batches[client_id] def save_state(self): with open(SERVICE_STATE_FILE, "wb") as f: pickle.dump(self.active_batches, f) def write_json_to_postgres(self, json_data): json_tuples = [(json_str,) for json_str in json_data] self.cur.executemany( "INSERT INTO reddit_data (data) VALUES (%s)", json_tuples, ) self.conn.commit() def serve(): server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) id_batcher = IDBatcher() id_service = IDService(id_batcher) id_service_pb2_grpc.add_IDServiceServicer_to_server(id_service, server) # Save state when the service is interrupted. atexit.register(id_service.save_state) atexit.register(id_batcher.save_state) server.add_insecure_port("[::]:50051") server.start() server.wait_for_termination() if __name__ == "__main__": serve()