- Introduction of dynamic Retrievers & Specialists
- Introduction of dynamic Processors - Introduction of caching system - Introduction of a better template manager - Adaptation of ModelVariables to support dynamic Processors / Retrievers / Specialists - Start adaptation of chat client
This commit is contained in:
@@ -4,7 +4,6 @@ from contextlib import contextmanager
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, Optional
|
||||
from datetime import datetime as dt, timezone as tz
|
||||
from portkey_ai import Portkey, Config
|
||||
import logging
|
||||
|
||||
from .business_event_context import BusinessEventContext
|
||||
|
||||
0
common/utils/cache/__init__old.py
vendored
Normal file
0
common/utils/cache/__init__old.py
vendored
Normal file
89
common/utils/cache/base.py
vendored
Normal file
89
common/utils/cache/base.py
vendored
Normal file
@@ -0,0 +1,89 @@
|
||||
# common/utils/cache/base.py
|
||||
|
||||
from typing import Any, Dict, List, Optional, TypeVar, Generic, Type
|
||||
from dataclasses import dataclass
|
||||
from flask import Flask
|
||||
from dogpile.cache import CacheRegion
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheKey:
|
||||
"""Represents a cache key with multiple components"""
|
||||
components: Dict[str, Any]
|
||||
|
||||
def __str__(self) -> str:
|
||||
return ":".join(f"{k}={v}" for k, v in sorted(self.components.items()))
|
||||
|
||||
|
||||
class CacheInvalidationManager:
|
||||
"""Manages cache invalidation subscriptions"""
|
||||
|
||||
def __init__(self):
|
||||
self._subscribers = {}
|
||||
|
||||
def subscribe(self, model: str, handler: 'CacheHandler', key_fields: List[str]):
|
||||
if model not in self._subscribers:
|
||||
self._subscribers[model] = []
|
||||
self._subscribers[model].append((handler, key_fields))
|
||||
|
||||
def notify_change(self, model: str, **identifiers):
|
||||
if model in self._subscribers:
|
||||
for handler, key_fields in self._subscribers[model]:
|
||||
if all(field in identifiers for field in key_fields):
|
||||
handler.invalidate_by_model(model, **identifiers)
|
||||
|
||||
|
||||
class CacheHandler(Generic[T]):
|
||||
"""Base cache handler implementation"""
|
||||
|
||||
def __init__(self, region: CacheRegion, prefix: str):
|
||||
self.region = region
|
||||
self.prefix = prefix
|
||||
self._key_components = []
|
||||
|
||||
def configure_keys(self, *components: str):
|
||||
self._key_components = components
|
||||
return self
|
||||
|
||||
def subscribe_to_model(self, model: str, key_fields: List[str]):
|
||||
invalidation_manager.subscribe(model, self, key_fields)
|
||||
return self
|
||||
|
||||
def generate_key(self, **identifiers) -> str:
|
||||
missing = set(self._key_components) - set(identifiers.keys())
|
||||
if missing:
|
||||
raise ValueError(f"Missing key components: {missing}")
|
||||
|
||||
key = CacheKey({k: identifiers[k] for k in self._key_components})
|
||||
return f"{self.prefix}:{str(key)}"
|
||||
|
||||
def get(self, creator_func, **identifiers) -> T:
|
||||
cache_key = self.generate_key(**identifiers)
|
||||
|
||||
def creator():
|
||||
instance = creator_func(**identifiers)
|
||||
return self.to_cache_data(instance)
|
||||
|
||||
cached_data = self.region.get_or_create(
|
||||
cache_key,
|
||||
creator,
|
||||
should_cache_fn=self.should_cache
|
||||
)
|
||||
|
||||
return self.from_cache_data(cached_data, **identifiers)
|
||||
|
||||
def invalidate(self, **identifiers):
|
||||
cache_key = self.generate_key(**identifiers)
|
||||
self.region.delete(cache_key)
|
||||
|
||||
def invalidate_by_model(self, model: str, **identifiers):
|
||||
try:
|
||||
self.invalidate(**identifiers)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
|
||||
# Create global invalidation manager
|
||||
invalidation_manager = CacheInvalidationManager()
|
||||
32
common/utils/cache/eveai_cache_manager.py
vendored
Normal file
32
common/utils/cache/eveai_cache_manager.py
vendored
Normal file
@@ -0,0 +1,32 @@
|
||||
from typing import Type
|
||||
|
||||
from flask import Flask
|
||||
|
||||
from common.utils.cache.base import CacheHandler
|
||||
|
||||
|
||||
class EveAICacheManager:
|
||||
"""Cache manager with registration capabilities"""
|
||||
|
||||
def __init__(self):
|
||||
self.model_region = None
|
||||
self.eveai_chat_workers_region = None
|
||||
self.eveai_workers_region = None
|
||||
self._handlers = {}
|
||||
|
||||
def init_app(self, app: Flask):
|
||||
"""Initialize cache regions"""
|
||||
from common.utils.cache.regions import create_cache_regions
|
||||
self.model_region, self.eveai_chat_workers_region, self.eveai_workers_region = create_cache_regions(app)
|
||||
|
||||
# Initialize all registered handlers with their regions
|
||||
for handler_class, region_name in self._handlers.items():
|
||||
region = getattr(self, f"{region_name}_region")
|
||||
handler_instance = handler_class(region)
|
||||
setattr(self, handler_class.handler_name, handler_instance)
|
||||
|
||||
def register_handler(self, handler_class: Type[CacheHandler], region: str):
|
||||
"""Register a cache handler class with its region"""
|
||||
if not hasattr(handler_class, 'handler_name'):
|
||||
raise ValueError("Cache handler must define handler_name class attribute")
|
||||
self._handlers[handler_class] = region
|
||||
61
common/utils/cache/regions.py
vendored
Normal file
61
common/utils/cache/regions.py
vendored
Normal file
@@ -0,0 +1,61 @@
|
||||
# common/utils/cache/regions.py
|
||||
|
||||
from dogpile.cache import make_region
|
||||
from flask import current_app
|
||||
from urllib.parse import urlparse
|
||||
import os
|
||||
|
||||
|
||||
def get_redis_config(app):
|
||||
"""
|
||||
Create Redis configuration dict based on app config
|
||||
Handles both authenticated and non-authenticated setups
|
||||
"""
|
||||
# Parse the REDIS_BASE_URI to get all components
|
||||
redis_uri = urlparse(app.config['REDIS_BASE_URI'])
|
||||
|
||||
config = {
|
||||
'host': redis_uri.hostname,
|
||||
'port': int(redis_uri.port or 6379),
|
||||
'db': 4, # Keep this for later use
|
||||
'redis_expiration_time': 3600,
|
||||
'distributed_lock': True
|
||||
}
|
||||
|
||||
# Add authentication if provided
|
||||
if redis_uri.username and redis_uri.password:
|
||||
config.update({
|
||||
'username': redis_uri.username,
|
||||
'password': redis_uri.password
|
||||
})
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def create_cache_regions(app):
|
||||
"""Initialize all cache regions with app config"""
|
||||
redis_config = get_redis_config(app)
|
||||
|
||||
# Region for model-related caching (ModelVariables etc)
|
||||
model_region = make_region(name='model').configure(
|
||||
'dogpile.cache.redis',
|
||||
arguments=redis_config,
|
||||
replace_existing_backend=True
|
||||
)
|
||||
|
||||
# Region for eveai_chat_workers components (Specialists, Retrievers, ...)
|
||||
eveai_chat_workers_region = make_region(name='chat_workers').configure(
|
||||
'dogpile.cache.redis',
|
||||
arguments=redis_config, # arguments={**redis_config, 'db': 4}, # Different DB
|
||||
replace_existing_backend=True
|
||||
)
|
||||
|
||||
# Region for eveai_workers components (Processors, ...)
|
||||
eveai_workers_region = make_region(name='workers').configure(
|
||||
'dogpile.cache.redis',
|
||||
arguments=redis_config, # Same config for now
|
||||
replace_existing_backend=True
|
||||
)
|
||||
|
||||
return model_region, eveai_chat_workers_region, eveai_workers_region
|
||||
|
||||
@@ -8,8 +8,6 @@ celery_app = Celery()
|
||||
|
||||
def init_celery(celery, app, is_beat=False):
|
||||
celery_app.main = app.name
|
||||
app.logger.debug(f'CELERY_BROKER_URL: {app.config["CELERY_BROKER_URL"]}')
|
||||
app.logger.debug(f'CELERY_RESULT_BACKEND: {app.config["CELERY_RESULT_BACKEND"]}')
|
||||
|
||||
celery_config = {
|
||||
'broker_url': app.config.get('CELERY_BROKER_URL', 'redis://localhost:6379/0'),
|
||||
|
||||
613
common/utils/config_field_types.py
Normal file
613
common/utils/config_field_types.py
Normal file
@@ -0,0 +1,613 @@
|
||||
from typing import Optional, List, Union, Dict, Any, Pattern
|
||||
from pydantic import BaseModel, field_validator, model_validator
|
||||
from typing_extensions import Annotated
|
||||
import re
|
||||
from datetime import datetime
|
||||
import json
|
||||
from textwrap import dedent
|
||||
import yaml
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
class TaggingField(BaseModel):
|
||||
"""Represents a single tagging field configuration"""
|
||||
type: str
|
||||
required: bool = False
|
||||
description: Optional[str] = None
|
||||
allowed_values: Optional[List[Any]] = None # for enum type
|
||||
min_value: Optional[Union[int, float]] = None # for numeric types
|
||||
max_value: Optional[Union[int, float]] = None # for numeric types
|
||||
|
||||
@field_validator('type', mode='before')
|
||||
@classmethod
|
||||
def validate_type(cls, v: str) -> str:
|
||||
valid_types = ['string', 'integer', 'float', 'date', 'enum']
|
||||
if v not in valid_types:
|
||||
raise ValueError(f'type must be one of {valid_types}')
|
||||
return v
|
||||
|
||||
@model_validator(mode='after')
|
||||
def validate_field_constraints(self) -> 'TaggingField':
|
||||
# Validate enum constraints
|
||||
if self.type == 'enum':
|
||||
if not self.allowed_values:
|
||||
raise ValueError('allowed_values must be provided for enum type')
|
||||
elif self.allowed_values is not None:
|
||||
raise ValueError('allowed_values only valid for enum type')
|
||||
|
||||
# Validate numeric constraints
|
||||
if self.type not in ('integer', 'float'):
|
||||
if self.min_value is not None or self.max_value is not None:
|
||||
raise ValueError('min_value/max_value only valid for numeric types')
|
||||
else:
|
||||
if self.min_value is not None and self.max_value is not None and self.min_value >= self.max_value:
|
||||
raise ValueError('min_value must be less than max_value')
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class TaggingFields(BaseModel):
|
||||
"""Represents a collection of tagging fields, mapped by their names"""
|
||||
fields: Dict[str, TaggingField]
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Dict[str, Any]]) -> 'TaggingFields':
|
||||
return cls(fields={
|
||||
field_name: TaggingField(**field_config)
|
||||
for field_name, field_config in data.items()
|
||||
})
|
||||
|
||||
def to_dict(self) -> Dict[str, Dict[str, Any]]:
|
||||
return {
|
||||
field_name: field.model_dump(exclude_none=True)
|
||||
for field_name, field in self.fields.items()
|
||||
}
|
||||
|
||||
|
||||
class ArgumentConstraint(BaseModel):
|
||||
"""Base class for all argument constraints"""
|
||||
description: Optional[str] = None
|
||||
error_message: Optional[str] = None
|
||||
|
||||
|
||||
class NumericConstraint(ArgumentConstraint):
|
||||
"""Constraints for numeric values (int/float)"""
|
||||
min_value: Optional[float] = None
|
||||
max_value: Optional[float] = None
|
||||
include_min: bool = True # True for >= min_value, False for > min_value
|
||||
include_max: bool = True # True for <= max_value, False for < max_value
|
||||
|
||||
@model_validator(mode='after')
|
||||
def validate_ranges(self) -> 'NumericConstraint':
|
||||
if self.min_value is not None and self.max_value is not None:
|
||||
if self.min_value > self.max_value:
|
||||
raise ValueError("min_value must be less than or equal to max_value")
|
||||
return self
|
||||
|
||||
def validate(self, value: Union[int, float]) -> bool:
|
||||
if self.min_value is not None:
|
||||
if self.include_min and value < self.min_value:
|
||||
return False
|
||||
if not self.include_min and value <= self.min_value:
|
||||
return False
|
||||
if self.max_value is not None:
|
||||
if self.include_max and value > self.max_value:
|
||||
return False
|
||||
if not self.include_max and value >= self.max_value:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class StringConstraint(ArgumentConstraint):
|
||||
"""Constraints for string values"""
|
||||
min_length: Optional[int] = None
|
||||
max_length: Optional[int] = None
|
||||
patterns: Optional[List[str]] = None # List of regex patterns to match
|
||||
pattern_match_all: bool = False # If True, string must match all patterns
|
||||
forbidden_patterns: Optional[List[str]] = None # List of regex patterns that must not match
|
||||
allow_empty: bool = False
|
||||
|
||||
@field_validator('patterns', 'forbidden_patterns')
|
||||
@classmethod
|
||||
def validate_patterns(cls, v: Optional[List[str]]) -> Optional[List[str]]:
|
||||
if v is not None:
|
||||
# Validate each pattern compiles
|
||||
for pattern in v:
|
||||
try:
|
||||
re.compile(pattern)
|
||||
except re.error as e:
|
||||
raise ValueError(f"Invalid regex pattern '{pattern}': {str(e)}")
|
||||
return v
|
||||
|
||||
def validate(self, value: str) -> bool:
|
||||
if not self.allow_empty and not value:
|
||||
return False
|
||||
|
||||
if self.min_length is not None and len(value) < self.min_length:
|
||||
return False
|
||||
|
||||
if self.max_length is not None and len(value) > self.max_length:
|
||||
return False
|
||||
|
||||
if self.patterns:
|
||||
matches = [bool(re.search(pattern, value)) for pattern in self.patterns]
|
||||
if self.pattern_match_all and not all(matches):
|
||||
return False
|
||||
if not self.pattern_match_all and not any(matches):
|
||||
return False
|
||||
|
||||
if self.forbidden_patterns:
|
||||
for pattern in self.forbidden_patterns:
|
||||
if re.search(pattern, value):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
class DateConstraint(ArgumentConstraint):
|
||||
"""Constraints for date values"""
|
||||
min_date: Optional[datetime] = None
|
||||
max_date: Optional[datetime] = None
|
||||
include_min: bool = True
|
||||
include_max: bool = True
|
||||
allowed_formats: Optional[List[str]] = None # List of allowed date formats
|
||||
|
||||
@model_validator(mode='after')
|
||||
def validate_ranges(self) -> 'DateConstraint':
|
||||
if self.min_date and self.max_date and self.min_date > self.max_date:
|
||||
raise ValueError("min_date must be less than or equal to max_date")
|
||||
return self
|
||||
|
||||
def validate(self, value: datetime) -> bool:
|
||||
if self.min_date is not None:
|
||||
if self.include_min and value < self.min_date:
|
||||
return False
|
||||
if not self.include_min and value <= self.min_date:
|
||||
return False
|
||||
|
||||
if self.max_date is not None:
|
||||
if self.include_max and value > self.max_date:
|
||||
return False
|
||||
if not self.include_max and value >= self.max_date:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
class EnumConstraint(ArgumentConstraint):
|
||||
"""Constraints for enum values"""
|
||||
allowed_values: List[Any]
|
||||
case_sensitive: bool = True # For string enums
|
||||
allow_multiple: bool = False # If True, value can be a list of allowed values
|
||||
min_selections: Optional[int] = None # When allow_multiple is True
|
||||
max_selections: Optional[int] = None # When allow_multiple is True
|
||||
|
||||
@model_validator(mode='after')
|
||||
def validate_selections(self) -> 'EnumConstraint':
|
||||
if self.allow_multiple:
|
||||
if self.min_selections is not None and self.max_selections is not None:
|
||||
if self.min_selections > self.max_selections:
|
||||
raise ValueError("min_selections must be less than or equal to max_selections")
|
||||
if self.max_selections > len(self.allowed_values):
|
||||
raise ValueError("max_selections cannot be greater than number of allowed values")
|
||||
return self
|
||||
|
||||
def validate(self, value: Union[Any, List[Any]]) -> bool:
|
||||
if self.allow_multiple:
|
||||
if not isinstance(value, list):
|
||||
return False
|
||||
|
||||
if self.min_selections is not None and len(value) < self.min_selections:
|
||||
return False
|
||||
|
||||
if self.max_selections is not None and len(value) > self.max_selections:
|
||||
return False
|
||||
|
||||
for v in value:
|
||||
if not self._validate_single_value(v):
|
||||
return False
|
||||
else:
|
||||
return self._validate_single_value(value)
|
||||
|
||||
return True
|
||||
|
||||
def _validate_single_value(self, value: Any) -> bool:
|
||||
if isinstance(value, str) and not self.case_sensitive:
|
||||
return any(str(value).lower() == str(v).lower() for v in self.allowed_values)
|
||||
return value in self.allowed_values
|
||||
|
||||
|
||||
class ArgumentDefinition(BaseModel):
|
||||
"""Defines an argument with its type and constraints"""
|
||||
name: str
|
||||
type: str
|
||||
description: Optional[str] = None
|
||||
required: bool = False
|
||||
default: Optional[Any] = None
|
||||
constraints: Optional[Union[NumericConstraint, StringConstraint, DateConstraint, EnumConstraint]] = None
|
||||
|
||||
@field_validator('type')
|
||||
@classmethod
|
||||
def validate_type(cls, v: str) -> str:
|
||||
valid_types = ['string', 'integer', 'float', 'date', 'enum']
|
||||
if v not in valid_types:
|
||||
raise ValueError(f'type must be one of {valid_types}')
|
||||
return v
|
||||
|
||||
@model_validator(mode='after')
|
||||
def validate_constraints(self) -> 'ArgumentDefinition':
|
||||
if self.constraints:
|
||||
expected_constraint_types = {
|
||||
'string': StringConstraint,
|
||||
'integer': NumericConstraint,
|
||||
'float': NumericConstraint,
|
||||
'date': DateConstraint,
|
||||
'enum': EnumConstraint
|
||||
}
|
||||
|
||||
expected_type = expected_constraint_types.get(self.type)
|
||||
if not isinstance(self.constraints, expected_type):
|
||||
raise ValueError(f'Constraints for type {self.type} must be of type {expected_type.__name__}')
|
||||
|
||||
if self.default is not None:
|
||||
if not self.constraints.validate(self.default):
|
||||
raise ValueError(f'Default value does not satisfy constraints for {self.name}')
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class ArgumentDefinitions(BaseModel):
|
||||
"""Collection of argument definitions"""
|
||||
arguments: Dict[str, ArgumentDefinition]
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Dict[str, Any]]) -> 'ArgumentDefinitions':
|
||||
return cls(arguments={
|
||||
arg_name: ArgumentDefinition(**arg_config)
|
||||
for arg_name, arg_config in data.items()
|
||||
})
|
||||
|
||||
def to_dict(self) -> Dict[str, Dict[str, Any]]:
|
||||
return {
|
||||
arg_name: arg.model_dump(exclude_none=True)
|
||||
for arg_name, arg in self.arguments.items()
|
||||
}
|
||||
|
||||
def validate_argument_values(self, values: Dict[str, Any]) -> Dict[str, str]:
|
||||
"""
|
||||
Validate a set of argument values against their definitions
|
||||
Returns a dictionary of error messages for invalid arguments
|
||||
"""
|
||||
errors = {}
|
||||
|
||||
# Check for required arguments
|
||||
for name, arg_def in self.arguments.items():
|
||||
if arg_def.required and name not in values:
|
||||
errors[name] = "Required argument missing"
|
||||
continue
|
||||
|
||||
if name in values:
|
||||
value = values[name]
|
||||
|
||||
# Validate type
|
||||
try:
|
||||
if arg_def.type == 'integer':
|
||||
value = int(value)
|
||||
elif arg_def.type == 'float':
|
||||
value = float(value)
|
||||
elif arg_def.type == 'date' and isinstance(value, str):
|
||||
if arg_def.constraints and arg_def.constraints.allowed_formats:
|
||||
for fmt in arg_def.constraints.allowed_formats:
|
||||
try:
|
||||
value = datetime.strptime(value, fmt)
|
||||
break
|
||||
except ValueError:
|
||||
continue
|
||||
else:
|
||||
errors[
|
||||
name] = f"Invalid date format. Allowed formats: {arg_def.constraints.allowed_formats}"
|
||||
continue
|
||||
except (ValueError, TypeError):
|
||||
errors[name] = f"Invalid type. Expected {arg_def.type}"
|
||||
continue
|
||||
|
||||
# Validate constraints
|
||||
if arg_def.constraints and not arg_def.constraints.validate(value):
|
||||
errors[name] = arg_def.constraints.error_message or "Value does not satisfy constraints"
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
@dataclass
|
||||
class DocumentationFormat:
|
||||
"""Constants for documentation formats"""
|
||||
MARKDOWN = "markdown"
|
||||
JSON = "json"
|
||||
YAML = "yaml"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DocumentationVersion:
|
||||
"""Constants for documentation versions"""
|
||||
BASIC = "basic" # Original documentation without retriever info
|
||||
EXTENDED = "extended" # Including retriever documentation
|
||||
|
||||
|
||||
def _generate_argument_constraints(field_config: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
"""Generate possible argument constraints based on field type"""
|
||||
constraints = []
|
||||
|
||||
base_constraint = {
|
||||
"description": f"Constraint for {field_config.get('description', 'field')}",
|
||||
"error_message": "Optional custom error message"
|
||||
}
|
||||
|
||||
if field_config["type"] == "integer" or field_config["type"] == "float":
|
||||
constraints.append({
|
||||
**base_constraint,
|
||||
"type": "NumericConstraint",
|
||||
"possible_constraints": {
|
||||
"min_value": "number",
|
||||
"max_value": "number",
|
||||
"include_min": "boolean",
|
||||
"include_max": "boolean"
|
||||
},
|
||||
"example": {
|
||||
"min_value": field_config.get("min_value", 0),
|
||||
"max_value": field_config.get("max_value", 100),
|
||||
"include_min": True,
|
||||
"include_max": True
|
||||
}
|
||||
})
|
||||
|
||||
elif field_config["type"] == "string":
|
||||
constraints.append({
|
||||
**base_constraint,
|
||||
"type": "StringConstraint",
|
||||
"possible_constraints": {
|
||||
"min_length": "integer",
|
||||
"max_length": "integer",
|
||||
"patterns": "list[str]",
|
||||
"pattern_match_all": "boolean",
|
||||
"forbidden_patterns": "list[str]",
|
||||
"allow_empty": "boolean"
|
||||
},
|
||||
"example": {
|
||||
"min_length": 1,
|
||||
"max_length": 100,
|
||||
"patterns": ["^[A-Za-z0-9]+$"],
|
||||
"pattern_match_all": False,
|
||||
"forbidden_patterns": ["^test_", "_temp$"],
|
||||
"allow_empty": False
|
||||
}
|
||||
})
|
||||
|
||||
elif field_config["type"] == "enum":
|
||||
constraints.append({
|
||||
**base_constraint,
|
||||
"type": "EnumConstraint",
|
||||
"possible_constraints": {
|
||||
"allowed_values": f"list[{field_config.get('allowed_values', ['value1', 'value2'])}]",
|
||||
"case_sensitive": "boolean",
|
||||
"allow_multiple": "boolean",
|
||||
"min_selections": "integer",
|
||||
"max_selections": "integer"
|
||||
},
|
||||
"example": {
|
||||
"allowed_values": field_config.get("allowed_values", ["value1", "value2"]),
|
||||
"case_sensitive": True,
|
||||
"allow_multiple": True,
|
||||
"min_selections": 1,
|
||||
"max_selections": 2
|
||||
}
|
||||
})
|
||||
|
||||
elif field_config["type"] == "date":
|
||||
constraints.append({
|
||||
**base_constraint,
|
||||
"type": "DateConstraint",
|
||||
"possible_constraints": {
|
||||
"min_date": "datetime",
|
||||
"max_date": "datetime",
|
||||
"include_min": "boolean",
|
||||
"include_max": "boolean",
|
||||
"allowed_formats": "list[str]"
|
||||
},
|
||||
"example": {
|
||||
"min_date": "2024-01-01T00:00:00",
|
||||
"max_date": "2024-12-31T23:59:59",
|
||||
"include_min": True,
|
||||
"include_max": True,
|
||||
"allowed_formats": ["%Y-%m-%d", "%Y/%m/%d"]
|
||||
}
|
||||
})
|
||||
|
||||
return constraints
|
||||
|
||||
|
||||
def generate_field_documentation(
|
||||
tagging_fields: Dict[str, Any],
|
||||
format: str = "markdown",
|
||||
version: str = "basic"
|
||||
) -> str:
|
||||
"""
|
||||
Generate documentation for tagging fields configuration.
|
||||
|
||||
Args:
|
||||
tagging_fields: Dictionary containing tagging fields configuration
|
||||
format: Output format ("markdown", "json", or "yaml")
|
||||
version: Documentation version ("basic" or "extended")
|
||||
|
||||
Returns:
|
||||
str: Formatted documentation
|
||||
"""
|
||||
if version not in [DocumentationVersion.BASIC, DocumentationVersion.EXTENDED]:
|
||||
raise ValueError(f"Unsupported documentation version: {version}")
|
||||
|
||||
# Normalize fields configuration
|
||||
normalized_fields = {}
|
||||
|
||||
for field_name, field_config in tagging_fields.items():
|
||||
field_doc = {
|
||||
"name": field_name,
|
||||
"type": field_config["type"],
|
||||
"required": field_config.get("required", False),
|
||||
"description": field_config.get("description", "No description provided"),
|
||||
"constraints": []
|
||||
}
|
||||
|
||||
# Only include possible arguments in extended version
|
||||
if version == DocumentationVersion.EXTENDED:
|
||||
field_doc["possible_arguments"] = _generate_argument_constraints(field_config)
|
||||
|
||||
# Add type-specific constraints
|
||||
if field_config["type"] == "integer" or field_config["type"] == "float":
|
||||
if "min_value" in field_config:
|
||||
field_doc["constraints"].append(
|
||||
f"Minimum value: {field_config['min_value']}")
|
||||
if "max_value" in field_config:
|
||||
field_doc["constraints"].append(
|
||||
f"Maximum value: {field_config['max_value']}")
|
||||
|
||||
elif field_config["type"] == "string":
|
||||
if "min_length" in field_config:
|
||||
field_doc["constraints"].append(
|
||||
f"Minimum length: {field_config['min_length']}")
|
||||
if "max_length" in field_config:
|
||||
field_doc["constraints"].append(
|
||||
f"Maximum length: {field_config['max_length']}")
|
||||
if "patterns" in field_config:
|
||||
field_doc["constraints"].append(
|
||||
f"Must match patterns: {', '.join(field_config['patterns'])}")
|
||||
|
||||
elif field_config["type"] == "enum":
|
||||
if "allowed_values" in field_config:
|
||||
field_doc["constraints"].append(
|
||||
f"Allowed values: {', '.join(str(v) for v in field_config['allowed_values'])}")
|
||||
|
||||
elif field_config["type"] == "date":
|
||||
if "min_date" in field_config:
|
||||
field_doc["constraints"].append(
|
||||
f"Minimum date: {field_config['min_date']}")
|
||||
if "max_date" in field_config:
|
||||
field_doc["constraints"].append(
|
||||
f"Maximum date: {field_config['max_date']}")
|
||||
if "allowed_formats" in field_config:
|
||||
field_doc["constraints"].append(
|
||||
f"Allowed formats: {', '.join(field_config['allowed_formats'])}")
|
||||
|
||||
normalized_fields[field_name] = field_doc
|
||||
|
||||
# Generate documentation in requested format
|
||||
if format == DocumentationFormat.MARKDOWN:
|
||||
return _generate_markdown_docs(normalized_fields, version)
|
||||
elif format == DocumentationFormat.JSON:
|
||||
return _generate_json_docs(normalized_fields, version)
|
||||
elif format == DocumentationFormat.YAML:
|
||||
return _generate_yaml_docs(normalized_fields, version)
|
||||
else:
|
||||
raise ValueError(f"Unsupported documentation format: {format}")
|
||||
|
||||
|
||||
def _generate_markdown_docs(fields: Dict[str, Any], version: str) -> str:
|
||||
"""Generate markdown documentation"""
|
||||
docs = ["# Tagging Fields Documentation\n"]
|
||||
|
||||
# Add overview table
|
||||
docs.append("## Fields Overview\n")
|
||||
docs.append("| Field Name | Type | Required | Description |")
|
||||
docs.append("|------------|------|----------|-------------|")
|
||||
|
||||
for field_name, field in fields.items():
|
||||
docs.append(
|
||||
f"| {field_name} | {field['type']} | "
|
||||
f"{'Yes' if field['required'] else 'No'} | {field['description']} |"
|
||||
)
|
||||
|
||||
# Add detailed field specifications
|
||||
docs.append("\n## Detailed Field Specifications\n")
|
||||
|
||||
for field_name, field in fields.items():
|
||||
docs.append(f"### {field_name}\n")
|
||||
docs.append(f"**Type:** {field['type']}")
|
||||
docs.append(f"**Required:** {'Yes' if field['required'] else 'No'}")
|
||||
docs.append(f"**Description:** {field['description']}\n")
|
||||
|
||||
if field["constraints"]:
|
||||
docs.append("**Field Constraints:**")
|
||||
for constraint in field["constraints"]:
|
||||
docs.append(f"- {constraint}")
|
||||
docs.append("")
|
||||
|
||||
# Add retriever argument documentation only in extended version
|
||||
if version == DocumentationVersion.EXTENDED and "possible_arguments" in field:
|
||||
docs.append("**Possible Retriever Arguments:**")
|
||||
for arg_constraint in field["possible_arguments"]:
|
||||
docs.append(f"\n*{arg_constraint['type']}*")
|
||||
docs.append(f"Description: {arg_constraint['description']}")
|
||||
docs.append("\nPossible constraints:")
|
||||
for const_name, const_type in arg_constraint["possible_constraints"].items():
|
||||
docs.append(f"- `{const_name}`: {const_type}")
|
||||
|
||||
docs.append("\nExample:")
|
||||
docs.append("```python")
|
||||
docs.append(json.dumps(arg_constraint["example"], indent=2))
|
||||
docs.append("```\n")
|
||||
|
||||
# Add example retriever configuration only in extended version
|
||||
if version == DocumentationVersion.EXTENDED:
|
||||
docs.append("\n## Example Retriever Configuration\n")
|
||||
docs.append("```python")
|
||||
example_config = {
|
||||
"metadata_filters": {
|
||||
field_name: field["possible_arguments"][0]["example"]
|
||||
for field_name, field in fields.items()
|
||||
if "possible_arguments" in field
|
||||
}
|
||||
}
|
||||
docs.append(json.dumps(example_config, indent=2))
|
||||
docs.append("```")
|
||||
|
||||
return "\n".join(docs)
|
||||
|
||||
|
||||
def _generate_json_docs(fields: Dict[str, Any], version: str) -> str:
|
||||
"""Generate JSON documentation"""
|
||||
doc = {
|
||||
"tagging_fields_documentation": {
|
||||
"version": version,
|
||||
"fields": fields
|
||||
}
|
||||
}
|
||||
|
||||
if version == DocumentationVersion.EXTENDED:
|
||||
doc["tagging_fields_documentation"]["example_retriever_config"] = {
|
||||
"metadata_filters": {
|
||||
field_name: field["possible_arguments"][0]["example"]
|
||||
for field_name, field in fields.items()
|
||||
if "possible_arguments" in field
|
||||
}
|
||||
}
|
||||
|
||||
return json.dumps(doc, indent=2)
|
||||
|
||||
|
||||
def _generate_yaml_docs(fields: Dict[str, Any], version: str) -> str:
|
||||
"""Generate YAML documentation"""
|
||||
doc = {
|
||||
"tagging_fields_documentation": {
|
||||
"version": version,
|
||||
"fields": fields
|
||||
}
|
||||
}
|
||||
|
||||
if version == DocumentationVersion.EXTENDED:
|
||||
doc["tagging_fields_documentation"]["example_retriever_config"] = {
|
||||
"metadata_filters": {
|
||||
field_name: field["possible_arguments"][0]["example"]
|
||||
for field_name, field in fields.items()
|
||||
if "possible_arguments" in field
|
||||
}
|
||||
}
|
||||
|
||||
return yaml.dump(doc, sort_keys=False, default_flow_style=False)
|
||||
@@ -5,10 +5,8 @@ from common.models.user import Tenant, TenantDomain
|
||||
def get_allowed_origins(tenant_id):
|
||||
session_key = f"allowed_origins_{tenant_id}"
|
||||
if session_key in session:
|
||||
current_app.logger.debug(f"Fetching allowed origins for tenant {tenant_id} from session")
|
||||
return session[session_key]
|
||||
|
||||
current_app.logger.debug(f"Fetching allowed origins for tenant {tenant_id} from database")
|
||||
tenant_domains = TenantDomain.query.filter_by(tenant_id=int(tenant_id)).all()
|
||||
allowed_origins = [domain.domain for domain in tenant_domains]
|
||||
|
||||
@@ -18,14 +16,8 @@ def get_allowed_origins(tenant_id):
|
||||
|
||||
|
||||
def cors_after_request(response, prefix):
|
||||
current_app.logger.debug(f'CORS after request: {request.path}, prefix: {prefix}')
|
||||
current_app.logger.debug(f'request.headers: {request.headers}')
|
||||
current_app.logger.debug(f'request.args: {request.args}')
|
||||
current_app.logger.debug(f'request is json?: {request.is_json}')
|
||||
|
||||
# Exclude health checks from checks
|
||||
if request.path.startswith('/healthz') or request.path.startswith('/_healthz'):
|
||||
current_app.logger.debug('Skipping CORS headers for health checks')
|
||||
response.headers.add('Access-Control-Allow-Origin', '*')
|
||||
response.headers.add('Access-Control-Allow-Headers', '*')
|
||||
response.headers.add('Access-Control-Allow-Methods', '*')
|
||||
@@ -36,7 +28,6 @@ def cors_after_request(response, prefix):
|
||||
|
||||
# Try to get tenant_id from JSON payload
|
||||
json_data = request.get_json(silent=True)
|
||||
current_app.logger.debug(f'request.get_json(silent=True): {json_data}')
|
||||
|
||||
if json_data and 'tenant_id' in json_data:
|
||||
tenant_id = json_data['tenant_id']
|
||||
@@ -44,23 +35,17 @@ def cors_after_request(response, prefix):
|
||||
# Fallback to get tenant_id from query parameters or headers if JSON is not available
|
||||
tenant_id = request.args.get('tenant_id') or request.args.get('tenantId') or request.headers.get('X-Tenant-ID')
|
||||
|
||||
current_app.logger.debug(f'Identified tenant_id: {tenant_id}')
|
||||
|
||||
if tenant_id:
|
||||
allowed_origins = get_allowed_origins(tenant_id)
|
||||
current_app.logger.debug(f'Allowed origins for tenant {tenant_id}: {allowed_origins}')
|
||||
else:
|
||||
current_app.logger.warning('tenant_id not found in request')
|
||||
|
||||
origin = request.headers.get('Origin')
|
||||
current_app.logger.debug(f'Origin: {origin}')
|
||||
|
||||
if origin in allowed_origins:
|
||||
response.headers.add('Access-Control-Allow-Origin', origin)
|
||||
response.headers.add('Access-Control-Allow-Headers', 'Content-Type,Authorization')
|
||||
response.headers.add('Access-Control-Allow-Methods', 'GET,POST,PUT,DELETE,OPTIONS')
|
||||
response.headers.add('Access-Control-Allow-Credentials', 'true')
|
||||
current_app.logger.debug(f'CORS headers set for origin: {origin}')
|
||||
else:
|
||||
current_app.logger.warning(f'Origin {origin} not allowed')
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@ def log_request_middleware(app):
|
||||
|
||||
@app.before_request
|
||||
def log_session_state_before():
|
||||
app.logger.debug(f'Session state before request: {session.items()}')
|
||||
pass
|
||||
|
||||
# @app.after_request
|
||||
# def log_response_info(response):
|
||||
@@ -58,5 +58,4 @@ def log_request_middleware(app):
|
||||
|
||||
@app.after_request
|
||||
def log_session_state_after(response):
|
||||
app.logger.debug(f'Session state after request: {session.items()}')
|
||||
return response
|
||||
|
||||
@@ -24,6 +24,7 @@ def create_document_stack(api_input, file, filename, extension, tenant_id):
|
||||
# Create the DocumentVersion
|
||||
new_doc_vers = create_version_for_document(new_doc, tenant_id,
|
||||
api_input.get('url', ''),
|
||||
api_input.get('sub_file_type', ''),
|
||||
api_input.get('language', 'en'),
|
||||
api_input.get('user_context', ''),
|
||||
api_input.get('user_metadata'),
|
||||
@@ -64,7 +65,7 @@ def create_document(form, filename, catalog_id):
|
||||
return new_doc
|
||||
|
||||
|
||||
def create_version_for_document(document, tenant_id, url, language, user_context, user_metadata, catalog_properties):
|
||||
def create_version_for_document(document, tenant_id, url, sub_file_type, language, user_context, user_metadata, catalog_properties):
|
||||
new_doc_vers = DocumentVersion()
|
||||
if url != '':
|
||||
new_doc_vers.url = url
|
||||
@@ -83,6 +84,9 @@ def create_version_for_document(document, tenant_id, url, language, user_context
|
||||
if catalog_properties != '' and catalog_properties is not None:
|
||||
new_doc_vers.catalog_properties = catalog_properties
|
||||
|
||||
if sub_file_type != '':
|
||||
new_doc_vers.sub_file_type = sub_file_type
|
||||
|
||||
new_doc_vers.document = document
|
||||
|
||||
set_logging_information(new_doc_vers, dt.now(tz.utc))
|
||||
@@ -237,8 +241,6 @@ def start_embedding_task(tenant_id, doc_vers_id):
|
||||
|
||||
|
||||
def validate_file_type(extension):
|
||||
current_app.logger.debug(f'Validating file type {extension}')
|
||||
current_app.logger.debug(f'Supported file types: {current_app.config["SUPPORTED_FILE_TYPES"]}')
|
||||
if extension not in current_app.config['SUPPORTED_FILE_TYPES']:
|
||||
raise EveAIUnsupportedFileType(f"Filetype {extension} is currently not supported. "
|
||||
f"Supported filetypes: {', '.join(current_app.config['SUPPORTED_FILE_TYPES'])}")
|
||||
|
||||
@@ -10,6 +10,7 @@ class EveAIException(Exception):
|
||||
def to_dict(self):
|
||||
rv = dict(self.payload or ())
|
||||
rv['message'] = self.message
|
||||
rv['error'] = self.__class__.__name__
|
||||
return rv
|
||||
|
||||
|
||||
@@ -41,3 +42,9 @@ class EveAINoLicenseForTenant(EveAIException):
|
||||
super().__init__(message, status_code, payload)
|
||||
|
||||
|
||||
class EveAITenantNotFound(EveAIException):
|
||||
"""Raised when a tenant is not found"""
|
||||
|
||||
def __init__(self, message="Tenant not found", status_code=400, payload=None):
|
||||
super().__init__(message, status_code, payload)
|
||||
|
||||
|
||||
@@ -24,9 +24,6 @@ def mw_before_request():
|
||||
if not tenant_id:
|
||||
raise Exception('Cannot switch schema for tenant: no tenant defined in session')
|
||||
|
||||
for role in current_user.roles:
|
||||
current_app.logger.debug(f'In middleware: User {current_user.email} has role {role.name}')
|
||||
|
||||
# user = User.query.get(current_user.id)
|
||||
if current_user.has_role('Super User') or current_user.tenant_id == tenant_id:
|
||||
Database(tenant_id).switch_schema()
|
||||
|
||||
@@ -1,249 +1,36 @@
|
||||
import os
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
import langcodes
|
||||
from flask import current_app
|
||||
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
from typing import List, Any, Iterator
|
||||
from collections.abc import MutableMapping
|
||||
from openai import OpenAI
|
||||
from portkey_ai import createHeaders, PORTKEY_GATEWAY_URL
|
||||
from portkey_ai.langchain.portkey_langchain_callback_handler import LangchainCallbackHandler
|
||||
|
||||
from common.langchain.llm_metrics_handler import LLMMetricsHandler
|
||||
from common.langchain.templates.template_manager import TemplateManager
|
||||
from langchain_openai import OpenAIEmbeddings, ChatOpenAI, OpenAI
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
from flask import current_app
|
||||
from datetime import datetime as dt, timezone as tz
|
||||
|
||||
from common.langchain.tracked_openai_embeddings import TrackedOpenAIEmbeddings
|
||||
from common.langchain.tracked_transcribe import tracked_transcribe
|
||||
from common.models.document import EmbeddingSmallOpenAI, EmbeddingLargeOpenAI, Catalog
|
||||
from common.langchain.tracked_transcription import TrackedOpenAITranscription
|
||||
from common.models.user import Tenant
|
||||
from common.utils.cache.base import CacheHandler
|
||||
from config.model_config import MODEL_CONFIG
|
||||
from common.utils.business_event_context import current_event
|
||||
from common.extensions import template_manager, cache_manager
|
||||
from common.models.document import EmbeddingLargeOpenAI, EmbeddingSmallOpenAI
|
||||
from common.utils.eveai_exceptions import EveAITenantNotFound
|
||||
|
||||
|
||||
class CitedAnswer(BaseModel):
|
||||
"""Default docstring - to be replaced with actual prompt"""
|
||||
def create_language_template(template: str, language: str) -> str:
|
||||
"""
|
||||
Replace language placeholder in template with specified language
|
||||
|
||||
answer: str = Field(
|
||||
...,
|
||||
description="The answer to the user question, based on the given sources",
|
||||
)
|
||||
citations: List[int] = Field(
|
||||
...,
|
||||
description="The integer IDs of the SPECIFIC sources that were used to generate the answer"
|
||||
)
|
||||
insufficient_info: bool = Field(
|
||||
False, # Default value is set to False
|
||||
description="A boolean indicating wether given sources were sufficient or not to generate the answer"
|
||||
)
|
||||
Args:
|
||||
template: Template string with {language} placeholder
|
||||
language: Language code to insert
|
||||
|
||||
|
||||
def set_language_prompt_template(cls, language_prompt):
|
||||
cls.__doc__ = language_prompt
|
||||
|
||||
|
||||
class ModelVariables(MutableMapping):
|
||||
def __init__(self, tenant: Tenant, catalog_id=None):
|
||||
self.tenant = tenant
|
||||
self.catalog_id = catalog_id
|
||||
self._variables = self._initialize_variables()
|
||||
self._embedding_model = None
|
||||
self._llm = None
|
||||
self._llm_no_rag = None
|
||||
self._transcription_client = None
|
||||
self._prompt_templates = {}
|
||||
self._embedding_db_model = None
|
||||
self.llm_metrics_handler = LLMMetricsHandler()
|
||||
self._transcription_client = None
|
||||
|
||||
def _initialize_variables(self):
|
||||
variables = {}
|
||||
|
||||
# Get the Catalog if catalog_id is passed
|
||||
if self.catalog_id:
|
||||
catalog = Catalog.query.get_or_404(self.catalog_id)
|
||||
|
||||
# We initialize the variables that are available knowing the tenant.
|
||||
variables['embed_tuning'] = catalog.embed_tuning or False
|
||||
|
||||
# Set HTML Chunking Variables
|
||||
variables['html_tags'] = catalog.html_tags
|
||||
variables['html_end_tags'] = catalog.html_end_tags
|
||||
variables['html_included_elements'] = catalog.html_included_elements
|
||||
variables['html_excluded_elements'] = catalog.html_excluded_elements
|
||||
variables['html_excluded_classes'] = catalog.html_excluded_classes
|
||||
|
||||
# Set Chunk Size variables
|
||||
variables['min_chunk_size'] = catalog.min_chunk_size
|
||||
variables['max_chunk_size'] = catalog.max_chunk_size
|
||||
|
||||
# Set the RAG Context (will have to change once specialists are defined
|
||||
variables['rag_context'] = self.tenant.rag_context or " "
|
||||
# Temporary setting until we have Specialists
|
||||
variables['rag_tuning'] = False
|
||||
variables['RAG_temperature'] = 0.3
|
||||
variables['no_RAG_temperature'] = 0.5
|
||||
variables['k'] = 8
|
||||
variables['similarity_threshold'] = 0.4
|
||||
|
||||
# Set model providers
|
||||
variables['embedding_provider'], variables['embedding_model'] = self.tenant.embedding_model.rsplit('.', 1)
|
||||
variables['llm_provider'], variables['llm_model'] = self.tenant.llm_model.rsplit('.', 1)
|
||||
variables["templates"] = current_app.config['PROMPT_TEMPLATES'][(f"{variables['llm_provider']}."
|
||||
f"{variables['llm_model']}")]
|
||||
current_app.logger.info(f"Loaded prompt templates: \n")
|
||||
current_app.logger.info(f"{variables['templates']}")
|
||||
|
||||
# Set model-specific configurations
|
||||
model_config = MODEL_CONFIG.get(variables['llm_provider'], {}).get(variables['llm_model'], {})
|
||||
variables.update(model_config)
|
||||
|
||||
variables['annotation_chunk_length'] = current_app.config['ANNOTATION_TEXT_CHUNK_LENGTH'][self.tenant.llm_model]
|
||||
|
||||
if variables['tool_calling_supported']:
|
||||
variables['cited_answer_cls'] = CitedAnswer
|
||||
|
||||
variables['max_compression_duration'] = current_app.config['MAX_COMPRESSION_DURATION']
|
||||
variables['max_transcription_duration'] = current_app.config['MAX_TRANSCRIPTION_DURATION']
|
||||
variables['compression_cpu_limit'] = current_app.config['COMPRESSION_CPU_LIMIT']
|
||||
variables['compression_process_delay'] = current_app.config['COMPRESSION_PROCESS_DELAY']
|
||||
|
||||
return variables
|
||||
|
||||
@property
|
||||
def embedding_model(self):
|
||||
api_key = os.getenv('OPENAI_API_KEY')
|
||||
model = self._variables['embedding_model']
|
||||
self._embedding_model = TrackedOpenAIEmbeddings(api_key=api_key,
|
||||
model=model,
|
||||
)
|
||||
self._embedding_db_model = EmbeddingSmallOpenAI \
|
||||
if model == 'text-embedding-3-small' \
|
||||
else EmbeddingLargeOpenAI
|
||||
|
||||
return self._embedding_model
|
||||
|
||||
@property
|
||||
def llm(self):
|
||||
api_key = self.get_api_key_for_llm()
|
||||
self._llm = ChatOpenAI(api_key=api_key,
|
||||
model=self._variables['llm_model'],
|
||||
temperature=self._variables['RAG_temperature'],
|
||||
callbacks=[self.llm_metrics_handler])
|
||||
return self._llm
|
||||
|
||||
@property
|
||||
def llm_no_rag(self):
|
||||
api_key = self.get_api_key_for_llm()
|
||||
self._llm_no_rag = ChatOpenAI(api_key=api_key,
|
||||
model=self._variables['llm_model'],
|
||||
temperature=self._variables['RAG_temperature'],
|
||||
callbacks=[self.llm_metrics_handler])
|
||||
return self._llm_no_rag
|
||||
|
||||
def get_api_key_for_llm(self):
|
||||
if self._variables['llm_provider'] == 'openai':
|
||||
api_key = os.getenv('OPENAI_API_KEY')
|
||||
else: # self._variables['llm_provider'] == 'anthropic'
|
||||
api_key = os.getenv('ANTHROPIC_API_KEY')
|
||||
|
||||
return api_key
|
||||
|
||||
@property
|
||||
def transcription_client(self):
|
||||
api_key = os.getenv('OPENAI_API_KEY')
|
||||
self._transcription_client = OpenAI(api_key=api_key, )
|
||||
self._variables['transcription_model'] = 'whisper-1'
|
||||
return self._transcription_client
|
||||
|
||||
def transcribe(self, *args, **kwargs):
|
||||
return tracked_transcribe(self._transcription_client, *args, **kwargs)
|
||||
|
||||
@property
|
||||
def embedding_db_model(self):
|
||||
if self._embedding_db_model is None:
|
||||
self._embedding_db_model = self.get_embedding_db_model()
|
||||
return self._embedding_db_model
|
||||
|
||||
def get_embedding_db_model(self):
|
||||
current_app.logger.debug("In get_embedding_db_model")
|
||||
if self._embedding_db_model is None:
|
||||
self._embedding_db_model = EmbeddingSmallOpenAI \
|
||||
if self._variables['embedding_model'] == 'text-embedding-3-small' \
|
||||
else EmbeddingLargeOpenAI
|
||||
current_app.logger.debug(f"Embedding DB Model: {self._embedding_db_model}")
|
||||
return self._embedding_db_model
|
||||
|
||||
def get_prompt_template(self, template_name: str) -> str:
|
||||
current_app.logger.info(f"Getting prompt template for {template_name}")
|
||||
if template_name not in self._prompt_templates:
|
||||
self._prompt_templates[template_name] = self._load_prompt_template(template_name)
|
||||
return self._prompt_templates[template_name]
|
||||
|
||||
def _load_prompt_template(self, template_name: str) -> str:
|
||||
# In the future, this method will make an API call to Portkey
|
||||
# For now, we'll simulate it with a placeholder implementation
|
||||
# You can replace this with your current prompt loading logic
|
||||
return self._variables['templates'][template_name]
|
||||
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
current_app.logger.debug(f"ModelVariables: Getting {key}")
|
||||
# Support older template names (suffix = _template)
|
||||
if key.endswith('_template'):
|
||||
key = key[:-len('_template')]
|
||||
current_app.logger.debug(f"ModelVariables: Getting modified {key}")
|
||||
if key == 'embedding_model':
|
||||
return self.embedding_model
|
||||
elif key == 'embedding_db_model':
|
||||
return self.embedding_db_model
|
||||
elif key == 'llm':
|
||||
return self.llm
|
||||
elif key == 'llm_no_rag':
|
||||
return self.llm_no_rag
|
||||
elif key == 'transcription_client':
|
||||
return self.transcription_client
|
||||
elif key in self._variables.get('prompt_templates', []):
|
||||
return self.get_prompt_template(key)
|
||||
else:
|
||||
value = self._variables.get(key)
|
||||
if value is not None:
|
||||
return value
|
||||
else:
|
||||
raise KeyError(f'Variable {key} does not exist in ModelVariables')
|
||||
|
||||
def __setitem__(self, key: str, value: Any) -> None:
|
||||
self._variables[key] = value
|
||||
|
||||
def __delitem__(self, key: str) -> None:
|
||||
del self._variables[key]
|
||||
|
||||
def __iter__(self) -> Iterator[str]:
|
||||
return iter(self._variables)
|
||||
|
||||
def __len__(self):
|
||||
return len(self._variables)
|
||||
|
||||
def get(self, key: str, default: Any = None) -> Any:
|
||||
return self.__getitem__(key) or default
|
||||
|
||||
def update(self, **kwargs) -> None:
|
||||
self._variables.update(kwargs)
|
||||
|
||||
def items(self):
|
||||
return self._variables.items()
|
||||
|
||||
def keys(self):
|
||||
return self._variables.keys()
|
||||
|
||||
def values(self):
|
||||
return self._variables.values()
|
||||
|
||||
|
||||
def select_model_variables(tenant, catalog_id=None):
|
||||
model_variables = ModelVariables(tenant=tenant, catalog_id=catalog_id)
|
||||
return model_variables
|
||||
|
||||
|
||||
def create_language_template(template, language):
|
||||
Returns:
|
||||
str: Template with language placeholder replaced
|
||||
"""
|
||||
try:
|
||||
full_language = langcodes.Language.make(language=language)
|
||||
language_template = template.replace('{language}', full_language.display_name())
|
||||
@@ -253,5 +40,249 @@ def create_language_template(template, language):
|
||||
return language_template
|
||||
|
||||
|
||||
def replace_variable_in_template(template, variable, value):
|
||||
return template.replace(variable, value)
|
||||
def replace_variable_in_template(template: str, variable: str, value: str) -> str:
|
||||
"""
|
||||
Replace a variable placeholder in template with specified value
|
||||
|
||||
Args:
|
||||
template: Template string with variable placeholder
|
||||
variable: Variable placeholder to replace (e.g. "{tenant_context}")
|
||||
value: Value to insert
|
||||
|
||||
Returns:
|
||||
str: Template with variable placeholder replaced
|
||||
"""
|
||||
return template.replace(variable, value or "")
|
||||
|
||||
|
||||
class ModelVariables:
|
||||
"""Manages model-related variables and configurations"""
|
||||
|
||||
def __init__(self, tenant_id: int, variables: Dict[str, Any] = None):
|
||||
"""
|
||||
Initialize ModelVariables with tenant and optional template manager
|
||||
|
||||
Args:
|
||||
tenant: Tenant instance
|
||||
template_manager: Optional TemplateManager instance
|
||||
"""
|
||||
current_app.logger.info(f'Model variables initialized with tenant {tenant_id} and variables \n{variables}')
|
||||
self.tenant_id = tenant_id
|
||||
self._variables = variables if variables is not None else self._initialize_variables()
|
||||
current_app.logger.info(f'Model _variables initialized to {self._variables}')
|
||||
self._embedding_model = None
|
||||
self._embedding_model_class = None
|
||||
self._llm_instances = {}
|
||||
self.llm_metrics_handler = LLMMetricsHandler()
|
||||
self._transcription_model = None
|
||||
|
||||
def _initialize_variables(self) -> Dict[str, Any]:
|
||||
"""Initialize the variables dictionary"""
|
||||
variables = {}
|
||||
|
||||
tenant = Tenant.query.get(self.tenant_id)
|
||||
if not tenant:
|
||||
raise EveAITenantNotFound(f"Tenant {self.tenant_id} not found")
|
||||
|
||||
# Set model providers
|
||||
variables['embedding_provider'], variables['embedding_model'] = tenant.embedding_model.split('.')
|
||||
variables['llm_provider'], variables['llm_model'] = tenant.llm_model.split('.')
|
||||
variables['llm_full_model'] = tenant.llm_model
|
||||
|
||||
# Set model-specific configurations
|
||||
model_config = MODEL_CONFIG.get(variables['llm_provider'], {}).get(variables['llm_model'], {})
|
||||
variables.update(model_config)
|
||||
|
||||
# Additional configurations
|
||||
variables['annotation_chunk_length'] = current_app.config['ANNOTATION_TEXT_CHUNK_LENGTH'][tenant.llm_model]
|
||||
variables['max_compression_duration'] = current_app.config['MAX_COMPRESSION_DURATION']
|
||||
variables['max_transcription_duration'] = current_app.config['MAX_TRANSCRIPTION_DURATION']
|
||||
variables['compression_cpu_limit'] = current_app.config['COMPRESSION_CPU_LIMIT']
|
||||
variables['compression_process_delay'] = current_app.config['COMPRESSION_PROCESS_DELAY']
|
||||
|
||||
return variables
|
||||
|
||||
@property
|
||||
def embedding_model(self):
|
||||
"""Get the embedding model instance"""
|
||||
if self._embedding_model is None:
|
||||
api_key = os.getenv('OPENAI_API_KEY')
|
||||
self._embedding_model = TrackedOpenAIEmbeddings(
|
||||
api_key=api_key,
|
||||
model=self._variables['embedding_model']
|
||||
)
|
||||
return self._embedding_model
|
||||
|
||||
@property
|
||||
def embedding_model_class(self):
|
||||
"""Get the embedding model class"""
|
||||
if self._embedding_model_class is None:
|
||||
if self._variables['embedding_model'] == 'text-embedding-3-large':
|
||||
self._embedding_model_class = EmbeddingLargeOpenAI
|
||||
else: # text-embedding-3-small
|
||||
self._embedding_model_class = EmbeddingSmallOpenAI
|
||||
|
||||
return self._embedding_model_class
|
||||
|
||||
@property
|
||||
def annotation_chunk_length(self):
|
||||
return self._variables['annotation_chunk_length']
|
||||
|
||||
@property
|
||||
def max_compression_duration(self):
|
||||
return self._variables['max_compression_duration']
|
||||
|
||||
@property
|
||||
def max_transcription_duration(self):
|
||||
return self._variables['max_transcription_duration']
|
||||
|
||||
@property
|
||||
def compression_cpu_limit(self):
|
||||
return self._variables['compression_cpu_limit']
|
||||
|
||||
@property
|
||||
def compression_process_delay(self):
|
||||
return self._variables['compression_process_delay']
|
||||
|
||||
def get_llm(self, temperature: float = 0.3, **kwargs) -> Any:
|
||||
"""
|
||||
Get an LLM instance with specific configuration
|
||||
|
||||
Args:
|
||||
temperature: The temperature for the LLM
|
||||
**kwargs: Additional configuration parameters
|
||||
|
||||
Returns:
|
||||
An instance of the configured LLM
|
||||
"""
|
||||
cache_key = f"{temperature}_{hash(frozenset(kwargs.items()))}"
|
||||
|
||||
if cache_key not in self._llm_instances:
|
||||
provider = self._variables['llm_provider']
|
||||
model = self._variables['llm_model']
|
||||
|
||||
if provider == 'openai':
|
||||
self._llm_instances[cache_key] = ChatOpenAI(
|
||||
api_key=os.getenv('OPENAI_API_KEY'),
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
callbacks=[self.llm_metrics_handler],
|
||||
**kwargs
|
||||
)
|
||||
elif provider == 'anthropic':
|
||||
self._llm_instances[cache_key] = ChatAnthropic(
|
||||
api_key=os.getenv('ANTHROPIC_API_KEY'),
|
||||
model=current_app.config['ANTHROPIC_LLM_VERSIONS'][model],
|
||||
temperature=temperature,
|
||||
callbacks=[self.llm_metrics_handler],
|
||||
**kwargs
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported LLM provider: {provider}")
|
||||
|
||||
return self._llm_instances[cache_key]
|
||||
|
||||
@property
|
||||
def transcription_model(self) -> TrackedOpenAITranscription:
|
||||
"""Get the transcription model instance"""
|
||||
if self._transcription_model is None:
|
||||
api_key = os.getenv('OPENAI_API_KEY')
|
||||
self._transcription_model = TrackedOpenAITranscription(
|
||||
api_key=api_key,
|
||||
model='whisper-1'
|
||||
)
|
||||
return self._transcription_model
|
||||
|
||||
# Remove the old transcription-related methods since they're now handled by TrackedOpenAITranscription
|
||||
@property
|
||||
def transcription_client(self):
|
||||
raise DeprecationWarning("Use transcription_model instead")
|
||||
|
||||
def transcribe(self, *args, **kwargs):
|
||||
raise DeprecationWarning("Use transcription_model.transcribe() instead")
|
||||
|
||||
def get_template(self, template_name: str, version: Optional[str] = None) -> str:
|
||||
"""
|
||||
Get a template for the tenant's configured LLM
|
||||
|
||||
Args:
|
||||
template_name: Name of the template to retrieve
|
||||
version: Optional specific version to retrieve
|
||||
|
||||
Returns:
|
||||
The template content
|
||||
"""
|
||||
try:
|
||||
template = template_manager.get_template(
|
||||
self._variables['llm_full_model'],
|
||||
template_name,
|
||||
version
|
||||
)
|
||||
return template.content
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error getting template {template_name}: {str(e)}")
|
||||
# Fall back to old template loading if template_manager fails
|
||||
if template_name in self._variables.get('templates', {}):
|
||||
return self._variables['templates'][template_name]
|
||||
raise
|
||||
|
||||
|
||||
class ModelVariablesCacheHandler(CacheHandler[ModelVariables]):
|
||||
handler_name = 'model_vars_cache' # Used to access handler instance from cache_manager
|
||||
|
||||
def __init__(self, region):
|
||||
super().__init__(region, 'model_variables')
|
||||
self.configure_keys('tenant_id')
|
||||
self.subscribe_to_model('Tenant', ['tenant_id'])
|
||||
|
||||
def to_cache_data(self, instance: ModelVariables) -> Dict[str, Any]:
|
||||
return {
|
||||
'tenant_id': instance.tenant_id,
|
||||
'variables': instance._variables,
|
||||
'last_updated': dt.now(tz=tz.utc).isoformat()
|
||||
}
|
||||
|
||||
def from_cache_data(self, data: Dict[str, Any], tenant_id: int, **kwargs) -> ModelVariables:
|
||||
instance = ModelVariables(tenant_id, data.get('variables'))
|
||||
return instance
|
||||
|
||||
def should_cache(self, value: Dict[str, Any]) -> bool:
|
||||
required_fields = {'tenant_id', 'variables'}
|
||||
return all(field in value for field in required_fields)
|
||||
|
||||
|
||||
# Register the handler with the cache manager
|
||||
cache_manager.register_handler(ModelVariablesCacheHandler, 'model')
|
||||
|
||||
|
||||
# Helper function to get cached model variables
|
||||
def get_model_variables(tenant_id: int) -> ModelVariables:
|
||||
return cache_manager.model_vars_cache.get(
|
||||
lambda tenant_id: ModelVariables(tenant_id), # function to create ModelVariables if required
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
|
||||
# Written in a long format, without lambda
|
||||
# def get_model_variables(tenant_id: int) -> ModelVariables:
|
||||
# """
|
||||
# Get ModelVariables instance, either from cache or newly created
|
||||
#
|
||||
# Args:
|
||||
# tenant_id: The tenant's ID
|
||||
#
|
||||
# Returns:
|
||||
# ModelVariables: Instance with either cached or fresh data
|
||||
#
|
||||
# Raises:
|
||||
# TenantNotFoundError: If tenant doesn't exist
|
||||
# CacheStateError: If cached data is invalid
|
||||
# """
|
||||
#
|
||||
# def create_new_instance(tenant_id: int) -> ModelVariables:
|
||||
# """Creator function that's called when cache miss occurs"""
|
||||
# return ModelVariables(tenant_id) # This will initialize fresh variables
|
||||
#
|
||||
# return cache_manager.model_vars_cache.get(
|
||||
# create_new_instance, # Function to create new instance if needed
|
||||
# tenant_id=tenant_id # Parameters passed to both get() and create_new_instance
|
||||
# )
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
import gevent
|
||||
import time
|
||||
from flask import current_app
|
||||
@@ -28,3 +30,17 @@ def sync_folder(file_path):
|
||||
dir_fd = os.open(file_path, os.O_RDONLY)
|
||||
os.fsync(dir_fd)
|
||||
os.close(dir_fd)
|
||||
|
||||
|
||||
def get_project_root():
|
||||
"""Get the root directory of the project."""
|
||||
# Use the module that's actually running (not this file)
|
||||
module = sys.modules['__main__']
|
||||
if hasattr(module, '__file__'):
|
||||
# Get the path to the main module
|
||||
main_path = os.path.abspath(module.__file__)
|
||||
# Get the root directory (where the main module is located)
|
||||
return os.path.dirname(main_path)
|
||||
else:
|
||||
# Fallback: use current working directory
|
||||
return os.getcwd()
|
||||
|
||||
@@ -4,7 +4,6 @@ from common.models.user import Tenant
|
||||
|
||||
# Definition of Trigger Handlers
|
||||
def set_tenant_session_data(sender, user, **kwargs):
|
||||
current_app.logger.debug(f"Setting tenant session data for user {user.id}")
|
||||
tenant = Tenant.query.filter_by(id=user.tenant_id).first()
|
||||
session['tenant'] = tenant.to_dict()
|
||||
session['default_language'] = tenant.default_language
|
||||
|
||||
@@ -11,7 +11,7 @@ def confirm_token(token, expiration=3600):
|
||||
try:
|
||||
email = serializer.loads(token, salt=current_app.config['SECURITY_PASSWORD_SALT'], max_age=expiration)
|
||||
except Exception as e:
|
||||
current_app.logger.debug(f'Error confirming token: {e}')
|
||||
current_app.logger.error(f'Error confirming token: {e}')
|
||||
raise
|
||||
return email
|
||||
|
||||
@@ -35,14 +35,11 @@ def generate_confirmation_token(email):
|
||||
|
||||
|
||||
def send_confirmation_email(user):
|
||||
current_app.logger.debug(f'Sending confirmation email to {user.email}')
|
||||
|
||||
if not test_smtp_connection():
|
||||
raise Exception("Failed to connect to SMTP server")
|
||||
|
||||
token = generate_confirmation_token(user.email)
|
||||
confirm_url = prefixed_url_for('security_bp.confirm_email', token=token, _external=True)
|
||||
current_app.logger.debug(f'Confirmation URL: {confirm_url}')
|
||||
|
||||
html = render_template('email/activate.html', confirm_url=confirm_url)
|
||||
subject = "Please confirm your email"
|
||||
@@ -56,10 +53,8 @@ def send_confirmation_email(user):
|
||||
|
||||
|
||||
def send_reset_email(user):
|
||||
current_app.logger.debug(f'Sending reset email to {user.email}')
|
||||
token = generate_reset_token(user.email)
|
||||
reset_url = prefixed_url_for('security_bp.reset_password', token=token, _external=True)
|
||||
current_app.logger.debug(f'Reset URL: {reset_url}')
|
||||
|
||||
html = render_template('email/reset_password.html', reset_url=reset_url)
|
||||
subject = "Reset Your Password"
|
||||
|
||||
112
common/utils/string_list_converter.py
Normal file
112
common/utils/string_list_converter.py
Normal file
@@ -0,0 +1,112 @@
|
||||
from typing import List, Union
|
||||
import re
|
||||
|
||||
|
||||
class StringListConverter:
|
||||
"""Utility class for converting between comma-separated strings and lists"""
|
||||
|
||||
@staticmethod
|
||||
def string_to_list(input_string: Union[str, None], allow_empty: bool = True) -> List[str]:
|
||||
"""
|
||||
Convert a comma-separated string to a list of strings.
|
||||
|
||||
Args:
|
||||
input_string: Comma-separated string to convert
|
||||
allow_empty: If True, returns empty list for None/empty input
|
||||
If False, raises ValueError for None/empty input
|
||||
|
||||
Returns:
|
||||
List of stripped strings
|
||||
|
||||
Raises:
|
||||
ValueError: If input is None/empty and allow_empty is False
|
||||
"""
|
||||
if not input_string:
|
||||
if allow_empty:
|
||||
return []
|
||||
raise ValueError("Input string cannot be None or empty")
|
||||
|
||||
return [item.strip() for item in input_string.split(',') if item.strip()]
|
||||
|
||||
@staticmethod
|
||||
def list_to_string(input_list: Union[List[str], None], allow_empty: bool = True) -> str:
|
||||
"""
|
||||
Convert a list of strings to a comma-separated string.
|
||||
|
||||
Args:
|
||||
input_list: List of strings to convert
|
||||
allow_empty: If True, returns empty string for None/empty input
|
||||
If False, raises ValueError for None/empty input
|
||||
|
||||
Returns:
|
||||
Comma-separated string
|
||||
|
||||
Raises:
|
||||
ValueError: If input is None/empty and allow_empty is False
|
||||
"""
|
||||
if not input_list:
|
||||
if allow_empty:
|
||||
return ''
|
||||
raise ValueError("Input list cannot be None or empty")
|
||||
|
||||
return ', '.join(str(item).strip() for item in input_list)
|
||||
|
||||
@staticmethod
|
||||
def validate_format(input_string: str,
|
||||
allowed_chars: str = r'a-zA-Z0-9_\-',
|
||||
min_length: int = 1,
|
||||
max_length: int = 50) -> bool:
|
||||
"""
|
||||
Validate the format of items in a comma-separated string.
|
||||
|
||||
Args:
|
||||
input_string: String to validate
|
||||
allowed_chars: String of allowed characters (for regex pattern)
|
||||
min_length: Minimum length for each item
|
||||
max_length: Maximum length for each item
|
||||
|
||||
Returns:
|
||||
bool: True if format is valid, False otherwise
|
||||
"""
|
||||
if not input_string:
|
||||
return False
|
||||
|
||||
# Create regex pattern for individual items
|
||||
pattern = f'^[{allowed_chars}]{{{min_length},{max_length}}}$'
|
||||
|
||||
try:
|
||||
# Convert to list and check each item
|
||||
items = StringListConverter.string_to_list(input_string)
|
||||
return all(bool(re.match(pattern, item)) for item in items)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def validate_and_convert(input_string: str,
|
||||
allowed_chars: str = r'a-zA-Z0-9_\-',
|
||||
min_length: int = 1,
|
||||
max_length: int = 50) -> List[str]:
|
||||
"""
|
||||
Validate and convert a comma-separated string to a list.
|
||||
|
||||
Args:
|
||||
input_string: String to validate and convert
|
||||
allowed_chars: String of allowed characters (for regex pattern)
|
||||
min_length: Minimum length for each item
|
||||
max_length: Maximum length for each item
|
||||
|
||||
Returns:
|
||||
List of validated and converted strings
|
||||
|
||||
Raises:
|
||||
ValueError: If input string format is invalid
|
||||
"""
|
||||
if not StringListConverter.validate_format(
|
||||
input_string, allowed_chars, min_length, max_length
|
||||
):
|
||||
raise ValueError(
|
||||
f"Invalid format. Items must be {min_length}-{max_length} characters "
|
||||
f"long and contain only these characters: {allowed_chars}"
|
||||
)
|
||||
|
||||
return StringListConverter.string_to_list(input_string)
|
||||
@@ -44,7 +44,7 @@ def form_validation_failed(request, form):
|
||||
for fieldName, errorMessages in form.errors.items():
|
||||
for err in errorMessages:
|
||||
flash(f"Error in {fieldName}: {err}", 'danger')
|
||||
current_app.logger.debug(f"Error in {fieldName}: {err}")
|
||||
current_app.logger.error(f"Error in {fieldName}: {err}")
|
||||
|
||||
|
||||
def form_to_dict(form):
|
||||
|
||||
Reference in New Issue
Block a user