151 lines
6.1 KiB
Python
151 lines
6.1 KiB
Python
import uuid
|
|
|
|
from flask_jwt_extended import create_access_token, get_jwt_identity, verify_jwt_in_request, decode_token
|
|
from flask_socketio import emit, disconnect
|
|
from flask import current_app, request, session
|
|
from sqlalchemy.exc import SQLAlchemyError
|
|
|
|
from common.extensions import socketio, kms_client, db
|
|
from common.models.user import Tenant
|
|
from common.models.interaction import Interaction
|
|
from common.utils.celery_utils import current_celery
|
|
|
|
|
|
@socketio.on('connect')
|
|
def handle_connect():
|
|
try:
|
|
current_app.logger.debug(f'SocketIO: Connection handling started using {request.args}')
|
|
tenant_id = request.args.get('tenantId')
|
|
if not tenant_id:
|
|
raise Exception("Missing Tenant ID")
|
|
api_key = request.args.get('apiKey')
|
|
if not api_key:
|
|
raise Exception("Missing API Key")
|
|
current_app.logger.info(f'SocketIO: Connection handling found Tenant {tenant_id} with API Key {api_key}')
|
|
|
|
if not validate_api_key(tenant_id, api_key):
|
|
raise Exception("Invalid tenant_id - api_key combination")
|
|
|
|
# Create JWT token
|
|
token = create_access_token(identity={"tenant_id": tenant_id, "api_key": api_key})
|
|
current_app.logger.debug(f'SocketIO: Connection handling created token: {token} for tenant {tenant_id}')
|
|
|
|
# Create a unique session ID
|
|
if 'session_id' not in session:
|
|
session['session_id'] = str(uuid.uuid4())
|
|
|
|
# Communicate connection to client
|
|
emit('connect', {'status': 'Connected', 'tenant_id': tenant_id})
|
|
emit('authenticated', {'token': token}) # Emit custom event with the token
|
|
current_app.logger.debug(f'SocketIO: Connection handling sent token to client for tenant {tenant_id}')
|
|
except Exception as e:
|
|
current_app.logger.error(f'SocketIO: Connection failed: {e}')
|
|
# communicate connection problem to client
|
|
emit('connect_error', {'status': 'Connection Failed'})
|
|
disconnect()
|
|
|
|
|
|
@socketio.on('disconnect')
|
|
def handle_disconnect():
|
|
current_app.logger.debug('SocketIO: Client disconnected')
|
|
|
|
|
|
@socketio.on('user_message')
|
|
def handle_message(data):
|
|
try:
|
|
current_app.logger.debug(f"SocketIO: Message handling received message from tenant {data['tenantId']}: "
|
|
f"{data['message']} with token {data['token']}")
|
|
token = data.get('token')
|
|
if not token:
|
|
raise Exception("Missing token")
|
|
|
|
# decoded_token = decode_token(token.split(" ")[1]) # remove "Bearer "
|
|
decoded_token = decode_token(token)
|
|
if not decoded_token:
|
|
raise Exception("Invalid token")
|
|
current_app.logger.debug(f"SocketIO: Message handling decoded token: {decoded_token}")
|
|
|
|
token_sub = decoded_token.get('sub')
|
|
|
|
current_tenant_id = token_sub.get('tenant_id')
|
|
if not current_tenant_id:
|
|
raise Exception("Missing tenant_id")
|
|
|
|
current_api_key = token_sub.get('api_key')
|
|
if not current_api_key:
|
|
raise Exception("Missing api_key")
|
|
|
|
# Offload actual processing of question
|
|
task = current_celery.send_task('ask_question', queue='llm_interactions', args=[
|
|
current_tenant_id,
|
|
data['message'],
|
|
data['language'],
|
|
session['session_id'],
|
|
])
|
|
current_app.logger.debug(f'SocketIO: Message offloading for tenant {current_tenant_id}, '
|
|
f'Question: {task.id}')
|
|
response = {
|
|
'tenantId': data['tenantId'],
|
|
'message': f'Processing question ... Session ID = {session["session_id"]}',
|
|
'taskId': task.id,
|
|
}
|
|
current_app.logger.debug(f"SocketIO: Message handling sent bot response: {response}")
|
|
emit('bot_response', response, broadcast=True)
|
|
except Exception as e:
|
|
current_app.logger.error(f'SocketIO: Message handling failed: {e}')
|
|
disconnect()
|
|
|
|
|
|
@socketio.on('check_task_status')
|
|
def check_task_status(data):
|
|
task_id = data.get('task_id')
|
|
current_app.logger.debug(f'SocketIO: Check task status for task_id: {task_id}')
|
|
if not task_id:
|
|
emit('task_status', {'status': 'error', 'message': 'Missing task ID'})
|
|
return
|
|
|
|
task_result = current_celery.AsyncResult(task_id)
|
|
if task_result.state == 'PENDING':
|
|
current_app.logger.debug(f'SocketIO: Task {task_id} is pending')
|
|
emit('task_status', {'status': 'pending', 'taskId': task_id})
|
|
elif task_result.state == 'SUCCESS':
|
|
current_app.logger.debug(f'SocketIO: Task {task_id} has finished. Status: {task_result.state}, '
|
|
f'Result: {task_result.result}')
|
|
result = task_result.result
|
|
response = {
|
|
'status': 'success',
|
|
'taskId': task_id,
|
|
'answer': result['answer'],
|
|
'citations': result['citations'],
|
|
'algorithm': result['algorithm'],
|
|
'interaction_id': result['interaction_id'],
|
|
}
|
|
emit('task_status', response)
|
|
else:
|
|
current_app.logger.error(f'SocketIO: Task {task_id} has failed. Error: {task_result.info}')
|
|
emit('task_status', {'status': task_result.state, 'message': str(task_result.info)})
|
|
|
|
|
|
@socketio.on('feedback')
|
|
def handle_feedback(data):
|
|
interaction_id = data.get('interaction_id')
|
|
feedback = data.get('feedback') # 'up' or 'down'
|
|
# Store feedback in the database associated with the interaction_id
|
|
interaction = Interaction.query.get_or_404(interaction_id)
|
|
interaction.feedback = 0 if feedback == 'down' else 1
|
|
try:
|
|
db.session.commit()
|
|
emit('feedback_received', {'status': 'success', 'interaction_id': interaction_id})
|
|
except SQLAlchemyError as e:
|
|
current_app.logger.error(f'SocketIO: Feedback handling failed: {e}')
|
|
db.session.rollback()
|
|
emit('feedback_received', {'status': 'Could not register feedback', 'interaction_id': interaction_id})
|
|
raise e
|
|
|
|
|
|
def validate_api_key(tenant_id, api_key):
|
|
tenant = Tenant.query.get_or_404(tenant_id)
|
|
decrypted_api_key = kms_client.decrypt_api_key(tenant.encrypted_chat_api_key)
|
|
|
|
return decrypted_api_key == api_key
|