from datetime import timedelta, datetime as dt, timezone as tz from flask_restx import Namespace, Resource, fields from flask_jwt_extended import create_access_token, verify_jwt_in_request, get_jwt, get_jwt_identity, jwt_required from common.models.user import Tenant, TenantProject from common.extensions import simple_encryption from flask import current_app, jsonify, request from functools import wraps auth_ns = Namespace('auth', description='Authentication related operations') token_model = auth_ns.model('Token', { 'tenant_id': fields.Integer(required=True, description='Tenant ID'), 'api_key': fields.String(required=True, description='API Key') }) token_response = auth_ns.model('TokenResponse', { 'access_token': fields.String(description='JWT access token'), 'expires_in': fields.Integer(description='Token expiration time in seconds') }) token_verification = auth_ns.model('TokenVerification', { 'is_valid': fields.Boolean(description='Token validity status'), 'expires_in': fields.Integer(description='Seconds until token expiration'), 'tenant_id': fields.Integer(description='Tenant ID from token') }) @auth_ns.route('/token') class Token(Resource): @auth_ns.expect(token_model) @auth_ns.response(200, 'Success', token_response) @auth_ns.response(400, 'Validation Error') @auth_ns.response(401, 'Unauthorized') @auth_ns.response(404, 'Tenant Not Found') def post(self): """ Get JWT token """ try: tenant_id = int(auth_ns.payload['tenant_id']) api_key = auth_ns.payload['api_key'] except KeyError as e: current_app.logger.error(f"Missing required field: {e}") return {'message': f"Missing required field: {e}"}, 400 tenant = Tenant.query.get(tenant_id) if not tenant: current_app.logger.error(f"Tenant not found: {tenant_id}") return {'message': f"Authentication invalid for tenant {tenant_id}"}, 404 projects = TenantProject.query.filter_by( tenant_id=tenant_id, active=True ).all() # Find project with matching API key matching_project = None for project in projects: try: decrypted_key = simple_encryption.decrypt_api_key(project.encrypted_api_key) if decrypted_key == api_key: matching_project = project break except Exception as e: current_app.logger.error(f"Error decrypting API key for project {project.id}: {e}") continue if not matching_project: current_app.logger.error(f"Project for given API key not found for Tenant: {tenant_id}") return {'message': "Invalid API key"}, 401 # Get the JWT_ACCESS_TOKEN_EXPIRES setting from the app config expires_delta = current_app.config.get('JWT_ACCESS_TOKEN_EXPIRES', timedelta(minutes=15)) try: additional_claims = { 'services': matching_project.services, } access_token = create_access_token( identity=tenant_id, expires_delta=expires_delta, additional_claims=additional_claims ) return { 'access_token': access_token, 'expires_in': expires_delta.total_seconds() }, 200 except Exception as e: current_app.logger.error(f"Error creating access token: {e}") return {'message': "Internal server error"}, 500 @auth_ns.route('/verify') class TokenVerification(Resource): @auth_ns.doc('verify_token') @auth_ns.response(200, 'Token verification result', token_verification) @auth_ns.response(401, 'Invalid token') def get(self): """Verify a token's validity and get expiration information""" try: verify_jwt_in_request() jwt_data = get_jwt() # Get expiration timestamp from token exp_timestamp = jwt_data['exp'] current_timestamp = dt.now().timestamp() return { 'is_valid': True, 'expires_in': int(exp_timestamp - current_timestamp), 'tenant_id': jwt_data['sub'] # tenant_id is stored in 'sub' claim }, 200 except Exception as e: current_app.logger.error(f"Token verification failed: {str(e)}") return { 'is_valid': False, 'message': 'Invalid token' }, 401 @auth_ns.route('/refresh') class TokenRefresh(Resource): @auth_ns.doc('refresh_token') @auth_ns.response(200, 'New token', token_response) @auth_ns.response(401, 'Invalid token') def post(self): """Get a new token before the current one expires""" try: verify_jwt_in_request() jwt_data = get_jwt() tenant_id = jwt_data['sub'] # Optional: Add additional verification here if needed # Create new token expires_delta = current_app.config.get('JWT_ACCESS_TOKEN_EXPIRES', timedelta(minutes=15)) new_token = create_access_token( identity=tenant_id, expires_delta=expires_delta ) return { 'access_token': new_token, 'expires_in': int(expires_delta.total_seconds()) }, 200 except Exception as e: current_app.logger.error(f"Token refresh failed: {str(e)}") return {'message': 'Token refresh failed'}, 401 @auth_ns.route('/services') class Services(Resource): @jwt_required() @auth_ns.doc(security='Bearer') @auth_ns.response(200, 'Success', { 'services': fields.List(fields.String, description='List of allowed services for this token'), 'tenant_id': fields.Integer(description='Tenant ID associated with this token') }) @auth_ns.response(401, 'Invalid or expired token') def get(self): """ Get allowed services for the current token """ claims = get_jwt() tenant_id = get_jwt_identity() return { 'services': claims.get('services', []), 'tenant_id': tenant_id }, 200 # Decorate function to check for a particular service def requires_service(service_name): def decorator(fn): @wraps(fn) def wrapper(*args, **kwargs): # Get the JWT claims claims = get_jwt() services = claims.get('services', []) if service_name not in services: return { 'message': f'This endpoint requires the {service_name} service', 'error': 'Insufficient permissions' }, 403 return fn(*args, **kwargs) return wrapper return decorator