from sqlalchemy.sql.expression import text from common.extensions import db from datetime import datetime as dt, timezone as tz from enum import Enum from sqlalchemy import event from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.ext.hybrid import hybrid_property from dateutil.relativedelta import relativedelta from common.utils.database import Database class BusinessEventLog(db.Model): __bind_key__ = 'public' __table_args__ = {'schema': 'public'} id = db.Column(db.Integer, primary_key=True) timestamp = db.Column(db.DateTime, nullable=False) event_type = db.Column(db.String(50), nullable=False) tenant_id = db.Column(db.Integer, nullable=False) trace_id = db.Column(db.String(50), nullable=False) span_id = db.Column(db.String(50)) span_name = db.Column(db.String(255)) parent_span_id = db.Column(db.String(50)) document_version_id = db.Column(db.Integer) document_version_file_size = db.Column(db.Float) specialist_id = db.Column(db.Integer) specialist_type = db.Column(db.String(50)) specialist_type_version = db.Column(db.String(20)) chat_session_id = db.Column(db.String(50)) interaction_id = db.Column(db.Integer) environment = db.Column(db.String(20)) llm_metrics_total_tokens = db.Column(db.Integer) llm_metrics_prompt_tokens = db.Column(db.Integer) llm_metrics_completion_tokens = db.Column(db.Integer) llm_metrics_total_time = db.Column(db.Float) llm_metrics_nr_of_pages = db.Column(db.Integer) llm_metrics_call_count = db.Column(db.Integer) llm_interaction_type = db.Column(db.String(20)) message = db.Column(db.Text) license_usage_id = db.Column(db.Integer, db.ForeignKey('public.license_usage.id'), nullable=True) license_usage = db.relationship('LicenseUsage', backref='events') class License(db.Model): __bind_key__ = 'public' __table_args__ = {'schema': 'public'} id = db.Column(db.Integer, primary_key=True) tenant_id = db.Column(db.Integer, db.ForeignKey('public.tenant.id'), nullable=False) tier_id = db.Column(db.Integer, db.ForeignKey('public.license_tier.id'),nullable=False) # 'small', 'medium', 'custom' start_date = db.Column(db.Date, nullable=False) end_date = db.Column(db.Date, nullable=True) nr_of_periods = db.Column(db.Integer, nullable=False) currency = db.Column(db.String(20), nullable=False) yearly_payment = db.Column(db.Boolean, nullable=False, default=False) basic_fee = db.Column(db.Float, nullable=False) max_storage_mb = db.Column(db.Integer, nullable=False) additional_storage_price = db.Column(db.Float, nullable=False) additional_storage_bucket = db.Column(db.Integer, nullable=False) included_embedding_mb = db.Column(db.Integer, nullable=False) additional_embedding_price = db.Column(db.Numeric(10, 4), nullable=False) additional_embedding_bucket = db.Column(db.Integer, nullable=False) included_interaction_tokens = db.Column(db.Integer, nullable=False) additional_interaction_token_price = db.Column(db.Numeric(10, 4), nullable=False) additional_interaction_bucket = db.Column(db.Integer, nullable=False) overage_embedding = db.Column(db.Float, nullable=False, default=0) overage_interaction = db.Column(db.Float, nullable=False, default=0) additional_storage_allowed = db.Column(db.Boolean, nullable=False, default=False) additional_embedding_allowed = db.Column(db.Boolean, nullable=False, default=False) additional_interaction_allowed = db.Column(db.Boolean, nullable=False, default=False) # Versioning Information created_at = db.Column(db.DateTime, nullable=True, server_default=db.func.now()) created_by = db.Column(db.Integer, db.ForeignKey('public.user.id'), nullable=True) updated_at = db.Column(db.DateTime, nullable=True, server_default=db.func.now(), onupdate=db.func.now()) updated_by = db.Column(db.Integer, db.ForeignKey('public.user.id'), nullable=True) tenant = db.relationship('Tenant', back_populates='licenses') license_tier = db.relationship('LicenseTier', back_populates='licenses') periods = db.relationship('LicensePeriod', back_populates='license', order_by='LicensePeriod.period_number', cascade='all, delete-orphan') def calculate_end_date(start_date, nr_of_periods): """Utility functie om einddatum te berekenen""" if start_date and nr_of_periods: return start_date + relativedelta(months=nr_of_periods) - relativedelta(days=1) return None # Luister naar start_date wijzigingen @event.listens_for(License.start_date, 'set') def set_start_date(target, value, oldvalue, initiator): """Bijwerken van end_date wanneer start_date wordt aangepast""" if value and target.nr_of_periods: target.end_date = calculate_end_date(value, target.nr_of_periods) # Luister naar nr_of_periods wijzigingen @event.listens_for(License.nr_of_periods, 'set') def set_nr_of_periods(target, value, oldvalue, initiator): """Bijwerken van end_date wanneer nr_of_periods wordt aangepast""" if value and target.start_date: target.end_date = calculate_end_date(target.start_date, value) class LicenseTier(db.Model): __bind_key__ = 'public' __table_args__ = {'schema': 'public'} id = db.Column(db.Integer, primary_key=True) name = db.Column(db.String(50), nullable=False) version = db.Column(db.String(50), nullable=False) start_date = db.Column(db.Date, nullable=False) end_date = db.Column(db.Date, nullable=True) basic_fee_d = db.Column(db.Float, nullable=True) basic_fee_e = db.Column(db.Float, nullable=True) max_storage_mb = db.Column(db.Integer, nullable=False) additional_storage_price_d = db.Column(db.Numeric(10, 4), nullable=False) additional_storage_price_e = db.Column(db.Numeric(10, 4), nullable=False) additional_storage_bucket = db.Column(db.Integer, nullable=False) included_embedding_mb = db.Column(db.Integer, nullable=False) additional_embedding_price_d = db.Column(db.Numeric(10, 4), nullable=False) additional_embedding_price_e = db.Column(db.Numeric(10, 4), nullable=False) additional_embedding_bucket = db.Column(db.Integer, nullable=False) included_interaction_tokens = db.Column(db.Integer, nullable=False) additional_interaction_token_price_d = db.Column(db.Numeric(10, 4), nullable=False) additional_interaction_token_price_e = db.Column(db.Numeric(10, 4), nullable=False) additional_interaction_bucket = db.Column(db.Integer, nullable=False) standard_overage_embedding = db.Column(db.Float, nullable=False, default=0) standard_overage_interaction = db.Column(db.Float, nullable=False, default=0) # Versioning Information created_at = db.Column(db.DateTime, nullable=True, server_default=db.func.now()) created_by = db.Column(db.Integer, db.ForeignKey('public.user.id'), nullable=True) updated_at = db.Column(db.DateTime, nullable=True, server_default=db.func.now(), onupdate=db.func.now()) updated_by = db.Column(db.Integer, db.ForeignKey('public.user.id'), nullable=True) licenses = db.relationship('License', back_populates='license_tier') partner_services = db.relationship('PartnerServiceLicenseTier', back_populates='license_tier') class PartnerServiceLicenseTier(db.Model): __bind_key__ = 'public' __table_args__ = {'schema': 'public'} partner_service_id = db.Column(db.Integer, db.ForeignKey('public.partner_service.id'), primary_key=True, nullable=False) license_tier_id = db.Column(db.Integer, db.ForeignKey('public.license_tier.id'), primary_key=True, nullable=False) # Versioning Information created_at = db.Column(db.DateTime, nullable=True, server_default=db.func.now()) created_by = db.Column(db.Integer, db.ForeignKey('public.user.id'), nullable=True) updated_at = db.Column(db.DateTime, nullable=True, server_default=db.func.now(), onupdate=db.func.now()) updated_by = db.Column(db.Integer, db.ForeignKey('public.user.id'), nullable=True) license_tier = db.relationship('LicenseTier', back_populates='partner_services') partner_service = db.relationship('PartnerService', back_populates='license_tiers') class PeriodStatus(Enum): UPCOMING = "UPCOMING" # The period is still in the future PENDING = "PENDING" # The period is active, but prepaid is not yet received ACTIVE = "ACTIVE" # The period is active and prepaid has been received COMPLETED = "COMPLETED" # The period has been completed, but not yet invoiced INVOICED = "INVOICED" # The period has been completed and invoiced, but overage payment still pending CLOSED = "CLOSED" # The period has been closed, invoiced and fully paid class LicensePeriod(db.Model): __bind_key__ = 'public' __table_args__ = {'schema': 'public'} id = db.Column(db.Integer, primary_key=True) license_id = db.Column(db.Integer, db.ForeignKey('public.license.id'), nullable=False) tenant_id = db.Column(db.Integer, db.ForeignKey('public.tenant.id'), nullable=False) # Period identification period_number = db.Column(db.Integer, nullable=False) period_start = db.Column(db.Date, nullable=False) period_end = db.Column(db.Date, nullable=False) # License configuration snapshot - copied from license when period is created currency = db.Column(db.String(20), nullable=True) basic_fee = db.Column(db.Float, nullable=True) max_storage_mb = db.Column(db.Integer, nullable=True) additional_storage_price = db.Column(db.Float, nullable=True) additional_storage_bucket = db.Column(db.Integer, nullable=True) included_embedding_mb = db.Column(db.Integer, nullable=True) additional_embedding_price = db.Column(db.Numeric(10, 4), nullable=True) additional_embedding_bucket = db.Column(db.Integer, nullable=True) included_interaction_tokens = db.Column(db.Integer, nullable=True) additional_interaction_token_price = db.Column(db.Numeric(10, 4), nullable=True) additional_interaction_bucket = db.Column(db.Integer, nullable=True) # Allowance flags - can be changed from False to True within a period additional_storage_allowed = db.Column(db.Boolean, nullable=True, default=False) additional_embedding_allowed = db.Column(db.Boolean, nullable=True, default=False) additional_interaction_allowed = db.Column(db.Boolean, nullable=True, default=False) # Status tracking status = db.Column(db.Enum(PeriodStatus), nullable=False, default=PeriodStatus.UPCOMING) # State transition timestamps upcoming_at = db.Column(db.DateTime, nullable=True) pending_at = db.Column(db.DateTime, nullable=True) active_at = db.Column(db.DateTime, nullable=True) completed_at = db.Column(db.DateTime, nullable=True) invoiced_at = db.Column(db.DateTime, nullable=True) closed_at = db.Column(db.DateTime, nullable=True) # Standard audit fields created_at = db.Column(db.DateTime, server_default=db.func.now()) updated_at = db.Column(db.DateTime, server_default=db.func.now(), onupdate=db.func.now()) created_by = db.Column(db.Integer, db.ForeignKey('public.user.id')) updated_by = db.Column(db.Integer, db.ForeignKey('public.user.id')) # Relationships license = db.relationship('License', back_populates='periods') license_usage = db.relationship('LicenseUsage', uselist=False, # This makes it one-to-one back_populates='license_period', cascade='all, delete-orphan') payments = db.relationship('Payment', back_populates='license_period') invoices = db.relationship('Invoice', back_populates='license_period', cascade='all, delete-orphan') def update_allowance(self, allowance_type, allow_value, user_id=None): """ Update an allowance flag within a period Only allows transitioning from False to True Args: allowance_type: One of 'storage', 'embedding', or 'interaction' allow_value: The new value (must be True) user_id: User ID performing the update Raises: ValueError: If trying to change from True to False, or invalid allowance type """ field_name = f"additional_{allowance_type}_allowed" # Verify valid field if not hasattr(self, field_name): raise ValueError(f"Invalid allowance type: {allowance_type}") # Get current value current_value = getattr(self, field_name) # Only allow False -> True transition if current_value is True and allow_value is True: # Already True, no change needed return elif allow_value is False: raise ValueError(f"Cannot change {field_name} from {current_value} to False") # Update the field setattr(self, field_name, True) self.updated_at = dt.now(tz.utc) if user_id: self.updated_by = user_id @property def prepaid_invoice(self): """Get the prepaid invoice for this period""" return Invoice.query.filter_by( license_period_id=self.id, invoice_type=PaymentType.PREPAID ).first() @property def overage_invoice(self): """Get the overage invoice for this period""" return Invoice.query.filter_by( license_period_id=self.id, invoice_type=PaymentType.POSTPAID ).first() @property def prepaid_payment(self): """Get the prepaid payment for this period""" return Payment.query.filter_by( license_period_id=self.id, payment_type=PaymentType.PREPAID ).first() @property def overage_payment(self): """Get the overage payment for this period""" return Payment.query.filter_by( license_period_id=self.id, payment_type=PaymentType.POSTPAID ).first() @property def all_invoices(self): """Get all invoices for this period""" return self.invoices @property def all_payments(self): """Get all payments for this period""" return self.payments def transition_status(self, new_status: PeriodStatus, user_id: int = None): """Transition to a new status with proper validation and logging""" if not self.can_transition_to(new_status): raise ValueError(f"Invalid status transition from {self.status} to {new_status}") self.status = new_status self.updated_at = dt.now(tz.utc) if user_id: self.updated_by = user_id # Set appropriate timestamps if new_status == PeriodStatus.ACTIVE and not self.prepaid_received_at: self.prepaid_received_at = dt.now(tz.utc) elif new_status == PeriodStatus.COMPLETED: self.completed_at = dt.now(tz.utc) elif new_status == PeriodStatus.INVOICED: self.invoiced_at = dt.now(tz.utc) elif new_status == PeriodStatus.CLOSED: self.closed_at = dt.now(tz.utc) @property def is_overdue(self): """Check if a prepaid payment is overdue""" return (self.status == PeriodStatus.PENDING and self.period_start <= dt.now(tz.utc).date()) def can_transition_to(self, new_status: PeriodStatus) -> bool: """Check if a status transition is valid""" valid_transitions = { PeriodStatus.UPCOMING: [PeriodStatus.ACTIVE, PeriodStatus.PENDING], PeriodStatus.PENDING: [PeriodStatus.ACTIVE], PeriodStatus.ACTIVE: [PeriodStatus.COMPLETED], PeriodStatus.COMPLETED: [PeriodStatus.INVOICED, PeriodStatus.CLOSED], PeriodStatus.INVOICED: [PeriodStatus.CLOSED], PeriodStatus.CLOSED: [] } return new_status in valid_transitions.get(self.status, []) def __repr__(self): return f'' class LicenseUsage(db.Model): __bind_key__ = 'public' __table_args__ = {'schema': 'public'} id = db.Column(db.Integer, primary_key=True) tenant_id = db.Column(db.Integer, db.ForeignKey('public.tenant.id'), nullable=False) storage_mb_used = db.Column(db.Float, default=0) embedding_mb_used = db.Column(db.Float, default=0) embedding_prompt_tokens_used = db.Column(db.Integer, default=0) embedding_completion_tokens_used = db.Column(db.Integer, default=0) embedding_total_tokens_used = db.Column(db.Integer, default=0) interaction_prompt_tokens_used = db.Column(db.Integer, default=0) interaction_completion_tokens_used = db.Column(db.Integer, default=0) interaction_total_tokens_used = db.Column(db.Integer, default=0) license_period_id = db.Column(db.Integer, db.ForeignKey('public.license_period.id'), nullable=False) # Standard audit fields created_at = db.Column(db.DateTime, server_default=db.func.now()) updated_at = db.Column(db.DateTime, server_default=db.func.now(), onupdate=db.func.now()) created_by = db.Column(db.Integer, db.ForeignKey('public.user.id')) updated_by = db.Column(db.Integer, db.ForeignKey('public.user.id')) license_period = db.relationship('LicensePeriod', back_populates='license_usage') def recalculate_storage(self): Database(self.tenant_id).switch_schema() # Perform a SUM operation to get the total file size from document_versions total_storage = db.session.execute(text(f""" SELECT SUM(file_size) FROM document_version """)).scalar() self.storage_mb_used = total_storage class PaymentType(Enum): PREPAID = "PREPAID" POSTPAID = "POSTPAID" class PaymentStatus(Enum): PENDING = "PENDING" PAID = "PAID" FAILED = "FAILED" CANCELLED = "CANCELLED" class Payment(db.Model): __bind_key__ = 'public' __table_args__ = {'schema': 'public'} id = db.Column(db.Integer, primary_key=True) license_period_id = db.Column(db.Integer, db.ForeignKey('public.license_period.id'), nullable=True) tenant_id = db.Column(db.Integer, db.ForeignKey('public.tenant.id'), nullable=False) # Payment details payment_type = db.Column(db.Enum(PaymentType), nullable=False) amount = db.Column(db.Numeric(10, 2), nullable=False) currency = db.Column(db.String(3), nullable=False) description = db.Column(db.Text, nullable=True) # Status tracking status = db.Column(db.Enum(PaymentStatus), nullable=False, default=PaymentStatus.PENDING) # External provider information external_payment_id = db.Column(db.String(255), nullable=True) payment_method = db.Column(db.String(50), nullable=True) # credit_card, bank_transfer, etc. provider_data = db.Column(JSONB, nullable=True) # Provider-specific data # Payment information paid_at = db.Column(db.DateTime, nullable=True) # Standard audit fields created_at = db.Column(db.DateTime, server_default=db.func.now()) created_by = db.Column(db.Integer, db.ForeignKey('public.user.id')) updated_at = db.Column(db.DateTime, server_default=db.func.now(), onupdate=db.func.now()) updated_by = db.Column(db.Integer, db.ForeignKey('public.user.id')) # Relationships license_period = db.relationship('LicensePeriod', back_populates='payments') invoice = db.relationship('Invoice', back_populates='payment', uselist=False) @property def is_overdue(self): """Check if payment is overdue""" if self.status != PaymentStatus.PENDING: return False # For prepaid payments, check if period start has passed if (self.payment_type == PaymentType.PREPAID and self.license_period_id): return self.license_period.period_start <= dt.now(tz.utc).date() # For postpaid, check against due date (would be on invoice) return False def __repr__(self): return f'' class InvoiceStatus(Enum): DRAFT = "DRAFT" SENT = "SENT" PAID = "PAID" OVERDUE = "OVERDUE" CANCELLED = "CANCELLED" class Invoice(db.Model): __bind_key__ = 'public' __table_args__ = {'schema': 'public'} id = db.Column(db.Integer, primary_key=True) license_period_id = db.Column(db.Integer, db.ForeignKey('public.license_period.id'), nullable=False) payment_id = db.Column(db.Integer, db.ForeignKey('public.payment.id'), nullable=True) tenant_id = db.Column(db.Integer, db.ForeignKey('public.tenant.id'), nullable=False) # Invoice details invoice_type = db.Column(db.Enum(PaymentType), nullable=False) invoice_number = db.Column(db.String(50), unique=True, nullable=False) invoice_date = db.Column(db.Date, nullable=False) due_date = db.Column(db.Date, nullable=False) # Financial details amount = db.Column(db.Numeric(10, 2), nullable=False) currency = db.Column(db.String(3), nullable=False) tax_amount = db.Column(db.Numeric(10, 2), default=0) # Descriptive fields description = db.Column(db.Text, nullable=True) status = db.Column(db.Enum(InvoiceStatus), nullable=False, default=InvoiceStatus.DRAFT) # Timestamps sent_at = db.Column(db.DateTime, nullable=True) paid_at = db.Column(db.DateTime, nullable=True) # Standard audit fields created_at = db.Column(db.DateTime, server_default=db.func.now()) created_by = db.Column(db.Integer, db.ForeignKey('public.user.id')) updated_at = db.Column(db.DateTime, server_default=db.func.now(), onupdate=db.func.now()) updated_by = db.Column(db.Integer, db.ForeignKey('public.user.id')) # Relationships license_period = db.relationship('LicensePeriod', back_populates='invoices') payment = db.relationship('Payment', back_populates='invoice') def __repr__(self): return f'' class LicenseChangeLog(db.Model): """ Log of changes to license configurations Used for auditing and tracking when/why license details changed """ __bind_key__ = 'public' __table_args__ = {'schema': 'public'} id = db.Column(db.Integer, primary_key=True) license_id = db.Column(db.Integer, db.ForeignKey('public.license.id'), nullable=False) changed_at = db.Column(db.DateTime, nullable=False, default=lambda: dt.now(tz.utc)) # What changed field_name = db.Column(db.String(100), nullable=False) old_value = db.Column(db.String(255), nullable=True) new_value = db.Column(db.String(255), nullable=False) # Why it changed reason = db.Column(db.Text, nullable=True) # Standard audit fields created_by = db.Column(db.Integer, db.ForeignKey('public.user.id'), nullable=True) # Relationships license = db.relationship('License', backref=db.backref('change_logs', order_by='LicenseChangeLog.changed_at')) def __repr__(self): return f' {self.new_value}>'