diff --git a/migrations/tenant/env.py b/migrations/tenant/env.py index 7ca7306..b8d8317 100644 --- a/migrations/tenant/env.py +++ b/migrations/tenant/env.py @@ -1,3 +1,4 @@ +import inspect import logging from logging.config import fileConfig @@ -5,12 +6,12 @@ from flask import current_app from alembic import context from sqlalchemy import NullPool, engine_from_config, text - -from common.models.user import Tenant - +import sqlalchemy as sa +from sqlalchemy.sql import schema import pgvector from pgvector.sqlalchemy import Vector +from common.models import document, interaction, user # this is the Alembic Config object, which provides # access to the values within the .ini file in use. @@ -46,17 +47,34 @@ def get_engine_url(): config.set_main_option('sqlalchemy.url', get_engine_url()) target_db = current_app.extensions['migrate'].db -# other values from the config, defined by the needs of env.py, -# can be acquired: -# my_important_option = config.get_main_option("my_important_option") -# ... etc. + +def get_public_table_names(): + # TODO: This function should include the necessary functionality to automatically retrieve table names + return ['role', 'roles_users', 'tenant', 'user', 'tenant_domain'] + + +PUBLIC_TABLES = get_public_table_names() +logger.info(f"Public tables: {PUBLIC_TABLES}") + + +def include_object(object, name, type_, reflected, compare_to): + if type_ == "table" and name in PUBLIC_TABLES: + logger.info(f"Excluding public table: {name}") + return False + logger.info(f"Including object: {type_} {name}") + return True + # List of Tenants -tenants_ = Tenant.query.all() -if tenants_: - tenants = [tenant.id for tenant in tenants_] -else: - tenants = [] +def get_tenant_ids(): + from common.models.user import Tenant + + tenants_ = Tenant.query.all() + if tenants_: + tenants = [tenant.id for tenant in tenants_] + else: + tenants = [] + return tenants def get_metadata(): @@ -79,7 +97,10 @@ def run_migrations_offline(): """ url = config.get_main_option("sqlalchemy.url") context.configure( - url=url, target_metadata=get_metadata(), literal_binds=True + url=url, + target_metadata=get_metadata(), + # literal_binds=True, + include_object=include_object, ) with context.begin_transaction(): @@ -99,18 +120,8 @@ def run_migrations_online(): poolclass=NullPool, ) - def process_revision_directives(context, revision, directives): - if config.cmd_opts.autogenerate: - script = directives[0] - if script.upgrade_ops is not None: - # Add import for pgvector and set search path - script.upgrade_ops.ops.insert(0, sa.schema.ExecuteSQLOp( - "SET search_path TO CURRENT_SCHEMA(), public; IMPORT pgvector", - execution_options=None - )) - with connectable.connect() as connection: - print(tenants) + tenants = get_tenant_ids() for tenant in tenants: logger.info(f"Migrating tenant: {tenant}") # set search path on the connection, which ensures that @@ -127,7 +138,8 @@ def run_migrations_online(): context.configure( connection=connection, target_metadata=get_metadata(), - process_revision_directives=process_revision_directives, + # literal_binds=True, + include_object=include_object, ) with context.begin_transaction(): diff --git a/migrations/tenant/versions/43eac8a7a00b_ensure_logging_information_in_document_.py b/migrations/tenant/versions/43eac8a7a00b_ensure_logging_information_in_document_.py new file mode 100644 index 0000000..d7dc517 --- /dev/null +++ b/migrations/tenant/versions/43eac8a7a00b_ensure_logging_information_in_document_.py @@ -0,0 +1,27 @@ +"""Ensure logging information in document domain does not require user + +Revision ID: 43eac8a7a00b +Revises: 5d5437d81041 +Create Date: 2024-09-03 09:36:06.541938 + +""" +from alembic import op +import sqlalchemy as sa +import pgvector + + +# revision identifiers, used by Alembic. +revision = '43eac8a7a00b' +down_revision = '5d5437d81041' +branch_labels = None +depends_on = None + + +def upgrade(): + # Manual upgrade commands + op.execute('ALTER TABLE document ALTER COLUMN created_by DROP NOT NULL') + + +def downgrade(): + # Manual downgrade commands + op.execute('ALTER TABLE document ALTER COLUMN created_by SET NOT NULL')