""" DossierRetriever implementation that adds metadata filtering to retrieval """ import json from datetime import datetime as dt, date, timezone as tz from typing import Dict, Any, List, Optional, Union, Tuple from sqlalchemy import func, or_, desc, and_, text, cast, JSON, String, Integer, Float, Boolean, DateTime from sqlalchemy.sql import expression from sqlalchemy.exc import SQLAlchemyError from flask import current_app from common.extensions import db from common.models.document import Document, DocumentVersion, Catalog, Retriever from common.utils.model_utils import get_embedding_model_and_class from .base import BaseRetriever from .registry import RetrieverRegistry from .retriever_typing import RetrieverArguments, RetrieverResult, RetrieverMetadata class DossierRetriever(BaseRetriever): """ Dossier Retriever implementation that adds metadata filtering to standard retrieval functionality """ def __init__(self, tenant_id: int, retriever_id: int): super().__init__(tenant_id, retriever_id) # Set up standard retrieval parameters self.embedding_model, self.embedding_model_class, self.catalog_id, self.similarity_threshold, self.k = self.setup_standard_retrieval_params() # Dossier-specific configuration self.tagging_fields_filter = self.retriever.configuration.get('tagging_fields_filter', {}) self.dynamic_arguments = self.retriever.configuration.get('dynamic_arguments', {}) self.log_tuning("Dossier retriever initialized", { "tagging_fields_filter": self.tagging_fields_filter, "dynamic_arguments": self.dynamic_arguments, "similarity_threshold": self.similarity_threshold, "k": self.k }) @property def type(self) -> str: return "DOSSIER_RETRIEVER" def _parse_metadata(self, metadata: Any) -> Dict[str, Any]: """ Parse metadata ensuring it's a dictionary Args: metadata: Input metadata which could be string, dict, or None Returns: Dict[str, Any]: Parsed metadata as dictionary """ if metadata is None: return {} if isinstance(metadata, dict): return metadata if isinstance(metadata, str): try: return json.loads(metadata) except json.JSONDecodeError: current_app.logger.warning(f"Failed to parse metadata JSON string: {metadata}") return {} current_app.logger.warning(f"Unexpected metadata type: {type(metadata)}") return {} def _apply_metadata_filter(self, query_obj, arguments: RetrieverArguments): """ Apply metadata filters to the query based on tagging_fields_filter configuration Args: query_obj: SQLAlchemy query object arguments: Retriever arguments (for variable substitution) Returns: Modified SQLAlchemy query object """ if not self.tagging_fields_filter: return query_obj # Get dynamic argument values dynamic_values = {} for arg_name, arg_config in self.dynamic_arguments.items(): if hasattr(arguments, arg_name): dynamic_values[arg_name] = getattr(arguments, arg_name) # Build the filter filter_condition = self._build_filter_condition(self.tagging_fields_filter, dynamic_values) if filter_condition is not None: query_obj = query_obj.filter(filter_condition) self.log_tuning("Applied metadata filter", { "filter_sql": str(filter_condition), "dynamic_values": dynamic_values }) return query_obj def _build_filter_condition(self, filter_def: Dict[str, Any], dynamic_values: Dict[str, Any]) -> Optional[ expression.BinaryExpression]: """ Recursively build SQLAlchemy filter condition from filter definition Args: filter_def: Filter definition (logical group or field condition) dynamic_values: Values for dynamic variable substitution Returns: SQLAlchemy expression or None if invalid """ # Handle logical groups (AND, OR, NOT) if 'logical' in filter_def: logical_op = filter_def['logical'].lower() subconditions = [ self._build_filter_condition(cond, dynamic_values) for cond in filter_def.get('conditions', []) ] # Remove None values subconditions = [c for c in subconditions if c is not None] if not subconditions: return None if logical_op == 'and': return and_(*subconditions) elif logical_op == 'or': return or_(*subconditions) elif logical_op == 'not': if len(subconditions) == 1: return ~subconditions[0] else: # NOT should have exactly one condition current_app.logger.warning(f"NOT operator requires exactly one condition, got {len(subconditions)}") return None else: current_app.logger.warning(f"Unknown logical operator: {logical_op}") return None # Handle field conditions elif 'field' in filter_def and 'operator' in filter_def and 'value' in filter_def: field_name = filter_def['field'] operator = filter_def['operator'].lower() value = self._resolve_value(filter_def['value'], dynamic_values, filter_def.get('default')) # Skip if we couldn't resolve the value if value is None and operator not in ['is_null', 'is_not_null']: return None # Create the field expression to match JSON data field_expr = cast(DocumentVersion.catalog_properties['tagging_fields'][field_name], String) # Apply the appropriate operator return self._apply_operator(field_expr, operator, value, filter_def) else: current_app.logger.warning(f"Invalid filter definition: {filter_def}") return None def _resolve_value(self, value_def: Any, dynamic_values: Dict[str, Any], default: Any = None) -> Any: """ Resolve a value definition, handling variables and defaults Args: value_def: Value definition (could be literal, variable reference, or list) dynamic_values: Values for dynamic variable substitution default: Default value if variable not found Returns: Resolved value """ # Handle lists (recursively resolve each item) if isinstance(value_def, list): return [self._resolve_value(item, dynamic_values) for item in value_def] # Handle variable references (strings starting with $) if isinstance(value_def, str) and value_def.startswith('$'): var_name = value_def[1:] # Remove $ prefix if var_name in dynamic_values: return dynamic_values[var_name] else: # Use default if provided return default # Return literal values as-is return value_def def _apply_operator(self, field_expr, operator: str, value: Any, filter_def: Dict[str, Any]) -> Optional[ expression.BinaryExpression]: """ Apply the specified operator to create a filter condition Args: field_expr: SQLAlchemy field expression operator: Operator to apply value: Value to compare against filter_def: Original filter definition (for additional options) Returns: SQLAlchemy expression """ try: # String operators if operator == 'eq': return field_expr == str(value) elif operator == 'neq': return field_expr != str(value) elif operator == 'contains': return field_expr.contains(str(value)) elif operator == 'not_contains': return ~field_expr.contains(str(value)) elif operator == 'starts_with': return field_expr.startswith(str(value)) elif operator == 'ends_with': return field_expr.endswith(str(value)) elif operator == 'in': return field_expr.in_([str(v) for v in value]) elif operator == 'not_in': return ~field_expr.in_([str(v) for v in value]) elif operator == 'regex' or operator == 'not_regex': # PostgreSQL regex using ~ or !~ operator case_insensitive = filter_def.get('case_insensitive', False) regex_op = '~*' if case_insensitive else '~' if operator == 'not_regex': regex_op = '!~*' if case_insensitive else '!~' return text( f"{field_expr.compile(compile_kwargs={'literal_binds': True})} {regex_op} :regex_value").bindparams( regex_value=str(value)) # Numeric/Date operators elif operator == 'gt': return cast(field_expr, Float) > float(value) elif operator == 'gte': return cast(field_expr, Float) >= float(value) elif operator == 'lt': return cast(field_expr, Float) < float(value) elif operator == 'lte': return cast(field_expr, Float) <= float(value) elif operator == 'between': if len(value) == 2: return cast(field_expr, Float).between(float(value[0]), float(value[1])) else: current_app.logger.warning(f"BETWEEN operator requires exactly two values, got {len(value)}") return None elif operator == 'not_between': if len(value) == 2: return ~cast(field_expr, Float).between(float(value[0]), float(value[1])) else: current_app.logger.warning(f"NOT_BETWEEN operator requires exactly two values, got {len(value)}") return None # Null checking elif operator == 'is_null': return field_expr.is_(None) elif operator == 'is_not_null': return field_expr.isnot(None) else: current_app.logger.warning(f"Unknown operator: {operator}") return None except (ValueError, TypeError) as e: current_app.logger.error(f"Error applying operator {operator} with value {value}: {e}") return None def retrieve(self, arguments: RetrieverArguments) -> List[RetrieverResult]: """ Retrieve documents based on query with added metadata filtering Args: arguments: Validated RetrieverArguments containing at minimum: - query: str - The search query Returns: List[RetrieverResult]: List of retrieved documents with similarity scores """ try: query = arguments.query # Get query embedding query_embedding = self.embedding_model.embed_query(query) # Get the appropriate embedding database model db_class = self.embedding_model_class # Get current date for validity checks current_date = dt.now(tz=tz.utc).date() # Create subquery for latest versions subquery = ( db.session.query( DocumentVersion.doc_id, func.max(DocumentVersion.id).label('latest_version_id') ) .group_by(DocumentVersion.doc_id) .subquery() ) # Main query query_obj = ( db.session.query( db_class, (1 - db_class.embedding.cosine_distance(query_embedding)).label('similarity') ) .join(DocumentVersion, db_class.doc_vers_id == DocumentVersion.id) .join(Document, DocumentVersion.doc_id == Document.id) .join(subquery, DocumentVersion.id == subquery.c.latest_version_id) .filter( or_(Document.valid_from.is_(None), func.date(Document.valid_from) <= current_date), or_(Document.valid_to.is_(None), func.date(Document.valid_to) >= current_date), (1 - db_class.embedding.cosine_distance(query_embedding)) > self.similarity_threshold, Document.catalog_id == self.catalog_id ) ) # Apply metadata filtering query_obj = self._apply_metadata_filter(query_obj, arguments) # Apply ordering and limit query_obj = query_obj.order_by(desc('similarity')).limit(self.k) # Execute query results = query_obj.all() # Transform results into standard format processed_results = [] for doc, similarity in results: # Parse user_metadata to ensure it's a dictionary user_metadata = self._parse_metadata(doc.document_version.user_metadata) processed_results.append( RetrieverResult( id=doc.id, chunk=doc.chunk, similarity=float(similarity), metadata=RetrieverMetadata( document_id=doc.document_version.doc_id, version_id=doc.document_version.id, document_name=doc.document_version.document.name, user_metadata=user_metadata, ) ) ) # Log the retrieval if self.tuning: compiled_query = str(query_obj.statement.compile( compile_kwargs={"literal_binds": True} # This will include the actual values in the SQL )) self.log_tuning('retrieve', { "arguments": arguments.model_dump(), "similarity_threshold": self.similarity_threshold, "k": self.k, "query": compiled_query, "results_count": len(results), "processed_results_count": len(processed_results), }) return processed_results except SQLAlchemyError as e: current_app.logger.error(f'Error in Dossier retrieval: {e}') db.session.rollback() raise except Exception as e: current_app.logger.error(f'Unexpected error in Dossier retrieval: {e}') raise # Register the retriever type RetrieverRegistry.register("DOSSIER_RETRIEVER", DossierRetriever)