import os import langcodes from flask import current_app from langchain_openai import OpenAIEmbeddings, ChatOpenAI from langchain_anthropic import ChatAnthropic from langchain_core.pydantic_v1 import BaseModel, Field from typing import List, Any, Iterator from collections.abc import MutableMapping from openai import OpenAI from portkey_ai import createHeaders, PORTKEY_GATEWAY_URL from portkey_ai.langchain.portkey_langchain_callback_handler import LangchainCallbackHandler from common.langchain.llm_metrics_handler import LLMMetricsHandler from common.langchain.tracked_openai_embeddings import TrackedOpenAIEmbeddings from common.langchain.tracked_transcribe import tracked_transcribe from common.models.document import EmbeddingSmallOpenAI, EmbeddingLargeOpenAI, Catalog from common.models.user import Tenant from config.model_config import MODEL_CONFIG from common.utils.business_event_context import current_event class CitedAnswer(BaseModel): """Default docstring - to be replaced with actual prompt""" answer: str = Field( ..., description="The answer to the user question, based on the given sources", ) citations: List[int] = Field( ..., description="The integer IDs of the SPECIFIC sources that were used to generate the answer" ) insufficient_info: bool = Field( False, # Default value is set to False description="A boolean indicating wether given sources were sufficient or not to generate the answer" ) def set_language_prompt_template(cls, language_prompt): cls.__doc__ = language_prompt class ModelVariables(MutableMapping): def __init__(self, tenant: Tenant, catalog_id=None): self.tenant = tenant self.catalog_id = catalog_id self._variables = self._initialize_variables() self._embedding_model = None self._llm = None self._llm_no_rag = None self._transcription_client = None self._prompt_templates = {} self._embedding_db_model = None self.llm_metrics_handler = LLMMetricsHandler() self._transcription_client = None def _initialize_variables(self): variables = {} # Get the Catalog if catalog_id is passed if self.catalog_id: catalog = Catalog.query.get_or_404(self.catalog_id) # We initialize the variables that are available knowing the tenant. variables['embed_tuning'] = catalog.embed_tuning or False # Set HTML Chunking Variables variables['html_tags'] = catalog.html_tags variables['html_end_tags'] = catalog.html_end_tags variables['html_included_elements'] = catalog.html_included_elements variables['html_excluded_elements'] = catalog.html_excluded_elements variables['html_excluded_classes'] = catalog.html_excluded_classes # Set Chunk Size variables variables['min_chunk_size'] = catalog.min_chunk_size variables['max_chunk_size'] = catalog.max_chunk_size # Set the RAG Context (will have to change once specialists are defined variables['rag_context'] = self.tenant.rag_context or " " # Temporary setting until we have Specialists variables['rag_tuning'] = False variables['RAG_temperature'] = 0.3 variables['no_RAG_temperature'] = 0.5 variables['k'] = 8 variables['similarity_threshold'] = 0.4 # Set model providers variables['embedding_provider'], variables['embedding_model'] = self.tenant.embedding_model.rsplit('.', 1) variables['llm_provider'], variables['llm_model'] = self.tenant.llm_model.rsplit('.', 1) variables["templates"] = current_app.config['PROMPT_TEMPLATES'][(f"{variables['llm_provider']}." f"{variables['llm_model']}")] current_app.logger.info(f"Loaded prompt templates: \n") current_app.logger.info(f"{variables['templates']}") # Set model-specific configurations model_config = MODEL_CONFIG.get(variables['llm_provider'], {}).get(variables['llm_model'], {}) variables.update(model_config) variables['annotation_chunk_length'] = current_app.config['ANNOTATION_TEXT_CHUNK_LENGTH'][self.tenant.llm_model] if variables['tool_calling_supported']: variables['cited_answer_cls'] = CitedAnswer variables['max_compression_duration'] = current_app.config['MAX_COMPRESSION_DURATION'] variables['max_transcription_duration'] = current_app.config['MAX_TRANSCRIPTION_DURATION'] variables['compression_cpu_limit'] = current_app.config['COMPRESSION_CPU_LIMIT'] variables['compression_process_delay'] = current_app.config['COMPRESSION_PROCESS_DELAY'] return variables @property def embedding_model(self): api_key = os.getenv('OPENAI_API_KEY') model = self._variables['embedding_model'] self._embedding_model = TrackedOpenAIEmbeddings(api_key=api_key, model=model, ) self._embedding_db_model = EmbeddingSmallOpenAI \ if model == 'text-embedding-3-small' \ else EmbeddingLargeOpenAI return self._embedding_model @property def llm(self): api_key = self.get_api_key_for_llm() self._llm = ChatOpenAI(api_key=api_key, model=self._variables['llm_model'], temperature=self._variables['RAG_temperature'], callbacks=[self.llm_metrics_handler]) return self._llm @property def llm_no_rag(self): api_key = self.get_api_key_for_llm() self._llm_no_rag = ChatOpenAI(api_key=api_key, model=self._variables['llm_model'], temperature=self._variables['RAG_temperature'], callbacks=[self.llm_metrics_handler]) return self._llm_no_rag def get_api_key_for_llm(self): if self._variables['llm_provider'] == 'openai': api_key = os.getenv('OPENAI_API_KEY') else: # self._variables['llm_provider'] == 'anthropic' api_key = os.getenv('ANTHROPIC_API_KEY') return api_key @property def transcription_client(self): api_key = os.getenv('OPENAI_API_KEY') self._transcription_client = OpenAI(api_key=api_key, ) self._variables['transcription_model'] = 'whisper-1' return self._transcription_client def transcribe(self, *args, **kwargs): return tracked_transcribe(self._transcription_client, *args, **kwargs) @property def embedding_db_model(self): if self._embedding_db_model is None: self._embedding_db_model = self.get_embedding_db_model() return self._embedding_db_model def get_embedding_db_model(self): current_app.logger.debug("In get_embedding_db_model") if self._embedding_db_model is None: self._embedding_db_model = EmbeddingSmallOpenAI \ if self._variables['embedding_model'] == 'text-embedding-3-small' \ else EmbeddingLargeOpenAI current_app.logger.debug(f"Embedding DB Model: {self._embedding_db_model}") return self._embedding_db_model def get_prompt_template(self, template_name: str) -> str: current_app.logger.info(f"Getting prompt template for {template_name}") if template_name not in self._prompt_templates: self._prompt_templates[template_name] = self._load_prompt_template(template_name) return self._prompt_templates[template_name] def _load_prompt_template(self, template_name: str) -> str: # In the future, this method will make an API call to Portkey # For now, we'll simulate it with a placeholder implementation # You can replace this with your current prompt loading logic return self._variables['templates'][template_name] def __getitem__(self, key: str) -> Any: current_app.logger.debug(f"ModelVariables: Getting {key}") # Support older template names (suffix = _template) if key.endswith('_template'): key = key[:-len('_template')] current_app.logger.debug(f"ModelVariables: Getting modified {key}") if key == 'embedding_model': return self.embedding_model elif key == 'embedding_db_model': return self.embedding_db_model elif key == 'llm': return self.llm elif key == 'llm_no_rag': return self.llm_no_rag elif key == 'transcription_client': return self.transcription_client elif key in self._variables.get('prompt_templates', []): return self.get_prompt_template(key) else: value = self._variables.get(key) if value is not None: return value else: raise KeyError(f'Variable {key} does not exist in ModelVariables') def __setitem__(self, key: str, value: Any) -> None: self._variables[key] = value def __delitem__(self, key: str) -> None: del self._variables[key] def __iter__(self) -> Iterator[str]: return iter(self._variables) def __len__(self): return len(self._variables) def get(self, key: str, default: Any = None) -> Any: return self.__getitem__(key) or default def update(self, **kwargs) -> None: self._variables.update(kwargs) def items(self): return self._variables.items() def keys(self): return self._variables.keys() def values(self): return self._variables.values() def select_model_variables(tenant, catalog_id=None): model_variables = ModelVariables(tenant=tenant, catalog_id=catalog_id) return model_variables def create_language_template(template, language): try: full_language = langcodes.Language.make(language=language) language_template = template.replace('{language}', full_language.display_name()) except ValueError: language_template = template.replace('{language}', language) return language_template def replace_variable_in_template(template, variable, value): return template.replace(variable, value)