Initial commit.

master
float-trip 2023-07-11 04:55:36 +00:00
commit 04ea9c9e5e
8 changed files with 464 additions and 0 deletions

220
batcher.py 100644
View File

@ -0,0 +1,220 @@
import atexit
import os
import pickle
import threading
import time
from collections import deque
from concurrent import futures
from itertools import zip_longest
import grpc
import numpy as np
import psycopg2
import requests
import schedule
import yaml
from fake_useragent import UserAgent
import id_service_pb2
import id_service_pb2_grpc
DIR = os.path.dirname(os.path.realpath(__file__))
BATCHER_STATE_FILE = os.path.join(DIR, "batcher_state.pkl")
SERVICE_STATE_FILE = os.path.join(DIR, "id_service_state.pkl")
with open(os.path.join(DIR, "batcher_config.yaml")) as f:
config = yaml.load(f, yaml.FullLoader)
class IDBatcher:
def __init__(self):
self.lock = threading.Lock()
self.ids_queue = deque()
self.ua = UserAgent()
# Load state from file, if exists.
if os.path.exists(BATCHER_STATE_FILE):
with open(BATCHER_STATE_FILE, "rb") as f:
self.last_post_id, self.last_comment_id, self.ids_queue = pickle.load(f)
else:
self.last_post_id = self.fetch_newest_id("/r/all/new.json")
self.last_comment_id = self.fetch_newest_id("/r/all/comments.json")
self.save_state()
self.fetch_new_ids()
schedule.every(5).seconds.do(self.fetch_new_ids)
thread = threading.Thread(target=self.start_scheduler)
thread.start()
def start_scheduler(self):
while True:
schedule.run_pending()
time.sleep(1)
def fetch_newest_id(self, endpoint):
reddit_id = requests.get(
f"https://www.reddit.com{endpoint}",
headers={"User-Agent": self.ua.random},
).json()["data"]["children"][0]["data"]["name"]
return self.reddit_id_to_int(reddit_id)
def fetch_new_ids(self):
with self.lock:
new_post_id = self.fetch_newest_id("/r/all/new.json")
new_comment_id = self.fetch_newest_id("/r/all/comments.json")
new_post_ids = [
self.int_to_reddit_id("t3", i)
for i in range(self.last_post_id + 1, new_post_id + 1)
]
new_comment_ids = [
self.int_to_reddit_id("t1", i)
for i in range(self.last_comment_id + 1, new_comment_id + 1)
]
self.ids_queue.extend(self.interleave(new_post_ids, new_comment_ids))
self.last_post_id = new_post_id
self.last_comment_id = new_comment_id
print("Queue size:", len(self.ids_queue))
self.save_state()
def get_batch(self):
with self.lock:
batch = [
self.ids_queue.popleft() for _ in range(min(100, len(self.ids_queue)))
]
return batch
def save_state(self):
with open(BATCHER_STATE_FILE, "wb") as f:
pickle.dump((self.last_post_id, self.last_comment_id, self.ids_queue), f)
@staticmethod
def reddit_id_to_int(reddit_id):
_prefix, base36 = reddit_id.split("_")
return int(base36, 36)
@staticmethod
def int_to_reddit_id(id_type, i):
id_b36 = np.base_repr(i, 36)
return f"{id_type}_{id_b36}".lower()
@staticmethod
def interleave(list1, list2):
return [
item
for pair in zip_longest(list1, list2)
for item in pair
if item is not None
]
class IDService(id_service_pb2_grpc.IDServiceServicer):
def __init__(self, batcher):
self.batcher = batcher
self.active_batches = {}
self.timers = {}
self.lock = threading.Lock()
self.conn = psycopg2.connect(
f"dbname={config['db']['name']} "
f"user={config['db']['user']} "
f"password={config['db']['password']}"
)
self.cur = self.conn.cursor()
self.cur.execute(
"""
CREATE TABLE IF NOT EXISTS reddit_data(
id SERIAL PRIMARY KEY,
data JSONB NOT NULL
)
"""
)
self.conn.commit()
if os.path.exists(SERVICE_STATE_FILE):
with open(SERVICE_STATE_FILE, "rb") as f:
self.active_batches = pickle.load(f)
def GetBatch(self, request, context):
client_id = request.client_id
with self.lock:
# Check if there is an unconfirmed batch for the client.
if client_id in self.active_batches:
batch = self.active_batches[client_id]
else:
batch = self.batcher.get_batch()
self.active_batches[client_id] = batch
# Cancel any existing timer for this client.
if client_id in self.timers:
self.timers[client_id].cancel()
# Start a new timer for the client.
self.timers[client_id] = threading.Timer(
120, self.return_batch, [client_id]
)
self.timers[client_id].start()
return id_service_pb2.BatchResponse(ids=batch)
def SubmitBatch(self, request, context):
client_id = request.client_id
with self.lock:
if client_id in self.active_batches:
# Process the submitted data.
self.write_json_to_postgres(request.data)
# Remove batch from active batches for the client.
del self.active_batches[client_id]
# Cancel the timer for this client.
if client_id in self.timers:
self.timers[client_id].cancel()
del self.timers[client_id]
return id_service_pb2.SubmitResponse(success=True)
def return_batch(self, client_id):
with self.lock:
if client_id in self.active_batches:
batch = self.active_batches[client_id]
self.batcher.ids_queue.extendleft(batch)
del self.active_batches[client_id]
def save_state(self):
with open(SERVICE_STATE_FILE, "wb") as f:
pickle.dump(self.active_batches, f)
def write_json_to_postgres(self, json_data):
json_tuples = [(json_str,) for json_str in json_data]
self.cur.executemany(
"INSERT INTO reddit_data (data) VALUES (%s)",
json_tuples,
)
self.conn.commit()
def serve():
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
id_batcher = IDBatcher()
id_service = IDService(id_batcher)
id_service_pb2_grpc.add_IDServiceServicer_to_server(id_service, server)
# Save state when the service is interrupted.
atexit.register(id_service.save_state)
atexit.register(id_batcher.save_state)
server.add_insecure_port("[::]:50051")
server.start()
server.wait_for_termination()
if __name__ == "__main__":
serve()

View File

@ -0,0 +1,4 @@
db:
user: postgres
password: ""
name: postgres

5
compile_proto.sh 100644
View File

@ -0,0 +1,5 @@
#!/usr/bin/env bash
full_path=$(realpath $0)
dir=$(dirname $full_path)
python -m grpc_tools.protoc --python_out=$dir --grpc_python_out=$dir $dir/id_service.proto -I=$dir

171
fetcher.py 100644
View File

@ -0,0 +1,171 @@
import json
import logging
import os
import time
from logging.handlers import RotatingFileHandler
import grpc
import requests
import yaml
from fake_useragent import UserAgent
import id_service_pb2
import id_service_pb2_grpc
# Load config.yaml.
DIR = os.path.dirname(os.path.realpath(__file__))
with open(os.path.join(DIR, "fetcher_config.yaml")) as f:
config = yaml.load(f, yaml.FullLoader)
log_file = os.path.join(DIR, "logs", "fetcher.log")
# Configure logger.
logger = logging.getLogger("fetcher")
logger.setLevel(logging.DEBUG)
formatter = logging.Formatter(
fmt="%(asctime)s %(levelname)-8s %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
)
# Add rotating file handler.
rh = RotatingFileHandler(log_file, maxBytes=10 * 10**6, backupCount=20)
rh.setLevel(logging.DEBUG)
rh.setFormatter(formatter)
logger.addHandler(rh)
# Add stream handler.
sh = logging.StreamHandler()
sh.setLevel(logging.INFO)
sh.setFormatter(formatter)
logger.addHandler(sh)
class RedditClient:
ACCESS_TOKEN_URL = "https://www.reddit.com/api/v1/access_token"
INFO_API_URL = "https://oauth.reddit.com/api/info.json"
BACKOFF_INCREMENT = 5
MIN_DELAY = 1
def __init__(self):
self.headers = {"User-Agent": UserAgent().random}
self.last_request = time.time()
self.backoff = 0
self.ratelimit_remaining = 300
self.ratelimit_reset = time.time()
self.token = None
self.token_expiration = time.time()
def sleep_until_ready(self):
if self.ratelimit_remaining == 0 and time.time() < self.ratelimit_reset:
time.sleep(self.ratelimit_reset - time.time())
delay = max(
0, self.MIN_DELAY + self.backoff - (time.time() - self.last_request)
)
time.sleep(delay + self.backoff)
def authorize(self):
self.sleep_until_ready()
self.last_request = time.time()
res = requests.post(
self.ACCESS_TOKEN_URL,
params={
"grant_type": "refresh_token",
"refresh_token": config["reddit"]["refresh_token"],
},
headers=self.headers,
auth=(config["reddit"]["client_id"], config["reddit"]["client_secret"]),
)
if res.status_code == 200:
logger.info("Auth successful.")
self.backoff = 0
self.token = res.json()["access_token"]
self.token_expiration = time.time() + res.json()["expires_in"]
else:
self.backoff += self.BACKOFF_INCREMENT
logger.error("Auth failed.")
def update_ratelimit(self, headers):
if "x-ratelimit-remaining" in headers:
self.ratelimit_remaining = int(float(headers["x-ratelimit-remaining"]))
self.ratelimit_reset = time.time() + int(
float(headers["x-ratelimit-reset"])
)
def fetch(self, ids):
if self.token is None or self.token_expiration - time.time() < 60:
self.authorize()
return None
self.sleep_until_ready()
self.last_request = time.time()
params = {"id": ",".join(ids)}
logger.debug(f"GET /api/info.json?id={params['id']}")
res = requests.get(
self.INFO_API_URL,
params=params,
headers=self.headers | {"Authorization": f"bearer {self.token}"},
)
if res.status_code == 200:
self.backoff = 0
self.update_ratelimit(res.headers)
logger.debug(f"Response: {res.text}")
items = res.json()["data"]["children"]
return [json.dumps(item) for item in items]
else:
self.update_ratelimit(res.headers)
self.backoff += self.BACKOFF_INCREMENT
logger.error(f"Bad status: {res.status_code}. Backoff: {self.backoff}.")
return None
class FetcherClient:
def __init__(self, client_id, reddit_client):
self.client_id = client_id
self.reddit_client = reddit_client
self.channel = grpc.insecure_channel(
config["distributor_uri"],
options=[
("grpc.max_send_message_length", 100 * 10**6),
("grpc.max_receive_message_length", 1 * 10**6),
],
)
self.stub = id_service_pb2_grpc.IDServiceStub(self.channel)
def request_and_process_batch(self):
# Request a batch of IDs from the server.
response = self.stub.GetBatch(
id_service_pb2.BatchRequest(client_id=self.client_id)
)
# Fetch data from reddit using the received IDs.
data = None
while data is None:
data = self.reddit_client.fetch(response.ids)
# Submit batch to the server.
self.stub.SubmitBatch(
id_service_pb2.SubmitRequest(client_id=self.client_id, data=data)
)
def run(self):
# Continuously request and process batches.
while True:
try:
self.request_and_process_batch()
time.sleep(1)
except grpc.RpcError as e:
print(f"gRPC Error: {e.details()}")
time.sleep(10)
if __name__ == "__main__":
reddit_client = RedditClient()
fetcher_client = FetcherClient(config["id"], reddit_client)
fetcher_client.run()

View File

@ -0,0 +1,6 @@
id: fetcher-1
distributor_uri: localhost:50051
reddit:
client_id: abc
client_secret: abc
refresh_token: 123-abc

32
id_service.proto 100644
View File

@ -0,0 +1,32 @@
syntax = "proto3";
package IDService;
service IDService {
// Fetch a batch of IDs.
rpc GetBatch (BatchRequest) returns (BatchResponse) {}
// Submit processed batch of data.
rpc SubmitBatch (SubmitRequest) returns (SubmitResponse) {}
}
// The BatchRequest message contains the client id.
message BatchRequest {
string client_id = 1;
}
// The BatchResponse message contains the IDs.
message BatchResponse {
repeated string ids = 1;
}
// The SubmitRequest message contains the client id and a batch of data.
message SubmitRequest {
string client_id = 1;
repeated string data = 2;
}
// The SubmitResponse message confirms successful batch processing.
message SubmitResponse {
bool success = 1;
}

16
readme.md 100644
View File

@ -0,0 +1,16 @@
```bash
# ...install PostgreSQL...
pip install -r requirements.txt
# ...modify example yamls...
mv batcher_config.example.yaml batcher_config.yaml
mv fetcher_config.example.yaml fetcher_config.yaml
mkdir logs
bash compile_proto.sh
# Then run both of these files:
python batcher.py
python fetcher.py
```
Getting a refresh token:
https://praw.readthedocs.io/en/stable/tutorials/refresh_token.html#obtaining-refresh-tokens

10
requirements.txt 100644
View File

@ -0,0 +1,10 @@
fake_useragent==1.1.3
grpcio==1.56.0
grpcio-tools==1.56.0
numpy==1.24.2
praw==7.7.0
protobuf==4.23.4
psycopg2_binary==2.9.6
PyYAML==6.0
Requests==2.31.0
schedule==1.2.0