Files
eveAI/common/utils/execution_progress.py
Josako 5465dae52f - Optimisation and streamlining of messages in ExecutionProgressTracker (ept)
- Adaptation of ProgressTracker to handle these optimised messages
- Hardening SSE-streaming in eveai_chat_client
2025-10-03 08:58:44 +02:00

160 lines
6.3 KiB
Python

# common/utils/execution_progress.py
from datetime import datetime as dt, timezone as tz
from typing import Generator
from redis import Redis, RedisError
import json
from flask import current_app
import time
class ExecutionProgressTracker:
"""Tracks progress of specialist executions using Redis"""
# Normalized processing types and aliases
PT_COMPLETE = 'EVEAI_COMPLETE'
PT_ERROR = 'EVEAI_ERROR'
_COMPLETE_ALIASES = {'EveAI Specialist Complete', 'Task Complete', 'task complete'}
_ERROR_ALIASES = {'EveAI Specialist Error', 'Task Error', 'task error'}
def __init__(self):
try:
# Use shared pubsub pool (lazy connect; no eager ping)
from common.utils.redis_pubsub_pool import get_pubsub_client
self.redis = get_pubsub_client(current_app)
self.expiry = 3600 # 1 hour expiry
except Exception as e:
current_app.logger.error(f"Error initializing ExecutionProgressTracker: {str(e)}")
raise
def _get_key(self, execution_id: str) -> str:
prefix = current_app.config.get('REDIS_PREFIXES', {}).get('pubsub_execution', 'pubsub:execution:')
return f"{prefix}{execution_id}"
def _retry(self, op, attempts: int = 3, base_delay: float = 0.1):
"""Retry wrapper for Redis operations with exponential backoff."""
last_exc = None
for i in range(attempts):
try:
return op()
except RedisError as e:
last_exc = e
if i == attempts - 1:
break
delay = base_delay * (3 ** i) # 0.1, 0.3, 0.9
current_app.logger.warning(f"Redis operation failed (attempt {i+1}/{attempts}): {e}. Retrying in {delay}s")
time.sleep(delay)
# Exhausted retries
raise last_exc
def _normalize_processing_type(self, processing_type: str) -> str:
if not processing_type:
return processing_type
p = str(processing_type).strip()
if p in self._COMPLETE_ALIASES:
return self.PT_COMPLETE
if p in self._ERROR_ALIASES:
return self.PT_ERROR
return p
def send_update(self, ctask_id: str, processing_type: str, data: dict):
"""Send an update about execution progress"""
try:
current_app.logger.debug(f"Sending update for {ctask_id} with processing type {processing_type} and data:\n"
f"{data}")
key = self._get_key(ctask_id)
processing_type = self._normalize_processing_type(processing_type)
update = {
'processing_type': processing_type,
'data': data,
'timestamp': dt.now(tz=tz.utc)
}
# Log initial state
try:
orig_len = self._retry(lambda: self.redis.llen(key))
# Try to serialize the update and check the result
try:
serialized_update = json.dumps(update, default=str) # Add default handler for datetime
except TypeError as e:
current_app.logger.error(f"Failed to serialize update: {str(e)}")
raise
# Store update in list with pipeline for atomicity
def _pipeline_op():
with self.redis.pipeline() as pipe:
pipe.rpush(key, serialized_update)
pipe.publish(key, serialized_update)
pipe.expire(key, self.expiry)
return pipe.execute()
results = self._retry(_pipeline_op)
new_len = self._retry(lambda: self.redis.llen(key))
if new_len <= orig_len:
current_app.logger.error(
f"List length did not increase as expected. Original: {orig_len}, New: {new_len}")
except RedisError as e:
current_app.logger.error(f"Redis operation failed: {str(e)}")
raise
except Exception as e:
current_app.logger.error(f"Unexpected error in send_update: {str(e)}, type: {type(e)}")
raise
def get_updates(self, ctask_id: str) -> Generator[str, None, None]:
key = self._get_key(ctask_id)
pubsub = self.redis.pubsub()
# Subscribe with retry
self._retry(lambda: pubsub.subscribe(key))
try:
# Hint client reconnect interval (optional but helpful)
yield "retry: 3000\n\n"
# First yield any existing updates
length = self._retry(lambda: self.redis.llen(key))
if length > 0:
updates = self._retry(lambda: self.redis.lrange(key, 0, -1))
for update in updates:
update_data = json.loads(update.decode('utf-8'))
update_data['processing_type'] = self._normalize_processing_type(update_data.get('processing_type'))
yield f"data: {json.dumps(update_data)}\n\n"
# Then listen for new updates
while True:
try:
message = pubsub.get_message(timeout=30) # message['type'] is Redis pub/sub type
except RedisError as e:
current_app.logger.warning(f"Redis pubsub get_message error: {e}. Continuing...")
time.sleep(0.3)
continue
if message is None:
yield ": keepalive\n\n"
continue
if message['type'] == 'message': # This is Redis pub/sub type
update_data = json.loads(message['data'].decode('utf-8'))
update_data['processing_type'] = self._normalize_processing_type(update_data.get('processing_type'))
yield f"data: {json.dumps(update_data)}\n\n"
# Unified completion check
if update_data['processing_type'] in [self.PT_COMPLETE, self.PT_ERROR]:
# Give proxies/clients a chance to flush
yield ": closing\n\n"
break
finally:
try:
pubsub.unsubscribe()
except Exception:
pass
try:
pubsub.close()
except Exception:
pass