forked from rDrama/rDrama
1
0
Fork 0

dedup sql-escaping code

master
Aevann 2023-09-27 01:12:01 +03:00
parent c3870505ea
commit 46633a4529
5 changed files with 20 additions and 13 deletions

View File

@ -7,9 +7,14 @@ from files.classes import Comment, CommentVote, Hat, Sub, Post, User, UserBlock,
from files.helpers.config.const import * from files.helpers.config.const import *
from files.__main__ import cache from files.__main__ import cache
# Escape SQL pattern-matching special characters
def escape_for_search(string):
return string.replace('\\', '').replace('_', '\_').replace('%', '\%').strip()
def sanitize_username(username): def sanitize_username(username):
if not username: return username username = username.lstrip('@').replace('(', '').replace(')', '')
return username.lstrip('@').replace('\\', '').replace('_', '\_').replace('%', '').replace('(', '').replace(')', '').strip() username = escape_for_search(username)
return username
def get_user(username, v=None, graceful=False, include_blocks=False, attributes=None): def get_user(username, v=None, graceful=False, include_blocks=False, attributes=None):
if not username: if not username:

View File

@ -128,8 +128,8 @@ def frontlist(v=None, sort="hot", page=1, t="all", ids_only=True, filter_words='
if v and filter_words: if v and filter_words:
for word in filter_words: for word in filter_words:
word = word.replace('\\', '').replace('_', '\_').replace('%', '\%').strip() word = escape_for_search(word)
posts=posts.filter(not_(Post.title.ilike(f'%{word}%'))) posts = posts.filter(not_(Post.title.ilike(f'%{word}%')))
total = posts.count() total = posts.count()

View File

@ -40,7 +40,7 @@ def login_post(v):
username = request.values.get("username") username = request.values.get("username")
if not username: abort(400) if not username: abort(400)
username = username.lstrip('@').replace('\\', '').replace('_', '\_').replace('%', '').strip() username = sanitize_username(username)
if not username: abort(400) if not username: abort(400)
if username.startswith('@'): username = username[1:] if username.startswith('@'): username = username[1:]
@ -379,7 +379,7 @@ def post_forgot():
user = get_user(username, graceful=True) user = get_user(username, graceful=True)
email = email.replace('\\', '').replace('_', '\_').replace('%', '').strip() email = escape_for_search(email)
if user and user.email.lower() == email.lower(): if user and user.email.lower() == email.lower():
now = int(time.time()) now = int(time.time())

View File

@ -417,12 +417,14 @@ def is_repost(v):
url = normalize_url(url) url = normalize_url(url)
search_url = url.replace('%', '').replace('\\', '').replace('_', '\_').strip() url = escape_for_search(url)
repost = g.db.query(Post).filter( repost = g.db.query(Post).filter(
Post.url.ilike(search_url), Post.url.ilike(url),
Post.deleted_utc == 0, Post.deleted_utc == 0,
Post.is_banned == False Post.is_banned == False
).first() ).first()
if repost: return {'permalink': repost.permalink} if repost: return {'permalink': repost.permalink}
else: return not_a_repost else: return not_a_repost

View File

@ -6,6 +6,7 @@ from sqlalchemy.orm import load_only
from files.helpers.regex import * from files.helpers.regex import *
from files.helpers.sorting_and_time import * from files.helpers.sorting_and_time import *
from files.helpers.get import *
from files.routes.wrappers import * from files.routes.wrappers import *
from files.__main__ import app from files.__main__ import app
@ -40,8 +41,7 @@ def searchparse(text):
for m in search_token_regex.finditer(text): for m in search_token_regex.finditer(text):
token = m[1] if m[1] else m[2] token = m[1] if m[1] else m[2]
if not token: token = '' if not token: token = ''
# Escape SQL pattern matching special characters token = escape_for_search(token)
token = token.replace('\\', '').replace('_', '\_').replace('%', '\%')
criteria['q'].append(token) criteria['q'].append(token)
return criteria return criteria
@ -117,7 +117,7 @@ def searchposts(v):
if 'domain' in criteria: if 'domain' in criteria:
domain = criteria['domain'] domain = criteria['domain']
domain = domain.replace('\\', '').replace('_', '\_').replace('%', '').strip() domain = escape_for_search(domain)
posts = posts.filter( posts = posts.filter(
or_( or_(
@ -413,8 +413,8 @@ def searchusers(v):
if 'q' in criteria: if 'q' in criteria:
term = criteria['q'][0] term = criteria['q'][0]
term = term.lstrip('@')
term = term.replace('\\','').replace('_','\_').replace('%','') term = sanitize_username(term)
users = users.filter( users = users.filter(
or_( or_(