226 lines
6.8 KiB
Python
226 lines
6.8 KiB
Python
import atexit
|
|
import os
|
|
import pickle
|
|
import threading
|
|
import time
|
|
from collections import deque
|
|
from concurrent import futures
|
|
from itertools import zip_longest
|
|
|
|
import grpc
|
|
import numpy as np
|
|
import psycopg2
|
|
import requests
|
|
import schedule
|
|
import yaml
|
|
from fake_useragent import UserAgent
|
|
|
|
import id_service_pb2
|
|
import id_service_pb2_grpc
|
|
|
|
DIR = os.path.dirname(os.path.realpath(__file__))
|
|
BATCHER_STATE_FILE = os.path.join(DIR, "batcher_state.pkl")
|
|
SERVICE_STATE_FILE = os.path.join(DIR, "id_service_state.pkl")
|
|
|
|
with open(os.path.join(DIR, "batcher_config.yaml")) as f:
|
|
config = yaml.load(f, yaml.FullLoader)
|
|
|
|
|
|
class IDBatcher:
|
|
"""
|
|
Periodically gets the newest comment/post ID from reddit, compares it to the
|
|
last fetched one, and adds all intermediate values into a queue.
|
|
"""
|
|
|
|
def __init__(self):
|
|
self.lock = threading.Lock()
|
|
self.ids_queue = deque()
|
|
self.ua = UserAgent()
|
|
|
|
# Load state from file, if exists.
|
|
if os.path.exists(BATCHER_STATE_FILE):
|
|
with open(BATCHER_STATE_FILE, "rb") as f:
|
|
self.last_post_id, self.last_comment_id, self.ids_queue = pickle.load(f)
|
|
else:
|
|
self.last_post_id = self.fetch_newest_id("/r/all/new.json")
|
|
self.last_comment_id = self.fetch_newest_id("/r/all/comments.json")
|
|
self.save_state()
|
|
|
|
self.fetch_new_ids()
|
|
|
|
schedule.every(30).seconds.do(self.fetch_new_ids)
|
|
|
|
thread = threading.Thread(target=self.start_scheduler)
|
|
thread.start()
|
|
|
|
def start_scheduler(self):
|
|
while True:
|
|
schedule.run_pending()
|
|
time.sleep(1)
|
|
|
|
def fetch_newest_id(self, endpoint):
|
|
reddit_id = requests.get(
|
|
f"https://www.reddit.com{endpoint}",
|
|
headers={"User-Agent": self.ua.random},
|
|
).json()["data"]["children"][0]["data"]["name"]
|
|
|
|
return self.reddit_id_to_int(reddit_id)
|
|
|
|
def fetch_new_ids(self):
|
|
with self.lock:
|
|
new_post_id = self.fetch_newest_id("/r/all/new.json")
|
|
new_comment_id = self.fetch_newest_id("/r/all/comments.json")
|
|
|
|
new_post_ids = [
|
|
self.int_to_reddit_id("t3", i)
|
|
for i in range(self.last_post_id + 1, new_post_id + 1)
|
|
]
|
|
new_comment_ids = [
|
|
self.int_to_reddit_id("t1", i)
|
|
for i in range(self.last_comment_id + 1, new_comment_id + 1)
|
|
]
|
|
|
|
self.ids_queue.extend(self.interleave(new_post_ids, new_comment_ids))
|
|
self.last_post_id = new_post_id
|
|
self.last_comment_id = new_comment_id
|
|
|
|
print("Queue size:", len(self.ids_queue))
|
|
self.save_state()
|
|
|
|
def get_batch(self):
|
|
with self.lock:
|
|
batch = [
|
|
self.ids_queue.popleft() for _ in range(min(100, len(self.ids_queue)))
|
|
]
|
|
return batch
|
|
|
|
def save_state(self):
|
|
with open(BATCHER_STATE_FILE, "wb") as f:
|
|
pickle.dump((self.last_post_id, self.last_comment_id, self.ids_queue), f)
|
|
|
|
@staticmethod
|
|
def reddit_id_to_int(reddit_id):
|
|
_prefix, base36 = reddit_id.split("_")
|
|
return int(base36, 36)
|
|
|
|
@staticmethod
|
|
def int_to_reddit_id(id_type, i):
|
|
id_b36 = np.base_repr(i, 36)
|
|
return f"{id_type}_{id_b36}".lower()
|
|
|
|
@staticmethod
|
|
def interleave(list1, list2):
|
|
return [
|
|
item
|
|
for pair in zip_longest(list1, list2)
|
|
for item in pair
|
|
if item is not None
|
|
]
|
|
|
|
|
|
class IDService(id_service_pb2_grpc.IDServiceServicer):
|
|
def __init__(self, batcher):
|
|
self.batcher = batcher
|
|
self.active_batches = {}
|
|
self.timers = {}
|
|
self.lock = threading.Lock()
|
|
|
|
self.conn = psycopg2.connect(
|
|
f"dbname={config['db']['name']} "
|
|
f"user={config['db']['user']} "
|
|
f"password={config['db']['password']}"
|
|
)
|
|
self.cur = self.conn.cursor()
|
|
self.cur.execute(
|
|
"""
|
|
CREATE TABLE IF NOT EXISTS reddit_data(
|
|
id SERIAL PRIMARY KEY,
|
|
data JSONB NOT NULL
|
|
)
|
|
"""
|
|
)
|
|
self.conn.commit()
|
|
|
|
if os.path.exists(SERVICE_STATE_FILE):
|
|
with open(SERVICE_STATE_FILE, "rb") as f:
|
|
self.active_batches = pickle.load(f)
|
|
|
|
def GetBatch(self, request, context):
|
|
client_id = request.client_id
|
|
|
|
with self.lock:
|
|
# Check if there is an unconfirmed batch for the client.
|
|
if client_id in self.active_batches:
|
|
batch = self.active_batches[client_id]
|
|
else:
|
|
batch = self.batcher.get_batch()
|
|
self.active_batches[client_id] = batch
|
|
|
|
# Cancel any existing timer for this client.
|
|
if client_id in self.timers:
|
|
self.timers[client_id].cancel()
|
|
|
|
# Start a new timer for the client.
|
|
self.timers[client_id] = threading.Timer(
|
|
120, self.return_batch, [client_id]
|
|
)
|
|
self.timers[client_id].start()
|
|
|
|
return id_service_pb2.BatchResponse(ids=batch)
|
|
|
|
def SubmitBatch(self, request, context):
|
|
client_id = request.client_id
|
|
|
|
with self.lock:
|
|
if client_id in self.active_batches:
|
|
# Process the submitted data.
|
|
self.write_json_to_postgres(request.data)
|
|
|
|
# Remove batch from active batches for the client.
|
|
del self.active_batches[client_id]
|
|
|
|
# Cancel the timer for this client.
|
|
if client_id in self.timers:
|
|
self.timers[client_id].cancel()
|
|
del self.timers[client_id]
|
|
|
|
return id_service_pb2.SubmitResponse(success=True)
|
|
|
|
def return_batch(self, client_id):
|
|
with self.lock:
|
|
if client_id in self.active_batches:
|
|
batch = self.active_batches[client_id]
|
|
self.batcher.ids_queue.extendleft(batch)
|
|
del self.active_batches[client_id]
|
|
|
|
def save_state(self):
|
|
with open(SERVICE_STATE_FILE, "wb") as f:
|
|
pickle.dump(self.active_batches, f)
|
|
|
|
def write_json_to_postgres(self, json_data):
|
|
json_tuples = [(json_str,) for json_str in json_data]
|
|
self.cur.executemany(
|
|
"INSERT INTO reddit_data (data) VALUES (%s)",
|
|
json_tuples,
|
|
)
|
|
self.conn.commit()
|
|
|
|
|
|
def serve():
|
|
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
|
|
id_batcher = IDBatcher()
|
|
id_service = IDService(id_batcher)
|
|
id_service_pb2_grpc.add_IDServiceServicer_to_server(id_service, server)
|
|
|
|
# Save state when the service is interrupted.
|
|
atexit.register(id_service.save_state)
|
|
atexit.register(id_batcher.save_state)
|
|
|
|
server.add_insecure_port("[::]:50051")
|
|
server.start()
|
|
server.wait_for_termination()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
serve()
|