- Significantly changed the PDF Processor to use Mistral's OCR model
- ensure very long chunks get split into smaller chunks - ensure TrackedMistralAIEmbedding is batched if needed to ensure correct execution - upgraded some of the packages to a higher version
This commit is contained in:
@@ -9,32 +9,133 @@ from mistralai import Mistral
|
||||
|
||||
|
||||
class TrackedMistralAIEmbeddings(EveAIEmbeddings):
|
||||
def __init__(self, model: str = "mistral_embed"):
|
||||
def __init__(self, model: str = "mistral_embed", batch_size: int = 10):
|
||||
"""
|
||||
Initialize the TrackedMistralAIEmbeddings class.
|
||||
|
||||
Args:
|
||||
model: The embedding model to use
|
||||
batch_size: Maximum number of texts to send in a single API call
|
||||
"""
|
||||
api_key = current_app.config['MISTRAL_API_KEY']
|
||||
self.client = Mistral(
|
||||
api_key=api_key
|
||||
)
|
||||
self.model = model
|
||||
self.batch_size = batch_size
|
||||
super().__init__()
|
||||
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
start_time = time.time()
|
||||
result = self.client.embeddings.create(
|
||||
model=self.model,
|
||||
inputs=texts
|
||||
)
|
||||
end_time = time.time()
|
||||
"""
|
||||
Embed a list of texts, processing in batches to avoid API limitations.
|
||||
|
||||
metrics = {
|
||||
'total_tokens': result.usage.total_tokens,
|
||||
'prompt_tokens': result.usage.prompt_tokens, # For embeddings, all tokens are prompt tokens
|
||||
'completion_tokens': result.usage.completion_tokens,
|
||||
'time_elapsed': end_time - start_time,
|
||||
'interaction_type': 'Embedding',
|
||||
}
|
||||
current_event.log_llm_metrics(metrics)
|
||||
Args:
|
||||
texts: A list of texts to embed
|
||||
|
||||
embeddings = [embedding.embedding for embedding in result.data]
|
||||
Returns:
|
||||
A list of embeddings, one for each input text
|
||||
"""
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
return embeddings
|
||||
all_embeddings = []
|
||||
|
||||
# Process texts in batches
|
||||
for i in range(0, len(texts), self.batch_size):
|
||||
batch = texts[i:i + self.batch_size]
|
||||
batch_num = i // self.batch_size + 1
|
||||
current_app.logger.debug(f"Processing embedding batch {batch_num}, size: {len(batch)}")
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
result = self.client.embeddings.create(
|
||||
model=self.model,
|
||||
inputs=batch
|
||||
)
|
||||
end_time = time.time()
|
||||
batch_time = end_time - start_time
|
||||
|
||||
batch_embeddings = [embedding.embedding for embedding in result.data]
|
||||
all_embeddings.extend(batch_embeddings)
|
||||
|
||||
# Log metrics for this batch
|
||||
metrics = {
|
||||
'total_tokens': result.usage.total_tokens,
|
||||
'prompt_tokens': result.usage.prompt_tokens,
|
||||
'completion_tokens': result.usage.completion_tokens,
|
||||
'time_elapsed': batch_time,
|
||||
'interaction_type': 'Embedding',
|
||||
'batch': batch_num,
|
||||
'batch_size': len(batch)
|
||||
}
|
||||
current_event.log_llm_metrics(metrics)
|
||||
|
||||
current_app.logger.debug(f"Batch {batch_num} processed: {len(batch)} texts, "
|
||||
f"{result.usage.total_tokens} tokens, {batch_time:.2f}s")
|
||||
|
||||
# If processing multiple batches, add a small delay to avoid rate limits
|
||||
if len(texts) > self.batch_size and i + self.batch_size < len(texts):
|
||||
time.sleep(0.25) # 250ms pause between batches
|
||||
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error in embedding batch {batch_num}: {str(e)}")
|
||||
# If a batch fails, try to process each text individually
|
||||
for j, text in enumerate(batch):
|
||||
try:
|
||||
current_app.logger.debug(f"Attempting individual embedding for item {i + j}")
|
||||
single_start_time = time.time()
|
||||
single_result = self.client.embeddings.create(
|
||||
model=self.model,
|
||||
inputs=[text]
|
||||
)
|
||||
single_end_time = time.time()
|
||||
|
||||
# Add the single embedding
|
||||
single_embedding = single_result.data[0].embedding
|
||||
all_embeddings.append(single_embedding)
|
||||
|
||||
# Log metrics for this individual embedding
|
||||
single_metrics = {
|
||||
'total_tokens': single_result.usage.total_tokens,
|
||||
'prompt_tokens': single_result.usage.prompt_tokens,
|
||||
'completion_tokens': single_result.usage.completion_tokens,
|
||||
'time_elapsed': single_end_time - single_start_time,
|
||||
'interaction_type': 'Embedding',
|
||||
'batch': f"{batch_num}-recovery-{j}",
|
||||
'batch_size': 1
|
||||
}
|
||||
current_event.log_llm_metrics(single_metrics)
|
||||
|
||||
except Exception as inner_e:
|
||||
current_app.logger.error(f"Failed to embed individual text at index {i + j}: {str(inner_e)}")
|
||||
# Add a zero vector as a placeholder for failed embeddings
|
||||
# Use the correct dimensionality for the model (1024 for mistral_embed)
|
||||
embedding_dim = 1024
|
||||
all_embeddings.append([0.0] * embedding_dim)
|
||||
|
||||
total_batches = (len(texts) + self.batch_size - 1) // self.batch_size
|
||||
current_app.logger.info(f"Embedded {len(texts)} texts in {total_batches} batches")
|
||||
|
||||
return all_embeddings
|
||||
|
||||
# def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
# start_time = time.time()
|
||||
# result = self.client.embeddings.create(
|
||||
# model=self.model,
|
||||
# inputs=texts
|
||||
# )
|
||||
# end_time = time.time()
|
||||
#
|
||||
# metrics = {
|
||||
# 'total_tokens': result.usage.total_tokens,
|
||||
# 'prompt_tokens': result.usage.prompt_tokens, # For embeddings, all tokens are prompt tokens
|
||||
# 'completion_tokens': result.usage.completion_tokens,
|
||||
# 'time_elapsed': end_time - start_time,
|
||||
# 'interaction_type': 'Embedding',
|
||||
# }
|
||||
# current_event.log_llm_metrics(metrics)
|
||||
#
|
||||
# embeddings = [embedding.embedding for embedding in result.data]
|
||||
#
|
||||
# return embeddings
|
||||
|
||||
|
||||
53
common/eveai_model/tracked_mistral_ocr_client.py
Normal file
53
common/eveai_model/tracked_mistral_ocr_client.py
Normal file
@@ -0,0 +1,53 @@
|
||||
import re
|
||||
import time
|
||||
|
||||
from flask import current_app
|
||||
from mistralai import Mistral
|
||||
|
||||
from common.utils.business_event_context import current_event
|
||||
|
||||
|
||||
class TrackedMistralOcrClient:
|
||||
def __init__(self):
|
||||
api_key = current_app.config['MISTRAL_API_KEY']
|
||||
self.client = Mistral(
|
||||
api_key=api_key,
|
||||
)
|
||||
self.model = "mistral-ocr-latest"
|
||||
|
||||
def _get_title(self, markdown):
|
||||
# Look for the first level-1 heading
|
||||
match = re.search(r'^# (.+)', markdown, re.MULTILINE)
|
||||
return match.group(1).strip() if match else None
|
||||
|
||||
def process_pdf(self, file_name, file_content):
|
||||
start_time = time.time()
|
||||
uploaded_pdf = self.client.files.upload(
|
||||
file={
|
||||
"file_name": file_name,
|
||||
"content": file_content
|
||||
},
|
||||
purpose="ocr"
|
||||
)
|
||||
signed_url = self.client.files.get_signed_url(file_id=uploaded_pdf.id)
|
||||
ocr_response = self.client.ocr.process(
|
||||
model=self.model,
|
||||
document={
|
||||
"type": "document_url",
|
||||
"document_url": signed_url.url
|
||||
},
|
||||
include_image_base64=False
|
||||
)
|
||||
nr_of_pages = len(ocr_response.pages)
|
||||
all_markdown = " ".join(page.markdown for page in ocr_response.pages)
|
||||
title = self._get_title(all_markdown)
|
||||
end_time = time.time()
|
||||
|
||||
metrics = {
|
||||
'nr_of_pages': nr_of_pages,
|
||||
'time_elapsed': end_time - start_time,
|
||||
'interaction_type': 'OCR',
|
||||
}
|
||||
current_event.log_llm_metrics(metrics)
|
||||
|
||||
return all_markdown, title
|
||||
Reference in New Issue
Block a user