Improving chat functionality significantly throughout the application.

This commit is contained in:
Josako
2024-06-12 11:07:18 +02:00
parent 27b6de8734
commit be311c440b
22 changed files with 604 additions and 127 deletions

View File

@@ -1,7 +1,8 @@
from datetime import datetime as dt, timezone as tz
from flask import current_app
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableParallel
from langchain_core.runnables import RunnableParallel, RunnablePassthrough
from langchain.globals import set_debug
from sqlalchemy.exc import SQLAlchemyError
from celery import states
from celery.exceptions import Ignore
@@ -15,17 +16,25 @@ from langchain.text_splitter import CharacterTextSplitter
from langchain_core.exceptions import LangChainException
from common.utils.database import Database
from common.models.document import DocumentVersion, EmbeddingMistral, EmbeddingSmallOpenAI
from common.models.document import DocumentVersion, EmbeddingMistral, EmbeddingSmallOpenAI, Embedding
from common.models.user import Tenant
from common.models.interaction import ChatSession, Interaction, InteractionEmbedding
from common.extensions import db
from common.utils.celery_utils import current_celery
from common.utils.model_utils import select_model_variables
from common.utils.model_utils import select_model_variables, create_language_template
from common.langchain.EveAIRetriever import EveAIRetriever
@current_celery.task(name='ask_question', queue='llm_interactions')
def ask_question(tenant_id, question):
current_app.logger.debug('In ask_question')
def ask_question(tenant_id, question, language, session_id):
"""returns result structured as follows:
result = {
'answer': 'Your answer here',
'citations': ['http://example.com/citation1', 'http://example.com/citation2'],
'algorithm': 'algorithm_name',
'interaction_id': 'interaction_id_value'
}
"""
current_app.logger.info(f'ask_question: Received question for tenant {tenant_id}: {question}. Processing...')
try:
@@ -37,17 +46,106 @@ def ask_question(tenant_id, question):
# Ensure we are working in the correct database schema
Database(tenant_id).switch_schema()
chat_session = ChatSession.query.filter_by(session_id=session_id).first()
if not chat_session:
# Initialize a chat_session on the database
try:
chat_session = ChatSession()
chat_session.session_id = session_id
chat_session.session_start = dt.now(tz.utc)
db.session.add(chat_session)
db.session.commit()
except SQLAlchemyError as e:
current_app.logger.error(f'ask_question: Error initializing chat session in database: {e}')
raise
new_interaction = Interaction()
new_interaction.question = question
new_interaction.language = language
new_interaction.chat_session_id = chat_session.id
new_interaction.question_at = dt.now(tz.utc)
new_interaction.algorithm_used = current_app.config['INTERACTION_ALGORITHMS']['RAG_TENANT']['name']
# try:
# db.session.add(new_interaction)
# db.session.commit()
# except SQLAlchemyError as e:
# current_app.logger.error(f'ask_question: Error saving interaction to database: {e}')
# raise
current_app.logger.debug(f'ask_question: new_interaction: {new_interaction}')
# Select variables to work with depending on tenant model
model_variables = select_model_variables(tenant)
current_app.logger.debug(f'ask_question: model_variables: {model_variables}')
set_debug(True)
retriever = EveAIRetriever(model_variables)
llm = model_variables['llm']
template = model_variables['rag_template']
language_template = create_language_template(template, language)
rag_prompt = ChatPromptTemplate.from_template(language_template)
setup_and_retrieval = RunnableParallel({"context": retriever, "question": RunnablePassthrough()})
# Search the database for relevant embeddings
relevant_embeddings = retriever.invoke(question)
new_interaction_embeddings = []
if not model_variables['cited_answer_cls']: # The model doesn't support structured feedback
output_parser = StrOutputParser()
return 'No response yet, check back later.'
chain = setup_and_retrieval | rag_prompt | llm | output_parser
# Invoke the chain with the actual question
answer = chain.invoke(question)
new_interaction.answer = answer
result = {
'answer': answer,
'citations': []
}
else: # The model supports structured feedback
structured_llm = llm.with_structured_output(model_variables['cited_answer_cls'])
chain = setup_and_retrieval | rag_prompt | structured_llm
result = chain.invoke(question).dict()
current_app.logger.debug(f'ask_question: result answer: {result['answer']}')
current_app.logger.debug(f'ask_question: result citations: {result["citations"]}')
new_interaction.answer = result['answer']
# Filter out the existing Embedding IDs
given_embedding_ids = [int(emb_id) for emb_id in result['citations']]
embeddings = (
db.session.query(Embedding)
.filter(Embedding.id.in_(given_embedding_ids))
.all()
)
existing_embedding_ids = [emb.id for emb in embeddings]
urls = [emb.document_version.url for emb in embeddings]
for emb_id in existing_embedding_ids:
new_interaction_embedding = InteractionEmbedding(embedding_id=emb_id)
new_interaction_embedding.interaction = new_interaction
new_interaction_embeddings.append(new_interaction_embedding)
result['citations'] = urls
new_interaction.answer_at = dt.now(tz.utc)
chat_session.session_end = dt.now(tz.utc)
try:
db.session.add(chat_session)
db.session.add(new_interaction)
db.session.add_all(new_interaction_embeddings)
db.session.commit()
except SQLAlchemyError as e:
current_app.logger.error(f'ask_question: Error saving interaction to database: {e}')
raise
set_debug(False)
result['algorithm'] = current_app.config['INTERACTION_ALGORITHMS']['RAG_TENANT']['name']
result['interaction_id'] = new_interaction.id
return result
except Exception as e:
current_app.logger.error(f'ask_question: Error processing question: {e}')
raise