- Improvements on audio processing to limit CPU and memory usage

- Removed Portkey from the equation, and defined explicit monitoring using Langchain native code
- Optimization of Business Event logging
This commit is contained in:
Josako
2024-10-02 14:11:46 +02:00
parent 883175b8f5
commit b700cfac64
13 changed files with 450 additions and 228 deletions

View File

@@ -25,6 +25,28 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Security ### Security
- In case of vulnerabilities. - In case of vulnerabilities.
## [1.0.9-alfa] - 2024/10/01
### Added
- Business Event tracing (eveai_workers & eveai_chat_workers)
- Flower Container added for monitoring
### Changed
- Healthcheck improvements
- model_utils turned into a class with lazy loading
### Deprecated
- For soon-to-be removed features.
### Removed
- For now removed features.
### Fixed
- Set default language when registering Documents or URLs.
### Security
- In case of vulnerabilities.
## [1.0.8-alfa] - 2024-09-12 ## [1.0.8-alfa] - 2024-09-12
### Added ### Added

View File

@@ -0,0 +1,49 @@
import time
from langchain.callbacks.base import BaseCallbackHandler
from typing import Dict, Any, List
from langchain.schema import LLMResult
from common.utils.business_event_context import current_event
from flask import current_app
class LLMMetricsHandler(BaseCallbackHandler):
def __init__(self):
self.total_tokens: int = 0
self.prompt_tokens: int = 0
self.completion_tokens: int = 0
self.start_time: float = 0
self.end_time: float = 0
self.total_time: float = 0
def reset(self):
self.total_tokens = 0
self.prompt_tokens = 0
self.completion_tokens = 0
self.start_time = 0
self.end_time = 0
self.total_time = 0
def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> None:
self.start_time = time.time()
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
self.end_time = time.time()
self.total_time = self.end_time - self.start_time
usage = response.llm_output.get('token_usage', {})
self.prompt_tokens += usage.get('prompt_tokens', 0)
self.completion_tokens += usage.get('completion_tokens', 0)
self.total_tokens = self.prompt_tokens + self.completion_tokens
metrics = self.get_metrics()
current_event.log_llm_metrics(metrics)
self.reset() # Reset for the next call
def get_metrics(self) -> Dict[str, int | float]:
return {
'total_tokens': self.total_tokens,
'prompt_tokens': self.prompt_tokens,
'completion_tokens': self.completion_tokens,
'time_elapsed': self.total_time,
'interaction_type': 'LLM',
}

View File

@@ -0,0 +1,51 @@
from langchain_openai import OpenAIEmbeddings
from typing import List, Any
import time
from common.utils.business_event_context import current_event
class TrackedOpenAIEmbeddings(OpenAIEmbeddings):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def embed_documents(self, texts: list[str]) -> list[list[float]]:
start_time = time.time()
result = super().embed_documents(texts)
end_time = time.time()
# Estimate token usage (OpenAI uses tiktoken for this)
import tiktoken
enc = tiktoken.encoding_for_model(self.model)
total_tokens = sum(len(enc.encode(text)) for text in texts)
metrics = {
'total_tokens': total_tokens,
'prompt_tokens': total_tokens, # For embeddings, all tokens are prompt tokens
'completion_tokens': 0,
'time_elapsed': end_time - start_time,
'interaction_type': 'Embedding',
}
current_event.log_llm_metrics(metrics)
return result
def embed_query(self, text: str) -> List[float]:
start_time = time.time()
result = super().embed_query(text)
end_time = time.time()
# Estimate token usage
import tiktoken
enc = tiktoken.encoding_for_model(self.model)
total_tokens = len(enc.encode(text))
metrics = {
'total_tokens': total_tokens,
'prompt_tokens': total_tokens,
'completion_tokens': 0,
'time_elapsed': end_time - start_time,
'interaction_type': 'Embedding',
}
current_event.log_llm_metrics(metrics)
return result

View File

@@ -0,0 +1,27 @@
import time
from common.utils.business_event_context import current_event
def tracked_transcribe(client, *args, **kwargs):
start_time = time.time()
# Extract the file and model from kwargs if present, otherwise use defaults
file = kwargs.get('file')
model = kwargs.get('model', 'whisper-1')
duration = kwargs.pop('duration', 600)
result = client.audio.transcriptions.create(*args, **kwargs)
end_time = time.time()
# Token usage for transcriptions is actually the duration in seconds we pass, as the whisper model is priced per second transcribed
metrics = {
'total_tokens': duration,
'prompt_tokens': 0, # For transcriptions, all tokens are considered "completion"
'completion_tokens': duration,
'time_elapsed': end_time - start_time,
'interaction_type': 'ASR',
}
current_event.log_llm_metrics(metrics)
return result

View File

@@ -17,5 +17,11 @@ class BusinessEventLog(db.Model):
chat_session_id = db.Column(db.String(50)) chat_session_id = db.Column(db.String(50))
interaction_id = db.Column(db.Integer) interaction_id = db.Column(db.Integer)
environment = db.Column(db.String(20)) environment = db.Column(db.String(20))
llm_metrics_total_tokens = db.Column(db.Integer)
llm_metrics_prompt_tokens = db.Column(db.Integer)
llm_metrics_completion_tokens = db.Column(db.Integer)
llm_metrics_total_time = db.Column(db.Float)
llm_metrics_call_count = db.Column(db.Integer)
llm_interaction_type = db.Column(db.String(20))
message = db.Column(db.Text) message = db.Column(db.Text)
# Add any other fields relevant for invoicing or warnings # Add any other fields relevant for invoicing or warnings

View File

@@ -30,6 +30,14 @@ class BusinessEvent:
self.environment = os.environ.get("FLASK_ENV", "development") self.environment = os.environ.get("FLASK_ENV", "development")
self.span_counter = 0 self.span_counter = 0
self.spans = [] self.spans = []
self.llm_metrics = {
'total_tokens': 0,
'prompt_tokens': 0,
'completion_tokens': 0,
'total_time': 0,
'call_count': 0,
'interaction_type': None
}
def update_attribute(self, attribute: str, value: any): def update_attribute(self, attribute: str, value: any):
if hasattr(self, attribute): if hasattr(self, attribute):
@@ -37,6 +45,22 @@ class BusinessEvent:
else: else:
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{attribute}'") raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{attribute}'")
def update_llm_metrics(self, metrics: dict):
self.llm_metrics['total_tokens'] += metrics['total_tokens']
self.llm_metrics['prompt_tokens'] += metrics['prompt_tokens']
self.llm_metrics['completion_tokens'] += metrics['completion_tokens']
self.llm_metrics['total_time'] += metrics['time_elapsed']
self.llm_metrics['call_count'] += 1
self.llm_metrics['interaction_type'] = metrics['interaction_type']
def reset_llm_metrics(self):
self.llm_metrics['total_tokens'] = 0
self.llm_metrics['prompt_tokens'] = 0
self.llm_metrics['completion_tokens'] = 0
self.llm_metrics['total_time'] = 0
self.llm_metrics['call_count'] = 0
self.llm_metrics['interaction_type'] = None
@contextmanager @contextmanager
def create_span(self, span_name: str): def create_span(self, span_name: str):
# The create_span method is designed to be used as a context manager. We want to perform some actions when # The create_span method is designed to be used as a context manager. We want to perform some actions when
@@ -61,6 +85,9 @@ class BusinessEvent:
try: try:
yield yield
finally: finally:
if self.llm_metrics['call_count'] > 0:
self.log_final_metrics()
self.reset_llm_metrics()
self.log(f"Ending span {span_name}") self.log(f"Ending span {span_name}")
# Restore the previous span info # Restore the previous span info
if self.spans: if self.spans:
@@ -82,7 +109,7 @@ class BusinessEvent:
'document_version_id': self.document_version_id, 'document_version_id': self.document_version_id,
'chat_session_id': self.chat_session_id, 'chat_session_id': self.chat_session_id,
'interaction_id': self.interaction_id, 'interaction_id': self.interaction_id,
'environment': self.environment 'environment': self.environment,
} }
# log to Graylog # log to Graylog
getattr(logger, level)(message, extra=log_data) getattr(logger, level)(message, extra=log_data)
@@ -105,10 +132,108 @@ class BusinessEvent:
db.session.add(event_log) db.session.add(event_log)
db.session.commit() db.session.commit()
def log_llm_metrics(self, metrics: dict, level: str = 'info'):
self.update_llm_metrics(metrics)
message = "LLM Metrics"
logger = logging.getLogger('business_events')
log_data = {
'event_type': self.event_type,
'tenant_id': self.tenant_id,
'trace_id': self.trace_id,
'span_id': self.span_id,
'span_name': self.span_name,
'parent_span_id': self.parent_span_id,
'document_version_id': self.document_version_id,
'chat_session_id': self.chat_session_id,
'interaction_id': self.interaction_id,
'environment': self.environment,
'llm_metrics_total_tokens': metrics['total_tokens'],
'llm_metrics_prompt_tokens': metrics['prompt_tokens'],
'llm_metrics_completion_tokens': metrics['completion_tokens'],
'llm_metrics_total_time': metrics['time_elapsed'],
'llm_interaction_type': metrics['interaction_type'],
}
# log to Graylog
getattr(logger, level)(message, extra=log_data)
# Log to database
event_log = BusinessEventLog(
timestamp=dt.now(tz=tz.utc),
event_type=self.event_type,
tenant_id=self.tenant_id,
trace_id=self.trace_id,
span_id=self.span_id,
span_name=self.span_name,
parent_span_id=self.parent_span_id,
document_version_id=self.document_version_id,
chat_session_id=self.chat_session_id,
interaction_id=self.interaction_id,
environment=self.environment,
llm_metrics_total_tokens=metrics['total_tokens'],
llm_metrics_prompt_tokens=metrics['prompt_tokens'],
llm_metrics_completion_tokens=metrics['completion_tokens'],
llm_metrics_total_time=metrics['time_elapsed'],
llm_interaction_type=metrics['interaction_type'],
message=message
)
db.session.add(event_log)
db.session.commit()
def log_final_metrics(self, level: str = 'info'):
logger = logging.getLogger('business_events')
message = "Final LLM Metrics"
log_data = {
'event_type': self.event_type,
'tenant_id': self.tenant_id,
'trace_id': self.trace_id,
'span_id': self.span_id,
'span_name': self.span_name,
'parent_span_id': self.parent_span_id,
'document_version_id': self.document_version_id,
'chat_session_id': self.chat_session_id,
'interaction_id': self.interaction_id,
'environment': self.environment,
'llm_metrics_total_tokens': self.llm_metrics['total_tokens'],
'llm_metrics_prompt_tokens': self.llm_metrics['prompt_tokens'],
'llm_metrics_completion_tokens': self.llm_metrics['completion_tokens'],
'llm_metrics_total_time': self.llm_metrics['total_time'],
'llm_metrics_call_count': self.llm_metrics['call_count'],
'llm_interaction_type': self.llm_metrics['interaction_type'],
}
# log to Graylog
getattr(logger, level)(message, extra=log_data)
# Log to database
event_log = BusinessEventLog(
timestamp=dt.now(tz=tz.utc),
event_type=self.event_type,
tenant_id=self.tenant_id,
trace_id=self.trace_id,
span_id=self.span_id,
span_name=self.span_name,
parent_span_id=self.parent_span_id,
document_version_id=self.document_version_id,
chat_session_id=self.chat_session_id,
interaction_id=self.interaction_id,
environment=self.environment,
llm_metrics_total_tokens=self.llm_metrics['total_tokens'],
llm_metrics_prompt_tokens=self.llm_metrics['prompt_tokens'],
llm_metrics_completion_tokens=self.llm_metrics['completion_tokens'],
llm_metrics_total_time=self.llm_metrics['total_time'],
llm_metrics_call_count=self.llm_metrics['call_count'],
llm_interaction_type=self.llm_metrics['interaction_type'],
message=message
)
db.session.add(event_log)
db.session.commit()
def __enter__(self): def __enter__(self):
self.log(f'Starting Trace for {self.event_type}') self.log(f'Starting Trace for {self.event_type}')
return BusinessEventContext(self).__enter__() return BusinessEventContext(self).__enter__()
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
if self.llm_metrics['call_count'] > 0:
self.log_final_metrics()
self.reset_llm_metrics()
self.log(f'Ending Trace for {self.event_type}') self.log(f'Ending Trace for {self.event_type}')
return BusinessEventContext(self).__exit__(exc_type, exc_val, exc_tb) return BusinessEventContext(self).__exit__(exc_type, exc_val, exc_tb)

View File

@@ -11,6 +11,9 @@ from openai import OpenAI
from portkey_ai import createHeaders, PORTKEY_GATEWAY_URL from portkey_ai import createHeaders, PORTKEY_GATEWAY_URL
from portkey_ai.langchain.portkey_langchain_callback_handler import LangchainCallbackHandler from portkey_ai.langchain.portkey_langchain_callback_handler import LangchainCallbackHandler
from common.langchain.llm_metrics_handler import LLMMetricsHandler
from common.langchain.tracked_openai_embeddings import TrackedOpenAIEmbeddings
from common.langchain.tracked_transcribe import tracked_transcribe
from common.models.document import EmbeddingSmallOpenAI, EmbeddingLargeOpenAI from common.models.document import EmbeddingSmallOpenAI, EmbeddingLargeOpenAI
from common.models.user import Tenant from common.models.user import Tenant
from config.model_config import MODEL_CONFIG from config.model_config import MODEL_CONFIG
@@ -48,6 +51,8 @@ class ModelVariables(MutableMapping):
self._transcription_client = None self._transcription_client = None
self._prompt_templates = {} self._prompt_templates = {}
self._embedding_db_model = None self._embedding_db_model = None
self.llm_metrics_handler = LLMMetricsHandler()
self._transcription_client = None
def _initialize_variables(self): def _initialize_variables(self):
variables = {} variables = {}
@@ -89,26 +94,20 @@ class ModelVariables(MutableMapping):
if variables['tool_calling_supported']: if variables['tool_calling_supported']:
variables['cited_answer_cls'] = CitedAnswer variables['cited_answer_cls'] = CitedAnswer
variables['max_compression_duration'] = current_app.config['MAX_COMPRESSION_DURATION']
variables['max_transcription_duration'] = current_app.config['MAX_TRANSCRIPTION_DURATION']
variables['compression_cpu_limit'] = current_app.config['COMPRESSION_CPU_LIMIT']
variables['compression_process_delay'] = current_app.config['COMPRESSION_PROCESS_DELAY']
return variables return variables
@property @property
def embedding_model(self): def embedding_model(self):
portkey_metadata = self.get_portkey_metadata()
portkey_headers = createHeaders(api_key=os.getenv('PORTKEY_API_KEY'),
provider=self._variables['embedding_provider'],
metadata=portkey_metadata,
trace_id=current_event.trace_id,
span_id=current_event.span_id,
span_name=current_event.span_name,
parent_span_id=current_event.parent_span_id
)
api_key = os.getenv('OPENAI_API_KEY') api_key = os.getenv('OPENAI_API_KEY')
model = self._variables['embedding_model'] model = self._variables['embedding_model']
self._embedding_model = OpenAIEmbeddings(api_key=api_key, self._embedding_model = TrackedOpenAIEmbeddings(api_key=api_key,
model=model, model=model,
base_url=PORTKEY_GATEWAY_URL, )
default_headers=portkey_headers)
self._embedding_db_model = EmbeddingSmallOpenAI \ self._embedding_db_model = EmbeddingSmallOpenAI \
if model == 'text-embedding-3-small' \ if model == 'text-embedding-3-small' \
else EmbeddingLargeOpenAI else EmbeddingLargeOpenAI
@@ -117,49 +116,22 @@ class ModelVariables(MutableMapping):
@property @property
def llm(self): def llm(self):
portkey_headers = self.get_portkey_headers_for_llm()
api_key = self.get_api_key_for_llm() api_key = self.get_api_key_for_llm()
self._llm = ChatOpenAI(api_key=api_key, self._llm = ChatOpenAI(api_key=api_key,
model=self._variables['llm_model'], model=self._variables['llm_model'],
temperature=self._variables['RAG_temperature'], temperature=self._variables['RAG_temperature'],
base_url=PORTKEY_GATEWAY_URL, callbacks=[self.llm_metrics_handler])
default_headers=portkey_headers)
return self._llm return self._llm
@property @property
def llm_no_rag(self): def llm_no_rag(self):
portkey_headers = self.get_portkey_headers_for_llm()
api_key = self.get_api_key_for_llm() api_key = self.get_api_key_for_llm()
self._llm_no_rag = ChatOpenAI(api_key=api_key, self._llm_no_rag = ChatOpenAI(api_key=api_key,
model=self._variables['llm_model'], model=self._variables['llm_model'],
temperature=self._variables['RAG_temperature'], temperature=self._variables['RAG_temperature'],
base_url=PORTKEY_GATEWAY_URL, callbacks=[self.llm_metrics_handler])
default_headers=portkey_headers)
return self._llm_no_rag return self._llm_no_rag
def get_portkey_headers_for_llm(self):
portkey_metadata = self.get_portkey_metadata()
portkey_headers = createHeaders(api_key=os.getenv('PORTKEY_API_KEY'),
metadata=portkey_metadata,
provider=self._variables['llm_provider'],
trace_id=current_event.trace_id,
span_id=current_event.span_id,
span_name=current_event.span_name,
parent_span_id=current_event.parent_span_id
)
return portkey_headers
def get_portkey_metadata(self):
environment = os.getenv('FLASK_ENV', 'development')
portkey_metadata = {'tenant_id': str(self.tenant.id),
'environment': environment,
'trace_id': current_event.trace_id,
'span_id': current_event.span_id,
'span_name': current_event.span_name,
'parent_span_id': current_event.parent_span_id,
}
return portkey_metadata
def get_api_key_for_llm(self): def get_api_key_for_llm(self):
if self._variables['llm_provider'] == 'openai': if self._variables['llm_provider'] == 'openai':
api_key = os.getenv('OPENAI_API_KEY') api_key = os.getenv('OPENAI_API_KEY')
@@ -168,57 +140,16 @@ class ModelVariables(MutableMapping):
return api_key return api_key
# def _initialize_llm(self):
#
#
# if self._variables['llm_provider'] == 'openai':
# portkey_headers = createHeaders(api_key=os.getenv('PORTKEY_API_KEY'),
# metadata=portkey_metadata,
# provider='openai')
#
# self._llm = ChatOpenAI(api_key=api_key,
# model=self._variables['llm_model'],
# temperature=self._variables['RAG_temperature'],
# base_url=PORTKEY_GATEWAY_URL,
# default_headers=portkey_headers)
# self._llm_no_rag = ChatOpenAI(api_key=api_key,
# model=self._variables['llm_model'],
# temperature=self._variables['no_RAG_temperature'],
# base_url=PORTKEY_GATEWAY_URL,
# default_headers=portkey_headers)
# self._variables['tool_calling_supported'] = self._variables['llm_model'] in ['gpt-4o', 'gpt-4o-mini']
# elif self._variables['llm_provider'] == 'anthropic':
# api_key = os.getenv('ANTHROPIC_API_KEY')
# llm_model_ext = os.getenv('ANTHROPIC_LLM_VERSIONS', {}).get(self._variables['llm_model'])
# self._llm = ChatAnthropic(api_key=api_key,
# model=llm_model_ext,
# temperature=self._variables['RAG_temperature'])
# self._llm_no_rag = ChatAnthropic(api_key=api_key,
# model=llm_model_ext,
# temperature=self._variables['RAG_temperature'])
# self._variables['tool_calling_supported'] = True
# else:
# raise ValueError(f"Invalid chat provider: {self._variables['llm_provider']}")
@property @property
def transcription_client(self): def transcription_client(self):
environment = os.getenv('FLASK_ENV', 'development')
portkey_metadata = self.get_portkey_metadata()
portkey_headers = createHeaders(api_key=os.getenv('PORTKEY_API_KEY'),
metadata=portkey_metadata,
provider='openai',
trace_id=current_event.trace_id,
span_id=current_event.span_id,
span_name=current_event.span_name,
parent_span_id=current_event.parent_span_id
)
api_key = os.getenv('OPENAI_API_KEY') api_key = os.getenv('OPENAI_API_KEY')
self._transcription_client = OpenAI(api_key=api_key, self._transcription_client = OpenAI(api_key=api_key, )
base_url=PORTKEY_GATEWAY_URL,
default_headers=portkey_headers)
self._variables['transcription_model'] = 'whisper-1' self._variables['transcription_model'] = 'whisper-1'
return self._transcription_client return self._transcription_client
def transcribe(self, *args, **kwargs):
return tracked_transcribe(self._transcription_client, *args, **kwargs)
@property @property
def embedding_db_model(self): def embedding_db_model(self):
if self._embedding_db_model is None: if self._embedding_db_model is None:

View File

@@ -1,99 +0,0 @@
import requests
import json
from typing import Optional
# Define a function to make the GET request
def get_metadata_grouped_data(
api_key: str,
metadata_key: str,
time_of_generation_min: Optional[str] = None,
time_of_generation_max: Optional[str] = None,
total_units_min: Optional[int] = None,
total_units_max: Optional[int] = None,
cost_min: Optional[float] = None,
cost_max: Optional[float] = None,
prompt_token_min: Optional[int] = None,
prompt_token_max: Optional[int] = None,
completion_token_min: Optional[int] = None,
completion_token_max: Optional[int] = None,
status_code: Optional[str] = None,
weighted_feedback_min: Optional[float] = None,
weighted_feedback_max: Optional[float] = None,
virtual_keys: Optional[str] = None,
configs: Optional[str] = None,
workspace_slug: Optional[str] = None,
api_key_ids: Optional[str] = None,
current_page: Optional[int] = 1,
page_size: Optional[int] = 20,
metadata: Optional[str] = None,
ai_org_model: Optional[str] = None,
trace_id: Optional[str] = None,
span_id: Optional[str] = None,
):
url = f"https://api.portkey.ai/v1/analytics/groups/metadata/{metadata_key}"
# Set up query parameters
params = {
"time_of_generation_min": time_of_generation_min,
"time_of_generation_max": time_of_generation_max,
"total_units_min": total_units_min,
"total_units_max": total_units_max,
"cost_min": cost_min,
"cost_max": cost_max,
"prompt_token_min": prompt_token_min,
"prompt_token_max": prompt_token_max,
"completion_token_min": completion_token_min,
"completion_token_max": completion_token_max,
"status_code": status_code,
"weighted_feedback_min": weighted_feedback_min,
"weighted_feedback_max": weighted_feedback_max,
"virtual_keys": virtual_keys,
"configs": configs,
"workspace_slug": workspace_slug,
"api_key_ids": api_key_ids,
"current_page": current_page,
"page_size": page_size,
"metadata": metadata,
"ai_org_model": ai_org_model,
"trace_id": trace_id,
"span_id": span_id,
}
# Remove any keys with None values
params = {k: v for k, v in params.items() if v is not None}
# Set up the headers
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
}
# Make the GET request
response = requests.get(url, headers=headers, params=params)
# Check for successful response
if response.status_code == 200:
return response.json() # Return JSON data
else:
response.raise_for_status() # Raise an exception for errors
# Example usage
# Replace 'your_api_key' and 'your_metadata_key' with actual values
api_key = 'your_api_key'
metadata_key = 'your_metadata_key'
try:
data = get_metadata_grouped_data(
api_key=api_key,
metadata_key=metadata_key,
time_of_generation_min="2024-08-23T15:50:23+05:30",
time_of_generation_max="2024-09-23T15:50:23+05:30",
total_units_min=100,
total_units_max=1000,
cost_min=10,
cost_max=100,
status_code="200,201"
)
print(json.dumps(data, indent=4))
except Exception as e:
print(f"Error occurred: {str(e)}")

View File

@@ -147,6 +147,15 @@ class Config(object):
TENANT_TYPES = ['Active', 'Demo', 'Inactive', 'Test'] TENANT_TYPES = ['Active', 'Demo', 'Inactive', 'Test']
# The maximum number of seconds allowed for audio compression (to save resources)
MAX_COMPRESSION_DURATION = 60*10 # 10 minutes
# The maximum number of seconds allowed for transcribing audio
MAX_TRANSCRIPTION_DURATION = 60*10 # 10 minutes
# Maximum CPU usage for a compression task
COMPRESSION_CPU_LIMIT = 50
# Delay between compressing chunks in seconds
COMPRESSION_PROCESS_DELAY = 1
class DevConfig(Config): class DevConfig(Config):
DEVELOPMENT = True DEVELOPMENT = True

View File

@@ -1,6 +1,8 @@
import io import io
import os import os
import time
import psutil
from pydub import AudioSegment from pydub import AudioSegment
import tempfile import tempfile
from common.extensions import minio_client from common.extensions import minio_client
@@ -16,6 +18,11 @@ class AudioProcessor(TranscriptionProcessor):
self.transcription_client = model_variables['transcription_client'] self.transcription_client = model_variables['transcription_client']
self.transcription_model = model_variables['transcription_model'] self.transcription_model = model_variables['transcription_model']
self.ffmpeg_path = 'ffmpeg' self.ffmpeg_path = 'ffmpeg'
self.max_compression_duration = model_variables['max_compression_duration']
self.max_transcription_duration = model_variables['max_transcription_duration']
self.compression_cpu_limit = model_variables.get('compression_cpu_limit', 50) # CPU usage limit in percentage
self.compression_process_delay = model_variables.get('compression_process_delay', 0.1) # Delay between processing chunks in seconds
self.file_type = document_version.file_type
def _get_transcription(self): def _get_transcription(self):
file_data = minio_client.download_document_file( file_data = minio_client.download_document_file(
@@ -26,68 +33,121 @@ class AudioProcessor(TranscriptionProcessor):
self.document_version.file_name self.document_version.file_name
) )
with current_event.create_span("Audio Processing"): with current_event.create_span("Audio Compression"):
compressed_audio = self._compress_audio(file_data) compressed_audio = self._compress_audio(file_data)
with current_event.create_span("Transcription Generation"): with current_event.create_span("Audio Transcription"):
transcription = self._transcribe_audio(compressed_audio) transcription = self._transcribe_audio(compressed_audio)
return transcription return transcription
def _compress_audio(self, audio_data): def _compress_audio(self, audio_data):
self._log("Compressing audio") self._log("Compressing audio")
with tempfile.NamedTemporaryFile(delete=False, suffix=f'.{self.document_version.file_type}') as temp_input:
temp_input.write(audio_data)
temp_input.flush()
# Use a unique filename for the output to avoid conflicts with tempfile.NamedTemporaryFile(delete=False, suffix=f'.{self.document_version.file_type}') as temp_file:
output_filename = f'compressed_{os.urandom(8).hex()}.mp3' temp_file.write(audio_data)
output_path = os.path.join(tempfile.gettempdir(), output_filename) temp_file_path = temp_file.name
try: try:
result = subprocess.run( self._log("Creating AudioSegment from file")
[self.ffmpeg_path, '-y', '-i', temp_input.name, '-b:a', '64k', '-f', 'mp3', output_path], audio_info = AudioSegment.from_file(temp_file_path, format=self.document_version.file_type)
capture_output=True, self._log("Finished creating AudioSegment from file")
text=True, total_duration = len(audio_info)
check=True self._log(f"Audio duration: {total_duration / 1000} seconds")
segment_length = self.max_compression_duration * 1000 # Convert to milliseconds
total_chunks = (total_duration + segment_length - 1) // segment_length
compressed_segments = AudioSegment.empty()
for i in range(total_chunks):
self._log(f"Compressing segment {i + 1} of {total_chunks}")
start_time = i * segment_length
end_time = min((i + 1) * segment_length, total_duration)
chunk = AudioSegment.from_file(
temp_file_path,
format=self.document_version.file_type,
start_second=start_time / 1000,
duration=(end_time - start_time) / 1000
) )
with open(output_path, 'rb') as f: compressed_chunk = self._compress_segment(chunk)
compressed_data = f.read() compressed_segments += compressed_chunk
time.sleep(self.compression_process_delay)
# Save compressed audio to MinIO # Save compressed audio to MinIO
compressed_filename = f"{self.document_version.id}_compressed.mp3" compressed_filename = f"{self.document_version.id}_compressed.mp3"
with io.BytesIO() as compressed_buffer:
compressed_segments.export(compressed_buffer, format="mp3")
compressed_buffer.seek(0)
minio_client.upload_document_file( minio_client.upload_document_file(
self.tenant.id, self.tenant.id,
self.document_version.doc_id, self.document_version.doc_id,
self.document_version.language, self.document_version.language,
self.document_version.id, self.document_version.id,
compressed_filename, compressed_filename,
compressed_data compressed_buffer.read()
) )
self._log(f"Saved compressed audio to MinIO: {compressed_filename}") self._log(f"Saved compressed audio to MinIO: {compressed_filename}")
return compressed_data return compressed_segments
except subprocess.CalledProcessError as e:
error_message = f"Compression failed: {e.stderr}"
self._log(error_message, level='error')
raise Exception(error_message)
except Exception as e:
self._log(f"Error during audio processing: {str(e)}", level='error')
raise
finally: finally:
# Clean up temporary files os.unlink(temp_file_path) # Ensure the temporary file is deleted
os.unlink(temp_input.name)
if os.path.exists(output_path): def _compress_segment(self, audio_segment):
os.unlink(output_path) with io.BytesIO() as segment_buffer:
audio_segment.export(segment_buffer, format="wav")
segment_buffer.seek(0)
with io.BytesIO() as output_buffer:
command = [
'nice', '-n', '19',
'ffmpeg',
'-i', 'pipe:0',
'-ar', '16000',
'-ac', '1',
'-b:a', '32k',
'-filter:a', 'loudnorm',
'-f', 'mp3',
'pipe:1'
]
process = psutil.Popen(command, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stdout, stderr = process.communicate(input=segment_buffer.read())
if process.returncode != 0:
self._log(f"FFmpeg error: {stderr.decode()}", level='error')
raise Exception("FFmpeg compression failed")
output_buffer.write(stdout)
output_buffer.seek(0)
compressed_segment = AudioSegment.from_mp3(output_buffer)
return compressed_segment
def _transcribe_audio(self, audio_data): def _transcribe_audio(self, audio_data):
self._log("Starting audio transcription") self._log("Starting audio transcription")
audio = AudioSegment.from_file(io.BytesIO(audio_data), format="mp3") # audio = AudioSegment.from_file(io.BytesIO(audio_data), format="mp3")
audio = audio_data
segment_length = 10 * 60 * 1000 # 10 minutes in milliseconds segment_length = self.max_transcription_duration * 1000 # calculate milliseconds
transcriptions = [] transcriptions = []
total_chunks = len(audio) // segment_length + 1
for i, chunk in enumerate(audio[::segment_length]): for i, chunk in enumerate(audio[::segment_length]):
self._log(f'Processing chunk {i + 1} of {len(audio) // segment_length + 1}') self._log(f'Processing chunk {i + 1} of {total_chunks}')
segment_duration = 0
if i == total_chunks - 1:
segment_duration = (len(audio) % segment_length) // 1000
else:
segment_duration = self.max_transcription_duration
with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as temp_audio: with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as temp_audio:
chunk.export(temp_audio.name, format="mp3") chunk.export(temp_audio.name, format="mp3")
@@ -103,11 +163,12 @@ class AudioProcessor(TranscriptionProcessor):
audio_file.seek(0) # Reset file pointer to the beginning audio_file.seek(0) # Reset file pointer to the beginning
self._log("Calling transcription API") self._log("Calling transcription API")
transcription = self.transcription_client.audio.transcriptions.create( transcription = self.model_variables.transcribe(
file=audio_file, file=audio_file,
model=self.transcription_model, model=self.transcription_model,
language=self.document_version.language, language=self.document_version.language,
response_format='verbose_json', response_format='verbose_json',
duration=segment_duration,
) )
self._log("Transcription API call completed") self._log("Transcription API call completed")

View File

@@ -171,9 +171,11 @@ def embed_markdown(tenant, model_variables, document_version, markdown, title):
model_variables['max_chunk_size']) model_variables['max_chunk_size'])
# Enrich chunks # Enrich chunks
with current_event.create_span("Enrich Chunks"):
enriched_chunks = enrich_chunks(tenant, model_variables, document_version, title, chunks) enriched_chunks = enrich_chunks(tenant, model_variables, document_version, title, chunks)
# Create embeddings # Create embeddings
with current_event.create_span("Create Embeddings"):
embeddings = embed_chunks(tenant, model_variables, document_version, enriched_chunks) embeddings = embed_chunks(tenant, model_variables, document_version, enriched_chunks)
# Update document version and save embeddings # Update document version and save embeddings
@@ -194,7 +196,6 @@ def embed_markdown(tenant, model_variables, document_version, markdown, title):
def enrich_chunks(tenant, model_variables, document_version, title, chunks): def enrich_chunks(tenant, model_variables, document_version, title, chunks):
current_event.log("Starting Enriching Chunks Processing")
current_app.logger.debug(f'Enriching chunks for tenant {tenant.id} ' current_app.logger.debug(f'Enriching chunks for tenant {tenant.id} '
f'on document version {document_version.id}') f'on document version {document_version.id}')
@@ -227,7 +228,6 @@ def enrich_chunks(tenant, model_variables, document_version, title, chunks):
current_app.logger.debug(f'Finished enriching chunks for tenant {tenant.id} ' current_app.logger.debug(f'Finished enriching chunks for tenant {tenant.id} '
f'on document version {document_version.id}') f'on document version {document_version.id}')
current_event.log("Finished Enriching Chunks Processing")
return enriched_chunks return enriched_chunks
@@ -261,7 +261,6 @@ def summarize_chunk(tenant, model_variables, document_version, chunk):
def embed_chunks(tenant, model_variables, document_version, chunks): def embed_chunks(tenant, model_variables, document_version, chunks):
current_event.log("Starting Embedding Chunks Processing")
current_app.logger.debug(f'Embedding chunks for tenant {tenant.id} ' current_app.logger.debug(f'Embedding chunks for tenant {tenant.id} '
f'on document version {document_version.id}') f'on document version {document_version.id}')
embedding_model = model_variables['embedding_model'] embedding_model = model_variables['embedding_model']

View File

@@ -0,0 +1,40 @@
"""Add LLM metrics information to business events
Revision ID: 16f70b210557
Revises: 829094f07d44
Create Date: 2024-10-01 09:46:49.372953
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '16f70b210557'
down_revision = '829094f07d44'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('business_event_log', schema=None) as batch_op:
batch_op.add_column(sa.Column('llm_metrics_total_tokens', sa.Integer(), nullable=True))
batch_op.add_column(sa.Column('llm_metrics_prompt_tokens', sa.Integer(), nullable=True))
batch_op.add_column(sa.Column('llm_metrics_completion_tokens', sa.Integer(), nullable=True))
batch_op.add_column(sa.Column('llm_metrics_total_time', sa.Float(), nullable=True))
batch_op.add_column(sa.Column('llm_metrics_call_count', sa.Integer(), nullable=True))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('business_event_log', schema=None) as batch_op:
batch_op.drop_column('llm_metrics_call_count')
batch_op.drop_column('llm_metrics_total_time')
batch_op.drop_column('llm_metrics_completion_tokens')
batch_op.drop_column('llm_metrics_prompt_tokens')
batch_op.drop_column('llm_metrics_total_tokens')
# ### end Alembic commands ###

View File

@@ -80,3 +80,4 @@ langsmith~=0.1.121
anthropic~=0.34.2 anthropic~=0.34.2
prometheus-client~=0.20.0 prometheus-client~=0.20.0
flower~=2.0.1 flower~=2.0.1
psutil~=6.0.0