199 lines
7.0 KiB
Python
199 lines
7.0 KiB
Python
from datetime import timedelta, datetime as dt, timezone as tz
|
|
|
|
from flask_restx import Namespace, Resource, fields
|
|
from flask_jwt_extended import create_access_token, verify_jwt_in_request, get_jwt, get_jwt_identity, jwt_required
|
|
from common.models.user import Tenant, TenantProject
|
|
from common.extensions import simple_encryption
|
|
from flask import current_app, jsonify, request
|
|
from functools import wraps
|
|
|
|
auth_ns = Namespace('auth', description='Authentication related operations')
|
|
|
|
token_model = auth_ns.model('Token', {
|
|
'tenant_id': fields.Integer(required=True, description='Tenant ID'),
|
|
'api_key': fields.String(required=True, description='API Key')
|
|
})
|
|
|
|
token_response = auth_ns.model('TokenResponse', {
|
|
'access_token': fields.String(description='JWT access token'),
|
|
'expires_in': fields.Integer(description='Token expiration time in seconds')
|
|
})
|
|
|
|
token_verification = auth_ns.model('TokenVerification', {
|
|
'is_valid': fields.Boolean(description='Token validity status'),
|
|
'expires_in': fields.Integer(description='Seconds until token expiration'),
|
|
'tenant_id': fields.Integer(description='Tenant ID from token')
|
|
})
|
|
|
|
|
|
@auth_ns.route('/token')
|
|
class Token(Resource):
|
|
@auth_ns.expect(token_model)
|
|
@auth_ns.response(200, 'Success', token_response)
|
|
@auth_ns.response(400, 'Validation Error')
|
|
@auth_ns.response(401, 'Unauthorized')
|
|
@auth_ns.response(404, 'Tenant Not Found')
|
|
def post(self):
|
|
"""
|
|
Get JWT token
|
|
"""
|
|
try:
|
|
tenant_id = int(auth_ns.payload['tenant_id'])
|
|
api_key = auth_ns.payload['api_key']
|
|
except KeyError as e:
|
|
current_app.logger.error(f"Missing required field: {e}")
|
|
return {'message': f"Missing required field: {e}"}, 400
|
|
|
|
tenant = Tenant.query.get(tenant_id)
|
|
|
|
if not tenant:
|
|
current_app.logger.error(f"Tenant not found: {tenant_id}")
|
|
return {'message': f"Authentication invalid for tenant {tenant_id}"}, 404
|
|
|
|
projects = TenantProject.query.filter_by(
|
|
tenant_id=tenant_id,
|
|
active=True
|
|
).all()
|
|
|
|
# Find project with matching API key
|
|
matching_project = None
|
|
for project in projects:
|
|
try:
|
|
decrypted_key = simple_encryption.decrypt_api_key(project.encrypted_api_key)
|
|
if decrypted_key == api_key:
|
|
matching_project = project
|
|
break
|
|
except Exception as e:
|
|
current_app.logger.error(f"Error decrypting API key for project {project.id}: {e}")
|
|
continue
|
|
|
|
if not matching_project:
|
|
current_app.logger.error(f"Project for given API key not found for Tenant: {tenant_id}")
|
|
return {'message': "Invalid API key"}, 401
|
|
|
|
# Get the JWT_ACCESS_TOKEN_EXPIRES setting from the app config
|
|
expires_delta = current_app.config.get('JWT_ACCESS_TOKEN_EXPIRES', timedelta(minutes=15))
|
|
|
|
try:
|
|
additional_claims = {
|
|
'services': matching_project.services,
|
|
}
|
|
access_token = create_access_token(
|
|
identity=tenant_id,
|
|
expires_delta=expires_delta,
|
|
additional_claims=additional_claims
|
|
)
|
|
current_app.logger.debug(f"Created token: {access_token}")
|
|
return {
|
|
'access_token': access_token,
|
|
'expires_in': expires_delta.total_seconds()
|
|
}, 200
|
|
except Exception as e:
|
|
current_app.logger.error(f"Error creating access token: {e}")
|
|
return {'message': "Internal server error"}, 500
|
|
|
|
|
|
@auth_ns.route('/verify')
|
|
class TokenVerification(Resource):
|
|
@auth_ns.doc('verify_token')
|
|
@auth_ns.response(200, 'Token verification result', token_verification)
|
|
@auth_ns.response(401, 'Invalid token')
|
|
def get(self):
|
|
"""Verify a token's validity and get expiration information"""
|
|
try:
|
|
verify_jwt_in_request()
|
|
jwt_data = get_jwt()
|
|
|
|
# Get expiration timestamp from token
|
|
exp_timestamp = jwt_data['exp']
|
|
current_timestamp = dt.now().timestamp()
|
|
|
|
return {
|
|
'is_valid': True,
|
|
'expires_in': int(exp_timestamp - current_timestamp),
|
|
'tenant_id': jwt_data['sub'] # tenant_id is stored in 'sub' claim
|
|
}, 200
|
|
except Exception as e:
|
|
current_app.logger.error(f"Token verification failed: {str(e)}")
|
|
return {
|
|
'is_valid': False,
|
|
'message': 'Invalid token'
|
|
}, 401
|
|
|
|
|
|
@auth_ns.route('/refresh')
|
|
class TokenRefresh(Resource):
|
|
@auth_ns.doc('refresh_token')
|
|
@auth_ns.response(200, 'New token', token_response)
|
|
@auth_ns.response(401, 'Invalid token')
|
|
def post(self):
|
|
"""Get a new token before the current one expires"""
|
|
try:
|
|
verify_jwt_in_request()
|
|
jwt_data = get_jwt()
|
|
tenant_id = jwt_data['sub']
|
|
|
|
# Optional: Add additional verification here if needed
|
|
|
|
# Create new token
|
|
expires_delta = current_app.config.get('JWT_ACCESS_TOKEN_EXPIRES', timedelta(minutes=15))
|
|
new_token = create_access_token(
|
|
identity=tenant_id,
|
|
expires_delta=expires_delta
|
|
)
|
|
|
|
return {
|
|
'access_token': new_token,
|
|
'expires_in': int(expires_delta.total_seconds())
|
|
}, 200
|
|
except Exception as e:
|
|
current_app.logger.error(f"Token refresh failed: {str(e)}")
|
|
return {'message': 'Token refresh failed'}, 401
|
|
|
|
|
|
@auth_ns.route('/services')
|
|
class Services(Resource):
|
|
@jwt_required()
|
|
@auth_ns.doc(security='Bearer')
|
|
@auth_ns.response(200, 'Success', {
|
|
'services': fields.List(fields.String, description='List of allowed services for this token'),
|
|
'tenant_id': fields.Integer(description='Tenant ID associated with this token')
|
|
})
|
|
@auth_ns.response(401, 'Invalid or expired token')
|
|
def get(self):
|
|
"""
|
|
Get allowed services for the current token
|
|
"""
|
|
# Log the incoming authorization header
|
|
auth_header = request.headers.get('Authorization')
|
|
current_app.logger.debug(f"Received Authorization header: {auth_header}")
|
|
|
|
claims = get_jwt()
|
|
tenant_id = get_jwt_identity()
|
|
|
|
return {
|
|
'services': claims.get('services', []),
|
|
'tenant_id': tenant_id
|
|
}, 200
|
|
|
|
|
|
# Decorate function to check for a particular service
|
|
def requires_service(service_name):
|
|
def decorator(fn):
|
|
@wraps(fn)
|
|
def wrapper(*args, **kwargs):
|
|
# Get the JWT claims
|
|
claims = get_jwt()
|
|
services = claims.get('services', [])
|
|
|
|
if service_name not in services:
|
|
return {
|
|
'message': f'This endpoint requires the {service_name} service',
|
|
'error': 'Insufficient permissions'
|
|
}, 403
|
|
return fn(*args, **kwargs)
|
|
|
|
return wrapper
|
|
|
|
return decorator
|