135 lines
4.0 KiB
Python
135 lines
4.0 KiB
Python
import os
|
|
import random
|
|
import sys
|
|
|
|
import requests
|
|
from rich import traceback
|
|
|
|
import model
|
|
import utils
|
|
from client import DramaClient
|
|
from data import config, db
|
|
|
|
traceback.install()
|
|
|
|
|
|
class Bot:
|
|
def __init__(self):
|
|
print("Loading model...")
|
|
self.model = model.Model()
|
|
self.client = DramaClient()
|
|
print("Ready.")
|
|
|
|
def post_random_reply(self):
|
|
print("Looking for comments...")
|
|
|
|
comments = [
|
|
c
|
|
for c in self.client.fetch_new_comments()
|
|
if "author_name" in c
|
|
and not c["is_bot"]
|
|
and c["author_name"] != config["username"]
|
|
and c["author_id"] not in config["ignore_user_ids"]
|
|
and c["post_id"] != 0
|
|
]
|
|
|
|
if len(comments) == 0:
|
|
print("No comments found.")
|
|
return
|
|
|
|
random.shuffle(comments)
|
|
comments = comments[: config["num_replies"]]
|
|
for comment in comments:
|
|
self.reply(comment)
|
|
|
|
def respond_to_replies(self):
|
|
replies = self.client.fetch_new_replies()
|
|
|
|
for reply in replies:
|
|
if "author_name" not in reply:
|
|
continue
|
|
|
|
if not reply["is_bot"]:
|
|
try:
|
|
self.reply(reply)
|
|
except requests.exceptions.RequestException as e:
|
|
print(f"Error while replying: {e}")
|
|
|
|
def make_forced_replies(self):
|
|
file_path = f"{config['data_dir']}/forced.txt"
|
|
if not os.path.isfile(file_path):
|
|
return
|
|
|
|
with open(file_path, "r") as f:
|
|
lines = f.read().splitlines()
|
|
|
|
for comment_id in lines:
|
|
comment = self.client.get(f"/comment/{comment_id}")
|
|
try:
|
|
self.reply(comment)
|
|
except requests.exceptions.RequestException as e:
|
|
print(f"Error while replying: {e}")
|
|
|
|
os.remove(file_path)
|
|
|
|
def reply(self, comment):
|
|
print(f"Generating reply for https://rdrama.net/comment/{comment['id']}")
|
|
|
|
post, thread_comments = self.client.fetch_context(comment)
|
|
|
|
if not post or not thread_comments:
|
|
print("Could not fetch context!")
|
|
return
|
|
|
|
prompt = utils.build_prompt(post, thread_comments)
|
|
|
|
candidates = []
|
|
rejects = []
|
|
while len(candidates) < config["num_candidates"]:
|
|
gen_text = self.model.generate(prompt)
|
|
reply = utils.extract_reply(gen_text)
|
|
print(f"Generated text: {gen_text}\nReply:\n{reply}")
|
|
reply = utils.format_reply(reply)
|
|
|
|
if len(reply) == 0:
|
|
print("Retrying: reply empty after processing.")
|
|
rejects.append(reply)
|
|
elif utils.is_low_quality(reply, post, thread_comments):
|
|
print("Retrying: low quality reply.")
|
|
rejects.append(reply)
|
|
else:
|
|
candidates.append(reply)
|
|
print("Accepting reply.")
|
|
|
|
if any(not utils.contains_url(c) for c in candidates):
|
|
for candidate in candidates:
|
|
if utils.contains_url(candidate):
|
|
rejects.append(candidate)
|
|
|
|
candidates = [c for c in candidates if not utils.contains_url(c)]
|
|
reply = utils.median_by_key(candidates, key=utils.reply_length)
|
|
|
|
json = self.client.reply(reply, comment)
|
|
|
|
if "id" not in json:
|
|
print("Error posting reply", json)
|
|
else:
|
|
data = {
|
|
"key": config["web_key"],
|
|
"comment_id": json["id"],
|
|
"prompt": prompt,
|
|
"parent_comment": comment["body"],
|
|
"candidates": candidates + rejects,
|
|
"selected": reply,
|
|
}
|
|
requests.post("https://rdra.ma/submit", json=data)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
bot = Bot()
|
|
bot.make_forced_replies()
|
|
bot.respond_to_replies()
|
|
if len(sys.argv) < 2 or sys.argv[1] != "reply":
|
|
bot.post_random_reply()
|
|
bot.respond_to_replies()
|