From 04ea9c9e5ef095537900edacc7a271d7acd09483 Mon Sep 17 00:00:00 2001 From: float-trip Date: Tue, 11 Jul 2023 04:55:36 +0000 Subject: [PATCH] Initial commit. --- batcher.py | 220 ++++++++++++++++++++++++++++++++++++ batcher_config.example.yaml | 4 + compile_proto.sh | 5 + fetcher.py | 171 ++++++++++++++++++++++++++++ fetcher_config.example.yaml | 6 + id_service.proto | 32 ++++++ readme.md | 16 +++ requirements.txt | 10 ++ 8 files changed, 464 insertions(+) create mode 100644 batcher.py create mode 100644 batcher_config.example.yaml create mode 100644 compile_proto.sh create mode 100644 fetcher.py create mode 100644 fetcher_config.example.yaml create mode 100644 id_service.proto create mode 100644 readme.md create mode 100644 requirements.txt diff --git a/batcher.py b/batcher.py new file mode 100644 index 0000000..f13f548 --- /dev/null +++ b/batcher.py @@ -0,0 +1,220 @@ +import atexit +import os +import pickle +import threading +import time +from collections import deque +from concurrent import futures +from itertools import zip_longest + +import grpc +import numpy as np +import psycopg2 +import requests +import schedule +import yaml +from fake_useragent import UserAgent + +import id_service_pb2 +import id_service_pb2_grpc + +DIR = os.path.dirname(os.path.realpath(__file__)) +BATCHER_STATE_FILE = os.path.join(DIR, "batcher_state.pkl") +SERVICE_STATE_FILE = os.path.join(DIR, "id_service_state.pkl") + +with open(os.path.join(DIR, "batcher_config.yaml")) as f: + config = yaml.load(f, yaml.FullLoader) + + +class IDBatcher: + def __init__(self): + self.lock = threading.Lock() + self.ids_queue = deque() + self.ua = UserAgent() + + # Load state from file, if exists. + if os.path.exists(BATCHER_STATE_FILE): + with open(BATCHER_STATE_FILE, "rb") as f: + self.last_post_id, self.last_comment_id, self.ids_queue = pickle.load(f) + else: + self.last_post_id = self.fetch_newest_id("/r/all/new.json") + self.last_comment_id = self.fetch_newest_id("/r/all/comments.json") + self.save_state() + + self.fetch_new_ids() + + schedule.every(5).seconds.do(self.fetch_new_ids) + + thread = threading.Thread(target=self.start_scheduler) + thread.start() + + def start_scheduler(self): + while True: + schedule.run_pending() + time.sleep(1) + + def fetch_newest_id(self, endpoint): + reddit_id = requests.get( + f"https://www.reddit.com{endpoint}", + headers={"User-Agent": self.ua.random}, + ).json()["data"]["children"][0]["data"]["name"] + + return self.reddit_id_to_int(reddit_id) + + def fetch_new_ids(self): + with self.lock: + new_post_id = self.fetch_newest_id("/r/all/new.json") + new_comment_id = self.fetch_newest_id("/r/all/comments.json") + + new_post_ids = [ + self.int_to_reddit_id("t3", i) + for i in range(self.last_post_id + 1, new_post_id + 1) + ] + new_comment_ids = [ + self.int_to_reddit_id("t1", i) + for i in range(self.last_comment_id + 1, new_comment_id + 1) + ] + + self.ids_queue.extend(self.interleave(new_post_ids, new_comment_ids)) + self.last_post_id = new_post_id + self.last_comment_id = new_comment_id + + print("Queue size:", len(self.ids_queue)) + self.save_state() + + def get_batch(self): + with self.lock: + batch = [ + self.ids_queue.popleft() for _ in range(min(100, len(self.ids_queue))) + ] + return batch + + def save_state(self): + with open(BATCHER_STATE_FILE, "wb") as f: + pickle.dump((self.last_post_id, self.last_comment_id, self.ids_queue), f) + + @staticmethod + def reddit_id_to_int(reddit_id): + _prefix, base36 = reddit_id.split("_") + return int(base36, 36) + + @staticmethod + def int_to_reddit_id(id_type, i): + id_b36 = np.base_repr(i, 36) + return f"{id_type}_{id_b36}".lower() + + @staticmethod + def interleave(list1, list2): + return [ + item + for pair in zip_longest(list1, list2) + for item in pair + if item is not None + ] + + +class IDService(id_service_pb2_grpc.IDServiceServicer): + def __init__(self, batcher): + self.batcher = batcher + self.active_batches = {} + self.timers = {} + self.lock = threading.Lock() + + self.conn = psycopg2.connect( + f"dbname={config['db']['name']} " + f"user={config['db']['user']} " + f"password={config['db']['password']}" + ) + self.cur = self.conn.cursor() + self.cur.execute( + """ + CREATE TABLE IF NOT EXISTS reddit_data( + id SERIAL PRIMARY KEY, + data JSONB NOT NULL + ) + """ + ) + self.conn.commit() + + if os.path.exists(SERVICE_STATE_FILE): + with open(SERVICE_STATE_FILE, "rb") as f: + self.active_batches = pickle.load(f) + + def GetBatch(self, request, context): + client_id = request.client_id + + with self.lock: + # Check if there is an unconfirmed batch for the client. + if client_id in self.active_batches: + batch = self.active_batches[client_id] + else: + batch = self.batcher.get_batch() + self.active_batches[client_id] = batch + + # Cancel any existing timer for this client. + if client_id in self.timers: + self.timers[client_id].cancel() + + # Start a new timer for the client. + self.timers[client_id] = threading.Timer( + 120, self.return_batch, [client_id] + ) + self.timers[client_id].start() + + return id_service_pb2.BatchResponse(ids=batch) + + def SubmitBatch(self, request, context): + client_id = request.client_id + + with self.lock: + if client_id in self.active_batches: + # Process the submitted data. + self.write_json_to_postgres(request.data) + + # Remove batch from active batches for the client. + del self.active_batches[client_id] + + # Cancel the timer for this client. + if client_id in self.timers: + self.timers[client_id].cancel() + del self.timers[client_id] + + return id_service_pb2.SubmitResponse(success=True) + + def return_batch(self, client_id): + with self.lock: + if client_id in self.active_batches: + batch = self.active_batches[client_id] + self.batcher.ids_queue.extendleft(batch) + del self.active_batches[client_id] + + def save_state(self): + with open(SERVICE_STATE_FILE, "wb") as f: + pickle.dump(self.active_batches, f) + + def write_json_to_postgres(self, json_data): + json_tuples = [(json_str,) for json_str in json_data] + self.cur.executemany( + "INSERT INTO reddit_data (data) VALUES (%s)", + json_tuples, + ) + self.conn.commit() + + +def serve(): + server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) + id_batcher = IDBatcher() + id_service = IDService(id_batcher) + id_service_pb2_grpc.add_IDServiceServicer_to_server(id_service, server) + + # Save state when the service is interrupted. + atexit.register(id_service.save_state) + atexit.register(id_batcher.save_state) + + server.add_insecure_port("[::]:50051") + server.start() + server.wait_for_termination() + + +if __name__ == "__main__": + serve() diff --git a/batcher_config.example.yaml b/batcher_config.example.yaml new file mode 100644 index 0000000..3bd660f --- /dev/null +++ b/batcher_config.example.yaml @@ -0,0 +1,4 @@ +db: + user: postgres + password: "" + name: postgres \ No newline at end of file diff --git a/compile_proto.sh b/compile_proto.sh new file mode 100644 index 0000000..d7b6d8a --- /dev/null +++ b/compile_proto.sh @@ -0,0 +1,5 @@ +#!/usr/bin/env bash + +full_path=$(realpath $0) +dir=$(dirname $full_path) +python -m grpc_tools.protoc --python_out=$dir --grpc_python_out=$dir $dir/id_service.proto -I=$dir diff --git a/fetcher.py b/fetcher.py new file mode 100644 index 0000000..533d828 --- /dev/null +++ b/fetcher.py @@ -0,0 +1,171 @@ +import json +import logging +import os +import time +from logging.handlers import RotatingFileHandler + +import grpc +import requests +import yaml +from fake_useragent import UserAgent + +import id_service_pb2 +import id_service_pb2_grpc + +# Load config.yaml. +DIR = os.path.dirname(os.path.realpath(__file__)) +with open(os.path.join(DIR, "fetcher_config.yaml")) as f: + config = yaml.load(f, yaml.FullLoader) + +log_file = os.path.join(DIR, "logs", "fetcher.log") + +# Configure logger. +logger = logging.getLogger("fetcher") +logger.setLevel(logging.DEBUG) +formatter = logging.Formatter( + fmt="%(asctime)s %(levelname)-8s %(message)s", datefmt="%Y-%m-%d %H:%M:%S" +) + +# Add rotating file handler. +rh = RotatingFileHandler(log_file, maxBytes=10 * 10**6, backupCount=20) +rh.setLevel(logging.DEBUG) +rh.setFormatter(formatter) +logger.addHandler(rh) + +# Add stream handler. +sh = logging.StreamHandler() +sh.setLevel(logging.INFO) +sh.setFormatter(formatter) +logger.addHandler(sh) + + +class RedditClient: + ACCESS_TOKEN_URL = "https://www.reddit.com/api/v1/access_token" + INFO_API_URL = "https://oauth.reddit.com/api/info.json" + BACKOFF_INCREMENT = 5 + MIN_DELAY = 1 + + def __init__(self): + self.headers = {"User-Agent": UserAgent().random} + self.last_request = time.time() + self.backoff = 0 + + self.ratelimit_remaining = 300 + self.ratelimit_reset = time.time() + + self.token = None + self.token_expiration = time.time() + + def sleep_until_ready(self): + if self.ratelimit_remaining == 0 and time.time() < self.ratelimit_reset: + time.sleep(self.ratelimit_reset - time.time()) + + delay = max( + 0, self.MIN_DELAY + self.backoff - (time.time() - self.last_request) + ) + + time.sleep(delay + self.backoff) + + def authorize(self): + self.sleep_until_ready() + self.last_request = time.time() + + res = requests.post( + self.ACCESS_TOKEN_URL, + params={ + "grant_type": "refresh_token", + "refresh_token": config["reddit"]["refresh_token"], + }, + headers=self.headers, + auth=(config["reddit"]["client_id"], config["reddit"]["client_secret"]), + ) + + if res.status_code == 200: + logger.info("Auth successful.") + self.backoff = 0 + self.token = res.json()["access_token"] + self.token_expiration = time.time() + res.json()["expires_in"] + else: + self.backoff += self.BACKOFF_INCREMENT + logger.error("Auth failed.") + + def update_ratelimit(self, headers): + if "x-ratelimit-remaining" in headers: + self.ratelimit_remaining = int(float(headers["x-ratelimit-remaining"])) + self.ratelimit_reset = time.time() + int( + float(headers["x-ratelimit-reset"]) + ) + + def fetch(self, ids): + if self.token is None or self.token_expiration - time.time() < 60: + self.authorize() + return None + + self.sleep_until_ready() + self.last_request = time.time() + + params = {"id": ",".join(ids)} + logger.debug(f"GET /api/info.json?id={params['id']}") + res = requests.get( + self.INFO_API_URL, + params=params, + headers=self.headers | {"Authorization": f"bearer {self.token}"}, + ) + + if res.status_code == 200: + self.backoff = 0 + self.update_ratelimit(res.headers) + logger.debug(f"Response: {res.text}") + items = res.json()["data"]["children"] + return [json.dumps(item) for item in items] + else: + self.update_ratelimit(res.headers) + self.backoff += self.BACKOFF_INCREMENT + logger.error(f"Bad status: {res.status_code}. Backoff: {self.backoff}.") + return None + + +class FetcherClient: + def __init__(self, client_id, reddit_client): + self.client_id = client_id + self.reddit_client = reddit_client + self.channel = grpc.insecure_channel( + config["distributor_uri"], + options=[ + ("grpc.max_send_message_length", 100 * 10**6), + ("grpc.max_receive_message_length", 1 * 10**6), + ], + ) + self.stub = id_service_pb2_grpc.IDServiceStub(self.channel) + + def request_and_process_batch(self): + # Request a batch of IDs from the server. + response = self.stub.GetBatch( + id_service_pb2.BatchRequest(client_id=self.client_id) + ) + + # Fetch data from reddit using the received IDs. + data = None + while data is None: + data = self.reddit_client.fetch(response.ids) + + # Submit batch to the server. + self.stub.SubmitBatch( + id_service_pb2.SubmitRequest(client_id=self.client_id, data=data) + ) + + def run(self): + # Continuously request and process batches. + while True: + try: + self.request_and_process_batch() + time.sleep(1) + except grpc.RpcError as e: + print(f"gRPC Error: {e.details()}") + time.sleep(10) + + +if __name__ == "__main__": + reddit_client = RedditClient() + fetcher_client = FetcherClient(config["id"], reddit_client) + fetcher_client.run() diff --git a/fetcher_config.example.yaml b/fetcher_config.example.yaml new file mode 100644 index 0000000..209af0a --- /dev/null +++ b/fetcher_config.example.yaml @@ -0,0 +1,6 @@ +id: fetcher-1 +distributor_uri: localhost:50051 +reddit: + client_id: abc + client_secret: abc + refresh_token: 123-abc \ No newline at end of file diff --git a/id_service.proto b/id_service.proto new file mode 100644 index 0000000..92845b7 --- /dev/null +++ b/id_service.proto @@ -0,0 +1,32 @@ +syntax = "proto3"; + +package IDService; + +service IDService { + // Fetch a batch of IDs. + rpc GetBatch (BatchRequest) returns (BatchResponse) {} + + // Submit processed batch of data. + rpc SubmitBatch (SubmitRequest) returns (SubmitResponse) {} +} + +// The BatchRequest message contains the client id. +message BatchRequest { + string client_id = 1; +} + +// The BatchResponse message contains the IDs. +message BatchResponse { + repeated string ids = 1; +} + +// The SubmitRequest message contains the client id and a batch of data. +message SubmitRequest { + string client_id = 1; + repeated string data = 2; +} + +// The SubmitResponse message confirms successful batch processing. +message SubmitResponse { + bool success = 1; +} diff --git a/readme.md b/readme.md new file mode 100644 index 0000000..3b72cde --- /dev/null +++ b/readme.md @@ -0,0 +1,16 @@ +```bash +# ...install PostgreSQL... +pip install -r requirements.txt +# ...modify example yamls... +mv batcher_config.example.yaml batcher_config.yaml +mv fetcher_config.example.yaml fetcher_config.yaml +mkdir logs +bash compile_proto.sh + +# Then run both of these files: +python batcher.py +python fetcher.py +``` + +Getting a refresh token: +https://praw.readthedocs.io/en/stable/tutorials/refresh_token.html#obtaining-refresh-tokens \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..9ae5868 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,10 @@ +fake_useragent==1.1.3 +grpcio==1.56.0 +grpcio-tools==1.56.0 +numpy==1.24.2 +praw==7.7.0 +protobuf==4.23.4 +psycopg2_binary==2.9.6 +PyYAML==6.0 +Requests==2.31.0 +schedule==1.2.0 \ No newline at end of file