Extra commit for files in 'common'

- Add functionality to add a default dictionary for configuration fields
- Correct entitlement processing
- Remove get_template functionality from ModelVariables, define it directly with LLM model definition in configuration file.
This commit is contained in:
Josako
2025-05-19 14:12:38 +02:00
parent 28aea85b10
commit d2bb51a4a8
6 changed files with 128 additions and 290 deletions

View File

@@ -10,7 +10,6 @@ from flask_wtf import CSRFProtect
from flask_restx import Api
from prometheus_flask_exporter import PrometheusMetrics
from .langchain.templates.template_manager import TemplateManager
from .utils.cache.eveai_cache_manager import EveAICacheManager
from .utils.simple_encryption import SimpleEncryption
from .utils.minio_utils import MinioClient
@@ -30,6 +29,5 @@ api_rest = Api()
simple_encryption = SimpleEncryption()
minio_client = MinioClient()
metrics = PrometheusMetrics.for_app_factory()
template_manager = TemplateManager()
cache_manager = EveAICacheManager()

View File

@@ -1,153 +0,0 @@
import os
import yaml
from typing import Dict, Optional, Any
from packaging import version
from dataclasses import dataclass
from flask import current_app, Flask
from common.utils.os_utils import get_project_root
@dataclass
class PromptTemplate:
"""Represents a versioned prompt template"""
content: str
version: str
metadata: Dict[str, Any]
class TemplateManager:
"""Manages versioned prompt templates"""
def __init__(self):
self.templates_dir = None
self._templates = None
self.app = None
def init_app(self, app: Flask) -> None:
# Initialize template manager
base_dir = "/app"
self.templates_dir = os.path.join(base_dir, 'config', 'prompts')
self.app = app
self._templates = self._load_templates()
# Log available templates for each supported model
for llm in app.config['SUPPORTED_LLMS']:
try:
available_templates = self.list_templates(llm)
app.logger.info(f"Loaded templates for {llm}: {available_templates}")
except ValueError:
app.logger.warning(f"No templates found for {llm}")
def _load_templates(self) -> Dict[str, Dict[str, Dict[str, PromptTemplate]]]:
"""
Load all template versions from the templates directory.
Structure: {provider.model -> {template_name -> {version -> template}}}
Directory structure:
prompts/
├── provider/
│ └── model/
│ └── template_name/
│ └── version.yaml
"""
templates = {}
# Iterate through providers (anthropic, openai)
for provider in os.listdir(self.templates_dir):
provider_path = os.path.join(self.templates_dir, provider)
if not os.path.isdir(provider_path):
continue
# Iterate through models (claude-3, gpt-4o)
for model in os.listdir(provider_path):
model_path = os.path.join(provider_path, model)
if not os.path.isdir(model_path):
continue
provider_model = f"{provider}.{model}"
templates[provider_model] = {}
# Iterate through template types (rag, summary, etc.)
for template_name in os.listdir(model_path):
template_path = os.path.join(model_path, template_name)
if not os.path.isdir(template_path):
continue
template_versions = {}
# Load all version files for this template
for version_file in os.listdir(template_path):
if not version_file.endswith('.yaml'):
continue
version_str = version_file[:-5] # Remove .yaml
if not self._is_valid_version(version_str):
current_app.logger.warning(
f"Invalid version format for {template_name}: {version_str}")
continue
try:
with open(os.path.join(template_path, version_file)) as f:
template_data = yaml.safe_load(f)
# Verify required fields
if not template_data.get('content'):
raise ValueError("Template content is required")
template_versions[version_str] = PromptTemplate(
content=template_data['content'],
version=version_str,
metadata=template_data.get('metadata', {})
)
except Exception as e:
current_app.logger.error(
f"Error loading template {template_name} version {version_str}: {e}")
continue
if template_versions:
templates[provider_model][template_name] = template_versions
return templates
def _is_valid_version(self, version_str: str) -> bool:
"""Validate semantic versioning string"""
try:
version.parse(version_str)
return True
except version.InvalidVersion:
return False
def get_template(self,
provider_model: str,
template_name: str,
template_version: Optional[str] = None) -> PromptTemplate:
"""
Get a specific template version. If version not specified,
returns the latest version.
"""
if provider_model not in self._templates:
raise ValueError(f"Unknown provider.model: {provider_model}")
if template_name not in self._templates[provider_model]:
raise ValueError(f"Unknown template: {template_name}")
versions = self._templates[provider_model][template_name]
if template_version:
if template_version not in versions:
raise ValueError(f"Template version {template_version} not found")
return versions[template_version]
# Return latest version
latest = max(versions.keys(), key=version.parse)
return versions[latest]
def list_templates(self, provider_model: str) -> Dict[str, list]:
"""
List all available templates and their versions for a provider.model
Returns: {template_name: [version1, version2, ...]}
"""
if provider_model not in self._templates:
raise ValueError(f"Unknown provider.model: {provider_model}")
return {
template_name: sorted(versions.keys(), key=version.parse)
for template_name, versions in self._templates[provider_model].items()
}

View File

@@ -3,6 +3,7 @@ from sqlalchemy.sql.expression import text
from common.extensions import db
from datetime import datetime as dt, timezone as tz
from enum import Enum
from sqlalchemy import event
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.ext.hybrid import hybrid_property
from dateutil.relativedelta import relativedelta
@@ -50,6 +51,7 @@ class License(db.Model):
tenant_id = db.Column(db.Integer, db.ForeignKey('public.tenant.id'), nullable=False)
tier_id = db.Column(db.Integer, db.ForeignKey('public.license_tier.id'),nullable=False) # 'small', 'medium', 'custom'
start_date = db.Column(db.Date, nullable=False)
end_date = db.Column(db.Date, nullable=True)
nr_of_periods = db.Column(db.Integer, nullable=False)
currency = db.Column(db.String(20), nullable=False)
yearly_payment = db.Column(db.Boolean, nullable=False, default=False)
@@ -81,54 +83,25 @@ class License(db.Model):
order_by='LicensePeriod.period_number',
cascade='all, delete-orphan')
@hybrid_property
def end_date(self):
"""
Berekent de einddatum van de licentie op basis van start_date en nr_of_periods.
Elke periode is 1 maand, dus einddatum = startdatum + nr_of_periods maanden - 1 dag
"""
if self.start_date and self.nr_of_periods:
return self.start_date + relativedelta(months=self.nr_of_periods) - relativedelta(days=1)
def calculate_end_date(start_date, nr_of_periods):
"""Utility functie om einddatum te berekenen"""
if start_date and nr_of_periods:
return start_date + relativedelta(months=nr_of_periods) - relativedelta(days=1)
return None
@end_date.expression
def end_date(cls):
"""
SQL expressie versie van de end_date property voor gebruik in queries
"""
return db.func.date_add(
db.func.date_add(
cls.start_date,
db.text(f'INTERVAL cls.nr_of_periods MONTH')
),
db.text('INTERVAL -1 DAY')
)
# Luister naar start_date wijzigingen
@event.listens_for(License.start_date, 'set')
def set_start_date(target, value, oldvalue, initiator):
"""Bijwerken van end_date wanneer start_date wordt aangepast"""
if value and target.nr_of_periods:
target.end_date = calculate_end_date(value, target.nr_of_periods)
def update_configuration(self, **changes):
"""
Update license configuration
These changes will only apply to future periods, not existing ones
Args:
**changes: Dictionary of changes to apply to the license
Returns:
None
"""
allowed_fields = [
'tier_id', 'currency', 'basic_fee', 'max_storage_mb',
'additional_storage_price', 'additional_storage_bucket',
'included_embedding_mb', 'additional_embedding_price', 'additional_embedding_bucket',
'included_interaction_tokens', 'additional_interaction_token_price',
'additional_interaction_bucket', 'overage_embedding', 'overage_interaction',
'additional_storage_allowed', 'additional_embedding_allowed',
'additional_interaction_allowed'
]
# Apply only allowed changes
for key, value in changes.items():
if key in allowed_fields:
setattr(self, key, value)
# Luister naar nr_of_periods wijzigingen
@event.listens_for(License.nr_of_periods, 'set')
def set_nr_of_periods(target, value, oldvalue, initiator):
"""Bijwerken van end_date wanneer nr_of_periods wordt aangepast"""
if value and target.start_date:
target.end_date = calculate_end_date(target.start_date, value)
class LicenseTier(db.Model):
@@ -209,22 +182,22 @@ class LicensePeriod(db.Model):
period_end = db.Column(db.Date, nullable=False)
# License configuration snapshot - copied from license when period is created
currency = db.Column(db.String(20), nullable=False)
basic_fee = db.Column(db.Float, nullable=False)
max_storage_mb = db.Column(db.Integer, nullable=False)
additional_storage_price = db.Column(db.Float, nullable=False)
additional_storage_bucket = db.Column(db.Integer, nullable=False)
included_embedding_mb = db.Column(db.Integer, nullable=False)
additional_embedding_price = db.Column(db.Numeric(10, 4), nullable=False)
additional_embedding_bucket = db.Column(db.Integer, nullable=False)
included_interaction_tokens = db.Column(db.Integer, nullable=False)
additional_interaction_token_price = db.Column(db.Numeric(10, 4), nullable=False)
additional_interaction_bucket = db.Column(db.Integer, nullable=False)
currency = db.Column(db.String(20), nullable=True)
basic_fee = db.Column(db.Float, nullable=True)
max_storage_mb = db.Column(db.Integer, nullable=True)
additional_storage_price = db.Column(db.Float, nullable=True)
additional_storage_bucket = db.Column(db.Integer, nullable=True)
included_embedding_mb = db.Column(db.Integer, nullable=True)
additional_embedding_price = db.Column(db.Numeric(10, 4), nullable=True)
additional_embedding_bucket = db.Column(db.Integer, nullable=True)
included_interaction_tokens = db.Column(db.Integer, nullable=True)
additional_interaction_token_price = db.Column(db.Numeric(10, 4), nullable=True)
additional_interaction_bucket = db.Column(db.Integer, nullable=True)
# Allowance flags - can be changed from False to True within a period
additional_storage_allowed = db.Column(db.Boolean, nullable=False, default=False)
additional_embedding_allowed = db.Column(db.Boolean, nullable=False, default=False)
additional_interaction_allowed = db.Column(db.Boolean, nullable=False, default=False)
additional_storage_allowed = db.Column(db.Boolean, nullable=True, default=False)
additional_embedding_allowed = db.Column(db.Boolean, nullable=True, default=False)
additional_interaction_allowed = db.Column(db.Boolean, nullable=True, default=False)
# Status tracking
status = db.Column(db.Enum(PeriodStatus), nullable=False, default=PeriodStatus.UPCOMING)

View File

@@ -24,29 +24,43 @@ class LicensePeriodServices:
Raises:
EveAIException: and derived classes
"""
try:
current_app.logger.debug(f"Finding current license period for tenant {tenant_id}")
current_date = dt.now(tz.utc).date()
license_period = (db.session.query(LicensePeriod)
.filter_by(tenant_id=tenant_id)
.filter(and_(LicensePeriod.period_start_date <= current_date,
LicensePeriod.period_end_date >= current_date))
.filter(and_(LicensePeriod.period_start <= current_date,
LicensePeriod.period_end >= current_date))
.first())
current_app.logger.debug(f"End searching for license period for tenant {tenant_id} ")
if not license_period:
current_app.logger.debug(f"No license period found for tenant {tenant_id} on date {current_date}")
license_period = LicensePeriodServices._create_next_license_period_for_usage(tenant_id)
current_app.logger.debug(f"Created license period {license_period.id} for tenant {tenant_id}")
if license_period:
current_app.logger.debug(f"Found license period {license_period.id} for tenant {tenant_id} "
f"with status {license_period.status}")
match license_period.status:
case PeriodStatus.UPCOMING:
LicensePeriodServices._complete_last_license_period()
LicensePeriodServices._activate_license_period(license_period)
current_app.logger.debug(f"In upcoming state")
LicensePeriodServices._complete_last_license_period(tenant_id=tenant_id)
current_app.logger.debug(f"Completed last license period for tenant {tenant_id}")
LicensePeriodServices._activate_license_period(license_period=license_period)
current_app.logger.debug(f"Activated license period {license_period.id} for tenant {tenant_id}")
if not license_period.license_usage:
new_license_usage = LicenseUsage()
new_license_usage = LicenseUsage(
tenant_id=tenant_id,
)
new_license_usage.license_period = license_period
try:
db.session.add(new_license_usage)
db.session.commit()
except SQLAlchemyError as e:
db.session.rollback()
current_app.logger.error(
f"Error creating new license usage for license period {license_period.id}: {str(e)}")
f"Error creating new license usage for license period "
f"{license_period.id}: {str(e)}")
raise e
if license_period.status == PeriodStatus.ACTIVE:
return license_period
@@ -63,6 +77,12 @@ class LicensePeriodServices:
return license_period
else:
raise EveAILicensePeriodsExceeded(license_id=None)
except SQLAlchemyError as e:
db.session.rollback()
current_app.logger.error(f"Error finding current license period for tenant {tenant_id}: {str(e)}")
raise e
except Exception as e:
raise e
@staticmethod
def _create_next_license_period_for_usage(tenant_id) -> LicensePeriod:
@@ -87,13 +107,17 @@ class LicensePeriodServices:
if not the_license:
current_app.logger.error(f"No active license found for tenant {tenant_id} on date {current_date}")
raise EveAINoActiveLicense(tenant_id=tenant_id)
else:
current_app.logger.debug(f"Found active license {the_license.id} for tenant {tenant_id} "
f"on date {current_date}")
next_period_number = 1
if the_license.periods:
# If there are existing periods, get the next sequential number
next_period_number = max(p.period_number for p in the_license.periods) + 1
current_app.logger.debug(f"Next period number for tenant {tenant_id} is {next_period_number}")
if next_period_number > the_license.max_periods:
if next_period_number > the_license.nr_of_periods:
raise EveAILicensePeriodsExceeded(license_id=the_license.id)
new_license_period = LicensePeriod(
@@ -103,18 +127,16 @@ class LicensePeriodServices:
period_start=the_license.start_date + relativedelta(months=next_period_number-1),
period_end=the_license.end_date + relativedelta(months=next_period_number, days=-1),
status=PeriodStatus.UPCOMING,
upcoming_at=dt.now(tz.utc),
)
set_logging_information(new_license_period, dt.now(tz.utc))
new_license_usage = LicenseUsage(
license_period=new_license_period,
tenant_id=tenant_id,
)
set_logging_information(new_license_usage, dt.now(tz.utc))
try:
current_app.logger.debug(f"Creating next license period for tenant {tenant_id} ")
db.session.add(new_license_period)
db.session.add(new_license_usage)
db.session.commit()
current_app.logger.info(f"Created next license period for tenant {tenant_id} "
f"with id {new_license_period.id}")
return new_license_period
except SQLAlchemyError as e:
db.session.rollback()
@@ -136,15 +158,17 @@ class LicensePeriodServices:
Raises:
ValueError: If neither license_period_id nor license_period is provided
"""
current_app.logger.debug(f"Activating license period")
if license_period is None and license_period_id is None:
raise ValueError("Either license_period_id or license_period must be provided")
# Get a license period object if only ID was provided
if license_period is None:
current_app.logger.debug(f"Getting license period {license_period_id} to activate")
license_period = LicensePeriod.query.get_or_404(license_period_id)
if license_period.upcoming_at is not None:
license_period.pending_at.upcoming_at = dt.now(tz.utc)
if license_period.pending_at is not None:
license_period.pending_at = dt.now(tz.utc)
license_period.status = PeriodStatus.PENDING
if license_period.prepaid_payment:
# There is a payment received for the given period
@@ -173,7 +197,7 @@ class LicensePeriodServices:
if not license_period.license_usage:
license_period.license_usage = LicenseUsage(
tenant_id=license_period.tenant_id,
license_period=license_period,
license_period_id=license_period.id,
)
license_period.license_usage.recalculate_storage()

View File

@@ -14,7 +14,7 @@ from common.eveai_model.tracked_mistral_embeddings import TrackedMistralAIEmbedd
from common.langchain.tracked_transcription import TrackedOpenAITranscription
from common.models.user import Tenant
from config.model_config import MODEL_CONFIG
from common.extensions import template_manager
from common.extensions import cache_manager
from common.models.document import EmbeddingMistral
from common.utils.eveai_exceptions import EveAITenantNotFound, EveAIInvalidEmbeddingModel
from crewai import LLM
@@ -139,6 +139,19 @@ def process_pdf():
full_model_name = 'mistral-ocr-latest'
def get_template(template_name: str, version: Optional[str] = "1.0") -> tuple[
Any, BaseChatModel | None | ChatOpenAI | ChatMistralAI]:
"""
Get a prompt template
"""
prompt = cache_manager.prompts_config_cache.get_config(template_name, version)
if "llm_model" in prompt:
llm = get_embedding_llm(full_model_name=prompt["llm_model"])
else:
llm = get_embedding_llm()
return prompt["content"], llm
class ModelVariables:
"""Manages model-related variables and configurations"""
@@ -261,31 +274,6 @@ class ModelVariables:
def transcribe(self, *args, **kwargs):
raise DeprecationWarning("Use transcription_model.transcribe() instead")
def get_template(self, template_name: str, version: Optional[str] = None) -> str:
"""
Get a template for the tenant's configured LLM
Args:
template_name: Name of the template to retrieve
version: Optional specific version to retrieve
Returns:
The template content
"""
try:
template = template_manager.get_template(
self._variables['llm_full_model'],
template_name,
version
)
return template.content
except Exception as e:
current_app.logger.error(f"Error getting template {template_name}: {str(e)}")
# Fall back to old template loading if template_manager fails
if template_name in self._variables.get('templates', {}):
return self._variables['templates'][template_name]
raise
# Helper function to get cached model variables
def get_model_variables(tenant_id: int) -> ModelVariables:

View File

@@ -2,6 +2,7 @@
import pytz
from datetime import datetime
from common.utils.nginx_utils import prefixed_url_for as puf
def to_local_time(utc_dt, timezone_str):
@@ -42,6 +43,10 @@ def status_color(status_name):
return colors.get(status_name, 'secondary')
def prefixed_url_for(endpoint):
return puf(endpoint)
def register_filters(app):
"""
Registers custom filters with the Flask app.
@@ -49,3 +54,6 @@ def register_filters(app):
app.jinja_env.filters['to_local_time'] = to_local_time
app.jinja_env.filters['time_difference'] = time_difference
app.jinja_env.filters['status_color'] = status_color
app.jinja_env.filters['prefixed_url_for'] = prefixed_url_for
app.jinja_env.globals['prefixed_url_for'] = prefixed_url_for