- Revisiting RAG_SPECIALIST
- Adapt Catalogs & Retrievers to use specific types, removing tagging_fields - Adding CrewAI Implementation Guide
This commit is contained in:
136
eveai_chat_workers/retrievers/globals/STANDARD_RAG/1_0.py
Normal file
136
eveai_chat_workers/retrievers/globals/STANDARD_RAG/1_0.py
Normal file
@@ -0,0 +1,136 @@
|
||||
# retrievers/standard_rag.py
|
||||
import json
|
||||
from datetime import datetime as dt, timezone as tz
|
||||
from typing import Dict, Any, List
|
||||
from sqlalchemy import func, or_, desc
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from flask import current_app
|
||||
|
||||
from common.extensions import db
|
||||
from common.models.document import Document, DocumentVersion, Catalog, Retriever
|
||||
from common.models.user import Tenant
|
||||
from common.utils.datetime_utils import get_date_in_timezone
|
||||
from common.utils.model_utils import get_embedding_model_and_class
|
||||
from eveai_chat_workers.retrievers.base_retriever import BaseRetriever
|
||||
|
||||
from eveai_chat_workers.retrievers.retriever_typing import RetrieverArguments, RetrieverResult, RetrieverMetadata
|
||||
|
||||
|
||||
class RetrieverExecutor(BaseRetriever):
|
||||
"""Standard RAG retriever implementation"""
|
||||
|
||||
def __init__(self, tenant_id: int, retriever_id: int):
|
||||
super().__init__(tenant_id, retriever_id)
|
||||
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return "STANDARD_RAG"
|
||||
|
||||
@property
|
||||
def type_version(self) -> str:
|
||||
return "1.0"
|
||||
|
||||
def retrieve(self, arguments: RetrieverArguments) -> List[RetrieverResult]:
|
||||
"""
|
||||
Retrieve documents based on query
|
||||
|
||||
Args:
|
||||
arguments: Validated RetrieverArguments containing at minimum:
|
||||
- query: str - The search query
|
||||
|
||||
Returns:
|
||||
List[RetrieverResult]: List of retrieved documents with similarity scores
|
||||
"""
|
||||
try:
|
||||
question = arguments.question
|
||||
|
||||
# Get query embedding
|
||||
query_embedding = self.embedding_model.embed_query(question)
|
||||
|
||||
# Get the appropriate embedding database model
|
||||
db_class = self.embedding_model_class
|
||||
|
||||
# Get current date for validity checks
|
||||
current_date = dt.now(tz=tz.utc).date()
|
||||
|
||||
# Create subquery for latest versions
|
||||
subquery = (
|
||||
db.session.query(
|
||||
DocumentVersion.doc_id,
|
||||
func.max(DocumentVersion.id).label('latest_version_id')
|
||||
)
|
||||
.group_by(DocumentVersion.doc_id)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
similarity_threshold = self.retriever.configuration.get('es_similarity_threshold', 0.3)
|
||||
k = self.retriever.configuration.get('es_k', 8)
|
||||
|
||||
# Main query
|
||||
query_obj = (
|
||||
db.session.query(
|
||||
db_class,
|
||||
DocumentVersion.url,
|
||||
(1 - db_class.embedding.cosine_distance(query_embedding)).label('similarity')
|
||||
)
|
||||
.join(DocumentVersion, db_class.doc_vers_id == DocumentVersion.id)
|
||||
.join(Document, DocumentVersion.doc_id == Document.id)
|
||||
.join(subquery, DocumentVersion.id == subquery.c.latest_version_id)
|
||||
.filter(
|
||||
or_(Document.valid_from.is_(None), func.date(Document.valid_from) <= current_date),
|
||||
or_(Document.valid_to.is_(None), func.date(Document.valid_to) >= current_date),
|
||||
(1 - db_class.embedding.cosine_distance(query_embedding)) > similarity_threshold,
|
||||
Document.catalog_id == self.catalog_id
|
||||
)
|
||||
.order_by(desc('similarity'))
|
||||
.limit(k)
|
||||
)
|
||||
|
||||
results = query_obj.all()
|
||||
|
||||
# Transform results into standard format
|
||||
processed_results = []
|
||||
for doc, url, similarity in results:
|
||||
# Parse user_metadata to ensure it's a dictionary
|
||||
user_metadata = self._parse_metadata(doc.document_version.user_metadata)
|
||||
processed_results.append(
|
||||
RetrieverResult(
|
||||
id=doc.id,
|
||||
chunk=doc.chunk,
|
||||
similarity=float(similarity),
|
||||
metadata=RetrieverMetadata(
|
||||
document_id=doc.document_version.doc_id,
|
||||
version_id=doc.document_version.id,
|
||||
document_name=doc.document_version.document.name,
|
||||
url=url or "",
|
||||
user_metadata=user_metadata,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Log the retrieval
|
||||
if self.tuning:
|
||||
compiled_query = str(query_obj.statement.compile(
|
||||
compile_kwargs={"literal_binds": True} # This will include the actual values in the SQL
|
||||
))
|
||||
self.log_tuning('retrieve', {
|
||||
"arguments": arguments.model_dump(),
|
||||
"similarity_threshold": self.similarity_threshold,
|
||||
"k": self.k,
|
||||
"query": compiled_query,
|
||||
"Raw Results": str(results),
|
||||
"Processed Results": [r.model_dump() for r in processed_results],
|
||||
})
|
||||
|
||||
return processed_results
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
current_app.logger.error(f'Error in RAG retrieval: {e}')
|
||||
db.session.rollback()
|
||||
raise
|
||||
except Exception as e:
|
||||
current_app.logger.error(f'Unexpected error in RAG retrieval: {e}')
|
||||
raise
|
||||
|
||||
|
||||
Reference in New Issue
Block a user