159 lines
6.0 KiB
Python
159 lines
6.0 KiB
Python
from datetime import datetime as dt, timezone as tz
|
|
from flask import current_app
|
|
from sqlalchemy.exc import SQLAlchemyError
|
|
from celery import states
|
|
from celery.exceptions import Ignore
|
|
import os
|
|
|
|
# Unstructured commercial client imports
|
|
from unstructured_client import UnstructuredClient
|
|
from unstructured_client.models import shared
|
|
from unstructured_client.models.errors import SDKError
|
|
|
|
# OpenAI imports
|
|
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
|
|
from langchain_core.prompts import ChatPromptTemplate
|
|
from langchain.chains.summarize import load_summarize_chain
|
|
from langchain.text_splitter import CharacterTextSplitter
|
|
from langchain_core.exceptions import LangChainException
|
|
|
|
from common.utils.database import Database
|
|
from common.models.document import DocumentVersion, EmbeddingMistral, EmbeddingSmallOpenAI
|
|
from common.models.user import Tenant
|
|
from common.extensions import db
|
|
from common.utils.celery_utils import current_celery
|
|
|
|
from bs4 import BeautifulSoup
|
|
|
|
|
|
@current_celery.task(name='ask_question', queue='llm_interactions')
|
|
def ask_question(tenant_id, question):
|
|
current_app.logger.debug('In ask_question')
|
|
current_app.logger.debug(f'ask_question: Received question for tenant {tenant_id}: {question}. Processing...')
|
|
|
|
try:
|
|
# Retrieve the tenant
|
|
tenant = Tenant.query.get(tenant_id)
|
|
if not tenant:
|
|
raise Exception(f'Tenant {tenant_id} not found.')
|
|
|
|
# Ensure we are working in the correct database schema
|
|
Database(tenant_id).switch_schema()
|
|
|
|
# Select variables to work with depending on tenant model
|
|
model_variables = select_model_variables(tenant)
|
|
|
|
# create embedding for the query
|
|
embedded_question = create_embedding(model_variables, question)
|
|
|
|
# Search the database for relevant embeddings
|
|
relevant_embeddings = search_embeddings(model_variables, embedded_question)
|
|
|
|
response = ""
|
|
for embed in relevant_embeddings:
|
|
response += relevant_embeddings.chunk + '\n'
|
|
|
|
return response
|
|
except Exception as e:
|
|
current_app.logger.error(f'ask_question: Error processing question: {e}')
|
|
raise Ignore
|
|
|
|
|
|
def select_model_variables(tenant):
|
|
embedding_provider = tenant.embedding_model.rsplit('.', 1)[0]
|
|
embedding_model = tenant.embedding_model.rsplit('.', 1)[1]
|
|
|
|
llm_provider = tenant.llm_model.rsplit('.', 1)[0]
|
|
llm_model = tenant.llm_model.rsplit('.', 1)[1]
|
|
|
|
# Set model variables
|
|
model_variables = {}
|
|
if tenant.es_k:
|
|
model_variables['k'] = tenant.es_k
|
|
else:
|
|
model_variables['k'] = 5
|
|
|
|
if tenant.es_similarity_threshold:
|
|
model_variables['similarity_threshold'] = tenant.es_similarity_threshold
|
|
else:
|
|
model_variables['similarity_threshold'] = 0.7
|
|
|
|
if tenant.chat_RAG_temperature:
|
|
model_variables['RAG_temperature'] = tenant.chat_RAG_temperature
|
|
else:
|
|
model_variables['RAG_temperature'] = 0.3
|
|
|
|
if tenant.chat_no_RAG_temperature:
|
|
model_variables['no_RAG_temperature'] = tenant.chat_no_RAG_temperature
|
|
else:
|
|
model_variables['no_RAG_temperature'] = 0.5
|
|
|
|
# Set Embedding variables
|
|
match embedding_provider:
|
|
case 'openai':
|
|
match embedding_model:
|
|
case 'text-embedding-3-small':
|
|
api_key = current_app.config.get('OPENAI_API_KEY')
|
|
model_variables['embedding'] = OpenAIEmbeddings(api_key=api_key,
|
|
model='text-embedding-3-small')
|
|
model_variables['embedding_db_model'] = EmbeddingSmallOpenAI
|
|
case _:
|
|
raise Exception(f'Error setting model variables for tenant {tenant.id} '
|
|
f'error: Invalid embedding model')
|
|
case _:
|
|
raise Exception(f'Error setting model variables for tenant {tenant.id} '
|
|
f'error: Invalid embedding provider')
|
|
|
|
# Set Chat model variables
|
|
match llm_provider:
|
|
case 'openai':
|
|
api_key = current_app.config.get('OPENAI_API_KEY')
|
|
model_variables['llm'] = ChatOpenAI(api_key=api_key,
|
|
model=llm_model,
|
|
temperature=model_variables['RAG_temperature'])
|
|
match llm_model:
|
|
case 'gpt-4-turbo' | 'gpt-4-o':
|
|
rag_template = current_app.config.get('GPT4_RAG_TEMPLATE')
|
|
case 'gpt-3-5-turbo':
|
|
rag_template = current_app.config.get('GPT3_5_RAG_TEMPLATE')
|
|
case _:
|
|
raise Exception(f'Error setting model variables for tenant {tenant.id} '
|
|
f'error: Invalid chat model')
|
|
model_variables['prompt'] = ChatPromptTemplate.from_template(rag_template)
|
|
case _:
|
|
raise Exception(f'Error setting model variables for tenant {tenant.id} '
|
|
f'error: Invalid chat provider')
|
|
|
|
return model_variables
|
|
|
|
|
|
def create_embedding(model_variables, question):
|
|
try:
|
|
embeddings = model_variables['embedding'].embed_documents(question)
|
|
except LangChainException as e:
|
|
raise Exception(f'Error creating embedding for question (LangChain): {e}')
|
|
|
|
return embeddings[0]
|
|
|
|
|
|
def search_embeddings(model_variables, embedded_query):
|
|
current_app.logger.debug(f'In search_embeddings searching for {embedded_query}')
|
|
db_class = model_variables['embedding_db_model']
|
|
try:
|
|
res = (
|
|
db.session.query(db_class, db_class.embedding.cosine_distance(embedded_query).label('distance'))
|
|
.filter(db_class.embedding.cosine_distance(embedded_query) < model_variables['similarity_threshold'])
|
|
.order_by("distance")
|
|
.limit(model_variables['k'])
|
|
.all()
|
|
)
|
|
except SQLAlchemyError as e:
|
|
raise Exception(f'Error searching embeddings (SQLAlchemy): {e}')
|
|
|
|
current_app.logger.debug(f'Results from embedding search: {res}')
|
|
return res
|
|
|
|
|
|
def tasks_ping():
|
|
return 'pong'
|