from langchain_core.retrievers import BaseRetriever from sqlalchemy import func, and_, or_, desc from sqlalchemy.exc import SQLAlchemyError from pydantic import BaseModel, Field from typing import Any, Dict from flask import current_app from common.extensions import db from common.models.document import Document, DocumentVersion from common.utils.datetime_utils import get_date_in_timezone class EveAIRetriever(BaseRetriever): model_variables: Dict[str, Any] = Field(...) tenant_info: Dict[str, Any] = Field(...) def __init__(self, model_variables: Dict[str, Any], tenant_info: Dict[str, Any]): super().__init__() self.model_variables = model_variables self.tenant_info = tenant_info def _get_relevant_documents(self, query: str): current_app.logger.debug(f'Retrieving relevant documents for query: {query}') query_embedding = self._get_query_embedding(query) db_class = self.model_variables['embedding_db_model'] similarity_threshold = self.model_variables['similarity_threshold'] k = self.model_variables['k'] if self.tenant_info['rag_tuning']: try: current_date = get_date_in_timezone(self.tenant_info['timezone']) current_app.rag_tuning_logger.debug(f'Current date: {current_date}\n') # Debug query to show similarity for all valid documents (without chunk text) debug_query = ( db.session.query( Document.id.label('document_id'), DocumentVersion.id.label('version_id'), db_class.id.label('embedding_id'), (1 - db_class.embedding.cosine_distance(query_embedding)).label('similarity') ) .join(DocumentVersion, db_class.doc_vers_id == DocumentVersion.id) .join(Document, DocumentVersion.doc_id == Document.id) .filter( or_(Document.valid_from.is_(None), func.date(Document.valid_from) <= current_date), or_(Document.valid_to.is_(None), func.date(Document.valid_to) >= current_date) ) .order_by(desc('similarity')) ) debug_results = debug_query.all() current_app.logger.debug("Debug: Similarity for all valid documents:") for row in debug_results: current_app.rag_tuning_logger.debug(f"Doc ID: {row.document_id}, " f"Version ID: {row.version_id}, " f"Embedding ID: {row.embedding_id}, " f"Similarity: {row.similarity}") current_app.rag_tuning_logger.debug(f'---------------------------------------\n') except SQLAlchemyError as e: current_app.logger.error(f'Error generating overview: {e}') db.session.rollback() if self.tenant_info['rag_tuning']: current_app.rag_tuning_logger.debug(f'Parameters for Retrieval of documents: \n') current_app.rag_tuning_logger.debug(f'Similarity Threshold: {similarity_threshold}\n') current_app.rag_tuning_logger.debug(f'K: {k}\n') current_app.rag_tuning_logger.debug(f'---------------------------------------\n') try: current_date = get_date_in_timezone(self.tenant_info['timezone']) # Subquery to find the latest version of each document subquery = ( db.session.query( DocumentVersion.doc_id, func.max(DocumentVersion.id).label('latest_version_id') ) .group_by(DocumentVersion.doc_id) .subquery() ) # Main query to filter embeddings query_obj = ( db.session.query(db_class, (1 - db_class.embedding.cosine_distance(query_embedding)).label('similarity')) .join(DocumentVersion, db_class.doc_vers_id == DocumentVersion.id) .join(Document, DocumentVersion.doc_id == Document.id) .join(subquery, DocumentVersion.id == subquery.c.latest_version_id) .filter( or_(Document.valid_from.is_(None), func.date(Document.valid_from) <= current_date), or_(Document.valid_to.is_(None), func.date(Document.valid_to) >= current_date), (1 - db_class.embedding.cosine_distance(query_embedding)) > similarity_threshold ) .order_by(desc('similarity')) .limit(k) ) if self.tenant_info['rag_tuning']: current_app.rag_tuning_logger.debug(f'Query executed for Retrieval of documents: \n') current_app.rag_tuning_logger.debug(f'{query_obj.statement}\n') current_app.rag_tuning_logger.debug(f'---------------------------------------\n') res = query_obj.all() if self.tenant_info['rag_tuning']: current_app.rag_tuning_logger.debug(f'Retrieved {len(res)} relevant documents \n') current_app.rag_tuning_logger.debug(f'Data retrieved: \n') current_app.rag_tuning_logger.debug(f'{res}\n') current_app.rag_tuning_logger.debug(f'---------------------------------------\n') result = [] for doc in res: if self.tenant_info['rag_tuning']: current_app.rag_tuning_logger.debug(f'Document ID: {doc[0].id} - Distance: {doc[1]}\n') current_app.rag_tuning_logger.debug(f'Chunk: \n {doc[0].chunk}\n\n') result.append(f'SOURCE: {doc[0].id}\n\n{doc[0].chunk}\n\n') except SQLAlchemyError as e: current_app.logger.error(f'Error retrieving relevant documents: {e}') db.session.rollback() return [] return result def _get_query_embedding(self, query: str): embedding_model = self.model_variables['embedding_model'] query_embedding = embedding_model.embed_query(query) return query_embedding