from datetime import datetime as dt, timezone as tz from flask import current_app from langchain_core.output_parsers import StrOutputParser from langchain_core.runnables import RunnableParallel, RunnablePassthrough from langchain.globals import set_debug from sqlalchemy.exc import SQLAlchemyError from celery import states from celery.exceptions import Ignore import os # OpenAI imports from langchain_openai import OpenAIEmbeddings, ChatOpenAI from langchain_core.prompts import ChatPromptTemplate from langchain.chains.summarize import load_summarize_chain from langchain.text_splitter import CharacterTextSplitter from langchain_core.exceptions import LangChainException from common.utils.database import Database from common.models.document import DocumentVersion, EmbeddingMistral, EmbeddingSmallOpenAI, Embedding from common.models.user import Tenant from common.models.interaction import ChatSession, Interaction, InteractionEmbedding from common.extensions import db from common.utils.celery_utils import current_celery from common.utils.model_utils import select_model_variables, create_language_template from common.langchain.EveAIRetriever import EveAIRetriever @current_celery.task(name='ask_question', queue='llm_interactions') def ask_question(tenant_id, question, language, session_id): """returns result structured as follows: result = { 'answer': 'Your answer here', 'citations': ['http://example.com/citation1', 'http://example.com/citation2'], 'algorithm': 'algorithm_name', 'interaction_id': 'interaction_id_value' } """ current_app.logger.info(f'ask_question: Received question for tenant {tenant_id}: {question}. Processing...') try: # Retrieve the tenant tenant = Tenant.query.get(tenant_id) if not tenant: raise Exception(f'Tenant {tenant_id} not found.') # Ensure we are working in the correct database schema Database(tenant_id).switch_schema() chat_session = ChatSession.query.filter_by(session_id=session_id).first() if not chat_session: # Initialize a chat_session on the database try: chat_session = ChatSession() chat_session.session_id = session_id chat_session.session_start = dt.now(tz.utc) db.session.add(chat_session) db.session.commit() except SQLAlchemyError as e: current_app.logger.error(f'ask_question: Error initializing chat session in database: {e}') raise new_interaction = Interaction() new_interaction.question = question new_interaction.language = language new_interaction.chat_session_id = chat_session.id new_interaction.question_at = dt.now(tz.utc) new_interaction.algorithm_used = current_app.config['INTERACTION_ALGORITHMS']['RAG_TENANT']['name'] # try: # db.session.add(new_interaction) # db.session.commit() # except SQLAlchemyError as e: # current_app.logger.error(f'ask_question: Error saving interaction to database: {e}') # raise current_app.logger.debug(f'ask_question: new_interaction: {new_interaction}') # Select variables to work with depending on tenant model model_variables = select_model_variables(tenant) current_app.logger.debug(f'ask_question: model_variables: {model_variables}') set_debug(True) retriever = EveAIRetriever(model_variables) llm = model_variables['llm'] template = model_variables['rag_template'] language_template = create_language_template(template, language) rag_prompt = ChatPromptTemplate.from_template(language_template) setup_and_retrieval = RunnableParallel({"context": retriever, "question": RunnablePassthrough()}) new_interaction_embeddings = [] if not model_variables['cited_answer_cls']: # The model doesn't support structured feedback output_parser = StrOutputParser() chain = setup_and_retrieval | rag_prompt | llm | output_parser # Invoke the chain with the actual question answer = chain.invoke(question) new_interaction.answer = answer result = { 'answer': answer, 'citations': [] } else: # The model supports structured feedback structured_llm = llm.with_structured_output(model_variables['cited_answer_cls']) chain = setup_and_retrieval | rag_prompt | structured_llm result = chain.invoke(question).dict() current_app.logger.debug(f'ask_question: result answer: {result['answer']}') current_app.logger.debug(f'ask_question: result citations: {result["citations"]}') new_interaction.answer = result['answer'] # Filter out the existing Embedding IDs given_embedding_ids = [int(emb_id) for emb_id in result['citations']] embeddings = ( db.session.query(Embedding) .filter(Embedding.id.in_(given_embedding_ids)) .all() ) existing_embedding_ids = [emb.id for emb in embeddings] urls = [emb.document_version.url for emb in embeddings] for emb_id in existing_embedding_ids: new_interaction_embedding = InteractionEmbedding(embedding_id=emb_id) new_interaction_embedding.interaction = new_interaction new_interaction_embeddings.append(new_interaction_embedding) result['citations'] = urls new_interaction.answer_at = dt.now(tz.utc) chat_session.session_end = dt.now(tz.utc) try: db.session.add(chat_session) db.session.add(new_interaction) db.session.add_all(new_interaction_embeddings) db.session.commit() except SQLAlchemyError as e: current_app.logger.error(f'ask_question: Error saving interaction to database: {e}') raise set_debug(False) result['algorithm'] = current_app.config['INTERACTION_ALGORITHMS']['RAG_TENANT']['name'] result['interaction_id'] = new_interaction.id return result except Exception as e: current_app.logger.error(f'ask_question: Error processing question: {e}') raise def tasks_ping(): return 'pong'