Files
eveAI/eveai_chat_workers/tasks.py
Josako 5bfd3445bb - adding usage to specialist execution
- Correcting implementation of usage
- Removed some obsolete debug statements
2025-03-07 11:10:28 +01:00

326 lines
13 KiB
Python

from datetime import datetime as dt, timezone as tz
from typing import Dict, Any, Optional
from flask import current_app
from sqlalchemy.exc import SQLAlchemyError
from common.utils.config_field_types import TaggingFields
from common.utils.database import Database
from common.models.document import Catalog
from common.models.user import Tenant
from common.models.interaction import Interaction, Specialist, SpecialistRetriever
from common.extensions import db, cache_manager
from common.utils.celery_utils import current_celery
from common.utils.business_event import BusinessEvent
from common.utils.business_event_context import current_event
from eveai_chat_workers.specialists.specialist_typing import SpecialistArguments
from eveai_chat_workers.specialists.base_specialist import get_specialist_class
from common.utils.execution_progress import ExecutionProgressTracker
# Healthcheck task
@current_celery.task(name='ping', queue='llm_interactions')
def ping():
return 'pong'
class ArgumentPreparationError(Exception):
"""Custom exception for argument preparation errors"""
pass
def validate_specialist_arguments(specialist_type: str, specialist_type_version:str, arguments: Dict[str, Any]) -> None:
"""
Validate specialist-specific arguments
Args:
specialist_type: Type of specialist
specialist_type_version: Version of specialist type
arguments: Arguments to validate (excluding retriever-specific arguments)
Raises:
ArgumentPreparationError: If validation fails
"""
specialist_config = cache_manager.specialists_config_cache.get_config(specialist_type, specialist_type_version)
if not specialist_config:
raise ArgumentPreparationError(f"Unknown specialist type: {specialist_type}")
required_args = specialist_config.get('arguments', {})
# Check for required arguments
for arg_name, arg_config in required_args.items():
if arg_config.get('required', False) and arg_name not in arguments:
raise ArgumentPreparationError(f"Missing required argument '{arg_name}' for specialist")
if arg_name in arguments:
# Type checking
expected_type = arg_config.get('type')
if expected_type == 'str' and not isinstance(arguments[arg_name], str):
raise ArgumentPreparationError(f"Argument '{arg_name}' must be a string")
elif expected_type == 'int' and not isinstance(arguments[arg_name], int):
raise ArgumentPreparationError(f"Argument '{arg_name}' must be an integer")
def validate_retriever_arguments(retriever_type: str, retriever_type_version: str, arguments: Dict[str, Any],
catalog_config: Optional[Dict[str, Any]] = None) -> None:
"""
Validate retriever-specific arguments
Args:
retriever_type: Type of retriever
retriever_type_version: Version of retriever type
arguments: Arguments to validate
catalog_config: Optional catalog configuration for metadata validation
Raises:
ArgumentPreparationError: If validation fails
"""
retriever_config = cache_manager.retrievers_config_cache.get_config(retriever_type, retriever_type_version)
if not retriever_config:
raise ArgumentPreparationError(f"Unknown retriever type: {retriever_type}")
# Validate standard retriever arguments
required_args = retriever_config.get('arguments', {})
for arg_name, arg_config in required_args.items():
if arg_config.get('required', False) and arg_name not in arguments:
raise ArgumentPreparationError(f"Missing required argument '{arg_name}' for retriever")
# Only validate metadata filters if catalog configuration is provided
if catalog_config and 'metadata_filters' in arguments:
if 'tagging_fields' in catalog_config:
tagging_fields = TaggingFields.from_dict(catalog_config['tagging_fields'])
errors = tagging_fields.validate_argument_values(arguments['metadata_filters'])
if errors:
raise ArgumentPreparationError(f"Invalid metadata filters: {errors}")
def is_retriever_id(key: str) -> bool:
"""
Check if a key represents a valid retriever ID.
Valid formats: positive integers, including leading zeros
Args:
key: String to check
Returns:
bool: True if the key represents a valid retriever ID
"""
try:
# Convert to int to handle leading zeros
value = int(key)
# Ensure it's a positive number
return value > 0
except ValueError:
return False
def prepare_arguments(specialist: Any, arguments: Dict[str, Any]) -> Dict[str, Any]:
"""
Prepare complete argument dictionary for specialist execution with inheritance
Args:
specialist: Specialist model instance
arguments: Dictionary containing:
- Specialist arguments
- Retriever-specific arguments keyed by retriever ID
Returns:
Dict containing prepared arguments with inheritance applied
Raises:
ArgumentPreparationError: If argument preparation or validation fails
"""
try:
# Separate specialist arguments from retriever arguments
retriever_args = {}
specialist_args = {}
for key, value in arguments.items():
if isinstance(key, str) and is_retriever_id(key): # Retriever ID
retriever_args[key] = value
else:
specialist_args[key] = value
# Validate specialist arguments
validate_specialist_arguments(specialist.type, specialist.type_version, specialist_args)
# Get all retrievers associated with this specialist
specialist_retrievers = (
SpecialistRetriever.query
.filter_by(specialist_id=specialist.id)
.all()
)
# Process each retriever
prepared_retriever_args = {}
for spec_retriever in specialist_retrievers:
retriever = spec_retriever.retriever
retriever_id = str(retriever.id)
# Get catalog configuration if it exists
catalog_config = None
if retriever.catalog_id:
try:
catalog = Catalog.query.get(retriever.catalog_id)
if catalog:
catalog_config = catalog.configuration
except SQLAlchemyError:
current_app.logger.warning(
f"Could not fetch catalog {retriever.catalog_id} for retriever {retriever_id}"
)
# Start with specialist arguments (inheritance)
inherited_args = specialist_args.copy()
# Override with retriever-specific arguments if provided
if retriever_id in retriever_args:
inherited_args.update(retriever_args[retriever_id])
# Always include the retriever type
inherited_args['type'] = retriever.type
inherited_args['type_version'] = retriever.type_version
# Validate the combined arguments
validate_retriever_arguments(
retriever.type, retriever.type_version,
inherited_args,
catalog_config
)
prepared_retriever_args[retriever_id] = inherited_args
# Construct final argument structure
final_arguments = {
**specialist_args,
'retriever_arguments': prepared_retriever_args
}
return final_arguments
except SQLAlchemyError as e:
current_app.logger.error(f'Database error during argument preparation: {e}')
raise ArgumentPreparationError(f"Database error: {str(e)}")
except Exception as e:
current_app.logger.error(f'Error during argument preparation: {e}')
raise ArgumentPreparationError(str(e))
@current_celery.task(name='execute_specialist', queue='llm_interactions', bind=True)
def execute_specialist(self, tenant_id: int, specialist_id: int, arguments: Dict[str, Any],
session_id: str, user_timezone: str) -> dict:
"""
Execute a specialist with given arguments
Args:
tenant_id: ID of the tenant
specialist_id: ID of the specialist to use
arguments: Dictionary containing all required arguments for specialist and retrievers
session_id: Chat session ID
user_timezone: User's timezone
Returns:
dict: {
'result': Dict - Specialist execution result
'interaction_id': int - Created interaction ID
}
"""
task_id = self.request.id
ept = ExecutionProgressTracker()
ept.send_update(task_id, "EveAI Specialist Started", {})
with BusinessEvent("Execute Specialist", tenant_id=tenant_id, chat_session_id=session_id) as event:
current_app.logger.info(
f'execute_specialist: Processing request for tenant {tenant_id} using specialist {specialist_id}')
try:
# Retrieve the tenant
tenant = Tenant.query.get(tenant_id)
if not tenant:
raise Exception(f'Tenant {tenant_id} not found.')
# Switch to correct database schema
Database(tenant_id).switch_schema()
# Ensure we have a session
cached_session = cache_manager.chat_session_cache.get_cached_session(
session_id,
create_params={'timezone': user_timezone}
)
# Get specialist from database
specialist = Specialist.query.get_or_404(specialist_id)
# Prepare complete arguments
try:
raw_arguments = prepare_arguments(specialist, arguments)
# Convert the prepared arguments into a SpecialistArguments instance
complete_arguments = SpecialistArguments.create(
type_name=specialist.type,
type_version=specialist.type_version,
specialist_args={k: v for k, v in raw_arguments.items() if k != 'retriever_arguments'},
retriever_args=raw_arguments.get('retriever_arguments', {})
)
except ValueError as e:
current_app.logger.error(f'execute_specialist: Error preparing arguments: {e}')
raise
# Create new interaction record
new_interaction = Interaction()
new_interaction.chat_session_id = cached_session.id
new_interaction.timezone = user_timezone
new_interaction.question_at = dt.now(tz.utc)
new_interaction.specialist_id = specialist.id
new_interaction.specialist_arguments = complete_arguments.model_dump(mode='json')
try:
db.session.add(new_interaction)
db.session.commit()
event.update_attribute('interaction_id', new_interaction.id)
except SQLAlchemyError as e:
current_app.logger.error(f'execute_specialist: Error creating interaction: {e}')
raise
with current_event.create_span("Specialist invocation"):
ept.send_update(task_id, "EveAI Specialist Start", {})
# Initialize specialist instance
specialist_class = get_specialist_class(specialist.type, specialist.type_version)
specialist_instance = specialist_class(
tenant_id=tenant_id,
specialist_id=specialist_id,
session_id=session_id,
task_id=task_id,
)
# Execute specialist
result = specialist_instance.execute(complete_arguments)
# Update interaction record
new_interaction.specialist_results = result.model_dump(mode='json') # Store complete result
new_interaction.answer_at = dt.now(tz.utc)
try:
db.session.add(new_interaction)
db.session.commit()
except SQLAlchemyError as e:
current_app.logger.error(f'execute_specialist: Error updating interaction: {e}')
raise
# Now that we have a complete interaction with an answer, add it to the cache
cache_manager.chat_session_cache.add_completed_interaction(session_id, new_interaction)
# Prepare response
response = {
'result': result.model_dump(),
'interaction_id': new_interaction.id
}
ept.send_update(task_id, "EveAI Specialist Complete", response)
return response
except Exception as e:
ept.send_update(task_id, "EveAI Specialist Error", {'Error': str(e)})
current_app.logger.error(f'execute_specialist: Error executing specialist: {e}')
raise
def tasks_ping():
return 'pong'