master
commit
e813c12852
|
@ -0,0 +1,17 @@
|
||||||
|
beat_schedule = {
|
||||||
|
"find-prompts": {"task": "tasks.find_prompts", "schedule": 60.0, "args": ()},
|
||||||
|
}
|
||||||
|
|
||||||
|
task_annotations = {
|
||||||
|
"tasks.post_reply": {"rate_limit": "5/m"},
|
||||||
|
"tasks.find_prompts": {"rate_limit": "1/m"},
|
||||||
|
}
|
||||||
|
|
||||||
|
task_default_queue = "api"
|
||||||
|
|
||||||
|
broker_url = ""
|
||||||
|
result_backend = ""
|
||||||
|
|
||||||
|
task_routes = {"tasks.generate_reply": {"queue": "gen"}}
|
||||||
|
|
||||||
|
worker_prefetch_multiplier = 1
|
|
@ -0,0 +1,89 @@
|
||||||
|
import requests
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
|
class DramaClient:
|
||||||
|
BASE_URL = "https://rdrama.net"
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.token = os.environ.get("RDRAMA_TOKEN", "")
|
||||||
|
self.last_processed_id = 2821161 # Most recent comment seen.
|
||||||
|
|
||||||
|
def get(self, endpoint):
|
||||||
|
print(endpoint)
|
||||||
|
time.sleep(5)
|
||||||
|
|
||||||
|
r = requests.get(
|
||||||
|
f"{self.BASE_URL}{endpoint}", headers={"Authorization": self.token}
|
||||||
|
)
|
||||||
|
|
||||||
|
if r.status_code != 200:
|
||||||
|
print("Error!", r, r.status_code, r.content)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
return r.json()["data"]
|
||||||
|
|
||||||
|
def post(self, endpoint, payload, files=[]):
|
||||||
|
print(endpoint)
|
||||||
|
time.sleep(5)
|
||||||
|
|
||||||
|
r = requests.post(
|
||||||
|
f"{self.BASE_URL}{endpoint}",
|
||||||
|
payload,
|
||||||
|
headers={"Authorization": self.token},
|
||||||
|
files=files,
|
||||||
|
)
|
||||||
|
|
||||||
|
if r.status_code != 200:
|
||||||
|
print("Error!", r, r.status_code, r.content)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
return r.json()
|
||||||
|
|
||||||
|
def fetch_new_comments(self):
|
||||||
|
comments = []
|
||||||
|
if self.last_processed_id is None:
|
||||||
|
comments += self.fetch_page(1)
|
||||||
|
else:
|
||||||
|
earliest_id = None
|
||||||
|
page = 1
|
||||||
|
# Fetch comments until we find the last one processed.
|
||||||
|
while earliest_id is None or earliest_id > self.last_processed_id:
|
||||||
|
page_comments = self.fetch_page(page)
|
||||||
|
earliest_id = min([c["id"] for c in page_comments])
|
||||||
|
comments += [
|
||||||
|
c for c in page_comments if c["id"] > self.last_processed_id
|
||||||
|
]
|
||||||
|
page += 1
|
||||||
|
|
||||||
|
if not comments:
|
||||||
|
return []
|
||||||
|
|
||||||
|
self.last_processed_id = max(c["id"] for c in comments)
|
||||||
|
|
||||||
|
# New comments may have pushed others to page n+1 while fetching.
|
||||||
|
deduped_comments = {c["id"]: c for c in comments}.values()
|
||||||
|
|
||||||
|
# Oldest first.
|
||||||
|
comments.reverse()
|
||||||
|
|
||||||
|
return comments
|
||||||
|
|
||||||
|
def fetch_page(self, page):
|
||||||
|
return self.get(f"/comments?page={page}")
|
||||||
|
|
||||||
|
def reply(self, parent_fullname, submission, body, image_path=None):
|
||||||
|
payload = {
|
||||||
|
"parent_fullname": parent_fullname,
|
||||||
|
"submission": submission,
|
||||||
|
"body": body,
|
||||||
|
}
|
||||||
|
|
||||||
|
files = []
|
||||||
|
if image_path:
|
||||||
|
filename = image_path.split("/")[-1]
|
||||||
|
files = {"file": (filename, open(image_path, "rb"), "image/webp")}
|
||||||
|
|
||||||
|
self.post("/comment", payload, files=files)
|
|
@ -0,0 +1,31 @@
|
||||||
|
# Marseygen
|
||||||
|
|
||||||
|
Stable Diffusion bot with distributed inference.
|
||||||
|
|
||||||
|
# Usage
|
||||||
|
|
||||||
|
* Set up [InvokeAI](https://github.com/invoke-ai/InvokeAI) on the gen workers and activate the `ldm` environment
|
||||||
|
|
||||||
|
* Install rabbitmq and redis, add URLs to `celeryconfig.py`
|
||||||
|
|
||||||
|
* `git clone https://github.com/float-trip/marseygen`
|
||||||
|
|
||||||
|
* `pip install -r marseygen/requirements.txt`
|
||||||
|
|
||||||
|
* `mv marseygen/*.py InvokeAI && cd InvokeAI`
|
||||||
|
* Running the gen workers from this dir circumvents some Python import issues that I don't care to figure out right now
|
||||||
|
|
||||||
|
* Start the API worker
|
||||||
|
|
||||||
|
`celery -A tasks worker -B --concurrency 1 --loglevel=INFO`
|
||||||
|
|
||||||
|
* Start a gen worker for each GPU
|
||||||
|
|
||||||
|
```sh
|
||||||
|
export CUDA_VISIBLE_DEVICES=0,
|
||||||
|
export WORKER_HOST="user@gen_worker_ip"
|
||||||
|
export WORKER_SSH_PORT="22"
|
||||||
|
export WORKER_ID="unique_id"
|
||||||
|
celery -A tasks worker -Q gen -n unique_name -B --concurrency 1 --loglevel=INFO`
|
||||||
|
```
|
||||||
|
|
|
@ -0,0 +1,5 @@
|
||||||
|
celery==5.2.7
|
||||||
|
celery_singleton==0.3.1
|
||||||
|
Pillow==9.2.0
|
||||||
|
pry.py==0.1.1
|
||||||
|
requests==2.27.1
|
|
@ -0,0 +1,140 @@
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import random
|
||||||
|
from celery import Celery, Task, chain
|
||||||
|
import time
|
||||||
|
import celeryconfig
|
||||||
|
|
||||||
|
from client import DramaClient
|
||||||
|
from utils import concat_images
|
||||||
|
|
||||||
|
app = Celery("tasks")
|
||||||
|
app.config_from_object(celeryconfig)
|
||||||
|
|
||||||
|
client = DramaClient()
|
||||||
|
|
||||||
|
generator = None
|
||||||
|
|
||||||
|
|
||||||
|
#
|
||||||
|
# API worker tasks
|
||||||
|
#
|
||||||
|
@app.task
|
||||||
|
def post_reply(context):
|
||||||
|
basename = os.path.basename(context["image_path"])
|
||||||
|
save_path = f"/fs/marseys/{basename}"
|
||||||
|
|
||||||
|
print(f"Copying {basename}")
|
||||||
|
|
||||||
|
# Copy image from remote machine.
|
||||||
|
subprocess.run(
|
||||||
|
[
|
||||||
|
"rsync",
|
||||||
|
"-a",
|
||||||
|
f"{context['worker_host']}:{context['image_path']}",
|
||||||
|
save_path,
|
||||||
|
"-e",
|
||||||
|
f"ssh -p {context['worker_ssh_port']}",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Replying for prompt {context['prompt']}")
|
||||||
|
|
||||||
|
client.reply(
|
||||||
|
context["parent_fullname"],
|
||||||
|
context["submission"],
|
||||||
|
f"`{context['prompt']}`",
|
||||||
|
save_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FindPromptsTask(Task):
|
||||||
|
last_call = None
|
||||||
|
|
||||||
|
# Temp fix for comments being replied to multiple times.
|
||||||
|
queued_ids = set()
|
||||||
|
|
||||||
|
|
||||||
|
@app.task(base=FindPromptsTask)
|
||||||
|
def find_prompts():
|
||||||
|
if find_prompts.last_call is not None and time.time() - find_prompts.last_call < 60:
|
||||||
|
return
|
||||||
|
|
||||||
|
find_prompts.last_call = time.time()
|
||||||
|
|
||||||
|
print("Looking for prompts.")
|
||||||
|
comments = client.fetch_new_comments()
|
||||||
|
|
||||||
|
for comment in comments:
|
||||||
|
if comment["id"] in find_prompts.queued_ids:
|
||||||
|
continue
|
||||||
|
|
||||||
|
find_prompts.queued_ids.add(comment["id"])
|
||||||
|
|
||||||
|
reply_contexts = [
|
||||||
|
{
|
||||||
|
"parent_fullname": f"c_{comment['id']}",
|
||||||
|
"submission": comment["post_id"],
|
||||||
|
"prompt": line[4:],
|
||||||
|
}
|
||||||
|
for line in comment["body"].split("\n")
|
||||||
|
if line.startswith("!sd ")
|
||||||
|
]
|
||||||
|
|
||||||
|
# Max 5 prompts per comment.
|
||||||
|
reply_contexts = reply_contexts[:5]
|
||||||
|
|
||||||
|
for context in reply_contexts:
|
||||||
|
print(f"Queueing prompt `{context['prompt']}`.")
|
||||||
|
chain(
|
||||||
|
generate_reply.s(context).set(queue="gen"),
|
||||||
|
post_reply.s().set(queue="api"),
|
||||||
|
).apply_async()
|
||||||
|
|
||||||
|
|
||||||
|
#
|
||||||
|
# Generation worker tasks
|
||||||
|
#
|
||||||
|
class GenTask(Task):
|
||||||
|
_generator = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def generator(self):
|
||||||
|
if self._generator is None:
|
||||||
|
from ldm.generate import Generate
|
||||||
|
|
||||||
|
self._generator = Generate(sampler_name="k_euler_a")
|
||||||
|
self._generator.load_model()
|
||||||
|
|
||||||
|
print("Model loaded.")
|
||||||
|
|
||||||
|
return self._generator
|
||||||
|
|
||||||
|
|
||||||
|
@app.task(base=GenTask)
|
||||||
|
def generate_reply(context):
|
||||||
|
print(f"Generating `{context['prompt']}`.")
|
||||||
|
|
||||||
|
if not os.path.exists("out"):
|
||||||
|
os.makedirs("out")
|
||||||
|
|
||||||
|
results = generate_reply.generator.prompt2png(
|
||||||
|
context["prompt"], outdir=f"out/{os.environ['WORKER_ID']}", iterations=9
|
||||||
|
)
|
||||||
|
|
||||||
|
image_paths = [r[0] for r in results]
|
||||||
|
grid = concat_images(image_paths, size=(512, 512), shape=(3, 3))
|
||||||
|
|
||||||
|
grid_basename = f"{random.randrange(10**6, 10**7)}.webp"
|
||||||
|
|
||||||
|
if not os.path.exists("grid"):
|
||||||
|
os.makedirs("grid")
|
||||||
|
|
||||||
|
grid_path = f"grid/{grid_basename}"
|
||||||
|
grid.save(grid_path, "WEBP")
|
||||||
|
|
||||||
|
context["image_path"] = os.path.abspath(grid_path)
|
||||||
|
context["worker_host"] = os.environ["WORKER_HOST"]
|
||||||
|
context["worker_ssh_port"] = os.environ["WORKER_SSH_PORT"]
|
||||||
|
|
||||||
|
return context
|
|
@ -0,0 +1,24 @@
|
||||||
|
from PIL import Image, ImageOps
|
||||||
|
|
||||||
|
#
|
||||||
|
# https://gist.github.com/njanakiev/1932e0a450df6d121c05069d5f7d7d6f
|
||||||
|
#
|
||||||
|
def concat_images(image_paths, size, shape=None):
|
||||||
|
# Open images and resize them
|
||||||
|
width, height = size
|
||||||
|
images = map(Image.open, image_paths)
|
||||||
|
images = [ImageOps.fit(image, size, Image.ANTIALIAS) for image in images]
|
||||||
|
|
||||||
|
# Create canvas for the final image with total size
|
||||||
|
shape = shape if shape else (1, len(images))
|
||||||
|
image_size = (width * shape[1], height * shape[0])
|
||||||
|
image = Image.new("RGB", image_size)
|
||||||
|
|
||||||
|
# Paste images into final image
|
||||||
|
for row in range(shape[0]):
|
||||||
|
for col in range(shape[1]):
|
||||||
|
offset = width * col, height * row
|
||||||
|
idx = row * shape[1] + col
|
||||||
|
image.paste(images[idx], offset)
|
||||||
|
|
||||||
|
return image
|
Loading…
Reference in New Issue