68 lines
2.0 KiB
Python
68 lines
2.0 KiB
Python
import os
|
|
|
|
import requests
|
|
import utils
|
|
from model import Model
|
|
from rich import traceback
|
|
|
|
traceback.install()
|
|
|
|
|
|
def generate(model, config, context):
|
|
print(
|
|
f"Generating reply for https://rdrama.net/comment/{context['thread'][-1]['id']}"
|
|
)
|
|
prompt = utils.build_prompt(config, context["post"], context["thread"])
|
|
|
|
candidates = []
|
|
rejects = []
|
|
while len(candidates) < config["num_candidates"]:
|
|
gen_text = model.generate(prompt)
|
|
reply = utils.extract_reply(gen_text)
|
|
print(f"Generated text: {gen_text}\nReply:\n{reply}")
|
|
reply = utils.format_reply(config, reply)
|
|
|
|
if len(reply) == 0:
|
|
print("Retrying: reply empty after processing.")
|
|
rejects.append(reply)
|
|
elif utils.is_low_quality(reply, context["post"], context["thread"]):
|
|
print("Retrying: low quality reply.")
|
|
rejects.append(reply)
|
|
else:
|
|
candidates.append(reply)
|
|
print("Accepting reply.")
|
|
|
|
if any(not utils.contains_url(c) for c in candidates):
|
|
for candidate in candidates:
|
|
if utils.contains_url(candidate):
|
|
rejects.append(candidate)
|
|
|
|
candidates = [c for c in candidates if not utils.contains_url(c)]
|
|
|
|
body = utils.median_by_key(candidates, key=utils.reply_length)
|
|
|
|
return {"prompt": prompt, "body": body, "candidates": candidates + rejects}
|
|
|
|
|
|
def process_queue():
|
|
response = requests.get("https://rdra.ma/bussyboy").json()
|
|
|
|
config = response["config"]
|
|
queue = response["queue"]
|
|
|
|
model = Model()
|
|
|
|
for request_id, context in queue.items():
|
|
json_data = generate(model, config, context)
|
|
|
|
post_response = requests.post(
|
|
f"https://rdra.ma/bussyboy/{request_id}?key={os.getenv('SERVER_KEY')}",
|
|
json=json_data,
|
|
headers={"Content-Type": "application/json"},
|
|
)
|
|
print(f"Response for request {request_id}: {post_response.status_code}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
process_queue()
|