pushshift/fetcher.py

176 lines
5.3 KiB
Python

import json
import logging
import os
import time
from logging.handlers import RotatingFileHandler
import grpc
import requests
import yaml
from fake_useragent import UserAgent
import id_service_pb2
import id_service_pb2_grpc
# Load config.yaml.
DIR = os.path.dirname(os.path.realpath(__file__))
with open(os.path.join(DIR, "fetcher_config.yaml")) as f:
config = yaml.load(f, yaml.FullLoader)
log_file = os.path.join(DIR, "logs", "fetcher.log")
# Configure logger.
logger = logging.getLogger("fetcher")
logger.setLevel(logging.DEBUG)
formatter = logging.Formatter(
fmt="%(asctime)s %(levelname)-8s %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
)
# Add rotating file handler.
rh = RotatingFileHandler(log_file, maxBytes=10 * 10**6, backupCount=20)
rh.setLevel(logging.DEBUG)
rh.setFormatter(formatter)
logger.addHandler(rh)
# Add stream handler.
sh = logging.StreamHandler()
sh.setLevel(logging.INFO)
sh.setFormatter(formatter)
logger.addHandler(sh)
class RedditClient:
ACCESS_TOKEN_URL = "https://www.reddit.com/api/v1/access_token"
INFO_API_URL = "https://oauth.reddit.com/api/info.json"
BACKOFF_INCREMENT = 5
MIN_DELAY = 1
def __init__(self):
self.headers = {"User-Agent": UserAgent().random}
self.last_request = time.time()
self.backoff = 0
self.ratelimit_remaining = 300
self.ratelimit_reset = time.time()
self.token = None
self.token_expiration = time.time()
def sleep_until_ready(self):
if self.ratelimit_remaining == 0 and time.time() < self.ratelimit_reset:
time.sleep(self.ratelimit_reset - time.time())
delay = max(
0, self.MIN_DELAY + self.backoff - (time.time() - self.last_request)
)
time.sleep(delay + self.backoff)
def authorize(self):
self.sleep_until_ready()
self.last_request = time.time()
res = requests.post(
self.ACCESS_TOKEN_URL,
params={
"grant_type": "refresh_token",
"refresh_token": config["reddit"]["refresh_token"],
},
headers=self.headers,
auth=(config["reddit"]["client_id"], config["reddit"]["client_secret"]),
)
if res.status_code == 200:
logger.info("Auth successful.")
self.backoff = 0
self.token = res.json()["access_token"]
self.token_expiration = time.time() + res.json()["expires_in"]
else:
self.backoff += self.BACKOFF_INCREMENT
logger.error("Auth failed.")
def update_ratelimit(self, headers):
if "x-ratelimit-remaining" in headers:
self.ratelimit_remaining = int(float(headers["x-ratelimit-remaining"]))
self.ratelimit_reset = time.time() + int(
float(headers["x-ratelimit-reset"])
)
def fetch(self, ids):
if self.token is None or self.token_expiration - time.time() < 60:
self.authorize()
return None
self.sleep_until_ready()
self.last_request = time.time()
params = {"id": ",".join(ids)}
logger.debug(f"GET /api/info.json?id={params['id']}")
res = requests.get(
self.INFO_API_URL,
params=params,
headers=self.headers | {"Authorization": f"bearer {self.token}"},
)
if res.status_code == 200:
self.backoff = 0
self.update_ratelimit(res.headers)
logger.debug(f"Response: {res.text}")
items = res.json()["data"]["children"]
return [json.dumps(item) for item in items]
else:
self.update_ratelimit(res.headers)
self.backoff += self.BACKOFF_INCREMENT
logger.error(f"Bad status: {res.status_code}. Backoff: {self.backoff}.")
return None
class FetcherClient:
"""
Retrieves lists of IDs from the batcher, fetches them from reddit, and submits the JSON back.
"""
def __init__(self, client_id, reddit_client):
self.client_id = client_id
self.reddit_client = reddit_client
self.channel = grpc.insecure_channel(
config["batcher_uri"],
options=[
("grpc.max_send_message_length", 100 * 10**6),
("grpc.max_receive_message_length", 1 * 10**6),
],
)
self.stub = id_service_pb2_grpc.IDServiceStub(self.channel)
def request_and_process_batch(self):
# Request a batch of IDs from the server.
response = self.stub.GetBatch(
id_service_pb2.BatchRequest(client_id=self.client_id)
)
# Fetch data from reddit using the received IDs.
data = None
while data is None:
data = self.reddit_client.fetch(response.ids)
# Submit batch to the server.
self.stub.SubmitBatch(
id_service_pb2.SubmitRequest(client_id=self.client_id, data=data)
)
def run(self):
# Continuously request and process batches.
while True:
try:
self.request_and_process_batch()
time.sleep(1)
except grpc.RpcError as e:
print(f"gRPC Error: {e.details()}")
time.sleep(10)
if __name__ == "__main__":
reddit_client = RedditClient()
fetcher_client = FetcherClient(config["id"], reddit_client)
fetcher_client.run()