44 lines
1.5 KiB
Python
44 lines
1.5 KiB
Python
from langchain_core.retrievers import BaseRetriever
|
|
from sqlalchemy import asc
|
|
from sqlalchemy.exc import SQLAlchemyError
|
|
from pydantic import BaseModel, Field
|
|
from typing import Any, Dict
|
|
from flask import current_app
|
|
|
|
from common.extensions import db
|
|
from common.models.interaction import ChatSession, Interaction
|
|
from common.utils.datetime_utils import get_date_in_timezone
|
|
|
|
|
|
class EveAIHistoryRetriever(BaseRetriever):
|
|
model_variables: Dict[str, Any] = Field(...)
|
|
session_id: str = Field(...)
|
|
|
|
def __init__(self, model_variables: Dict[str, Any], session_id: str):
|
|
super().__init__()
|
|
self.model_variables = model_variables
|
|
self.session_id = session_id
|
|
|
|
def _get_relevant_documents(self, query: str):
|
|
current_app.logger.debug(f'Retrieving history of interactions for query: {query}')
|
|
|
|
try:
|
|
query_obj = (
|
|
db.session.query(Interaction)
|
|
.join(ChatSession, Interaction.chat_session_id == ChatSession.id)
|
|
.filter(ChatSession.session_id == self.session_id)
|
|
.order_by(asc(Interaction.id))
|
|
)
|
|
|
|
interactions = query_obj.all()
|
|
|
|
result = []
|
|
for interaction in interactions:
|
|
result.append(f'HUMAN:\n{interaction.detailed_question}\n\nAI: \n{interaction.answer}\n\n')
|
|
|
|
except SQLAlchemyError as e:
|
|
current_app.logger.error(f'Error retrieving history of interactions: {e}')
|
|
db.session.rollback()
|
|
return []
|
|
|
|
return result |