from datetime import datetime from typing import Dict, Any, List from flask import current_app from langchain_core.exceptions import LangChainException from langchain_core.output_parsers import StrOutputParser from langchain_core.prompts import ChatPromptTemplate from langchain_core.runnables import RunnableParallel, RunnablePassthrough from common.langchain.outputs.base import OutputRegistry from common.langchain.outputs.rag import RAGOutput from common.utils.business_event_context import current_event from .specialist_typing import SpecialistArguments, SpecialistResult from ..chat_session_cache import CachedInteraction, get_chat_history from ..retrievers.registry import RetrieverRegistry from ..retrievers.base import BaseRetriever from common.models.interaction import SpecialistRetriever, Specialist from common.utils.model_utils import get_model_variables, create_language_template, replace_variable_in_template from .base import BaseSpecialist from .registry import SpecialistRegistry from eveai_chat_workers.retrievers.retriever_typing import RetrieverArguments, RetrieverResult class RAGSpecialist(BaseSpecialist): """ Standard Q&A RAG Specialist implementation that combines retriever results with LLM processing to generate answers. """ def __init__(self, tenant_id: int, specialist_id: int, session_id: str): super().__init__(tenant_id, specialist_id, session_id) # Check and load the specialist specialist = Specialist.query.get_or_404(specialist_id) # Set the specific configuration for the RAG Specialist self.specialist_context = specialist.configuration.get('specialist_context', '') self.temperature = specialist.configuration.get('temperature', 0.3) self.tuning = specialist.tuning # Initialize retrievers self.retrievers = self._initialize_retrievers() # Initialize model variables self.model_variables = get_model_variables(tenant_id) @property def type(self) -> str: return "STANDARD_RAG" def _initialize_retrievers(self) -> List[BaseRetriever]: """Initialize all retrievers associated with this specialist""" retrievers = [] # Get retriever associations from database specialist_retrievers = ( SpecialistRetriever.query .filter_by(specialist_id=self.specialist_id) .all() ) self._log_tuning("_initialize_retrievers", {"Nr of retrievers": len(specialist_retrievers)}) for spec_retriever in specialist_retrievers: # Get retriever configuration from database retriever = spec_retriever.retriever retriever_class = RetrieverRegistry.get_retriever_class(retriever.type) self._log_tuning("_initialize_retrievers", { "Retriever id": spec_retriever.retriever_id, "Retriever Type": retriever.type, "Retriever Class": str(retriever_class), }) # Initialize retriever with its configuration retrievers.append( retriever_class( tenant_id=self.tenant_id, retriever_id=retriever.id, ) ) return retrievers @property def required_templates(self) -> List[str]: """List of required templates for this specialist""" return ['rag', 'history'] # def _detail_question(question, language, model_variables, session_id): # retriever = EveAIHistoryRetriever(model_variables=model_variables, session_id=session_id) # llm = model_variables['llm'] # template = model_variables['history_template'] # language_template = create_language_template(template, language) # full_template = replace_variable_in_template(language_template, "{tenant_context}", # model_variables['rag_context']) # history_prompt = ChatPromptTemplate.from_template(full_template) # setup_and_retrieval = RunnableParallel({"history": retriever, "question": RunnablePassthrough()}) # output_parser = StrOutputParser() # # chain = setup_and_retrieval | history_prompt | llm | output_parser # # try: # answer = chain.invoke(question) # return answer # except LangChainException as e: # current_app.logger.error(f'Error detailing question: {e}') # raise def _detail_question(self, language: str, question: str) -> str: """Detail question based on conversation history""" try: # Get cached session history cached_session = get_chat_history(self.session_id) # Format history for the prompt formatted_history = "\n\n".join([ f"HUMAN:\n{interaction.specialist_results.get('detailed_query')}\n\n" f"AI:\n{interaction.specialist_results.get('answer')}" for interaction in cached_session.interactions ]) # Get LLM and template llm = self.model_variables.get_llm(temperature=0.3) template = self.model_variables.get_template('history') language_template = create_language_template(template, language) # Create prompt history_prompt = ChatPromptTemplate.from_template(language_template) # Create chain chain = ( history_prompt | llm | StrOutputParser() ) # Execute chain detailed_question = chain.invoke({ "history": formatted_history, "question": question }) if self.tuning: self._log_tuning("_detail_question", { "cached_session_id": cached_session.session_id, "cached_session.interactions": str(cached_session.interactions), "original_question": question, "history_used": formatted_history, "detailed_question": detailed_question, }) return detailed_question except Exception as e: current_app.logger.error(f"Error detailing question: {e}") return question # Fallback to original question def execute(self, arguments: SpecialistArguments) -> SpecialistResult: """ Execute the RAG specialist to generate an answer """ start_time = datetime.now() try: with current_event.create_span("Specialist Detail Question"): # Get required arguments language = arguments.language query = arguments.query detailed_question = self._detail_question(language, query) # Log the start of retrieval process if tuning is enabled with current_event.create_span("Specialist Retrieval"): self._log_tuning("Starting context retrieval", { "num_retrievers": len(self.retrievers), "all arguments": arguments.model_dump(), }) # Get retriever-specific arguments retriever_arguments = arguments.retriever_arguments # Collect context from all retrievers all_context = [] for retriever in self.retrievers: # Get arguments for this specific retriever retriever_id = str(retriever.retriever_id) if retriever_id not in retriever_arguments: current_app.logger.error(f"Missing arguments for retriever {retriever_id}") continue # Get the retriever's arguments and update the query current_retriever_args = retriever_arguments[retriever_id] if isinstance(retriever_arguments[retriever_id], RetrieverArguments): updated_args = current_retriever_args.model_dump() updated_args['query'] = detailed_question retriever_args = RetrieverArguments(**updated_args) else: # Create a new RetrieverArguments instance from the dictionary current_retriever_args['query'] = detailed_question retriever_args = RetrieverArguments(**current_retriever_args) # Each retriever gets its own specific arguments retriever_result = retriever.retrieve(retriever_args) all_context.extend(retriever_result) # Sort by similarity if available and get unique contexts all_context.sort(key=lambda x: x.similarity, reverse=True) unique_contexts = [] seen_chunks = set() for ctx in all_context: if ctx.chunk not in seen_chunks: unique_contexts.append(ctx) seen_chunks.add(ctx.chunk) self._log_tuning("Context retrieval completed", { "total_contexts": len(all_context), "unique_contexts": len(unique_contexts), "average_similarity": sum(ctx.similarity for ctx in unique_contexts) / len( unique_contexts) if unique_contexts else 0 }) # Prepare context for LLM formatted_context = "\n\n".join([ f"SOURCE: {ctx.metadata.document_id}\n\n{ctx.chunk}" for ctx in unique_contexts ]) with current_event.create_span("Specialist RAG invocation"): try: # Get LLM with specified temperature llm = self.model_variables.get_llm(temperature=self.temperature) # Get template template = self.model_variables.get_template('rag') language_template = create_language_template(template, language) full_template = replace_variable_in_template( language_template, "{tenant_context}", self.specialist_context ) if self.tuning: self._log_tuning("Template preparation completed", { "template": full_template, "context": formatted_context, "tenant_context": self.specialist_context, }) # Create prompt rag_prompt = ChatPromptTemplate.from_template(full_template) # Setup chain components setup_and_retrieval = RunnableParallel({ "context": lambda x: formatted_context, "question": lambda x: x }) # Get output schema for structured output output_schema = OutputRegistry.get_schema(self.type) structured_llm = llm.with_structured_output(output_schema) chain = setup_and_retrieval | rag_prompt | structured_llm raw_result = chain.invoke(query) result = SpecialistResult.create_for_type( "STANDARD_RAG", detailed_query=detailed_question, answer=raw_result.answer, citations=[ctx.metadata.document_id for ctx in unique_contexts if ctx.id in raw_result.citations], insufficient_info=raw_result.insufficient_info ) if self.tuning: self._log_tuning("LLM chain execution completed", { "Result": result.model_dump() }) except Exception as e: current_app.logger.error(f"Error in LLM processing: {e}") if self.tuning: self._log_tuning("LLM processing error", {"error": str(e)}) raise return result except Exception as e: current_app.logger.error(f'Error in RAG specialist execution: {str(e)}') raise # Register the specialist type SpecialistRegistry.register("STANDARD_RAG", RAGSpecialist) OutputRegistry.register("STANDARD_RAG", RAGOutput)