Include history to improve query if required.
This commit is contained in:
44
common/langchain/EveAIHistoryRetriever.py
Normal file
44
common/langchain/EveAIHistoryRetriever.py
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
from langchain_core.retrievers import BaseRetriever
|
||||||
|
from sqlalchemy import asc
|
||||||
|
from sqlalchemy.exc import SQLAlchemyError
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from typing import Any, Dict
|
||||||
|
from flask import current_app
|
||||||
|
|
||||||
|
from common.extensions import db
|
||||||
|
from common.models.interaction import ChatSession, Interaction
|
||||||
|
from common.utils.datetime_utils import get_date_in_timezone
|
||||||
|
|
||||||
|
|
||||||
|
class EveAIHistoryRetriever(BaseRetriever):
|
||||||
|
model_variables: Dict[str, Any] = Field(...)
|
||||||
|
session_id: str = Field(...)
|
||||||
|
|
||||||
|
def __init__(self, model_variables: Dict[str, Any], session_id: str):
|
||||||
|
super().__init__()
|
||||||
|
self.model_variables = model_variables
|
||||||
|
self.session_id = session_id
|
||||||
|
|
||||||
|
def _get_relevant_documents(self, query: str):
|
||||||
|
current_app.logger.debug(f'Retrieving history of interactions for query: {query}')
|
||||||
|
|
||||||
|
try:
|
||||||
|
query_obj = (
|
||||||
|
db.session.query(Interaction)
|
||||||
|
.join(ChatSession, Interaction.chat_session_id == ChatSession.id)
|
||||||
|
.filter(ChatSession.session_id == self.session_id)
|
||||||
|
.order_by(asc(Interaction.id))
|
||||||
|
)
|
||||||
|
|
||||||
|
interactions = query_obj.all()
|
||||||
|
|
||||||
|
result = []
|
||||||
|
for interaction in interactions:
|
||||||
|
result.append(f'HUMAN:\n{interaction.detailed_question}\n\nAI: \n{interaction.answer}\n\n')
|
||||||
|
|
||||||
|
except SQLAlchemyError as e:
|
||||||
|
current_app.logger.error(f'Error retrieving history of interactions: {e}')
|
||||||
|
db.session.rollback()
|
||||||
|
return []
|
||||||
|
|
||||||
|
return result
|
||||||
@@ -16,10 +16,8 @@ class EveAIRetriever(BaseRetriever):
|
|||||||
|
|
||||||
def __init__(self, model_variables: Dict[str, Any], tenant_info: Dict[str, Any]):
|
def __init__(self, model_variables: Dict[str, Any], tenant_info: Dict[str, Any]):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
current_app.logger.debug('Initializing EveAIRetriever')
|
|
||||||
self.model_variables = model_variables
|
self.model_variables = model_variables
|
||||||
self.tenant_info = tenant_info
|
self.tenant_info = tenant_info
|
||||||
current_app.logger.debug('EveAIRetriever initialized')
|
|
||||||
|
|
||||||
def _get_relevant_documents(self, query: str):
|
def _get_relevant_documents(self, query: str):
|
||||||
current_app.logger.debug(f'Retrieving relevant documents for query: {query}')
|
current_app.logger.debug(f'Retrieving relevant documents for query: {query}')
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ class Interaction(db.Model):
|
|||||||
id = db.Column(db.Integer, primary_key=True)
|
id = db.Column(db.Integer, primary_key=True)
|
||||||
chat_session_id = db.Column(db.Integer, db.ForeignKey(ChatSession.id), nullable=False)
|
chat_session_id = db.Column(db.Integer, db.ForeignKey(ChatSession.id), nullable=False)
|
||||||
question = db.Column(db.Text, nullable=False)
|
question = db.Column(db.Text, nullable=False)
|
||||||
|
detailed_question = db.Column(db.Text, nullable=True)
|
||||||
answer = db.Column(db.Text, nullable=True)
|
answer = db.Column(db.Text, nullable=True)
|
||||||
algorithm_used = db.Column(db.String(20), nullable=True)
|
algorithm_used = db.Column(db.String(20), nullable=True)
|
||||||
language = db.Column(db.String(2), nullable=False)
|
language = db.Column(db.String(2), nullable=False)
|
||||||
@@ -28,6 +29,7 @@ class Interaction(db.Model):
|
|||||||
|
|
||||||
# Timing information
|
# Timing information
|
||||||
question_at = db.Column(db.DateTime, nullable=False)
|
question_at = db.Column(db.DateTime, nullable=False)
|
||||||
|
detailed_question_at = db.Column(db.DateTime, nullable=True)
|
||||||
answer_at = db.Column(db.DateTime, nullable=True)
|
answer_at = db.Column(db.DateTime, nullable=True)
|
||||||
|
|
||||||
# Relations
|
# Relations
|
||||||
|
|||||||
@@ -103,15 +103,18 @@ def select_model_variables(tenant):
|
|||||||
case 'gpt-4-turbo' | 'gpt-4o':
|
case 'gpt-4-turbo' | 'gpt-4o':
|
||||||
summary_template = current_app.config.get('GPT4_SUMMARY_TEMPLATE')
|
summary_template = current_app.config.get('GPT4_SUMMARY_TEMPLATE')
|
||||||
rag_template = current_app.config.get('GPT4_RAG_TEMPLATE')
|
rag_template = current_app.config.get('GPT4_RAG_TEMPLATE')
|
||||||
|
history_template = current_app.config.get('GPT4_HISTORY_TEMPLATE')
|
||||||
tool_calling_supported = True
|
tool_calling_supported = True
|
||||||
case 'gpt-3-5-turbo':
|
case 'gpt-3-5-turbo':
|
||||||
summary_template = current_app.config.get('GPT3_5_SUMMARY_TEMPLATE')
|
summary_template = current_app.config.get('GPT3_5_SUMMARY_TEMPLATE')
|
||||||
rag_template = current_app.config.get('GPT3_5_RAG_TEMPLATE')
|
rag_template = current_app.config.get('GPT3_5_RAG_TEMPLATE')
|
||||||
|
history_template = current_app.config.get('GPT3_5_HISTORY_TEMPLATE')
|
||||||
case _:
|
case _:
|
||||||
raise Exception(f'Error setting model variables for tenant {tenant.id} '
|
raise Exception(f'Error setting model variables for tenant {tenant.id} '
|
||||||
f'error: Invalid chat model')
|
f'error: Invalid chat model')
|
||||||
model_variables['summary_template'] = summary_template
|
model_variables['summary_template'] = summary_template
|
||||||
model_variables['rag_template'] = rag_template
|
model_variables['rag_template'] = rag_template
|
||||||
|
model_variables['history_template'] = history_template
|
||||||
if tool_calling_supported:
|
if tool_calling_supported:
|
||||||
model_variables['cited_answer_cls'] = CitedAnswer
|
model_variables['cited_answer_cls'] = CitedAnswer
|
||||||
case _:
|
case _:
|
||||||
|
|||||||
@@ -23,6 +23,27 @@ from common.extensions import db
|
|||||||
from common.utils.celery_utils import current_celery
|
from common.utils.celery_utils import current_celery
|
||||||
from common.utils.model_utils import select_model_variables, create_language_template
|
from common.utils.model_utils import select_model_variables, create_language_template
|
||||||
from common.langchain.EveAIRetriever import EveAIRetriever
|
from common.langchain.EveAIRetriever import EveAIRetriever
|
||||||
|
from common.langchain.EveAIHistoryRetriever import EveAIHistoryRetriever
|
||||||
|
|
||||||
|
|
||||||
|
def detail_question(question, language, model_variables, session_id):
|
||||||
|
retriever = EveAIHistoryRetriever(model_variables, session_id)
|
||||||
|
llm = model_variables['llm']
|
||||||
|
template = model_variables['history_template']
|
||||||
|
language_template = create_language_template(template, language)
|
||||||
|
history_prompt = ChatPromptTemplate.from_template(language_template)
|
||||||
|
setup_and_retrieval = RunnableParallel({"history": retriever,"question": RunnablePassthrough()})
|
||||||
|
output_parser = StrOutputParser()
|
||||||
|
|
||||||
|
chain = setup_and_retrieval | history_prompt | llm | output_parser
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
answer = chain.invoke(question)
|
||||||
|
return answer
|
||||||
|
except LangChainException as e:
|
||||||
|
current_app.logger.error(f'Error detailing question: {e}')
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
@current_celery.task(name='ask_question', queue='llm_interactions')
|
@current_celery.task(name='ask_question', queue='llm_interactions')
|
||||||
@@ -73,6 +94,11 @@ def ask_question(tenant_id, question, language, session_id):
|
|||||||
# Langchain debugging if required
|
# Langchain debugging if required
|
||||||
# set_debug(True)
|
# set_debug(True)
|
||||||
|
|
||||||
|
detailed_question = detail_question(question, language, model_variables, session_id)
|
||||||
|
current_app.logger.debug(f'Original question:\n {question}\n\nDetailed question: {detailed_question}')
|
||||||
|
new_interaction.detailed_question = detailed_question
|
||||||
|
new_interaction.detailed_question_at = dt.now(tz.utc)
|
||||||
|
|
||||||
retriever = EveAIRetriever(model_variables, tenant_info)
|
retriever = EveAIRetriever(model_variables, tenant_info)
|
||||||
llm = model_variables['llm']
|
llm = model_variables['llm']
|
||||||
template = model_variables['rag_template']
|
template = model_variables['rag_template']
|
||||||
@@ -87,7 +113,7 @@ def ask_question(tenant_id, question, language, session_id):
|
|||||||
chain = setup_and_retrieval | rag_prompt | llm | output_parser
|
chain = setup_and_retrieval | rag_prompt | llm | output_parser
|
||||||
|
|
||||||
# Invoke the chain with the actual question
|
# Invoke the chain with the actual question
|
||||||
answer = chain.invoke(question)
|
answer = chain.invoke(detailed_question)
|
||||||
new_interaction.answer = answer
|
new_interaction.answer = answer
|
||||||
result = {
|
result = {
|
||||||
'answer': answer,
|
'answer': answer,
|
||||||
@@ -99,7 +125,7 @@ def ask_question(tenant_id, question, language, session_id):
|
|||||||
|
|
||||||
chain = setup_and_retrieval | rag_prompt | structured_llm
|
chain = setup_and_retrieval | rag_prompt | structured_llm
|
||||||
|
|
||||||
result = chain.invoke(question).dict()
|
result = chain.invoke(detailed_question).dict()
|
||||||
current_app.logger.debug(f'ask_question: result answer: {result['answer']}')
|
current_app.logger.debug(f'ask_question: result answer: {result['answer']}')
|
||||||
current_app.logger.debug(f'ask_question: result citations: {result["citations"]}')
|
current_app.logger.debug(f'ask_question: result citations: {result["citations"]}')
|
||||||
new_interaction.answer = result['answer']
|
new_interaction.answer = result['answer']
|
||||||
|
|||||||
Reference in New Issue
Block a user