Bussy-boy/utils.py

176 lines
4.9 KiB
Python

import random
import re
from fuzzywuzzy import fuzz
from transformers import GPTNeoXTokenizerFast
from config import config
from maxsubstring import longest_common_substring
URL_REGEX = (
r"http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+"
)
tokenizer = GPTNeoXTokenizerFast.from_pretrained(
f"{config['data_dir']}/mpt-30b-drama-ba678"
)
def remove_notifications(text):
"""Change @float-trip to @<i></i>float-trip and carp to c<i></i>arp."""
text = re.sub(rf"@(?!{config['username']}\b)", "@<i></i>", text)
notified_users = [
"aevan",
"avean",
"joan",
"pewkie",
"carp",
"idio3",
"idio ",
"the_homocracy",
"schizocel",
"scitzocel",
"snakes",
"sneks",
"jc",
"justcool",
"clit",
"geese",
"kippy",
"mccox",
"chiobu",
"donger",
"soren",
]
for user in notified_users:
match = re.search(user, text, re.IGNORECASE)
if match:
text = f"{text[:match.start() + 1]}<i></i>{text[match.start() + 1:]}"
return text
def format_reply(text):
for username in config["fake_usernames"]:
text.replace(username, config["username"])
text = replace_rdrama_images(text)
return text.strip()
def is_low_quality(reply, post, comments):
"""
Label the reply as low quality if:
- The Levenshtein distance determines it's similar to a previous comment in the thread.
- len(longest_common_substring) > 100
- After removing links, Markdown images, and quoted text, the length is < 10.
"""
for comment in comments:
if fuzz.ratio(reply, comment["body"]) > 90:
return True
lcs = list(longest_common_substring(reply).keys())[0]
if len(lcs) >= 100:
return True
if reply_length(reply) < 10:
return True
return False
def replace_rdrama_images(text):
"""Replace images pointing to rdrama.net with a loading image."""
loading = "https://i.rdrama.net/i/l.webp"
webp_pattern = r"https://\S*\.rdrama\.net/\S*\.webp"
md_img_pattern = r"!\[[^\]]*\]\((https://\S*\.rdrama\.net)?/\S*\)"
text = re.sub(webp_pattern, loading, text)
text = re.sub(md_img_pattern, f"![]({loading})", text)
return text
def normalize_emojis(s):
"""Bring # and ! to the front of an emoji."""
def repl(match):
# Extract the word between colons and the special characters.
word = match.group(0)
specials = set(re.findall(r"[#!]", word))
# Sort specials and append the word without specials.
new_emoji = "".join(sorted(specials, reverse=True)) + re.sub(r"[#!]", "", word)
return new_emoji
emoji_pattern = r"(?<=:)[a-zA-Z@#!]*[#!][a-zA-Z@#!]*(?=:)"
s = re.sub(emoji_pattern, repl, s)
return s
def build_prompt(post, comments):
prompt = (
f"[Post] [Author] {post['author_name']} "
f"[Title] {post['title']} [URL] {post['url']} "
f"[Hole] {post['sub'] or 'N/A'} [Votes] +71 / -0\n\n"
f"{post['body']}\n\n[Comments]"
)
comments.append({"author_name": config["username"], "body": ""})
for depth, comment in enumerate(comments):
body = normalize_emojis(comment["body"])
author = comment["author_name"]
comment_str = f"\n\n{author} +45 / -0\n{body}"
indent = depth * " "
comment_str = "\n".join([indent + line for line in comment_str.split("\n")])
prompt += comment_str
prompt = prompt.replace(config["username"], random.choice(config["fake_usernames"]))
prompt = prompt.strip() + "\n"
# Truncate the prompt to leave room for generation.
tokens = tokenizer.tokenize(prompt)
if len(tokens) > config["prompt_token_limit"]:
tokens = tokens[-config["prompt_token_limit"] :]
prompt = tokenizer.convert_tokens_to_string(tokens)
return prompt
def log_prompt(prompt):
with open(f"{config['data_dir']}/prompts.txt", "a") as f:
f.write(f"{prompt}\n==========\n")
def reply_length(reply):
"""Return the length of the reply, without Markdown images, URLs, or quoted text."""
# Remove Markdown images and URLs.
reply = re.sub(r"!\[.*?\]\(.*?\)", "", reply)
reply = re.sub(URL_REGEX, "", reply)
# Remove quoted text.
lines = reply.splitlines()
lines = [line for line in lines if not line.lstrip().startswith((">", "\\>"))]
reply = "\n".join(lines).strip()
return len(reply)
def count_tokens(text):
return len(tokenizer(text).input_ids)
def extract_reply(text):
"""
Generated text will either:
- Be cut off at the token limit
- End with the start of a new comment: `float-trip +10`
For the latter case, drop the last line.
"""
new_comment_pattern = r"^ *[\w-]* +\+.*$"
lines = text.split("\n")
if re.match(new_comment_pattern, lines[-1]):
lines = lines[:-1]
return "\n".join([line.strip() for line in lines]).strip()