from typing import Dict, Any, Optional from flask import current_app from pydantic import BaseModel, Field, model_validator from common.extensions import cache_manager class RetrieverMetadata(BaseModel): """Metadata structure for retrieved documents""" document_id: int = Field(..., description="ID of the source document") version_id: int = Field(..., description="Version ID of the source document") document_name: str = Field(..., description="Name of the source document") url: Optional[str] = Field(..., description="URL of the source document") user_metadata: Dict[str, Any] = Field( default_factory=dict, # This will use an empty dict if None is provided description="User-defined metadata" ) class RetrieverResult(BaseModel): """Standard result format for all retrievers""" id: int = Field(..., description="ID of the retrieved embedding") chunk: str = Field(..., description="Retrieved text chunk") similarity: float = Field(..., description="Similarity score (0-1)") metadata: RetrieverMetadata = Field(..., description="Associated metadata") class RetrieverArguments(BaseModel): """ Dynamic arguments for retrievers, allowing arbitrary fields but validating required ones based on RETRIEVER_TYPES configuration. """ type: str = Field(..., description="Type of retriever (e.g. STANDARD_RAG)") type_version: str = Field(..., description="Version of retriever type (e.g. 1.0)") question: str = Field(..., description="Question to retrieve answers for") # Allow any additional fields model_config = { "extra": "allow" } @model_validator(mode='after') def validate_required_arguments(self) -> 'RetrieverArguments': """Validate that all required arguments for this retriever type are present""" retriever_config = cache_manager.retrievers_config_cache.get_config(self.type, self.type_version) if not retriever_config: raise ValueError(f"Unknown retriever type: {self.type}") # Check required arguments from configuration for arg_name, arg_config in retriever_config['arguments'].items(): if arg_config.get('required', False): if not hasattr(self, arg_name): raise ValueError(f"Missing required argument '{arg_name}' for {self.type}") # Type validation value = getattr(self, arg_name) expected_type = arg_config['type'] if expected_type == 'str' and not isinstance(value, str): raise ValueError(f"Argument '{arg_name}' must be a string") elif expected_type == 'int' and not isinstance(value, int): raise ValueError(f"Argument '{arg_name}' must be an integer") # Add other type validations as needed return self