188 lines
7.2 KiB
Python
188 lines
7.2 KiB
Python
from datetime import datetime
|
|
from email.policy import default
|
|
import time
|
|
import openai
|
|
import sqlalchemy
|
|
from sqlalchemy.orm import declarative_base, Session
|
|
from sqlalchemy import Column, DateTime, String, ForeignKey, Integer, Boolean, Table, and_, or_
|
|
|
|
Base = declarative_base()
|
|
|
|
class User(Base):
|
|
__tablename__ = "user"
|
|
|
|
id = Column(Integer, primary_key = True)
|
|
number_of_comments = Column(Integer, default = 0)
|
|
|
|
def get_user(user_id : int, session : Session):
|
|
stmt = sqlalchemy.select(User).where(User.id == user_id)
|
|
user = session.execute(stmt).scalar_one_or_none()
|
|
if (user == None):
|
|
user = User(id = user_id, number_of_comments = 0)
|
|
session.add(user)
|
|
return user
|
|
|
|
def increase_number_of_comments(user_id : int, session : Session):
|
|
User.get_user(user_id, session).number_of_comments+=1
|
|
|
|
def get_number_of_comments(user_id : int, session : Session) -> int:
|
|
return User.get_user(user_id, session).number_of_comments
|
|
|
|
def reset_number_of_comments(user_id : int, session : Session):
|
|
User.get_user(user_id, session).number_of_comments = 0
|
|
|
|
def reset_all_comments(session : Session):
|
|
stmt = sqlalchemy.select(User)
|
|
all_comments = session.execute(stmt).scalars().fetchall()
|
|
|
|
for comment in all_comments:
|
|
comment.number_of_comments = 0
|
|
|
|
session.flush()
|
|
session.commit()
|
|
|
|
|
|
class Comment(Base):
|
|
__tablename__ = "comment"
|
|
|
|
id = Column(Integer, primary_key = True)
|
|
user_comment_id = Column(Integer)
|
|
bbbb_comment_id = Column(Integer)
|
|
conversation_depth = Column(Integer)
|
|
comment_string = Column(String)
|
|
|
|
def get_past_comments(session : Session) -> 'list[str]':
|
|
stmt = sqlalchemy.select(Comment)
|
|
return [i.comment_string for i in session.execute(stmt).scalars().fetchall()[0:100]]
|
|
|
|
def get_user_comment(user_comment_id:int, session : Session):
|
|
stmt = sqlalchemy.select(Comment).where(Comment.user_comment_id == user_comment_id)
|
|
comments = session.execute(stmt).scalars().fetchall()
|
|
|
|
if len(comments) == 0:
|
|
return None
|
|
else:
|
|
return comments[0]
|
|
|
|
def get_bbbb_comment(bbbb_comment_id:int, session : Session):
|
|
stmt = sqlalchemy.select(Comment).where(Comment.bbbb_comment_id == bbbb_comment_id)
|
|
comments = session.execute(stmt).scalars().fetchall()
|
|
|
|
if len(comments) == 0:
|
|
return None
|
|
else:
|
|
return comments[0]
|
|
|
|
def get_comment(comment_id : int, session : Session):
|
|
user_comment = Comment.get_user_comment(comment_id, session)
|
|
if (user_comment is not None):
|
|
return user_comment
|
|
else:
|
|
return Comment.get_bbbb_comment(comment_id, session)
|
|
|
|
def has_replied_to_comment(comment_id : int, session : Session):
|
|
return Comment.get_comment(comment_id, session) == None
|
|
|
|
def get_conversation_depth(parent_comment_id : int, session : Session):
|
|
looked_up_comment = Comment.get_comment(parent_comment_id, session)
|
|
if (looked_up_comment is not None):
|
|
return looked_up_comment.conversation_depth
|
|
else:
|
|
return 0
|
|
|
|
def create_new_comment(user_comment_id : int, bbbb_comment_id : int, conversation_depth : int, comment_string : str, session : Session):
|
|
comment = Comment(user_comment_id = user_comment_id, bbbb_comment_id = bbbb_comment_id, conversation_depth = conversation_depth, comment_string = comment_string)
|
|
session.add(comment)
|
|
|
|
class Post(Base):
|
|
__tablename__ = "post"
|
|
|
|
id = Column(Integer, primary_key = True)
|
|
has_replied = Column(Boolean, default = False)
|
|
replies_to_post = Column(Integer, default = 0)
|
|
|
|
def get_post(post_id : int, session : Session):
|
|
stmt = sqlalchemy.select(Post).where(Post.id == post_id)
|
|
post = session.execute(stmt).scalar_one_or_none()
|
|
if (post == None):
|
|
post = Post(id = post_id)
|
|
session.add(post)
|
|
return post
|
|
|
|
def has_replied_to_post(post_id : int, session : Session):
|
|
return Post.get_post(post_id, session).has_replied
|
|
|
|
def increment_replies(post_id : int, session : Session):
|
|
Post.get_post(post_id, session).replies_to_post += 1
|
|
|
|
def get_number_of_replies(post_id : int, session : Session):
|
|
replies = Post.get_post(post_id, session).replies_to_post
|
|
if replies == None:
|
|
replies = 0
|
|
|
|
return replies
|
|
|
|
def register_post_reply(post_id : int, session : Session):
|
|
Post.get_post(post_id, session).has_replied = True
|
|
|
|
|
|
class OpenAIToken(Base):
|
|
__tablename__ = "openaikey"
|
|
|
|
id = Column(Integer, primary_key = True)
|
|
token = Column(String)
|
|
is_active = Column(Boolean,default=False)
|
|
is_expended = Column(Boolean,default=False)
|
|
number_of_requests = Column(Integer, default=0)
|
|
registered_time = Column(DateTime)
|
|
begin_time = Column(DateTime)
|
|
end_time = Column(DateTime)
|
|
|
|
def add_token(token : str, session : Session):
|
|
openAIToken = OpenAIToken(token = token,
|
|
is_active = False,
|
|
is_expended = False,
|
|
registered_time = datetime.now())
|
|
session.add(openAIToken)
|
|
session.flush()
|
|
session.commit()
|
|
|
|
def get_all_valid_tokens(session : Session) -> 'list[OpenAIToken]':
|
|
stmt = sqlalchemy.select(OpenAIToken).where(OpenAIToken.is_expended == False)
|
|
open_ai_tokens = session.execute(stmt).scalars().fetchall()
|
|
return open_ai_tokens
|
|
|
|
def get_active_token(session : Session) -> 'OpenAIToken':
|
|
stmt = sqlalchemy.select(OpenAIToken).where(OpenAIToken.is_active == True)
|
|
open_ai_token = session.execute(stmt).scalar_one_or_none()
|
|
if (open_ai_token is None):
|
|
print("Aw shucks, it's None")
|
|
valid_tokens = OpenAIToken.get_all_valid_tokens(session)
|
|
if len(valid_tokens) == 0:
|
|
raise BaseException("WE ARE OUT OF TOKENS!!!!")
|
|
else:
|
|
print("Activating token")
|
|
new_active_token = valid_tokens[0]
|
|
new_active_token.is_active = True
|
|
new_active_token.begin_time = datetime.now()
|
|
return new_active_token
|
|
else:
|
|
print("Returning token")
|
|
return open_ai_token
|
|
|
|
def call_open_ai(prompt : str, session : Session) -> 'str':
|
|
openAIToken = OpenAIToken.get_active_token(session)
|
|
if (openAIToken.number_of_requests == None):
|
|
openAIToken.number_of_requests = 1
|
|
else:
|
|
openAIToken.number_of_requests+=1
|
|
print(f"Calling OPENAI. Token has been used {openAIToken.number_of_requests} times, beginning at {openAIToken.begin_time}.")
|
|
openai.api_key = openAIToken.token
|
|
try:
|
|
return openai.Completion.create(model="text-davinci-002", prompt=prompt, temperature=0.9, max_tokens=256)
|
|
except openai.error.RateLimitError:
|
|
openAIToken.is_active = False
|
|
openAIToken.is_expended = True
|
|
openAIToken.end_time = datetime.now()
|
|
print(f"Expended this token! Token lasted {openAIToken.end_time - openAIToken.begin_time}")
|
|
return OpenAIToken.call_open_ai(prompt, session) |