Files
eveAI/eveai_chat_workers/retrievers/standard_rag.py
2025-03-10 08:31:15 +01:00

171 lines
6.5 KiB
Python

# retrievers/standard_rag.py
import json
from datetime import datetime as dt, timezone as tz
from typing import Dict, Any, List
from sqlalchemy import func, or_, desc
from sqlalchemy.exc import SQLAlchemyError
from flask import current_app
from common.extensions import db
from common.models.document import Document, DocumentVersion, Catalog, Retriever
from common.models.user import Tenant
from common.utils.datetime_utils import get_date_in_timezone
from common.utils.model_utils import get_embedding_model_and_class
from .base import BaseRetriever
from .registry import RetrieverRegistry
from .retriever_typing import RetrieverArguments, RetrieverResult, RetrieverMetadata
class StandardRAGRetriever(BaseRetriever):
"""Standard RAG retriever implementation"""
def __init__(self, tenant_id: int, retriever_id: int):
super().__init__(tenant_id, retriever_id)
retriever = Retriever.query.get_or_404(retriever_id)
self.catalog_id = retriever.catalog_id
self.tenant_id = tenant_id
catalog = Catalog.query.get_or_404(self.catalog_id)
embedding_model = "mistral.mistral-embed"
self.embedding_model, self.embedding_model_class = get_embedding_model_and_class(self.tenant_id,
self.catalog_id,
embedding_model)
self.similarity_threshold = retriever.configuration.get('es_similarity_threshold', 0.3)
self.k = retriever.configuration.get('es_k', 8)
self.tuning = retriever.tuning
self.log_tuning("Standard RAG retriever initialized")
@property
def type(self) -> str:
return "STANDARD_RAG"
def _parse_metadata(self, metadata: Any) -> Dict[str, Any]:
"""
Parse metadata ensuring it's a dictionary
Args:
metadata: Input metadata which could be string, dict, or None
Returns:
Dict[str, Any]: Parsed metadata as dictionary
"""
if metadata is None:
return {}
if isinstance(metadata, dict):
return metadata
if isinstance(metadata, str):
try:
return json.loads(metadata)
except json.JSONDecodeError:
current_app.logger.warning(f"Failed to parse metadata JSON string: {metadata}")
return {}
current_app.logger.warning(f"Unexpected metadata type: {type(metadata)}")
return {}
def retrieve(self, arguments: RetrieverArguments) -> List[RetrieverResult]:
"""
Retrieve documents based on query
Args:
arguments: Validated RetrieverArguments containing at minimum:
- query: str - The search query
Returns:
List[RetrieverResult]: List of retrieved documents with similarity scores
"""
try:
query = arguments.query
# Get query embedding
query_embedding = self.embedding_model.embed_query(query)
# Get the appropriate embedding database model
db_class = self.embedding_model_class
# Get current date for validity checks
current_date = dt.now(tz=tz.utc).date()
# Create subquery for latest versions
subquery = (
db.session.query(
DocumentVersion.doc_id,
func.max(DocumentVersion.id).label('latest_version_id')
)
.group_by(DocumentVersion.doc_id)
.subquery()
)
# Main query
query_obj = (
db.session.query(
db_class,
(1 - db_class.embedding.cosine_distance(query_embedding)).label('similarity')
)
.join(DocumentVersion, db_class.doc_vers_id == DocumentVersion.id)
.join(Document, DocumentVersion.doc_id == Document.id)
.join(subquery, DocumentVersion.id == subquery.c.latest_version_id)
.filter(
or_(Document.valid_from.is_(None), func.date(Document.valid_from) <= current_date),
or_(Document.valid_to.is_(None), func.date(Document.valid_to) >= current_date),
(1 - db_class.embedding.cosine_distance(query_embedding)) > self.similarity_threshold,
Document.catalog_id == self.catalog_id
)
.order_by(desc('similarity'))
.limit(self.k)
)
results = query_obj.all()
# Transform results into standard format
processed_results = []
for doc, similarity in results:
# Parse user_metadata to ensure it's a dictionary
user_metadata = self._parse_metadata(doc.document_version.user_metadata)
processed_results.append(
RetrieverResult(
id=doc.id,
chunk=doc.chunk,
similarity=float(similarity),
metadata=RetrieverMetadata(
document_id=doc.document_version.doc_id,
version_id=doc.document_version.id,
document_name=doc.document_version.document.name,
user_metadata=user_metadata,
)
)
)
# Log the retrieval
if self.tuning:
compiled_query = str(query_obj.statement.compile(
compile_kwargs={"literal_binds": True} # This will include the actual values in the SQL
))
self.log_tuning('retrieve', {
"arguments": arguments.model_dump(),
"similarity_threshold": self.similarity_threshold,
"k": self.k,
"query": compiled_query,
"Raw Results": str(results),
"Processed Results": [r.model_dump() for r in processed_results],
})
return processed_results
except SQLAlchemyError as e:
current_app.logger.error(f'Error in RAG retrieval: {e}')
db.session.rollback()
raise
except Exception as e:
current_app.logger.error(f'Unexpected error in RAG retrieval: {e}')
raise
# Register the retriever type
RetrieverRegistry.register("STANDARD_RAG", StandardRAGRetriever)