Initial commit.
commit
04ea9c9e5e
|
@ -0,0 +1,220 @@
|
|||
import atexit
|
||||
import os
|
||||
import pickle
|
||||
import threading
|
||||
import time
|
||||
from collections import deque
|
||||
from concurrent import futures
|
||||
from itertools import zip_longest
|
||||
|
||||
import grpc
|
||||
import numpy as np
|
||||
import psycopg2
|
||||
import requests
|
||||
import schedule
|
||||
import yaml
|
||||
from fake_useragent import UserAgent
|
||||
|
||||
import id_service_pb2
|
||||
import id_service_pb2_grpc
|
||||
|
||||
DIR = os.path.dirname(os.path.realpath(__file__))
|
||||
BATCHER_STATE_FILE = os.path.join(DIR, "batcher_state.pkl")
|
||||
SERVICE_STATE_FILE = os.path.join(DIR, "id_service_state.pkl")
|
||||
|
||||
with open(os.path.join(DIR, "batcher_config.yaml")) as f:
|
||||
config = yaml.load(f, yaml.FullLoader)
|
||||
|
||||
|
||||
class IDBatcher:
|
||||
def __init__(self):
|
||||
self.lock = threading.Lock()
|
||||
self.ids_queue = deque()
|
||||
self.ua = UserAgent()
|
||||
|
||||
# Load state from file, if exists.
|
||||
if os.path.exists(BATCHER_STATE_FILE):
|
||||
with open(BATCHER_STATE_FILE, "rb") as f:
|
||||
self.last_post_id, self.last_comment_id, self.ids_queue = pickle.load(f)
|
||||
else:
|
||||
self.last_post_id = self.fetch_newest_id("/r/all/new.json")
|
||||
self.last_comment_id = self.fetch_newest_id("/r/all/comments.json")
|
||||
self.save_state()
|
||||
|
||||
self.fetch_new_ids()
|
||||
|
||||
schedule.every(5).seconds.do(self.fetch_new_ids)
|
||||
|
||||
thread = threading.Thread(target=self.start_scheduler)
|
||||
thread.start()
|
||||
|
||||
def start_scheduler(self):
|
||||
while True:
|
||||
schedule.run_pending()
|
||||
time.sleep(1)
|
||||
|
||||
def fetch_newest_id(self, endpoint):
|
||||
reddit_id = requests.get(
|
||||
f"https://www.reddit.com{endpoint}",
|
||||
headers={"User-Agent": self.ua.random},
|
||||
).json()["data"]["children"][0]["data"]["name"]
|
||||
|
||||
return self.reddit_id_to_int(reddit_id)
|
||||
|
||||
def fetch_new_ids(self):
|
||||
with self.lock:
|
||||
new_post_id = self.fetch_newest_id("/r/all/new.json")
|
||||
new_comment_id = self.fetch_newest_id("/r/all/comments.json")
|
||||
|
||||
new_post_ids = [
|
||||
self.int_to_reddit_id("t3", i)
|
||||
for i in range(self.last_post_id + 1, new_post_id + 1)
|
||||
]
|
||||
new_comment_ids = [
|
||||
self.int_to_reddit_id("t1", i)
|
||||
for i in range(self.last_comment_id + 1, new_comment_id + 1)
|
||||
]
|
||||
|
||||
self.ids_queue.extend(self.interleave(new_post_ids, new_comment_ids))
|
||||
self.last_post_id = new_post_id
|
||||
self.last_comment_id = new_comment_id
|
||||
|
||||
print("Queue size:", len(self.ids_queue))
|
||||
self.save_state()
|
||||
|
||||
def get_batch(self):
|
||||
with self.lock:
|
||||
batch = [
|
||||
self.ids_queue.popleft() for _ in range(min(100, len(self.ids_queue)))
|
||||
]
|
||||
return batch
|
||||
|
||||
def save_state(self):
|
||||
with open(BATCHER_STATE_FILE, "wb") as f:
|
||||
pickle.dump((self.last_post_id, self.last_comment_id, self.ids_queue), f)
|
||||
|
||||
@staticmethod
|
||||
def reddit_id_to_int(reddit_id):
|
||||
_prefix, base36 = reddit_id.split("_")
|
||||
return int(base36, 36)
|
||||
|
||||
@staticmethod
|
||||
def int_to_reddit_id(id_type, i):
|
||||
id_b36 = np.base_repr(i, 36)
|
||||
return f"{id_type}_{id_b36}".lower()
|
||||
|
||||
@staticmethod
|
||||
def interleave(list1, list2):
|
||||
return [
|
||||
item
|
||||
for pair in zip_longest(list1, list2)
|
||||
for item in pair
|
||||
if item is not None
|
||||
]
|
||||
|
||||
|
||||
class IDService(id_service_pb2_grpc.IDServiceServicer):
|
||||
def __init__(self, batcher):
|
||||
self.batcher = batcher
|
||||
self.active_batches = {}
|
||||
self.timers = {}
|
||||
self.lock = threading.Lock()
|
||||
|
||||
self.conn = psycopg2.connect(
|
||||
f"dbname={config['db']['name']} "
|
||||
f"user={config['db']['user']} "
|
||||
f"password={config['db']['password']}"
|
||||
)
|
||||
self.cur = self.conn.cursor()
|
||||
self.cur.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS reddit_data(
|
||||
id SERIAL PRIMARY KEY,
|
||||
data JSONB NOT NULL
|
||||
)
|
||||
"""
|
||||
)
|
||||
self.conn.commit()
|
||||
|
||||
if os.path.exists(SERVICE_STATE_FILE):
|
||||
with open(SERVICE_STATE_FILE, "rb") as f:
|
||||
self.active_batches = pickle.load(f)
|
||||
|
||||
def GetBatch(self, request, context):
|
||||
client_id = request.client_id
|
||||
|
||||
with self.lock:
|
||||
# Check if there is an unconfirmed batch for the client.
|
||||
if client_id in self.active_batches:
|
||||
batch = self.active_batches[client_id]
|
||||
else:
|
||||
batch = self.batcher.get_batch()
|
||||
self.active_batches[client_id] = batch
|
||||
|
||||
# Cancel any existing timer for this client.
|
||||
if client_id in self.timers:
|
||||
self.timers[client_id].cancel()
|
||||
|
||||
# Start a new timer for the client.
|
||||
self.timers[client_id] = threading.Timer(
|
||||
120, self.return_batch, [client_id]
|
||||
)
|
||||
self.timers[client_id].start()
|
||||
|
||||
return id_service_pb2.BatchResponse(ids=batch)
|
||||
|
||||
def SubmitBatch(self, request, context):
|
||||
client_id = request.client_id
|
||||
|
||||
with self.lock:
|
||||
if client_id in self.active_batches:
|
||||
# Process the submitted data.
|
||||
self.write_json_to_postgres(request.data)
|
||||
|
||||
# Remove batch from active batches for the client.
|
||||
del self.active_batches[client_id]
|
||||
|
||||
# Cancel the timer for this client.
|
||||
if client_id in self.timers:
|
||||
self.timers[client_id].cancel()
|
||||
del self.timers[client_id]
|
||||
|
||||
return id_service_pb2.SubmitResponse(success=True)
|
||||
|
||||
def return_batch(self, client_id):
|
||||
with self.lock:
|
||||
if client_id in self.active_batches:
|
||||
batch = self.active_batches[client_id]
|
||||
self.batcher.ids_queue.extendleft(batch)
|
||||
del self.active_batches[client_id]
|
||||
|
||||
def save_state(self):
|
||||
with open(SERVICE_STATE_FILE, "wb") as f:
|
||||
pickle.dump(self.active_batches, f)
|
||||
|
||||
def write_json_to_postgres(self, json_data):
|
||||
json_tuples = [(json_str,) for json_str in json_data]
|
||||
self.cur.executemany(
|
||||
"INSERT INTO reddit_data (data) VALUES (%s)",
|
||||
json_tuples,
|
||||
)
|
||||
self.conn.commit()
|
||||
|
||||
|
||||
def serve():
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
|
||||
id_batcher = IDBatcher()
|
||||
id_service = IDService(id_batcher)
|
||||
id_service_pb2_grpc.add_IDServiceServicer_to_server(id_service, server)
|
||||
|
||||
# Save state when the service is interrupted.
|
||||
atexit.register(id_service.save_state)
|
||||
atexit.register(id_batcher.save_state)
|
||||
|
||||
server.add_insecure_port("[::]:50051")
|
||||
server.start()
|
||||
server.wait_for_termination()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
serve()
|
|
@ -0,0 +1,4 @@
|
|||
db:
|
||||
user: postgres
|
||||
password: ""
|
||||
name: postgres
|
|
@ -0,0 +1,5 @@
|
|||
#!/usr/bin/env bash
|
||||
|
||||
full_path=$(realpath $0)
|
||||
dir=$(dirname $full_path)
|
||||
python -m grpc_tools.protoc --python_out=$dir --grpc_python_out=$dir $dir/id_service.proto -I=$dir
|
|
@ -0,0 +1,171 @@
|
|||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from logging.handlers import RotatingFileHandler
|
||||
|
||||
import grpc
|
||||
import requests
|
||||
import yaml
|
||||
from fake_useragent import UserAgent
|
||||
|
||||
import id_service_pb2
|
||||
import id_service_pb2_grpc
|
||||
|
||||
# Load config.yaml.
|
||||
DIR = os.path.dirname(os.path.realpath(__file__))
|
||||
with open(os.path.join(DIR, "fetcher_config.yaml")) as f:
|
||||
config = yaml.load(f, yaml.FullLoader)
|
||||
|
||||
log_file = os.path.join(DIR, "logs", "fetcher.log")
|
||||
|
||||
# Configure logger.
|
||||
logger = logging.getLogger("fetcher")
|
||||
logger.setLevel(logging.DEBUG)
|
||||
formatter = logging.Formatter(
|
||||
fmt="%(asctime)s %(levelname)-8s %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
|
||||
)
|
||||
|
||||
# Add rotating file handler.
|
||||
rh = RotatingFileHandler(log_file, maxBytes=10 * 10**6, backupCount=20)
|
||||
rh.setLevel(logging.DEBUG)
|
||||
rh.setFormatter(formatter)
|
||||
logger.addHandler(rh)
|
||||
|
||||
# Add stream handler.
|
||||
sh = logging.StreamHandler()
|
||||
sh.setLevel(logging.INFO)
|
||||
sh.setFormatter(formatter)
|
||||
logger.addHandler(sh)
|
||||
|
||||
|
||||
class RedditClient:
|
||||
ACCESS_TOKEN_URL = "https://www.reddit.com/api/v1/access_token"
|
||||
INFO_API_URL = "https://oauth.reddit.com/api/info.json"
|
||||
BACKOFF_INCREMENT = 5
|
||||
MIN_DELAY = 1
|
||||
|
||||
def __init__(self):
|
||||
self.headers = {"User-Agent": UserAgent().random}
|
||||
self.last_request = time.time()
|
||||
self.backoff = 0
|
||||
|
||||
self.ratelimit_remaining = 300
|
||||
self.ratelimit_reset = time.time()
|
||||
|
||||
self.token = None
|
||||
self.token_expiration = time.time()
|
||||
|
||||
def sleep_until_ready(self):
|
||||
if self.ratelimit_remaining == 0 and time.time() < self.ratelimit_reset:
|
||||
time.sleep(self.ratelimit_reset - time.time())
|
||||
|
||||
delay = max(
|
||||
0, self.MIN_DELAY + self.backoff - (time.time() - self.last_request)
|
||||
)
|
||||
|
||||
time.sleep(delay + self.backoff)
|
||||
|
||||
def authorize(self):
|
||||
self.sleep_until_ready()
|
||||
self.last_request = time.time()
|
||||
|
||||
res = requests.post(
|
||||
self.ACCESS_TOKEN_URL,
|
||||
params={
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": config["reddit"]["refresh_token"],
|
||||
},
|
||||
headers=self.headers,
|
||||
auth=(config["reddit"]["client_id"], config["reddit"]["client_secret"]),
|
||||
)
|
||||
|
||||
if res.status_code == 200:
|
||||
logger.info("Auth successful.")
|
||||
self.backoff = 0
|
||||
self.token = res.json()["access_token"]
|
||||
self.token_expiration = time.time() + res.json()["expires_in"]
|
||||
else:
|
||||
self.backoff += self.BACKOFF_INCREMENT
|
||||
logger.error("Auth failed.")
|
||||
|
||||
def update_ratelimit(self, headers):
|
||||
if "x-ratelimit-remaining" in headers:
|
||||
self.ratelimit_remaining = int(float(headers["x-ratelimit-remaining"]))
|
||||
self.ratelimit_reset = time.time() + int(
|
||||
float(headers["x-ratelimit-reset"])
|
||||
)
|
||||
|
||||
def fetch(self, ids):
|
||||
if self.token is None or self.token_expiration - time.time() < 60:
|
||||
self.authorize()
|
||||
return None
|
||||
|
||||
self.sleep_until_ready()
|
||||
self.last_request = time.time()
|
||||
|
||||
params = {"id": ",".join(ids)}
|
||||
logger.debug(f"GET /api/info.json?id={params['id']}")
|
||||
res = requests.get(
|
||||
self.INFO_API_URL,
|
||||
params=params,
|
||||
headers=self.headers | {"Authorization": f"bearer {self.token}"},
|
||||
)
|
||||
|
||||
if res.status_code == 200:
|
||||
self.backoff = 0
|
||||
self.update_ratelimit(res.headers)
|
||||
logger.debug(f"Response: {res.text}")
|
||||
items = res.json()["data"]["children"]
|
||||
return [json.dumps(item) for item in items]
|
||||
else:
|
||||
self.update_ratelimit(res.headers)
|
||||
self.backoff += self.BACKOFF_INCREMENT
|
||||
logger.error(f"Bad status: {res.status_code}. Backoff: {self.backoff}.")
|
||||
return None
|
||||
|
||||
|
||||
class FetcherClient:
|
||||
def __init__(self, client_id, reddit_client):
|
||||
self.client_id = client_id
|
||||
self.reddit_client = reddit_client
|
||||
self.channel = grpc.insecure_channel(
|
||||
config["distributor_uri"],
|
||||
options=[
|
||||
("grpc.max_send_message_length", 100 * 10**6),
|
||||
("grpc.max_receive_message_length", 1 * 10**6),
|
||||
],
|
||||
)
|
||||
self.stub = id_service_pb2_grpc.IDServiceStub(self.channel)
|
||||
|
||||
def request_and_process_batch(self):
|
||||
# Request a batch of IDs from the server.
|
||||
response = self.stub.GetBatch(
|
||||
id_service_pb2.BatchRequest(client_id=self.client_id)
|
||||
)
|
||||
|
||||
# Fetch data from reddit using the received IDs.
|
||||
data = None
|
||||
while data is None:
|
||||
data = self.reddit_client.fetch(response.ids)
|
||||
|
||||
# Submit batch to the server.
|
||||
self.stub.SubmitBatch(
|
||||
id_service_pb2.SubmitRequest(client_id=self.client_id, data=data)
|
||||
)
|
||||
|
||||
def run(self):
|
||||
# Continuously request and process batches.
|
||||
while True:
|
||||
try:
|
||||
self.request_and_process_batch()
|
||||
time.sleep(1)
|
||||
except grpc.RpcError as e:
|
||||
print(f"gRPC Error: {e.details()}")
|
||||
time.sleep(10)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
reddit_client = RedditClient()
|
||||
fetcher_client = FetcherClient(config["id"], reddit_client)
|
||||
fetcher_client.run()
|
|
@ -0,0 +1,6 @@
|
|||
id: fetcher-1
|
||||
distributor_uri: localhost:50051
|
||||
reddit:
|
||||
client_id: abc
|
||||
client_secret: abc
|
||||
refresh_token: 123-abc
|
|
@ -0,0 +1,32 @@
|
|||
syntax = "proto3";
|
||||
|
||||
package IDService;
|
||||
|
||||
service IDService {
|
||||
// Fetch a batch of IDs.
|
||||
rpc GetBatch (BatchRequest) returns (BatchResponse) {}
|
||||
|
||||
// Submit processed batch of data.
|
||||
rpc SubmitBatch (SubmitRequest) returns (SubmitResponse) {}
|
||||
}
|
||||
|
||||
// The BatchRequest message contains the client id.
|
||||
message BatchRequest {
|
||||
string client_id = 1;
|
||||
}
|
||||
|
||||
// The BatchResponse message contains the IDs.
|
||||
message BatchResponse {
|
||||
repeated string ids = 1;
|
||||
}
|
||||
|
||||
// The SubmitRequest message contains the client id and a batch of data.
|
||||
message SubmitRequest {
|
||||
string client_id = 1;
|
||||
repeated string data = 2;
|
||||
}
|
||||
|
||||
// The SubmitResponse message confirms successful batch processing.
|
||||
message SubmitResponse {
|
||||
bool success = 1;
|
||||
}
|
|
@ -0,0 +1,16 @@
|
|||
```bash
|
||||
# ...install PostgreSQL...
|
||||
pip install -r requirements.txt
|
||||
# ...modify example yamls...
|
||||
mv batcher_config.example.yaml batcher_config.yaml
|
||||
mv fetcher_config.example.yaml fetcher_config.yaml
|
||||
mkdir logs
|
||||
bash compile_proto.sh
|
||||
|
||||
# Then run both of these files:
|
||||
python batcher.py
|
||||
python fetcher.py
|
||||
```
|
||||
|
||||
Getting a refresh token:
|
||||
https://praw.readthedocs.io/en/stable/tutorials/refresh_token.html#obtaining-refresh-tokens
|
|
@ -0,0 +1,10 @@
|
|||
fake_useragent==1.1.3
|
||||
grpcio==1.56.0
|
||||
grpcio-tools==1.56.0
|
||||
numpy==1.24.2
|
||||
praw==7.7.0
|
||||
protobuf==4.23.4
|
||||
psycopg2_binary==2.9.6
|
||||
PyYAML==6.0
|
||||
Requests==2.31.0
|
||||
schedule==1.2.0
|
Loading…
Reference in New Issue