52 lines
1.7 KiB
Python
52 lines
1.7 KiB
Python
from langchain_core.retrievers import BaseRetriever
|
|
from sqlalchemy import asc
|
|
from sqlalchemy.exc import SQLAlchemyError
|
|
from pydantic import Field, BaseModel, PrivateAttr
|
|
from typing import Any, Dict
|
|
from flask import current_app
|
|
|
|
from common.extensions import db
|
|
from common.models.interaction import ChatSession, Interaction
|
|
from common.utils.model_utils import ModelVariables
|
|
|
|
|
|
class EveAIHistoryRetriever(BaseRetriever, BaseModel):
|
|
_model_variables: ModelVariables = PrivateAttr()
|
|
_session_id: str = PrivateAttr()
|
|
|
|
def __init__(self, model_variables: ModelVariables, session_id: str):
|
|
super().__init__()
|
|
self._model_variables = model_variables
|
|
self._session_id = session_id
|
|
|
|
@property
|
|
def model_variables(self) -> ModelVariables:
|
|
return self._model_variables
|
|
|
|
@property
|
|
def session_id(self) -> str:
|
|
return self._session_id
|
|
|
|
def _get_relevant_documents(self, query: str):
|
|
current_app.logger.debug(f'Retrieving history of interactions for query: {query}')
|
|
|
|
try:
|
|
query_obj = (
|
|
db.session.query(Interaction)
|
|
.join(ChatSession, Interaction.chat_session_id == ChatSession.id)
|
|
.filter(ChatSession.session_id == self.session_id)
|
|
.order_by(asc(Interaction.id))
|
|
)
|
|
|
|
interactions = query_obj.all()
|
|
|
|
result = []
|
|
for interaction in interactions:
|
|
result.append(f'HUMAN:\n{interaction.detailed_question}\n\nAI: \n{interaction.answer}\n\n')
|
|
|
|
except SQLAlchemyError as e:
|
|
current_app.logger.error(f'Error retrieving history of interactions: {e}')
|
|
db.session.rollback()
|
|
return []
|
|
|
|
return result |