type hints for get functions :)

remotes/1693176582716663532/tmp_refs/heads/watchparty
justcool393 2022-10-12 02:22:13 -07:00
parent 32bede574b
commit db9b37de40
1 changed files with 11 additions and 21 deletions

View File

@ -1,10 +1,9 @@
from typing import Optional
from files.classes import *
from flask import g
def get_id(username, graceful=False):
def get_id(username, graceful=False) -> Optional[int]:
username = username.replace('\\', '').replace('_', '\_').replace('%', '').strip()
user = g.db.query(
User.id
).filter(
@ -21,13 +20,12 @@ def get_id(username, graceful=False):
return user[0]
def get_user(username, v=None, graceful=False, rendered=False, include_blocks=False, include_shadowbanned=True):
def get_user(username, v=None, graceful=False, rendered=False, include_blocks=False, include_shadowbanned=True) -> Optional[User]:
if not username:
if not graceful: abort(404)
else: return None
username = username.replace('\\', '').replace('_', '\_').replace('%', '').replace('(', '').replace(')', '').strip()
user = g.db.query(
User
).filter(
@ -66,13 +64,11 @@ def get_user(username, v=None, graceful=False, rendered=False, include_blocks=Fa
return user
def get_users(usernames, graceful=False):
def get_users(usernames, graceful=False) -> list[User]:
def clean(n):
return n.replace('\\', '').replace('_', '\_').replace('%', '').strip()
usernames = [clean(n) for n in usernames]
users = g.db.query(User).filter(
or_(
User.username.ilike(any_(usernames)),
@ -85,8 +81,7 @@ def get_users(usernames, graceful=False):
return users
def get_account(id, v=None, graceful=False, include_blocks=False, include_shadowbanned=True):
def get_account(id, v=None, graceful=False, include_blocks=False, include_shadowbanned=True) -> Optional[User]:
try:
id = int(id)
except:
@ -118,8 +113,7 @@ def get_account(id, v=None, graceful=False, include_blocks=False, include_shadow
return user
def get_post(i, v=None, graceful=False):
def get_post(i, v=None, graceful=False) -> Optional[Submission]:
try: i = int(i)
except: abort(404)
@ -167,8 +161,7 @@ def get_post(i, v=None, graceful=False):
return x
def get_posts(pids, v=None):
def get_posts(pids, v=None) -> list[Submission]:
if not pids:
return []
@ -210,8 +203,7 @@ def get_posts(pids, v=None):
return sorted(output, key=lambda x: pids.index(x.id))
def get_comment(i, v=None, graceful=False):
def get_comment(i, v=None, graceful=False) -> Optional[Comment]:
try: i = int(i)
except: abort(404)
@ -246,8 +238,7 @@ def get_comment(i, v=None, graceful=False):
return comment
def get_comments(cids, v=None, load_parent=False):
def get_comments(cids, v=None, load_parent=False) -> list[Comment]:
if not cids: return []
if v:
@ -295,7 +286,7 @@ def get_comments(cids, v=None, load_parent=False):
return sorted(output, key=lambda x: cids.index(x.id))
def get_sub_by_name(sub, v=None, graceful=False):
def get_sub_by_name(sub, v=None, graceful=False) -> Optional[Sub]:
if not sub:
if graceful: return None
else: abort(404)
@ -309,8 +300,7 @@ def get_sub_by_name(sub, v=None, graceful=False):
else: abort(404)
return sub
def get_domain(s):
def get_domain(s) -> Optional[BannedDomain]:
parts = s.split(".")
domain_list = set()
for i in range(len(parts)):