Files
eveAI/eveai_chat_workers/tasks.py

149 lines
6.1 KiB
Python

from datetime import datetime as dt, timezone as tz
from flask import current_app, session
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableParallel, RunnablePassthrough
from langchain.globals import set_debug
from sqlalchemy.exc import SQLAlchemyError
from celery import states
from celery.exceptions import Ignore
import os
# OpenAI imports
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain.chains.summarize import load_summarize_chain
from langchain.text_splitter import CharacterTextSplitter
from langchain_core.exceptions import LangChainException
from common.utils.database import Database
from common.models.document import DocumentVersion, EmbeddingMistral, EmbeddingSmallOpenAI, Embedding
from common.models.user import Tenant
from common.models.interaction import ChatSession, Interaction, InteractionEmbedding
from common.extensions import db
from common.utils.celery_utils import current_celery
from common.utils.model_utils import select_model_variables, create_language_template
from common.langchain.EveAIRetriever import EveAIRetriever
@current_celery.task(name='ask_question', queue='llm_interactions')
def ask_question(tenant_id, question, language, session_id):
"""returns result structured as follows:
result = {
'answer': 'Your answer here',
'citations': ['http://example.com/citation1', 'http://example.com/citation2'],
'algorithm': 'algorithm_name',
'interaction_id': 'interaction_id_value'
}
"""
current_app.logger.info(f'ask_question: Received question for tenant {tenant_id}: {question}. Processing...')
try:
# Retrieve the tenant
tenant = Tenant.query.get(tenant_id)
if not tenant:
raise Exception(f'Tenant {tenant_id} not found.')
# Ensure we are working in the correct database schema
Database(tenant_id).switch_schema()
# Ensure we have a session to story history
chat_session = ChatSession.query.filter_by(session_id=session_id).first()
if not chat_session:
try:
chat_session = ChatSession()
chat_session.session_id = session_id
chat_session.session_start = dt.now(tz.utc)
db.session.add(chat_session)
db.session.commit()
except SQLAlchemyError as e:
current_app.logger.error(f'ask_question: Error initializing chat session in database: {e}')
raise
new_interaction = Interaction()
new_interaction.question = question
new_interaction.language = language
new_interaction.chat_session_id = chat_session.id
new_interaction.question_at = dt.now(tz.utc)
new_interaction.algorithm_used = current_app.config['INTERACTION_ALGORITHMS']['RAG_TENANT']['name']
# Select variables to work with depending on tenant model
model_variables = select_model_variables(tenant)
tenant_info = tenant.to_dict()
# Langchain debugging if required
# set_debug(True)
retriever = EveAIRetriever(model_variables, tenant_info)
llm = model_variables['llm']
template = model_variables['rag_template']
language_template = create_language_template(template, language)
rag_prompt = ChatPromptTemplate.from_template(language_template)
setup_and_retrieval = RunnableParallel({"context": retriever, "question": RunnablePassthrough()})
new_interaction_embeddings = []
if not model_variables['cited_answer_cls']: # The model doesn't support structured feedback
output_parser = StrOutputParser()
chain = setup_and_retrieval | rag_prompt | llm | output_parser
# Invoke the chain with the actual question
answer = chain.invoke(question)
new_interaction.answer = answer
result = {
'answer': answer,
'citations': []
}
else: # The model supports structured feedback
structured_llm = llm.with_structured_output(model_variables['cited_answer_cls'])
chain = setup_and_retrieval | rag_prompt | structured_llm
result = chain.invoke(question).dict()
current_app.logger.debug(f'ask_question: result answer: {result['answer']}')
current_app.logger.debug(f'ask_question: result citations: {result["citations"]}')
new_interaction.answer = result['answer']
# Filter out the existing Embedding IDs
given_embedding_ids = [int(emb_id) for emb_id in result['citations']]
embeddings = (
db.session.query(Embedding)
.filter(Embedding.id.in_(given_embedding_ids))
.all()
)
existing_embedding_ids = [emb.id for emb in embeddings]
urls = [emb.document_version.url for emb in embeddings]
for emb_id in existing_embedding_ids:
new_interaction_embedding = InteractionEmbedding(embedding_id=emb_id)
new_interaction_embedding.interaction = new_interaction
new_interaction_embeddings.append(new_interaction_embedding)
result['citations'] = urls
new_interaction.answer_at = dt.now(tz.utc)
chat_session.session_end = dt.now(tz.utc)
try:
db.session.add(chat_session)
db.session.add(new_interaction)
db.session.add_all(new_interaction_embeddings)
db.session.commit()
except SQLAlchemyError as e:
current_app.logger.error(f'ask_question: Error saving interaction to database: {e}')
raise
# Disable langchain debugging if set above.
# set_debug(False)
result['algorithm'] = current_app.config['INTERACTION_ALGORITHMS']['RAG_TENANT']['name']
result['interaction_id'] = new_interaction.id
return result
except Exception as e:
current_app.logger.error(f'ask_question: Error processing question: {e}')
raise
def tasks_ping():
return 'pong'