- Move from OpenAI to Mistral Embeddings

- Move embedding model settings from tenant to catalog
- BUG: error processing configuration for chunking patterns in HTML_PROCESSOR
- Removed eveai_chat from docker-files and nginx configuration, as it is now obsolete
- BUG: error in Library Operations when creating a new default RAG library
- BUG: Added public type in migration scripts
- Removed SocketIO from all code and requirements.txt
This commit is contained in:
Josako
2025-02-25 11:17:19 +01:00
parent c037d4135e
commit 55a89c11bb
34 changed files with 457 additions and 444 deletions

View File

View File

@@ -0,0 +1,11 @@
from abc import abstractmethod
from typing import List
class EveAIEmbeddings:
@abstractmethod
def embed_documents(self, texts: List[str]) -> List[List[float]]:
pass
def embed_query(self, text: str) -> List[float]:
return self.embed_documents([text])[0]

View File

@@ -0,0 +1,40 @@
from flask import current_app
from langchain_mistralai import MistralAIEmbeddings
from typing import List, Any
import time
from common.eveai_model.eveai_embedding_base import EveAIEmbeddings
from common.utils.business_event_context import current_event
from mistralai import Mistral
class TrackedMistralAIEmbeddings(EveAIEmbeddings):
def __init__(self, model: str = "mistral_embed"):
api_key = current_app.config['MISTRAL_API_KEY']
self.client = Mistral(
api_key=api_key
)
self.model = model
super().__init__()
def embed_documents(self, texts: list[str]) -> list[list[float]]:
start_time = time.time()
result = self.client.embeddings.create(
model=self.model,
inputs=texts
)
end_time = time.time()
metrics = {
'total_tokens': result.usage.total_tokens,
'prompt_tokens': result.usage.prompt_tokens, # For embeddings, all tokens are prompt tokens
'completion_tokens': result.usage.completion_tokens,
'time_elapsed': end_time - start_time,
'interaction_type': 'Embedding',
}
current_event.log_llm_metrics(metrics)
embeddings = [embedding.embedding for embedding in result.data]
return embeddings

View File

@@ -5,7 +5,6 @@ from flask_security import Security
from flask_mailman import Mail
from flask_login import LoginManager
from flask_cors import CORS
from flask_socketio import SocketIO
from flask_jwt_extended import JWTManager
from flask_session import Session
from flask_wtf import CSRFProtect
@@ -27,7 +26,6 @@ security = Security()
mail = Mail()
login_manager = LoginManager()
cors = CORS()
socketio = SocketIO()
jwt = JWTManager()
session = Session()
api_rest = Api()

View File

@@ -12,8 +12,10 @@ class Catalog(db.Model):
description = db.Column(db.Text, nullable=True)
type = db.Column(db.String(50), nullable=False, default="STANDARD_CATALOG")
min_chunk_size = db.Column(db.Integer, nullable=True, default=2000)
max_chunk_size = db.Column(db.Integer, nullable=True, default=3000)
embedding_model = db.Column(db.String(50), nullable=True)
min_chunk_size = db.Column(db.Integer, nullable=True, default=1500)
max_chunk_size = db.Column(db.Integer, nullable=True, default=2500)
# Meta Data
user_metadata = db.Column(JSONB, nullable=True)

View File

@@ -31,7 +31,6 @@ class Tenant(db.Model):
allowed_languages = db.Column(ARRAY(sa.String(2)), nullable=True)
# LLM specific choices
embedding_model = db.Column(db.String(50), nullable=True)
llm_model = db.Column(db.String(50), nullable=True)
# Entitlements
@@ -66,7 +65,6 @@ class Tenant(db.Model):
'type': self.type,
'default_language': self.default_language,
'allowed_languages': self.allowed_languages,
'embedding_model': self.embedding_model,
'llm_model': self.llm_model,
'currency': self.currency,
}

View File

@@ -652,12 +652,15 @@ def json_to_patterns(json_content: str) -> str:
def json_to_pattern_list(json_content: str) -> list:
"""Convert JSON patterns list to text area content"""
try:
patterns = json.loads(json_content)
if not isinstance(patterns, list):
raise ValueError("JSON must contain a list of patterns")
# Unescape if needed
patterns = [pattern.replace('\\\\', '\\') for pattern in patterns]
return patterns
if json_content:
patterns = json.loads(json_content)
if not isinstance(patterns, list):
raise ValueError("JSON must contain a list of patterns")
# Unescape if needed
patterns = [pattern.replace('\\\\', '\\') for pattern in patterns]
return patterns
else:
return []
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON format: {e}")

View File

@@ -124,4 +124,15 @@ class EveAISocketInputException(EveAIException):
"""Raised when a socket call receives an invalid payload"""
def __init__(self, message, status_code=400, payload=None):
super.__init__(message, status_code, payload)
super.__init__(message, status_code, payload)
class EveAIInvalidEmbeddingModel(EveAIException):
"""Raised when no or an invalid embedding model is provided in the catalog"""
def __init__(self, tenant_id, catalog_id, status_code=400, payload=None):
self.tenant_id = tenant_id
self.catalog_id = catalog_id
# Construct the message dynamically
message = f"Tenant with ID '{tenant_id}' has no or an invalid embedding model in Catalog {catalog_id}."
super().__init__(message, status_code, payload)

View File

@@ -1,23 +1,25 @@
import os
from typing import Dict, Any, Optional
from typing import Dict, Any, Optional, Tuple
import langcodes
from langchain_core.language_models import BaseChatModel
from common.langchain.llm_metrics_handler import LLMMetricsHandler
from common.langchain.templates.template_manager import TemplateManager
from langchain_openai import OpenAIEmbeddings, ChatOpenAI, OpenAI
from langchain_openai import ChatOpenAI
from langchain_anthropic import ChatAnthropic
from langchain_mistralai import ChatMistralAI
from flask import current_app
from datetime import datetime as dt, timezone as tz
from common.langchain.tracked_openai_embeddings import TrackedOpenAIEmbeddings
from common.eveai_model.tracked_mistral_embeddings import TrackedMistralAIEmbeddings
from common.langchain.tracked_transcription import TrackedOpenAITranscription
from common.models.user import Tenant
from common.utils.cache.base import CacheHandler
from config.model_config import MODEL_CONFIG
from common.extensions import template_manager, cache_manager
from common.models.document import EmbeddingLargeOpenAI, EmbeddingSmallOpenAI
from common.utils.eveai_exceptions import EveAITenantNotFound
from common.extensions import template_manager
from common.models.document import EmbeddingMistral
from common.utils.eveai_exceptions import EveAITenantNotFound, EveAIInvalidEmbeddingModel
llm_model_cache: Dict[Tuple[str, float], BaseChatModel] = {}
llm_metrics_handler = LLMMetricsHandler()
def create_language_template(template: str, language: str) -> str:
@@ -55,6 +57,63 @@ def replace_variable_in_template(template: str, variable: str, value: str) -> st
return template.replace(variable, value or "")
def get_embedding_model_and_class(tenant_id, catalog_id, full_embedding_name):
"""
Retrieve the embedding model and embedding model class to store Embeddings
Args:
tenant_id: ID of the tenant
catalog_id: ID of the catalog
full_embedding_name: The full name of the embedding model: <provider>.<model>
Returns:
embedding_model, embedding_model_class
"""
embedding_provider, embedding_model_name = full_embedding_name.split('.')
# Calculate the embedding model to be used
if embedding_provider == "mistral":
api_key = current_app.config['MISTRAL_API_KEY']
embedding_model = TrackedMistralAIEmbeddings(
model=embedding_model_name
)
else:
raise EveAIInvalidEmbeddingModel(tenant_id, catalog_id)
# Calculate the Embedding Model Class to be used to store embeddings
if embedding_model_name == "mistral-embed":
embedding_model_class = EmbeddingMistral
else:
raise EveAIInvalidEmbeddingModel(tenant_id, catalog_id)
return embedding_model, embedding_model_class
def get_llm(full_model_name, temperature):
if not full_model_name:
full_model_name = 'openai.gpt-4o' # Default to gpt-4o for now, as this is the original model developed against
llm = llm_model_cache.get((full_model_name, temperature))
if not llm:
llm_provider, llm_model_name = full_model_name.split('.')
if llm_provider == "openai":
llm = ChatOpenAI(
api_key=current_app.config['OPENAI_API_KEY'],
model=llm_model_name,
temperature=temperature,
callbacks=[llm_metrics_handler]
)
elif llm_provider == "mistral":
llm = ChatMistralAI(
api_key=current_app.config['MISTRAL_API_KEY'],
model=llm_model_name,
temperature=temperature,
callbacks=[llm_metrics_handler]
)
llm_model_cache[(full_model_name, temperature)] = llm
class ModelVariables:
"""Manages model-related variables and configurations"""
@@ -63,15 +122,13 @@ class ModelVariables:
Initialize ModelVariables with tenant and optional template manager
Args:
tenant: Tenant instance
template_manager: Optional TemplateManager instance
tenant_id: Tenant instance
variables: Optional variables
"""
current_app.logger.info(f'Model variables initialized with tenant {tenant_id} and variables \n{variables}')
self.tenant_id = tenant_id
self._variables = variables if variables is not None else self._initialize_variables()
current_app.logger.info(f'Model _variables initialized to {self._variables}')
self._embedding_model = None
self._embedding_model_class = None
self._llm_instances = {}
self.llm_metrics_handler = LLMMetricsHandler()
self._transcription_model = None
@@ -85,7 +142,6 @@ class ModelVariables:
raise EveAITenantNotFound(self.tenant_id)
# Set model providers
variables['embedding_provider'], variables['embedding_model'] = tenant.embedding_model.split('.')
variables['llm_provider'], variables['llm_model'] = tenant.llm_model.split('.')
variables['llm_full_model'] = tenant.llm_model
@@ -102,28 +158,6 @@ class ModelVariables:
return variables
@property
def embedding_model(self):
"""Get the embedding model instance"""
if self._embedding_model is None:
api_key = os.getenv('OPENAI_API_KEY')
self._embedding_model = TrackedOpenAIEmbeddings(
api_key=api_key,
model=self._variables['embedding_model']
)
return self._embedding_model
@property
def embedding_model_class(self):
"""Get the embedding model class"""
if self._embedding_model_class is None:
if self._variables['embedding_model'] == 'text-embedding-3-large':
self._embedding_model_class = EmbeddingLargeOpenAI
else: # text-embedding-3-small
self._embedding_model_class = EmbeddingSmallOpenAI
return self._embedding_model_class
@property
def annotation_chunk_length(self):
return self._variables['annotation_chunk_length']

View File

@@ -13,14 +13,12 @@ def set_tenant_session_data(sender, user, **kwargs):
tenant = Tenant.query.filter_by(id=user.tenant_id).first()
session['tenant'] = tenant.to_dict()
session['default_language'] = tenant.default_language
session['default_embedding_model'] = tenant.embedding_model
session['default_llm_model'] = tenant.llm_model
def clear_tenant_session_data(sender, user, **kwargs):
session.pop('tenant', None)
session.pop('default_language', None)
session.pop('default_embedding_model', None)
session.pop('default_llm_model', None)