Correcting the retrieval of relevant documents

This commit is contained in:
Josako
2024-06-12 16:15:48 +02:00
parent be311c440b
commit fd510c8fcd
8 changed files with 62 additions and 28 deletions

View File

@@ -4,19 +4,21 @@ from sqlalchemy.exc import SQLAlchemyError
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing import Any, Dict from typing import Any, Dict
from flask import current_app from flask import current_app
from datetime import date
from common.extensions import db from common.extensions import db
from common.models.document import Document, DocumentVersion, Embedding from common.models.document import Document, DocumentVersion
from common.utils.datetime_utils import get_date_in_timezone
class EveAIRetriever(BaseRetriever): class EveAIRetriever(BaseRetriever):
model_variables: Dict[str, Any] = Field(...) model_variables: Dict[str, Any] = Field(...)
tenant_info: Dict[str, Any] = Field(...)
def __init__(self, model_variables: Dict[str, Any]): def __init__(self, model_variables: Dict[str, Any], tenant_info: Dict[str, Any]):
super().__init__() super().__init__()
current_app.logger.debug('Initializing EveAIRetriever') current_app.logger.debug('Initializing EveAIRetriever')
self.model_variables = model_variables self.model_variables = model_variables
self.tenant_info = tenant_info
current_app.logger.debug('EveAIRetriever initialized') current_app.logger.debug('EveAIRetriever initialized')
def _get_relevant_documents(self, query: str): def _get_relevant_documents(self, query: str):
@@ -27,7 +29,7 @@ class EveAIRetriever(BaseRetriever):
k = self.model_variables['k'] k = self.model_variables['k']
try: try:
current_date = date.today() current_date = get_date_in_timezone(self.tenant_info['timezone'])
# Subquery to find the latest version of each document # Subquery to find the latest version of each document
subquery = ( subquery = (
db.session.query( db.session.query(
@@ -53,18 +55,17 @@ class EveAIRetriever(BaseRetriever):
.limit(k) .limit(k)
) )
# Print the generated SQL statement for debugging
current_app.logger.debug("SQL Statement:\n")
current_app.logger.debug(query_obj.statement.compile(compile_kwargs={"literal_binds": True}))
res = query_obj.all() res = query_obj.all()
# current_app.rag_tuning_logger.debug(f'Retrieved {len(res)} relevant documents') if self.tenant_info['rag_tuning']:
# current_app.rag_tuning_logger.debug(f'---------------------------------------') current_app.rag_tuning_logger.debug(f'Retrieved {len(res)} relevant documents')
current_app.rag_tuning_logger.debug(f'---------------------------------------')
result = [] result = []
for doc in res: for doc in res:
# current_app.rag_tuning_logger.debug(f'Document ID: {doc[0].id} - Distance: {doc[1]}\n') if self.tenant_info['rag_tuning']:
# current_app.rag_tuning_logger.debug(f'Chunk: \n {doc[0].chunk}\n\n') current_app.rag_tuning_logger.debug(f'Document ID: {doc[0].id} - Distance: {doc[1]}\n')
current_app.rag_tuning_logger.debug(f'Chunk: \n {doc[0].chunk}\n\n')
result.append(f'SOURCE: {doc[0].id}\n\n{doc[0].chunk}\n\n') result.append(f'SOURCE: {doc[0].id}\n\n{doc[0].chunk}\n\n')
except SQLAlchemyError as e: except SQLAlchemyError as e:

View File

@@ -67,6 +67,7 @@ class Tenant(db.Model):
'website': self.website, 'website': self.website,
'default_language': self.default_language, 'default_language': self.default_language,
'allowed_languages': self.allowed_languages, 'allowed_languages': self.allowed_languages,
'timezone': self.timezone,
'embedding_model': self.embedding_model, 'embedding_model': self.embedding_model,
'llm_model': self.llm_model, 'llm_model': self.llm_model,
'license_start_date': self.license_start_date, 'license_start_date': self.license_start_date,

View File

@@ -0,0 +1,19 @@
from datetime import datetime
import pytz
def get_date_in_timezone(timezone_str):
try:
# Get the timezone object from the string
timezone = pytz.timezone(timezone_str)
# Get the current time in the specified timezone
current_time = datetime.now(timezone)
# Extract the date part
current_date = current_time.date()
return current_date
except Exception as e:
print(f"Error getting date in timezone {timezone_str}: {e}")
return None

View File

@@ -77,6 +77,7 @@ class Config(object):
```{text}```""" ```{text}```"""
GPT3_5_SUMMARY_TEMPLATE = """Write a concise summary of the text in {language}. The text is delimited between triple backquotes. GPT3_5_SUMMARY_TEMPLATE = """Write a concise summary of the text in {language}. The text is delimited between triple backquotes.
```{text}```""" ```{text}```"""
GPT4_RAG_TEMPLATE = """Answer the question based on the following context, delimited between triple backquotes. GPT4_RAG_TEMPLATE = """Answer the question based on the following context, delimited between triple backquotes.
Use the following {language} in your communication, and cite the sources used. Use the following {language} in your communication, and cite the sources used.
If the question cannot be answered using the given context, say "I have insufficient information to answer this question." If the question cannot be answered using the given context, say "I have insufficient information to answer this question."
@@ -92,6 +93,17 @@ class Config(object):
Question: Question:
{question}""" {question}"""
GPT4_HISTORY_TEMPLATE = """You are a helpful assistant that details a question based on a previous context,
in such a way that the question is understandable without the previous context.
The context is a conversation history, with the HUMAN asking questions, the AI answering questions.
The history is delimited between triple backquotes.
Your answer by stating the question in {language}.
History:
```{history}```
Question to be detailed:
{question}"""
# SocketIO settings # SocketIO settings
# SOCKETIO_ASYNC_MODE = 'threading' # SOCKETIO_ASYNC_MODE = 'threading'
SOCKETIO_ASYNC_MODE = 'gevent' SOCKETIO_ASYNC_MODE = 'gevent'

View File

@@ -3,6 +3,8 @@ from flask_wtf import FlaskForm
from wtforms import (StringField, PasswordField, BooleanField, SubmitField, EmailField, IntegerField, DateField, from wtforms import (StringField, PasswordField, BooleanField, SubmitField, EmailField, IntegerField, DateField,
SelectField, SelectMultipleField, FieldList, FormField, FloatField) SelectField, SelectMultipleField, FieldList, FormField, FloatField)
from wtforms.validators import DataRequired, Length, Email, NumberRange, Optional from wtforms.validators import DataRequired, Length, Email, NumberRange, Optional
import pytz
from common.models.user import Role from common.models.user import Role
@@ -12,6 +14,8 @@ class TenantForm(FlaskForm):
# language fields # language fields
default_language = SelectField('Default Language', choices=[], validators=[DataRequired()]) default_language = SelectField('Default Language', choices=[], validators=[DataRequired()])
allowed_languages = SelectMultipleField('Allowed Languages', choices=[], validators=[DataRequired()]) allowed_languages = SelectMultipleField('Allowed Languages', choices=[], validators=[DataRequired()])
# Timezone
timezone = SelectField('Timezone', choices=[], validators=[DataRequired()])
# LLM fields # LLM fields
embedding_model = SelectField('Embedding Model', choices=[], validators=[DataRequired()]) embedding_model = SelectField('Embedding Model', choices=[], validators=[DataRequired()])
llm_model = SelectField('Large Language Model', choices=[], validators=[DataRequired()]) llm_model = SelectField('Large Language Model', choices=[], validators=[DataRequired()])
@@ -44,6 +48,8 @@ class TenantForm(FlaskForm):
# initialise language fields # initialise language fields
self.default_language.choices = [(lang, lang.lower()) for lang in current_app.config['SUPPORTED_LANGUAGES']] self.default_language.choices = [(lang, lang.lower()) for lang in current_app.config['SUPPORTED_LANGUAGES']]
self.allowed_languages.choices = [(lang, lang.lower()) for lang in current_app.config['SUPPORTED_LANGUAGES']] self.allowed_languages.choices = [(lang, lang.lower()) for lang in current_app.config['SUPPORTED_LANGUAGES']]
# initialise timezone
self.timezone.choices = [(tz, tz) for tz in pytz.all_timezones]
# initialise LLM fields # initialise LLM fields
self.embedding_model.choices = [(model, model) for model in current_app.config['SUPPORTED_EMBEDDINGS']] self.embedding_model.choices = [(model, model) for model in current_app.config['SUPPORTED_EMBEDDINGS']]
self.llm_model.choices = [(model, model) for model in current_app.config['SUPPORTED_LLMS']] self.llm_model.choices = [(model, model) for model in current_app.config['SUPPORTED_LLMS']]

View File

@@ -49,6 +49,7 @@ def tenant():
website=form.website.data, website=form.website.data,
default_language=form.default_language.data, default_language=form.default_language.data,
allowed_languages=form.allowed_languages.data, allowed_languages=form.allowed_languages.data,
timezone=form.timezone.data,
embedding_model=form.embedding_model.data, embedding_model=form.embedding_model.data,
llm_model=form.llm_model.data, llm_model=form.llm_model.data,
license_start_date=form.license_start_date.data, license_start_date=form.license_start_date.data,

View File

@@ -1,5 +1,5 @@
from datetime import datetime as dt, timezone as tz from datetime import datetime as dt, timezone as tz
from flask import current_app from flask import current_app, session
from langchain_core.output_parsers import StrOutputParser from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableParallel, RunnablePassthrough from langchain_core.runnables import RunnableParallel, RunnablePassthrough
from langchain.globals import set_debug from langchain.globals import set_debug
@@ -46,9 +46,9 @@ def ask_question(tenant_id, question, language, session_id):
# Ensure we are working in the correct database schema # Ensure we are working in the correct database schema
Database(tenant_id).switch_schema() Database(tenant_id).switch_schema()
# Ensure we have a session to story history
chat_session = ChatSession.query.filter_by(session_id=session_id).first() chat_session = ChatSession.query.filter_by(session_id=session_id).first()
if not chat_session: if not chat_session:
# Initialize a chat_session on the database
try: try:
chat_session = ChatSession() chat_session = ChatSession()
chat_session.session_id = session_id chat_session.session_id = session_id
@@ -66,22 +66,14 @@ def ask_question(tenant_id, question, language, session_id):
new_interaction.question_at = dt.now(tz.utc) new_interaction.question_at = dt.now(tz.utc)
new_interaction.algorithm_used = current_app.config['INTERACTION_ALGORITHMS']['RAG_TENANT']['name'] new_interaction.algorithm_used = current_app.config['INTERACTION_ALGORITHMS']['RAG_TENANT']['name']
# try:
# db.session.add(new_interaction)
# db.session.commit()
# except SQLAlchemyError as e:
# current_app.logger.error(f'ask_question: Error saving interaction to database: {e}')
# raise
current_app.logger.debug(f'ask_question: new_interaction: {new_interaction}')
# Select variables to work with depending on tenant model # Select variables to work with depending on tenant model
model_variables = select_model_variables(tenant) model_variables = select_model_variables(tenant)
tenant_info = tenant.to_dict()
current_app.logger.debug(f'ask_question: model_variables: {model_variables}') # Langchain debugging if required
# set_debug(True)
set_debug(True) retriever = EveAIRetriever(model_variables, tenant_info)
retriever = EveAIRetriever(model_variables)
llm = model_variables['llm'] llm = model_variables['llm']
template = model_variables['rag_template'] template = model_variables['rag_template']
language_template = create_language_template(template, language) language_template = create_language_template(template, language)
@@ -141,7 +133,8 @@ def ask_question(tenant_id, question, language, session_id):
current_app.logger.error(f'ask_question: Error saving interaction to database: {e}') current_app.logger.error(f'ask_question: Error saving interaction to database: {e}')
raise raise
set_debug(False) # Disable langchain debugging if set above.
# set_debug(False)
result['algorithm'] = current_app.config['INTERACTION_ALGORITHMS']['RAG_TENANT']['name'] result['algorithm'] = current_app.config['INTERACTION_ALGORITHMS']['RAG_TENANT']['name']
result['interaction_id'] = new_interaction.id result['interaction_id'] = new_interaction.id

View File

@@ -15,4 +15,5 @@ redis~=5.0.4
itsdangerous~=2.2.0 itsdangerous~=2.2.0
pydantic~=2.7.1 pydantic~=2.7.1
chardet~=5.2.0 chardet~=5.2.0
langcodes~=3.4.0 langcodes~=3.4.0
pytz~=2024.1