90 lines
3.3 KiB
Python
90 lines
3.3 KiB
Python
from abc import ABC, abstractmethod, abstractproperty
|
|
from typing import Dict, Any, List, Optional, Tuple
|
|
|
|
from flask import current_app
|
|
from sqlalchemy import func, or_, desc
|
|
from sqlalchemy.exc import SQLAlchemyError
|
|
|
|
from common.extensions import db
|
|
from common.models.document import Document, DocumentVersion, Catalog, Retriever
|
|
from common.utils.model_utils import get_embedding_model_and_class
|
|
from eveai_chat_workers.retrievers.retriever_typing import RetrieverResult, RetrieverArguments, RetrieverMetadata
|
|
from config.logging_config import TuningLogger
|
|
|
|
|
|
class BaseRetriever(ABC):
|
|
"""Base class for all retrievers"""
|
|
|
|
def __init__(self, tenant_id: int, retriever_id: int):
|
|
self.tenant_id = tenant_id
|
|
self.retriever_id = retriever_id
|
|
self.retriever = Retriever.query.get_or_404(retriever_id)
|
|
self.tuning = False
|
|
self.tuning_logger = None
|
|
self._setup_tuning_logger()
|
|
|
|
@property
|
|
@abstractmethod
|
|
def type(self) -> str:
|
|
"""The type of the retriever"""
|
|
pass
|
|
|
|
def _setup_tuning_logger(self):
|
|
try:
|
|
self.tuning_logger = TuningLogger(
|
|
'tuning',
|
|
tenant_id=self.tenant_id,
|
|
retriever_id=self.retriever_id,
|
|
)
|
|
# Verify logger is working with a test message
|
|
if self.tuning:
|
|
self.tuning_logger.log_tuning('retriever', "Tuning logger initialized")
|
|
except Exception as e:
|
|
current_app.logger.error(f"Failed to setup tuning logger: {str(e)}")
|
|
raise
|
|
|
|
def log_tuning(self, message: str, data: Dict[str, Any] = None) -> None:
|
|
if self.tuning and self.tuning_logger:
|
|
try:
|
|
self.tuning_logger.log_tuning('retriever', message, data)
|
|
except Exception as e:
|
|
current_app.logger.error(f"Processor: Error in tuning logging: {e}")
|
|
|
|
def setup_standard_retrieval_params(self) -> Tuple[Any, Any, Any, float, int]:
|
|
"""
|
|
Set up common parameters needed for standard retrieval functionality
|
|
|
|
Returns:
|
|
Tuple containing:
|
|
- embedding_model: The model to use for embeddings
|
|
- embedding_model_class: The class for storing embeddings
|
|
- catalog_id: ID of the catalog
|
|
- similarity_threshold: Threshold for similarity matching
|
|
- k: Maximum number of results to return
|
|
"""
|
|
catalog_id = self.retriever.catalog_id
|
|
catalog = Catalog.query.get_or_404(catalog_id)
|
|
embedding_model = "mistral.mistral-embed"
|
|
|
|
embedding_model, embedding_model_class = get_embedding_model_and_class(
|
|
self.tenant_id, catalog_id, embedding_model
|
|
)
|
|
|
|
similarity_threshold = self.retriever.configuration.get('es_similarity_threshold', 0.3)
|
|
k = self.retriever.configuration.get('es_k', 8)
|
|
|
|
return embedding_model, embedding_model_class, catalog_id, similarity_threshold, k
|
|
|
|
@abstractmethod
|
|
def retrieve(self, arguments: RetrieverArguments) -> List[RetrieverResult]:
|
|
"""
|
|
Retrieve relevant documents based on provided arguments
|
|
|
|
Args:
|
|
arguments: Dictionary of arguments for the retrieval operation
|
|
|
|
Returns:
|
|
List[Dict[str, Any]]: List of retrieved documents/content
|
|
"""
|
|
pass
|