Correcting the retrieval of relevant documents
This commit is contained in:
@@ -4,19 +4,21 @@ from sqlalchemy.exc import SQLAlchemyError
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
from flask import current_app
|
from flask import current_app
|
||||||
from datetime import date
|
|
||||||
|
|
||||||
from common.extensions import db
|
from common.extensions import db
|
||||||
from common.models.document import Document, DocumentVersion, Embedding
|
from common.models.document import Document, DocumentVersion
|
||||||
|
from common.utils.datetime_utils import get_date_in_timezone
|
||||||
|
|
||||||
|
|
||||||
class EveAIRetriever(BaseRetriever):
|
class EveAIRetriever(BaseRetriever):
|
||||||
model_variables: Dict[str, Any] = Field(...)
|
model_variables: Dict[str, Any] = Field(...)
|
||||||
|
tenant_info: Dict[str, Any] = Field(...)
|
||||||
|
|
||||||
def __init__(self, model_variables: Dict[str, Any]):
|
def __init__(self, model_variables: Dict[str, Any], tenant_info: Dict[str, Any]):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
current_app.logger.debug('Initializing EveAIRetriever')
|
current_app.logger.debug('Initializing EveAIRetriever')
|
||||||
self.model_variables = model_variables
|
self.model_variables = model_variables
|
||||||
|
self.tenant_info = tenant_info
|
||||||
current_app.logger.debug('EveAIRetriever initialized')
|
current_app.logger.debug('EveAIRetriever initialized')
|
||||||
|
|
||||||
def _get_relevant_documents(self, query: str):
|
def _get_relevant_documents(self, query: str):
|
||||||
@@ -27,7 +29,7 @@ class EveAIRetriever(BaseRetriever):
|
|||||||
k = self.model_variables['k']
|
k = self.model_variables['k']
|
||||||
|
|
||||||
try:
|
try:
|
||||||
current_date = date.today()
|
current_date = get_date_in_timezone(self.tenant_info['timezone'])
|
||||||
# Subquery to find the latest version of each document
|
# Subquery to find the latest version of each document
|
||||||
subquery = (
|
subquery = (
|
||||||
db.session.query(
|
db.session.query(
|
||||||
@@ -53,18 +55,17 @@ class EveAIRetriever(BaseRetriever):
|
|||||||
.limit(k)
|
.limit(k)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Print the generated SQL statement for debugging
|
|
||||||
current_app.logger.debug("SQL Statement:\n")
|
|
||||||
current_app.logger.debug(query_obj.statement.compile(compile_kwargs={"literal_binds": True}))
|
|
||||||
|
|
||||||
res = query_obj.all()
|
res = query_obj.all()
|
||||||
|
|
||||||
# current_app.rag_tuning_logger.debug(f'Retrieved {len(res)} relevant documents')
|
if self.tenant_info['rag_tuning']:
|
||||||
# current_app.rag_tuning_logger.debug(f'---------------------------------------')
|
current_app.rag_tuning_logger.debug(f'Retrieved {len(res)} relevant documents')
|
||||||
|
current_app.rag_tuning_logger.debug(f'---------------------------------------')
|
||||||
|
|
||||||
result = []
|
result = []
|
||||||
for doc in res:
|
for doc in res:
|
||||||
# current_app.rag_tuning_logger.debug(f'Document ID: {doc[0].id} - Distance: {doc[1]}\n')
|
if self.tenant_info['rag_tuning']:
|
||||||
# current_app.rag_tuning_logger.debug(f'Chunk: \n {doc[0].chunk}\n\n')
|
current_app.rag_tuning_logger.debug(f'Document ID: {doc[0].id} - Distance: {doc[1]}\n')
|
||||||
|
current_app.rag_tuning_logger.debug(f'Chunk: \n {doc[0].chunk}\n\n')
|
||||||
result.append(f'SOURCE: {doc[0].id}\n\n{doc[0].chunk}\n\n')
|
result.append(f'SOURCE: {doc[0].id}\n\n{doc[0].chunk}\n\n')
|
||||||
|
|
||||||
except SQLAlchemyError as e:
|
except SQLAlchemyError as e:
|
||||||
|
|||||||
@@ -67,6 +67,7 @@ class Tenant(db.Model):
|
|||||||
'website': self.website,
|
'website': self.website,
|
||||||
'default_language': self.default_language,
|
'default_language': self.default_language,
|
||||||
'allowed_languages': self.allowed_languages,
|
'allowed_languages': self.allowed_languages,
|
||||||
|
'timezone': self.timezone,
|
||||||
'embedding_model': self.embedding_model,
|
'embedding_model': self.embedding_model,
|
||||||
'llm_model': self.llm_model,
|
'llm_model': self.llm_model,
|
||||||
'license_start_date': self.license_start_date,
|
'license_start_date': self.license_start_date,
|
||||||
|
|||||||
19
common/utils/datetime_utils.py
Normal file
19
common/utils/datetime_utils.py
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
import pytz
|
||||||
|
|
||||||
|
|
||||||
|
def get_date_in_timezone(timezone_str):
|
||||||
|
try:
|
||||||
|
# Get the timezone object from the string
|
||||||
|
timezone = pytz.timezone(timezone_str)
|
||||||
|
|
||||||
|
# Get the current time in the specified timezone
|
||||||
|
current_time = datetime.now(timezone)
|
||||||
|
|
||||||
|
# Extract the date part
|
||||||
|
current_date = current_time.date()
|
||||||
|
|
||||||
|
return current_date
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error getting date in timezone {timezone_str}: {e}")
|
||||||
|
return None
|
||||||
@@ -77,6 +77,7 @@ class Config(object):
|
|||||||
```{text}```"""
|
```{text}```"""
|
||||||
GPT3_5_SUMMARY_TEMPLATE = """Write a concise summary of the text in {language}. The text is delimited between triple backquotes.
|
GPT3_5_SUMMARY_TEMPLATE = """Write a concise summary of the text in {language}. The text is delimited between triple backquotes.
|
||||||
```{text}```"""
|
```{text}```"""
|
||||||
|
|
||||||
GPT4_RAG_TEMPLATE = """Answer the question based on the following context, delimited between triple backquotes.
|
GPT4_RAG_TEMPLATE = """Answer the question based on the following context, delimited between triple backquotes.
|
||||||
Use the following {language} in your communication, and cite the sources used.
|
Use the following {language} in your communication, and cite the sources used.
|
||||||
If the question cannot be answered using the given context, say "I have insufficient information to answer this question."
|
If the question cannot be answered using the given context, say "I have insufficient information to answer this question."
|
||||||
@@ -92,6 +93,17 @@ class Config(object):
|
|||||||
Question:
|
Question:
|
||||||
{question}"""
|
{question}"""
|
||||||
|
|
||||||
|
GPT4_HISTORY_TEMPLATE = """You are a helpful assistant that details a question based on a previous context,
|
||||||
|
in such a way that the question is understandable without the previous context.
|
||||||
|
The context is a conversation history, with the HUMAN asking questions, the AI answering questions.
|
||||||
|
The history is delimited between triple backquotes.
|
||||||
|
Your answer by stating the question in {language}.
|
||||||
|
History:
|
||||||
|
```{history}```
|
||||||
|
Question to be detailed:
|
||||||
|
{question}"""
|
||||||
|
|
||||||
|
|
||||||
# SocketIO settings
|
# SocketIO settings
|
||||||
# SOCKETIO_ASYNC_MODE = 'threading'
|
# SOCKETIO_ASYNC_MODE = 'threading'
|
||||||
SOCKETIO_ASYNC_MODE = 'gevent'
|
SOCKETIO_ASYNC_MODE = 'gevent'
|
||||||
|
|||||||
@@ -3,6 +3,8 @@ from flask_wtf import FlaskForm
|
|||||||
from wtforms import (StringField, PasswordField, BooleanField, SubmitField, EmailField, IntegerField, DateField,
|
from wtforms import (StringField, PasswordField, BooleanField, SubmitField, EmailField, IntegerField, DateField,
|
||||||
SelectField, SelectMultipleField, FieldList, FormField, FloatField)
|
SelectField, SelectMultipleField, FieldList, FormField, FloatField)
|
||||||
from wtforms.validators import DataRequired, Length, Email, NumberRange, Optional
|
from wtforms.validators import DataRequired, Length, Email, NumberRange, Optional
|
||||||
|
import pytz
|
||||||
|
|
||||||
from common.models.user import Role
|
from common.models.user import Role
|
||||||
|
|
||||||
|
|
||||||
@@ -12,6 +14,8 @@ class TenantForm(FlaskForm):
|
|||||||
# language fields
|
# language fields
|
||||||
default_language = SelectField('Default Language', choices=[], validators=[DataRequired()])
|
default_language = SelectField('Default Language', choices=[], validators=[DataRequired()])
|
||||||
allowed_languages = SelectMultipleField('Allowed Languages', choices=[], validators=[DataRequired()])
|
allowed_languages = SelectMultipleField('Allowed Languages', choices=[], validators=[DataRequired()])
|
||||||
|
# Timezone
|
||||||
|
timezone = SelectField('Timezone', choices=[], validators=[DataRequired()])
|
||||||
# LLM fields
|
# LLM fields
|
||||||
embedding_model = SelectField('Embedding Model', choices=[], validators=[DataRequired()])
|
embedding_model = SelectField('Embedding Model', choices=[], validators=[DataRequired()])
|
||||||
llm_model = SelectField('Large Language Model', choices=[], validators=[DataRequired()])
|
llm_model = SelectField('Large Language Model', choices=[], validators=[DataRequired()])
|
||||||
@@ -44,6 +48,8 @@ class TenantForm(FlaskForm):
|
|||||||
# initialise language fields
|
# initialise language fields
|
||||||
self.default_language.choices = [(lang, lang.lower()) for lang in current_app.config['SUPPORTED_LANGUAGES']]
|
self.default_language.choices = [(lang, lang.lower()) for lang in current_app.config['SUPPORTED_LANGUAGES']]
|
||||||
self.allowed_languages.choices = [(lang, lang.lower()) for lang in current_app.config['SUPPORTED_LANGUAGES']]
|
self.allowed_languages.choices = [(lang, lang.lower()) for lang in current_app.config['SUPPORTED_LANGUAGES']]
|
||||||
|
# initialise timezone
|
||||||
|
self.timezone.choices = [(tz, tz) for tz in pytz.all_timezones]
|
||||||
# initialise LLM fields
|
# initialise LLM fields
|
||||||
self.embedding_model.choices = [(model, model) for model in current_app.config['SUPPORTED_EMBEDDINGS']]
|
self.embedding_model.choices = [(model, model) for model in current_app.config['SUPPORTED_EMBEDDINGS']]
|
||||||
self.llm_model.choices = [(model, model) for model in current_app.config['SUPPORTED_LLMS']]
|
self.llm_model.choices = [(model, model) for model in current_app.config['SUPPORTED_LLMS']]
|
||||||
|
|||||||
@@ -49,6 +49,7 @@ def tenant():
|
|||||||
website=form.website.data,
|
website=form.website.data,
|
||||||
default_language=form.default_language.data,
|
default_language=form.default_language.data,
|
||||||
allowed_languages=form.allowed_languages.data,
|
allowed_languages=form.allowed_languages.data,
|
||||||
|
timezone=form.timezone.data,
|
||||||
embedding_model=form.embedding_model.data,
|
embedding_model=form.embedding_model.data,
|
||||||
llm_model=form.llm_model.data,
|
llm_model=form.llm_model.data,
|
||||||
license_start_date=form.license_start_date.data,
|
license_start_date=form.license_start_date.data,
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from datetime import datetime as dt, timezone as tz
|
from datetime import datetime as dt, timezone as tz
|
||||||
from flask import current_app
|
from flask import current_app, session
|
||||||
from langchain_core.output_parsers import StrOutputParser
|
from langchain_core.output_parsers import StrOutputParser
|
||||||
from langchain_core.runnables import RunnableParallel, RunnablePassthrough
|
from langchain_core.runnables import RunnableParallel, RunnablePassthrough
|
||||||
from langchain.globals import set_debug
|
from langchain.globals import set_debug
|
||||||
@@ -46,9 +46,9 @@ def ask_question(tenant_id, question, language, session_id):
|
|||||||
# Ensure we are working in the correct database schema
|
# Ensure we are working in the correct database schema
|
||||||
Database(tenant_id).switch_schema()
|
Database(tenant_id).switch_schema()
|
||||||
|
|
||||||
|
# Ensure we have a session to story history
|
||||||
chat_session = ChatSession.query.filter_by(session_id=session_id).first()
|
chat_session = ChatSession.query.filter_by(session_id=session_id).first()
|
||||||
if not chat_session:
|
if not chat_session:
|
||||||
# Initialize a chat_session on the database
|
|
||||||
try:
|
try:
|
||||||
chat_session = ChatSession()
|
chat_session = ChatSession()
|
||||||
chat_session.session_id = session_id
|
chat_session.session_id = session_id
|
||||||
@@ -66,22 +66,14 @@ def ask_question(tenant_id, question, language, session_id):
|
|||||||
new_interaction.question_at = dt.now(tz.utc)
|
new_interaction.question_at = dt.now(tz.utc)
|
||||||
new_interaction.algorithm_used = current_app.config['INTERACTION_ALGORITHMS']['RAG_TENANT']['name']
|
new_interaction.algorithm_used = current_app.config['INTERACTION_ALGORITHMS']['RAG_TENANT']['name']
|
||||||
|
|
||||||
# try:
|
|
||||||
# db.session.add(new_interaction)
|
|
||||||
# db.session.commit()
|
|
||||||
# except SQLAlchemyError as e:
|
|
||||||
# current_app.logger.error(f'ask_question: Error saving interaction to database: {e}')
|
|
||||||
# raise
|
|
||||||
|
|
||||||
current_app.logger.debug(f'ask_question: new_interaction: {new_interaction}')
|
|
||||||
|
|
||||||
# Select variables to work with depending on tenant model
|
# Select variables to work with depending on tenant model
|
||||||
model_variables = select_model_variables(tenant)
|
model_variables = select_model_variables(tenant)
|
||||||
|
tenant_info = tenant.to_dict()
|
||||||
|
|
||||||
current_app.logger.debug(f'ask_question: model_variables: {model_variables}')
|
# Langchain debugging if required
|
||||||
|
# set_debug(True)
|
||||||
|
|
||||||
set_debug(True)
|
retriever = EveAIRetriever(model_variables, tenant_info)
|
||||||
retriever = EveAIRetriever(model_variables)
|
|
||||||
llm = model_variables['llm']
|
llm = model_variables['llm']
|
||||||
template = model_variables['rag_template']
|
template = model_variables['rag_template']
|
||||||
language_template = create_language_template(template, language)
|
language_template = create_language_template(template, language)
|
||||||
@@ -141,7 +133,8 @@ def ask_question(tenant_id, question, language, session_id):
|
|||||||
current_app.logger.error(f'ask_question: Error saving interaction to database: {e}')
|
current_app.logger.error(f'ask_question: Error saving interaction to database: {e}')
|
||||||
raise
|
raise
|
||||||
|
|
||||||
set_debug(False)
|
# Disable langchain debugging if set above.
|
||||||
|
# set_debug(False)
|
||||||
|
|
||||||
result['algorithm'] = current_app.config['INTERACTION_ALGORITHMS']['RAG_TENANT']['name']
|
result['algorithm'] = current_app.config['INTERACTION_ALGORITHMS']['RAG_TENANT']['name']
|
||||||
result['interaction_id'] = new_interaction.id
|
result['interaction_id'] = new_interaction.id
|
||||||
|
|||||||
@@ -16,3 +16,4 @@ itsdangerous~=2.2.0
|
|||||||
pydantic~=2.7.1
|
pydantic~=2.7.1
|
||||||
chardet~=5.2.0
|
chardet~=5.2.0
|
||||||
langcodes~=3.4.0
|
langcodes~=3.4.0
|
||||||
|
pytz~=2024.1
|
||||||
Reference in New Issue
Block a user