103 lines
3.5 KiB
Python
103 lines
3.5 KiB
Python
import logging
|
|
import logging.config
|
|
from flask import Flask, jsonify, request
|
|
import os
|
|
|
|
from flask_jwt_extended import verify_jwt_in_request, get_jwt_identity
|
|
|
|
from common.extensions import db, socketio, jwt, cors, session, simple_encryption, metrics
|
|
from config.logging_config import LOGGING
|
|
from eveai_chat.socket_handlers import chat_handler
|
|
from common.utils.cors_utils import create_cors_after_request, get_allowed_origins
|
|
from common.utils.celery_utils import make_celery, init_celery
|
|
from config.config import get_config
|
|
|
|
|
|
def create_app(config_file=None):
|
|
app = Flask(__name__)
|
|
|
|
environment = os.getenv('FLASK_ENV', 'development')
|
|
|
|
match environment:
|
|
case 'development':
|
|
app.config.from_object(get_config('dev'))
|
|
case 'production':
|
|
app.config.from_object(get_config('prod'))
|
|
case _:
|
|
app.config.from_object(get_config('dev'))
|
|
|
|
app.config['SESSION_KEY_PREFIX'] = 'eveai_chat_'
|
|
|
|
logging.config.dictConfig(LOGGING)
|
|
register_extensions(app)
|
|
|
|
app.celery = make_celery(app.name, app.config)
|
|
init_celery(app.celery, app)
|
|
|
|
@app.before_request
|
|
def check_cors():
|
|
app.logger.debug('Checking CORS')
|
|
if request.method == 'OPTIONS':
|
|
app.logger.debug("Handling OPTIONS request")
|
|
return '', 200 # Allow OPTIONS to pass through
|
|
|
|
origin = request.headers.get('Origin')
|
|
if not origin:
|
|
return # Not a CORS request
|
|
|
|
# Get tenant ID from request
|
|
if verify_jwt_in_request():
|
|
tenant_id = get_jwt_identity()
|
|
if not tenant_id:
|
|
return
|
|
else:
|
|
return
|
|
|
|
# Check if origin is allowed for this tenant
|
|
allowed_origins = get_allowed_origins(tenant_id)
|
|
|
|
if origin not in allowed_origins:
|
|
app.logger.warning(f'Origin {origin} not allowed for tenant {tenant_id}')
|
|
return {'error': 'Origin not allowed'}, 403
|
|
|
|
app.logger.info("EveAI Chat Server Started Successfully")
|
|
app.logger.info("-------------------------------------------------------------------------------------------------")
|
|
return app
|
|
|
|
|
|
def register_extensions(app):
|
|
db.init_app(app)
|
|
socketio.init_app(app,
|
|
message_queue=app.config.get('SOCKETIO_MESSAGE_QUEUE'),
|
|
cors_allowed_origins=app.config.get('SOCKETIO_CORS_ALLOWED_ORIGINS'),
|
|
async_mode=app.config.get('SOCKETIO_ASYNC_MODE'),
|
|
logger=app.config.get('SOCKETIO_LOGGER'),
|
|
engineio_logger=app.config.get('SOCKETIO_ENGINEIO_LOGGER'),
|
|
path='/socket.io/',
|
|
ping_timeout=app.config.get('SOCKETIO_PING_TIMEOUT'),
|
|
ping_interval=app.config.get('SOCKETIO_PING_INTERVAL'),
|
|
)
|
|
jwt.init_app(app)
|
|
simple_encryption.init_app(app)
|
|
metrics.init_app(app)
|
|
|
|
# Cors setup
|
|
cors.init_app(app, resources={
|
|
r"/*": { # Make sure this matches your setup
|
|
"origins": "*",
|
|
"methods": ["GET", "POST", "PUT", "OPTIONS"],
|
|
"allow_headers": ["Content-Type", "Authorization", "X-Requested-With"],
|
|
"expose_headers": ["Content-Length", "Content-Range"],
|
|
"supports_credentials": True,
|
|
"max_age": 1728000,
|
|
"allow_credentials": True
|
|
}
|
|
})
|
|
|
|
session.init_app(app)
|
|
|
|
|
|
def register_blueprints(app):
|
|
from views.healthz_views import healthz_bp
|
|
app.register_blueprint(healthz_bp)
|