from langchain_core.retrievers import BaseRetriever from sqlalchemy import func, and_, or_ from sqlalchemy.exc import SQLAlchemyError from pydantic import BaseModel, Field from typing import Any, Dict from flask import current_app from common.extensions import db from common.models.document import Document, DocumentVersion from common.utils.datetime_utils import get_date_in_timezone class EveAIRetriever(BaseRetriever): model_variables: Dict[str, Any] = Field(...) tenant_info: Dict[str, Any] = Field(...) def __init__(self, model_variables: Dict[str, Any], tenant_info: Dict[str, Any]): super().__init__() self.model_variables = model_variables self.tenant_info = tenant_info def _get_relevant_documents(self, query: str): current_app.logger.debug(f'Retrieving relevant documents for query: {query}') query_embedding = self._get_query_embedding(query) db_class = self.model_variables['embedding_db_model'] similarity_threshold = self.model_variables['similarity_threshold'] k = self.model_variables['k'] try: current_date = get_date_in_timezone(self.tenant_info['timezone']) # Subquery to find the latest version of each document subquery = ( db.session.query( DocumentVersion.doc_id, func.max(DocumentVersion.id).label('latest_version_id') ) .group_by(DocumentVersion.doc_id) .subquery() ) # Main query to filter embeddings query_obj = ( db.session.query(db_class, db_class.embedding.cosine_distance(query_embedding).label('distance')) .join(DocumentVersion, db_class.doc_vers_id == DocumentVersion.id) .join(Document, DocumentVersion.doc_id == Document.id) .join(subquery, DocumentVersion.id == subquery.c.latest_version_id) .filter( or_(Document.valid_from.is_(None), Document.valid_from <= current_date), or_(Document.valid_to.is_(None), Document.valid_to >= current_date), db_class.embedding.cosine_distance(query_embedding) < similarity_threshold ) .order_by('distance') .limit(k) ) res = query_obj.all() if self.tenant_info['rag_tuning']: current_app.rag_tuning_logger.debug(f'Retrieved {len(res)} relevant documents') current_app.rag_tuning_logger.debug(f'---------------------------------------') result = [] for doc in res: if self.tenant_info['rag_tuning']: current_app.rag_tuning_logger.debug(f'Document ID: {doc[0].id} - Distance: {doc[1]}\n') current_app.rag_tuning_logger.debug(f'Chunk: \n {doc[0].chunk}\n\n') result.append(f'SOURCE: {doc[0].id}\n\n{doc[0].chunk}\n\n') except SQLAlchemyError as e: current_app.logger.error(f'Error retrieving relevant documents: {e}') db.session.rollback() return [] return result def _get_query_embedding(self, query: str): embedding_model = self.model_variables['embedding_model'] query_embedding = embedding_model.embed_query(query) return query_embedding