import asyncio import aiohttp import json from itertools import cycle import re import io from drama_client import DramaClient from data import config from image_utils import create_grid SERVICES = config["services"] async def gen_worker(session, task_queue, result_queue, service): while True: # Get a task from the task queue comment, prompt = await task_queue.get() # Process the task headers = { 'accept': 'application/json', 'Content-Type': 'application/json' } data = { "prompt": prompt, "sampler_name": "Euler a", "batch_size": 1, "n_iter": 3, "steps": 30, "cfg_scale": 7, "width": 1024, "height": 1024, "do_not_save_samples": True, "do_not_save_grid": True, "send_images": True, "save_images": False } b64_images = [] while len(b64_images) < 9: try: async with session.post(f"{service}sdapi/v1/txt2img", headers=headers, data=json.dumps(data)) as r: resp_json = await r.json() b64_images.extend(resp_json["images"]) except aiohttp.client_exceptions.ContentTypeError: html_response = await r.text() with open(f"{config['data_dir']}/log.txt", "a") as f: f.write(html_response) raise grid = create_grid(b64_images) # Put the result in the result queue await result_queue.put((comment, prompt, grid)) # Indicate that the task is done task_queue.task_done() async def feed_worker(client, task_queue): while True: # Fetch new tasks comments = await client.fetch_new_comments() # Add each task to the task queue for comment in comments: prompts = re.findall(r"^!sd (.*)$", comment["body"], re.MULTILINE) prompts = prompts[:5] for prompt in prompts: await task_queue.put((comment, prompt)) await asyncio.sleep(20) async def result_worker(client, result_queue): while True: # Get a result from the result queue comment, prompt, grid = await result_queue.get() # Post the result await client.reply(comment, f"`{prompt}`", images=[grid]) # Indicate that the result has been processed result_queue.task_done() await asyncio.sleep(10) async def main(): client = DramaClient() task_queue = asyncio.Queue() # used to pass tasks to the workers result_queue = asyncio.Queue() # used to pass results to the result worker # Create the feed worker feed_worker_task = asyncio.create_task(feed_worker(client, task_queue)) # Create the result worker to post the generated images result_worker_task = asyncio.create_task(result_worker(client, result_queue)) async with aiohttp.ClientSession() as session: # Create a worker for each Stable Diffusion service gen_workers = [ asyncio.create_task(gen_worker(session, task_queue, result_queue, service)) for service in SERVICES ] try: await asyncio.gather(feed_worker_task, *gen_workers, result_worker_task) except asyncio.CancelledError: # If the main() coroutine is cancelled, propagate the cancellation to all workers feed_worker_task.cancel() for worker in gen_workers: worker.cancel() result_worker_task.cancel() await asyncio.gather(feed_worker_task, *gen_workers, result_worker_task, return_exceptions=True) # ignore cancellation exceptions asyncio.run(main())