pushshift/batcher.py

226 lines
6.8 KiB
Python

import atexit
import os
import pickle
import threading
import time
from collections import deque
from concurrent import futures
from itertools import zip_longest
import grpc
import numpy as np
import psycopg2
import requests
import schedule
import yaml
from fake_useragent import UserAgent
import id_service_pb2
import id_service_pb2_grpc
DIR = os.path.dirname(os.path.realpath(__file__))
BATCHER_STATE_FILE = os.path.join(DIR, "batcher_state.pkl")
SERVICE_STATE_FILE = os.path.join(DIR, "id_service_state.pkl")
with open(os.path.join(DIR, "batcher_config.yaml")) as f:
config = yaml.load(f, yaml.FullLoader)
class IDBatcher:
"""
Periodically gets the newest comment/post ID from reddit, compares it to the
last fetched one, and adds all intermediate values into a queue.
"""
def __init__(self):
self.lock = threading.Lock()
self.ids_queue = deque()
self.ua = UserAgent()
# Load state from file, if exists.
if os.path.exists(BATCHER_STATE_FILE):
with open(BATCHER_STATE_FILE, "rb") as f:
self.last_post_id, self.last_comment_id, self.ids_queue = pickle.load(f)
else:
self.last_post_id = self.fetch_newest_id("/r/all/new.json")
self.last_comment_id = self.fetch_newest_id("/r/all/comments.json")
self.save_state()
self.fetch_new_ids()
schedule.every(30).seconds.do(self.fetch_new_ids)
thread = threading.Thread(target=self.start_scheduler)
thread.start()
def start_scheduler(self):
while True:
schedule.run_pending()
time.sleep(1)
def fetch_newest_id(self, endpoint):
reddit_id = requests.get(
f"https://www.reddit.com{endpoint}",
headers={"User-Agent": self.ua.random},
).json()["data"]["children"][0]["data"]["name"]
return self.reddit_id_to_int(reddit_id)
def fetch_new_ids(self):
with self.lock:
new_post_id = self.fetch_newest_id("/r/all/new.json")
new_comment_id = self.fetch_newest_id("/r/all/comments.json")
new_post_ids = [
self.int_to_reddit_id("t3", i)
for i in range(self.last_post_id + 1, new_post_id + 1)
]
new_comment_ids = [
self.int_to_reddit_id("t1", i)
for i in range(self.last_comment_id + 1, new_comment_id + 1)
]
self.ids_queue.extend(self.interleave(new_post_ids, new_comment_ids))
self.last_post_id = new_post_id
self.last_comment_id = new_comment_id
print("Queue size:", len(self.ids_queue))
self.save_state()
def get_batch(self):
with self.lock:
batch = [
self.ids_queue.popleft() for _ in range(min(100, len(self.ids_queue)))
]
return batch
def save_state(self):
with open(BATCHER_STATE_FILE, "wb") as f:
pickle.dump((self.last_post_id, self.last_comment_id, self.ids_queue), f)
@staticmethod
def reddit_id_to_int(reddit_id):
_prefix, base36 = reddit_id.split("_")
return int(base36, 36)
@staticmethod
def int_to_reddit_id(id_type, i):
id_b36 = np.base_repr(i, 36)
return f"{id_type}_{id_b36}".lower()
@staticmethod
def interleave(list1, list2):
return [
item
for pair in zip_longest(list1, list2)
for item in pair
if item is not None
]
class IDService(id_service_pb2_grpc.IDServiceServicer):
def __init__(self, batcher):
self.batcher = batcher
self.active_batches = {}
self.timers = {}
self.lock = threading.Lock()
self.conn = psycopg2.connect(
f"dbname={config['db']['name']} "
f"user={config['db']['user']} "
f"password={config['db']['password']}"
)
self.cur = self.conn.cursor()
self.cur.execute(
"""
CREATE TABLE IF NOT EXISTS reddit_data(
id SERIAL PRIMARY KEY,
data JSONB NOT NULL
)
"""
)
self.conn.commit()
if os.path.exists(SERVICE_STATE_FILE):
with open(SERVICE_STATE_FILE, "rb") as f:
self.active_batches = pickle.load(f)
def GetBatch(self, request, context):
client_id = request.client_id
with self.lock:
# Check if there is an unconfirmed batch for the client.
if client_id in self.active_batches:
batch = self.active_batches[client_id]
else:
batch = self.batcher.get_batch()
self.active_batches[client_id] = batch
# Cancel any existing timer for this client.
if client_id in self.timers:
self.timers[client_id].cancel()
# Start a new timer for the client.
self.timers[client_id] = threading.Timer(
120, self.return_batch, [client_id]
)
self.timers[client_id].start()
return id_service_pb2.BatchResponse(ids=batch)
def SubmitBatch(self, request, context):
client_id = request.client_id
with self.lock:
if client_id in self.active_batches:
# Process the submitted data.
self.write_json_to_postgres(request.data)
# Remove batch from active batches for the client.
del self.active_batches[client_id]
# Cancel the timer for this client.
if client_id in self.timers:
self.timers[client_id].cancel()
del self.timers[client_id]
return id_service_pb2.SubmitResponse(success=True)
def return_batch(self, client_id):
with self.lock:
if client_id in self.active_batches:
batch = self.active_batches[client_id]
self.batcher.ids_queue.extendleft(batch)
del self.active_batches[client_id]
def save_state(self):
with open(SERVICE_STATE_FILE, "wb") as f:
pickle.dump(self.active_batches, f)
def write_json_to_postgres(self, json_data):
json_tuples = [(json_str,) for json_str in json_data]
self.cur.executemany(
"INSERT INTO reddit_data (data) VALUES (%s)",
json_tuples,
)
self.conn.commit()
def serve():
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
id_batcher = IDBatcher()
id_service = IDService(id_batcher)
id_service_pb2_grpc.add_IDServiceServicer_to_server(id_service, server)
# Save state when the service is interrupted.
atexit.register(id_service.save_state)
atexit.register(id_batcher.save_state)
server.add_insecure_port("[::]:50051")
server.start()
server.wait_for_termination()
if __name__ == "__main__":
serve()