- Introduction of dynamic Retrievers & Specialists
- Introduction of dynamic Processors - Introduction of caching system - Introduction of a better template manager - Adaptation of ModelVariables to support dynamic Processors / Retrievers / Specialists - Start adaptation of chat client
This commit is contained in:
289
eveai_chat_workers/specialists/rag_specialist.py
Normal file
289
eveai_chat_workers/specialists/rag_specialist.py
Normal file
@@ -0,0 +1,289 @@
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, List
|
||||
from flask import current_app
|
||||
from langchain_core.exceptions import LangChainException
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_core.runnables import RunnableParallel, RunnablePassthrough
|
||||
|
||||
from common.langchain.outputs.base import OutputRegistry
|
||||
from common.langchain.outputs.rag import RAGOutput
|
||||
from common.utils.business_event_context import current_event
|
||||
from .specialist_typing import SpecialistArguments, SpecialistResult
|
||||
from ..chat_session_cache import CachedInteraction, get_chat_history
|
||||
from ..retrievers.registry import RetrieverRegistry
|
||||
from ..retrievers.base import BaseRetriever
|
||||
from common.models.interaction import SpecialistRetriever, Specialist
|
||||
from common.utils.model_utils import get_model_variables, create_language_template, replace_variable_in_template
|
||||
from .base import BaseSpecialist
|
||||
from .registry import SpecialistRegistry
|
||||
from eveai_chat_workers.retrievers.retriever_typing import RetrieverArguments, RetrieverResult
|
||||
|
||||
|
||||
class RAGSpecialist(BaseSpecialist):
|
||||
"""
|
||||
Standard Q&A RAG Specialist implementation that combines retriever results
|
||||
with LLM processing to generate answers.
|
||||
"""
|
||||
def __init__(self, tenant_id: int, specialist_id: int, session_id: str):
|
||||
super().__init__(tenant_id, specialist_id, session_id)
|
||||
|
||||
# Check and load the specialist
|
||||
specialist = Specialist.query.get_or_404(specialist_id)
|
||||
# Set the specific configuration for the RAG Specialist
|
||||
self.specialist_context = specialist.configuration.get('specialist_context', '')
|
||||
self.temperature = specialist.configuration.get('temperature', 0.3)
|
||||
self.tuning = specialist.tuning
|
||||
|
||||
# Initialize retrievers
|
||||
self.retrievers = self._initialize_retrievers()
|
||||
|
||||
# Initialize model variables
|
||||
self.model_variables = get_model_variables(tenant_id)
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return "STANDARD_RAG"
|
||||
|
||||
def _initialize_retrievers(self) -> List[BaseRetriever]:
|
||||
"""Initialize all retrievers associated with this specialist"""
|
||||
retrievers = []
|
||||
|
||||
# Get retriever associations from database
|
||||
specialist_retrievers = (
|
||||
SpecialistRetriever.query
|
||||
.filter_by(specialist_id=self.specialist_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
self._log_tuning("_initialize_retrievers", {"Nr of retrievers": len(specialist_retrievers)})
|
||||
|
||||
for spec_retriever in specialist_retrievers:
|
||||
# Get retriever configuration from database
|
||||
retriever = spec_retriever.retriever
|
||||
retriever_class = RetrieverRegistry.get_retriever_class(retriever.type)
|
||||
self._log_tuning("_initialize_retrievers", {
|
||||
"Retriever id": spec_retriever.retriever_id,
|
||||
"Retriever Type": retriever.type,
|
||||
"Retriever Class": str(retriever_class),
|
||||
})
|
||||
|
||||
# Initialize retriever with its configuration
|
||||
retrievers.append(
|
||||
retriever_class(
|
||||
tenant_id=self.tenant_id,
|
||||
retriever_id=retriever.id,
|
||||
)
|
||||
)
|
||||
|
||||
return retrievers
|
||||
|
||||
@property
|
||||
def required_templates(self) -> List[str]:
|
||||
"""List of required templates for this specialist"""
|
||||
return ['rag', 'history']
|
||||
|
||||
# def _detail_question(question, language, model_variables, session_id):
|
||||
# retriever = EveAIHistoryRetriever(model_variables=model_variables, session_id=session_id)
|
||||
# llm = model_variables['llm']
|
||||
# template = model_variables['history_template']
|
||||
# language_template = create_language_template(template, language)
|
||||
# full_template = replace_variable_in_template(language_template, "{tenant_context}",
|
||||
# model_variables['rag_context'])
|
||||
# history_prompt = ChatPromptTemplate.from_template(full_template)
|
||||
# setup_and_retrieval = RunnableParallel({"history": retriever, "question": RunnablePassthrough()})
|
||||
# output_parser = StrOutputParser()
|
||||
#
|
||||
# chain = setup_and_retrieval | history_prompt | llm | output_parser
|
||||
#
|
||||
# try:
|
||||
# answer = chain.invoke(question)
|
||||
# return answer
|
||||
# except LangChainException as e:
|
||||
# current_app.logger.error(f'Error detailing question: {e}')
|
||||
# raise
|
||||
|
||||
def _detail_question(self, language: str, question: str) -> str:
|
||||
"""Detail question based on conversation history"""
|
||||
try:
|
||||
# Get cached session history
|
||||
cached_session = get_chat_history(self.session_id)
|
||||
|
||||
# Format history for the prompt
|
||||
formatted_history = "\n\n".join([
|
||||
f"HUMAN:\n{interaction.specialist_results.get('detailed_query')}\n\n"
|
||||
f"AI:\n{interaction.specialist_results.get('answer')}"
|
||||
for interaction in cached_session.interactions
|
||||
])
|
||||
|
||||
# Get LLM and template
|
||||
llm = self.model_variables.get_llm(temperature=0.3)
|
||||
template = self.model_variables.get_template('history')
|
||||
language_template = create_language_template(template, language)
|
||||
|
||||
# Create prompt
|
||||
history_prompt = ChatPromptTemplate.from_template(language_template)
|
||||
|
||||
# Create chain
|
||||
chain = (
|
||||
history_prompt |
|
||||
llm |
|
||||
StrOutputParser()
|
||||
)
|
||||
|
||||
# Execute chain
|
||||
detailed_question = chain.invoke({
|
||||
"history": formatted_history,
|
||||
"question": question
|
||||
})
|
||||
|
||||
if self.tuning:
|
||||
self._log_tuning("_detail_question", {
|
||||
"cached_session_id": cached_session.session_id,
|
||||
"cached_session.interactions": str(cached_session.interactions),
|
||||
"original_question": question,
|
||||
"history_used": formatted_history,
|
||||
"detailed_question": detailed_question,
|
||||
})
|
||||
|
||||
return detailed_question
|
||||
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error detailing question: {e}")
|
||||
return question # Fallback to original question
|
||||
|
||||
def execute(self, arguments: SpecialistArguments) -> SpecialistResult:
|
||||
"""
|
||||
Execute the RAG specialist to generate an answer
|
||||
"""
|
||||
start_time = datetime.now()
|
||||
|
||||
try:
|
||||
with current_event.create_span("Specialist Detail Question"):
|
||||
# Get required arguments
|
||||
language = arguments.language
|
||||
query = arguments.query
|
||||
detailed_question = self._detail_question(language, query)
|
||||
|
||||
# Log the start of retrieval process if tuning is enabled
|
||||
with current_event.create_span("Specialist Retrieval"):
|
||||
self._log_tuning("Starting context retrieval", {
|
||||
"num_retrievers": len(self.retrievers),
|
||||
"all arguments": arguments.model_dump(),
|
||||
})
|
||||
|
||||
# Get retriever-specific arguments
|
||||
retriever_arguments = arguments.retriever_arguments
|
||||
|
||||
# Collect context from all retrievers
|
||||
all_context = []
|
||||
for retriever in self.retrievers:
|
||||
# Get arguments for this specific retriever
|
||||
retriever_id = str(retriever.retriever_id)
|
||||
if retriever_id not in retriever_arguments:
|
||||
current_app.logger.error(f"Missing arguments for retriever {retriever_id}")
|
||||
continue
|
||||
|
||||
# Get the retriever's arguments and update the query
|
||||
current_retriever_args = retriever_arguments[retriever_id]
|
||||
if isinstance(retriever_arguments[retriever_id], RetrieverArguments):
|
||||
updated_args = current_retriever_args.model_dump()
|
||||
updated_args['query'] = detailed_question
|
||||
retriever_args = RetrieverArguments(**updated_args)
|
||||
else:
|
||||
# Create a new RetrieverArguments instance from the dictionary
|
||||
current_retriever_args['query'] = detailed_question
|
||||
retriever_args = RetrieverArguments(**current_retriever_args)
|
||||
|
||||
# Each retriever gets its own specific arguments
|
||||
retriever_result = retriever.retrieve(retriever_args)
|
||||
all_context.extend(retriever_result)
|
||||
|
||||
# Sort by similarity if available and get unique contexts
|
||||
all_context.sort(key=lambda x: x.similarity, reverse=True)
|
||||
unique_contexts = []
|
||||
seen_chunks = set()
|
||||
for ctx in all_context:
|
||||
if ctx.chunk not in seen_chunks:
|
||||
unique_contexts.append(ctx)
|
||||
seen_chunks.add(ctx.chunk)
|
||||
|
||||
self._log_tuning("Context retrieval completed", {
|
||||
"total_contexts": len(all_context),
|
||||
"unique_contexts": len(unique_contexts),
|
||||
"average_similarity": sum(ctx.similarity for ctx in unique_contexts) / len(
|
||||
unique_contexts) if unique_contexts else 0
|
||||
})
|
||||
|
||||
# Prepare context for LLM
|
||||
formatted_context = "\n\n".join([
|
||||
f"SOURCE: {ctx.metadata.document_id}\n\n{ctx.chunk}"
|
||||
for ctx in unique_contexts
|
||||
])
|
||||
|
||||
with current_event.create_span("Specialist RAG invocation"):
|
||||
try:
|
||||
# Get LLM with specified temperature
|
||||
llm = self.model_variables.get_llm(temperature=self.temperature)
|
||||
|
||||
# Get template
|
||||
template = self.model_variables.get_template('rag')
|
||||
language_template = create_language_template(template, language)
|
||||
full_template = replace_variable_in_template(
|
||||
language_template,
|
||||
"{tenant_context}",
|
||||
self.specialist_context
|
||||
)
|
||||
|
||||
if self.tuning:
|
||||
self._log_tuning("Template preparation completed", {
|
||||
"template": full_template,
|
||||
"context": formatted_context,
|
||||
"tenant_context": self.specialist_context,
|
||||
})
|
||||
|
||||
# Create prompt
|
||||
rag_prompt = ChatPromptTemplate.from_template(full_template)
|
||||
|
||||
# Setup chain components
|
||||
setup_and_retrieval = RunnableParallel({
|
||||
"context": lambda x: formatted_context,
|
||||
"question": lambda x: x
|
||||
})
|
||||
|
||||
# Get output schema for structured output
|
||||
output_schema = OutputRegistry.get_schema(self.type)
|
||||
structured_llm = llm.with_structured_output(output_schema)
|
||||
chain = setup_and_retrieval | rag_prompt | structured_llm
|
||||
|
||||
raw_result = chain.invoke(query)
|
||||
result = SpecialistResult.create_for_type(
|
||||
"STANDARD_RAG",
|
||||
detailed_query=detailed_question,
|
||||
answer=raw_result.answer,
|
||||
citations=[ctx.metadata.document_id for ctx in unique_contexts
|
||||
if ctx.id in raw_result.citations],
|
||||
insufficient_info=raw_result.insufficient_info
|
||||
)
|
||||
|
||||
if self.tuning:
|
||||
self._log_tuning("LLM chain execution completed", {
|
||||
"Result": result.model_dump()
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error in LLM processing: {e}")
|
||||
if self.tuning:
|
||||
self._log_tuning("LLM processing error", {"error": str(e)})
|
||||
raise
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
current_app.logger.error(f'Error in RAG specialist execution: {str(e)}')
|
||||
raise
|
||||
|
||||
|
||||
# Register the specialist type
|
||||
SpecialistRegistry.register("STANDARD_RAG", RAGSpecialist)
|
||||
OutputRegistry.register("STANDARD_RAG", RAGOutput)
|
||||
Reference in New Issue
Block a user