Pick the reply with the median length + switch from shelve to SqliteDict.
parent
e549796935
commit
e0c7556344
25
bot.py
25
bot.py
|
@ -7,7 +7,7 @@ import model
|
||||||
import utils
|
import utils
|
||||||
from client import DramaClient
|
from client import DramaClient
|
||||||
|
|
||||||
from config import config
|
from data import config, db
|
||||||
|
|
||||||
traceback.install()
|
traceback.install()
|
||||||
|
|
||||||
|
@ -24,7 +24,7 @@ class Bot:
|
||||||
|
|
||||||
comments = [
|
comments = [
|
||||||
c
|
c
|
||||||
for c in self.client.fetch_new_comments(limit=50)
|
for c in self.client.fetch_new_comments()
|
||||||
if "author_name" in c
|
if "author_name" in c
|
||||||
and not c["is_bot"]
|
and not c["is_bot"]
|
||||||
and c["author_name"] != config["username"]
|
and c["author_name"] != config["username"]
|
||||||
|
@ -71,13 +71,13 @@ class Bot:
|
||||||
|
|
||||||
if not post or not thread_comments:
|
if not post or not thread_comments:
|
||||||
print("Could not fetch context!")
|
print("Could not fetch context!")
|
||||||
|
return
|
||||||
|
|
||||||
prompt = utils.build_prompt(post, thread_comments)
|
prompt = utils.build_prompt(post, thread_comments)
|
||||||
utils.log_prompt(prompt)
|
|
||||||
|
|
||||||
candidates = []
|
candidates = []
|
||||||
num_candidates = config["num_candidates"] if random.random() < 0.6 else 1
|
rejects = []
|
||||||
while len(candidates) < num_candidates:
|
while len(candidates) < config["num_candidates"]:
|
||||||
gen_text = self.model.generate(prompt)
|
gen_text = self.model.generate(prompt)
|
||||||
reply = utils.extract_reply(gen_text)
|
reply = utils.extract_reply(gen_text)
|
||||||
print(f"Generated text: {gen_text}\nReply:\n{reply}")
|
print(f"Generated text: {gen_text}\nReply:\n{reply}")
|
||||||
|
@ -85,14 +85,25 @@ class Bot:
|
||||||
|
|
||||||
if len(reply) == 0:
|
if len(reply) == 0:
|
||||||
print("Retrying: reply empty after processing.")
|
print("Retrying: reply empty after processing.")
|
||||||
|
rejects.append(reply)
|
||||||
elif utils.is_low_quality(reply, post, thread_comments):
|
elif utils.is_low_quality(reply, post, thread_comments):
|
||||||
print("Retrying: low quality reply.")
|
print("Retrying: low quality reply.")
|
||||||
|
rejects
|
||||||
else:
|
else:
|
||||||
candidates.append(reply)
|
candidates.append(reply)
|
||||||
print("Accepting reply.")
|
print("Accepting reply.")
|
||||||
|
|
||||||
# Get the longest reply, but cap the considered length at 500 chars.
|
reply = utils.median_by_key(candidates, key=utils.reply_length)
|
||||||
reply = max(candidates, key=lambda r: min(utils.reply_length(r), 500))
|
|
||||||
|
db["prompts"].append(
|
||||||
|
{
|
||||||
|
"prompt": prompt,
|
||||||
|
"candidates": candidates,
|
||||||
|
"rejects": rejects,
|
||||||
|
"selected": reply,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
self.client.reply(reply, comment)
|
self.client.reply(reply, comment)
|
||||||
|
|
||||||
|
|
||||||
|
|
15
client.py
15
client.py
|
@ -7,16 +7,13 @@ import shelve
|
||||||
|
|
||||||
from requests.adapters import HTTPAdapter, Retry
|
from requests.adapters import HTTPAdapter, Retry
|
||||||
|
|
||||||
from config import config
|
from data import config, db
|
||||||
|
|
||||||
|
|
||||||
class DramaClient:
|
class DramaClient:
|
||||||
BASE_URL = "https://rdrama.net"
|
BASE_URL = "https://rdrama.net"
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.db = shelve.open(f"{config['data_dir']}/client_state.p", writeback=True)
|
|
||||||
self.db.setdefault("processed_replies", set())
|
|
||||||
|
|
||||||
self.session = requests.Session()
|
self.session = requests.Session()
|
||||||
retries = Retry(
|
retries = Retry(
|
||||||
total=5, backoff_factor=5, status_forcelist=[500, 502, 503, 504, 521]
|
total=5, backoff_factor=5, status_forcelist=[500, 502, 503, 504, 521]
|
||||||
|
@ -77,22 +74,21 @@ class DramaClient:
|
||||||
|
|
||||||
return r.json()
|
return r.json()
|
||||||
|
|
||||||
def fetch_new_comments(self, limit=0):
|
def fetch_new_comments(self, limit=config["num_replies"] * 25):
|
||||||
comments = []
|
comments = []
|
||||||
|
|
||||||
last_processed_id = self.db.get("last_processed_id", -1)
|
|
||||||
earliest_id = math.inf
|
earliest_id = math.inf
|
||||||
page = 1
|
page = 1
|
||||||
|
|
||||||
# Fetch comments until we find the last one processed.
|
# Fetch comments until we find the last one processed.
|
||||||
while earliest_id > last_processed_id:
|
while earliest_id > db["last_processed_id"]:
|
||||||
page_comments = self.fetch_page(page)
|
page_comments = self.fetch_page(page)
|
||||||
|
|
||||||
if len(page_comments) == 0:
|
if len(page_comments) == 0:
|
||||||
break
|
break
|
||||||
|
|
||||||
earliest_id = min([c["id"] for c in page_comments])
|
earliest_id = min([c["id"] for c in page_comments])
|
||||||
comments += [c for c in page_comments if c["id"] > last_processed_id]
|
comments += [c for c in page_comments if c["id"] > db["last_processed_id"]]
|
||||||
|
|
||||||
if limit > 0 and len(comments) >= limit:
|
if limit > 0 and len(comments) >= limit:
|
||||||
break
|
break
|
||||||
|
@ -102,8 +98,7 @@ class DramaClient:
|
||||||
if not comments:
|
if not comments:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
self.db["last_processed_id"] = max(c["id"] for c in comments)
|
db["last_processed_id"] = max(c["id"] for c in comments)
|
||||||
self.db.sync()
|
|
||||||
|
|
||||||
# New comments may have pushed others to page n+1 while fetching.
|
# New comments may have pushed others to page n+1 while fetching.
|
||||||
deduped_comments = {c["id"]: c for c in comments}.values()
|
deduped_comments = {c["id"]: c for c in comments}.values()
|
||||||
|
|
|
@ -1,8 +0,0 @@
|
||||||
import yaml
|
|
||||||
import os
|
|
||||||
|
|
||||||
current_dir = os.path.dirname(os.path.realpath(__file__))
|
|
||||||
config_path = os.path.join(current_dir, "config.yaml")
|
|
||||||
|
|
||||||
with open(config_path, "r") as f:
|
|
||||||
config = yaml.safe_load(f)
|
|
2
model.py
2
model.py
|
@ -7,7 +7,7 @@ from transformers import (
|
||||||
AutoModelForCausalLM,
|
AutoModelForCausalLM,
|
||||||
LogitsProcessorList,
|
LogitsProcessorList,
|
||||||
)
|
)
|
||||||
from config import config
|
from data import config
|
||||||
|
|
||||||
|
|
||||||
class StopAfterPlusIsGenerated(LogitsProcessor):
|
class StopAfterPlusIsGenerated(LogitsProcessor):
|
||||||
|
|
|
@ -2,5 +2,6 @@ fuzzywuzzy==0.18.0
|
||||||
PyYAML==6.0
|
PyYAML==6.0
|
||||||
Requests==2.31.0
|
Requests==2.31.0
|
||||||
rich==13.4.2
|
rich==13.4.2
|
||||||
|
sqlitedict==2.1.0
|
||||||
torch==2.0.1
|
torch==2.0.1
|
||||||
transformers==4.31.0
|
transformers==4.31.0
|
||||||
|
|
19
utils.py
19
utils.py
|
@ -4,7 +4,7 @@ import re
|
||||||
from fuzzywuzzy import fuzz
|
from fuzzywuzzy import fuzz
|
||||||
from transformers import GPTNeoXTokenizerFast
|
from transformers import GPTNeoXTokenizerFast
|
||||||
|
|
||||||
from config import config
|
from data import config
|
||||||
from maxsubstring import longest_common_substring
|
from maxsubstring import longest_common_substring
|
||||||
|
|
||||||
URL_REGEX = (
|
URL_REGEX = (
|
||||||
|
@ -131,6 +131,7 @@ def build_prompt(post, comments):
|
||||||
prompt += comment_str
|
prompt += comment_str
|
||||||
|
|
||||||
prompt = prompt.replace(config["username"], random.choice(config["fake_usernames"]))
|
prompt = prompt.replace(config["username"], random.choice(config["fake_usernames"]))
|
||||||
|
prompt = prompt.replace("👻", "Ghost")
|
||||||
prompt = prompt.strip() + "\n"
|
prompt = prompt.strip() + "\n"
|
||||||
|
|
||||||
# Truncate the prompt to leave room for generation.
|
# Truncate the prompt to leave room for generation.
|
||||||
|
@ -142,11 +143,6 @@ def build_prompt(post, comments):
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
def log_prompt(prompt):
|
|
||||||
with open(f"{config['data_dir']}/prompts.txt", "a") as f:
|
|
||||||
f.write(f"{prompt}\n==========\n")
|
|
||||||
|
|
||||||
|
|
||||||
def reply_length(reply):
|
def reply_length(reply):
|
||||||
"""Return the length of the reply, without Markdown images, URLs, or quoted text."""
|
"""Return the length of the reply, without Markdown images, URLs, or quoted text."""
|
||||||
# Remove Markdown images and URLs.
|
# Remove Markdown images and URLs.
|
||||||
|
@ -160,6 +156,17 @@ def reply_length(reply):
|
||||||
return len(reply)
|
return len(reply)
|
||||||
|
|
||||||
|
|
||||||
|
def median_by_key(lst, key):
|
||||||
|
lst = sorted(lst, key=key)
|
||||||
|
mid_index = len(lst) // 2
|
||||||
|
|
||||||
|
# For lists of even length, pick either option as the median.
|
||||||
|
if len(lst) % 2 == 0:
|
||||||
|
return random.choice([lst[mid_index - 1], lst[mid_index]])
|
||||||
|
else:
|
||||||
|
return lst[mid_index]
|
||||||
|
|
||||||
|
|
||||||
def count_tokens(text):
|
def count_tokens(text):
|
||||||
return len(tokenizer(text).input_ids)
|
return len(tokenizer(text).input_ids)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue