import os from typing import Dict, Any, Optional import langcodes from common.langchain.llm_metrics_handler import LLMMetricsHandler from common.langchain.templates.template_manager import TemplateManager from langchain_openai import OpenAIEmbeddings, ChatOpenAI, OpenAI from langchain_anthropic import ChatAnthropic from flask import current_app from datetime import datetime as dt, timezone as tz from common.langchain.tracked_openai_embeddings import TrackedOpenAIEmbeddings from common.langchain.tracked_transcription import TrackedOpenAITranscription from common.models.user import Tenant from common.utils.cache.base import CacheHandler from config.model_config import MODEL_CONFIG from common.extensions import template_manager, cache_manager from common.models.document import EmbeddingLargeOpenAI, EmbeddingSmallOpenAI from common.utils.eveai_exceptions import EveAITenantNotFound def create_language_template(template: str, language: str) -> str: """ Replace language placeholder in template with specified language Args: template: Template string with {language} placeholder language: Language code to insert Returns: str: Template with language placeholder replaced """ try: full_language = langcodes.Language.make(language=language) language_template = template.replace('{language}', full_language.display_name()) except ValueError: language_template = template.replace('{language}', language) return language_template def replace_variable_in_template(template: str, variable: str, value: str) -> str: """ Replace a variable placeholder in template with specified value Args: template: Template string with variable placeholder variable: Variable placeholder to replace (e.g. "{tenant_context}") value: Value to insert Returns: str: Template with variable placeholder replaced """ return template.replace(variable, value or "") class ModelVariables: """Manages model-related variables and configurations""" def __init__(self, tenant_id: int, variables: Dict[str, Any] = None): """ Initialize ModelVariables with tenant and optional template manager Args: tenant: Tenant instance template_manager: Optional TemplateManager instance """ current_app.logger.info(f'Model variables initialized with tenant {tenant_id} and variables \n{variables}') self.tenant_id = tenant_id self._variables = variables if variables is not None else self._initialize_variables() current_app.logger.info(f'Model _variables initialized to {self._variables}') self._embedding_model = None self._embedding_model_class = None self._llm_instances = {} self.llm_metrics_handler = LLMMetricsHandler() self._transcription_model = None def _initialize_variables(self) -> Dict[str, Any]: """Initialize the variables dictionary""" variables = {} tenant = Tenant.query.get(self.tenant_id) if not tenant: raise EveAITenantNotFound(self.tenant_id) # Set model providers variables['embedding_provider'], variables['embedding_model'] = tenant.embedding_model.split('.') variables['llm_provider'], variables['llm_model'] = tenant.llm_model.split('.') variables['llm_full_model'] = tenant.llm_model # Set model-specific configurations model_config = MODEL_CONFIG.get(variables['llm_provider'], {}).get(variables['llm_model'], {}) variables.update(model_config) # Additional configurations variables['annotation_chunk_length'] = current_app.config['ANNOTATION_TEXT_CHUNK_LENGTH'][tenant.llm_model] variables['max_compression_duration'] = current_app.config['MAX_COMPRESSION_DURATION'] variables['max_transcription_duration'] = current_app.config['MAX_TRANSCRIPTION_DURATION'] variables['compression_cpu_limit'] = current_app.config['COMPRESSION_CPU_LIMIT'] variables['compression_process_delay'] = current_app.config['COMPRESSION_PROCESS_DELAY'] return variables @property def embedding_model(self): """Get the embedding model instance""" if self._embedding_model is None: api_key = os.getenv('OPENAI_API_KEY') self._embedding_model = TrackedOpenAIEmbeddings( api_key=api_key, model=self._variables['embedding_model'] ) return self._embedding_model @property def embedding_model_class(self): """Get the embedding model class""" if self._embedding_model_class is None: if self._variables['embedding_model'] == 'text-embedding-3-large': self._embedding_model_class = EmbeddingLargeOpenAI else: # text-embedding-3-small self._embedding_model_class = EmbeddingSmallOpenAI return self._embedding_model_class @property def annotation_chunk_length(self): return self._variables['annotation_chunk_length'] @property def max_compression_duration(self): return self._variables['max_compression_duration'] @property def max_transcription_duration(self): return self._variables['max_transcription_duration'] @property def compression_cpu_limit(self): return self._variables['compression_cpu_limit'] @property def compression_process_delay(self): return self._variables['compression_process_delay'] def get_llm(self, temperature: float = 0.3, **kwargs) -> Any: """ Get an LLM instance with specific configuration Args: temperature: The temperature for the LLM **kwargs: Additional configuration parameters Returns: An instance of the configured LLM """ cache_key = f"{temperature}_{hash(frozenset(kwargs.items()))}" if cache_key not in self._llm_instances: provider = self._variables['llm_provider'] model = self._variables['llm_model'] if provider == 'openai': self._llm_instances[cache_key] = ChatOpenAI( api_key=os.getenv('OPENAI_API_KEY'), model=model, temperature=temperature, callbacks=[self.llm_metrics_handler], **kwargs ) elif provider == 'anthropic': self._llm_instances[cache_key] = ChatAnthropic( api_key=os.getenv('ANTHROPIC_API_KEY'), model=current_app.config['ANTHROPIC_LLM_VERSIONS'][model], temperature=temperature, callbacks=[self.llm_metrics_handler], **kwargs ) else: raise ValueError(f"Unsupported LLM provider: {provider}") return self._llm_instances[cache_key] @property def transcription_model(self) -> TrackedOpenAITranscription: """Get the transcription model instance""" if self._transcription_model is None: api_key = os.getenv('OPENAI_API_KEY') self._transcription_model = TrackedOpenAITranscription( api_key=api_key, model='whisper-1' ) return self._transcription_model # Remove the old transcription-related methods since they're now handled by TrackedOpenAITranscription @property def transcription_client(self): raise DeprecationWarning("Use transcription_model instead") def transcribe(self, *args, **kwargs): raise DeprecationWarning("Use transcription_model.transcribe() instead") def get_template(self, template_name: str, version: Optional[str] = None) -> str: """ Get a template for the tenant's configured LLM Args: template_name: Name of the template to retrieve version: Optional specific version to retrieve Returns: The template content """ try: template = template_manager.get_template( self._variables['llm_full_model'], template_name, version ) return template.content except Exception as e: current_app.logger.error(f"Error getting template {template_name}: {str(e)}") # Fall back to old template loading if template_manager fails if template_name in self._variables.get('templates', {}): return self._variables['templates'][template_name] raise class ModelVariablesCacheHandler(CacheHandler[ModelVariables]): handler_name = 'model_vars_cache' # Used to access handler instance from cache_manager def __init__(self, region): super().__init__(region, 'model_variables') self.configure_keys('tenant_id') self.subscribe_to_model('Tenant', ['tenant_id']) def to_cache_data(self, instance: ModelVariables) -> Dict[str, Any]: return { 'tenant_id': instance.tenant_id, 'variables': instance._variables, 'last_updated': dt.now(tz=tz.utc).isoformat() } def from_cache_data(self, data: Dict[str, Any], tenant_id: int, **kwargs) -> ModelVariables: instance = ModelVariables(tenant_id, data.get('variables')) return instance def should_cache(self, value: Dict[str, Any]) -> bool: required_fields = {'tenant_id', 'variables'} return all(field in value for field in required_fields) # Register the handler with the cache manager cache_manager.register_handler(ModelVariablesCacheHandler, 'eveai_model') # Helper function to get cached model variables def get_model_variables(tenant_id: int) -> ModelVariables: return cache_manager.model_vars_cache.get( lambda tenant_id: ModelVariables(tenant_id), # function to create ModelVariables if required tenant_id=tenant_id ) # Written in a long format, without lambda # def get_model_variables(tenant_id: int) -> ModelVariables: # """ # Get ModelVariables instance, either from cache or newly created # # Args: # tenant_id: The tenant's ID # # Returns: # ModelVariables: Instance with either cached or fresh data # # Raises: # TenantNotFoundError: If tenant doesn't exist # CacheStateError: If cached data is invalid # """ # # def create_new_instance(tenant_id: int) -> ModelVariables: # """Creator function that's called when cache miss occurs""" # return ModelVariables(tenant_id) # This will initialize fresh variables # # return cache_manager.model_vars_cache.get( # create_new_instance, # Function to create new instance if needed # tenant_id=tenant_id # Parameters passed to both get() and create_new_instance # )