149 lines
4.2 KiB
Python
149 lines
4.2 KiB
Python
import re
|
|
|
|
import logging
|
|
from io import BytesIO
|
|
from typing import List
|
|
|
|
import requests
|
|
from fastapi import FastAPI, HTTPException, Request, UploadFile
|
|
from fastapi.responses import HTMLResponse, Response
|
|
from fastapi.templating import Jinja2Templates
|
|
from pydantic import BaseModel
|
|
from sqlitedict import SqliteDict
|
|
from starlette.responses import RedirectResponse
|
|
|
|
from bots.data import config, marseygen_queue, bussyboy_queue, db, bussyboy_log
|
|
from bots.clients.drama import DramaClient
|
|
|
|
app = FastAPI()
|
|
|
|
templates = Jinja2Templates(directory="templates")
|
|
|
|
|
|
# Set up logging.
|
|
logging.basicConfig(
|
|
filename=f"{config['data_dir']}/logs/web.log",
|
|
filemode="a",
|
|
format="%(name)s - %(levelname)s - %(message)s",
|
|
level=logging.INFO,
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@app.get("/")
|
|
def read_root():
|
|
return RedirectResponse(url="https://rdrama.net")
|
|
|
|
|
|
@app.get(
|
|
"/h/{_hole}/post/{_post_id}/{_post_slug}/{comment_id}", response_class=HTMLResponse
|
|
)
|
|
async def get_alt_comments_with_hole(
|
|
request: Request, _hole: str, _post_id: int, _post_slug: str, comment_id: str
|
|
):
|
|
return _get_alt_comments(request, comment_id)
|
|
|
|
|
|
@app.get("/post/{_post_id}/{_post_slug}/{comment_id}", response_class=HTMLResponse)
|
|
async def get_alt_comments_without_hole(
|
|
request: Request, _post_id: int, _post_slug: str, comment_id: str
|
|
):
|
|
return _get_alt_comments(request, comment_id)
|
|
|
|
|
|
def _get_alt_comments(request, comment_id):
|
|
logger.info(f"Request received for {comment_id}")
|
|
if comment_id not in bussyboy_log:
|
|
return templates.TemplateResponse("404.html", {"request": request})
|
|
|
|
log = bussyboy_log[comment_id]
|
|
return templates.TemplateResponse(
|
|
"comment.html",
|
|
{
|
|
"request": request,
|
|
"parent_body": log["parent_body"],
|
|
"candidates": log["candidates"],
|
|
},
|
|
)
|
|
|
|
|
|
@app.get("/bussyboy")
|
|
async def bussyboy():
|
|
config_keys = ["prompt_token_limit", "num_candidates", "username", "fake_usernames"]
|
|
bussyboy_config = {key: config["bussyboy"][key] for key in config_keys}
|
|
|
|
return {"config": bussyboy_config, "queue": bussyboy_queue}
|
|
|
|
|
|
class BussyboyReplyInfo(BaseModel):
|
|
prompt: str
|
|
body: str
|
|
candidates: List[str]
|
|
|
|
|
|
@app.post("/bussyboy/{request_id}")
|
|
async def bussyboy_submit(
|
|
request: Request, request_id: str, key: str, info: BussyboyReplyInfo
|
|
):
|
|
if key != config["server_key"]:
|
|
return Response(status_code=400)
|
|
|
|
if request_id not in bussyboy_queue:
|
|
logging.error(f"Unknown request ID: {request_id}")
|
|
raise HTTPException(status_code=500)
|
|
|
|
# Post reply.
|
|
request = bussyboy_queue[request_id]
|
|
drama_client = DramaClient(config["bussyboy"]["token"], logger=logger)
|
|
reply_response = drama_client.reply(request["thread"][-1], info.body)
|
|
del bussyboy_queue[request_id]
|
|
|
|
log = {
|
|
# Log thread context.
|
|
"parent_body": request["thread"][-1]["body"],
|
|
"post": request["post"],
|
|
"thread": request["thread"],
|
|
# Log generation info.
|
|
"prompt": info.prompt,
|
|
"body": info.body,
|
|
"candidates": info.candidates,
|
|
}
|
|
|
|
bussyboy_log[str(reply_response["id"])] = log
|
|
|
|
return Response(status_code=200)
|
|
|
|
|
|
@app.get("/marseygen")
|
|
async def marseygen():
|
|
return {key: value[1] for key, value in marseygen_queue.items()}
|
|
|
|
|
|
@app.post("/marseygen/{request_id}")
|
|
async def marseygen_submit(request_id: str, key: str, file: UploadFile):
|
|
if key != config["server_key"]:
|
|
return Response(status_code=400)
|
|
|
|
if request_id not in marseygen_queue:
|
|
logging.error(f"Unknown request ID: {request_id}")
|
|
raise HTTPException(status_code=500)
|
|
|
|
comment, prompt = marseygen_queue[request_id]
|
|
del marseygen_queue[request_id]
|
|
|
|
contents = await file.read()
|
|
image_bytes = BytesIO(contents)
|
|
|
|
# Reset the stream to the start.
|
|
image_bytes.seek(0)
|
|
|
|
# No pinging or gambling for Marseygen.
|
|
prompt = re.sub(r"(!)(\w+)", r"<span>\1</span>\2", prompt)
|
|
|
|
# Post reply.
|
|
image = {"file": ("image.webp", image_bytes, "image/webp")}
|
|
drama_client = DramaClient(config["marseygen"]["token"], logger=logger)
|
|
drama_client.reply(comment, f"`{prompt}`<br>[image.webp]", image)
|
|
|
|
return Response(status_code=200)
|