Files
eveAI/eveai_chat_workers/retrievers/retriever_typing.py
Josako 1807435339 - Introduction of dynamic Retrievers & Specialists
- Introduction of dynamic Processors
- Introduction of caching system
- Introduction of a better template manager
- Adaptation of ModelVariables to support dynamic Processors / Retrievers / Specialists
- Start adaptation of chat client
2024-11-15 10:00:53 +01:00

61 lines
2.6 KiB
Python

from typing import List, Dict, Any
from pydantic import BaseModel, Field, model_validator
from common.utils.config_field_types import ArgumentDefinition, TaggingFields
from config.retriever_types import RETRIEVER_TYPES
class RetrieverMetadata(BaseModel):
"""Metadata structure for retrieved documents"""
document_id: int = Field(..., description="ID of the source document")
version_id: int = Field(..., description="Version ID of the source document")
document_name: str = Field(..., description="Name of the source document")
user_metadata: Dict[str, Any] = Field(
default_factory=dict, # This will use an empty dict if None is provided
description="User-defined metadata"
)
class RetrieverResult(BaseModel):
"""Standard result format for all retrievers"""
id: int = Field(..., description="ID of the retrieved embedding")
chunk: str = Field(..., description="Retrieved text chunk")
similarity: float = Field(..., description="Similarity score (0-1)")
metadata: RetrieverMetadata = Field(..., description="Associated metadata")
class RetrieverArguments(BaseModel):
"""
Dynamic arguments for retrievers, allowing arbitrary fields but validating required ones
based on RETRIEVER_TYPES configuration.
"""
type: str = Field(..., description="Type of retriever (e.g. STANDARD_RAG)")
# Allow any additional fields
model_config = {
"extra": "allow"
}
@model_validator(mode='after')
def validate_required_arguments(self) -> 'RetrieverArguments':
"""Validate that all required arguments for this retriever type are present"""
retriever_config = RETRIEVER_TYPES.get(self.type)
if not retriever_config:
raise ValueError(f"Unknown retriever type: {self.type}")
# Check required arguments from configuration
for arg_name, arg_config in retriever_config['arguments'].items():
if arg_config.get('required', False):
if not hasattr(self, arg_name):
raise ValueError(f"Missing required argument '{arg_name}' for {self.type}")
# Type validation
value = getattr(self, arg_name)
expected_type = arg_config['type']
if expected_type == 'str' and not isinstance(value, str):
raise ValueError(f"Argument '{arg_name}' must be a string")
elif expected_type == 'int' and not isinstance(value, int):
raise ValueError(f"Argument '{arg_name}' must be an integer")
# Add other type validations as needed
return self