- Started addition of Assets (to e.g. handle document templates).

- To be continued (Models, first views are ready)
This commit is contained in:
Josako
2025-03-17 17:40:42 +01:00
parent a6402524ce
commit cf2201a1f7
13 changed files with 778 additions and 39 deletions

View File

@@ -1,5 +1,6 @@
# Import all specialist implementations here to ensure registration
from . import standard_rag
from . import dossier_retriever
# List of all available specialist implementations
__all__ = ['standard_rag']
__all__ = ['standard_rag', 'dossier_retriever']

View File

@@ -1,8 +1,14 @@
from abc import ABC, abstractmethod, abstractproperty
from typing import Dict, Any, List
from typing import Dict, Any, List, Optional, Tuple
from flask import current_app
from eveai_chat_workers.retrievers.retriever_typing import RetrieverResult, RetrieverArguments
from sqlalchemy import func, or_, desc
from sqlalchemy.exc import SQLAlchemyError
from common.extensions import db
from common.models.document import Document, DocumentVersion, Catalog, Retriever
from common.utils.model_utils import get_embedding_model_and_class
from eveai_chat_workers.retrievers.retriever_typing import RetrieverResult, RetrieverArguments, RetrieverMetadata
from config.logging_config import TuningLogger
@@ -12,6 +18,7 @@ class BaseRetriever(ABC):
def __init__(self, tenant_id: int, retriever_id: int):
self.tenant_id = tenant_id
self.retriever_id = retriever_id
self.retriever = Retriever.query.get_or_404(retriever_id)
self.tuning = False
self.tuning_logger = None
self._setup_tuning_logger()
@@ -43,6 +50,31 @@ class BaseRetriever(ABC):
except Exception as e:
current_app.logger.error(f"Processor: Error in tuning logging: {e}")
def setup_standard_retrieval_params(self) -> Tuple[Any, Any, Any, float, int]:
"""
Set up common parameters needed for standard retrieval functionality
Returns:
Tuple containing:
- embedding_model: The model to use for embeddings
- embedding_model_class: The class for storing embeddings
- catalog_id: ID of the catalog
- similarity_threshold: Threshold for similarity matching
- k: Maximum number of results to return
"""
catalog_id = self.retriever.catalog_id
catalog = Catalog.query.get_or_404(catalog_id)
embedding_model = "mistral.mistral-embed"
embedding_model, embedding_model_class = get_embedding_model_and_class(
self.tenant_id, catalog_id, embedding_model
)
similarity_threshold = self.retriever.configuration.get('es_similarity_threshold', 0.3)
k = self.retriever.configuration.get('es_k', 8)
return embedding_model, embedding_model_class, catalog_id, similarity_threshold, k
@abstractmethod
def retrieve(self, arguments: RetrieverArguments) -> List[RetrieverResult]:
"""

View File

@@ -0,0 +1,374 @@
"""
DossierRetriever implementation that adds metadata filtering to retrieval
"""
import json
from datetime import datetime as dt, date, timezone as tz
from typing import Dict, Any, List, Optional, Union, Tuple
from sqlalchemy import func, or_, desc, and_, text, cast, JSON, String, Integer, Float, Boolean, DateTime
from sqlalchemy.sql import expression
from sqlalchemy.exc import SQLAlchemyError
from flask import current_app
from common.extensions import db
from common.models.document import Document, DocumentVersion, Catalog, Retriever
from common.utils.model_utils import get_embedding_model_and_class
from .base import BaseRetriever
from .registry import RetrieverRegistry
from .retriever_typing import RetrieverArguments, RetrieverResult, RetrieverMetadata
class DossierRetriever(BaseRetriever):
"""
Dossier Retriever implementation that adds metadata filtering
to standard retrieval functionality
"""
def __init__(self, tenant_id: int, retriever_id: int):
super().__init__(tenant_id, retriever_id)
# Set up standard retrieval parameters
self.embedding_model, self.embedding_model_class, self.catalog_id, self.similarity_threshold, self.k = self.setup_standard_retrieval_params()
# Dossier-specific configuration
self.tagging_fields_filter = self.retriever.configuration.get('tagging_fields_filter', {})
self.dynamic_arguments = self.retriever.configuration.get('dynamic_arguments', {})
self.log_tuning("Dossier retriever initialized", {
"tagging_fields_filter": self.tagging_fields_filter,
"dynamic_arguments": self.dynamic_arguments,
"similarity_threshold": self.similarity_threshold,
"k": self.k
})
@property
def type(self) -> str:
return "DOSSIER_RETRIEVER"
def _parse_metadata(self, metadata: Any) -> Dict[str, Any]:
"""
Parse metadata ensuring it's a dictionary
Args:
metadata: Input metadata which could be string, dict, or None
Returns:
Dict[str, Any]: Parsed metadata as dictionary
"""
if metadata is None:
return {}
if isinstance(metadata, dict):
return metadata
if isinstance(metadata, str):
try:
return json.loads(metadata)
except json.JSONDecodeError:
current_app.logger.warning(f"Failed to parse metadata JSON string: {metadata}")
return {}
current_app.logger.warning(f"Unexpected metadata type: {type(metadata)}")
return {}
def _apply_metadata_filter(self, query_obj, arguments: RetrieverArguments):
"""
Apply metadata filters to the query based on tagging_fields_filter configuration
Args:
query_obj: SQLAlchemy query object
arguments: Retriever arguments (for variable substitution)
Returns:
Modified SQLAlchemy query object
"""
if not self.tagging_fields_filter:
return query_obj
# Get dynamic argument values
dynamic_values = {}
for arg_name, arg_config in self.dynamic_arguments.items():
if hasattr(arguments, arg_name):
dynamic_values[arg_name] = getattr(arguments, arg_name)
# Build the filter
filter_condition = self._build_filter_condition(self.tagging_fields_filter, dynamic_values)
if filter_condition is not None:
query_obj = query_obj.filter(filter_condition)
self.log_tuning("Applied metadata filter", {
"filter_sql": str(filter_condition),
"dynamic_values": dynamic_values
})
return query_obj
def _build_filter_condition(self, filter_def: Dict[str, Any], dynamic_values: Dict[str, Any]) -> Optional[
expression.BinaryExpression]:
"""
Recursively build SQLAlchemy filter condition from filter definition
Args:
filter_def: Filter definition (logical group or field condition)
dynamic_values: Values for dynamic variable substitution
Returns:
SQLAlchemy expression or None if invalid
"""
# Handle logical groups (AND, OR, NOT)
if 'logical' in filter_def:
logical_op = filter_def['logical'].lower()
subconditions = [
self._build_filter_condition(cond, dynamic_values)
for cond in filter_def.get('conditions', [])
]
# Remove None values
subconditions = [c for c in subconditions if c is not None]
if not subconditions:
return None
if logical_op == 'and':
return and_(*subconditions)
elif logical_op == 'or':
return or_(*subconditions)
elif logical_op == 'not':
if len(subconditions) == 1:
return ~subconditions[0]
else:
# NOT should have exactly one condition
current_app.logger.warning(f"NOT operator requires exactly one condition, got {len(subconditions)}")
return None
else:
current_app.logger.warning(f"Unknown logical operator: {logical_op}")
return None
# Handle field conditions
elif 'field' in filter_def and 'operator' in filter_def and 'value' in filter_def:
field_name = filter_def['field']
operator = filter_def['operator'].lower()
value = self._resolve_value(filter_def['value'], dynamic_values, filter_def.get('default'))
# Skip if we couldn't resolve the value
if value is None and operator not in ['is_null', 'is_not_null']:
return None
# Create the field expression to match JSON data
field_expr = cast(DocumentVersion.catalog_properties['tagging_fields'][field_name], String)
# Apply the appropriate operator
return self._apply_operator(field_expr, operator, value, filter_def)
else:
current_app.logger.warning(f"Invalid filter definition: {filter_def}")
return None
def _resolve_value(self, value_def: Any, dynamic_values: Dict[str, Any], default: Any = None) -> Any:
"""
Resolve a value definition, handling variables and defaults
Args:
value_def: Value definition (could be literal, variable reference, or list)
dynamic_values: Values for dynamic variable substitution
default: Default value if variable not found
Returns:
Resolved value
"""
# Handle lists (recursively resolve each item)
if isinstance(value_def, list):
return [self._resolve_value(item, dynamic_values) for item in value_def]
# Handle variable references (strings starting with $)
if isinstance(value_def, str) and value_def.startswith('$'):
var_name = value_def[1:] # Remove $ prefix
if var_name in dynamic_values:
return dynamic_values[var_name]
else:
# Use default if provided
return default
# Return literal values as-is
return value_def
def _apply_operator(self, field_expr, operator: str, value: Any, filter_def: Dict[str, Any]) -> Optional[
expression.BinaryExpression]:
"""
Apply the specified operator to create a filter condition
Args:
field_expr: SQLAlchemy field expression
operator: Operator to apply
value: Value to compare against
filter_def: Original filter definition (for additional options)
Returns:
SQLAlchemy expression
"""
try:
# String operators
if operator == 'eq':
return field_expr == str(value)
elif operator == 'neq':
return field_expr != str(value)
elif operator == 'contains':
return field_expr.contains(str(value))
elif operator == 'not_contains':
return ~field_expr.contains(str(value))
elif operator == 'starts_with':
return field_expr.startswith(str(value))
elif operator == 'ends_with':
return field_expr.endswith(str(value))
elif operator == 'in':
return field_expr.in_([str(v) for v in value])
elif operator == 'not_in':
return ~field_expr.in_([str(v) for v in value])
elif operator == 'regex' or operator == 'not_regex':
# PostgreSQL regex using ~ or !~ operator
case_insensitive = filter_def.get('case_insensitive', False)
regex_op = '~*' if case_insensitive else '~'
if operator == 'not_regex':
regex_op = '!~*' if case_insensitive else '!~'
return text(
f"{field_expr.compile(compile_kwargs={'literal_binds': True})} {regex_op} :regex_value").bindparams(
regex_value=str(value))
# Numeric/Date operators
elif operator == 'gt':
return cast(field_expr, Float) > float(value)
elif operator == 'gte':
return cast(field_expr, Float) >= float(value)
elif operator == 'lt':
return cast(field_expr, Float) < float(value)
elif operator == 'lte':
return cast(field_expr, Float) <= float(value)
elif operator == 'between':
if len(value) == 2:
return cast(field_expr, Float).between(float(value[0]), float(value[1]))
else:
current_app.logger.warning(f"BETWEEN operator requires exactly two values, got {len(value)}")
return None
elif operator == 'not_between':
if len(value) == 2:
return ~cast(field_expr, Float).between(float(value[0]), float(value[1]))
else:
current_app.logger.warning(f"NOT_BETWEEN operator requires exactly two values, got {len(value)}")
return None
# Null checking
elif operator == 'is_null':
return field_expr.is_(None)
elif operator == 'is_not_null':
return field_expr.isnot(None)
else:
current_app.logger.warning(f"Unknown operator: {operator}")
return None
except (ValueError, TypeError) as e:
current_app.logger.error(f"Error applying operator {operator} with value {value}: {e}")
return None
def retrieve(self, arguments: RetrieverArguments) -> List[RetrieverResult]:
"""
Retrieve documents based on query with added metadata filtering
Args:
arguments: Validated RetrieverArguments containing at minimum:
- query: str - The search query
Returns:
List[RetrieverResult]: List of retrieved documents with similarity scores
"""
try:
query = arguments.query
# Get query embedding
query_embedding = self.embedding_model.embed_query(query)
# Get the appropriate embedding database model
db_class = self.embedding_model_class
# Get current date for validity checks
current_date = dt.now(tz=tz.utc).date()
# Create subquery for latest versions
subquery = (
db.session.query(
DocumentVersion.doc_id,
func.max(DocumentVersion.id).label('latest_version_id')
)
.group_by(DocumentVersion.doc_id)
.subquery()
)
# Main query
query_obj = (
db.session.query(
db_class,
(1 - db_class.embedding.cosine_distance(query_embedding)).label('similarity')
)
.join(DocumentVersion, db_class.doc_vers_id == DocumentVersion.id)
.join(Document, DocumentVersion.doc_id == Document.id)
.join(subquery, DocumentVersion.id == subquery.c.latest_version_id)
.filter(
or_(Document.valid_from.is_(None), func.date(Document.valid_from) <= current_date),
or_(Document.valid_to.is_(None), func.date(Document.valid_to) >= current_date),
(1 - db_class.embedding.cosine_distance(query_embedding)) > self.similarity_threshold,
Document.catalog_id == self.catalog_id
)
)
# Apply metadata filtering
query_obj = self._apply_metadata_filter(query_obj, arguments)
# Apply ordering and limit
query_obj = query_obj.order_by(desc('similarity')).limit(self.k)
# Execute query
results = query_obj.all()
# Transform results into standard format
processed_results = []
for doc, similarity in results:
# Parse user_metadata to ensure it's a dictionary
user_metadata = self._parse_metadata(doc.document_version.user_metadata)
processed_results.append(
RetrieverResult(
id=doc.id,
chunk=doc.chunk,
similarity=float(similarity),
metadata=RetrieverMetadata(
document_id=doc.document_version.doc_id,
version_id=doc.document_version.id,
document_name=doc.document_version.document.name,
user_metadata=user_metadata,
)
)
)
# Log the retrieval
if self.tuning:
compiled_query = str(query_obj.statement.compile(
compile_kwargs={"literal_binds": True} # This will include the actual values in the SQL
))
self.log_tuning('retrieve', {
"arguments": arguments.model_dump(),
"similarity_threshold": self.similarity_threshold,
"k": self.k,
"query": compiled_query,
"results_count": len(results),
"processed_results_count": len(processed_results),
})
return processed_results
except SQLAlchemyError as e:
current_app.logger.error(f'Error in Dossier retrieval: {e}')
db.session.rollback()
raise
except Exception as e:
current_app.logger.error(f'Unexpected error in Dossier retrieval: {e}')
raise
# Register the retriever type
RetrieverRegistry.register("DOSSIER_RETRIEVER", DossierRetriever)

View File

@@ -23,20 +23,12 @@ class StandardRAGRetriever(BaseRetriever):
def __init__(self, tenant_id: int, retriever_id: int):
super().__init__(tenant_id, retriever_id)
retriever = Retriever.query.get_or_404(retriever_id)
self.catalog_id = retriever.catalog_id
self.tenant_id = tenant_id
catalog = Catalog.query.get_or_404(self.catalog_id)
embedding_model = "mistral.mistral-embed"
self.embedding_model, self.embedding_model_class = get_embedding_model_and_class(self.tenant_id,
self.catalog_id,
embedding_model)
self.similarity_threshold = retriever.configuration.get('es_similarity_threshold', 0.3)
self.k = retriever.configuration.get('es_k', 8)
self.tuning = retriever.tuning
self.log_tuning("Standard RAG retriever initialized")
# Set up standard retrieval parameters
self.embedding_model, self.embedding_model_class, self.catalog_id, self.similarity_threshold, self.k = self.setup_standard_retrieval_params()
self.log_tuning("Standard RAG retriever initialized", {
"similarity_threshold": self.similarity_threshold,
"k": self.k
})
@property
def type(self) -> str:
@@ -167,4 +159,4 @@ class StandardRAGRetriever(BaseRetriever):
# Register the retriever type
RetrieverRegistry.register("STANDARD_RAG", StandardRAGRetriever)
RetrieverRegistry.register("STANDARD_RAG", StandardRAGRetriever)