- Correcting mistakes in tenant schema migrations

This commit is contained in:
Josako
2024-09-03 11:50:25 +02:00
parent bcf7d439f3
commit 1fa33c029b
2 changed files with 64 additions and 25 deletions

View File

@@ -1,3 +1,4 @@
import inspect
import logging
from logging.config import fileConfig
@@ -5,12 +6,12 @@ from flask import current_app
from alembic import context
from sqlalchemy import NullPool, engine_from_config, text
from common.models.user import Tenant
import sqlalchemy as sa
from sqlalchemy.sql import schema
import pgvector
from pgvector.sqlalchemy import Vector
from common.models import document, interaction, user
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
@@ -46,17 +47,34 @@ def get_engine_url():
config.set_main_option('sqlalchemy.url', get_engine_url())
target_db = current_app.extensions['migrate'].db
# other values from the config, defined by the needs of env.py,
# can be acquired:
# my_important_option = config.get_main_option("my_important_option")
# ... etc.
def get_public_table_names():
# TODO: This function should include the necessary functionality to automatically retrieve table names
return ['role', 'roles_users', 'tenant', 'user', 'tenant_domain']
PUBLIC_TABLES = get_public_table_names()
logger.info(f"Public tables: {PUBLIC_TABLES}")
def include_object(object, name, type_, reflected, compare_to):
if type_ == "table" and name in PUBLIC_TABLES:
logger.info(f"Excluding public table: {name}")
return False
logger.info(f"Including object: {type_} {name}")
return True
# List of Tenants
tenants_ = Tenant.query.all()
if tenants_:
tenants = [tenant.id for tenant in tenants_]
else:
tenants = []
def get_tenant_ids():
from common.models.user import Tenant
tenants_ = Tenant.query.all()
if tenants_:
tenants = [tenant.id for tenant in tenants_]
else:
tenants = []
return tenants
def get_metadata():
@@ -79,7 +97,10 @@ def run_migrations_offline():
"""
url = config.get_main_option("sqlalchemy.url")
context.configure(
url=url, target_metadata=get_metadata(), literal_binds=True
url=url,
target_metadata=get_metadata(),
# literal_binds=True,
include_object=include_object,
)
with context.begin_transaction():
@@ -99,18 +120,8 @@ def run_migrations_online():
poolclass=NullPool,
)
def process_revision_directives(context, revision, directives):
if config.cmd_opts.autogenerate:
script = directives[0]
if script.upgrade_ops is not None:
# Add import for pgvector and set search path
script.upgrade_ops.ops.insert(0, sa.schema.ExecuteSQLOp(
"SET search_path TO CURRENT_SCHEMA(), public; IMPORT pgvector",
execution_options=None
))
with connectable.connect() as connection:
print(tenants)
tenants = get_tenant_ids()
for tenant in tenants:
logger.info(f"Migrating tenant: {tenant}")
# set search path on the connection, which ensures that
@@ -127,7 +138,8 @@ def run_migrations_online():
context.configure(
connection=connection,
target_metadata=get_metadata(),
process_revision_directives=process_revision_directives,
# literal_binds=True,
include_object=include_object,
)
with context.begin_transaction():

View File

@@ -0,0 +1,27 @@
"""Ensure logging information in document domain does not require user
Revision ID: 43eac8a7a00b
Revises: 5d5437d81041
Create Date: 2024-09-03 09:36:06.541938
"""
from alembic import op
import sqlalchemy as sa
import pgvector
# revision identifiers, used by Alembic.
revision = '43eac8a7a00b'
down_revision = '5d5437d81041'
branch_labels = None
depends_on = None
def upgrade():
# Manual upgrade commands
op.execute('ALTER TABLE document ALTER COLUMN created_by DROP NOT NULL')
def downgrade():
# Manual downgrade commands
op.execute('ALTER TABLE document ALTER COLUMN created_by SET NOT NULL')