from langchain_core.retrievers import BaseRetriever from sqlalchemy.exc import SQLAlchemyError from pydantic import BaseModel, Field from typing import Any, Dict from common.extensions import db from flask import current_app from config.logging_config import LOGGING class EveAIRetriever(BaseRetriever): model_variables: Dict[str, Any] = Field(...) def __init__(self, model_variables: Dict[str, Any]): super().__init__() current_app.logger.debug('Initializing EveAIRetriever') self.model_variables = model_variables current_app.logger.debug('EveAIRetriever initialized') def _get_relevant_documents(self, query: str): current_app.logger.debug(f'Retrieving relevant documents for query: {query}') query_embedding = self._get_query_embedding(query) db_class = self.model_variables['embedding_db_model'] similarity_threshold = self.model_variables['similarity_threshold'] k = self.model_variables['k'] try: res = ( db.session.query(db_class, db_class.embedding.cosine_distance(query_embedding) .label('distance')) .filter(db_class.embedding.cosine_distance(query_embedding) < similarity_threshold) .order_by('distance') .limit(k) .all() ) current_app.rag_tuning_logger.debug(f'Retrieved {len(res)} relevant documents') current_app.rag_tuning_logger.debug(f'---------------------------------------') for doc in res: current_app.rag_tuning_logger.debug(f'Document ID: {doc[0].id} - Distance: {doc[1]}\n') current_app.rag_tuning_logger.debug(f'Chunk: \n {doc[0].chunk}\n\n') except SQLAlchemyError as e: current_app.logger.error(f'Error retrieving relevant documents: {e}') return [] return res def _get_query_embedding(self, query: str): embedding_model = self.model_variables['embedding_model'] query_embedding = embedding_model.embed_query(query) return query_embedding