from flask import request, current_app, session from flask_jwt_extended import decode_token, verify_jwt_in_request, get_jwt_identity from common.models.user import Tenant, TenantDomain def get_allowed_origins(tenant_id): session_key = f"allowed_origins_{tenant_id}" if session_key in session: return session[session_key] tenant_domains = TenantDomain.query.filter_by(tenant_id=int(tenant_id)).all() allowed_origins = [domain.domain for domain in tenant_domains] # Cache the result in the session session[session_key] = allowed_origins return allowed_origins def cors_after_request(response, prefix): # Exclude health checks from checks if request.path.startswith('/healthz') or request.path.startswith('/_healthz'): response.headers.add('Access-Control-Allow-Origin', '*') response.headers.add('Access-Control-Allow-Headers', '*') response.headers.add('Access-Control-Allow-Methods', '*') return response # Handle OPTIONS preflight requests if request.method == 'OPTIONS': response.headers.add('Access-Control-Allow-Origin', '*') response.headers.add('Access-Control-Allow-Headers', 'Content-Type,Authorization,X-Tenant-ID') response.headers.add('Access-Control-Allow-Methods', 'GET,POST,PUT,DELETE,OPTIONS') response.headers.add('Access-Control-Allow-Credentials', 'true') return response tenant_id = None allowed_origins = [] # Check Socket.IO connection if 'socket.io' in request.path: token = request.args.get('token') if token: try: decoded = decode_token(token) tenant_id = decoded['sub'] except Exception as e: current_app.logger.error(f'Error decoding token: {e}') return response else: # Regular API requests try: if verify_jwt_in_request(optional=True): tenant_id = get_jwt_identity() except Exception as e: current_app.logger.error(f'Error verifying JWT: {e}') return response if tenant_id: origin = request.headers.get('Origin') allowed_origins = get_allowed_origins(tenant_id) if origin in allowed_origins: response.headers.add('Access-Control-Allow-Origin', origin) response.headers.add('Access-Control-Allow-Headers', 'Content-Type,Authorization') response.headers.add('Access-Control-Allow-Methods', 'GET,POST,PUT,DELETE,OPTIONS') response.headers.add('Access-Control-Allow-Credentials', 'true') return response def create_cors_after_request(prefix): def wrapped_cors_after_request(response): return cors_after_request(response, prefix) return wrapped_cors_after_request def create_multiple_cors_after_requests(prefixes): def wrapped_cors_after_requests(response): for prefix, cors_function in prefixes: response = cors_function(response) return response return wrapped_cors_after_requests