diff --git a/common/langchain/EveAIRetriever.py b/common/langchain/EveAIRetriever.py index efbacf3..252138a 100644 --- a/common/langchain/EveAIRetriever.py +++ b/common/langchain/EveAIRetriever.py @@ -4,19 +4,21 @@ from sqlalchemy.exc import SQLAlchemyError from pydantic import BaseModel, Field from typing import Any, Dict from flask import current_app -from datetime import date from common.extensions import db -from common.models.document import Document, DocumentVersion, Embedding +from common.models.document import Document, DocumentVersion +from common.utils.datetime_utils import get_date_in_timezone class EveAIRetriever(BaseRetriever): model_variables: Dict[str, Any] = Field(...) + tenant_info: Dict[str, Any] = Field(...) - def __init__(self, model_variables: Dict[str, Any]): + def __init__(self, model_variables: Dict[str, Any], tenant_info: Dict[str, Any]): super().__init__() current_app.logger.debug('Initializing EveAIRetriever') self.model_variables = model_variables + self.tenant_info = tenant_info current_app.logger.debug('EveAIRetriever initialized') def _get_relevant_documents(self, query: str): @@ -27,7 +29,7 @@ class EveAIRetriever(BaseRetriever): k = self.model_variables['k'] try: - current_date = date.today() + current_date = get_date_in_timezone(self.tenant_info['timezone']) # Subquery to find the latest version of each document subquery = ( db.session.query( @@ -53,18 +55,17 @@ class EveAIRetriever(BaseRetriever): .limit(k) ) - # Print the generated SQL statement for debugging - current_app.logger.debug("SQL Statement:\n") - current_app.logger.debug(query_obj.statement.compile(compile_kwargs={"literal_binds": True})) - res = query_obj.all() - # current_app.rag_tuning_logger.debug(f'Retrieved {len(res)} relevant documents') - # current_app.rag_tuning_logger.debug(f'---------------------------------------') + if self.tenant_info['rag_tuning']: + current_app.rag_tuning_logger.debug(f'Retrieved {len(res)} relevant documents') + current_app.rag_tuning_logger.debug(f'---------------------------------------') + result = [] for doc in res: - # current_app.rag_tuning_logger.debug(f'Document ID: {doc[0].id} - Distance: {doc[1]}\n') - # current_app.rag_tuning_logger.debug(f'Chunk: \n {doc[0].chunk}\n\n') + if self.tenant_info['rag_tuning']: + current_app.rag_tuning_logger.debug(f'Document ID: {doc[0].id} - Distance: {doc[1]}\n') + current_app.rag_tuning_logger.debug(f'Chunk: \n {doc[0].chunk}\n\n') result.append(f'SOURCE: {doc[0].id}\n\n{doc[0].chunk}\n\n') except SQLAlchemyError as e: diff --git a/common/models/user.py b/common/models/user.py index 0d8400f..9d698b1 100644 --- a/common/models/user.py +++ b/common/models/user.py @@ -67,6 +67,7 @@ class Tenant(db.Model): 'website': self.website, 'default_language': self.default_language, 'allowed_languages': self.allowed_languages, + 'timezone': self.timezone, 'embedding_model': self.embedding_model, 'llm_model': self.llm_model, 'license_start_date': self.license_start_date, diff --git a/common/utils/datetime_utils.py b/common/utils/datetime_utils.py new file mode 100644 index 0000000..6013291 --- /dev/null +++ b/common/utils/datetime_utils.py @@ -0,0 +1,19 @@ +from datetime import datetime +import pytz + + +def get_date_in_timezone(timezone_str): + try: + # Get the timezone object from the string + timezone = pytz.timezone(timezone_str) + + # Get the current time in the specified timezone + current_time = datetime.now(timezone) + + # Extract the date part + current_date = current_time.date() + + return current_date + except Exception as e: + print(f"Error getting date in timezone {timezone_str}: {e}") + return None diff --git a/config/config.py b/config/config.py index a53a26f..f00fba3 100644 --- a/config/config.py +++ b/config/config.py @@ -77,6 +77,7 @@ class Config(object): ```{text}```""" GPT3_5_SUMMARY_TEMPLATE = """Write a concise summary of the text in {language}. The text is delimited between triple backquotes. ```{text}```""" + GPT4_RAG_TEMPLATE = """Answer the question based on the following context, delimited between triple backquotes. Use the following {language} in your communication, and cite the sources used. If the question cannot be answered using the given context, say "I have insufficient information to answer this question." @@ -92,6 +93,17 @@ class Config(object): Question: {question}""" + GPT4_HISTORY_TEMPLATE = """You are a helpful assistant that details a question based on a previous context, + in such a way that the question is understandable without the previous context. + The context is a conversation history, with the HUMAN asking questions, the AI answering questions. + The history is delimited between triple backquotes. + Your answer by stating the question in {language}. + History: + ```{history}``` + Question to be detailed: + {question}""" + + # SocketIO settings # SOCKETIO_ASYNC_MODE = 'threading' SOCKETIO_ASYNC_MODE = 'gevent' diff --git a/eveai_app/views/user_forms.py b/eveai_app/views/user_forms.py index 692d1a6..326b250 100644 --- a/eveai_app/views/user_forms.py +++ b/eveai_app/views/user_forms.py @@ -3,6 +3,8 @@ from flask_wtf import FlaskForm from wtforms import (StringField, PasswordField, BooleanField, SubmitField, EmailField, IntegerField, DateField, SelectField, SelectMultipleField, FieldList, FormField, FloatField) from wtforms.validators import DataRequired, Length, Email, NumberRange, Optional +import pytz + from common.models.user import Role @@ -12,6 +14,8 @@ class TenantForm(FlaskForm): # language fields default_language = SelectField('Default Language', choices=[], validators=[DataRequired()]) allowed_languages = SelectMultipleField('Allowed Languages', choices=[], validators=[DataRequired()]) + # Timezone + timezone = SelectField('Timezone', choices=[], validators=[DataRequired()]) # LLM fields embedding_model = SelectField('Embedding Model', choices=[], validators=[DataRequired()]) llm_model = SelectField('Large Language Model', choices=[], validators=[DataRequired()]) @@ -44,6 +48,8 @@ class TenantForm(FlaskForm): # initialise language fields self.default_language.choices = [(lang, lang.lower()) for lang in current_app.config['SUPPORTED_LANGUAGES']] self.allowed_languages.choices = [(lang, lang.lower()) for lang in current_app.config['SUPPORTED_LANGUAGES']] + # initialise timezone + self.timezone.choices = [(tz, tz) for tz in pytz.all_timezones] # initialise LLM fields self.embedding_model.choices = [(model, model) for model in current_app.config['SUPPORTED_EMBEDDINGS']] self.llm_model.choices = [(model, model) for model in current_app.config['SUPPORTED_LLMS']] diff --git a/eveai_app/views/user_views.py b/eveai_app/views/user_views.py index 5d43d9b..d5ae6f0 100644 --- a/eveai_app/views/user_views.py +++ b/eveai_app/views/user_views.py @@ -49,6 +49,7 @@ def tenant(): website=form.website.data, default_language=form.default_language.data, allowed_languages=form.allowed_languages.data, + timezone=form.timezone.data, embedding_model=form.embedding_model.data, llm_model=form.llm_model.data, license_start_date=form.license_start_date.data, diff --git a/eveai_chat_workers/tasks.py b/eveai_chat_workers/tasks.py index 48adbd3..768061f 100644 --- a/eveai_chat_workers/tasks.py +++ b/eveai_chat_workers/tasks.py @@ -1,5 +1,5 @@ from datetime import datetime as dt, timezone as tz -from flask import current_app +from flask import current_app, session from langchain_core.output_parsers import StrOutputParser from langchain_core.runnables import RunnableParallel, RunnablePassthrough from langchain.globals import set_debug @@ -46,9 +46,9 @@ def ask_question(tenant_id, question, language, session_id): # Ensure we are working in the correct database schema Database(tenant_id).switch_schema() + # Ensure we have a session to story history chat_session = ChatSession.query.filter_by(session_id=session_id).first() if not chat_session: - # Initialize a chat_session on the database try: chat_session = ChatSession() chat_session.session_id = session_id @@ -66,22 +66,14 @@ def ask_question(tenant_id, question, language, session_id): new_interaction.question_at = dt.now(tz.utc) new_interaction.algorithm_used = current_app.config['INTERACTION_ALGORITHMS']['RAG_TENANT']['name'] - # try: - # db.session.add(new_interaction) - # db.session.commit() - # except SQLAlchemyError as e: - # current_app.logger.error(f'ask_question: Error saving interaction to database: {e}') - # raise - - current_app.logger.debug(f'ask_question: new_interaction: {new_interaction}') - # Select variables to work with depending on tenant model model_variables = select_model_variables(tenant) + tenant_info = tenant.to_dict() - current_app.logger.debug(f'ask_question: model_variables: {model_variables}') + # Langchain debugging if required + # set_debug(True) - set_debug(True) - retriever = EveAIRetriever(model_variables) + retriever = EveAIRetriever(model_variables, tenant_info) llm = model_variables['llm'] template = model_variables['rag_template'] language_template = create_language_template(template, language) @@ -141,7 +133,8 @@ def ask_question(tenant_id, question, language, session_id): current_app.logger.error(f'ask_question: Error saving interaction to database: {e}') raise - set_debug(False) + # Disable langchain debugging if set above. + # set_debug(False) result['algorithm'] = current_app.config['INTERACTION_ALGORITHMS']['RAG_TENANT']['name'] result['interaction_id'] = new_interaction.id diff --git a/requirements.txt b/requirements.txt index d5cf086..e9b73d0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,4 +15,5 @@ redis~=5.0.4 itsdangerous~=2.2.0 pydantic~=2.7.1 chardet~=5.2.0 -langcodes~=3.4.0 \ No newline at end of file +langcodes~=3.4.0 +pytz~=2024.1 \ No newline at end of file