Bussy-boy/utils.py

208 lines
5.7 KiB
Python
Raw Normal View History

2023-07-09 23:00:43 +00:00
import random
import re
from fuzzywuzzy import fuzz
from transformers import GPTNeoXTokenizerFast
from data import config
2023-07-09 23:15:25 +00:00
from maxsubstring import longest_common_substring
2023-07-09 23:00:43 +00:00
URL_REGEX = (
r"http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+"
)
tokenizer = GPTNeoXTokenizerFast.from_pretrained("float-trip/mpt-30b-drama")
2023-07-09 23:00:43 +00:00
def remove_notifications(text):
"""Change @float-trip to <span>@</span>float-trip and carp to <span>c</span>arp."""
text = re.sub(rf"@(?!{config['username']}\b)", "<span>@</span>", text)
2023-07-09 23:00:43 +00:00
notified_users = [
"aevan",
"avean",
" capy",
"the rodent",
2023-07-09 23:00:43 +00:00
"carp",
"clit",
2023-07-09 23:00:43 +00:00
"snakes",
"sneks",
"snekky",
"snekchad",
2023-07-09 23:00:43 +00:00
"jc",
"justcool",
"lawlz",
"transgirltradwife",
"impassionata",
"pizzashill",
"idio3",
"idio ",
"telegram ",
"schizo",
"joan",
"pewkie",
"homocracy",
2023-07-09 23:00:43 +00:00
"donger",
"geese",
2023-07-09 23:00:43 +00:00
"soren",
"marseyismywaifu",
"mimw",
"heymoon",
"gaypoon",
"jollymoon",
"chiobu",
"mccox",
"august",
"marco",
"klen",
2023-07-09 23:00:43 +00:00
]
def replace(match):
# Insert <span></span> around the first character of the matched string.
user = match.group()
return f"<span>{user[:1]}</span>{user[1:]}"
2023-07-09 23:00:43 +00:00
for user in notified_users:
text = re.sub(user, replace, text, flags=re.IGNORECASE)
2023-07-09 23:00:43 +00:00
return text
def format_reply(text):
for username in config["fake_usernames"]:
text.replace(username, config["username"])
text = replace_rdrama_images(text)
text = remove_notifications(text)
2023-07-09 23:00:43 +00:00
return text.strip()
def is_low_quality(reply, _post, comments):
2023-07-09 23:00:43 +00:00
"""
Label the reply as low quality if:
- The Levenshtein distance determines it's similar to a previous comment in the thread.
- len(longest_common_substring) > 100
- After removing links, Markdown images, and quoted text, the length is < 10.
"""
for comment in comments:
if fuzz.ratio(reply, comment["body"]) > 90:
return True
lcs = list(longest_common_substring(reply).keys())[0]
if len(lcs) >= 100:
return True
if reply_length(reply) < 10:
return True
# Lost pinging rights.
if re.findall(r"!\w+", reply):
return True
2023-07-09 23:00:43 +00:00
return False
def contains_url(text):
return re.search(URL_REGEX, text) is not None
2023-07-09 23:00:43 +00:00
def replace_rdrama_images(text):
"""Replace images pointing to rdrama.net with a loading image."""
loading = "https://i.rdrama.net/i/l.webp"
webp_pattern = r"https://\S*\.rdrama\.net/\S*\.webp"
md_img_pattern = r"!\[[^\]]*\]\((https://\S*\.rdrama\.net)?/\S*\)"
text = re.sub(webp_pattern, loading, text)
text = re.sub(md_img_pattern, f"![]({loading})", text)
return text
def normalize_emojis(s):
"""Bring # and ! to the front of an emoji."""
def repl(match):
# Extract the word between colons and the special characters.
word = match.group(0)
specials = set(re.findall(r"[#!]", word))
# Sort specials and append the word without specials.
new_emoji = "".join(sorted(specials, reverse=True)) + re.sub(r"[#!]", "", word)
return new_emoji
emoji_pattern = r"(?<=:)[a-zA-Z@#!]*[#!][a-zA-Z@#!]*(?=:)"
s = re.sub(emoji_pattern, repl, s)
return s
def build_prompt(post, comments):
prompt = (
f"[Post] [Author] {post['author_name']} "
f"[Title] {post['title']} [URL] {post['url']} "
f"[Hole] {post['sub'] or 'N/A'} [Votes] +71 / -0\n\n"
f"{post['body']}\n\n[Comments]"
)
comments.append({"author_name": config["username"], "body": ""})
for depth, comment in enumerate(comments):
body = normalize_emojis(comment["body"])
author = comment["author_name"]
comment_str = f"\n\n{author} +45 / -0\n{body}"
indent = depth * " "
comment_str = "\n".join([indent + line for line in comment_str.split("\n")])
prompt += comment_str
prompt = prompt.replace(config["username"], random.choice(config["fake_usernames"]))
prompt = prompt.replace("👻", "Ghost")
2023-07-09 23:00:43 +00:00
prompt = prompt.strip() + "\n"
# Truncate the prompt to leave room for generation.
tokens = tokenizer.tokenize(prompt)
if len(tokens) > config["prompt_token_limit"]:
tokens = tokens[-config["prompt_token_limit"] :]
prompt = tokenizer.convert_tokens_to_string(tokens)
return prompt
def reply_length(reply):
"""Return the length of the reply, without Markdown images, URLs, or quoted text."""
# Remove Markdown images and URLs.
reply = re.sub(r"!\[.*?\]\(.*?\)", "", reply)
reply = re.sub(URL_REGEX, "", reply)
# Remove quoted text.
lines = reply.splitlines()
lines = [line for line in lines if not line.lstrip().startswith((">", "\\>"))]
reply = "\n".join(lines).strip()
return len(reply)
def median_by_key(lst, key):
lst = sorted(lst, key=key)
mid_index = len(lst) // 2
# For lists of even length, pick either option as the median.
if len(lst) % 2 == 0:
return random.choice([lst[mid_index - 1], lst[mid_index]])
else:
return lst[mid_index]
2023-07-09 23:00:43 +00:00
def count_tokens(text):
return len(tokenizer(text).input_ids)
def extract_reply(text):
"""
Generated text will either:
- Be cut off at the token limit
- End with the start of a new comment: `float-trip +10`
For the latter case, drop the last line.
"""
new_comment_pattern = r"^ *[\w-]* +\+.*$"
2023-07-09 23:00:43 +00:00
lines = text.split("\n")
if re.match(new_comment_pattern, lines[-1]):
2023-07-09 23:00:43 +00:00
lines = lines[:-1]
return "\n".join([line.strip() for line in lines]).strip()