Include history to improve query if required.

This commit is contained in:
Josako
2024-06-13 10:35:15 +02:00
parent 50851dc51c
commit 24a3747b99
5 changed files with 77 additions and 4 deletions

View File

@@ -23,6 +23,27 @@ from common.extensions import db
from common.utils.celery_utils import current_celery
from common.utils.model_utils import select_model_variables, create_language_template
from common.langchain.EveAIRetriever import EveAIRetriever
from common.langchain.EveAIHistoryRetriever import EveAIHistoryRetriever
def detail_question(question, language, model_variables, session_id):
retriever = EveAIHistoryRetriever(model_variables, session_id)
llm = model_variables['llm']
template = model_variables['history_template']
language_template = create_language_template(template, language)
history_prompt = ChatPromptTemplate.from_template(language_template)
setup_and_retrieval = RunnableParallel({"history": retriever,"question": RunnablePassthrough()})
output_parser = StrOutputParser()
chain = setup_and_retrieval | history_prompt | llm | output_parser
try:
answer = chain.invoke(question)
return answer
except LangChainException as e:
current_app.logger.error(f'Error detailing question: {e}')
raise
@current_celery.task(name='ask_question', queue='llm_interactions')
@@ -73,6 +94,11 @@ def ask_question(tenant_id, question, language, session_id):
# Langchain debugging if required
# set_debug(True)
detailed_question = detail_question(question, language, model_variables, session_id)
current_app.logger.debug(f'Original question:\n {question}\n\nDetailed question: {detailed_question}')
new_interaction.detailed_question = detailed_question
new_interaction.detailed_question_at = dt.now(tz.utc)
retriever = EveAIRetriever(model_variables, tenant_info)
llm = model_variables['llm']
template = model_variables['rag_template']
@@ -87,7 +113,7 @@ def ask_question(tenant_id, question, language, session_id):
chain = setup_and_retrieval | rag_prompt | llm | output_parser
# Invoke the chain with the actual question
answer = chain.invoke(question)
answer = chain.invoke(detailed_question)
new_interaction.answer = answer
result = {
'answer': answer,
@@ -99,7 +125,7 @@ def ask_question(tenant_id, question, language, session_id):
chain = setup_and_retrieval | rag_prompt | structured_llm
result = chain.invoke(question).dict()
result = chain.invoke(detailed_question).dict()
current_app.logger.debug(f'ask_question: result answer: {result['answer']}')
current_app.logger.debug(f'ask_question: result citations: {result["citations"]}')
new_interaction.answer = result['answer']