bots/web/main.py

149 lines
4.2 KiB
Python

import re
import logging
from io import BytesIO
from typing import List
import requests
from fastapi import FastAPI, HTTPException, Request, UploadFile
from fastapi.responses import HTMLResponse, Response
from fastapi.templating import Jinja2Templates
from pydantic import BaseModel
from sqlitedict import SqliteDict
from starlette.responses import RedirectResponse
from bots.data import config, marseygen_queue, bussyboy_queue, db, bussyboy_log
from bots.clients.drama import DramaClient
app = FastAPI()
templates = Jinja2Templates(directory="templates")
# Set up logging.
logging.basicConfig(
filename=f"{config['data_dir']}/logs/web.log",
filemode="a",
format="%(name)s - %(levelname)s - %(message)s",
level=logging.INFO,
)
logger = logging.getLogger(__name__)
@app.get("/")
def read_root():
return RedirectResponse(url="https://rdrama.net")
@app.get(
"/h/{_hole}/post/{_post_id}/{_post_slug}/{comment_id}", response_class=HTMLResponse
)
async def get_alt_comments_with_hole(
request: Request, _hole: str, _post_id: int, _post_slug: str, comment_id: str
):
return _get_alt_comments(request, comment_id)
@app.get("/post/{_post_id}/{_post_slug}/{comment_id}", response_class=HTMLResponse)
async def get_alt_comments_without_hole(
request: Request, _post_id: int, _post_slug: str, comment_id: str
):
return _get_alt_comments(request, comment_id)
def _get_alt_comments(request, comment_id):
logger.info(f"Request received for {comment_id}")
if comment_id not in bussyboy_log:
return templates.TemplateResponse("404.html", {"request": request})
log = bussyboy_log[comment_id]
return templates.TemplateResponse(
"comment.html",
{
"request": request,
"parent_body": log["parent_body"],
"candidates": log["candidates"],
},
)
@app.get("/bussyboy")
async def bussyboy():
config_keys = ["prompt_token_limit", "num_candidates", "username", "fake_usernames"]
bussyboy_config = {key: config["bussyboy"][key] for key in config_keys}
return {"config": bussyboy_config, "queue": bussyboy_queue}
class BussyboyReplyInfo(BaseModel):
prompt: str
body: str
candidates: List[str]
@app.post("/bussyboy/{request_id}")
async def bussyboy_submit(
request: Request, request_id: str, key: str, info: BussyboyReplyInfo
):
if key != config["server_key"]:
return Response(status_code=400)
if request_id not in bussyboy_queue:
logging.error(f"Unknown request ID: {request_id}")
raise HTTPException(status_code=500)
# Post reply.
request = bussyboy_queue[request_id]
drama_client = DramaClient(config["bussyboy"]["token"], logger=logger)
reply_response = drama_client.reply(request["thread"][-1], info.body)
del bussyboy_queue[request_id]
log = {
# Log thread context.
"parent_body": request["thread"][-1]["body"],
"post": request["post"],
"thread": request["thread"],
# Log generation info.
"prompt": info.prompt,
"body": info.body,
"candidates": info.candidates,
}
bussyboy_log[str(reply_response["id"])] = log
return Response(status_code=200)
@app.get("/marseygen")
async def marseygen():
return {key: value[1] for key, value in marseygen_queue.items()}
@app.post("/marseygen/{request_id}")
async def marseygen_submit(request_id: str, key: str, file: UploadFile):
if key != config["server_key"]:
return Response(status_code=400)
if request_id not in marseygen_queue:
logging.error(f"Unknown request ID: {request_id}")
raise HTTPException(status_code=500)
comment, prompt = marseygen_queue[request_id]
del marseygen_queue[request_id]
contents = await file.read()
image_bytes = BytesIO(contents)
# Reset the stream to the start.
image_bytes.seek(0)
# No pinging or gambling for Marseygen.
prompt = re.sub(r"(!)(\w+)", r"<span>\1</span>\2", prompt)
# Post reply.
image = {"file": ("image.webp", image_bytes, "image/webp")}
drama_client = DramaClient(config["marseygen"]["token"], logger=logger)
drama_client.reply(comment, f"`{prompt}`<br>[image.webp]", image)
return Response(status_code=200)