- Introduction of dynamic Retrievers & Specialists
- Introduction of dynamic Processors - Introduction of caching system - Introduction of a better template manager - Adaptation of ModelVariables to support dynamic Processors / Retrievers / Specialists - Start adaptation of chat client
This commit is contained in:
@@ -12,6 +12,8 @@ from flask_wtf import CSRFProtect
|
||||
from flask_restx import Api
|
||||
from prometheus_flask_exporter import PrometheusMetrics
|
||||
|
||||
from .langchain.templates.template_manager import TemplateManager
|
||||
from .utils.cache.eveai_cache_manager import EveAICacheManager
|
||||
from .utils.simple_encryption import SimpleEncryption
|
||||
from .utils.minio_utils import MinioClient
|
||||
|
||||
@@ -32,3 +34,5 @@ api_rest = Api()
|
||||
simple_encryption = SimpleEncryption()
|
||||
minio_client = MinioClient()
|
||||
metrics = PrometheusMetrics.for_app_factory()
|
||||
template_manager = TemplateManager()
|
||||
cache_manager = EveAICacheManager()
|
||||
|
||||
23
common/langchain/outputs/base.py
Normal file
23
common/langchain/outputs/base.py
Normal file
@@ -0,0 +1,23 @@
|
||||
# Output Schema Management - common/langchain/outputs/base.py
|
||||
from typing import Dict, Type, Any
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class BaseSpecialistOutput(BaseModel):
|
||||
"""Base class for all specialist outputs"""
|
||||
pass
|
||||
|
||||
|
||||
class OutputRegistry:
|
||||
"""Registry for specialist output schemas"""
|
||||
_schemas: Dict[str, Type[BaseSpecialistOutput]] = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, specialist_type: str, schema_class: Type[BaseSpecialistOutput]):
|
||||
cls._schemas[specialist_type] = schema_class
|
||||
|
||||
@classmethod
|
||||
def get_schema(cls, specialist_type: str) -> Type[BaseSpecialistOutput]:
|
||||
if specialist_type not in cls._schemas:
|
||||
raise ValueError(f"No output schema registered for {specialist_type}")
|
||||
return cls._schemas[specialist_type]
|
||||
22
common/langchain/outputs/rag.py
Normal file
22
common/langchain/outputs/rag.py
Normal file
@@ -0,0 +1,22 @@
|
||||
# RAG Specialist Output - common/langchain/outputs/rag.py
|
||||
from typing import List
|
||||
from pydantic import Field
|
||||
from .base import BaseSpecialistOutput
|
||||
|
||||
|
||||
class RAGOutput(BaseSpecialistOutput):
|
||||
"""Output schema for RAG specialist"""
|
||||
"""Default docstring - to be replaced with actual prompt"""
|
||||
|
||||
answer: str = Field(
|
||||
...,
|
||||
description="The answer to the user question, based on the given sources",
|
||||
)
|
||||
citations: List[int] = Field(
|
||||
...,
|
||||
description="The integer IDs of the SPECIFIC sources that were used to generate the answer"
|
||||
)
|
||||
insufficient_info: bool = Field(
|
||||
False, # Default value is set to False
|
||||
description="A boolean indicating whether given sources were sufficient or not to generate the answer"
|
||||
)
|
||||
@@ -1,145 +0,0 @@
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from sqlalchemy import func, and_, or_, desc
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
from typing import Any, Dict
|
||||
from flask import current_app
|
||||
|
||||
from common.extensions import db
|
||||
from common.models.document import Document, DocumentVersion
|
||||
from common.utils.datetime_utils import get_date_in_timezone
|
||||
from common.utils.model_utils import ModelVariables
|
||||
|
||||
|
||||
class EveAIDefaultRagRetriever(BaseRetriever, BaseModel):
|
||||
_catalog_id: int = PrivateAttr()
|
||||
_model_variables: ModelVariables = PrivateAttr()
|
||||
_tenant_info: Dict[str, Any] = PrivateAttr()
|
||||
|
||||
def __init__(self, catalog_id: int, model_variables: ModelVariables, tenant_info: Dict[str, Any]):
|
||||
super().__init__()
|
||||
current_app.logger.debug(f'Model variables type: {type(model_variables)}')
|
||||
self._catalog_id = catalog_id
|
||||
self._model_variables = model_variables
|
||||
self._tenant_info = tenant_info
|
||||
|
||||
@property
|
||||
def catalog_id(self) -> int:
|
||||
return self._catalog_id
|
||||
|
||||
@property
|
||||
def model_variables(self) -> ModelVariables:
|
||||
return self._model_variables
|
||||
|
||||
@property
|
||||
def tenant_info(self) -> Dict[str, Any]:
|
||||
return self._tenant_info
|
||||
|
||||
def _get_relevant_documents(self, query: str):
|
||||
current_app.logger.debug(f'Retrieving relevant documents for query: {query}')
|
||||
query_embedding = self._get_query_embedding(query)
|
||||
current_app.logger.debug(f'Model Variables Private: {type(self._model_variables)}')
|
||||
current_app.logger.debug(f'Model Variables Property: {type(self.model_variables)}')
|
||||
db_class = self.model_variables['embedding_db_model']
|
||||
similarity_threshold = self.model_variables['similarity_threshold']
|
||||
k = self.model_variables['k']
|
||||
|
||||
if self.model_variables['rag_tuning']:
|
||||
try:
|
||||
current_date = get_date_in_timezone(self.tenant_info['timezone'])
|
||||
current_app.rag_tuning_logger.debug(f'Current date: {current_date}\n')
|
||||
|
||||
# Debug query to show similarity for all valid documents (without chunk text)
|
||||
debug_query = (
|
||||
db.session.query(
|
||||
Document.id.label('document_id'),
|
||||
DocumentVersion.id.label('version_id'),
|
||||
db_class.id.label('embedding_id'),
|
||||
(1 - db_class.embedding.cosine_distance(query_embedding)).label('similarity')
|
||||
)
|
||||
.join(DocumentVersion, db_class.doc_vers_id == DocumentVersion.id)
|
||||
.join(Document, DocumentVersion.doc_id == Document.id)
|
||||
.filter(
|
||||
or_(Document.valid_from.is_(None), func.date(Document.valid_from) <= current_date),
|
||||
or_(Document.valid_to.is_(None), func.date(Document.valid_to) >= current_date)
|
||||
)
|
||||
.order_by(desc('similarity'))
|
||||
)
|
||||
|
||||
debug_results = debug_query.all()
|
||||
|
||||
current_app.logger.debug("Debug: Similarity for all valid documents:")
|
||||
for row in debug_results:
|
||||
current_app.rag_tuning_logger.debug(f"Doc ID: {row.document_id}, "
|
||||
f"Version ID: {row.version_id}, "
|
||||
f"Embedding ID: {row.embedding_id}, "
|
||||
f"Similarity: {row.similarity}")
|
||||
current_app.rag_tuning_logger.debug(f'---------------------------------------\n')
|
||||
except SQLAlchemyError as e:
|
||||
current_app.logger.error(f'Error generating overview: {e}')
|
||||
db.session.rollback()
|
||||
|
||||
if self.model_variables['rag_tuning']:
|
||||
current_app.rag_tuning_logger.debug(f'Parameters for Retrieval of documents: \n')
|
||||
current_app.rag_tuning_logger.debug(f'Similarity Threshold: {similarity_threshold}\n')
|
||||
current_app.rag_tuning_logger.debug(f'K: {k}\n')
|
||||
current_app.rag_tuning_logger.debug(f'---------------------------------------\n')
|
||||
|
||||
try:
|
||||
current_date = get_date_in_timezone(self.tenant_info['timezone'])
|
||||
# Subquery to find the latest version of each document
|
||||
subquery = (
|
||||
db.session.query(
|
||||
DocumentVersion.doc_id,
|
||||
func.max(DocumentVersion.id).label('latest_version_id')
|
||||
)
|
||||
.group_by(DocumentVersion.doc_id)
|
||||
.subquery()
|
||||
)
|
||||
# Main query to filter embeddings
|
||||
query_obj = (
|
||||
db.session.query(db_class,
|
||||
(1 - db_class.embedding.cosine_distance(query_embedding)).label('similarity'))
|
||||
.join(DocumentVersion, db_class.doc_vers_id == DocumentVersion.id)
|
||||
.join(Document, DocumentVersion.doc_id == Document.id)
|
||||
.join(subquery, DocumentVersion.id == subquery.c.latest_version_id)
|
||||
.filter(
|
||||
or_(Document.valid_from.is_(None), func.date(Document.valid_from) <= current_date),
|
||||
or_(Document.valid_to.is_(None), func.date(Document.valid_to) >= current_date),
|
||||
(1 - db_class.embedding.cosine_distance(query_embedding)) > similarity_threshold,
|
||||
Document.catalog_id == self._catalog_id
|
||||
)
|
||||
.order_by(desc('similarity'))
|
||||
.limit(k)
|
||||
)
|
||||
|
||||
if self.model_variables['rag_tuning']:
|
||||
current_app.rag_tuning_logger.debug(f'Query executed for Retrieval of documents: \n')
|
||||
current_app.rag_tuning_logger.debug(f'{query_obj.statement}\n')
|
||||
current_app.rag_tuning_logger.debug(f'---------------------------------------\n')
|
||||
|
||||
res = query_obj.all()
|
||||
|
||||
if self.model_variables['rag_tuning']:
|
||||
current_app.rag_tuning_logger.debug(f'Retrieved {len(res)} relevant documents \n')
|
||||
current_app.rag_tuning_logger.debug(f'Data retrieved: \n')
|
||||
current_app.rag_tuning_logger.debug(f'{res}\n')
|
||||
current_app.rag_tuning_logger.debug(f'---------------------------------------\n')
|
||||
|
||||
result = []
|
||||
for doc in res:
|
||||
if self.model_variables['rag_tuning']:
|
||||
current_app.rag_tuning_logger.debug(f'Document ID: {doc[0].id} - Distance: {doc[1]}\n')
|
||||
current_app.rag_tuning_logger.debug(f'Chunk: \n {doc[0].chunk}\n\n')
|
||||
result.append(f'SOURCE: {doc[0].id}\n\n{doc[0].chunk}\n\n')
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
current_app.logger.error(f'Error retrieving relevant documents: {e}')
|
||||
db.session.rollback()
|
||||
return []
|
||||
return result
|
||||
|
||||
def _get_query_embedding(self, query: str):
|
||||
embedding_model = self.model_variables['embedding_model']
|
||||
query_embedding = embedding_model.embed_query(query)
|
||||
return query_embedding
|
||||
@@ -1,154 +0,0 @@
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from sqlalchemy import func, and_, or_, desc, cast, JSON
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
from typing import Any, Dict, List, Optional
|
||||
from flask import current_app
|
||||
from contextlib import contextmanager
|
||||
|
||||
from common.extensions import db
|
||||
from common.models.document import Document, DocumentVersion, Catalog
|
||||
from common.utils.datetime_utils import get_date_in_timezone
|
||||
from common.utils.model_utils import ModelVariables
|
||||
|
||||
|
||||
class EveAIDossierRetriever(BaseRetriever, BaseModel):
|
||||
_catalog_id: int = PrivateAttr()
|
||||
_model_variables: ModelVariables = PrivateAttr()
|
||||
_tenant_info: Dict[str, Any] = PrivateAttr()
|
||||
_active_filters: Optional[Dict[str, Any]] = PrivateAttr()
|
||||
|
||||
def __init__(self, catalog_id: int, model_variables: ModelVariables, tenant_info: Dict[str, Any]):
|
||||
super().__init__()
|
||||
self._catalog_id = catalog_id
|
||||
self._model_variables = model_variables
|
||||
self._tenant_info = tenant_info
|
||||
self._active_filters = None
|
||||
|
||||
@contextmanager
|
||||
def filtering(self, metadata_filters: Dict[str, Any]):
|
||||
"""Context manager for temporarily setting metadata filters"""
|
||||
previous_filters = self._active_filters
|
||||
self._active_filters = metadata_filters
|
||||
try:
|
||||
yield self
|
||||
finally:
|
||||
self._active_filters = previous_filters
|
||||
|
||||
def _build_metadata_filter_conditions(self, query):
|
||||
"""Build SQL conditions for metadata filtering"""
|
||||
if not self._active_filters:
|
||||
return query
|
||||
|
||||
conditions = []
|
||||
for field, value in self._active_filters.items():
|
||||
if value is None:
|
||||
continue
|
||||
|
||||
# Handle both single values and lists of values
|
||||
if isinstance(value, (list, tuple)):
|
||||
# Multiple values - create OR condition
|
||||
or_conditions = []
|
||||
for val in value:
|
||||
or_conditions.append(
|
||||
cast(DocumentVersion.user_metadata[field].astext, JSON) == str(val)
|
||||
)
|
||||
if or_conditions:
|
||||
conditions.append(or_(*or_conditions))
|
||||
else:
|
||||
# Single value - direct comparison
|
||||
conditions.append(
|
||||
cast(DocumentVersion.user_metadata[field].astext, JSON) == str(value)
|
||||
)
|
||||
|
||||
if conditions:
|
||||
query = query.filter(and_(*conditions))
|
||||
|
||||
return query
|
||||
|
||||
def _get_relevant_documents(self, query: str):
|
||||
current_app.logger.debug(f'Retrieving relevant documents for dossier query: {query}')
|
||||
if self._active_filters:
|
||||
current_app.logger.debug(f'Using metadata filters: {self._active_filters}')
|
||||
|
||||
query_embedding = self._get_query_embedding(query)
|
||||
db_class = self.model_variables['embedding_db_model']
|
||||
similarity_threshold = self.model_variables['similarity_threshold']
|
||||
k = self.model_variables['k']
|
||||
|
||||
try:
|
||||
current_date = get_date_in_timezone(self.tenant_info['timezone'])
|
||||
|
||||
# Subquery to find the latest version of each document
|
||||
subquery = (
|
||||
db.session.query(
|
||||
DocumentVersion.doc_id,
|
||||
func.max(DocumentVersion.id).label('latest_version_id')
|
||||
)
|
||||
.group_by(DocumentVersion.doc_id)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
# Build base query
|
||||
# Build base query
|
||||
query_obj = (
|
||||
db.session.query(db_class,
|
||||
(1 - db_class.embedding.cosine_distance(query_embedding)).label('similarity'))
|
||||
.join(DocumentVersion, db_class.doc_vers_id == DocumentVersion.id)
|
||||
.join(Document, DocumentVersion.doc_id == Document.id)
|
||||
.join(subquery, DocumentVersion.id == subquery.c.latest_version_id)
|
||||
.filter(
|
||||
or_(Document.valid_from.is_(None), func.date(Document.valid_from) <= current_date),
|
||||
or_(Document.valid_to.is_(None), func.date(Document.valid_to) >= current_date),
|
||||
(1 - db_class.embedding.cosine_distance(query_embedding)) > similarity_threshold,
|
||||
Document.catalog_id == self._catalog_id
|
||||
)
|
||||
)
|
||||
|
||||
# Apply metadata filters
|
||||
query_obj = self._build_metadata_filter_conditions(query_obj)
|
||||
|
||||
# Order and limit results
|
||||
query_obj = query_obj.order_by(desc('similarity')).limit(k)
|
||||
|
||||
# Debug logging for RAG tuning if enabled
|
||||
if self.model_variables['rag_tuning']:
|
||||
self._log_rag_tuning(query_obj, query_embedding)
|
||||
|
||||
res = query_obj.all()
|
||||
|
||||
result = []
|
||||
for doc in res:
|
||||
if self.model_variables['rag_tuning']:
|
||||
current_app.logger.debug(f'Document ID: {doc[0].id} - Distance: {doc[1]}\n')
|
||||
current_app.logger.debug(f'Chunk: \n {doc[0].chunk}\n\n')
|
||||
result.append(f'SOURCE: {doc[0].id}\n\n{doc[0].chunk}\n\n')
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
current_app.logger.error(f'Error retrieving relevant documents: {e}')
|
||||
db.session.rollback()
|
||||
return []
|
||||
|
||||
return result
|
||||
|
||||
def _log_rag_tuning(self, query_obj, query_embedding):
|
||||
"""Log debug information for RAG tuning"""
|
||||
current_app.rag_tuning_logger.debug("Debug: Query execution plan:")
|
||||
current_app.rag_tuning_logger.debug(f"{query_obj.statement}")
|
||||
if self._active_filters:
|
||||
current_app.rag_tuning_logger.debug("Debug: Active metadata filters:")
|
||||
current_app.rag_tuning_logger.debug(f"{self._active_filters}")
|
||||
|
||||
def _get_query_embedding(self, query: str):
|
||||
"""Get embedding for the query text"""
|
||||
embedding_model = self.model_variables['embedding_model']
|
||||
query_embedding = embedding_model.embed_query(query)
|
||||
return query_embedding
|
||||
|
||||
@property
|
||||
def model_variables(self) -> ModelVariables:
|
||||
return self._model_variables
|
||||
|
||||
@property
|
||||
def tenant_info(self) -> Dict[str, Any]:
|
||||
return self._tenant_info
|
||||
@@ -1,52 +0,0 @@
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from sqlalchemy import asc
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from pydantic import Field, BaseModel, PrivateAttr
|
||||
from typing import Any, Dict
|
||||
from flask import current_app
|
||||
|
||||
from common.extensions import db
|
||||
from common.models.interaction import ChatSession, Interaction
|
||||
from common.utils.model_utils import ModelVariables
|
||||
|
||||
|
||||
class EveAIHistoryRetriever(BaseRetriever, BaseModel):
|
||||
_model_variables: ModelVariables = PrivateAttr()
|
||||
_session_id: str = PrivateAttr()
|
||||
|
||||
def __init__(self, model_variables: ModelVariables, session_id: str):
|
||||
super().__init__()
|
||||
self._model_variables = model_variables
|
||||
self._session_id = session_id
|
||||
|
||||
@property
|
||||
def model_variables(self) -> ModelVariables:
|
||||
return self._model_variables
|
||||
|
||||
@property
|
||||
def session_id(self) -> str:
|
||||
return self._session_id
|
||||
|
||||
def _get_relevant_documents(self, query: str):
|
||||
current_app.logger.debug(f'Retrieving history of interactions for query: {query}')
|
||||
|
||||
try:
|
||||
query_obj = (
|
||||
db.session.query(Interaction)
|
||||
.join(ChatSession, Interaction.chat_session_id == ChatSession.id)
|
||||
.filter(ChatSession.session_id == self.session_id)
|
||||
.order_by(asc(Interaction.id))
|
||||
)
|
||||
|
||||
interactions = query_obj.all()
|
||||
|
||||
result = []
|
||||
for interaction in interactions:
|
||||
result.append(f'HUMAN:\n{interaction.detailed_question}\n\nAI: \n{interaction.answer}\n\n')
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
current_app.logger.error(f'Error retrieving history of interactions: {e}')
|
||||
db.session.rollback()
|
||||
return []
|
||||
|
||||
return result
|
||||
@@ -1,40 +0,0 @@
|
||||
from pydantic import BaseModel, PrivateAttr
|
||||
from typing import Dict, Any
|
||||
|
||||
from common.utils.model_utils import ModelVariables
|
||||
|
||||
|
||||
class EveAIRetriever(BaseModel):
|
||||
_catalog_id: int = PrivateAttr()
|
||||
_user_metadata: Dict[str, Any] = PrivateAttr()
|
||||
_system_metadata: Dict[str, Any] = PrivateAttr()
|
||||
_configuration: Dict[str, Any] = PrivateAttr()
|
||||
_tenant_info: Dict[str, Any] = PrivateAttr()
|
||||
_model_variables: ModelVariables = PrivateAttr()
|
||||
_tuning: bool = PrivateAttr()
|
||||
|
||||
def __init__(self, catalog_id: int, user_metadata: Dict[str, Any], system_metadata: Dict[str, Any],
|
||||
configuration: Dict[str, Any]):
|
||||
super().__init__()
|
||||
self._catalog_id = catalog_id
|
||||
self._user_metadata = user_metadata
|
||||
self._system_metadata = system_metadata
|
||||
self._configuration = configuration
|
||||
|
||||
@property
|
||||
def catalog_id(self):
|
||||
return self._catalog_id
|
||||
|
||||
@property
|
||||
def user_metadata(self):
|
||||
return self._user_metadata
|
||||
|
||||
@property
|
||||
def system_metadata(self):
|
||||
return self._system_metadata
|
||||
|
||||
@property
|
||||
def configuration(self):
|
||||
return self._configuration
|
||||
|
||||
# Any common methods that should be shared among retrievers can go here.
|
||||
154
common/langchain/templates/template_manager.py
Normal file
154
common/langchain/templates/template_manager.py
Normal file
@@ -0,0 +1,154 @@
|
||||
import os
|
||||
import yaml
|
||||
from typing import Dict, Optional, Any
|
||||
from packaging import version
|
||||
from dataclasses import dataclass
|
||||
from flask import current_app, Flask
|
||||
|
||||
from common.utils.os_utils import get_project_root
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptTemplate:
|
||||
"""Represents a versioned prompt template"""
|
||||
content: str
|
||||
version: str
|
||||
metadata: Dict[str, Any]
|
||||
|
||||
|
||||
class TemplateManager:
|
||||
"""Manages versioned prompt templates"""
|
||||
|
||||
def __init__(self):
|
||||
self.templates_dir = None
|
||||
self._templates = None
|
||||
self.app = None
|
||||
|
||||
def init_app(self, app: Flask) -> None:
|
||||
# Initialize template manager
|
||||
base_dir = "/app"
|
||||
self.templates_dir = os.path.join(base_dir, 'config', 'prompts')
|
||||
app.logger.debug(f'Loading templates from {self.templates_dir}')
|
||||
self.app = app
|
||||
self._templates = self._load_templates()
|
||||
# Log available templates for each supported model
|
||||
for llm in app.config['SUPPORTED_LLMS']:
|
||||
try:
|
||||
available_templates = self.list_templates(llm)
|
||||
app.logger.info(f"Loaded templates for {llm}: {available_templates}")
|
||||
except ValueError:
|
||||
app.logger.warning(f"No templates found for {llm}")
|
||||
|
||||
def _load_templates(self) -> Dict[str, Dict[str, Dict[str, PromptTemplate]]]:
|
||||
"""
|
||||
Load all template versions from the templates directory.
|
||||
Structure: {provider.model -> {template_name -> {version -> template}}}
|
||||
Directory structure:
|
||||
prompts/
|
||||
├── provider/
|
||||
│ └── model/
|
||||
│ └── template_name/
|
||||
│ └── version.yaml
|
||||
"""
|
||||
templates = {}
|
||||
|
||||
# Iterate through providers (anthropic, openai)
|
||||
for provider in os.listdir(self.templates_dir):
|
||||
provider_path = os.path.join(self.templates_dir, provider)
|
||||
if not os.path.isdir(provider_path):
|
||||
continue
|
||||
|
||||
# Iterate through models (claude-3, gpt-4o)
|
||||
for model in os.listdir(provider_path):
|
||||
model_path = os.path.join(provider_path, model)
|
||||
if not os.path.isdir(model_path):
|
||||
continue
|
||||
|
||||
provider_model = f"{provider}.{model}"
|
||||
templates[provider_model] = {}
|
||||
|
||||
# Iterate through template types (rag, summary, etc.)
|
||||
for template_name in os.listdir(model_path):
|
||||
template_path = os.path.join(model_path, template_name)
|
||||
if not os.path.isdir(template_path):
|
||||
continue
|
||||
|
||||
template_versions = {}
|
||||
# Load all version files for this template
|
||||
for version_file in os.listdir(template_path):
|
||||
if not version_file.endswith('.yaml'):
|
||||
continue
|
||||
|
||||
version_str = version_file[:-5] # Remove .yaml
|
||||
if not self._is_valid_version(version_str):
|
||||
current_app.logger.warning(
|
||||
f"Invalid version format for {template_name}: {version_str}")
|
||||
continue
|
||||
|
||||
try:
|
||||
with open(os.path.join(template_path, version_file)) as f:
|
||||
template_data = yaml.safe_load(f)
|
||||
# Verify required fields
|
||||
if not template_data.get('content'):
|
||||
raise ValueError("Template content is required")
|
||||
|
||||
template_versions[version_str] = PromptTemplate(
|
||||
content=template_data['content'],
|
||||
version=version_str,
|
||||
metadata=template_data.get('metadata', {})
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(
|
||||
f"Error loading template {template_name} version {version_str}: {e}")
|
||||
continue
|
||||
|
||||
if template_versions:
|
||||
templates[provider_model][template_name] = template_versions
|
||||
|
||||
return templates
|
||||
|
||||
def _is_valid_version(self, version_str: str) -> bool:
|
||||
"""Validate semantic versioning string"""
|
||||
try:
|
||||
version.parse(version_str)
|
||||
return True
|
||||
except version.InvalidVersion:
|
||||
return False
|
||||
|
||||
def get_template(self,
|
||||
provider_model: str,
|
||||
template_name: str,
|
||||
template_version: Optional[str] = None) -> PromptTemplate:
|
||||
"""
|
||||
Get a specific template version. If version not specified,
|
||||
returns the latest version.
|
||||
"""
|
||||
if provider_model not in self._templates:
|
||||
raise ValueError(f"Unknown provider.model: {provider_model}")
|
||||
|
||||
if template_name not in self._templates[provider_model]:
|
||||
raise ValueError(f"Unknown template: {template_name}")
|
||||
|
||||
versions = self._templates[provider_model][template_name]
|
||||
|
||||
if template_version:
|
||||
if template_version not in versions:
|
||||
raise ValueError(f"Template version {template_version} not found")
|
||||
return versions[template_version]
|
||||
|
||||
# Return latest version
|
||||
latest = max(versions.keys(), key=version.parse)
|
||||
return versions[latest]
|
||||
|
||||
def list_templates(self, provider_model: str) -> Dict[str, list]:
|
||||
"""
|
||||
List all available templates and their versions for a provider.model
|
||||
Returns: {template_name: [version1, version2, ...]}
|
||||
"""
|
||||
if provider_model not in self._templates:
|
||||
raise ValueError(f"Unknown provider.model: {provider_model}")
|
||||
|
||||
return {
|
||||
template_name: sorted(versions.keys(), key=version.parse)
|
||||
for template_name, versions in self._templates[provider_model].items()
|
||||
}
|
||||
@@ -1,27 +0,0 @@
|
||||
import time
|
||||
from common.utils.business_event_context import current_event
|
||||
|
||||
|
||||
def tracked_transcribe(client, *args, **kwargs):
|
||||
start_time = time.time()
|
||||
|
||||
# Extract the file and model from kwargs if present, otherwise use defaults
|
||||
file = kwargs.get('file')
|
||||
model = kwargs.get('model', 'whisper-1')
|
||||
duration = kwargs.pop('duration', 600)
|
||||
|
||||
result = client.audio.transcriptions.create(*args, **kwargs)
|
||||
end_time = time.time()
|
||||
|
||||
# Token usage for transcriptions is actually the duration in seconds we pass, as the whisper model is priced per second transcribed
|
||||
|
||||
metrics = {
|
||||
'total_tokens': duration,
|
||||
'prompt_tokens': 0, # For transcriptions, all tokens are considered "completion"
|
||||
'completion_tokens': duration,
|
||||
'time_elapsed': end_time - start_time,
|
||||
'interaction_type': 'ASR',
|
||||
}
|
||||
current_event.log_llm_metrics(metrics)
|
||||
|
||||
return result
|
||||
77
common/langchain/tracked_transcription.py
Normal file
77
common/langchain/tracked_transcription.py
Normal file
@@ -0,0 +1,77 @@
|
||||
# common/langchain/tracked_transcription.py
|
||||
from typing import Any, Optional, Dict
|
||||
import time
|
||||
from openai import OpenAI
|
||||
from common.utils.business_event_context import current_event
|
||||
|
||||
|
||||
class TrackedOpenAITranscription:
|
||||
"""Wrapper for OpenAI transcription with metric tracking"""
|
||||
|
||||
def __init__(self, api_key: str, **kwargs: Any):
|
||||
"""Initialize with OpenAI client settings"""
|
||||
self.client = OpenAI(api_key=api_key)
|
||||
self.model = kwargs.get('model', 'whisper-1')
|
||||
|
||||
def transcribe(self,
|
||||
file: Any,
|
||||
model: Optional[str] = None,
|
||||
language: Optional[str] = None,
|
||||
prompt: Optional[str] = None,
|
||||
response_format: Optional[str] = None,
|
||||
temperature: Optional[float] = None,
|
||||
duration: Optional[int] = None) -> str:
|
||||
"""
|
||||
Transcribe audio with metrics tracking
|
||||
|
||||
Args:
|
||||
file: Audio file to transcribe
|
||||
model: Model to use (defaults to whisper-1)
|
||||
language: Optional language of the audio
|
||||
prompt: Optional prompt to guide transcription
|
||||
response_format: Response format (json, text, etc)
|
||||
temperature: Sampling temperature
|
||||
duration: Duration of audio in seconds for metrics
|
||||
|
||||
Returns:
|
||||
Transcription text
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# Create transcription options
|
||||
options = {
|
||||
"file": file,
|
||||
"model": model or self.model,
|
||||
}
|
||||
if language:
|
||||
options["language"] = language
|
||||
if prompt:
|
||||
options["prompt"] = prompt
|
||||
if response_format:
|
||||
options["response_format"] = response_format
|
||||
if temperature:
|
||||
options["temperature"] = temperature
|
||||
|
||||
response = self.client.audio.transcriptions.create(**options)
|
||||
|
||||
# Calculate metrics
|
||||
end_time = time.time()
|
||||
|
||||
# Token usage for transcriptions is based on audio duration
|
||||
metrics = {
|
||||
'total_tokens': duration or 600, # Default to 10 minutes if duration not provided
|
||||
'prompt_tokens': 0, # For transcriptions, all tokens are completion
|
||||
'completion_tokens': duration or 600,
|
||||
'time_elapsed': end_time - start_time,
|
||||
'interaction_type': 'ASR',
|
||||
}
|
||||
current_event.log_llm_metrics(metrics)
|
||||
|
||||
# Return text from response
|
||||
if isinstance(response, str):
|
||||
return response
|
||||
return response.text
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"Transcription failed: {str(e)}")
|
||||
@@ -12,22 +12,31 @@ class Catalog(db.Model):
|
||||
description = db.Column(db.Text, nullable=True)
|
||||
type = db.Column(db.String(50), nullable=False, default="STANDARD_CATALOG")
|
||||
|
||||
# Embedding variables
|
||||
html_tags = db.Column(ARRAY(sa.String(10)), nullable=True, default=['p', 'h1', 'h2', 'h3', 'h4', 'h5', 'h6', 'li'])
|
||||
html_end_tags = db.Column(ARRAY(sa.String(10)), nullable=True, default=['p', 'li'])
|
||||
html_included_elements = db.Column(ARRAY(sa.String(50)), nullable=True)
|
||||
html_excluded_elements = db.Column(ARRAY(sa.String(50)), nullable=True)
|
||||
html_excluded_classes = db.Column(ARRAY(sa.String(200)), nullable=True)
|
||||
|
||||
min_chunk_size = db.Column(db.Integer, nullable=True, default=2000)
|
||||
max_chunk_size = db.Column(db.Integer, nullable=True, default=3000)
|
||||
|
||||
# Chat variables ==> Move to Specialist?
|
||||
chat_RAG_temperature = db.Column(db.Float, nullable=True, default=0.3)
|
||||
chat_no_RAG_temperature = db.Column(db.Float, nullable=True, default=0.5)
|
||||
# Meta Data
|
||||
user_metadata = db.Column(JSONB, nullable=True)
|
||||
system_metadata = db.Column(JSONB, nullable=True)
|
||||
configuration = db.Column(JSONB, nullable=True)
|
||||
|
||||
# Versioning Information
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now())
|
||||
created_by = db.Column(db.Integer, db.ForeignKey(User.id), nullable=True)
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now(), onupdate=db.func.now())
|
||||
updated_by = db.Column(db.Integer, db.ForeignKey(User.id))
|
||||
|
||||
|
||||
class Processor(db.Model):
|
||||
id = db.Column(db.Integer, primary_key=True)
|
||||
name = db.Column(db.String(50), nullable=False)
|
||||
description = db.Column(db.Text, nullable=True)
|
||||
catalog_id = db.Column(db.Integer, db.ForeignKey('catalog.id'), nullable=True)
|
||||
type = db.Column(db.String(50), nullable=False)
|
||||
sub_file_type = db.Column(db.String(50), nullable=True)
|
||||
|
||||
# Tuning enablers
|
||||
embed_tuning = db.Column(db.Boolean, nullable=True, default=False)
|
||||
tuning = db.Column(db.Boolean, nullable=True, default=False)
|
||||
|
||||
# Meta Data
|
||||
user_metadata = db.Column(JSONB, nullable=True)
|
||||
@@ -90,6 +99,7 @@ class DocumentVersion(db.Model):
|
||||
bucket_name = db.Column(db.String(255), nullable=True)
|
||||
object_name = db.Column(db.String(200), nullable=True)
|
||||
file_type = db.Column(db.String(20), nullable=True)
|
||||
sub_file_type = db.Column(db.String(50), nullable=True)
|
||||
file_size = db.Column(db.Float, nullable=True)
|
||||
language = db.Column(db.String(2), nullable=False)
|
||||
user_context = db.Column(db.Text, nullable=True)
|
||||
|
||||
@@ -20,34 +20,6 @@ class ChatSession(db.Model):
|
||||
return f"<ChatSession {self.id} by {self.user_id}>"
|
||||
|
||||
|
||||
class Interaction(db.Model):
|
||||
id = db.Column(db.Integer, primary_key=True)
|
||||
chat_session_id = db.Column(db.Integer, db.ForeignKey(ChatSession.id), nullable=False)
|
||||
question = db.Column(db.Text, nullable=False)
|
||||
detailed_question = db.Column(db.Text, nullable=True)
|
||||
answer = db.Column(db.Text, nullable=True)
|
||||
algorithm_used = db.Column(db.String(20), nullable=True)
|
||||
language = db.Column(db.String(2), nullable=False)
|
||||
timezone = db.Column(db.String(30), nullable=True)
|
||||
appreciation = db.Column(db.Integer, nullable=True)
|
||||
|
||||
# Timing information
|
||||
question_at = db.Column(db.DateTime, nullable=False)
|
||||
detailed_question_at = db.Column(db.DateTime, nullable=True)
|
||||
answer_at = db.Column(db.DateTime, nullable=True)
|
||||
|
||||
# Relations
|
||||
embeddings = db.relationship('InteractionEmbedding', backref='interaction', lazy=True)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Interaction {self.id}>"
|
||||
|
||||
|
||||
class InteractionEmbedding(db.Model):
|
||||
interaction_id = db.Column(db.Integer, db.ForeignKey(Interaction.id, ondelete='CASCADE'), primary_key=True)
|
||||
embedding_id = db.Column(db.Integer, db.ForeignKey(Embedding.id, ondelete='CASCADE'), primary_key=True)
|
||||
|
||||
|
||||
class Specialist(db.Model):
|
||||
id = db.Column(db.Integer, primary_key=True)
|
||||
name = db.Column(db.String(50), nullable=False)
|
||||
@@ -68,7 +40,34 @@ class Specialist(db.Model):
|
||||
updated_by = db.Column(db.Integer, db.ForeignKey(User.id))
|
||||
|
||||
|
||||
class Interaction(db.Model):
|
||||
id = db.Column(db.Integer, primary_key=True)
|
||||
chat_session_id = db.Column(db.Integer, db.ForeignKey(ChatSession.id), nullable=False)
|
||||
specialist_id = db.Column(db.Integer, db.ForeignKey(Specialist.id), nullable=True)
|
||||
specialist_arguments = db.Column(JSONB, nullable=True)
|
||||
specialist_results = db.Column(JSONB, nullable=True)
|
||||
timezone = db.Column(db.String(30), nullable=True)
|
||||
appreciation = db.Column(db.Integer, nullable=True)
|
||||
|
||||
# Timing information
|
||||
question_at = db.Column(db.DateTime, nullable=False)
|
||||
detailed_question_at = db.Column(db.DateTime, nullable=True)
|
||||
answer_at = db.Column(db.DateTime, nullable=True)
|
||||
|
||||
# Relations
|
||||
embeddings = db.relationship('InteractionEmbedding', backref='interaction', lazy=True)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Interaction {self.id}>"
|
||||
|
||||
|
||||
class InteractionEmbedding(db.Model):
|
||||
interaction_id = db.Column(db.Integer, db.ForeignKey(Interaction.id, ondelete='CASCADE'), primary_key=True)
|
||||
embedding_id = db.Column(db.Integer, db.ForeignKey(Embedding.id, ondelete='CASCADE'), primary_key=True)
|
||||
|
||||
|
||||
class SpecialistRetriever(db.Model):
|
||||
specialist_id = db.Column(db.Integer, db.ForeignKey(Specialist.id, ondelete='CASCADE'), primary_key=True)
|
||||
retriever_id = db.Column(db.Integer, db.ForeignKey(Retriever.id, ondelete='CASCADE'), primary_key=True)
|
||||
|
||||
retriever = db.relationship("Retriever", backref="specialist_retrievers")
|
||||
|
||||
@@ -4,7 +4,6 @@ from contextlib import contextmanager
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, Optional
|
||||
from datetime import datetime as dt, timezone as tz
|
||||
from portkey_ai import Portkey, Config
|
||||
import logging
|
||||
|
||||
from .business_event_context import BusinessEventContext
|
||||
|
||||
89
common/utils/cache/base.py
vendored
Normal file
89
common/utils/cache/base.py
vendored
Normal file
@@ -0,0 +1,89 @@
|
||||
# common/utils/cache/base.py
|
||||
|
||||
from typing import Any, Dict, List, Optional, TypeVar, Generic, Type
|
||||
from dataclasses import dataclass
|
||||
from flask import Flask
|
||||
from dogpile.cache import CacheRegion
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheKey:
|
||||
"""Represents a cache key with multiple components"""
|
||||
components: Dict[str, Any]
|
||||
|
||||
def __str__(self) -> str:
|
||||
return ":".join(f"{k}={v}" for k, v in sorted(self.components.items()))
|
||||
|
||||
|
||||
class CacheInvalidationManager:
|
||||
"""Manages cache invalidation subscriptions"""
|
||||
|
||||
def __init__(self):
|
||||
self._subscribers = {}
|
||||
|
||||
def subscribe(self, model: str, handler: 'CacheHandler', key_fields: List[str]):
|
||||
if model not in self._subscribers:
|
||||
self._subscribers[model] = []
|
||||
self._subscribers[model].append((handler, key_fields))
|
||||
|
||||
def notify_change(self, model: str, **identifiers):
|
||||
if model in self._subscribers:
|
||||
for handler, key_fields in self._subscribers[model]:
|
||||
if all(field in identifiers for field in key_fields):
|
||||
handler.invalidate_by_model(model, **identifiers)
|
||||
|
||||
|
||||
class CacheHandler(Generic[T]):
|
||||
"""Base cache handler implementation"""
|
||||
|
||||
def __init__(self, region: CacheRegion, prefix: str):
|
||||
self.region = region
|
||||
self.prefix = prefix
|
||||
self._key_components = []
|
||||
|
||||
def configure_keys(self, *components: str):
|
||||
self._key_components = components
|
||||
return self
|
||||
|
||||
def subscribe_to_model(self, model: str, key_fields: List[str]):
|
||||
invalidation_manager.subscribe(model, self, key_fields)
|
||||
return self
|
||||
|
||||
def generate_key(self, **identifiers) -> str:
|
||||
missing = set(self._key_components) - set(identifiers.keys())
|
||||
if missing:
|
||||
raise ValueError(f"Missing key components: {missing}")
|
||||
|
||||
key = CacheKey({k: identifiers[k] for k in self._key_components})
|
||||
return f"{self.prefix}:{str(key)}"
|
||||
|
||||
def get(self, creator_func, **identifiers) -> T:
|
||||
cache_key = self.generate_key(**identifiers)
|
||||
|
||||
def creator():
|
||||
instance = creator_func(**identifiers)
|
||||
return self.to_cache_data(instance)
|
||||
|
||||
cached_data = self.region.get_or_create(
|
||||
cache_key,
|
||||
creator,
|
||||
should_cache_fn=self.should_cache
|
||||
)
|
||||
|
||||
return self.from_cache_data(cached_data, **identifiers)
|
||||
|
||||
def invalidate(self, **identifiers):
|
||||
cache_key = self.generate_key(**identifiers)
|
||||
self.region.delete(cache_key)
|
||||
|
||||
def invalidate_by_model(self, model: str, **identifiers):
|
||||
try:
|
||||
self.invalidate(**identifiers)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
|
||||
# Create global invalidation manager
|
||||
invalidation_manager = CacheInvalidationManager()
|
||||
32
common/utils/cache/eveai_cache_manager.py
vendored
Normal file
32
common/utils/cache/eveai_cache_manager.py
vendored
Normal file
@@ -0,0 +1,32 @@
|
||||
from typing import Type
|
||||
|
||||
from flask import Flask
|
||||
|
||||
from common.utils.cache.base import CacheHandler
|
||||
|
||||
|
||||
class EveAICacheManager:
|
||||
"""Cache manager with registration capabilities"""
|
||||
|
||||
def __init__(self):
|
||||
self.model_region = None
|
||||
self.eveai_chat_workers_region = None
|
||||
self.eveai_workers_region = None
|
||||
self._handlers = {}
|
||||
|
||||
def init_app(self, app: Flask):
|
||||
"""Initialize cache regions"""
|
||||
from common.utils.cache.regions import create_cache_regions
|
||||
self.model_region, self.eveai_chat_workers_region, self.eveai_workers_region = create_cache_regions(app)
|
||||
|
||||
# Initialize all registered handlers with their regions
|
||||
for handler_class, region_name in self._handlers.items():
|
||||
region = getattr(self, f"{region_name}_region")
|
||||
handler_instance = handler_class(region)
|
||||
setattr(self, handler_class.handler_name, handler_instance)
|
||||
|
||||
def register_handler(self, handler_class: Type[CacheHandler], region: str):
|
||||
"""Register a cache handler class with its region"""
|
||||
if not hasattr(handler_class, 'handler_name'):
|
||||
raise ValueError("Cache handler must define handler_name class attribute")
|
||||
self._handlers[handler_class] = region
|
||||
61
common/utils/cache/regions.py
vendored
Normal file
61
common/utils/cache/regions.py
vendored
Normal file
@@ -0,0 +1,61 @@
|
||||
# common/utils/cache/regions.py
|
||||
|
||||
from dogpile.cache import make_region
|
||||
from flask import current_app
|
||||
from urllib.parse import urlparse
|
||||
import os
|
||||
|
||||
|
||||
def get_redis_config(app):
|
||||
"""
|
||||
Create Redis configuration dict based on app config
|
||||
Handles both authenticated and non-authenticated setups
|
||||
"""
|
||||
# Parse the REDIS_BASE_URI to get all components
|
||||
redis_uri = urlparse(app.config['REDIS_BASE_URI'])
|
||||
|
||||
config = {
|
||||
'host': redis_uri.hostname,
|
||||
'port': int(redis_uri.port or 6379),
|
||||
'db': 4, # Keep this for later use
|
||||
'redis_expiration_time': 3600,
|
||||
'distributed_lock': True
|
||||
}
|
||||
|
||||
# Add authentication if provided
|
||||
if redis_uri.username and redis_uri.password:
|
||||
config.update({
|
||||
'username': redis_uri.username,
|
||||
'password': redis_uri.password
|
||||
})
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def create_cache_regions(app):
|
||||
"""Initialize all cache regions with app config"""
|
||||
redis_config = get_redis_config(app)
|
||||
|
||||
# Region for model-related caching (ModelVariables etc)
|
||||
model_region = make_region(name='model').configure(
|
||||
'dogpile.cache.redis',
|
||||
arguments=redis_config,
|
||||
replace_existing_backend=True
|
||||
)
|
||||
|
||||
# Region for eveai_chat_workers components (Specialists, Retrievers, ...)
|
||||
eveai_chat_workers_region = make_region(name='chat_workers').configure(
|
||||
'dogpile.cache.redis',
|
||||
arguments=redis_config, # arguments={**redis_config, 'db': 4}, # Different DB
|
||||
replace_existing_backend=True
|
||||
)
|
||||
|
||||
# Region for eveai_workers components (Processors, ...)
|
||||
eveai_workers_region = make_region(name='workers').configure(
|
||||
'dogpile.cache.redis',
|
||||
arguments=redis_config, # Same config for now
|
||||
replace_existing_backend=True
|
||||
)
|
||||
|
||||
return model_region, eveai_chat_workers_region, eveai_workers_region
|
||||
|
||||
@@ -8,8 +8,6 @@ celery_app = Celery()
|
||||
|
||||
def init_celery(celery, app, is_beat=False):
|
||||
celery_app.main = app.name
|
||||
app.logger.debug(f'CELERY_BROKER_URL: {app.config["CELERY_BROKER_URL"]}')
|
||||
app.logger.debug(f'CELERY_RESULT_BACKEND: {app.config["CELERY_RESULT_BACKEND"]}')
|
||||
|
||||
celery_config = {
|
||||
'broker_url': app.config.get('CELERY_BROKER_URL', 'redis://localhost:6379/0'),
|
||||
|
||||
613
common/utils/config_field_types.py
Normal file
613
common/utils/config_field_types.py
Normal file
@@ -0,0 +1,613 @@
|
||||
from typing import Optional, List, Union, Dict, Any, Pattern
|
||||
from pydantic import BaseModel, field_validator, model_validator
|
||||
from typing_extensions import Annotated
|
||||
import re
|
||||
from datetime import datetime
|
||||
import json
|
||||
from textwrap import dedent
|
||||
import yaml
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
class TaggingField(BaseModel):
|
||||
"""Represents a single tagging field configuration"""
|
||||
type: str
|
||||
required: bool = False
|
||||
description: Optional[str] = None
|
||||
allowed_values: Optional[List[Any]] = None # for enum type
|
||||
min_value: Optional[Union[int, float]] = None # for numeric types
|
||||
max_value: Optional[Union[int, float]] = None # for numeric types
|
||||
|
||||
@field_validator('type', mode='before')
|
||||
@classmethod
|
||||
def validate_type(cls, v: str) -> str:
|
||||
valid_types = ['string', 'integer', 'float', 'date', 'enum']
|
||||
if v not in valid_types:
|
||||
raise ValueError(f'type must be one of {valid_types}')
|
||||
return v
|
||||
|
||||
@model_validator(mode='after')
|
||||
def validate_field_constraints(self) -> 'TaggingField':
|
||||
# Validate enum constraints
|
||||
if self.type == 'enum':
|
||||
if not self.allowed_values:
|
||||
raise ValueError('allowed_values must be provided for enum type')
|
||||
elif self.allowed_values is not None:
|
||||
raise ValueError('allowed_values only valid for enum type')
|
||||
|
||||
# Validate numeric constraints
|
||||
if self.type not in ('integer', 'float'):
|
||||
if self.min_value is not None or self.max_value is not None:
|
||||
raise ValueError('min_value/max_value only valid for numeric types')
|
||||
else:
|
||||
if self.min_value is not None and self.max_value is not None and self.min_value >= self.max_value:
|
||||
raise ValueError('min_value must be less than max_value')
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class TaggingFields(BaseModel):
|
||||
"""Represents a collection of tagging fields, mapped by their names"""
|
||||
fields: Dict[str, TaggingField]
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Dict[str, Any]]) -> 'TaggingFields':
|
||||
return cls(fields={
|
||||
field_name: TaggingField(**field_config)
|
||||
for field_name, field_config in data.items()
|
||||
})
|
||||
|
||||
def to_dict(self) -> Dict[str, Dict[str, Any]]:
|
||||
return {
|
||||
field_name: field.model_dump(exclude_none=True)
|
||||
for field_name, field in self.fields.items()
|
||||
}
|
||||
|
||||
|
||||
class ArgumentConstraint(BaseModel):
|
||||
"""Base class for all argument constraints"""
|
||||
description: Optional[str] = None
|
||||
error_message: Optional[str] = None
|
||||
|
||||
|
||||
class NumericConstraint(ArgumentConstraint):
|
||||
"""Constraints for numeric values (int/float)"""
|
||||
min_value: Optional[float] = None
|
||||
max_value: Optional[float] = None
|
||||
include_min: bool = True # True for >= min_value, False for > min_value
|
||||
include_max: bool = True # True for <= max_value, False for < max_value
|
||||
|
||||
@model_validator(mode='after')
|
||||
def validate_ranges(self) -> 'NumericConstraint':
|
||||
if self.min_value is not None and self.max_value is not None:
|
||||
if self.min_value > self.max_value:
|
||||
raise ValueError("min_value must be less than or equal to max_value")
|
||||
return self
|
||||
|
||||
def validate(self, value: Union[int, float]) -> bool:
|
||||
if self.min_value is not None:
|
||||
if self.include_min and value < self.min_value:
|
||||
return False
|
||||
if not self.include_min and value <= self.min_value:
|
||||
return False
|
||||
if self.max_value is not None:
|
||||
if self.include_max and value > self.max_value:
|
||||
return False
|
||||
if not self.include_max and value >= self.max_value:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class StringConstraint(ArgumentConstraint):
|
||||
"""Constraints for string values"""
|
||||
min_length: Optional[int] = None
|
||||
max_length: Optional[int] = None
|
||||
patterns: Optional[List[str]] = None # List of regex patterns to match
|
||||
pattern_match_all: bool = False # If True, string must match all patterns
|
||||
forbidden_patterns: Optional[List[str]] = None # List of regex patterns that must not match
|
||||
allow_empty: bool = False
|
||||
|
||||
@field_validator('patterns', 'forbidden_patterns')
|
||||
@classmethod
|
||||
def validate_patterns(cls, v: Optional[List[str]]) -> Optional[List[str]]:
|
||||
if v is not None:
|
||||
# Validate each pattern compiles
|
||||
for pattern in v:
|
||||
try:
|
||||
re.compile(pattern)
|
||||
except re.error as e:
|
||||
raise ValueError(f"Invalid regex pattern '{pattern}': {str(e)}")
|
||||
return v
|
||||
|
||||
def validate(self, value: str) -> bool:
|
||||
if not self.allow_empty and not value:
|
||||
return False
|
||||
|
||||
if self.min_length is not None and len(value) < self.min_length:
|
||||
return False
|
||||
|
||||
if self.max_length is not None and len(value) > self.max_length:
|
||||
return False
|
||||
|
||||
if self.patterns:
|
||||
matches = [bool(re.search(pattern, value)) for pattern in self.patterns]
|
||||
if self.pattern_match_all and not all(matches):
|
||||
return False
|
||||
if not self.pattern_match_all and not any(matches):
|
||||
return False
|
||||
|
||||
if self.forbidden_patterns:
|
||||
for pattern in self.forbidden_patterns:
|
||||
if re.search(pattern, value):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
class DateConstraint(ArgumentConstraint):
|
||||
"""Constraints for date values"""
|
||||
min_date: Optional[datetime] = None
|
||||
max_date: Optional[datetime] = None
|
||||
include_min: bool = True
|
||||
include_max: bool = True
|
||||
allowed_formats: Optional[List[str]] = None # List of allowed date formats
|
||||
|
||||
@model_validator(mode='after')
|
||||
def validate_ranges(self) -> 'DateConstraint':
|
||||
if self.min_date and self.max_date and self.min_date > self.max_date:
|
||||
raise ValueError("min_date must be less than or equal to max_date")
|
||||
return self
|
||||
|
||||
def validate(self, value: datetime) -> bool:
|
||||
if self.min_date is not None:
|
||||
if self.include_min and value < self.min_date:
|
||||
return False
|
||||
if not self.include_min and value <= self.min_date:
|
||||
return False
|
||||
|
||||
if self.max_date is not None:
|
||||
if self.include_max and value > self.max_date:
|
||||
return False
|
||||
if not self.include_max and value >= self.max_date:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
class EnumConstraint(ArgumentConstraint):
|
||||
"""Constraints for enum values"""
|
||||
allowed_values: List[Any]
|
||||
case_sensitive: bool = True # For string enums
|
||||
allow_multiple: bool = False # If True, value can be a list of allowed values
|
||||
min_selections: Optional[int] = None # When allow_multiple is True
|
||||
max_selections: Optional[int] = None # When allow_multiple is True
|
||||
|
||||
@model_validator(mode='after')
|
||||
def validate_selections(self) -> 'EnumConstraint':
|
||||
if self.allow_multiple:
|
||||
if self.min_selections is not None and self.max_selections is not None:
|
||||
if self.min_selections > self.max_selections:
|
||||
raise ValueError("min_selections must be less than or equal to max_selections")
|
||||
if self.max_selections > len(self.allowed_values):
|
||||
raise ValueError("max_selections cannot be greater than number of allowed values")
|
||||
return self
|
||||
|
||||
def validate(self, value: Union[Any, List[Any]]) -> bool:
|
||||
if self.allow_multiple:
|
||||
if not isinstance(value, list):
|
||||
return False
|
||||
|
||||
if self.min_selections is not None and len(value) < self.min_selections:
|
||||
return False
|
||||
|
||||
if self.max_selections is not None and len(value) > self.max_selections:
|
||||
return False
|
||||
|
||||
for v in value:
|
||||
if not self._validate_single_value(v):
|
||||
return False
|
||||
else:
|
||||
return self._validate_single_value(value)
|
||||
|
||||
return True
|
||||
|
||||
def _validate_single_value(self, value: Any) -> bool:
|
||||
if isinstance(value, str) and not self.case_sensitive:
|
||||
return any(str(value).lower() == str(v).lower() for v in self.allowed_values)
|
||||
return value in self.allowed_values
|
||||
|
||||
|
||||
class ArgumentDefinition(BaseModel):
|
||||
"""Defines an argument with its type and constraints"""
|
||||
name: str
|
||||
type: str
|
||||
description: Optional[str] = None
|
||||
required: bool = False
|
||||
default: Optional[Any] = None
|
||||
constraints: Optional[Union[NumericConstraint, StringConstraint, DateConstraint, EnumConstraint]] = None
|
||||
|
||||
@field_validator('type')
|
||||
@classmethod
|
||||
def validate_type(cls, v: str) -> str:
|
||||
valid_types = ['string', 'integer', 'float', 'date', 'enum']
|
||||
if v not in valid_types:
|
||||
raise ValueError(f'type must be one of {valid_types}')
|
||||
return v
|
||||
|
||||
@model_validator(mode='after')
|
||||
def validate_constraints(self) -> 'ArgumentDefinition':
|
||||
if self.constraints:
|
||||
expected_constraint_types = {
|
||||
'string': StringConstraint,
|
||||
'integer': NumericConstraint,
|
||||
'float': NumericConstraint,
|
||||
'date': DateConstraint,
|
||||
'enum': EnumConstraint
|
||||
}
|
||||
|
||||
expected_type = expected_constraint_types.get(self.type)
|
||||
if not isinstance(self.constraints, expected_type):
|
||||
raise ValueError(f'Constraints for type {self.type} must be of type {expected_type.__name__}')
|
||||
|
||||
if self.default is not None:
|
||||
if not self.constraints.validate(self.default):
|
||||
raise ValueError(f'Default value does not satisfy constraints for {self.name}')
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class ArgumentDefinitions(BaseModel):
|
||||
"""Collection of argument definitions"""
|
||||
arguments: Dict[str, ArgumentDefinition]
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Dict[str, Any]]) -> 'ArgumentDefinitions':
|
||||
return cls(arguments={
|
||||
arg_name: ArgumentDefinition(**arg_config)
|
||||
for arg_name, arg_config in data.items()
|
||||
})
|
||||
|
||||
def to_dict(self) -> Dict[str, Dict[str, Any]]:
|
||||
return {
|
||||
arg_name: arg.model_dump(exclude_none=True)
|
||||
for arg_name, arg in self.arguments.items()
|
||||
}
|
||||
|
||||
def validate_argument_values(self, values: Dict[str, Any]) -> Dict[str, str]:
|
||||
"""
|
||||
Validate a set of argument values against their definitions
|
||||
Returns a dictionary of error messages for invalid arguments
|
||||
"""
|
||||
errors = {}
|
||||
|
||||
# Check for required arguments
|
||||
for name, arg_def in self.arguments.items():
|
||||
if arg_def.required and name not in values:
|
||||
errors[name] = "Required argument missing"
|
||||
continue
|
||||
|
||||
if name in values:
|
||||
value = values[name]
|
||||
|
||||
# Validate type
|
||||
try:
|
||||
if arg_def.type == 'integer':
|
||||
value = int(value)
|
||||
elif arg_def.type == 'float':
|
||||
value = float(value)
|
||||
elif arg_def.type == 'date' and isinstance(value, str):
|
||||
if arg_def.constraints and arg_def.constraints.allowed_formats:
|
||||
for fmt in arg_def.constraints.allowed_formats:
|
||||
try:
|
||||
value = datetime.strptime(value, fmt)
|
||||
break
|
||||
except ValueError:
|
||||
continue
|
||||
else:
|
||||
errors[
|
||||
name] = f"Invalid date format. Allowed formats: {arg_def.constraints.allowed_formats}"
|
||||
continue
|
||||
except (ValueError, TypeError):
|
||||
errors[name] = f"Invalid type. Expected {arg_def.type}"
|
||||
continue
|
||||
|
||||
# Validate constraints
|
||||
if arg_def.constraints and not arg_def.constraints.validate(value):
|
||||
errors[name] = arg_def.constraints.error_message or "Value does not satisfy constraints"
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
@dataclass
|
||||
class DocumentationFormat:
|
||||
"""Constants for documentation formats"""
|
||||
MARKDOWN = "markdown"
|
||||
JSON = "json"
|
||||
YAML = "yaml"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DocumentationVersion:
|
||||
"""Constants for documentation versions"""
|
||||
BASIC = "basic" # Original documentation without retriever info
|
||||
EXTENDED = "extended" # Including retriever documentation
|
||||
|
||||
|
||||
def _generate_argument_constraints(field_config: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
"""Generate possible argument constraints based on field type"""
|
||||
constraints = []
|
||||
|
||||
base_constraint = {
|
||||
"description": f"Constraint for {field_config.get('description', 'field')}",
|
||||
"error_message": "Optional custom error message"
|
||||
}
|
||||
|
||||
if field_config["type"] == "integer" or field_config["type"] == "float":
|
||||
constraints.append({
|
||||
**base_constraint,
|
||||
"type": "NumericConstraint",
|
||||
"possible_constraints": {
|
||||
"min_value": "number",
|
||||
"max_value": "number",
|
||||
"include_min": "boolean",
|
||||
"include_max": "boolean"
|
||||
},
|
||||
"example": {
|
||||
"min_value": field_config.get("min_value", 0),
|
||||
"max_value": field_config.get("max_value", 100),
|
||||
"include_min": True,
|
||||
"include_max": True
|
||||
}
|
||||
})
|
||||
|
||||
elif field_config["type"] == "string":
|
||||
constraints.append({
|
||||
**base_constraint,
|
||||
"type": "StringConstraint",
|
||||
"possible_constraints": {
|
||||
"min_length": "integer",
|
||||
"max_length": "integer",
|
||||
"patterns": "list[str]",
|
||||
"pattern_match_all": "boolean",
|
||||
"forbidden_patterns": "list[str]",
|
||||
"allow_empty": "boolean"
|
||||
},
|
||||
"example": {
|
||||
"min_length": 1,
|
||||
"max_length": 100,
|
||||
"patterns": ["^[A-Za-z0-9]+$"],
|
||||
"pattern_match_all": False,
|
||||
"forbidden_patterns": ["^test_", "_temp$"],
|
||||
"allow_empty": False
|
||||
}
|
||||
})
|
||||
|
||||
elif field_config["type"] == "enum":
|
||||
constraints.append({
|
||||
**base_constraint,
|
||||
"type": "EnumConstraint",
|
||||
"possible_constraints": {
|
||||
"allowed_values": f"list[{field_config.get('allowed_values', ['value1', 'value2'])}]",
|
||||
"case_sensitive": "boolean",
|
||||
"allow_multiple": "boolean",
|
||||
"min_selections": "integer",
|
||||
"max_selections": "integer"
|
||||
},
|
||||
"example": {
|
||||
"allowed_values": field_config.get("allowed_values", ["value1", "value2"]),
|
||||
"case_sensitive": True,
|
||||
"allow_multiple": True,
|
||||
"min_selections": 1,
|
||||
"max_selections": 2
|
||||
}
|
||||
})
|
||||
|
||||
elif field_config["type"] == "date":
|
||||
constraints.append({
|
||||
**base_constraint,
|
||||
"type": "DateConstraint",
|
||||
"possible_constraints": {
|
||||
"min_date": "datetime",
|
||||
"max_date": "datetime",
|
||||
"include_min": "boolean",
|
||||
"include_max": "boolean",
|
||||
"allowed_formats": "list[str]"
|
||||
},
|
||||
"example": {
|
||||
"min_date": "2024-01-01T00:00:00",
|
||||
"max_date": "2024-12-31T23:59:59",
|
||||
"include_min": True,
|
||||
"include_max": True,
|
||||
"allowed_formats": ["%Y-%m-%d", "%Y/%m/%d"]
|
||||
}
|
||||
})
|
||||
|
||||
return constraints
|
||||
|
||||
|
||||
def generate_field_documentation(
|
||||
tagging_fields: Dict[str, Any],
|
||||
format: str = "markdown",
|
||||
version: str = "basic"
|
||||
) -> str:
|
||||
"""
|
||||
Generate documentation for tagging fields configuration.
|
||||
|
||||
Args:
|
||||
tagging_fields: Dictionary containing tagging fields configuration
|
||||
format: Output format ("markdown", "json", or "yaml")
|
||||
version: Documentation version ("basic" or "extended")
|
||||
|
||||
Returns:
|
||||
str: Formatted documentation
|
||||
"""
|
||||
if version not in [DocumentationVersion.BASIC, DocumentationVersion.EXTENDED]:
|
||||
raise ValueError(f"Unsupported documentation version: {version}")
|
||||
|
||||
# Normalize fields configuration
|
||||
normalized_fields = {}
|
||||
|
||||
for field_name, field_config in tagging_fields.items():
|
||||
field_doc = {
|
||||
"name": field_name,
|
||||
"type": field_config["type"],
|
||||
"required": field_config.get("required", False),
|
||||
"description": field_config.get("description", "No description provided"),
|
||||
"constraints": []
|
||||
}
|
||||
|
||||
# Only include possible arguments in extended version
|
||||
if version == DocumentationVersion.EXTENDED:
|
||||
field_doc["possible_arguments"] = _generate_argument_constraints(field_config)
|
||||
|
||||
# Add type-specific constraints
|
||||
if field_config["type"] == "integer" or field_config["type"] == "float":
|
||||
if "min_value" in field_config:
|
||||
field_doc["constraints"].append(
|
||||
f"Minimum value: {field_config['min_value']}")
|
||||
if "max_value" in field_config:
|
||||
field_doc["constraints"].append(
|
||||
f"Maximum value: {field_config['max_value']}")
|
||||
|
||||
elif field_config["type"] == "string":
|
||||
if "min_length" in field_config:
|
||||
field_doc["constraints"].append(
|
||||
f"Minimum length: {field_config['min_length']}")
|
||||
if "max_length" in field_config:
|
||||
field_doc["constraints"].append(
|
||||
f"Maximum length: {field_config['max_length']}")
|
||||
if "patterns" in field_config:
|
||||
field_doc["constraints"].append(
|
||||
f"Must match patterns: {', '.join(field_config['patterns'])}")
|
||||
|
||||
elif field_config["type"] == "enum":
|
||||
if "allowed_values" in field_config:
|
||||
field_doc["constraints"].append(
|
||||
f"Allowed values: {', '.join(str(v) for v in field_config['allowed_values'])}")
|
||||
|
||||
elif field_config["type"] == "date":
|
||||
if "min_date" in field_config:
|
||||
field_doc["constraints"].append(
|
||||
f"Minimum date: {field_config['min_date']}")
|
||||
if "max_date" in field_config:
|
||||
field_doc["constraints"].append(
|
||||
f"Maximum date: {field_config['max_date']}")
|
||||
if "allowed_formats" in field_config:
|
||||
field_doc["constraints"].append(
|
||||
f"Allowed formats: {', '.join(field_config['allowed_formats'])}")
|
||||
|
||||
normalized_fields[field_name] = field_doc
|
||||
|
||||
# Generate documentation in requested format
|
||||
if format == DocumentationFormat.MARKDOWN:
|
||||
return _generate_markdown_docs(normalized_fields, version)
|
||||
elif format == DocumentationFormat.JSON:
|
||||
return _generate_json_docs(normalized_fields, version)
|
||||
elif format == DocumentationFormat.YAML:
|
||||
return _generate_yaml_docs(normalized_fields, version)
|
||||
else:
|
||||
raise ValueError(f"Unsupported documentation format: {format}")
|
||||
|
||||
|
||||
def _generate_markdown_docs(fields: Dict[str, Any], version: str) -> str:
|
||||
"""Generate markdown documentation"""
|
||||
docs = ["# Tagging Fields Documentation\n"]
|
||||
|
||||
# Add overview table
|
||||
docs.append("## Fields Overview\n")
|
||||
docs.append("| Field Name | Type | Required | Description |")
|
||||
docs.append("|------------|------|----------|-------------|")
|
||||
|
||||
for field_name, field in fields.items():
|
||||
docs.append(
|
||||
f"| {field_name} | {field['type']} | "
|
||||
f"{'Yes' if field['required'] else 'No'} | {field['description']} |"
|
||||
)
|
||||
|
||||
# Add detailed field specifications
|
||||
docs.append("\n## Detailed Field Specifications\n")
|
||||
|
||||
for field_name, field in fields.items():
|
||||
docs.append(f"### {field_name}\n")
|
||||
docs.append(f"**Type:** {field['type']}")
|
||||
docs.append(f"**Required:** {'Yes' if field['required'] else 'No'}")
|
||||
docs.append(f"**Description:** {field['description']}\n")
|
||||
|
||||
if field["constraints"]:
|
||||
docs.append("**Field Constraints:**")
|
||||
for constraint in field["constraints"]:
|
||||
docs.append(f"- {constraint}")
|
||||
docs.append("")
|
||||
|
||||
# Add retriever argument documentation only in extended version
|
||||
if version == DocumentationVersion.EXTENDED and "possible_arguments" in field:
|
||||
docs.append("**Possible Retriever Arguments:**")
|
||||
for arg_constraint in field["possible_arguments"]:
|
||||
docs.append(f"\n*{arg_constraint['type']}*")
|
||||
docs.append(f"Description: {arg_constraint['description']}")
|
||||
docs.append("\nPossible constraints:")
|
||||
for const_name, const_type in arg_constraint["possible_constraints"].items():
|
||||
docs.append(f"- `{const_name}`: {const_type}")
|
||||
|
||||
docs.append("\nExample:")
|
||||
docs.append("```python")
|
||||
docs.append(json.dumps(arg_constraint["example"], indent=2))
|
||||
docs.append("```\n")
|
||||
|
||||
# Add example retriever configuration only in extended version
|
||||
if version == DocumentationVersion.EXTENDED:
|
||||
docs.append("\n## Example Retriever Configuration\n")
|
||||
docs.append("```python")
|
||||
example_config = {
|
||||
"metadata_filters": {
|
||||
field_name: field["possible_arguments"][0]["example"]
|
||||
for field_name, field in fields.items()
|
||||
if "possible_arguments" in field
|
||||
}
|
||||
}
|
||||
docs.append(json.dumps(example_config, indent=2))
|
||||
docs.append("```")
|
||||
|
||||
return "\n".join(docs)
|
||||
|
||||
|
||||
def _generate_json_docs(fields: Dict[str, Any], version: str) -> str:
|
||||
"""Generate JSON documentation"""
|
||||
doc = {
|
||||
"tagging_fields_documentation": {
|
||||
"version": version,
|
||||
"fields": fields
|
||||
}
|
||||
}
|
||||
|
||||
if version == DocumentationVersion.EXTENDED:
|
||||
doc["tagging_fields_documentation"]["example_retriever_config"] = {
|
||||
"metadata_filters": {
|
||||
field_name: field["possible_arguments"][0]["example"]
|
||||
for field_name, field in fields.items()
|
||||
if "possible_arguments" in field
|
||||
}
|
||||
}
|
||||
|
||||
return json.dumps(doc, indent=2)
|
||||
|
||||
|
||||
def _generate_yaml_docs(fields: Dict[str, Any], version: str) -> str:
|
||||
"""Generate YAML documentation"""
|
||||
doc = {
|
||||
"tagging_fields_documentation": {
|
||||
"version": version,
|
||||
"fields": fields
|
||||
}
|
||||
}
|
||||
|
||||
if version == DocumentationVersion.EXTENDED:
|
||||
doc["tagging_fields_documentation"]["example_retriever_config"] = {
|
||||
"metadata_filters": {
|
||||
field_name: field["possible_arguments"][0]["example"]
|
||||
for field_name, field in fields.items()
|
||||
if "possible_arguments" in field
|
||||
}
|
||||
}
|
||||
|
||||
return yaml.dump(doc, sort_keys=False, default_flow_style=False)
|
||||
@@ -5,10 +5,8 @@ from common.models.user import Tenant, TenantDomain
|
||||
def get_allowed_origins(tenant_id):
|
||||
session_key = f"allowed_origins_{tenant_id}"
|
||||
if session_key in session:
|
||||
current_app.logger.debug(f"Fetching allowed origins for tenant {tenant_id} from session")
|
||||
return session[session_key]
|
||||
|
||||
current_app.logger.debug(f"Fetching allowed origins for tenant {tenant_id} from database")
|
||||
tenant_domains = TenantDomain.query.filter_by(tenant_id=int(tenant_id)).all()
|
||||
allowed_origins = [domain.domain for domain in tenant_domains]
|
||||
|
||||
@@ -18,14 +16,8 @@ def get_allowed_origins(tenant_id):
|
||||
|
||||
|
||||
def cors_after_request(response, prefix):
|
||||
current_app.logger.debug(f'CORS after request: {request.path}, prefix: {prefix}')
|
||||
current_app.logger.debug(f'request.headers: {request.headers}')
|
||||
current_app.logger.debug(f'request.args: {request.args}')
|
||||
current_app.logger.debug(f'request is json?: {request.is_json}')
|
||||
|
||||
# Exclude health checks from checks
|
||||
if request.path.startswith('/healthz') or request.path.startswith('/_healthz'):
|
||||
current_app.logger.debug('Skipping CORS headers for health checks')
|
||||
response.headers.add('Access-Control-Allow-Origin', '*')
|
||||
response.headers.add('Access-Control-Allow-Headers', '*')
|
||||
response.headers.add('Access-Control-Allow-Methods', '*')
|
||||
@@ -36,7 +28,6 @@ def cors_after_request(response, prefix):
|
||||
|
||||
# Try to get tenant_id from JSON payload
|
||||
json_data = request.get_json(silent=True)
|
||||
current_app.logger.debug(f'request.get_json(silent=True): {json_data}')
|
||||
|
||||
if json_data and 'tenant_id' in json_data:
|
||||
tenant_id = json_data['tenant_id']
|
||||
@@ -44,23 +35,17 @@ def cors_after_request(response, prefix):
|
||||
# Fallback to get tenant_id from query parameters or headers if JSON is not available
|
||||
tenant_id = request.args.get('tenant_id') or request.args.get('tenantId') or request.headers.get('X-Tenant-ID')
|
||||
|
||||
current_app.logger.debug(f'Identified tenant_id: {tenant_id}')
|
||||
|
||||
if tenant_id:
|
||||
allowed_origins = get_allowed_origins(tenant_id)
|
||||
current_app.logger.debug(f'Allowed origins for tenant {tenant_id}: {allowed_origins}')
|
||||
else:
|
||||
current_app.logger.warning('tenant_id not found in request')
|
||||
|
||||
origin = request.headers.get('Origin')
|
||||
current_app.logger.debug(f'Origin: {origin}')
|
||||
|
||||
if origin in allowed_origins:
|
||||
response.headers.add('Access-Control-Allow-Origin', origin)
|
||||
response.headers.add('Access-Control-Allow-Headers', 'Content-Type,Authorization')
|
||||
response.headers.add('Access-Control-Allow-Methods', 'GET,POST,PUT,DELETE,OPTIONS')
|
||||
response.headers.add('Access-Control-Allow-Credentials', 'true')
|
||||
current_app.logger.debug(f'CORS headers set for origin: {origin}')
|
||||
else:
|
||||
current_app.logger.warning(f'Origin {origin} not allowed')
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@ def log_request_middleware(app):
|
||||
|
||||
@app.before_request
|
||||
def log_session_state_before():
|
||||
app.logger.debug(f'Session state before request: {session.items()}')
|
||||
pass
|
||||
|
||||
# @app.after_request
|
||||
# def log_response_info(response):
|
||||
@@ -58,5 +58,4 @@ def log_request_middleware(app):
|
||||
|
||||
@app.after_request
|
||||
def log_session_state_after(response):
|
||||
app.logger.debug(f'Session state after request: {session.items()}')
|
||||
return response
|
||||
|
||||
@@ -24,6 +24,7 @@ def create_document_stack(api_input, file, filename, extension, tenant_id):
|
||||
# Create the DocumentVersion
|
||||
new_doc_vers = create_version_for_document(new_doc, tenant_id,
|
||||
api_input.get('url', ''),
|
||||
api_input.get('sub_file_type', ''),
|
||||
api_input.get('language', 'en'),
|
||||
api_input.get('user_context', ''),
|
||||
api_input.get('user_metadata'),
|
||||
@@ -64,7 +65,7 @@ def create_document(form, filename, catalog_id):
|
||||
return new_doc
|
||||
|
||||
|
||||
def create_version_for_document(document, tenant_id, url, language, user_context, user_metadata, catalog_properties):
|
||||
def create_version_for_document(document, tenant_id, url, sub_file_type, language, user_context, user_metadata, catalog_properties):
|
||||
new_doc_vers = DocumentVersion()
|
||||
if url != '':
|
||||
new_doc_vers.url = url
|
||||
@@ -83,6 +84,9 @@ def create_version_for_document(document, tenant_id, url, language, user_context
|
||||
if catalog_properties != '' and catalog_properties is not None:
|
||||
new_doc_vers.catalog_properties = catalog_properties
|
||||
|
||||
if sub_file_type != '':
|
||||
new_doc_vers.sub_file_type = sub_file_type
|
||||
|
||||
new_doc_vers.document = document
|
||||
|
||||
set_logging_information(new_doc_vers, dt.now(tz.utc))
|
||||
@@ -237,8 +241,6 @@ def start_embedding_task(tenant_id, doc_vers_id):
|
||||
|
||||
|
||||
def validate_file_type(extension):
|
||||
current_app.logger.debug(f'Validating file type {extension}')
|
||||
current_app.logger.debug(f'Supported file types: {current_app.config["SUPPORTED_FILE_TYPES"]}')
|
||||
if extension not in current_app.config['SUPPORTED_FILE_TYPES']:
|
||||
raise EveAIUnsupportedFileType(f"Filetype {extension} is currently not supported. "
|
||||
f"Supported filetypes: {', '.join(current_app.config['SUPPORTED_FILE_TYPES'])}")
|
||||
|
||||
@@ -10,6 +10,7 @@ class EveAIException(Exception):
|
||||
def to_dict(self):
|
||||
rv = dict(self.payload or ())
|
||||
rv['message'] = self.message
|
||||
rv['error'] = self.__class__.__name__
|
||||
return rv
|
||||
|
||||
|
||||
@@ -41,3 +42,9 @@ class EveAINoLicenseForTenant(EveAIException):
|
||||
super().__init__(message, status_code, payload)
|
||||
|
||||
|
||||
class EveAITenantNotFound(EveAIException):
|
||||
"""Raised when a tenant is not found"""
|
||||
|
||||
def __init__(self, message="Tenant not found", status_code=400, payload=None):
|
||||
super().__init__(message, status_code, payload)
|
||||
|
||||
|
||||
@@ -24,9 +24,6 @@ def mw_before_request():
|
||||
if not tenant_id:
|
||||
raise Exception('Cannot switch schema for tenant: no tenant defined in session')
|
||||
|
||||
for role in current_user.roles:
|
||||
current_app.logger.debug(f'In middleware: User {current_user.email} has role {role.name}')
|
||||
|
||||
# user = User.query.get(current_user.id)
|
||||
if current_user.has_role('Super User') or current_user.tenant_id == tenant_id:
|
||||
Database(tenant_id).switch_schema()
|
||||
|
||||
@@ -1,249 +1,36 @@
|
||||
import os
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
import langcodes
|
||||
from flask import current_app
|
||||
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
from typing import List, Any, Iterator
|
||||
from collections.abc import MutableMapping
|
||||
from openai import OpenAI
|
||||
from portkey_ai import createHeaders, PORTKEY_GATEWAY_URL
|
||||
from portkey_ai.langchain.portkey_langchain_callback_handler import LangchainCallbackHandler
|
||||
|
||||
from common.langchain.llm_metrics_handler import LLMMetricsHandler
|
||||
from common.langchain.templates.template_manager import TemplateManager
|
||||
from langchain_openai import OpenAIEmbeddings, ChatOpenAI, OpenAI
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
from flask import current_app
|
||||
from datetime import datetime as dt, timezone as tz
|
||||
|
||||
from common.langchain.tracked_openai_embeddings import TrackedOpenAIEmbeddings
|
||||
from common.langchain.tracked_transcribe import tracked_transcribe
|
||||
from common.models.document import EmbeddingSmallOpenAI, EmbeddingLargeOpenAI, Catalog
|
||||
from common.langchain.tracked_transcription import TrackedOpenAITranscription
|
||||
from common.models.user import Tenant
|
||||
from common.utils.cache.base import CacheHandler
|
||||
from config.model_config import MODEL_CONFIG
|
||||
from common.utils.business_event_context import current_event
|
||||
from common.extensions import template_manager, cache_manager
|
||||
from common.models.document import EmbeddingLargeOpenAI, EmbeddingSmallOpenAI
|
||||
from common.utils.eveai_exceptions import EveAITenantNotFound
|
||||
|
||||
|
||||
class CitedAnswer(BaseModel):
|
||||
"""Default docstring - to be replaced with actual prompt"""
|
||||
def create_language_template(template: str, language: str) -> str:
|
||||
"""
|
||||
Replace language placeholder in template with specified language
|
||||
|
||||
answer: str = Field(
|
||||
...,
|
||||
description="The answer to the user question, based on the given sources",
|
||||
)
|
||||
citations: List[int] = Field(
|
||||
...,
|
||||
description="The integer IDs of the SPECIFIC sources that were used to generate the answer"
|
||||
)
|
||||
insufficient_info: bool = Field(
|
||||
False, # Default value is set to False
|
||||
description="A boolean indicating wether given sources were sufficient or not to generate the answer"
|
||||
)
|
||||
Args:
|
||||
template: Template string with {language} placeholder
|
||||
language: Language code to insert
|
||||
|
||||
|
||||
def set_language_prompt_template(cls, language_prompt):
|
||||
cls.__doc__ = language_prompt
|
||||
|
||||
|
||||
class ModelVariables(MutableMapping):
|
||||
def __init__(self, tenant: Tenant, catalog_id=None):
|
||||
self.tenant = tenant
|
||||
self.catalog_id = catalog_id
|
||||
self._variables = self._initialize_variables()
|
||||
self._embedding_model = None
|
||||
self._llm = None
|
||||
self._llm_no_rag = None
|
||||
self._transcription_client = None
|
||||
self._prompt_templates = {}
|
||||
self._embedding_db_model = None
|
||||
self.llm_metrics_handler = LLMMetricsHandler()
|
||||
self._transcription_client = None
|
||||
|
||||
def _initialize_variables(self):
|
||||
variables = {}
|
||||
|
||||
# Get the Catalog if catalog_id is passed
|
||||
if self.catalog_id:
|
||||
catalog = Catalog.query.get_or_404(self.catalog_id)
|
||||
|
||||
# We initialize the variables that are available knowing the tenant.
|
||||
variables['embed_tuning'] = catalog.embed_tuning or False
|
||||
|
||||
# Set HTML Chunking Variables
|
||||
variables['html_tags'] = catalog.html_tags
|
||||
variables['html_end_tags'] = catalog.html_end_tags
|
||||
variables['html_included_elements'] = catalog.html_included_elements
|
||||
variables['html_excluded_elements'] = catalog.html_excluded_elements
|
||||
variables['html_excluded_classes'] = catalog.html_excluded_classes
|
||||
|
||||
# Set Chunk Size variables
|
||||
variables['min_chunk_size'] = catalog.min_chunk_size
|
||||
variables['max_chunk_size'] = catalog.max_chunk_size
|
||||
|
||||
# Set the RAG Context (will have to change once specialists are defined
|
||||
variables['rag_context'] = self.tenant.rag_context or " "
|
||||
# Temporary setting until we have Specialists
|
||||
variables['rag_tuning'] = False
|
||||
variables['RAG_temperature'] = 0.3
|
||||
variables['no_RAG_temperature'] = 0.5
|
||||
variables['k'] = 8
|
||||
variables['similarity_threshold'] = 0.4
|
||||
|
||||
# Set model providers
|
||||
variables['embedding_provider'], variables['embedding_model'] = self.tenant.embedding_model.rsplit('.', 1)
|
||||
variables['llm_provider'], variables['llm_model'] = self.tenant.llm_model.rsplit('.', 1)
|
||||
variables["templates"] = current_app.config['PROMPT_TEMPLATES'][(f"{variables['llm_provider']}."
|
||||
f"{variables['llm_model']}")]
|
||||
current_app.logger.info(f"Loaded prompt templates: \n")
|
||||
current_app.logger.info(f"{variables['templates']}")
|
||||
|
||||
# Set model-specific configurations
|
||||
model_config = MODEL_CONFIG.get(variables['llm_provider'], {}).get(variables['llm_model'], {})
|
||||
variables.update(model_config)
|
||||
|
||||
variables['annotation_chunk_length'] = current_app.config['ANNOTATION_TEXT_CHUNK_LENGTH'][self.tenant.llm_model]
|
||||
|
||||
if variables['tool_calling_supported']:
|
||||
variables['cited_answer_cls'] = CitedAnswer
|
||||
|
||||
variables['max_compression_duration'] = current_app.config['MAX_COMPRESSION_DURATION']
|
||||
variables['max_transcription_duration'] = current_app.config['MAX_TRANSCRIPTION_DURATION']
|
||||
variables['compression_cpu_limit'] = current_app.config['COMPRESSION_CPU_LIMIT']
|
||||
variables['compression_process_delay'] = current_app.config['COMPRESSION_PROCESS_DELAY']
|
||||
|
||||
return variables
|
||||
|
||||
@property
|
||||
def embedding_model(self):
|
||||
api_key = os.getenv('OPENAI_API_KEY')
|
||||
model = self._variables['embedding_model']
|
||||
self._embedding_model = TrackedOpenAIEmbeddings(api_key=api_key,
|
||||
model=model,
|
||||
)
|
||||
self._embedding_db_model = EmbeddingSmallOpenAI \
|
||||
if model == 'text-embedding-3-small' \
|
||||
else EmbeddingLargeOpenAI
|
||||
|
||||
return self._embedding_model
|
||||
|
||||
@property
|
||||
def llm(self):
|
||||
api_key = self.get_api_key_for_llm()
|
||||
self._llm = ChatOpenAI(api_key=api_key,
|
||||
model=self._variables['llm_model'],
|
||||
temperature=self._variables['RAG_temperature'],
|
||||
callbacks=[self.llm_metrics_handler])
|
||||
return self._llm
|
||||
|
||||
@property
|
||||
def llm_no_rag(self):
|
||||
api_key = self.get_api_key_for_llm()
|
||||
self._llm_no_rag = ChatOpenAI(api_key=api_key,
|
||||
model=self._variables['llm_model'],
|
||||
temperature=self._variables['RAG_temperature'],
|
||||
callbacks=[self.llm_metrics_handler])
|
||||
return self._llm_no_rag
|
||||
|
||||
def get_api_key_for_llm(self):
|
||||
if self._variables['llm_provider'] == 'openai':
|
||||
api_key = os.getenv('OPENAI_API_KEY')
|
||||
else: # self._variables['llm_provider'] == 'anthropic'
|
||||
api_key = os.getenv('ANTHROPIC_API_KEY')
|
||||
|
||||
return api_key
|
||||
|
||||
@property
|
||||
def transcription_client(self):
|
||||
api_key = os.getenv('OPENAI_API_KEY')
|
||||
self._transcription_client = OpenAI(api_key=api_key, )
|
||||
self._variables['transcription_model'] = 'whisper-1'
|
||||
return self._transcription_client
|
||||
|
||||
def transcribe(self, *args, **kwargs):
|
||||
return tracked_transcribe(self._transcription_client, *args, **kwargs)
|
||||
|
||||
@property
|
||||
def embedding_db_model(self):
|
||||
if self._embedding_db_model is None:
|
||||
self._embedding_db_model = self.get_embedding_db_model()
|
||||
return self._embedding_db_model
|
||||
|
||||
def get_embedding_db_model(self):
|
||||
current_app.logger.debug("In get_embedding_db_model")
|
||||
if self._embedding_db_model is None:
|
||||
self._embedding_db_model = EmbeddingSmallOpenAI \
|
||||
if self._variables['embedding_model'] == 'text-embedding-3-small' \
|
||||
else EmbeddingLargeOpenAI
|
||||
current_app.logger.debug(f"Embedding DB Model: {self._embedding_db_model}")
|
||||
return self._embedding_db_model
|
||||
|
||||
def get_prompt_template(self, template_name: str) -> str:
|
||||
current_app.logger.info(f"Getting prompt template for {template_name}")
|
||||
if template_name not in self._prompt_templates:
|
||||
self._prompt_templates[template_name] = self._load_prompt_template(template_name)
|
||||
return self._prompt_templates[template_name]
|
||||
|
||||
def _load_prompt_template(self, template_name: str) -> str:
|
||||
# In the future, this method will make an API call to Portkey
|
||||
# For now, we'll simulate it with a placeholder implementation
|
||||
# You can replace this with your current prompt loading logic
|
||||
return self._variables['templates'][template_name]
|
||||
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
current_app.logger.debug(f"ModelVariables: Getting {key}")
|
||||
# Support older template names (suffix = _template)
|
||||
if key.endswith('_template'):
|
||||
key = key[:-len('_template')]
|
||||
current_app.logger.debug(f"ModelVariables: Getting modified {key}")
|
||||
if key == 'embedding_model':
|
||||
return self.embedding_model
|
||||
elif key == 'embedding_db_model':
|
||||
return self.embedding_db_model
|
||||
elif key == 'llm':
|
||||
return self.llm
|
||||
elif key == 'llm_no_rag':
|
||||
return self.llm_no_rag
|
||||
elif key == 'transcription_client':
|
||||
return self.transcription_client
|
||||
elif key in self._variables.get('prompt_templates', []):
|
||||
return self.get_prompt_template(key)
|
||||
else:
|
||||
value = self._variables.get(key)
|
||||
if value is not None:
|
||||
return value
|
||||
else:
|
||||
raise KeyError(f'Variable {key} does not exist in ModelVariables')
|
||||
|
||||
def __setitem__(self, key: str, value: Any) -> None:
|
||||
self._variables[key] = value
|
||||
|
||||
def __delitem__(self, key: str) -> None:
|
||||
del self._variables[key]
|
||||
|
||||
def __iter__(self) -> Iterator[str]:
|
||||
return iter(self._variables)
|
||||
|
||||
def __len__(self):
|
||||
return len(self._variables)
|
||||
|
||||
def get(self, key: str, default: Any = None) -> Any:
|
||||
return self.__getitem__(key) or default
|
||||
|
||||
def update(self, **kwargs) -> None:
|
||||
self._variables.update(kwargs)
|
||||
|
||||
def items(self):
|
||||
return self._variables.items()
|
||||
|
||||
def keys(self):
|
||||
return self._variables.keys()
|
||||
|
||||
def values(self):
|
||||
return self._variables.values()
|
||||
|
||||
|
||||
def select_model_variables(tenant, catalog_id=None):
|
||||
model_variables = ModelVariables(tenant=tenant, catalog_id=catalog_id)
|
||||
return model_variables
|
||||
|
||||
|
||||
def create_language_template(template, language):
|
||||
Returns:
|
||||
str: Template with language placeholder replaced
|
||||
"""
|
||||
try:
|
||||
full_language = langcodes.Language.make(language=language)
|
||||
language_template = template.replace('{language}', full_language.display_name())
|
||||
@@ -253,5 +40,249 @@ def create_language_template(template, language):
|
||||
return language_template
|
||||
|
||||
|
||||
def replace_variable_in_template(template, variable, value):
|
||||
return template.replace(variable, value)
|
||||
def replace_variable_in_template(template: str, variable: str, value: str) -> str:
|
||||
"""
|
||||
Replace a variable placeholder in template with specified value
|
||||
|
||||
Args:
|
||||
template: Template string with variable placeholder
|
||||
variable: Variable placeholder to replace (e.g. "{tenant_context}")
|
||||
value: Value to insert
|
||||
|
||||
Returns:
|
||||
str: Template with variable placeholder replaced
|
||||
"""
|
||||
return template.replace(variable, value or "")
|
||||
|
||||
|
||||
class ModelVariables:
|
||||
"""Manages model-related variables and configurations"""
|
||||
|
||||
def __init__(self, tenant_id: int, variables: Dict[str, Any] = None):
|
||||
"""
|
||||
Initialize ModelVariables with tenant and optional template manager
|
||||
|
||||
Args:
|
||||
tenant: Tenant instance
|
||||
template_manager: Optional TemplateManager instance
|
||||
"""
|
||||
current_app.logger.info(f'Model variables initialized with tenant {tenant_id} and variables \n{variables}')
|
||||
self.tenant_id = tenant_id
|
||||
self._variables = variables if variables is not None else self._initialize_variables()
|
||||
current_app.logger.info(f'Model _variables initialized to {self._variables}')
|
||||
self._embedding_model = None
|
||||
self._embedding_model_class = None
|
||||
self._llm_instances = {}
|
||||
self.llm_metrics_handler = LLMMetricsHandler()
|
||||
self._transcription_model = None
|
||||
|
||||
def _initialize_variables(self) -> Dict[str, Any]:
|
||||
"""Initialize the variables dictionary"""
|
||||
variables = {}
|
||||
|
||||
tenant = Tenant.query.get(self.tenant_id)
|
||||
if not tenant:
|
||||
raise EveAITenantNotFound(f"Tenant {self.tenant_id} not found")
|
||||
|
||||
# Set model providers
|
||||
variables['embedding_provider'], variables['embedding_model'] = tenant.embedding_model.split('.')
|
||||
variables['llm_provider'], variables['llm_model'] = tenant.llm_model.split('.')
|
||||
variables['llm_full_model'] = tenant.llm_model
|
||||
|
||||
# Set model-specific configurations
|
||||
model_config = MODEL_CONFIG.get(variables['llm_provider'], {}).get(variables['llm_model'], {})
|
||||
variables.update(model_config)
|
||||
|
||||
# Additional configurations
|
||||
variables['annotation_chunk_length'] = current_app.config['ANNOTATION_TEXT_CHUNK_LENGTH'][tenant.llm_model]
|
||||
variables['max_compression_duration'] = current_app.config['MAX_COMPRESSION_DURATION']
|
||||
variables['max_transcription_duration'] = current_app.config['MAX_TRANSCRIPTION_DURATION']
|
||||
variables['compression_cpu_limit'] = current_app.config['COMPRESSION_CPU_LIMIT']
|
||||
variables['compression_process_delay'] = current_app.config['COMPRESSION_PROCESS_DELAY']
|
||||
|
||||
return variables
|
||||
|
||||
@property
|
||||
def embedding_model(self):
|
||||
"""Get the embedding model instance"""
|
||||
if self._embedding_model is None:
|
||||
api_key = os.getenv('OPENAI_API_KEY')
|
||||
self._embedding_model = TrackedOpenAIEmbeddings(
|
||||
api_key=api_key,
|
||||
model=self._variables['embedding_model']
|
||||
)
|
||||
return self._embedding_model
|
||||
|
||||
@property
|
||||
def embedding_model_class(self):
|
||||
"""Get the embedding model class"""
|
||||
if self._embedding_model_class is None:
|
||||
if self._variables['embedding_model'] == 'text-embedding-3-large':
|
||||
self._embedding_model_class = EmbeddingLargeOpenAI
|
||||
else: # text-embedding-3-small
|
||||
self._embedding_model_class = EmbeddingSmallOpenAI
|
||||
|
||||
return self._embedding_model_class
|
||||
|
||||
@property
|
||||
def annotation_chunk_length(self):
|
||||
return self._variables['annotation_chunk_length']
|
||||
|
||||
@property
|
||||
def max_compression_duration(self):
|
||||
return self._variables['max_compression_duration']
|
||||
|
||||
@property
|
||||
def max_transcription_duration(self):
|
||||
return self._variables['max_transcription_duration']
|
||||
|
||||
@property
|
||||
def compression_cpu_limit(self):
|
||||
return self._variables['compression_cpu_limit']
|
||||
|
||||
@property
|
||||
def compression_process_delay(self):
|
||||
return self._variables['compression_process_delay']
|
||||
|
||||
def get_llm(self, temperature: float = 0.3, **kwargs) -> Any:
|
||||
"""
|
||||
Get an LLM instance with specific configuration
|
||||
|
||||
Args:
|
||||
temperature: The temperature for the LLM
|
||||
**kwargs: Additional configuration parameters
|
||||
|
||||
Returns:
|
||||
An instance of the configured LLM
|
||||
"""
|
||||
cache_key = f"{temperature}_{hash(frozenset(kwargs.items()))}"
|
||||
|
||||
if cache_key not in self._llm_instances:
|
||||
provider = self._variables['llm_provider']
|
||||
model = self._variables['llm_model']
|
||||
|
||||
if provider == 'openai':
|
||||
self._llm_instances[cache_key] = ChatOpenAI(
|
||||
api_key=os.getenv('OPENAI_API_KEY'),
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
callbacks=[self.llm_metrics_handler],
|
||||
**kwargs
|
||||
)
|
||||
elif provider == 'anthropic':
|
||||
self._llm_instances[cache_key] = ChatAnthropic(
|
||||
api_key=os.getenv('ANTHROPIC_API_KEY'),
|
||||
model=current_app.config['ANTHROPIC_LLM_VERSIONS'][model],
|
||||
temperature=temperature,
|
||||
callbacks=[self.llm_metrics_handler],
|
||||
**kwargs
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported LLM provider: {provider}")
|
||||
|
||||
return self._llm_instances[cache_key]
|
||||
|
||||
@property
|
||||
def transcription_model(self) -> TrackedOpenAITranscription:
|
||||
"""Get the transcription model instance"""
|
||||
if self._transcription_model is None:
|
||||
api_key = os.getenv('OPENAI_API_KEY')
|
||||
self._transcription_model = TrackedOpenAITranscription(
|
||||
api_key=api_key,
|
||||
model='whisper-1'
|
||||
)
|
||||
return self._transcription_model
|
||||
|
||||
# Remove the old transcription-related methods since they're now handled by TrackedOpenAITranscription
|
||||
@property
|
||||
def transcription_client(self):
|
||||
raise DeprecationWarning("Use transcription_model instead")
|
||||
|
||||
def transcribe(self, *args, **kwargs):
|
||||
raise DeprecationWarning("Use transcription_model.transcribe() instead")
|
||||
|
||||
def get_template(self, template_name: str, version: Optional[str] = None) -> str:
|
||||
"""
|
||||
Get a template for the tenant's configured LLM
|
||||
|
||||
Args:
|
||||
template_name: Name of the template to retrieve
|
||||
version: Optional specific version to retrieve
|
||||
|
||||
Returns:
|
||||
The template content
|
||||
"""
|
||||
try:
|
||||
template = template_manager.get_template(
|
||||
self._variables['llm_full_model'],
|
||||
template_name,
|
||||
version
|
||||
)
|
||||
return template.content
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error getting template {template_name}: {str(e)}")
|
||||
# Fall back to old template loading if template_manager fails
|
||||
if template_name in self._variables.get('templates', {}):
|
||||
return self._variables['templates'][template_name]
|
||||
raise
|
||||
|
||||
|
||||
class ModelVariablesCacheHandler(CacheHandler[ModelVariables]):
|
||||
handler_name = 'model_vars_cache' # Used to access handler instance from cache_manager
|
||||
|
||||
def __init__(self, region):
|
||||
super().__init__(region, 'model_variables')
|
||||
self.configure_keys('tenant_id')
|
||||
self.subscribe_to_model('Tenant', ['tenant_id'])
|
||||
|
||||
def to_cache_data(self, instance: ModelVariables) -> Dict[str, Any]:
|
||||
return {
|
||||
'tenant_id': instance.tenant_id,
|
||||
'variables': instance._variables,
|
||||
'last_updated': dt.now(tz=tz.utc).isoformat()
|
||||
}
|
||||
|
||||
def from_cache_data(self, data: Dict[str, Any], tenant_id: int, **kwargs) -> ModelVariables:
|
||||
instance = ModelVariables(tenant_id, data.get('variables'))
|
||||
return instance
|
||||
|
||||
def should_cache(self, value: Dict[str, Any]) -> bool:
|
||||
required_fields = {'tenant_id', 'variables'}
|
||||
return all(field in value for field in required_fields)
|
||||
|
||||
|
||||
# Register the handler with the cache manager
|
||||
cache_manager.register_handler(ModelVariablesCacheHandler, 'model')
|
||||
|
||||
|
||||
# Helper function to get cached model variables
|
||||
def get_model_variables(tenant_id: int) -> ModelVariables:
|
||||
return cache_manager.model_vars_cache.get(
|
||||
lambda tenant_id: ModelVariables(tenant_id), # function to create ModelVariables if required
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
|
||||
# Written in a long format, without lambda
|
||||
# def get_model_variables(tenant_id: int) -> ModelVariables:
|
||||
# """
|
||||
# Get ModelVariables instance, either from cache or newly created
|
||||
#
|
||||
# Args:
|
||||
# tenant_id: The tenant's ID
|
||||
#
|
||||
# Returns:
|
||||
# ModelVariables: Instance with either cached or fresh data
|
||||
#
|
||||
# Raises:
|
||||
# TenantNotFoundError: If tenant doesn't exist
|
||||
# CacheStateError: If cached data is invalid
|
||||
# """
|
||||
#
|
||||
# def create_new_instance(tenant_id: int) -> ModelVariables:
|
||||
# """Creator function that's called when cache miss occurs"""
|
||||
# return ModelVariables(tenant_id) # This will initialize fresh variables
|
||||
#
|
||||
# return cache_manager.model_vars_cache.get(
|
||||
# create_new_instance, # Function to create new instance if needed
|
||||
# tenant_id=tenant_id # Parameters passed to both get() and create_new_instance
|
||||
# )
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
import gevent
|
||||
import time
|
||||
from flask import current_app
|
||||
@@ -28,3 +30,17 @@ def sync_folder(file_path):
|
||||
dir_fd = os.open(file_path, os.O_RDONLY)
|
||||
os.fsync(dir_fd)
|
||||
os.close(dir_fd)
|
||||
|
||||
|
||||
def get_project_root():
|
||||
"""Get the root directory of the project."""
|
||||
# Use the module that's actually running (not this file)
|
||||
module = sys.modules['__main__']
|
||||
if hasattr(module, '__file__'):
|
||||
# Get the path to the main module
|
||||
main_path = os.path.abspath(module.__file__)
|
||||
# Get the root directory (where the main module is located)
|
||||
return os.path.dirname(main_path)
|
||||
else:
|
||||
# Fallback: use current working directory
|
||||
return os.getcwd()
|
||||
|
||||
@@ -4,7 +4,6 @@ from common.models.user import Tenant
|
||||
|
||||
# Definition of Trigger Handlers
|
||||
def set_tenant_session_data(sender, user, **kwargs):
|
||||
current_app.logger.debug(f"Setting tenant session data for user {user.id}")
|
||||
tenant = Tenant.query.filter_by(id=user.tenant_id).first()
|
||||
session['tenant'] = tenant.to_dict()
|
||||
session['default_language'] = tenant.default_language
|
||||
|
||||
@@ -11,7 +11,7 @@ def confirm_token(token, expiration=3600):
|
||||
try:
|
||||
email = serializer.loads(token, salt=current_app.config['SECURITY_PASSWORD_SALT'], max_age=expiration)
|
||||
except Exception as e:
|
||||
current_app.logger.debug(f'Error confirming token: {e}')
|
||||
current_app.logger.error(f'Error confirming token: {e}')
|
||||
raise
|
||||
return email
|
||||
|
||||
@@ -35,14 +35,11 @@ def generate_confirmation_token(email):
|
||||
|
||||
|
||||
def send_confirmation_email(user):
|
||||
current_app.logger.debug(f'Sending confirmation email to {user.email}')
|
||||
|
||||
if not test_smtp_connection():
|
||||
raise Exception("Failed to connect to SMTP server")
|
||||
|
||||
token = generate_confirmation_token(user.email)
|
||||
confirm_url = prefixed_url_for('security_bp.confirm_email', token=token, _external=True)
|
||||
current_app.logger.debug(f'Confirmation URL: {confirm_url}')
|
||||
|
||||
html = render_template('email/activate.html', confirm_url=confirm_url)
|
||||
subject = "Please confirm your email"
|
||||
@@ -56,10 +53,8 @@ def send_confirmation_email(user):
|
||||
|
||||
|
||||
def send_reset_email(user):
|
||||
current_app.logger.debug(f'Sending reset email to {user.email}')
|
||||
token = generate_reset_token(user.email)
|
||||
reset_url = prefixed_url_for('security_bp.reset_password', token=token, _external=True)
|
||||
current_app.logger.debug(f'Reset URL: {reset_url}')
|
||||
|
||||
html = render_template('email/reset_password.html', reset_url=reset_url)
|
||||
subject = "Reset Your Password"
|
||||
|
||||
112
common/utils/string_list_converter.py
Normal file
112
common/utils/string_list_converter.py
Normal file
@@ -0,0 +1,112 @@
|
||||
from typing import List, Union
|
||||
import re
|
||||
|
||||
|
||||
class StringListConverter:
|
||||
"""Utility class for converting between comma-separated strings and lists"""
|
||||
|
||||
@staticmethod
|
||||
def string_to_list(input_string: Union[str, None], allow_empty: bool = True) -> List[str]:
|
||||
"""
|
||||
Convert a comma-separated string to a list of strings.
|
||||
|
||||
Args:
|
||||
input_string: Comma-separated string to convert
|
||||
allow_empty: If True, returns empty list for None/empty input
|
||||
If False, raises ValueError for None/empty input
|
||||
|
||||
Returns:
|
||||
List of stripped strings
|
||||
|
||||
Raises:
|
||||
ValueError: If input is None/empty and allow_empty is False
|
||||
"""
|
||||
if not input_string:
|
||||
if allow_empty:
|
||||
return []
|
||||
raise ValueError("Input string cannot be None or empty")
|
||||
|
||||
return [item.strip() for item in input_string.split(',') if item.strip()]
|
||||
|
||||
@staticmethod
|
||||
def list_to_string(input_list: Union[List[str], None], allow_empty: bool = True) -> str:
|
||||
"""
|
||||
Convert a list of strings to a comma-separated string.
|
||||
|
||||
Args:
|
||||
input_list: List of strings to convert
|
||||
allow_empty: If True, returns empty string for None/empty input
|
||||
If False, raises ValueError for None/empty input
|
||||
|
||||
Returns:
|
||||
Comma-separated string
|
||||
|
||||
Raises:
|
||||
ValueError: If input is None/empty and allow_empty is False
|
||||
"""
|
||||
if not input_list:
|
||||
if allow_empty:
|
||||
return ''
|
||||
raise ValueError("Input list cannot be None or empty")
|
||||
|
||||
return ', '.join(str(item).strip() for item in input_list)
|
||||
|
||||
@staticmethod
|
||||
def validate_format(input_string: str,
|
||||
allowed_chars: str = r'a-zA-Z0-9_\-',
|
||||
min_length: int = 1,
|
||||
max_length: int = 50) -> bool:
|
||||
"""
|
||||
Validate the format of items in a comma-separated string.
|
||||
|
||||
Args:
|
||||
input_string: String to validate
|
||||
allowed_chars: String of allowed characters (for regex pattern)
|
||||
min_length: Minimum length for each item
|
||||
max_length: Maximum length for each item
|
||||
|
||||
Returns:
|
||||
bool: True if format is valid, False otherwise
|
||||
"""
|
||||
if not input_string:
|
||||
return False
|
||||
|
||||
# Create regex pattern for individual items
|
||||
pattern = f'^[{allowed_chars}]{{{min_length},{max_length}}}$'
|
||||
|
||||
try:
|
||||
# Convert to list and check each item
|
||||
items = StringListConverter.string_to_list(input_string)
|
||||
return all(bool(re.match(pattern, item)) for item in items)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def validate_and_convert(input_string: str,
|
||||
allowed_chars: str = r'a-zA-Z0-9_\-',
|
||||
min_length: int = 1,
|
||||
max_length: int = 50) -> List[str]:
|
||||
"""
|
||||
Validate and convert a comma-separated string to a list.
|
||||
|
||||
Args:
|
||||
input_string: String to validate and convert
|
||||
allowed_chars: String of allowed characters (for regex pattern)
|
||||
min_length: Minimum length for each item
|
||||
max_length: Maximum length for each item
|
||||
|
||||
Returns:
|
||||
List of validated and converted strings
|
||||
|
||||
Raises:
|
||||
ValueError: If input string format is invalid
|
||||
"""
|
||||
if not StringListConverter.validate_format(
|
||||
input_string, allowed_chars, min_length, max_length
|
||||
):
|
||||
raise ValueError(
|
||||
f"Invalid format. Items must be {min_length}-{max_length} characters "
|
||||
f"long and contain only these characters: {allowed_chars}"
|
||||
)
|
||||
|
||||
return StringListConverter.string_to_list(input_string)
|
||||
@@ -44,7 +44,7 @@ def form_validation_failed(request, form):
|
||||
for fieldName, errorMessages in form.errors.items():
|
||||
for err in errorMessages:
|
||||
flash(f"Error in {fieldName}: {err}", 'danger')
|
||||
current_app.logger.debug(f"Error in {fieldName}: {err}")
|
||||
current_app.logger.error(f"Error in {fieldName}: {err}")
|
||||
|
||||
|
||||
def form_to_dict(form):
|
||||
|
||||
Reference in New Issue
Block a user