176 lines
4.9 KiB
Python
176 lines
4.9 KiB
Python
import random
|
|
import re
|
|
|
|
from fuzzywuzzy import fuzz
|
|
from transformers import GPTNeoXTokenizerFast
|
|
|
|
from config import config
|
|
from maxsubstring import longest_common_substring
|
|
|
|
URL_REGEX = (
|
|
r"http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+"
|
|
)
|
|
|
|
tokenizer = GPTNeoXTokenizerFast.from_pretrained(
|
|
f"{config['data_dir']}/mpt-30b-drama-ba678"
|
|
)
|
|
|
|
|
|
def remove_notifications(text):
|
|
"""Change @float-trip to @<i></i>float-trip and carp to c<i></i>arp."""
|
|
text = re.sub(rf"@(?!{config['username']}\b)", "@<i></i>", text)
|
|
|
|
notified_users = [
|
|
"aevan",
|
|
"avean",
|
|
"joan",
|
|
"pewkie",
|
|
"carp",
|
|
"idio3",
|
|
"idio ",
|
|
"the_homocracy",
|
|
"schizocel",
|
|
"scitzocel",
|
|
"snakes",
|
|
"sneks",
|
|
"jc",
|
|
"justcool",
|
|
"clit",
|
|
"geese",
|
|
"kippy",
|
|
"mccox",
|
|
"chiobu",
|
|
"donger",
|
|
"soren",
|
|
]
|
|
|
|
for user in notified_users:
|
|
match = re.search(user, text, re.IGNORECASE)
|
|
if match:
|
|
text = f"{text[:match.start() + 1]}<i></i>{text[match.start() + 1:]}"
|
|
|
|
return text
|
|
|
|
|
|
def format_reply(text):
|
|
for username in config["fake_usernames"]:
|
|
text.replace(username, config["username"])
|
|
text = replace_rdrama_images(text)
|
|
return text.strip()
|
|
|
|
|
|
def is_low_quality(reply, post, comments):
|
|
"""
|
|
Label the reply as low quality if:
|
|
- The Levenshtein distance determines it's similar to a previous comment in the thread.
|
|
- len(longest_common_substring) > 100
|
|
- After removing links, Markdown images, and quoted text, the length is < 10.
|
|
"""
|
|
for comment in comments:
|
|
if fuzz.ratio(reply, comment["body"]) > 90:
|
|
return True
|
|
|
|
lcs = list(longest_common_substring(reply).keys())[0]
|
|
if len(lcs) >= 100:
|
|
return True
|
|
|
|
if reply_length(reply) < 10:
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
def replace_rdrama_images(text):
|
|
"""Replace images pointing to rdrama.net with a loading image."""
|
|
loading = "https://i.rdrama.net/i/l.webp"
|
|
webp_pattern = r"https://\S*\.rdrama\.net/\S*\.webp"
|
|
md_img_pattern = r"!\[[^\]]*\]\((https://\S*\.rdrama\.net)?/\S*\)"
|
|
|
|
text = re.sub(webp_pattern, loading, text)
|
|
text = re.sub(md_img_pattern, f"![]({loading})", text)
|
|
return text
|
|
|
|
|
|
def normalize_emojis(s):
|
|
"""Bring # and ! to the front of an emoji."""
|
|
|
|
def repl(match):
|
|
# Extract the word between colons and the special characters.
|
|
word = match.group(0)
|
|
specials = set(re.findall(r"[#!]", word))
|
|
# Sort specials and append the word without specials.
|
|
new_emoji = "".join(sorted(specials, reverse=True)) + re.sub(r"[#!]", "", word)
|
|
return new_emoji
|
|
|
|
emoji_pattern = r"(?<=:)[a-zA-Z@#!]*[#!][a-zA-Z@#!]*(?=:)"
|
|
s = re.sub(emoji_pattern, repl, s)
|
|
|
|
return s
|
|
|
|
|
|
def build_prompt(post, comments):
|
|
prompt = (
|
|
f"[Post] [Author] {post['author_name']} "
|
|
f"[Title] {post['title']} [URL] {post['url']} "
|
|
f"[Hole] {post['sub'] or 'N/A'} [Votes] +71 / -0\n\n"
|
|
f"{post['body']}\n\n[Comments]"
|
|
)
|
|
|
|
comments.append({"author_name": config["username"], "body": ""})
|
|
|
|
for depth, comment in enumerate(comments):
|
|
body = normalize_emojis(comment["body"])
|
|
author = comment["author_name"]
|
|
|
|
comment_str = f"\n\n{author} +45 / -0\n{body}"
|
|
indent = depth * " "
|
|
comment_str = "\n".join([indent + line for line in comment_str.split("\n")])
|
|
prompt += comment_str
|
|
|
|
prompt = prompt.replace(config["username"], random.choice(config["fake_usernames"]))
|
|
prompt = prompt.strip() + "\n"
|
|
|
|
# Truncate the prompt to leave room for generation.
|
|
tokens = tokenizer.tokenize(prompt)
|
|
if len(tokens) > config["prompt_token_limit"]:
|
|
tokens = tokens[-config["prompt_token_limit"] :]
|
|
prompt = tokenizer.convert_tokens_to_string(tokens)
|
|
|
|
return prompt
|
|
|
|
|
|
def log_prompt(prompt):
|
|
with open(f"{config['data_dir']}/prompts.txt", "a") as f:
|
|
f.write(f"{prompt}\n==========\n")
|
|
|
|
|
|
def reply_length(reply):
|
|
"""Return the length of the reply, without Markdown images, URLs, or quoted text."""
|
|
# Remove Markdown images and URLs.
|
|
reply = re.sub(r"!\[.*?\]\(.*?\)", "", reply)
|
|
reply = re.sub(URL_REGEX, "", reply)
|
|
# Remove quoted text.
|
|
lines = reply.splitlines()
|
|
lines = [line for line in lines if not line.lstrip().startswith((">", "\\>"))]
|
|
reply = "\n".join(lines).strip()
|
|
|
|
return len(reply)
|
|
|
|
|
|
def count_tokens(text):
|
|
return len(tokenizer(text).input_ids)
|
|
|
|
|
|
def extract_reply(text):
|
|
"""
|
|
Generated text will either:
|
|
- Be cut off at the token limit
|
|
- End with the start of a new comment: `float-trip +10`
|
|
For the latter case, drop the last line.
|
|
"""
|
|
new_comment_pattern = r"^ *[\w-]* +\+.*$"
|
|
lines = text.split("\n")
|
|
if re.match(new_comment_pattern, lines[-1]):
|
|
lines = lines[:-1]
|
|
return "\n".join([line.strip() for line in lines]).strip()
|