diff --git a/files/classes/user.py b/files/classes/user.py index 538558886..4fec754b6 100644 --- a/files/classes/user.py +++ b/files/classes/user.py @@ -211,33 +211,32 @@ class User(Base): def charge_account(self, currency, amount, **kwargs): - in_db = g.db.query(User).filter(User.id == self.id).with_for_update().one() succeeded = False charged_coins = 0 should_check_balance = kwargs.get('should_check_balance', True) if currency == 'coins': - account_balance = in_db.coins + account_balance = self.coins if not should_check_balance or account_balance >= amount: self.coins -= amount succeeded = True charged_coins = amount elif currency == 'marseybux': - account_balance = in_db.marseybux + account_balance = self.marseybux if not should_check_balance or account_balance >= amount: self.marseybux -= amount succeeded = True elif currency == 'combined': - if in_db.marseybux >= amount: + if self.marseybux >= amount: subtracted_mbux = amount subtracted_coins = 0 else: - subtracted_mbux = in_db.marseybux + subtracted_mbux = self.marseybux subtracted_coins = amount - subtracted_mbux - if subtracted_coins > in_db.coins: + if subtracted_coins > self.coins: return (False, 0) self.coins -= subtracted_coins @@ -247,7 +246,6 @@ class User(Base): if succeeded: g.db.add(self) - g.db.flush() return (succeeded, charged_coins) diff --git a/files/helpers/alerts.py b/files/helpers/alerts.py index c8e17b97c..370fd693b 100644 --- a/files/helpers/alerts.py +++ b/files/helpers/alerts.py @@ -216,7 +216,6 @@ def push_notif(uids, title, body, url_or_comment): subscriptions = g.db.query(PushSubscription.subscription_json).filter(PushSubscription.user_id.in_(uids)).all() subscriptions = [x[0] for x in subscriptions] - g.db.flush() gevent.spawn(_push_notif_thread, subscriptions, title, body, url) diff --git a/files/helpers/get.py b/files/helpers/get.py index 84d20cf3a..85ec9e49d 100644 --- a/files/helpers/get.py +++ b/files/helpers/get.py @@ -2,7 +2,7 @@ from typing import Callable, Iterable, List, Optional, Union from flask import * from sqlalchemy import and_, any_, or_ -from sqlalchemy.orm import joinedload, selectinload, Query +from sqlalchemy.orm import joinedload, Query from files.classes import Comment, CommentVote, Hat, Sub, Post, User, UserBlock, Vote from files.helpers.config.const import * @@ -169,7 +169,7 @@ def get_post(i:Union[str, int], v:Optional[User]=None, graceful=False) -> Option return x -def get_posts(pids:Iterable[int], v:Optional[User]=None, eager:bool=False, extra:Optional[Callable[[Query], Query]]=None) -> List[Post]: +def get_posts(pids:Iterable[int], v:Optional[User]=None, extra:Optional[Callable[[Query], Query]]=None) -> List[Post]: if not pids: return [] if v: @@ -202,20 +202,6 @@ def get_posts(pids:Iterable[int], v:Optional[User]=None, eager:bool=False, extra if extra: query = extra(query) - if eager: - query = query.options( - selectinload(Post.author).options( - selectinload(User.hats_equipped.and_(Hat.equipped == True)) \ - .joinedload(Hat.hat_def, innerjoin=True), - selectinload(User.badges), - selectinload(User.sub_mods), - selectinload(User.sub_exiles), - ), - selectinload(Post.reports), - selectinload(Post.awards), - selectinload(Post.options), - ) - results = query.all() if v: diff --git a/files/routes/front.py b/files/routes/front.py index 353a5e9d2..3b6c5477e 100644 --- a/files/routes/front.py +++ b/files/routes/front.py @@ -65,7 +65,7 @@ def front_all(v, sub=None): pins=pins, ) - posts = get_posts(ids, v=v, eager=True) + posts = get_posts(ids, v=v) if v and v.hidevotedon: posts = [x for x in posts if not hasattr(x, 'voted') or not x.voted] diff --git a/files/routes/notifications.py b/files/routes/notifications.py index 685042b64..ab6cbe03d 100644 --- a/files/routes/notifications.py +++ b/files/routes/notifications.py @@ -184,7 +184,7 @@ def notifications_posts(v:User): listing = listing.order_by(Post.created_utc.desc()).offset(PAGE_SIZE * (page - 1)).limit(PAGE_SIZE).all() listing = [x.id for x in listing] - listing = get_posts(listing, v=v, eager=True) + listing = get_posts(listing, v=v) for p in listing: p.unread = p.created_utc > v.last_viewed_post_notifs diff --git a/files/routes/search.py b/files/routes/search.py index 2afdfa7ec..ee2d74ff5 100644 --- a/files/routes/search.py +++ b/files/routes/search.py @@ -167,7 +167,7 @@ def searchposts(v:User): ids = [x.id for x in posts] - posts = get_posts(ids, v=v, eager=True) + posts = get_posts(ids, v=v) if v.client: return {"total":total, "data":[x.json for x in posts]} diff --git a/files/routes/users.py b/files/routes/users.py index e88208d2d..ee19691dd 100644 --- a/files/routes/users.py +++ b/files/routes/users.py @@ -53,7 +53,7 @@ def upvoters_downvoters(v, username, uid, cls, vote_cls, vote_dir, template, sta listing = [x.id for x in listing] if cls == Post: - listing = get_posts(listing, v=v, eager=True) + listing = get_posts(listing, v=v) elif cls == Comment: listing = get_comments(listing, v=v) else: @@ -119,7 +119,7 @@ def upvoting_downvoting(v, username, uid, cls, vote_cls, vote_dir, template, sta listing = [x.id for x in listing] if cls == Post: - listing = get_posts(listing, v=v, eager=True) + listing = get_posts(listing, v=v) elif cls == Comment: listing = get_comments(listing, v=v) else: @@ -179,7 +179,7 @@ def user_voted(v, username, cls, vote_cls, template, standalone): listing = [x.id for x in listing] if cls == Post: - listing = get_posts(listing, v=v, eager=True) + listing = get_posts(listing, v=v) elif cls == Comment: listing = get_comments(listing, v=v) else: @@ -374,11 +374,10 @@ def transfer_currency(v:User, username:str, currency_name:Literal['coins', 'mars abort(400, f"You don't have enough {currency_name}") if not v.shadowbanned: - user_query = g.db.query(User).filter_by(id=receiver.id) if currency_name == 'marseybux': - user_query.update({ User.marseybux: User.marseybux + amount - tax }) + receiver.pay_account('marseybux', amount - tax) elif currency_name == 'coins': - user_query.update({ User.coins: User.coins + amount - tax }) + receiver.pay_account('coins', amount - tax) else: raise ValueError(f"Invalid currency '{currency_name}' got when transferring {amount} from {v.id} to {receiver.id}") g.db.add(receiver) @@ -999,7 +998,7 @@ def u_username(v:Optional[User], username:str): for p in sticky: ids = [p.id] + ids - listing = get_posts(ids, v=v, eager=True) + listing = get_posts(ids, v=v) if u.unban_utc: if v and v.client: @@ -1247,7 +1246,7 @@ def get_saves_and_subscribes(v, template, relationship_cls, page:int, standalone extra = lambda q:q.filter(cls.is_banned == False, cls.deleted_utc == 0) if cls is Post: - listing = get_posts(ids, v=v, eager=True, extra=extra) + listing = get_posts(ids, v=v, extra=extra) elif cls is Comment: listing = get_comments(ids, v=v, extra=extra) else: diff --git a/files/routes/votes.py b/files/routes/votes.py index c086d5f65..f037efa72 100644 --- a/files/routes/votes.py +++ b/files/routes/votes.py @@ -137,7 +137,6 @@ def vote_post_comment(target_id, new, v, cls, vote_cls): coins=coin_value ) g.db.add(vote) - g.db.flush() # this is hacky but it works, we should probably do better later def get_vote_count(dir, real_instead_of_dir):