121 lines
3.7 KiB
Python
121 lines
3.7 KiB
Python
import asyncio
|
|
import aiohttp
|
|
import json
|
|
from itertools import cycle
|
|
import re
|
|
|
|
import io
|
|
from drama_client import DramaClient
|
|
from data import config
|
|
|
|
from image_utils import create_grid
|
|
|
|
SERVICES = config["services"]
|
|
|
|
|
|
async def gen_worker(session, task_queue, result_queue, service):
|
|
while True:
|
|
# Get a task from the task queue
|
|
comment, prompt = await task_queue.get()
|
|
|
|
# Process the task
|
|
headers = {
|
|
'accept': 'application/json',
|
|
'Content-Type': 'application/json'
|
|
}
|
|
data = {
|
|
"prompt": prompt,
|
|
"sampler_name": "Euler a",
|
|
"batch_size": 1,
|
|
"n_iter": 3,
|
|
"steps": 30,
|
|
"cfg_scale": 7,
|
|
"width": 1024,
|
|
"height": 1024,
|
|
"do_not_save_samples": True,
|
|
"do_not_save_grid": True,
|
|
"send_images": True,
|
|
"save_images": False
|
|
}
|
|
|
|
b64_images = []
|
|
while len(b64_images) < 9:
|
|
try:
|
|
async with session.post(f"{service}sdapi/v1/txt2img", headers=headers, data=json.dumps(data)) as r:
|
|
resp_json = await r.json()
|
|
b64_images.extend(resp_json["images"])
|
|
except aiohttp.client_exceptions.ContentTypeError:
|
|
html_response = await r.text()
|
|
with open(f"{config['data_dir']}/log.txt", "a") as f:
|
|
f.write(html_response)
|
|
raise
|
|
|
|
grid = create_grid(b64_images)
|
|
|
|
# Put the result in the result queue
|
|
await result_queue.put((comment, prompt, grid))
|
|
|
|
# Indicate that the task is done
|
|
task_queue.task_done()
|
|
|
|
|
|
async def feed_worker(client, task_queue):
|
|
while True:
|
|
# Fetch new tasks
|
|
comments = await client.fetch_new_comments()
|
|
|
|
# Add each task to the task queue
|
|
for comment in comments:
|
|
prompts = re.findall(r"^!sd (.*)$", comment["body"], re.MULTILINE)
|
|
prompts = prompts[:5]
|
|
for prompt in prompts:
|
|
await task_queue.put((comment, prompt))
|
|
|
|
await asyncio.sleep(20)
|
|
|
|
|
|
async def result_worker(client, result_queue):
|
|
while True:
|
|
# Get a result from the result queue
|
|
comment, prompt, grid = await result_queue.get()
|
|
|
|
# Post the result
|
|
await client.reply(comment, f"`{prompt}`", images=[grid])
|
|
|
|
# Indicate that the result has been processed
|
|
result_queue.task_done()
|
|
|
|
await asyncio.sleep(10)
|
|
|
|
|
|
async def main():
|
|
client = DramaClient()
|
|
|
|
task_queue = asyncio.Queue() # used to pass tasks to the workers
|
|
result_queue = asyncio.Queue() # used to pass results to the result worker
|
|
|
|
# Create the feed worker
|
|
feed_worker_task = asyncio.create_task(feed_worker(client, task_queue))
|
|
# Create the result worker to post the generated images
|
|
result_worker_task = asyncio.create_task(result_worker(client, result_queue))
|
|
|
|
async with aiohttp.ClientSession() as session:
|
|
# Create a worker for each Stable Diffusion service
|
|
gen_workers = [
|
|
asyncio.create_task(gen_worker(session, task_queue, result_queue, service))
|
|
for service in SERVICES
|
|
]
|
|
|
|
try:
|
|
await asyncio.gather(feed_worker_task, *gen_workers, result_worker_task)
|
|
except asyncio.CancelledError:
|
|
# If the main() coroutine is cancelled, propagate the cancellation to all workers
|
|
feed_worker_task.cancel()
|
|
for worker in gen_workers:
|
|
worker.cancel()
|
|
result_worker_task.cancel()
|
|
await asyncio.gather(feed_worker_task, *gen_workers, result_worker_task, return_exceptions=True) # ignore cancellation exceptions
|
|
|
|
|
|
asyncio.run(main())
|