- Introduction of dynamic Retrievers & Specialists

- Introduction of dynamic Processors
- Introduction of caching system
- Introduction of a better template manager
- Adaptation of ModelVariables to support dynamic Processors / Retrievers / Specialists
- Start adaptation of chat client
This commit is contained in:
Josako
2024-11-15 10:00:53 +01:00
parent 55a8a95f79
commit 1807435339
101 changed files with 4181 additions and 1764 deletions

View File

@@ -4,7 +4,6 @@ from contextlib import contextmanager
from datetime import datetime
from typing import Dict, Any, Optional
from datetime import datetime as dt, timezone as tz
from portkey_ai import Portkey, Config
import logging
from .business_event_context import BusinessEventContext

0
common/utils/cache/__init__old.py vendored Normal file
View File

89
common/utils/cache/base.py vendored Normal file
View File

@@ -0,0 +1,89 @@
# common/utils/cache/base.py
from typing import Any, Dict, List, Optional, TypeVar, Generic, Type
from dataclasses import dataclass
from flask import Flask
from dogpile.cache import CacheRegion
T = TypeVar('T')
@dataclass
class CacheKey:
"""Represents a cache key with multiple components"""
components: Dict[str, Any]
def __str__(self) -> str:
return ":".join(f"{k}={v}" for k, v in sorted(self.components.items()))
class CacheInvalidationManager:
"""Manages cache invalidation subscriptions"""
def __init__(self):
self._subscribers = {}
def subscribe(self, model: str, handler: 'CacheHandler', key_fields: List[str]):
if model not in self._subscribers:
self._subscribers[model] = []
self._subscribers[model].append((handler, key_fields))
def notify_change(self, model: str, **identifiers):
if model in self._subscribers:
for handler, key_fields in self._subscribers[model]:
if all(field in identifiers for field in key_fields):
handler.invalidate_by_model(model, **identifiers)
class CacheHandler(Generic[T]):
"""Base cache handler implementation"""
def __init__(self, region: CacheRegion, prefix: str):
self.region = region
self.prefix = prefix
self._key_components = []
def configure_keys(self, *components: str):
self._key_components = components
return self
def subscribe_to_model(self, model: str, key_fields: List[str]):
invalidation_manager.subscribe(model, self, key_fields)
return self
def generate_key(self, **identifiers) -> str:
missing = set(self._key_components) - set(identifiers.keys())
if missing:
raise ValueError(f"Missing key components: {missing}")
key = CacheKey({k: identifiers[k] for k in self._key_components})
return f"{self.prefix}:{str(key)}"
def get(self, creator_func, **identifiers) -> T:
cache_key = self.generate_key(**identifiers)
def creator():
instance = creator_func(**identifiers)
return self.to_cache_data(instance)
cached_data = self.region.get_or_create(
cache_key,
creator,
should_cache_fn=self.should_cache
)
return self.from_cache_data(cached_data, **identifiers)
def invalidate(self, **identifiers):
cache_key = self.generate_key(**identifiers)
self.region.delete(cache_key)
def invalidate_by_model(self, model: str, **identifiers):
try:
self.invalidate(**identifiers)
except ValueError:
pass
# Create global invalidation manager
invalidation_manager = CacheInvalidationManager()

View File

@@ -0,0 +1,32 @@
from typing import Type
from flask import Flask
from common.utils.cache.base import CacheHandler
class EveAICacheManager:
"""Cache manager with registration capabilities"""
def __init__(self):
self.model_region = None
self.eveai_chat_workers_region = None
self.eveai_workers_region = None
self._handlers = {}
def init_app(self, app: Flask):
"""Initialize cache regions"""
from common.utils.cache.regions import create_cache_regions
self.model_region, self.eveai_chat_workers_region, self.eveai_workers_region = create_cache_regions(app)
# Initialize all registered handlers with their regions
for handler_class, region_name in self._handlers.items():
region = getattr(self, f"{region_name}_region")
handler_instance = handler_class(region)
setattr(self, handler_class.handler_name, handler_instance)
def register_handler(self, handler_class: Type[CacheHandler], region: str):
"""Register a cache handler class with its region"""
if not hasattr(handler_class, 'handler_name'):
raise ValueError("Cache handler must define handler_name class attribute")
self._handlers[handler_class] = region

61
common/utils/cache/regions.py vendored Normal file
View File

@@ -0,0 +1,61 @@
# common/utils/cache/regions.py
from dogpile.cache import make_region
from flask import current_app
from urllib.parse import urlparse
import os
def get_redis_config(app):
"""
Create Redis configuration dict based on app config
Handles both authenticated and non-authenticated setups
"""
# Parse the REDIS_BASE_URI to get all components
redis_uri = urlparse(app.config['REDIS_BASE_URI'])
config = {
'host': redis_uri.hostname,
'port': int(redis_uri.port or 6379),
'db': 4, # Keep this for later use
'redis_expiration_time': 3600,
'distributed_lock': True
}
# Add authentication if provided
if redis_uri.username and redis_uri.password:
config.update({
'username': redis_uri.username,
'password': redis_uri.password
})
return config
def create_cache_regions(app):
"""Initialize all cache regions with app config"""
redis_config = get_redis_config(app)
# Region for model-related caching (ModelVariables etc)
model_region = make_region(name='model').configure(
'dogpile.cache.redis',
arguments=redis_config,
replace_existing_backend=True
)
# Region for eveai_chat_workers components (Specialists, Retrievers, ...)
eveai_chat_workers_region = make_region(name='chat_workers').configure(
'dogpile.cache.redis',
arguments=redis_config, # arguments={**redis_config, 'db': 4}, # Different DB
replace_existing_backend=True
)
# Region for eveai_workers components (Processors, ...)
eveai_workers_region = make_region(name='workers').configure(
'dogpile.cache.redis',
arguments=redis_config, # Same config for now
replace_existing_backend=True
)
return model_region, eveai_chat_workers_region, eveai_workers_region

View File

@@ -8,8 +8,6 @@ celery_app = Celery()
def init_celery(celery, app, is_beat=False):
celery_app.main = app.name
app.logger.debug(f'CELERY_BROKER_URL: {app.config["CELERY_BROKER_URL"]}')
app.logger.debug(f'CELERY_RESULT_BACKEND: {app.config["CELERY_RESULT_BACKEND"]}')
celery_config = {
'broker_url': app.config.get('CELERY_BROKER_URL', 'redis://localhost:6379/0'),

View File

@@ -0,0 +1,613 @@
from typing import Optional, List, Union, Dict, Any, Pattern
from pydantic import BaseModel, field_validator, model_validator
from typing_extensions import Annotated
import re
from datetime import datetime
import json
from textwrap import dedent
import yaml
from dataclasses import dataclass
class TaggingField(BaseModel):
"""Represents a single tagging field configuration"""
type: str
required: bool = False
description: Optional[str] = None
allowed_values: Optional[List[Any]] = None # for enum type
min_value: Optional[Union[int, float]] = None # for numeric types
max_value: Optional[Union[int, float]] = None # for numeric types
@field_validator('type', mode='before')
@classmethod
def validate_type(cls, v: str) -> str:
valid_types = ['string', 'integer', 'float', 'date', 'enum']
if v not in valid_types:
raise ValueError(f'type must be one of {valid_types}')
return v
@model_validator(mode='after')
def validate_field_constraints(self) -> 'TaggingField':
# Validate enum constraints
if self.type == 'enum':
if not self.allowed_values:
raise ValueError('allowed_values must be provided for enum type')
elif self.allowed_values is not None:
raise ValueError('allowed_values only valid for enum type')
# Validate numeric constraints
if self.type not in ('integer', 'float'):
if self.min_value is not None or self.max_value is not None:
raise ValueError('min_value/max_value only valid for numeric types')
else:
if self.min_value is not None and self.max_value is not None and self.min_value >= self.max_value:
raise ValueError('min_value must be less than max_value')
return self
class TaggingFields(BaseModel):
"""Represents a collection of tagging fields, mapped by their names"""
fields: Dict[str, TaggingField]
@classmethod
def from_dict(cls, data: Dict[str, Dict[str, Any]]) -> 'TaggingFields':
return cls(fields={
field_name: TaggingField(**field_config)
for field_name, field_config in data.items()
})
def to_dict(self) -> Dict[str, Dict[str, Any]]:
return {
field_name: field.model_dump(exclude_none=True)
for field_name, field in self.fields.items()
}
class ArgumentConstraint(BaseModel):
"""Base class for all argument constraints"""
description: Optional[str] = None
error_message: Optional[str] = None
class NumericConstraint(ArgumentConstraint):
"""Constraints for numeric values (int/float)"""
min_value: Optional[float] = None
max_value: Optional[float] = None
include_min: bool = True # True for >= min_value, False for > min_value
include_max: bool = True # True for <= max_value, False for < max_value
@model_validator(mode='after')
def validate_ranges(self) -> 'NumericConstraint':
if self.min_value is not None and self.max_value is not None:
if self.min_value > self.max_value:
raise ValueError("min_value must be less than or equal to max_value")
return self
def validate(self, value: Union[int, float]) -> bool:
if self.min_value is not None:
if self.include_min and value < self.min_value:
return False
if not self.include_min and value <= self.min_value:
return False
if self.max_value is not None:
if self.include_max and value > self.max_value:
return False
if not self.include_max and value >= self.max_value:
return False
return True
class StringConstraint(ArgumentConstraint):
"""Constraints for string values"""
min_length: Optional[int] = None
max_length: Optional[int] = None
patterns: Optional[List[str]] = None # List of regex patterns to match
pattern_match_all: bool = False # If True, string must match all patterns
forbidden_patterns: Optional[List[str]] = None # List of regex patterns that must not match
allow_empty: bool = False
@field_validator('patterns', 'forbidden_patterns')
@classmethod
def validate_patterns(cls, v: Optional[List[str]]) -> Optional[List[str]]:
if v is not None:
# Validate each pattern compiles
for pattern in v:
try:
re.compile(pattern)
except re.error as e:
raise ValueError(f"Invalid regex pattern '{pattern}': {str(e)}")
return v
def validate(self, value: str) -> bool:
if not self.allow_empty and not value:
return False
if self.min_length is not None and len(value) < self.min_length:
return False
if self.max_length is not None and len(value) > self.max_length:
return False
if self.patterns:
matches = [bool(re.search(pattern, value)) for pattern in self.patterns]
if self.pattern_match_all and not all(matches):
return False
if not self.pattern_match_all and not any(matches):
return False
if self.forbidden_patterns:
for pattern in self.forbidden_patterns:
if re.search(pattern, value):
return False
return True
class DateConstraint(ArgumentConstraint):
"""Constraints for date values"""
min_date: Optional[datetime] = None
max_date: Optional[datetime] = None
include_min: bool = True
include_max: bool = True
allowed_formats: Optional[List[str]] = None # List of allowed date formats
@model_validator(mode='after')
def validate_ranges(self) -> 'DateConstraint':
if self.min_date and self.max_date and self.min_date > self.max_date:
raise ValueError("min_date must be less than or equal to max_date")
return self
def validate(self, value: datetime) -> bool:
if self.min_date is not None:
if self.include_min and value < self.min_date:
return False
if not self.include_min and value <= self.min_date:
return False
if self.max_date is not None:
if self.include_max and value > self.max_date:
return False
if not self.include_max and value >= self.max_date:
return False
return True
class EnumConstraint(ArgumentConstraint):
"""Constraints for enum values"""
allowed_values: List[Any]
case_sensitive: bool = True # For string enums
allow_multiple: bool = False # If True, value can be a list of allowed values
min_selections: Optional[int] = None # When allow_multiple is True
max_selections: Optional[int] = None # When allow_multiple is True
@model_validator(mode='after')
def validate_selections(self) -> 'EnumConstraint':
if self.allow_multiple:
if self.min_selections is not None and self.max_selections is not None:
if self.min_selections > self.max_selections:
raise ValueError("min_selections must be less than or equal to max_selections")
if self.max_selections > len(self.allowed_values):
raise ValueError("max_selections cannot be greater than number of allowed values")
return self
def validate(self, value: Union[Any, List[Any]]) -> bool:
if self.allow_multiple:
if not isinstance(value, list):
return False
if self.min_selections is not None and len(value) < self.min_selections:
return False
if self.max_selections is not None and len(value) > self.max_selections:
return False
for v in value:
if not self._validate_single_value(v):
return False
else:
return self._validate_single_value(value)
return True
def _validate_single_value(self, value: Any) -> bool:
if isinstance(value, str) and not self.case_sensitive:
return any(str(value).lower() == str(v).lower() for v in self.allowed_values)
return value in self.allowed_values
class ArgumentDefinition(BaseModel):
"""Defines an argument with its type and constraints"""
name: str
type: str
description: Optional[str] = None
required: bool = False
default: Optional[Any] = None
constraints: Optional[Union[NumericConstraint, StringConstraint, DateConstraint, EnumConstraint]] = None
@field_validator('type')
@classmethod
def validate_type(cls, v: str) -> str:
valid_types = ['string', 'integer', 'float', 'date', 'enum']
if v not in valid_types:
raise ValueError(f'type must be one of {valid_types}')
return v
@model_validator(mode='after')
def validate_constraints(self) -> 'ArgumentDefinition':
if self.constraints:
expected_constraint_types = {
'string': StringConstraint,
'integer': NumericConstraint,
'float': NumericConstraint,
'date': DateConstraint,
'enum': EnumConstraint
}
expected_type = expected_constraint_types.get(self.type)
if not isinstance(self.constraints, expected_type):
raise ValueError(f'Constraints for type {self.type} must be of type {expected_type.__name__}')
if self.default is not None:
if not self.constraints.validate(self.default):
raise ValueError(f'Default value does not satisfy constraints for {self.name}')
return self
class ArgumentDefinitions(BaseModel):
"""Collection of argument definitions"""
arguments: Dict[str, ArgumentDefinition]
@classmethod
def from_dict(cls, data: Dict[str, Dict[str, Any]]) -> 'ArgumentDefinitions':
return cls(arguments={
arg_name: ArgumentDefinition(**arg_config)
for arg_name, arg_config in data.items()
})
def to_dict(self) -> Dict[str, Dict[str, Any]]:
return {
arg_name: arg.model_dump(exclude_none=True)
for arg_name, arg in self.arguments.items()
}
def validate_argument_values(self, values: Dict[str, Any]) -> Dict[str, str]:
"""
Validate a set of argument values against their definitions
Returns a dictionary of error messages for invalid arguments
"""
errors = {}
# Check for required arguments
for name, arg_def in self.arguments.items():
if arg_def.required and name not in values:
errors[name] = "Required argument missing"
continue
if name in values:
value = values[name]
# Validate type
try:
if arg_def.type == 'integer':
value = int(value)
elif arg_def.type == 'float':
value = float(value)
elif arg_def.type == 'date' and isinstance(value, str):
if arg_def.constraints and arg_def.constraints.allowed_formats:
for fmt in arg_def.constraints.allowed_formats:
try:
value = datetime.strptime(value, fmt)
break
except ValueError:
continue
else:
errors[
name] = f"Invalid date format. Allowed formats: {arg_def.constraints.allowed_formats}"
continue
except (ValueError, TypeError):
errors[name] = f"Invalid type. Expected {arg_def.type}"
continue
# Validate constraints
if arg_def.constraints and not arg_def.constraints.validate(value):
errors[name] = arg_def.constraints.error_message or "Value does not satisfy constraints"
return errors
@dataclass
class DocumentationFormat:
"""Constants for documentation formats"""
MARKDOWN = "markdown"
JSON = "json"
YAML = "yaml"
@dataclass
class DocumentationVersion:
"""Constants for documentation versions"""
BASIC = "basic" # Original documentation without retriever info
EXTENDED = "extended" # Including retriever documentation
def _generate_argument_constraints(field_config: Dict[str, Any]) -> List[Dict[str, Any]]:
"""Generate possible argument constraints based on field type"""
constraints = []
base_constraint = {
"description": f"Constraint for {field_config.get('description', 'field')}",
"error_message": "Optional custom error message"
}
if field_config["type"] == "integer" or field_config["type"] == "float":
constraints.append({
**base_constraint,
"type": "NumericConstraint",
"possible_constraints": {
"min_value": "number",
"max_value": "number",
"include_min": "boolean",
"include_max": "boolean"
},
"example": {
"min_value": field_config.get("min_value", 0),
"max_value": field_config.get("max_value", 100),
"include_min": True,
"include_max": True
}
})
elif field_config["type"] == "string":
constraints.append({
**base_constraint,
"type": "StringConstraint",
"possible_constraints": {
"min_length": "integer",
"max_length": "integer",
"patterns": "list[str]",
"pattern_match_all": "boolean",
"forbidden_patterns": "list[str]",
"allow_empty": "boolean"
},
"example": {
"min_length": 1,
"max_length": 100,
"patterns": ["^[A-Za-z0-9]+$"],
"pattern_match_all": False,
"forbidden_patterns": ["^test_", "_temp$"],
"allow_empty": False
}
})
elif field_config["type"] == "enum":
constraints.append({
**base_constraint,
"type": "EnumConstraint",
"possible_constraints": {
"allowed_values": f"list[{field_config.get('allowed_values', ['value1', 'value2'])}]",
"case_sensitive": "boolean",
"allow_multiple": "boolean",
"min_selections": "integer",
"max_selections": "integer"
},
"example": {
"allowed_values": field_config.get("allowed_values", ["value1", "value2"]),
"case_sensitive": True,
"allow_multiple": True,
"min_selections": 1,
"max_selections": 2
}
})
elif field_config["type"] == "date":
constraints.append({
**base_constraint,
"type": "DateConstraint",
"possible_constraints": {
"min_date": "datetime",
"max_date": "datetime",
"include_min": "boolean",
"include_max": "boolean",
"allowed_formats": "list[str]"
},
"example": {
"min_date": "2024-01-01T00:00:00",
"max_date": "2024-12-31T23:59:59",
"include_min": True,
"include_max": True,
"allowed_formats": ["%Y-%m-%d", "%Y/%m/%d"]
}
})
return constraints
def generate_field_documentation(
tagging_fields: Dict[str, Any],
format: str = "markdown",
version: str = "basic"
) -> str:
"""
Generate documentation for tagging fields configuration.
Args:
tagging_fields: Dictionary containing tagging fields configuration
format: Output format ("markdown", "json", or "yaml")
version: Documentation version ("basic" or "extended")
Returns:
str: Formatted documentation
"""
if version not in [DocumentationVersion.BASIC, DocumentationVersion.EXTENDED]:
raise ValueError(f"Unsupported documentation version: {version}")
# Normalize fields configuration
normalized_fields = {}
for field_name, field_config in tagging_fields.items():
field_doc = {
"name": field_name,
"type": field_config["type"],
"required": field_config.get("required", False),
"description": field_config.get("description", "No description provided"),
"constraints": []
}
# Only include possible arguments in extended version
if version == DocumentationVersion.EXTENDED:
field_doc["possible_arguments"] = _generate_argument_constraints(field_config)
# Add type-specific constraints
if field_config["type"] == "integer" or field_config["type"] == "float":
if "min_value" in field_config:
field_doc["constraints"].append(
f"Minimum value: {field_config['min_value']}")
if "max_value" in field_config:
field_doc["constraints"].append(
f"Maximum value: {field_config['max_value']}")
elif field_config["type"] == "string":
if "min_length" in field_config:
field_doc["constraints"].append(
f"Minimum length: {field_config['min_length']}")
if "max_length" in field_config:
field_doc["constraints"].append(
f"Maximum length: {field_config['max_length']}")
if "patterns" in field_config:
field_doc["constraints"].append(
f"Must match patterns: {', '.join(field_config['patterns'])}")
elif field_config["type"] == "enum":
if "allowed_values" in field_config:
field_doc["constraints"].append(
f"Allowed values: {', '.join(str(v) for v in field_config['allowed_values'])}")
elif field_config["type"] == "date":
if "min_date" in field_config:
field_doc["constraints"].append(
f"Minimum date: {field_config['min_date']}")
if "max_date" in field_config:
field_doc["constraints"].append(
f"Maximum date: {field_config['max_date']}")
if "allowed_formats" in field_config:
field_doc["constraints"].append(
f"Allowed formats: {', '.join(field_config['allowed_formats'])}")
normalized_fields[field_name] = field_doc
# Generate documentation in requested format
if format == DocumentationFormat.MARKDOWN:
return _generate_markdown_docs(normalized_fields, version)
elif format == DocumentationFormat.JSON:
return _generate_json_docs(normalized_fields, version)
elif format == DocumentationFormat.YAML:
return _generate_yaml_docs(normalized_fields, version)
else:
raise ValueError(f"Unsupported documentation format: {format}")
def _generate_markdown_docs(fields: Dict[str, Any], version: str) -> str:
"""Generate markdown documentation"""
docs = ["# Tagging Fields Documentation\n"]
# Add overview table
docs.append("## Fields Overview\n")
docs.append("| Field Name | Type | Required | Description |")
docs.append("|------------|------|----------|-------------|")
for field_name, field in fields.items():
docs.append(
f"| {field_name} | {field['type']} | "
f"{'Yes' if field['required'] else 'No'} | {field['description']} |"
)
# Add detailed field specifications
docs.append("\n## Detailed Field Specifications\n")
for field_name, field in fields.items():
docs.append(f"### {field_name}\n")
docs.append(f"**Type:** {field['type']}")
docs.append(f"**Required:** {'Yes' if field['required'] else 'No'}")
docs.append(f"**Description:** {field['description']}\n")
if field["constraints"]:
docs.append("**Field Constraints:**")
for constraint in field["constraints"]:
docs.append(f"- {constraint}")
docs.append("")
# Add retriever argument documentation only in extended version
if version == DocumentationVersion.EXTENDED and "possible_arguments" in field:
docs.append("**Possible Retriever Arguments:**")
for arg_constraint in field["possible_arguments"]:
docs.append(f"\n*{arg_constraint['type']}*")
docs.append(f"Description: {arg_constraint['description']}")
docs.append("\nPossible constraints:")
for const_name, const_type in arg_constraint["possible_constraints"].items():
docs.append(f"- `{const_name}`: {const_type}")
docs.append("\nExample:")
docs.append("```python")
docs.append(json.dumps(arg_constraint["example"], indent=2))
docs.append("```\n")
# Add example retriever configuration only in extended version
if version == DocumentationVersion.EXTENDED:
docs.append("\n## Example Retriever Configuration\n")
docs.append("```python")
example_config = {
"metadata_filters": {
field_name: field["possible_arguments"][0]["example"]
for field_name, field in fields.items()
if "possible_arguments" in field
}
}
docs.append(json.dumps(example_config, indent=2))
docs.append("```")
return "\n".join(docs)
def _generate_json_docs(fields: Dict[str, Any], version: str) -> str:
"""Generate JSON documentation"""
doc = {
"tagging_fields_documentation": {
"version": version,
"fields": fields
}
}
if version == DocumentationVersion.EXTENDED:
doc["tagging_fields_documentation"]["example_retriever_config"] = {
"metadata_filters": {
field_name: field["possible_arguments"][0]["example"]
for field_name, field in fields.items()
if "possible_arguments" in field
}
}
return json.dumps(doc, indent=2)
def _generate_yaml_docs(fields: Dict[str, Any], version: str) -> str:
"""Generate YAML documentation"""
doc = {
"tagging_fields_documentation": {
"version": version,
"fields": fields
}
}
if version == DocumentationVersion.EXTENDED:
doc["tagging_fields_documentation"]["example_retriever_config"] = {
"metadata_filters": {
field_name: field["possible_arguments"][0]["example"]
for field_name, field in fields.items()
if "possible_arguments" in field
}
}
return yaml.dump(doc, sort_keys=False, default_flow_style=False)

View File

@@ -5,10 +5,8 @@ from common.models.user import Tenant, TenantDomain
def get_allowed_origins(tenant_id):
session_key = f"allowed_origins_{tenant_id}"
if session_key in session:
current_app.logger.debug(f"Fetching allowed origins for tenant {tenant_id} from session")
return session[session_key]
current_app.logger.debug(f"Fetching allowed origins for tenant {tenant_id} from database")
tenant_domains = TenantDomain.query.filter_by(tenant_id=int(tenant_id)).all()
allowed_origins = [domain.domain for domain in tenant_domains]
@@ -18,14 +16,8 @@ def get_allowed_origins(tenant_id):
def cors_after_request(response, prefix):
current_app.logger.debug(f'CORS after request: {request.path}, prefix: {prefix}')
current_app.logger.debug(f'request.headers: {request.headers}')
current_app.logger.debug(f'request.args: {request.args}')
current_app.logger.debug(f'request is json?: {request.is_json}')
# Exclude health checks from checks
if request.path.startswith('/healthz') or request.path.startswith('/_healthz'):
current_app.logger.debug('Skipping CORS headers for health checks')
response.headers.add('Access-Control-Allow-Origin', '*')
response.headers.add('Access-Control-Allow-Headers', '*')
response.headers.add('Access-Control-Allow-Methods', '*')
@@ -36,7 +28,6 @@ def cors_after_request(response, prefix):
# Try to get tenant_id from JSON payload
json_data = request.get_json(silent=True)
current_app.logger.debug(f'request.get_json(silent=True): {json_data}')
if json_data and 'tenant_id' in json_data:
tenant_id = json_data['tenant_id']
@@ -44,23 +35,17 @@ def cors_after_request(response, prefix):
# Fallback to get tenant_id from query parameters or headers if JSON is not available
tenant_id = request.args.get('tenant_id') or request.args.get('tenantId') or request.headers.get('X-Tenant-ID')
current_app.logger.debug(f'Identified tenant_id: {tenant_id}')
if tenant_id:
allowed_origins = get_allowed_origins(tenant_id)
current_app.logger.debug(f'Allowed origins for tenant {tenant_id}: {allowed_origins}')
else:
current_app.logger.warning('tenant_id not found in request')
origin = request.headers.get('Origin')
current_app.logger.debug(f'Origin: {origin}')
if origin in allowed_origins:
response.headers.add('Access-Control-Allow-Origin', origin)
response.headers.add('Access-Control-Allow-Headers', 'Content-Type,Authorization')
response.headers.add('Access-Control-Allow-Methods', 'GET,POST,PUT,DELETE,OPTIONS')
response.headers.add('Access-Control-Allow-Credentials', 'true')
current_app.logger.debug(f'CORS headers set for origin: {origin}')
else:
current_app.logger.warning(f'Origin {origin} not allowed')

View File

@@ -36,7 +36,7 @@ def log_request_middleware(app):
@app.before_request
def log_session_state_before():
app.logger.debug(f'Session state before request: {session.items()}')
pass
# @app.after_request
# def log_response_info(response):
@@ -58,5 +58,4 @@ def log_request_middleware(app):
@app.after_request
def log_session_state_after(response):
app.logger.debug(f'Session state after request: {session.items()}')
return response

View File

@@ -24,6 +24,7 @@ def create_document_stack(api_input, file, filename, extension, tenant_id):
# Create the DocumentVersion
new_doc_vers = create_version_for_document(new_doc, tenant_id,
api_input.get('url', ''),
api_input.get('sub_file_type', ''),
api_input.get('language', 'en'),
api_input.get('user_context', ''),
api_input.get('user_metadata'),
@@ -64,7 +65,7 @@ def create_document(form, filename, catalog_id):
return new_doc
def create_version_for_document(document, tenant_id, url, language, user_context, user_metadata, catalog_properties):
def create_version_for_document(document, tenant_id, url, sub_file_type, language, user_context, user_metadata, catalog_properties):
new_doc_vers = DocumentVersion()
if url != '':
new_doc_vers.url = url
@@ -83,6 +84,9 @@ def create_version_for_document(document, tenant_id, url, language, user_context
if catalog_properties != '' and catalog_properties is not None:
new_doc_vers.catalog_properties = catalog_properties
if sub_file_type != '':
new_doc_vers.sub_file_type = sub_file_type
new_doc_vers.document = document
set_logging_information(new_doc_vers, dt.now(tz.utc))
@@ -237,8 +241,6 @@ def start_embedding_task(tenant_id, doc_vers_id):
def validate_file_type(extension):
current_app.logger.debug(f'Validating file type {extension}')
current_app.logger.debug(f'Supported file types: {current_app.config["SUPPORTED_FILE_TYPES"]}')
if extension not in current_app.config['SUPPORTED_FILE_TYPES']:
raise EveAIUnsupportedFileType(f"Filetype {extension} is currently not supported. "
f"Supported filetypes: {', '.join(current_app.config['SUPPORTED_FILE_TYPES'])}")

View File

@@ -10,6 +10,7 @@ class EveAIException(Exception):
def to_dict(self):
rv = dict(self.payload or ())
rv['message'] = self.message
rv['error'] = self.__class__.__name__
return rv
@@ -41,3 +42,9 @@ class EveAINoLicenseForTenant(EveAIException):
super().__init__(message, status_code, payload)
class EveAITenantNotFound(EveAIException):
"""Raised when a tenant is not found"""
def __init__(self, message="Tenant not found", status_code=400, payload=None):
super().__init__(message, status_code, payload)

View File

@@ -24,9 +24,6 @@ def mw_before_request():
if not tenant_id:
raise Exception('Cannot switch schema for tenant: no tenant defined in session')
for role in current_user.roles:
current_app.logger.debug(f'In middleware: User {current_user.email} has role {role.name}')
# user = User.query.get(current_user.id)
if current_user.has_role('Super User') or current_user.tenant_id == tenant_id:
Database(tenant_id).switch_schema()

View File

@@ -1,249 +1,36 @@
import os
from typing import Dict, Any, Optional
import langcodes
from flask import current_app
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain_anthropic import ChatAnthropic
from langchain_core.pydantic_v1 import BaseModel, Field
from typing import List, Any, Iterator
from collections.abc import MutableMapping
from openai import OpenAI
from portkey_ai import createHeaders, PORTKEY_GATEWAY_URL
from portkey_ai.langchain.portkey_langchain_callback_handler import LangchainCallbackHandler
from common.langchain.llm_metrics_handler import LLMMetricsHandler
from common.langchain.templates.template_manager import TemplateManager
from langchain_openai import OpenAIEmbeddings, ChatOpenAI, OpenAI
from langchain_anthropic import ChatAnthropic
from flask import current_app
from datetime import datetime as dt, timezone as tz
from common.langchain.tracked_openai_embeddings import TrackedOpenAIEmbeddings
from common.langchain.tracked_transcribe import tracked_transcribe
from common.models.document import EmbeddingSmallOpenAI, EmbeddingLargeOpenAI, Catalog
from common.langchain.tracked_transcription import TrackedOpenAITranscription
from common.models.user import Tenant
from common.utils.cache.base import CacheHandler
from config.model_config import MODEL_CONFIG
from common.utils.business_event_context import current_event
from common.extensions import template_manager, cache_manager
from common.models.document import EmbeddingLargeOpenAI, EmbeddingSmallOpenAI
from common.utils.eveai_exceptions import EveAITenantNotFound
class CitedAnswer(BaseModel):
"""Default docstring - to be replaced with actual prompt"""
def create_language_template(template: str, language: str) -> str:
"""
Replace language placeholder in template with specified language
answer: str = Field(
...,
description="The answer to the user question, based on the given sources",
)
citations: List[int] = Field(
...,
description="The integer IDs of the SPECIFIC sources that were used to generate the answer"
)
insufficient_info: bool = Field(
False, # Default value is set to False
description="A boolean indicating wether given sources were sufficient or not to generate the answer"
)
Args:
template: Template string with {language} placeholder
language: Language code to insert
def set_language_prompt_template(cls, language_prompt):
cls.__doc__ = language_prompt
class ModelVariables(MutableMapping):
def __init__(self, tenant: Tenant, catalog_id=None):
self.tenant = tenant
self.catalog_id = catalog_id
self._variables = self._initialize_variables()
self._embedding_model = None
self._llm = None
self._llm_no_rag = None
self._transcription_client = None
self._prompt_templates = {}
self._embedding_db_model = None
self.llm_metrics_handler = LLMMetricsHandler()
self._transcription_client = None
def _initialize_variables(self):
variables = {}
# Get the Catalog if catalog_id is passed
if self.catalog_id:
catalog = Catalog.query.get_or_404(self.catalog_id)
# We initialize the variables that are available knowing the tenant.
variables['embed_tuning'] = catalog.embed_tuning or False
# Set HTML Chunking Variables
variables['html_tags'] = catalog.html_tags
variables['html_end_tags'] = catalog.html_end_tags
variables['html_included_elements'] = catalog.html_included_elements
variables['html_excluded_elements'] = catalog.html_excluded_elements
variables['html_excluded_classes'] = catalog.html_excluded_classes
# Set Chunk Size variables
variables['min_chunk_size'] = catalog.min_chunk_size
variables['max_chunk_size'] = catalog.max_chunk_size
# Set the RAG Context (will have to change once specialists are defined
variables['rag_context'] = self.tenant.rag_context or " "
# Temporary setting until we have Specialists
variables['rag_tuning'] = False
variables['RAG_temperature'] = 0.3
variables['no_RAG_temperature'] = 0.5
variables['k'] = 8
variables['similarity_threshold'] = 0.4
# Set model providers
variables['embedding_provider'], variables['embedding_model'] = self.tenant.embedding_model.rsplit('.', 1)
variables['llm_provider'], variables['llm_model'] = self.tenant.llm_model.rsplit('.', 1)
variables["templates"] = current_app.config['PROMPT_TEMPLATES'][(f"{variables['llm_provider']}."
f"{variables['llm_model']}")]
current_app.logger.info(f"Loaded prompt templates: \n")
current_app.logger.info(f"{variables['templates']}")
# Set model-specific configurations
model_config = MODEL_CONFIG.get(variables['llm_provider'], {}).get(variables['llm_model'], {})
variables.update(model_config)
variables['annotation_chunk_length'] = current_app.config['ANNOTATION_TEXT_CHUNK_LENGTH'][self.tenant.llm_model]
if variables['tool_calling_supported']:
variables['cited_answer_cls'] = CitedAnswer
variables['max_compression_duration'] = current_app.config['MAX_COMPRESSION_DURATION']
variables['max_transcription_duration'] = current_app.config['MAX_TRANSCRIPTION_DURATION']
variables['compression_cpu_limit'] = current_app.config['COMPRESSION_CPU_LIMIT']
variables['compression_process_delay'] = current_app.config['COMPRESSION_PROCESS_DELAY']
return variables
@property
def embedding_model(self):
api_key = os.getenv('OPENAI_API_KEY')
model = self._variables['embedding_model']
self._embedding_model = TrackedOpenAIEmbeddings(api_key=api_key,
model=model,
)
self._embedding_db_model = EmbeddingSmallOpenAI \
if model == 'text-embedding-3-small' \
else EmbeddingLargeOpenAI
return self._embedding_model
@property
def llm(self):
api_key = self.get_api_key_for_llm()
self._llm = ChatOpenAI(api_key=api_key,
model=self._variables['llm_model'],
temperature=self._variables['RAG_temperature'],
callbacks=[self.llm_metrics_handler])
return self._llm
@property
def llm_no_rag(self):
api_key = self.get_api_key_for_llm()
self._llm_no_rag = ChatOpenAI(api_key=api_key,
model=self._variables['llm_model'],
temperature=self._variables['RAG_temperature'],
callbacks=[self.llm_metrics_handler])
return self._llm_no_rag
def get_api_key_for_llm(self):
if self._variables['llm_provider'] == 'openai':
api_key = os.getenv('OPENAI_API_KEY')
else: # self._variables['llm_provider'] == 'anthropic'
api_key = os.getenv('ANTHROPIC_API_KEY')
return api_key
@property
def transcription_client(self):
api_key = os.getenv('OPENAI_API_KEY')
self._transcription_client = OpenAI(api_key=api_key, )
self._variables['transcription_model'] = 'whisper-1'
return self._transcription_client
def transcribe(self, *args, **kwargs):
return tracked_transcribe(self._transcription_client, *args, **kwargs)
@property
def embedding_db_model(self):
if self._embedding_db_model is None:
self._embedding_db_model = self.get_embedding_db_model()
return self._embedding_db_model
def get_embedding_db_model(self):
current_app.logger.debug("In get_embedding_db_model")
if self._embedding_db_model is None:
self._embedding_db_model = EmbeddingSmallOpenAI \
if self._variables['embedding_model'] == 'text-embedding-3-small' \
else EmbeddingLargeOpenAI
current_app.logger.debug(f"Embedding DB Model: {self._embedding_db_model}")
return self._embedding_db_model
def get_prompt_template(self, template_name: str) -> str:
current_app.logger.info(f"Getting prompt template for {template_name}")
if template_name not in self._prompt_templates:
self._prompt_templates[template_name] = self._load_prompt_template(template_name)
return self._prompt_templates[template_name]
def _load_prompt_template(self, template_name: str) -> str:
# In the future, this method will make an API call to Portkey
# For now, we'll simulate it with a placeholder implementation
# You can replace this with your current prompt loading logic
return self._variables['templates'][template_name]
def __getitem__(self, key: str) -> Any:
current_app.logger.debug(f"ModelVariables: Getting {key}")
# Support older template names (suffix = _template)
if key.endswith('_template'):
key = key[:-len('_template')]
current_app.logger.debug(f"ModelVariables: Getting modified {key}")
if key == 'embedding_model':
return self.embedding_model
elif key == 'embedding_db_model':
return self.embedding_db_model
elif key == 'llm':
return self.llm
elif key == 'llm_no_rag':
return self.llm_no_rag
elif key == 'transcription_client':
return self.transcription_client
elif key in self._variables.get('prompt_templates', []):
return self.get_prompt_template(key)
else:
value = self._variables.get(key)
if value is not None:
return value
else:
raise KeyError(f'Variable {key} does not exist in ModelVariables')
def __setitem__(self, key: str, value: Any) -> None:
self._variables[key] = value
def __delitem__(self, key: str) -> None:
del self._variables[key]
def __iter__(self) -> Iterator[str]:
return iter(self._variables)
def __len__(self):
return len(self._variables)
def get(self, key: str, default: Any = None) -> Any:
return self.__getitem__(key) or default
def update(self, **kwargs) -> None:
self._variables.update(kwargs)
def items(self):
return self._variables.items()
def keys(self):
return self._variables.keys()
def values(self):
return self._variables.values()
def select_model_variables(tenant, catalog_id=None):
model_variables = ModelVariables(tenant=tenant, catalog_id=catalog_id)
return model_variables
def create_language_template(template, language):
Returns:
str: Template with language placeholder replaced
"""
try:
full_language = langcodes.Language.make(language=language)
language_template = template.replace('{language}', full_language.display_name())
@@ -253,5 +40,249 @@ def create_language_template(template, language):
return language_template
def replace_variable_in_template(template, variable, value):
return template.replace(variable, value)
def replace_variable_in_template(template: str, variable: str, value: str) -> str:
"""
Replace a variable placeholder in template with specified value
Args:
template: Template string with variable placeholder
variable: Variable placeholder to replace (e.g. "{tenant_context}")
value: Value to insert
Returns:
str: Template with variable placeholder replaced
"""
return template.replace(variable, value or "")
class ModelVariables:
"""Manages model-related variables and configurations"""
def __init__(self, tenant_id: int, variables: Dict[str, Any] = None):
"""
Initialize ModelVariables with tenant and optional template manager
Args:
tenant: Tenant instance
template_manager: Optional TemplateManager instance
"""
current_app.logger.info(f'Model variables initialized with tenant {tenant_id} and variables \n{variables}')
self.tenant_id = tenant_id
self._variables = variables if variables is not None else self._initialize_variables()
current_app.logger.info(f'Model _variables initialized to {self._variables}')
self._embedding_model = None
self._embedding_model_class = None
self._llm_instances = {}
self.llm_metrics_handler = LLMMetricsHandler()
self._transcription_model = None
def _initialize_variables(self) -> Dict[str, Any]:
"""Initialize the variables dictionary"""
variables = {}
tenant = Tenant.query.get(self.tenant_id)
if not tenant:
raise EveAITenantNotFound(f"Tenant {self.tenant_id} not found")
# Set model providers
variables['embedding_provider'], variables['embedding_model'] = tenant.embedding_model.split('.')
variables['llm_provider'], variables['llm_model'] = tenant.llm_model.split('.')
variables['llm_full_model'] = tenant.llm_model
# Set model-specific configurations
model_config = MODEL_CONFIG.get(variables['llm_provider'], {}).get(variables['llm_model'], {})
variables.update(model_config)
# Additional configurations
variables['annotation_chunk_length'] = current_app.config['ANNOTATION_TEXT_CHUNK_LENGTH'][tenant.llm_model]
variables['max_compression_duration'] = current_app.config['MAX_COMPRESSION_DURATION']
variables['max_transcription_duration'] = current_app.config['MAX_TRANSCRIPTION_DURATION']
variables['compression_cpu_limit'] = current_app.config['COMPRESSION_CPU_LIMIT']
variables['compression_process_delay'] = current_app.config['COMPRESSION_PROCESS_DELAY']
return variables
@property
def embedding_model(self):
"""Get the embedding model instance"""
if self._embedding_model is None:
api_key = os.getenv('OPENAI_API_KEY')
self._embedding_model = TrackedOpenAIEmbeddings(
api_key=api_key,
model=self._variables['embedding_model']
)
return self._embedding_model
@property
def embedding_model_class(self):
"""Get the embedding model class"""
if self._embedding_model_class is None:
if self._variables['embedding_model'] == 'text-embedding-3-large':
self._embedding_model_class = EmbeddingLargeOpenAI
else: # text-embedding-3-small
self._embedding_model_class = EmbeddingSmallOpenAI
return self._embedding_model_class
@property
def annotation_chunk_length(self):
return self._variables['annotation_chunk_length']
@property
def max_compression_duration(self):
return self._variables['max_compression_duration']
@property
def max_transcription_duration(self):
return self._variables['max_transcription_duration']
@property
def compression_cpu_limit(self):
return self._variables['compression_cpu_limit']
@property
def compression_process_delay(self):
return self._variables['compression_process_delay']
def get_llm(self, temperature: float = 0.3, **kwargs) -> Any:
"""
Get an LLM instance with specific configuration
Args:
temperature: The temperature for the LLM
**kwargs: Additional configuration parameters
Returns:
An instance of the configured LLM
"""
cache_key = f"{temperature}_{hash(frozenset(kwargs.items()))}"
if cache_key not in self._llm_instances:
provider = self._variables['llm_provider']
model = self._variables['llm_model']
if provider == 'openai':
self._llm_instances[cache_key] = ChatOpenAI(
api_key=os.getenv('OPENAI_API_KEY'),
model=model,
temperature=temperature,
callbacks=[self.llm_metrics_handler],
**kwargs
)
elif provider == 'anthropic':
self._llm_instances[cache_key] = ChatAnthropic(
api_key=os.getenv('ANTHROPIC_API_KEY'),
model=current_app.config['ANTHROPIC_LLM_VERSIONS'][model],
temperature=temperature,
callbacks=[self.llm_metrics_handler],
**kwargs
)
else:
raise ValueError(f"Unsupported LLM provider: {provider}")
return self._llm_instances[cache_key]
@property
def transcription_model(self) -> TrackedOpenAITranscription:
"""Get the transcription model instance"""
if self._transcription_model is None:
api_key = os.getenv('OPENAI_API_KEY')
self._transcription_model = TrackedOpenAITranscription(
api_key=api_key,
model='whisper-1'
)
return self._transcription_model
# Remove the old transcription-related methods since they're now handled by TrackedOpenAITranscription
@property
def transcription_client(self):
raise DeprecationWarning("Use transcription_model instead")
def transcribe(self, *args, **kwargs):
raise DeprecationWarning("Use transcription_model.transcribe() instead")
def get_template(self, template_name: str, version: Optional[str] = None) -> str:
"""
Get a template for the tenant's configured LLM
Args:
template_name: Name of the template to retrieve
version: Optional specific version to retrieve
Returns:
The template content
"""
try:
template = template_manager.get_template(
self._variables['llm_full_model'],
template_name,
version
)
return template.content
except Exception as e:
current_app.logger.error(f"Error getting template {template_name}: {str(e)}")
# Fall back to old template loading if template_manager fails
if template_name in self._variables.get('templates', {}):
return self._variables['templates'][template_name]
raise
class ModelVariablesCacheHandler(CacheHandler[ModelVariables]):
handler_name = 'model_vars_cache' # Used to access handler instance from cache_manager
def __init__(self, region):
super().__init__(region, 'model_variables')
self.configure_keys('tenant_id')
self.subscribe_to_model('Tenant', ['tenant_id'])
def to_cache_data(self, instance: ModelVariables) -> Dict[str, Any]:
return {
'tenant_id': instance.tenant_id,
'variables': instance._variables,
'last_updated': dt.now(tz=tz.utc).isoformat()
}
def from_cache_data(self, data: Dict[str, Any], tenant_id: int, **kwargs) -> ModelVariables:
instance = ModelVariables(tenant_id, data.get('variables'))
return instance
def should_cache(self, value: Dict[str, Any]) -> bool:
required_fields = {'tenant_id', 'variables'}
return all(field in value for field in required_fields)
# Register the handler with the cache manager
cache_manager.register_handler(ModelVariablesCacheHandler, 'model')
# Helper function to get cached model variables
def get_model_variables(tenant_id: int) -> ModelVariables:
return cache_manager.model_vars_cache.get(
lambda tenant_id: ModelVariables(tenant_id), # function to create ModelVariables if required
tenant_id=tenant_id
)
# Written in a long format, without lambda
# def get_model_variables(tenant_id: int) -> ModelVariables:
# """
# Get ModelVariables instance, either from cache or newly created
#
# Args:
# tenant_id: The tenant's ID
#
# Returns:
# ModelVariables: Instance with either cached or fresh data
#
# Raises:
# TenantNotFoundError: If tenant doesn't exist
# CacheStateError: If cached data is invalid
# """
#
# def create_new_instance(tenant_id: int) -> ModelVariables:
# """Creator function that's called when cache miss occurs"""
# return ModelVariables(tenant_id) # This will initialize fresh variables
#
# return cache_manager.model_vars_cache.get(
# create_new_instance, # Function to create new instance if needed
# tenant_id=tenant_id # Parameters passed to both get() and create_new_instance
# )

View File

@@ -1,4 +1,6 @@
import os
import sys
import gevent
import time
from flask import current_app
@@ -28,3 +30,17 @@ def sync_folder(file_path):
dir_fd = os.open(file_path, os.O_RDONLY)
os.fsync(dir_fd)
os.close(dir_fd)
def get_project_root():
"""Get the root directory of the project."""
# Use the module that's actually running (not this file)
module = sys.modules['__main__']
if hasattr(module, '__file__'):
# Get the path to the main module
main_path = os.path.abspath(module.__file__)
# Get the root directory (where the main module is located)
return os.path.dirname(main_path)
else:
# Fallback: use current working directory
return os.getcwd()

View File

@@ -4,7 +4,6 @@ from common.models.user import Tenant
# Definition of Trigger Handlers
def set_tenant_session_data(sender, user, **kwargs):
current_app.logger.debug(f"Setting tenant session data for user {user.id}")
tenant = Tenant.query.filter_by(id=user.tenant_id).first()
session['tenant'] = tenant.to_dict()
session['default_language'] = tenant.default_language

View File

@@ -11,7 +11,7 @@ def confirm_token(token, expiration=3600):
try:
email = serializer.loads(token, salt=current_app.config['SECURITY_PASSWORD_SALT'], max_age=expiration)
except Exception as e:
current_app.logger.debug(f'Error confirming token: {e}')
current_app.logger.error(f'Error confirming token: {e}')
raise
return email
@@ -35,14 +35,11 @@ def generate_confirmation_token(email):
def send_confirmation_email(user):
current_app.logger.debug(f'Sending confirmation email to {user.email}')
if not test_smtp_connection():
raise Exception("Failed to connect to SMTP server")
token = generate_confirmation_token(user.email)
confirm_url = prefixed_url_for('security_bp.confirm_email', token=token, _external=True)
current_app.logger.debug(f'Confirmation URL: {confirm_url}')
html = render_template('email/activate.html', confirm_url=confirm_url)
subject = "Please confirm your email"
@@ -56,10 +53,8 @@ def send_confirmation_email(user):
def send_reset_email(user):
current_app.logger.debug(f'Sending reset email to {user.email}')
token = generate_reset_token(user.email)
reset_url = prefixed_url_for('security_bp.reset_password', token=token, _external=True)
current_app.logger.debug(f'Reset URL: {reset_url}')
html = render_template('email/reset_password.html', reset_url=reset_url)
subject = "Reset Your Password"

View File

@@ -0,0 +1,112 @@
from typing import List, Union
import re
class StringListConverter:
"""Utility class for converting between comma-separated strings and lists"""
@staticmethod
def string_to_list(input_string: Union[str, None], allow_empty: bool = True) -> List[str]:
"""
Convert a comma-separated string to a list of strings.
Args:
input_string: Comma-separated string to convert
allow_empty: If True, returns empty list for None/empty input
If False, raises ValueError for None/empty input
Returns:
List of stripped strings
Raises:
ValueError: If input is None/empty and allow_empty is False
"""
if not input_string:
if allow_empty:
return []
raise ValueError("Input string cannot be None or empty")
return [item.strip() for item in input_string.split(',') if item.strip()]
@staticmethod
def list_to_string(input_list: Union[List[str], None], allow_empty: bool = True) -> str:
"""
Convert a list of strings to a comma-separated string.
Args:
input_list: List of strings to convert
allow_empty: If True, returns empty string for None/empty input
If False, raises ValueError for None/empty input
Returns:
Comma-separated string
Raises:
ValueError: If input is None/empty and allow_empty is False
"""
if not input_list:
if allow_empty:
return ''
raise ValueError("Input list cannot be None or empty")
return ', '.join(str(item).strip() for item in input_list)
@staticmethod
def validate_format(input_string: str,
allowed_chars: str = r'a-zA-Z0-9_\-',
min_length: int = 1,
max_length: int = 50) -> bool:
"""
Validate the format of items in a comma-separated string.
Args:
input_string: String to validate
allowed_chars: String of allowed characters (for regex pattern)
min_length: Minimum length for each item
max_length: Maximum length for each item
Returns:
bool: True if format is valid, False otherwise
"""
if not input_string:
return False
# Create regex pattern for individual items
pattern = f'^[{allowed_chars}]{{{min_length},{max_length}}}$'
try:
# Convert to list and check each item
items = StringListConverter.string_to_list(input_string)
return all(bool(re.match(pattern, item)) for item in items)
except Exception:
return False
@staticmethod
def validate_and_convert(input_string: str,
allowed_chars: str = r'a-zA-Z0-9_\-',
min_length: int = 1,
max_length: int = 50) -> List[str]:
"""
Validate and convert a comma-separated string to a list.
Args:
input_string: String to validate and convert
allowed_chars: String of allowed characters (for regex pattern)
min_length: Minimum length for each item
max_length: Maximum length for each item
Returns:
List of validated and converted strings
Raises:
ValueError: If input string format is invalid
"""
if not StringListConverter.validate_format(
input_string, allowed_chars, min_length, max_length
):
raise ValueError(
f"Invalid format. Items must be {min_length}-{max_length} characters "
f"long and contain only these characters: {allowed_chars}"
)
return StringListConverter.string_to_list(input_string)

View File

@@ -44,7 +44,7 @@ def form_validation_failed(request, form):
for fieldName, errorMessages in form.errors.items():
for err in errorMessages:
flash(f"Error in {fieldName}: {err}", 'danger')
current_app.logger.debug(f"Error in {fieldName}: {err}")
current_app.logger.error(f"Error in {fieldName}: {err}")
def form_to_dict(form):