- Introduction of dynamic Retrievers & Specialists

- Introduction of dynamic Processors
- Introduction of caching system
- Introduction of a better template manager
- Adaptation of ModelVariables to support dynamic Processors / Retrievers / Specialists
- Start adaptation of chat client
This commit is contained in:
Josako
2024-11-15 10:00:53 +01:00
parent 55a8a95f79
commit 1807435339
101 changed files with 4181 additions and 1764 deletions

View File

@@ -12,6 +12,8 @@ from flask_wtf import CSRFProtect
from flask_restx import Api
from prometheus_flask_exporter import PrometheusMetrics
from .langchain.templates.template_manager import TemplateManager
from .utils.cache.eveai_cache_manager import EveAICacheManager
from .utils.simple_encryption import SimpleEncryption
from .utils.minio_utils import MinioClient
@@ -32,3 +34,5 @@ api_rest = Api()
simple_encryption = SimpleEncryption()
minio_client = MinioClient()
metrics = PrometheusMetrics.for_app_factory()
template_manager = TemplateManager()
cache_manager = EveAICacheManager()

View File

@@ -0,0 +1,23 @@
# Output Schema Management - common/langchain/outputs/base.py
from typing import Dict, Type, Any
from pydantic import BaseModel
class BaseSpecialistOutput(BaseModel):
"""Base class for all specialist outputs"""
pass
class OutputRegistry:
"""Registry for specialist output schemas"""
_schemas: Dict[str, Type[BaseSpecialistOutput]] = {}
@classmethod
def register(cls, specialist_type: str, schema_class: Type[BaseSpecialistOutput]):
cls._schemas[specialist_type] = schema_class
@classmethod
def get_schema(cls, specialist_type: str) -> Type[BaseSpecialistOutput]:
if specialist_type not in cls._schemas:
raise ValueError(f"No output schema registered for {specialist_type}")
return cls._schemas[specialist_type]

View File

@@ -0,0 +1,22 @@
# RAG Specialist Output - common/langchain/outputs/rag.py
from typing import List
from pydantic import Field
from .base import BaseSpecialistOutput
class RAGOutput(BaseSpecialistOutput):
"""Output schema for RAG specialist"""
"""Default docstring - to be replaced with actual prompt"""
answer: str = Field(
...,
description="The answer to the user question, based on the given sources",
)
citations: List[int] = Field(
...,
description="The integer IDs of the SPECIFIC sources that were used to generate the answer"
)
insufficient_info: bool = Field(
False, # Default value is set to False
description="A boolean indicating whether given sources were sufficient or not to generate the answer"
)

View File

@@ -1,145 +0,0 @@
from langchain_core.retrievers import BaseRetriever
from sqlalchemy import func, and_, or_, desc
from sqlalchemy.exc import SQLAlchemyError
from pydantic import BaseModel, Field, PrivateAttr
from typing import Any, Dict
from flask import current_app
from common.extensions import db
from common.models.document import Document, DocumentVersion
from common.utils.datetime_utils import get_date_in_timezone
from common.utils.model_utils import ModelVariables
class EveAIDefaultRagRetriever(BaseRetriever, BaseModel):
_catalog_id: int = PrivateAttr()
_model_variables: ModelVariables = PrivateAttr()
_tenant_info: Dict[str, Any] = PrivateAttr()
def __init__(self, catalog_id: int, model_variables: ModelVariables, tenant_info: Dict[str, Any]):
super().__init__()
current_app.logger.debug(f'Model variables type: {type(model_variables)}')
self._catalog_id = catalog_id
self._model_variables = model_variables
self._tenant_info = tenant_info
@property
def catalog_id(self) -> int:
return self._catalog_id
@property
def model_variables(self) -> ModelVariables:
return self._model_variables
@property
def tenant_info(self) -> Dict[str, Any]:
return self._tenant_info
def _get_relevant_documents(self, query: str):
current_app.logger.debug(f'Retrieving relevant documents for query: {query}')
query_embedding = self._get_query_embedding(query)
current_app.logger.debug(f'Model Variables Private: {type(self._model_variables)}')
current_app.logger.debug(f'Model Variables Property: {type(self.model_variables)}')
db_class = self.model_variables['embedding_db_model']
similarity_threshold = self.model_variables['similarity_threshold']
k = self.model_variables['k']
if self.model_variables['rag_tuning']:
try:
current_date = get_date_in_timezone(self.tenant_info['timezone'])
current_app.rag_tuning_logger.debug(f'Current date: {current_date}\n')
# Debug query to show similarity for all valid documents (without chunk text)
debug_query = (
db.session.query(
Document.id.label('document_id'),
DocumentVersion.id.label('version_id'),
db_class.id.label('embedding_id'),
(1 - db_class.embedding.cosine_distance(query_embedding)).label('similarity')
)
.join(DocumentVersion, db_class.doc_vers_id == DocumentVersion.id)
.join(Document, DocumentVersion.doc_id == Document.id)
.filter(
or_(Document.valid_from.is_(None), func.date(Document.valid_from) <= current_date),
or_(Document.valid_to.is_(None), func.date(Document.valid_to) >= current_date)
)
.order_by(desc('similarity'))
)
debug_results = debug_query.all()
current_app.logger.debug("Debug: Similarity for all valid documents:")
for row in debug_results:
current_app.rag_tuning_logger.debug(f"Doc ID: {row.document_id}, "
f"Version ID: {row.version_id}, "
f"Embedding ID: {row.embedding_id}, "
f"Similarity: {row.similarity}")
current_app.rag_tuning_logger.debug(f'---------------------------------------\n')
except SQLAlchemyError as e:
current_app.logger.error(f'Error generating overview: {e}')
db.session.rollback()
if self.model_variables['rag_tuning']:
current_app.rag_tuning_logger.debug(f'Parameters for Retrieval of documents: \n')
current_app.rag_tuning_logger.debug(f'Similarity Threshold: {similarity_threshold}\n')
current_app.rag_tuning_logger.debug(f'K: {k}\n')
current_app.rag_tuning_logger.debug(f'---------------------------------------\n')
try:
current_date = get_date_in_timezone(self.tenant_info['timezone'])
# Subquery to find the latest version of each document
subquery = (
db.session.query(
DocumentVersion.doc_id,
func.max(DocumentVersion.id).label('latest_version_id')
)
.group_by(DocumentVersion.doc_id)
.subquery()
)
# Main query to filter embeddings
query_obj = (
db.session.query(db_class,
(1 - db_class.embedding.cosine_distance(query_embedding)).label('similarity'))
.join(DocumentVersion, db_class.doc_vers_id == DocumentVersion.id)
.join(Document, DocumentVersion.doc_id == Document.id)
.join(subquery, DocumentVersion.id == subquery.c.latest_version_id)
.filter(
or_(Document.valid_from.is_(None), func.date(Document.valid_from) <= current_date),
or_(Document.valid_to.is_(None), func.date(Document.valid_to) >= current_date),
(1 - db_class.embedding.cosine_distance(query_embedding)) > similarity_threshold,
Document.catalog_id == self._catalog_id
)
.order_by(desc('similarity'))
.limit(k)
)
if self.model_variables['rag_tuning']:
current_app.rag_tuning_logger.debug(f'Query executed for Retrieval of documents: \n')
current_app.rag_tuning_logger.debug(f'{query_obj.statement}\n')
current_app.rag_tuning_logger.debug(f'---------------------------------------\n')
res = query_obj.all()
if self.model_variables['rag_tuning']:
current_app.rag_tuning_logger.debug(f'Retrieved {len(res)} relevant documents \n')
current_app.rag_tuning_logger.debug(f'Data retrieved: \n')
current_app.rag_tuning_logger.debug(f'{res}\n')
current_app.rag_tuning_logger.debug(f'---------------------------------------\n')
result = []
for doc in res:
if self.model_variables['rag_tuning']:
current_app.rag_tuning_logger.debug(f'Document ID: {doc[0].id} - Distance: {doc[1]}\n')
current_app.rag_tuning_logger.debug(f'Chunk: \n {doc[0].chunk}\n\n')
result.append(f'SOURCE: {doc[0].id}\n\n{doc[0].chunk}\n\n')
except SQLAlchemyError as e:
current_app.logger.error(f'Error retrieving relevant documents: {e}')
db.session.rollback()
return []
return result
def _get_query_embedding(self, query: str):
embedding_model = self.model_variables['embedding_model']
query_embedding = embedding_model.embed_query(query)
return query_embedding

View File

@@ -1,154 +0,0 @@
from langchain_core.retrievers import BaseRetriever
from sqlalchemy import func, and_, or_, desc, cast, JSON
from sqlalchemy.exc import SQLAlchemyError
from pydantic import BaseModel, Field, PrivateAttr
from typing import Any, Dict, List, Optional
from flask import current_app
from contextlib import contextmanager
from common.extensions import db
from common.models.document import Document, DocumentVersion, Catalog
from common.utils.datetime_utils import get_date_in_timezone
from common.utils.model_utils import ModelVariables
class EveAIDossierRetriever(BaseRetriever, BaseModel):
_catalog_id: int = PrivateAttr()
_model_variables: ModelVariables = PrivateAttr()
_tenant_info: Dict[str, Any] = PrivateAttr()
_active_filters: Optional[Dict[str, Any]] = PrivateAttr()
def __init__(self, catalog_id: int, model_variables: ModelVariables, tenant_info: Dict[str, Any]):
super().__init__()
self._catalog_id = catalog_id
self._model_variables = model_variables
self._tenant_info = tenant_info
self._active_filters = None
@contextmanager
def filtering(self, metadata_filters: Dict[str, Any]):
"""Context manager for temporarily setting metadata filters"""
previous_filters = self._active_filters
self._active_filters = metadata_filters
try:
yield self
finally:
self._active_filters = previous_filters
def _build_metadata_filter_conditions(self, query):
"""Build SQL conditions for metadata filtering"""
if not self._active_filters:
return query
conditions = []
for field, value in self._active_filters.items():
if value is None:
continue
# Handle both single values and lists of values
if isinstance(value, (list, tuple)):
# Multiple values - create OR condition
or_conditions = []
for val in value:
or_conditions.append(
cast(DocumentVersion.user_metadata[field].astext, JSON) == str(val)
)
if or_conditions:
conditions.append(or_(*or_conditions))
else:
# Single value - direct comparison
conditions.append(
cast(DocumentVersion.user_metadata[field].astext, JSON) == str(value)
)
if conditions:
query = query.filter(and_(*conditions))
return query
def _get_relevant_documents(self, query: str):
current_app.logger.debug(f'Retrieving relevant documents for dossier query: {query}')
if self._active_filters:
current_app.logger.debug(f'Using metadata filters: {self._active_filters}')
query_embedding = self._get_query_embedding(query)
db_class = self.model_variables['embedding_db_model']
similarity_threshold = self.model_variables['similarity_threshold']
k = self.model_variables['k']
try:
current_date = get_date_in_timezone(self.tenant_info['timezone'])
# Subquery to find the latest version of each document
subquery = (
db.session.query(
DocumentVersion.doc_id,
func.max(DocumentVersion.id).label('latest_version_id')
)
.group_by(DocumentVersion.doc_id)
.subquery()
)
# Build base query
# Build base query
query_obj = (
db.session.query(db_class,
(1 - db_class.embedding.cosine_distance(query_embedding)).label('similarity'))
.join(DocumentVersion, db_class.doc_vers_id == DocumentVersion.id)
.join(Document, DocumentVersion.doc_id == Document.id)
.join(subquery, DocumentVersion.id == subquery.c.latest_version_id)
.filter(
or_(Document.valid_from.is_(None), func.date(Document.valid_from) <= current_date),
or_(Document.valid_to.is_(None), func.date(Document.valid_to) >= current_date),
(1 - db_class.embedding.cosine_distance(query_embedding)) > similarity_threshold,
Document.catalog_id == self._catalog_id
)
)
# Apply metadata filters
query_obj = self._build_metadata_filter_conditions(query_obj)
# Order and limit results
query_obj = query_obj.order_by(desc('similarity')).limit(k)
# Debug logging for RAG tuning if enabled
if self.model_variables['rag_tuning']:
self._log_rag_tuning(query_obj, query_embedding)
res = query_obj.all()
result = []
for doc in res:
if self.model_variables['rag_tuning']:
current_app.logger.debug(f'Document ID: {doc[0].id} - Distance: {doc[1]}\n')
current_app.logger.debug(f'Chunk: \n {doc[0].chunk}\n\n')
result.append(f'SOURCE: {doc[0].id}\n\n{doc[0].chunk}\n\n')
except SQLAlchemyError as e:
current_app.logger.error(f'Error retrieving relevant documents: {e}')
db.session.rollback()
return []
return result
def _log_rag_tuning(self, query_obj, query_embedding):
"""Log debug information for RAG tuning"""
current_app.rag_tuning_logger.debug("Debug: Query execution plan:")
current_app.rag_tuning_logger.debug(f"{query_obj.statement}")
if self._active_filters:
current_app.rag_tuning_logger.debug("Debug: Active metadata filters:")
current_app.rag_tuning_logger.debug(f"{self._active_filters}")
def _get_query_embedding(self, query: str):
"""Get embedding for the query text"""
embedding_model = self.model_variables['embedding_model']
query_embedding = embedding_model.embed_query(query)
return query_embedding
@property
def model_variables(self) -> ModelVariables:
return self._model_variables
@property
def tenant_info(self) -> Dict[str, Any]:
return self._tenant_info

View File

@@ -1,52 +0,0 @@
from langchain_core.retrievers import BaseRetriever
from sqlalchemy import asc
from sqlalchemy.exc import SQLAlchemyError
from pydantic import Field, BaseModel, PrivateAttr
from typing import Any, Dict
from flask import current_app
from common.extensions import db
from common.models.interaction import ChatSession, Interaction
from common.utils.model_utils import ModelVariables
class EveAIHistoryRetriever(BaseRetriever, BaseModel):
_model_variables: ModelVariables = PrivateAttr()
_session_id: str = PrivateAttr()
def __init__(self, model_variables: ModelVariables, session_id: str):
super().__init__()
self._model_variables = model_variables
self._session_id = session_id
@property
def model_variables(self) -> ModelVariables:
return self._model_variables
@property
def session_id(self) -> str:
return self._session_id
def _get_relevant_documents(self, query: str):
current_app.logger.debug(f'Retrieving history of interactions for query: {query}')
try:
query_obj = (
db.session.query(Interaction)
.join(ChatSession, Interaction.chat_session_id == ChatSession.id)
.filter(ChatSession.session_id == self.session_id)
.order_by(asc(Interaction.id))
)
interactions = query_obj.all()
result = []
for interaction in interactions:
result.append(f'HUMAN:\n{interaction.detailed_question}\n\nAI: \n{interaction.answer}\n\n')
except SQLAlchemyError as e:
current_app.logger.error(f'Error retrieving history of interactions: {e}')
db.session.rollback()
return []
return result

View File

@@ -1,40 +0,0 @@
from pydantic import BaseModel, PrivateAttr
from typing import Dict, Any
from common.utils.model_utils import ModelVariables
class EveAIRetriever(BaseModel):
_catalog_id: int = PrivateAttr()
_user_metadata: Dict[str, Any] = PrivateAttr()
_system_metadata: Dict[str, Any] = PrivateAttr()
_configuration: Dict[str, Any] = PrivateAttr()
_tenant_info: Dict[str, Any] = PrivateAttr()
_model_variables: ModelVariables = PrivateAttr()
_tuning: bool = PrivateAttr()
def __init__(self, catalog_id: int, user_metadata: Dict[str, Any], system_metadata: Dict[str, Any],
configuration: Dict[str, Any]):
super().__init__()
self._catalog_id = catalog_id
self._user_metadata = user_metadata
self._system_metadata = system_metadata
self._configuration = configuration
@property
def catalog_id(self):
return self._catalog_id
@property
def user_metadata(self):
return self._user_metadata
@property
def system_metadata(self):
return self._system_metadata
@property
def configuration(self):
return self._configuration
# Any common methods that should be shared among retrievers can go here.

View File

@@ -0,0 +1,154 @@
import os
import yaml
from typing import Dict, Optional, Any
from packaging import version
from dataclasses import dataclass
from flask import current_app, Flask
from common.utils.os_utils import get_project_root
@dataclass
class PromptTemplate:
"""Represents a versioned prompt template"""
content: str
version: str
metadata: Dict[str, Any]
class TemplateManager:
"""Manages versioned prompt templates"""
def __init__(self):
self.templates_dir = None
self._templates = None
self.app = None
def init_app(self, app: Flask) -> None:
# Initialize template manager
base_dir = "/app"
self.templates_dir = os.path.join(base_dir, 'config', 'prompts')
app.logger.debug(f'Loading templates from {self.templates_dir}')
self.app = app
self._templates = self._load_templates()
# Log available templates for each supported model
for llm in app.config['SUPPORTED_LLMS']:
try:
available_templates = self.list_templates(llm)
app.logger.info(f"Loaded templates for {llm}: {available_templates}")
except ValueError:
app.logger.warning(f"No templates found for {llm}")
def _load_templates(self) -> Dict[str, Dict[str, Dict[str, PromptTemplate]]]:
"""
Load all template versions from the templates directory.
Structure: {provider.model -> {template_name -> {version -> template}}}
Directory structure:
prompts/
├── provider/
│ └── model/
│ └── template_name/
│ └── version.yaml
"""
templates = {}
# Iterate through providers (anthropic, openai)
for provider in os.listdir(self.templates_dir):
provider_path = os.path.join(self.templates_dir, provider)
if not os.path.isdir(provider_path):
continue
# Iterate through models (claude-3, gpt-4o)
for model in os.listdir(provider_path):
model_path = os.path.join(provider_path, model)
if not os.path.isdir(model_path):
continue
provider_model = f"{provider}.{model}"
templates[provider_model] = {}
# Iterate through template types (rag, summary, etc.)
for template_name in os.listdir(model_path):
template_path = os.path.join(model_path, template_name)
if not os.path.isdir(template_path):
continue
template_versions = {}
# Load all version files for this template
for version_file in os.listdir(template_path):
if not version_file.endswith('.yaml'):
continue
version_str = version_file[:-5] # Remove .yaml
if not self._is_valid_version(version_str):
current_app.logger.warning(
f"Invalid version format for {template_name}: {version_str}")
continue
try:
with open(os.path.join(template_path, version_file)) as f:
template_data = yaml.safe_load(f)
# Verify required fields
if not template_data.get('content'):
raise ValueError("Template content is required")
template_versions[version_str] = PromptTemplate(
content=template_data['content'],
version=version_str,
metadata=template_data.get('metadata', {})
)
except Exception as e:
current_app.logger.error(
f"Error loading template {template_name} version {version_str}: {e}")
continue
if template_versions:
templates[provider_model][template_name] = template_versions
return templates
def _is_valid_version(self, version_str: str) -> bool:
"""Validate semantic versioning string"""
try:
version.parse(version_str)
return True
except version.InvalidVersion:
return False
def get_template(self,
provider_model: str,
template_name: str,
template_version: Optional[str] = None) -> PromptTemplate:
"""
Get a specific template version. If version not specified,
returns the latest version.
"""
if provider_model not in self._templates:
raise ValueError(f"Unknown provider.model: {provider_model}")
if template_name not in self._templates[provider_model]:
raise ValueError(f"Unknown template: {template_name}")
versions = self._templates[provider_model][template_name]
if template_version:
if template_version not in versions:
raise ValueError(f"Template version {template_version} not found")
return versions[template_version]
# Return latest version
latest = max(versions.keys(), key=version.parse)
return versions[latest]
def list_templates(self, provider_model: str) -> Dict[str, list]:
"""
List all available templates and their versions for a provider.model
Returns: {template_name: [version1, version2, ...]}
"""
if provider_model not in self._templates:
raise ValueError(f"Unknown provider.model: {provider_model}")
return {
template_name: sorted(versions.keys(), key=version.parse)
for template_name, versions in self._templates[provider_model].items()
}

View File

@@ -1,27 +0,0 @@
import time
from common.utils.business_event_context import current_event
def tracked_transcribe(client, *args, **kwargs):
start_time = time.time()
# Extract the file and model from kwargs if present, otherwise use defaults
file = kwargs.get('file')
model = kwargs.get('model', 'whisper-1')
duration = kwargs.pop('duration', 600)
result = client.audio.transcriptions.create(*args, **kwargs)
end_time = time.time()
# Token usage for transcriptions is actually the duration in seconds we pass, as the whisper model is priced per second transcribed
metrics = {
'total_tokens': duration,
'prompt_tokens': 0, # For transcriptions, all tokens are considered "completion"
'completion_tokens': duration,
'time_elapsed': end_time - start_time,
'interaction_type': 'ASR',
}
current_event.log_llm_metrics(metrics)
return result

View File

@@ -0,0 +1,77 @@
# common/langchain/tracked_transcription.py
from typing import Any, Optional, Dict
import time
from openai import OpenAI
from common.utils.business_event_context import current_event
class TrackedOpenAITranscription:
"""Wrapper for OpenAI transcription with metric tracking"""
def __init__(self, api_key: str, **kwargs: Any):
"""Initialize with OpenAI client settings"""
self.client = OpenAI(api_key=api_key)
self.model = kwargs.get('model', 'whisper-1')
def transcribe(self,
file: Any,
model: Optional[str] = None,
language: Optional[str] = None,
prompt: Optional[str] = None,
response_format: Optional[str] = None,
temperature: Optional[float] = None,
duration: Optional[int] = None) -> str:
"""
Transcribe audio with metrics tracking
Args:
file: Audio file to transcribe
model: Model to use (defaults to whisper-1)
language: Optional language of the audio
prompt: Optional prompt to guide transcription
response_format: Response format (json, text, etc)
temperature: Sampling temperature
duration: Duration of audio in seconds for metrics
Returns:
Transcription text
"""
start_time = time.time()
try:
# Create transcription options
options = {
"file": file,
"model": model or self.model,
}
if language:
options["language"] = language
if prompt:
options["prompt"] = prompt
if response_format:
options["response_format"] = response_format
if temperature:
options["temperature"] = temperature
response = self.client.audio.transcriptions.create(**options)
# Calculate metrics
end_time = time.time()
# Token usage for transcriptions is based on audio duration
metrics = {
'total_tokens': duration or 600, # Default to 10 minutes if duration not provided
'prompt_tokens': 0, # For transcriptions, all tokens are completion
'completion_tokens': duration or 600,
'time_elapsed': end_time - start_time,
'interaction_type': 'ASR',
}
current_event.log_llm_metrics(metrics)
# Return text from response
if isinstance(response, str):
return response
return response.text
except Exception as e:
raise Exception(f"Transcription failed: {str(e)}")

View File

@@ -12,22 +12,31 @@ class Catalog(db.Model):
description = db.Column(db.Text, nullable=True)
type = db.Column(db.String(50), nullable=False, default="STANDARD_CATALOG")
# Embedding variables
html_tags = db.Column(ARRAY(sa.String(10)), nullable=True, default=['p', 'h1', 'h2', 'h3', 'h4', 'h5', 'h6', 'li'])
html_end_tags = db.Column(ARRAY(sa.String(10)), nullable=True, default=['p', 'li'])
html_included_elements = db.Column(ARRAY(sa.String(50)), nullable=True)
html_excluded_elements = db.Column(ARRAY(sa.String(50)), nullable=True)
html_excluded_classes = db.Column(ARRAY(sa.String(200)), nullable=True)
min_chunk_size = db.Column(db.Integer, nullable=True, default=2000)
max_chunk_size = db.Column(db.Integer, nullable=True, default=3000)
# Chat variables ==> Move to Specialist?
chat_RAG_temperature = db.Column(db.Float, nullable=True, default=0.3)
chat_no_RAG_temperature = db.Column(db.Float, nullable=True, default=0.5)
# Meta Data
user_metadata = db.Column(JSONB, nullable=True)
system_metadata = db.Column(JSONB, nullable=True)
configuration = db.Column(JSONB, nullable=True)
# Versioning Information
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now())
created_by = db.Column(db.Integer, db.ForeignKey(User.id), nullable=True)
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now(), onupdate=db.func.now())
updated_by = db.Column(db.Integer, db.ForeignKey(User.id))
class Processor(db.Model):
id = db.Column(db.Integer, primary_key=True)
name = db.Column(db.String(50), nullable=False)
description = db.Column(db.Text, nullable=True)
catalog_id = db.Column(db.Integer, db.ForeignKey('catalog.id'), nullable=True)
type = db.Column(db.String(50), nullable=False)
sub_file_type = db.Column(db.String(50), nullable=True)
# Tuning enablers
embed_tuning = db.Column(db.Boolean, nullable=True, default=False)
tuning = db.Column(db.Boolean, nullable=True, default=False)
# Meta Data
user_metadata = db.Column(JSONB, nullable=True)
@@ -90,6 +99,7 @@ class DocumentVersion(db.Model):
bucket_name = db.Column(db.String(255), nullable=True)
object_name = db.Column(db.String(200), nullable=True)
file_type = db.Column(db.String(20), nullable=True)
sub_file_type = db.Column(db.String(50), nullable=True)
file_size = db.Column(db.Float, nullable=True)
language = db.Column(db.String(2), nullable=False)
user_context = db.Column(db.Text, nullable=True)

View File

@@ -20,34 +20,6 @@ class ChatSession(db.Model):
return f"<ChatSession {self.id} by {self.user_id}>"
class Interaction(db.Model):
id = db.Column(db.Integer, primary_key=True)
chat_session_id = db.Column(db.Integer, db.ForeignKey(ChatSession.id), nullable=False)
question = db.Column(db.Text, nullable=False)
detailed_question = db.Column(db.Text, nullable=True)
answer = db.Column(db.Text, nullable=True)
algorithm_used = db.Column(db.String(20), nullable=True)
language = db.Column(db.String(2), nullable=False)
timezone = db.Column(db.String(30), nullable=True)
appreciation = db.Column(db.Integer, nullable=True)
# Timing information
question_at = db.Column(db.DateTime, nullable=False)
detailed_question_at = db.Column(db.DateTime, nullable=True)
answer_at = db.Column(db.DateTime, nullable=True)
# Relations
embeddings = db.relationship('InteractionEmbedding', backref='interaction', lazy=True)
def __repr__(self):
return f"<Interaction {self.id}>"
class InteractionEmbedding(db.Model):
interaction_id = db.Column(db.Integer, db.ForeignKey(Interaction.id, ondelete='CASCADE'), primary_key=True)
embedding_id = db.Column(db.Integer, db.ForeignKey(Embedding.id, ondelete='CASCADE'), primary_key=True)
class Specialist(db.Model):
id = db.Column(db.Integer, primary_key=True)
name = db.Column(db.String(50), nullable=False)
@@ -68,7 +40,34 @@ class Specialist(db.Model):
updated_by = db.Column(db.Integer, db.ForeignKey(User.id))
class Interaction(db.Model):
id = db.Column(db.Integer, primary_key=True)
chat_session_id = db.Column(db.Integer, db.ForeignKey(ChatSession.id), nullable=False)
specialist_id = db.Column(db.Integer, db.ForeignKey(Specialist.id), nullable=True)
specialist_arguments = db.Column(JSONB, nullable=True)
specialist_results = db.Column(JSONB, nullable=True)
timezone = db.Column(db.String(30), nullable=True)
appreciation = db.Column(db.Integer, nullable=True)
# Timing information
question_at = db.Column(db.DateTime, nullable=False)
detailed_question_at = db.Column(db.DateTime, nullable=True)
answer_at = db.Column(db.DateTime, nullable=True)
# Relations
embeddings = db.relationship('InteractionEmbedding', backref='interaction', lazy=True)
def __repr__(self):
return f"<Interaction {self.id}>"
class InteractionEmbedding(db.Model):
interaction_id = db.Column(db.Integer, db.ForeignKey(Interaction.id, ondelete='CASCADE'), primary_key=True)
embedding_id = db.Column(db.Integer, db.ForeignKey(Embedding.id, ondelete='CASCADE'), primary_key=True)
class SpecialistRetriever(db.Model):
specialist_id = db.Column(db.Integer, db.ForeignKey(Specialist.id, ondelete='CASCADE'), primary_key=True)
retriever_id = db.Column(db.Integer, db.ForeignKey(Retriever.id, ondelete='CASCADE'), primary_key=True)
retriever = db.relationship("Retriever", backref="specialist_retrievers")

View File

@@ -4,7 +4,6 @@ from contextlib import contextmanager
from datetime import datetime
from typing import Dict, Any, Optional
from datetime import datetime as dt, timezone as tz
from portkey_ai import Portkey, Config
import logging
from .business_event_context import BusinessEventContext

89
common/utils/cache/base.py vendored Normal file
View File

@@ -0,0 +1,89 @@
# common/utils/cache/base.py
from typing import Any, Dict, List, Optional, TypeVar, Generic, Type
from dataclasses import dataclass
from flask import Flask
from dogpile.cache import CacheRegion
T = TypeVar('T')
@dataclass
class CacheKey:
"""Represents a cache key with multiple components"""
components: Dict[str, Any]
def __str__(self) -> str:
return ":".join(f"{k}={v}" for k, v in sorted(self.components.items()))
class CacheInvalidationManager:
"""Manages cache invalidation subscriptions"""
def __init__(self):
self._subscribers = {}
def subscribe(self, model: str, handler: 'CacheHandler', key_fields: List[str]):
if model not in self._subscribers:
self._subscribers[model] = []
self._subscribers[model].append((handler, key_fields))
def notify_change(self, model: str, **identifiers):
if model in self._subscribers:
for handler, key_fields in self._subscribers[model]:
if all(field in identifiers for field in key_fields):
handler.invalidate_by_model(model, **identifiers)
class CacheHandler(Generic[T]):
"""Base cache handler implementation"""
def __init__(self, region: CacheRegion, prefix: str):
self.region = region
self.prefix = prefix
self._key_components = []
def configure_keys(self, *components: str):
self._key_components = components
return self
def subscribe_to_model(self, model: str, key_fields: List[str]):
invalidation_manager.subscribe(model, self, key_fields)
return self
def generate_key(self, **identifiers) -> str:
missing = set(self._key_components) - set(identifiers.keys())
if missing:
raise ValueError(f"Missing key components: {missing}")
key = CacheKey({k: identifiers[k] for k in self._key_components})
return f"{self.prefix}:{str(key)}"
def get(self, creator_func, **identifiers) -> T:
cache_key = self.generate_key(**identifiers)
def creator():
instance = creator_func(**identifiers)
return self.to_cache_data(instance)
cached_data = self.region.get_or_create(
cache_key,
creator,
should_cache_fn=self.should_cache
)
return self.from_cache_data(cached_data, **identifiers)
def invalidate(self, **identifiers):
cache_key = self.generate_key(**identifiers)
self.region.delete(cache_key)
def invalidate_by_model(self, model: str, **identifiers):
try:
self.invalidate(**identifiers)
except ValueError:
pass
# Create global invalidation manager
invalidation_manager = CacheInvalidationManager()

View File

@@ -0,0 +1,32 @@
from typing import Type
from flask import Flask
from common.utils.cache.base import CacheHandler
class EveAICacheManager:
"""Cache manager with registration capabilities"""
def __init__(self):
self.model_region = None
self.eveai_chat_workers_region = None
self.eveai_workers_region = None
self._handlers = {}
def init_app(self, app: Flask):
"""Initialize cache regions"""
from common.utils.cache.regions import create_cache_regions
self.model_region, self.eveai_chat_workers_region, self.eveai_workers_region = create_cache_regions(app)
# Initialize all registered handlers with their regions
for handler_class, region_name in self._handlers.items():
region = getattr(self, f"{region_name}_region")
handler_instance = handler_class(region)
setattr(self, handler_class.handler_name, handler_instance)
def register_handler(self, handler_class: Type[CacheHandler], region: str):
"""Register a cache handler class with its region"""
if not hasattr(handler_class, 'handler_name'):
raise ValueError("Cache handler must define handler_name class attribute")
self._handlers[handler_class] = region

61
common/utils/cache/regions.py vendored Normal file
View File

@@ -0,0 +1,61 @@
# common/utils/cache/regions.py
from dogpile.cache import make_region
from flask import current_app
from urllib.parse import urlparse
import os
def get_redis_config(app):
"""
Create Redis configuration dict based on app config
Handles both authenticated and non-authenticated setups
"""
# Parse the REDIS_BASE_URI to get all components
redis_uri = urlparse(app.config['REDIS_BASE_URI'])
config = {
'host': redis_uri.hostname,
'port': int(redis_uri.port or 6379),
'db': 4, # Keep this for later use
'redis_expiration_time': 3600,
'distributed_lock': True
}
# Add authentication if provided
if redis_uri.username and redis_uri.password:
config.update({
'username': redis_uri.username,
'password': redis_uri.password
})
return config
def create_cache_regions(app):
"""Initialize all cache regions with app config"""
redis_config = get_redis_config(app)
# Region for model-related caching (ModelVariables etc)
model_region = make_region(name='model').configure(
'dogpile.cache.redis',
arguments=redis_config,
replace_existing_backend=True
)
# Region for eveai_chat_workers components (Specialists, Retrievers, ...)
eveai_chat_workers_region = make_region(name='chat_workers').configure(
'dogpile.cache.redis',
arguments=redis_config, # arguments={**redis_config, 'db': 4}, # Different DB
replace_existing_backend=True
)
# Region for eveai_workers components (Processors, ...)
eveai_workers_region = make_region(name='workers').configure(
'dogpile.cache.redis',
arguments=redis_config, # Same config for now
replace_existing_backend=True
)
return model_region, eveai_chat_workers_region, eveai_workers_region

View File

@@ -8,8 +8,6 @@ celery_app = Celery()
def init_celery(celery, app, is_beat=False):
celery_app.main = app.name
app.logger.debug(f'CELERY_BROKER_URL: {app.config["CELERY_BROKER_URL"]}')
app.logger.debug(f'CELERY_RESULT_BACKEND: {app.config["CELERY_RESULT_BACKEND"]}')
celery_config = {
'broker_url': app.config.get('CELERY_BROKER_URL', 'redis://localhost:6379/0'),

View File

@@ -0,0 +1,613 @@
from typing import Optional, List, Union, Dict, Any, Pattern
from pydantic import BaseModel, field_validator, model_validator
from typing_extensions import Annotated
import re
from datetime import datetime
import json
from textwrap import dedent
import yaml
from dataclasses import dataclass
class TaggingField(BaseModel):
"""Represents a single tagging field configuration"""
type: str
required: bool = False
description: Optional[str] = None
allowed_values: Optional[List[Any]] = None # for enum type
min_value: Optional[Union[int, float]] = None # for numeric types
max_value: Optional[Union[int, float]] = None # for numeric types
@field_validator('type', mode='before')
@classmethod
def validate_type(cls, v: str) -> str:
valid_types = ['string', 'integer', 'float', 'date', 'enum']
if v not in valid_types:
raise ValueError(f'type must be one of {valid_types}')
return v
@model_validator(mode='after')
def validate_field_constraints(self) -> 'TaggingField':
# Validate enum constraints
if self.type == 'enum':
if not self.allowed_values:
raise ValueError('allowed_values must be provided for enum type')
elif self.allowed_values is not None:
raise ValueError('allowed_values only valid for enum type')
# Validate numeric constraints
if self.type not in ('integer', 'float'):
if self.min_value is not None or self.max_value is not None:
raise ValueError('min_value/max_value only valid for numeric types')
else:
if self.min_value is not None and self.max_value is not None and self.min_value >= self.max_value:
raise ValueError('min_value must be less than max_value')
return self
class TaggingFields(BaseModel):
"""Represents a collection of tagging fields, mapped by their names"""
fields: Dict[str, TaggingField]
@classmethod
def from_dict(cls, data: Dict[str, Dict[str, Any]]) -> 'TaggingFields':
return cls(fields={
field_name: TaggingField(**field_config)
for field_name, field_config in data.items()
})
def to_dict(self) -> Dict[str, Dict[str, Any]]:
return {
field_name: field.model_dump(exclude_none=True)
for field_name, field in self.fields.items()
}
class ArgumentConstraint(BaseModel):
"""Base class for all argument constraints"""
description: Optional[str] = None
error_message: Optional[str] = None
class NumericConstraint(ArgumentConstraint):
"""Constraints for numeric values (int/float)"""
min_value: Optional[float] = None
max_value: Optional[float] = None
include_min: bool = True # True for >= min_value, False for > min_value
include_max: bool = True # True for <= max_value, False for < max_value
@model_validator(mode='after')
def validate_ranges(self) -> 'NumericConstraint':
if self.min_value is not None and self.max_value is not None:
if self.min_value > self.max_value:
raise ValueError("min_value must be less than or equal to max_value")
return self
def validate(self, value: Union[int, float]) -> bool:
if self.min_value is not None:
if self.include_min and value < self.min_value:
return False
if not self.include_min and value <= self.min_value:
return False
if self.max_value is not None:
if self.include_max and value > self.max_value:
return False
if not self.include_max and value >= self.max_value:
return False
return True
class StringConstraint(ArgumentConstraint):
"""Constraints for string values"""
min_length: Optional[int] = None
max_length: Optional[int] = None
patterns: Optional[List[str]] = None # List of regex patterns to match
pattern_match_all: bool = False # If True, string must match all patterns
forbidden_patterns: Optional[List[str]] = None # List of regex patterns that must not match
allow_empty: bool = False
@field_validator('patterns', 'forbidden_patterns')
@classmethod
def validate_patterns(cls, v: Optional[List[str]]) -> Optional[List[str]]:
if v is not None:
# Validate each pattern compiles
for pattern in v:
try:
re.compile(pattern)
except re.error as e:
raise ValueError(f"Invalid regex pattern '{pattern}': {str(e)}")
return v
def validate(self, value: str) -> bool:
if not self.allow_empty and not value:
return False
if self.min_length is not None and len(value) < self.min_length:
return False
if self.max_length is not None and len(value) > self.max_length:
return False
if self.patterns:
matches = [bool(re.search(pattern, value)) for pattern in self.patterns]
if self.pattern_match_all and not all(matches):
return False
if not self.pattern_match_all and not any(matches):
return False
if self.forbidden_patterns:
for pattern in self.forbidden_patterns:
if re.search(pattern, value):
return False
return True
class DateConstraint(ArgumentConstraint):
"""Constraints for date values"""
min_date: Optional[datetime] = None
max_date: Optional[datetime] = None
include_min: bool = True
include_max: bool = True
allowed_formats: Optional[List[str]] = None # List of allowed date formats
@model_validator(mode='after')
def validate_ranges(self) -> 'DateConstraint':
if self.min_date and self.max_date and self.min_date > self.max_date:
raise ValueError("min_date must be less than or equal to max_date")
return self
def validate(self, value: datetime) -> bool:
if self.min_date is not None:
if self.include_min and value < self.min_date:
return False
if not self.include_min and value <= self.min_date:
return False
if self.max_date is not None:
if self.include_max and value > self.max_date:
return False
if not self.include_max and value >= self.max_date:
return False
return True
class EnumConstraint(ArgumentConstraint):
"""Constraints for enum values"""
allowed_values: List[Any]
case_sensitive: bool = True # For string enums
allow_multiple: bool = False # If True, value can be a list of allowed values
min_selections: Optional[int] = None # When allow_multiple is True
max_selections: Optional[int] = None # When allow_multiple is True
@model_validator(mode='after')
def validate_selections(self) -> 'EnumConstraint':
if self.allow_multiple:
if self.min_selections is not None and self.max_selections is not None:
if self.min_selections > self.max_selections:
raise ValueError("min_selections must be less than or equal to max_selections")
if self.max_selections > len(self.allowed_values):
raise ValueError("max_selections cannot be greater than number of allowed values")
return self
def validate(self, value: Union[Any, List[Any]]) -> bool:
if self.allow_multiple:
if not isinstance(value, list):
return False
if self.min_selections is not None and len(value) < self.min_selections:
return False
if self.max_selections is not None and len(value) > self.max_selections:
return False
for v in value:
if not self._validate_single_value(v):
return False
else:
return self._validate_single_value(value)
return True
def _validate_single_value(self, value: Any) -> bool:
if isinstance(value, str) and not self.case_sensitive:
return any(str(value).lower() == str(v).lower() for v in self.allowed_values)
return value in self.allowed_values
class ArgumentDefinition(BaseModel):
"""Defines an argument with its type and constraints"""
name: str
type: str
description: Optional[str] = None
required: bool = False
default: Optional[Any] = None
constraints: Optional[Union[NumericConstraint, StringConstraint, DateConstraint, EnumConstraint]] = None
@field_validator('type')
@classmethod
def validate_type(cls, v: str) -> str:
valid_types = ['string', 'integer', 'float', 'date', 'enum']
if v not in valid_types:
raise ValueError(f'type must be one of {valid_types}')
return v
@model_validator(mode='after')
def validate_constraints(self) -> 'ArgumentDefinition':
if self.constraints:
expected_constraint_types = {
'string': StringConstraint,
'integer': NumericConstraint,
'float': NumericConstraint,
'date': DateConstraint,
'enum': EnumConstraint
}
expected_type = expected_constraint_types.get(self.type)
if not isinstance(self.constraints, expected_type):
raise ValueError(f'Constraints for type {self.type} must be of type {expected_type.__name__}')
if self.default is not None:
if not self.constraints.validate(self.default):
raise ValueError(f'Default value does not satisfy constraints for {self.name}')
return self
class ArgumentDefinitions(BaseModel):
"""Collection of argument definitions"""
arguments: Dict[str, ArgumentDefinition]
@classmethod
def from_dict(cls, data: Dict[str, Dict[str, Any]]) -> 'ArgumentDefinitions':
return cls(arguments={
arg_name: ArgumentDefinition(**arg_config)
for arg_name, arg_config in data.items()
})
def to_dict(self) -> Dict[str, Dict[str, Any]]:
return {
arg_name: arg.model_dump(exclude_none=True)
for arg_name, arg in self.arguments.items()
}
def validate_argument_values(self, values: Dict[str, Any]) -> Dict[str, str]:
"""
Validate a set of argument values against their definitions
Returns a dictionary of error messages for invalid arguments
"""
errors = {}
# Check for required arguments
for name, arg_def in self.arguments.items():
if arg_def.required and name not in values:
errors[name] = "Required argument missing"
continue
if name in values:
value = values[name]
# Validate type
try:
if arg_def.type == 'integer':
value = int(value)
elif arg_def.type == 'float':
value = float(value)
elif arg_def.type == 'date' and isinstance(value, str):
if arg_def.constraints and arg_def.constraints.allowed_formats:
for fmt in arg_def.constraints.allowed_formats:
try:
value = datetime.strptime(value, fmt)
break
except ValueError:
continue
else:
errors[
name] = f"Invalid date format. Allowed formats: {arg_def.constraints.allowed_formats}"
continue
except (ValueError, TypeError):
errors[name] = f"Invalid type. Expected {arg_def.type}"
continue
# Validate constraints
if arg_def.constraints and not arg_def.constraints.validate(value):
errors[name] = arg_def.constraints.error_message or "Value does not satisfy constraints"
return errors
@dataclass
class DocumentationFormat:
"""Constants for documentation formats"""
MARKDOWN = "markdown"
JSON = "json"
YAML = "yaml"
@dataclass
class DocumentationVersion:
"""Constants for documentation versions"""
BASIC = "basic" # Original documentation without retriever info
EXTENDED = "extended" # Including retriever documentation
def _generate_argument_constraints(field_config: Dict[str, Any]) -> List[Dict[str, Any]]:
"""Generate possible argument constraints based on field type"""
constraints = []
base_constraint = {
"description": f"Constraint for {field_config.get('description', 'field')}",
"error_message": "Optional custom error message"
}
if field_config["type"] == "integer" or field_config["type"] == "float":
constraints.append({
**base_constraint,
"type": "NumericConstraint",
"possible_constraints": {
"min_value": "number",
"max_value": "number",
"include_min": "boolean",
"include_max": "boolean"
},
"example": {
"min_value": field_config.get("min_value", 0),
"max_value": field_config.get("max_value", 100),
"include_min": True,
"include_max": True
}
})
elif field_config["type"] == "string":
constraints.append({
**base_constraint,
"type": "StringConstraint",
"possible_constraints": {
"min_length": "integer",
"max_length": "integer",
"patterns": "list[str]",
"pattern_match_all": "boolean",
"forbidden_patterns": "list[str]",
"allow_empty": "boolean"
},
"example": {
"min_length": 1,
"max_length": 100,
"patterns": ["^[A-Za-z0-9]+$"],
"pattern_match_all": False,
"forbidden_patterns": ["^test_", "_temp$"],
"allow_empty": False
}
})
elif field_config["type"] == "enum":
constraints.append({
**base_constraint,
"type": "EnumConstraint",
"possible_constraints": {
"allowed_values": f"list[{field_config.get('allowed_values', ['value1', 'value2'])}]",
"case_sensitive": "boolean",
"allow_multiple": "boolean",
"min_selections": "integer",
"max_selections": "integer"
},
"example": {
"allowed_values": field_config.get("allowed_values", ["value1", "value2"]),
"case_sensitive": True,
"allow_multiple": True,
"min_selections": 1,
"max_selections": 2
}
})
elif field_config["type"] == "date":
constraints.append({
**base_constraint,
"type": "DateConstraint",
"possible_constraints": {
"min_date": "datetime",
"max_date": "datetime",
"include_min": "boolean",
"include_max": "boolean",
"allowed_formats": "list[str]"
},
"example": {
"min_date": "2024-01-01T00:00:00",
"max_date": "2024-12-31T23:59:59",
"include_min": True,
"include_max": True,
"allowed_formats": ["%Y-%m-%d", "%Y/%m/%d"]
}
})
return constraints
def generate_field_documentation(
tagging_fields: Dict[str, Any],
format: str = "markdown",
version: str = "basic"
) -> str:
"""
Generate documentation for tagging fields configuration.
Args:
tagging_fields: Dictionary containing tagging fields configuration
format: Output format ("markdown", "json", or "yaml")
version: Documentation version ("basic" or "extended")
Returns:
str: Formatted documentation
"""
if version not in [DocumentationVersion.BASIC, DocumentationVersion.EXTENDED]:
raise ValueError(f"Unsupported documentation version: {version}")
# Normalize fields configuration
normalized_fields = {}
for field_name, field_config in tagging_fields.items():
field_doc = {
"name": field_name,
"type": field_config["type"],
"required": field_config.get("required", False),
"description": field_config.get("description", "No description provided"),
"constraints": []
}
# Only include possible arguments in extended version
if version == DocumentationVersion.EXTENDED:
field_doc["possible_arguments"] = _generate_argument_constraints(field_config)
# Add type-specific constraints
if field_config["type"] == "integer" or field_config["type"] == "float":
if "min_value" in field_config:
field_doc["constraints"].append(
f"Minimum value: {field_config['min_value']}")
if "max_value" in field_config:
field_doc["constraints"].append(
f"Maximum value: {field_config['max_value']}")
elif field_config["type"] == "string":
if "min_length" in field_config:
field_doc["constraints"].append(
f"Minimum length: {field_config['min_length']}")
if "max_length" in field_config:
field_doc["constraints"].append(
f"Maximum length: {field_config['max_length']}")
if "patterns" in field_config:
field_doc["constraints"].append(
f"Must match patterns: {', '.join(field_config['patterns'])}")
elif field_config["type"] == "enum":
if "allowed_values" in field_config:
field_doc["constraints"].append(
f"Allowed values: {', '.join(str(v) for v in field_config['allowed_values'])}")
elif field_config["type"] == "date":
if "min_date" in field_config:
field_doc["constraints"].append(
f"Minimum date: {field_config['min_date']}")
if "max_date" in field_config:
field_doc["constraints"].append(
f"Maximum date: {field_config['max_date']}")
if "allowed_formats" in field_config:
field_doc["constraints"].append(
f"Allowed formats: {', '.join(field_config['allowed_formats'])}")
normalized_fields[field_name] = field_doc
# Generate documentation in requested format
if format == DocumentationFormat.MARKDOWN:
return _generate_markdown_docs(normalized_fields, version)
elif format == DocumentationFormat.JSON:
return _generate_json_docs(normalized_fields, version)
elif format == DocumentationFormat.YAML:
return _generate_yaml_docs(normalized_fields, version)
else:
raise ValueError(f"Unsupported documentation format: {format}")
def _generate_markdown_docs(fields: Dict[str, Any], version: str) -> str:
"""Generate markdown documentation"""
docs = ["# Tagging Fields Documentation\n"]
# Add overview table
docs.append("## Fields Overview\n")
docs.append("| Field Name | Type | Required | Description |")
docs.append("|------------|------|----------|-------------|")
for field_name, field in fields.items():
docs.append(
f"| {field_name} | {field['type']} | "
f"{'Yes' if field['required'] else 'No'} | {field['description']} |"
)
# Add detailed field specifications
docs.append("\n## Detailed Field Specifications\n")
for field_name, field in fields.items():
docs.append(f"### {field_name}\n")
docs.append(f"**Type:** {field['type']}")
docs.append(f"**Required:** {'Yes' if field['required'] else 'No'}")
docs.append(f"**Description:** {field['description']}\n")
if field["constraints"]:
docs.append("**Field Constraints:**")
for constraint in field["constraints"]:
docs.append(f"- {constraint}")
docs.append("")
# Add retriever argument documentation only in extended version
if version == DocumentationVersion.EXTENDED and "possible_arguments" in field:
docs.append("**Possible Retriever Arguments:**")
for arg_constraint in field["possible_arguments"]:
docs.append(f"\n*{arg_constraint['type']}*")
docs.append(f"Description: {arg_constraint['description']}")
docs.append("\nPossible constraints:")
for const_name, const_type in arg_constraint["possible_constraints"].items():
docs.append(f"- `{const_name}`: {const_type}")
docs.append("\nExample:")
docs.append("```python")
docs.append(json.dumps(arg_constraint["example"], indent=2))
docs.append("```\n")
# Add example retriever configuration only in extended version
if version == DocumentationVersion.EXTENDED:
docs.append("\n## Example Retriever Configuration\n")
docs.append("```python")
example_config = {
"metadata_filters": {
field_name: field["possible_arguments"][0]["example"]
for field_name, field in fields.items()
if "possible_arguments" in field
}
}
docs.append(json.dumps(example_config, indent=2))
docs.append("```")
return "\n".join(docs)
def _generate_json_docs(fields: Dict[str, Any], version: str) -> str:
"""Generate JSON documentation"""
doc = {
"tagging_fields_documentation": {
"version": version,
"fields": fields
}
}
if version == DocumentationVersion.EXTENDED:
doc["tagging_fields_documentation"]["example_retriever_config"] = {
"metadata_filters": {
field_name: field["possible_arguments"][0]["example"]
for field_name, field in fields.items()
if "possible_arguments" in field
}
}
return json.dumps(doc, indent=2)
def _generate_yaml_docs(fields: Dict[str, Any], version: str) -> str:
"""Generate YAML documentation"""
doc = {
"tagging_fields_documentation": {
"version": version,
"fields": fields
}
}
if version == DocumentationVersion.EXTENDED:
doc["tagging_fields_documentation"]["example_retriever_config"] = {
"metadata_filters": {
field_name: field["possible_arguments"][0]["example"]
for field_name, field in fields.items()
if "possible_arguments" in field
}
}
return yaml.dump(doc, sort_keys=False, default_flow_style=False)

View File

@@ -5,10 +5,8 @@ from common.models.user import Tenant, TenantDomain
def get_allowed_origins(tenant_id):
session_key = f"allowed_origins_{tenant_id}"
if session_key in session:
current_app.logger.debug(f"Fetching allowed origins for tenant {tenant_id} from session")
return session[session_key]
current_app.logger.debug(f"Fetching allowed origins for tenant {tenant_id} from database")
tenant_domains = TenantDomain.query.filter_by(tenant_id=int(tenant_id)).all()
allowed_origins = [domain.domain for domain in tenant_domains]
@@ -18,14 +16,8 @@ def get_allowed_origins(tenant_id):
def cors_after_request(response, prefix):
current_app.logger.debug(f'CORS after request: {request.path}, prefix: {prefix}')
current_app.logger.debug(f'request.headers: {request.headers}')
current_app.logger.debug(f'request.args: {request.args}')
current_app.logger.debug(f'request is json?: {request.is_json}')
# Exclude health checks from checks
if request.path.startswith('/healthz') or request.path.startswith('/_healthz'):
current_app.logger.debug('Skipping CORS headers for health checks')
response.headers.add('Access-Control-Allow-Origin', '*')
response.headers.add('Access-Control-Allow-Headers', '*')
response.headers.add('Access-Control-Allow-Methods', '*')
@@ -36,7 +28,6 @@ def cors_after_request(response, prefix):
# Try to get tenant_id from JSON payload
json_data = request.get_json(silent=True)
current_app.logger.debug(f'request.get_json(silent=True): {json_data}')
if json_data and 'tenant_id' in json_data:
tenant_id = json_data['tenant_id']
@@ -44,23 +35,17 @@ def cors_after_request(response, prefix):
# Fallback to get tenant_id from query parameters or headers if JSON is not available
tenant_id = request.args.get('tenant_id') or request.args.get('tenantId') or request.headers.get('X-Tenant-ID')
current_app.logger.debug(f'Identified tenant_id: {tenant_id}')
if tenant_id:
allowed_origins = get_allowed_origins(tenant_id)
current_app.logger.debug(f'Allowed origins for tenant {tenant_id}: {allowed_origins}')
else:
current_app.logger.warning('tenant_id not found in request')
origin = request.headers.get('Origin')
current_app.logger.debug(f'Origin: {origin}')
if origin in allowed_origins:
response.headers.add('Access-Control-Allow-Origin', origin)
response.headers.add('Access-Control-Allow-Headers', 'Content-Type,Authorization')
response.headers.add('Access-Control-Allow-Methods', 'GET,POST,PUT,DELETE,OPTIONS')
response.headers.add('Access-Control-Allow-Credentials', 'true')
current_app.logger.debug(f'CORS headers set for origin: {origin}')
else:
current_app.logger.warning(f'Origin {origin} not allowed')

View File

@@ -36,7 +36,7 @@ def log_request_middleware(app):
@app.before_request
def log_session_state_before():
app.logger.debug(f'Session state before request: {session.items()}')
pass
# @app.after_request
# def log_response_info(response):
@@ -58,5 +58,4 @@ def log_request_middleware(app):
@app.after_request
def log_session_state_after(response):
app.logger.debug(f'Session state after request: {session.items()}')
return response

View File

@@ -24,6 +24,7 @@ def create_document_stack(api_input, file, filename, extension, tenant_id):
# Create the DocumentVersion
new_doc_vers = create_version_for_document(new_doc, tenant_id,
api_input.get('url', ''),
api_input.get('sub_file_type', ''),
api_input.get('language', 'en'),
api_input.get('user_context', ''),
api_input.get('user_metadata'),
@@ -64,7 +65,7 @@ def create_document(form, filename, catalog_id):
return new_doc
def create_version_for_document(document, tenant_id, url, language, user_context, user_metadata, catalog_properties):
def create_version_for_document(document, tenant_id, url, sub_file_type, language, user_context, user_metadata, catalog_properties):
new_doc_vers = DocumentVersion()
if url != '':
new_doc_vers.url = url
@@ -83,6 +84,9 @@ def create_version_for_document(document, tenant_id, url, language, user_context
if catalog_properties != '' and catalog_properties is not None:
new_doc_vers.catalog_properties = catalog_properties
if sub_file_type != '':
new_doc_vers.sub_file_type = sub_file_type
new_doc_vers.document = document
set_logging_information(new_doc_vers, dt.now(tz.utc))
@@ -237,8 +241,6 @@ def start_embedding_task(tenant_id, doc_vers_id):
def validate_file_type(extension):
current_app.logger.debug(f'Validating file type {extension}')
current_app.logger.debug(f'Supported file types: {current_app.config["SUPPORTED_FILE_TYPES"]}')
if extension not in current_app.config['SUPPORTED_FILE_TYPES']:
raise EveAIUnsupportedFileType(f"Filetype {extension} is currently not supported. "
f"Supported filetypes: {', '.join(current_app.config['SUPPORTED_FILE_TYPES'])}")

View File

@@ -10,6 +10,7 @@ class EveAIException(Exception):
def to_dict(self):
rv = dict(self.payload or ())
rv['message'] = self.message
rv['error'] = self.__class__.__name__
return rv
@@ -41,3 +42,9 @@ class EveAINoLicenseForTenant(EveAIException):
super().__init__(message, status_code, payload)
class EveAITenantNotFound(EveAIException):
"""Raised when a tenant is not found"""
def __init__(self, message="Tenant not found", status_code=400, payload=None):
super().__init__(message, status_code, payload)

View File

@@ -24,9 +24,6 @@ def mw_before_request():
if not tenant_id:
raise Exception('Cannot switch schema for tenant: no tenant defined in session')
for role in current_user.roles:
current_app.logger.debug(f'In middleware: User {current_user.email} has role {role.name}')
# user = User.query.get(current_user.id)
if current_user.has_role('Super User') or current_user.tenant_id == tenant_id:
Database(tenant_id).switch_schema()

View File

@@ -1,249 +1,36 @@
import os
from typing import Dict, Any, Optional
import langcodes
from flask import current_app
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain_anthropic import ChatAnthropic
from langchain_core.pydantic_v1 import BaseModel, Field
from typing import List, Any, Iterator
from collections.abc import MutableMapping
from openai import OpenAI
from portkey_ai import createHeaders, PORTKEY_GATEWAY_URL
from portkey_ai.langchain.portkey_langchain_callback_handler import LangchainCallbackHandler
from common.langchain.llm_metrics_handler import LLMMetricsHandler
from common.langchain.templates.template_manager import TemplateManager
from langchain_openai import OpenAIEmbeddings, ChatOpenAI, OpenAI
from langchain_anthropic import ChatAnthropic
from flask import current_app
from datetime import datetime as dt, timezone as tz
from common.langchain.tracked_openai_embeddings import TrackedOpenAIEmbeddings
from common.langchain.tracked_transcribe import tracked_transcribe
from common.models.document import EmbeddingSmallOpenAI, EmbeddingLargeOpenAI, Catalog
from common.langchain.tracked_transcription import TrackedOpenAITranscription
from common.models.user import Tenant
from common.utils.cache.base import CacheHandler
from config.model_config import MODEL_CONFIG
from common.utils.business_event_context import current_event
from common.extensions import template_manager, cache_manager
from common.models.document import EmbeddingLargeOpenAI, EmbeddingSmallOpenAI
from common.utils.eveai_exceptions import EveAITenantNotFound
class CitedAnswer(BaseModel):
"""Default docstring - to be replaced with actual prompt"""
def create_language_template(template: str, language: str) -> str:
"""
Replace language placeholder in template with specified language
answer: str = Field(
...,
description="The answer to the user question, based on the given sources",
)
citations: List[int] = Field(
...,
description="The integer IDs of the SPECIFIC sources that were used to generate the answer"
)
insufficient_info: bool = Field(
False, # Default value is set to False
description="A boolean indicating wether given sources were sufficient or not to generate the answer"
)
Args:
template: Template string with {language} placeholder
language: Language code to insert
def set_language_prompt_template(cls, language_prompt):
cls.__doc__ = language_prompt
class ModelVariables(MutableMapping):
def __init__(self, tenant: Tenant, catalog_id=None):
self.tenant = tenant
self.catalog_id = catalog_id
self._variables = self._initialize_variables()
self._embedding_model = None
self._llm = None
self._llm_no_rag = None
self._transcription_client = None
self._prompt_templates = {}
self._embedding_db_model = None
self.llm_metrics_handler = LLMMetricsHandler()
self._transcription_client = None
def _initialize_variables(self):
variables = {}
# Get the Catalog if catalog_id is passed
if self.catalog_id:
catalog = Catalog.query.get_or_404(self.catalog_id)
# We initialize the variables that are available knowing the tenant.
variables['embed_tuning'] = catalog.embed_tuning or False
# Set HTML Chunking Variables
variables['html_tags'] = catalog.html_tags
variables['html_end_tags'] = catalog.html_end_tags
variables['html_included_elements'] = catalog.html_included_elements
variables['html_excluded_elements'] = catalog.html_excluded_elements
variables['html_excluded_classes'] = catalog.html_excluded_classes
# Set Chunk Size variables
variables['min_chunk_size'] = catalog.min_chunk_size
variables['max_chunk_size'] = catalog.max_chunk_size
# Set the RAG Context (will have to change once specialists are defined
variables['rag_context'] = self.tenant.rag_context or " "
# Temporary setting until we have Specialists
variables['rag_tuning'] = False
variables['RAG_temperature'] = 0.3
variables['no_RAG_temperature'] = 0.5
variables['k'] = 8
variables['similarity_threshold'] = 0.4
# Set model providers
variables['embedding_provider'], variables['embedding_model'] = self.tenant.embedding_model.rsplit('.', 1)
variables['llm_provider'], variables['llm_model'] = self.tenant.llm_model.rsplit('.', 1)
variables["templates"] = current_app.config['PROMPT_TEMPLATES'][(f"{variables['llm_provider']}."
f"{variables['llm_model']}")]
current_app.logger.info(f"Loaded prompt templates: \n")
current_app.logger.info(f"{variables['templates']}")
# Set model-specific configurations
model_config = MODEL_CONFIG.get(variables['llm_provider'], {}).get(variables['llm_model'], {})
variables.update(model_config)
variables['annotation_chunk_length'] = current_app.config['ANNOTATION_TEXT_CHUNK_LENGTH'][self.tenant.llm_model]
if variables['tool_calling_supported']:
variables['cited_answer_cls'] = CitedAnswer
variables['max_compression_duration'] = current_app.config['MAX_COMPRESSION_DURATION']
variables['max_transcription_duration'] = current_app.config['MAX_TRANSCRIPTION_DURATION']
variables['compression_cpu_limit'] = current_app.config['COMPRESSION_CPU_LIMIT']
variables['compression_process_delay'] = current_app.config['COMPRESSION_PROCESS_DELAY']
return variables
@property
def embedding_model(self):
api_key = os.getenv('OPENAI_API_KEY')
model = self._variables['embedding_model']
self._embedding_model = TrackedOpenAIEmbeddings(api_key=api_key,
model=model,
)
self._embedding_db_model = EmbeddingSmallOpenAI \
if model == 'text-embedding-3-small' \
else EmbeddingLargeOpenAI
return self._embedding_model
@property
def llm(self):
api_key = self.get_api_key_for_llm()
self._llm = ChatOpenAI(api_key=api_key,
model=self._variables['llm_model'],
temperature=self._variables['RAG_temperature'],
callbacks=[self.llm_metrics_handler])
return self._llm
@property
def llm_no_rag(self):
api_key = self.get_api_key_for_llm()
self._llm_no_rag = ChatOpenAI(api_key=api_key,
model=self._variables['llm_model'],
temperature=self._variables['RAG_temperature'],
callbacks=[self.llm_metrics_handler])
return self._llm_no_rag
def get_api_key_for_llm(self):
if self._variables['llm_provider'] == 'openai':
api_key = os.getenv('OPENAI_API_KEY')
else: # self._variables['llm_provider'] == 'anthropic'
api_key = os.getenv('ANTHROPIC_API_KEY')
return api_key
@property
def transcription_client(self):
api_key = os.getenv('OPENAI_API_KEY')
self._transcription_client = OpenAI(api_key=api_key, )
self._variables['transcription_model'] = 'whisper-1'
return self._transcription_client
def transcribe(self, *args, **kwargs):
return tracked_transcribe(self._transcription_client, *args, **kwargs)
@property
def embedding_db_model(self):
if self._embedding_db_model is None:
self._embedding_db_model = self.get_embedding_db_model()
return self._embedding_db_model
def get_embedding_db_model(self):
current_app.logger.debug("In get_embedding_db_model")
if self._embedding_db_model is None:
self._embedding_db_model = EmbeddingSmallOpenAI \
if self._variables['embedding_model'] == 'text-embedding-3-small' \
else EmbeddingLargeOpenAI
current_app.logger.debug(f"Embedding DB Model: {self._embedding_db_model}")
return self._embedding_db_model
def get_prompt_template(self, template_name: str) -> str:
current_app.logger.info(f"Getting prompt template for {template_name}")
if template_name not in self._prompt_templates:
self._prompt_templates[template_name] = self._load_prompt_template(template_name)
return self._prompt_templates[template_name]
def _load_prompt_template(self, template_name: str) -> str:
# In the future, this method will make an API call to Portkey
# For now, we'll simulate it with a placeholder implementation
# You can replace this with your current prompt loading logic
return self._variables['templates'][template_name]
def __getitem__(self, key: str) -> Any:
current_app.logger.debug(f"ModelVariables: Getting {key}")
# Support older template names (suffix = _template)
if key.endswith('_template'):
key = key[:-len('_template')]
current_app.logger.debug(f"ModelVariables: Getting modified {key}")
if key == 'embedding_model':
return self.embedding_model
elif key == 'embedding_db_model':
return self.embedding_db_model
elif key == 'llm':
return self.llm
elif key == 'llm_no_rag':
return self.llm_no_rag
elif key == 'transcription_client':
return self.transcription_client
elif key in self._variables.get('prompt_templates', []):
return self.get_prompt_template(key)
else:
value = self._variables.get(key)
if value is not None:
return value
else:
raise KeyError(f'Variable {key} does not exist in ModelVariables')
def __setitem__(self, key: str, value: Any) -> None:
self._variables[key] = value
def __delitem__(self, key: str) -> None:
del self._variables[key]
def __iter__(self) -> Iterator[str]:
return iter(self._variables)
def __len__(self):
return len(self._variables)
def get(self, key: str, default: Any = None) -> Any:
return self.__getitem__(key) or default
def update(self, **kwargs) -> None:
self._variables.update(kwargs)
def items(self):
return self._variables.items()
def keys(self):
return self._variables.keys()
def values(self):
return self._variables.values()
def select_model_variables(tenant, catalog_id=None):
model_variables = ModelVariables(tenant=tenant, catalog_id=catalog_id)
return model_variables
def create_language_template(template, language):
Returns:
str: Template with language placeholder replaced
"""
try:
full_language = langcodes.Language.make(language=language)
language_template = template.replace('{language}', full_language.display_name())
@@ -253,5 +40,249 @@ def create_language_template(template, language):
return language_template
def replace_variable_in_template(template, variable, value):
return template.replace(variable, value)
def replace_variable_in_template(template: str, variable: str, value: str) -> str:
"""
Replace a variable placeholder in template with specified value
Args:
template: Template string with variable placeholder
variable: Variable placeholder to replace (e.g. "{tenant_context}")
value: Value to insert
Returns:
str: Template with variable placeholder replaced
"""
return template.replace(variable, value or "")
class ModelVariables:
"""Manages model-related variables and configurations"""
def __init__(self, tenant_id: int, variables: Dict[str, Any] = None):
"""
Initialize ModelVariables with tenant and optional template manager
Args:
tenant: Tenant instance
template_manager: Optional TemplateManager instance
"""
current_app.logger.info(f'Model variables initialized with tenant {tenant_id} and variables \n{variables}')
self.tenant_id = tenant_id
self._variables = variables if variables is not None else self._initialize_variables()
current_app.logger.info(f'Model _variables initialized to {self._variables}')
self._embedding_model = None
self._embedding_model_class = None
self._llm_instances = {}
self.llm_metrics_handler = LLMMetricsHandler()
self._transcription_model = None
def _initialize_variables(self) -> Dict[str, Any]:
"""Initialize the variables dictionary"""
variables = {}
tenant = Tenant.query.get(self.tenant_id)
if not tenant:
raise EveAITenantNotFound(f"Tenant {self.tenant_id} not found")
# Set model providers
variables['embedding_provider'], variables['embedding_model'] = tenant.embedding_model.split('.')
variables['llm_provider'], variables['llm_model'] = tenant.llm_model.split('.')
variables['llm_full_model'] = tenant.llm_model
# Set model-specific configurations
model_config = MODEL_CONFIG.get(variables['llm_provider'], {}).get(variables['llm_model'], {})
variables.update(model_config)
# Additional configurations
variables['annotation_chunk_length'] = current_app.config['ANNOTATION_TEXT_CHUNK_LENGTH'][tenant.llm_model]
variables['max_compression_duration'] = current_app.config['MAX_COMPRESSION_DURATION']
variables['max_transcription_duration'] = current_app.config['MAX_TRANSCRIPTION_DURATION']
variables['compression_cpu_limit'] = current_app.config['COMPRESSION_CPU_LIMIT']
variables['compression_process_delay'] = current_app.config['COMPRESSION_PROCESS_DELAY']
return variables
@property
def embedding_model(self):
"""Get the embedding model instance"""
if self._embedding_model is None:
api_key = os.getenv('OPENAI_API_KEY')
self._embedding_model = TrackedOpenAIEmbeddings(
api_key=api_key,
model=self._variables['embedding_model']
)
return self._embedding_model
@property
def embedding_model_class(self):
"""Get the embedding model class"""
if self._embedding_model_class is None:
if self._variables['embedding_model'] == 'text-embedding-3-large':
self._embedding_model_class = EmbeddingLargeOpenAI
else: # text-embedding-3-small
self._embedding_model_class = EmbeddingSmallOpenAI
return self._embedding_model_class
@property
def annotation_chunk_length(self):
return self._variables['annotation_chunk_length']
@property
def max_compression_duration(self):
return self._variables['max_compression_duration']
@property
def max_transcription_duration(self):
return self._variables['max_transcription_duration']
@property
def compression_cpu_limit(self):
return self._variables['compression_cpu_limit']
@property
def compression_process_delay(self):
return self._variables['compression_process_delay']
def get_llm(self, temperature: float = 0.3, **kwargs) -> Any:
"""
Get an LLM instance with specific configuration
Args:
temperature: The temperature for the LLM
**kwargs: Additional configuration parameters
Returns:
An instance of the configured LLM
"""
cache_key = f"{temperature}_{hash(frozenset(kwargs.items()))}"
if cache_key not in self._llm_instances:
provider = self._variables['llm_provider']
model = self._variables['llm_model']
if provider == 'openai':
self._llm_instances[cache_key] = ChatOpenAI(
api_key=os.getenv('OPENAI_API_KEY'),
model=model,
temperature=temperature,
callbacks=[self.llm_metrics_handler],
**kwargs
)
elif provider == 'anthropic':
self._llm_instances[cache_key] = ChatAnthropic(
api_key=os.getenv('ANTHROPIC_API_KEY'),
model=current_app.config['ANTHROPIC_LLM_VERSIONS'][model],
temperature=temperature,
callbacks=[self.llm_metrics_handler],
**kwargs
)
else:
raise ValueError(f"Unsupported LLM provider: {provider}")
return self._llm_instances[cache_key]
@property
def transcription_model(self) -> TrackedOpenAITranscription:
"""Get the transcription model instance"""
if self._transcription_model is None:
api_key = os.getenv('OPENAI_API_KEY')
self._transcription_model = TrackedOpenAITranscription(
api_key=api_key,
model='whisper-1'
)
return self._transcription_model
# Remove the old transcription-related methods since they're now handled by TrackedOpenAITranscription
@property
def transcription_client(self):
raise DeprecationWarning("Use transcription_model instead")
def transcribe(self, *args, **kwargs):
raise DeprecationWarning("Use transcription_model.transcribe() instead")
def get_template(self, template_name: str, version: Optional[str] = None) -> str:
"""
Get a template for the tenant's configured LLM
Args:
template_name: Name of the template to retrieve
version: Optional specific version to retrieve
Returns:
The template content
"""
try:
template = template_manager.get_template(
self._variables['llm_full_model'],
template_name,
version
)
return template.content
except Exception as e:
current_app.logger.error(f"Error getting template {template_name}: {str(e)}")
# Fall back to old template loading if template_manager fails
if template_name in self._variables.get('templates', {}):
return self._variables['templates'][template_name]
raise
class ModelVariablesCacheHandler(CacheHandler[ModelVariables]):
handler_name = 'model_vars_cache' # Used to access handler instance from cache_manager
def __init__(self, region):
super().__init__(region, 'model_variables')
self.configure_keys('tenant_id')
self.subscribe_to_model('Tenant', ['tenant_id'])
def to_cache_data(self, instance: ModelVariables) -> Dict[str, Any]:
return {
'tenant_id': instance.tenant_id,
'variables': instance._variables,
'last_updated': dt.now(tz=tz.utc).isoformat()
}
def from_cache_data(self, data: Dict[str, Any], tenant_id: int, **kwargs) -> ModelVariables:
instance = ModelVariables(tenant_id, data.get('variables'))
return instance
def should_cache(self, value: Dict[str, Any]) -> bool:
required_fields = {'tenant_id', 'variables'}
return all(field in value for field in required_fields)
# Register the handler with the cache manager
cache_manager.register_handler(ModelVariablesCacheHandler, 'model')
# Helper function to get cached model variables
def get_model_variables(tenant_id: int) -> ModelVariables:
return cache_manager.model_vars_cache.get(
lambda tenant_id: ModelVariables(tenant_id), # function to create ModelVariables if required
tenant_id=tenant_id
)
# Written in a long format, without lambda
# def get_model_variables(tenant_id: int) -> ModelVariables:
# """
# Get ModelVariables instance, either from cache or newly created
#
# Args:
# tenant_id: The tenant's ID
#
# Returns:
# ModelVariables: Instance with either cached or fresh data
#
# Raises:
# TenantNotFoundError: If tenant doesn't exist
# CacheStateError: If cached data is invalid
# """
#
# def create_new_instance(tenant_id: int) -> ModelVariables:
# """Creator function that's called when cache miss occurs"""
# return ModelVariables(tenant_id) # This will initialize fresh variables
#
# return cache_manager.model_vars_cache.get(
# create_new_instance, # Function to create new instance if needed
# tenant_id=tenant_id # Parameters passed to both get() and create_new_instance
# )

View File

@@ -1,4 +1,6 @@
import os
import sys
import gevent
import time
from flask import current_app
@@ -28,3 +30,17 @@ def sync_folder(file_path):
dir_fd = os.open(file_path, os.O_RDONLY)
os.fsync(dir_fd)
os.close(dir_fd)
def get_project_root():
"""Get the root directory of the project."""
# Use the module that's actually running (not this file)
module = sys.modules['__main__']
if hasattr(module, '__file__'):
# Get the path to the main module
main_path = os.path.abspath(module.__file__)
# Get the root directory (where the main module is located)
return os.path.dirname(main_path)
else:
# Fallback: use current working directory
return os.getcwd()

View File

@@ -4,7 +4,6 @@ from common.models.user import Tenant
# Definition of Trigger Handlers
def set_tenant_session_data(sender, user, **kwargs):
current_app.logger.debug(f"Setting tenant session data for user {user.id}")
tenant = Tenant.query.filter_by(id=user.tenant_id).first()
session['tenant'] = tenant.to_dict()
session['default_language'] = tenant.default_language

View File

@@ -11,7 +11,7 @@ def confirm_token(token, expiration=3600):
try:
email = serializer.loads(token, salt=current_app.config['SECURITY_PASSWORD_SALT'], max_age=expiration)
except Exception as e:
current_app.logger.debug(f'Error confirming token: {e}')
current_app.logger.error(f'Error confirming token: {e}')
raise
return email
@@ -35,14 +35,11 @@ def generate_confirmation_token(email):
def send_confirmation_email(user):
current_app.logger.debug(f'Sending confirmation email to {user.email}')
if not test_smtp_connection():
raise Exception("Failed to connect to SMTP server")
token = generate_confirmation_token(user.email)
confirm_url = prefixed_url_for('security_bp.confirm_email', token=token, _external=True)
current_app.logger.debug(f'Confirmation URL: {confirm_url}')
html = render_template('email/activate.html', confirm_url=confirm_url)
subject = "Please confirm your email"
@@ -56,10 +53,8 @@ def send_confirmation_email(user):
def send_reset_email(user):
current_app.logger.debug(f'Sending reset email to {user.email}')
token = generate_reset_token(user.email)
reset_url = prefixed_url_for('security_bp.reset_password', token=token, _external=True)
current_app.logger.debug(f'Reset URL: {reset_url}')
html = render_template('email/reset_password.html', reset_url=reset_url)
subject = "Reset Your Password"

View File

@@ -0,0 +1,112 @@
from typing import List, Union
import re
class StringListConverter:
"""Utility class for converting between comma-separated strings and lists"""
@staticmethod
def string_to_list(input_string: Union[str, None], allow_empty: bool = True) -> List[str]:
"""
Convert a comma-separated string to a list of strings.
Args:
input_string: Comma-separated string to convert
allow_empty: If True, returns empty list for None/empty input
If False, raises ValueError for None/empty input
Returns:
List of stripped strings
Raises:
ValueError: If input is None/empty and allow_empty is False
"""
if not input_string:
if allow_empty:
return []
raise ValueError("Input string cannot be None or empty")
return [item.strip() for item in input_string.split(',') if item.strip()]
@staticmethod
def list_to_string(input_list: Union[List[str], None], allow_empty: bool = True) -> str:
"""
Convert a list of strings to a comma-separated string.
Args:
input_list: List of strings to convert
allow_empty: If True, returns empty string for None/empty input
If False, raises ValueError for None/empty input
Returns:
Comma-separated string
Raises:
ValueError: If input is None/empty and allow_empty is False
"""
if not input_list:
if allow_empty:
return ''
raise ValueError("Input list cannot be None or empty")
return ', '.join(str(item).strip() for item in input_list)
@staticmethod
def validate_format(input_string: str,
allowed_chars: str = r'a-zA-Z0-9_\-',
min_length: int = 1,
max_length: int = 50) -> bool:
"""
Validate the format of items in a comma-separated string.
Args:
input_string: String to validate
allowed_chars: String of allowed characters (for regex pattern)
min_length: Minimum length for each item
max_length: Maximum length for each item
Returns:
bool: True if format is valid, False otherwise
"""
if not input_string:
return False
# Create regex pattern for individual items
pattern = f'^[{allowed_chars}]{{{min_length},{max_length}}}$'
try:
# Convert to list and check each item
items = StringListConverter.string_to_list(input_string)
return all(bool(re.match(pattern, item)) for item in items)
except Exception:
return False
@staticmethod
def validate_and_convert(input_string: str,
allowed_chars: str = r'a-zA-Z0-9_\-',
min_length: int = 1,
max_length: int = 50) -> List[str]:
"""
Validate and convert a comma-separated string to a list.
Args:
input_string: String to validate and convert
allowed_chars: String of allowed characters (for regex pattern)
min_length: Minimum length for each item
max_length: Maximum length for each item
Returns:
List of validated and converted strings
Raises:
ValueError: If input string format is invalid
"""
if not StringListConverter.validate_format(
input_string, allowed_chars, min_length, max_length
):
raise ValueError(
f"Invalid format. Items must be {min_length}-{max_length} characters "
f"long and contain only these characters: {allowed_chars}"
)
return StringListConverter.string_to_list(input_string)

View File

@@ -44,7 +44,7 @@ def form_validation_failed(request, form):
for fieldName, errorMessages in form.errors.items():
for err in errorMessages:
flash(f"Error in {fieldName}: {err}", 'danger')
current_app.logger.debug(f"Error in {fieldName}: {err}")
current_app.logger.error(f"Error in {fieldName}: {err}")
def form_to_dict(form):