import time from langchain.callbacks.base import BaseCallbackHandler from typing import Dict, Any, List from langchain.schema import LLMResult from common.utils.business_event_context import current_event class LLMMetricsHandler(BaseCallbackHandler): def __init__(self): self.total_tokens: int = 0 self.prompt_tokens: int = 0 self.completion_tokens: int = 0 self.start_time: float = 0 self.end_time: float = 0 self.total_time: float = 0 def reset(self): self.total_tokens = 0 self.prompt_tokens = 0 self.completion_tokens = 0 self.start_time = 0 self.end_time = 0 self.total_time = 0 def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> None: self.start_time = time.time() def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: self.end_time = time.time() self.total_time = self.end_time - self.start_time usage = response.llm_output.get('token_usage', {}) self.prompt_tokens += usage.get('prompt_tokens', 0) self.completion_tokens += usage.get('completion_tokens', 0) self.total_tokens = self.prompt_tokens + self.completion_tokens metrics = self.get_metrics() current_event.log_llm_metrics(metrics) self.reset() # Reset for the next call def get_metrics(self) -> Dict[str, int | float]: return { 'total_tokens': self.total_tokens, 'prompt_tokens': self.prompt_tokens, 'completion_tokens': self.completion_tokens, 'time_elapsed': self.total_time, 'interaction_type': 'LLM', }