Enable model variables & start working on RAG task
This commit is contained in:
158
eveai_chat_workers/tasks.py
Normal file
158
eveai_chat_workers/tasks.py
Normal file
@@ -0,0 +1,158 @@
|
||||
from datetime import datetime as dt, timezone as tz
|
||||
from flask import current_app
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from celery import states
|
||||
from celery.exceptions import Ignore
|
||||
import os
|
||||
|
||||
# Unstructured commercial client imports
|
||||
from unstructured_client import UnstructuredClient
|
||||
from unstructured_client.models import shared
|
||||
from unstructured_client.models.errors import SDKError
|
||||
|
||||
# OpenAI imports
|
||||
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain.chains.summarize import load_summarize_chain
|
||||
from langchain.text_splitter import CharacterTextSplitter
|
||||
from langchain_core.exceptions import LangChainException
|
||||
|
||||
from common.utils.database import Database
|
||||
from common.models.document import DocumentVersion, EmbeddingMistral, EmbeddingSmallOpenAI
|
||||
from common.models.user import Tenant
|
||||
from common.extensions import db
|
||||
from common.utils.celery_utils import current_celery
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
|
||||
@current_celery.task(name='ask_question', queue='llm_interactions')
|
||||
def ask_question(tenant_id, question):
|
||||
current_app.logger.debug('In ask_question')
|
||||
current_app.logger.debug(f'ask_question: Received question for tenant {tenant_id}: {question}. Processing...')
|
||||
|
||||
try:
|
||||
# Retrieve the tenant
|
||||
tenant = Tenant.query.get(tenant_id)
|
||||
if not tenant:
|
||||
raise Exception(f'Tenant {tenant_id} not found.')
|
||||
|
||||
# Ensure we are working in the correct database schema
|
||||
Database(tenant_id).switch_schema()
|
||||
|
||||
# Select variables to work with depending on tenant model
|
||||
model_variables = select_model_variables(tenant)
|
||||
|
||||
# create embedding for the query
|
||||
embedded_question = create_embedding(model_variables, question)
|
||||
|
||||
# Search the database for relevant embeddings
|
||||
relevant_embeddings = search_embeddings(model_variables, embedded_question)
|
||||
|
||||
response = ""
|
||||
for embed in relevant_embeddings:
|
||||
response += relevant_embeddings.chunk + '\n'
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
current_app.logger.error(f'ask_question: Error processing question: {e}')
|
||||
raise Ignore
|
||||
|
||||
|
||||
def select_model_variables(tenant):
|
||||
embedding_provider = tenant.embedding_model.rsplit('.', 1)[0]
|
||||
embedding_model = tenant.embedding_model.rsplit('.', 1)[1]
|
||||
|
||||
llm_provider = tenant.llm_model.rsplit('.', 1)[0]
|
||||
llm_model = tenant.llm_model.rsplit('.', 1)[1]
|
||||
|
||||
# Set model variables
|
||||
model_variables = {}
|
||||
if tenant.es_k:
|
||||
model_variables['k'] = tenant.es_k
|
||||
else:
|
||||
model_variables['k'] = 5
|
||||
|
||||
if tenant.es_similarity_threshold:
|
||||
model_variables['similarity_threshold'] = tenant.es_similarity_threshold
|
||||
else:
|
||||
model_variables['similarity_threshold'] = 0.7
|
||||
|
||||
if tenant.chat_RAG_temperature:
|
||||
model_variables['RAG_temperature'] = tenant.chat_RAG_temperature
|
||||
else:
|
||||
model_variables['RAG_temperature'] = 0.3
|
||||
|
||||
if tenant.chat_no_RAG_temperature:
|
||||
model_variables['no_RAG_temperature'] = tenant.chat_no_RAG_temperature
|
||||
else:
|
||||
model_variables['no_RAG_temperature'] = 0.5
|
||||
|
||||
# Set Embedding variables
|
||||
match embedding_provider:
|
||||
case 'openai':
|
||||
match embedding_model:
|
||||
case 'text-embedding-3-small':
|
||||
api_key = current_app.config.get('OPENAI_API_KEY')
|
||||
model_variables['embedding'] = OpenAIEmbeddings(api_key=api_key,
|
||||
model='text-embedding-3-small')
|
||||
model_variables['embedding_db_model'] = EmbeddingSmallOpenAI
|
||||
case _:
|
||||
raise Exception(f'Error setting model variables for tenant {tenant.id} '
|
||||
f'error: Invalid embedding model')
|
||||
case _:
|
||||
raise Exception(f'Error setting model variables for tenant {tenant.id} '
|
||||
f'error: Invalid embedding provider')
|
||||
|
||||
# Set Chat model variables
|
||||
match llm_provider:
|
||||
case 'openai':
|
||||
api_key = current_app.config.get('OPENAI_API_KEY')
|
||||
model_variables['llm'] = ChatOpenAI(api_key=api_key,
|
||||
model=llm_model,
|
||||
temperature=model_variables['RAG_temperature'])
|
||||
match llm_model:
|
||||
case 'gpt-4-turbo' | 'gpt-4-o':
|
||||
rag_template = current_app.config.get('GPT4_RAG_TEMPLATE')
|
||||
case 'gpt-3-5-turbo':
|
||||
rag_template = current_app.config.get('GPT3_5_RAG_TEMPLATE')
|
||||
case _:
|
||||
raise Exception(f'Error setting model variables for tenant {tenant.id} '
|
||||
f'error: Invalid chat model')
|
||||
model_variables['prompt'] = ChatPromptTemplate.from_template(rag_template)
|
||||
case _:
|
||||
raise Exception(f'Error setting model variables for tenant {tenant.id} '
|
||||
f'error: Invalid chat provider')
|
||||
|
||||
return model_variables
|
||||
|
||||
|
||||
def create_embedding(model_variables, question):
|
||||
try:
|
||||
embeddings = model_variables['embedding'].embed_documents(question)
|
||||
except LangChainException as e:
|
||||
raise Exception(f'Error creating embedding for question (LangChain): {e}')
|
||||
|
||||
return embeddings[0]
|
||||
|
||||
|
||||
def search_embeddings(model_variables, embedded_query):
|
||||
current_app.logger.debug(f'In search_embeddings searching for {embedded_query}')
|
||||
db_class = model_variables['embedding_db_model']
|
||||
try:
|
||||
res = (
|
||||
db.session.query(db_class, db_class.embedding.cosine_distance(embedded_query).label('distance'))
|
||||
.filter(db_class.embedding.cosine_distance(embedded_query) < model_variables['similarity_threshold'])
|
||||
.order_by("distance")
|
||||
.limit(model_variables['k'])
|
||||
.all()
|
||||
)
|
||||
except SQLAlchemyError as e:
|
||||
raise Exception(f'Error searching embeddings (SQLAlchemy): {e}')
|
||||
|
||||
current_app.logger.debug(f'Results from embedding search: {res}')
|
||||
return res
|
||||
|
||||
|
||||
def tasks_ping():
|
||||
return 'pong'
|
||||
Reference in New Issue
Block a user