162 lines
6.0 KiB
Python
162 lines
6.0 KiB
Python
# retrievers/standard_rag.py
|
|
import json
|
|
from datetime import datetime as dt, timezone as tz
|
|
from typing import Dict, Any, List
|
|
from sqlalchemy import func, or_, desc
|
|
from sqlalchemy.exc import SQLAlchemyError
|
|
from flask import current_app
|
|
|
|
from common.extensions import db
|
|
from common.models.document import Document, DocumentVersion, Catalog, Retriever
|
|
from common.models.user import Tenant
|
|
from common.utils.datetime_utils import get_date_in_timezone
|
|
from common.utils.model_utils import get_embedding_model_and_class
|
|
from .base import BaseRetriever
|
|
|
|
from .registry import RetrieverRegistry
|
|
from .retriever_typing import RetrieverArguments, RetrieverResult, RetrieverMetadata
|
|
|
|
|
|
class StandardRAGRetriever(BaseRetriever):
|
|
"""Standard RAG retriever implementation"""
|
|
|
|
def __init__(self, tenant_id: int, retriever_id: int):
|
|
super().__init__(tenant_id, retriever_id)
|
|
|
|
# Set up standard retrieval parameters
|
|
self.embedding_model, self.embedding_model_class, self.catalog_id, self.similarity_threshold, self.k = self.setup_standard_retrieval_params()
|
|
self.log_tuning("Standard RAG retriever initialized", {
|
|
"similarity_threshold": self.similarity_threshold,
|
|
"k": self.k
|
|
})
|
|
|
|
@property
|
|
def type(self) -> str:
|
|
return "STANDARD_RAG"
|
|
|
|
def _parse_metadata(self, metadata: Any) -> Dict[str, Any]:
|
|
"""
|
|
Parse metadata ensuring it's a dictionary
|
|
|
|
Args:
|
|
metadata: Input metadata which could be string, dict, or None
|
|
|
|
Returns:
|
|
Dict[str, Any]: Parsed metadata as dictionary
|
|
"""
|
|
if metadata is None:
|
|
return {}
|
|
|
|
if isinstance(metadata, dict):
|
|
return metadata
|
|
|
|
if isinstance(metadata, str):
|
|
try:
|
|
return json.loads(metadata)
|
|
except json.JSONDecodeError:
|
|
current_app.logger.warning(f"Failed to parse metadata JSON string: {metadata}")
|
|
return {}
|
|
|
|
current_app.logger.warning(f"Unexpected metadata type: {type(metadata)}")
|
|
return {}
|
|
|
|
def retrieve(self, arguments: RetrieverArguments) -> List[RetrieverResult]:
|
|
"""
|
|
Retrieve documents based on query
|
|
|
|
Args:
|
|
arguments: Validated RetrieverArguments containing at minimum:
|
|
- query: str - The search query
|
|
|
|
Returns:
|
|
List[RetrieverResult]: List of retrieved documents with similarity scores
|
|
"""
|
|
try:
|
|
query = arguments.query
|
|
|
|
# Get query embedding
|
|
query_embedding = self.embedding_model.embed_query(query)
|
|
|
|
# Get the appropriate embedding database model
|
|
db_class = self.embedding_model_class
|
|
|
|
# Get current date for validity checks
|
|
current_date = dt.now(tz=tz.utc).date()
|
|
|
|
# Create subquery for latest versions
|
|
subquery = (
|
|
db.session.query(
|
|
DocumentVersion.doc_id,
|
|
func.max(DocumentVersion.id).label('latest_version_id')
|
|
)
|
|
.group_by(DocumentVersion.doc_id)
|
|
.subquery()
|
|
)
|
|
|
|
# Main query
|
|
query_obj = (
|
|
db.session.query(
|
|
db_class,
|
|
(1 - db_class.embedding.cosine_distance(query_embedding)).label('similarity')
|
|
)
|
|
.join(DocumentVersion, db_class.doc_vers_id == DocumentVersion.id)
|
|
.join(Document, DocumentVersion.doc_id == Document.id)
|
|
.join(subquery, DocumentVersion.id == subquery.c.latest_version_id)
|
|
.filter(
|
|
or_(Document.valid_from.is_(None), func.date(Document.valid_from) <= current_date),
|
|
or_(Document.valid_to.is_(None), func.date(Document.valid_to) >= current_date),
|
|
(1 - db_class.embedding.cosine_distance(query_embedding)) > self.similarity_threshold,
|
|
Document.catalog_id == self.catalog_id
|
|
)
|
|
.order_by(desc('similarity'))
|
|
.limit(self.k)
|
|
)
|
|
|
|
results = query_obj.all()
|
|
|
|
# Transform results into standard format
|
|
processed_results = []
|
|
for doc, similarity in results:
|
|
# Parse user_metadata to ensure it's a dictionary
|
|
user_metadata = self._parse_metadata(doc.document_version.user_metadata)
|
|
processed_results.append(
|
|
RetrieverResult(
|
|
id=doc.id,
|
|
chunk=doc.chunk,
|
|
similarity=float(similarity),
|
|
metadata=RetrieverMetadata(
|
|
document_id=doc.document_version.doc_id,
|
|
version_id=doc.document_version.id,
|
|
document_name=doc.document_version.document.name,
|
|
user_metadata=user_metadata,
|
|
)
|
|
)
|
|
)
|
|
|
|
# Log the retrieval
|
|
if self.tuning:
|
|
compiled_query = str(query_obj.statement.compile(
|
|
compile_kwargs={"literal_binds": True} # This will include the actual values in the SQL
|
|
))
|
|
self.log_tuning('retrieve', {
|
|
"arguments": arguments.model_dump(),
|
|
"similarity_threshold": self.similarity_threshold,
|
|
"k": self.k,
|
|
"query": compiled_query,
|
|
"Raw Results": str(results),
|
|
"Processed Results": [r.model_dump() for r in processed_results],
|
|
})
|
|
|
|
return processed_results
|
|
|
|
except SQLAlchemyError as e:
|
|
current_app.logger.error(f'Error in RAG retrieval: {e}')
|
|
db.session.rollback()
|
|
raise
|
|
except Exception as e:
|
|
current_app.logger.error(f'Unexpected error in RAG retrieval: {e}')
|
|
raise
|
|
|
|
|
|
# Register the retriever type
|
|
RetrieverRegistry.register("STANDARD_RAG", StandardRAGRetriever) |