import os import langcodes from flask import current_app from langchain_openai import OpenAIEmbeddings, ChatOpenAI from langchain_anthropic import ChatAnthropic from langchain_core.pydantic_v1 import BaseModel, Field from typing import List, Any, Iterator from collections.abc import MutableMapping from openai import OpenAI from portkey_ai import createHeaders, PORTKEY_GATEWAY_URL from common.models.document import EmbeddingSmallOpenAI, EmbeddingLargeOpenAI from common.models.user import Tenant from config.model_config import MODEL_CONFIG class CitedAnswer(BaseModel): """Default docstring - to be replaced with actual prompt""" answer: str = Field( ..., description="The answer to the user question, based on the given sources", ) citations: List[int] = Field( ..., description="The integer IDs of the SPECIFIC sources that were used to generate the answer" ) insufficient_info: bool = Field( False, # Default value is set to False description="A boolean indicating wether given sources were sufficient or not to generate the answer" ) def set_language_prompt_template(cls, language_prompt): cls.__doc__ = language_prompt class ModelVariables(MutableMapping): def __init__(self, tenant: Tenant): self.tenant = tenant self._variables = self._initialize_variables() self._embedding_model = None self._llm = None self._llm_no_rag = None self._transcription_client = None self._prompt_templates = {} self._embedding_db_model = None def _initialize_variables(self): variables = {} # We initialize the variables that are available knowing the tenant. For the other, we will apply 'lazy loading' variables['k'] = self.tenant.es_k or 5 variables['similarity_threshold'] = self.tenant.es_similarity_threshold or 0.7 variables['RAG_temperature'] = self.tenant.chat_RAG_temperature or 0.3 variables['no_RAG_temperature'] = self.tenant.chat_no_RAG_temperature or 0.5 variables['embed_tuning'] = self.tenant.embed_tuning or False variables['rag_tuning'] = self.tenant.rag_tuning or False variables['rag_context'] = self.tenant.rag_context or " " # Set HTML Chunking Variables variables['html_tags'] = self.tenant.html_tags variables['html_end_tags'] = self.tenant.html_end_tags variables['html_included_elements'] = self.tenant.html_included_elements variables['html_excluded_elements'] = self.tenant.html_excluded_elements variables['html_excluded_classes'] = self.tenant.html_excluded_classes # Set Chunk Size variables variables['min_chunk_size'] = self.tenant.min_chunk_size variables['max_chunk_size'] = self.tenant.max_chunk_size # Set model providers variables['embedding_provider'], variables['embedding_model'] = self.tenant.embedding_model.rsplit('.', 1) variables['llm_provider'], variables['llm_model'] = self.tenant.llm_model.rsplit('.', 1) variables["templates"] = current_app.config['PROMPT_TEMPLATES'][(f"{variables['llm_provider']}." f"{variables['llm_model']}")] current_app.logger.info(f"Loaded prompt templates: \n") current_app.logger.info(f"{variables['templates']}") # Set model-specific configurations model_config = MODEL_CONFIG.get(variables['llm_provider'], {}).get(variables['llm_model'], {}) variables.update(model_config) variables['annotation_chunk_length'] = current_app.config['ANNOTATION_TEXT_CHUNK_LENGTH'][self.tenant.llm_model] if variables['tool_calling_supported']: variables['cited_answer_cls'] = CitedAnswer return variables @property def embedding_model(self): if self._embedding_model is None: environment = os.getenv('FLASK_ENV', 'development') portkey_metadata = {'tenant_id': str(self.tenant.id), 'environment': environment} if self._variables['embedding_provider'] == 'openai': portkey_headers = createHeaders(api_key=os.getenv('PORTKEY_API_KEY'), provider='openai', metadata=portkey_metadata) api_key = os.getenv('OPENAI_API_KEY') model = self._variables['embedding_model'] self._embedding_model = OpenAIEmbeddings(api_key=api_key, model=model, base_url=PORTKEY_GATEWAY_URL, default_headers=portkey_headers) self._embedding_db_model = EmbeddingSmallOpenAI \ if model == 'text-embedding-3-small' \ else EmbeddingLargeOpenAI else: raise ValueError(f"Invalid embedding provider: {self._variables['embedding_provider']}") return self._embedding_model @property def llm(self): if self._llm is None: self._initialize_llm() return self._llm @property def llm_no_rag(self): if self._llm_no_rag is None: self._initialize_llm() return self._llm_no_rag def _initialize_llm(self): environment = os.getenv('FLASK_ENV', 'development') portkey_metadata = {'tenant_id': str(self.tenant.id), 'environment': environment} if self._variables['llm_provider'] == 'openai': portkey_headers = createHeaders(api_key=os.getenv('PORTKEY_API_KEY'), metadata=portkey_metadata, provider='openai') api_key = os.getenv('OPENAI_API_KEY') self._llm = ChatOpenAI(api_key=api_key, model=self._variables['llm_model'], temperature=self._variables['RAG_temperature'], base_url=PORTKEY_GATEWAY_URL, default_headers=portkey_headers) self._llm_no_rag = ChatOpenAI(api_key=api_key, model=self._variables['llm_model'], temperature=self._variables['no_RAG_temperature'], base_url=PORTKEY_GATEWAY_URL, default_headers=portkey_headers) self._variables['tool_calling_supported'] = self._variables['llm_model'] in ['gpt-4o', 'gpt-4o-mini'] elif self._variables['llm_provider'] == 'anthropic': api_key = os.getenv('ANTHROPIC_API_KEY') llm_model_ext = os.getenv('ANTHROPIC_LLM_VERSIONS', {}).get(self._variables['llm_model']) self._llm = ChatAnthropic(api_key=api_key, model=llm_model_ext, temperature=self._variables['RAG_temperature']) self._llm_no_rag = ChatAnthropic(api_key=api_key, model=llm_model_ext, temperature=self._variables['RAG_temperature']) self._variables['tool_calling_supported'] = True else: raise ValueError(f"Invalid chat provider: {self._variables['llm_provider']}") @property def transcription_client(self): if self._transcription_client is None: environment = os.getenv('FLASK_ENV', 'development') portkey_metadata = {'tenant_id': str(self.tenant.id), 'environment': environment} portkey_headers = createHeaders(api_key=os.getenv('PORTKEY_API_KEY'), metadata=portkey_metadata, provider='openai') api_key = os.getenv('OPENAI_API_KEY') self._transcription_client = OpenAI(api_key=api_key, base_url=PORTKEY_GATEWAY_URL, default_headers=portkey_headers) self._variables['transcription_model'] = 'whisper-1' return self._transcription_client @property def embedding_db_model(self): if self._embedding_db_model is None: self._embedding_db_model = self.get_embedding_db_model() return self._embedding_db_model def get_embedding_db_model(self): current_app.logger.debug("In get_embedding_db_model") if self._embedding_db_model is None: self._embedding_db_model = EmbeddingSmallOpenAI \ if self._variables['embedding_model'] == 'text-embedding-3-small' \ else EmbeddingLargeOpenAI current_app.logger.debug(f"Embedding DB Model: {self._embedding_db_model}") return self._embedding_db_model def get_prompt_template(self, template_name: str) -> str: current_app.logger.info(f"Getting prompt template for {template_name}") if template_name not in self._prompt_templates: self._prompt_templates[template_name] = self._load_prompt_template(template_name) return self._prompt_templates[template_name] def _load_prompt_template(self, template_name: str) -> str: # In the future, this method will make an API call to Portkey # For now, we'll simulate it with a placeholder implementation # You can replace this with your current prompt loading logic return self._variables['templates'][template_name] def __getitem__(self, key: str) -> Any: current_app.logger.debug(f"ModelVariables: Getting {key}") # Support older template names (suffix = _template) if key.endswith('_template'): key = key[:-len('_template')] current_app.logger.debug(f"ModelVariables: Getting modified {key}") if key == 'embedding_model': return self.embedding_model elif key == 'embedding_db_model': return self.embedding_db_model elif key == 'llm': return self.llm elif key == 'llm_no_rag': return self.llm_no_rag elif key == 'transcription_client': return self.transcription_client elif key in self._variables.get('prompt_templates', []): return self.get_prompt_template(key) return self._variables.get(key) def __setitem__(self, key: str, value: Any) -> None: self._variables[key] = value def __delitem__(self, key: str) -> None: del self._variables[key] def __iter__(self) -> Iterator[str]: return iter(self._variables) def __len__(self): return len(self._variables) def get(self, key: str, default: Any = None) -> Any: return self.__getitem__(key) or default def update(self, **kwargs) -> None: self._variables.update(kwargs) def items(self): return self._variables.items() def keys(self): return self._variables.keys() def values(self): return self._variables.values() def select_model_variables(tenant): model_variables = ModelVariables(tenant=tenant) return model_variables def create_language_template(template, language): try: full_language = langcodes.Language.make(language=language) language_template = template.replace('{language}', full_language.display_name()) except ValueError: language_template = template.replace('{language}', language) return language_template def replace_variable_in_template(template, variable, value): return template.replace(variable, value)