from flask import current_app from langchain_mistralai import MistralAIEmbeddings from typing import List, Any import time from common.eveai_model.eveai_embedding_base import EveAIEmbeddings from common.utils.business_event_context import current_event from mistralai import Mistral class TrackedMistralAIEmbeddings(EveAIEmbeddings): def __init__(self, model: str = "mistral_embed", batch_size: int = 10): """ Initialize the TrackedMistralAIEmbeddings class. Args: model: The embedding model to use batch_size: Maximum number of texts to send in a single API call """ api_key = current_app.config['MISTRAL_API_KEY'] self.client = Mistral( api_key=api_key ) self.model = model self.batch_size = batch_size super().__init__() def embed_documents(self, texts: list[str]) -> list[list[float]]: """ Embed a list of texts, processing in batches to avoid API limitations. Args: texts: A list of texts to embed Returns: A list of embeddings, one for each input text """ if not texts: return [] all_embeddings = [] # Process texts in batches for i in range(0, len(texts), self.batch_size): batch = texts[i:i + self.batch_size] batch_num = i // self.batch_size + 1 current_app.logger.debug(f"Processing embedding batch {batch_num}, size: {len(batch)}") start_time = time.time() try: result = self.client.embeddings.create( model=self.model, inputs=batch ) end_time = time.time() batch_time = end_time - start_time batch_embeddings = [embedding.embedding for embedding in result.data] all_embeddings.extend(batch_embeddings) # Log metrics for this batch metrics = { 'total_tokens': result.usage.total_tokens, 'prompt_tokens': result.usage.prompt_tokens, 'completion_tokens': result.usage.completion_tokens, 'time_elapsed': batch_time, 'interaction_type': 'Embedding', 'batch': batch_num, 'batch_size': len(batch) } current_event.log_llm_metrics(metrics) current_app.logger.debug(f"Batch {batch_num} processed: {len(batch)} texts, " f"{result.usage.total_tokens} tokens, {batch_time:.2f}s") # If processing multiple batches, add a small delay to avoid rate limits if len(texts) > self.batch_size and i + self.batch_size < len(texts): time.sleep(0.25) # 250ms pause between batches except Exception as e: current_app.logger.error(f"Error in embedding batch {batch_num}: {str(e)}") # If a batch fails, try to process each text individually for j, text in enumerate(batch): try: current_app.logger.debug(f"Attempting individual embedding for item {i + j}") single_start_time = time.time() single_result = self.client.embeddings.create( model=self.model, inputs=[text] ) single_end_time = time.time() # Add the single embedding single_embedding = single_result.data[0].embedding all_embeddings.append(single_embedding) # Log metrics for this individual embedding single_metrics = { 'total_tokens': single_result.usage.total_tokens, 'prompt_tokens': single_result.usage.prompt_tokens, 'completion_tokens': single_result.usage.completion_tokens, 'time_elapsed': single_end_time - single_start_time, 'interaction_type': 'Embedding', 'batch': f"{batch_num}-recovery-{j}", 'batch_size': 1 } current_event.log_llm_metrics(single_metrics) except Exception as inner_e: current_app.logger.error(f"Failed to embed individual text at index {i + j}: {str(inner_e)}") # Add a zero vector as a placeholder for failed embeddings # Use the correct dimensionality for the model (1024 for mistral_embed) embedding_dim = 1024 all_embeddings.append([0.0] * embedding_dim) total_batches = (len(texts) + self.batch_size - 1) // self.batch_size current_app.logger.info(f"Embedded {len(texts)} texts in {total_batches} batches") return all_embeddings # def embed_documents(self, texts: list[str]) -> list[list[float]]: # start_time = time.time() # result = self.client.embeddings.create( # model=self.model, # inputs=texts # ) # end_time = time.time() # # metrics = { # 'total_tokens': result.usage.total_tokens, # 'prompt_tokens': result.usage.prompt_tokens, # For embeddings, all tokens are prompt tokens # 'completion_tokens': result.usage.completion_tokens, # 'time_elapsed': end_time - start_time, # 'interaction_type': 'Embedding', # } # current_event.log_llm_metrics(metrics) # # embeddings = [embedding.embedding for embedding in result.data] # # return embeddings