import io import os from datetime import datetime as dt, timezone as tz, datetime from celery import states from dateutil.relativedelta import relativedelta from flask import current_app from sqlalchemy import or_, and_, text from sqlalchemy.exc import SQLAlchemyError from common.extensions import db from common.models.user import Tenant from common.models.entitlements import BusinessEventLog, LicenseUsage, License from common.utils.celery_utils import current_celery from common.utils.eveai_exceptions import EveAINoLicenseForTenant, EveAIException, EveAINoActiveLicense from common.utils.database import Database # Healthcheck task @current_celery.task(name='ping', queue='entitlements') def ping(): return 'pong' @current_celery.task(name='update_usages', queue='entitlements') def update_usages(): current_timestamp = dt.now(tz.utc) tenant_ids = get_all_tenant_ids() # List to collect all errors error_list = [] for tenant_id in tenant_ids: if tenant_id == 1: continue try: Database(tenant_id).switch_schema() check_and_create_license_usage_for_tenant(tenant_id) tenant = Tenant.query.get(tenant_id) if tenant.storage_dirty: recalculate_storage_for_tenant(tenant) logs = get_logs_for_processing(tenant_id, current_timestamp) if not logs: continue # If no logs to be processed, continu to the next tenant # Get the min and max timestamp from the logs min_timestamp = min(log.timestamp for log in logs) max_timestamp = max(log.timestamp for log in logs) # Retrieve relevant LicenseUsage records license_usages = get_relevant_license_usages(db.session, tenant_id, min_timestamp, max_timestamp) # Split logs based on LicenseUsage periods logs_by_usage = split_logs_by_license_usage(logs, license_usages) # Now you can process logs for each LicenseUsage for license_usage_id, logs in logs_by_usage.items(): process_logs_for_license_usage(tenant_id, license_usage_id, logs) except Exception as e: error = f"Usage Calculation error for Tenant {tenant_id}: {e}" error_list.append(error) current_app.logger.error(error) continue if error_list: raise Exception('\n'.join(error_list)) return "Update Usages taks completed successfully" @current_celery.task(name='persist_business_events', queue='entitlements') def persist_business_events(log_entries): """ Persist multiple business event logs to the database in a single transaction Args: log_entries: List of log event dictionaries to persist """ try: db_entries = [] for entry in log_entries: event_log = BusinessEventLog( timestamp=entry.pop('timestamp'), event_type=entry.pop('event_type'), tenant_id=entry.pop('tenant_id'), trace_id=entry.pop('trace_id'), span_id=entry.pop('span_id', None), span_name=entry.pop('span_name', None), parent_span_id=entry.pop('parent_span_id', None), document_version_id=entry.pop('document_version_id', None), document_version_file_size=entry.pop('document_version_file_size', None), chat_session_id=entry.pop('chat_session_id', None), interaction_id=entry.pop('interaction_id', None), environment=entry.pop('environment', None), llm_metrics_total_tokens=entry.pop('llm_metrics_total_tokens', None), llm_metrics_prompt_tokens=entry.pop('llm_metrics_prompt_tokens', None), llm_metrics_completion_tokens=entry.pop('llm_metrics_completion_tokens', None), llm_metrics_total_time=entry.pop('llm_metrics_total_time', None), llm_metrics_call_count=entry.pop('llm_metrics_call_count', None), llm_interaction_type=entry.pop('llm_interaction_type', None), message=entry.pop('message', None) ) db_entries.append(event_log) # Perform a bulk insert of all entries db.session.bulk_save_objects(db_entries) db.session.commit() current_app.logger.info(f"Successfully persisted {len(db_entries)} business event logs") except Exception as e: current_app.logger.error(f"Failed to persist business event logs: {e}") db.session.rollback() def get_all_tenant_ids(): tenant_ids = db.session.query(Tenant.id).all() return [tenant_id[0] for tenant_id in tenant_ids] # Extract tenant_id from tuples def check_and_create_license_usage_for_tenant(tenant_id): current_date = dt.now(tz.utc).date() license_usages = (db.session.query(LicenseUsage) .filter_by(tenant_id=tenant_id) .filter(and_(LicenseUsage.period_start_date <= current_date, LicenseUsage.period_end_date >= current_date)) .all()) if not license_usages: active_license = (db.session.query(License).filter_by(tenant_id=tenant_id) .filter(and_(License.start_date <= current_date, License.end_date >= current_date)) .one_or_none()) if not active_license: current_app.logger.error(f"No License defined for {tenant_id}. " f"Impossible to calculate license usage.") raise EveAINoActiveLicense(tenant_id) start_date, end_date = calculate_valid_period(current_date, active_license.start_date) new_license_usage = LicenseUsage(period_start_date=start_date, period_end_date=end_date, license_id=active_license.id, tenant_id=tenant_id ) try: db.session.add(new_license_usage) db.session.commit() except SQLAlchemyError as e: db.session.rollback() current_app.logger.error(f"Error trying to create new license usage for tenant {tenant_id}. " f"Error: {str(e)}") raise e def calculate_valid_period(given_date, original_start_date): # Ensure both dates are of datetime.date type if isinstance(given_date, datetime): given_date = given_date.date() if isinstance(original_start_date, datetime): original_start_date = original_start_date.date() # Step 1: Find the most recent start_date less than or equal to given_date start_date = original_start_date while start_date <= given_date: next_start_date = start_date + relativedelta(months=1) if next_start_date > given_date: break start_date = next_start_date # Step 2: Calculate the end_date for this period end_date = start_date + relativedelta(months=1, days=-1) # Ensure the given date falls within the period if start_date <= given_date <= end_date: return start_date, end_date else: raise ValueError("Given date does not fall within a valid period.") def get_logs_for_processing(tenant_id, end_time_stamp): return (db.session.query(BusinessEventLog).filter( BusinessEventLog.tenant_id == tenant_id, BusinessEventLog.license_usage_id == None, BusinessEventLog.timestamp <= end_time_stamp, ).all()) def get_relevant_license_usages(session, tenant_id, min_timestamp, max_timestamp): # Fetch LicenseUsage records where the log timestamps fall between period_start_date and period_end_date return session.query(LicenseUsage).filter( LicenseUsage.tenant_id == tenant_id, LicenseUsage.period_start_date <= max_timestamp.date(), LicenseUsage.period_end_date >= min_timestamp.date() ).order_by(LicenseUsage.period_start_date).all() def split_logs_by_license_usage(logs, license_usages): # Dictionary to hold logs categorized by LicenseUsage logs_by_usage = {lu.id: [] for lu in license_usages} for log in logs: # Find the corresponding LicenseUsage for each log based on the timestamp for license_usage in license_usages: if license_usage.period_start_date <= log.timestamp.date() <= license_usage.period_end_date: logs_by_usage[license_usage.id].append(log) break return logs_by_usage def process_logs_for_license_usage(tenant_id, license_usage_id, logs): # Retrieve the LicenseUsage record license_usage = db.session.query(LicenseUsage).filter_by(id=license_usage_id).first() if not license_usage: raise ValueError(f"LicenseUsage with id {license_usage_id} not found.") # Initialize variables to accumulate usage data embedding_mb_used = 0 embedding_prompt_tokens_used = 0 embedding_completion_tokens_used = 0 embedding_total_tokens_used = 0 interaction_prompt_tokens_used = 0 interaction_completion_tokens_used = 0 interaction_total_tokens_used = 0 # Process each log for log in logs: # Case for 'Create Embeddings' event if log.event_type == 'Create Embeddings': if log.message == 'Starting Trace for Create Embeddings': embedding_mb_used += log.document_version_file_size elif log.message == 'Final LLM Metrics': embedding_prompt_tokens_used += log.llm_metrics_prompt_tokens embedding_completion_tokens_used += log.llm_metrics_completion_tokens embedding_total_tokens_used += log.llm_metrics_total_tokens # Case for 'Ask Question' event elif log.event_type == 'Ask Question': if log.message == 'Final LLM Metrics': interaction_prompt_tokens_used += log.llm_metrics_prompt_tokens interaction_completion_tokens_used += log.llm_metrics_completion_tokens interaction_total_tokens_used += log.llm_metrics_total_tokens # Case for 'Specialist Execution' event elif log.event_type == 'Execute Specialist': if log.message == 'Final LLM Metrics': if log.span_name == 'Specialist Retrieval': # This is embedding embedding_prompt_tokens_used += log.llm_metrics_prompt_tokens embedding_completion_tokens_used += log.llm_metrics_completion_tokens embedding_total_tokens_used += log.llm_metrics_total_tokens else: # This is an interaction interaction_prompt_tokens_used += log.llm_metrics_prompt_tokens interaction_completion_tokens_used += log.llm_metrics_completion_tokens interaction_total_tokens_used += log.llm_metrics_total_tokens # Mark the log as processed by setting the license_usage_id log.license_usage_id = license_usage_id # Update the LicenseUsage record with the accumulated values license_usage.embedding_mb_used += embedding_mb_used license_usage.embedding_prompt_tokens_used += embedding_prompt_tokens_used license_usage.embedding_completion_tokens_used += embedding_completion_tokens_used license_usage.embedding_total_tokens_used += embedding_total_tokens_used license_usage.interaction_prompt_tokens_used += interaction_prompt_tokens_used license_usage.interaction_completion_tokens_used += interaction_completion_tokens_used license_usage.interaction_total_tokens_used += interaction_total_tokens_used # Commit the updates to the LicenseUsage and log records try: db.session.add(license_usage) for log in logs: db.session.add(log) db.session.commit() except SQLAlchemyError as e: db.session.rollback() current_app.logger.error(f"Error trying to update license usage and logs for tenant {tenant_id}: {e}") raise e def recalculate_storage_for_tenant(tenant): # Perform a SUM operation to get the total file size from document_versions total_storage = db.session.execute(text(f""" SELECT SUM(file_size) FROM document_version """)).scalar() # Update the LicenseUsage with the recalculated storage license_usage = db.session.query(LicenseUsage).filter_by(tenant_id=tenant.id).first() license_usage.storage_mb_used = total_storage # Reset the dirty flag after recalculating tenant.storage_dirty = False # Commit the changes try: db.session.add(tenant) db.session.add(license_usage) db.session.commit() except SQLAlchemyError as e: db.session.rollback() current_app.logger.error(f"Error trying to update tenant {tenant.id} for Dirty Storage. ")