- Introduction of dynamic Retrievers & Specialists
- Introduction of dynamic Processors - Introduction of caching system - Introduction of a better template manager - Adaptation of ModelVariables to support dynamic Processors / Retrievers / Specialists - Start adaptation of chat client
This commit is contained in:
5
eveai_chat_workers/specialists/__init__.py
Normal file
5
eveai_chat_workers/specialists/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
# Import all specialist implementations here to ensure registration
|
||||
from . import rag_specialist
|
||||
|
||||
# List of all available specialist implementations
|
||||
__all__ = ['rag_specialist']
|
||||
50
eveai_chat_workers/specialists/base.py
Normal file
50
eveai_chat_workers/specialists/base.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Any
|
||||
from flask import current_app
|
||||
|
||||
from config.logging_config import TuningLogger
|
||||
from eveai_chat_workers.specialists.specialist_typing import SpecialistArguments, SpecialistResult
|
||||
|
||||
|
||||
class BaseSpecialist(ABC):
|
||||
"""Base class for all specialists"""
|
||||
|
||||
def __init__(self, tenant_id: int, specialist_id: int, session_id: str):
|
||||
self.tenant_id = tenant_id
|
||||
self.specialist_id = specialist_id
|
||||
self.session_id = session_id
|
||||
self.tuning = False
|
||||
self.tuning_logger = None
|
||||
self._setup_tuning_logger()
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def type(self) -> str:
|
||||
"""The type of the specialist"""
|
||||
pass
|
||||
|
||||
def _setup_tuning_logger(self):
|
||||
try:
|
||||
self.tuning_logger = TuningLogger(
|
||||
'tuning',
|
||||
tenant_id=self.tenant_id,
|
||||
specialist_id=self.specialist_id,
|
||||
)
|
||||
# Verify logger is working with a test message
|
||||
if self.tuning:
|
||||
self.tuning_logger.log_tuning('specialist', "Tuning logger initialized")
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Failed to setup tuning logger: {str(e)}")
|
||||
raise
|
||||
|
||||
def _log_tuning(self, message: str, data: Dict[str, Any] = None) -> None:
|
||||
if self.tuning and self.tuning_logger:
|
||||
try:
|
||||
self.tuning_logger.log_tuning('specialist', message, data)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Processor: Error in tuning logging: {e}")
|
||||
|
||||
@abstractmethod
|
||||
def execute(self, arguments: SpecialistArguments) -> SpecialistResult:
|
||||
"""Execute the specialist's logic"""
|
||||
pass
|
||||
289
eveai_chat_workers/specialists/rag_specialist.py
Normal file
289
eveai_chat_workers/specialists/rag_specialist.py
Normal file
@@ -0,0 +1,289 @@
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, List
|
||||
from flask import current_app
|
||||
from langchain_core.exceptions import LangChainException
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_core.runnables import RunnableParallel, RunnablePassthrough
|
||||
|
||||
from common.langchain.outputs.base import OutputRegistry
|
||||
from common.langchain.outputs.rag import RAGOutput
|
||||
from common.utils.business_event_context import current_event
|
||||
from .specialist_typing import SpecialistArguments, SpecialistResult
|
||||
from ..chat_session_cache import CachedInteraction, get_chat_history
|
||||
from ..retrievers.registry import RetrieverRegistry
|
||||
from ..retrievers.base import BaseRetriever
|
||||
from common.models.interaction import SpecialistRetriever, Specialist
|
||||
from common.utils.model_utils import get_model_variables, create_language_template, replace_variable_in_template
|
||||
from .base import BaseSpecialist
|
||||
from .registry import SpecialistRegistry
|
||||
from eveai_chat_workers.retrievers.retriever_typing import RetrieverArguments, RetrieverResult
|
||||
|
||||
|
||||
class RAGSpecialist(BaseSpecialist):
|
||||
"""
|
||||
Standard Q&A RAG Specialist implementation that combines retriever results
|
||||
with LLM processing to generate answers.
|
||||
"""
|
||||
def __init__(self, tenant_id: int, specialist_id: int, session_id: str):
|
||||
super().__init__(tenant_id, specialist_id, session_id)
|
||||
|
||||
# Check and load the specialist
|
||||
specialist = Specialist.query.get_or_404(specialist_id)
|
||||
# Set the specific configuration for the RAG Specialist
|
||||
self.specialist_context = specialist.configuration.get('specialist_context', '')
|
||||
self.temperature = specialist.configuration.get('temperature', 0.3)
|
||||
self.tuning = specialist.tuning
|
||||
|
||||
# Initialize retrievers
|
||||
self.retrievers = self._initialize_retrievers()
|
||||
|
||||
# Initialize model variables
|
||||
self.model_variables = get_model_variables(tenant_id)
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return "STANDARD_RAG"
|
||||
|
||||
def _initialize_retrievers(self) -> List[BaseRetriever]:
|
||||
"""Initialize all retrievers associated with this specialist"""
|
||||
retrievers = []
|
||||
|
||||
# Get retriever associations from database
|
||||
specialist_retrievers = (
|
||||
SpecialistRetriever.query
|
||||
.filter_by(specialist_id=self.specialist_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
self._log_tuning("_initialize_retrievers", {"Nr of retrievers": len(specialist_retrievers)})
|
||||
|
||||
for spec_retriever in specialist_retrievers:
|
||||
# Get retriever configuration from database
|
||||
retriever = spec_retriever.retriever
|
||||
retriever_class = RetrieverRegistry.get_retriever_class(retriever.type)
|
||||
self._log_tuning("_initialize_retrievers", {
|
||||
"Retriever id": spec_retriever.retriever_id,
|
||||
"Retriever Type": retriever.type,
|
||||
"Retriever Class": str(retriever_class),
|
||||
})
|
||||
|
||||
# Initialize retriever with its configuration
|
||||
retrievers.append(
|
||||
retriever_class(
|
||||
tenant_id=self.tenant_id,
|
||||
retriever_id=retriever.id,
|
||||
)
|
||||
)
|
||||
|
||||
return retrievers
|
||||
|
||||
@property
|
||||
def required_templates(self) -> List[str]:
|
||||
"""List of required templates for this specialist"""
|
||||
return ['rag', 'history']
|
||||
|
||||
# def _detail_question(question, language, model_variables, session_id):
|
||||
# retriever = EveAIHistoryRetriever(model_variables=model_variables, session_id=session_id)
|
||||
# llm = model_variables['llm']
|
||||
# template = model_variables['history_template']
|
||||
# language_template = create_language_template(template, language)
|
||||
# full_template = replace_variable_in_template(language_template, "{tenant_context}",
|
||||
# model_variables['rag_context'])
|
||||
# history_prompt = ChatPromptTemplate.from_template(full_template)
|
||||
# setup_and_retrieval = RunnableParallel({"history": retriever, "question": RunnablePassthrough()})
|
||||
# output_parser = StrOutputParser()
|
||||
#
|
||||
# chain = setup_and_retrieval | history_prompt | llm | output_parser
|
||||
#
|
||||
# try:
|
||||
# answer = chain.invoke(question)
|
||||
# return answer
|
||||
# except LangChainException as e:
|
||||
# current_app.logger.error(f'Error detailing question: {e}')
|
||||
# raise
|
||||
|
||||
def _detail_question(self, language: str, question: str) -> str:
|
||||
"""Detail question based on conversation history"""
|
||||
try:
|
||||
# Get cached session history
|
||||
cached_session = get_chat_history(self.session_id)
|
||||
|
||||
# Format history for the prompt
|
||||
formatted_history = "\n\n".join([
|
||||
f"HUMAN:\n{interaction.specialist_results.get('detailed_query')}\n\n"
|
||||
f"AI:\n{interaction.specialist_results.get('answer')}"
|
||||
for interaction in cached_session.interactions
|
||||
])
|
||||
|
||||
# Get LLM and template
|
||||
llm = self.model_variables.get_llm(temperature=0.3)
|
||||
template = self.model_variables.get_template('history')
|
||||
language_template = create_language_template(template, language)
|
||||
|
||||
# Create prompt
|
||||
history_prompt = ChatPromptTemplate.from_template(language_template)
|
||||
|
||||
# Create chain
|
||||
chain = (
|
||||
history_prompt |
|
||||
llm |
|
||||
StrOutputParser()
|
||||
)
|
||||
|
||||
# Execute chain
|
||||
detailed_question = chain.invoke({
|
||||
"history": formatted_history,
|
||||
"question": question
|
||||
})
|
||||
|
||||
if self.tuning:
|
||||
self._log_tuning("_detail_question", {
|
||||
"cached_session_id": cached_session.session_id,
|
||||
"cached_session.interactions": str(cached_session.interactions),
|
||||
"original_question": question,
|
||||
"history_used": formatted_history,
|
||||
"detailed_question": detailed_question,
|
||||
})
|
||||
|
||||
return detailed_question
|
||||
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error detailing question: {e}")
|
||||
return question # Fallback to original question
|
||||
|
||||
def execute(self, arguments: SpecialistArguments) -> SpecialistResult:
|
||||
"""
|
||||
Execute the RAG specialist to generate an answer
|
||||
"""
|
||||
start_time = datetime.now()
|
||||
|
||||
try:
|
||||
with current_event.create_span("Specialist Detail Question"):
|
||||
# Get required arguments
|
||||
language = arguments.language
|
||||
query = arguments.query
|
||||
detailed_question = self._detail_question(language, query)
|
||||
|
||||
# Log the start of retrieval process if tuning is enabled
|
||||
with current_event.create_span("Specialist Retrieval"):
|
||||
self._log_tuning("Starting context retrieval", {
|
||||
"num_retrievers": len(self.retrievers),
|
||||
"all arguments": arguments.model_dump(),
|
||||
})
|
||||
|
||||
# Get retriever-specific arguments
|
||||
retriever_arguments = arguments.retriever_arguments
|
||||
|
||||
# Collect context from all retrievers
|
||||
all_context = []
|
||||
for retriever in self.retrievers:
|
||||
# Get arguments for this specific retriever
|
||||
retriever_id = str(retriever.retriever_id)
|
||||
if retriever_id not in retriever_arguments:
|
||||
current_app.logger.error(f"Missing arguments for retriever {retriever_id}")
|
||||
continue
|
||||
|
||||
# Get the retriever's arguments and update the query
|
||||
current_retriever_args = retriever_arguments[retriever_id]
|
||||
if isinstance(retriever_arguments[retriever_id], RetrieverArguments):
|
||||
updated_args = current_retriever_args.model_dump()
|
||||
updated_args['query'] = detailed_question
|
||||
retriever_args = RetrieverArguments(**updated_args)
|
||||
else:
|
||||
# Create a new RetrieverArguments instance from the dictionary
|
||||
current_retriever_args['query'] = detailed_question
|
||||
retriever_args = RetrieverArguments(**current_retriever_args)
|
||||
|
||||
# Each retriever gets its own specific arguments
|
||||
retriever_result = retriever.retrieve(retriever_args)
|
||||
all_context.extend(retriever_result)
|
||||
|
||||
# Sort by similarity if available and get unique contexts
|
||||
all_context.sort(key=lambda x: x.similarity, reverse=True)
|
||||
unique_contexts = []
|
||||
seen_chunks = set()
|
||||
for ctx in all_context:
|
||||
if ctx.chunk not in seen_chunks:
|
||||
unique_contexts.append(ctx)
|
||||
seen_chunks.add(ctx.chunk)
|
||||
|
||||
self._log_tuning("Context retrieval completed", {
|
||||
"total_contexts": len(all_context),
|
||||
"unique_contexts": len(unique_contexts),
|
||||
"average_similarity": sum(ctx.similarity for ctx in unique_contexts) / len(
|
||||
unique_contexts) if unique_contexts else 0
|
||||
})
|
||||
|
||||
# Prepare context for LLM
|
||||
formatted_context = "\n\n".join([
|
||||
f"SOURCE: {ctx.metadata.document_id}\n\n{ctx.chunk}"
|
||||
for ctx in unique_contexts
|
||||
])
|
||||
|
||||
with current_event.create_span("Specialist RAG invocation"):
|
||||
try:
|
||||
# Get LLM with specified temperature
|
||||
llm = self.model_variables.get_llm(temperature=self.temperature)
|
||||
|
||||
# Get template
|
||||
template = self.model_variables.get_template('rag')
|
||||
language_template = create_language_template(template, language)
|
||||
full_template = replace_variable_in_template(
|
||||
language_template,
|
||||
"{tenant_context}",
|
||||
self.specialist_context
|
||||
)
|
||||
|
||||
if self.tuning:
|
||||
self._log_tuning("Template preparation completed", {
|
||||
"template": full_template,
|
||||
"context": formatted_context,
|
||||
"tenant_context": self.specialist_context,
|
||||
})
|
||||
|
||||
# Create prompt
|
||||
rag_prompt = ChatPromptTemplate.from_template(full_template)
|
||||
|
||||
# Setup chain components
|
||||
setup_and_retrieval = RunnableParallel({
|
||||
"context": lambda x: formatted_context,
|
||||
"question": lambda x: x
|
||||
})
|
||||
|
||||
# Get output schema for structured output
|
||||
output_schema = OutputRegistry.get_schema(self.type)
|
||||
structured_llm = llm.with_structured_output(output_schema)
|
||||
chain = setup_and_retrieval | rag_prompt | structured_llm
|
||||
|
||||
raw_result = chain.invoke(query)
|
||||
result = SpecialistResult.create_for_type(
|
||||
"STANDARD_RAG",
|
||||
detailed_query=detailed_question,
|
||||
answer=raw_result.answer,
|
||||
citations=[ctx.metadata.document_id for ctx in unique_contexts
|
||||
if ctx.id in raw_result.citations],
|
||||
insufficient_info=raw_result.insufficient_info
|
||||
)
|
||||
|
||||
if self.tuning:
|
||||
self._log_tuning("LLM chain execution completed", {
|
||||
"Result": result.model_dump()
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error in LLM processing: {e}")
|
||||
if self.tuning:
|
||||
self._log_tuning("LLM processing error", {"error": str(e)})
|
||||
raise
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
current_app.logger.error(f'Error in RAG specialist execution: {str(e)}')
|
||||
raise
|
||||
|
||||
|
||||
# Register the specialist type
|
||||
SpecialistRegistry.register("STANDARD_RAG", RAGSpecialist)
|
||||
OutputRegistry.register("STANDARD_RAG", RAGOutput)
|
||||
21
eveai_chat_workers/specialists/registry.py
Normal file
21
eveai_chat_workers/specialists/registry.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from typing import Dict, Type
|
||||
from .base import BaseSpecialist
|
||||
|
||||
|
||||
class SpecialistRegistry:
|
||||
"""Registry for specialist types"""
|
||||
|
||||
_registry: Dict[str, Type[BaseSpecialist]] = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, specialist_type: str, specialist_class: Type[BaseSpecialist]):
|
||||
"""Register a new specialist type"""
|
||||
cls._registry[specialist_type] = specialist_class
|
||||
|
||||
@classmethod
|
||||
def get_specialist_class(cls, specialist_type: str) -> Type[BaseSpecialist]:
|
||||
"""Get the specialist class for a given type"""
|
||||
if specialist_type not in cls._registry:
|
||||
raise ValueError(f"Unknown specialist type: {specialist_type}")
|
||||
return cls._registry[specialist_type]
|
||||
|
||||
144
eveai_chat_workers/specialists/specialist_typing.py
Normal file
144
eveai_chat_workers/specialists/specialist_typing.py
Normal file
@@ -0,0 +1,144 @@
|
||||
from typing import List, Dict, Any, Optional
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from config.specialist_types import SPECIALIST_TYPES
|
||||
from eveai_chat_workers.retrievers.retriever_typing import RetrieverArguments
|
||||
|
||||
|
||||
class SpecialistArguments(BaseModel):
|
||||
"""
|
||||
Dynamic arguments for specialists, allowing arbitrary fields but validating required ones
|
||||
based on SPECIALIST_TYPES configuration.
|
||||
"""
|
||||
type: str = Field(..., description="Type of specialist (e.g. STANDARD_RAG)")
|
||||
retriever_arguments: Dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Arguments for each retriever, keyed by retriever ID"
|
||||
)
|
||||
|
||||
# Allow any additional fields
|
||||
model_config = {
|
||||
"extra": "allow"
|
||||
}
|
||||
|
||||
@model_validator(mode='after')
|
||||
def validate_required_arguments(self) -> 'SpecialistArguments':
|
||||
"""Validate that all required arguments for this specialist type are present"""
|
||||
specialist_config = SPECIALIST_TYPES.get(self.type)
|
||||
if not specialist_config:
|
||||
raise ValueError(f"Unknown specialist type: {self.type}")
|
||||
|
||||
# Check required arguments from configuration
|
||||
for arg_name, arg_config in specialist_config['arguments'].items():
|
||||
if arg_config.get('required', False):
|
||||
if not hasattr(self, arg_name):
|
||||
raise ValueError(f"Missing required argument '{arg_name}' for {self.type}")
|
||||
|
||||
# Type validation
|
||||
value = getattr(self, arg_name)
|
||||
expected_type = arg_config['type']
|
||||
if expected_type == 'str' and not isinstance(value, str):
|
||||
raise ValueError(f"Argument '{arg_name}' must be a string")
|
||||
elif expected_type == 'int' and not isinstance(value, int):
|
||||
raise ValueError(f"Argument '{arg_name}' must be an integer")
|
||||
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def create(cls, type_name: str, specialist_args: Dict[str, Any],
|
||||
retriever_args: Dict[str, Dict[str, Any]]) -> 'SpecialistArguments':
|
||||
"""
|
||||
Factory method to create SpecialistArguments with validated retriever arguments
|
||||
|
||||
Args:
|
||||
type_name: The specialist type (e.g., 'STANDARD_RAG')
|
||||
specialist_args: Arguments specific to the specialist
|
||||
retriever_args: Dictionary of retriever arguments keyed by retriever ID
|
||||
|
||||
Returns:
|
||||
Validated SpecialistArguments instance
|
||||
"""
|
||||
# Convert raw retriever arguments to RetrieverArguments instances
|
||||
validated_retriever_args = {}
|
||||
for retriever_id, args in retriever_args.items():
|
||||
# Ensure type is included in retriever arguments
|
||||
if 'type' not in args:
|
||||
raise ValueError(f"Retriever arguments for {retriever_id} must include 'type'")
|
||||
|
||||
validated_retriever_args[retriever_id] = RetrieverArguments(**args)
|
||||
|
||||
# Combine everything into the specialist arguments
|
||||
return cls(
|
||||
type=type_name,
|
||||
**specialist_args,
|
||||
retriever_arguments=validated_retriever_args
|
||||
)
|
||||
|
||||
|
||||
class SpecialistResult(BaseModel):
|
||||
"""
|
||||
Dynamic results from specialists, validating required fields based on
|
||||
SPECIALIST_TYPES configuration.
|
||||
"""
|
||||
type: str = Field(..., description="Type of specialist (e.g. STANDARD_RAG)")
|
||||
|
||||
# Allow any additional fields
|
||||
model_config = {
|
||||
"extra": "allow"
|
||||
}
|
||||
|
||||
@model_validator(mode='after')
|
||||
def validate_required_results(self) -> 'SpecialistResult':
|
||||
"""Validate that all required result fields for this specialist type are present"""
|
||||
specialist_config = SPECIALIST_TYPES.get(self.type)
|
||||
if not specialist_config:
|
||||
raise ValueError(f"Unknown specialist type: {self.type}")
|
||||
|
||||
# Check required results from configuration
|
||||
required_results = specialist_config.get('results', {})
|
||||
for result_name, result_config in required_results.items():
|
||||
if result_config.get('required', False):
|
||||
if not hasattr(self, result_name):
|
||||
raise ValueError(f"Missing required result '{result_name}' for {self.type}")
|
||||
|
||||
# Type validation
|
||||
value = getattr(self, result_name)
|
||||
expected_type = result_config['type']
|
||||
|
||||
# Validate based on type annotation
|
||||
if expected_type == 'str' and not isinstance(value, str):
|
||||
raise ValueError(f"Result '{result_name}' must be a string")
|
||||
elif expected_type == 'bool' and not isinstance(value, bool):
|
||||
raise ValueError(f"Result '{result_name}' must be a boolean")
|
||||
elif expected_type == 'List[str]' and not (
|
||||
isinstance(value, list) and all(isinstance(x, str) for x in value)):
|
||||
raise ValueError(f"Result '{result_name}' must be a list of strings")
|
||||
# Add other type validations as needed
|
||||
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def create_for_type(cls, specialist_type: str, **results) -> 'SpecialistResult':
|
||||
"""
|
||||
Factory method to create a type-specific result
|
||||
|
||||
Args:
|
||||
specialist_type: The type of specialist (e.g., 'STANDARD_RAG')
|
||||
**results: The result values to include
|
||||
|
||||
Returns:
|
||||
Validated SpecialistResult instance
|
||||
|
||||
Example:
|
||||
For STANDARD_RAG:
|
||||
result = SpecialistResult.create_for_type(
|
||||
'STANDARD_RAG',
|
||||
answer="The answer text",
|
||||
citations=["doc1", "doc2"],
|
||||
insufficient_info=False
|
||||
)
|
||||
"""
|
||||
# Add the type to the results
|
||||
results['type'] = specialist_type
|
||||
|
||||
# Create and validate the result
|
||||
return cls(**results)
|
||||
Reference in New Issue
Block a user