- Introduction of dynamic Processors - Introduction of caching system - Introduction of a better template manager - Adaptation of ModelVariables to support dynamic Processors / Retrievers / Specialists - Start adaptation of chat client
290 lines
12 KiB
Python
290 lines
12 KiB
Python
from datetime import datetime
|
|
from typing import Dict, Any, List
|
|
from flask import current_app
|
|
from langchain_core.exceptions import LangChainException
|
|
from langchain_core.output_parsers import StrOutputParser
|
|
from langchain_core.prompts import ChatPromptTemplate
|
|
from langchain_core.runnables import RunnableParallel, RunnablePassthrough
|
|
|
|
from common.langchain.outputs.base import OutputRegistry
|
|
from common.langchain.outputs.rag import RAGOutput
|
|
from common.utils.business_event_context import current_event
|
|
from .specialist_typing import SpecialistArguments, SpecialistResult
|
|
from ..chat_session_cache import CachedInteraction, get_chat_history
|
|
from ..retrievers.registry import RetrieverRegistry
|
|
from ..retrievers.base import BaseRetriever
|
|
from common.models.interaction import SpecialistRetriever, Specialist
|
|
from common.utils.model_utils import get_model_variables, create_language_template, replace_variable_in_template
|
|
from .base import BaseSpecialist
|
|
from .registry import SpecialistRegistry
|
|
from eveai_chat_workers.retrievers.retriever_typing import RetrieverArguments, RetrieverResult
|
|
|
|
|
|
class RAGSpecialist(BaseSpecialist):
|
|
"""
|
|
Standard Q&A RAG Specialist implementation that combines retriever results
|
|
with LLM processing to generate answers.
|
|
"""
|
|
def __init__(self, tenant_id: int, specialist_id: int, session_id: str):
|
|
super().__init__(tenant_id, specialist_id, session_id)
|
|
|
|
# Check and load the specialist
|
|
specialist = Specialist.query.get_or_404(specialist_id)
|
|
# Set the specific configuration for the RAG Specialist
|
|
self.specialist_context = specialist.configuration.get('specialist_context', '')
|
|
self.temperature = specialist.configuration.get('temperature', 0.3)
|
|
self.tuning = specialist.tuning
|
|
|
|
# Initialize retrievers
|
|
self.retrievers = self._initialize_retrievers()
|
|
|
|
# Initialize model variables
|
|
self.model_variables = get_model_variables(tenant_id)
|
|
|
|
@property
|
|
def type(self) -> str:
|
|
return "STANDARD_RAG"
|
|
|
|
def _initialize_retrievers(self) -> List[BaseRetriever]:
|
|
"""Initialize all retrievers associated with this specialist"""
|
|
retrievers = []
|
|
|
|
# Get retriever associations from database
|
|
specialist_retrievers = (
|
|
SpecialistRetriever.query
|
|
.filter_by(specialist_id=self.specialist_id)
|
|
.all()
|
|
)
|
|
|
|
self._log_tuning("_initialize_retrievers", {"Nr of retrievers": len(specialist_retrievers)})
|
|
|
|
for spec_retriever in specialist_retrievers:
|
|
# Get retriever configuration from database
|
|
retriever = spec_retriever.retriever
|
|
retriever_class = RetrieverRegistry.get_retriever_class(retriever.type)
|
|
self._log_tuning("_initialize_retrievers", {
|
|
"Retriever id": spec_retriever.retriever_id,
|
|
"Retriever Type": retriever.type,
|
|
"Retriever Class": str(retriever_class),
|
|
})
|
|
|
|
# Initialize retriever with its configuration
|
|
retrievers.append(
|
|
retriever_class(
|
|
tenant_id=self.tenant_id,
|
|
retriever_id=retriever.id,
|
|
)
|
|
)
|
|
|
|
return retrievers
|
|
|
|
@property
|
|
def required_templates(self) -> List[str]:
|
|
"""List of required templates for this specialist"""
|
|
return ['rag', 'history']
|
|
|
|
# def _detail_question(question, language, model_variables, session_id):
|
|
# retriever = EveAIHistoryRetriever(model_variables=model_variables, session_id=session_id)
|
|
# llm = model_variables['llm']
|
|
# template = model_variables['history_template']
|
|
# language_template = create_language_template(template, language)
|
|
# full_template = replace_variable_in_template(language_template, "{tenant_context}",
|
|
# model_variables['rag_context'])
|
|
# history_prompt = ChatPromptTemplate.from_template(full_template)
|
|
# setup_and_retrieval = RunnableParallel({"history": retriever, "question": RunnablePassthrough()})
|
|
# output_parser = StrOutputParser()
|
|
#
|
|
# chain = setup_and_retrieval | history_prompt | llm | output_parser
|
|
#
|
|
# try:
|
|
# answer = chain.invoke(question)
|
|
# return answer
|
|
# except LangChainException as e:
|
|
# current_app.logger.error(f'Error detailing question: {e}')
|
|
# raise
|
|
|
|
def _detail_question(self, language: str, question: str) -> str:
|
|
"""Detail question based on conversation history"""
|
|
try:
|
|
# Get cached session history
|
|
cached_session = get_chat_history(self.session_id)
|
|
|
|
# Format history for the prompt
|
|
formatted_history = "\n\n".join([
|
|
f"HUMAN:\n{interaction.specialist_results.get('detailed_query')}\n\n"
|
|
f"AI:\n{interaction.specialist_results.get('answer')}"
|
|
for interaction in cached_session.interactions
|
|
])
|
|
|
|
# Get LLM and template
|
|
llm = self.model_variables.get_llm(temperature=0.3)
|
|
template = self.model_variables.get_template('history')
|
|
language_template = create_language_template(template, language)
|
|
|
|
# Create prompt
|
|
history_prompt = ChatPromptTemplate.from_template(language_template)
|
|
|
|
# Create chain
|
|
chain = (
|
|
history_prompt |
|
|
llm |
|
|
StrOutputParser()
|
|
)
|
|
|
|
# Execute chain
|
|
detailed_question = chain.invoke({
|
|
"history": formatted_history,
|
|
"question": question
|
|
})
|
|
|
|
if self.tuning:
|
|
self._log_tuning("_detail_question", {
|
|
"cached_session_id": cached_session.session_id,
|
|
"cached_session.interactions": str(cached_session.interactions),
|
|
"original_question": question,
|
|
"history_used": formatted_history,
|
|
"detailed_question": detailed_question,
|
|
})
|
|
|
|
return detailed_question
|
|
|
|
except Exception as e:
|
|
current_app.logger.error(f"Error detailing question: {e}")
|
|
return question # Fallback to original question
|
|
|
|
def execute(self, arguments: SpecialistArguments) -> SpecialistResult:
|
|
"""
|
|
Execute the RAG specialist to generate an answer
|
|
"""
|
|
start_time = datetime.now()
|
|
|
|
try:
|
|
with current_event.create_span("Specialist Detail Question"):
|
|
# Get required arguments
|
|
language = arguments.language
|
|
query = arguments.query
|
|
detailed_question = self._detail_question(language, query)
|
|
|
|
# Log the start of retrieval process if tuning is enabled
|
|
with current_event.create_span("Specialist Retrieval"):
|
|
self._log_tuning("Starting context retrieval", {
|
|
"num_retrievers": len(self.retrievers),
|
|
"all arguments": arguments.model_dump(),
|
|
})
|
|
|
|
# Get retriever-specific arguments
|
|
retriever_arguments = arguments.retriever_arguments
|
|
|
|
# Collect context from all retrievers
|
|
all_context = []
|
|
for retriever in self.retrievers:
|
|
# Get arguments for this specific retriever
|
|
retriever_id = str(retriever.retriever_id)
|
|
if retriever_id not in retriever_arguments:
|
|
current_app.logger.error(f"Missing arguments for retriever {retriever_id}")
|
|
continue
|
|
|
|
# Get the retriever's arguments and update the query
|
|
current_retriever_args = retriever_arguments[retriever_id]
|
|
if isinstance(retriever_arguments[retriever_id], RetrieverArguments):
|
|
updated_args = current_retriever_args.model_dump()
|
|
updated_args['query'] = detailed_question
|
|
retriever_args = RetrieverArguments(**updated_args)
|
|
else:
|
|
# Create a new RetrieverArguments instance from the dictionary
|
|
current_retriever_args['query'] = detailed_question
|
|
retriever_args = RetrieverArguments(**current_retriever_args)
|
|
|
|
# Each retriever gets its own specific arguments
|
|
retriever_result = retriever.retrieve(retriever_args)
|
|
all_context.extend(retriever_result)
|
|
|
|
# Sort by similarity if available and get unique contexts
|
|
all_context.sort(key=lambda x: x.similarity, reverse=True)
|
|
unique_contexts = []
|
|
seen_chunks = set()
|
|
for ctx in all_context:
|
|
if ctx.chunk not in seen_chunks:
|
|
unique_contexts.append(ctx)
|
|
seen_chunks.add(ctx.chunk)
|
|
|
|
self._log_tuning("Context retrieval completed", {
|
|
"total_contexts": len(all_context),
|
|
"unique_contexts": len(unique_contexts),
|
|
"average_similarity": sum(ctx.similarity for ctx in unique_contexts) / len(
|
|
unique_contexts) if unique_contexts else 0
|
|
})
|
|
|
|
# Prepare context for LLM
|
|
formatted_context = "\n\n".join([
|
|
f"SOURCE: {ctx.metadata.document_id}\n\n{ctx.chunk}"
|
|
for ctx in unique_contexts
|
|
])
|
|
|
|
with current_event.create_span("Specialist RAG invocation"):
|
|
try:
|
|
# Get LLM with specified temperature
|
|
llm = self.model_variables.get_llm(temperature=self.temperature)
|
|
|
|
# Get template
|
|
template = self.model_variables.get_template('rag')
|
|
language_template = create_language_template(template, language)
|
|
full_template = replace_variable_in_template(
|
|
language_template,
|
|
"{tenant_context}",
|
|
self.specialist_context
|
|
)
|
|
|
|
if self.tuning:
|
|
self._log_tuning("Template preparation completed", {
|
|
"template": full_template,
|
|
"context": formatted_context,
|
|
"tenant_context": self.specialist_context,
|
|
})
|
|
|
|
# Create prompt
|
|
rag_prompt = ChatPromptTemplate.from_template(full_template)
|
|
|
|
# Setup chain components
|
|
setup_and_retrieval = RunnableParallel({
|
|
"context": lambda x: formatted_context,
|
|
"question": lambda x: x
|
|
})
|
|
|
|
# Get output schema for structured output
|
|
output_schema = OutputRegistry.get_schema(self.type)
|
|
structured_llm = llm.with_structured_output(output_schema)
|
|
chain = setup_and_retrieval | rag_prompt | structured_llm
|
|
|
|
raw_result = chain.invoke(query)
|
|
result = SpecialistResult.create_for_type(
|
|
"STANDARD_RAG",
|
|
detailed_query=detailed_question,
|
|
answer=raw_result.answer,
|
|
citations=[ctx.metadata.document_id for ctx in unique_contexts
|
|
if ctx.id in raw_result.citations],
|
|
insufficient_info=raw_result.insufficient_info
|
|
)
|
|
|
|
if self.tuning:
|
|
self._log_tuning("LLM chain execution completed", {
|
|
"Result": result.model_dump()
|
|
})
|
|
|
|
except Exception as e:
|
|
current_app.logger.error(f"Error in LLM processing: {e}")
|
|
if self.tuning:
|
|
self._log_tuning("LLM processing error", {"error": str(e)})
|
|
raise
|
|
|
|
return result
|
|
|
|
except Exception as e:
|
|
current_app.logger.error(f'Error in RAG specialist execution: {str(e)}')
|
|
raise
|
|
|
|
|
|
# Register the specialist type
|
|
SpecialistRegistry.register("STANDARD_RAG", RAGSpecialist)
|
|
OutputRegistry.register("STANDARD_RAG", RAGOutput)
|