Refactoring part 1
Some changes for workers, but stopped due to refactoring
This commit is contained in:
@@ -5,14 +5,13 @@ from flask_security import SQLAlchemyUserDatastore
|
||||
from flask_security.signals import user_authenticated
|
||||
from werkzeug.middleware.proxy_fix import ProxyFix
|
||||
import logging.config
|
||||
from celery import Celery
|
||||
|
||||
from .extensions import db, migrate, bootstrap, security, mail, login_manager, cors
|
||||
from .models.user import User, Tenant, Role
|
||||
from .models.document import Document, DocumentLanguage, DocumentVersion
|
||||
from .logging_config import LOGGING
|
||||
from .utils.security import set_tenant_session_data
|
||||
from .worker.celery_utils import init_celery
|
||||
from common.extensions import db, migrate, bootstrap, security, mail, login_manager, cors
|
||||
from common.models.user import User, Role
|
||||
from config.logging_config import LOGGING
|
||||
from common.utils.security import set_tenant_session_data
|
||||
from .errors import register_error_handlers
|
||||
from eveai_workers.celery_utils import init_celery
|
||||
|
||||
|
||||
def create_app(config_file=None):
|
||||
@@ -20,7 +19,7 @@ def create_app(config_file=None):
|
||||
app.wsgi_app = ProxyFix(app.wsgi_app, x_for=1)
|
||||
|
||||
if config_file is None:
|
||||
app.config.from_object('config.DevConfig')
|
||||
app.config.from_object('config.config.DevConfig')
|
||||
else:
|
||||
app.config.from_object(config_file)
|
||||
|
||||
@@ -31,6 +30,10 @@ def create_app(config_file=None):
|
||||
|
||||
logging.config.dictConfig(LOGGING)
|
||||
register_extensions(app)
|
||||
|
||||
# Initialize celery
|
||||
init_celery(app)
|
||||
|
||||
# Setup Flask-Security-Too
|
||||
user_datastore = SQLAlchemyUserDatastore(db, User, Role)
|
||||
security.init_app(app, user_datastore)
|
||||
@@ -39,6 +42,10 @@ def create_app(config_file=None):
|
||||
# Register Blueprints
|
||||
register_blueprints(app)
|
||||
|
||||
# Register Error Handlers
|
||||
register_error_handlers(app)
|
||||
|
||||
# Debugging settings
|
||||
if app.config['DEBUG'] is True:
|
||||
app.logger.setLevel(logging.DEBUG)
|
||||
mail_logger = logging.getLogger('flask_mailman')
|
||||
@@ -79,17 +86,3 @@ def register_api(app):
|
||||
# from . import api
|
||||
# app.register_blueprint(api.bp, url_prefix='/api')
|
||||
|
||||
|
||||
def create_celery_app(config_file=None):
|
||||
app = Flask(__name__)
|
||||
if config_file is None:
|
||||
app.config.from_object('config.DevConfig')
|
||||
else:
|
||||
app.config.from_object(config_file)
|
||||
|
||||
celery = Celery(app.import_name)
|
||||
init_celery(celery, app)
|
||||
return celery
|
||||
|
||||
|
||||
celery = create_celery_app()
|
||||
|
||||
@@ -1,4 +0,0 @@
|
||||
# from flask import Blueprint, request
|
||||
#
|
||||
# public_api_bp = Blueprint("public", __name__, url_prefix="/api/v1")
|
||||
# tenant_api_bp = Blueprint("tenant", __name__, url_prefix="/api/v1/tenant")
|
||||
@@ -1,7 +0,0 @@
|
||||
from flask import request
|
||||
from flask.views import MethodView
|
||||
|
||||
class RegisterAPI(MethodView):
|
||||
def post(self):
|
||||
username = request.json['username']
|
||||
|
||||
22
eveai_app/errors.py
Normal file
22
eveai_app/errors.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from flask import render_template, request, jsonify
|
||||
|
||||
|
||||
def not_found_error(error):
|
||||
if request.accept_mimetypes.accept_json and not request.accept_mimetypes.accept_html:
|
||||
response = jsonify({'error': 'Not found'})
|
||||
response.status_code = 404
|
||||
return response
|
||||
return render_template('error/404.html'), 404
|
||||
|
||||
|
||||
def internal_server_error(error):
|
||||
if request.accept_mimetypes.accept_json and not request.accept_mimetypes.accept_html:
|
||||
response = jsonify({'error': 'Internal server error'})
|
||||
response.status_code = 500
|
||||
return response
|
||||
return render_template('error/500.html'), 500
|
||||
|
||||
|
||||
def register_error_handlers(app):
|
||||
app.register_error_handler(404, not_found_error)
|
||||
app.register_error_handler(500, internal_server_error)
|
||||
@@ -1,17 +0,0 @@
|
||||
from flask_sqlalchemy import SQLAlchemy
|
||||
from flask_migrate import Migrate
|
||||
from flask_bootstrap import Bootstrap
|
||||
from flask_security import Security
|
||||
from flask_mailman import Mail
|
||||
from flask_login import LoginManager
|
||||
from flask_cors import CORS
|
||||
|
||||
|
||||
# Create extensions
|
||||
db = SQLAlchemy()
|
||||
migrate = Migrate()
|
||||
bootstrap = Bootstrap()
|
||||
security = Security()
|
||||
mail = Mail()
|
||||
login_manager = LoginManager()
|
||||
cors = CORS()
|
||||
@@ -1,31 +0,0 @@
|
||||
LOGGING = {
|
||||
'version': 1,
|
||||
'disable_existing_loggers': False,
|
||||
'handlers': {
|
||||
'file': {
|
||||
'level': 'DEBUG',
|
||||
'class': 'logging.handlers.RotatingFileHandler',
|
||||
'filename': 'app.log',
|
||||
'maxBytes': 1024*1024*5, # 5MB
|
||||
'backupCount': 10,
|
||||
'formatter': 'standard',
|
||||
},
|
||||
'console': {
|
||||
'class': 'logging.StreamHandler',
|
||||
'level': 'DEBUG',
|
||||
'formatter': 'standard',
|
||||
},
|
||||
},
|
||||
'formatters': {
|
||||
'standard': {
|
||||
'format': '%(asctime)s [%(levelname)s] %(name)s: %(message)s'
|
||||
},
|
||||
},
|
||||
'loggers': {
|
||||
'': { # root logger
|
||||
'handlers': ['file', 'console'],
|
||||
'level': 'DEBUG',
|
||||
'propagate': True
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,3 +0,0 @@
|
||||
|
||||
|
||||
|
||||
@@ -1,105 +0,0 @@
|
||||
from ..extensions import db
|
||||
from .user import User, Tenant
|
||||
from pgvector.sqlalchemy import Vector
|
||||
|
||||
|
||||
class Document(db.Model):
|
||||
id = db.Column(db.Integer, primary_key=True)
|
||||
name = db.Column(db.String(100), nullable=False)
|
||||
tenant_id = db.Column(db.Integer, db.ForeignKey(Tenant.id), nullable=False)
|
||||
valid_from = db.Column(db.DateTime, nullable=True)
|
||||
valid_to = db.Column(db.DateTime, nullable=True)
|
||||
|
||||
# Versioning Information
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now())
|
||||
created_by = db.Column(db.Integer, db.ForeignKey(User.id), nullable=False)
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now(), onupdate=db.func.now())
|
||||
updated_by = db.Column(db.Integer, db.ForeignKey(User.id))
|
||||
|
||||
# Relations
|
||||
languages = db.relationship('DocumentLanguage', backref='document', lazy=True)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Document {self.id}: {self.name}>"
|
||||
|
||||
|
||||
class DocumentLanguage(db.Model):
|
||||
id = db.Column(db.Integer, primary_key=True)
|
||||
document_id = db.Column(db.Integer, db.ForeignKey(Document.id), nullable=False)
|
||||
language = db.Column(db.String(2), nullable=False)
|
||||
latest_version_id = db.Column(db.Integer, db.ForeignKey('document_version.id'), nullable=True)
|
||||
|
||||
# Versioning Information
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now())
|
||||
created_by = db.Column(db.Integer, db.ForeignKey(User.id), nullable=False)
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now(), onupdate=db.func.now())
|
||||
updated_by = db.Column(db.Integer, db.ForeignKey(User.id))
|
||||
|
||||
# Relations
|
||||
versions = db.relationship(
|
||||
'DocumentVersion',
|
||||
backref='document_language',
|
||||
lazy='joined',
|
||||
foreign_keys='DocumentVersion.doc_lang_id'
|
||||
)
|
||||
latest_version = db.relationship(
|
||||
'DocumentVersion',
|
||||
uselist=False,
|
||||
foreign_keys=[latest_version_id]
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<DocumentLanguage {self.document_id}.{self.language}>"
|
||||
|
||||
|
||||
class DocumentVersion(db.Model):
|
||||
id = db.Column(db.Integer, primary_key=True)
|
||||
doc_lang_id = db.Column(db.Integer, db.ForeignKey(DocumentLanguage.id), nullable=False)
|
||||
url = db.Column(db.String(200), nullable=True)
|
||||
file_location = db.Column(db.String(255), nullable=True)
|
||||
file_name = db.Column(db.String(200), nullable=True)
|
||||
file_type = db.Column(db.String(20), nullable=True)
|
||||
|
||||
# Versioning Information
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now())
|
||||
created_by = db.Column(db.Integer, db.ForeignKey(User.id))
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now(), onupdate=db.func.now())
|
||||
updated_by = db.Column(db.Integer, db.ForeignKey(User.id))
|
||||
|
||||
# Processing Information
|
||||
processing = db.Column(db.Boolean, nullable=False, default=False)
|
||||
processing_started_at = db.Column(db.DateTime, nullable=True)
|
||||
processing_finished_at = db.Column(db.DateTime, nullable=True)
|
||||
processing_error = db.Column(db.String(255), nullable=True)
|
||||
|
||||
# Relations
|
||||
embeddings = db.relationship('EmbeddingMistral', backref='document_version', lazy=True)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<DocumentVersion {self.document_language.document_id}.{self.document_language.language}>.{self.id}>"
|
||||
|
||||
def calc_file_location(self):
|
||||
return f"{self.document_language.document.tenant_id}/{self.document_language.document.id}/{self.document_language.language}"
|
||||
|
||||
def calc_file_name(self):
|
||||
return f"{self.id}.{self.file_type}"
|
||||
|
||||
|
||||
class EmbeddingMistral(db.Model):
|
||||
id = db.Column(db.Integer, primary_key=True)
|
||||
doc_vers_id = db.Column(db.Integer, db.ForeignKey(DocumentVersion.id), nullable=False)
|
||||
active = db.Column(db.Boolean, nullable=False, default=True)
|
||||
|
||||
# 1024 is the MISTRAL Embedding dimension.
|
||||
# If another embedding model is chosen, this dimension may need to be changed.
|
||||
embedding = db.Column(Vector(1024), nullable=False)
|
||||
|
||||
|
||||
class EmbeddingSmallOpenAI(db.Model):
|
||||
id = db.Column(db.Integer, primary_key=True)
|
||||
doc_vers_id = db.Column(db.Integer, db.ForeignKey(DocumentVersion.id), nullable=False)
|
||||
active = db.Column(db.Boolean, nullable=False, default=True)
|
||||
|
||||
# 1536 is the OpenAI Small Embedding dimension.
|
||||
# If another embedding model is chosen, this dimension may need to be changed.
|
||||
embedding = db.Column(Vector(1536), nullable=False)
|
||||
@@ -1,115 +0,0 @@
|
||||
from ..extensions import db
|
||||
from flask_security import UserMixin, RoleMixin
|
||||
from sqlalchemy.dialects.postgresql import ARRAY
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
class Tenant(db.Model):
|
||||
"""Tenant model"""
|
||||
|
||||
__bind_key__ = 'public'
|
||||
__table_args__ = {'schema': 'public'}
|
||||
|
||||
# Versioning Information
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now())
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now(), onupdate=db.func.now())
|
||||
|
||||
# company Information
|
||||
id = db.Column(db.Integer, primary_key=True)
|
||||
name = db.Column(db.String(80), unique=True, nullable=False)
|
||||
website = db.Column(db.String(255), nullable=True)
|
||||
|
||||
# language information
|
||||
default_language = db.Column(db.String(2), nullable=True)
|
||||
allowed_languages = db.Column(ARRAY(sa.String(2)), nullable=True)
|
||||
|
||||
# LLM specific choices
|
||||
default_embedding_model = db.Column(db.String(50), nullable=True)
|
||||
allowed_embedding_models = db.Column(ARRAY(sa.String(50)), nullable=True)
|
||||
default_llm_model = db.Column(db.String(50), nullable=True)
|
||||
allowed_llm_models = db.Column(ARRAY(sa.String(50)), nullable=True)
|
||||
|
||||
# Licensing Information
|
||||
license_start_date = db.Column(db.Date, nullable=True)
|
||||
license_end_date = db.Column(db.Date, nullable=True)
|
||||
allowed_monthly_interactions = db.Column(db.Integer, nullable=True)
|
||||
|
||||
# Relations
|
||||
users = db.relationship('User', backref='tenant')
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Tenant {self.id}: {self.name}>"
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
'id': self.id,
|
||||
'name': self.name,
|
||||
'website': self.website,
|
||||
'default_language': self.default_language,
|
||||
'allowed_languages': self.allowed_languages,
|
||||
'default_embedding_model': self.default_embedding_model,
|
||||
'allowed_embedding_models': self.allowed_embedding_models,
|
||||
'default_llm_model': self.default_llm_model,
|
||||
'allowed_llm_models': self.allowed_llm_models,
|
||||
'license_start_date': self.license_start_date,
|
||||
'license_end_date': self.license_end_date,
|
||||
'allowed_monthly_interactions': self.allowed_monthly_interactions
|
||||
}
|
||||
|
||||
|
||||
class Role(db.Model, RoleMixin):
|
||||
__bind_key__ = 'public'
|
||||
__table_args__ = {'schema': 'public'}
|
||||
|
||||
id = db.Column(db.Integer(), primary_key=True)
|
||||
name = db.Column(db.String(80), unique=True)
|
||||
description = db.Column(db.String(255))
|
||||
|
||||
|
||||
class RolesUsers(db.Model):
|
||||
__bind_key__ = 'public'
|
||||
__table_args__ = {'schema': 'public'}
|
||||
|
||||
user_id = db.Column(db.Integer(), db.ForeignKey('public.user.id', ondelete='CASCADE'), primary_key=True)
|
||||
role_id = db.Column(db.Integer(), db.ForeignKey('public.role.id', ondelete='CASCADE'), primary_key=True)
|
||||
|
||||
|
||||
class User(db.Model, UserMixin):
|
||||
"""User model"""
|
||||
|
||||
__bind_key__ = 'public'
|
||||
__table_args__ = {'schema': 'public'}
|
||||
|
||||
# Versioning Information
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now())
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now(), onupdate=db.func.now())
|
||||
|
||||
# User Information
|
||||
id = db.Column(db.Integer, primary_key=True)
|
||||
user_name = db.Column(db.String(80), unique=True, nullable=False)
|
||||
email = db.Column(db.String(255), unique=True, nullable=False)
|
||||
password = db.Column(db.String(255), nullable=False)
|
||||
first_name = db.Column(db.String(80), nullable=False)
|
||||
last_name = db.Column(db.String(80), nullable=False)
|
||||
is_active = db.Column(db.Boolean, default=True)
|
||||
active = db.Column(db.Boolean)
|
||||
fs_uniquifier = db.Column(db.String(255), unique=True, nullable=False)
|
||||
confirmed_at = db.Column(db.DateTime, nullable=True)
|
||||
valid_to = db.Column(db.Date, nullable=True)
|
||||
|
||||
# Security Trackable Information
|
||||
last_login_at = db.Column(db.DateTime, nullable=True)
|
||||
current_login_at = db.Column(db.DateTime, nullable=True)
|
||||
last_login_ip = db.Column(db.String(255), nullable=True)
|
||||
current_login_ip = db.Column(db.String(255), nullable=True)
|
||||
login_count = db.Column(db.Integer, nullable=False, default=0)
|
||||
|
||||
# Relations
|
||||
roles = db.relationship('Role', secondary=RolesUsers.__table__, backref=db.backref('users', lazy='dynamic'))
|
||||
tenant_id = db.Column(db.Integer, db.ForeignKey('public.tenant.id'), nullable=False)
|
||||
|
||||
def __repr__(self):
|
||||
return '<User %r>' % self.user_name
|
||||
|
||||
def has_roles(self, *args):
|
||||
return any(role.name in args for role in self.roles)
|
||||
@@ -20,7 +20,7 @@
|
||||
<span>
|
||||
<div class="container mb-4">
|
||||
<div class="row mt-lg-n12 mt-md-n12 mt-n12 justify-content-center">
|
||||
<div class="col-xl-8 col-lg-5 col-md-7 mx-auto">
|
||||
{% block content_class %}<div class="col-xl-8 col-lg-5 col-md-7 mx-auto">{% endblock %}
|
||||
<div class="card mt-8">
|
||||
<div class="card-header p-0 position-relative mt-n4 mx-3 z-index-2">
|
||||
<div class="bg-gradient-success shadow-success border-radius-lg py-3 pe-1 text-center py-4">
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
|
||||
{% block content_title %}Documents{% endblock %}
|
||||
{% block content_description %}View Documents for Tenant{% endblock %}
|
||||
{% block content_class %}<div class="col-xl-12 col-lg-5 col-md-7 mx-auto">{% endblock %}
|
||||
|
||||
{% block content %}
|
||||
<div class="container">
|
||||
|
||||
9
eveai_app/templates/error/404.html
Normal file
9
eveai_app/templates/error/404.html
Normal file
@@ -0,0 +1,9 @@
|
||||
{% extends "base.html" %}
|
||||
|
||||
{% block title %}Error 404{% endblock %}
|
||||
|
||||
{% block content_title %}File not Found{% endblock %}
|
||||
{% block content_description %}Something unexpected happened!{% endblock %}
|
||||
{% block content %}
|
||||
<p><a href="{{ url_for('basic_bp.index') }}">Return home</a></p>
|
||||
{% endblock %}
|
||||
9
eveai_app/templates/error/500.html
Normal file
9
eveai_app/templates/error/500.html
Normal file
@@ -0,0 +1,9 @@
|
||||
{% extends "base.html" %}
|
||||
|
||||
{% block title %}Error 500{% endblock %}
|
||||
|
||||
{% block content_title %}Internal Server error{% endblock %}
|
||||
{% block content_description %}Something unexpected happened! The administrator has been notified.{% endblock %}
|
||||
{% block content %}
|
||||
<p><a href="{{ url_for('basic_bp.index') }}">Return home</a></p>
|
||||
{% endblock %}
|
||||
@@ -1,12 +0,0 @@
|
||||
from flask import session
|
||||
|
||||
|
||||
def make_session_permanent():
|
||||
session.permanent = True # Refresh the session timeout on every request
|
||||
# if 'user_id' in session and session.get('was_authenticated'):
|
||||
# if session.modified: # Check if the session was modified
|
||||
# session['was_authenticated'] = True
|
||||
# else:
|
||||
# session.pop('user_id', None) # Clear session
|
||||
# session.pop('was_authenticated', None)
|
||||
# return redirect(url_for('login')) # Redirect to login page if session expired
|
||||
@@ -1,91 +0,0 @@
|
||||
"""Database related functions"""
|
||||
from os import popen
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.schema import CreateSchema
|
||||
from sqlalchemy.exc import InternalError
|
||||
from sqlalchemy.orm import sessionmaker, scoped_session
|
||||
from flask import current_app
|
||||
|
||||
from ..extensions import db, migrate
|
||||
|
||||
|
||||
class Database:
|
||||
"""used for managing tenant databases related operations"""
|
||||
|
||||
def __init__(self, tenant: str) -> None:
|
||||
self.schema = str(tenant)
|
||||
|
||||
def get_engine(self):
|
||||
"""create new schema engine"""
|
||||
return db.engine.execution_options(
|
||||
schema_translate_map={None: self.schema}
|
||||
)
|
||||
|
||||
def get_session(self):
|
||||
"""To get session of tenant/public database schema for quick use
|
||||
|
||||
returns:
|
||||
session: session of tenant/public database schema
|
||||
"""
|
||||
return scoped_session(
|
||||
sessionmaker(bind=self.get_engine(), expire_on_commit=True)
|
||||
)
|
||||
|
||||
def create_schema(self):
|
||||
"""create new database schema, mostly used on tenant creation"""
|
||||
try:
|
||||
db.session.execute(CreateSchema(self.schema))
|
||||
db.session.execute(text(f"CREATE EXTENSION IF NOT EXISTS pgvector SCHEMA {self.schema}"))
|
||||
db.session.commit()
|
||||
except InternalError as e:
|
||||
db.session.rollback()
|
||||
db.session.close()
|
||||
current_app.logger.error(f"Error creating schema {self.schema}: {e.args}")
|
||||
|
||||
def create_tables(self):
|
||||
"""create tables in for schema"""
|
||||
db.metadata.create_all(self.get_engine())
|
||||
|
||||
def switch_schema(self):
|
||||
"""switch between tenant/public database schema"""
|
||||
db.session.execute(text(f'set search_path to "{self.schema}"'))
|
||||
db.session.commit()
|
||||
|
||||
def migrate_tenant_schema(self):
|
||||
"""migrate tenant database schema for new tenant"""
|
||||
# Get the current revision for a database.
|
||||
# NOTE: using popen may have a minor performance impact on the application
|
||||
# you can store it in a different table in public schema and use it from there
|
||||
# may be a faster approach
|
||||
# last_revision = heads(directory="migrations/tenant", verbose=True, resolve_dependencies=False)
|
||||
last_revision = popen(".venv/bin/flask db heads -d migrations/tenant").read()
|
||||
print("LAST REVISION")
|
||||
print(last_revision)
|
||||
last_revision = last_revision.splitlines()[-1].split(" ")[0]
|
||||
|
||||
# creating revision table in tenant schema
|
||||
session = self.get_session()
|
||||
session.execute(
|
||||
text(
|
||||
f'CREATE TABLE "{self.schema}".alembic_version (version_num '
|
||||
"VARCHAR(32) NOT NULL)"
|
||||
)
|
||||
)
|
||||
session.commit()
|
||||
|
||||
# Insert last revision to alembic_version table
|
||||
session.execute(
|
||||
text(
|
||||
f'INSERT INTO "{self.schema}".alembic_version (version_num) '
|
||||
"VALUES (:version)"
|
||||
),
|
||||
{"version": last_revision},
|
||||
)
|
||||
session.commit()
|
||||
session.close()
|
||||
|
||||
def create_tenant_schema(self):
|
||||
"""create tenant used for creating new schema and its tables"""
|
||||
self.create_schema()
|
||||
self.create_tables()
|
||||
self.migrate_tenant_schema()
|
||||
@@ -1,29 +0,0 @@
|
||||
"""Middleware for the API
|
||||
|
||||
for handling tenant requests
|
||||
"""
|
||||
|
||||
from flask_security import current_user
|
||||
from flask import session
|
||||
|
||||
from ..models.user import User, Tenant
|
||||
from .database import Database
|
||||
|
||||
|
||||
def mw_before_request():
|
||||
"""Before request
|
||||
|
||||
switch tenant schema
|
||||
"""
|
||||
|
||||
tenant_id = session['tenant']['id']
|
||||
if not tenant_id:
|
||||
return {"message": "You are not logged into any tenant"}, 403
|
||||
|
||||
# user = User.query.get(current_user.id)
|
||||
if current_user.has_roles(['Super User']) or current_user.tenant_id == tenant_id:
|
||||
Database(tenant_id).switch_schema()
|
||||
else:
|
||||
return {"message": "You are not a member of this tenant"}, 403
|
||||
|
||||
|
||||
@@ -1,12 +0,0 @@
|
||||
from flask import session
|
||||
from ..models.user import User, Tenant
|
||||
|
||||
|
||||
# Definition of Trigger Handlers
|
||||
def set_tenant_session_data(sender, user, **kwargs):
|
||||
tenant = Tenant.query.filter_by(id=user.tenant_id).first()
|
||||
session['tenant'] = tenant.to_dict()
|
||||
session['default_language'] = tenant.default_language
|
||||
session['default_embedding_model'] = tenant.default_embedding_model
|
||||
session['default_llm_model'] = tenant.default_llm_model
|
||||
|
||||
@@ -1,15 +1,17 @@
|
||||
import os
|
||||
from datetime import datetime as dt, timezone as tz
|
||||
from flask import request, redirect, url_for, flash, render_template, Blueprint, session, current_app
|
||||
from flask_security import hash_password, roles_required, roles_accepted, current_user
|
||||
from flask_security import roles_accepted, current_user
|
||||
from sqlalchemy import desc
|
||||
from sqlalchemy.orm import joinedload
|
||||
from werkzeug.utils import secure_filename
|
||||
|
||||
from ..models.document import Document, DocumentLanguage, DocumentVersion
|
||||
from ..extensions import db
|
||||
from common.models import Document, DocumentLanguage, DocumentVersion
|
||||
from common.extensions import db
|
||||
from .document_forms import AddDocumentForm
|
||||
from ..utils.middleware import mw_before_request
|
||||
from common.utils.middleware import mw_before_request
|
||||
from eveai_workers.tasks import create_embeddings
|
||||
|
||||
|
||||
document_bp = Blueprint('document_bp', __name__, url_prefix='/document')
|
||||
|
||||
@@ -68,11 +70,17 @@ def add_document():
|
||||
if error is None:
|
||||
flash('Document added successfully.', 'success')
|
||||
upload_file_for_version(new_doc_vers, file, extension)
|
||||
create_embeddings.delay(tenant_id=session['tenant']['id'],
|
||||
document_version_id=new_doc_vers.id,
|
||||
default_embedding_model=session['default_embedding_model'])
|
||||
current_app.logger.info(f'Document processing started for tenant {session["tenant"]["id"]}, '
|
||||
f'Document Version {new_doc_vers.id}')
|
||||
print('Processing should start soon')
|
||||
else:
|
||||
flash('Error adding document.', 'error')
|
||||
current_app.logger.error(f'Error adding document for tenant {session["tenant"]["id"]}: {error}')
|
||||
|
||||
# return render_template('document/add_document.html', form=form)
|
||||
return render_template('document/add_document.html', form=form)
|
||||
|
||||
|
||||
@document_bp.route('/documents', methods=['GET', 'POST'])
|
||||
|
||||
@@ -3,7 +3,7 @@ from flask_wtf import FlaskForm
|
||||
from wtforms import (StringField, PasswordField, BooleanField, SubmitField, EmailField, IntegerField, DateField,
|
||||
SelectField, SelectMultipleField, FieldList, FormField)
|
||||
from wtforms.validators import DataRequired, Length, Email, NumberRange, Optional
|
||||
from ..models.user import User, Role
|
||||
from common.models import Role
|
||||
|
||||
|
||||
class TenantForm(FlaskForm):
|
||||
|
||||
@@ -4,10 +4,10 @@ from datetime import datetime as dt, timezone as tz
|
||||
from flask import request, redirect, url_for, flash, render_template, Blueprint, session, current_app
|
||||
from flask_security import hash_password, roles_required, roles_accepted
|
||||
|
||||
from ..models.user import User, Tenant, Role
|
||||
from ..extensions import db
|
||||
from common.models import User, Tenant, Role
|
||||
from common.extensions import db
|
||||
from .user_forms import TenantForm, CreateUserForm, EditUserForm
|
||||
from ..utils.database import Database
|
||||
from common.utils.database import Database
|
||||
|
||||
user_bp = Blueprint('user_bp', __name__, url_prefix='/user')
|
||||
|
||||
|
||||
@@ -1,9 +0,0 @@
|
||||
def init_celery(celery, app):
|
||||
celery.conf.update(app.config) # Load all configurations form Flask app including Queue settings
|
||||
|
||||
class ContextTask(celery.Task):
|
||||
def __call__(self, *args, **kwargs):
|
||||
with app.app_context():
|
||||
return self.run(*args, **kwargs)
|
||||
|
||||
celery.Task = ContextTask
|
||||
@@ -1,59 +0,0 @@
|
||||
from datetime import datetime as dt, timezone as tz
|
||||
from flask import current_app
|
||||
from langchain_mistralai import MistralAIEmbeddings
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
from langchain_community.document_loaders.pdf import PyPDFLoader
|
||||
from langchain_community.vectorstores.chroma import Chroma
|
||||
from langchain_text_splitters import CharacterTextSplitter
|
||||
import os
|
||||
|
||||
from eveai_app import celery
|
||||
from ..utils.database import Database
|
||||
from ..models.document import DocumentVersion, EmbeddingMistral, EmbeddingSmallOpenAI
|
||||
from .. import db
|
||||
|
||||
|
||||
@celery.task(name='create_embeddings', queue='embeddings')
|
||||
def create_embeddings(tenant_id, document_version_id, embedding_model_def):
|
||||
current_app.logger.info(f'Creating embeddings for tenant {tenant_id} on document version {document_version_id} '
|
||||
f'with model {embedding_model_def}')
|
||||
Database(tenant_id).switch_schema()
|
||||
document_version = DocumentVersion.query.get(document_version_id)
|
||||
if document_version is None:
|
||||
current_app.logger.error(f'Cannot create embeddings for tenant {tenant_id}. '
|
||||
f'Document version {document_version_id} not found')
|
||||
return
|
||||
db.session.add(document_version)
|
||||
|
||||
# start processing
|
||||
document_version.processing = True
|
||||
document_version.processing_started_at = dt.now(tz.utc)
|
||||
db.session.commit()
|
||||
|
||||
embedding_provider = embedding_model_def.rsplit('.', 1)[0]
|
||||
embedding_model = embedding_model_def.rsplit('.', 1)[1]
|
||||
# define embedding variables
|
||||
match (embedding_provider, embedding_model):
|
||||
case ('openai', 'text-embedding-3-small'):
|
||||
embedding_model = EmbeddingSmallOpenAI()
|
||||
case ('mistral', 'text-embedding-3-small'):
|
||||
embedding_model = EmbeddingMistral()
|
||||
|
||||
match document_version.file_type:
|
||||
case 'pdf':
|
||||
pdf_file = os.path.join(current_app.config['UPLOAD_FOLDER'],
|
||||
document_version.file_location,
|
||||
document_version.file_path)
|
||||
loader = PyPDFLoader(pdf_file)
|
||||
|
||||
# We
|
||||
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
|
||||
documents = text_splitter.split_documents(loader.load())
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@celery.task(name='ask_eveAI', queue='llm_interactions')
|
||||
def ask_eve_ai(query):
|
||||
# Interaction logic with LLMs like GPT (Langchain API calls, etc.)
|
||||
pass
|
||||
Reference in New Issue
Block a user