import random import re from fuzzywuzzy import fuzz from transformers import GPTNeoXTokenizerFast from data import config from maxsubstring import longest_common_substring URL_REGEX = ( r"http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+" ) tokenizer = GPTNeoXTokenizerFast.from_pretrained( f"{config['data_dir']}/mpt-30b-drama-ba678" ) def remove_notifications(text): """Change @float-trip to @float-trip and carp to carp.""" text = re.sub(rf"@(?!{config['username']}\b)", "@", text) notified_users = [ "aevan", "avean", "joan", "pewkie", "carp", "idio3", "idio ", "the_homocracy", "schizocel", "scitzocel", "snakes", "sneks", "jc", "justcool", "clit", "geese", "kippy", "mccox", "chiobu", "donger", "soren", ] def replace(match): # Insert after the first character of the matched string. user = match.group() return f"{user[:1]}{user[1:]}" for user in notified_users: text = re.sub(user, replace, text, flags=re.IGNORECASE) return text def format_reply(text): for username in config["fake_usernames"]: text.replace(username, config["username"]) text = replace_rdrama_images(text) return text.strip() def is_low_quality(reply, post, comments): """ Label the reply as low quality if: - The Levenshtein distance determines it's similar to a previous comment in the thread. - len(longest_common_substring) > 100 - After removing links, Markdown images, and quoted text, the length is < 10. """ for comment in comments: if fuzz.ratio(reply, comment["body"]) > 90: return True lcs = list(longest_common_substring(reply).keys())[0] if len(lcs) >= 100: return True if reply_length(reply) < 10: return True return False def replace_rdrama_images(text): """Replace images pointing to rdrama.net with a loading image.""" loading = "https://i.rdrama.net/i/l.webp" webp_pattern = r"https://\S*\.rdrama\.net/\S*\.webp" md_img_pattern = r"!\[[^\]]*\]\((https://\S*\.rdrama\.net)?/\S*\)" text = re.sub(webp_pattern, loading, text) text = re.sub(md_img_pattern, f"![]({loading})", text) return text def normalize_emojis(s): """Bring # and ! to the front of an emoji.""" def repl(match): # Extract the word between colons and the special characters. word = match.group(0) specials = set(re.findall(r"[#!]", word)) # Sort specials and append the word without specials. new_emoji = "".join(sorted(specials, reverse=True)) + re.sub(r"[#!]", "", word) return new_emoji emoji_pattern = r"(?<=:)[a-zA-Z@#!]*[#!][a-zA-Z@#!]*(?=:)" s = re.sub(emoji_pattern, repl, s) return s def build_prompt(post, comments): prompt = ( f"[Post] [Author] {post['author_name']} " f"[Title] {post['title']} [URL] {post['url']} " f"[Hole] {post['sub'] or 'N/A'} [Votes] +71 / -0\n\n" f"{post['body']}\n\n[Comments]" ) comments.append({"author_name": config["username"], "body": ""}) for depth, comment in enumerate(comments): body = normalize_emojis(comment["body"]) author = comment["author_name"] comment_str = f"\n\n{author} +45 / -0\n{body}" indent = depth * " " comment_str = "\n".join([indent + line for line in comment_str.split("\n")]) prompt += comment_str prompt = prompt.replace(config["username"], random.choice(config["fake_usernames"])) prompt = prompt.replace("👻", "Ghost") prompt = prompt.strip() + "\n" # Truncate the prompt to leave room for generation. tokens = tokenizer.tokenize(prompt) if len(tokens) > config["prompt_token_limit"]: tokens = tokens[-config["prompt_token_limit"] :] prompt = tokenizer.convert_tokens_to_string(tokens) return prompt def reply_length(reply): """Return the length of the reply, without Markdown images, URLs, or quoted text.""" # Remove Markdown images and URLs. reply = re.sub(r"!\[.*?\]\(.*?\)", "", reply) reply = re.sub(URL_REGEX, "", reply) # Remove quoted text. lines = reply.splitlines() lines = [line for line in lines if not line.lstrip().startswith((">", "\\>"))] reply = "\n".join(lines).strip() return len(reply) def median_by_key(lst, key): lst = sorted(lst, key=key) mid_index = len(lst) // 2 # For lists of even length, pick either option as the median. if len(lst) % 2 == 0: return random.choice([lst[mid_index - 1], lst[mid_index]]) else: return lst[mid_index] def count_tokens(text): return len(tokenizer(text).input_ids) def extract_reply(text): """ Generated text will either: - Be cut off at the token limit - End with the start of a new comment: `float-trip +10` For the latter case, drop the last line. """ new_comment_pattern = r"^ *[\w-]* +\+.*$" lines = text.split("\n") if re.match(new_comment_pattern, lines[-1]): lines = lines[:-1] return "\n".join([line.strip() for line in lines]).strip()