diff --git a/common/utils/key_encryption.py b/common/utils/key_encryption.py index e5c852f..5ee1bec 100644 --- a/common/utils/key_encryption.py +++ b/common/utils/key_encryption.py @@ -6,6 +6,7 @@ import random import time from flask import Flask import os +import ast def generate_api_key(prefix="EveAI-Chat"): @@ -72,6 +73,8 @@ class JosKMSClient(kms_v1.KeyManagementServiceClient): def decrypt_api_key(self, encrypted_data): """Decrypts the API key using the specified key version.""" + if isinstance(encrypted_data, str): + encrypted_data = ast.literal_eval(encrypted_data) key_version = encrypted_data['key_version'] key_name = self.key_name encrypted_dek = b64decode(encrypted_data['encrypted_dek'].encode('utf-8')) diff --git a/eveai_chat/socket_handlers/chat_handler.py b/eveai_chat/socket_handlers/chat_handler.py index 7943107..91b1c7b 100644 --- a/eveai_chat/socket_handlers/chat_handler.py +++ b/eveai_chat/socket_handlers/chat_handler.py @@ -1,27 +1,38 @@ -from flask_jwt_extended import verify_jwt_in_request, get_jwt_identity, verify_jwt_in_request, decode_token +from flask_jwt_extended import create_access_token, get_jwt_identity, verify_jwt_in_request, decode_token from flask_socketio import emit, disconnect from flask import current_app, request -from common.extensions import socketio + +from common.extensions import socketio, kms_client +from common.models.user import Tenant @socketio.on('connect') def handle_connect(): try: - # Extract token from the auth object - token = request.args.get('token') - if not token: - raise Exception("Missing Authorization Token") - current_app.logger.debug(f'SocketIO: Received token: {token}') - # Verify token - decoded_token = decode_token(token.split(" ")[1]) # Split to remove "Bearer " prefix - tenant_id = decoded_token["identity"]["tenant_id"] - current_app.logger.info(f'SocketIO: Tenant {decoded_token["identity"]["tenant_id"]} connected') - # communicate connection to client + current_app.logger.debug(f'SocketIO: Connection handling started using {request.args}') + tenant_id = request.args.get('tenantId') + if not tenant_id: + raise Exception("Missing Tenant ID") + api_key = request.args.get('apiKey') + if not api_key: + raise Exception("Missing API Key") + current_app.logger.info(f'SocketIO: Connection handling found Tenant {tenant_id} with API Key {api_key}') + + if not validate_api_key(tenant_id, api_key): + raise Exception("Invalid tenant_id - api_key combination") + + # Create JWT token + token = create_access_token(identity={"tenant_id": tenant_id, "api_key": api_key}) + current_app.logger.debug(f'SocketIO: Connection handling created token: {token} for tenant {tenant_id}') + + # Communicate connection to client emit('connect', {'status': 'Connected', 'tenant_id': tenant_id}) + emit('authenticated', {'token': token}) # Emit custom event with the token + current_app.logger.debug(f'SocketIO: Connection handling sent token to client for tenant {tenant_id}') except Exception as e: current_app.logger.error(f'SocketIO: Connection failed: {e}') # communicate connection problem to client - emit('connect', {'status': 'Connection Failed'}) + emit('connect_error', {'status': 'Connection Failed'}) disconnect() @@ -33,20 +44,46 @@ def handle_disconnect(): @socketio.on('user_message') def handle_message(data): try: - current_app.logger.debug(f"SocketIO: Received message from tenant {data['tenantId']}: {data['message']}") - verify_jwt_in_request() - current_tenant = get_jwt_identity() - print(f'Tenant {current_tenant["tenant_id"]} sent a message: {data}') + current_app.logger.debug(f"SocketIO: Message handling received message from tenant {data['tenantId']}: " + f"{data['message']} with token {data['token']}") + token = data.get('token') + if not token: + raise Exception("Missing token") + + # decoded_token = decode_token(token.split(" ")[1]) # remove "Bearer " + decoded_token = decode_token(token) + if not decoded_token: + raise Exception("Invalid token") + current_app.logger.debug(f"SocketIO: Message handling decoded token: {decoded_token}") + + token_sub = decoded_token.get('sub') + + current_tenant_id = token_sub.get('tenant_id') + if not current_tenant_id: + raise Exception("Missing tenant_id") + + current_api_key = token_sub.get('api_key') + if not current_api_key: + raise Exception("Missing api_key") + # Store interaction in the database + response = { 'tenantId': data['tenantId'], - 'message': 'This is a bot response. Actual implementation still required.', + 'message': f'This is a bot response. Responding to message {data['message']} ' + f'from tenant {current_tenant_id}', 'messageId': 'bot-message-id', 'algorithm': 'alg1' } - current_app.logger.debug(f"SocketIO: Bot response: {response}") + current_app.logger.debug(f"SocketIO: Message handling sent bot response: {response}") emit('bot_response', response, broadcast=True) except Exception as e: current_app.logger.error(f'SocketIO: Message handling failed: {e}') disconnect() + +def validate_api_key(tenant_id, api_key): + tenant = Tenant.query.get_or_404(tenant_id) + decrypted_api_key = kms_client.decrypt_api_key(tenant.encrypted_chat_api_key) + + return decrypted_api_key == api_key diff --git a/public/chat.html b/public/chat.html index 895491e..0fb59f1 100644 --- a/public/chat.html +++ b/public/chat.html @@ -12,7 +12,7 @@