- Introduction of dynamic Retrievers & Specialists

- Introduction of dynamic Processors
- Introduction of caching system
- Introduction of a better template manager
- Adaptation of ModelVariables to support dynamic Processors / Retrievers / Specialists
- Start adaptation of chat client
This commit is contained in:
Josako
2024-11-15 10:00:53 +01:00
parent 55a8a95f79
commit 1807435339
101 changed files with 4181 additions and 1764 deletions

View File

@@ -0,0 +1,5 @@
# Import all specialist implementations here to ensure registration
from . import rag_specialist
# List of all available specialist implementations
__all__ = ['rag_specialist']

View File

@@ -0,0 +1,50 @@
from abc import ABC, abstractmethod
from typing import Dict, Any
from flask import current_app
from config.logging_config import TuningLogger
from eveai_chat_workers.specialists.specialist_typing import SpecialistArguments, SpecialistResult
class BaseSpecialist(ABC):
"""Base class for all specialists"""
def __init__(self, tenant_id: int, specialist_id: int, session_id: str):
self.tenant_id = tenant_id
self.specialist_id = specialist_id
self.session_id = session_id
self.tuning = False
self.tuning_logger = None
self._setup_tuning_logger()
@property
@abstractmethod
def type(self) -> str:
"""The type of the specialist"""
pass
def _setup_tuning_logger(self):
try:
self.tuning_logger = TuningLogger(
'tuning',
tenant_id=self.tenant_id,
specialist_id=self.specialist_id,
)
# Verify logger is working with a test message
if self.tuning:
self.tuning_logger.log_tuning('specialist', "Tuning logger initialized")
except Exception as e:
current_app.logger.error(f"Failed to setup tuning logger: {str(e)}")
raise
def _log_tuning(self, message: str, data: Dict[str, Any] = None) -> None:
if self.tuning and self.tuning_logger:
try:
self.tuning_logger.log_tuning('specialist', message, data)
except Exception as e:
current_app.logger.error(f"Processor: Error in tuning logging: {e}")
@abstractmethod
def execute(self, arguments: SpecialistArguments) -> SpecialistResult:
"""Execute the specialist's logic"""
pass

View File

@@ -0,0 +1,289 @@
from datetime import datetime
from typing import Dict, Any, List
from flask import current_app
from langchain_core.exceptions import LangChainException
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableParallel, RunnablePassthrough
from common.langchain.outputs.base import OutputRegistry
from common.langchain.outputs.rag import RAGOutput
from common.utils.business_event_context import current_event
from .specialist_typing import SpecialistArguments, SpecialistResult
from ..chat_session_cache import CachedInteraction, get_chat_history
from ..retrievers.registry import RetrieverRegistry
from ..retrievers.base import BaseRetriever
from common.models.interaction import SpecialistRetriever, Specialist
from common.utils.model_utils import get_model_variables, create_language_template, replace_variable_in_template
from .base import BaseSpecialist
from .registry import SpecialistRegistry
from eveai_chat_workers.retrievers.retriever_typing import RetrieverArguments, RetrieverResult
class RAGSpecialist(BaseSpecialist):
"""
Standard Q&A RAG Specialist implementation that combines retriever results
with LLM processing to generate answers.
"""
def __init__(self, tenant_id: int, specialist_id: int, session_id: str):
super().__init__(tenant_id, specialist_id, session_id)
# Check and load the specialist
specialist = Specialist.query.get_or_404(specialist_id)
# Set the specific configuration for the RAG Specialist
self.specialist_context = specialist.configuration.get('specialist_context', '')
self.temperature = specialist.configuration.get('temperature', 0.3)
self.tuning = specialist.tuning
# Initialize retrievers
self.retrievers = self._initialize_retrievers()
# Initialize model variables
self.model_variables = get_model_variables(tenant_id)
@property
def type(self) -> str:
return "STANDARD_RAG"
def _initialize_retrievers(self) -> List[BaseRetriever]:
"""Initialize all retrievers associated with this specialist"""
retrievers = []
# Get retriever associations from database
specialist_retrievers = (
SpecialistRetriever.query
.filter_by(specialist_id=self.specialist_id)
.all()
)
self._log_tuning("_initialize_retrievers", {"Nr of retrievers": len(specialist_retrievers)})
for spec_retriever in specialist_retrievers:
# Get retriever configuration from database
retriever = spec_retriever.retriever
retriever_class = RetrieverRegistry.get_retriever_class(retriever.type)
self._log_tuning("_initialize_retrievers", {
"Retriever id": spec_retriever.retriever_id,
"Retriever Type": retriever.type,
"Retriever Class": str(retriever_class),
})
# Initialize retriever with its configuration
retrievers.append(
retriever_class(
tenant_id=self.tenant_id,
retriever_id=retriever.id,
)
)
return retrievers
@property
def required_templates(self) -> List[str]:
"""List of required templates for this specialist"""
return ['rag', 'history']
# def _detail_question(question, language, model_variables, session_id):
# retriever = EveAIHistoryRetriever(model_variables=model_variables, session_id=session_id)
# llm = model_variables['llm']
# template = model_variables['history_template']
# language_template = create_language_template(template, language)
# full_template = replace_variable_in_template(language_template, "{tenant_context}",
# model_variables['rag_context'])
# history_prompt = ChatPromptTemplate.from_template(full_template)
# setup_and_retrieval = RunnableParallel({"history": retriever, "question": RunnablePassthrough()})
# output_parser = StrOutputParser()
#
# chain = setup_and_retrieval | history_prompt | llm | output_parser
#
# try:
# answer = chain.invoke(question)
# return answer
# except LangChainException as e:
# current_app.logger.error(f'Error detailing question: {e}')
# raise
def _detail_question(self, language: str, question: str) -> str:
"""Detail question based on conversation history"""
try:
# Get cached session history
cached_session = get_chat_history(self.session_id)
# Format history for the prompt
formatted_history = "\n\n".join([
f"HUMAN:\n{interaction.specialist_results.get('detailed_query')}\n\n"
f"AI:\n{interaction.specialist_results.get('answer')}"
for interaction in cached_session.interactions
])
# Get LLM and template
llm = self.model_variables.get_llm(temperature=0.3)
template = self.model_variables.get_template('history')
language_template = create_language_template(template, language)
# Create prompt
history_prompt = ChatPromptTemplate.from_template(language_template)
# Create chain
chain = (
history_prompt |
llm |
StrOutputParser()
)
# Execute chain
detailed_question = chain.invoke({
"history": formatted_history,
"question": question
})
if self.tuning:
self._log_tuning("_detail_question", {
"cached_session_id": cached_session.session_id,
"cached_session.interactions": str(cached_session.interactions),
"original_question": question,
"history_used": formatted_history,
"detailed_question": detailed_question,
})
return detailed_question
except Exception as e:
current_app.logger.error(f"Error detailing question: {e}")
return question # Fallback to original question
def execute(self, arguments: SpecialistArguments) -> SpecialistResult:
"""
Execute the RAG specialist to generate an answer
"""
start_time = datetime.now()
try:
with current_event.create_span("Specialist Detail Question"):
# Get required arguments
language = arguments.language
query = arguments.query
detailed_question = self._detail_question(language, query)
# Log the start of retrieval process if tuning is enabled
with current_event.create_span("Specialist Retrieval"):
self._log_tuning("Starting context retrieval", {
"num_retrievers": len(self.retrievers),
"all arguments": arguments.model_dump(),
})
# Get retriever-specific arguments
retriever_arguments = arguments.retriever_arguments
# Collect context from all retrievers
all_context = []
for retriever in self.retrievers:
# Get arguments for this specific retriever
retriever_id = str(retriever.retriever_id)
if retriever_id not in retriever_arguments:
current_app.logger.error(f"Missing arguments for retriever {retriever_id}")
continue
# Get the retriever's arguments and update the query
current_retriever_args = retriever_arguments[retriever_id]
if isinstance(retriever_arguments[retriever_id], RetrieverArguments):
updated_args = current_retriever_args.model_dump()
updated_args['query'] = detailed_question
retriever_args = RetrieverArguments(**updated_args)
else:
# Create a new RetrieverArguments instance from the dictionary
current_retriever_args['query'] = detailed_question
retriever_args = RetrieverArguments(**current_retriever_args)
# Each retriever gets its own specific arguments
retriever_result = retriever.retrieve(retriever_args)
all_context.extend(retriever_result)
# Sort by similarity if available and get unique contexts
all_context.sort(key=lambda x: x.similarity, reverse=True)
unique_contexts = []
seen_chunks = set()
for ctx in all_context:
if ctx.chunk not in seen_chunks:
unique_contexts.append(ctx)
seen_chunks.add(ctx.chunk)
self._log_tuning("Context retrieval completed", {
"total_contexts": len(all_context),
"unique_contexts": len(unique_contexts),
"average_similarity": sum(ctx.similarity for ctx in unique_contexts) / len(
unique_contexts) if unique_contexts else 0
})
# Prepare context for LLM
formatted_context = "\n\n".join([
f"SOURCE: {ctx.metadata.document_id}\n\n{ctx.chunk}"
for ctx in unique_contexts
])
with current_event.create_span("Specialist RAG invocation"):
try:
# Get LLM with specified temperature
llm = self.model_variables.get_llm(temperature=self.temperature)
# Get template
template = self.model_variables.get_template('rag')
language_template = create_language_template(template, language)
full_template = replace_variable_in_template(
language_template,
"{tenant_context}",
self.specialist_context
)
if self.tuning:
self._log_tuning("Template preparation completed", {
"template": full_template,
"context": formatted_context,
"tenant_context": self.specialist_context,
})
# Create prompt
rag_prompt = ChatPromptTemplate.from_template(full_template)
# Setup chain components
setup_and_retrieval = RunnableParallel({
"context": lambda x: formatted_context,
"question": lambda x: x
})
# Get output schema for structured output
output_schema = OutputRegistry.get_schema(self.type)
structured_llm = llm.with_structured_output(output_schema)
chain = setup_and_retrieval | rag_prompt | structured_llm
raw_result = chain.invoke(query)
result = SpecialistResult.create_for_type(
"STANDARD_RAG",
detailed_query=detailed_question,
answer=raw_result.answer,
citations=[ctx.metadata.document_id for ctx in unique_contexts
if ctx.id in raw_result.citations],
insufficient_info=raw_result.insufficient_info
)
if self.tuning:
self._log_tuning("LLM chain execution completed", {
"Result": result.model_dump()
})
except Exception as e:
current_app.logger.error(f"Error in LLM processing: {e}")
if self.tuning:
self._log_tuning("LLM processing error", {"error": str(e)})
raise
return result
except Exception as e:
current_app.logger.error(f'Error in RAG specialist execution: {str(e)}')
raise
# Register the specialist type
SpecialistRegistry.register("STANDARD_RAG", RAGSpecialist)
OutputRegistry.register("STANDARD_RAG", RAGOutput)

View File

@@ -0,0 +1,21 @@
from typing import Dict, Type
from .base import BaseSpecialist
class SpecialistRegistry:
"""Registry for specialist types"""
_registry: Dict[str, Type[BaseSpecialist]] = {}
@classmethod
def register(cls, specialist_type: str, specialist_class: Type[BaseSpecialist]):
"""Register a new specialist type"""
cls._registry[specialist_type] = specialist_class
@classmethod
def get_specialist_class(cls, specialist_type: str) -> Type[BaseSpecialist]:
"""Get the specialist class for a given type"""
if specialist_type not in cls._registry:
raise ValueError(f"Unknown specialist type: {specialist_type}")
return cls._registry[specialist_type]

View File

@@ -0,0 +1,144 @@
from typing import List, Dict, Any, Optional
from pydantic import BaseModel, Field, model_validator
from config.specialist_types import SPECIALIST_TYPES
from eveai_chat_workers.retrievers.retriever_typing import RetrieverArguments
class SpecialistArguments(BaseModel):
"""
Dynamic arguments for specialists, allowing arbitrary fields but validating required ones
based on SPECIALIST_TYPES configuration.
"""
type: str = Field(..., description="Type of specialist (e.g. STANDARD_RAG)")
retriever_arguments: Dict[str, Any] = Field(
default_factory=dict,
description="Arguments for each retriever, keyed by retriever ID"
)
# Allow any additional fields
model_config = {
"extra": "allow"
}
@model_validator(mode='after')
def validate_required_arguments(self) -> 'SpecialistArguments':
"""Validate that all required arguments for this specialist type are present"""
specialist_config = SPECIALIST_TYPES.get(self.type)
if not specialist_config:
raise ValueError(f"Unknown specialist type: {self.type}")
# Check required arguments from configuration
for arg_name, arg_config in specialist_config['arguments'].items():
if arg_config.get('required', False):
if not hasattr(self, arg_name):
raise ValueError(f"Missing required argument '{arg_name}' for {self.type}")
# Type validation
value = getattr(self, arg_name)
expected_type = arg_config['type']
if expected_type == 'str' and not isinstance(value, str):
raise ValueError(f"Argument '{arg_name}' must be a string")
elif expected_type == 'int' and not isinstance(value, int):
raise ValueError(f"Argument '{arg_name}' must be an integer")
return self
@classmethod
def create(cls, type_name: str, specialist_args: Dict[str, Any],
retriever_args: Dict[str, Dict[str, Any]]) -> 'SpecialistArguments':
"""
Factory method to create SpecialistArguments with validated retriever arguments
Args:
type_name: The specialist type (e.g., 'STANDARD_RAG')
specialist_args: Arguments specific to the specialist
retriever_args: Dictionary of retriever arguments keyed by retriever ID
Returns:
Validated SpecialistArguments instance
"""
# Convert raw retriever arguments to RetrieverArguments instances
validated_retriever_args = {}
for retriever_id, args in retriever_args.items():
# Ensure type is included in retriever arguments
if 'type' not in args:
raise ValueError(f"Retriever arguments for {retriever_id} must include 'type'")
validated_retriever_args[retriever_id] = RetrieverArguments(**args)
# Combine everything into the specialist arguments
return cls(
type=type_name,
**specialist_args,
retriever_arguments=validated_retriever_args
)
class SpecialistResult(BaseModel):
"""
Dynamic results from specialists, validating required fields based on
SPECIALIST_TYPES configuration.
"""
type: str = Field(..., description="Type of specialist (e.g. STANDARD_RAG)")
# Allow any additional fields
model_config = {
"extra": "allow"
}
@model_validator(mode='after')
def validate_required_results(self) -> 'SpecialistResult':
"""Validate that all required result fields for this specialist type are present"""
specialist_config = SPECIALIST_TYPES.get(self.type)
if not specialist_config:
raise ValueError(f"Unknown specialist type: {self.type}")
# Check required results from configuration
required_results = specialist_config.get('results', {})
for result_name, result_config in required_results.items():
if result_config.get('required', False):
if not hasattr(self, result_name):
raise ValueError(f"Missing required result '{result_name}' for {self.type}")
# Type validation
value = getattr(self, result_name)
expected_type = result_config['type']
# Validate based on type annotation
if expected_type == 'str' and not isinstance(value, str):
raise ValueError(f"Result '{result_name}' must be a string")
elif expected_type == 'bool' and not isinstance(value, bool):
raise ValueError(f"Result '{result_name}' must be a boolean")
elif expected_type == 'List[str]' and not (
isinstance(value, list) and all(isinstance(x, str) for x in value)):
raise ValueError(f"Result '{result_name}' must be a list of strings")
# Add other type validations as needed
return self
@classmethod
def create_for_type(cls, specialist_type: str, **results) -> 'SpecialistResult':
"""
Factory method to create a type-specific result
Args:
specialist_type: The type of specialist (e.g., 'STANDARD_RAG')
**results: The result values to include
Returns:
Validated SpecialistResult instance
Example:
For STANDARD_RAG:
result = SpecialistResult.create_for_type(
'STANDARD_RAG',
answer="The answer text",
citations=["doc1", "doc2"],
insufficient_info=False
)
"""
# Add the type to the results
results['type'] = specialist_type
# Create and validate the result
return cls(**results)