From d9a9518f543dfd8843b26c89b4a844fa905c04c6 Mon Sep 17 00:00:00 2001 From: float-trip Date: Sat, 22 Jul 2023 17:05:35 +0000 Subject: [PATCH] New Marseygen. --- celeryconfig.py | 17 ------ client.py | 89 ------------------------------ data.py | 16 ++++++ drama_client.py | 135 +++++++++++++++++++++++++++++++++++++++++++++ image_utils.py | 29 ++++++++++ main.py | 120 ++++++++++++++++++++++++++++++++++++++++ readme.md | 31 ----------- requirements.txt | 5 -- tasks.py | 140 ----------------------------------------------- utils.py | 24 -------- 10 files changed, 300 insertions(+), 306 deletions(-) delete mode 100644 celeryconfig.py delete mode 100644 client.py create mode 100644 data.py create mode 100644 drama_client.py create mode 100644 image_utils.py create mode 100644 main.py delete mode 100644 readme.md delete mode 100644 requirements.txt delete mode 100644 tasks.py delete mode 100644 utils.py diff --git a/celeryconfig.py b/celeryconfig.py deleted file mode 100644 index 99e7289..0000000 --- a/celeryconfig.py +++ /dev/null @@ -1,17 +0,0 @@ -beat_schedule = { - "find-prompts": {"task": "tasks.find_prompts", "schedule": 60.0, "args": ()}, -} - -task_annotations = { - "tasks.post_reply": {"rate_limit": "5/m"}, - "tasks.find_prompts": {"rate_limit": "1/m"}, -} - -task_default_queue = "api" - -broker_url = "" -result_backend = "" - -task_routes = {"tasks.generate_reply": {"queue": "gen"}} - -worker_prefetch_multiplier = 1 diff --git a/client.py b/client.py deleted file mode 100644 index 7f171e6..0000000 --- a/client.py +++ /dev/null @@ -1,89 +0,0 @@ -import requests -import sys -import os -import time - - -class DramaClient: - BASE_URL = "https://rdrama.net" - - def __init__(self): - self.token = os.environ.get("RDRAMA_TOKEN", "") - self.last_processed_id = 2821161 # Most recent comment seen. - - def get(self, endpoint): - print(endpoint) - time.sleep(5) - - r = requests.get( - f"{self.BASE_URL}{endpoint}", headers={"Authorization": self.token} - ) - - if r.status_code != 200: - print("Error!", r, r.status_code, r.content) - sys.exit(1) - - return r.json()["data"] - - def post(self, endpoint, payload, files=[]): - print(endpoint) - time.sleep(5) - - r = requests.post( - f"{self.BASE_URL}{endpoint}", - payload, - headers={"Authorization": self.token}, - files=files, - ) - - if r.status_code != 200: - print("Error!", r, r.status_code, r.content) - sys.exit(1) - - return r.json() - - def fetch_new_comments(self): - comments = [] - if self.last_processed_id is None: - comments += self.fetch_page(1) - else: - earliest_id = None - page = 1 - # Fetch comments until we find the last one processed. - while earliest_id is None or earliest_id > self.last_processed_id: - page_comments = self.fetch_page(page) - earliest_id = min([c["id"] for c in page_comments]) - comments += [ - c for c in page_comments if c["id"] > self.last_processed_id - ] - page += 1 - - if not comments: - return [] - - self.last_processed_id = max(c["id"] for c in comments) - - # New comments may have pushed others to page n+1 while fetching. - deduped_comments = {c["id"]: c for c in comments}.values() - - # Oldest first. - comments.reverse() - - return comments - - def fetch_page(self, page): - return self.get(f"/comments?page={page}") - - def reply(self, parent_fullname, submission, body, image_path=None): - payload = { - "parent_fullname": parent_fullname, - "submission": submission, - "body": body, - } - - files = [] - if image_path: - filename = image_path.split("/")[-1] - files = {"file": (filename, open(image_path, "rb"), "image/webp")} - - self.post("/comment", payload, files=files) diff --git a/data.py b/data.py new file mode 100644 index 0000000..3cefd87 --- /dev/null +++ b/data.py @@ -0,0 +1,16 @@ +import os + +import yaml +from sqlitedict import SqliteDict + +current_dir = os.path.dirname(os.path.realpath(__file__)) +config_path = os.path.join(current_dir, "config.yaml") + + +def load_config(): + with open(config_path, "r") as f: + return yaml.safe_load(f) + + +config = load_config() +db = SqliteDict(f"{config['data_dir']}/db.sqlite", autocommit=True) diff --git a/drama_client.py b/drama_client.py new file mode 100644 index 0000000..d453019 --- /dev/null +++ b/drama_client.py @@ -0,0 +1,135 @@ +import aiohttp +import asyncio +import json + +import sys +import os +import math +import random +import shelve +import io + +from aiohttp_retry import RetryClient, ExponentialRetry +from collections import OrderedDict +from data import config, db + +import logging + +logging.basicConfig(level=logging.WARNING) + +def copy_file_obj(file_obj): + # Read the contents of the original file + content = file_obj.read() + + # Reset the position of the original file in case it needs to be read again + file_obj.seek(0) + + # Create a new BytesIO object with the same content + new_file_obj = io.BytesIO(content) + + return new_file_obj + +class DramaClient: + BASE_URL = "https://rdrama.net" + + def __init__(self, client=None): + #self.client = client or RetryClient(retry_options=ExponentialRetry(attempts=5)) + self.client = client or aiohttp.ClientSession() + self.max_retries = 5 # define a maximum number of retries + + + #self.chud_phrase = asyncio.run(self.get("/@me")).get("chud_phrase", "") + + async def get(self, endpoint): + print("GET", endpoint) + print(config["api_token"]) + + async with self.client.get( + f"{self.BASE_URL}{endpoint}", + headers={"Authorization": config["api_token"]}, + ) as r: + if r.status != 200: + print("Error!", r, r.status, await r.text()) + sys.exit(1) + + return await r.json() + + async def post(self, endpoint, data=None, images=None): + for attempt in range(self.max_retries): + await asyncio.sleep(5) + try: + form_data = aiohttp.FormData() + + if data is not None: + for key, value in data.items(): + form_data.add_field(key, str(value)) + + if images is not None: + for file in images: + form_data.add_field('file', file, filename='image.webp', content_type='image/webp') + + async with self.client.post(f"{self.BASE_URL}{endpoint}", data=form_data, headers={"Authorization": config["api_token"]}) as r: + if r.status != 200: + print("Error!", r, r.status, await r.text()) + raise Exception("HTTP error") # raise an exception to trigger the retry + return await r.json() + except Exception as e: + if attempt < self.max_retries - 1: # if this wasn't the last attempt, continue to the next one + continue + else: # this was the last attempt, re-raise the exception + print("Exception", data) + print(e) + raise e # this was the last attempt, re-raise the exception + + async def fetch_new_comments(self): + comments = [] + + earliest_id = math.inf + page = 1 + + if "last_processed_id" not in db: + page_comments = await self.fetch_page(1) + db["last_processed_id"] = max(c["id"] for c in page_comments) + db.commit() + return [] + + # Fetch comments until we find the last one processed. + while earliest_id > db["last_processed_id"]: + page_comments = await self.fetch_page(page) + + if len(page_comments) == 0: + break + + earliest_id = min([c["id"] for c in page_comments]) + comments += [c for c in page_comments if c["id"] > db["last_processed_id"]] + + page += 1 + + if not comments: + return [] + + db["last_processed_id"] = max(c["id"] for c in comments) + db.commit() + + # New comments may have pushed others to page n+1 while fetching. + comments = {c["id"]: c for c in comments}.values() + comments = list(OrderedDict((c['id'], c) for c in comments).values()) + + # Oldest first. + comments.reverse() + + return comments + + async def fetch_page(self, page): + return (await self.get(f"/comments?page={page}"))["data"] + + async def reply(self, comment, body, images=None): + #if self.chud_phrase and self.chud_phrase not in body: + # body += f"\n{self.chud_phrase}" + + data = { + "parent_fullname": f"c_{comment['id']}", + "body": f"{body}{random.randint(1, 1000000000)}", + } + + return await self.post("/comment", data=data, images=images) diff --git a/image_utils.py b/image_utils.py new file mode 100644 index 0000000..af006b3 --- /dev/null +++ b/image_utils.py @@ -0,0 +1,29 @@ +import numpy as np +from PIL import Image +from io import BytesIO +import base64 + +def decode_and_resize(image_string): + img_data = base64.b64decode(image_string) + img = Image.open(BytesIO(img_data)) + return img.resize((512, 512)) + +def combine_images(images): + combined = Image.new('RGB', (1536, 1536)) # 3 * 512 = 1536 + + for i, img in enumerate(images): + x = i % 3 * 512 + y = i // 3 * 512 + combined.paste(img, (x, y)) + + return combined + +def create_grid(b64_images): + # decode, resize, and combine images + images = [decode_and_resize(img_str) for img_str in b64_images] + combined = combine_images(images) + + # convert combined image to byte stream for posting + img_byte_arr = BytesIO() + combined.save(img_byte_arr, format='WEBP') + return img_byte_arr.getvalue() \ No newline at end of file diff --git a/main.py b/main.py new file mode 100644 index 0000000..21d54b0 --- /dev/null +++ b/main.py @@ -0,0 +1,120 @@ +import asyncio +import aiohttp +import json +from itertools import cycle +import re + +import io +from drama_client import DramaClient +from data import config + +from image_utils import create_grid + +SERVICES = config["services"] + + +async def gen_worker(session, task_queue, result_queue, service): + while True: + # Get a task from the task queue + comment, prompt = await task_queue.get() + + # Process the task + headers = { + 'accept': 'application/json', + 'Content-Type': 'application/json' + } + data = { + "prompt": prompt, + "sampler_name": "Euler a", + "batch_size": 1, + "n_iter": 3, + "steps": 30, + "cfg_scale": 7, + "width": 1024, + "height": 1024, + "do_not_save_samples": True, + "do_not_save_grid": True, + "send_images": True, + "save_images": False + } + + b64_images = [] + while len(b64_images) < 9: + try: + async with session.post(f"{service}sdapi/v1/txt2img", headers=headers, data=json.dumps(data)) as r: + resp_json = await r.json() + b64_images.extend(resp_json["images"]) + except aiohttp.client_exceptions.ContentTypeError: + html_response = await r.text() + with open(f"{config['data_dir']}/log.txt", "a") as f: + f.write(html_response) + raise + + grid = create_grid(b64_images) + + # Put the result in the result queue + await result_queue.put((comment, prompt, grid)) + + # Indicate that the task is done + task_queue.task_done() + + +async def feed_worker(client, task_queue): + while True: + # Fetch new tasks + comments = await client.fetch_new_comments() + + # Add each task to the task queue + for comment in comments: + prompts = re.findall(r"^!sd (.*)$", comment["body"], re.MULTILINE) + prompts = prompts[:5] + for prompt in prompts: + await task_queue.put((comment, prompt)) + + await asyncio.sleep(20) + + +async def result_worker(client, result_queue): + while True: + # Get a result from the result queue + comment, prompt, grid = await result_queue.get() + + # Post the result + await client.reply(comment, f"`{prompt}`", images=[grid]) + + # Indicate that the result has been processed + result_queue.task_done() + + await asyncio.sleep(10) + + +async def main(): + client = DramaClient() + + task_queue = asyncio.Queue() # used to pass tasks to the workers + result_queue = asyncio.Queue() # used to pass results to the result worker + + # Create the feed worker + feed_worker_task = asyncio.create_task(feed_worker(client, task_queue)) + # Create the result worker to post the generated images + result_worker_task = asyncio.create_task(result_worker(client, result_queue)) + + async with aiohttp.ClientSession() as session: + # Create a worker for each Stable Diffusion service + gen_workers = [ + asyncio.create_task(gen_worker(session, task_queue, result_queue, service)) + for service in SERVICES + ] + + try: + await asyncio.gather(feed_worker_task, *gen_workers, result_worker_task) + except asyncio.CancelledError: + # If the main() coroutine is cancelled, propagate the cancellation to all workers + feed_worker_task.cancel() + for worker in gen_workers: + worker.cancel() + result_worker_task.cancel() + await asyncio.gather(feed_worker_task, *gen_workers, result_worker_task, return_exceptions=True) # ignore cancellation exceptions + + +asyncio.run(main()) diff --git a/readme.md b/readme.md deleted file mode 100644 index d351107..0000000 --- a/readme.md +++ /dev/null @@ -1,31 +0,0 @@ -# Marseygen - -Stable Diffusion bot with distributed inference. - -# Usage - -* Set up [InvokeAI](https://github.com/invoke-ai/InvokeAI) on the gen workers and activate the `ldm` environment - -* Install rabbitmq and redis, add URLs to `celeryconfig.py` - -* `git clone https://github.com/float-trip/marseygen` - -* `pip install -r marseygen/requirements.txt` - -* `mv marseygen/*.py InvokeAI && cd InvokeAI` - * Running the gen workers from this dir circumvents some Python import issues that I don't care to figure out right now - -* Start the API worker - -`celery -A tasks worker -B --concurrency 1 --loglevel=INFO` - -* Start a gen worker for each GPU - -```sh -export CUDA_VISIBLE_DEVICES=0, -export WORKER_HOST="user@gen_worker_ip" -export WORKER_SSH_PORT="22" -export WORKER_ID="unique_id" -celery -A tasks worker -Q gen -n unique_name -B --concurrency 1 --loglevel=INFO` -``` - diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 16660e6..0000000 --- a/requirements.txt +++ /dev/null @@ -1,5 +0,0 @@ -celery==5.2.7 -celery_singleton==0.3.1 -Pillow==9.2.0 -pry.py==0.1.1 -requests==2.27.1 diff --git a/tasks.py b/tasks.py deleted file mode 100644 index 67d8f5b..0000000 --- a/tasks.py +++ /dev/null @@ -1,140 +0,0 @@ -import os -import subprocess -import random -from celery import Celery, Task, chain -import time -import celeryconfig - -from client import DramaClient -from utils import concat_images - -app = Celery("tasks") -app.config_from_object(celeryconfig) - -client = DramaClient() - -generator = None - - -# -# API worker tasks -# -@app.task -def post_reply(context): - basename = os.path.basename(context["image_path"]) - save_path = f"/fs/marseys/{basename}" - - print(f"Copying {basename}") - - # Copy image from remote machine. - subprocess.run( - [ - "rsync", - "-a", - f"{context['worker_host']}:{context['image_path']}", - save_path, - "-e", - f"ssh -p {context['worker_ssh_port']}", - ] - ) - - print(f"Replying for prompt {context['prompt']}") - - client.reply( - context["parent_fullname"], - context["submission"], - f"`{context['prompt']}`", - save_path, - ) - - -class FindPromptsTask(Task): - last_call = None - - # Temp fix for comments being replied to multiple times. - queued_ids = set() - - -@app.task(base=FindPromptsTask) -def find_prompts(): - if find_prompts.last_call is not None and time.time() - find_prompts.last_call < 60: - return - - find_prompts.last_call = time.time() - - print("Looking for prompts.") - comments = client.fetch_new_comments() - - for comment in comments: - if comment["id"] in find_prompts.queued_ids: - continue - - find_prompts.queued_ids.add(comment["id"]) - - reply_contexts = [ - { - "parent_fullname": f"c_{comment['id']}", - "submission": comment["post_id"], - "prompt": line[4:], - } - for line in comment["body"].split("\n") - if line.startswith("!sd ") - ] - - # Max 5 prompts per comment. - reply_contexts = reply_contexts[:5] - - for context in reply_contexts: - print(f"Queueing prompt `{context['prompt']}`.") - chain( - generate_reply.s(context).set(queue="gen"), - post_reply.s().set(queue="api"), - ).apply_async() - - -# -# Generation worker tasks -# -class GenTask(Task): - _generator = None - - @property - def generator(self): - if self._generator is None: - from ldm.generate import Generate - - self._generator = Generate(sampler_name="k_euler_a") - self._generator.load_model() - - print("Model loaded.") - - return self._generator - - -@app.task(base=GenTask) -def generate_reply(context): - print(f"Generating `{context['prompt']}`.") - - if not os.path.exists("out"): - os.makedirs("out") - - results = generate_reply.generator.prompt2png( - context["prompt"], outdir=f"out/{os.environ['WORKER_ID']}", iterations=9 - ) - - image_paths = [r[0] for r in results] - grid = concat_images(image_paths, size=(512, 512), shape=(3, 3)) - - grid_basename = f"{random.randrange(10**6, 10**7)}.webp" - - if not os.path.exists("grid"): - os.makedirs("grid") - - grid_path = f"grid/{grid_basename}" - grid.save(grid_path, "WEBP") - - context["image_path"] = os.path.abspath(grid_path) - context["worker_host"] = os.environ["WORKER_HOST"] - context["worker_ssh_port"] = os.environ["WORKER_SSH_PORT"] - - return context diff --git a/utils.py b/utils.py deleted file mode 100644 index 5313ce3..0000000 --- a/utils.py +++ /dev/null @@ -1,24 +0,0 @@ -from PIL import Image, ImageOps - -# -# https://gist.github.com/njanakiev/1932e0a450df6d121c05069d5f7d7d6f -# -def concat_images(image_paths, size, shape=None): - # Open images and resize them - width, height = size - images = map(Image.open, image_paths) - images = [ImageOps.fit(image, size, Image.ANTIALIAS) for image in images] - - # Create canvas for the final image with total size - shape = shape if shape else (1, len(images)) - image_size = (width * shape[1], height * shape[0]) - image = Image.new("RGB", image_size) - - # Paste images into final image - for row in range(shape[0]): - for col in range(shape[1]): - offset = width * col, height * row - idx = row * shape[1] + col - image.paste(images[idx], offset) - - return image