Bussy-boy/utils.py

208 lines
5.7 KiB
Python

import random
import re
from fuzzywuzzy import fuzz
from transformers import GPTNeoXTokenizerFast
from data import config
from maxsubstring import longest_common_substring
URL_REGEX = (
r"http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+"
)
tokenizer = GPTNeoXTokenizerFast.from_pretrained("float-trip/mpt-30b-drama")
def remove_notifications(text):
"""Change @float-trip to <span>@</span>float-trip and carp to <span>c</span>arp."""
text = re.sub(rf"@(?!{config['username']}\b)", "<span>@</span>", text)
notified_users = [
"aevan",
"avean",
" capy",
"the rodent",
"carp",
"clit",
"snakes",
"sneks",
"snekky",
"snekchad",
"jc",
"justcool",
"lawlz",
"transgirltradwife",
"impassionata",
"pizzashill",
"idio3",
"idio ",
"telegram ",
"schizo",
"joan",
"pewkie",
"homocracy",
"donger",
"geese",
"soren",
"marseyismywaifu",
"mimw",
"heymoon",
"gaypoon",
"jollymoon",
"chiobu",
"mccox",
"august",
"marco",
"klen",
]
def replace(match):
# Insert <span></span> around the first character of the matched string.
user = match.group()
return f"<span>{user[:1]}</span>{user[1:]}"
for user in notified_users:
text = re.sub(user, replace, text, flags=re.IGNORECASE)
return text
def format_reply(text):
for username in config["fake_usernames"]:
text.replace(username, config["username"])
text = replace_rdrama_images(text)
text = remove_notifications(text)
return text.strip()
def is_low_quality(reply, _post, comments):
"""
Label the reply as low quality if:
- The Levenshtein distance determines it's similar to a previous comment in the thread.
- len(longest_common_substring) > 100
- After removing links, Markdown images, and quoted text, the length is < 10.
"""
for comment in comments:
if fuzz.ratio(reply, comment["body"]) > 90:
return True
lcs = list(longest_common_substring(reply).keys())[0]
if len(lcs) >= 100:
return True
if reply_length(reply) < 10:
return True
# Lost pinging rights.
if re.findall(r"!\w+", reply):
return True
return False
def contains_url(text):
return re.search(URL_REGEX, text) is not None
def replace_rdrama_images(text):
"""Replace images pointing to rdrama.net with a loading image."""
loading = "https://i.rdrama.net/i/l.webp"
webp_pattern = r"https://\S*\.rdrama\.net/\S*\.webp"
md_img_pattern = r"!\[[^\]]*\]\((https://\S*\.rdrama\.net)?/\S*\)"
text = re.sub(webp_pattern, loading, text)
text = re.sub(md_img_pattern, f"![]({loading})", text)
return text
def normalize_emojis(s):
"""Bring # and ! to the front of an emoji."""
def repl(match):
# Extract the word between colons and the special characters.
word = match.group(0)
specials = set(re.findall(r"[#!]", word))
# Sort specials and append the word without specials.
new_emoji = "".join(sorted(specials, reverse=True)) + re.sub(r"[#!]", "", word)
return new_emoji
emoji_pattern = r"(?<=:)[a-zA-Z@#!]*[#!][a-zA-Z@#!]*(?=:)"
s = re.sub(emoji_pattern, repl, s)
return s
def build_prompt(post, comments):
prompt = (
f"[Post] [Author] {post['author_name']} "
f"[Title] {post['title']} [URL] {post['url']} "
f"[Hole] {post['sub'] or 'N/A'} [Votes] +71 / -0\n\n"
f"{post['body']}\n\n[Comments]"
)
comments.append({"author_name": config["username"], "body": ""})
for depth, comment in enumerate(comments):
body = normalize_emojis(comment["body"])
author = comment["author_name"]
comment_str = f"\n\n{author} +45 / -0\n{body}"
indent = depth * " "
comment_str = "\n".join([indent + line for line in comment_str.split("\n")])
prompt += comment_str
prompt = prompt.replace(config["username"], random.choice(config["fake_usernames"]))
prompt = prompt.replace("👻", "Ghost")
prompt = prompt.strip() + "\n"
# Truncate the prompt to leave room for generation.
tokens = tokenizer.tokenize(prompt)
if len(tokens) > config["prompt_token_limit"]:
tokens = tokens[-config["prompt_token_limit"] :]
prompt = tokenizer.convert_tokens_to_string(tokens)
return prompt
def reply_length(reply):
"""Return the length of the reply, without Markdown images, URLs, or quoted text."""
# Remove Markdown images and URLs.
reply = re.sub(r"!\[.*?\]\(.*?\)", "", reply)
reply = re.sub(URL_REGEX, "", reply)
# Remove quoted text.
lines = reply.splitlines()
lines = [line for line in lines if not line.lstrip().startswith((">", "\\>"))]
reply = "\n".join(lines).strip()
return len(reply)
def median_by_key(lst, key):
lst = sorted(lst, key=key)
mid_index = len(lst) // 2
# For lists of even length, pick either option as the median.
if len(lst) % 2 == 0:
return random.choice([lst[mid_index - 1], lst[mid_index]])
else:
return lst[mid_index]
def count_tokens(text):
return len(tokenizer(text).input_ids)
def extract_reply(text):
"""
Generated text will either:
- Be cut off at the token limit
- End with the start of a new comment: `float-trip +10`
For the latter case, drop the last line.
"""
new_comment_pattern = r"^ *[\w-]* +\+.*$"
lines = text.split("\n")
if re.match(new_comment_pattern, lines[-1]):
lines = lines[:-1]
return "\n".join([line.strip() for line in lines]).strip()