Pick the reply with the median length + switch from shelve to SqliteDict.

master
float-trip 2023-07-10 21:13:47 +00:00
parent e549796935
commit e0c7556344
6 changed files with 38 additions and 32 deletions

25
bot.py
View File

@ -7,7 +7,7 @@ import model
import utils import utils
from client import DramaClient from client import DramaClient
from config import config from data import config, db
traceback.install() traceback.install()
@ -24,7 +24,7 @@ class Bot:
comments = [ comments = [
c c
for c in self.client.fetch_new_comments(limit=50) for c in self.client.fetch_new_comments()
if "author_name" in c if "author_name" in c
and not c["is_bot"] and not c["is_bot"]
and c["author_name"] != config["username"] and c["author_name"] != config["username"]
@ -71,13 +71,13 @@ class Bot:
if not post or not thread_comments: if not post or not thread_comments:
print("Could not fetch context!") print("Could not fetch context!")
return
prompt = utils.build_prompt(post, thread_comments) prompt = utils.build_prompt(post, thread_comments)
utils.log_prompt(prompt)
candidates = [] candidates = []
num_candidates = config["num_candidates"] if random.random() < 0.6 else 1 rejects = []
while len(candidates) < num_candidates: while len(candidates) < config["num_candidates"]:
gen_text = self.model.generate(prompt) gen_text = self.model.generate(prompt)
reply = utils.extract_reply(gen_text) reply = utils.extract_reply(gen_text)
print(f"Generated text: {gen_text}\nReply:\n{reply}") print(f"Generated text: {gen_text}\nReply:\n{reply}")
@ -85,14 +85,25 @@ class Bot:
if len(reply) == 0: if len(reply) == 0:
print("Retrying: reply empty after processing.") print("Retrying: reply empty after processing.")
rejects.append(reply)
elif utils.is_low_quality(reply, post, thread_comments): elif utils.is_low_quality(reply, post, thread_comments):
print("Retrying: low quality reply.") print("Retrying: low quality reply.")
rejects
else: else:
candidates.append(reply) candidates.append(reply)
print("Accepting reply.") print("Accepting reply.")
# Get the longest reply, but cap the considered length at 500 chars. reply = utils.median_by_key(candidates, key=utils.reply_length)
reply = max(candidates, key=lambda r: min(utils.reply_length(r), 500))
db["prompts"].append(
{
"prompt": prompt,
"candidates": candidates,
"rejects": rejects,
"selected": reply,
}
)
self.client.reply(reply, comment) self.client.reply(reply, comment)

View File

@ -7,16 +7,13 @@ import shelve
from requests.adapters import HTTPAdapter, Retry from requests.adapters import HTTPAdapter, Retry
from config import config from data import config, db
class DramaClient: class DramaClient:
BASE_URL = "https://rdrama.net" BASE_URL = "https://rdrama.net"
def __init__(self): def __init__(self):
self.db = shelve.open(f"{config['data_dir']}/client_state.p", writeback=True)
self.db.setdefault("processed_replies", set())
self.session = requests.Session() self.session = requests.Session()
retries = Retry( retries = Retry(
total=5, backoff_factor=5, status_forcelist=[500, 502, 503, 504, 521] total=5, backoff_factor=5, status_forcelist=[500, 502, 503, 504, 521]
@ -77,22 +74,21 @@ class DramaClient:
return r.json() return r.json()
def fetch_new_comments(self, limit=0): def fetch_new_comments(self, limit=config["num_replies"] * 25):
comments = [] comments = []
last_processed_id = self.db.get("last_processed_id", -1)
earliest_id = math.inf earliest_id = math.inf
page = 1 page = 1
# Fetch comments until we find the last one processed. # Fetch comments until we find the last one processed.
while earliest_id > last_processed_id: while earliest_id > db["last_processed_id"]:
page_comments = self.fetch_page(page) page_comments = self.fetch_page(page)
if len(page_comments) == 0: if len(page_comments) == 0:
break break
earliest_id = min([c["id"] for c in page_comments]) earliest_id = min([c["id"] for c in page_comments])
comments += [c for c in page_comments if c["id"] > last_processed_id] comments += [c for c in page_comments if c["id"] > db["last_processed_id"]]
if limit > 0 and len(comments) >= limit: if limit > 0 and len(comments) >= limit:
break break
@ -102,8 +98,7 @@ class DramaClient:
if not comments: if not comments:
return [] return []
self.db["last_processed_id"] = max(c["id"] for c in comments) db["last_processed_id"] = max(c["id"] for c in comments)
self.db.sync()
# New comments may have pushed others to page n+1 while fetching. # New comments may have pushed others to page n+1 while fetching.
deduped_comments = {c["id"]: c for c in comments}.values() deduped_comments = {c["id"]: c for c in comments}.values()

View File

@ -1,8 +0,0 @@
import yaml
import os
current_dir = os.path.dirname(os.path.realpath(__file__))
config_path = os.path.join(current_dir, "config.yaml")
with open(config_path, "r") as f:
config = yaml.safe_load(f)

View File

@ -7,7 +7,7 @@ from transformers import (
AutoModelForCausalLM, AutoModelForCausalLM,
LogitsProcessorList, LogitsProcessorList,
) )
from config import config from data import config
class StopAfterPlusIsGenerated(LogitsProcessor): class StopAfterPlusIsGenerated(LogitsProcessor):

View File

@ -2,5 +2,6 @@ fuzzywuzzy==0.18.0
PyYAML==6.0 PyYAML==6.0
Requests==2.31.0 Requests==2.31.0
rich==13.4.2 rich==13.4.2
sqlitedict==2.1.0
torch==2.0.1 torch==2.0.1
transformers==4.31.0 transformers==4.31.0

View File

@ -4,7 +4,7 @@ import re
from fuzzywuzzy import fuzz from fuzzywuzzy import fuzz
from transformers import GPTNeoXTokenizerFast from transformers import GPTNeoXTokenizerFast
from config import config from data import config
from maxsubstring import longest_common_substring from maxsubstring import longest_common_substring
URL_REGEX = ( URL_REGEX = (
@ -131,6 +131,7 @@ def build_prompt(post, comments):
prompt += comment_str prompt += comment_str
prompt = prompt.replace(config["username"], random.choice(config["fake_usernames"])) prompt = prompt.replace(config["username"], random.choice(config["fake_usernames"]))
prompt = prompt.replace("👻", "Ghost")
prompt = prompt.strip() + "\n" prompt = prompt.strip() + "\n"
# Truncate the prompt to leave room for generation. # Truncate the prompt to leave room for generation.
@ -142,11 +143,6 @@ def build_prompt(post, comments):
return prompt return prompt
def log_prompt(prompt):
with open(f"{config['data_dir']}/prompts.txt", "a") as f:
f.write(f"{prompt}\n==========\n")
def reply_length(reply): def reply_length(reply):
"""Return the length of the reply, without Markdown images, URLs, or quoted text.""" """Return the length of the reply, without Markdown images, URLs, or quoted text."""
# Remove Markdown images and URLs. # Remove Markdown images and URLs.
@ -160,6 +156,17 @@ def reply_length(reply):
return len(reply) return len(reply)
def median_by_key(lst, key):
lst = sorted(lst, key=key)
mid_index = len(lst) // 2
# For lists of even length, pick either option as the median.
if len(lst) % 2 == 0:
return random.choice([lst[mid_index - 1], lst[mid_index]])
else:
return lst[mid_index]
def count_tokens(text): def count_tokens(text):
return len(tokenizer(text).input_ids) return len(tokenizer(text).input_ids)