Files
eveAI/eveai_chat/socket_handlers/chat_handler.py

359 lines
13 KiB
Python

import uuid
from functools import wraps
from flask_jwt_extended import create_access_token, get_jwt_identity, verify_jwt_in_request, decode_token
from flask_socketio import emit, disconnect, join_room, leave_room
from flask import current_app, request, session
from sqlalchemy.exc import SQLAlchemyError
from datetime import datetime, timedelta
from prometheus_client import Counter, Histogram
from time import time
import re
from common.extensions import socketio, db, simple_encryption
from common.models.user import Tenant
from common.models.interaction import Interaction
from common.utils.celery_utils import current_celery
from common.utils.database import Database
from common.utils.token_validation import TokenValidator
from common.utils.eveai_exceptions import EveAISocketInputException
# Define custom metrics
socketio_message_counter = Counter('socketio_message_count', 'Count of SocketIO messages', ['event_type'])
socketio_message_latency = Histogram('socketio_message_latency_seconds', 'Latency of SocketIO message processing', ['event_type'])
class RoomManager:
def __init__(self):
self.active_rooms = {} # Store active room metadata
def validate_room_format(self, room_id: str) -> bool:
"""Validate room ID format: tenant_id_sessionid_timestamp"""
pattern = r'^\d+_[a-zA-Z0-9]+_\d+$'
return bool(re.match(pattern, room_id))
def is_room_active(self, room_id: str) -> bool:
return room_id in self.active_rooms
def validate_room_ownership(self, room_id: str, tenant_id: int, token: str) -> bool:
if not self.is_room_active(room_id):
return False
room_data = self.active_rooms[room_id]
return (room_data['tenant_id'] == tenant_id and
room_data['token'] == token)
def create_room(self, tenant_id: int, token: str) -> str:
"""Create new room with metadata"""
timestamp = int(datetime.now().timestamp())
room_id = f"{tenant_id}_{request.sid}_{timestamp}"
self.active_rooms[room_id] = {
'tenant_id': tenant_id,
'token': token,
'created_at': datetime.now(),
'last_activity': datetime.now()
}
return room_id
def update_room_activity(self, room_id: str):
"""Update room's last activity timestamp"""
if room_id in self.active_rooms:
self.active_rooms[room_id]['last_activity'] = datetime.now()
def cleanup_inactive_rooms(self, max_age_hours: int = 1):
"""Remove inactive rooms"""
now = datetime.now()
cutoff = now - timedelta(hours=max_age_hours)
inactive_rooms = [
room_id for room_id, data in self.active_rooms.items()
if data['last_activity'] < cutoff
]
for room_id in inactive_rooms:
del self.active_rooms[room_id]
room_manager = RoomManager()
# Decorator to measure SocketIO events
def track_socketio_event(func):
@wraps(func)
def wrapper(*args, **kwargs):
event_type = func.__name__
socketio_message_counter.labels(event_type=event_type).inc()
start_time = time()
result = func(*args, **kwargs)
latency = time() - start_time
socketio_message_latency.labels(event_type=event_type).observe(latency)
return result
return wrapper
@socketio.on('connect')
@track_socketio_event
def handle_connect():
"""Handle incoming socket connections with enhanced security"""
try:
current_app.logger.debug('Handle Connection')
token = request.args.get('token')
if not token:
raise ValueError("Missing token")
current_app.logger.debug(f"Token received: {token}")
if not token:
raise ValueError("Missing token")
current_app.logger.info(f"Trying to connect with: {token}")
validator = TokenValidator()
validation_result = validator.validate_token(token)
if not validation_result.is_valid:
current_app.logger.error(f"Socket connection failed: {validation_result.error_message}")
emit('connect_error', {'error': validation_result.error_message})
disconnect()
return
# Create room and setup session
room = room_manager.create_room(validation_result.tenant_id, token)
join_room(room)
session['session_id'] = str(uuid.uuid4())
session['last_activity'] = datetime.now()
session['room'] = room
# Emit success events
emit('connect', {
'status': 'Connected',
'tenant_id': validation_result.tenant_id,
'room': room
})
emit('authenticated', {'token': token, 'room': room})
current_app.logger.info(f"Socket connection succeeded: {token} / {room}")
except Exception as e:
current_app.logger.error(f"Socket connection failed: {str(e)}")
emit('connect_error', {'status': 'Connection Failed'})
disconnect()
@socketio.on('rejoin_room')
def handle_room_rejoin(data):
try:
token = data.get('token')
tenant_id = data.get('tenant_id')
previous_room = data.get('previousRoom')
validator = TokenValidator()
validation_result = validator.validate_token(token, require_session=True)
if not validation_result.is_valid:
emit('room_rejoin_result', {'success': False, 'error': validation_result.error_message})
return
if not all([token, tenant_id, previous_room]):
raise ValueError("Missing required rejoin data")
# Validate room ownership
if not room_manager.validate_room_ownership(previous_room, tenant_id, token):
raise ValueError("Invalid room ownership")
# Rejoin room
join_room(previous_room)
session['room'] = previous_room
room_manager.update_room_activity(previous_room)
emit('room_rejoin_result', {
'success': True,
'room': previous_room
})
except Exception as e:
current_app.logger.error(f'Room rejoin failed: {e}')
emit('room_rejoin_result', {
'success': False,
'error': str(e)
})
@socketio.on('disconnect')
@track_socketio_event
def handle_disconnect():
room = session.get('room')
if room:
leave_room(room)
@socketio.on('heartbeat')
def handle_heartbeat():
last_activity = session.get('last_activity')
if datetime.now() - last_activity > current_app.config.get('SOCKETIO_MAX_IDLE_TIME'):
disconnect()
@socketio.on('user_message')
def handle_message(data):
current_app.logger.debug(f"SocketIO: Received message: {data}")
try:
validator = TokenValidator()
validation_result = validator.validate_token(data.get('token'))
if not validation_result.is_valid:
emit('error', {'message': validation_result.error_message})
return
current_app.logger.debug(f"SocketIO: token validated: {validation_result}")
room = session.get('room')
current_app.logger.debug(f"SocketIO: Room in session: {room}, Room in arguments: {data.get('room')}")
current_app.logger.debug(f"SocketIO: Room: {room}")
if not room or not room_manager.is_room_active(room):
raise Exception("Invalid or inactive room")
current_app.logger.debug(f"SocketIO: Room active: {room}")
if not room_manager.validate_room_ownership(room, data['tenant_id'], data['token']):
raise Exception("Room ownership validation failed")
current_app.logger.debug(f"SocketIO: Room ownership validated: {room}")
room_manager.update_room_activity(room)
current_app.logger.debug(f"SocketIO: Room activity updated: {room}")
session['last_activity'] = datetime.now()
current_tenant_id = validate_incoming_data(data)
current_app.logger.debug(f"SocketIO: Incoming data validated: {current_tenant_id}")
# Offload actual processing of question
task = current_celery.send_task('execute_specialist',
queue='llm_interactions',
args=[
current_tenant_id,
data['specialist_id'],
data['arguments'],
session['session_id'],
data['timezone'],
room
])
response = {
'tenantId': current_tenant_id,
'message': f'Processing question ... Session ID = {session["session_id"]}',
'taskId': task.id,
'room': room,
}
current_app.logger.debug(f"SocketIO: Sent response {response}")
emit('bot_response', response, room=room)
except Exception as e:
current_app.logger.error(f'SocketIO: Message handling failed: {str(e)}')
emit('error', {'message': 'Failed to process message'}, room=room)
@socketio.on('check_task_status')
def check_task_status(data):
current_app.logger.debug(f'SocketIO: Checking Task Status ... {data}')
validator = TokenValidator()
validation_result = validator.validate_token(data.get('token'))
if not validation_result.is_valid:
emit('feedback_received', {'status': 'error', 'error': validation_result.error_message})
return
task_id = data.get('task_id')
room = session.get('room')
if not task_id:
emit('task_status', {'status': 'error', 'message': 'Missing task ID'}, room=room)
return
task_result = current_celery.AsyncResult(task_id)
if task_result.state == 'PENDING':
emit('task_status', {'status': 'pending', 'taskId': task_id}, room=room)
elif task_result.state == 'SUCCESS':
result = task_result.result
current_app.logger.debug(f'SocketIO: Task {task_id} returned: {result}')
# Access the result structure correctly
specialist_result = result['result'] # This contains the SpecialistResult model_dump
response = {
'status': 'success',
'taskId': task_id,
'results': {
'answer': specialist_result.get('answer'),
'citations': specialist_result.get('citations', []),
'insufficient_info': specialist_result.get('insufficient_info', False)
},
'interaction_id': result['interaction_id'],
'room': room
}
emit('task_status', response, room=room)
else:
current_app.logger.error(f'SocketIO: Task {task_id} has failed. Error: {task_result.info}')
emit('task_status', {'status': task_result.state, 'message': str(task_result.info)}, room=room)
@socketio.on('feedback')
def handle_feedback(data):
current_app.logger.debug(f'SocketIO: Received feedback: {data}')
try:
validator = TokenValidator()
validation_result = validator.validate_token(data.get('token'))
if not validation_result.is_valid:
emit('feedback_received', {'status': 'error', 'error': validation_result.error_message})
return
current_tenant_id = validate_incoming_data(data)
interaction_id = data.get('interactionId')
feedback = data.get('feedback') # 'up' or 'down'
Database(current_tenant_id).switch_schema()
interaction = Interaction.query.get_or_404(interaction_id)
interaction.appreciation = 0 if feedback == 'down' else 100
room = session.get('room')
if not room:
emit('feedback_received', {'status': 'error', 'message': 'No active room'})
return
try:
db.session.commit()
emit('feedback_received', {'status': 'success', 'interaction_id': interaction_id, 'room': room}, room=room)
except SQLAlchemyError as e:
current_app.logger.error(f'SocketIO: Feedback handling failed: {e}')
db.session.rollback()
emit('feedback_received', {'status': 'Could not register feedback', 'interaction_id': interaction_id})
raise e
except Exception as e:
current_app.logger.error(f'SocketIO: Feedback handling failed: {e}')
disconnect()
def validate_api_key(tenant_id, api_key):
tenant = Tenant.query.get_or_404(tenant_id)
decrypted_api_key = simple_encryption.decrypt_api_key(tenant.encrypted_chat_api_key)
return decrypted_api_key == api_key
def validate_incoming_data(data):
current_app.logger.debug(f'SocketIO: Validating incoming data: {data}')
token = data.get('token')
if not token:
raise EveAISocketInputException("SocketIO: Missing token in input")
decoded_token = decode_token(token)
if not decoded_token:
raise EveAISocketInputException("SocketIO: Invalid token in input")
current_app.logger.debug(f'SocketIO: Decoded token: {decoded_token}')
current_tenant_id = decoded_token.get('sub')
if not current_tenant_id:
raise EveAISocketInputException("SocketIO: Missing tenant_id (sub) in input")
return current_tenant_id