dedup sql-escaping code

pull/210/head
Aevann 2023-09-27 01:12:01 +03:00
parent c3870505ea
commit 46633a4529
5 changed files with 20 additions and 13 deletions

View File

@ -7,9 +7,14 @@ from files.classes import Comment, CommentVote, Hat, Sub, Post, User, UserBlock,
from files.helpers.config.const import *
from files.__main__ import cache
# Escape SQL pattern-matching special characters
def escape_for_search(string):
return string.replace('\\', '').replace('_', '\_').replace('%', '\%').strip()
def sanitize_username(username):
if not username: return username
return username.lstrip('@').replace('\\', '').replace('_', '\_').replace('%', '').replace('(', '').replace(')', '').strip()
username = username.lstrip('@').replace('(', '').replace(')', '')
username = escape_for_search(username)
return username
def get_user(username, v=None, graceful=False, include_blocks=False, attributes=None):
if not username:

View File

@ -128,8 +128,8 @@ def frontlist(v=None, sort="hot", page=1, t="all", ids_only=True, filter_words='
if v and filter_words:
for word in filter_words:
word = word.replace('\\', '').replace('_', '\_').replace('%', '\%').strip()
posts=posts.filter(not_(Post.title.ilike(f'%{word}%')))
word = escape_for_search(word)
posts = posts.filter(not_(Post.title.ilike(f'%{word}%')))
total = posts.count()

View File

@ -40,7 +40,7 @@ def login_post(v):
username = request.values.get("username")
if not username: abort(400)
username = username.lstrip('@').replace('\\', '').replace('_', '\_').replace('%', '').strip()
username = sanitize_username(username)
if not username: abort(400)
if username.startswith('@'): username = username[1:]
@ -379,7 +379,7 @@ def post_forgot():
user = get_user(username, graceful=True)
email = email.replace('\\', '').replace('_', '\_').replace('%', '').strip()
email = escape_for_search(email)
if user and user.email.lower() == email.lower():
now = int(time.time())

View File

@ -417,12 +417,14 @@ def is_repost(v):
url = normalize_url(url)
search_url = url.replace('%', '').replace('\\', '').replace('_', '\_').strip()
url = escape_for_search(url)
repost = g.db.query(Post).filter(
Post.url.ilike(search_url),
Post.url.ilike(url),
Post.deleted_utc == 0,
Post.is_banned == False
).first()
if repost: return {'permalink': repost.permalink}
else: return not_a_repost

View File

@ -6,6 +6,7 @@ from sqlalchemy.orm import load_only
from files.helpers.regex import *
from files.helpers.sorting_and_time import *
from files.helpers.get import *
from files.routes.wrappers import *
from files.__main__ import app
@ -40,8 +41,7 @@ def searchparse(text):
for m in search_token_regex.finditer(text):
token = m[1] if m[1] else m[2]
if not token: token = ''
# Escape SQL pattern matching special characters
token = token.replace('\\', '').replace('_', '\_').replace('%', '\%')
token = escape_for_search(token)
criteria['q'].append(token)
return criteria
@ -117,7 +117,7 @@ def searchposts(v):
if 'domain' in criteria:
domain = criteria['domain']
domain = domain.replace('\\', '').replace('_', '\_').replace('%', '').strip()
domain = escape_for_search(domain)
posts = posts.filter(
or_(
@ -413,8 +413,8 @@ def searchusers(v):
if 'q' in criteria:
term = criteria['q'][0]
term = term.lstrip('@')
term = term.replace('\\','').replace('_','\_').replace('%','')
term = sanitize_username(term)
users = users.filter(
or_(