- Started addition of Assets (to e.g. handle document templates).
- To be continued (Models, first views are ready)
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
# Import all specialist implementations here to ensure registration
|
||||
from . import standard_rag
|
||||
from . import dossier_retriever
|
||||
|
||||
# List of all available specialist implementations
|
||||
__all__ = ['standard_rag']
|
||||
__all__ = ['standard_rag', 'dossier_retriever']
|
||||
|
||||
@@ -1,8 +1,14 @@
|
||||
from abc import ABC, abstractmethod, abstractproperty
|
||||
from typing import Dict, Any, List
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
|
||||
from flask import current_app
|
||||
from eveai_chat_workers.retrievers.retriever_typing import RetrieverResult, RetrieverArguments
|
||||
from sqlalchemy import func, or_, desc
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from common.extensions import db
|
||||
from common.models.document import Document, DocumentVersion, Catalog, Retriever
|
||||
from common.utils.model_utils import get_embedding_model_and_class
|
||||
from eveai_chat_workers.retrievers.retriever_typing import RetrieverResult, RetrieverArguments, RetrieverMetadata
|
||||
from config.logging_config import TuningLogger
|
||||
|
||||
|
||||
@@ -12,6 +18,7 @@ class BaseRetriever(ABC):
|
||||
def __init__(self, tenant_id: int, retriever_id: int):
|
||||
self.tenant_id = tenant_id
|
||||
self.retriever_id = retriever_id
|
||||
self.retriever = Retriever.query.get_or_404(retriever_id)
|
||||
self.tuning = False
|
||||
self.tuning_logger = None
|
||||
self._setup_tuning_logger()
|
||||
@@ -43,6 +50,31 @@ class BaseRetriever(ABC):
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Processor: Error in tuning logging: {e}")
|
||||
|
||||
def setup_standard_retrieval_params(self) -> Tuple[Any, Any, Any, float, int]:
|
||||
"""
|
||||
Set up common parameters needed for standard retrieval functionality
|
||||
|
||||
Returns:
|
||||
Tuple containing:
|
||||
- embedding_model: The model to use for embeddings
|
||||
- embedding_model_class: The class for storing embeddings
|
||||
- catalog_id: ID of the catalog
|
||||
- similarity_threshold: Threshold for similarity matching
|
||||
- k: Maximum number of results to return
|
||||
"""
|
||||
catalog_id = self.retriever.catalog_id
|
||||
catalog = Catalog.query.get_or_404(catalog_id)
|
||||
embedding_model = "mistral.mistral-embed"
|
||||
|
||||
embedding_model, embedding_model_class = get_embedding_model_and_class(
|
||||
self.tenant_id, catalog_id, embedding_model
|
||||
)
|
||||
|
||||
similarity_threshold = self.retriever.configuration.get('es_similarity_threshold', 0.3)
|
||||
k = self.retriever.configuration.get('es_k', 8)
|
||||
|
||||
return embedding_model, embedding_model_class, catalog_id, similarity_threshold, k
|
||||
|
||||
@abstractmethod
|
||||
def retrieve(self, arguments: RetrieverArguments) -> List[RetrieverResult]:
|
||||
"""
|
||||
|
||||
374
eveai_chat_workers/retrievers/dossier_retriever.py
Normal file
374
eveai_chat_workers/retrievers/dossier_retriever.py
Normal file
@@ -0,0 +1,374 @@
|
||||
"""
|
||||
DossierRetriever implementation that adds metadata filtering to retrieval
|
||||
"""
|
||||
import json
|
||||
from datetime import datetime as dt, date, timezone as tz
|
||||
from typing import Dict, Any, List, Optional, Union, Tuple
|
||||
from sqlalchemy import func, or_, desc, and_, text, cast, JSON, String, Integer, Float, Boolean, DateTime
|
||||
from sqlalchemy.sql import expression
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from flask import current_app
|
||||
|
||||
from common.extensions import db
|
||||
from common.models.document import Document, DocumentVersion, Catalog, Retriever
|
||||
from common.utils.model_utils import get_embedding_model_and_class
|
||||
from .base import BaseRetriever
|
||||
from .registry import RetrieverRegistry
|
||||
from .retriever_typing import RetrieverArguments, RetrieverResult, RetrieverMetadata
|
||||
|
||||
|
||||
class DossierRetriever(BaseRetriever):
|
||||
"""
|
||||
Dossier Retriever implementation that adds metadata filtering
|
||||
to standard retrieval functionality
|
||||
"""
|
||||
|
||||
def __init__(self, tenant_id: int, retriever_id: int):
|
||||
super().__init__(tenant_id, retriever_id)
|
||||
|
||||
# Set up standard retrieval parameters
|
||||
self.embedding_model, self.embedding_model_class, self.catalog_id, self.similarity_threshold, self.k = self.setup_standard_retrieval_params()
|
||||
|
||||
# Dossier-specific configuration
|
||||
self.tagging_fields_filter = self.retriever.configuration.get('tagging_fields_filter', {})
|
||||
self.dynamic_arguments = self.retriever.configuration.get('dynamic_arguments', {})
|
||||
|
||||
self.log_tuning("Dossier retriever initialized", {
|
||||
"tagging_fields_filter": self.tagging_fields_filter,
|
||||
"dynamic_arguments": self.dynamic_arguments,
|
||||
"similarity_threshold": self.similarity_threshold,
|
||||
"k": self.k
|
||||
})
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return "DOSSIER_RETRIEVER"
|
||||
|
||||
def _parse_metadata(self, metadata: Any) -> Dict[str, Any]:
|
||||
"""
|
||||
Parse metadata ensuring it's a dictionary
|
||||
|
||||
Args:
|
||||
metadata: Input metadata which could be string, dict, or None
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Parsed metadata as dictionary
|
||||
"""
|
||||
if metadata is None:
|
||||
return {}
|
||||
|
||||
if isinstance(metadata, dict):
|
||||
return metadata
|
||||
|
||||
if isinstance(metadata, str):
|
||||
try:
|
||||
return json.loads(metadata)
|
||||
except json.JSONDecodeError:
|
||||
current_app.logger.warning(f"Failed to parse metadata JSON string: {metadata}")
|
||||
return {}
|
||||
|
||||
current_app.logger.warning(f"Unexpected metadata type: {type(metadata)}")
|
||||
return {}
|
||||
|
||||
def _apply_metadata_filter(self, query_obj, arguments: RetrieverArguments):
|
||||
"""
|
||||
Apply metadata filters to the query based on tagging_fields_filter configuration
|
||||
|
||||
Args:
|
||||
query_obj: SQLAlchemy query object
|
||||
arguments: Retriever arguments (for variable substitution)
|
||||
|
||||
Returns:
|
||||
Modified SQLAlchemy query object
|
||||
"""
|
||||
if not self.tagging_fields_filter:
|
||||
return query_obj
|
||||
|
||||
# Get dynamic argument values
|
||||
dynamic_values = {}
|
||||
for arg_name, arg_config in self.dynamic_arguments.items():
|
||||
if hasattr(arguments, arg_name):
|
||||
dynamic_values[arg_name] = getattr(arguments, arg_name)
|
||||
|
||||
# Build the filter
|
||||
filter_condition = self._build_filter_condition(self.tagging_fields_filter, dynamic_values)
|
||||
if filter_condition is not None:
|
||||
query_obj = query_obj.filter(filter_condition)
|
||||
self.log_tuning("Applied metadata filter", {
|
||||
"filter_sql": str(filter_condition),
|
||||
"dynamic_values": dynamic_values
|
||||
})
|
||||
|
||||
return query_obj
|
||||
|
||||
def _build_filter_condition(self, filter_def: Dict[str, Any], dynamic_values: Dict[str, Any]) -> Optional[
|
||||
expression.BinaryExpression]:
|
||||
"""
|
||||
Recursively build SQLAlchemy filter condition from filter definition
|
||||
|
||||
Args:
|
||||
filter_def: Filter definition (logical group or field condition)
|
||||
dynamic_values: Values for dynamic variable substitution
|
||||
|
||||
Returns:
|
||||
SQLAlchemy expression or None if invalid
|
||||
"""
|
||||
# Handle logical groups (AND, OR, NOT)
|
||||
if 'logical' in filter_def:
|
||||
logical_op = filter_def['logical'].lower()
|
||||
subconditions = [
|
||||
self._build_filter_condition(cond, dynamic_values)
|
||||
for cond in filter_def.get('conditions', [])
|
||||
]
|
||||
# Remove None values
|
||||
subconditions = [c for c in subconditions if c is not None]
|
||||
|
||||
if not subconditions:
|
||||
return None
|
||||
|
||||
if logical_op == 'and':
|
||||
return and_(*subconditions)
|
||||
elif logical_op == 'or':
|
||||
return or_(*subconditions)
|
||||
elif logical_op == 'not':
|
||||
if len(subconditions) == 1:
|
||||
return ~subconditions[0]
|
||||
else:
|
||||
# NOT should have exactly one condition
|
||||
current_app.logger.warning(f"NOT operator requires exactly one condition, got {len(subconditions)}")
|
||||
return None
|
||||
else:
|
||||
current_app.logger.warning(f"Unknown logical operator: {logical_op}")
|
||||
return None
|
||||
|
||||
# Handle field conditions
|
||||
elif 'field' in filter_def and 'operator' in filter_def and 'value' in filter_def:
|
||||
field_name = filter_def['field']
|
||||
operator = filter_def['operator'].lower()
|
||||
value = self._resolve_value(filter_def['value'], dynamic_values, filter_def.get('default'))
|
||||
|
||||
# Skip if we couldn't resolve the value
|
||||
if value is None and operator not in ['is_null', 'is_not_null']:
|
||||
return None
|
||||
|
||||
# Create the field expression to match JSON data
|
||||
field_expr = cast(DocumentVersion.catalog_properties['tagging_fields'][field_name], String)
|
||||
|
||||
# Apply the appropriate operator
|
||||
return self._apply_operator(field_expr, operator, value, filter_def)
|
||||
|
||||
else:
|
||||
current_app.logger.warning(f"Invalid filter definition: {filter_def}")
|
||||
return None
|
||||
|
||||
def _resolve_value(self, value_def: Any, dynamic_values: Dict[str, Any], default: Any = None) -> Any:
|
||||
"""
|
||||
Resolve a value definition, handling variables and defaults
|
||||
|
||||
Args:
|
||||
value_def: Value definition (could be literal, variable reference, or list)
|
||||
dynamic_values: Values for dynamic variable substitution
|
||||
default: Default value if variable not found
|
||||
|
||||
Returns:
|
||||
Resolved value
|
||||
"""
|
||||
# Handle lists (recursively resolve each item)
|
||||
if isinstance(value_def, list):
|
||||
return [self._resolve_value(item, dynamic_values) for item in value_def]
|
||||
|
||||
# Handle variable references (strings starting with $)
|
||||
if isinstance(value_def, str) and value_def.startswith('$'):
|
||||
var_name = value_def[1:] # Remove $ prefix
|
||||
if var_name in dynamic_values:
|
||||
return dynamic_values[var_name]
|
||||
else:
|
||||
# Use default if provided
|
||||
return default
|
||||
|
||||
# Return literal values as-is
|
||||
return value_def
|
||||
|
||||
def _apply_operator(self, field_expr, operator: str, value: Any, filter_def: Dict[str, Any]) -> Optional[
|
||||
expression.BinaryExpression]:
|
||||
"""
|
||||
Apply the specified operator to create a filter condition
|
||||
|
||||
Args:
|
||||
field_expr: SQLAlchemy field expression
|
||||
operator: Operator to apply
|
||||
value: Value to compare against
|
||||
filter_def: Original filter definition (for additional options)
|
||||
|
||||
Returns:
|
||||
SQLAlchemy expression
|
||||
"""
|
||||
try:
|
||||
# String operators
|
||||
if operator == 'eq':
|
||||
return field_expr == str(value)
|
||||
elif operator == 'neq':
|
||||
return field_expr != str(value)
|
||||
elif operator == 'contains':
|
||||
return field_expr.contains(str(value))
|
||||
elif operator == 'not_contains':
|
||||
return ~field_expr.contains(str(value))
|
||||
elif operator == 'starts_with':
|
||||
return field_expr.startswith(str(value))
|
||||
elif operator == 'ends_with':
|
||||
return field_expr.endswith(str(value))
|
||||
elif operator == 'in':
|
||||
return field_expr.in_([str(v) for v in value])
|
||||
elif operator == 'not_in':
|
||||
return ~field_expr.in_([str(v) for v in value])
|
||||
elif operator == 'regex' or operator == 'not_regex':
|
||||
# PostgreSQL regex using ~ or !~ operator
|
||||
case_insensitive = filter_def.get('case_insensitive', False)
|
||||
regex_op = '~*' if case_insensitive else '~'
|
||||
if operator == 'not_regex':
|
||||
regex_op = '!~*' if case_insensitive else '!~'
|
||||
return text(
|
||||
f"{field_expr.compile(compile_kwargs={'literal_binds': True})} {regex_op} :regex_value").bindparams(
|
||||
regex_value=str(value))
|
||||
|
||||
# Numeric/Date operators
|
||||
elif operator == 'gt':
|
||||
return cast(field_expr, Float) > float(value)
|
||||
elif operator == 'gte':
|
||||
return cast(field_expr, Float) >= float(value)
|
||||
elif operator == 'lt':
|
||||
return cast(field_expr, Float) < float(value)
|
||||
elif operator == 'lte':
|
||||
return cast(field_expr, Float) <= float(value)
|
||||
elif operator == 'between':
|
||||
if len(value) == 2:
|
||||
return cast(field_expr, Float).between(float(value[0]), float(value[1]))
|
||||
else:
|
||||
current_app.logger.warning(f"BETWEEN operator requires exactly two values, got {len(value)}")
|
||||
return None
|
||||
elif operator == 'not_between':
|
||||
if len(value) == 2:
|
||||
return ~cast(field_expr, Float).between(float(value[0]), float(value[1]))
|
||||
else:
|
||||
current_app.logger.warning(f"NOT_BETWEEN operator requires exactly two values, got {len(value)}")
|
||||
return None
|
||||
|
||||
# Null checking
|
||||
elif operator == 'is_null':
|
||||
return field_expr.is_(None)
|
||||
elif operator == 'is_not_null':
|
||||
return field_expr.isnot(None)
|
||||
|
||||
else:
|
||||
current_app.logger.warning(f"Unknown operator: {operator}")
|
||||
return None
|
||||
|
||||
except (ValueError, TypeError) as e:
|
||||
current_app.logger.error(f"Error applying operator {operator} with value {value}: {e}")
|
||||
return None
|
||||
|
||||
def retrieve(self, arguments: RetrieverArguments) -> List[RetrieverResult]:
|
||||
"""
|
||||
Retrieve documents based on query with added metadata filtering
|
||||
|
||||
Args:
|
||||
arguments: Validated RetrieverArguments containing at minimum:
|
||||
- query: str - The search query
|
||||
|
||||
Returns:
|
||||
List[RetrieverResult]: List of retrieved documents with similarity scores
|
||||
"""
|
||||
try:
|
||||
query = arguments.query
|
||||
|
||||
# Get query embedding
|
||||
query_embedding = self.embedding_model.embed_query(query)
|
||||
|
||||
# Get the appropriate embedding database model
|
||||
db_class = self.embedding_model_class
|
||||
|
||||
# Get current date for validity checks
|
||||
current_date = dt.now(tz=tz.utc).date()
|
||||
|
||||
# Create subquery for latest versions
|
||||
subquery = (
|
||||
db.session.query(
|
||||
DocumentVersion.doc_id,
|
||||
func.max(DocumentVersion.id).label('latest_version_id')
|
||||
)
|
||||
.group_by(DocumentVersion.doc_id)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
# Main query
|
||||
query_obj = (
|
||||
db.session.query(
|
||||
db_class,
|
||||
(1 - db_class.embedding.cosine_distance(query_embedding)).label('similarity')
|
||||
)
|
||||
.join(DocumentVersion, db_class.doc_vers_id == DocumentVersion.id)
|
||||
.join(Document, DocumentVersion.doc_id == Document.id)
|
||||
.join(subquery, DocumentVersion.id == subquery.c.latest_version_id)
|
||||
.filter(
|
||||
or_(Document.valid_from.is_(None), func.date(Document.valid_from) <= current_date),
|
||||
or_(Document.valid_to.is_(None), func.date(Document.valid_to) >= current_date),
|
||||
(1 - db_class.embedding.cosine_distance(query_embedding)) > self.similarity_threshold,
|
||||
Document.catalog_id == self.catalog_id
|
||||
)
|
||||
)
|
||||
|
||||
# Apply metadata filtering
|
||||
query_obj = self._apply_metadata_filter(query_obj, arguments)
|
||||
|
||||
# Apply ordering and limit
|
||||
query_obj = query_obj.order_by(desc('similarity')).limit(self.k)
|
||||
|
||||
# Execute query
|
||||
results = query_obj.all()
|
||||
|
||||
# Transform results into standard format
|
||||
processed_results = []
|
||||
for doc, similarity in results:
|
||||
# Parse user_metadata to ensure it's a dictionary
|
||||
user_metadata = self._parse_metadata(doc.document_version.user_metadata)
|
||||
processed_results.append(
|
||||
RetrieverResult(
|
||||
id=doc.id,
|
||||
chunk=doc.chunk,
|
||||
similarity=float(similarity),
|
||||
metadata=RetrieverMetadata(
|
||||
document_id=doc.document_version.doc_id,
|
||||
version_id=doc.document_version.id,
|
||||
document_name=doc.document_version.document.name,
|
||||
user_metadata=user_metadata,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Log the retrieval
|
||||
if self.tuning:
|
||||
compiled_query = str(query_obj.statement.compile(
|
||||
compile_kwargs={"literal_binds": True} # This will include the actual values in the SQL
|
||||
))
|
||||
self.log_tuning('retrieve', {
|
||||
"arguments": arguments.model_dump(),
|
||||
"similarity_threshold": self.similarity_threshold,
|
||||
"k": self.k,
|
||||
"query": compiled_query,
|
||||
"results_count": len(results),
|
||||
"processed_results_count": len(processed_results),
|
||||
})
|
||||
|
||||
return processed_results
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
current_app.logger.error(f'Error in Dossier retrieval: {e}')
|
||||
db.session.rollback()
|
||||
raise
|
||||
except Exception as e:
|
||||
current_app.logger.error(f'Unexpected error in Dossier retrieval: {e}')
|
||||
raise
|
||||
|
||||
|
||||
# Register the retriever type
|
||||
RetrieverRegistry.register("DOSSIER_RETRIEVER", DossierRetriever)
|
||||
@@ -23,20 +23,12 @@ class StandardRAGRetriever(BaseRetriever):
|
||||
def __init__(self, tenant_id: int, retriever_id: int):
|
||||
super().__init__(tenant_id, retriever_id)
|
||||
|
||||
retriever = Retriever.query.get_or_404(retriever_id)
|
||||
self.catalog_id = retriever.catalog_id
|
||||
self.tenant_id = tenant_id
|
||||
catalog = Catalog.query.get_or_404(self.catalog_id)
|
||||
embedding_model = "mistral.mistral-embed"
|
||||
|
||||
self.embedding_model, self.embedding_model_class = get_embedding_model_and_class(self.tenant_id,
|
||||
self.catalog_id,
|
||||
embedding_model)
|
||||
self.similarity_threshold = retriever.configuration.get('es_similarity_threshold', 0.3)
|
||||
self.k = retriever.configuration.get('es_k', 8)
|
||||
self.tuning = retriever.tuning
|
||||
|
||||
self.log_tuning("Standard RAG retriever initialized")
|
||||
# Set up standard retrieval parameters
|
||||
self.embedding_model, self.embedding_model_class, self.catalog_id, self.similarity_threshold, self.k = self.setup_standard_retrieval_params()
|
||||
self.log_tuning("Standard RAG retriever initialized", {
|
||||
"similarity_threshold": self.similarity_threshold,
|
||||
"k": self.k
|
||||
})
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
@@ -167,4 +159,4 @@ class StandardRAGRetriever(BaseRetriever):
|
||||
|
||||
|
||||
# Register the retriever type
|
||||
RetrieverRegistry.register("STANDARD_RAG", StandardRAGRetriever)
|
||||
RetrieverRegistry.register("STANDARD_RAG", StandardRAGRetriever)
|
||||
Reference in New Issue
Block a user