From 24a3747b993f951da35517cb02519659addcd9ef Mon Sep 17 00:00:00 2001 From: Josako Date: Thu, 13 Jun 2024 10:35:15 +0200 Subject: [PATCH] Include history to improve query if required. --- common/langchain/EveAIHistoryRetriever.py | 44 +++++++++++++++++++++++ common/langchain/EveAIRetriever.py | 2 -- common/models/interaction.py | 2 ++ common/utils/model_utils.py | 3 ++ eveai_chat_workers/tasks.py | 30 ++++++++++++++-- 5 files changed, 77 insertions(+), 4 deletions(-) create mode 100644 common/langchain/EveAIHistoryRetriever.py diff --git a/common/langchain/EveAIHistoryRetriever.py b/common/langchain/EveAIHistoryRetriever.py new file mode 100644 index 0000000..1810169 --- /dev/null +++ b/common/langchain/EveAIHistoryRetriever.py @@ -0,0 +1,44 @@ +from langchain_core.retrievers import BaseRetriever +from sqlalchemy import asc +from sqlalchemy.exc import SQLAlchemyError +from pydantic import BaseModel, Field +from typing import Any, Dict +from flask import current_app + +from common.extensions import db +from common.models.interaction import ChatSession, Interaction +from common.utils.datetime_utils import get_date_in_timezone + + +class EveAIHistoryRetriever(BaseRetriever): + model_variables: Dict[str, Any] = Field(...) + session_id: str = Field(...) + + def __init__(self, model_variables: Dict[str, Any], session_id: str): + super().__init__() + self.model_variables = model_variables + self.session_id = session_id + + def _get_relevant_documents(self, query: str): + current_app.logger.debug(f'Retrieving history of interactions for query: {query}') + + try: + query_obj = ( + db.session.query(Interaction) + .join(ChatSession, Interaction.chat_session_id == ChatSession.id) + .filter(ChatSession.session_id == self.session_id) + .order_by(asc(Interaction.id)) + ) + + interactions = query_obj.all() + + result = [] + for interaction in interactions: + result.append(f'HUMAN:\n{interaction.detailed_question}\n\nAI: \n{interaction.answer}\n\n') + + except SQLAlchemyError as e: + current_app.logger.error(f'Error retrieving history of interactions: {e}') + db.session.rollback() + return [] + + return result \ No newline at end of file diff --git a/common/langchain/EveAIRetriever.py b/common/langchain/EveAIRetriever.py index 252138a..258da0e 100644 --- a/common/langchain/EveAIRetriever.py +++ b/common/langchain/EveAIRetriever.py @@ -16,10 +16,8 @@ class EveAIRetriever(BaseRetriever): def __init__(self, model_variables: Dict[str, Any], tenant_info: Dict[str, Any]): super().__init__() - current_app.logger.debug('Initializing EveAIRetriever') self.model_variables = model_variables self.tenant_info = tenant_info - current_app.logger.debug('EveAIRetriever initialized') def _get_relevant_documents(self, query: str): current_app.logger.debug(f'Retrieving relevant documents for query: {query}') diff --git a/common/models/interaction.py b/common/models/interaction.py index 1c1a144..ad81513 100644 --- a/common/models/interaction.py +++ b/common/models/interaction.py @@ -21,6 +21,7 @@ class Interaction(db.Model): id = db.Column(db.Integer, primary_key=True) chat_session_id = db.Column(db.Integer, db.ForeignKey(ChatSession.id), nullable=False) question = db.Column(db.Text, nullable=False) + detailed_question = db.Column(db.Text, nullable=True) answer = db.Column(db.Text, nullable=True) algorithm_used = db.Column(db.String(20), nullable=True) language = db.Column(db.String(2), nullable=False) @@ -28,6 +29,7 @@ class Interaction(db.Model): # Timing information question_at = db.Column(db.DateTime, nullable=False) + detailed_question_at = db.Column(db.DateTime, nullable=True) answer_at = db.Column(db.DateTime, nullable=True) # Relations diff --git a/common/utils/model_utils.py b/common/utils/model_utils.py index fcb344e..afcd7b9 100644 --- a/common/utils/model_utils.py +++ b/common/utils/model_utils.py @@ -103,15 +103,18 @@ def select_model_variables(tenant): case 'gpt-4-turbo' | 'gpt-4o': summary_template = current_app.config.get('GPT4_SUMMARY_TEMPLATE') rag_template = current_app.config.get('GPT4_RAG_TEMPLATE') + history_template = current_app.config.get('GPT4_HISTORY_TEMPLATE') tool_calling_supported = True case 'gpt-3-5-turbo': summary_template = current_app.config.get('GPT3_5_SUMMARY_TEMPLATE') rag_template = current_app.config.get('GPT3_5_RAG_TEMPLATE') + history_template = current_app.config.get('GPT3_5_HISTORY_TEMPLATE') case _: raise Exception(f'Error setting model variables for tenant {tenant.id} ' f'error: Invalid chat model') model_variables['summary_template'] = summary_template model_variables['rag_template'] = rag_template + model_variables['history_template'] = history_template if tool_calling_supported: model_variables['cited_answer_cls'] = CitedAnswer case _: diff --git a/eveai_chat_workers/tasks.py b/eveai_chat_workers/tasks.py index 768061f..c6bd0e1 100644 --- a/eveai_chat_workers/tasks.py +++ b/eveai_chat_workers/tasks.py @@ -23,6 +23,27 @@ from common.extensions import db from common.utils.celery_utils import current_celery from common.utils.model_utils import select_model_variables, create_language_template from common.langchain.EveAIRetriever import EveAIRetriever +from common.langchain.EveAIHistoryRetriever import EveAIHistoryRetriever + + +def detail_question(question, language, model_variables, session_id): + retriever = EveAIHistoryRetriever(model_variables, session_id) + llm = model_variables['llm'] + template = model_variables['history_template'] + language_template = create_language_template(template, language) + history_prompt = ChatPromptTemplate.from_template(language_template) + setup_and_retrieval = RunnableParallel({"history": retriever,"question": RunnablePassthrough()}) + output_parser = StrOutputParser() + + chain = setup_and_retrieval | history_prompt | llm | output_parser + + + try: + answer = chain.invoke(question) + return answer + except LangChainException as e: + current_app.logger.error(f'Error detailing question: {e}') + raise @current_celery.task(name='ask_question', queue='llm_interactions') @@ -73,6 +94,11 @@ def ask_question(tenant_id, question, language, session_id): # Langchain debugging if required # set_debug(True) + detailed_question = detail_question(question, language, model_variables, session_id) + current_app.logger.debug(f'Original question:\n {question}\n\nDetailed question: {detailed_question}') + new_interaction.detailed_question = detailed_question + new_interaction.detailed_question_at = dt.now(tz.utc) + retriever = EveAIRetriever(model_variables, tenant_info) llm = model_variables['llm'] template = model_variables['rag_template'] @@ -87,7 +113,7 @@ def ask_question(tenant_id, question, language, session_id): chain = setup_and_retrieval | rag_prompt | llm | output_parser # Invoke the chain with the actual question - answer = chain.invoke(question) + answer = chain.invoke(detailed_question) new_interaction.answer = answer result = { 'answer': answer, @@ -99,7 +125,7 @@ def ask_question(tenant_id, question, language, session_id): chain = setup_and_retrieval | rag_prompt | structured_llm - result = chain.invoke(question).dict() + result = chain.invoke(detailed_question).dict() current_app.logger.debug(f'ask_question: result answer: {result['answer']}') current_app.logger.debug(f'ask_question: result citations: {result["citations"]}') new_interaction.answer = result['answer']