From 46633a452913fea2ade7bdd5fde0d094e7504e66 Mon Sep 17 00:00:00 2001 From: Aevann Date: Wed, 27 Sep 2023 01:12:01 +0300 Subject: [PATCH] dedup sql-escaping code --- files/helpers/get.py | 9 +++++++-- files/routes/front.py | 4 ++-- files/routes/login.py | 4 ++-- files/routes/posts.py | 6 ++++-- files/routes/search.py | 10 +++++----- 5 files changed, 20 insertions(+), 13 deletions(-) diff --git a/files/helpers/get.py b/files/helpers/get.py index 5294c6d65..8deba9eb0 100644 --- a/files/helpers/get.py +++ b/files/helpers/get.py @@ -7,9 +7,14 @@ from files.classes import Comment, CommentVote, Hat, Sub, Post, User, UserBlock, from files.helpers.config.const import * from files.__main__ import cache +# Escape SQL pattern-matching special characters +def escape_for_search(string): + return string.replace('\\', '').replace('_', '\_').replace('%', '\%').strip() + def sanitize_username(username): - if not username: return username - return username.lstrip('@').replace('\\', '').replace('_', '\_').replace('%', '').replace('(', '').replace(')', '').strip() + username = username.lstrip('@').replace('(', '').replace(')', '') + username = escape_for_search(username) + return username def get_user(username, v=None, graceful=False, include_blocks=False, attributes=None): if not username: diff --git a/files/routes/front.py b/files/routes/front.py index 544847da9..7f2f08e4e 100644 --- a/files/routes/front.py +++ b/files/routes/front.py @@ -128,8 +128,8 @@ def frontlist(v=None, sort="hot", page=1, t="all", ids_only=True, filter_words=' if v and filter_words: for word in filter_words: - word = word.replace('\\', '').replace('_', '\_').replace('%', '\%').strip() - posts=posts.filter(not_(Post.title.ilike(f'%{word}%'))) + word = escape_for_search(word) + posts = posts.filter(not_(Post.title.ilike(f'%{word}%'))) total = posts.count() diff --git a/files/routes/login.py b/files/routes/login.py index 75513782d..a8602c8a3 100644 --- a/files/routes/login.py +++ b/files/routes/login.py @@ -40,7 +40,7 @@ def login_post(v): username = request.values.get("username") if not username: abort(400) - username = username.lstrip('@').replace('\\', '').replace('_', '\_').replace('%', '').strip() + username = sanitize_username(username) if not username: abort(400) if username.startswith('@'): username = username[1:] @@ -379,7 +379,7 @@ def post_forgot(): user = get_user(username, graceful=True) - email = email.replace('\\', '').replace('_', '\_').replace('%', '').strip() + email = escape_for_search(email) if user and user.email.lower() == email.lower(): now = int(time.time()) diff --git a/files/routes/posts.py b/files/routes/posts.py index ffc523eec..66fd78fca 100644 --- a/files/routes/posts.py +++ b/files/routes/posts.py @@ -417,12 +417,14 @@ def is_repost(v): url = normalize_url(url) - search_url = url.replace('%', '').replace('\\', '').replace('_', '\_').strip() + url = escape_for_search(url) + repost = g.db.query(Post).filter( - Post.url.ilike(search_url), + Post.url.ilike(url), Post.deleted_utc == 0, Post.is_banned == False ).first() + if repost: return {'permalink': repost.permalink} else: return not_a_repost diff --git a/files/routes/search.py b/files/routes/search.py index 3dd4dd115..897213952 100644 --- a/files/routes/search.py +++ b/files/routes/search.py @@ -6,6 +6,7 @@ from sqlalchemy.orm import load_only from files.helpers.regex import * from files.helpers.sorting_and_time import * +from files.helpers.get import * from files.routes.wrappers import * from files.__main__ import app @@ -40,8 +41,7 @@ def searchparse(text): for m in search_token_regex.finditer(text): token = m[1] if m[1] else m[2] if not token: token = '' - # Escape SQL pattern matching special characters - token = token.replace('\\', '').replace('_', '\_').replace('%', '\%') + token = escape_for_search(token) criteria['q'].append(token) return criteria @@ -117,7 +117,7 @@ def searchposts(v): if 'domain' in criteria: domain = criteria['domain'] - domain = domain.replace('\\', '').replace('_', '\_').replace('%', '').strip() + domain = escape_for_search(domain) posts = posts.filter( or_( @@ -413,8 +413,8 @@ def searchusers(v): if 'q' in criteria: term = criteria['q'][0] - term = term.lstrip('@') - term = term.replace('\\','').replace('_','\_').replace('%','') + + term = sanitize_username(term) users = users.filter( or_(