New Marseygen.
parent
e813c12852
commit
d9a9518f54
|
@ -1,17 +0,0 @@
|
||||||
beat_schedule = {
|
|
||||||
"find-prompts": {"task": "tasks.find_prompts", "schedule": 60.0, "args": ()},
|
|
||||||
}
|
|
||||||
|
|
||||||
task_annotations = {
|
|
||||||
"tasks.post_reply": {"rate_limit": "5/m"},
|
|
||||||
"tasks.find_prompts": {"rate_limit": "1/m"},
|
|
||||||
}
|
|
||||||
|
|
||||||
task_default_queue = "api"
|
|
||||||
|
|
||||||
broker_url = ""
|
|
||||||
result_backend = ""
|
|
||||||
|
|
||||||
task_routes = {"tasks.generate_reply": {"queue": "gen"}}
|
|
||||||
|
|
||||||
worker_prefetch_multiplier = 1
|
|
89
client.py
89
client.py
|
@ -1,89 +0,0 @@
|
||||||
import requests
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
import time
|
|
||||||
|
|
||||||
|
|
||||||
class DramaClient:
|
|
||||||
BASE_URL = "https://rdrama.net"
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.token = os.environ.get("RDRAMA_TOKEN", "")
|
|
||||||
self.last_processed_id = 2821161 # Most recent comment seen.
|
|
||||||
|
|
||||||
def get(self, endpoint):
|
|
||||||
print(endpoint)
|
|
||||||
time.sleep(5)
|
|
||||||
|
|
||||||
r = requests.get(
|
|
||||||
f"{self.BASE_URL}{endpoint}", headers={"Authorization": self.token}
|
|
||||||
)
|
|
||||||
|
|
||||||
if r.status_code != 200:
|
|
||||||
print("Error!", r, r.status_code, r.content)
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
return r.json()["data"]
|
|
||||||
|
|
||||||
def post(self, endpoint, payload, files=[]):
|
|
||||||
print(endpoint)
|
|
||||||
time.sleep(5)
|
|
||||||
|
|
||||||
r = requests.post(
|
|
||||||
f"{self.BASE_URL}{endpoint}",
|
|
||||||
payload,
|
|
||||||
headers={"Authorization": self.token},
|
|
||||||
files=files,
|
|
||||||
)
|
|
||||||
|
|
||||||
if r.status_code != 200:
|
|
||||||
print("Error!", r, r.status_code, r.content)
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
return r.json()
|
|
||||||
|
|
||||||
def fetch_new_comments(self):
|
|
||||||
comments = []
|
|
||||||
if self.last_processed_id is None:
|
|
||||||
comments += self.fetch_page(1)
|
|
||||||
else:
|
|
||||||
earliest_id = None
|
|
||||||
page = 1
|
|
||||||
# Fetch comments until we find the last one processed.
|
|
||||||
while earliest_id is None or earliest_id > self.last_processed_id:
|
|
||||||
page_comments = self.fetch_page(page)
|
|
||||||
earliest_id = min([c["id"] for c in page_comments])
|
|
||||||
comments += [
|
|
||||||
c for c in page_comments if c["id"] > self.last_processed_id
|
|
||||||
]
|
|
||||||
page += 1
|
|
||||||
|
|
||||||
if not comments:
|
|
||||||
return []
|
|
||||||
|
|
||||||
self.last_processed_id = max(c["id"] for c in comments)
|
|
||||||
|
|
||||||
# New comments may have pushed others to page n+1 while fetching.
|
|
||||||
deduped_comments = {c["id"]: c for c in comments}.values()
|
|
||||||
|
|
||||||
# Oldest first.
|
|
||||||
comments.reverse()
|
|
||||||
|
|
||||||
return comments
|
|
||||||
|
|
||||||
def fetch_page(self, page):
|
|
||||||
return self.get(f"/comments?page={page}")
|
|
||||||
|
|
||||||
def reply(self, parent_fullname, submission, body, image_path=None):
|
|
||||||
payload = {
|
|
||||||
"parent_fullname": parent_fullname,
|
|
||||||
"submission": submission,
|
|
||||||
"body": body,
|
|
||||||
}
|
|
||||||
|
|
||||||
files = []
|
|
||||||
if image_path:
|
|
||||||
filename = image_path.split("/")[-1]
|
|
||||||
files = {"file": (filename, open(image_path, "rb"), "image/webp")}
|
|
||||||
|
|
||||||
self.post("/comment", payload, files=files)
|
|
|
@ -0,0 +1,16 @@
|
||||||
|
import os
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
from sqlitedict import SqliteDict
|
||||||
|
|
||||||
|
current_dir = os.path.dirname(os.path.realpath(__file__))
|
||||||
|
config_path = os.path.join(current_dir, "config.yaml")
|
||||||
|
|
||||||
|
|
||||||
|
def load_config():
|
||||||
|
with open(config_path, "r") as f:
|
||||||
|
return yaml.safe_load(f)
|
||||||
|
|
||||||
|
|
||||||
|
config = load_config()
|
||||||
|
db = SqliteDict(f"{config['data_dir']}/db.sqlite", autocommit=True)
|
|
@ -0,0 +1,135 @@
|
||||||
|
import aiohttp
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import math
|
||||||
|
import random
|
||||||
|
import shelve
|
||||||
|
import io
|
||||||
|
|
||||||
|
from aiohttp_retry import RetryClient, ExponentialRetry
|
||||||
|
from collections import OrderedDict
|
||||||
|
from data import config, db
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.WARNING)
|
||||||
|
|
||||||
|
def copy_file_obj(file_obj):
|
||||||
|
# Read the contents of the original file
|
||||||
|
content = file_obj.read()
|
||||||
|
|
||||||
|
# Reset the position of the original file in case it needs to be read again
|
||||||
|
file_obj.seek(0)
|
||||||
|
|
||||||
|
# Create a new BytesIO object with the same content
|
||||||
|
new_file_obj = io.BytesIO(content)
|
||||||
|
|
||||||
|
return new_file_obj
|
||||||
|
|
||||||
|
class DramaClient:
|
||||||
|
BASE_URL = "https://rdrama.net"
|
||||||
|
|
||||||
|
def __init__(self, client=None):
|
||||||
|
#self.client = client or RetryClient(retry_options=ExponentialRetry(attempts=5))
|
||||||
|
self.client = client or aiohttp.ClientSession()
|
||||||
|
self.max_retries = 5 # define a maximum number of retries
|
||||||
|
|
||||||
|
|
||||||
|
#self.chud_phrase = asyncio.run(self.get("/@me")).get("chud_phrase", "")
|
||||||
|
|
||||||
|
async def get(self, endpoint):
|
||||||
|
print("GET", endpoint)
|
||||||
|
print(config["api_token"])
|
||||||
|
|
||||||
|
async with self.client.get(
|
||||||
|
f"{self.BASE_URL}{endpoint}",
|
||||||
|
headers={"Authorization": config["api_token"]},
|
||||||
|
) as r:
|
||||||
|
if r.status != 200:
|
||||||
|
print("Error!", r, r.status, await r.text())
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
return await r.json()
|
||||||
|
|
||||||
|
async def post(self, endpoint, data=None, images=None):
|
||||||
|
for attempt in range(self.max_retries):
|
||||||
|
await asyncio.sleep(5)
|
||||||
|
try:
|
||||||
|
form_data = aiohttp.FormData()
|
||||||
|
|
||||||
|
if data is not None:
|
||||||
|
for key, value in data.items():
|
||||||
|
form_data.add_field(key, str(value))
|
||||||
|
|
||||||
|
if images is not None:
|
||||||
|
for file in images:
|
||||||
|
form_data.add_field('file', file, filename='image.webp', content_type='image/webp')
|
||||||
|
|
||||||
|
async with self.client.post(f"{self.BASE_URL}{endpoint}", data=form_data, headers={"Authorization": config["api_token"]}) as r:
|
||||||
|
if r.status != 200:
|
||||||
|
print("Error!", r, r.status, await r.text())
|
||||||
|
raise Exception("HTTP error") # raise an exception to trigger the retry
|
||||||
|
return await r.json()
|
||||||
|
except Exception as e:
|
||||||
|
if attempt < self.max_retries - 1: # if this wasn't the last attempt, continue to the next one
|
||||||
|
continue
|
||||||
|
else: # this was the last attempt, re-raise the exception
|
||||||
|
print("Exception", data)
|
||||||
|
print(e)
|
||||||
|
raise e # this was the last attempt, re-raise the exception
|
||||||
|
|
||||||
|
async def fetch_new_comments(self):
|
||||||
|
comments = []
|
||||||
|
|
||||||
|
earliest_id = math.inf
|
||||||
|
page = 1
|
||||||
|
|
||||||
|
if "last_processed_id" not in db:
|
||||||
|
page_comments = await self.fetch_page(1)
|
||||||
|
db["last_processed_id"] = max(c["id"] for c in page_comments)
|
||||||
|
db.commit()
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Fetch comments until we find the last one processed.
|
||||||
|
while earliest_id > db["last_processed_id"]:
|
||||||
|
page_comments = await self.fetch_page(page)
|
||||||
|
|
||||||
|
if len(page_comments) == 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
earliest_id = min([c["id"] for c in page_comments])
|
||||||
|
comments += [c for c in page_comments if c["id"] > db["last_processed_id"]]
|
||||||
|
|
||||||
|
page += 1
|
||||||
|
|
||||||
|
if not comments:
|
||||||
|
return []
|
||||||
|
|
||||||
|
db["last_processed_id"] = max(c["id"] for c in comments)
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
# New comments may have pushed others to page n+1 while fetching.
|
||||||
|
comments = {c["id"]: c for c in comments}.values()
|
||||||
|
comments = list(OrderedDict((c['id'], c) for c in comments).values())
|
||||||
|
|
||||||
|
# Oldest first.
|
||||||
|
comments.reverse()
|
||||||
|
|
||||||
|
return comments
|
||||||
|
|
||||||
|
async def fetch_page(self, page):
|
||||||
|
return (await self.get(f"/comments?page={page}"))["data"]
|
||||||
|
|
||||||
|
async def reply(self, comment, body, images=None):
|
||||||
|
#if self.chud_phrase and self.chud_phrase not in body:
|
||||||
|
# body += f"\n{self.chud_phrase}"
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"parent_fullname": f"c_{comment['id']}",
|
||||||
|
"body": f"{body}<sub><sub><sub><sub><sub>{random.randint(1, 1000000000)}",
|
||||||
|
}
|
||||||
|
|
||||||
|
return await self.post("/comment", data=data, images=images)
|
|
@ -0,0 +1,29 @@
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
from io import BytesIO
|
||||||
|
import base64
|
||||||
|
|
||||||
|
def decode_and_resize(image_string):
|
||||||
|
img_data = base64.b64decode(image_string)
|
||||||
|
img = Image.open(BytesIO(img_data))
|
||||||
|
return img.resize((512, 512))
|
||||||
|
|
||||||
|
def combine_images(images):
|
||||||
|
combined = Image.new('RGB', (1536, 1536)) # 3 * 512 = 1536
|
||||||
|
|
||||||
|
for i, img in enumerate(images):
|
||||||
|
x = i % 3 * 512
|
||||||
|
y = i // 3 * 512
|
||||||
|
combined.paste(img, (x, y))
|
||||||
|
|
||||||
|
return combined
|
||||||
|
|
||||||
|
def create_grid(b64_images):
|
||||||
|
# decode, resize, and combine images
|
||||||
|
images = [decode_and_resize(img_str) for img_str in b64_images]
|
||||||
|
combined = combine_images(images)
|
||||||
|
|
||||||
|
# convert combined image to byte stream for posting
|
||||||
|
img_byte_arr = BytesIO()
|
||||||
|
combined.save(img_byte_arr, format='WEBP')
|
||||||
|
return img_byte_arr.getvalue()
|
|
@ -0,0 +1,120 @@
|
||||||
|
import asyncio
|
||||||
|
import aiohttp
|
||||||
|
import json
|
||||||
|
from itertools import cycle
|
||||||
|
import re
|
||||||
|
|
||||||
|
import io
|
||||||
|
from drama_client import DramaClient
|
||||||
|
from data import config
|
||||||
|
|
||||||
|
from image_utils import create_grid
|
||||||
|
|
||||||
|
SERVICES = config["services"]
|
||||||
|
|
||||||
|
|
||||||
|
async def gen_worker(session, task_queue, result_queue, service):
|
||||||
|
while True:
|
||||||
|
# Get a task from the task queue
|
||||||
|
comment, prompt = await task_queue.get()
|
||||||
|
|
||||||
|
# Process the task
|
||||||
|
headers = {
|
||||||
|
'accept': 'application/json',
|
||||||
|
'Content-Type': 'application/json'
|
||||||
|
}
|
||||||
|
data = {
|
||||||
|
"prompt": prompt,
|
||||||
|
"sampler_name": "Euler a",
|
||||||
|
"batch_size": 1,
|
||||||
|
"n_iter": 3,
|
||||||
|
"steps": 30,
|
||||||
|
"cfg_scale": 7,
|
||||||
|
"width": 1024,
|
||||||
|
"height": 1024,
|
||||||
|
"do_not_save_samples": True,
|
||||||
|
"do_not_save_grid": True,
|
||||||
|
"send_images": True,
|
||||||
|
"save_images": False
|
||||||
|
}
|
||||||
|
|
||||||
|
b64_images = []
|
||||||
|
while len(b64_images) < 9:
|
||||||
|
try:
|
||||||
|
async with session.post(f"{service}sdapi/v1/txt2img", headers=headers, data=json.dumps(data)) as r:
|
||||||
|
resp_json = await r.json()
|
||||||
|
b64_images.extend(resp_json["images"])
|
||||||
|
except aiohttp.client_exceptions.ContentTypeError:
|
||||||
|
html_response = await r.text()
|
||||||
|
with open(f"{config['data_dir']}/log.txt", "a") as f:
|
||||||
|
f.write(html_response)
|
||||||
|
raise
|
||||||
|
|
||||||
|
grid = create_grid(b64_images)
|
||||||
|
|
||||||
|
# Put the result in the result queue
|
||||||
|
await result_queue.put((comment, prompt, grid))
|
||||||
|
|
||||||
|
# Indicate that the task is done
|
||||||
|
task_queue.task_done()
|
||||||
|
|
||||||
|
|
||||||
|
async def feed_worker(client, task_queue):
|
||||||
|
while True:
|
||||||
|
# Fetch new tasks
|
||||||
|
comments = await client.fetch_new_comments()
|
||||||
|
|
||||||
|
# Add each task to the task queue
|
||||||
|
for comment in comments:
|
||||||
|
prompts = re.findall(r"^!sd (.*)$", comment["body"], re.MULTILINE)
|
||||||
|
prompts = prompts[:5]
|
||||||
|
for prompt in prompts:
|
||||||
|
await task_queue.put((comment, prompt))
|
||||||
|
|
||||||
|
await asyncio.sleep(20)
|
||||||
|
|
||||||
|
|
||||||
|
async def result_worker(client, result_queue):
|
||||||
|
while True:
|
||||||
|
# Get a result from the result queue
|
||||||
|
comment, prompt, grid = await result_queue.get()
|
||||||
|
|
||||||
|
# Post the result
|
||||||
|
await client.reply(comment, f"`{prompt}`", images=[grid])
|
||||||
|
|
||||||
|
# Indicate that the result has been processed
|
||||||
|
result_queue.task_done()
|
||||||
|
|
||||||
|
await asyncio.sleep(10)
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
client = DramaClient()
|
||||||
|
|
||||||
|
task_queue = asyncio.Queue() # used to pass tasks to the workers
|
||||||
|
result_queue = asyncio.Queue() # used to pass results to the result worker
|
||||||
|
|
||||||
|
# Create the feed worker
|
||||||
|
feed_worker_task = asyncio.create_task(feed_worker(client, task_queue))
|
||||||
|
# Create the result worker to post the generated images
|
||||||
|
result_worker_task = asyncio.create_task(result_worker(client, result_queue))
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
# Create a worker for each Stable Diffusion service
|
||||||
|
gen_workers = [
|
||||||
|
asyncio.create_task(gen_worker(session, task_queue, result_queue, service))
|
||||||
|
for service in SERVICES
|
||||||
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
await asyncio.gather(feed_worker_task, *gen_workers, result_worker_task)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
# If the main() coroutine is cancelled, propagate the cancellation to all workers
|
||||||
|
feed_worker_task.cancel()
|
||||||
|
for worker in gen_workers:
|
||||||
|
worker.cancel()
|
||||||
|
result_worker_task.cancel()
|
||||||
|
await asyncio.gather(feed_worker_task, *gen_workers, result_worker_task, return_exceptions=True) # ignore cancellation exceptions
|
||||||
|
|
||||||
|
|
||||||
|
asyncio.run(main())
|
31
readme.md
31
readme.md
|
@ -1,31 +0,0 @@
|
||||||
# Marseygen
|
|
||||||
|
|
||||||
Stable Diffusion bot with distributed inference.
|
|
||||||
|
|
||||||
# Usage
|
|
||||||
|
|
||||||
* Set up [InvokeAI](https://github.com/invoke-ai/InvokeAI) on the gen workers and activate the `ldm` environment
|
|
||||||
|
|
||||||
* Install rabbitmq and redis, add URLs to `celeryconfig.py`
|
|
||||||
|
|
||||||
* `git clone https://github.com/float-trip/marseygen`
|
|
||||||
|
|
||||||
* `pip install -r marseygen/requirements.txt`
|
|
||||||
|
|
||||||
* `mv marseygen/*.py InvokeAI && cd InvokeAI`
|
|
||||||
* Running the gen workers from this dir circumvents some Python import issues that I don't care to figure out right now
|
|
||||||
|
|
||||||
* Start the API worker
|
|
||||||
|
|
||||||
`celery -A tasks worker -B --concurrency 1 --loglevel=INFO`
|
|
||||||
|
|
||||||
* Start a gen worker for each GPU
|
|
||||||
|
|
||||||
```sh
|
|
||||||
export CUDA_VISIBLE_DEVICES=0,
|
|
||||||
export WORKER_HOST="user@gen_worker_ip"
|
|
||||||
export WORKER_SSH_PORT="22"
|
|
||||||
export WORKER_ID="unique_id"
|
|
||||||
celery -A tasks worker -Q gen -n unique_name -B --concurrency 1 --loglevel=INFO`
|
|
||||||
```
|
|
||||||
|
|
|
@ -1,5 +0,0 @@
|
||||||
celery==5.2.7
|
|
||||||
celery_singleton==0.3.1
|
|
||||||
Pillow==9.2.0
|
|
||||||
pry.py==0.1.1
|
|
||||||
requests==2.27.1
|
|
140
tasks.py
140
tasks.py
|
@ -1,140 +0,0 @@
|
||||||
import os
|
|
||||||
import subprocess
|
|
||||||
import random
|
|
||||||
from celery import Celery, Task, chain
|
|
||||||
import time
|
|
||||||
import celeryconfig
|
|
||||||
|
|
||||||
from client import DramaClient
|
|
||||||
from utils import concat_images
|
|
||||||
|
|
||||||
app = Celery("tasks")
|
|
||||||
app.config_from_object(celeryconfig)
|
|
||||||
|
|
||||||
client = DramaClient()
|
|
||||||
|
|
||||||
generator = None
|
|
||||||
|
|
||||||
|
|
||||||
#
|
|
||||||
# API worker tasks
|
|
||||||
#
|
|
||||||
@app.task
|
|
||||||
def post_reply(context):
|
|
||||||
basename = os.path.basename(context["image_path"])
|
|
||||||
save_path = f"/fs/marseys/{basename}"
|
|
||||||
|
|
||||||
print(f"Copying {basename}")
|
|
||||||
|
|
||||||
# Copy image from remote machine.
|
|
||||||
subprocess.run(
|
|
||||||
[
|
|
||||||
"rsync",
|
|
||||||
"-a",
|
|
||||||
f"{context['worker_host']}:{context['image_path']}",
|
|
||||||
save_path,
|
|
||||||
"-e",
|
|
||||||
f"ssh -p {context['worker_ssh_port']}",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"Replying for prompt {context['prompt']}")
|
|
||||||
|
|
||||||
client.reply(
|
|
||||||
context["parent_fullname"],
|
|
||||||
context["submission"],
|
|
||||||
f"`{context['prompt']}`",
|
|
||||||
save_path,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class FindPromptsTask(Task):
|
|
||||||
last_call = None
|
|
||||||
|
|
||||||
# Temp fix for comments being replied to multiple times.
|
|
||||||
queued_ids = set()
|
|
||||||
|
|
||||||
|
|
||||||
@app.task(base=FindPromptsTask)
|
|
||||||
def find_prompts():
|
|
||||||
if find_prompts.last_call is not None and time.time() - find_prompts.last_call < 60:
|
|
||||||
return
|
|
||||||
|
|
||||||
find_prompts.last_call = time.time()
|
|
||||||
|
|
||||||
print("Looking for prompts.")
|
|
||||||
comments = client.fetch_new_comments()
|
|
||||||
|
|
||||||
for comment in comments:
|
|
||||||
if comment["id"] in find_prompts.queued_ids:
|
|
||||||
continue
|
|
||||||
|
|
||||||
find_prompts.queued_ids.add(comment["id"])
|
|
||||||
|
|
||||||
reply_contexts = [
|
|
||||||
{
|
|
||||||
"parent_fullname": f"c_{comment['id']}",
|
|
||||||
"submission": comment["post_id"],
|
|
||||||
"prompt": line[4:],
|
|
||||||
}
|
|
||||||
for line in comment["body"].split("\n")
|
|
||||||
if line.startswith("!sd ")
|
|
||||||
]
|
|
||||||
|
|
||||||
# Max 5 prompts per comment.
|
|
||||||
reply_contexts = reply_contexts[:5]
|
|
||||||
|
|
||||||
for context in reply_contexts:
|
|
||||||
print(f"Queueing prompt `{context['prompt']}`.")
|
|
||||||
chain(
|
|
||||||
generate_reply.s(context).set(queue="gen"),
|
|
||||||
post_reply.s().set(queue="api"),
|
|
||||||
).apply_async()
|
|
||||||
|
|
||||||
|
|
||||||
#
|
|
||||||
# Generation worker tasks
|
|
||||||
#
|
|
||||||
class GenTask(Task):
|
|
||||||
_generator = None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def generator(self):
|
|
||||||
if self._generator is None:
|
|
||||||
from ldm.generate import Generate
|
|
||||||
|
|
||||||
self._generator = Generate(sampler_name="k_euler_a")
|
|
||||||
self._generator.load_model()
|
|
||||||
|
|
||||||
print("Model loaded.")
|
|
||||||
|
|
||||||
return self._generator
|
|
||||||
|
|
||||||
|
|
||||||
@app.task(base=GenTask)
|
|
||||||
def generate_reply(context):
|
|
||||||
print(f"Generating `{context['prompt']}`.")
|
|
||||||
|
|
||||||
if not os.path.exists("out"):
|
|
||||||
os.makedirs("out")
|
|
||||||
|
|
||||||
results = generate_reply.generator.prompt2png(
|
|
||||||
context["prompt"], outdir=f"out/{os.environ['WORKER_ID']}", iterations=9
|
|
||||||
)
|
|
||||||
|
|
||||||
image_paths = [r[0] for r in results]
|
|
||||||
grid = concat_images(image_paths, size=(512, 512), shape=(3, 3))
|
|
||||||
|
|
||||||
grid_basename = f"{random.randrange(10**6, 10**7)}.webp"
|
|
||||||
|
|
||||||
if not os.path.exists("grid"):
|
|
||||||
os.makedirs("grid")
|
|
||||||
|
|
||||||
grid_path = f"grid/{grid_basename}"
|
|
||||||
grid.save(grid_path, "WEBP")
|
|
||||||
|
|
||||||
context["image_path"] = os.path.abspath(grid_path)
|
|
||||||
context["worker_host"] = os.environ["WORKER_HOST"]
|
|
||||||
context["worker_ssh_port"] = os.environ["WORKER_SSH_PORT"]
|
|
||||||
|
|
||||||
return context
|
|
24
utils.py
24
utils.py
|
@ -1,24 +0,0 @@
|
||||||
from PIL import Image, ImageOps
|
|
||||||
|
|
||||||
#
|
|
||||||
# https://gist.github.com/njanakiev/1932e0a450df6d121c05069d5f7d7d6f
|
|
||||||
#
|
|
||||||
def concat_images(image_paths, size, shape=None):
|
|
||||||
# Open images and resize them
|
|
||||||
width, height = size
|
|
||||||
images = map(Image.open, image_paths)
|
|
||||||
images = [ImageOps.fit(image, size, Image.ANTIALIAS) for image in images]
|
|
||||||
|
|
||||||
# Create canvas for the final image with total size
|
|
||||||
shape = shape if shape else (1, len(images))
|
|
||||||
image_size = (width * shape[1], height * shape[0])
|
|
||||||
image = Image.new("RGB", image_size)
|
|
||||||
|
|
||||||
# Paste images into final image
|
|
||||||
for row in range(shape[0]):
|
|
||||||
for col in range(shape[1]):
|
|
||||||
offset = width * col, height * row
|
|
||||||
idx = row * shape[1] + col
|
|
||||||
image.paste(images[idx], offset)
|
|
||||||
|
|
||||||
return image
|
|
Loading…
Reference in New Issue