import inspect import logging import sys import os from logging.config import fileConfig from flask import current_app from alembic import context from sqlalchemy import NullPool, engine_from_config, text import sqlalchemy as sa from sqlalchemy.sql import schema import pgvector from pgvector.sqlalchemy import Vector from common.models import document, interaction, user # this is the Alembic Config object, which provides # access to the values within the .ini file in use. config = context.config # Interpret the config file for Python logging. # This line sets up loggers basically. # fileConfig(config.config_file_name) # logger = logging.getLogger('alembic.env') logging.basicConfig( stream=sys.stdout, level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger() # Reset handlers to avoid issues with Alembic overriding them for handler in logger.handlers: logger.removeHandler(handler) # Add a stream handler to output logs to the console console_handler = logging.StreamHandler(sys.stdout) console_handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')) logger.addHandler(console_handler) logger.setLevel(logging.INFO) def get_engine(): try: # this works with Flask-SQLAlchemy<3 and Alchemical return current_app.extensions['migrate'].db.get_engine() except (TypeError, AttributeError): # this works with Flask-SQLAlchemy>=3 return current_app.extensions['migrate'].db.engine def get_engine_url(): try: return get_engine().url.render_as_string(hide_password=False).replace( '%', '%%') except AttributeError: return str(get_engine().url).replace('%', '%%') # add your model's MetaData object here # for 'autogenerate' support # from myapp import mymodel # target_metadata = mymodel.Base.metadata config.set_main_option('sqlalchemy.url', get_engine_url()) target_db = current_app.extensions['migrate'].db def get_public_table_names(): # TODO: This function should include the necessary functionality to automatically retrieve table names return ['role', 'roles_users', 'tenant', 'user', 'tenant_domain','license_tier', 'license', 'license_usage', 'business_event_log', 'tenant_project'] PUBLIC_TABLES = get_public_table_names() logger.info(f"Public tables: {PUBLIC_TABLES}") def include_object(object, name, type_, reflected, compare_to): if type_ == "table" and name in PUBLIC_TABLES: logger.info(f"Excluding public table: {name}") return False logger.info(f"Including object: {type_} {name}") return True # List of Tenants def get_tenant_ids(): from common.models.user import Tenant tenants_ = Tenant.query.all() if tenants_: tenants = [tenant.id for tenant in tenants_] else: tenants = [] return tenants def get_metadata(): if hasattr(target_db, 'metadatas'): return target_db.metadatas[None] return target_db.metadata def run_migrations_offline(): """Run migrations in 'offline' mode. This configures the context with just a URL and not an Engine, though an Engine is acceptable here as well. By skipping the Engine creation we don't even need a DBAPI to be available. Calls to context.execute() here emit the given string to the script output. """ url = config.get_main_option("sqlalchemy.url") context.configure( url=url, target_metadata=get_metadata(), # literal_binds=True, include_object=include_object, ) with context.begin_transaction(): context.run_migrations() def run_migrations_online(): """Updated migration script for handling schema based multi-tenancy ref: - https://alembic.sqlalchemy.org/en/latest/cookbook.html#rudimental-schema-level-multi-tenancy-for-postgresql-databases # noqa """ connectable = engine_from_config( config.get_section(config.config_ini_section), prefix="sqlalchemy.", poolclass=NullPool, ) with connectable.connect() as connection: tenants = get_tenant_ids() for tenant in tenants: try: os.environ['TENANT_ID'] = str(tenant) logger.info(f"Migrating tenant: {tenant}") # set search path on the connection, which ensures that # PostgreSQL will emit all CREATE / ALTER / DROP statements # in terms of this schema by default connection.execute(text(f'SET search_path TO "{tenant}", public')) # in SQLAlchemy v2+ the search path change needs to be committed connection.commit() # make use of non-supported SQLAlchemy attribute to ensure # the dialect reflects tables in terms of the current tenant name connection.dialect.default_schema_name = str(tenant) context.configure( connection=connection, target_metadata=get_metadata(), # literal_binds=True, include_object=include_object, ) with context.begin_transaction(): context.run_migrations() # for checking migrate or upgrade is running if getattr(config.cmd_opts, "autogenerate", False): break except Exception as e: continue if context.is_offline_mode(): raise Exception("Offline migrations are not supported") else: run_migrations_online()