- Move from OpenAI to Mistral Embeddings
- Move embedding model settings from tenant to catalog - BUG: error processing configuration for chunking patterns in HTML_PROCESSOR - Removed eveai_chat from docker-files and nginx configuration, as it is now obsolete - BUG: error in Library Operations when creating a new default RAG library - BUG: Added public type in migration scripts - Removed SocketIO from all code and requirements.txt
This commit is contained in:
0
common/eveai_model/__init__.py
Normal file
0
common/eveai_model/__init__.py
Normal file
11
common/eveai_model/eveai_embedding_base.py
Normal file
11
common/eveai_model/eveai_embedding_base.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from abc import abstractmethod
|
||||
from typing import List
|
||||
|
||||
|
||||
class EveAIEmbeddings:
|
||||
@abstractmethod
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
pass
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
return self.embed_documents([text])[0]
|
||||
40
common/eveai_model/tracked_mistral_embeddings.py
Normal file
40
common/eveai_model/tracked_mistral_embeddings.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from flask import current_app
|
||||
from langchain_mistralai import MistralAIEmbeddings
|
||||
from typing import List, Any
|
||||
import time
|
||||
|
||||
from common.eveai_model.eveai_embedding_base import EveAIEmbeddings
|
||||
from common.utils.business_event_context import current_event
|
||||
from mistralai import Mistral
|
||||
|
||||
|
||||
class TrackedMistralAIEmbeddings(EveAIEmbeddings):
|
||||
def __init__(self, model: str = "mistral_embed"):
|
||||
api_key = current_app.config['MISTRAL_API_KEY']
|
||||
self.client = Mistral(
|
||||
api_key=api_key
|
||||
)
|
||||
self.model = model
|
||||
super().__init__()
|
||||
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
start_time = time.time()
|
||||
result = self.client.embeddings.create(
|
||||
model=self.model,
|
||||
inputs=texts
|
||||
)
|
||||
end_time = time.time()
|
||||
|
||||
metrics = {
|
||||
'total_tokens': result.usage.total_tokens,
|
||||
'prompt_tokens': result.usage.prompt_tokens, # For embeddings, all tokens are prompt tokens
|
||||
'completion_tokens': result.usage.completion_tokens,
|
||||
'time_elapsed': end_time - start_time,
|
||||
'interaction_type': 'Embedding',
|
||||
}
|
||||
current_event.log_llm_metrics(metrics)
|
||||
|
||||
embeddings = [embedding.embedding for embedding in result.data]
|
||||
|
||||
return embeddings
|
||||
|
||||
@@ -5,7 +5,6 @@ from flask_security import Security
|
||||
from flask_mailman import Mail
|
||||
from flask_login import LoginManager
|
||||
from flask_cors import CORS
|
||||
from flask_socketio import SocketIO
|
||||
from flask_jwt_extended import JWTManager
|
||||
from flask_session import Session
|
||||
from flask_wtf import CSRFProtect
|
||||
@@ -27,7 +26,6 @@ security = Security()
|
||||
mail = Mail()
|
||||
login_manager = LoginManager()
|
||||
cors = CORS()
|
||||
socketio = SocketIO()
|
||||
jwt = JWTManager()
|
||||
session = Session()
|
||||
api_rest = Api()
|
||||
|
||||
@@ -12,8 +12,10 @@ class Catalog(db.Model):
|
||||
description = db.Column(db.Text, nullable=True)
|
||||
type = db.Column(db.String(50), nullable=False, default="STANDARD_CATALOG")
|
||||
|
||||
min_chunk_size = db.Column(db.Integer, nullable=True, default=2000)
|
||||
max_chunk_size = db.Column(db.Integer, nullable=True, default=3000)
|
||||
embedding_model = db.Column(db.String(50), nullable=True)
|
||||
|
||||
min_chunk_size = db.Column(db.Integer, nullable=True, default=1500)
|
||||
max_chunk_size = db.Column(db.Integer, nullable=True, default=2500)
|
||||
|
||||
# Meta Data
|
||||
user_metadata = db.Column(JSONB, nullable=True)
|
||||
|
||||
@@ -31,7 +31,6 @@ class Tenant(db.Model):
|
||||
allowed_languages = db.Column(ARRAY(sa.String(2)), nullable=True)
|
||||
|
||||
# LLM specific choices
|
||||
embedding_model = db.Column(db.String(50), nullable=True)
|
||||
llm_model = db.Column(db.String(50), nullable=True)
|
||||
|
||||
# Entitlements
|
||||
@@ -66,7 +65,6 @@ class Tenant(db.Model):
|
||||
'type': self.type,
|
||||
'default_language': self.default_language,
|
||||
'allowed_languages': self.allowed_languages,
|
||||
'embedding_model': self.embedding_model,
|
||||
'llm_model': self.llm_model,
|
||||
'currency': self.currency,
|
||||
}
|
||||
|
||||
@@ -652,12 +652,15 @@ def json_to_patterns(json_content: str) -> str:
|
||||
def json_to_pattern_list(json_content: str) -> list:
|
||||
"""Convert JSON patterns list to text area content"""
|
||||
try:
|
||||
patterns = json.loads(json_content)
|
||||
if not isinstance(patterns, list):
|
||||
raise ValueError("JSON must contain a list of patterns")
|
||||
# Unescape if needed
|
||||
patterns = [pattern.replace('\\\\', '\\') for pattern in patterns]
|
||||
return patterns
|
||||
if json_content:
|
||||
patterns = json.loads(json_content)
|
||||
if not isinstance(patterns, list):
|
||||
raise ValueError("JSON must contain a list of patterns")
|
||||
# Unescape if needed
|
||||
patterns = [pattern.replace('\\\\', '\\') for pattern in patterns]
|
||||
return patterns
|
||||
else:
|
||||
return []
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Invalid JSON format: {e}")
|
||||
|
||||
|
||||
@@ -124,4 +124,15 @@ class EveAISocketInputException(EveAIException):
|
||||
"""Raised when a socket call receives an invalid payload"""
|
||||
|
||||
def __init__(self, message, status_code=400, payload=None):
|
||||
super.__init__(message, status_code, payload)
|
||||
super.__init__(message, status_code, payload)
|
||||
|
||||
|
||||
class EveAIInvalidEmbeddingModel(EveAIException):
|
||||
"""Raised when no or an invalid embedding model is provided in the catalog"""
|
||||
|
||||
def __init__(self, tenant_id, catalog_id, status_code=400, payload=None):
|
||||
self.tenant_id = tenant_id
|
||||
self.catalog_id = catalog_id
|
||||
# Construct the message dynamically
|
||||
message = f"Tenant with ID '{tenant_id}' has no or an invalid embedding model in Catalog {catalog_id}."
|
||||
super().__init__(message, status_code, payload)
|
||||
|
||||
@@ -1,23 +1,25 @@
|
||||
import os
|
||||
from typing import Dict, Any, Optional
|
||||
from typing import Dict, Any, Optional, Tuple
|
||||
|
||||
import langcodes
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
|
||||
from common.langchain.llm_metrics_handler import LLMMetricsHandler
|
||||
from common.langchain.templates.template_manager import TemplateManager
|
||||
from langchain_openai import OpenAIEmbeddings, ChatOpenAI, OpenAI
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
from langchain_mistralai import ChatMistralAI
|
||||
from flask import current_app
|
||||
from datetime import datetime as dt, timezone as tz
|
||||
|
||||
from common.langchain.tracked_openai_embeddings import TrackedOpenAIEmbeddings
|
||||
from common.eveai_model.tracked_mistral_embeddings import TrackedMistralAIEmbeddings
|
||||
from common.langchain.tracked_transcription import TrackedOpenAITranscription
|
||||
from common.models.user import Tenant
|
||||
from common.utils.cache.base import CacheHandler
|
||||
from config.model_config import MODEL_CONFIG
|
||||
from common.extensions import template_manager, cache_manager
|
||||
from common.models.document import EmbeddingLargeOpenAI, EmbeddingSmallOpenAI
|
||||
from common.utils.eveai_exceptions import EveAITenantNotFound
|
||||
from common.extensions import template_manager
|
||||
from common.models.document import EmbeddingMistral
|
||||
from common.utils.eveai_exceptions import EveAITenantNotFound, EveAIInvalidEmbeddingModel
|
||||
|
||||
llm_model_cache: Dict[Tuple[str, float], BaseChatModel] = {}
|
||||
llm_metrics_handler = LLMMetricsHandler()
|
||||
|
||||
|
||||
def create_language_template(template: str, language: str) -> str:
|
||||
@@ -55,6 +57,63 @@ def replace_variable_in_template(template: str, variable: str, value: str) -> st
|
||||
return template.replace(variable, value or "")
|
||||
|
||||
|
||||
def get_embedding_model_and_class(tenant_id, catalog_id, full_embedding_name):
|
||||
"""
|
||||
Retrieve the embedding model and embedding model class to store Embeddings
|
||||
|
||||
Args:
|
||||
tenant_id: ID of the tenant
|
||||
catalog_id: ID of the catalog
|
||||
full_embedding_name: The full name of the embedding model: <provider>.<model>
|
||||
|
||||
Returns:
|
||||
embedding_model, embedding_model_class
|
||||
"""
|
||||
embedding_provider, embedding_model_name = full_embedding_name.split('.')
|
||||
|
||||
# Calculate the embedding model to be used
|
||||
if embedding_provider == "mistral":
|
||||
api_key = current_app.config['MISTRAL_API_KEY']
|
||||
embedding_model = TrackedMistralAIEmbeddings(
|
||||
model=embedding_model_name
|
||||
)
|
||||
else:
|
||||
raise EveAIInvalidEmbeddingModel(tenant_id, catalog_id)
|
||||
|
||||
# Calculate the Embedding Model Class to be used to store embeddings
|
||||
if embedding_model_name == "mistral-embed":
|
||||
embedding_model_class = EmbeddingMistral
|
||||
else:
|
||||
raise EveAIInvalidEmbeddingModel(tenant_id, catalog_id)
|
||||
|
||||
return embedding_model, embedding_model_class
|
||||
|
||||
|
||||
def get_llm(full_model_name, temperature):
|
||||
if not full_model_name:
|
||||
full_model_name = 'openai.gpt-4o' # Default to gpt-4o for now, as this is the original model developed against
|
||||
|
||||
llm = llm_model_cache.get((full_model_name, temperature))
|
||||
if not llm:
|
||||
llm_provider, llm_model_name = full_model_name.split('.')
|
||||
if llm_provider == "openai":
|
||||
llm = ChatOpenAI(
|
||||
api_key=current_app.config['OPENAI_API_KEY'],
|
||||
model=llm_model_name,
|
||||
temperature=temperature,
|
||||
callbacks=[llm_metrics_handler]
|
||||
)
|
||||
elif llm_provider == "mistral":
|
||||
llm = ChatMistralAI(
|
||||
api_key=current_app.config['MISTRAL_API_KEY'],
|
||||
model=llm_model_name,
|
||||
temperature=temperature,
|
||||
callbacks=[llm_metrics_handler]
|
||||
)
|
||||
|
||||
llm_model_cache[(full_model_name, temperature)] = llm
|
||||
|
||||
|
||||
class ModelVariables:
|
||||
"""Manages model-related variables and configurations"""
|
||||
|
||||
@@ -63,15 +122,13 @@ class ModelVariables:
|
||||
Initialize ModelVariables with tenant and optional template manager
|
||||
|
||||
Args:
|
||||
tenant: Tenant instance
|
||||
template_manager: Optional TemplateManager instance
|
||||
tenant_id: Tenant instance
|
||||
variables: Optional variables
|
||||
"""
|
||||
current_app.logger.info(f'Model variables initialized with tenant {tenant_id} and variables \n{variables}')
|
||||
self.tenant_id = tenant_id
|
||||
self._variables = variables if variables is not None else self._initialize_variables()
|
||||
current_app.logger.info(f'Model _variables initialized to {self._variables}')
|
||||
self._embedding_model = None
|
||||
self._embedding_model_class = None
|
||||
self._llm_instances = {}
|
||||
self.llm_metrics_handler = LLMMetricsHandler()
|
||||
self._transcription_model = None
|
||||
@@ -85,7 +142,6 @@ class ModelVariables:
|
||||
raise EveAITenantNotFound(self.tenant_id)
|
||||
|
||||
# Set model providers
|
||||
variables['embedding_provider'], variables['embedding_model'] = tenant.embedding_model.split('.')
|
||||
variables['llm_provider'], variables['llm_model'] = tenant.llm_model.split('.')
|
||||
variables['llm_full_model'] = tenant.llm_model
|
||||
|
||||
@@ -102,28 +158,6 @@ class ModelVariables:
|
||||
|
||||
return variables
|
||||
|
||||
@property
|
||||
def embedding_model(self):
|
||||
"""Get the embedding model instance"""
|
||||
if self._embedding_model is None:
|
||||
api_key = os.getenv('OPENAI_API_KEY')
|
||||
self._embedding_model = TrackedOpenAIEmbeddings(
|
||||
api_key=api_key,
|
||||
model=self._variables['embedding_model']
|
||||
)
|
||||
return self._embedding_model
|
||||
|
||||
@property
|
||||
def embedding_model_class(self):
|
||||
"""Get the embedding model class"""
|
||||
if self._embedding_model_class is None:
|
||||
if self._variables['embedding_model'] == 'text-embedding-3-large':
|
||||
self._embedding_model_class = EmbeddingLargeOpenAI
|
||||
else: # text-embedding-3-small
|
||||
self._embedding_model_class = EmbeddingSmallOpenAI
|
||||
|
||||
return self._embedding_model_class
|
||||
|
||||
@property
|
||||
def annotation_chunk_length(self):
|
||||
return self._variables['annotation_chunk_length']
|
||||
|
||||
@@ -13,14 +13,12 @@ def set_tenant_session_data(sender, user, **kwargs):
|
||||
tenant = Tenant.query.filter_by(id=user.tenant_id).first()
|
||||
session['tenant'] = tenant.to_dict()
|
||||
session['default_language'] = tenant.default_language
|
||||
session['default_embedding_model'] = tenant.embedding_model
|
||||
session['default_llm_model'] = tenant.llm_model
|
||||
|
||||
|
||||
def clear_tenant_session_data(sender, user, **kwargs):
|
||||
session.pop('tenant', None)
|
||||
session.pop('default_language', None)
|
||||
session.pop('default_embedding_model', None)
|
||||
session.pop('default_llm_model', None)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user