Refactoring part 1

Some changes for workers, but stopped due to refactoring
This commit is contained in:
Josako
2024-05-06 21:30:07 +02:00
parent d925477e68
commit 8e5ad5f312
34 changed files with 193 additions and 109 deletions

View File

@@ -5,14 +5,13 @@ from flask_security import SQLAlchemyUserDatastore
from flask_security.signals import user_authenticated
from werkzeug.middleware.proxy_fix import ProxyFix
import logging.config
from celery import Celery
from .extensions import db, migrate, bootstrap, security, mail, login_manager, cors
from .models.user import User, Tenant, Role
from .models.document import Document, DocumentLanguage, DocumentVersion
from .logging_config import LOGGING
from .utils.security import set_tenant_session_data
from .worker.celery_utils import init_celery
from common.extensions import db, migrate, bootstrap, security, mail, login_manager, cors
from common.models.user import User, Role
from config.logging_config import LOGGING
from common.utils.security import set_tenant_session_data
from .errors import register_error_handlers
from eveai_workers.celery_utils import init_celery
def create_app(config_file=None):
@@ -20,7 +19,7 @@ def create_app(config_file=None):
app.wsgi_app = ProxyFix(app.wsgi_app, x_for=1)
if config_file is None:
app.config.from_object('config.DevConfig')
app.config.from_object('config.config.DevConfig')
else:
app.config.from_object(config_file)
@@ -31,6 +30,10 @@ def create_app(config_file=None):
logging.config.dictConfig(LOGGING)
register_extensions(app)
# Initialize celery
init_celery(app)
# Setup Flask-Security-Too
user_datastore = SQLAlchemyUserDatastore(db, User, Role)
security.init_app(app, user_datastore)
@@ -39,6 +42,10 @@ def create_app(config_file=None):
# Register Blueprints
register_blueprints(app)
# Register Error Handlers
register_error_handlers(app)
# Debugging settings
if app.config['DEBUG'] is True:
app.logger.setLevel(logging.DEBUG)
mail_logger = logging.getLogger('flask_mailman')
@@ -79,17 +86,3 @@ def register_api(app):
# from . import api
# app.register_blueprint(api.bp, url_prefix='/api')
def create_celery_app(config_file=None):
app = Flask(__name__)
if config_file is None:
app.config.from_object('config.DevConfig')
else:
app.config.from_object(config_file)
celery = Celery(app.import_name)
init_celery(celery, app)
return celery
celery = create_celery_app()

View File

@@ -1,4 +0,0 @@
# from flask import Blueprint, request
#
# public_api_bp = Blueprint("public", __name__, url_prefix="/api/v1")
# tenant_api_bp = Blueprint("tenant", __name__, url_prefix="/api/v1/tenant")

View File

@@ -1,7 +0,0 @@
from flask import request
from flask.views import MethodView
class RegisterAPI(MethodView):
def post(self):
username = request.json['username']

22
eveai_app/errors.py Normal file
View File

@@ -0,0 +1,22 @@
from flask import render_template, request, jsonify
def not_found_error(error):
if request.accept_mimetypes.accept_json and not request.accept_mimetypes.accept_html:
response = jsonify({'error': 'Not found'})
response.status_code = 404
return response
return render_template('error/404.html'), 404
def internal_server_error(error):
if request.accept_mimetypes.accept_json and not request.accept_mimetypes.accept_html:
response = jsonify({'error': 'Internal server error'})
response.status_code = 500
return response
return render_template('error/500.html'), 500
def register_error_handlers(app):
app.register_error_handler(404, not_found_error)
app.register_error_handler(500, internal_server_error)

View File

@@ -1,17 +0,0 @@
from flask_sqlalchemy import SQLAlchemy
from flask_migrate import Migrate
from flask_bootstrap import Bootstrap
from flask_security import Security
from flask_mailman import Mail
from flask_login import LoginManager
from flask_cors import CORS
# Create extensions
db = SQLAlchemy()
migrate = Migrate()
bootstrap = Bootstrap()
security = Security()
mail = Mail()
login_manager = LoginManager()
cors = CORS()

View File

@@ -1,31 +0,0 @@
LOGGING = {
'version': 1,
'disable_existing_loggers': False,
'handlers': {
'file': {
'level': 'DEBUG',
'class': 'logging.handlers.RotatingFileHandler',
'filename': 'app.log',
'maxBytes': 1024*1024*5, # 5MB
'backupCount': 10,
'formatter': 'standard',
},
'console': {
'class': 'logging.StreamHandler',
'level': 'DEBUG',
'formatter': 'standard',
},
},
'formatters': {
'standard': {
'format': '%(asctime)s [%(levelname)s] %(name)s: %(message)s'
},
},
'loggers': {
'': { # root logger
'handlers': ['file', 'console'],
'level': 'DEBUG',
'propagate': True
}
}
}

View File

@@ -1,3 +0,0 @@

View File

@@ -1,105 +0,0 @@
from ..extensions import db
from .user import User, Tenant
from pgvector.sqlalchemy import Vector
class Document(db.Model):
id = db.Column(db.Integer, primary_key=True)
name = db.Column(db.String(100), nullable=False)
tenant_id = db.Column(db.Integer, db.ForeignKey(Tenant.id), nullable=False)
valid_from = db.Column(db.DateTime, nullable=True)
valid_to = db.Column(db.DateTime, nullable=True)
# Versioning Information
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now())
created_by = db.Column(db.Integer, db.ForeignKey(User.id), nullable=False)
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now(), onupdate=db.func.now())
updated_by = db.Column(db.Integer, db.ForeignKey(User.id))
# Relations
languages = db.relationship('DocumentLanguage', backref='document', lazy=True)
def __repr__(self):
return f"<Document {self.id}: {self.name}>"
class DocumentLanguage(db.Model):
id = db.Column(db.Integer, primary_key=True)
document_id = db.Column(db.Integer, db.ForeignKey(Document.id), nullable=False)
language = db.Column(db.String(2), nullable=False)
latest_version_id = db.Column(db.Integer, db.ForeignKey('document_version.id'), nullable=True)
# Versioning Information
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now())
created_by = db.Column(db.Integer, db.ForeignKey(User.id), nullable=False)
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now(), onupdate=db.func.now())
updated_by = db.Column(db.Integer, db.ForeignKey(User.id))
# Relations
versions = db.relationship(
'DocumentVersion',
backref='document_language',
lazy='joined',
foreign_keys='DocumentVersion.doc_lang_id'
)
latest_version = db.relationship(
'DocumentVersion',
uselist=False,
foreign_keys=[latest_version_id]
)
def __repr__(self):
return f"<DocumentLanguage {self.document_id}.{self.language}>"
class DocumentVersion(db.Model):
id = db.Column(db.Integer, primary_key=True)
doc_lang_id = db.Column(db.Integer, db.ForeignKey(DocumentLanguage.id), nullable=False)
url = db.Column(db.String(200), nullable=True)
file_location = db.Column(db.String(255), nullable=True)
file_name = db.Column(db.String(200), nullable=True)
file_type = db.Column(db.String(20), nullable=True)
# Versioning Information
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now())
created_by = db.Column(db.Integer, db.ForeignKey(User.id))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now(), onupdate=db.func.now())
updated_by = db.Column(db.Integer, db.ForeignKey(User.id))
# Processing Information
processing = db.Column(db.Boolean, nullable=False, default=False)
processing_started_at = db.Column(db.DateTime, nullable=True)
processing_finished_at = db.Column(db.DateTime, nullable=True)
processing_error = db.Column(db.String(255), nullable=True)
# Relations
embeddings = db.relationship('EmbeddingMistral', backref='document_version', lazy=True)
def __repr__(self):
return f"<DocumentVersion {self.document_language.document_id}.{self.document_language.language}>.{self.id}>"
def calc_file_location(self):
return f"{self.document_language.document.tenant_id}/{self.document_language.document.id}/{self.document_language.language}"
def calc_file_name(self):
return f"{self.id}.{self.file_type}"
class EmbeddingMistral(db.Model):
id = db.Column(db.Integer, primary_key=True)
doc_vers_id = db.Column(db.Integer, db.ForeignKey(DocumentVersion.id), nullable=False)
active = db.Column(db.Boolean, nullable=False, default=True)
# 1024 is the MISTRAL Embedding dimension.
# If another embedding model is chosen, this dimension may need to be changed.
embedding = db.Column(Vector(1024), nullable=False)
class EmbeddingSmallOpenAI(db.Model):
id = db.Column(db.Integer, primary_key=True)
doc_vers_id = db.Column(db.Integer, db.ForeignKey(DocumentVersion.id), nullable=False)
active = db.Column(db.Boolean, nullable=False, default=True)
# 1536 is the OpenAI Small Embedding dimension.
# If another embedding model is chosen, this dimension may need to be changed.
embedding = db.Column(Vector(1536), nullable=False)

View File

@@ -1,115 +0,0 @@
from ..extensions import db
from flask_security import UserMixin, RoleMixin
from sqlalchemy.dialects.postgresql import ARRAY
import sqlalchemy as sa
class Tenant(db.Model):
"""Tenant model"""
__bind_key__ = 'public'
__table_args__ = {'schema': 'public'}
# Versioning Information
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now())
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now(), onupdate=db.func.now())
# company Information
id = db.Column(db.Integer, primary_key=True)
name = db.Column(db.String(80), unique=True, nullable=False)
website = db.Column(db.String(255), nullable=True)
# language information
default_language = db.Column(db.String(2), nullable=True)
allowed_languages = db.Column(ARRAY(sa.String(2)), nullable=True)
# LLM specific choices
default_embedding_model = db.Column(db.String(50), nullable=True)
allowed_embedding_models = db.Column(ARRAY(sa.String(50)), nullable=True)
default_llm_model = db.Column(db.String(50), nullable=True)
allowed_llm_models = db.Column(ARRAY(sa.String(50)), nullable=True)
# Licensing Information
license_start_date = db.Column(db.Date, nullable=True)
license_end_date = db.Column(db.Date, nullable=True)
allowed_monthly_interactions = db.Column(db.Integer, nullable=True)
# Relations
users = db.relationship('User', backref='tenant')
def __repr__(self):
return f"<Tenant {self.id}: {self.name}>"
def to_dict(self):
return {
'id': self.id,
'name': self.name,
'website': self.website,
'default_language': self.default_language,
'allowed_languages': self.allowed_languages,
'default_embedding_model': self.default_embedding_model,
'allowed_embedding_models': self.allowed_embedding_models,
'default_llm_model': self.default_llm_model,
'allowed_llm_models': self.allowed_llm_models,
'license_start_date': self.license_start_date,
'license_end_date': self.license_end_date,
'allowed_monthly_interactions': self.allowed_monthly_interactions
}
class Role(db.Model, RoleMixin):
__bind_key__ = 'public'
__table_args__ = {'schema': 'public'}
id = db.Column(db.Integer(), primary_key=True)
name = db.Column(db.String(80), unique=True)
description = db.Column(db.String(255))
class RolesUsers(db.Model):
__bind_key__ = 'public'
__table_args__ = {'schema': 'public'}
user_id = db.Column(db.Integer(), db.ForeignKey('public.user.id', ondelete='CASCADE'), primary_key=True)
role_id = db.Column(db.Integer(), db.ForeignKey('public.role.id', ondelete='CASCADE'), primary_key=True)
class User(db.Model, UserMixin):
"""User model"""
__bind_key__ = 'public'
__table_args__ = {'schema': 'public'}
# Versioning Information
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now())
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now(), onupdate=db.func.now())
# User Information
id = db.Column(db.Integer, primary_key=True)
user_name = db.Column(db.String(80), unique=True, nullable=False)
email = db.Column(db.String(255), unique=True, nullable=False)
password = db.Column(db.String(255), nullable=False)
first_name = db.Column(db.String(80), nullable=False)
last_name = db.Column(db.String(80), nullable=False)
is_active = db.Column(db.Boolean, default=True)
active = db.Column(db.Boolean)
fs_uniquifier = db.Column(db.String(255), unique=True, nullable=False)
confirmed_at = db.Column(db.DateTime, nullable=True)
valid_to = db.Column(db.Date, nullable=True)
# Security Trackable Information
last_login_at = db.Column(db.DateTime, nullable=True)
current_login_at = db.Column(db.DateTime, nullable=True)
last_login_ip = db.Column(db.String(255), nullable=True)
current_login_ip = db.Column(db.String(255), nullable=True)
login_count = db.Column(db.Integer, nullable=False, default=0)
# Relations
roles = db.relationship('Role', secondary=RolesUsers.__table__, backref=db.backref('users', lazy='dynamic'))
tenant_id = db.Column(db.Integer, db.ForeignKey('public.tenant.id'), nullable=False)
def __repr__(self):
return '<User %r>' % self.user_name
def has_roles(self, *args):
return any(role.name in args for role in self.roles)

View File

@@ -20,7 +20,7 @@
<span>
<div class="container mb-4">
<div class="row mt-lg-n12 mt-md-n12 mt-n12 justify-content-center">
<div class="col-xl-8 col-lg-5 col-md-7 mx-auto">
{% block content_class %}<div class="col-xl-8 col-lg-5 col-md-7 mx-auto">{% endblock %}
<div class="card mt-8">
<div class="card-header p-0 position-relative mt-n4 mx-3 z-index-2">
<div class="bg-gradient-success shadow-success border-radius-lg py-3 pe-1 text-center py-4">

View File

@@ -5,6 +5,7 @@
{% block content_title %}Documents{% endblock %}
{% block content_description %}View Documents for Tenant{% endblock %}
{% block content_class %}<div class="col-xl-12 col-lg-5 col-md-7 mx-auto">{% endblock %}
{% block content %}
<div class="container">

View File

@@ -0,0 +1,9 @@
{% extends "base.html" %}
{% block title %}Error 404{% endblock %}
{% block content_title %}File not Found{% endblock %}
{% block content_description %}Something unexpected happened!{% endblock %}
{% block content %}
<p><a href="{{ url_for('basic_bp.index') }}">Return home</a></p>
{% endblock %}

View File

@@ -0,0 +1,9 @@
{% extends "base.html" %}
{% block title %}Error 500{% endblock %}
{% block content_title %}Internal Server error{% endblock %}
{% block content_description %}Something unexpected happened! The administrator has been notified.{% endblock %}
{% block content %}
<p><a href="{{ url_for('basic_bp.index') }}">Return home</a></p>
{% endblock %}

View File

@@ -1,12 +0,0 @@
from flask import session
def make_session_permanent():
session.permanent = True # Refresh the session timeout on every request
# if 'user_id' in session and session.get('was_authenticated'):
# if session.modified: # Check if the session was modified
# session['was_authenticated'] = True
# else:
# session.pop('user_id', None) # Clear session
# session.pop('was_authenticated', None)
# return redirect(url_for('login')) # Redirect to login page if session expired

View File

@@ -1,91 +0,0 @@
"""Database related functions"""
from os import popen
from sqlalchemy import text
from sqlalchemy.schema import CreateSchema
from sqlalchemy.exc import InternalError
from sqlalchemy.orm import sessionmaker, scoped_session
from flask import current_app
from ..extensions import db, migrate
class Database:
"""used for managing tenant databases related operations"""
def __init__(self, tenant: str) -> None:
self.schema = str(tenant)
def get_engine(self):
"""create new schema engine"""
return db.engine.execution_options(
schema_translate_map={None: self.schema}
)
def get_session(self):
"""To get session of tenant/public database schema for quick use
returns:
session: session of tenant/public database schema
"""
return scoped_session(
sessionmaker(bind=self.get_engine(), expire_on_commit=True)
)
def create_schema(self):
"""create new database schema, mostly used on tenant creation"""
try:
db.session.execute(CreateSchema(self.schema))
db.session.execute(text(f"CREATE EXTENSION IF NOT EXISTS pgvector SCHEMA {self.schema}"))
db.session.commit()
except InternalError as e:
db.session.rollback()
db.session.close()
current_app.logger.error(f"Error creating schema {self.schema}: {e.args}")
def create_tables(self):
"""create tables in for schema"""
db.metadata.create_all(self.get_engine())
def switch_schema(self):
"""switch between tenant/public database schema"""
db.session.execute(text(f'set search_path to "{self.schema}"'))
db.session.commit()
def migrate_tenant_schema(self):
"""migrate tenant database schema for new tenant"""
# Get the current revision for a database.
# NOTE: using popen may have a minor performance impact on the application
# you can store it in a different table in public schema and use it from there
# may be a faster approach
# last_revision = heads(directory="migrations/tenant", verbose=True, resolve_dependencies=False)
last_revision = popen(".venv/bin/flask db heads -d migrations/tenant").read()
print("LAST REVISION")
print(last_revision)
last_revision = last_revision.splitlines()[-1].split(" ")[0]
# creating revision table in tenant schema
session = self.get_session()
session.execute(
text(
f'CREATE TABLE "{self.schema}".alembic_version (version_num '
"VARCHAR(32) NOT NULL)"
)
)
session.commit()
# Insert last revision to alembic_version table
session.execute(
text(
f'INSERT INTO "{self.schema}".alembic_version (version_num) '
"VALUES (:version)"
),
{"version": last_revision},
)
session.commit()
session.close()
def create_tenant_schema(self):
"""create tenant used for creating new schema and its tables"""
self.create_schema()
self.create_tables()
self.migrate_tenant_schema()

View File

@@ -1,29 +0,0 @@
"""Middleware for the API
for handling tenant requests
"""
from flask_security import current_user
from flask import session
from ..models.user import User, Tenant
from .database import Database
def mw_before_request():
"""Before request
switch tenant schema
"""
tenant_id = session['tenant']['id']
if not tenant_id:
return {"message": "You are not logged into any tenant"}, 403
# user = User.query.get(current_user.id)
if current_user.has_roles(['Super User']) or current_user.tenant_id == tenant_id:
Database(tenant_id).switch_schema()
else:
return {"message": "You are not a member of this tenant"}, 403

View File

@@ -1,12 +0,0 @@
from flask import session
from ..models.user import User, Tenant
# Definition of Trigger Handlers
def set_tenant_session_data(sender, user, **kwargs):
tenant = Tenant.query.filter_by(id=user.tenant_id).first()
session['tenant'] = tenant.to_dict()
session['default_language'] = tenant.default_language
session['default_embedding_model'] = tenant.default_embedding_model
session['default_llm_model'] = tenant.default_llm_model

View File

@@ -1,15 +1,17 @@
import os
from datetime import datetime as dt, timezone as tz
from flask import request, redirect, url_for, flash, render_template, Blueprint, session, current_app
from flask_security import hash_password, roles_required, roles_accepted, current_user
from flask_security import roles_accepted, current_user
from sqlalchemy import desc
from sqlalchemy.orm import joinedload
from werkzeug.utils import secure_filename
from ..models.document import Document, DocumentLanguage, DocumentVersion
from ..extensions import db
from common.models import Document, DocumentLanguage, DocumentVersion
from common.extensions import db
from .document_forms import AddDocumentForm
from ..utils.middleware import mw_before_request
from common.utils.middleware import mw_before_request
from eveai_workers.tasks import create_embeddings
document_bp = Blueprint('document_bp', __name__, url_prefix='/document')
@@ -68,11 +70,17 @@ def add_document():
if error is None:
flash('Document added successfully.', 'success')
upload_file_for_version(new_doc_vers, file, extension)
create_embeddings.delay(tenant_id=session['tenant']['id'],
document_version_id=new_doc_vers.id,
default_embedding_model=session['default_embedding_model'])
current_app.logger.info(f'Document processing started for tenant {session["tenant"]["id"]}, '
f'Document Version {new_doc_vers.id}')
print('Processing should start soon')
else:
flash('Error adding document.', 'error')
current_app.logger.error(f'Error adding document for tenant {session["tenant"]["id"]}: {error}')
# return render_template('document/add_document.html', form=form)
return render_template('document/add_document.html', form=form)
@document_bp.route('/documents', methods=['GET', 'POST'])

View File

@@ -3,7 +3,7 @@ from flask_wtf import FlaskForm
from wtforms import (StringField, PasswordField, BooleanField, SubmitField, EmailField, IntegerField, DateField,
SelectField, SelectMultipleField, FieldList, FormField)
from wtforms.validators import DataRequired, Length, Email, NumberRange, Optional
from ..models.user import User, Role
from common.models import Role
class TenantForm(FlaskForm):

View File

@@ -4,10 +4,10 @@ from datetime import datetime as dt, timezone as tz
from flask import request, redirect, url_for, flash, render_template, Blueprint, session, current_app
from flask_security import hash_password, roles_required, roles_accepted
from ..models.user import User, Tenant, Role
from ..extensions import db
from common.models import User, Tenant, Role
from common.extensions import db
from .user_forms import TenantForm, CreateUserForm, EditUserForm
from ..utils.database import Database
from common.utils.database import Database
user_bp = Blueprint('user_bp', __name__, url_prefix='/user')

View File

@@ -1,9 +0,0 @@
def init_celery(celery, app):
celery.conf.update(app.config) # Load all configurations form Flask app including Queue settings
class ContextTask(celery.Task):
def __call__(self, *args, **kwargs):
with app.app_context():
return self.run(*args, **kwargs)
celery.Task = ContextTask

View File

@@ -1,59 +0,0 @@
from datetime import datetime as dt, timezone as tz
from flask import current_app
from langchain_mistralai import MistralAIEmbeddings
from langchain_openai import OpenAIEmbeddings
from langchain_community.document_loaders.pdf import PyPDFLoader
from langchain_community.vectorstores.chroma import Chroma
from langchain_text_splitters import CharacterTextSplitter
import os
from eveai_app import celery
from ..utils.database import Database
from ..models.document import DocumentVersion, EmbeddingMistral, EmbeddingSmallOpenAI
from .. import db
@celery.task(name='create_embeddings', queue='embeddings')
def create_embeddings(tenant_id, document_version_id, embedding_model_def):
current_app.logger.info(f'Creating embeddings for tenant {tenant_id} on document version {document_version_id} '
f'with model {embedding_model_def}')
Database(tenant_id).switch_schema()
document_version = DocumentVersion.query.get(document_version_id)
if document_version is None:
current_app.logger.error(f'Cannot create embeddings for tenant {tenant_id}. '
f'Document version {document_version_id} not found')
return
db.session.add(document_version)
# start processing
document_version.processing = True
document_version.processing_started_at = dt.now(tz.utc)
db.session.commit()
embedding_provider = embedding_model_def.rsplit('.', 1)[0]
embedding_model = embedding_model_def.rsplit('.', 1)[1]
# define embedding variables
match (embedding_provider, embedding_model):
case ('openai', 'text-embedding-3-small'):
embedding_model = EmbeddingSmallOpenAI()
case ('mistral', 'text-embedding-3-small'):
embedding_model = EmbeddingMistral()
match document_version.file_type:
case 'pdf':
pdf_file = os.path.join(current_app.config['UPLOAD_FOLDER'],
document_version.file_location,
document_version.file_path)
loader = PyPDFLoader(pdf_file)
# We
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
documents = text_splitter.split_documents(loader.load())
pass
@celery.task(name='ask_eveAI', queue='llm_interactions')
def ask_eve_ai(query):
# Interaction logic with LLMs like GPT (Langchain API calls, etc.)
pass