- Implementation of specialist execution api, including SSE protocol
- eveai_chat becomes deprecated and should be replaced with SSE - Adaptation of STANDARD_RAG specialist - Base class definition allowing to realise specialists with crewai framework - Implementation of SPIN_SPECIALIST - Implementation of test app for testing specialists (test_specialist_client). Also serves as an example for future SSE-based client - Improvements to startup scripts to better handle and scale multiple connections - Small improvements to the interaction forms and views - Caching implementation improved and augmented with additional caches
This commit is contained in:
296
eveai_chat_workers/specialists/SPIN_SPECIALIST/1_0.py
Normal file
296
eveai_chat_workers/specialists/SPIN_SPECIALIST/1_0.py
Normal file
@@ -0,0 +1,296 @@
|
||||
import json
|
||||
from os import wait
|
||||
from typing import Optional, List
|
||||
|
||||
from crewai.flow.flow import start, listen, and_
|
||||
from flask import current_app
|
||||
from gevent import sleep
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from common.extensions import cache_manager
|
||||
from common.models.user import Tenant
|
||||
from common.utils.business_event_context import current_event
|
||||
from eveai_chat_workers.retrievers.retriever_typing import RetrieverArguments
|
||||
from eveai_chat_workers.specialists.crewai_base_specialist import CrewAIBaseSpecialistExecutor
|
||||
from eveai_chat_workers.specialists.specialist_typing import SpecialistResult, SpecialistArguments
|
||||
from eveai_chat_workers.outputs.identification.identification_v1_0 import LeadInfoOutput
|
||||
from eveai_chat_workers.outputs.spin.spin_v1_0 import SPINOutput
|
||||
from eveai_chat_workers.outputs.rag.rag_v1_0 import RAGOutput
|
||||
from eveai_chat_workers.specialists.crewai_base_classes import EveAICrewAICrew, EveAICrewAIFlow, EveAIFlowState
|
||||
from common.utils.pydantic_utils import flatten_pydantic_model
|
||||
|
||||
|
||||
class SpecialistExecutor(CrewAIBaseSpecialistExecutor):
|
||||
"""
|
||||
type: SPIN_SPECIALIST
|
||||
type_version: 1.0
|
||||
SPIN Specialist Executor class
|
||||
"""
|
||||
|
||||
def __init__(self, tenant_id, specialist_id, session_id, task_id, **kwargs):
|
||||
self.rag_crew = None
|
||||
self.spin_crew = None
|
||||
self.identification_crew = None
|
||||
self.rag_consolidation_crew = None
|
||||
|
||||
super().__init__(tenant_id, specialist_id, session_id, task_id)
|
||||
|
||||
# Load the Tenant & set language
|
||||
self.tenant = Tenant.query.get_or_404(tenant_id)
|
||||
if self.specialist.configuration['tenant_language'] is None:
|
||||
self.specialist.configuration['tenant_language'] = self.tenant.language
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return "SPIN_SPECIALIST"
|
||||
|
||||
@property
|
||||
def type_version(self) -> str:
|
||||
return "1.0"
|
||||
|
||||
def _config_task_agents(self):
|
||||
self._add_task_agent("rag_task", "rag_agent")
|
||||
self._add_task_agent("spin_detect_task", "spin_detection_agent")
|
||||
self._add_task_agent("spin_questions_task", "spin_sales_specialist_agent")
|
||||
self._add_task_agent("identification_detection_task", "identification_agent")
|
||||
self._add_task_agent("identification_questions_task", "identification_agent")
|
||||
self._add_task_agent("email_lead_drafting_task", "email_content_agent")
|
||||
self._add_task_agent("email_lead_engagement_task", "email_engagement_agent")
|
||||
self._add_task_agent("email_lead_retrieval_task", "email_engagement_agent")
|
||||
self._add_task_agent("rag_consolidation_task", "rag_communication_agent")
|
||||
|
||||
def _config_pydantic_outputs(self):
|
||||
self._add_pydantic_output("rag_task", RAGOutput, "rag_output")
|
||||
self._add_pydantic_output("spin_questions_task", SPINOutput, "spin_questions")
|
||||
self._add_pydantic_output("identification_questions_task", LeadInfoOutput, "lead_identification_questions")
|
||||
self._add_pydantic_output("rag_consolidation_task", RAGOutput, "rag_output")
|
||||
|
||||
def _instantiate_specialist(self):
|
||||
verbose = self.tuning
|
||||
|
||||
rag_agents = [self.rag_agent]
|
||||
rag_tasks = [self.rag_task]
|
||||
self.rag_crew = EveAICrewAICrew(
|
||||
self,
|
||||
"Rag Crew",
|
||||
agents=rag_agents,
|
||||
tasks=rag_tasks,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
spin_agents = [self.spin_detection_agent, self.spin_sales_specialist_agent]
|
||||
spin_tasks = [self.spin_detect_task, self.spin_questions_task]
|
||||
self.spin_crew = EveAICrewAICrew(
|
||||
self,
|
||||
"SPIN Crew",
|
||||
agents=spin_agents,
|
||||
tasks=spin_tasks,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
identification_agents = [self.identification_agent]
|
||||
identification_tasks = [self.identification_detection_task, self.identification_questions_task]
|
||||
self.identification_crew = EveAICrewAICrew(
|
||||
self,
|
||||
"Identification Crew",
|
||||
agents=identification_agents,
|
||||
tasks=identification_tasks,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
consolidation_agents = [self.rag_communication_agent]
|
||||
consolidation_tasks = [self.rag_consolidation_task]
|
||||
self.rag_consolidation_crew = EveAICrewAICrew(
|
||||
self,
|
||||
"Rag Consolidation Crew",
|
||||
agents=consolidation_agents,
|
||||
tasks=consolidation_tasks,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
self.flow = SPINFlow(
|
||||
self,
|
||||
self.rag_crew,
|
||||
self.spin_crew,
|
||||
self.identification_crew,
|
||||
self.rag_consolidation_crew
|
||||
)
|
||||
|
||||
def execute(self, arguments: SpecialistArguments) -> SpecialistResult:
|
||||
formatted_context, citations = self.retrieve_context(arguments)
|
||||
|
||||
self.log_tuning("SPIN Specialist execution started", {})
|
||||
|
||||
flow_inputs = {
|
||||
"language": arguments.language,
|
||||
"query": arguments.query,
|
||||
"context": formatted_context,
|
||||
"citations": citations,
|
||||
"history": self._formatted_history,
|
||||
"name": self.specialist.configuration.get('name', ''),
|
||||
"company": self.specialist.configuration.get('company', ''),
|
||||
"products": self.specialist.configuration.get('products', ''),
|
||||
"product_information": self.specialist.configuration.get('product_information', ''),
|
||||
"engagement_options": self.specialist.configuration.get('engagement_options', ''),
|
||||
"tenant_language": self.specialist.configuration.get('tenant_language', ''),
|
||||
"nr_of_questions": self.specialist.configuration.get('nr_of_questions', ''),
|
||||
}
|
||||
# crew_results = self.rag_crew.kickoff(inputs=flow_inputs)
|
||||
# current_app.logger.debug(f"Test Crew Output received: {crew_results}")
|
||||
flow_results = self.flow.kickoff(inputs=flow_inputs)
|
||||
|
||||
flow_state = self.flow.state
|
||||
|
||||
results = SPINSpecialistResult.create_for_type(self.type, self.type_version)
|
||||
update_data = {}
|
||||
if flow_state.final_output:
|
||||
update_data["rag_output"] = flow_state.final_output
|
||||
elif flow_state.rag_output: # Fallback
|
||||
update_data["rag_output"] = flow_state.rag_output
|
||||
if flow_state.spin:
|
||||
update_data["spin"] = flow_state.spin
|
||||
if flow_state.lead_info:
|
||||
update_data["lead_info"] = flow_state.lead_info
|
||||
|
||||
results = results.model_copy(update=update_data)
|
||||
|
||||
self.log_tuning(f"SPIN Specialist execution ended", {"Results": results.model_dump()})
|
||||
|
||||
return results
|
||||
|
||||
# TODO: metrics
|
||||
|
||||
|
||||
class SPINSpecialistInput(BaseModel):
|
||||
language: Optional[str] = Field(None, alias="language")
|
||||
query: Optional[str] = Field(None, alias="query")
|
||||
context: Optional[str] = Field(None, alias="context")
|
||||
citations: Optional[List[int]] = Field(None, alias="citations")
|
||||
history: Optional[str] = Field(None, alias="history")
|
||||
name: Optional[str] = Field(None, alias="name")
|
||||
company: Optional[str] = Field(None, alias="company")
|
||||
products: Optional[str] = Field(None, alias="products")
|
||||
product_information: Optional[str] = Field(None, alias="product_information")
|
||||
engagement_options: Optional[str] = Field(None, alias="engagement_options")
|
||||
tenant_language: Optional[str] = Field(None, alias="tenant_language")
|
||||
nr_of_questions: Optional[int] = Field(None, alias="nr_of_questions")
|
||||
|
||||
|
||||
class SPINSpecialistResult(SpecialistResult):
|
||||
rag_output: Optional[RAGOutput] = Field(None, alias="Rag Output")
|
||||
spin: Optional[SPINOutput] = Field(None, alias="Spin Output")
|
||||
lead_info: Optional[LeadInfoOutput] = Field(None, alias="Lead Info Output")
|
||||
|
||||
|
||||
class SPINFlowState(EveAIFlowState):
|
||||
"""Flow state for SPIN specialist that automatically updates from task outputs"""
|
||||
input: Optional[SPINSpecialistInput] = None
|
||||
rag_output: Optional[RAGOutput] = None
|
||||
lead_info: Optional[LeadInfoOutput] = None
|
||||
spin: Optional[SPINOutput] = None
|
||||
final_output: Optional[RAGOutput] = None
|
||||
|
||||
|
||||
class SPINFlow(EveAICrewAIFlow[SPINFlowState]):
|
||||
def __init__(self, specialist_executor, rag_crew, spin_crew, identification_crew, rag_consolidation_crew, **kwargs):
|
||||
super().__init__(specialist_executor, "SPIN Specialist Flow", **kwargs)
|
||||
self.specialist_executor = specialist_executor
|
||||
self.rag_crew = rag_crew
|
||||
self.spin_crew = spin_crew
|
||||
self.identification_crew = identification_crew
|
||||
self.rag_consolidation_crew = rag_consolidation_crew
|
||||
self.exception_raised = False
|
||||
|
||||
@start()
|
||||
def process_inputs(self):
|
||||
return ""
|
||||
|
||||
@listen(process_inputs)
|
||||
def execute_rag(self):
|
||||
inputs = self.state.input.model_dump()
|
||||
try:
|
||||
crew_output = self.rag_crew.kickoff(inputs=inputs)
|
||||
self.specialist_executor.log_tuning("RAG Crew Output", crew_output.model_dump())
|
||||
output_pydantic = crew_output.pydantic
|
||||
if not output_pydantic:
|
||||
raw_json = json.loads(crew_output.raw)
|
||||
output_pydantic = RAGOutput.model_validate(raw_json)
|
||||
self.state.rag_output = output_pydantic
|
||||
return crew_output
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"CREW rag_crew Kickoff Error: {str(e)}")
|
||||
self.exception_raised = True
|
||||
raise e
|
||||
|
||||
@listen(process_inputs)
|
||||
def execute_spin(self):
|
||||
inputs = self.state.input.model_dump()
|
||||
try:
|
||||
crew_output = self.spin_crew.kickoff(inputs=inputs)
|
||||
self.specialist_executor.log_tuning("Spin Crew Output", crew_output.model_dump())
|
||||
output_pydantic = crew_output.pydantic
|
||||
if not output_pydantic:
|
||||
raw_json = json.loads(crew_output.raw)
|
||||
output_pydantic = SPINOutput.model_validate(raw_json)
|
||||
self.state.spin = output_pydantic
|
||||
return crew_output
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"CREW spin_crew Kickoff Error: {str(e)}")
|
||||
self.exception_raised = True
|
||||
raise e
|
||||
|
||||
@listen(process_inputs)
|
||||
def execute_identification(self):
|
||||
inputs = self.state.input.model_dump()
|
||||
try:
|
||||
crew_output = self.identification_crew.kickoff(inputs=inputs)
|
||||
self.specialist_executor.log_tuning("Identification Crew Output", crew_output.model_dump())
|
||||
output_pydantic = crew_output.pydantic
|
||||
if not output_pydantic:
|
||||
raw_json = json.loads(crew_output.raw)
|
||||
output_pydantic = LeadInfoOutput.model_validate(raw_json)
|
||||
self.state.lead_info = output_pydantic
|
||||
return crew_output
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"CREW identification_crew Kickoff Error: {str(e)}")
|
||||
self.exception_raised = True
|
||||
raise e
|
||||
|
||||
@listen(and_(execute_rag, execute_spin, execute_identification))
|
||||
def consolidate(self):
|
||||
inputs = self.state.input.model_dump()
|
||||
if self.state.rag_output:
|
||||
inputs["prepared_answers"] = self.state.rag_output.answer
|
||||
additional_questions = ""
|
||||
if self.state.lead_info:
|
||||
additional_questions = self.state.lead_info.questions + "\n"
|
||||
if self.state.spin:
|
||||
additional_questions = additional_questions + self.state.spin.questions
|
||||
inputs["additional_questions"] = additional_questions
|
||||
try:
|
||||
crew_output = self.rag_consolidation_crew.kickoff(inputs=inputs)
|
||||
self.specialist_executor.log_tuning("RAG Consolidation Crew Output", crew_output.model_dump())
|
||||
output_pydantic = crew_output.pydantic
|
||||
if not output_pydantic:
|
||||
raw_json = json.loads(crew_output.raw)
|
||||
output_pydantic = LeadInfoOutput.model_validate(raw_json)
|
||||
self.state.final_output = output_pydantic
|
||||
return crew_output
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"CREW rag_consolidation_crew Kickoff Error: {str(e)}")
|
||||
self.exception_raised = True
|
||||
raise e
|
||||
|
||||
def kickoff(self, inputs=None):
|
||||
with current_event.create_span("SPIN Specialist Execution"):
|
||||
self.specialist_executor.log_tuning("Inputs retrieved", inputs)
|
||||
self.state.input = SPINSpecialistInput.model_validate(inputs)
|
||||
self.specialist.update_progress("EveAI Flow Start", {"name": "SPIN"})
|
||||
try:
|
||||
result = super().kickoff()
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error kicking of Flow: {str(e)}")
|
||||
|
||||
self.specialist.update_progress("EveAI Flow End", {"name": "SPIN"})
|
||||
|
||||
return self.state
|
||||
@@ -9,24 +9,24 @@ from langchain_core.runnables import RunnableParallel, RunnablePassthrough
|
||||
from common.langchain.outputs.base import OutputRegistry
|
||||
from common.langchain.outputs.rag import RAGOutput
|
||||
from common.utils.business_event_context import current_event
|
||||
from .specialist_typing import SpecialistArguments, SpecialistResult
|
||||
from ..chat_session_cache import CachedInteraction, get_chat_history
|
||||
from ..retrievers.registry import RetrieverRegistry
|
||||
from ..retrievers.base import BaseRetriever
|
||||
from common.models.interaction import SpecialistRetriever, Specialist
|
||||
from eveai_chat_workers.specialists.specialist_typing import SpecialistArguments, SpecialistResult
|
||||
from eveai_chat_workers.chat_session_cache import get_chat_history
|
||||
from common.models.interaction import Specialist
|
||||
from common.utils.model_utils import get_model_variables, create_language_template, replace_variable_in_template
|
||||
from .base import BaseSpecialist
|
||||
from .registry import SpecialistRegistry
|
||||
from eveai_chat_workers.retrievers.retriever_typing import RetrieverArguments, RetrieverResult
|
||||
from eveai_chat_workers.specialists.base_specialist import BaseSpecialistExecutor
|
||||
from eveai_chat_workers.retrievers.retriever_typing import RetrieverArguments
|
||||
|
||||
|
||||
class RAGSpecialist(BaseSpecialist):
|
||||
class SpecialistExecutor(BaseSpecialistExecutor):
|
||||
"""
|
||||
type: STANDARD_RAG
|
||||
type_version: 1.0
|
||||
Standard Q&A RAG Specialist implementation that combines retriever results
|
||||
with LLM processing to generate answers.
|
||||
"""
|
||||
def __init__(self, tenant_id: int, specialist_id: int, session_id: str):
|
||||
super().__init__(tenant_id, specialist_id, session_id)
|
||||
|
||||
def __init__(self, tenant_id: int, specialist_id: int, session_id: str, task_id: str):
|
||||
super().__init__(tenant_id, specialist_id, session_id, task_id)
|
||||
|
||||
# Check and load the specialist
|
||||
specialist = Specialist.query.get_or_404(specialist_id)
|
||||
@@ -43,66 +43,17 @@ class RAGSpecialist(BaseSpecialist):
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return "STANDARD_RAG"
|
||||
return "STANDARD_RAG_SPECIALIST"
|
||||
|
||||
def _initialize_retrievers(self) -> List[BaseRetriever]:
|
||||
"""Initialize all retrievers associated with this specialist"""
|
||||
retrievers = []
|
||||
|
||||
# Get retriever associations from database
|
||||
specialist_retrievers = (
|
||||
SpecialistRetriever.query
|
||||
.filter_by(specialist_id=self.specialist_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
self._log_tuning("_initialize_retrievers", {"Nr of retrievers": len(specialist_retrievers)})
|
||||
|
||||
for spec_retriever in specialist_retrievers:
|
||||
# Get retriever configuration from database
|
||||
retriever = spec_retriever.retriever
|
||||
retriever_class = RetrieverRegistry.get_retriever_class(retriever.type)
|
||||
self._log_tuning("_initialize_retrievers", {
|
||||
"Retriever id": spec_retriever.retriever_id,
|
||||
"Retriever Type": retriever.type,
|
||||
"Retriever Class": str(retriever_class),
|
||||
})
|
||||
|
||||
# Initialize retriever with its configuration
|
||||
retrievers.append(
|
||||
retriever_class(
|
||||
tenant_id=self.tenant_id,
|
||||
retriever_id=retriever.id,
|
||||
)
|
||||
)
|
||||
|
||||
return retrievers
|
||||
@property
|
||||
def type_version(self) -> str:
|
||||
return "1.0"
|
||||
|
||||
@property
|
||||
def required_templates(self) -> List[str]:
|
||||
"""List of required templates for this specialist"""
|
||||
return ['rag', 'history']
|
||||
|
||||
# def _detail_question(question, language, model_variables, session_id):
|
||||
# retriever = EveAIHistoryRetriever(model_variables=model_variables, session_id=session_id)
|
||||
# llm = model_variables['llm']
|
||||
# template = model_variables['history_template']
|
||||
# language_template = create_language_template(template, language)
|
||||
# full_template = replace_variable_in_template(language_template, "{tenant_context}",
|
||||
# model_variables['rag_context'])
|
||||
# history_prompt = ChatPromptTemplate.from_template(full_template)
|
||||
# setup_and_retrieval = RunnableParallel({"history": retriever, "question": RunnablePassthrough()})
|
||||
# output_parser = StrOutputParser()
|
||||
#
|
||||
# chain = setup_and_retrieval | history_prompt | llm | output_parser
|
||||
#
|
||||
# try:
|
||||
# answer = chain.invoke(question)
|
||||
# return answer
|
||||
# except LangChainException as e:
|
||||
# current_app.logger.error(f'Error detailing question: {e}')
|
||||
# raise
|
||||
|
||||
def _detail_question(self, language: str, question: str) -> str:
|
||||
"""Detail question based on conversation history"""
|
||||
try:
|
||||
@@ -138,7 +89,7 @@ class RAGSpecialist(BaseSpecialist):
|
||||
})
|
||||
|
||||
if self.tuning:
|
||||
self._log_tuning("_detail_question", {
|
||||
self.log_tuning("_detail_question", {
|
||||
"cached_session_id": cached_session.session_id,
|
||||
"cached_session.interactions": str(cached_session.interactions),
|
||||
"original_question": question,
|
||||
@@ -160,17 +111,20 @@ class RAGSpecialist(BaseSpecialist):
|
||||
|
||||
try:
|
||||
with current_event.create_span("Specialist Detail Question"):
|
||||
self.update_progress("Detail Question Start", {})
|
||||
# Get required arguments
|
||||
language = arguments.language
|
||||
query = arguments.query
|
||||
detailed_question = self._detail_question(language, query)
|
||||
self.update_progress("Detail Question End", {})
|
||||
|
||||
# Log the start of retrieval process if tuning is enabled
|
||||
with current_event.create_span("Specialist Retrieval"):
|
||||
self._log_tuning("Starting context retrieval", {
|
||||
self.log_tuning("Starting context retrieval", {
|
||||
"num_retrievers": len(self.retrievers),
|
||||
"all arguments": arguments.model_dump(),
|
||||
})
|
||||
self.update_progress("EveAI Retriever Start", {})
|
||||
|
||||
# Get retriever-specific arguments
|
||||
retriever_arguments = arguments.retriever_arguments
|
||||
@@ -208,12 +162,13 @@ class RAGSpecialist(BaseSpecialist):
|
||||
unique_contexts.append(ctx)
|
||||
seen_chunks.add(ctx.chunk)
|
||||
|
||||
self._log_tuning("Context retrieval completed", {
|
||||
self.log_tuning("Context retrieval completed", {
|
||||
"total_contexts": len(all_context),
|
||||
"unique_contexts": len(unique_contexts),
|
||||
"average_similarity": sum(ctx.similarity for ctx in unique_contexts) / len(
|
||||
unique_contexts) if unique_contexts else 0
|
||||
})
|
||||
self.update_progress("EveAI Retriever Complete", {})
|
||||
|
||||
# Prepare context for LLM
|
||||
formatted_context = "\n\n".join([
|
||||
@@ -223,6 +178,7 @@ class RAGSpecialist(BaseSpecialist):
|
||||
|
||||
with current_event.create_span("Specialist RAG invocation"):
|
||||
try:
|
||||
self.update_progress(self.task_id, "EveAI Chain Start", {})
|
||||
# Get LLM with specified temperature
|
||||
llm = self.model_variables.get_llm(temperature=self.temperature)
|
||||
|
||||
@@ -236,7 +192,7 @@ class RAGSpecialist(BaseSpecialist):
|
||||
)
|
||||
|
||||
if self.tuning:
|
||||
self._log_tuning("Template preparation completed", {
|
||||
self.log_tuning("Template preparation completed", {
|
||||
"template": full_template,
|
||||
"context": formatted_context,
|
||||
"tenant_context": self.specialist_context,
|
||||
@@ -258,7 +214,8 @@ class RAGSpecialist(BaseSpecialist):
|
||||
|
||||
raw_result = chain.invoke(detailed_question)
|
||||
result = SpecialistResult.create_for_type(
|
||||
"STANDARD_RAG",
|
||||
self.type,
|
||||
self.type_version,
|
||||
detailed_query=detailed_question,
|
||||
answer=raw_result.answer,
|
||||
citations=[ctx.metadata.document_id for ctx in unique_contexts
|
||||
@@ -267,14 +224,15 @@ class RAGSpecialist(BaseSpecialist):
|
||||
)
|
||||
|
||||
if self.tuning:
|
||||
self._log_tuning("LLM chain execution completed", {
|
||||
self.log_tuning("LLM chain execution completed", {
|
||||
"Result": result.model_dump()
|
||||
})
|
||||
self.update_progress("EveAI Chain Complete", {})
|
||||
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error in LLM processing: {e}")
|
||||
if self.tuning:
|
||||
self._log_tuning("LLM processing error", {"error": str(e)})
|
||||
self.log_tuning("LLM processing error", {"error": str(e)})
|
||||
raise
|
||||
|
||||
return result
|
||||
@@ -285,5 +243,4 @@ class RAGSpecialist(BaseSpecialist):
|
||||
|
||||
|
||||
# Register the specialist type
|
||||
SpecialistRegistry.register("STANDARD_RAG", RAGSpecialist)
|
||||
OutputRegistry.register("STANDARD_RAG", RAGOutput)
|
||||
OutputRegistry.register("STANDARD_RAG_SPECIALIST", RAGOutput)
|
||||
@@ -1,5 +0,0 @@
|
||||
# Import all specialist implementations here to ensure registration
|
||||
from . import rag_specialist
|
||||
|
||||
# List of all available specialist implementations
|
||||
__all__ = ['rag_specialist']
|
||||
@@ -1,50 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Any
|
||||
from flask import current_app
|
||||
|
||||
from config.logging_config import TuningLogger
|
||||
from eveai_chat_workers.specialists.specialist_typing import SpecialistArguments, SpecialistResult
|
||||
|
||||
|
||||
class BaseSpecialist(ABC):
|
||||
"""Base class for all specialists"""
|
||||
|
||||
def __init__(self, tenant_id: int, specialist_id: int, session_id: str):
|
||||
self.tenant_id = tenant_id
|
||||
self.specialist_id = specialist_id
|
||||
self.session_id = session_id
|
||||
self.tuning = False
|
||||
self.tuning_logger = None
|
||||
self._setup_tuning_logger()
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def type(self) -> str:
|
||||
"""The type of the specialist"""
|
||||
pass
|
||||
|
||||
def _setup_tuning_logger(self):
|
||||
try:
|
||||
self.tuning_logger = TuningLogger(
|
||||
'tuning',
|
||||
tenant_id=self.tenant_id,
|
||||
specialist_id=self.specialist_id,
|
||||
)
|
||||
# Verify logger is working with a test message
|
||||
if self.tuning:
|
||||
self.tuning_logger.log_tuning('specialist', "Tuning logger initialized")
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Failed to setup tuning logger: {str(e)}")
|
||||
raise
|
||||
|
||||
def _log_tuning(self, message: str, data: Dict[str, Any] = None) -> None:
|
||||
if self.tuning and self.tuning_logger:
|
||||
try:
|
||||
self.tuning_logger.log_tuning('specialist', message, data)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Processor: Error in tuning logging: {e}")
|
||||
|
||||
@abstractmethod
|
||||
def execute(self, arguments: SpecialistArguments) -> SpecialistResult:
|
||||
"""Execute the specialist's logic"""
|
||||
pass
|
||||
106
eveai_chat_workers/specialists/base_specialist.py
Normal file
106
eveai_chat_workers/specialists/base_specialist.py
Normal file
@@ -0,0 +1,106 @@
|
||||
import importlib
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Any, List
|
||||
from flask import current_app
|
||||
|
||||
from common.models.interaction import SpecialistRetriever
|
||||
from common.utils.execution_progress import ExecutionProgressTracker
|
||||
from config.logging_config import TuningLogger
|
||||
from eveai_chat_workers.retrievers.base import BaseRetriever
|
||||
from eveai_chat_workers.retrievers.registry import RetrieverRegistry
|
||||
from eveai_chat_workers.specialists.specialist_typing import SpecialistArguments, SpecialistResult
|
||||
|
||||
|
||||
class BaseSpecialistExecutor(ABC):
|
||||
"""Base class for all specialists"""
|
||||
|
||||
def __init__(self, tenant_id: int, specialist_id: int, session_id: str, task_id: str):
|
||||
self.tenant_id = tenant_id
|
||||
self.specialist_id = specialist_id
|
||||
self.session_id = session_id
|
||||
self.task_id = task_id
|
||||
self.tuning = False
|
||||
self.tuning_logger = None
|
||||
self._setup_tuning_logger()
|
||||
self.ept = ExecutionProgressTracker()
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def type(self) -> str:
|
||||
"""The type of the specialist"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def type_version(self) -> str:
|
||||
"""The type version of the specialist"""
|
||||
pass
|
||||
|
||||
def _initialize_retrievers(self) -> List[BaseRetriever]:
|
||||
"""Initialize all retrievers associated with this specialist"""
|
||||
retrievers = []
|
||||
|
||||
# Get retriever associations from database
|
||||
specialist_retrievers = (
|
||||
SpecialistRetriever.query
|
||||
.filter_by(specialist_id=self.specialist_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
self.log_tuning("_initialize_retrievers", {"Nr of retrievers": len(specialist_retrievers)})
|
||||
|
||||
for spec_retriever in specialist_retrievers:
|
||||
# Get retriever configuration from database
|
||||
retriever = spec_retriever.retriever
|
||||
retriever_class = RetrieverRegistry.get_retriever_class(retriever.type)
|
||||
self.log_tuning("_initialize_retrievers", {
|
||||
"Retriever id": spec_retriever.retriever_id,
|
||||
"Retriever Type": retriever.type,
|
||||
"Retriever Class": str(retriever_class),
|
||||
})
|
||||
|
||||
# Initialize retriever with its configuration
|
||||
retrievers.append(
|
||||
retriever_class(
|
||||
tenant_id=self.tenant_id,
|
||||
retriever_id=retriever.id,
|
||||
)
|
||||
)
|
||||
|
||||
return retrievers
|
||||
|
||||
def _setup_tuning_logger(self):
|
||||
try:
|
||||
self.tuning_logger = TuningLogger(
|
||||
'tuning',
|
||||
tenant_id=self.tenant_id,
|
||||
specialist_id=self.specialist_id,
|
||||
)
|
||||
# Verify logger is working with a test message
|
||||
if self.tuning:
|
||||
self.tuning_logger.log_tuning('specialist', "Tuning logger initialized")
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Failed to setup tuning logger: {str(e)}")
|
||||
raise
|
||||
|
||||
def log_tuning(self, message: str, data: Dict[str, Any] = None) -> None:
|
||||
if self.tuning and self.tuning_logger:
|
||||
try:
|
||||
self.tuning_logger.log_tuning('specialist', message, data)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Processor: Error in tuning logging: {e}")
|
||||
|
||||
def update_progress(self, processing_type, data) -> None:
|
||||
self.ept.send_update(self.task_id, processing_type, data)
|
||||
|
||||
@abstractmethod
|
||||
def execute(self, arguments: SpecialistArguments) -> SpecialistResult:
|
||||
"""Execute the specialist's logic"""
|
||||
pass
|
||||
|
||||
|
||||
def get_specialist_class(specialist_type: str, type_version: str):
|
||||
major_minor = '_'.join(type_version.split('.')[:2])
|
||||
module_path = f"eveai_chat_workers.specialists.{specialist_type}.{major_minor}"
|
||||
module = importlib.import_module(module_path)
|
||||
return module.SpecialistExecutor
|
||||
129
eveai_chat_workers/specialists/crewai_base_classes.py
Normal file
129
eveai_chat_workers/specialists/crewai_base_classes.py
Normal file
@@ -0,0 +1,129 @@
|
||||
import json
|
||||
|
||||
from crewai import Agent, Task, Crew, Flow
|
||||
from crewai.agents.parser import AgentAction, AgentFinish
|
||||
from crewai.tools import BaseTool
|
||||
from flask import current_app
|
||||
from pydantic import BaseModel, create_model, Field, ConfigDict
|
||||
from typing import Dict, Type, get_type_hints, Optional, List, Any, Callable
|
||||
|
||||
|
||||
class EveAICrewAIAgent(Agent):
|
||||
specialist: Any = Field(default=None, exclude=True)
|
||||
name: str = Field(default=None, exclude=True)
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
def __init__(self, specialist, name: str, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.specialist = specialist
|
||||
self.name = name
|
||||
self.specialist.log_tuning("Initializing EveAICrewAIAgent", {"name": name})
|
||||
self.specialist.update_progress("EveAI Agent Initialisation", {"name": self.name})
|
||||
|
||||
def execute_task(
|
||||
self,
|
||||
task: Task,
|
||||
context: Optional[str] = None,
|
||||
tools: Optional[List[BaseTool]] = None,
|
||||
) -> str:
|
||||
"""Execute a task with the agent. Performs AskEveAI specific fuctionality on top of task execution
|
||||
|
||||
Args:
|
||||
task: Task to execute.
|
||||
context: Context to execute the task in.
|
||||
tools: Tools to use for the task.
|
||||
|
||||
Returns:
|
||||
Output of the agent
|
||||
"""
|
||||
self.specialist.log_tuning("EveAI Agent Task Start",
|
||||
{"name": self.name,
|
||||
'task': task.name,
|
||||
})
|
||||
self.specialist.update_progress("EveAI Agent Task Start",
|
||||
{"name": self.name,
|
||||
'task': task.name,
|
||||
})
|
||||
|
||||
result = super().execute_task(task, context, tools)
|
||||
|
||||
self.specialist.log_tuning("EveAI Agent Task Complete",
|
||||
{"name": self.name,
|
||||
'task': task.name,
|
||||
'result': result,
|
||||
})
|
||||
self.specialist.update_progress("EveAI Agent Task Complete",
|
||||
{"name": self.name,
|
||||
'task': task.name,
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class EveAICrewAITask(Task):
|
||||
specialist: Any = Field(default=None, exclude=True)
|
||||
name: str = Field(default=None, exclude=True)
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
def __init__(self, specialist, name: str, **kwargs):
|
||||
# kwargs.update({"callback": create_task_callback(self)})
|
||||
super().__init__(**kwargs)
|
||||
# current_app.logger.debug(f"Task pydantic class for {name}: {"class", self.output_pydantic}")
|
||||
self.specialist = specialist
|
||||
self.name = name
|
||||
self.specialist.log_tuning("Initializing EveAICrewAITask", {"name": name})
|
||||
self.specialist.update_progress("EveAI Task Initialisation", {"name": name})
|
||||
|
||||
|
||||
# def create_task_callback(task: EveAICrewAITask):
|
||||
# def task_callback(output):
|
||||
# # Todo Check if required with new version of crewai
|
||||
# if isinstance(output, BaseModel):
|
||||
# task.specialist.log_tuning(f"TASK CALLBACK: EveAICrewAITask {task.name} Output:",
|
||||
# {'output': output.model_dump()})
|
||||
# if output.output_format == "pydantic" and not output.pydantic:
|
||||
# try:
|
||||
# raw_json = json.loads(output.raw)
|
||||
# output_pydantic = task.output_pydantic(**raw_json)
|
||||
# output.pydantic = output_pydantic
|
||||
# task.specialist.log_tuning(f"TASK CALLBACK: EveAICrewAITask {task.name} Converted Output",
|
||||
# {'output': output_pydantic.model_dump()})
|
||||
# except Exception as e:
|
||||
# task.specialist.log_tuning(f"TASK CALLBACK: EveAICrewAITask {task.name} Output Conversion Error: "
|
||||
# f"{str(e)}", {})
|
||||
#
|
||||
# return task_callback
|
||||
|
||||
|
||||
class EveAICrewAICrew(Crew):
|
||||
specialist: Any = Field(default=None, exclude=True)
|
||||
name: str = Field(default=None, exclude=True)
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
def __init__(self, specialist, name: str, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.specialist = specialist
|
||||
self.name = name
|
||||
self.specialist.log_tuning("Initializing EveAICrewAICrew", {"name": self.name})
|
||||
self.specialist.update_progress("EveAI Crew Initialisation", {"name": self.name})
|
||||
|
||||
|
||||
class EveAICrewAIFlow(Flow):
|
||||
specialist: Any = Field(default=None, exclude=True)
|
||||
name: str = Field(default=None, exclude=True)
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
def __init__(self, specialist, name: str, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.specialist = specialist
|
||||
self.name = name
|
||||
self.specialist.log_tuning("Initializing EveAICrewAIFlow", {"name": self.name})
|
||||
self.specialist.update_progress("EveAI Flow Initialisation", {"name": self.name})
|
||||
|
||||
|
||||
class EveAIFlowState(BaseModel):
|
||||
"""Base class for all EveAI flow states"""
|
||||
pass
|
||||
|
||||
|
||||
|
||||
243
eveai_chat_workers/specialists/crewai_base_specialist.py
Normal file
243
eveai_chat_workers/specialists/crewai_base_specialist.py
Normal file
@@ -0,0 +1,243 @@
|
||||
import json
|
||||
from typing import Dict, Any, Optional, Type, TypeVar, List, Tuple
|
||||
|
||||
from crewai.flow.flow import FlowState
|
||||
from flask import current_app
|
||||
|
||||
from common.models.interaction import Specialist
|
||||
from common.utils.business_event_context import current_event
|
||||
from common.utils.model_utils import get_model_variables
|
||||
from eveai_chat_workers.retrievers.retriever_typing import RetrieverArguments
|
||||
from eveai_chat_workers.specialists.crewai_base_classes import EveAICrewAIAgent, EveAICrewAITask
|
||||
from crewai.tools import BaseTool
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from common.extensions import cache_manager
|
||||
from eveai_chat_workers.specialists.base_specialist import BaseSpecialistExecutor
|
||||
from common.utils.cache.crewai_configuration import (
|
||||
ProcessedAgentConfig, ProcessedTaskConfig, ProcessedToolConfig,
|
||||
SpecialistProcessedConfig
|
||||
)
|
||||
from eveai_chat_workers.specialists.specialist_typing import SpecialistArguments
|
||||
|
||||
T = TypeVar('T') # For generic type hints
|
||||
|
||||
|
||||
class CrewAIBaseSpecialistExecutor(BaseSpecialistExecutor):
|
||||
"""Base class for all CrewAI-based specialists"""
|
||||
|
||||
def __init__(self, tenant_id: int, specialist_id: int, session_id: str, task_id):
|
||||
super().__init__(tenant_id, specialist_id, session_id, task_id)
|
||||
|
||||
# Check and load the specialist
|
||||
self.specialist = Specialist.query.get_or_404(specialist_id)
|
||||
# Set the specific configuration for the SPIN Specialist
|
||||
# self.specialist_configuration = json.loads(self.specialist.configuration)
|
||||
self.tuning = self.specialist.tuning
|
||||
# Initialize retrievers
|
||||
self.retrievers = self._initialize_retrievers()
|
||||
|
||||
# Initialize model variables
|
||||
self.model_variables = get_model_variables(tenant_id)
|
||||
|
||||
# initialize the Flow
|
||||
self.flow = None
|
||||
|
||||
# Runtime instances
|
||||
self._agents: Dict[str, EveAICrewAIAgent] = {}
|
||||
self._tasks: Dict[str, EveAICrewAITask] = {}
|
||||
self._tools: Dict[str, BaseTool] = {}
|
||||
|
||||
# Crew configuration
|
||||
self._task_agents: Dict[str, str] = {}
|
||||
self._task_pydantic_outputs: Dict[str, Type[BaseModel]] = {}
|
||||
self._task_state_names: Dict[str, str] = {}
|
||||
|
||||
# Processed configurations
|
||||
self._config = cache_manager.crewai_processed_config_cache.get_specialist_config(tenant_id, specialist_id)
|
||||
self._config_task_agents()
|
||||
self._config_pydantic_outputs()
|
||||
self._instantiate_crew_assets()
|
||||
self._instantiate_specialist()
|
||||
|
||||
# Retrieve history
|
||||
self._cached_session = cache_manager.chat_session_cache.get_cached_session(self.session_id)
|
||||
# Format history for the prompt
|
||||
self._formatted_history = "\n\n".join([
|
||||
f"HUMAN:\n{interaction.specialist_results.get('detailed_query')}\n\n"
|
||||
f"AI:\n{interaction.specialist_results.get('answer')}"
|
||||
for interaction in self._cached_session.interactions
|
||||
])
|
||||
|
||||
def _add_task_agent(self, task_name: str, agent_name: str):
|
||||
self._task_agents[task_name.lower()] = agent_name
|
||||
|
||||
@abstractmethod
|
||||
def _config_task_agents(self):
|
||||
"""Configure the task agents by adding task-agent combinations. Use _add_task_agent()
|
||||
"""
|
||||
|
||||
@property
|
||||
def task_agents(self) -> Dict[str, str]:
|
||||
return self._task_agents
|
||||
|
||||
def _add_pydantic_output(self, task_name: str, output: Type[BaseModel], state_name: str is None):
|
||||
self._task_pydantic_outputs[task_name.lower()] = output
|
||||
if state_name is not None:
|
||||
self._task_state_names[task_name.lower()] = state_name
|
||||
|
||||
@abstractmethod
|
||||
def _config_pydantic_outputs(self):
|
||||
"""Configure the task pydantic outputs by adding task-output combinations. Use _add_pydantic_output()"""
|
||||
|
||||
@property
|
||||
def task_pydantic_outputs(self):
|
||||
return self._task_pydantic_outputs
|
||||
|
||||
@property
|
||||
def task_state_names(self):
|
||||
return self._task_state_names
|
||||
|
||||
def _instantiate_crew_assets(self):
|
||||
self._instantiate_crew_agents()
|
||||
self._instantiate_tasks()
|
||||
self._instantiate_tools()
|
||||
|
||||
def _instantiate_crew_agents(self):
|
||||
for agent in self.specialist.agents:
|
||||
agent_config = cache_manager.agents_config_cache.get_config(agent.type, agent.type_version)
|
||||
agent_role = agent_config.get('role', '').replace('{custom_role}', agent.role or '')
|
||||
agent_goal = agent_config.get('goal', '').replace('{custom_goal}', agent.goal or '')
|
||||
agent_backstory = agent_config.get('backstory', '').replace('{custom_backstory}', agent.backstory or '')
|
||||
new_agent = EveAICrewAIAgent(
|
||||
self,
|
||||
agent.type.lower(),
|
||||
role=agent_role,
|
||||
goal=agent_goal,
|
||||
backstory=agent_backstory,
|
||||
verbose=agent.tuning,
|
||||
)
|
||||
agent_name = agent.type.lower()
|
||||
self.log_tuning(f"CrewAI Agent {agent_name} initialized", agent_config)
|
||||
self._agents[agent_name] = new_agent
|
||||
|
||||
def _instantiate_tasks(self):
|
||||
for task in self.specialist.tasks:
|
||||
task_config = cache_manager.tasks_config_cache.get_config(task.type, task.type_version)
|
||||
task_description = (task_config.get('task_description', '')
|
||||
.replace('{custom_description}', task.task_description or ''))
|
||||
task_expected_output = (task_config.get('expected_output', '')
|
||||
.replace('{custom_expected_output}', task.expected_output or ''))
|
||||
# dynamically build the arguments
|
||||
task_kwargs = {
|
||||
"description": task_description,
|
||||
"expected_output": task_expected_output,
|
||||
"verbose": task.tuning
|
||||
}
|
||||
task_name = task.type.lower()
|
||||
if task_name in self._task_pydantic_outputs:
|
||||
task_kwargs["output_pydantic"] = self._task_pydantic_outputs[task_name]
|
||||
if task_name in self._task_agents:
|
||||
task_kwargs["agent"] = self._agents[self._task_agents[task_name]]
|
||||
|
||||
# Instantiate the task with dynamic arguments
|
||||
new_task = EveAICrewAITask(self, task_name, **task_kwargs)
|
||||
|
||||
# Logging and storing the task
|
||||
self.log_tuning(f"CrewAI Task {task_name} initialized", task_config)
|
||||
self._tasks[task_name] = new_task
|
||||
|
||||
def _instantiate_tools(self):
|
||||
# This currently is not implemented
|
||||
# TODO: complete Tool instantiation
|
||||
pass
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
"""Enable dynamic access to agents as attributes"""
|
||||
try:
|
||||
if name.endswith('_agent'):
|
||||
return self._agents[name]
|
||||
|
||||
if name.endswith('_task'):
|
||||
return self._tasks[name]
|
||||
|
||||
if name.endswith('_tool'):
|
||||
return self._tools[name]
|
||||
|
||||
# Not a known component request
|
||||
raise AttributeError(f"'{self.__class__.__name__}' has no attribute '{name}'")
|
||||
except KeyError:
|
||||
raise AttributeError(f"'{self.__class__.__name__}' has no attribute '{name}'")
|
||||
|
||||
@abstractmethod
|
||||
def _instantiate_specialist(self):
|
||||
"""Instantiate a crew (or flow) to set up the complete specialist, using the assets (agents, tasks, tools).
|
||||
The assets can be retrieved using their type name in lower case, e.g. rag_agent"""
|
||||
|
||||
def retrieve_context(self, arguments: SpecialistArguments) -> Tuple[str, List[int]]:
|
||||
with current_event.create_span("Specialist Retrieval"):
|
||||
self.log_tuning("Starting context retrieval", {
|
||||
"num_retrievers": len(self.retrievers),
|
||||
"all arguments": arguments.model_dump(),
|
||||
})
|
||||
|
||||
# Get retriever-specific arguments
|
||||
retriever_arguments = arguments.retriever_arguments
|
||||
|
||||
# Collect context from all retrievers
|
||||
all_context = []
|
||||
for retriever in self.retrievers:
|
||||
# Get arguments for this specific retriever
|
||||
retriever_id = str(retriever.retriever_id)
|
||||
if retriever_id not in retriever_arguments:
|
||||
current_app.logger.error(f"Missing arguments for retriever {retriever_id}")
|
||||
continue
|
||||
|
||||
# Get the retriever's arguments and update the query
|
||||
current_retriever_args = retriever_arguments[retriever_id]
|
||||
if isinstance(retriever_arguments[retriever_id], RetrieverArguments):
|
||||
updated_args = current_retriever_args.model_dump()
|
||||
updated_args['query'] = arguments.query
|
||||
updated_args['language'] = arguments.language
|
||||
retriever_args = RetrieverArguments(**updated_args)
|
||||
else:
|
||||
# Create a new RetrieverArguments instance from the dictionary
|
||||
current_retriever_args['query'] = arguments.query
|
||||
retriever_args = RetrieverArguments(**current_retriever_args)
|
||||
|
||||
# Each retriever gets its own specific arguments
|
||||
retriever_result = retriever.retrieve(retriever_args)
|
||||
all_context.extend(retriever_result)
|
||||
|
||||
# Sort by similarity if available and get unique contexts
|
||||
all_context.sort(key=lambda x: x.similarity, reverse=True)
|
||||
unique_contexts = []
|
||||
seen_chunks = set()
|
||||
for ctx in all_context:
|
||||
if ctx.chunk not in seen_chunks:
|
||||
unique_contexts.append(ctx)
|
||||
seen_chunks.add(ctx.chunk)
|
||||
|
||||
self.log_tuning("Context retrieval completed", {
|
||||
"total_contexts": len(all_context),
|
||||
"unique_contexts": len(unique_contexts),
|
||||
"average_similarity": sum(ctx.similarity for ctx in unique_contexts) / len(
|
||||
unique_contexts) if unique_contexts else 0
|
||||
})
|
||||
|
||||
# Prepare context for LLM
|
||||
formatted_context = "\n\n".join([
|
||||
f"SOURCE: {ctx.metadata.document_id}\n{ctx.chunk}\n\n"
|
||||
for ctx in unique_contexts
|
||||
])
|
||||
|
||||
# Return document_ids for citations
|
||||
citations = [ctx.metadata.document_id for ctx in unique_contexts]
|
||||
|
||||
self.log_tuning("Context Retrieval Results",
|
||||
{"Formatted Context": formatted_context,
|
||||
"Citations": citations})
|
||||
|
||||
return formatted_context, citations
|
||||
@@ -1,21 +0,0 @@
|
||||
from typing import Dict, Type
|
||||
from .base import BaseSpecialist
|
||||
|
||||
|
||||
class SpecialistRegistry:
|
||||
"""Registry for specialist types"""
|
||||
|
||||
_registry: Dict[str, Type[BaseSpecialist]] = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, specialist_type: str, specialist_class: Type[BaseSpecialist]):
|
||||
"""Register a new specialist type"""
|
||||
cls._registry[specialist_type] = specialist_class
|
||||
|
||||
@classmethod
|
||||
def get_specialist_class(cls, specialist_type: str) -> Type[BaseSpecialist]:
|
||||
"""Get the specialist class for a given type"""
|
||||
if specialist_type not in cls._registry:
|
||||
raise ValueError(f"Unknown specialist type: {specialist_type}")
|
||||
return cls._registry[specialist_type]
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Dict, Any
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from config.type_defs.specialist_types import SPECIALIST_TYPES
|
||||
from eveai_chat_workers.retrievers.retriever_typing import RetrieverArguments
|
||||
from common.extensions import cache_manager
|
||||
|
||||
|
||||
class SpecialistArguments(BaseModel):
|
||||
@@ -10,6 +10,7 @@ class SpecialistArguments(BaseModel):
|
||||
based on SPECIALIST_TYPES configuration.
|
||||
"""
|
||||
type: str = Field(..., description="Type of specialist (e.g. STANDARD_RAG)")
|
||||
type_version: str = Field(..., description="Type version (e.g. 1.0)")
|
||||
retriever_arguments: Dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Arguments for each retriever, keyed by retriever ID"
|
||||
@@ -23,7 +24,7 @@ class SpecialistArguments(BaseModel):
|
||||
@model_validator(mode='after')
|
||||
def validate_required_arguments(self) -> 'SpecialistArguments':
|
||||
"""Validate that all required arguments for this specialist type are present"""
|
||||
specialist_config = SPECIALIST_TYPES.get(self.type)
|
||||
specialist_config = cache_manager.specialists_config_cache.get_config(self.type, self.type_version)
|
||||
if not specialist_config:
|
||||
raise ValueError(f"Unknown specialist type: {self.type}")
|
||||
|
||||
@@ -44,7 +45,7 @@ class SpecialistArguments(BaseModel):
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def create(cls, type_name: str, specialist_args: Dict[str, Any],
|
||||
def create(cls, type_name: str, type_version: str, specialist_args: Dict[str, Any],
|
||||
retriever_args: Dict[str, Dict[str, Any]]) -> 'SpecialistArguments':
|
||||
"""
|
||||
Factory method to create SpecialistArguments with validated retriever arguments
|
||||
@@ -63,12 +64,15 @@ class SpecialistArguments(BaseModel):
|
||||
# Ensure type is included in retriever arguments
|
||||
if 'type' not in args:
|
||||
raise ValueError(f"Retriever arguments for {retriever_id} must include 'type'")
|
||||
if 'type_version' not in args:
|
||||
raise ValueError(f"Retriever arguments for {retriever_id} must include 'type_version'")
|
||||
|
||||
validated_retriever_args[retriever_id] = RetrieverArguments(**args)
|
||||
|
||||
# Combine everything into the specialist arguments
|
||||
return cls(
|
||||
type=type_name,
|
||||
type_version=type_version,
|
||||
**specialist_args,
|
||||
retriever_arguments=validated_retriever_args
|
||||
)
|
||||
@@ -80,6 +84,7 @@ class SpecialistResult(BaseModel):
|
||||
SPECIALIST_TYPES configuration.
|
||||
"""
|
||||
type: str = Field(..., description="Type of specialist (e.g. STANDARD_RAG)")
|
||||
type_version: str = Field(..., description="Type version (e.g. 1.0)")
|
||||
|
||||
# Allow any additional fields
|
||||
model_config = {
|
||||
@@ -89,9 +94,9 @@ class SpecialistResult(BaseModel):
|
||||
@model_validator(mode='after')
|
||||
def validate_required_results(self) -> 'SpecialistResult':
|
||||
"""Validate that all required result fields for this specialist type are present"""
|
||||
specialist_config = SPECIALIST_TYPES.get(self.type)
|
||||
specialist_config = cache_manager.specialists_config_cache.get_config(self.type, self.type_version)
|
||||
if not specialist_config:
|
||||
raise ValueError(f"Unknown specialist type: {self.type}")
|
||||
raise ValueError(f"Unknown specialist type: {self.type}, {self.type_version}")
|
||||
|
||||
# Check required results from configuration
|
||||
required_results = specialist_config.get('results', {})
|
||||
@@ -117,12 +122,13 @@ class SpecialistResult(BaseModel):
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def create_for_type(cls, specialist_type: str, **results) -> 'SpecialistResult':
|
||||
def create_for_type(cls, specialist_type: str, specialist_type_version: str, **results) -> 'SpecialistResult':
|
||||
"""
|
||||
Factory method to create a type-specific result
|
||||
|
||||
Args:
|
||||
specialist_type: The type of specialist (e.g., 'STANDARD_RAG')
|
||||
specialist_type_version: The type of specialist (e.g., '1.0')
|
||||
**results: The result values to include
|
||||
|
||||
Returns:
|
||||
@@ -132,6 +138,7 @@ class SpecialistResult(BaseModel):
|
||||
For STANDARD_RAG:
|
||||
result = SpecialistResult.create_for_type(
|
||||
'STANDARD_RAG',
|
||||
'1.0',
|
||||
answer="The answer text",
|
||||
citations=["doc1", "doc2"],
|
||||
insufficient_info=False
|
||||
@@ -139,6 +146,7 @@ class SpecialistResult(BaseModel):
|
||||
"""
|
||||
# Add the type to the results
|
||||
results['type'] = specialist_type
|
||||
results['type_version'] = specialist_type_version
|
||||
|
||||
# Create and validate the result
|
||||
return cls(**results)
|
||||
return cls(**results)
|
||||
|
||||
Reference in New Issue
Block a user