Pick the reply with the median length + switch from shelve to SqliteDict.

master
float-trip 2023-07-10 21:13:47 +00:00
parent e549796935
commit e0c7556344
6 changed files with 38 additions and 32 deletions

25
bot.py
View File

@ -7,7 +7,7 @@ import model
import utils
from client import DramaClient
from config import config
from data import config, db
traceback.install()
@ -24,7 +24,7 @@ class Bot:
comments = [
c
for c in self.client.fetch_new_comments(limit=50)
for c in self.client.fetch_new_comments()
if "author_name" in c
and not c["is_bot"]
and c["author_name"] != config["username"]
@ -71,13 +71,13 @@ class Bot:
if not post or not thread_comments:
print("Could not fetch context!")
return
prompt = utils.build_prompt(post, thread_comments)
utils.log_prompt(prompt)
candidates = []
num_candidates = config["num_candidates"] if random.random() < 0.6 else 1
while len(candidates) < num_candidates:
rejects = []
while len(candidates) < config["num_candidates"]:
gen_text = self.model.generate(prompt)
reply = utils.extract_reply(gen_text)
print(f"Generated text: {gen_text}\nReply:\n{reply}")
@ -85,14 +85,25 @@ class Bot:
if len(reply) == 0:
print("Retrying: reply empty after processing.")
rejects.append(reply)
elif utils.is_low_quality(reply, post, thread_comments):
print("Retrying: low quality reply.")
rejects
else:
candidates.append(reply)
print("Accepting reply.")
# Get the longest reply, but cap the considered length at 500 chars.
reply = max(candidates, key=lambda r: min(utils.reply_length(r), 500))
reply = utils.median_by_key(candidates, key=utils.reply_length)
db["prompts"].append(
{
"prompt": prompt,
"candidates": candidates,
"rejects": rejects,
"selected": reply,
}
)
self.client.reply(reply, comment)

View File

@ -7,16 +7,13 @@ import shelve
from requests.adapters import HTTPAdapter, Retry
from config import config
from data import config, db
class DramaClient:
BASE_URL = "https://rdrama.net"
def __init__(self):
self.db = shelve.open(f"{config['data_dir']}/client_state.p", writeback=True)
self.db.setdefault("processed_replies", set())
self.session = requests.Session()
retries = Retry(
total=5, backoff_factor=5, status_forcelist=[500, 502, 503, 504, 521]
@ -77,22 +74,21 @@ class DramaClient:
return r.json()
def fetch_new_comments(self, limit=0):
def fetch_new_comments(self, limit=config["num_replies"] * 25):
comments = []
last_processed_id = self.db.get("last_processed_id", -1)
earliest_id = math.inf
page = 1
# Fetch comments until we find the last one processed.
while earliest_id > last_processed_id:
while earliest_id > db["last_processed_id"]:
page_comments = self.fetch_page(page)
if len(page_comments) == 0:
break
earliest_id = min([c["id"] for c in page_comments])
comments += [c for c in page_comments if c["id"] > last_processed_id]
comments += [c for c in page_comments if c["id"] > db["last_processed_id"]]
if limit > 0 and len(comments) >= limit:
break
@ -102,8 +98,7 @@ class DramaClient:
if not comments:
return []
self.db["last_processed_id"] = max(c["id"] for c in comments)
self.db.sync()
db["last_processed_id"] = max(c["id"] for c in comments)
# New comments may have pushed others to page n+1 while fetching.
deduped_comments = {c["id"]: c for c in comments}.values()

View File

@ -1,8 +0,0 @@
import yaml
import os
current_dir = os.path.dirname(os.path.realpath(__file__))
config_path = os.path.join(current_dir, "config.yaml")
with open(config_path, "r") as f:
config = yaml.safe_load(f)

View File

@ -7,7 +7,7 @@ from transformers import (
AutoModelForCausalLM,
LogitsProcessorList,
)
from config import config
from data import config
class StopAfterPlusIsGenerated(LogitsProcessor):

View File

@ -2,5 +2,6 @@ fuzzywuzzy==0.18.0
PyYAML==6.0
Requests==2.31.0
rich==13.4.2
sqlitedict==2.1.0
torch==2.0.1
transformers==4.31.0

View File

@ -4,7 +4,7 @@ import re
from fuzzywuzzy import fuzz
from transformers import GPTNeoXTokenizerFast
from config import config
from data import config
from maxsubstring import longest_common_substring
URL_REGEX = (
@ -131,6 +131,7 @@ def build_prompt(post, comments):
prompt += comment_str
prompt = prompt.replace(config["username"], random.choice(config["fake_usernames"]))
prompt = prompt.replace("👻", "Ghost")
prompt = prompt.strip() + "\n"
# Truncate the prompt to leave room for generation.
@ -142,11 +143,6 @@ def build_prompt(post, comments):
return prompt
def log_prompt(prompt):
with open(f"{config['data_dir']}/prompts.txt", "a") as f:
f.write(f"{prompt}\n==========\n")
def reply_length(reply):
"""Return the length of the reply, without Markdown images, URLs, or quoted text."""
# Remove Markdown images and URLs.
@ -160,6 +156,17 @@ def reply_length(reply):
return len(reply)
def median_by_key(lst, key):
lst = sorted(lst, key=key)
mid_index = len(lst) // 2
# For lists of even length, pick either option as the median.
if len(lst) % 2 == 0:
return random.choice([lst[mid_index - 1], lst[mid_index]])
else:
return lst[mid_index]
def count_tokens(text):
return len(tokenizer(text).input_ids)