76 lines
2.9 KiB
Python
76 lines
2.9 KiB
Python
from flask import request, current_app, session
|
|
from common.models.user import Tenant, TenantDomain
|
|
|
|
|
|
def get_allowed_origins(tenant_id):
|
|
session_key = f"allowed_origins_{tenant_id}"
|
|
if session_key in session:
|
|
current_app.logger.debug(f"Fetching allowed origins for tenant {tenant_id} from session")
|
|
return session[session_key]
|
|
|
|
current_app.logger.debug(f"Fetching allowed origins for tenant {tenant_id} from database")
|
|
tenant_domains = TenantDomain.query.filter_by(tenant_id=int(tenant_id)).all()
|
|
allowed_origins = [domain.domain for domain in tenant_domains]
|
|
|
|
# Cache the result in the session
|
|
session[session_key] = allowed_origins
|
|
return allowed_origins
|
|
|
|
|
|
def cors_after_request(response, prefix):
|
|
current_app.logger.debug(f'CORS after request: {request.path}, prefix: {prefix}')
|
|
current_app.logger.debug(f'request.headers: {request.headers}')
|
|
current_app.logger.debug(f'request.args: {request.args}')
|
|
current_app.logger.debug(f'request is json?: {request.is_json}')
|
|
|
|
tenant_id = None
|
|
allowed_origins = []
|
|
|
|
# Try to get tenant_id from JSON payload
|
|
json_data = request.get_json(silent=True)
|
|
current_app.logger.debug(f'request.get_json(silent=True): {json_data}')
|
|
|
|
if json_data and 'tenant_id' in json_data:
|
|
tenant_id = json_data['tenant_id']
|
|
else:
|
|
# Fallback to get tenant_id from query parameters or headers if JSON is not available
|
|
tenant_id = request.args.get('tenant_id') or request.args.get('tenantId') or request.headers.get('X-Tenant-ID')
|
|
|
|
current_app.logger.debug(f'Identified tenant_id: {tenant_id}')
|
|
|
|
if tenant_id:
|
|
allowed_origins = get_allowed_origins(tenant_id)
|
|
current_app.logger.debug(f'Allowed origins for tenant {tenant_id}: {allowed_origins}')
|
|
else:
|
|
current_app.logger.warning('tenant_id not found in request')
|
|
|
|
origin = request.headers.get('Origin')
|
|
current_app.logger.debug(f'Origin: {origin}')
|
|
|
|
if origin in allowed_origins:
|
|
response.headers.add('Access-Control-Allow-Origin', origin)
|
|
response.headers.add('Access-Control-Allow-Headers', 'Content-Type,Authorization')
|
|
response.headers.add('Access-Control-Allow-Methods', 'GET,POST,PUT,DELETE,OPTIONS')
|
|
response.headers.add('Access-Control-Allow-Credentials', 'true')
|
|
current_app.logger.debug(f'CORS headers set for origin: {origin}')
|
|
else:
|
|
current_app.logger.warning(f'Origin {origin} not allowed')
|
|
|
|
return response
|
|
|
|
|
|
def create_cors_after_request(prefix):
|
|
def wrapped_cors_after_request(response):
|
|
return cors_after_request(response, prefix)
|
|
|
|
return wrapped_cors_after_request
|
|
|
|
|
|
def create_multiple_cors_after_requests(prefixes):
|
|
def wrapped_cors_after_requests(response):
|
|
for prefix, cors_function in prefixes:
|
|
response = cors_function(response)
|
|
return response
|
|
|
|
return wrapped_cors_after_requests
|