commit e813c12852228390a9f44fe4571f616c8e86d471 Author: float-trip <102226344+float-trip@users.noreply.github.com> Date: Sat Oct 1 15:10:01 2022 +0000 . diff --git a/celeryconfig.py b/celeryconfig.py new file mode 100644 index 0000000..99e7289 --- /dev/null +++ b/celeryconfig.py @@ -0,0 +1,17 @@ +beat_schedule = { + "find-prompts": {"task": "tasks.find_prompts", "schedule": 60.0, "args": ()}, +} + +task_annotations = { + "tasks.post_reply": {"rate_limit": "5/m"}, + "tasks.find_prompts": {"rate_limit": "1/m"}, +} + +task_default_queue = "api" + +broker_url = "" +result_backend = "" + +task_routes = {"tasks.generate_reply": {"queue": "gen"}} + +worker_prefetch_multiplier = 1 diff --git a/client.py b/client.py new file mode 100644 index 0000000..7f171e6 --- /dev/null +++ b/client.py @@ -0,0 +1,89 @@ +import requests +import sys +import os +import time + + +class DramaClient: + BASE_URL = "https://rdrama.net" + + def __init__(self): + self.token = os.environ.get("RDRAMA_TOKEN", "") + self.last_processed_id = 2821161 # Most recent comment seen. + + def get(self, endpoint): + print(endpoint) + time.sleep(5) + + r = requests.get( + f"{self.BASE_URL}{endpoint}", headers={"Authorization": self.token} + ) + + if r.status_code != 200: + print("Error!", r, r.status_code, r.content) + sys.exit(1) + + return r.json()["data"] + + def post(self, endpoint, payload, files=[]): + print(endpoint) + time.sleep(5) + + r = requests.post( + f"{self.BASE_URL}{endpoint}", + payload, + headers={"Authorization": self.token}, + files=files, + ) + + if r.status_code != 200: + print("Error!", r, r.status_code, r.content) + sys.exit(1) + + return r.json() + + def fetch_new_comments(self): + comments = [] + if self.last_processed_id is None: + comments += self.fetch_page(1) + else: + earliest_id = None + page = 1 + # Fetch comments until we find the last one processed. + while earliest_id is None or earliest_id > self.last_processed_id: + page_comments = self.fetch_page(page) + earliest_id = min([c["id"] for c in page_comments]) + comments += [ + c for c in page_comments if c["id"] > self.last_processed_id + ] + page += 1 + + if not comments: + return [] + + self.last_processed_id = max(c["id"] for c in comments) + + # New comments may have pushed others to page n+1 while fetching. + deduped_comments = {c["id"]: c for c in comments}.values() + + # Oldest first. + comments.reverse() + + return comments + + def fetch_page(self, page): + return self.get(f"/comments?page={page}") + + def reply(self, parent_fullname, submission, body, image_path=None): + payload = { + "parent_fullname": parent_fullname, + "submission": submission, + "body": body, + } + + files = [] + if image_path: + filename = image_path.split("/")[-1] + files = {"file": (filename, open(image_path, "rb"), "image/webp")} + + self.post("/comment", payload, files=files) diff --git a/readme.md b/readme.md new file mode 100644 index 0000000..d351107 --- /dev/null +++ b/readme.md @@ -0,0 +1,31 @@ +# Marseygen + +Stable Diffusion bot with distributed inference. + +# Usage + +* Set up [InvokeAI](https://github.com/invoke-ai/InvokeAI) on the gen workers and activate the `ldm` environment + +* Install rabbitmq and redis, add URLs to `celeryconfig.py` + +* `git clone https://github.com/float-trip/marseygen` + +* `pip install -r marseygen/requirements.txt` + +* `mv marseygen/*.py InvokeAI && cd InvokeAI` + * Running the gen workers from this dir circumvents some Python import issues that I don't care to figure out right now + +* Start the API worker + +`celery -A tasks worker -B --concurrency 1 --loglevel=INFO` + +* Start a gen worker for each GPU + +```sh +export CUDA_VISIBLE_DEVICES=0, +export WORKER_HOST="user@gen_worker_ip" +export WORKER_SSH_PORT="22" +export WORKER_ID="unique_id" +celery -A tasks worker -Q gen -n unique_name -B --concurrency 1 --loglevel=INFO` +``` + diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..16660e6 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,5 @@ +celery==5.2.7 +celery_singleton==0.3.1 +Pillow==9.2.0 +pry.py==0.1.1 +requests==2.27.1 diff --git a/tasks.py b/tasks.py new file mode 100644 index 0000000..67d8f5b --- /dev/null +++ b/tasks.py @@ -0,0 +1,140 @@ +import os +import subprocess +import random +from celery import Celery, Task, chain +import time +import celeryconfig + +from client import DramaClient +from utils import concat_images + +app = Celery("tasks") +app.config_from_object(celeryconfig) + +client = DramaClient() + +generator = None + + +# +# API worker tasks +# +@app.task +def post_reply(context): + basename = os.path.basename(context["image_path"]) + save_path = f"/fs/marseys/{basename}" + + print(f"Copying {basename}") + + # Copy image from remote machine. + subprocess.run( + [ + "rsync", + "-a", + f"{context['worker_host']}:{context['image_path']}", + save_path, + "-e", + f"ssh -p {context['worker_ssh_port']}", + ] + ) + + print(f"Replying for prompt {context['prompt']}") + + client.reply( + context["parent_fullname"], + context["submission"], + f"`{context['prompt']}`", + save_path, + ) + + +class FindPromptsTask(Task): + last_call = None + + # Temp fix for comments being replied to multiple times. + queued_ids = set() + + +@app.task(base=FindPromptsTask) +def find_prompts(): + if find_prompts.last_call is not None and time.time() - find_prompts.last_call < 60: + return + + find_prompts.last_call = time.time() + + print("Looking for prompts.") + comments = client.fetch_new_comments() + + for comment in comments: + if comment["id"] in find_prompts.queued_ids: + continue + + find_prompts.queued_ids.add(comment["id"]) + + reply_contexts = [ + { + "parent_fullname": f"c_{comment['id']}", + "submission": comment["post_id"], + "prompt": line[4:], + } + for line in comment["body"].split("\n") + if line.startswith("!sd ") + ] + + # Max 5 prompts per comment. + reply_contexts = reply_contexts[:5] + + for context in reply_contexts: + print(f"Queueing prompt `{context['prompt']}`.") + chain( + generate_reply.s(context).set(queue="gen"), + post_reply.s().set(queue="api"), + ).apply_async() + + +# +# Generation worker tasks +# +class GenTask(Task): + _generator = None + + @property + def generator(self): + if self._generator is None: + from ldm.generate import Generate + + self._generator = Generate(sampler_name="k_euler_a") + self._generator.load_model() + + print("Model loaded.") + + return self._generator + + +@app.task(base=GenTask) +def generate_reply(context): + print(f"Generating `{context['prompt']}`.") + + if not os.path.exists("out"): + os.makedirs("out") + + results = generate_reply.generator.prompt2png( + context["prompt"], outdir=f"out/{os.environ['WORKER_ID']}", iterations=9 + ) + + image_paths = [r[0] for r in results] + grid = concat_images(image_paths, size=(512, 512), shape=(3, 3)) + + grid_basename = f"{random.randrange(10**6, 10**7)}.webp" + + if not os.path.exists("grid"): + os.makedirs("grid") + + grid_path = f"grid/{grid_basename}" + grid.save(grid_path, "WEBP") + + context["image_path"] = os.path.abspath(grid_path) + context["worker_host"] = os.environ["WORKER_HOST"] + context["worker_ssh_port"] = os.environ["WORKER_SSH_PORT"] + + return context diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..5313ce3 --- /dev/null +++ b/utils.py @@ -0,0 +1,24 @@ +from PIL import Image, ImageOps + +# +# https://gist.github.com/njanakiev/1932e0a450df6d121c05069d5f7d7d6f +# +def concat_images(image_paths, size, shape=None): + # Open images and resize them + width, height = size + images = map(Image.open, image_paths) + images = [ImageOps.fit(image, size, Image.ANTIALIAS) for image in images] + + # Create canvas for the final image with total size + shape = shape if shape else (1, len(images)) + image_size = (width * shape[1], height * shape[0]) + image = Image.new("RGB", image_size) + + # Paste images into final image + for row in range(shape[0]): + for col in range(shape[1]): + offset = width * col, height * row + idx = row * shape[1] + col + image.paste(images[idx], offset) + + return image