diff --git a/eveai_app/__init__.py b/eveai_app/__init__.py
index 5abcc41..fa21eec 100644
--- a/eveai_app/__init__.py
+++ b/eveai_app/__init__.py
@@ -2,6 +2,7 @@ import os
from flask import Flask
from .extensions import db, migrate, bcrypt, bootstrap, jwt
from .models.user import User, Tenant
+from .models.document import Document, DocumentLanguage, DocumentVersion
def create_app(config_file=None):
diff --git a/eveai_app/models/document.py b/eveai_app/models/document.py
index 2d64608..2d0ea19 100644
--- a/eveai_app/models/document.py
+++ b/eveai_app/models/document.py
@@ -1,37 +1,38 @@
from ..extensions import db
+from .user import User, Tenant
class Document(db.Model):
id = db.Column(db.Integer, primary_key=True)
name = db.Column(db.String(100), nullable=False)
- tenant_id = db.Column(db.Integer, db.ForeignKey('tenant.id'), nullable=False)
+ tenant_id = db.Column(db.Integer, db.ForeignKey(Tenant.id), nullable=False)
valid_from = db.Column(db.DateTime, nullable=True)
valid_to = db.Column(db.DateTime, nullable=True)
# Versioning Information
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now())
- created_by = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False)
+ created_by = db.Column(db.Integer, db.ForeignKey(User.id), nullable=False)
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now(), onupdate=db.func.now())
- updated_by = db.Column(db.Integer, db.ForeignKey('user.id'))
+ updated_by = db.Column(db.Integer, db.ForeignKey(User.id))
class DocumentLanguage(db.Model):
id = db.Column(db.Integer, primary_key=True)
- document_id = db.Column(db.Integer, db.ForeignKey('document.id'), nullable=False)
+ document_id = db.Column(db.Integer, db.ForeignKey(Document.id), nullable=False)
language = db.Column(db.String(2), nullable=False)
# Versioning Information
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now())
- created_by = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False)
+ created_by = db.Column(db.Integer, db.ForeignKey(User.id), nullable=False)
class DocumentVersion(db.Model):
id = db.Column(db.Integer, primary_key=True)
- doc_lang_id = db.Column(db.Integer, db.ForeignKey('document_language.id'), nullable=False)
+ doc_lang_id = db.Column(db.Integer, db.ForeignKey(DocumentLanguage.id), nullable=False)
url = db.Column(db.String(200), nullable=True)
embeddings = db.Column(db.PickleType, nullable=True)
# Versioning Information
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now())
- created_by = db.Column(db.Integer, db.ForeignKey('user.id'))
+ created_by = db.Column(db.Integer, db.ForeignKey(User.id))
diff --git a/eveai_app/templates/base.html b/eveai_app/templates/base.html
index 8352c09..5918e5d 100644
--- a/eveai_app/templates/base.html
+++ b/eveai_app/templates/base.html
@@ -24,7 +24,8 @@
{% include 'navbar.html' %}
- {% with messages = get_flashed_messages()%}
+ {% include 'header.html' %}
+ {% with messages = get_flashed_messages() %}
{% if messages%}
{% for message in messages%}
{{message}}
diff --git a/eveai_app/templates/header.html b/eveai_app/templates/header.html
new file mode 100644
index 0000000..b7e584f
--- /dev/null
+++ b/eveai_app/templates/header.html
@@ -0,0 +1,13 @@
+
\ No newline at end of file
diff --git a/eveai_app/templates/navbar.html b/eveai_app/templates/navbar.html
index 5733f81..1ac83df 100644
--- a/eveai_app/templates/navbar.html
+++ b/eveai_app/templates/navbar.html
@@ -1,59 +1,46 @@
-
-
-
- EveAI
-
-
-
-
-
\ No newline at end of file
diff --git a/eveai_app/utils/database.py b/eveai_app/utils/database.py
new file mode 100644
index 0000000..fd2b22f
--- /dev/null
+++ b/eveai_app/utils/database.py
@@ -0,0 +1,89 @@
+"""Database related functions"""
+from os import popen
+from sqlalchemy import text
+from sqlalchemy.schema import CreateSchema
+from sqlalchemy.exc import InternalError
+from sqlalchemy.orm import sessionmaker, scoped_session
+from flask_migrate import heads
+
+from ..extensions import db, migrate
+
+
+class Database:
+ """used for managing tenant databases related operations"""
+
+ def __init__(self, tenant: str) -> None:
+ self.schema = str(tenant)
+
+ def get_engine(self):
+ """create new schema engine"""
+ return db.engine.execution_options(
+ schema_translate_map={None: self.schema}
+ )
+
+ def get_session(self):
+ """To get session of tenant/public database schema for quick use
+
+ returns:
+ session: session of tenant/public database schema
+ """
+ return scoped_session(
+ sessionmaker(bind=self.get_engine(), expire_on_commit=True)
+ )
+
+ def create_schema(self):
+ """create new database schema, mostly used on tenant creation"""
+ try:
+ db.session.execute(CreateSchema(self.schema))
+ db.session.commit()
+ except InternalError:
+ db.session.rollback()
+ db.session.close()
+
+ def create_tables(self):
+ """create tables in for schema"""
+ db.metadata.create_all(self.get_engine())
+
+ def switch_schema(self):
+ """switch between tenant/public database schema"""
+ db.session.execute(text(f'set search_path to "{self.schema}"'))
+ db.session.commit()
+
+ def migrate_tenant_schema(self):
+ """migrate tenant database schema for new tenant"""
+ # Get the current revision for a database.
+ # NOTE: using popen may have a minor performance impact on the application
+ # you can store it in a different table in public schema and use it from there
+ # may be a faster approach
+ # last_revision = heads(directory="migrations/tenant", verbose=True, resolve_dependencies=False)
+ last_revision = popen(".venv/bin/flask db heads -d migrations/tenant").read()
+ print("LAST REVISION")
+ print(last_revision)
+ last_revision = last_revision.splitlines()[-1].split(" ")[0]
+
+ # creating revision table in tenant schema
+ session = self.get_session()
+ session.execute(
+ text(
+ f'CREATE TABLE "{self.schema}".alembic_version (version_num '
+ "VARCHAR(32) NOT NULL)"
+ )
+ )
+ session.commit()
+
+ # Insert last revision to alembic_version table
+ session.execute(
+ text(
+ f'INSERT INTO "{self.schema}".alembic_version (version_num) '
+ "VALUES (:version)"
+ ),
+ {"version": last_revision},
+ )
+ session.commit()
+ session.close()
+
+ def create_tenant_schema(self):
+ """create tenant used for creating new schema and its tables"""
+ self.create_schema()
+ self.create_tables()
+ self.migrate_tenant_schema()
diff --git a/eveai_app/views/user_views.py b/eveai_app/views/user_views.py
index 1b697d9..1f6c282 100644
--- a/eveai_app/views/user_views.py
+++ b/eveai_app/views/user_views.py
@@ -4,6 +4,7 @@ from flask import request, redirect, url_for, flash, render_template, Blueprint
from ..models.user import User, Tenant
from ..extensions import db, bcrypt
from .user_forms import TenantForm, UserForm
+from ..utils.database import Database
user_bp = Blueprint('user_bp', __name__, url_prefix='/user')
@@ -31,9 +32,9 @@ def tenant():
monthly = request.form.get('allowed_monthly_interactions')
if lic_start != '':
- new_tenant.license_start_date = dt.strptime(lic_start, '%d-%m-%Y')
+ new_tenant.license_start_date = dt.strptime(lic_start, '%Y-%m-%d')
if lic_end != '':
- new_tenant.license_end_date = dt.strptime(lic_end, '%d-%m-%Y')
+ new_tenant.license_end_date = dt.strptime(lic_end, '%Y-%m-%d')
if monthly != '':
new_tenant.allowed_monthly_interactions = int(monthly)
@@ -43,13 +44,17 @@ def tenant():
new_tenant.updated_at = timestamp
# Add the new tenant to the database and commit the changes
-
try:
db.session.add(new_tenant)
db.session.commit()
except Exception as e:
error = e.args
+ # Create schema for new tenant
+ if error is None:
+ print(new_tenant.id)
+ Database(new_tenant.id).create_tenant_schema()
+
flash(error) if error else flash('Tenant added successfully.')
form = TenantForm()
@@ -81,7 +86,8 @@ def user():
password_hash = bcrypt.generate_password_hash(password).decode('utf-8')
# Create new user if there is no error
- new_user = User(user_name=username, email=email, password=password_hash, first_name=first_name, last_name=last_name)
+ new_user = User(user_name=username, email=email, password=password_hash, first_name=first_name,
+ last_name=last_name)
# Handle optional attributes
new_user.is_active = bool(request.form.get('is_active'))
diff --git a/migrations/tenant/env.py b/migrations/tenant/env.py
index 66d89ce..a810cc2 100644
--- a/migrations/tenant/env.py
+++ b/migrations/tenant/env.py
@@ -4,6 +4,7 @@ from logging.config import fileConfig
from flask import current_app
from alembic import context
+from sqlalchemy import NullPool, engine_from_config, text
from eveai_app.models.user import Tenant
@@ -92,6 +93,7 @@ def run_migrations_online():
)
with connectable.connect() as connection:
+ print(tenants)
for tenant in tenants:
logger.info(f"Migrating tenant: {tenant}")
# set search path on the connection, which ensures that
@@ -99,7 +101,7 @@ def run_migrations_online():
# in terms of this schema by default
connection.execute(text(f'SET search_path TO "{tenant}"'))
# in SQLAlchemy v2+ the search path change needs to be committed
- # connection.commit()
+ connection.commit()
# make use of non-supported SQLAlchemy attribute to ensure
# the dialect reflects tables in terms of the current tenant name