Files
eveAI/eveai_chat/__init__.py
2024-11-29 11:24:32 +01:00

103 lines
3.5 KiB
Python

import logging
import logging.config
from flask import Flask, jsonify, request
import os
from flask_jwt_extended import verify_jwt_in_request, get_jwt_identity
from common.extensions import db, socketio, jwt, cors, session, simple_encryption, metrics
from config.logging_config import LOGGING
from eveai_chat.socket_handlers import chat_handler
from common.utils.cors_utils import create_cors_after_request, get_allowed_origins
from common.utils.celery_utils import make_celery, init_celery
from config.config import get_config
def create_app(config_file=None):
app = Flask(__name__)
environment = os.getenv('FLASK_ENV', 'development')
match environment:
case 'development':
app.config.from_object(get_config('dev'))
case 'production':
app.config.from_object(get_config('prod'))
case _:
app.config.from_object(get_config('dev'))
app.config['SESSION_KEY_PREFIX'] = 'eveai_chat_'
logging.config.dictConfig(LOGGING)
register_extensions(app)
app.celery = make_celery(app.name, app.config)
init_celery(app.celery, app)
@app.before_request
def check_cors():
app.logger.debug('Checking CORS')
if request.method == 'OPTIONS':
app.logger.debug("Handling OPTIONS request")
return '', 200 # Allow OPTIONS to pass through
origin = request.headers.get('Origin')
if not origin:
return # Not a CORS request
# Get tenant ID from request
if verify_jwt_in_request():
tenant_id = get_jwt_identity()
if not tenant_id:
return
else:
return
# Check if origin is allowed for this tenant
allowed_origins = get_allowed_origins(tenant_id)
if origin not in allowed_origins:
app.logger.warning(f'Origin {origin} not allowed for tenant {tenant_id}')
return {'error': 'Origin not allowed'}, 403
app.logger.info("EveAI Chat Server Started Successfully")
app.logger.info("-------------------------------------------------------------------------------------------------")
return app
def register_extensions(app):
db.init_app(app)
socketio.init_app(app,
message_queue=app.config.get('SOCKETIO_MESSAGE_QUEUE'),
cors_allowed_origins=app.config.get('SOCKETIO_CORS_ALLOWED_ORIGINS'),
async_mode=app.config.get('SOCKETIO_ASYNC_MODE'),
logger=app.config.get('SOCKETIO_LOGGER'),
engineio_logger=app.config.get('SOCKETIO_ENGINEIO_LOGGER'),
path='/socket.io/',
ping_timeout=app.config.get('SOCKETIO_PING_TIMEOUT'),
ping_interval=app.config.get('SOCKETIO_PING_INTERVAL'),
)
jwt.init_app(app)
simple_encryption.init_app(app)
metrics.init_app(app)
# Cors setup
cors.init_app(app, resources={
r"/*": { # Make sure this matches your setup
"origins": "*",
"methods": ["GET", "POST", "PUT", "OPTIONS"],
"allow_headers": ["Content-Type", "Authorization", "X-Requested-With"],
"expose_headers": ["Content-Length", "Content-Range"],
"supports_credentials": True,
"max_age": 1728000,
"allow_credentials": True
}
})
session.init_app(app)
def register_blueprints(app):
from views.healthz_views import healthz_bp
app.register_blueprint(healthz_bp)