# retrievers/standard_rag.py import json from datetime import datetime as dt, timezone as tz from typing import Dict, Any, List from sqlalchemy import func, or_, desc from sqlalchemy.exc import SQLAlchemyError from flask import current_app from common.extensions import db from common.models.document import Document, DocumentVersion, Catalog, Retriever from common.models.user import Tenant from common.utils.datetime_utils import get_date_in_timezone from common.utils.model_utils import get_embedding_model_and_class from .base import BaseRetriever from .registry import RetrieverRegistry from .retriever_typing import RetrieverArguments, RetrieverResult, RetrieverMetadata class StandardRAGRetriever(BaseRetriever): """Standard RAG retriever implementation""" def __init__(self, tenant_id: int, retriever_id: int): super().__init__(tenant_id, retriever_id) retriever = Retriever.query.get_or_404(retriever_id) self.catalog_id = retriever.catalog_id self.tenant_id = tenant_id self.similarity_threshold = retriever.configuration.get('es_similarity_threshold', 0.3) self.k = retriever.configuration.get('es_k', 8) self.tuning = retriever.tuning self.log_tuning("Standard RAG retriever initialized") @property def type(self) -> str: return "STANDARD_RAG" def _parse_metadata(self, metadata: Any) -> Dict[str, Any]: """ Parse metadata ensuring it's a dictionary Args: metadata: Input metadata which could be string, dict, or None Returns: Dict[str, Any]: Parsed metadata as dictionary """ if metadata is None: return {} if isinstance(metadata, dict): return metadata if isinstance(metadata, str): try: return json.loads(metadata) except json.JSONDecodeError: current_app.logger.warning(f"Failed to parse metadata JSON string: {metadata}") return {} current_app.logger.warning(f"Unexpected metadata type: {type(metadata)}") return {} def retrieve(self, arguments: RetrieverArguments) -> List[RetrieverResult]: """ Retrieve documents based on query Args: arguments: Validated RetrieverArguments containing at minimum: - query: str - The search query Returns: List[RetrieverResult]: List of retrieved documents with similarity scores """ try: query = arguments.query # Get query embedding query_embedding = self._get_query_embedding(query) # Get the appropriate embedding database model db_class = self.model_variables.embedding_model_class # Get current date for validity checks current_date = dt.now(tz=tz.utc).date() # Create subquery for latest versions subquery = ( db.session.query( DocumentVersion.doc_id, func.max(DocumentVersion.id).label('latest_version_id') ) .group_by(DocumentVersion.doc_id) .subquery() ) # Main query query_obj = ( db.session.query( db_class, (1 - db_class.embedding.cosine_distance(query_embedding)).label('similarity') ) .join(DocumentVersion, db_class.doc_vers_id == DocumentVersion.id) .join(Document, DocumentVersion.doc_id == Document.id) .join(subquery, DocumentVersion.id == subquery.c.latest_version_id) .filter( or_(Document.valid_from.is_(None), func.date(Document.valid_from) <= current_date), or_(Document.valid_to.is_(None), func.date(Document.valid_to) >= current_date), (1 - db_class.embedding.cosine_distance(query_embedding)) > self.similarity_threshold, Document.catalog_id == self.catalog_id ) .order_by(desc('similarity')) .limit(self.k) ) results = query_obj.all() # Transform results into standard format processed_results = [] for doc, similarity in results: # Parse user_metadata to ensure it's a dictionary user_metadata = self._parse_metadata(doc.document_version.user_metadata) processed_results.append( RetrieverResult( id=doc.id, chunk=doc.chunk, similarity=float(similarity), metadata=RetrieverMetadata( document_id=doc.document_version.doc_id, version_id=doc.document_version.id, document_name=doc.document_version.document.name, user_metadata=user_metadata, ) ) ) # Log the retrieval if self.tuning: compiled_query = str(query_obj.statement.compile( compile_kwargs={"literal_binds": True} # This will include the actual values in the SQL )) self.log_tuning('retrieve', { "arguments": arguments.model_dump(), "similarity_threshold": self.similarity_threshold, "k": self.k, "query": compiled_query, "Raw Results": str(results), "Processed Results": [r.model_dump() for r in processed_results], }) return processed_results except SQLAlchemyError as e: current_app.logger.error(f'Error in RAG retrieval: {e}') db.session.rollback() raise except Exception as e: current_app.logger.error(f'Unexpected error in RAG retrieval: {e}') raise def _get_query_embedding(self, query: str): """Get embedding for the query text""" catalog = Catalog.query.get_or_404(self.catalog_id) embedding_model, embedding_model_class = get_embedding_model_and_class(self.tenant_id, self.catalog_id, catalog.embedding_model) # Register the retriever type RetrieverRegistry.register("STANDARD_RAG", StandardRAGRetriever)