import time from langchain.callbacks.base import BaseCallbackHandler from typing import Dict, Any, List from langchain.schema import LLMResult from common.utils.business_event_context import current_event class PersistentLLMMetricsHandler(BaseCallbackHandler): """Metrics handler that allows metrics to be retrieved from within any call. In case metrics are required for other purposes than business event logging.""" def __init__(self): self.total_tokens: int = 0 self.prompt_tokens: int = 0 self.completion_tokens: int = 0 self.start_time: float = 0 self.end_time: float = 0 self.total_time: float = 0 def reset(self): self.total_tokens = 0 self.prompt_tokens = 0 self.completion_tokens = 0 self.start_time = 0 self.end_time = 0 self.total_time = 0 def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> None: self.start_time = time.time() def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: self.end_time = time.time() self.total_time = self.end_time - self.start_time usage = response.llm_output.get('token_usage', {}) self.prompt_tokens += usage.get('prompt_tokens', 0) self.completion_tokens += usage.get('completion_tokens', 0) self.total_tokens = self.prompt_tokens + self.completion_tokens def get_metrics(self) -> Dict[str, int | float]: return { 'total_tokens': self.total_tokens, 'prompt_tokens': self.prompt_tokens, 'completion_tokens': self.completion_tokens, 'time_elapsed': self.total_time, 'interaction_type': 'LLM', }