- Introduction of dynamic Retrievers & Specialists
- Introduction of dynamic Processors - Introduction of caching system - Introduction of a better template manager - Adaptation of ModelVariables to support dynamic Processors / Retrievers / Specialists - Start adaptation of chat client
This commit is contained in:
5
eveai_chat_workers/retrievers/__init__.py
Normal file
5
eveai_chat_workers/retrievers/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
# Import all specialist implementations here to ensure registration
|
||||
from . import standard_rag
|
||||
|
||||
# List of all available specialist implementations
|
||||
__all__ = ['standard_rag']
|
||||
57
eveai_chat_workers/retrievers/base.py
Normal file
57
eveai_chat_workers/retrievers/base.py
Normal file
@@ -0,0 +1,57 @@
|
||||
from abc import ABC, abstractmethod, abstractproperty
|
||||
from typing import Dict, Any, List
|
||||
|
||||
from flask import current_app
|
||||
from eveai_chat_workers.retrievers.retriever_typing import RetrieverResult, RetrieverArguments
|
||||
from config.logging_config import TuningLogger
|
||||
|
||||
|
||||
class BaseRetriever(ABC):
|
||||
"""Base class for all retrievers"""
|
||||
|
||||
def __init__(self, tenant_id: int, retriever_id: int):
|
||||
self.tenant_id = tenant_id
|
||||
self.retriever_id = retriever_id
|
||||
self.tuning = False
|
||||
self.tuning_logger = None
|
||||
self._setup_tuning_logger()
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def type(self) -> str:
|
||||
"""The type of the retriever"""
|
||||
pass
|
||||
|
||||
def _setup_tuning_logger(self):
|
||||
try:
|
||||
self.tuning_logger = TuningLogger(
|
||||
'tuning',
|
||||
tenant_id=self.tenant_id,
|
||||
retriever_id=self.retriever_id,
|
||||
)
|
||||
# Verify logger is working with a test message
|
||||
if self.tuning:
|
||||
self.tuning_logger.log_tuning('retriever', "Tuning logger initialized")
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Failed to setup tuning logger: {str(e)}")
|
||||
raise
|
||||
|
||||
def _log_tuning(self, message: str, data: Dict[str, Any] = None) -> None:
|
||||
if self.tuning and self.tuning_logger:
|
||||
try:
|
||||
self.tuning_logger.log_tuning('retriever', message, data)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Processor: Error in tuning logging: {e}")
|
||||
|
||||
@abstractmethod
|
||||
def retrieve(self, arguments: RetrieverArguments) -> List[RetrieverResult]:
|
||||
"""
|
||||
Retrieve relevant documents based on provided arguments
|
||||
|
||||
Args:
|
||||
arguments: Dictionary of arguments for the retrieval operation
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: List of retrieved documents/content
|
||||
"""
|
||||
pass
|
||||
20
eveai_chat_workers/retrievers/registry.py
Normal file
20
eveai_chat_workers/retrievers/registry.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from typing import Dict, Type
|
||||
from .base import BaseRetriever
|
||||
|
||||
|
||||
class RetrieverRegistry:
|
||||
"""Registry for retriever types"""
|
||||
|
||||
_registry: Dict[str, Type[BaseRetriever]] = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, retriever_type: str, retriever_class: Type[BaseRetriever]):
|
||||
"""Register a new retriever type"""
|
||||
cls._registry[retriever_type] = retriever_class
|
||||
|
||||
@classmethod
|
||||
def get_retriever_class(cls, retriever_type: str) -> Type[BaseRetriever]:
|
||||
"""Get the retriever class for a given type"""
|
||||
if retriever_type not in cls._registry:
|
||||
raise ValueError(f"Unknown retriever type: {retriever_type}")
|
||||
return cls._registry[retriever_type]
|
||||
60
eveai_chat_workers/retrievers/retriever_typing.py
Normal file
60
eveai_chat_workers/retrievers/retriever_typing.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from typing import List, Dict, Any
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from common.utils.config_field_types import ArgumentDefinition, TaggingFields
|
||||
from config.retriever_types import RETRIEVER_TYPES
|
||||
|
||||
|
||||
class RetrieverMetadata(BaseModel):
|
||||
"""Metadata structure for retrieved documents"""
|
||||
document_id: int = Field(..., description="ID of the source document")
|
||||
version_id: int = Field(..., description="Version ID of the source document")
|
||||
document_name: str = Field(..., description="Name of the source document")
|
||||
user_metadata: Dict[str, Any] = Field(
|
||||
default_factory=dict, # This will use an empty dict if None is provided
|
||||
description="User-defined metadata"
|
||||
)
|
||||
|
||||
|
||||
class RetrieverResult(BaseModel):
|
||||
"""Standard result format for all retrievers"""
|
||||
id: int = Field(..., description="ID of the retrieved embedding")
|
||||
chunk: str = Field(..., description="Retrieved text chunk")
|
||||
similarity: float = Field(..., description="Similarity score (0-1)")
|
||||
metadata: RetrieverMetadata = Field(..., description="Associated metadata")
|
||||
|
||||
|
||||
class RetrieverArguments(BaseModel):
|
||||
"""
|
||||
Dynamic arguments for retrievers, allowing arbitrary fields but validating required ones
|
||||
based on RETRIEVER_TYPES configuration.
|
||||
"""
|
||||
type: str = Field(..., description="Type of retriever (e.g. STANDARD_RAG)")
|
||||
|
||||
# Allow any additional fields
|
||||
model_config = {
|
||||
"extra": "allow"
|
||||
}
|
||||
|
||||
@model_validator(mode='after')
|
||||
def validate_required_arguments(self) -> 'RetrieverArguments':
|
||||
"""Validate that all required arguments for this retriever type are present"""
|
||||
retriever_config = RETRIEVER_TYPES.get(self.type)
|
||||
if not retriever_config:
|
||||
raise ValueError(f"Unknown retriever type: {self.type}")
|
||||
|
||||
# Check required arguments from configuration
|
||||
for arg_name, arg_config in retriever_config['arguments'].items():
|
||||
if arg_config.get('required', False):
|
||||
if not hasattr(self, arg_name):
|
||||
raise ValueError(f"Missing required argument '{arg_name}' for {self.type}")
|
||||
|
||||
# Type validation
|
||||
value = getattr(self, arg_name)
|
||||
expected_type = arg_config['type']
|
||||
if expected_type == 'str' and not isinstance(value, str):
|
||||
raise ValueError(f"Argument '{arg_name}' must be a string")
|
||||
elif expected_type == 'int' and not isinstance(value, int):
|
||||
raise ValueError(f"Argument '{arg_name}' must be an integer")
|
||||
# Add other type validations as needed
|
||||
|
||||
return self
|
||||
140
eveai_chat_workers/retrievers/standard_rag.py
Normal file
140
eveai_chat_workers/retrievers/standard_rag.py
Normal file
@@ -0,0 +1,140 @@
|
||||
# retrievers/standard_rag.py
|
||||
from datetime import datetime as dt, timezone as tz
|
||||
from typing import Dict, Any, List
|
||||
from sqlalchemy import func, or_, desc
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from flask import current_app
|
||||
|
||||
from common.extensions import db
|
||||
from common.models.document import Document, DocumentVersion, Catalog, Retriever
|
||||
from common.models.user import Tenant
|
||||
from common.utils.datetime_utils import get_date_in_timezone
|
||||
from common.utils.model_utils import get_model_variables
|
||||
from .base import BaseRetriever
|
||||
|
||||
from .registry import RetrieverRegistry
|
||||
from .retriever_typing import RetrieverArguments, RetrieverResult, RetrieverMetadata
|
||||
|
||||
|
||||
class StandardRAGRetriever(BaseRetriever):
|
||||
"""Standard RAG retriever implementation"""
|
||||
|
||||
def __init__(self, tenant_id: int, retriever_id: int):
|
||||
super().__init__(tenant_id, retriever_id)
|
||||
|
||||
retriever = Retriever.query.get_or_404(retriever_id)
|
||||
self.catalog_id = retriever.catalog_id
|
||||
self.similarity_threshold = retriever.configuration.get('es_similarity_threshold', 0.3)
|
||||
self.k = retriever.configuration.get('es_k', 8)
|
||||
self.tuning = retriever.tuning
|
||||
self.model_variables = get_model_variables(self.tenant_id)
|
||||
|
||||
self._log_tuning("Standard RAG retriever initialized")
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return "STANDARD_RAG"
|
||||
|
||||
def retrieve(self, arguments: RetrieverArguments) -> List[RetrieverResult]:
|
||||
"""
|
||||
Retrieve documents based on query
|
||||
|
||||
Args:
|
||||
arguments: Validated RetrieverArguments containing at minimum:
|
||||
- query: str - The search query
|
||||
|
||||
Returns:
|
||||
List[RetrieverResult]: List of retrieved documents with similarity scores
|
||||
"""
|
||||
try:
|
||||
query = arguments.query
|
||||
|
||||
# Get query embedding
|
||||
query_embedding = self._get_query_embedding(query)
|
||||
|
||||
# Get the appropriate embedding database model
|
||||
db_class = self.model_variables.embedding_model_class
|
||||
|
||||
# Get current date for validity checks
|
||||
current_date = dt.now(tz=tz.utc).date()
|
||||
|
||||
# Create subquery for latest versions
|
||||
subquery = (
|
||||
db.session.query(
|
||||
DocumentVersion.doc_id,
|
||||
func.max(DocumentVersion.id).label('latest_version_id')
|
||||
)
|
||||
.group_by(DocumentVersion.doc_id)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
# Main query
|
||||
query_obj = (
|
||||
db.session.query(
|
||||
db_class,
|
||||
(1 - db_class.embedding.cosine_distance(query_embedding)).label('similarity')
|
||||
)
|
||||
.join(DocumentVersion, db_class.doc_vers_id == DocumentVersion.id)
|
||||
.join(Document, DocumentVersion.doc_id == Document.id)
|
||||
.join(subquery, DocumentVersion.id == subquery.c.latest_version_id)
|
||||
.filter(
|
||||
or_(Document.valid_from.is_(None), func.date(Document.valid_from) <= current_date),
|
||||
or_(Document.valid_to.is_(None), func.date(Document.valid_to) >= current_date),
|
||||
(1 - db_class.embedding.cosine_distance(query_embedding)) > self.similarity_threshold,
|
||||
Document.catalog_id == self.catalog_id
|
||||
)
|
||||
.order_by(desc('similarity'))
|
||||
.limit(self.k)
|
||||
)
|
||||
|
||||
results = query_obj.all()
|
||||
|
||||
# Transform results into standard format
|
||||
processed_results = []
|
||||
for doc, similarity in results:
|
||||
processed_results.append(
|
||||
RetrieverResult(
|
||||
id=doc.id,
|
||||
chunk=doc.chunk,
|
||||
similarity=float(similarity),
|
||||
metadata=RetrieverMetadata(
|
||||
document_id=doc.document_version.doc_id,
|
||||
version_id=doc.document_version.id,
|
||||
document_name=doc.document_version.document.name,
|
||||
user_metadata=doc.document_version.user_metadata or {},
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Log the retrieval
|
||||
if self.tuning:
|
||||
compiled_query = str(query_obj.statement.compile(
|
||||
compile_kwargs={"literal_binds": True} # This will include the actual values in the SQL
|
||||
))
|
||||
self._log_tuning('retrieve', {
|
||||
"arguments": arguments.model_dump(),
|
||||
"similarity_threshold": self.similarity_threshold,
|
||||
"k": self.k,
|
||||
"query": compiled_query,
|
||||
"Raw Results": str(results),
|
||||
"Processed Results": [r.model_dump() for r in processed_results],
|
||||
})
|
||||
|
||||
return processed_results
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
current_app.logger.error(f'Error in RAG retrieval: {e}')
|
||||
db.session.rollback()
|
||||
raise
|
||||
except Exception as e:
|
||||
current_app.logger.error(f'Unexpected error in RAG retrieval: {e}')
|
||||
raise
|
||||
|
||||
def _get_query_embedding(self, query: str):
|
||||
"""Get embedding for the query text"""
|
||||
embedding_model = self.model_variables.embedding_model
|
||||
return embedding_model.embed_query(query)
|
||||
|
||||
|
||||
# Register the retriever type
|
||||
RetrieverRegistry.register("STANDARD_RAG", StandardRAGRetriever)
|
||||
Reference in New Issue
Block a user