from sqlalchemy.sql.expression import text from common.extensions import db from datetime import datetime as dt, timezone as tz from enum import Enum from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.ext.hybrid import hybrid_property from dateutil.relativedelta import relativedelta from common.utils.database import Database class BusinessEventLog(db.Model): __bind_key__ = 'public' __table_args__ = {'schema': 'public'} id = db.Column(db.Integer, primary_key=True) timestamp = db.Column(db.DateTime, nullable=False) event_type = db.Column(db.String(50), nullable=False) tenant_id = db.Column(db.Integer, nullable=False) trace_id = db.Column(db.String(50), nullable=False) span_id = db.Column(db.String(50)) span_name = db.Column(db.String(255)) parent_span_id = db.Column(db.String(50)) document_version_id = db.Column(db.Integer) document_version_file_size = db.Column(db.Float) specialist_id = db.Column(db.Integer) specialist_type = db.Column(db.String(50)) specialist_type_version = db.Column(db.String(20)) chat_session_id = db.Column(db.String(50)) interaction_id = db.Column(db.Integer) environment = db.Column(db.String(20)) llm_metrics_total_tokens = db.Column(db.Integer) llm_metrics_prompt_tokens = db.Column(db.Integer) llm_metrics_completion_tokens = db.Column(db.Integer) llm_metrics_total_time = db.Column(db.Float) llm_metrics_nr_of_pages = db.Column(db.Integer) llm_metrics_call_count = db.Column(db.Integer) llm_interaction_type = db.Column(db.String(20)) message = db.Column(db.Text) license_usage_id = db.Column(db.Integer, db.ForeignKey('public.license_usage.id'), nullable=True) license_usage = db.relationship('LicenseUsage', backref='events') class License(db.Model): __bind_key__ = 'public' __table_args__ = {'schema': 'public'} id = db.Column(db.Integer, primary_key=True) tenant_id = db.Column(db.Integer, db.ForeignKey('public.tenant.id'), nullable=False) tier_id = db.Column(db.Integer, db.ForeignKey('public.license_tier.id'),nullable=False) # 'small', 'medium', 'custom' start_date = db.Column(db.Date, nullable=False) nr_of_periods = db.Column(db.Integer, nullable=False) currency = db.Column(db.String(20), nullable=False) yearly_payment = db.Column(db.Boolean, nullable=False, default=False) basic_fee = db.Column(db.Float, nullable=False) max_storage_mb = db.Column(db.Integer, nullable=False) additional_storage_price = db.Column(db.Float, nullable=False) additional_storage_bucket = db.Column(db.Integer, nullable=False) included_embedding_mb = db.Column(db.Integer, nullable=False) additional_embedding_price = db.Column(db.Numeric(10, 4), nullable=False) additional_embedding_bucket = db.Column(db.Integer, nullable=False) included_interaction_tokens = db.Column(db.Integer, nullable=False) additional_interaction_token_price = db.Column(db.Numeric(10, 4), nullable=False) additional_interaction_bucket = db.Column(db.Integer, nullable=False) overage_embedding = db.Column(db.Float, nullable=False, default=0) overage_interaction = db.Column(db.Float, nullable=False, default=0) additional_storage_allowed = db.Column(db.Boolean, nullable=False, default=False) additional_embedding_allowed = db.Column(db.Boolean, nullable=False, default=False) additional_interaction_allowed = db.Column(db.Boolean, nullable=False, default=False) # Versioning Information created_at = db.Column(db.DateTime, nullable=True, server_default=db.func.now()) created_by = db.Column(db.Integer, db.ForeignKey('public.user.id'), nullable=True) updated_at = db.Column(db.DateTime, nullable=True, server_default=db.func.now(), onupdate=db.func.now()) updated_by = db.Column(db.Integer, db.ForeignKey('public.user.id'), nullable=True) tenant = db.relationship('Tenant', back_populates='licenses') license_tier = db.relationship('LicenseTier', back_populates='licenses') periods = db.relationship('LicensePeriod', back_populates='license', order_by='LicensePeriod.period_number', cascade='all, delete-orphan') @hybrid_property def end_date(self): """ Berekent de einddatum van de licentie op basis van start_date en nr_of_periods. Elke periode is 1 maand, dus einddatum = startdatum + nr_of_periods maanden - 1 dag """ if self.start_date and self.nr_of_periods: return self.start_date + relativedelta(months=self.nr_of_periods) - relativedelta(days=1) return None @end_date.expression def end_date(cls): """ SQL expressie versie van de end_date property voor gebruik in queries """ return db.func.date_add( db.func.date_add( cls.start_date, db.text(f'INTERVAL cls.nr_of_periods MONTH') ), db.text('INTERVAL -1 DAY') ) def update_configuration(self, **changes): """ Update license configuration These changes will only apply to future periods, not existing ones Args: **changes: Dictionary of changes to apply to the license Returns: None """ allowed_fields = [ 'tier_id', 'currency', 'basic_fee', 'max_storage_mb', 'additional_storage_price', 'additional_storage_bucket', 'included_embedding_mb', 'additional_embedding_price', 'additional_embedding_bucket', 'included_interaction_tokens', 'additional_interaction_token_price', 'additional_interaction_bucket', 'overage_embedding', 'overage_interaction', 'additional_storage_allowed', 'additional_embedding_allowed', 'additional_interaction_allowed' ] # Apply only allowed changes for key, value in changes.items(): if key in allowed_fields: setattr(self, key, value) class LicenseTier(db.Model): __bind_key__ = 'public' __table_args__ = {'schema': 'public'} id = db.Column(db.Integer, primary_key=True) name = db.Column(db.String(50), nullable=False) version = db.Column(db.String(50), nullable=False) start_date = db.Column(db.Date, nullable=False) end_date = db.Column(db.Date, nullable=True) basic_fee_d = db.Column(db.Float, nullable=True) basic_fee_e = db.Column(db.Float, nullable=True) max_storage_mb = db.Column(db.Integer, nullable=False) additional_storage_price_d = db.Column(db.Numeric(10, 4), nullable=False) additional_storage_price_e = db.Column(db.Numeric(10, 4), nullable=False) additional_storage_bucket = db.Column(db.Integer, nullable=False) included_embedding_mb = db.Column(db.Integer, nullable=False) additional_embedding_price_d = db.Column(db.Numeric(10, 4), nullable=False) additional_embedding_price_e = db.Column(db.Numeric(10, 4), nullable=False) additional_embedding_bucket = db.Column(db.Integer, nullable=False) included_interaction_tokens = db.Column(db.Integer, nullable=False) additional_interaction_token_price_d = db.Column(db.Numeric(10, 4), nullable=False) additional_interaction_token_price_e = db.Column(db.Numeric(10, 4), nullable=False) additional_interaction_bucket = db.Column(db.Integer, nullable=False) standard_overage_embedding = db.Column(db.Float, nullable=False, default=0) standard_overage_interaction = db.Column(db.Float, nullable=False, default=0) # Versioning Information created_at = db.Column(db.DateTime, nullable=True, server_default=db.func.now()) created_by = db.Column(db.Integer, db.ForeignKey('public.user.id'), nullable=True) updated_at = db.Column(db.DateTime, nullable=True, server_default=db.func.now(), onupdate=db.func.now()) updated_by = db.Column(db.Integer, db.ForeignKey('public.user.id'), nullable=True) licenses = db.relationship('License', back_populates='license_tier') partner_services = db.relationship('PartnerServiceLicenseTier', back_populates='license_tier') class PartnerServiceLicenseTier(db.Model): __bind_key__ = 'public' __table_args__ = {'schema': 'public'} partner_service_id = db.Column(db.Integer, db.ForeignKey('public.partner_service.id'), primary_key=True, nullable=False) license_tier_id = db.Column(db.Integer, db.ForeignKey('public.license_tier.id'), primary_key=True, nullable=False) # Versioning Information created_at = db.Column(db.DateTime, nullable=True, server_default=db.func.now()) created_by = db.Column(db.Integer, db.ForeignKey('public.user.id'), nullable=True) updated_at = db.Column(db.DateTime, nullable=True, server_default=db.func.now(), onupdate=db.func.now()) updated_by = db.Column(db.Integer, db.ForeignKey('public.user.id'), nullable=True) license_tier = db.relationship('LicenseTier', back_populates='partner_services') partner_service = db.relationship('PartnerService', back_populates='license_tiers') class PeriodStatus(Enum): UPCOMING = "UPCOMING" # The period is still in the future PENDING = "PENDING" # The period is active, but prepaid is not yet received ACTIVE = "ACTIVE" # The period is active and prepaid has been received COMPLETED = "COMPLETED" # The period has been completed, but not yet invoiced INVOICED = "INVOICED" # The period has been completed and invoiced, but overage payment still pending CLOSED = "CLOSED" # The period has been closed, invoiced and fully paid class LicensePeriod(db.Model): __bind_key__ = 'public' __table_args__ = {'schema': 'public'} id = db.Column(db.Integer, primary_key=True) license_id = db.Column(db.Integer, db.ForeignKey('public.license.id'), nullable=False) tenant_id = db.Column(db.Integer, db.ForeignKey('public.tenant.id'), nullable=False) # Period identification period_number = db.Column(db.Integer, nullable=False) period_start = db.Column(db.Date, nullable=False) period_end = db.Column(db.Date, nullable=False) # License configuration snapshot - copied from license when period is created currency = db.Column(db.String(20), nullable=False) basic_fee = db.Column(db.Float, nullable=False) max_storage_mb = db.Column(db.Integer, nullable=False) additional_storage_price = db.Column(db.Float, nullable=False) additional_storage_bucket = db.Column(db.Integer, nullable=False) included_embedding_mb = db.Column(db.Integer, nullable=False) additional_embedding_price = db.Column(db.Numeric(10, 4), nullable=False) additional_embedding_bucket = db.Column(db.Integer, nullable=False) included_interaction_tokens = db.Column(db.Integer, nullable=False) additional_interaction_token_price = db.Column(db.Numeric(10, 4), nullable=False) additional_interaction_bucket = db.Column(db.Integer, nullable=False) # Allowance flags - can be changed from False to True within a period additional_storage_allowed = db.Column(db.Boolean, nullable=False, default=False) additional_embedding_allowed = db.Column(db.Boolean, nullable=False, default=False) additional_interaction_allowed = db.Column(db.Boolean, nullable=False, default=False) # Status tracking status = db.Column(db.Enum(PeriodStatus), nullable=False, default=PeriodStatus.UPCOMING) # State transition timestamps upcoming_at = db.Column(db.DateTime, nullable=True) pending_at = db.Column(db.DateTime, nullable=True) active_at = db.Column(db.DateTime, nullable=True) completed_at = db.Column(db.DateTime, nullable=True) invoiced_at = db.Column(db.DateTime, nullable=True) closed_at = db.Column(db.DateTime, nullable=True) # Standard audit fields created_at = db.Column(db.DateTime, server_default=db.func.now()) updated_at = db.Column(db.DateTime, server_default=db.func.now(), onupdate=db.func.now()) created_by = db.Column(db.Integer, db.ForeignKey('public.user.id')) updated_by = db.Column(db.Integer, db.ForeignKey('public.user.id')) # Relationships license = db.relationship('License', back_populates='periods') license_usage = db.relationship('LicenseUsage', uselist=False, # This makes it one-to-one back_populates='license_period', cascade='all, delete-orphan') payments = db.relationship('Payment', back_populates='license_period') invoices = db.relationship('Invoice', back_populates='license_period', cascade='all, delete-orphan') def update_allowance(self, allowance_type, allow_value, user_id=None): """ Update an allowance flag within a period Only allows transitioning from False to True Args: allowance_type: One of 'storage', 'embedding', or 'interaction' allow_value: The new value (must be True) user_id: User ID performing the update Raises: ValueError: If trying to change from True to False, or invalid allowance type """ field_name = f"additional_{allowance_type}_allowed" # Verify valid field if not hasattr(self, field_name): raise ValueError(f"Invalid allowance type: {allowance_type}") # Get current value current_value = getattr(self, field_name) # Only allow False -> True transition if current_value is True and allow_value is True: # Already True, no change needed return elif allow_value is False: raise ValueError(f"Cannot change {field_name} from {current_value} to False") # Update the field setattr(self, field_name, True) self.updated_at = dt.now(tz.utc) if user_id: self.updated_by = user_id @property def prepaid_invoice(self): """Get the prepaid invoice for this period""" return Invoice.query.filter_by( license_period_id=self.id, invoice_type=PaymentType.PREPAID ).first() @property def overage_invoice(self): """Get the overage invoice for this period""" return Invoice.query.filter_by( license_period_id=self.id, invoice_type=PaymentType.POSTPAID ).first() @property def prepaid_payment(self): """Get the prepaid payment for this period""" return Payment.query.filter_by( license_period_id=self.id, payment_type=PaymentType.PREPAID ).first() @property def overage_payment(self): """Get the overage payment for this period""" return Payment.query.filter_by( license_period_id=self.id, payment_type=PaymentType.POSTPAID ).first() @property def all_invoices(self): """Get all invoices for this period""" return self.invoices @property def all_payments(self): """Get all payments for this period""" return self.payments def transition_status(self, new_status: PeriodStatus, user_id: int = None): """Transition to a new status with proper validation and logging""" if not self.can_transition_to(new_status): raise ValueError(f"Invalid status transition from {self.status} to {new_status}") self.status = new_status self.updated_at = dt.now(tz.utc) if user_id: self.updated_by = user_id # Set appropriate timestamps if new_status == PeriodStatus.ACTIVE and not self.prepaid_received_at: self.prepaid_received_at = dt.now(tz.utc) elif new_status == PeriodStatus.COMPLETED: self.completed_at = dt.now(tz.utc) elif new_status == PeriodStatus.INVOICED: self.invoiced_at = dt.now(tz.utc) elif new_status == PeriodStatus.CLOSED: self.closed_at = dt.now(tz.utc) @property def is_overdue(self): """Check if a prepaid payment is overdue""" return (self.status == PeriodStatus.PENDING and self.period_start <= dt.now(tz.utc).date()) def can_transition_to(self, new_status: PeriodStatus) -> bool: """Check if a status transition is valid""" valid_transitions = { PeriodStatus.UPCOMING: [PeriodStatus.ACTIVE, PeriodStatus.PENDING], PeriodStatus.PENDING: [PeriodStatus.ACTIVE], PeriodStatus.ACTIVE: [PeriodStatus.COMPLETED], PeriodStatus.COMPLETED: [PeriodStatus.INVOICED, PeriodStatus.CLOSED], PeriodStatus.INVOICED: [PeriodStatus.CLOSED], PeriodStatus.CLOSED: [] } return new_status in valid_transitions.get(self.status, []) def __repr__(self): return f'' class LicenseUsage(db.Model): __bind_key__ = 'public' __table_args__ = {'schema': 'public'} id = db.Column(db.Integer, primary_key=True) tenant_id = db.Column(db.Integer, db.ForeignKey('public.tenant.id'), nullable=False) storage_mb_used = db.Column(db.Float, default=0) embedding_mb_used = db.Column(db.Float, default=0) embedding_prompt_tokens_used = db.Column(db.Integer, default=0) embedding_completion_tokens_used = db.Column(db.Integer, default=0) embedding_total_tokens_used = db.Column(db.Integer, default=0) interaction_prompt_tokens_used = db.Column(db.Integer, default=0) interaction_completion_tokens_used = db.Column(db.Integer, default=0) interaction_total_tokens_used = db.Column(db.Integer, default=0) license_period_id = db.Column(db.Integer, db.ForeignKey('public.license_period.id'), nullable=False) # Standard audit fields created_at = db.Column(db.DateTime, server_default=db.func.now()) updated_at = db.Column(db.DateTime, server_default=db.func.now(), onupdate=db.func.now()) created_by = db.Column(db.Integer, db.ForeignKey('public.user.id')) updated_by = db.Column(db.Integer, db.ForeignKey('public.user.id')) license_period = db.relationship('LicensePeriod', back_populates='license_usage') def recalculate_storage(self): Database(self.tenant_id).switch_schema() # Perform a SUM operation to get the total file size from document_versions total_storage = db.session.execute(text(f""" SELECT SUM(file_size) FROM document_version """)).scalar() self.storage_mb_used = total_storage class PaymentType(Enum): PREPAID = "PREPAID" POSTPAID = "POSTPAID" class PaymentStatus(Enum): PENDING = "PENDING" PAID = "PAID" FAILED = "FAILED" CANCELLED = "CANCELLED" class Payment(db.Model): __bind_key__ = 'public' __table_args__ = {'schema': 'public'} id = db.Column(db.Integer, primary_key=True) license_period_id = db.Column(db.Integer, db.ForeignKey('public.license_period.id'), nullable=True) tenant_id = db.Column(db.Integer, db.ForeignKey('public.tenant.id'), nullable=False) # Payment details payment_type = db.Column(db.Enum(PaymentType), nullable=False) amount = db.Column(db.Numeric(10, 2), nullable=False) currency = db.Column(db.String(3), nullable=False) description = db.Column(db.Text, nullable=True) # Status tracking status = db.Column(db.Enum(PaymentStatus), nullable=False, default=PaymentStatus.PENDING) # External provider information external_payment_id = db.Column(db.String(255), nullable=True) payment_method = db.Column(db.String(50), nullable=True) # credit_card, bank_transfer, etc. provider_data = db.Column(JSONB, nullable=True) # Provider-specific data # Payment information paid_at = db.Column(db.DateTime, nullable=True) # Standard audit fields created_at = db.Column(db.DateTime, server_default=db.func.now()) created_by = db.Column(db.Integer, db.ForeignKey('public.user.id')) updated_at = db.Column(db.DateTime, server_default=db.func.now(), onupdate=db.func.now()) updated_by = db.Column(db.Integer, db.ForeignKey('public.user.id')) # Relationships license_period = db.relationship('LicensePeriod', back_populates='payments') invoice = db.relationship('Invoice', back_populates='payment', uselist=False) @property def is_overdue(self): """Check if payment is overdue""" if self.status != PaymentStatus.PENDING: return False # For prepaid payments, check if period start has passed if (self.payment_type == PaymentType.PREPAID and self.license_period_id): return self.license_period.period_start <= dt.now(tz.utc).date() # For postpaid, check against due date (would be on invoice) return False def __repr__(self): return f'' class InvoiceStatus(Enum): DRAFT = "DRAFT" SENT = "SENT" PAID = "PAID" OVERDUE = "OVERDUE" CANCELLED = "CANCELLED" class Invoice(db.Model): __bind_key__ = 'public' __table_args__ = {'schema': 'public'} id = db.Column(db.Integer, primary_key=True) license_period_id = db.Column(db.Integer, db.ForeignKey('public.license_period.id'), nullable=False) payment_id = db.Column(db.Integer, db.ForeignKey('public.payment.id'), nullable=True) tenant_id = db.Column(db.Integer, db.ForeignKey('public.tenant.id'), nullable=False) # Invoice details invoice_type = db.Column(db.Enum(PaymentType), nullable=False) invoice_number = db.Column(db.String(50), unique=True, nullable=False) invoice_date = db.Column(db.Date, nullable=False) due_date = db.Column(db.Date, nullable=False) # Financial details amount = db.Column(db.Numeric(10, 2), nullable=False) currency = db.Column(db.String(3), nullable=False) tax_amount = db.Column(db.Numeric(10, 2), default=0) # Descriptive fields description = db.Column(db.Text, nullable=True) status = db.Column(db.Enum(InvoiceStatus), nullable=False, default=InvoiceStatus.DRAFT) # Timestamps sent_at = db.Column(db.DateTime, nullable=True) paid_at = db.Column(db.DateTime, nullable=True) # Standard audit fields created_at = db.Column(db.DateTime, server_default=db.func.now()) created_by = db.Column(db.Integer, db.ForeignKey('public.user.id')) updated_at = db.Column(db.DateTime, server_default=db.func.now(), onupdate=db.func.now()) updated_by = db.Column(db.Integer, db.ForeignKey('public.user.id')) # Relationships license_period = db.relationship('LicensePeriod', back_populates='invoices') payment = db.relationship('Payment', back_populates='invoice') def __repr__(self): return f'' class LicenseChangeLog(db.Model): """ Log of changes to license configurations Used for auditing and tracking when/why license details changed """ __bind_key__ = 'public' __table_args__ = {'schema': 'public'} id = db.Column(db.Integer, primary_key=True) license_id = db.Column(db.Integer, db.ForeignKey('public.license.id'), nullable=False) changed_at = db.Column(db.DateTime, nullable=False, default=lambda: dt.now(tz.utc)) # What changed field_name = db.Column(db.String(100), nullable=False) old_value = db.Column(db.String(255), nullable=True) new_value = db.Column(db.String(255), nullable=False) # Why it changed reason = db.Column(db.Text, nullable=True) # Standard audit fields created_by = db.Column(db.Integer, db.ForeignKey('public.user.id'), nullable=True) # Relationships license = db.relationship('License', backref=db.backref('change_logs', order_by='LicenseChangeLog.changed_at')) def __repr__(self): return f' {self.new_value}>'