diff --git a/common/models/entitlements.py b/common/models/entitlements.py index 5eb28dd..d87d63b 100644 --- a/common/models/entitlements.py +++ b/common/models/entitlements.py @@ -95,10 +95,13 @@ class LicenseUsage(db.Model): license_id = db.Column(db.Integer, db.ForeignKey('public.license.id'), nullable=False) tenant_id = db.Column(db.Integer, db.ForeignKey('public.tenant.id'), nullable=False) storage_mb_used = db.Column(db.Integer, default=0) - storage_tokens_used = db.Column(db.Integer, default=0) embedding_mb_used = db.Column(db.Integer, default=0) - embedding_tokens_used = db.Column(db.Integer, default=0) - interaction_tokens_used = db.Column(db.Integer, default=0) + embedding_prompt_tokens_used = db.Column(db.Integer, default=0) + embedding_completion_tokens_used = db.Column(db.Integer, default=0) + embedding_total_tokens_used = db.Column(db.Integer, default=0) + interaction_prompt_tokens_used = db.Column(db.Integer, default=0) + interaction_completion_tokens_used = db.Column(db.Integer, default=0) + interaction_total_tokens_used = db.Column(db.Integer, default=0) period_start_date = db.Column(db.Date, nullable=False) period_end_date = db.Column(db.Date, nullable=False) diff --git a/common/models/user.py b/common/models/user.py index b698450..93b5d4e 100644 --- a/common/models/user.py +++ b/common/models/user.py @@ -64,6 +64,7 @@ class Tenant(db.Model): # Entitlements currency = db.Column(db.String(20), nullable=True) usage_email = db.Column(db.String(255), nullable=True) + storage_dirty = db.Column(db.Boolean, nullable=True, default=False) # Relations users = db.relationship('User', backref='tenant') diff --git a/common/utils/document_utils.py b/common/utils/document_utils.py index 9cdcc65..17783ae 100644 --- a/common/utils/document_utils.py +++ b/common/utils/document_utils.py @@ -12,6 +12,7 @@ import requests from urllib.parse import urlparse, unquote import os from .eveai_exceptions import EveAIInvalidLanguageException, EveAIDoubleURLException, EveAIUnsupportedFileType +from ..models.user import Tenant def create_document_stack(api_input, file, filename, extension, tenant_id): @@ -81,6 +82,8 @@ def create_version_for_document(document, url, language, user_context, user_meta set_logging_information(new_doc_vers, dt.now(tz.utc)) + mark_tenant_storage_dirty(document.tenant_id) + return new_doc_vers @@ -338,3 +341,10 @@ def refresh_document(doc_id): } return refresh_document_with_info(doc_id, api_input) + + +# Function triggered when a document_version is created or updated +def mark_tenant_storage_dirty(tenant_id): + tenant = db.session.query(Tenant).filter_by(id=tenant_id).first() + tenant.storage_dirty = True + db.session.commit() diff --git a/common/utils/eveai_exceptions.py b/common/utils/eveai_exceptions.py index c1f1ee3..fcd54a4 100644 --- a/common/utils/eveai_exceptions.py +++ b/common/utils/eveai_exceptions.py @@ -34,3 +34,10 @@ class EveAIUnsupportedFileType(EveAIException): super().__init__(message, status_code, payload) +class EveAINoLicenseForTenant(EveAIException): + """Raised when no active license for a tenant is provided""" + + def __init__(self, message="No license for tenant found", status_code=400, payload=None): + super().__init__(message, status_code, payload) + + diff --git a/eveai_app/templates/entitlements/license.html b/eveai_app/templates/entitlements/license.html index 9cb96a1..46d50de 100644 --- a/eveai_app/templates/entitlements/license.html +++ b/eveai_app/templates/entitlements/license.html @@ -11,7 +11,7 @@ {{ form.hidden_tag() }} {% set main_fields = ['start_date', 'end_date', 'currency', 'yearly_payment', 'basic_fee'] %} {% for field in form %} - {{ render_included_field(field, disabled_fields=['currency'], include_fields=main_fields) }} + {{ render_included_field(field, disabled_fields=ext_disabled_fields + ['currency'], include_fields=main_fields) }} {% endfor %}
@@ -40,21 +40,21 @@
{% set storage_fields = ['max_storage_mb', 'additional_storage_price', 'additional_storage_bucket'] %} {% for field in form %} - {{ render_included_field(field, disabled_fields=[], include_fields=storage_fields) }} + {{ render_included_field(field, disabled_fields=ext_disabled_fields, include_fields=storage_fields) }} {% endfor %}
{% set embedding_fields = ['included_embedding_mb', 'additional_embedding_price', 'additional_embedding_bucket', 'overage_embedding'] %} {% for field in form %} - {{ render_included_field(field, disabled_fields=[], include_fields=embedding_fields) }} + {{ render_included_field(field, disabled_fields=ext_disabled_fields, include_fields=embedding_fields) }} {% endfor %}
{% set interaction_fields = ['included_interaction_tokens', 'additional_interaction_token_price', 'additional_interaction_bucket', 'overage_interaction'] %} {% for field in form %} - {{ render_included_field(field, disabled_fields=[], include_fields=interaction_fields) }} + {{ render_included_field(field, disabled_fields=ext_disabled_fields, include_fields=interaction_fields) }} {% endfor %}
diff --git a/eveai_app/templates/entitlements/license_tier.html b/eveai_app/templates/entitlements/license_tier.html index 76f4991..c25f7d4 100644 --- a/eveai_app/templates/entitlements/license_tier.html +++ b/eveai_app/templates/entitlements/license_tier.html @@ -11,7 +11,7 @@ {{ form.hidden_tag() }} {% set main_fields = ['name', 'version', 'start_date', 'end_date', 'basic_fee_d', 'basic_fee_e'] %} {% for field in form %} - {{ render_included_field(field, ext_disabled_fields=[], include_fields=main_fields) }} + {{ render_included_field(field, disabled_fields=[], include_fields=main_fields) }} {% endfor %}
@@ -40,21 +40,21 @@
{% set storage_fields = ['max_storage_mb', 'additional_storage_price_d', 'additional_storage_price_e', 'additional_storage_bucket'] %} {% for field in form %} - {{ render_included_field(field, ext_disabled_fields=[], include_fields=storage_fields) }} + {{ render_included_field(field, disabled_fields=[], include_fields=storage_fields) }} {% endfor %}
{% set embedding_fields = ['included_embedding_mb', 'additional_embedding_price_d', 'additional_embedding_price_e', 'additional_embedding_bucket', 'standard_overage_embedding'] %} {% for field in form %} - {{ render_included_field(field, ext_disabled_fields=[], include_fields=embedding_fields) }} + {{ render_included_field(field, disabled_fields=[], include_fields=embedding_fields) }} {% endfor %}
{% set interaction_fields = ['included_interaction_tokens', 'additional_interaction_token_price_d', 'additional_interaction_token_price_e', 'additional_interaction_bucket', 'standard_overage_interaction'] %} {% for field in form %} - {{ render_included_field(field, ext_disabled_fields=[], include_fields=interaction_fields) }} + {{ render_included_field(field, disabled_fields=[], include_fields=interaction_fields) }} {% endfor %}
diff --git a/eveai_app/templates/entitlements/view_license_tiers.html b/eveai_app/templates/entitlements/view_license_tiers.html index 6cf0f24..ea922f6 100644 --- a/eveai_app/templates/entitlements/view_license_tiers.html +++ b/eveai_app/templates/entitlements/view_license_tiers.html @@ -7,7 +7,7 @@
- {{ render_selectable_table(headers=["Name", "Version", "Start Date", "End Date"], rows=rows, selectable=True, id="licenseTierTable") }} + {{ render_selectable_table(headers=["ID", "Name", "Version", "Start Date", "End Date"], rows=rows, selectable=True, id="licenseTierTable") }}
diff --git a/eveai_entitlements/__init__.py b/eveai_entitlements/__init__.py new file mode 100644 index 0000000..e99a80a --- /dev/null +++ b/eveai_entitlements/__init__.py @@ -0,0 +1,44 @@ +import logging +import logging.config +from flask import Flask +import os + +from common.utils.celery_utils import make_celery, init_celery +from common.extensions import db, minio_client +from config.logging_config import LOGGING +from config.config import get_config + + +def create_app(config_file=None): + app = Flask(__name__) + + environment = os.getenv('FLASK_ENV', 'development') + + match environment: + case 'development': + app.config.from_object(get_config('dev')) + case 'production': + app.config.from_object(get_config('prod')) + case _: + app.config.from_object(get_config('dev')) + + logging.config.dictConfig(LOGGING) + + register_extensions(app) + + celery = make_celery(app.name, app.config) + init_celery(celery, app) + + from . import tasks + + app.logger.info("EveAI Entitlements Server Started Successfully") + app.logger.info("-------------------------------------------------------------------------------------------------") + + return app, celery + + +def register_extensions(app): + db.init_app(app) + + +app, celery = create_app() diff --git a/eveai_entitlements/tasks.py b/eveai_entitlements/tasks.py new file mode 100644 index 0000000..f0a1c73 --- /dev/null +++ b/eveai_entitlements/tasks.py @@ -0,0 +1,226 @@ +import io +import os +from datetime import datetime as dt, timezone as tz, datetime + +from celery import states +from dateutil.relativedelta import relativedelta +from flask import current_app +from sqlalchemy import or_, and_ +from sqlalchemy.exc import SQLAlchemyError +from common.extensions import db +from common.models.user import Tenant +from common.models.entitlements import BusinessEventLog, LicenseUsage, License +from common.utils.celery_utils import current_celery +from common.utils.eveai_exceptions import EveAINoLicenseForTenant + + +# Healthcheck task +@current_celery.task(name='ping', queue='entitlements') +def ping(): + return 'pong' + + +@current_celery.task(name='update_usages', queue='entitlements') +def update_usages(): + current_timestamp = dt.now(tz.utc) + tenant_ids = get_all_tenant_ids() + for tenant_id in tenant_ids: + tenant = Tenant.query.get(tenant_id) + if tenant.storage_dirty: + recalculate_storage_for_tenant(tenant) + check_and_create_license_usage_for_tenant(tenant_id) + logs = get_logs_for_processing(tenant_id, current_timestamp) + if not logs: + continue # If no logs to be processed, continu to the next tenant + + # Get the min and max timestamp from the logs + min_timestamp = min(log.timestamp for log in logs) + max_timestamp = max(log.timestamp for log in logs) + + # Retrieve relevant LicenseUsage records + license_usages = get_relevant_license_usages(db.session, tenant_id, min_timestamp, max_timestamp) + + # Split logs based on LicenseUsage periods + logs_by_usage = split_logs_by_license_usage(logs, license_usages) + + # Now you can process logs for each LicenseUsage + for license_usage_id, logs in logs_by_usage.items(): + process_logs_for_license_usage(tenant_id, license_usage_id, logs) + + +def get_all_tenant_ids(): + tenant_ids = db.session.query(Tenant.tenant_id).all() + return [tenant_id[0] for tenant_id in tenant_ids] # Extract tenant_id from tuples + + +def check_and_create_license_usage_for_tenant(tenant_id): + current_date = dt.now(tz.utc).date() + license_usages = (db.session.query(LicenseUsage) + .filter_by(tenant_id=tenant_id) + .filter_by(and_(LicenseUsage.period_start_date <= current_date, + LicenseUsage.period_end_date >= current_date)) + .all()) + if not license_usages: + active_license = (db.session.query(License).filter_by(tenant_id=tenant_id) + .filter_by(and_(License.start_date <= current_date, + License.end_date >= current_date)) + .one()) + if not active_license: + current_app.logger.error(f"No License defined for {tenant_id}. " + f"Impossible to calculate license usage.") + raise EveAINoLicenseForTenant(message=f"No License defined for {tenant_id}. " + f"Impossible to calculate license usage.") + + start_date, end_date = calculate_valid_period(current_date, active_license.period_start_date) + new_license_usage = LicenseUsage(period_start_date=start_date, + period_end_date=end_date, + license_id=active_license.id, + tenant_id=tenant_id + ) + try: + db.session.add(new_license_usage) + db.session.commit() + except SQLAlchemyError as e: + db.session.rollback() + current_app.logger.error(f"Error trying to create new license usage for tenant {tenant_id}. " + f"Error: {str(e)}") + raise e + + +def calculate_valid_period(given_date, original_start_date): + # Ensure both dates are of datetime.date type + if isinstance(given_date, datetime): + given_date = given_date.date() + if isinstance(original_start_date, datetime): + original_start_date = original_start_date.date() + + # Step 1: Find the most recent start_date less than or equal to given_date + start_date = original_start_date + while start_date <= given_date: + next_start_date = start_date + relativedelta(months=1) + if next_start_date > given_date: + break + start_date = next_start_date + + # Step 2: Calculate the end_date for this period + end_date = start_date + relativedelta(months=1, days=-1) + + # Ensure the given date falls within the period + if start_date <= given_date <= end_date: + return start_date, end_date + else: + raise ValueError("Given date does not fall within a valid period.") + + +def get_logs_for_processing(tenant_id, end_time_stamp): + return (db.session.query(BusinessEventLog).filter( + BusinessEventLog.tenant_id == tenant_id, + BusinessEventLog.license_usage_id == None, + BusinessEventLog.timestamp <= end_time_stamp, + ).all()) + + +def get_relevant_license_usages(session, tenant_id, min_timestamp, max_timestamp): + # Fetch LicenseUsage records where the log timestamps fall between period_start_date and period_end_date + return session.query(LicenseUsage).filter( + LicenseUsage.tenant_id == tenant_id, + LicenseUsage.period_start_date <= max_timestamp, + LicenseUsage.period_end_date >= min_timestamp + ).order_by(LicenseUsage.period_start_date).all() + + +def split_logs_by_license_usage(logs, license_usages): + # Dictionary to hold logs categorized by LicenseUsage + logs_by_usage = {lu.id: [] for lu in license_usages} + + for log in logs: + # Find the corresponding LicenseUsage for each log based on the timestamp + for license_usage in license_usages: + if license_usage.period_start_date <= log.timestamp <= license_usage.period_end_date: + logs_by_usage[license_usage.id].append(log) + break + + return logs_by_usage + + +def process_logs_for_license_usage(tenant_id, license_usage_id, logs): + # Retrieve the LicenseUsage record + license_usage = db.session.query(LicenseUsage).filter_by(id=license_usage_id).first() + + if not license_usage: + raise ValueError(f"LicenseUsage with id {license_usage_id} not found.") + + # Initialize variables to accumulate usage data + embedding_mb_used = 0 + embedding_prompt_tokens_used = 0 + embedding_completion_tokens_used = 0 + embedding_total_tokens_used = 0 + interaction_prompt_tokens_used = 0 + interaction_completion_tokens_used = 0 + interaction_total_tokens_used = 0 + + # Process each log + for log in logs: + # Case for 'Create Embeddings' event + if log.event_type == 'Create Embeddings': + if log.message == 'Starting Trace for Create Embeddings': + embedding_mb_used += log.document_version_file_size + elif log.message == 'Final LLM Metrics': + embedding_prompt_tokens_used += log.llm_metrics_prompt_tokens + embedding_completion_tokens_used += log.llm_metrics_completion_tokens + embedding_total_tokens_used += log.llm_metrics_total_tokens + + # Case for 'Ask Question' event + elif log.event_type == 'Ask Question': + if log.message == 'Final LLM Metrics': + interaction_prompt_tokens_used += log.llm_metrics_prompt_tokens + interaction_completion_tokens_used += log.llm_metrics_completion_tokens + interaction_total_tokens_used += log.llm_metrics_total_tokens + + # Mark the log as processed by setting the license_usage_id + log.license_usage_id = license_usage_id + + # Update the LicenseUsage record with the accumulated values + license_usage.embedding_mb += embedding_mb_used + license_usage.embedding_prompt_tokens_used += embedding_prompt_tokens_used + license_usage.embedding_completion_tokens_used += embedding_completion_tokens_used + license_usage.embedding_total_tokens_used += embedding_total_tokens_used + license_usage.interaction_prompt_tokens_used += interaction_prompt_tokens_used + license_usage.interaction_completion_tokens_used += interaction_completion_tokens_used + license_usage.interaction_total_tokens_used += interaction_total_tokens_used + + # Commit the updates to the LicenseUsage and log records + try: + db.session.add(license_usage) + db.session.add(logs) + db.session.commit() + except SQLAlchemyError as e: + db.session.rollback() + current_app.logger.error(f"Error trying to update license usage and logs for tenant {tenant_id}. ") + raise e + + +def recalculate_storage_for_tenant(tenant): + # Perform a SUM operation to get the total file size from document_versions + total_storage = db.session.execute(f""" + SELECT SUM(file_size) + FROM {tenant.id}.document_versions + """).scalar() + + # Update the LicenseUsage with the recalculated storage + license_usage = db.session.query(LicenseUsage).filter_by(tenant_id=tenant.id).first() + license_usage.storage_mb = total_storage / (1024 * 1024) # Convert bytes to MB + + # Reset the dirty flag after recalculating + tenant.storage_dirty = False + + # Commit the changes + try: + db.session.add(tenant) + db.session.add(license_usage) + db.session.commit() + except SQLAlchemyError as e: + db.session.rollback() + current_app.logger.error(f"Error trying to update tenant {tenant.id} for Dirty Storage. ") + + diff --git a/migrations/public/versions/8fdd7f2965c1_licenseusage_and_tenant_updates.py b/migrations/public/versions/8fdd7f2965c1_licenseusage_and_tenant_updates.py new file mode 100644 index 0000000..a44498d --- /dev/null +++ b/migrations/public/versions/8fdd7f2965c1_licenseusage_and_tenant_updates.py @@ -0,0 +1,56 @@ +"""LicenseUsage and Tenant updates + +Revision ID: 8fdd7f2965c1 +Revises: 6a7743d08106 +Create Date: 2024-10-08 06:33:50.297396 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '8fdd7f2965c1' +down_revision = '6a7743d08106' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('business_event_log', schema=None) as batch_op: + batch_op.add_column(sa.Column('document_version_file_size', sa.Float(), nullable=True)) + + with op.batch_alter_table('license_usage', schema=None) as batch_op: + batch_op.add_column(sa.Column('embedding_prompt_tokens_used', sa.Integer(), nullable=True)) + batch_op.add_column(sa.Column('embedding_completion_tokens_used', sa.Integer(), nullable=True)) + batch_op.add_column(sa.Column('embedding_total_tokens_used', sa.Integer(), nullable=True)) + batch_op.add_column(sa.Column('interaction_prompt_tokens_used', sa.Integer(), nullable=True)) + batch_op.add_column(sa.Column('interaction_completion_tokens_used', sa.Integer(), nullable=True)) + batch_op.add_column(sa.Column('interaction_total_tokens_used', sa.Integer(), nullable=True)) + batch_op.drop_column('interaction_tokens_used') + + with op.batch_alter_table('tenant', schema=None) as batch_op: + batch_op.add_column(sa.Column('storage_dirty', sa.Boolean(), nullable=True)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tenant', schema=None) as batch_op: + batch_op.drop_column('storage_dirty') + + with op.batch_alter_table('license_usage', schema=None) as batch_op: + batch_op.add_column(sa.Column('interaction_tokens_used', sa.INTEGER(), autoincrement=False, nullable=True)) + batch_op.drop_column('interaction_total_tokens_used') + batch_op.drop_column('interaction_completion_tokens_used') + batch_op.drop_column('interaction_prompt_tokens_used') + batch_op.drop_column('embedding_total_tokens_used') + batch_op.drop_column('embedding_completion_tokens_used') + batch_op.drop_column('embedding_prompt_tokens_used') + + with op.batch_alter_table('business_event_log', schema=None) as batch_op: + batch_op.drop_column('document_version_file_size') + + # ### end Alembic commands ### diff --git a/migrations/tenant/env.py b/migrations/tenant/env.py index b8d8317..a14f041 100644 --- a/migrations/tenant/env.py +++ b/migrations/tenant/env.py @@ -50,7 +50,8 @@ target_db = current_app.extensions['migrate'].db def get_public_table_names(): # TODO: This function should include the necessary functionality to automatically retrieve table names - return ['role', 'roles_users', 'tenant', 'user', 'tenant_domain'] + return ['role', 'roles_users', 'tenant', 'user', 'tenant_domain','license_tier', 'license', 'license_usage', + 'business_event_log'] PUBLIC_TABLES = get_public_table_names() diff --git a/migrations/tenant/versions/5a75fb6da7b8_remove_obsolete_fields_from_.py b/migrations/tenant/versions/5a75fb6da7b8_remove_obsolete_fields_from_.py new file mode 100644 index 0000000..6b58ebc --- /dev/null +++ b/migrations/tenant/versions/5a75fb6da7b8_remove_obsolete_fields_from_.py @@ -0,0 +1,31 @@ +"""Remove obsolete fields from DocumentVersion + +Revision ID: 5a75fb6da7b8 +Revises: 322d3cf1f17b +Create Date: 2024-10-08 06:49:57.349346 + +""" +from alembic import op +import sqlalchemy as sa +import pgvector + + +# revision identifiers, used by Alembic. +revision = '5a75fb6da7b8' +down_revision = '322d3cf1f17b' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('document_version', 'file_name') + op.drop_column('document_version', 'file_location') + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('document_version', sa.Column('file_location', sa.VARCHAR(length=255), autoincrement=False, nullable=True)) + op.add_column('document_version', sa.Column('file_name', sa.VARCHAR(length=200), autoincrement=False, nullable=True)) + # ### end Alembic commands ###