Files
eveAI/eveai_chat_workers/retrievers/retriever_typing.py
Josako 25213f2004 - Implementation of specialist execution api, including SSE protocol
- eveai_chat becomes deprecated and should be replaced with SSE
- Adaptation of STANDARD_RAG specialist
- Base class definition allowing to realise specialists with crewai framework
- Implementation of SPIN_SPECIALIST
- Implementation of test app for testing specialists (test_specialist_client). Also serves as an example for future SSE-based client
- Improvements to startup scripts to better handle and scale multiple connections
- Small improvements to the interaction forms and views
- Caching implementation improved and augmented with additional caches
2025-02-20 05:50:16 +01:00

64 lines
2.7 KiB
Python

from typing import Dict, Any
from flask import current_app
from pydantic import BaseModel, Field, model_validator
from common.extensions import cache_manager
class RetrieverMetadata(BaseModel):
"""Metadata structure for retrieved documents"""
document_id: int = Field(..., description="ID of the source document")
version_id: int = Field(..., description="Version ID of the source document")
document_name: str = Field(..., description="Name of the source document")
user_metadata: Dict[str, Any] = Field(
default_factory=dict, # This will use an empty dict if None is provided
description="User-defined metadata"
)
class RetrieverResult(BaseModel):
"""Standard result format for all retrievers"""
id: int = Field(..., description="ID of the retrieved embedding")
chunk: str = Field(..., description="Retrieved text chunk")
similarity: float = Field(..., description="Similarity score (0-1)")
metadata: RetrieverMetadata = Field(..., description="Associated metadata")
class RetrieverArguments(BaseModel):
"""
Dynamic arguments for retrievers, allowing arbitrary fields but validating required ones
based on RETRIEVER_TYPES configuration.
"""
type: str = Field(..., description="Type of retriever (e.g. STANDARD_RAG)")
type_version: str = Field(..., description="Version of retriever type (e.g. 1.0)")
# Allow any additional fields
model_config = {
"extra": "allow"
}
@model_validator(mode='after')
def validate_required_arguments(self) -> 'RetrieverArguments':
"""Validate that all required arguments for this retriever type are present"""
retriever_config = cache_manager.retrievers_config_cache.get_config(self.type, self.type_version)
if not retriever_config:
raise ValueError(f"Unknown retriever type: {self.type}")
# Check required arguments from configuration
for arg_name, arg_config in retriever_config['arguments'].items():
if arg_config.get('required', False):
if not hasattr(self, arg_name):
raise ValueError(f"Missing required argument '{arg_name}' for {self.type}")
# Type validation
value = getattr(self, arg_name)
expected_type = arg_config['type']
if expected_type == 'str' and not isinstance(value, str):
raise ValueError(f"Argument '{arg_name}' must be a string")
elif expected_type == 'int' and not isinstance(value, int):
raise ValueError(f"Argument '{arg_name}' must be an integer")
# Add other type validations as needed
return self