diff --git a/common/extensions.py b/common/extensions.py index ee43bfd..16941eb 100644 --- a/common/extensions.py +++ b/common/extensions.py @@ -10,7 +10,6 @@ from flask_wtf import CSRFProtect from flask_restx import Api from prometheus_flask_exporter import PrometheusMetrics -from .langchain.templates.template_manager import TemplateManager from .utils.cache.eveai_cache_manager import EveAICacheManager from .utils.simple_encryption import SimpleEncryption from .utils.minio_utils import MinioClient @@ -30,6 +29,5 @@ api_rest = Api() simple_encryption = SimpleEncryption() minio_client = MinioClient() metrics = PrometheusMetrics.for_app_factory() -template_manager = TemplateManager() cache_manager = EveAICacheManager() diff --git a/common/langchain/templates/template_manager.py b/common/langchain/templates/template_manager.py deleted file mode 100644 index 6b3c064..0000000 --- a/common/langchain/templates/template_manager.py +++ /dev/null @@ -1,153 +0,0 @@ -import os -import yaml -from typing import Dict, Optional, Any -from packaging import version -from dataclasses import dataclass -from flask import current_app, Flask - -from common.utils.os_utils import get_project_root - - -@dataclass -class PromptTemplate: - """Represents a versioned prompt template""" - content: str - version: str - metadata: Dict[str, Any] - - -class TemplateManager: - """Manages versioned prompt templates""" - - def __init__(self): - self.templates_dir = None - self._templates = None - self.app = None - - def init_app(self, app: Flask) -> None: - # Initialize template manager - base_dir = "/app" - self.templates_dir = os.path.join(base_dir, 'config', 'prompts') - self.app = app - self._templates = self._load_templates() - # Log available templates for each supported model - for llm in app.config['SUPPORTED_LLMS']: - try: - available_templates = self.list_templates(llm) - app.logger.info(f"Loaded templates for {llm}: {available_templates}") - except ValueError: - app.logger.warning(f"No templates found for {llm}") - - def _load_templates(self) -> Dict[str, Dict[str, Dict[str, PromptTemplate]]]: - """ - Load all template versions from the templates directory. - Structure: {provider.model -> {template_name -> {version -> template}}} - Directory structure: - prompts/ - ├── provider/ - │ └── model/ - │ └── template_name/ - │ └── version.yaml - """ - templates = {} - - # Iterate through providers (anthropic, openai) - for provider in os.listdir(self.templates_dir): - provider_path = os.path.join(self.templates_dir, provider) - if not os.path.isdir(provider_path): - continue - - # Iterate through models (claude-3, gpt-4o) - for model in os.listdir(provider_path): - model_path = os.path.join(provider_path, model) - if not os.path.isdir(model_path): - continue - - provider_model = f"{provider}.{model}" - templates[provider_model] = {} - - # Iterate through template types (rag, summary, etc.) - for template_name in os.listdir(model_path): - template_path = os.path.join(model_path, template_name) - if not os.path.isdir(template_path): - continue - - template_versions = {} - # Load all version files for this template - for version_file in os.listdir(template_path): - if not version_file.endswith('.yaml'): - continue - - version_str = version_file[:-5] # Remove .yaml - if not self._is_valid_version(version_str): - current_app.logger.warning( - f"Invalid version format for {template_name}: {version_str}") - continue - - try: - with open(os.path.join(template_path, version_file)) as f: - template_data = yaml.safe_load(f) - # Verify required fields - if not template_data.get('content'): - raise ValueError("Template content is required") - - template_versions[version_str] = PromptTemplate( - content=template_data['content'], - version=version_str, - metadata=template_data.get('metadata', {}) - ) - except Exception as e: - current_app.logger.error( - f"Error loading template {template_name} version {version_str}: {e}") - continue - - if template_versions: - templates[provider_model][template_name] = template_versions - - return templates - - def _is_valid_version(self, version_str: str) -> bool: - """Validate semantic versioning string""" - try: - version.parse(version_str) - return True - except version.InvalidVersion: - return False - - def get_template(self, - provider_model: str, - template_name: str, - template_version: Optional[str] = None) -> PromptTemplate: - """ - Get a specific template version. If version not specified, - returns the latest version. - """ - if provider_model not in self._templates: - raise ValueError(f"Unknown provider.model: {provider_model}") - - if template_name not in self._templates[provider_model]: - raise ValueError(f"Unknown template: {template_name}") - - versions = self._templates[provider_model][template_name] - - if template_version: - if template_version not in versions: - raise ValueError(f"Template version {template_version} not found") - return versions[template_version] - - # Return latest version - latest = max(versions.keys(), key=version.parse) - return versions[latest] - - def list_templates(self, provider_model: str) -> Dict[str, list]: - """ - List all available templates and their versions for a provider.model - Returns: {template_name: [version1, version2, ...]} - """ - if provider_model not in self._templates: - raise ValueError(f"Unknown provider.model: {provider_model}") - - return { - template_name: sorted(versions.keys(), key=version.parse) - for template_name, versions in self._templates[provider_model].items() - } diff --git a/common/models/entitlements.py b/common/models/entitlements.py index 24d6955..5f947ce 100644 --- a/common/models/entitlements.py +++ b/common/models/entitlements.py @@ -3,6 +3,7 @@ from sqlalchemy.sql.expression import text from common.extensions import db from datetime import datetime as dt, timezone as tz from enum import Enum +from sqlalchemy import event from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.ext.hybrid import hybrid_property from dateutil.relativedelta import relativedelta @@ -50,6 +51,7 @@ class License(db.Model): tenant_id = db.Column(db.Integer, db.ForeignKey('public.tenant.id'), nullable=False) tier_id = db.Column(db.Integer, db.ForeignKey('public.license_tier.id'),nullable=False) # 'small', 'medium', 'custom' start_date = db.Column(db.Date, nullable=False) + end_date = db.Column(db.Date, nullable=True) nr_of_periods = db.Column(db.Integer, nullable=False) currency = db.Column(db.String(20), nullable=False) yearly_payment = db.Column(db.Boolean, nullable=False, default=False) @@ -80,55 +82,26 @@ class License(db.Model): periods = db.relationship('LicensePeriod', back_populates='license', order_by='LicensePeriod.period_number', cascade='all, delete-orphan') - - @hybrid_property - def end_date(self): - """ - Berekent de einddatum van de licentie op basis van start_date en nr_of_periods. - Elke periode is 1 maand, dus einddatum = startdatum + nr_of_periods maanden - 1 dag - """ - if self.start_date and self.nr_of_periods: - return self.start_date + relativedelta(months=self.nr_of_periods) - relativedelta(days=1) - return None - @end_date.expression - def end_date(cls): - """ - SQL expressie versie van de end_date property voor gebruik in queries - """ - return db.func.date_add( - db.func.date_add( - cls.start_date, - db.text(f'INTERVAL cls.nr_of_periods MONTH') - ), - db.text('INTERVAL -1 DAY') - ) +def calculate_end_date(start_date, nr_of_periods): + """Utility functie om einddatum te berekenen""" + if start_date and nr_of_periods: + return start_date + relativedelta(months=nr_of_periods) - relativedelta(days=1) + return None - def update_configuration(self, **changes): - """ - Update license configuration - These changes will only apply to future periods, not existing ones - - Args: - **changes: Dictionary of changes to apply to the license - - Returns: - None - """ - allowed_fields = [ - 'tier_id', 'currency', 'basic_fee', 'max_storage_mb', - 'additional_storage_price', 'additional_storage_bucket', - 'included_embedding_mb', 'additional_embedding_price', 'additional_embedding_bucket', - 'included_interaction_tokens', 'additional_interaction_token_price', - 'additional_interaction_bucket', 'overage_embedding', 'overage_interaction', - 'additional_storage_allowed', 'additional_embedding_allowed', - 'additional_interaction_allowed' - ] - - # Apply only allowed changes - for key, value in changes.items(): - if key in allowed_fields: - setattr(self, key, value) +# Luister naar start_date wijzigingen +@event.listens_for(License.start_date, 'set') +def set_start_date(target, value, oldvalue, initiator): + """Bijwerken van end_date wanneer start_date wordt aangepast""" + if value and target.nr_of_periods: + target.end_date = calculate_end_date(value, target.nr_of_periods) + +# Luister naar nr_of_periods wijzigingen +@event.listens_for(License.nr_of_periods, 'set') +def set_nr_of_periods(target, value, oldvalue, initiator): + """Bijwerken van end_date wanneer nr_of_periods wordt aangepast""" + if value and target.start_date: + target.end_date = calculate_end_date(target.start_date, value) class LicenseTier(db.Model): @@ -209,22 +182,22 @@ class LicensePeriod(db.Model): period_end = db.Column(db.Date, nullable=False) # License configuration snapshot - copied from license when period is created - currency = db.Column(db.String(20), nullable=False) - basic_fee = db.Column(db.Float, nullable=False) - max_storage_mb = db.Column(db.Integer, nullable=False) - additional_storage_price = db.Column(db.Float, nullable=False) - additional_storage_bucket = db.Column(db.Integer, nullable=False) - included_embedding_mb = db.Column(db.Integer, nullable=False) - additional_embedding_price = db.Column(db.Numeric(10, 4), nullable=False) - additional_embedding_bucket = db.Column(db.Integer, nullable=False) - included_interaction_tokens = db.Column(db.Integer, nullable=False) - additional_interaction_token_price = db.Column(db.Numeric(10, 4), nullable=False) - additional_interaction_bucket = db.Column(db.Integer, nullable=False) + currency = db.Column(db.String(20), nullable=True) + basic_fee = db.Column(db.Float, nullable=True) + max_storage_mb = db.Column(db.Integer, nullable=True) + additional_storage_price = db.Column(db.Float, nullable=True) + additional_storage_bucket = db.Column(db.Integer, nullable=True) + included_embedding_mb = db.Column(db.Integer, nullable=True) + additional_embedding_price = db.Column(db.Numeric(10, 4), nullable=True) + additional_embedding_bucket = db.Column(db.Integer, nullable=True) + included_interaction_tokens = db.Column(db.Integer, nullable=True) + additional_interaction_token_price = db.Column(db.Numeric(10, 4), nullable=True) + additional_interaction_bucket = db.Column(db.Integer, nullable=True) # Allowance flags - can be changed from False to True within a period - additional_storage_allowed = db.Column(db.Boolean, nullable=False, default=False) - additional_embedding_allowed = db.Column(db.Boolean, nullable=False, default=False) - additional_interaction_allowed = db.Column(db.Boolean, nullable=False, default=False) + additional_storage_allowed = db.Column(db.Boolean, nullable=True, default=False) + additional_embedding_allowed = db.Column(db.Boolean, nullable=True, default=False) + additional_interaction_allowed = db.Column(db.Boolean, nullable=True, default=False) # Status tracking status = db.Column(db.Enum(PeriodStatus), nullable=False, default=PeriodStatus.UPCOMING) diff --git a/common/services/entitlements/license_period_services.py b/common/services/entitlements/license_period_services.py index ef5dc0c..b87596e 100644 --- a/common/services/entitlements/license_period_services.py +++ b/common/services/entitlements/license_period_services.py @@ -24,45 +24,65 @@ class LicensePeriodServices: Raises: EveAIException: and derived classes """ - current_date = dt.now(tz.utc).date() - license_period = (db.session.query(LicensePeriod) - .filter_by(tenant_id=tenant_id) - .filter(and_(LicensePeriod.period_start_date <= current_date, - LicensePeriod.period_end_date >= current_date)) - .first()) - if not license_period: - license_period = LicensePeriodServices._create_next_license_period_for_usage(tenant_id) - if license_period: - match license_period.status: - case PeriodStatus.UPCOMING: - LicensePeriodServices._complete_last_license_period() - LicensePeriodServices._activate_license_period(license_period) - if not license_period.license_usage: - new_license_usage = LicenseUsage() - new_license_usage.license_period = license_period - try: - db.session.add(new_license_usage) - db.session.commit() - except SQLAlchemyError as e: - db.session.rollback() - current_app.logger.error( - f"Error creating new license usage for license period {license_period.id}: {str(e)}") - raise e - if license_period.status == PeriodStatus.ACTIVE: + try: + current_app.logger.debug(f"Finding current license period for tenant {tenant_id}") + current_date = dt.now(tz.utc).date() + license_period = (db.session.query(LicensePeriod) + .filter_by(tenant_id=tenant_id) + .filter(and_(LicensePeriod.period_start <= current_date, + LicensePeriod.period_end >= current_date)) + .first()) + current_app.logger.debug(f"End searching for license period for tenant {tenant_id} ") + if not license_period: + current_app.logger.debug(f"No license period found for tenant {tenant_id} on date {current_date}") + license_period = LicensePeriodServices._create_next_license_period_for_usage(tenant_id) + current_app.logger.debug(f"Created license period {license_period.id} for tenant {tenant_id}") + if license_period: + current_app.logger.debug(f"Found license period {license_period.id} for tenant {tenant_id} " + f"with status {license_period.status}") + match license_period.status: + case PeriodStatus.UPCOMING: + current_app.logger.debug(f"In upcoming state") + LicensePeriodServices._complete_last_license_period(tenant_id=tenant_id) + current_app.logger.debug(f"Completed last license period for tenant {tenant_id}") + LicensePeriodServices._activate_license_period(license_period=license_period) + current_app.logger.debug(f"Activated license period {license_period.id} for tenant {tenant_id}") + if not license_period.license_usage: + new_license_usage = LicenseUsage( + tenant_id=tenant_id, + ) + new_license_usage.license_period = license_period + try: + db.session.add(new_license_usage) + db.session.commit() + + except SQLAlchemyError as e: + db.session.rollback() + current_app.logger.error( + f"Error creating new license usage for license period " + f"{license_period.id}: {str(e)}") + raise e + if license_period.status == PeriodStatus.ACTIVE: + return license_period + else: + # Status is PENDING, so no prepaid payment received. There is no license period we can use. + # We allow for a delay of 5 days before raising an exception. + current_date = dt.now(tz.utc).date() + delta = abs(current_date - license_period.period_start_date) + if delta > timedelta(days=current_app.config.get('ENTITLEMENTS_MAX_PENDING_DAYS', 5)): + raise EveAIPendingLicensePeriod() + case PeriodStatus.ACTIVE: return license_period - else: - # Status is PENDING, so no prepaid payment received. There is no license period we can use. - # We allow for a delay of 5 days before raising an exception. - current_date = dt.now(tz.utc).date() - delta = abs(current_date - license_period.period_start_date) - if delta > timedelta(days=current_app.config.get('ENTITLEMENTS_MAX_PENDING_DAYS', 5)): - raise EveAIPendingLicensePeriod() - case PeriodStatus.ACTIVE: - return license_period - case PeriodStatus.PENDING: - return license_period - else: - raise EveAILicensePeriodsExceeded(license_id=None) + case PeriodStatus.PENDING: + return license_period + else: + raise EveAILicensePeriodsExceeded(license_id=None) + except SQLAlchemyError as e: + db.session.rollback() + current_app.logger.error(f"Error finding current license period for tenant {tenant_id}: {str(e)}") + raise e + except Exception as e: + raise e @staticmethod def _create_next_license_period_for_usage(tenant_id) -> LicensePeriod: @@ -87,13 +107,17 @@ class LicensePeriodServices: if not the_license: current_app.logger.error(f"No active license found for tenant {tenant_id} on date {current_date}") raise EveAINoActiveLicense(tenant_id=tenant_id) + else: + current_app.logger.debug(f"Found active license {the_license.id} for tenant {tenant_id} " + f"on date {current_date}") next_period_number = 1 if the_license.periods: # If there are existing periods, get the next sequential number next_period_number = max(p.period_number for p in the_license.periods) + 1 + current_app.logger.debug(f"Next period number for tenant {tenant_id} is {next_period_number}") - if next_period_number > the_license.max_periods: + if next_period_number > the_license.nr_of_periods: raise EveAILicensePeriodsExceeded(license_id=the_license.id) new_license_period = LicensePeriod( @@ -103,18 +127,16 @@ class LicensePeriodServices: period_start=the_license.start_date + relativedelta(months=next_period_number-1), period_end=the_license.end_date + relativedelta(months=next_period_number, days=-1), status=PeriodStatus.UPCOMING, + upcoming_at=dt.now(tz.utc), ) set_logging_information(new_license_period, dt.now(tz.utc)) - new_license_usage = LicenseUsage( - license_period=new_license_period, - tenant_id=tenant_id, - ) - set_logging_information(new_license_usage, dt.now(tz.utc)) try: + current_app.logger.debug(f"Creating next license period for tenant {tenant_id} ") db.session.add(new_license_period) - db.session.add(new_license_usage) db.session.commit() + current_app.logger.info(f"Created next license period for tenant {tenant_id} " + f"with id {new_license_period.id}") return new_license_period except SQLAlchemyError as e: db.session.rollback() @@ -136,15 +158,17 @@ class LicensePeriodServices: Raises: ValueError: If neither license_period_id nor license_period is provided """ + current_app.logger.debug(f"Activating license period") if license_period is None and license_period_id is None: raise ValueError("Either license_period_id or license_period must be provided") # Get a license period object if only ID was provided if license_period is None: + current_app.logger.debug(f"Getting license period {license_period_id} to activate") license_period = LicensePeriod.query.get_or_404(license_period_id) - if license_period.upcoming_at is not None: - license_period.pending_at.upcoming_at = dt.now(tz.utc) + if license_period.pending_at is not None: + license_period.pending_at = dt.now(tz.utc) license_period.status = PeriodStatus.PENDING if license_period.prepaid_payment: # There is a payment received for the given period @@ -173,7 +197,7 @@ class LicensePeriodServices: if not license_period.license_usage: license_period.license_usage = LicenseUsage( tenant_id=license_period.tenant_id, - license_period=license_period, + license_period_id=license_period.id, ) license_period.license_usage.recalculate_storage() diff --git a/common/utils/model_utils.py b/common/utils/model_utils.py index 13abde7..7ff18d7 100644 --- a/common/utils/model_utils.py +++ b/common/utils/model_utils.py @@ -14,7 +14,7 @@ from common.eveai_model.tracked_mistral_embeddings import TrackedMistralAIEmbedd from common.langchain.tracked_transcription import TrackedOpenAITranscription from common.models.user import Tenant from config.model_config import MODEL_CONFIG -from common.extensions import template_manager +from common.extensions import cache_manager from common.models.document import EmbeddingMistral from common.utils.eveai_exceptions import EveAITenantNotFound, EveAIInvalidEmbeddingModel from crewai import LLM @@ -139,6 +139,19 @@ def process_pdf(): full_model_name = 'mistral-ocr-latest' +def get_template(template_name: str, version: Optional[str] = "1.0") -> tuple[ + Any, BaseChatModel | None | ChatOpenAI | ChatMistralAI]: + """ + Get a prompt template + """ + prompt = cache_manager.prompts_config_cache.get_config(template_name, version) + if "llm_model" in prompt: + llm = get_embedding_llm(full_model_name=prompt["llm_model"]) + else: + llm = get_embedding_llm() + + return prompt["content"], llm + class ModelVariables: """Manages model-related variables and configurations""" @@ -261,31 +274,6 @@ class ModelVariables: def transcribe(self, *args, **kwargs): raise DeprecationWarning("Use transcription_model.transcribe() instead") - def get_template(self, template_name: str, version: Optional[str] = None) -> str: - """ - Get a template for the tenant's configured LLM - - Args: - template_name: Name of the template to retrieve - version: Optional specific version to retrieve - - Returns: - The template content - """ - try: - template = template_manager.get_template( - self._variables['llm_full_model'], - template_name, - version - ) - return template.content - except Exception as e: - current_app.logger.error(f"Error getting template {template_name}: {str(e)}") - # Fall back to old template loading if template_manager fails - if template_name in self._variables.get('templates', {}): - return self._variables['templates'][template_name] - raise - # Helper function to get cached model variables def get_model_variables(tenant_id: int) -> ModelVariables: diff --git a/common/utils/template_filters.py b/common/utils/template_filters.py index f17ed3e..d42b47b 100644 --- a/common/utils/template_filters.py +++ b/common/utils/template_filters.py @@ -2,6 +2,7 @@ import pytz from datetime import datetime +from common.utils.nginx_utils import prefixed_url_for as puf def to_local_time(utc_dt, timezone_str): @@ -42,6 +43,10 @@ def status_color(status_name): return colors.get(status_name, 'secondary') +def prefixed_url_for(endpoint): + return puf(endpoint) + + def register_filters(app): """ Registers custom filters with the Flask app. @@ -49,3 +54,6 @@ def register_filters(app): app.jinja_env.filters['to_local_time'] = to_local_time app.jinja_env.filters['time_difference'] = time_difference app.jinja_env.filters['status_color'] = status_color + app.jinja_env.filters['prefixed_url_for'] = prefixed_url_for + + app.jinja_env.globals['prefixed_url_for'] = prefixed_url_for