Files
eveAI/eveai_chat_workers/tasks.py

159 lines
6.0 KiB
Python

from datetime import datetime as dt, timezone as tz
from flask import current_app
from sqlalchemy.exc import SQLAlchemyError
from celery import states
from celery.exceptions import Ignore
import os
# Unstructured commercial client imports
from unstructured_client import UnstructuredClient
from unstructured_client.models import shared
from unstructured_client.models.errors import SDKError
# OpenAI imports
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain.chains.summarize import load_summarize_chain
from langchain.text_splitter import CharacterTextSplitter
from langchain_core.exceptions import LangChainException
from common.utils.database import Database
from common.models.document import DocumentVersion, EmbeddingMistral, EmbeddingSmallOpenAI
from common.models.user import Tenant
from common.extensions import db
from common.utils.celery_utils import current_celery
from bs4 import BeautifulSoup
@current_celery.task(name='ask_question', queue='llm_interactions')
def ask_question(tenant_id, question):
current_app.logger.debug('In ask_question')
current_app.logger.debug(f'ask_question: Received question for tenant {tenant_id}: {question}. Processing...')
try:
# Retrieve the tenant
tenant = Tenant.query.get(tenant_id)
if not tenant:
raise Exception(f'Tenant {tenant_id} not found.')
# Ensure we are working in the correct database schema
Database(tenant_id).switch_schema()
# Select variables to work with depending on tenant model
model_variables = select_model_variables(tenant)
# create embedding for the query
embedded_question = create_embedding(model_variables, question)
# Search the database for relevant embeddings
relevant_embeddings = search_embeddings(model_variables, embedded_question)
response = ""
for embed in relevant_embeddings:
response += relevant_embeddings.chunk + '\n'
return response
except Exception as e:
current_app.logger.error(f'ask_question: Error processing question: {e}')
raise Ignore
def select_model_variables(tenant):
embedding_provider = tenant.embedding_model.rsplit('.', 1)[0]
embedding_model = tenant.embedding_model.rsplit('.', 1)[1]
llm_provider = tenant.llm_model.rsplit('.', 1)[0]
llm_model = tenant.llm_model.rsplit('.', 1)[1]
# Set model variables
model_variables = {}
if tenant.es_k:
model_variables['k'] = tenant.es_k
else:
model_variables['k'] = 5
if tenant.es_similarity_threshold:
model_variables['similarity_threshold'] = tenant.es_similarity_threshold
else:
model_variables['similarity_threshold'] = 0.7
if tenant.chat_RAG_temperature:
model_variables['RAG_temperature'] = tenant.chat_RAG_temperature
else:
model_variables['RAG_temperature'] = 0.3
if tenant.chat_no_RAG_temperature:
model_variables['no_RAG_temperature'] = tenant.chat_no_RAG_temperature
else:
model_variables['no_RAG_temperature'] = 0.5
# Set Embedding variables
match embedding_provider:
case 'openai':
match embedding_model:
case 'text-embedding-3-small':
api_key = current_app.config.get('OPENAI_API_KEY')
model_variables['embedding'] = OpenAIEmbeddings(api_key=api_key,
model='text-embedding-3-small')
model_variables['embedding_db_model'] = EmbeddingSmallOpenAI
case _:
raise Exception(f'Error setting model variables for tenant {tenant.id} '
f'error: Invalid embedding model')
case _:
raise Exception(f'Error setting model variables for tenant {tenant.id} '
f'error: Invalid embedding provider')
# Set Chat model variables
match llm_provider:
case 'openai':
api_key = current_app.config.get('OPENAI_API_KEY')
model_variables['llm'] = ChatOpenAI(api_key=api_key,
model=llm_model,
temperature=model_variables['RAG_temperature'])
match llm_model:
case 'gpt-4-turbo' | 'gpt-4-o':
rag_template = current_app.config.get('GPT4_RAG_TEMPLATE')
case 'gpt-3-5-turbo':
rag_template = current_app.config.get('GPT3_5_RAG_TEMPLATE')
case _:
raise Exception(f'Error setting model variables for tenant {tenant.id} '
f'error: Invalid chat model')
model_variables['prompt'] = ChatPromptTemplate.from_template(rag_template)
case _:
raise Exception(f'Error setting model variables for tenant {tenant.id} '
f'error: Invalid chat provider')
return model_variables
def create_embedding(model_variables, question):
try:
embeddings = model_variables['embedding'].embed_documents(question)
except LangChainException as e:
raise Exception(f'Error creating embedding for question (LangChain): {e}')
return embeddings[0]
def search_embeddings(model_variables, embedded_query):
current_app.logger.debug(f'In search_embeddings searching for {embedded_query}')
db_class = model_variables['embedding_db_model']
try:
res = (
db.session.query(db_class, db_class.embedding.cosine_distance(embedded_query).label('distance'))
.filter(db_class.embedding.cosine_distance(embedded_query) < model_variables['similarity_threshold'])
.order_by("distance")
.limit(model_variables['k'])
.all()
)
except SQLAlchemyError as e:
raise Exception(f'Error searching embeddings (SQLAlchemy): {e}')
current_app.logger.debug(f'Results from embedding search: {res}')
return res
def tasks_ping():
return 'pong'