diff --git a/bot.py b/bot.py index 4257cca..e5b00b6 100644 --- a/bot.py +++ b/bot.py @@ -7,7 +7,7 @@ import model import utils from client import DramaClient -from config import config +from data import config, db traceback.install() @@ -24,7 +24,7 @@ class Bot: comments = [ c - for c in self.client.fetch_new_comments(limit=50) + for c in self.client.fetch_new_comments() if "author_name" in c and not c["is_bot"] and c["author_name"] != config["username"] @@ -71,13 +71,13 @@ class Bot: if not post or not thread_comments: print("Could not fetch context!") + return prompt = utils.build_prompt(post, thread_comments) - utils.log_prompt(prompt) candidates = [] - num_candidates = config["num_candidates"] if random.random() < 0.6 else 1 - while len(candidates) < num_candidates: + rejects = [] + while len(candidates) < config["num_candidates"]: gen_text = self.model.generate(prompt) reply = utils.extract_reply(gen_text) print(f"Generated text: {gen_text}\nReply:\n{reply}") @@ -85,14 +85,25 @@ class Bot: if len(reply) == 0: print("Retrying: reply empty after processing.") + rejects.append(reply) elif utils.is_low_quality(reply, post, thread_comments): print("Retrying: low quality reply.") + rejects else: candidates.append(reply) print("Accepting reply.") - # Get the longest reply, but cap the considered length at 500 chars. - reply = max(candidates, key=lambda r: min(utils.reply_length(r), 500)) + reply = utils.median_by_key(candidates, key=utils.reply_length) + + db["prompts"].append( + { + "prompt": prompt, + "candidates": candidates, + "rejects": rejects, + "selected": reply, + } + ) + self.client.reply(reply, comment) diff --git a/client.py b/client.py index 0409ecc..a5643ae 100644 --- a/client.py +++ b/client.py @@ -7,16 +7,13 @@ import shelve from requests.adapters import HTTPAdapter, Retry -from config import config +from data import config, db class DramaClient: BASE_URL = "https://rdrama.net" def __init__(self): - self.db = shelve.open(f"{config['data_dir']}/client_state.p", writeback=True) - self.db.setdefault("processed_replies", set()) - self.session = requests.Session() retries = Retry( total=5, backoff_factor=5, status_forcelist=[500, 502, 503, 504, 521] @@ -77,22 +74,21 @@ class DramaClient: return r.json() - def fetch_new_comments(self, limit=0): + def fetch_new_comments(self, limit=config["num_replies"] * 25): comments = [] - last_processed_id = self.db.get("last_processed_id", -1) earliest_id = math.inf page = 1 # Fetch comments until we find the last one processed. - while earliest_id > last_processed_id: + while earliest_id > db["last_processed_id"]: page_comments = self.fetch_page(page) if len(page_comments) == 0: break earliest_id = min([c["id"] for c in page_comments]) - comments += [c for c in page_comments if c["id"] > last_processed_id] + comments += [c for c in page_comments if c["id"] > db["last_processed_id"]] if limit > 0 and len(comments) >= limit: break @@ -102,8 +98,7 @@ class DramaClient: if not comments: return [] - self.db["last_processed_id"] = max(c["id"] for c in comments) - self.db.sync() + db["last_processed_id"] = max(c["id"] for c in comments) # New comments may have pushed others to page n+1 while fetching. deduped_comments = {c["id"]: c for c in comments}.values() diff --git a/config.py b/config.py deleted file mode 100644 index fb42326..0000000 --- a/config.py +++ /dev/null @@ -1,8 +0,0 @@ -import yaml -import os - -current_dir = os.path.dirname(os.path.realpath(__file__)) -config_path = os.path.join(current_dir, "config.yaml") - -with open(config_path, "r") as f: - config = yaml.safe_load(f) diff --git a/model.py b/model.py index 527dc88..dfad223 100644 --- a/model.py +++ b/model.py @@ -7,7 +7,7 @@ from transformers import ( AutoModelForCausalLM, LogitsProcessorList, ) -from config import config +from data import config class StopAfterPlusIsGenerated(LogitsProcessor): diff --git a/requirements.txt b/requirements.txt index 2ae7a95..d0d0988 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,5 +2,6 @@ fuzzywuzzy==0.18.0 PyYAML==6.0 Requests==2.31.0 rich==13.4.2 +sqlitedict==2.1.0 torch==2.0.1 transformers==4.31.0 diff --git a/utils.py b/utils.py index 7e17610..3f0f4ab 100644 --- a/utils.py +++ b/utils.py @@ -4,7 +4,7 @@ import re from fuzzywuzzy import fuzz from transformers import GPTNeoXTokenizerFast -from config import config +from data import config from maxsubstring import longest_common_substring URL_REGEX = ( @@ -131,6 +131,7 @@ def build_prompt(post, comments): prompt += comment_str prompt = prompt.replace(config["username"], random.choice(config["fake_usernames"])) + prompt = prompt.replace("👻", "Ghost") prompt = prompt.strip() + "\n" # Truncate the prompt to leave room for generation. @@ -142,11 +143,6 @@ def build_prompt(post, comments): return prompt -def log_prompt(prompt): - with open(f"{config['data_dir']}/prompts.txt", "a") as f: - f.write(f"{prompt}\n==========\n") - - def reply_length(reply): """Return the length of the reply, without Markdown images, URLs, or quoted text.""" # Remove Markdown images and URLs. @@ -160,6 +156,17 @@ def reply_length(reply): return len(reply) +def median_by_key(lst, key): + lst = sorted(lst, key=key) + mid_index = len(lst) // 2 + + # For lists of even length, pick either option as the median. + if len(lst) % 2 == 0: + return random.choice([lst[mid_index - 1], lst[mid_index]]) + else: + return lst[mid_index] + + def count_tokens(text): return len(tokenizer(text).input_ids)