Correcting the retrieval of relevant documents
This commit is contained in:
@@ -4,19 +4,21 @@ from sqlalchemy.exc import SQLAlchemyError
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Any, Dict
|
||||
from flask import current_app
|
||||
from datetime import date
|
||||
|
||||
from common.extensions import db
|
||||
from common.models.document import Document, DocumentVersion, Embedding
|
||||
from common.models.document import Document, DocumentVersion
|
||||
from common.utils.datetime_utils import get_date_in_timezone
|
||||
|
||||
|
||||
class EveAIRetriever(BaseRetriever):
|
||||
model_variables: Dict[str, Any] = Field(...)
|
||||
tenant_info: Dict[str, Any] = Field(...)
|
||||
|
||||
def __init__(self, model_variables: Dict[str, Any]):
|
||||
def __init__(self, model_variables: Dict[str, Any], tenant_info: Dict[str, Any]):
|
||||
super().__init__()
|
||||
current_app.logger.debug('Initializing EveAIRetriever')
|
||||
self.model_variables = model_variables
|
||||
self.tenant_info = tenant_info
|
||||
current_app.logger.debug('EveAIRetriever initialized')
|
||||
|
||||
def _get_relevant_documents(self, query: str):
|
||||
@@ -27,7 +29,7 @@ class EveAIRetriever(BaseRetriever):
|
||||
k = self.model_variables['k']
|
||||
|
||||
try:
|
||||
current_date = date.today()
|
||||
current_date = get_date_in_timezone(self.tenant_info['timezone'])
|
||||
# Subquery to find the latest version of each document
|
||||
subquery = (
|
||||
db.session.query(
|
||||
@@ -53,18 +55,17 @@ class EveAIRetriever(BaseRetriever):
|
||||
.limit(k)
|
||||
)
|
||||
|
||||
# Print the generated SQL statement for debugging
|
||||
current_app.logger.debug("SQL Statement:\n")
|
||||
current_app.logger.debug(query_obj.statement.compile(compile_kwargs={"literal_binds": True}))
|
||||
|
||||
res = query_obj.all()
|
||||
|
||||
# current_app.rag_tuning_logger.debug(f'Retrieved {len(res)} relevant documents')
|
||||
# current_app.rag_tuning_logger.debug(f'---------------------------------------')
|
||||
if self.tenant_info['rag_tuning']:
|
||||
current_app.rag_tuning_logger.debug(f'Retrieved {len(res)} relevant documents')
|
||||
current_app.rag_tuning_logger.debug(f'---------------------------------------')
|
||||
|
||||
result = []
|
||||
for doc in res:
|
||||
# current_app.rag_tuning_logger.debug(f'Document ID: {doc[0].id} - Distance: {doc[1]}\n')
|
||||
# current_app.rag_tuning_logger.debug(f'Chunk: \n {doc[0].chunk}\n\n')
|
||||
if self.tenant_info['rag_tuning']:
|
||||
current_app.rag_tuning_logger.debug(f'Document ID: {doc[0].id} - Distance: {doc[1]}\n')
|
||||
current_app.rag_tuning_logger.debug(f'Chunk: \n {doc[0].chunk}\n\n')
|
||||
result.append(f'SOURCE: {doc[0].id}\n\n{doc[0].chunk}\n\n')
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
|
||||
@@ -67,6 +67,7 @@ class Tenant(db.Model):
|
||||
'website': self.website,
|
||||
'default_language': self.default_language,
|
||||
'allowed_languages': self.allowed_languages,
|
||||
'timezone': self.timezone,
|
||||
'embedding_model': self.embedding_model,
|
||||
'llm_model': self.llm_model,
|
||||
'license_start_date': self.license_start_date,
|
||||
|
||||
19
common/utils/datetime_utils.py
Normal file
19
common/utils/datetime_utils.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from datetime import datetime
|
||||
import pytz
|
||||
|
||||
|
||||
def get_date_in_timezone(timezone_str):
|
||||
try:
|
||||
# Get the timezone object from the string
|
||||
timezone = pytz.timezone(timezone_str)
|
||||
|
||||
# Get the current time in the specified timezone
|
||||
current_time = datetime.now(timezone)
|
||||
|
||||
# Extract the date part
|
||||
current_date = current_time.date()
|
||||
|
||||
return current_date
|
||||
except Exception as e:
|
||||
print(f"Error getting date in timezone {timezone_str}: {e}")
|
||||
return None
|
||||
@@ -77,6 +77,7 @@ class Config(object):
|
||||
```{text}```"""
|
||||
GPT3_5_SUMMARY_TEMPLATE = """Write a concise summary of the text in {language}. The text is delimited between triple backquotes.
|
||||
```{text}```"""
|
||||
|
||||
GPT4_RAG_TEMPLATE = """Answer the question based on the following context, delimited between triple backquotes.
|
||||
Use the following {language} in your communication, and cite the sources used.
|
||||
If the question cannot be answered using the given context, say "I have insufficient information to answer this question."
|
||||
@@ -92,6 +93,17 @@ class Config(object):
|
||||
Question:
|
||||
{question}"""
|
||||
|
||||
GPT4_HISTORY_TEMPLATE = """You are a helpful assistant that details a question based on a previous context,
|
||||
in such a way that the question is understandable without the previous context.
|
||||
The context is a conversation history, with the HUMAN asking questions, the AI answering questions.
|
||||
The history is delimited between triple backquotes.
|
||||
Your answer by stating the question in {language}.
|
||||
History:
|
||||
```{history}```
|
||||
Question to be detailed:
|
||||
{question}"""
|
||||
|
||||
|
||||
# SocketIO settings
|
||||
# SOCKETIO_ASYNC_MODE = 'threading'
|
||||
SOCKETIO_ASYNC_MODE = 'gevent'
|
||||
|
||||
@@ -3,6 +3,8 @@ from flask_wtf import FlaskForm
|
||||
from wtforms import (StringField, PasswordField, BooleanField, SubmitField, EmailField, IntegerField, DateField,
|
||||
SelectField, SelectMultipleField, FieldList, FormField, FloatField)
|
||||
from wtforms.validators import DataRequired, Length, Email, NumberRange, Optional
|
||||
import pytz
|
||||
|
||||
from common.models.user import Role
|
||||
|
||||
|
||||
@@ -12,6 +14,8 @@ class TenantForm(FlaskForm):
|
||||
# language fields
|
||||
default_language = SelectField('Default Language', choices=[], validators=[DataRequired()])
|
||||
allowed_languages = SelectMultipleField('Allowed Languages', choices=[], validators=[DataRequired()])
|
||||
# Timezone
|
||||
timezone = SelectField('Timezone', choices=[], validators=[DataRequired()])
|
||||
# LLM fields
|
||||
embedding_model = SelectField('Embedding Model', choices=[], validators=[DataRequired()])
|
||||
llm_model = SelectField('Large Language Model', choices=[], validators=[DataRequired()])
|
||||
@@ -44,6 +48,8 @@ class TenantForm(FlaskForm):
|
||||
# initialise language fields
|
||||
self.default_language.choices = [(lang, lang.lower()) for lang in current_app.config['SUPPORTED_LANGUAGES']]
|
||||
self.allowed_languages.choices = [(lang, lang.lower()) for lang in current_app.config['SUPPORTED_LANGUAGES']]
|
||||
# initialise timezone
|
||||
self.timezone.choices = [(tz, tz) for tz in pytz.all_timezones]
|
||||
# initialise LLM fields
|
||||
self.embedding_model.choices = [(model, model) for model in current_app.config['SUPPORTED_EMBEDDINGS']]
|
||||
self.llm_model.choices = [(model, model) for model in current_app.config['SUPPORTED_LLMS']]
|
||||
|
||||
@@ -49,6 +49,7 @@ def tenant():
|
||||
website=form.website.data,
|
||||
default_language=form.default_language.data,
|
||||
allowed_languages=form.allowed_languages.data,
|
||||
timezone=form.timezone.data,
|
||||
embedding_model=form.embedding_model.data,
|
||||
llm_model=form.llm_model.data,
|
||||
license_start_date=form.license_start_date.data,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from datetime import datetime as dt, timezone as tz
|
||||
from flask import current_app
|
||||
from flask import current_app, session
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.runnables import RunnableParallel, RunnablePassthrough
|
||||
from langchain.globals import set_debug
|
||||
@@ -46,9 +46,9 @@ def ask_question(tenant_id, question, language, session_id):
|
||||
# Ensure we are working in the correct database schema
|
||||
Database(tenant_id).switch_schema()
|
||||
|
||||
# Ensure we have a session to story history
|
||||
chat_session = ChatSession.query.filter_by(session_id=session_id).first()
|
||||
if not chat_session:
|
||||
# Initialize a chat_session on the database
|
||||
try:
|
||||
chat_session = ChatSession()
|
||||
chat_session.session_id = session_id
|
||||
@@ -66,22 +66,14 @@ def ask_question(tenant_id, question, language, session_id):
|
||||
new_interaction.question_at = dt.now(tz.utc)
|
||||
new_interaction.algorithm_used = current_app.config['INTERACTION_ALGORITHMS']['RAG_TENANT']['name']
|
||||
|
||||
# try:
|
||||
# db.session.add(new_interaction)
|
||||
# db.session.commit()
|
||||
# except SQLAlchemyError as e:
|
||||
# current_app.logger.error(f'ask_question: Error saving interaction to database: {e}')
|
||||
# raise
|
||||
|
||||
current_app.logger.debug(f'ask_question: new_interaction: {new_interaction}')
|
||||
|
||||
# Select variables to work with depending on tenant model
|
||||
model_variables = select_model_variables(tenant)
|
||||
tenant_info = tenant.to_dict()
|
||||
|
||||
current_app.logger.debug(f'ask_question: model_variables: {model_variables}')
|
||||
# Langchain debugging if required
|
||||
# set_debug(True)
|
||||
|
||||
set_debug(True)
|
||||
retriever = EveAIRetriever(model_variables)
|
||||
retriever = EveAIRetriever(model_variables, tenant_info)
|
||||
llm = model_variables['llm']
|
||||
template = model_variables['rag_template']
|
||||
language_template = create_language_template(template, language)
|
||||
@@ -141,7 +133,8 @@ def ask_question(tenant_id, question, language, session_id):
|
||||
current_app.logger.error(f'ask_question: Error saving interaction to database: {e}')
|
||||
raise
|
||||
|
||||
set_debug(False)
|
||||
# Disable langchain debugging if set above.
|
||||
# set_debug(False)
|
||||
|
||||
result['algorithm'] = current_app.config['INTERACTION_ALGORITHMS']['RAG_TENANT']['name']
|
||||
result['interaction_id'] = new_interaction.id
|
||||
|
||||
@@ -15,4 +15,5 @@ redis~=5.0.4
|
||||
itsdangerous~=2.2.0
|
||||
pydantic~=2.7.1
|
||||
chardet~=5.2.0
|
||||
langcodes~=3.4.0
|
||||
langcodes~=3.4.0
|
||||
pytz~=2024.1
|
||||
Reference in New Issue
Block a user