- Correcting mistakes in tenant schema migrations

This commit is contained in:
Josako
2024-09-03 11:50:25 +02:00
parent bcf7d439f3
commit 1fa33c029b
2 changed files with 64 additions and 25 deletions

View File

@@ -1,3 +1,4 @@
import inspect
import logging import logging
from logging.config import fileConfig from logging.config import fileConfig
@@ -5,12 +6,12 @@ from flask import current_app
from alembic import context from alembic import context
from sqlalchemy import NullPool, engine_from_config, text from sqlalchemy import NullPool, engine_from_config, text
import sqlalchemy as sa
from common.models.user import Tenant from sqlalchemy.sql import schema
import pgvector import pgvector
from pgvector.sqlalchemy import Vector from pgvector.sqlalchemy import Vector
from common.models import document, interaction, user
# this is the Alembic Config object, which provides # this is the Alembic Config object, which provides
# access to the values within the .ini file in use. # access to the values within the .ini file in use.
@@ -46,17 +47,34 @@ def get_engine_url():
config.set_main_option('sqlalchemy.url', get_engine_url()) config.set_main_option('sqlalchemy.url', get_engine_url())
target_db = current_app.extensions['migrate'].db target_db = current_app.extensions['migrate'].db
# other values from the config, defined by the needs of env.py,
# can be acquired: def get_public_table_names():
# my_important_option = config.get_main_option("my_important_option") # TODO: This function should include the necessary functionality to automatically retrieve table names
# ... etc. return ['role', 'roles_users', 'tenant', 'user', 'tenant_domain']
PUBLIC_TABLES = get_public_table_names()
logger.info(f"Public tables: {PUBLIC_TABLES}")
def include_object(object, name, type_, reflected, compare_to):
if type_ == "table" and name in PUBLIC_TABLES:
logger.info(f"Excluding public table: {name}")
return False
logger.info(f"Including object: {type_} {name}")
return True
# List of Tenants # List of Tenants
tenants_ = Tenant.query.all() def get_tenant_ids():
if tenants_: from common.models.user import Tenant
tenants = [tenant.id for tenant in tenants_]
else: tenants_ = Tenant.query.all()
tenants = [] if tenants_:
tenants = [tenant.id for tenant in tenants_]
else:
tenants = []
return tenants
def get_metadata(): def get_metadata():
@@ -79,7 +97,10 @@ def run_migrations_offline():
""" """
url = config.get_main_option("sqlalchemy.url") url = config.get_main_option("sqlalchemy.url")
context.configure( context.configure(
url=url, target_metadata=get_metadata(), literal_binds=True url=url,
target_metadata=get_metadata(),
# literal_binds=True,
include_object=include_object,
) )
with context.begin_transaction(): with context.begin_transaction():
@@ -99,18 +120,8 @@ def run_migrations_online():
poolclass=NullPool, poolclass=NullPool,
) )
def process_revision_directives(context, revision, directives):
if config.cmd_opts.autogenerate:
script = directives[0]
if script.upgrade_ops is not None:
# Add import for pgvector and set search path
script.upgrade_ops.ops.insert(0, sa.schema.ExecuteSQLOp(
"SET search_path TO CURRENT_SCHEMA(), public; IMPORT pgvector",
execution_options=None
))
with connectable.connect() as connection: with connectable.connect() as connection:
print(tenants) tenants = get_tenant_ids()
for tenant in tenants: for tenant in tenants:
logger.info(f"Migrating tenant: {tenant}") logger.info(f"Migrating tenant: {tenant}")
# set search path on the connection, which ensures that # set search path on the connection, which ensures that
@@ -127,7 +138,8 @@ def run_migrations_online():
context.configure( context.configure(
connection=connection, connection=connection,
target_metadata=get_metadata(), target_metadata=get_metadata(),
process_revision_directives=process_revision_directives, # literal_binds=True,
include_object=include_object,
) )
with context.begin_transaction(): with context.begin_transaction():

View File

@@ -0,0 +1,27 @@
"""Ensure logging information in document domain does not require user
Revision ID: 43eac8a7a00b
Revises: 5d5437d81041
Create Date: 2024-09-03 09:36:06.541938
"""
from alembic import op
import sqlalchemy as sa
import pgvector
# revision identifiers, used by Alembic.
revision = '43eac8a7a00b'
down_revision = '5d5437d81041'
branch_labels = None
depends_on = None
def upgrade():
# Manual upgrade commands
op.execute('ALTER TABLE document ALTER COLUMN created_by DROP NOT NULL')
def downgrade():
# Manual downgrade commands
op.execute('ALTER TABLE document ALTER COLUMN created_by SET NOT NULL')