202 lines
6.1 KiB
Python
202 lines
6.1 KiB
Python
import traceback
|
|
|
|
from flask import Flask, jsonify, request, redirect
|
|
from flask_jwt_extended import get_jwt_identity, verify_jwt_in_request
|
|
from sqlalchemy.exc import SQLAlchemyError
|
|
from werkzeug.exceptions import HTTPException
|
|
|
|
from common.extensions import db, api_rest, jwt, minio_client, simple_encryption, cors
|
|
import os
|
|
import logging.config
|
|
|
|
from common.models.user import TenantDomain
|
|
from common.utils.cors_utils import get_allowed_origins
|
|
from common.utils.database import Database
|
|
from config.logging_config import LOGGING
|
|
from .api.document_api import document_ns
|
|
from .api.auth import auth_ns
|
|
from config.config import get_config
|
|
from common.utils.celery_utils import make_celery, init_celery
|
|
from common.utils.eveai_exceptions import EveAIException
|
|
from common.utils.debug_utils import register_request_debugger
|
|
|
|
|
|
def create_app(config_file=None):
|
|
app = Flask(__name__)
|
|
|
|
environment = os.getenv('FLASK_ENV', 'development')
|
|
|
|
match environment:
|
|
case 'development':
|
|
app.config.from_object(get_config('dev'))
|
|
case 'production':
|
|
app.config.from_object(get_config('prod'))
|
|
case _:
|
|
app.config.from_object(get_config('dev'))
|
|
|
|
app.config['SESSION_KEY_PREFIX'] = 'eveai_api_'
|
|
|
|
app.celery = make_celery(app.name, app.config)
|
|
init_celery(app.celery, app)
|
|
|
|
logging.config.dictConfig(LOGGING)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
logger.info("eveai_api starting up")
|
|
|
|
# Register Necessary Extensions
|
|
register_extensions(app)
|
|
|
|
# register Namespaces
|
|
register_namespaces(api_rest)
|
|
|
|
# Register Blueprints
|
|
register_blueprints(app)
|
|
|
|
# Register Error Handlers
|
|
register_error_handlers(app)
|
|
|
|
# Register Request Debugger
|
|
register_request_debugger(app)
|
|
|
|
@app.before_request
|
|
def check_cors():
|
|
if request.method == 'OPTIONS':
|
|
app.logger.debug("Handling OPTIONS request")
|
|
return '', 200 # Allow OPTIONS to pass through
|
|
|
|
origin = request.headers.get('Origin')
|
|
if not origin:
|
|
return # Not a CORS request
|
|
|
|
# Get tenant ID from request
|
|
if verify_jwt_in_request():
|
|
tenant_id = get_jwt_identity()
|
|
if not tenant_id:
|
|
return
|
|
else:
|
|
return
|
|
|
|
# Check if origin is allowed for this tenant
|
|
allowed_origins = get_allowed_origins(tenant_id)
|
|
|
|
if origin not in allowed_origins:
|
|
app.logger.warning(f'Origin {origin} not allowed for tenant {tenant_id}')
|
|
return {'error': 'Origin not allowed'}, 403
|
|
|
|
@app.before_request
|
|
def set_tenant_schema():
|
|
# Check if this a health check request
|
|
if request.path.startswith('/_healthz') or request.path.startswith('/healthz'):
|
|
pass
|
|
else:
|
|
try:
|
|
verify_jwt_in_request(optional=True)
|
|
tenant_id = get_jwt_identity()
|
|
|
|
if tenant_id:
|
|
Database(tenant_id).switch_schema()
|
|
except Exception as e:
|
|
app.logger.error(f'Error in before_request: {str(e)}')
|
|
# Don't raise the exception here, let the request continue
|
|
# The appropriate error handling will be done in the specific endpoints
|
|
|
|
@app.route('/api/v1')
|
|
def swagger():
|
|
return redirect('/api/v1/')
|
|
|
|
return app
|
|
|
|
|
|
def register_extensions(app):
|
|
db.init_app(app)
|
|
api_rest.init_app(app,
|
|
title='EveAI API',
|
|
version='1.0',
|
|
description='EveAI API',
|
|
doc='/api/v1/',
|
|
prefix='/api/v1'),
|
|
jwt.init_app(app)
|
|
minio_client.init_app(app)
|
|
simple_encryption.init_app(app)
|
|
cors.init_app(app, resources={
|
|
r"/api/v1/*": {
|
|
"origins": "*",
|
|
"methods": ["GET", "POST", "PUT", "OPTIONS"],
|
|
"allow_headers": ["Content-Type", "Authorization", "X-Requested-With"],
|
|
"expose_headers": ["Content-Length", "Content-Range"],
|
|
"supports_credentials": True,
|
|
"max_age": 1728000, # 20 days
|
|
"allow_credentials": True
|
|
}
|
|
})
|
|
|
|
|
|
def register_namespaces(app):
|
|
api_rest.add_namespace(document_ns, path='/api/v1/documents')
|
|
api_rest.add_namespace(auth_ns, path='/api/v1/auth')
|
|
|
|
|
|
def register_blueprints(app):
|
|
from .views.healthz_views import healthz_bp
|
|
app.register_blueprint(healthz_bp)
|
|
|
|
|
|
def register_error_handlers(app):
|
|
@app.errorhandler(Exception)
|
|
def handle_exception(e):
|
|
"""Handle all unhandled exceptions with detailed error responses"""
|
|
# Get the current exception info
|
|
exc_info = traceback.format_exc()
|
|
|
|
# Log the full exception details
|
|
app.logger.error(f"Unhandled exception: {str(e)}\n{exc_info}")
|
|
|
|
# Start with a default error response
|
|
response = {
|
|
"error": "Internal Server Error",
|
|
"message": str(e),
|
|
"type": e.__class__.__name__
|
|
}
|
|
|
|
status_code = 500
|
|
|
|
# Handle specific types of exceptions
|
|
if isinstance(e, HTTPException):
|
|
status_code = e.code
|
|
response["error"] = e.name
|
|
|
|
elif isinstance(e, SQLAlchemyError):
|
|
response["error"] = "Database Error"
|
|
response["details"] = str(e.__cause__ or e)
|
|
|
|
elif isinstance(e, ValueError):
|
|
status_code = 400
|
|
response["error"] = "Invalid Input"
|
|
|
|
# In development, include additional debug information
|
|
if app.debug:
|
|
response["debug"] = {
|
|
"exception": exc_info,
|
|
"class": e.__class__.__name__,
|
|
"module": e.__class__.__module__
|
|
}
|
|
|
|
return jsonify(response), status_code
|
|
|
|
@app.errorhandler(404)
|
|
def not_found_error(e):
|
|
return jsonify({
|
|
"error": "Not Found",
|
|
"message": str(e),
|
|
"type": "NotFoundError"
|
|
}), 404
|
|
|
|
@app.errorhandler(400)
|
|
def bad_request_error(e):
|
|
return jsonify({
|
|
"error": "Bad Request",
|
|
"message": str(e),
|
|
"type": "BadRequestError"
|
|
}), 400
|