Files
eveAI/common/eveai_model/tracked_mistral_embeddings.py
Josako 51fd16bcc6 - RAG Specialist fully implemented new style
- Selection Specialist - VA version - fully implemented
- Correction of TRAICIE_ROLE_DEFINITION_SPECIALIST - adaptation to new style
- Removal of 'debug' statements
2025-07-10 10:39:42 +02:00

137 lines
5.4 KiB
Python

from flask import current_app
from langchain_mistralai import MistralAIEmbeddings
from typing import List, Any
import time
from common.eveai_model.eveai_embedding_base import EveAIEmbeddings
from common.utils.business_event_context import current_event
from mistralai import Mistral
class TrackedMistralAIEmbeddings(EveAIEmbeddings):
def __init__(self, model: str = "mistral_embed", batch_size: int = 10):
"""
Initialize the TrackedMistralAIEmbeddings class.
Args:
model: The embedding model to use
batch_size: Maximum number of texts to send in a single API call
"""
api_key = current_app.config['MISTRAL_API_KEY']
self.client = Mistral(
api_key=api_key
)
self.model = model
self.batch_size = batch_size
super().__init__()
def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""
Embed a list of texts, processing in batches to avoid API limitations.
Args:
texts: A list of texts to embed
Returns:
A list of embeddings, one for each input text
"""
if not texts:
return []
all_embeddings = []
# Process texts in batches
for i in range(0, len(texts), self.batch_size):
batch = texts[i:i + self.batch_size]
batch_num = i // self.batch_size + 1
start_time = time.time()
try:
result = self.client.embeddings.create(
model=self.model,
inputs=batch
)
end_time = time.time()
batch_time = end_time - start_time
batch_embeddings = [embedding.embedding for embedding in result.data]
all_embeddings.extend(batch_embeddings)
# Log metrics for this batch
metrics = {
'total_tokens': result.usage.total_tokens,
'prompt_tokens': result.usage.prompt_tokens,
'completion_tokens': result.usage.completion_tokens,
'time_elapsed': batch_time,
'interaction_type': 'Embedding',
'batch': batch_num,
'batch_size': len(batch)
}
current_event.log_llm_metrics(metrics)
# If processing multiple batches, add a small delay to avoid rate limits
if len(texts) > self.batch_size and i + self.batch_size < len(texts):
time.sleep(0.25) # 250ms pause between batches
except Exception as e:
current_app.logger.error(f"Error in embedding batch {batch_num}: {str(e)}")
# If a batch fails, try to process each text individually
for j, text in enumerate(batch):
try:
single_start_time = time.time()
single_result = self.client.embeddings.create(
model=self.model,
inputs=[text]
)
single_end_time = time.time()
# Add the single embedding
single_embedding = single_result.data[0].embedding
all_embeddings.append(single_embedding)
# Log metrics for this individual embedding
single_metrics = {
'total_tokens': single_result.usage.total_tokens,
'prompt_tokens': single_result.usage.prompt_tokens,
'completion_tokens': single_result.usage.completion_tokens,
'time_elapsed': single_end_time - single_start_time,
'interaction_type': 'Embedding',
'batch': f"{batch_num}-recovery-{j}",
'batch_size': 1
}
current_event.log_llm_metrics(single_metrics)
except Exception as inner_e:
current_app.logger.error(f"Failed to embed individual text at index {i + j}: {str(inner_e)}")
# Add a zero vector as a placeholder for failed embeddings
# Use the correct dimensionality for the model (1024 for mistral_embed)
embedding_dim = 1024
all_embeddings.append([0.0] * embedding_dim)
total_batches = (len(texts) + self.batch_size - 1) // self.batch_size
current_app.logger.info(f"Embedded {len(texts)} texts in {total_batches} batches")
return all_embeddings
# def embed_documents(self, texts: list[str]) -> list[list[float]]:
# start_time = time.time()
# result = self.client.embeddings.create(
# model=self.model,
# inputs=texts
# )
# end_time = time.time()
#
# metrics = {
# 'total_tokens': result.usage.total_tokens,
# 'prompt_tokens': result.usage.prompt_tokens, # For embeddings, all tokens are prompt tokens
# 'completion_tokens': result.usage.completion_tokens,
# 'time_elapsed': end_time - start_time,
# 'interaction_type': 'Embedding',
# }
# current_event.log_llm_metrics(metrics)
#
# embeddings = [embedding.embedding for embedding in result.data]
#
# return embeddings