Include history to improve query if required.

This commit is contained in:
Josako
2024-06-13 10:35:15 +02:00
parent 50851dc51c
commit 24a3747b99
5 changed files with 77 additions and 4 deletions

View File

@@ -0,0 +1,44 @@
from langchain_core.retrievers import BaseRetriever
from sqlalchemy import asc
from sqlalchemy.exc import SQLAlchemyError
from pydantic import BaseModel, Field
from typing import Any, Dict
from flask import current_app
from common.extensions import db
from common.models.interaction import ChatSession, Interaction
from common.utils.datetime_utils import get_date_in_timezone
class EveAIHistoryRetriever(BaseRetriever):
model_variables: Dict[str, Any] = Field(...)
session_id: str = Field(...)
def __init__(self, model_variables: Dict[str, Any], session_id: str):
super().__init__()
self.model_variables = model_variables
self.session_id = session_id
def _get_relevant_documents(self, query: str):
current_app.logger.debug(f'Retrieving history of interactions for query: {query}')
try:
query_obj = (
db.session.query(Interaction)
.join(ChatSession, Interaction.chat_session_id == ChatSession.id)
.filter(ChatSession.session_id == self.session_id)
.order_by(asc(Interaction.id))
)
interactions = query_obj.all()
result = []
for interaction in interactions:
result.append(f'HUMAN:\n{interaction.detailed_question}\n\nAI: \n{interaction.answer}\n\n')
except SQLAlchemyError as e:
current_app.logger.error(f'Error retrieving history of interactions: {e}')
db.session.rollback()
return []
return result

View File

@@ -16,10 +16,8 @@ class EveAIRetriever(BaseRetriever):
def __init__(self, model_variables: Dict[str, Any], tenant_info: Dict[str, Any]):
super().__init__()
current_app.logger.debug('Initializing EveAIRetriever')
self.model_variables = model_variables
self.tenant_info = tenant_info
current_app.logger.debug('EveAIRetriever initialized')
def _get_relevant_documents(self, query: str):
current_app.logger.debug(f'Retrieving relevant documents for query: {query}')

View File

@@ -21,6 +21,7 @@ class Interaction(db.Model):
id = db.Column(db.Integer, primary_key=True)
chat_session_id = db.Column(db.Integer, db.ForeignKey(ChatSession.id), nullable=False)
question = db.Column(db.Text, nullable=False)
detailed_question = db.Column(db.Text, nullable=True)
answer = db.Column(db.Text, nullable=True)
algorithm_used = db.Column(db.String(20), nullable=True)
language = db.Column(db.String(2), nullable=False)
@@ -28,6 +29,7 @@ class Interaction(db.Model):
# Timing information
question_at = db.Column(db.DateTime, nullable=False)
detailed_question_at = db.Column(db.DateTime, nullable=True)
answer_at = db.Column(db.DateTime, nullable=True)
# Relations

View File

@@ -103,15 +103,18 @@ def select_model_variables(tenant):
case 'gpt-4-turbo' | 'gpt-4o':
summary_template = current_app.config.get('GPT4_SUMMARY_TEMPLATE')
rag_template = current_app.config.get('GPT4_RAG_TEMPLATE')
history_template = current_app.config.get('GPT4_HISTORY_TEMPLATE')
tool_calling_supported = True
case 'gpt-3-5-turbo':
summary_template = current_app.config.get('GPT3_5_SUMMARY_TEMPLATE')
rag_template = current_app.config.get('GPT3_5_RAG_TEMPLATE')
history_template = current_app.config.get('GPT3_5_HISTORY_TEMPLATE')
case _:
raise Exception(f'Error setting model variables for tenant {tenant.id} '
f'error: Invalid chat model')
model_variables['summary_template'] = summary_template
model_variables['rag_template'] = rag_template
model_variables['history_template'] = history_template
if tool_calling_supported:
model_variables['cited_answer_cls'] = CitedAnswer
case _: