Unify Marseygen and Bussyboy code + change Bussyboy to client/server setup.

master
float-trip 2023-07-26 03:21:01 +00:00
commit 1791c553bb
20 changed files with 1266 additions and 0 deletions

0
__init__.py 100644
View File

View File

View File

@ -0,0 +1,27 @@
{
"operationName": "Mutation",
"variables": {
"input": {
"deployCost": 1.79,
"cloudType": "SECURE",
"containerDiskInGb": 20,
"volumeInGb": 0,
"dataCenterId": "EU-RO-1",
"gpuCount": 1,
"name": "bussyboy-run",
"gpuTypeId": "NVIDIA GeForce RTX 4090",
"minMemoryInGb": 83,
"minVcpuCount": 16,
"networkVolumeId": "0gmsaggsjx",
"startJupyter": false,
"startSsh": true,
"templateId": "runpod-torch",
"cudaVersion": "11.8",
"volumeKey": null,
"ports": "8888/http,22/tcp",
"dockerArgs": "bash -c 'echo \"\n===\n\" >> /workspace/logs/bussyboy.txt && /workspace/start-bussyboy.sh >> /workspace/logs/bussyboy.txt 2>&1'"
}
},
"query": "mutation Mutation($input: PodFindAndDeployOnDemandInput) { podFindAndDeployOnDemand(input: $input) { id machineId __typename } }"
}

94
bussyboy/cron.py 100644
View File

@ -0,0 +1,94 @@
import json
import logging
import re
import random
import time
import requests
from rich import traceback
from bots.data import bussyboy_queue, config, db
from bots.clients.drama import DramaClient
from bots.clients.runpod import RunpodClient
traceback.install()
logging.basicConfig(
filename=f"{config['data_dir']}/logs/bussyboy.log",
filemode="a",
format="%(name)s - %(levelname)s - %(message)s",
level=logging.INFO,
)
logger = logging.getLogger("bussyboy")
drama_client = DramaClient(config["bussyboy"]["token"], logger=logger)
runpod_client = RunpodClient(logger=logger)
def queue_reply(comment):
try:
post, thread = drama_client.fetch_context(comment)
except requests.exceptions.HTTPError as err:
logger.error(f"Error when fetching context: {err}")
return
db["bussyboy_id"] += 1
bussyboy_queue[str(db["bussyboy_id"])] = {"post": post, "thread": thread}
def queue_random_reply():
comments, newest_id = drama_client.fetch_new_comments(
after=db["bussyboy_last_processed"], limit=25
)
db["bussyboy_last_processed"] = newest_id
comments = [
c
for c in comments
if "author_name" in c
and not c["is_bot"]
and c["author_name"] != config["bussyboy"]["username"]
and c["author_name"] != "👻"
and c["author_id"] not in config["bussyboy"]["ignore_user_ids"]
and c["post_id"] != 0
]
if len(comments) == 0:
logger.warn("No comments found.")
return
random.shuffle(comments)
queue_reply(comments[0])
def main():
if "bussyboy_id" not in db:
db["bussyboy_id"] = 0
if "bussyboy_last_processed" not in db:
db["bussyboy_last_processed"] = 0
if "bussyboy_last_random_reply" not in db:
db["bussyboy_last_random_reply"] = 0
# If both the queue and an instance are active, return and let it complete.
if bussyboy_queue and runpod_client.is_running("bussyboy"):
return
time_since_reply = time.time() - db["bussyboy_last_random_reply"]
if time_since_reply > config["bussyboy"]["reply_frequency"]:
db["bussyboy_last_random_reply"] = time.time()
queue_random_reply()
for notif in drama_client.fetch_notifications():
if not notif["is_bot"]:
queue_reply(notif)
# Create instance if there are requests to be fulfilled.
if bussyboy_queue:
runpod_client.create_instance(
f"{config['data_dir']}/bussyboy/create_runpod.json"
)
if __name__ == "__main__":
main()

View File

113
clients/drama.py 100644
View File

@ -0,0 +1,113 @@
import logging
import math
import sys
import time
from collections import OrderedDict
import requests
class DramaClient:
BASE_URL = "https://rdrama.net"
def __init__(self, token, logger=None):
self.session = requests.Session()
self.token = token
retries = requests.adapters.Retry(
total=5, backoff_factor=10, status_forcelist=[500, 502, 503, 504, 521]
)
self.session.mount(
"https://", requests.adapters.HTTPAdapter(max_retries=retries)
)
self.logger = logger or logging.getLogger(__name__)
self.chud_phrase = self.get("/@me").get("chud_phrase", "")
def get(self, endpoint):
self.logger.info(f"GET {endpoint}")
time.sleep(1)
r = self.session.get(
f"{self.BASE_URL}{endpoint}", headers={"Authorization": self.token}
)
if r.status_code != 200:
self.logger.error("Error! {r}, {r.status_code}, {r.text}")
r.raise_for_status()
return r.json()
def post(self, endpoint, data=None, images=None):
self.logger.info(f"POST {endpoint}")
time.sleep(5)
if data is not None:
for key, value in data.items():
data[key] = str(value)
r = self.session.post(
f"{self.BASE_URL}{endpoint}",
data=data,
headers={"Authorization": self.token},
files=images,
)
if r.status_code != 200:
self.logger.error("Error! {r}, {r.status_code}, {r.text}")
r.raise_for_status()
return r.json()
# Return comments with a newer ID than `after`, up to `limit`.
def fetch_new_comments(self, after=0, limit=math.inf):
def newest_id(comments):
return max(c["id"] for c in comments) if comments else 0
def oldest_id(comments):
return min(c["id"] for c in comments) if comments else math.inf
comments = []
page = 1
# Fetch /comment?page=x until we've reached `after` or have satisfied `limit`.
while oldest_id(comments) > after and len(comments) < limit:
page_comments = self.fetch_page(page)
if not page_comments:
break
comments.extend(page_comments)
page += 1
# Filter for new comments.
comments = [c for c in comments if c["id"] > after]
# Deduplicate comments in case one was pushed to the next page while fetching.
comments = list(OrderedDict((c["id"], c) for c in comments).values())
# Oldest first.
comments.reverse()
return comments, newest_id(comments)
# Return replies and mentions.
def fetch_notifications(self):
notifs = self.get("/unread")["data"]
notifs = [n for n in notifs if n["body"]]
return notifs
# Return the post and comment thread (only including parents) for a comment.
def fetch_context(self, comment):
post = self.get(f"/post/{comment['post_id']}")
comments = [comment]
while parent_id := comments[-1].get("parent_comment_id", None):
parent = self.get(f"/comment/{parent_id}")
comments.append(parent)
# Make the top-level comment be first.
comments.reverse()
return post, comments
def fetch_page(self, page):
return self.get(f"/comments?page={page}")["data"]
def reply(self, comment, body, images=None):
data = {"parent_fullname": f"c_{comment['id']}", "body": body}
return self.post("/comment", data=data, images=images)

59
clients/runpod.py 100644
View File

@ -0,0 +1,59 @@
import json
import logging
import time
import requests
from bots.data import config
class RunpodClient:
INSTANCE_UNAVAILABLE_MSG = (
"There are no longer any instances available with the requested specifications. "
"Please refresh and try again."
)
def __init__(self, logger=None):
self.url = f"https://api.runpod.io/graphql?api_key={config['runpod_token']}"
self.logger = logger or logging.getLogger(__name__)
# Return True if a pod with pod_name is currently running.
def is_running(self, pod_name):
fetch_pods_data = {
"query": "query Pods { myself { pods { id name } } }",
}
fetch_response = requests.post(
self.url,
headers={"Content-Type": "application/json"},
data=json.dumps(fetch_pods_data),
)
if not fetch_response.ok:
self.logger.error("Error fetching pods.")
return
pods = fetch_response.json()["data"]["myself"]["pods"]
return any(pod["name"] == pod_name for pod in pods)
def create_instance(self, json_file):
with open(json_file, "r") as file:
runpod_query = json.load(file)
while True:
response = requests.post(
self.url,
headers={"Content-Type": "application/json"},
json=runpod_query,
)
if "errors" in response.json():
error_message = response.json()["errors"][0]["message"]
if error_message == RunpodClient.INSTANCE_UNAVAILABLE_MSG:
logging.warn("No instances available, retrying in 1 second...")
time.sleep(1)
else:
logging.error(f"Unhandled error: {error_message}")
else:
break

28
data.py 100644
View File

@ -0,0 +1,28 @@
import os
import yaml
from sqlitedict import SqliteDict
current_dir = os.path.dirname(os.path.realpath(__file__))
config_path = os.path.join(current_dir, "config.yaml")
def load_config():
with open(config_path, "r") as f:
return yaml.safe_load(f)
config = load_config()
db = SqliteDict(f"{config['data_dir']}/db.sqlite", autocommit=True)
marseygen_queue = SqliteDict(
f"{config['data_dir']}/db.sqlite", tablename="marseygen_queue", autocommit=True
)
bussyboy_queue = SqliteDict(
f"{config['data_dir']}/db.sqlite", tablename="bussyboy_queue", autocommit=True
)
bussyboy_log = SqliteDict(
f"{config['data_dir']}/db.sqlite", tablename="bussyboy_log", autocommit=True
)

View File

View File

@ -0,0 +1,26 @@
{
"operationName": "Mutation",
"variables": {
"input": {
"deployCost": 1.79,
"cloudType": "SECURE",
"containerDiskInGb": 20,
"volumeInGb": 0,
"dataCenterId": "EU-RO-1",
"gpuCount": 1,
"name": "marseygen-run",
"gpuTypeId": "NVIDIA GeForce RTX 4090",
"minMemoryInGb": 83,
"minVcpuCount": 16,
"networkVolumeId": "0gmsaggsjx",
"startJupyter": false,
"startSsh": true,
"templateId": "runpod-torch",
"cudaVersion": "11.8",
"volumeKey": null,
"ports": "8888/http,22/tcp",
"dockerArgs": "bash -c 'echo \"\n===\n\" >> /workspace/logs/marseygen.txt && /workspace/start-marseygen.sh >> /workspace/logs/marseygen.txt 2>&1'"
}
},
"query": "mutation Mutation($input: PodFindAndDeployOnDemandInput) { podFindAndDeployOnDemand(input: $input) { id machineId __typename } }"
}

66
marseygen/cron.py 100644
View File

@ -0,0 +1,66 @@
import json
import logging
import re
import time
import requests
from rich import traceback
from bots.data import config, db, marseygen_queue
from bots.clients.drama import DramaClient
from bots.clients.runpod import RunpodClient
traceback.install()
# Set up logging
logging.basicConfig(
filename=f"{config['data_dir']}/logs/marseygen.log",
filemode="a",
format="%(name)s - %(levelname)s - %(message)s",
level=logging.INFO,
)
logger = logging.getLogger("marseygen")
def main():
if "marseygen_id" not in db:
db["marseygen_id"] = 0
if "marseygen_last_processed" not in db:
db["marseygen_last_processed"] = 0
runpod_client = RunpodClient(logger=logger)
# If requests are still in the queue and
# an instance is running, return and let it complete.
if marseygen_queue and runpod_client.is_running("marseygen"):
return
drama_client = DramaClient(config["marseygen"]["token"], logger=logger)
# Fetch new requests and add each to the queue.
comments, newest_id = drama_client.fetch_new_comments(
after=db["marseygen_last_processed"]
)
db["marseygen_last_processed"] = newest_id
# Add new requests to queue.
for comment in comments:
prompts = re.findall(
r"^!sd (.*)$", comment["body"], re.MULTILINE | re.IGNORECASE
)
prompts = prompts[:5]
for prompt in prompts:
prompt = prompt.replace("`", "")
prompt = prompt.replace("marsey", "Marsey")
db["marseygen_id"] += 1
marseygen_queue[str(db["marseygen_id"])] = (comment, prompt)
# Create instance if there are requests to be fulfilled.
if marseygen_queue:
runpod_client.create_instance(
f"{config['data_dir']}/marseygen/create_runpod.json"
)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,249 @@
#!/usr/bin/env python
# https://gist.github.com/hynekcer/fa340f3b63826168ffc0c4b33310ae9c
"""Find the longest repeated substring.
"Efficient way to find longest duplicate string for Python (From Programming Pearls)"
http://stackoverflow.com/questions/13560037/
The algorithm is based on "Prefix doubling".
The worst time complexity is O(n (log n)^2). Memory requirements are linear.
"""
import time
from random import randint
import itertools
import sys
import unittest
from itertools import groupby
from operator import itemgetter
import logging
log = logging.getLogger(__name__)
log.setLevel(logging.INFO)
try:
log.addHandler(logging.NullHandler())
except AttributeError:
pass
def run():
if sys.argv[1:] == ["-"]:
text = sys.stdin.read()
elif sys.argv[1:]:
print("Reading data...")
text = open(sys.argv[1]).read()
else:
text = "banana"
print("Sorting...")
result = longest_common_substring(text)
print('Longest common substrings in "{0}..." are:\n{1}'.format(text[:20], result))
def longest_common_substring(text):
"""Get the longest common substrings and their positions.
>>> longest_common_substring('banana')
{'ana': [1, 3]}
>>> text = "not so Agamemnon, who spoke fiercely to "
>>> sorted(longest_common_substring(text).items())
[(' s', [3, 21]), ('no', [0, 13]), ('o ', [5, 20, 38])]
This function can be easy modified for any criteria, e.g. for searching ten
longest non overlapping repeated substrings.
"""
sa, rsa, lcp = suffix_array(text)
maxlen = max(lcp)
result = {}
for i in range(1, len(text)):
if lcp[i] == maxlen:
j1, j2, h = sa[i - 1], sa[i], lcp[i]
assert text[j1 : j1 + h] == text[j2 : j2 + h]
substring = text[j1 : j1 + h]
if substring not in result:
result[substring] = [j1]
result[substring].append(j2)
return dict((k, sorted(v)) for k, v in result.items())
def suffix_array(text, _step=16):
"""Analyze all common strings in the text.
Short substrings of the length _step a are first pre-sorted. The are the
results repeatedly merged so that the garanteed number of compared
characters bytes is doubled in every iteration until all substrings are
sorted exactly.
Arguments:
text: The text to be analyzed.
_step: Is only for optimization and testing. It is the optimal length
of substrings used for initial pre-sorting. The bigger value is
faster if there is enough memory. Memory requirements are
approximately (estimate for 32 bit Python 3.3):
len(text) * (29 + (_size + 20 if _size > 2 else 0)) + 1MB
Return value: (tuple)
(sa, rsa, lcp)
sa: Suffix array for i in range(1, size):
assert text[sa[i-1]:] < text[sa[i]:]
rsa: Reverse suffix array for i in range(size):
assert rsa[sa[i]] == i
lcp: Longest common prefix for i in range(1, size):
assert text[sa[i-1]:sa[i-1]+lcp[i]] == text[sa[i]:sa[i]+lcp[i]]
if sa[i-1] + lcp[i] < len(text):
assert text[sa[i-1] + lcp[i]] < text[sa[i] + lcp[i]]
>>> suffix_array(text='banana')
([5, 3, 1, 0, 4, 2], [3, 2, 5, 1, 4, 0], [0, 1, 3, 0, 0, 2])
Explanation: 'a' < 'ana' < 'anana' < 'banana' < 'na' < 'nana'
The Longest Common String is 'ana': lcp[2] == 3 == len('ana')
It is between tx[sa[1]:] == 'ana' < 'anana' == tx[sa[2]:]
"""
tx = text
t0 = time.time()
size = len(tx)
step = min(max(_step, 1), len(tx))
sa = list(range(len(tx)))
log.debug("%6.3f pre sort", time.time() - t0)
sa.sort(key=lambda i: tx[i : i + step])
log.debug("%6.3f after sort", time.time() - t0)
grpstart = size * [False] + [True] # a boolean map for iteration speedup.
# It helps to skip yet resolved values. The last value True is a sentinel.
rsa = size * [None]
stgrp, igrp = "", 0
for i, pos in enumerate(sa):
st = tx[pos : pos + step]
if st != stgrp:
grpstart[igrp] = igrp < i - 1
stgrp = st
igrp = i
rsa[pos] = igrp
sa[i] = pos
grpstart[igrp] = igrp < size - 1 or size == 0
log.debug("%6.3f after group", time.time() - t0)
while grpstart.index(True) < size:
# assert step <= size
nmerge = 0
nextgr = grpstart.index(True)
while nextgr < size:
igrp = nextgr
nextgr = grpstart.index(True, igrp + 1)
glist = []
for ig in range(igrp, nextgr):
pos = sa[ig]
if rsa[pos] != igrp:
break
newgr = rsa[pos + step] if pos + step < size else -1
glist.append((newgr, pos))
glist.sort()
for ig, g in groupby(glist, key=itemgetter(0)):
g = [x[1] for x in g]
sa[igrp : igrp + len(g)] = g
grpstart[igrp] = len(g) > 1
for pos in g:
rsa[pos] = igrp
igrp += len(g)
nmerge += len(glist)
log.debug("%6.3f for step=%d nmerge=%d", time.time() - t0, step, nmerge)
step *= 2
del grpstart
# create LCP array
lcp = size * [None]
h = 0
for i in range(size):
if rsa[i] > 0:
j = sa[rsa[i] - 1]
while i != size - h and j != size - h and tx[i + h] == tx[j + h]:
h += 1
lcp[rsa[i]] = h
if h > 0:
h -= 1
if size > 0:
lcp[0] = 0
log.debug("%6.3f end", time.time() - t0)
return sa, rsa, lcp
# ---
class TestMixin(object):
def suffix_verify(self, text, step=16):
tx = text
sa, rsa, lcp = suffix_array(text=tx, _step=step)
self.assertEqual(set(sa), set(range(len(tx))))
ok = True
for i0, i1, h in zip(sa[:-1], sa[1:], lcp[1:]):
self.assertEqual(
tx[i1 : i1 + h],
tx[i0 : i0 + h],
"Verify LCP characters equal on text '%s...'" % text[:20],
)
self.assertGreater(
tx[i1 + h : i1 + h + 1],
tx[i0 + h : i0 + h + 1],
"Verify LCP+1 char is different '%s...'" % text[:20],
)
self.assertLessEqual(
max(i0, i1),
len(tx) - h,
"Verify LCP is not more than length of string '%s...'" % text[:20],
)
self.assertTrue(ok)
class SuffixArrayTest(unittest.TestCase, TestMixin):
def test_16(self):
# 'a' < 'ana' < 'anana' < 'banana' < 'na' < 'nana'
expect = ([5, 3, 1, 0, 4, 2], [3, 2, 5, 1, 4, 0], [0, 1, 3, 0, 0, 2])
self.assertEqual(suffix_array(text="banana", _step=16), expect)
def test_1(self):
expect = ([5, 3, 1, 0, 4, 2], [3, 2, 5, 1, 4, 0], [0, 1, 3, 0, 0, 2])
self.assertEqual(suffix_array(text="banana", _step=1), expect)
def test_mini(self):
self.assertEqual(suffix_array(text="", _step=1), ([], [], []))
self.assertEqual(suffix_array(text="a", _step=1), ([0], [0], [0]))
self.assertEqual(suffix_array(text="aa", _step=1), ([1, 0], [1, 0], [0, 1]))
self.assertEqual(
suffix_array(text="aaa", _step=1), ([2, 1, 0], [2, 1, 0], [0, 1, 2])
)
def test_example(self):
self.suffix_verify("abracadabra")
def test_cartesian(self):
"""Test all combinations of alphabet "ABC" up to length 4 characters"""
for size in range(7):
for cartesian in itertools.product(*(size * ["ABC"])):
text = "".join(cartesian)
log.debug('Testing "%s"', text)
self.suffix_verify(text, 1)
def test_lcp(self):
expect = {"ana": [1, 3]}
self.assertDictEqual(longest_common_substring("banana"), expect)
expect = {" s": [3, 21], "no": [0, 13], "o ": [5, 20, 38]}
self.assertDictEqual(
longest_common_substring("not so Agamemnon, who spoke fiercely to "), expect
)
class SlowTests(unittest.TestCase, TestMixin):
"""Slow development tests running many minutes.
It can be run only by an EXPLICIT command!
e.g.: python -m unittest maxsubstring.SlowTests._test_random
"""
def _test_random(self):
for power in range(2, 21, 2):
size = randint(2 ** (power - 1), 2**power)
for alphabet in (2, 4, 16, 256):
text = "".join(chr(65 + randint(0, alphabet - 1)) for _ in range(size))
log.debug("%s %s %s", size, alphabet, 1)
self.suffix_verify(text, 1)
log.debug("%s %s %s", size, alphabet, 16)
self.suffix_verify(text, 16)
if __name__ == "__main__":
run()

View File

@ -0,0 +1,74 @@
import torch
from transformers import (
AutoModelForCausalLM,
BitsAndBytesConfig,
GPTNeoXTokenizerFast,
LogitsProcessor,
LogitsProcessorList,
)
class StopAfterPlusIsGenerated(LogitsProcessor):
def __init__(self, plus_token_id, eos_token_id):
super().__init__()
self.plus_token_id = plus_token_id
self.eos_token_id = eos_token_id
def __call__(self, input_ids, scores):
forced_eos = torch.full((scores.size(1),), -float("inf")).to(
device=scores.device, dtype=scores.dtype
)
forced_eos[self.eos_token_id] = 0
scores[input_ids[:, -1] == self.plus_token_id] = forced_eos
return scores
class Model:
def __init__(self):
name = "/workspace/models/mpt-30b-drama"
self.tokenizer = GPTNeoXTokenizerFast.from_pretrained(
name, pad_token="<|endoftext|>"
)
# model_config = AutoConfig.from_pretrained(name, trust_remote_code=True)
# model_config.attn_config["attn_impl"] = "triton"
# model_config.init_device = "cuda:0"
# model_config.eos_token_id = self.tokenizer.eos_token_id
quantization_config = BitsAndBytesConfig(
load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16
)
self.model = AutoModelForCausalLM.from_pretrained(
name,
device_map="auto",
quantization_config=quantization_config,
trust_remote_code=True,
)
self.logits_processor = LogitsProcessorList(
[StopAfterPlusIsGenerated(559, self.tokenizer.eos_token_id)]
)
def generate(self, prompt):
encoded = self.tokenizer(
prompt, return_tensors="pt", padding=True, truncation=True
).to("cuda")
gen_tokens = self.model.generate(
input_ids=encoded.input_ids,
attention_mask=encoded.attention_mask,
pad_token_id=0,
eos_token_id=self.tokenizer.eos_token_id,
do_sample=True,
temperature=0.90,
use_cache=True,
max_new_tokens=512,
logits_processor=self.logits_processor,
)
return self.tokenizer.batch_decode(gen_tokens, skip_special_tokens=True)[0][
len(prompt) :
]

View File

@ -0,0 +1,67 @@
import os
import requests
import utils
from model import Model
from rich import traceback
traceback.install()
def generate(model, config, context):
print(
f"Generating reply for https://rdrama.net/comment/{context['thread'][-1]['id']}"
)
prompt = utils.build_prompt(config, context["post"], context["thread"])
candidates = []
rejects = []
while len(candidates) < config["num_candidates"]:
gen_text = model.generate(prompt)
reply = utils.extract_reply(gen_text)
print(f"Generated text: {gen_text}\nReply:\n{reply}")
reply = utils.format_reply(config, reply)
if len(reply) == 0:
print("Retrying: reply empty after processing.")
rejects.append(reply)
elif utils.is_low_quality(reply, context["post"], context["thread"]):
print("Retrying: low quality reply.")
rejects.append(reply)
else:
candidates.append(reply)
print("Accepting reply.")
if any(not utils.contains_url(c) for c in candidates):
for candidate in candidates:
if utils.contains_url(candidate):
rejects.append(candidate)
candidates = [c for c in candidates if not utils.contains_url(c)]
body = utils.median_by_key(candidates, key=utils.reply_length)
return {"prompt": prompt, "body": body, "candidates": candidates + rejects}
def process_queue():
response = requests.get("https://rdra.ma/bussyboy").json()
config = response["config"]
queue = response["queue"]
model = Model()
for request_id, context in queue.items():
json_data = generate(model, config, context)
post_response = requests.post(
f"https://rdra.ma/bussyboy/{request_id}?key={os.getenv('SERVER_KEY')}",
json=json_data,
headers={"Content-Type": "application/json"},
)
print(f"Response for request {request_id}: {post_response.status_code}")
if __name__ == "__main__":
process_queue()

View File

@ -0,0 +1,206 @@
import random
import re
from fuzzywuzzy import fuzz
from transformers import GPTNeoXTokenizerFast
from maxsubstring import longest_common_substring
URL_REGEX = (
r"http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+"
)
tokenizer = GPTNeoXTokenizerFast.from_pretrained("float-trip/mpt-30b-drama")
def remove_notifications(text):
"""Change @float-trip to <span>@</span>float-trip and carp to <span>c</span>arp."""
text = text.replace("@", "<span>@</span>")
notified_users = [
"aevan",
"avean",
" capy",
"the rodent",
"carp",
"clit",
"snakes",
"sneks",
"snekky",
"snekchad",
"jc",
"justcool",
"lawlz",
"transgirltradwife",
"impassionata",
"pizzashill",
"idio3",
"idio ",
"telegram ",
"schizo",
"joan",
"pewkie",
"homocracy",
"donger",
"geese",
"soren",
"marseyismywaifu",
"mimw",
"heymoon",
"gaypoon",
"jollymoon",
"chiobu",
"mccox",
"august",
"marco",
"klen",
]
def replace(match):
# Insert <span></span> around the first character of the matched string.
user = match.group()
return f"<span>{user[:1]}</span>{user[1:]}"
for user in notified_users:
text = re.sub(user, replace, text, flags=re.IGNORECASE)
return text
def format_reply(config, text):
for username in config["fake_usernames"]:
text.replace(username, config["username"])
text = replace_rdrama_images(text)
text = remove_notifications(text)
return text.strip()
def is_low_quality(reply, _post, comments):
"""
Label the reply as low quality if:
- The Levenshtein distance determines it's similar to a previous comment in the thread.
- len(longest_common_substring) > 100
- After removing links, Markdown images, and quoted text, the length is < 10.
"""
for comment in comments:
if fuzz.ratio(reply, comment["body"]) > 90:
return True
lcs = list(longest_common_substring(reply).keys())[0]
if len(lcs) >= 100:
return True
if reply_length(reply) < 10:
return True
# Lost pinging rights.
if re.findall(r"!\w+", reply):
return True
return False
def contains_url(text):
return re.search(URL_REGEX, text) is not None
def replace_rdrama_images(text):
"""Replace images pointing to rdrama.net with a loading image."""
loading = "https://i.rdrama.net/i/l.webp"
webp_pattern = r"https://\S*\.rdrama\.net/\S*\.webp"
md_img_pattern = r"!\[[^\]]*\]\((https://\S*\.rdrama\.net)?/\S*\)"
text = re.sub(webp_pattern, loading, text)
text = re.sub(md_img_pattern, f"![]({loading})", text)
return text
def normalize_emojis(s):
"""Bring # and ! to the front of an emoji."""
def repl(match):
# Extract the word between colons and the special characters.
word = match.group(0)
specials = set(re.findall(r"[#!]", word))
# Sort specials and append the word without specials.
new_emoji = "".join(sorted(specials, reverse=True)) + re.sub(r"[#!]", "", word)
return new_emoji
emoji_pattern = r"(?<=:)[a-zA-Z@#!]*[#!][a-zA-Z@#!]*(?=:)"
s = re.sub(emoji_pattern, repl, s)
return s
def build_prompt(config, post, comments):
prompt = (
f"[Post] [Author] {post['author_name']} "
f"[Title] {post['title']} [URL] {post['url']} "
f"[Hole] {post['sub'] or 'N/A'} [Votes] +71 / -0\n\n"
f"{post['body']}\n\n[Comments]"
)
comments.append({"author_name": config["username"], "body": ""})
for depth, comment in enumerate(comments):
body = normalize_emojis(comment["body"])
author = comment["author_name"]
comment_str = f"\n\n{author} +45 / -0\n{body}"
indent = depth * " "
comment_str = "\n".join([indent + line for line in comment_str.split("\n")])
prompt += comment_str
prompt = prompt.replace(config["username"], random.choice(config["fake_usernames"]))
prompt = prompt.replace("👻", "Ghost")
prompt = prompt.strip() + "\n"
# Truncate the prompt to leave room for generation.
tokens = tokenizer.tokenize(prompt)
if len(tokens) > config["prompt_token_limit"]:
tokens = tokens[-config["prompt_token_limit"] :]
prompt = tokenizer.convert_tokens_to_string(tokens)
return prompt
def reply_length(reply):
"""Return the length of the reply, without Markdown images, URLs, or quoted text."""
# Remove Markdown images and URLs.
reply = re.sub(r"!\[.*?\]\(.*?\)", "", reply)
reply = re.sub(URL_REGEX, "", reply)
# Remove quoted text.
lines = reply.splitlines()
lines = [line for line in lines if not line.lstrip().startswith((">", "\\>"))]
reply = "\n".join(lines).strip()
return len(reply)
def median_by_key(lst, key):
lst = sorted(lst, key=key)
mid_index = len(lst) // 2
# For lists of even length, pick either option as the median.
if len(lst) % 2 == 0:
return random.choice([lst[mid_index - 1], lst[mid_index]])
else:
return lst[mid_index]
def count_tokens(text):
return len(tokenizer(text).input_ids)
def extract_reply(text):
"""
Generated text will either:
- Be cut off at the token limit
- End with the start of a new comment: `float-trip +10`
For the latter case, drop the last line.
"""
new_comment_pattern = r"^ *[\w-]* +\+.*$"
lines = text.split("\n")
if re.match(new_comment_pattern, lines[-1]):
lines = lines[:-1]
return "\n".join([line.strip() for line in lines]).strip()

View File

@ -0,0 +1,72 @@
import os
from io import BytesIO
import requests
import torch
from diffusers import EulerAncestralDiscreteScheduler, StableDiffusionXLPipeline
from PIL import Image
from rich import traceback
traceback.install()
scheduler = EulerAncestralDiscreteScheduler()
pipe = StableDiffusionXLPipeline.from_single_file(
"/workspace/models/marsey-xl.safetensors",
torch_dtype=torch.bfloat16,
safety_checker=None,
scheduler=scheduler,
)
pipe.to("cuda")
pipe.enable_xformers_memory_efficient_attention()
def generate(prompt):
images = []
for _ in range(3):
batch = pipe(prompt=[prompt] * 3, guidance_scale=7.0, num_inference_steps=30)
images.extend(batch.images)
return images
def image_grid(images, rows, cols):
assert len(images) == rows * cols
width = max(image.width for image in images)
height = max(image.height for image in images)
grid = Image.new("RGB", size=(width * cols, height * rows))
for index, image in enumerate(images):
row = index // cols
col = index % cols
grid.paste(image, (col * width, row * height))
return grid
def process_queue():
queue_json = requests.get("https://rdra.ma/marseygen").json()
for request_id, prompt in queue_json.items():
images = generate(prompt)
grid_image = image_grid(images, 3, 3)
# Save the image to a BytesIO object.
buffered = BytesIO()
grid_image.save(buffered, format="WEBP")
# Reset the buffer position to the beginning.
buffered.seek(0)
# Post the image to the server.
response = requests.post(
f"https://rdra.ma/marseygen/{request_id}?key={os.getenv('SERVER_KEY')}",
files={"file": ("image.webp", buffered, "image/webp")},
)
print(f"Response for request {request_id}: {response.status_code}")
if __name__ == "__main__":
process_queue()

0
web/__init__.py 100644
View File

148
web/main.py 100644
View File

@ -0,0 +1,148 @@
import re
import logging
from io import BytesIO
from typing import List
import requests
from fastapi import FastAPI, HTTPException, Request, UploadFile
from fastapi.responses import HTMLResponse, Response
from fastapi.templating import Jinja2Templates
from pydantic import BaseModel
from sqlitedict import SqliteDict
from starlette.responses import RedirectResponse
from bots.data import config, marseygen_queue, bussyboy_queue, db, bussyboy_log
from bots.clients.drama import DramaClient
app = FastAPI()
templates = Jinja2Templates(directory="templates")
# Set up logging.
logging.basicConfig(
filename=f"{config['data_dir']}/logs/web.log",
filemode="a",
format="%(name)s - %(levelname)s - %(message)s",
level=logging.INFO,
)
logger = logging.getLogger(__name__)
@app.get("/")
def read_root():
return RedirectResponse(url="https://rdrama.net")
@app.get(
"/h/{_hole}/post/{_post_id}/{_post_slug}/{comment_id}", response_class=HTMLResponse
)
async def get_alt_comments_with_hole(
request: Request, _hole: str, _post_id: int, _post_slug: str, comment_id: str
):
return _get_alt_comments(request, comment_id)
@app.get("/post/{_post_id}/{_post_slug}/{comment_id}", response_class=HTMLResponse)
async def get_alt_comments_without_hole(
request: Request, _post_id: int, _post_slug: str, comment_id: str
):
return _get_alt_comments(request, comment_id)
def _get_alt_comments(request, comment_id):
logger.info(f"Request received for {comment_id}")
if comment_id not in bussyboy_log:
return templates.TemplateResponse("404.html", {"request": request})
log = bussyboy_log[comment_id]
return templates.TemplateResponse(
"comment.html",
{
"request": request,
"parent_body": log["parent_body"],
"candidates": log["candidates"],
},
)
@app.get("/bussyboy")
async def bussyboy():
config_keys = ["prompt_token_limit", "num_candidates", "username", "fake_usernames"]
bussyboy_config = {key: config["bussyboy"][key] for key in config_keys}
return {"config": bussyboy_config, "queue": bussyboy_queue}
class BussyboyReplyInfo(BaseModel):
prompt: str
body: str
candidates: List[str]
@app.post("/bussyboy/{request_id}")
async def bussyboy_submit(
request: Request, request_id: str, key: str, info: BussyboyReplyInfo
):
if key != config["server_key"]:
return Response(status_code=400)
if request_id not in bussyboy_queue:
logging.error(f"Unknown request ID: {request_id}")
raise HTTPException(status_code=500)
# Post reply.
request = bussyboy_queue[request_id]
drama_client = DramaClient(config["bussyboy"]["token"], logger=logger)
reply_response = drama_client.reply(request["thread"][-1], info.body)
del bussyboy_queue[request_id]
log = {
# Log thread context.
"parent_body": request["thread"][-1]["body"],
"post": request["post"],
"thread": request["thread"],
# Log generation info.
"prompt": info.prompt,
"body": info.body,
"candidates": info.candidates,
}
bussyboy_log[str(reply_response["id"])] = log
return Response(status_code=200)
@app.get("/marseygen")
async def marseygen():
return {key: value[1] for key, value in marseygen_queue.items()}
@app.post("/marseygen/{request_id}")
async def marseygen_submit(request_id: str, key: str, file: UploadFile):
if key != config["server_key"]:
return Response(status_code=400)
if request_id not in marseygen_queue:
logging.error(f"Unknown request ID: {request_id}")
raise HTTPException(status_code=500)
comment, prompt = marseygen_queue[request_id]
del marseygen_queue[request_id]
contents = await file.read()
image_bytes = BytesIO(contents)
# Reset the stream to the start.
image_bytes.seek(0)
# No pinging or gambling for Marseygen.
prompt = re.sub(r"(!)(\w+)", r"<span>\1</span>\2", prompt)
# Post reply.
image = {"file": ("image.webp", image_bytes, "image/webp")}
drama_client = DramaClient(config["marseygen"]["token"], logger=logger)
drama_client.reply(comment, f"`{prompt}`<br>[image.webp]", image)
return Response(status_code=200)

View File

@ -0,0 +1,14 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Bot comments - 404</title>
<link href="https://cdn.jsdelivr.net/npm/tailwindcss@2.2.19/dist/tailwind.min.css" rel="stylesheet">
</head>
<body class="bg-gray-800 text-gray-200 flex justify-center items-center min-h-screen">
<div class="w-full max-w-xl bg-gray-700 rounded-xl shadow-lg p-6 space-y-6">
<div class="text-2xl font-bold mb-4">Comment ID not in the database. 404.</div>
</div>
</body>
</html>

View File

@ -0,0 +1,23 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Bot comments</title>
<link href="https://cdn.jsdelivr.net/npm/tailwindcss@2.2.19/dist/tailwind.min.css" rel="stylesheet">
</head>
<body class="bg-gray-800 text-gray-200 flex justify-center items-center min-h-screen mt-10 mb-10">
<div class="w-full max-w-xl bg-gray-700 rounded-xl shadow-lg p-6 space-y-6">
<div class="text-base font-bold mb-4">{{ parent_body }}</div>
<div class="space-y-4">
{% for candidate in candidates %}
<div class="bg-gray-600 rounded-md shadow p-4">
<p class="text-sm overflow-auto break-words">
{{ candidate | replace("\n", "<br/>") | safe }}
</p>
</div>
{% endfor %}
</div>
</div>
</body>
</html>