- Move to Mistral iso OpenAI as primary choice

This commit is contained in:
Josako
2025-03-06 14:19:35 +01:00
parent 55a89c11bb
commit c15cabc289
11 changed files with 74 additions and 36 deletions

View File

@@ -17,8 +17,10 @@ from config.model_config import MODEL_CONFIG
from common.extensions import template_manager
from common.models.document import EmbeddingMistral
from common.utils.eveai_exceptions import EveAITenantNotFound, EveAIInvalidEmbeddingModel
from crewai import LLM
llm_model_cache: Dict[Tuple[str, float], BaseChatModel] = {}
embedding_llm_model_cache: Dict[Tuple[str, float], BaseChatModel] = {}
crewai_llm_model_cache: Dict[Tuple[str, float], LLM] = {}
llm_metrics_handler = LLMMetricsHandler()
@@ -89,11 +91,8 @@ def get_embedding_model_and_class(tenant_id, catalog_id, full_embedding_name):
return embedding_model, embedding_model_class
def get_llm(full_model_name, temperature):
if not full_model_name:
full_model_name = 'openai.gpt-4o' # Default to gpt-4o for now, as this is the original model developed against
llm = llm_model_cache.get((full_model_name, temperature))
def get_embedding_llm(full_model_name='mistral.mistral-small-latest', temperature=0.3):
llm = embedding_llm_model_cache.get((full_model_name, temperature))
if not llm:
llm_provider, llm_model_name = full_model_name.split('.')
if llm_provider == "openai":
@@ -110,8 +109,30 @@ def get_llm(full_model_name, temperature):
temperature=temperature,
callbacks=[llm_metrics_handler]
)
embedding_llm_model_cache[(full_model_name, temperature)] = llm
llm_model_cache[(full_model_name, temperature)] = llm
return llm
def get_crewai_llm(full_model_name='mistral.mistral-large-latest', temperature=0.3):
llm = crewai_llm_model_cache.get((full_model_name, temperature))
if not llm:
llm_provider, llm_model_name = full_model_name.split('.')
crew_full_model_name = f"{llm_provider}/{llm_model_name}"
api_key = None
if llm_provider == "openai":
api_key = current_app.config['OPENAI_API_KEY']
elif llm_provider == "mistral":
api_key = current_app.config['MISTRAL_API_KEY']
llm = LLM(
model=crew_full_model_name,
temperature=temperature,
api_key=api_key
)
crewai_llm_model_cache[(full_model_name, temperature)] = llm
return llm
class ModelVariables: