from langchain_core.retrievers import BaseRetriever from sqlalchemy import func, and_, or_, desc, cast, JSON from sqlalchemy.exc import SQLAlchemyError from pydantic import BaseModel, Field, PrivateAttr from typing import Any, Dict, List, Optional from flask import current_app from contextlib import contextmanager from common.extensions import db from common.models.document import Document, DocumentVersion, Catalog from common.utils.datetime_utils import get_date_in_timezone from common.utils.model_utils import ModelVariables class EveAIDossierRetriever(BaseRetriever, BaseModel): _catalog_id: int = PrivateAttr() _model_variables: ModelVariables = PrivateAttr() _tenant_info: Dict[str, Any] = PrivateAttr() _active_filters: Optional[Dict[str, Any]] = PrivateAttr() def __init__(self, catalog_id: int, model_variables: ModelVariables, tenant_info: Dict[str, Any]): super().__init__() self._catalog_id = catalog_id self._model_variables = model_variables self._tenant_info = tenant_info self._active_filters = None @contextmanager def filtering(self, metadata_filters: Dict[str, Any]): """Context manager for temporarily setting metadata filters""" previous_filters = self._active_filters self._active_filters = metadata_filters try: yield self finally: self._active_filters = previous_filters def _build_metadata_filter_conditions(self, query): """Build SQL conditions for metadata filtering""" if not self._active_filters: return query conditions = [] for field, value in self._active_filters.items(): if value is None: continue # Handle both single values and lists of values if isinstance(value, (list, tuple)): # Multiple values - create OR condition or_conditions = [] for val in value: or_conditions.append( cast(DocumentVersion.user_metadata[field].astext, JSON) == str(val) ) if or_conditions: conditions.append(or_(*or_conditions)) else: # Single value - direct comparison conditions.append( cast(DocumentVersion.user_metadata[field].astext, JSON) == str(value) ) if conditions: query = query.filter(and_(*conditions)) return query def _get_relevant_documents(self, query: str): current_app.logger.debug(f'Retrieving relevant documents for dossier query: {query}') if self._active_filters: current_app.logger.debug(f'Using metadata filters: {self._active_filters}') query_embedding = self._get_query_embedding(query) db_class = self.model_variables['embedding_db_model'] similarity_threshold = self.model_variables['similarity_threshold'] k = self.model_variables['k'] try: current_date = get_date_in_timezone(self.tenant_info['timezone']) # Subquery to find the latest version of each document subquery = ( db.session.query( DocumentVersion.doc_id, func.max(DocumentVersion.id).label('latest_version_id') ) .group_by(DocumentVersion.doc_id) .subquery() ) # Build base query # Build base query query_obj = ( db.session.query(db_class, (1 - db_class.embedding.cosine_distance(query_embedding)).label('similarity')) .join(DocumentVersion, db_class.doc_vers_id == DocumentVersion.id) .join(Document, DocumentVersion.doc_id == Document.id) .join(subquery, DocumentVersion.id == subquery.c.latest_version_id) .filter( or_(Document.valid_from.is_(None), func.date(Document.valid_from) <= current_date), or_(Document.valid_to.is_(None), func.date(Document.valid_to) >= current_date), (1 - db_class.embedding.cosine_distance(query_embedding)) > similarity_threshold, Document.catalog_id == self._catalog_id ) ) # Apply metadata filters query_obj = self._build_metadata_filter_conditions(query_obj) # Order and limit results query_obj = query_obj.order_by(desc('similarity')).limit(k) # Debug logging for RAG tuning if enabled if self.model_variables['rag_tuning']: self._log_rag_tuning(query_obj, query_embedding) res = query_obj.all() result = [] for doc in res: if self.model_variables['rag_tuning']: current_app.logger.debug(f'Document ID: {doc[0].id} - Distance: {doc[1]}\n') current_app.logger.debug(f'Chunk: \n {doc[0].chunk}\n\n') result.append(f'SOURCE: {doc[0].id}\n\n{doc[0].chunk}\n\n') except SQLAlchemyError as e: current_app.logger.error(f'Error retrieving relevant documents: {e}') db.session.rollback() return [] return result def _log_rag_tuning(self, query_obj, query_embedding): """Log debug information for RAG tuning""" current_app.rag_tuning_logger.debug("Debug: Query execution plan:") current_app.rag_tuning_logger.debug(f"{query_obj.statement}") if self._active_filters: current_app.rag_tuning_logger.debug("Debug: Active metadata filters:") current_app.rag_tuning_logger.debug(f"{self._active_filters}") def _get_query_embedding(self, query: str): """Get embedding for the query text""" embedding_model = self.model_variables['embedding_model'] query_embedding = embedding_model.embed_query(query) return query_embedding @property def model_variables(self) -> ModelVariables: return self._model_variables @property def tenant_info(self) -> Dict[str, Any]: return self._tenant_info