- Significantly changed the PDF Processor to use Mistral's OCR model

- ensure very long chunks get split into smaller chunks
- ensure TrackedMistralAIEmbedding is batched if needed to ensure correct execution
- upgraded some of the packages to a higher version
This commit is contained in:
Josako
2025-04-16 15:39:16 +02:00
parent 5f58417d24
commit 4bf12db142
10 changed files with 518 additions and 91 deletions

View File

@@ -9,32 +9,133 @@ from mistralai import Mistral
class TrackedMistralAIEmbeddings(EveAIEmbeddings):
def __init__(self, model: str = "mistral_embed"):
def __init__(self, model: str = "mistral_embed", batch_size: int = 10):
"""
Initialize the TrackedMistralAIEmbeddings class.
Args:
model: The embedding model to use
batch_size: Maximum number of texts to send in a single API call
"""
api_key = current_app.config['MISTRAL_API_KEY']
self.client = Mistral(
api_key=api_key
)
self.model = model
self.batch_size = batch_size
super().__init__()
def embed_documents(self, texts: list[str]) -> list[list[float]]:
start_time = time.time()
result = self.client.embeddings.create(
model=self.model,
inputs=texts
)
end_time = time.time()
"""
Embed a list of texts, processing in batches to avoid API limitations.
metrics = {
'total_tokens': result.usage.total_tokens,
'prompt_tokens': result.usage.prompt_tokens, # For embeddings, all tokens are prompt tokens
'completion_tokens': result.usage.completion_tokens,
'time_elapsed': end_time - start_time,
'interaction_type': 'Embedding',
}
current_event.log_llm_metrics(metrics)
Args:
texts: A list of texts to embed
embeddings = [embedding.embedding for embedding in result.data]
Returns:
A list of embeddings, one for each input text
"""
if not texts:
return []
return embeddings
all_embeddings = []
# Process texts in batches
for i in range(0, len(texts), self.batch_size):
batch = texts[i:i + self.batch_size]
batch_num = i // self.batch_size + 1
current_app.logger.debug(f"Processing embedding batch {batch_num}, size: {len(batch)}")
start_time = time.time()
try:
result = self.client.embeddings.create(
model=self.model,
inputs=batch
)
end_time = time.time()
batch_time = end_time - start_time
batch_embeddings = [embedding.embedding for embedding in result.data]
all_embeddings.extend(batch_embeddings)
# Log metrics for this batch
metrics = {
'total_tokens': result.usage.total_tokens,
'prompt_tokens': result.usage.prompt_tokens,
'completion_tokens': result.usage.completion_tokens,
'time_elapsed': batch_time,
'interaction_type': 'Embedding',
'batch': batch_num,
'batch_size': len(batch)
}
current_event.log_llm_metrics(metrics)
current_app.logger.debug(f"Batch {batch_num} processed: {len(batch)} texts, "
f"{result.usage.total_tokens} tokens, {batch_time:.2f}s")
# If processing multiple batches, add a small delay to avoid rate limits
if len(texts) > self.batch_size and i + self.batch_size < len(texts):
time.sleep(0.25) # 250ms pause between batches
except Exception as e:
current_app.logger.error(f"Error in embedding batch {batch_num}: {str(e)}")
# If a batch fails, try to process each text individually
for j, text in enumerate(batch):
try:
current_app.logger.debug(f"Attempting individual embedding for item {i + j}")
single_start_time = time.time()
single_result = self.client.embeddings.create(
model=self.model,
inputs=[text]
)
single_end_time = time.time()
# Add the single embedding
single_embedding = single_result.data[0].embedding
all_embeddings.append(single_embedding)
# Log metrics for this individual embedding
single_metrics = {
'total_tokens': single_result.usage.total_tokens,
'prompt_tokens': single_result.usage.prompt_tokens,
'completion_tokens': single_result.usage.completion_tokens,
'time_elapsed': single_end_time - single_start_time,
'interaction_type': 'Embedding',
'batch': f"{batch_num}-recovery-{j}",
'batch_size': 1
}
current_event.log_llm_metrics(single_metrics)
except Exception as inner_e:
current_app.logger.error(f"Failed to embed individual text at index {i + j}: {str(inner_e)}")
# Add a zero vector as a placeholder for failed embeddings
# Use the correct dimensionality for the model (1024 for mistral_embed)
embedding_dim = 1024
all_embeddings.append([0.0] * embedding_dim)
total_batches = (len(texts) + self.batch_size - 1) // self.batch_size
current_app.logger.info(f"Embedded {len(texts)} texts in {total_batches} batches")
return all_embeddings
# def embed_documents(self, texts: list[str]) -> list[list[float]]:
# start_time = time.time()
# result = self.client.embeddings.create(
# model=self.model,
# inputs=texts
# )
# end_time = time.time()
#
# metrics = {
# 'total_tokens': result.usage.total_tokens,
# 'prompt_tokens': result.usage.prompt_tokens, # For embeddings, all tokens are prompt tokens
# 'completion_tokens': result.usage.completion_tokens,
# 'time_elapsed': end_time - start_time,
# 'interaction_type': 'Embedding',
# }
# current_event.log_llm_metrics(metrics)
#
# embeddings = [embedding.embedding for embedding in result.data]
#
# return embeddings

View File

@@ -0,0 +1,53 @@
import re
import time
from flask import current_app
from mistralai import Mistral
from common.utils.business_event_context import current_event
class TrackedMistralOcrClient:
def __init__(self):
api_key = current_app.config['MISTRAL_API_KEY']
self.client = Mistral(
api_key=api_key,
)
self.model = "mistral-ocr-latest"
def _get_title(self, markdown):
# Look for the first level-1 heading
match = re.search(r'^# (.+)', markdown, re.MULTILINE)
return match.group(1).strip() if match else None
def process_pdf(self, file_name, file_content):
start_time = time.time()
uploaded_pdf = self.client.files.upload(
file={
"file_name": file_name,
"content": file_content
},
purpose="ocr"
)
signed_url = self.client.files.get_signed_url(file_id=uploaded_pdf.id)
ocr_response = self.client.ocr.process(
model=self.model,
document={
"type": "document_url",
"document_url": signed_url.url
},
include_image_base64=False
)
nr_of_pages = len(ocr_response.pages)
all_markdown = " ".join(page.markdown for page in ocr_response.pages)
title = self._get_title(all_markdown)
end_time = time.time()
metrics = {
'nr_of_pages': nr_of_pages,
'time_elapsed': end_time - start_time,
'interaction_type': 'OCR',
}
current_event.log_llm_metrics(metrics)
return all_markdown, title

View File

@@ -25,6 +25,7 @@ class BusinessEventLog(db.Model):
llm_metrics_prompt_tokens = db.Column(db.Integer)
llm_metrics_completion_tokens = db.Column(db.Integer)
llm_metrics_total_time = db.Column(db.Float)
llm_metrics_nr_of_pages = db.Column(db.Integer)
llm_metrics_call_count = db.Column(db.Integer)
llm_interaction_type = db.Column(db.String(20))
message = db.Column(db.Text)

View File

@@ -106,6 +106,7 @@ class BusinessEvent:
'total_tokens': 0,
'prompt_tokens': 0,
'completion_tokens': 0,
'nr_of_pages': 0,
'total_time': 0,
'call_count': 0,
'interaction_type': None
@@ -121,13 +122,6 @@ class BusinessEvent:
if self.specialist_type_version else ""
self.span_name_str = ""
current_app.logger.debug(f"Labels for metrics: "
f"tenant_id={self.tenant_id_str}, "
f"event_type={self.event_type_str},"
f"specialist_id={self.specialist_id_str}, "
f"specialist_type={self.specialist_type_str}, " +
f"specialist_type_version={self.specialist_type_version_str}")
# Increment concurrent events gauge when initialized
CONCURRENT_TRACES.labels(
tenant_id=self.tenant_id_str,
@@ -168,24 +162,17 @@ class BusinessEvent:
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{attribute}'")
def update_llm_metrics(self, metrics: dict):
self.llm_metrics['total_tokens'] += metrics['total_tokens']
self.llm_metrics['prompt_tokens'] += metrics['prompt_tokens']
self.llm_metrics['completion_tokens'] += metrics['completion_tokens']
self.llm_metrics['total_time'] += metrics['time_elapsed']
self.llm_metrics['total_tokens'] += metrics.get('total_tokens', 0)
self.llm_metrics['prompt_tokens'] += metrics.get('prompt_tokens', 0)
self.llm_metrics['completion_tokens'] += metrics.get('completion_tokens', 0)
self.llm_metrics['nr_of_pages'] += metrics.get('nr_of_pages', 0)
self.llm_metrics['total_time'] += metrics.get('time_elapsed', 0)
self.llm_metrics['call_count'] += 1
self.llm_metrics['interaction_type'] = metrics['interaction_type']
# Track in Prometheus metrics
interaction_type_str = sanitize_label(metrics['interaction_type']) if metrics['interaction_type'] else ""
current_app.logger.debug(f"Labels for metrics: "
f"tenant_id={self.tenant_id_str}, "
f"event_type={self.event_type_str},"
f"interaction_type={interaction_type_str}, "
f"specialist_id={self.specialist_id_str}, "
f"specialist_type={self.specialist_type_str}, "
f"specialist_type_version={self.specialist_type_version_str}")
# Track token usage
LLM_TOKENS_COUNTER.labels(
tenant_id=self.tenant_id_str,
@@ -195,7 +182,7 @@ class BusinessEvent:
specialist_id=self.specialist_id_str,
specialist_type=self.specialist_type_str,
specialist_type_version=self.specialist_type_version_str
).inc(metrics['total_tokens'])
).inc(metrics.get('total_tokens', 0))
LLM_TOKENS_COUNTER.labels(
tenant_id=self.tenant_id_str,
@@ -205,7 +192,7 @@ class BusinessEvent:
specialist_id=self.specialist_id_str,
specialist_type=self.specialist_type_str,
specialist_type_version=self.specialist_type_version_str
).inc(metrics['prompt_tokens'])
).inc(metrics.get('prompt_tokens', 0))
LLM_TOKENS_COUNTER.labels(
tenant_id=self.tenant_id_str,
@@ -215,7 +202,7 @@ class BusinessEvent:
specialist_id=self.specialist_id_str,
specialist_type=self.specialist_type_str,
specialist_type_version=self.specialist_type_version_str
).inc(metrics['completion_tokens'])
).inc(metrics.get('completion_tokens', 0))
# Track duration
LLM_DURATION.labels(
@@ -225,7 +212,7 @@ class BusinessEvent:
specialist_id=self.specialist_id_str,
specialist_type=self.specialist_type_str,
specialist_type_version=self.specialist_type_version_str
).observe(metrics['time_elapsed'])
).observe(metrics.get('time_elapsed', 0))
# Track call count
LLM_CALLS_COUNTER.labels(
@@ -243,6 +230,7 @@ class BusinessEvent:
self.llm_metrics['total_tokens'] = 0
self.llm_metrics['prompt_tokens'] = 0
self.llm_metrics['completion_tokens'] = 0
self.llm_metrics['nr_of_pages'] = 0
self.llm_metrics['total_time'] = 0
self.llm_metrics['call_count'] = 0
self.llm_metrics['interaction_type'] = None
@@ -270,14 +258,6 @@ class BusinessEvent:
# Track start time for the span
span_start_time = time.time()
current_app.logger.debug(f"Labels for metrics: "
f"tenant_id={self.tenant_id_str}, "
f"event_type={self.event_type_str}, "
f"activity_name={self.span_name_str}, "
f"specialist_id={self.specialist_id_str}, "
f"specialist_type={self.specialist_type_str}, "
f"specialist_type_version={self.specialist_type_version_str}")
# Increment span metrics - using span_name as activity_name for metrics
SPAN_COUNTER.labels(
tenant_id=self.tenant_id_str,
@@ -363,14 +343,6 @@ class BusinessEvent:
# Track start time for the span
span_start_time = time.time()
current_app.logger.debug(f"Labels for metrics: "
f"tenant_id={self.tenant_id_str}, "
f"event_type={self.event_type_str}, "
f"activity_name={self.span_name_str}, "
f"specialist_id={self.specialist_id_str}, "
f"specialist_type={self.specialist_type_str}, "
f"specialist_type_version={self.specialist_type_version_str}")
# Increment span metrics - using span_name as activity_name for metrics
SPAN_COUNTER.labels(
tenant_id=self.tenant_id_str,
@@ -487,10 +459,11 @@ class BusinessEvent:
'specialist_type': self.specialist_type,
'specialist_type_version': self.specialist_type_version,
'environment': self.environment,
'llm_metrics_total_tokens': metrics['total_tokens'],
'llm_metrics_prompt_tokens': metrics['prompt_tokens'],
'llm_metrics_completion_tokens': metrics['completion_tokens'],
'llm_metrics_total_time': metrics['time_elapsed'],
'llm_metrics_total_tokens': metrics.get('total_tokens', 0),
'llm_metrics_prompt_tokens': metrics.get('prompt_tokens', 0),
'llm_metrics_completion_tokens': metrics.get('completion_tokens', 0),
'llm_metrics_nr_of_pages': metrics.get('nr_of_pages', 0),
'llm_metrics_total_time': metrics.get('time_elapsed', 0),
'llm_interaction_type': metrics['interaction_type'],
'message': message,
}
@@ -518,6 +491,7 @@ class BusinessEvent:
'llm_metrics_total_tokens': self.llm_metrics['total_tokens'],
'llm_metrics_prompt_tokens': self.llm_metrics['prompt_tokens'],
'llm_metrics_completion_tokens': self.llm_metrics['completion_tokens'],
'llm_metrics_nr_of_pages': self.llm_metrics['nr_of_pages'],
'llm_metrics_total_time': self.llm_metrics['total_time'],
'llm_metrics_call_count': self.llm_metrics['call_count'],
'llm_interaction_type': self.llm_metrics['interaction_type'],

View File

@@ -135,6 +135,11 @@ def get_crewai_llm(full_model_name='mistral.mistral-large-latest', temperature=0
return llm
def process_pdf():
full_model_name = 'mistral-ocr-latest'
class ModelVariables:
"""Manages model-related variables and configurations"""