176 lines
5.3 KiB
Python
176 lines
5.3 KiB
Python
import json
|
|
import logging
|
|
import os
|
|
import time
|
|
from logging.handlers import RotatingFileHandler
|
|
|
|
import grpc
|
|
import requests
|
|
import yaml
|
|
from fake_useragent import UserAgent
|
|
|
|
import id_service_pb2
|
|
import id_service_pb2_grpc
|
|
|
|
# Load config.yaml.
|
|
DIR = os.path.dirname(os.path.realpath(__file__))
|
|
with open(os.path.join(DIR, "fetcher_config.yaml")) as f:
|
|
config = yaml.load(f, yaml.FullLoader)
|
|
|
|
log_file = os.path.join(DIR, "logs", "fetcher.log")
|
|
|
|
# Configure logger.
|
|
logger = logging.getLogger("fetcher")
|
|
logger.setLevel(logging.DEBUG)
|
|
formatter = logging.Formatter(
|
|
fmt="%(asctime)s %(levelname)-8s %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
|
|
)
|
|
|
|
# Add rotating file handler.
|
|
rh = RotatingFileHandler(log_file, maxBytes=10 * 10**6, backupCount=20)
|
|
rh.setLevel(logging.DEBUG)
|
|
rh.setFormatter(formatter)
|
|
logger.addHandler(rh)
|
|
|
|
# Add stream handler.
|
|
sh = logging.StreamHandler()
|
|
sh.setLevel(logging.INFO)
|
|
sh.setFormatter(formatter)
|
|
logger.addHandler(sh)
|
|
|
|
|
|
class RedditClient:
|
|
ACCESS_TOKEN_URL = "https://www.reddit.com/api/v1/access_token"
|
|
INFO_API_URL = "https://oauth.reddit.com/api/info.json"
|
|
BACKOFF_INCREMENT = 5
|
|
MIN_DELAY = 1
|
|
|
|
def __init__(self):
|
|
self.headers = {"User-Agent": UserAgent().random}
|
|
self.last_request = time.time()
|
|
self.backoff = 0
|
|
|
|
self.ratelimit_remaining = 300
|
|
self.ratelimit_reset = time.time()
|
|
|
|
self.token = None
|
|
self.token_expiration = time.time()
|
|
|
|
def sleep_until_ready(self):
|
|
if self.ratelimit_remaining == 0 and time.time() < self.ratelimit_reset:
|
|
time.sleep(self.ratelimit_reset - time.time())
|
|
|
|
delay = max(
|
|
0, self.MIN_DELAY + self.backoff - (time.time() - self.last_request)
|
|
)
|
|
|
|
time.sleep(delay + self.backoff)
|
|
|
|
def authorize(self):
|
|
self.sleep_until_ready()
|
|
self.last_request = time.time()
|
|
|
|
res = requests.post(
|
|
self.ACCESS_TOKEN_URL,
|
|
params={
|
|
"grant_type": "refresh_token",
|
|
"refresh_token": config["reddit"]["refresh_token"],
|
|
},
|
|
headers=self.headers,
|
|
auth=(config["reddit"]["client_id"], config["reddit"]["client_secret"]),
|
|
)
|
|
|
|
if res.status_code == 200:
|
|
logger.info("Auth successful.")
|
|
self.backoff = 0
|
|
self.token = res.json()["access_token"]
|
|
self.token_expiration = time.time() + res.json()["expires_in"]
|
|
else:
|
|
self.backoff += self.BACKOFF_INCREMENT
|
|
logger.error("Auth failed.")
|
|
|
|
def update_ratelimit(self, headers):
|
|
if "x-ratelimit-remaining" in headers:
|
|
self.ratelimit_remaining = int(float(headers["x-ratelimit-remaining"]))
|
|
self.ratelimit_reset = time.time() + int(
|
|
float(headers["x-ratelimit-reset"])
|
|
)
|
|
|
|
def fetch(self, ids):
|
|
if self.token is None or self.token_expiration - time.time() < 60:
|
|
self.authorize()
|
|
return None
|
|
|
|
self.sleep_until_ready()
|
|
self.last_request = time.time()
|
|
|
|
params = {"id": ",".join(ids)}
|
|
logger.debug(f"GET /api/info.json?id={params['id']}")
|
|
res = requests.get(
|
|
self.INFO_API_URL,
|
|
params=params,
|
|
headers=self.headers | {"Authorization": f"bearer {self.token}"},
|
|
)
|
|
|
|
if res.status_code == 200:
|
|
self.backoff = 0
|
|
self.update_ratelimit(res.headers)
|
|
logger.debug(f"Response: {res.text}")
|
|
items = res.json()["data"]["children"]
|
|
return [json.dumps(item) for item in items]
|
|
else:
|
|
self.update_ratelimit(res.headers)
|
|
self.backoff += self.BACKOFF_INCREMENT
|
|
logger.error(f"Bad status: {res.status_code}. Backoff: {self.backoff}.")
|
|
return None
|
|
|
|
|
|
class FetcherClient:
|
|
"""
|
|
Retrieves lists of IDs from the batcher, fetches them from reddit, and submits the JSON back.
|
|
"""
|
|
|
|
def __init__(self, client_id, reddit_client):
|
|
self.client_id = client_id
|
|
self.reddit_client = reddit_client
|
|
self.channel = grpc.insecure_channel(
|
|
config["batcher_uri"],
|
|
options=[
|
|
("grpc.max_send_message_length", 100 * 10**6),
|
|
("grpc.max_receive_message_length", 1 * 10**6),
|
|
],
|
|
)
|
|
self.stub = id_service_pb2_grpc.IDServiceStub(self.channel)
|
|
|
|
def request_and_process_batch(self):
|
|
# Request a batch of IDs from the server.
|
|
response = self.stub.GetBatch(
|
|
id_service_pb2.BatchRequest(client_id=self.client_id)
|
|
)
|
|
|
|
# Fetch data from reddit using the received IDs.
|
|
data = None
|
|
while data is None:
|
|
data = self.reddit_client.fetch(response.ids)
|
|
|
|
# Submit batch to the server.
|
|
self.stub.SubmitBatch(
|
|
id_service_pb2.SubmitRequest(client_id=self.client_id, data=data)
|
|
)
|
|
|
|
def run(self):
|
|
# Continuously request and process batches.
|
|
while True:
|
|
try:
|
|
self.request_and_process_batch()
|
|
time.sleep(1)
|
|
except grpc.RpcError as e:
|
|
print(f"gRPC Error: {e.details()}")
|
|
time.sleep(10)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
reddit_client = RedditClient()
|
|
fetcher_client = FetcherClient(config["id"], reddit_client)
|
|
fetcher_client.run()
|