from datetime import datetime as dt, timezone as tz from flask import current_app from sqlalchemy.exc import SQLAlchemyError from celery import states from celery.exceptions import Ignore import os # Unstructured commercial client imports from unstructured_client import UnstructuredClient from unstructured_client.models import shared from unstructured_client.models.errors import SDKError # OpenAI imports from langchain_openai import OpenAIEmbeddings, ChatOpenAI from langchain_core.prompts import ChatPromptTemplate from langchain.chains.summarize import load_summarize_chain from langchain.text_splitter import CharacterTextSplitter from langchain_core.exceptions import LangChainException from common.utils.database import Database from common.models.document import DocumentVersion, EmbeddingMistral, EmbeddingSmallOpenAI from common.models.user import Tenant from common.extensions import db from common.utils.celery_utils import current_celery from bs4 import BeautifulSoup @current_celery.task(name='ask_question', queue='llm_interactions') def ask_question(tenant_id, question): current_app.logger.debug('In ask_question') current_app.logger.debug(f'ask_question: Received question for tenant {tenant_id}: {question}. Processing...') try: # Retrieve the tenant tenant = Tenant.query.get(tenant_id) if not tenant: raise Exception(f'Tenant {tenant_id} not found.') # Ensure we are working in the correct database schema Database(tenant_id).switch_schema() # Select variables to work with depending on tenant model model_variables = select_model_variables(tenant) # create embedding for the query embedded_question = create_embedding(model_variables, question) # Search the database for relevant embeddings relevant_embeddings = search_embeddings(model_variables, embedded_question) response = "" for embed in relevant_embeddings: response += relevant_embeddings.chunk + '\n' return response except Exception as e: current_app.logger.error(f'ask_question: Error processing question: {e}') raise Ignore def select_model_variables(tenant): embedding_provider = tenant.embedding_model.rsplit('.', 1)[0] embedding_model = tenant.embedding_model.rsplit('.', 1)[1] llm_provider = tenant.llm_model.rsplit('.', 1)[0] llm_model = tenant.llm_model.rsplit('.', 1)[1] # Set model variables model_variables = {} if tenant.es_k: model_variables['k'] = tenant.es_k else: model_variables['k'] = 5 if tenant.es_similarity_threshold: model_variables['similarity_threshold'] = tenant.es_similarity_threshold else: model_variables['similarity_threshold'] = 0.7 if tenant.chat_RAG_temperature: model_variables['RAG_temperature'] = tenant.chat_RAG_temperature else: model_variables['RAG_temperature'] = 0.3 if tenant.chat_no_RAG_temperature: model_variables['no_RAG_temperature'] = tenant.chat_no_RAG_temperature else: model_variables['no_RAG_temperature'] = 0.5 # Set Embedding variables match embedding_provider: case 'openai': match embedding_model: case 'text-embedding-3-small': api_key = current_app.config.get('OPENAI_API_KEY') model_variables['embedding'] = OpenAIEmbeddings(api_key=api_key, model='text-embedding-3-small') model_variables['embedding_db_model'] = EmbeddingSmallOpenAI case _: raise Exception(f'Error setting model variables for tenant {tenant.id} ' f'error: Invalid embedding model') case _: raise Exception(f'Error setting model variables for tenant {tenant.id} ' f'error: Invalid embedding provider') # Set Chat model variables match llm_provider: case 'openai': api_key = current_app.config.get('OPENAI_API_KEY') model_variables['llm'] = ChatOpenAI(api_key=api_key, model=llm_model, temperature=model_variables['RAG_temperature']) match llm_model: case 'gpt-4-turbo' | 'gpt-4-o': rag_template = current_app.config.get('GPT4_RAG_TEMPLATE') case 'gpt-3-5-turbo': rag_template = current_app.config.get('GPT3_5_RAG_TEMPLATE') case _: raise Exception(f'Error setting model variables for tenant {tenant.id} ' f'error: Invalid chat model') model_variables['prompt'] = ChatPromptTemplate.from_template(rag_template) case _: raise Exception(f'Error setting model variables for tenant {tenant.id} ' f'error: Invalid chat provider') return model_variables def create_embedding(model_variables, question): try: embeddings = model_variables['embedding'].embed_documents(question) except LangChainException as e: raise Exception(f'Error creating embedding for question (LangChain): {e}') return embeddings[0] def search_embeddings(model_variables, embedded_query): current_app.logger.debug(f'In search_embeddings searching for {embedded_query}') db_class = model_variables['embedding_db_model'] try: res = ( db.session.query(db_class, db_class.embedding.cosine_distance(embedded_query).label('distance')) .filter(db_class.embedding.cosine_distance(embedded_query) < model_variables['similarity_threshold']) .order_by("distance") .limit(model_variables['k']) .all() ) except SQLAlchemyError as e: raise Exception(f'Error searching embeddings (SQLAlchemy): {e}') current_app.logger.debug(f'Results from embedding search: {res}') return res def tasks_ping(): return 'pong'