- turned model_variables into a class with lazy loading

- some improvements to Healthchecks
This commit is contained in:
Josako
2024-09-24 10:48:52 +02:00
parent 67bdeac434
commit a740c96630
16 changed files with 382 additions and 191 deletions

View File

@@ -1,23 +1,31 @@
from langchain_core.retrievers import BaseRetriever
from sqlalchemy import asc
from sqlalchemy.exc import SQLAlchemyError
from pydantic import BaseModel, Field
from pydantic import Field, BaseModel, PrivateAttr
from typing import Any, Dict
from flask import current_app
from common.extensions import db
from common.models.interaction import ChatSession, Interaction
from common.utils.datetime_utils import get_date_in_timezone
from common.utils.model_utils import ModelVariables
class EveAIHistoryRetriever(BaseRetriever):
model_variables: Dict[str, Any] = Field(...)
session_id: str = Field(...)
class EveAIHistoryRetriever(BaseRetriever, BaseModel):
_model_variables: ModelVariables = PrivateAttr()
_session_id: str = PrivateAttr()
def __init__(self, model_variables: Dict[str, Any], session_id: str):
def __init__(self, model_variables: ModelVariables, session_id: str):
super().__init__()
self.model_variables = model_variables
self.session_id = session_id
self._model_variables = model_variables
self._session_id = session_id
@property
def model_variables(self) -> ModelVariables:
return self._model_variables
@property
def session_id(self) -> str:
return self._session_id
def _get_relevant_documents(self, query: str):
current_app.logger.debug(f'Retrieving history of interactions for query: {query}')

View File

@@ -1,30 +1,39 @@
from langchain_core.retrievers import BaseRetriever
from sqlalchemy import func, and_, or_, desc
from sqlalchemy.exc import SQLAlchemyError
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, PrivateAttr
from typing import Any, Dict
from flask import current_app
from common.extensions import db
from common.models.document import Document, DocumentVersion
from common.utils.datetime_utils import get_date_in_timezone
from common.utils.model_utils import ModelVariables
class EveAIRetriever(BaseRetriever):
model_variables: Dict[str, Any] = Field(...)
tenant_info: Dict[str, Any] = Field(...)
class EveAIRetriever(BaseRetriever, BaseModel):
_model_variables: ModelVariables = PrivateAttr()
_tenant_info: Dict[str, Any] = PrivateAttr()
def __init__(self, model_variables: Dict[str, Any], tenant_info: Dict[str, Any]):
def __init__(self, model_variables: ModelVariables, tenant_info: Dict[str, Any]):
super().__init__()
self.model_variables = model_variables
self.tenant_info = tenant_info
current_app.logger.debug(f'Model variables type: {type(model_variables)}')
self._model_variables = model_variables
self._tenant_info = tenant_info
@property
def model_variables(self) -> ModelVariables:
return self._model_variables
@property
def tenant_info(self) -> Dict[str, Any]:
return self._tenant_info
def _get_relevant_documents(self, query: str):
current_app.logger.debug(f'Retrieving relevant documents for query: {query}')
query_embedding = self._get_query_embedding(query)
current_app.logger.debug(f'Model Variables Private: {type(self._model_variables)}')
current_app.logger.debug(f'Model Variables Property: {type(self.model_variables)}')
db_class = self.model_variables['embedding_db_model']
similarity_threshold = self.model_variables['similarity_threshold']
k = self.model_variables['k']

View File

@@ -0,0 +1,28 @@
from common.extensions import db
from sqlalchemy.dialects.postgresql import JSONB
import sqlalchemy as sa
class LLMUsageMetric(db.Model):
__bind_key__ = 'public'
__table_args__ = {'schema': 'public'}
id = db.Column(db.Integer, primary_key=True)
tenant_id = db.Column(db.Integer, nullable=False)
environment = db.Column(db.String(20), nullable=False)
activity = db.Column(db.String(20), nullable=False)
sub_activity = db.Column(db.String(20), nullable=False)
activity_detail = db.Column(db.String(50), nullable=True)
session_id = db.Column(db.String(50), nullable=True) # Chat Session ID
interaction_id = db.Column(db.Integer, nullable=True) # Chat Interaction ID
document_version_id = db.Column(db.Integer, nullable=True)
prompt_tokens = db.Column(db.Integer, nullable=True)
completion_tokens = db.Column(db.Integer, nullable=True)
total_tokens = db.Column(db.Integer, nullable=True)
cost = db.Column(db.Float, nullable=True)
latency = db.Column(db.Float, nullable=True)
model_name = db.Column(db.String(50), nullable=False)
timestamp = db.Column(db.DateTime, nullable=False)
additional_info = db.Column(JSONB, nullable=True)
# Add any additional fields or methods as needed

View File

@@ -2,7 +2,6 @@ from common.extensions import db
from flask_security import UserMixin, RoleMixin
from sqlalchemy.dialects.postgresql import ARRAY
import sqlalchemy as sa
from sqlalchemy import CheckConstraint
class Tenant(db.Model):

View File

@@ -5,14 +5,14 @@ from flask import current_app
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain_anthropic import ChatAnthropic
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain.prompts import ChatPromptTemplate
import ast
from typing import List
from typing import List, Any, Iterator
from collections.abc import MutableMapping
from openai import OpenAI
# from groq import Groq
from portkey_ai import createHeaders, PORTKEY_GATEWAY_URL
from common.models.document import EmbeddingSmallOpenAI, EmbeddingLargeOpenAI
from common.models.user import Tenant
from config.model_config import MODEL_CONFIG
class CitedAnswer(BaseModel):
@@ -36,180 +36,221 @@ def set_language_prompt_template(cls, language_prompt):
cls.__doc__ = language_prompt
def select_model_variables(tenant):
embedding_provider = tenant.embedding_model.rsplit('.', 1)[0]
embedding_model = tenant.embedding_model.rsplit('.', 1)[1]
class ModelVariables(MutableMapping):
def __init__(self, tenant: Tenant):
self.tenant = tenant
self._variables = self._initialize_variables()
self._embedding_model = None
self._llm = None
self._llm_no_rag = None
self._transcription_client = None
self._prompt_templates = {}
self._embedding_db_model = None
llm_provider = tenant.llm_model.rsplit('.', 1)[0]
llm_model = tenant.llm_model.rsplit('.', 1)[1]
def _initialize_variables(self):
variables = {}
# Set model variables
model_variables = {}
if tenant.es_k:
model_variables['k'] = tenant.es_k
else:
model_variables['k'] = 5
# We initialize the variables that are available knowing the tenant. For the other, we will apply 'lazy loading'
variables['k'] = self.tenant.es_k or 5
variables['similarity_threshold'] = self.tenant.es_similarity_threshold or 0.7
variables['RAG_temperature'] = self.tenant.chat_RAG_temperature or 0.3
variables['no_RAG_temperature'] = self.tenant.chat_no_RAG_temperature or 0.5
variables['embed_tuning'] = self.tenant.embed_tuning or False
variables['rag_tuning'] = self.tenant.rag_tuning or False
variables['rag_context'] = self.tenant.rag_context or " "
if tenant.es_similarity_threshold:
model_variables['similarity_threshold'] = tenant.es_similarity_threshold
else:
model_variables['similarity_threshold'] = 0.7
# Set HTML Chunking Variables
variables['html_tags'] = self.tenant.html_tags
variables['html_end_tags'] = self.tenant.html_end_tags
variables['html_included_elements'] = self.tenant.html_included_elements
variables['html_excluded_elements'] = self.tenant.html_excluded_elements
variables['html_excluded_classes'] = self.tenant.html_excluded_classes
if tenant.chat_RAG_temperature:
model_variables['RAG_temperature'] = tenant.chat_RAG_temperature
else:
model_variables['RAG_temperature'] = 0.3
# Set Chunk Size variables
variables['min_chunk_size'] = self.tenant.min_chunk_size
variables['max_chunk_size'] = self.tenant.max_chunk_size
if tenant.chat_no_RAG_temperature:
model_variables['no_RAG_temperature'] = tenant.chat_no_RAG_temperature
else:
model_variables['no_RAG_temperature'] = 0.5
# Set model providers
variables['embedding_provider'], variables['embedding_model'] = self.tenant.embedding_model.rsplit('.', 1)
variables['llm_provider'], variables['llm_model'] = self.tenant.llm_model.rsplit('.', 1)
variables["templates"] = current_app.config['PROMPT_TEMPLATES'][(f"{variables['llm_provider']}."
f"{variables['llm_model']}")]
current_app.logger.info(f"Loaded prompt templates: \n")
current_app.logger.info(f"{variables['templates']}")
# Set Tuning variables
if tenant.embed_tuning:
model_variables['embed_tuning'] = tenant.embed_tuning
else:
model_variables['embed_tuning'] = False
# Set model-specific configurations
model_config = MODEL_CONFIG.get(variables['llm_provider'], {}).get(variables['llm_model'], {})
variables.update(model_config)
if tenant.rag_tuning:
model_variables['rag_tuning'] = tenant.rag_tuning
else:
model_variables['rag_tuning'] = False
variables['annotation_chunk_length'] = current_app.config['ANNOTATION_TEXT_CHUNK_LENGTH'][self.tenant.llm_model]
if tenant.rag_context:
model_variables['rag_context'] = tenant.rag_context
else:
model_variables['rag_context'] = " "
if variables['tool_calling_supported']:
variables['cited_answer_cls'] = CitedAnswer
# Set HTML Chunking Variables
model_variables['html_tags'] = tenant.html_tags
model_variables['html_end_tags'] = tenant.html_end_tags
model_variables['html_included_elements'] = tenant.html_included_elements
model_variables['html_excluded_elements'] = tenant.html_excluded_elements
model_variables['html_excluded_classes'] = tenant.html_excluded_classes
return variables
# Set Chunk Size variables
model_variables['min_chunk_size'] = tenant.min_chunk_size
model_variables['max_chunk_size'] = tenant.max_chunk_size
@property
def embedding_model(self):
if self._embedding_model is None:
environment = os.getenv('FLASK_ENV', 'development')
portkey_metadata = {'tenant_id': str(self.tenant.id), 'environment': environment}
environment = os.getenv('FLASK_ENV', 'development')
portkey_metadata = {'tenant_id': str(tenant.id), 'environment': environment}
if self._variables['embedding_provider'] == 'openai':
portkey_headers = createHeaders(api_key=os.getenv('PORTKEY_API_KEY'),
provider='openai',
metadata=portkey_metadata)
api_key = os.getenv('OPENAI_API_KEY')
model = self._variables['embedding_model']
self._embedding_model = OpenAIEmbeddings(api_key=api_key,
model=model,
base_url=PORTKEY_GATEWAY_URL,
default_headers=portkey_headers)
self._embedding_db_model = EmbeddingSmallOpenAI \
if model == 'text-embedding-3-small' \
else EmbeddingLargeOpenAI
else:
raise ValueError(f"Invalid embedding provider: {self._variables['embedding_provider']}")
# Set Embedding variables
match embedding_provider:
case 'openai':
portkey_headers = createHeaders(api_key=current_app.config.get('PORTKEY_API_KEY'),
provider='openai',
metadata=portkey_metadata)
match embedding_model:
case 'text-embedding-3-small':
api_key = current_app.config.get('OPENAI_API_KEY')
model_variables['embedding_model'] = OpenAIEmbeddings(api_key=api_key,
model='text-embedding-3-small',
base_url=PORTKEY_GATEWAY_URL,
default_headers=portkey_headers
)
model_variables['embedding_db_model'] = EmbeddingSmallOpenAI
case 'text-embedding-3-large':
api_key = current_app.config.get('OPENAI_API_KEY')
model_variables['embedding_model'] = OpenAIEmbeddings(api_key=api_key,
model='text-embedding-3-large',
base_url=PORTKEY_GATEWAY_URL,
default_headers=portkey_headers
)
model_variables['embedding_db_model'] = EmbeddingLargeOpenAI
case _:
raise Exception(f'Error setting model variables for tenant {tenant.id} '
f'error: Invalid embedding model')
case _:
raise Exception(f'Error setting model variables for tenant {tenant.id} '
f'error: Invalid embedding provider')
return self._embedding_model
# Set Chat model variables
match llm_provider:
case 'openai':
portkey_headers = createHeaders(api_key=current_app.config.get('PORTKEY_API_KEY'),
@property
def llm(self):
if self._llm is None:
self._initialize_llm()
return self._llm
@property
def llm_no_rag(self):
if self._llm_no_rag is None:
self._initialize_llm()
return self._llm_no_rag
def _initialize_llm(self):
environment = os.getenv('FLASK_ENV', 'development')
portkey_metadata = {'tenant_id': str(self.tenant.id), 'environment': environment}
if self._variables['llm_provider'] == 'openai':
portkey_headers = createHeaders(api_key=os.getenv('PORTKEY_API_KEY'),
metadata=portkey_metadata,
provider='openai')
tool_calling_supported = False
api_key = current_app.config.get('OPENAI_API_KEY')
model_variables['llm'] = ChatOpenAI(api_key=api_key,
model=llm_model,
temperature=model_variables['RAG_temperature'],
api_key = os.getenv('OPENAI_API_KEY')
self._llm = ChatOpenAI(api_key=api_key,
model=self._variables['llm_model'],
temperature=self._variables['RAG_temperature'],
base_url=PORTKEY_GATEWAY_URL,
default_headers=portkey_headers)
self._llm_no_rag = ChatOpenAI(api_key=api_key,
model=self._variables['llm_model'],
temperature=self._variables['no_RAG_temperature'],
base_url=PORTKEY_GATEWAY_URL,
default_headers=portkey_headers)
self._variables['tool_calling_supported'] = self._variables['llm_model'] in ['gpt-4o', 'gpt-4o-mini']
elif self._variables['llm_provider'] == 'anthropic':
api_key = os.getenv('ANTHROPIC_API_KEY')
llm_model_ext = os.getenv('ANTHROPIC_LLM_VERSIONS', {}).get(self._variables['llm_model'])
self._llm = ChatAnthropic(api_key=api_key,
model=llm_model_ext,
temperature=self._variables['RAG_temperature'])
self._llm_no_rag = ChatAnthropic(api_key=api_key,
model=llm_model_ext,
temperature=self._variables['RAG_temperature'])
self._variables['tool_calling_supported'] = True
else:
raise ValueError(f"Invalid chat provider: {self._variables['llm_provider']}")
@property
def transcription_client(self):
if self._transcription_client is None:
environment = os.getenv('FLASK_ENV', 'development')
portkey_metadata = {'tenant_id': str(self.tenant.id), 'environment': environment}
portkey_headers = createHeaders(api_key=os.getenv('PORTKEY_API_KEY'),
metadata=portkey_metadata,
provider='openai')
api_key = os.getenv('OPENAI_API_KEY')
self._transcription_client = OpenAI(api_key=api_key,
base_url=PORTKEY_GATEWAY_URL,
default_headers=portkey_headers)
model_variables['llm_no_rag'] = ChatOpenAI(api_key=api_key,
model=llm_model,
temperature=model_variables['no_RAG_temperature'],
base_url=PORTKEY_GATEWAY_URL,
default_headers=portkey_headers)
tool_calling_supported = False
match llm_model:
case 'gpt-4o' | 'gpt-4o-mini':
tool_calling_supported = True
processing_chunk_size = 10000
processing_chunk_overlap = 200
processing_min_chunk_size = 8000
processing_max_chunk_size = 12000
case _:
raise Exception(f'Error setting model variables for tenant {tenant.id} '
f'error: Invalid chat model')
case 'anthropic':
api_key = current_app.config.get('ANTHROPIC_API_KEY')
# Anthropic does not have the same 'generic' model names as OpenAI
llm_model_ext = current_app.config.get('ANTHROPIC_LLM_VERSIONS').get(llm_model)
model_variables['llm'] = ChatAnthropic(api_key=api_key,
model=llm_model_ext,
temperature=model_variables['RAG_temperature'])
model_variables['llm_no_rag'] = ChatAnthropic(api_key=api_key,
model=llm_model_ext,
temperature=model_variables['RAG_temperature'])
tool_calling_supported = True
processing_chunk_size = 10000
processing_chunk_overlap = 200
processing_min_chunk_size = 8000
processing_max_chunk_size = 12000
case _:
raise Exception(f'Error setting model variables for tenant {tenant.id} '
f'error: Invalid chat provider')
self._variables['transcription_model'] = 'whisper-1'
model_variables['processing_chunk_size'] = processing_chunk_size
model_variables['processing_chunk_overlap'] = processing_chunk_overlap
model_variables['processing_min_chunk_size'] = processing_min_chunk_size
model_variables['processing_max_chunk_size'] = processing_max_chunk_size
return self._transcription_client
if tool_calling_supported:
model_variables['cited_answer_cls'] = CitedAnswer
@property
def embedding_db_model(self):
if self._embedding_db_model is None:
self._embedding_db_model = self.get_embedding_db_model()
return self._embedding_db_model
templates = current_app.config['PROMPT_TEMPLATES'][f'{llm_provider}.{llm_model}']
model_variables['summary_template'] = templates['summary']
model_variables['rag_template'] = templates['rag']
model_variables['history_template'] = templates['history']
model_variables['encyclopedia_template'] = templates['encyclopedia']
model_variables['transcript_template'] = templates['transcript']
model_variables['html_parse_template'] = templates['html_parse']
model_variables['pdf_parse_template'] = templates['pdf_parse']
def get_embedding_db_model(self):
current_app.logger.debug("In get_embedding_db_model")
if self._embedding_db_model is None:
self._embedding_db_model = EmbeddingSmallOpenAI \
if self._variables['embedding_model'] == 'text-embedding-3-small' \
else EmbeddingLargeOpenAI
current_app.logger.debug(f"Embedding DB Model: {self._embedding_db_model}")
return self._embedding_db_model
model_variables['annotation_chunk_length'] = current_app.config['ANNOTATION_TEXT_CHUNK_LENGTH'][tenant.llm_model]
def get_prompt_template(self, template_name: str) -> str:
current_app.logger.info(f"Getting prompt template for {template_name}")
if template_name not in self._prompt_templates:
self._prompt_templates[template_name] = self._load_prompt_template(template_name)
return self._prompt_templates[template_name]
# Transcription Client Variables.
# Using Groq
# api_key = current_app.config.get('GROQ_API_KEY')
# model_variables['transcription_client'] = Groq(api_key=api_key)
# model_variables['transcription_model'] = 'whisper-large-v3'
def _load_prompt_template(self, template_name: str) -> str:
# In the future, this method will make an API call to Portkey
# For now, we'll simulate it with a placeholder implementation
# You can replace this with your current prompt loading logic
return self._variables['templates'][template_name]
# Using OpenAI for transcriptions
portkey_metadata = {'tenant_id': str(tenant.id)}
portkey_headers = createHeaders(api_key=current_app.config.get('PORTKEY_API_KEY'),
metadata=portkey_metadata,
provider='openai'
)
api_key = current_app.config.get('OPENAI_API_KEY')
model_variables['transcription_client'] = OpenAI(api_key=api_key,
base_url=PORTKEY_GATEWAY_URL,
default_headers=portkey_headers)
model_variables['transcription_model'] = 'whisper-1'
def __getitem__(self, key: str) -> Any:
current_app.logger.debug(f"ModelVariables: Getting {key}")
# Support older template names (suffix = _template)
if key.endswith('_template'):
key = key[:-len('_template')]
current_app.logger.debug(f"ModelVariables: Getting modified {key}")
if key == 'embedding_model':
return self.embedding_model
elif key == 'embedding_db_model':
return self.embedding_db_model
elif key == 'llm':
return self.llm
elif key == 'llm_no_rag':
return self.llm_no_rag
elif key == 'transcription_client':
return self.transcription_client
elif key in self._variables.get('prompt_templates', []):
return self.get_prompt_template(key)
return self._variables.get(key)
def __setitem__(self, key: str, value: Any) -> None:
self._variables[key] = value
def __delitem__(self, key: str) -> None:
del self._variables[key]
def __iter__(self) -> Iterator[str]:
return iter(self._variables)
def __len__(self):
return len(self._variables)
def get(self, key: str, default: Any = None) -> Any:
return self.__getitem__(key) or default
def update(self, **kwargs) -> None:
self._variables.update(kwargs)
def items(self):
return self._variables.items()
def keys(self):
return self._variables.keys()
def values(self):
return self._variables.values()
def select_model_variables(tenant):
model_variables = ModelVariables(tenant=tenant)
return model_variables