154 lines
6.1 KiB
Python
154 lines
6.1 KiB
Python
from langchain_core.retrievers import BaseRetriever
|
|
from sqlalchemy import func, and_, or_, desc, cast, JSON
|
|
from sqlalchemy.exc import SQLAlchemyError
|
|
from pydantic import BaseModel, Field, PrivateAttr
|
|
from typing import Any, Dict, List, Optional
|
|
from flask import current_app
|
|
from contextlib import contextmanager
|
|
|
|
from common.extensions import db
|
|
from common.models.document import Document, DocumentVersion, Catalog
|
|
from common.utils.datetime_utils import get_date_in_timezone
|
|
from common.utils.model_utils import ModelVariables
|
|
|
|
|
|
class EveAIDossierRetriever(BaseRetriever, BaseModel):
|
|
_catalog_id: int = PrivateAttr()
|
|
_model_variables: ModelVariables = PrivateAttr()
|
|
_tenant_info: Dict[str, Any] = PrivateAttr()
|
|
_active_filters: Optional[Dict[str, Any]] = PrivateAttr()
|
|
|
|
def __init__(self, catalog_id: int, model_variables: ModelVariables, tenant_info: Dict[str, Any]):
|
|
super().__init__()
|
|
self._catalog_id = catalog_id
|
|
self._model_variables = model_variables
|
|
self._tenant_info = tenant_info
|
|
self._active_filters = None
|
|
|
|
@contextmanager
|
|
def filtering(self, metadata_filters: Dict[str, Any]):
|
|
"""Context manager for temporarily setting metadata filters"""
|
|
previous_filters = self._active_filters
|
|
self._active_filters = metadata_filters
|
|
try:
|
|
yield self
|
|
finally:
|
|
self._active_filters = previous_filters
|
|
|
|
def _build_metadata_filter_conditions(self, query):
|
|
"""Build SQL conditions for metadata filtering"""
|
|
if not self._active_filters:
|
|
return query
|
|
|
|
conditions = []
|
|
for field, value in self._active_filters.items():
|
|
if value is None:
|
|
continue
|
|
|
|
# Handle both single values and lists of values
|
|
if isinstance(value, (list, tuple)):
|
|
# Multiple values - create OR condition
|
|
or_conditions = []
|
|
for val in value:
|
|
or_conditions.append(
|
|
cast(DocumentVersion.user_metadata[field].astext, JSON) == str(val)
|
|
)
|
|
if or_conditions:
|
|
conditions.append(or_(*or_conditions))
|
|
else:
|
|
# Single value - direct comparison
|
|
conditions.append(
|
|
cast(DocumentVersion.user_metadata[field].astext, JSON) == str(value)
|
|
)
|
|
|
|
if conditions:
|
|
query = query.filter(and_(*conditions))
|
|
|
|
return query
|
|
|
|
def _get_relevant_documents(self, query: str):
|
|
current_app.logger.debug(f'Retrieving relevant documents for dossier query: {query}')
|
|
if self._active_filters:
|
|
current_app.logger.debug(f'Using metadata filters: {self._active_filters}')
|
|
|
|
query_embedding = self._get_query_embedding(query)
|
|
db_class = self.model_variables['embedding_db_model']
|
|
similarity_threshold = self.model_variables['similarity_threshold']
|
|
k = self.model_variables['k']
|
|
|
|
try:
|
|
current_date = get_date_in_timezone(self.tenant_info['timezone'])
|
|
|
|
# Subquery to find the latest version of each document
|
|
subquery = (
|
|
db.session.query(
|
|
DocumentVersion.doc_id,
|
|
func.max(DocumentVersion.id).label('latest_version_id')
|
|
)
|
|
.group_by(DocumentVersion.doc_id)
|
|
.subquery()
|
|
)
|
|
|
|
# Build base query
|
|
# Build base query
|
|
query_obj = (
|
|
db.session.query(db_class,
|
|
(1 - db_class.embedding.cosine_distance(query_embedding)).label('similarity'))
|
|
.join(DocumentVersion, db_class.doc_vers_id == DocumentVersion.id)
|
|
.join(Document, DocumentVersion.doc_id == Document.id)
|
|
.join(subquery, DocumentVersion.id == subquery.c.latest_version_id)
|
|
.filter(
|
|
or_(Document.valid_from.is_(None), func.date(Document.valid_from) <= current_date),
|
|
or_(Document.valid_to.is_(None), func.date(Document.valid_to) >= current_date),
|
|
(1 - db_class.embedding.cosine_distance(query_embedding)) > similarity_threshold,
|
|
Document.catalog_id == self._catalog_id
|
|
)
|
|
)
|
|
|
|
# Apply metadata filters
|
|
query_obj = self._build_metadata_filter_conditions(query_obj)
|
|
|
|
# Order and limit results
|
|
query_obj = query_obj.order_by(desc('similarity')).limit(k)
|
|
|
|
# Debug logging for RAG tuning if enabled
|
|
if self.model_variables['rag_tuning']:
|
|
self._log_rag_tuning(query_obj, query_embedding)
|
|
|
|
res = query_obj.all()
|
|
|
|
result = []
|
|
for doc in res:
|
|
if self.model_variables['rag_tuning']:
|
|
current_app.logger.debug(f'Document ID: {doc[0].id} - Distance: {doc[1]}\n')
|
|
current_app.logger.debug(f'Chunk: \n {doc[0].chunk}\n\n')
|
|
result.append(f'SOURCE: {doc[0].id}\n\n{doc[0].chunk}\n\n')
|
|
|
|
except SQLAlchemyError as e:
|
|
current_app.logger.error(f'Error retrieving relevant documents: {e}')
|
|
db.session.rollback()
|
|
return []
|
|
|
|
return result
|
|
|
|
def _log_rag_tuning(self, query_obj, query_embedding):
|
|
"""Log debug information for RAG tuning"""
|
|
current_app.rag_tuning_logger.debug("Debug: Query execution plan:")
|
|
current_app.rag_tuning_logger.debug(f"{query_obj.statement}")
|
|
if self._active_filters:
|
|
current_app.rag_tuning_logger.debug("Debug: Active metadata filters:")
|
|
current_app.rag_tuning_logger.debug(f"{self._active_filters}")
|
|
|
|
def _get_query_embedding(self, query: str):
|
|
"""Get embedding for the query text"""
|
|
embedding_model = self.model_variables['embedding_model']
|
|
query_embedding = embedding_model.embed_query(query)
|
|
return query_embedding
|
|
|
|
@property
|
|
def model_variables(self) -> ModelVariables:
|
|
return self._model_variables
|
|
|
|
@property
|
|
def tenant_info(self) -> Dict[str, Any]:
|
|
return self._tenant_info |