Files
eveAI/eveai_chat_workers/specialists/rag_specialist.py

290 lines
12 KiB
Python

from datetime import datetime
from typing import Dict, Any, List
from flask import current_app
from langchain_core.exceptions import LangChainException
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableParallel, RunnablePassthrough
from common.langchain.outputs.base import OutputRegistry
from common.langchain.outputs.rag import RAGOutput
from common.utils.business_event_context import current_event
from .specialist_typing import SpecialistArguments, SpecialistResult
from ..chat_session_cache import CachedInteraction, get_chat_history
from ..retrievers.registry import RetrieverRegistry
from ..retrievers.base import BaseRetriever
from common.models.interaction import SpecialistRetriever, Specialist
from common.utils.model_utils import get_model_variables, create_language_template, replace_variable_in_template
from .base import BaseSpecialist
from .registry import SpecialistRegistry
from eveai_chat_workers.retrievers.retriever_typing import RetrieverArguments, RetrieverResult
class RAGSpecialist(BaseSpecialist):
"""
Standard Q&A RAG Specialist implementation that combines retriever results
with LLM processing to generate answers.
"""
def __init__(self, tenant_id: int, specialist_id: int, session_id: str):
super().__init__(tenant_id, specialist_id, session_id)
# Check and load the specialist
specialist = Specialist.query.get_or_404(specialist_id)
# Set the specific configuration for the RAG Specialist
self.specialist_context = specialist.configuration.get('specialist_context', '')
self.temperature = specialist.configuration.get('temperature', 0.3)
self.tuning = specialist.tuning
# Initialize retrievers
self.retrievers = self._initialize_retrievers()
# Initialize model variables
self.model_variables = get_model_variables(tenant_id)
@property
def type(self) -> str:
return "STANDARD_RAG"
def _initialize_retrievers(self) -> List[BaseRetriever]:
"""Initialize all retrievers associated with this specialist"""
retrievers = []
# Get retriever associations from database
specialist_retrievers = (
SpecialistRetriever.query
.filter_by(specialist_id=self.specialist_id)
.all()
)
self._log_tuning("_initialize_retrievers", {"Nr of retrievers": len(specialist_retrievers)})
for spec_retriever in specialist_retrievers:
# Get retriever configuration from database
retriever = spec_retriever.retriever
retriever_class = RetrieverRegistry.get_retriever_class(retriever.type)
self._log_tuning("_initialize_retrievers", {
"Retriever id": spec_retriever.retriever_id,
"Retriever Type": retriever.type,
"Retriever Class": str(retriever_class),
})
# Initialize retriever with its configuration
retrievers.append(
retriever_class(
tenant_id=self.tenant_id,
retriever_id=retriever.id,
)
)
return retrievers
@property
def required_templates(self) -> List[str]:
"""List of required templates for this specialist"""
return ['rag', 'history']
# def _detail_question(question, language, model_variables, session_id):
# retriever = EveAIHistoryRetriever(model_variables=model_variables, session_id=session_id)
# llm = model_variables['llm']
# template = model_variables['history_template']
# language_template = create_language_template(template, language)
# full_template = replace_variable_in_template(language_template, "{tenant_context}",
# model_variables['rag_context'])
# history_prompt = ChatPromptTemplate.from_template(full_template)
# setup_and_retrieval = RunnableParallel({"history": retriever, "question": RunnablePassthrough()})
# output_parser = StrOutputParser()
#
# chain = setup_and_retrieval | history_prompt | llm | output_parser
#
# try:
# answer = chain.invoke(question)
# return answer
# except LangChainException as e:
# current_app.logger.error(f'Error detailing question: {e}')
# raise
def _detail_question(self, language: str, question: str) -> str:
"""Detail question based on conversation history"""
try:
# Get cached session history
cached_session = get_chat_history(self.session_id)
# Format history for the prompt
formatted_history = "\n\n".join([
f"HUMAN:\n{interaction.specialist_results.get('detailed_query')}\n\n"
f"AI:\n{interaction.specialist_results.get('answer')}"
for interaction in cached_session.interactions
])
# Get LLM and template
llm = self.model_variables.get_llm(temperature=0.3)
template = self.model_variables.get_template('history')
language_template = create_language_template(template, language)
# Create prompt
history_prompt = ChatPromptTemplate.from_template(language_template)
# Create chain
chain = (
history_prompt |
llm |
StrOutputParser()
)
# Execute chain
detailed_question = chain.invoke({
"history": formatted_history,
"question": question
})
if self.tuning:
self._log_tuning("_detail_question", {
"cached_session_id": cached_session.session_id,
"cached_session.interactions": str(cached_session.interactions),
"original_question": question,
"history_used": formatted_history,
"detailed_question": detailed_question,
})
return detailed_question
except Exception as e:
current_app.logger.error(f"Error detailing question: {e}")
return question # Fallback to original question
def execute(self, arguments: SpecialistArguments) -> SpecialistResult:
"""
Execute the RAG specialist to generate an answer
"""
start_time = datetime.now()
try:
with current_event.create_span("Specialist Detail Question"):
# Get required arguments
language = arguments.language
query = arguments.query
detailed_question = self._detail_question(language, query)
# Log the start of retrieval process if tuning is enabled
with current_event.create_span("Specialist Retrieval"):
self._log_tuning("Starting context retrieval", {
"num_retrievers": len(self.retrievers),
"all arguments": arguments.model_dump(),
})
# Get retriever-specific arguments
retriever_arguments = arguments.retriever_arguments
# Collect context from all retrievers
all_context = []
for retriever in self.retrievers:
# Get arguments for this specific retriever
retriever_id = str(retriever.retriever_id)
if retriever_id not in retriever_arguments:
current_app.logger.error(f"Missing arguments for retriever {retriever_id}")
continue
# Get the retriever's arguments and update the query
current_retriever_args = retriever_arguments[retriever_id]
if isinstance(retriever_arguments[retriever_id], RetrieverArguments):
updated_args = current_retriever_args.model_dump()
updated_args['query'] = detailed_question
retriever_args = RetrieverArguments(**updated_args)
else:
# Create a new RetrieverArguments instance from the dictionary
current_retriever_args['query'] = detailed_question
retriever_args = RetrieverArguments(**current_retriever_args)
# Each retriever gets its own specific arguments
retriever_result = retriever.retrieve(retriever_args)
all_context.extend(retriever_result)
# Sort by similarity if available and get unique contexts
all_context.sort(key=lambda x: x.similarity, reverse=True)
unique_contexts = []
seen_chunks = set()
for ctx in all_context:
if ctx.chunk not in seen_chunks:
unique_contexts.append(ctx)
seen_chunks.add(ctx.chunk)
self._log_tuning("Context retrieval completed", {
"total_contexts": len(all_context),
"unique_contexts": len(unique_contexts),
"average_similarity": sum(ctx.similarity for ctx in unique_contexts) / len(
unique_contexts) if unique_contexts else 0
})
# Prepare context for LLM
formatted_context = "\n\n".join([
f"SOURCE: {ctx.metadata.document_id}\n\n{ctx.chunk}"
for ctx in unique_contexts
])
with current_event.create_span("Specialist RAG invocation"):
try:
# Get LLM with specified temperature
llm = self.model_variables.get_llm(temperature=self.temperature)
# Get template
template = self.model_variables.get_template('rag')
language_template = create_language_template(template, language)
full_template = replace_variable_in_template(
language_template,
"{tenant_context}",
self.specialist_context
)
if self.tuning:
self._log_tuning("Template preparation completed", {
"template": full_template,
"context": formatted_context,
"tenant_context": self.specialist_context,
})
# Create prompt
rag_prompt = ChatPromptTemplate.from_template(full_template)
# Setup chain components
setup_and_retrieval = RunnableParallel({
"context": lambda x: formatted_context,
"question": lambda x: x
})
# Get output schema for structured output
output_schema = OutputRegistry.get_schema(self.type)
structured_llm = llm.with_structured_output(output_schema)
chain = setup_and_retrieval | rag_prompt | structured_llm
raw_result = chain.invoke(detailed_question)
result = SpecialistResult.create_for_type(
"STANDARD_RAG",
detailed_query=detailed_question,
answer=raw_result.answer,
citations=[ctx.metadata.document_id for ctx in unique_contexts
if ctx.id in raw_result.citations],
insufficient_info=raw_result.insufficient_info
)
if self.tuning:
self._log_tuning("LLM chain execution completed", {
"Result": result.model_dump()
})
except Exception as e:
current_app.logger.error(f"Error in LLM processing: {e}")
if self.tuning:
self._log_tuning("LLM processing error", {"error": str(e)})
raise
return result
except Exception as e:
current_app.logger.error(f'Error in RAG specialist execution: {str(e)}')
raise
# Register the specialist type
SpecialistRegistry.register("STANDARD_RAG", RAGSpecialist)
OutputRegistry.register("STANDARD_RAG", RAGOutput)