Files
eveAI/common/langchain/retrievers/eveai_dossier_retriever.py
2024-11-01 11:19:34 +01:00

154 lines
6.1 KiB
Python

from langchain_core.retrievers import BaseRetriever
from sqlalchemy import func, and_, or_, desc, cast, JSON
from sqlalchemy.exc import SQLAlchemyError
from pydantic import BaseModel, Field, PrivateAttr
from typing import Any, Dict, List, Optional
from flask import current_app
from contextlib import contextmanager
from common.extensions import db
from common.models.document import Document, DocumentVersion, Catalog
from common.utils.datetime_utils import get_date_in_timezone
from common.utils.model_utils import ModelVariables
class EveAIDossierRetriever(BaseRetriever, BaseModel):
_catalog_id: int = PrivateAttr()
_model_variables: ModelVariables = PrivateAttr()
_tenant_info: Dict[str, Any] = PrivateAttr()
_active_filters: Optional[Dict[str, Any]] = PrivateAttr()
def __init__(self, catalog_id: int, model_variables: ModelVariables, tenant_info: Dict[str, Any]):
super().__init__()
self._catalog_id = catalog_id
self._model_variables = model_variables
self._tenant_info = tenant_info
self._active_filters = None
@contextmanager
def filtering(self, metadata_filters: Dict[str, Any]):
"""Context manager for temporarily setting metadata filters"""
previous_filters = self._active_filters
self._active_filters = metadata_filters
try:
yield self
finally:
self._active_filters = previous_filters
def _build_metadata_filter_conditions(self, query):
"""Build SQL conditions for metadata filtering"""
if not self._active_filters:
return query
conditions = []
for field, value in self._active_filters.items():
if value is None:
continue
# Handle both single values and lists of values
if isinstance(value, (list, tuple)):
# Multiple values - create OR condition
or_conditions = []
for val in value:
or_conditions.append(
cast(DocumentVersion.user_metadata[field].astext, JSON) == str(val)
)
if or_conditions:
conditions.append(or_(*or_conditions))
else:
# Single value - direct comparison
conditions.append(
cast(DocumentVersion.user_metadata[field].astext, JSON) == str(value)
)
if conditions:
query = query.filter(and_(*conditions))
return query
def _get_relevant_documents(self, query: str):
current_app.logger.debug(f'Retrieving relevant documents for dossier query: {query}')
if self._active_filters:
current_app.logger.debug(f'Using metadata filters: {self._active_filters}')
query_embedding = self._get_query_embedding(query)
db_class = self.model_variables['embedding_db_model']
similarity_threshold = self.model_variables['similarity_threshold']
k = self.model_variables['k']
try:
current_date = get_date_in_timezone(self.tenant_info['timezone'])
# Subquery to find the latest version of each document
subquery = (
db.session.query(
DocumentVersion.doc_id,
func.max(DocumentVersion.id).label('latest_version_id')
)
.group_by(DocumentVersion.doc_id)
.subquery()
)
# Build base query
# Build base query
query_obj = (
db.session.query(db_class,
(1 - db_class.embedding.cosine_distance(query_embedding)).label('similarity'))
.join(DocumentVersion, db_class.doc_vers_id == DocumentVersion.id)
.join(Document, DocumentVersion.doc_id == Document.id)
.join(subquery, DocumentVersion.id == subquery.c.latest_version_id)
.filter(
or_(Document.valid_from.is_(None), func.date(Document.valid_from) <= current_date),
or_(Document.valid_to.is_(None), func.date(Document.valid_to) >= current_date),
(1 - db_class.embedding.cosine_distance(query_embedding)) > similarity_threshold,
Document.catalog_id == self._catalog_id
)
)
# Apply metadata filters
query_obj = self._build_metadata_filter_conditions(query_obj)
# Order and limit results
query_obj = query_obj.order_by(desc('similarity')).limit(k)
# Debug logging for RAG tuning if enabled
if self.model_variables['rag_tuning']:
self._log_rag_tuning(query_obj, query_embedding)
res = query_obj.all()
result = []
for doc in res:
if self.model_variables['rag_tuning']:
current_app.logger.debug(f'Document ID: {doc[0].id} - Distance: {doc[1]}\n')
current_app.logger.debug(f'Chunk: \n {doc[0].chunk}\n\n')
result.append(f'SOURCE: {doc[0].id}\n\n{doc[0].chunk}\n\n')
except SQLAlchemyError as e:
current_app.logger.error(f'Error retrieving relevant documents: {e}')
db.session.rollback()
return []
return result
def _log_rag_tuning(self, query_obj, query_embedding):
"""Log debug information for RAG tuning"""
current_app.rag_tuning_logger.debug("Debug: Query execution plan:")
current_app.rag_tuning_logger.debug(f"{query_obj.statement}")
if self._active_filters:
current_app.rag_tuning_logger.debug("Debug: Active metadata filters:")
current_app.rag_tuning_logger.debug(f"{self._active_filters}")
def _get_query_embedding(self, query: str):
"""Get embedding for the query text"""
embedding_model = self.model_variables['embedding_model']
query_embedding = embedding_model.embed_query(query)
return query_embedding
@property
def model_variables(self) -> ModelVariables:
return self._model_variables
@property
def tenant_info(self) -> Dict[str, Any]:
return self._tenant_info