Marseygen/main.py

121 lines
3.7 KiB
Python

import asyncio
import aiohttp
import json
from itertools import cycle
import re
import io
from drama_client import DramaClient
from data import config
from image_utils import create_grid
SERVICES = config["services"]
async def gen_worker(session, task_queue, result_queue, service):
while True:
# Get a task from the task queue
comment, prompt = await task_queue.get()
# Process the task
headers = {
'accept': 'application/json',
'Content-Type': 'application/json'
}
data = {
"prompt": prompt,
"sampler_name": "Euler a",
"batch_size": 1,
"n_iter": 3,
"steps": 30,
"cfg_scale": 7,
"width": 1024,
"height": 1024,
"do_not_save_samples": True,
"do_not_save_grid": True,
"send_images": True,
"save_images": False
}
b64_images = []
while len(b64_images) < 9:
try:
async with session.post(f"{service}sdapi/v1/txt2img", headers=headers, data=json.dumps(data)) as r:
resp_json = await r.json()
b64_images.extend(resp_json["images"])
except aiohttp.client_exceptions.ContentTypeError:
html_response = await r.text()
with open(f"{config['data_dir']}/log.txt", "a") as f:
f.write(html_response)
raise
grid = create_grid(b64_images)
# Put the result in the result queue
await result_queue.put((comment, prompt, grid))
# Indicate that the task is done
task_queue.task_done()
async def feed_worker(client, task_queue):
while True:
# Fetch new tasks
comments = await client.fetch_new_comments()
# Add each task to the task queue
for comment in comments:
prompts = re.findall(r"^!sd (.*)$", comment["body"], re.MULTILINE)
prompts = prompts[:5]
for prompt in prompts:
await task_queue.put((comment, prompt))
await asyncio.sleep(20)
async def result_worker(client, result_queue):
while True:
# Get a result from the result queue
comment, prompt, grid = await result_queue.get()
# Post the result
await client.reply(comment, f"`{prompt}`", images=[grid])
# Indicate that the result has been processed
result_queue.task_done()
await asyncio.sleep(10)
async def main():
client = DramaClient()
task_queue = asyncio.Queue() # used to pass tasks to the workers
result_queue = asyncio.Queue() # used to pass results to the result worker
# Create the feed worker
feed_worker_task = asyncio.create_task(feed_worker(client, task_queue))
# Create the result worker to post the generated images
result_worker_task = asyncio.create_task(result_worker(client, result_queue))
async with aiohttp.ClientSession() as session:
# Create a worker for each Stable Diffusion service
gen_workers = [
asyncio.create_task(gen_worker(session, task_queue, result_queue, service))
for service in SERVICES
]
try:
await asyncio.gather(feed_worker_task, *gen_workers, result_worker_task)
except asyncio.CancelledError:
# If the main() coroutine is cancelled, propagate the cancellation to all workers
feed_worker_task.cancel()
for worker in gen_workers:
worker.cancel()
result_worker_task.cancel()
await asyncio.gather(feed_worker_task, *gen_workers, result_worker_task, return_exceptions=True) # ignore cancellation exceptions
asyncio.run(main())