Extra commit for files in 'common'

- Add functionality to add a default dictionary for configuration fields
- Correct entitlement processing
- Remove get_template functionality from ModelVariables, define it directly with LLM model definition in configuration file.
This commit is contained in:
Josako
2025-05-19 14:12:38 +02:00
parent 28aea85b10
commit d2bb51a4a8
6 changed files with 128 additions and 290 deletions

View File

@@ -10,7 +10,6 @@ from flask_wtf import CSRFProtect
from flask_restx import Api from flask_restx import Api
from prometheus_flask_exporter import PrometheusMetrics from prometheus_flask_exporter import PrometheusMetrics
from .langchain.templates.template_manager import TemplateManager
from .utils.cache.eveai_cache_manager import EveAICacheManager from .utils.cache.eveai_cache_manager import EveAICacheManager
from .utils.simple_encryption import SimpleEncryption from .utils.simple_encryption import SimpleEncryption
from .utils.minio_utils import MinioClient from .utils.minio_utils import MinioClient
@@ -30,6 +29,5 @@ api_rest = Api()
simple_encryption = SimpleEncryption() simple_encryption = SimpleEncryption()
minio_client = MinioClient() minio_client = MinioClient()
metrics = PrometheusMetrics.for_app_factory() metrics = PrometheusMetrics.for_app_factory()
template_manager = TemplateManager()
cache_manager = EveAICacheManager() cache_manager = EveAICacheManager()

View File

@@ -1,153 +0,0 @@
import os
import yaml
from typing import Dict, Optional, Any
from packaging import version
from dataclasses import dataclass
from flask import current_app, Flask
from common.utils.os_utils import get_project_root
@dataclass
class PromptTemplate:
"""Represents a versioned prompt template"""
content: str
version: str
metadata: Dict[str, Any]
class TemplateManager:
"""Manages versioned prompt templates"""
def __init__(self):
self.templates_dir = None
self._templates = None
self.app = None
def init_app(self, app: Flask) -> None:
# Initialize template manager
base_dir = "/app"
self.templates_dir = os.path.join(base_dir, 'config', 'prompts')
self.app = app
self._templates = self._load_templates()
# Log available templates for each supported model
for llm in app.config['SUPPORTED_LLMS']:
try:
available_templates = self.list_templates(llm)
app.logger.info(f"Loaded templates for {llm}: {available_templates}")
except ValueError:
app.logger.warning(f"No templates found for {llm}")
def _load_templates(self) -> Dict[str, Dict[str, Dict[str, PromptTemplate]]]:
"""
Load all template versions from the templates directory.
Structure: {provider.model -> {template_name -> {version -> template}}}
Directory structure:
prompts/
├── provider/
│ └── model/
│ └── template_name/
│ └── version.yaml
"""
templates = {}
# Iterate through providers (anthropic, openai)
for provider in os.listdir(self.templates_dir):
provider_path = os.path.join(self.templates_dir, provider)
if not os.path.isdir(provider_path):
continue
# Iterate through models (claude-3, gpt-4o)
for model in os.listdir(provider_path):
model_path = os.path.join(provider_path, model)
if not os.path.isdir(model_path):
continue
provider_model = f"{provider}.{model}"
templates[provider_model] = {}
# Iterate through template types (rag, summary, etc.)
for template_name in os.listdir(model_path):
template_path = os.path.join(model_path, template_name)
if not os.path.isdir(template_path):
continue
template_versions = {}
# Load all version files for this template
for version_file in os.listdir(template_path):
if not version_file.endswith('.yaml'):
continue
version_str = version_file[:-5] # Remove .yaml
if not self._is_valid_version(version_str):
current_app.logger.warning(
f"Invalid version format for {template_name}: {version_str}")
continue
try:
with open(os.path.join(template_path, version_file)) as f:
template_data = yaml.safe_load(f)
# Verify required fields
if not template_data.get('content'):
raise ValueError("Template content is required")
template_versions[version_str] = PromptTemplate(
content=template_data['content'],
version=version_str,
metadata=template_data.get('metadata', {})
)
except Exception as e:
current_app.logger.error(
f"Error loading template {template_name} version {version_str}: {e}")
continue
if template_versions:
templates[provider_model][template_name] = template_versions
return templates
def _is_valid_version(self, version_str: str) -> bool:
"""Validate semantic versioning string"""
try:
version.parse(version_str)
return True
except version.InvalidVersion:
return False
def get_template(self,
provider_model: str,
template_name: str,
template_version: Optional[str] = None) -> PromptTemplate:
"""
Get a specific template version. If version not specified,
returns the latest version.
"""
if provider_model not in self._templates:
raise ValueError(f"Unknown provider.model: {provider_model}")
if template_name not in self._templates[provider_model]:
raise ValueError(f"Unknown template: {template_name}")
versions = self._templates[provider_model][template_name]
if template_version:
if template_version not in versions:
raise ValueError(f"Template version {template_version} not found")
return versions[template_version]
# Return latest version
latest = max(versions.keys(), key=version.parse)
return versions[latest]
def list_templates(self, provider_model: str) -> Dict[str, list]:
"""
List all available templates and their versions for a provider.model
Returns: {template_name: [version1, version2, ...]}
"""
if provider_model not in self._templates:
raise ValueError(f"Unknown provider.model: {provider_model}")
return {
template_name: sorted(versions.keys(), key=version.parse)
for template_name, versions in self._templates[provider_model].items()
}

View File

@@ -3,6 +3,7 @@ from sqlalchemy.sql.expression import text
from common.extensions import db from common.extensions import db
from datetime import datetime as dt, timezone as tz from datetime import datetime as dt, timezone as tz
from enum import Enum from enum import Enum
from sqlalchemy import event
from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.ext.hybrid import hybrid_property
from dateutil.relativedelta import relativedelta from dateutil.relativedelta import relativedelta
@@ -50,6 +51,7 @@ class License(db.Model):
tenant_id = db.Column(db.Integer, db.ForeignKey('public.tenant.id'), nullable=False) tenant_id = db.Column(db.Integer, db.ForeignKey('public.tenant.id'), nullable=False)
tier_id = db.Column(db.Integer, db.ForeignKey('public.license_tier.id'),nullable=False) # 'small', 'medium', 'custom' tier_id = db.Column(db.Integer, db.ForeignKey('public.license_tier.id'),nullable=False) # 'small', 'medium', 'custom'
start_date = db.Column(db.Date, nullable=False) start_date = db.Column(db.Date, nullable=False)
end_date = db.Column(db.Date, nullable=True)
nr_of_periods = db.Column(db.Integer, nullable=False) nr_of_periods = db.Column(db.Integer, nullable=False)
currency = db.Column(db.String(20), nullable=False) currency = db.Column(db.String(20), nullable=False)
yearly_payment = db.Column(db.Boolean, nullable=False, default=False) yearly_payment = db.Column(db.Boolean, nullable=False, default=False)
@@ -80,55 +82,26 @@ class License(db.Model):
periods = db.relationship('LicensePeriod', back_populates='license', periods = db.relationship('LicensePeriod', back_populates='license',
order_by='LicensePeriod.period_number', order_by='LicensePeriod.period_number',
cascade='all, delete-orphan') cascade='all, delete-orphan')
@hybrid_property
def end_date(self):
"""
Berekent de einddatum van de licentie op basis van start_date en nr_of_periods.
Elke periode is 1 maand, dus einddatum = startdatum + nr_of_periods maanden - 1 dag
"""
if self.start_date and self.nr_of_periods:
return self.start_date + relativedelta(months=self.nr_of_periods) - relativedelta(days=1)
return None
@end_date.expression def calculate_end_date(start_date, nr_of_periods):
def end_date(cls): """Utility functie om einddatum te berekenen"""
""" if start_date and nr_of_periods:
SQL expressie versie van de end_date property voor gebruik in queries return start_date + relativedelta(months=nr_of_periods) - relativedelta(days=1)
""" return None
return db.func.date_add(
db.func.date_add(
cls.start_date,
db.text(f'INTERVAL cls.nr_of_periods MONTH')
),
db.text('INTERVAL -1 DAY')
)
def update_configuration(self, **changes): # Luister naar start_date wijzigingen
""" @event.listens_for(License.start_date, 'set')
Update license configuration def set_start_date(target, value, oldvalue, initiator):
These changes will only apply to future periods, not existing ones """Bijwerken van end_date wanneer start_date wordt aangepast"""
if value and target.nr_of_periods:
Args: target.end_date = calculate_end_date(value, target.nr_of_periods)
**changes: Dictionary of changes to apply to the license
# Luister naar nr_of_periods wijzigingen
Returns: @event.listens_for(License.nr_of_periods, 'set')
None def set_nr_of_periods(target, value, oldvalue, initiator):
""" """Bijwerken van end_date wanneer nr_of_periods wordt aangepast"""
allowed_fields = [ if value and target.start_date:
'tier_id', 'currency', 'basic_fee', 'max_storage_mb', target.end_date = calculate_end_date(target.start_date, value)
'additional_storage_price', 'additional_storage_bucket',
'included_embedding_mb', 'additional_embedding_price', 'additional_embedding_bucket',
'included_interaction_tokens', 'additional_interaction_token_price',
'additional_interaction_bucket', 'overage_embedding', 'overage_interaction',
'additional_storage_allowed', 'additional_embedding_allowed',
'additional_interaction_allowed'
]
# Apply only allowed changes
for key, value in changes.items():
if key in allowed_fields:
setattr(self, key, value)
class LicenseTier(db.Model): class LicenseTier(db.Model):
@@ -209,22 +182,22 @@ class LicensePeriod(db.Model):
period_end = db.Column(db.Date, nullable=False) period_end = db.Column(db.Date, nullable=False)
# License configuration snapshot - copied from license when period is created # License configuration snapshot - copied from license when period is created
currency = db.Column(db.String(20), nullable=False) currency = db.Column(db.String(20), nullable=True)
basic_fee = db.Column(db.Float, nullable=False) basic_fee = db.Column(db.Float, nullable=True)
max_storage_mb = db.Column(db.Integer, nullable=False) max_storage_mb = db.Column(db.Integer, nullable=True)
additional_storage_price = db.Column(db.Float, nullable=False) additional_storage_price = db.Column(db.Float, nullable=True)
additional_storage_bucket = db.Column(db.Integer, nullable=False) additional_storage_bucket = db.Column(db.Integer, nullable=True)
included_embedding_mb = db.Column(db.Integer, nullable=False) included_embedding_mb = db.Column(db.Integer, nullable=True)
additional_embedding_price = db.Column(db.Numeric(10, 4), nullable=False) additional_embedding_price = db.Column(db.Numeric(10, 4), nullable=True)
additional_embedding_bucket = db.Column(db.Integer, nullable=False) additional_embedding_bucket = db.Column(db.Integer, nullable=True)
included_interaction_tokens = db.Column(db.Integer, nullable=False) included_interaction_tokens = db.Column(db.Integer, nullable=True)
additional_interaction_token_price = db.Column(db.Numeric(10, 4), nullable=False) additional_interaction_token_price = db.Column(db.Numeric(10, 4), nullable=True)
additional_interaction_bucket = db.Column(db.Integer, nullable=False) additional_interaction_bucket = db.Column(db.Integer, nullable=True)
# Allowance flags - can be changed from False to True within a period # Allowance flags - can be changed from False to True within a period
additional_storage_allowed = db.Column(db.Boolean, nullable=False, default=False) additional_storage_allowed = db.Column(db.Boolean, nullable=True, default=False)
additional_embedding_allowed = db.Column(db.Boolean, nullable=False, default=False) additional_embedding_allowed = db.Column(db.Boolean, nullable=True, default=False)
additional_interaction_allowed = db.Column(db.Boolean, nullable=False, default=False) additional_interaction_allowed = db.Column(db.Boolean, nullable=True, default=False)
# Status tracking # Status tracking
status = db.Column(db.Enum(PeriodStatus), nullable=False, default=PeriodStatus.UPCOMING) status = db.Column(db.Enum(PeriodStatus), nullable=False, default=PeriodStatus.UPCOMING)

View File

@@ -24,45 +24,65 @@ class LicensePeriodServices:
Raises: Raises:
EveAIException: and derived classes EveAIException: and derived classes
""" """
current_date = dt.now(tz.utc).date() try:
license_period = (db.session.query(LicensePeriod) current_app.logger.debug(f"Finding current license period for tenant {tenant_id}")
.filter_by(tenant_id=tenant_id) current_date = dt.now(tz.utc).date()
.filter(and_(LicensePeriod.period_start_date <= current_date, license_period = (db.session.query(LicensePeriod)
LicensePeriod.period_end_date >= current_date)) .filter_by(tenant_id=tenant_id)
.first()) .filter(and_(LicensePeriod.period_start <= current_date,
if not license_period: LicensePeriod.period_end >= current_date))
license_period = LicensePeriodServices._create_next_license_period_for_usage(tenant_id) .first())
if license_period: current_app.logger.debug(f"End searching for license period for tenant {tenant_id} ")
match license_period.status: if not license_period:
case PeriodStatus.UPCOMING: current_app.logger.debug(f"No license period found for tenant {tenant_id} on date {current_date}")
LicensePeriodServices._complete_last_license_period() license_period = LicensePeriodServices._create_next_license_period_for_usage(tenant_id)
LicensePeriodServices._activate_license_period(license_period) current_app.logger.debug(f"Created license period {license_period.id} for tenant {tenant_id}")
if not license_period.license_usage: if license_period:
new_license_usage = LicenseUsage() current_app.logger.debug(f"Found license period {license_period.id} for tenant {tenant_id} "
new_license_usage.license_period = license_period f"with status {license_period.status}")
try: match license_period.status:
db.session.add(new_license_usage) case PeriodStatus.UPCOMING:
db.session.commit() current_app.logger.debug(f"In upcoming state")
except SQLAlchemyError as e: LicensePeriodServices._complete_last_license_period(tenant_id=tenant_id)
db.session.rollback() current_app.logger.debug(f"Completed last license period for tenant {tenant_id}")
current_app.logger.error( LicensePeriodServices._activate_license_period(license_period=license_period)
f"Error creating new license usage for license period {license_period.id}: {str(e)}") current_app.logger.debug(f"Activated license period {license_period.id} for tenant {tenant_id}")
raise e if not license_period.license_usage:
if license_period.status == PeriodStatus.ACTIVE: new_license_usage = LicenseUsage(
tenant_id=tenant_id,
)
new_license_usage.license_period = license_period
try:
db.session.add(new_license_usage)
db.session.commit()
except SQLAlchemyError as e:
db.session.rollback()
current_app.logger.error(
f"Error creating new license usage for license period "
f"{license_period.id}: {str(e)}")
raise e
if license_period.status == PeriodStatus.ACTIVE:
return license_period
else:
# Status is PENDING, so no prepaid payment received. There is no license period we can use.
# We allow for a delay of 5 days before raising an exception.
current_date = dt.now(tz.utc).date()
delta = abs(current_date - license_period.period_start_date)
if delta > timedelta(days=current_app.config.get('ENTITLEMENTS_MAX_PENDING_DAYS', 5)):
raise EveAIPendingLicensePeriod()
case PeriodStatus.ACTIVE:
return license_period return license_period
else: case PeriodStatus.PENDING:
# Status is PENDING, so no prepaid payment received. There is no license period we can use. return license_period
# We allow for a delay of 5 days before raising an exception. else:
current_date = dt.now(tz.utc).date() raise EveAILicensePeriodsExceeded(license_id=None)
delta = abs(current_date - license_period.period_start_date) except SQLAlchemyError as e:
if delta > timedelta(days=current_app.config.get('ENTITLEMENTS_MAX_PENDING_DAYS', 5)): db.session.rollback()
raise EveAIPendingLicensePeriod() current_app.logger.error(f"Error finding current license period for tenant {tenant_id}: {str(e)}")
case PeriodStatus.ACTIVE: raise e
return license_period except Exception as e:
case PeriodStatus.PENDING: raise e
return license_period
else:
raise EveAILicensePeriodsExceeded(license_id=None)
@staticmethod @staticmethod
def _create_next_license_period_for_usage(tenant_id) -> LicensePeriod: def _create_next_license_period_for_usage(tenant_id) -> LicensePeriod:
@@ -87,13 +107,17 @@ class LicensePeriodServices:
if not the_license: if not the_license:
current_app.logger.error(f"No active license found for tenant {tenant_id} on date {current_date}") current_app.logger.error(f"No active license found for tenant {tenant_id} on date {current_date}")
raise EveAINoActiveLicense(tenant_id=tenant_id) raise EveAINoActiveLicense(tenant_id=tenant_id)
else:
current_app.logger.debug(f"Found active license {the_license.id} for tenant {tenant_id} "
f"on date {current_date}")
next_period_number = 1 next_period_number = 1
if the_license.periods: if the_license.periods:
# If there are existing periods, get the next sequential number # If there are existing periods, get the next sequential number
next_period_number = max(p.period_number for p in the_license.periods) + 1 next_period_number = max(p.period_number for p in the_license.periods) + 1
current_app.logger.debug(f"Next period number for tenant {tenant_id} is {next_period_number}")
if next_period_number > the_license.max_periods: if next_period_number > the_license.nr_of_periods:
raise EveAILicensePeriodsExceeded(license_id=the_license.id) raise EveAILicensePeriodsExceeded(license_id=the_license.id)
new_license_period = LicensePeriod( new_license_period = LicensePeriod(
@@ -103,18 +127,16 @@ class LicensePeriodServices:
period_start=the_license.start_date + relativedelta(months=next_period_number-1), period_start=the_license.start_date + relativedelta(months=next_period_number-1),
period_end=the_license.end_date + relativedelta(months=next_period_number, days=-1), period_end=the_license.end_date + relativedelta(months=next_period_number, days=-1),
status=PeriodStatus.UPCOMING, status=PeriodStatus.UPCOMING,
upcoming_at=dt.now(tz.utc),
) )
set_logging_information(new_license_period, dt.now(tz.utc)) set_logging_information(new_license_period, dt.now(tz.utc))
new_license_usage = LicenseUsage(
license_period=new_license_period,
tenant_id=tenant_id,
)
set_logging_information(new_license_usage, dt.now(tz.utc))
try: try:
current_app.logger.debug(f"Creating next license period for tenant {tenant_id} ")
db.session.add(new_license_period) db.session.add(new_license_period)
db.session.add(new_license_usage)
db.session.commit() db.session.commit()
current_app.logger.info(f"Created next license period for tenant {tenant_id} "
f"with id {new_license_period.id}")
return new_license_period return new_license_period
except SQLAlchemyError as e: except SQLAlchemyError as e:
db.session.rollback() db.session.rollback()
@@ -136,15 +158,17 @@ class LicensePeriodServices:
Raises: Raises:
ValueError: If neither license_period_id nor license_period is provided ValueError: If neither license_period_id nor license_period is provided
""" """
current_app.logger.debug(f"Activating license period")
if license_period is None and license_period_id is None: if license_period is None and license_period_id is None:
raise ValueError("Either license_period_id or license_period must be provided") raise ValueError("Either license_period_id or license_period must be provided")
# Get a license period object if only ID was provided # Get a license period object if only ID was provided
if license_period is None: if license_period is None:
current_app.logger.debug(f"Getting license period {license_period_id} to activate")
license_period = LicensePeriod.query.get_or_404(license_period_id) license_period = LicensePeriod.query.get_or_404(license_period_id)
if license_period.upcoming_at is not None: if license_period.pending_at is not None:
license_period.pending_at.upcoming_at = dt.now(tz.utc) license_period.pending_at = dt.now(tz.utc)
license_period.status = PeriodStatus.PENDING license_period.status = PeriodStatus.PENDING
if license_period.prepaid_payment: if license_period.prepaid_payment:
# There is a payment received for the given period # There is a payment received for the given period
@@ -173,7 +197,7 @@ class LicensePeriodServices:
if not license_period.license_usage: if not license_period.license_usage:
license_period.license_usage = LicenseUsage( license_period.license_usage = LicenseUsage(
tenant_id=license_period.tenant_id, tenant_id=license_period.tenant_id,
license_period=license_period, license_period_id=license_period.id,
) )
license_period.license_usage.recalculate_storage() license_period.license_usage.recalculate_storage()

View File

@@ -14,7 +14,7 @@ from common.eveai_model.tracked_mistral_embeddings import TrackedMistralAIEmbedd
from common.langchain.tracked_transcription import TrackedOpenAITranscription from common.langchain.tracked_transcription import TrackedOpenAITranscription
from common.models.user import Tenant from common.models.user import Tenant
from config.model_config import MODEL_CONFIG from config.model_config import MODEL_CONFIG
from common.extensions import template_manager from common.extensions import cache_manager
from common.models.document import EmbeddingMistral from common.models.document import EmbeddingMistral
from common.utils.eveai_exceptions import EveAITenantNotFound, EveAIInvalidEmbeddingModel from common.utils.eveai_exceptions import EveAITenantNotFound, EveAIInvalidEmbeddingModel
from crewai import LLM from crewai import LLM
@@ -139,6 +139,19 @@ def process_pdf():
full_model_name = 'mistral-ocr-latest' full_model_name = 'mistral-ocr-latest'
def get_template(template_name: str, version: Optional[str] = "1.0") -> tuple[
Any, BaseChatModel | None | ChatOpenAI | ChatMistralAI]:
"""
Get a prompt template
"""
prompt = cache_manager.prompts_config_cache.get_config(template_name, version)
if "llm_model" in prompt:
llm = get_embedding_llm(full_model_name=prompt["llm_model"])
else:
llm = get_embedding_llm()
return prompt["content"], llm
class ModelVariables: class ModelVariables:
"""Manages model-related variables and configurations""" """Manages model-related variables and configurations"""
@@ -261,31 +274,6 @@ class ModelVariables:
def transcribe(self, *args, **kwargs): def transcribe(self, *args, **kwargs):
raise DeprecationWarning("Use transcription_model.transcribe() instead") raise DeprecationWarning("Use transcription_model.transcribe() instead")
def get_template(self, template_name: str, version: Optional[str] = None) -> str:
"""
Get a template for the tenant's configured LLM
Args:
template_name: Name of the template to retrieve
version: Optional specific version to retrieve
Returns:
The template content
"""
try:
template = template_manager.get_template(
self._variables['llm_full_model'],
template_name,
version
)
return template.content
except Exception as e:
current_app.logger.error(f"Error getting template {template_name}: {str(e)}")
# Fall back to old template loading if template_manager fails
if template_name in self._variables.get('templates', {}):
return self._variables['templates'][template_name]
raise
# Helper function to get cached model variables # Helper function to get cached model variables
def get_model_variables(tenant_id: int) -> ModelVariables: def get_model_variables(tenant_id: int) -> ModelVariables:

View File

@@ -2,6 +2,7 @@
import pytz import pytz
from datetime import datetime from datetime import datetime
from common.utils.nginx_utils import prefixed_url_for as puf
def to_local_time(utc_dt, timezone_str): def to_local_time(utc_dt, timezone_str):
@@ -42,6 +43,10 @@ def status_color(status_name):
return colors.get(status_name, 'secondary') return colors.get(status_name, 'secondary')
def prefixed_url_for(endpoint):
return puf(endpoint)
def register_filters(app): def register_filters(app):
""" """
Registers custom filters with the Flask app. Registers custom filters with the Flask app.
@@ -49,3 +54,6 @@ def register_filters(app):
app.jinja_env.filters['to_local_time'] = to_local_time app.jinja_env.filters['to_local_time'] = to_local_time
app.jinja_env.filters['time_difference'] = time_difference app.jinja_env.filters['time_difference'] = time_difference
app.jinja_env.filters['status_color'] = status_color app.jinja_env.filters['status_color'] = status_color
app.jinja_env.filters['prefixed_url_for'] = prefixed_url_for
app.jinja_env.globals['prefixed_url_for'] = prefixed_url_for