- Revisiting RAG_SPECIALIST

- Adapt Catalogs & Retrievers to use specific types, removing tagging_fields
- Adding CrewAI Implementation Guide
This commit is contained in:
Josako
2025-07-08 15:54:16 +02:00
parent 33b5742d2f
commit 509ee95d81
32 changed files with 997 additions and 825 deletions

View File

@@ -1,6 +0,0 @@
# Import all specialist implementations here to ensure registration
from . import standard_rag
from . import dossier_retriever
# List of all available specialist implementations
__all__ = ['standard_rag', 'dossier_retriever']

View File

@@ -1,3 +1,5 @@
import importlib
import json
from abc import ABC, abstractmethod, abstractproperty
from typing import Dict, Any, List, Optional, Tuple
@@ -5,7 +7,7 @@ from flask import current_app
from sqlalchemy import func, or_, desc
from sqlalchemy.exc import SQLAlchemyError
from common.extensions import db
from common.extensions import db, cache_manager
from common.models.document import Document, DocumentVersion, Catalog, Retriever
from common.utils.model_utils import get_embedding_model_and_class
from eveai_chat_workers.retrievers.retriever_typing import RetrieverResult, RetrieverArguments, RetrieverMetadata
@@ -19,15 +21,23 @@ class BaseRetriever(ABC):
self.tenant_id = tenant_id
self.retriever_id = retriever_id
self.retriever = Retriever.query.get_or_404(retriever_id)
self.tuning = False
self.catalog_id = self.retriever.catalog_id
self.tuning = self.retriever.tuning
self.tuning_logger = None
self._setup_tuning_logger()
self.embedding_model, self.embedding_model_class = (
get_embedding_model_and_class(tenant_id=tenant_id, catalog_id=self.catalog_id))
@property
@abstractmethod
def type(self) -> str:
"""The type of the retriever"""
raise NotImplementedError
@abstractmethod
def type_version(self) -> str:
"""The type version of the retriever"""
raise NotImplementedError
def _setup_tuning_logger(self):
try:
@@ -43,6 +53,32 @@ class BaseRetriever(ABC):
current_app.logger.error(f"Failed to setup tuning logger: {str(e)}")
raise
def _parse_metadata(self, metadata: Any) -> Dict[str, Any]:
"""
Parse metadata ensuring it's a dictionary
Args:
metadata: Input metadata which could be string, dict, or None
Returns:
Dict[str, Any]: Parsed metadata as dictionary
"""
if metadata is None:
return {}
if isinstance(metadata, dict):
return metadata
if isinstance(metadata, str):
try:
return json.loads(metadata)
except json.JSONDecodeError:
current_app.logger.warning(f"Failed to parse metadata JSON string: {metadata}")
return {}
current_app.logger.warning(f"Unexpected metadata type: {type(metadata)}")
return {}
def log_tuning(self, message: str, data: Dict[str, Any] = None) -> None:
if self.tuning and self.tuning_logger:
try:
@@ -50,31 +86,6 @@ class BaseRetriever(ABC):
except Exception as e:
current_app.logger.error(f"Processor: Error in tuning logging: {e}")
def setup_standard_retrieval_params(self) -> Tuple[Any, Any, Any, float, int]:
"""
Set up common parameters needed for standard retrieval functionality
Returns:
Tuple containing:
- embedding_model: The model to use for embeddings
- embedding_model_class: The class for storing embeddings
- catalog_id: ID of the catalog
- similarity_threshold: Threshold for similarity matching
- k: Maximum number of results to return
"""
catalog_id = self.retriever.catalog_id
catalog = Catalog.query.get_or_404(catalog_id)
embedding_model = "mistral.mistral-embed"
embedding_model, embedding_model_class = get_embedding_model_and_class(
self.tenant_id, catalog_id, embedding_model
)
similarity_threshold = self.retriever.configuration.get('es_similarity_threshold', 0.3)
k = self.retriever.configuration.get('es_k', 8)
return embedding_model, embedding_model_class, catalog_id, similarity_threshold, k
@abstractmethod
def retrieve(self, arguments: RetrieverArguments) -> List[RetrieverResult]:
"""
@@ -87,3 +98,16 @@ class BaseRetriever(ABC):
List[Dict[str, Any]]: List of retrieved documents/content
"""
raise NotImplementedError
def get_retriever_class(retriever_type: str, type_version: str):
major_minor = '_'.join(type_version.split('.')[:2])
retriever_config = cache_manager.retrievers_config_cache.get_config(retriever_type, type_version)
partner = retriever_config.get("partner", None)
if partner:
module_path = f"eveai_chat_workers.retrievers.{partner}.{retriever_type}.{major_minor}"
else:
module_path = f"eveai_chat_workers.retrievers.globals.{retriever_type}.{major_minor}"
current_app.logger.debug(f"Importing retriever class from {module_path}")
module = importlib.import_module(module_path)
return module.RetrieverExecutor

View File

@@ -1,374 +0,0 @@
"""
DossierRetriever implementation that adds metadata filtering to retrieval
"""
import json
from datetime import datetime as dt, date, timezone as tz
from typing import Dict, Any, List, Optional, Union, Tuple
from sqlalchemy import func, or_, desc, and_, text, cast, JSON, String, Integer, Float, Boolean, DateTime
from sqlalchemy.sql import expression
from sqlalchemy.exc import SQLAlchemyError
from flask import current_app
from common.extensions import db
from common.models.document import Document, DocumentVersion, Catalog, Retriever
from common.utils.model_utils import get_embedding_model_and_class
from .base import BaseRetriever
from .registry import RetrieverRegistry
from .retriever_typing import RetrieverArguments, RetrieverResult, RetrieverMetadata
class DossierRetriever(BaseRetriever):
"""
Dossier Retriever implementation that adds metadata filtering
to standard retrieval functionality
"""
def __init__(self, tenant_id: int, retriever_id: int):
super().__init__(tenant_id, retriever_id)
# Set up standard retrieval parameters
self.embedding_model, self.embedding_model_class, self.catalog_id, self.similarity_threshold, self.k = self.setup_standard_retrieval_params()
# Dossier-specific configuration
self.tagging_fields_filter = self.retriever.configuration.get('tagging_fields_filter', {})
self.dynamic_arguments = self.retriever.configuration.get('dynamic_arguments', {})
self.log_tuning("Dossier retriever initialized", {
"tagging_fields_filter": self.tagging_fields_filter,
"dynamic_arguments": self.dynamic_arguments,
"similarity_threshold": self.similarity_threshold,
"k": self.k
})
@property
def type(self) -> str:
return "DOSSIER_RETRIEVER"
def _parse_metadata(self, metadata: Any) -> Dict[str, Any]:
"""
Parse metadata ensuring it's a dictionary
Args:
metadata: Input metadata which could be string, dict, or None
Returns:
Dict[str, Any]: Parsed metadata as dictionary
"""
if metadata is None:
return {}
if isinstance(metadata, dict):
return metadata
if isinstance(metadata, str):
try:
return json.loads(metadata)
except json.JSONDecodeError:
current_app.logger.warning(f"Failed to parse metadata JSON string: {metadata}")
return {}
current_app.logger.warning(f"Unexpected metadata type: {type(metadata)}")
return {}
def _apply_metadata_filter(self, query_obj, arguments: RetrieverArguments):
"""
Apply metadata filters to the query based on tagging_fields_filter configuration
Args:
query_obj: SQLAlchemy query object
arguments: Retriever arguments (for variable substitution)
Returns:
Modified SQLAlchemy query object
"""
if not self.tagging_fields_filter:
return query_obj
# Get dynamic argument values
dynamic_values = {}
for arg_name, arg_config in self.dynamic_arguments.items():
if hasattr(arguments, arg_name):
dynamic_values[arg_name] = getattr(arguments, arg_name)
# Build the filter
filter_condition = self._build_filter_condition(self.tagging_fields_filter, dynamic_values)
if filter_condition is not None:
query_obj = query_obj.filter(filter_condition)
self.log_tuning("Applied metadata filter", {
"filter_sql": str(filter_condition),
"dynamic_values": dynamic_values
})
return query_obj
def _build_filter_condition(self, filter_def: Dict[str, Any], dynamic_values: Dict[str, Any]) -> Optional[
expression.BinaryExpression]:
"""
Recursively build SQLAlchemy filter condition from filter definition
Args:
filter_def: Filter definition (logical group or field condition)
dynamic_values: Values for dynamic variable substitution
Returns:
SQLAlchemy expression or None if invalid
"""
# Handle logical groups (AND, OR, NOT)
if 'logical' in filter_def:
logical_op = filter_def['logical'].lower()
subconditions = [
self._build_filter_condition(cond, dynamic_values)
for cond in filter_def.get('conditions', [])
]
# Remove None values
subconditions = [c for c in subconditions if c is not None]
if not subconditions:
return None
if logical_op == 'and':
return and_(*subconditions)
elif logical_op == 'or':
return or_(*subconditions)
elif logical_op == 'not':
if len(subconditions) == 1:
return ~subconditions[0]
else:
# NOT should have exactly one condition
current_app.logger.warning(f"NOT operator requires exactly one condition, got {len(subconditions)}")
return None
else:
current_app.logger.warning(f"Unknown logical operator: {logical_op}")
return None
# Handle field conditions
elif 'field' in filter_def and 'operator' in filter_def and 'value' in filter_def:
field_name = filter_def['field']
operator = filter_def['operator'].lower()
value = self._resolve_value(filter_def['value'], dynamic_values, filter_def.get('default'))
# Skip if we couldn't resolve the value
if value is None and operator not in ['is_null', 'is_not_null']:
return None
# Create the field expression to match JSON data
field_expr = cast(DocumentVersion.catalog_properties['tagging_fields'][field_name], String)
# Apply the appropriate operator
return self._apply_operator(field_expr, operator, value, filter_def)
else:
current_app.logger.warning(f"Invalid filter definition: {filter_def}")
return None
def _resolve_value(self, value_def: Any, dynamic_values: Dict[str, Any], default: Any = None) -> Any:
"""
Resolve a value definition, handling variables and defaults
Args:
value_def: Value definition (could be literal, variable reference, or list)
dynamic_values: Values for dynamic variable substitution
default: Default value if variable not found
Returns:
Resolved value
"""
# Handle lists (recursively resolve each item)
if isinstance(value_def, list):
return [self._resolve_value(item, dynamic_values) for item in value_def]
# Handle variable references (strings starting with $)
if isinstance(value_def, str) and value_def.startswith('$'):
var_name = value_def[1:] # Remove $ prefix
if var_name in dynamic_values:
return dynamic_values[var_name]
else:
# Use default if provided
return default
# Return literal values as-is
return value_def
def _apply_operator(self, field_expr, operator: str, value: Any, filter_def: Dict[str, Any]) -> Optional[
expression.BinaryExpression]:
"""
Apply the specified operator to create a filter condition
Args:
field_expr: SQLAlchemy field expression
operator: Operator to apply
value: Value to compare against
filter_def: Original filter definition (for additional options)
Returns:
SQLAlchemy expression
"""
try:
# String operators
if operator == 'eq':
return field_expr == str(value)
elif operator == 'neq':
return field_expr != str(value)
elif operator == 'contains':
return field_expr.contains(str(value))
elif operator == 'not_contains':
return ~field_expr.contains(str(value))
elif operator == 'starts_with':
return field_expr.startswith(str(value))
elif operator == 'ends_with':
return field_expr.endswith(str(value))
elif operator == 'in':
return field_expr.in_([str(v) for v in value])
elif operator == 'not_in':
return ~field_expr.in_([str(v) for v in value])
elif operator == 'regex' or operator == 'not_regex':
# PostgreSQL regex using ~ or !~ operator
case_insensitive = filter_def.get('case_insensitive', False)
regex_op = '~*' if case_insensitive else '~'
if operator == 'not_regex':
regex_op = '!~*' if case_insensitive else '!~'
return text(
f"{field_expr.compile(compile_kwargs={'literal_binds': True})} {regex_op} :regex_value").bindparams(
regex_value=str(value))
# Numeric/Date operators
elif operator == 'gt':
return cast(field_expr, Float) > float(value)
elif operator == 'gte':
return cast(field_expr, Float) >= float(value)
elif operator == 'lt':
return cast(field_expr, Float) < float(value)
elif operator == 'lte':
return cast(field_expr, Float) <= float(value)
elif operator == 'between':
if len(value) == 2:
return cast(field_expr, Float).between(float(value[0]), float(value[1]))
else:
current_app.logger.warning(f"BETWEEN operator requires exactly two values, got {len(value)}")
return None
elif operator == 'not_between':
if len(value) == 2:
return ~cast(field_expr, Float).between(float(value[0]), float(value[1]))
else:
current_app.logger.warning(f"NOT_BETWEEN operator requires exactly two values, got {len(value)}")
return None
# Null checking
elif operator == 'is_null':
return field_expr.is_(None)
elif operator == 'is_not_null':
return field_expr.isnot(None)
else:
current_app.logger.warning(f"Unknown operator: {operator}")
return None
except (ValueError, TypeError) as e:
current_app.logger.error(f"Error applying operator {operator} with value {value}: {e}")
return None
def retrieve(self, arguments: RetrieverArguments) -> List[RetrieverResult]:
"""
Retrieve documents based on query with added metadata filtering
Args:
arguments: Validated RetrieverArguments containing at minimum:
- query: str - The search query
Returns:
List[RetrieverResult]: List of retrieved documents with similarity scores
"""
try:
query = arguments.query
# Get query embedding
query_embedding = self.embedding_model.embed_query(query)
# Get the appropriate embedding database model
db_class = self.embedding_model_class
# Get current date for validity checks
current_date = dt.now(tz=tz.utc).date()
# Create subquery for latest versions
subquery = (
db.session.query(
DocumentVersion.doc_id,
func.max(DocumentVersion.id).label('latest_version_id')
)
.group_by(DocumentVersion.doc_id)
.subquery()
)
# Main query
query_obj = (
db.session.query(
db_class,
(1 - db_class.embedding.cosine_distance(query_embedding)).label('similarity')
)
.join(DocumentVersion, db_class.doc_vers_id == DocumentVersion.id)
.join(Document, DocumentVersion.doc_id == Document.id)
.join(subquery, DocumentVersion.id == subquery.c.latest_version_id)
.filter(
or_(Document.valid_from.is_(None), func.date(Document.valid_from) <= current_date),
or_(Document.valid_to.is_(None), func.date(Document.valid_to) >= current_date),
(1 - db_class.embedding.cosine_distance(query_embedding)) > self.similarity_threshold,
Document.catalog_id == self.catalog_id
)
)
# Apply metadata filtering
query_obj = self._apply_metadata_filter(query_obj, arguments)
# Apply ordering and limit
query_obj = query_obj.order_by(desc('similarity')).limit(self.k)
# Execute query
results = query_obj.all()
# Transform results into standard format
processed_results = []
for doc, similarity in results:
# Parse user_metadata to ensure it's a dictionary
user_metadata = self._parse_metadata(doc.document_version.user_metadata)
processed_results.append(
RetrieverResult(
id=doc.id,
chunk=doc.chunk,
similarity=float(similarity),
metadata=RetrieverMetadata(
document_id=doc.document_version.doc_id,
version_id=doc.document_version.id,
document_name=doc.document_version.document.name,
user_metadata=user_metadata,
)
)
)
# Log the retrieval
if self.tuning:
compiled_query = str(query_obj.statement.compile(
compile_kwargs={"literal_binds": True} # This will include the actual values in the SQL
))
self.log_tuning('retrieve', {
"arguments": arguments.model_dump(),
"similarity_threshold": self.similarity_threshold,
"k": self.k,
"query": compiled_query,
"results_count": len(results),
"processed_results_count": len(processed_results),
})
return processed_results
except SQLAlchemyError as e:
current_app.logger.error(f'Error in Dossier retrieval: {e}')
db.session.rollback()
raise
except Exception as e:
current_app.logger.error(f'Unexpected error in Dossier retrieval: {e}')
raise
# Register the retriever type
RetrieverRegistry.register("DOSSIER_RETRIEVER", DossierRetriever)

View File

@@ -11,54 +11,25 @@ from common.models.document import Document, DocumentVersion, Catalog, Retriever
from common.models.user import Tenant
from common.utils.datetime_utils import get_date_in_timezone
from common.utils.model_utils import get_embedding_model_and_class
from .base import BaseRetriever
from eveai_chat_workers.retrievers.base_retriever import BaseRetriever
from .registry import RetrieverRegistry
from .retriever_typing import RetrieverArguments, RetrieverResult, RetrieverMetadata
from eveai_chat_workers.retrievers.retriever_typing import RetrieverArguments, RetrieverResult, RetrieverMetadata
class StandardRAGRetriever(BaseRetriever):
class RetrieverExecutor(BaseRetriever):
"""Standard RAG retriever implementation"""
def __init__(self, tenant_id: int, retriever_id: int):
super().__init__(tenant_id, retriever_id)
# Set up standard retrieval parameters
self.embedding_model, self.embedding_model_class, self.catalog_id, self.similarity_threshold, self.k = self.setup_standard_retrieval_params()
self.log_tuning("Standard RAG retriever initialized", {
"similarity_threshold": self.similarity_threshold,
"k": self.k
})
@property
def type(self) -> str:
return "STANDARD_RAG"
def _parse_metadata(self, metadata: Any) -> Dict[str, Any]:
"""
Parse metadata ensuring it's a dictionary
Args:
metadata: Input metadata which could be string, dict, or None
Returns:
Dict[str, Any]: Parsed metadata as dictionary
"""
if metadata is None:
return {}
if isinstance(metadata, dict):
return metadata
if isinstance(metadata, str):
try:
return json.loads(metadata)
except json.JSONDecodeError:
current_app.logger.warning(f"Failed to parse metadata JSON string: {metadata}")
return {}
current_app.logger.warning(f"Unexpected metadata type: {type(metadata)}")
return {}
@property
def type_version(self) -> str:
return "1.0"
def retrieve(self, arguments: RetrieverArguments) -> List[RetrieverResult]:
"""
@@ -72,10 +43,10 @@ class StandardRAGRetriever(BaseRetriever):
List[RetrieverResult]: List of retrieved documents with similarity scores
"""
try:
query = arguments.query
question = arguments.question
# Get query embedding
query_embedding = self.embedding_model.embed_query(query)
query_embedding = self.embedding_model.embed_query(question)
# Get the appropriate embedding database model
db_class = self.embedding_model_class
@@ -93,6 +64,9 @@ class StandardRAGRetriever(BaseRetriever):
.subquery()
)
similarity_threshold = self.retriever.configuration.get('es_similarity_threshold', 0.3)
k = self.retriever.configuration.get('es_k', 8)
# Main query
query_obj = (
db.session.query(
@@ -106,11 +80,11 @@ class StandardRAGRetriever(BaseRetriever):
.filter(
or_(Document.valid_from.is_(None), func.date(Document.valid_from) <= current_date),
or_(Document.valid_to.is_(None), func.date(Document.valid_to) >= current_date),
(1 - db_class.embedding.cosine_distance(query_embedding)) > self.similarity_threshold,
(1 - db_class.embedding.cosine_distance(query_embedding)) > similarity_threshold,
Document.catalog_id == self.catalog_id
)
.order_by(desc('similarity'))
.limit(self.k)
.limit(k)
)
results = query_obj.all()
@@ -160,5 +134,3 @@ class StandardRAGRetriever(BaseRetriever):
raise
# Register the retriever type
RetrieverRegistry.register("STANDARD_RAG", StandardRAGRetriever)

View File

@@ -1,20 +0,0 @@
from typing import Dict, Type
from .base import BaseRetriever
class RetrieverRegistry:
"""Registry for retriever types"""
_registry: Dict[str, Type[BaseRetriever]] = {}
@classmethod
def register(cls, retriever_type: str, retriever_class: Type[BaseRetriever]):
"""Register a new retriever type"""
cls._registry[retriever_type] = retriever_class
@classmethod
def get_retriever_class(cls, retriever_type: str) -> Type[BaseRetriever]:
"""Get the retriever class for a given type"""
if retriever_type not in cls._registry:
raise ValueError(f"Unknown retriever type: {retriever_type}")
return cls._registry[retriever_type]

View File

@@ -34,6 +34,8 @@ class RetrieverArguments(BaseModel):
type: str = Field(..., description="Type of retriever (e.g. STANDARD_RAG)")
type_version: str = Field(..., description="Version of retriever type (e.g. 1.0)")
question: str = Field(..., description="Question to retrieve answers for")
# Allow any additional fields
model_config = {
"extra": "allow"

View File

@@ -8,8 +8,7 @@ from common.models.interaction import SpecialistRetriever, Specialist
from common.models.user import Tenant
from common.utils.execution_progress import ExecutionProgressTracker
from config.logging_config import TuningLogger
from eveai_chat_workers.retrievers.base import BaseRetriever
from eveai_chat_workers.retrievers.registry import RetrieverRegistry
from eveai_chat_workers.retrievers.base_retriever import BaseRetriever, get_retriever_class
from eveai_chat_workers.specialists.specialist_typing import SpecialistArguments, SpecialistResult
@@ -56,20 +55,16 @@ class BaseSpecialistExecutor(ABC):
for spec_retriever in specialist_retrievers:
# Get retriever configuration from database
retriever = spec_retriever.retriever
retriever_class = RetrieverRegistry.get_retriever_class(retriever.type)
retriever_executor_class = get_retriever_class(retriever.type, retriever.type_version)
self.log_tuning("_initialize_retrievers", {
"Retriever id": spec_retriever.retriever_id,
"Retriever Type": retriever.type,
"Retriever Class": str(retriever_class),
"Retriever Version": retriever.type_version,
})
retriever_executor = retriever_executor_class(self.tenant_id, spec_retriever.retriever_id)
# Initialize retriever with its configuration
retrievers.append(
retriever_class(
tenant_id=self.tenant_id,
retriever_id=retriever.id,
)
)
retrievers.append(retriever_executor)
return retrievers
@@ -144,7 +139,7 @@ def get_specialist_class(specialist_type: str, type_version: str):
if partner:
module_path = f"eveai_chat_workers.specialists.{partner}.{specialist_type}.{major_minor}"
else:
module_path = f"eveai_chat_workers.specialists.{specialist_type}.{major_minor}"
module_path = f"eveai_chat_workers.specialists.globals.{specialist_type}.{major_minor}"
current_app.logger.debug(f"Importing specialist class from {module_path}")
module = importlib.import_module(module_path)
return module.SpecialistExecutor

View File

@@ -308,12 +308,12 @@ class CrewAIBaseSpecialistExecutor(BaseSpecialistExecutor):
current_retriever_args = retriever_arguments[retriever_id]
if isinstance(retriever_arguments[retriever_id], RetrieverArguments):
updated_args = current_retriever_args.model_dump()
updated_args['query'] = arguments.query
updated_args['question'] = arguments.question
updated_args['language'] = arguments.language
retriever_args = RetrieverArguments(**updated_args)
else:
# Create a new RetrieverArguments instance from the dictionary
current_retriever_args['query'] = arguments.query
current_retriever_args['query'] = arguments.question
retriever_args = RetrieverArguments(**current_retriever_args)
# Each retriever gets its own specific arguments

View File

@@ -6,6 +6,7 @@ from crewai.flow.flow import start, listen, and_
from flask import current_app
from pydantic import BaseModel, Field
from common.services.utils.translation_services import TranslationServices
from common.utils.business_event_context import current_event
from eveai_chat_workers.retrievers.retriever_typing import RetrieverArguments
from eveai_chat_workers.specialists.crewai_base_specialist import CrewAIBaseSpecialistExecutor
@@ -13,6 +14,10 @@ from eveai_chat_workers.specialists.specialist_typing import SpecialistResult, S
from eveai_chat_workers.outputs.globals.rag.rag_v1_0 import RAGOutput
from eveai_chat_workers.specialists.crewai_base_classes import EveAICrewAICrew, EveAICrewAIFlow, EveAIFlowState
INSUFFICIENT_INFORMATION_MESSAGE = (
"We do not have the necessary information to provide you with the requested answers. "
"Please accept our apologies. Don't hesitate to ask other questions, and I'll do my best to answer them.")
class SpecialistExecutor(CrewAIBaseSpecialistExecutor):
"""
@@ -40,6 +45,9 @@ class SpecialistExecutor(CrewAIBaseSpecialistExecutor):
def _config_pydantic_outputs(self):
self._add_pydantic_output("rag_task", RAGOutput, "rag_output")
def _config_state_result_relations(self):
self._add_state_result_relation("rag_output")
def _instantiate_specialist(self):
verbose = self.tuning
@@ -61,40 +69,84 @@ class SpecialistExecutor(CrewAIBaseSpecialistExecutor):
def execute(self, arguments: SpecialistArguments, formatted_context, citations) -> SpecialistResult:
self.log_tuning("RAG Specialist execution started", {})
flow_inputs = {
"language": arguments.language,
"query": arguments.query,
"context": formatted_context,
"history": self.formatted_history,
"name": self.specialist.configuration.get('name', ''),
"company": self.specialist.configuration.get('company', ''),
}
# crew_results = self.rag_crew.kickoff(inputs=flow_inputs)
# current_app.logger.debug(f"Test Crew Output received: {crew_results}")
flow_results = self.flow.kickoff(inputs=flow_inputs)
current_app.logger.debug(f"Arguments: {arguments.model_dump()}")
current_app.logger.debug(f"Formatted Context: {formatted_context}")
current_app.logger.debug(f"Formatted History: {self._formatted_history}")
current_app.logger.debug(f"Cached Chat Session: {self._cached_session}")
flow_state = self.flow.state
if not self._cached_session.interactions:
specialist_phase = "initial"
else:
specialist_phase = self._cached_session.interactions[-1].specialist_results.get('phase', 'initial')
results = RAGSpecialistResult.create_for_type(self.type, self.type_version)
update_data = {}
if flow_state.rag_output: # Fallback
update_data["rag_output"] = flow_state.rag_output
results = None
current_app.logger.debug(f"Specialist Phase: {specialist_phase}")
results = results.model_copy(update=update_data)
match specialist_phase:
case "initial":
results = self.execute_initial_state(arguments, formatted_context, citations)
case "rag":
results = self.execute_rag_state(arguments, formatted_context, citations)
self.log_tuning(f"RAG Specialist execution ended", {"Results": results.model_dump()})
return results
def execute_initial_state(self, arguments: SpecialistArguments, formatted_context, citations) -> SpecialistResult:
self.log_tuning("RAG Specialist initial_state execution started", {})
welcome_message = self.specialist.configuration.get('welcome_message', 'Welcome! You can start asking questions')
welcome_message = TranslationServices.translate(self.tenant_id, welcome_message, arguments.language)
self.flow.state.answer = welcome_message
self.flow.state.phase = "rag"
results = RAGSpecialistResult.create_for_type(self.type, self.type_version)
return results
def execute_rag_state(self, arguments: SpecialistArguments, formatted_context, citations) -> SpecialistResult:
self.log_tuning("RAG Specialist rag_state execution started", {})
insufficient_info_message = TranslationServices.translate(self.tenant_id,
INSUFFICIENT_INFORMATION_MESSAGE,
arguments.language)
if formatted_context:
flow_inputs = {
"language": arguments.language,
"question": arguments.question,
"context": formatted_context,
"history": self.formatted_history,
"name": self.specialist.configuration.get('name', ''),
"welcome_message": self.specialist.configuration.get('welcome_message', '')
}
flow_results = self.flow.kickoff(inputs=flow_inputs)
if flow_results.rag_output.insufficient_info:
flow_results.rag_output.answer = insufficient_info_message
rag_output = flow_results.rag_output
else:
rag_output = RAGOutput(answer=insufficient_info_message, insufficient_info=True)
self.flow.state.rag_output = rag_output
self.flow.state.answer = rag_output.answer
self.flow.state.phase = "rag"
results = RAGSpecialistResult.create_for_type(self.type, self.type_version)
class RAGSpecialistInput(BaseModel):
language: Optional[str] = Field(None, alias="language")
query: Optional[str] = Field(None, alias="query")
question: Optional[str] = Field(None, alias="question")
context: Optional[str] = Field(None, alias="context")
citations: Optional[List[int]] = Field(None, alias="citations")
history: Optional[str] = Field(None, alias="history")
name: Optional[str] = Field(None, alias="name")
company: Optional[str] = Field(None, alias="company")
welcome_message: Optional[str] = Field(None, alias="welcome_message")
class RAGSpecialistResult(SpecialistResult):

View File

@@ -0,0 +1,197 @@
import json
from os import wait
from typing import Optional, List
from crewai.flow.flow import start, listen, and_
from flask import current_app
from pydantic import BaseModel, Field
from common.services.utils.translation_services import TranslationServices
from common.utils.business_event_context import current_event
from eveai_chat_workers.retrievers.retriever_typing import RetrieverArguments
from eveai_chat_workers.specialists.crewai_base_specialist import CrewAIBaseSpecialistExecutor
from eveai_chat_workers.specialists.specialist_typing import SpecialistResult, SpecialistArguments
from eveai_chat_workers.outputs.globals.rag.rag_v1_0 import RAGOutput
from eveai_chat_workers.specialists.crewai_base_classes import EveAICrewAICrew, EveAICrewAIFlow, EveAIFlowState
INSUFFICIENT_INFORMATION_MESSAGE = (
"We do not have the necessary information to provide you with the requested answers. "
"Please accept our apologies. Don't hesitate to ask other questions, and I'll do my best to answer them.")
class SpecialistExecutor(CrewAIBaseSpecialistExecutor):
"""
type: RAG_SPECIALIST
type_version: 1.0
RAG Specialist Executor class
"""
def __init__(self, tenant_id, specialist_id, session_id, task_id, **kwargs):
self.rag_crew = None
super().__init__(tenant_id, specialist_id, session_id, task_id)
@property
def type(self) -> str:
return "RAG_SPECIALIST"
@property
def type_version(self) -> str:
return "1.1"
def _config_task_agents(self):
self._add_task_agent("rag_task", "rag_agent")
def _config_pydantic_outputs(self):
self._add_pydantic_output("rag_task", RAGOutput, "rag_output")
def _config_state_result_relations(self):
self._add_state_result_relation("rag_output")
def _instantiate_specialist(self):
verbose = self.tuning
rag_agents = [self.rag_agent]
rag_tasks = [self.rag_task]
self.rag_crew = EveAICrewAICrew(
self,
"Rag Crew",
agents=rag_agents,
tasks=rag_tasks,
verbose=verbose,
)
self.flow = RAGFlow(
self,
self.rag_crew,
)
def execute(self, arguments: SpecialistArguments, formatted_context, citations) -> SpecialistResult:
self.log_tuning("RAG Specialist execution started", {})
current_app.logger.debug(f"Arguments: {arguments.model_dump()}")
current_app.logger.debug(f"Formatted Context: {formatted_context}")
current_app.logger.debug(f"Formatted History: {self._formatted_history}")
current_app.logger.debug(f"Cached Chat Session: {self._cached_session}")
if not self._cached_session.interactions:
specialist_phase = "initial"
else:
specialist_phase = self._cached_session.interactions[-1].specialist_results.get('phase', 'initial')
results = None
current_app.logger.debug(f"Specialist Phase: {specialist_phase}")
match specialist_phase:
case "initial":
results = self.execute_initial_state(arguments, formatted_context, citations)
case "rag":
results = self.execute_rag_state(arguments, formatted_context, citations)
self.log_tuning(f"RAG Specialist execution ended", {"Results": results.model_dump()})
return results
def execute_initial_state(self, arguments: SpecialistArguments, formatted_context, citations) -> SpecialistResult:
self.log_tuning("RAG Specialist initial_state execution started", {})
welcome_message = self.specialist.configuration.get('welcome_message', 'Welcome! You can start asking questions')
welcome_message = TranslationServices.translate(self.tenant_id, welcome_message, arguments.language)
self.flow.state.answer = welcome_message
self.flow.state.phase = "rag"
results = RAGSpecialistResult.create_for_type(self.type, self.type_version)
return results
def execute_rag_state(self, arguments: SpecialistArguments, formatted_context, citations) -> SpecialistResult:
self.log_tuning("RAG Specialist rag_state execution started", {})
insufficient_info_message = TranslationServices.translate(self.tenant_id,
INSUFFICIENT_INFORMATION_MESSAGE,
arguments.language)
if formatted_context:
flow_inputs = {
"language": arguments.language,
"question": arguments.question,
"context": formatted_context,
"history": self.formatted_history,
"name": self.specialist.configuration.get('name', ''),
"welcome_message": self.specialist.configuration.get('welcome_message', '')
}
flow_results = self.flow.kickoff(inputs=flow_inputs)
if flow_results.rag_output.insufficient_info:
flow_results.rag_output.answer = insufficient_info_message
rag_output = flow_results.rag_output
else:
rag_output = RAGOutput(answer=insufficient_info_message, insufficient_info=True)
self.flow.state.rag_output = rag_output
self.flow.state.answer = rag_output.answer
self.flow.state.phase = "rag"
results = RAGSpecialistResult.create_for_type(self.type, self.type_version)
class RAGSpecialistInput(BaseModel):
language: Optional[str] = Field(None, alias="language")
question: Optional[str] = Field(None, alias="question")
context: Optional[str] = Field(None, alias="context")
history: Optional[str] = Field(None, alias="history")
name: Optional[str] = Field(None, alias="name")
welcome_message: Optional[str] = Field(None, alias="welcome_message")
class RAGSpecialistResult(SpecialistResult):
rag_output: Optional[RAGOutput] = Field(None, alias="Rag Output")
class RAGFlowState(EveAIFlowState):
"""Flow state for RAG specialist that automatically updates from task outputs"""
input: Optional[RAGSpecialistInput] = None
rag_output: Optional[RAGOutput] = None
class RAGFlow(EveAICrewAIFlow[RAGFlowState]):
def __init__(self,
specialist_executor: CrewAIBaseSpecialistExecutor,
rag_crew: EveAICrewAICrew,
**kwargs):
super().__init__(specialist_executor, "RAG Specialist Flow", **kwargs)
self.specialist_executor = specialist_executor
self.rag_crew = rag_crew
self.exception_raised = False
@start()
def process_inputs(self):
return ""
@listen(process_inputs)
async def execute_rag(self):
inputs = self.state.input.model_dump()
try:
crew_output = await self.rag_crew.kickoff_async(inputs=inputs)
self.specialist_executor.log_tuning("RAG Crew Output", crew_output.model_dump())
output_pydantic = crew_output.pydantic
if not output_pydantic:
raw_json = json.loads(crew_output.raw)
output_pydantic = RAGOutput.model_validate(raw_json)
self.state.rag_output = output_pydantic
return crew_output
except Exception as e:
current_app.logger.error(f"CREW rag_crew Kickoff Error: {str(e)}")
self.exception_raised = True
raise e
async def kickoff_async(self, inputs=None):
current_app.logger.debug(f"Async kickoff {self.name}")
current_app.logger.debug(f"Inputs: {inputs}")
self.state.input = RAGSpecialistInput.model_validate(inputs)
result = await super().kickoff_async(inputs)
return self.state

View File

@@ -51,6 +51,9 @@ class SpecialistExecutor(CrewAIBaseSpecialistExecutor):
def _config_pydantic_outputs(self):
self._add_pydantic_output("traicie_get_competencies_task", Competencies, "competencies")
def _config_state_result_relations(self):
self._add_state_result_relation("competencies")
def _instantiate_specialist(self):
verbose = self.tuning
@@ -83,13 +86,9 @@ class SpecialistExecutor(CrewAIBaseSpecialistExecutor):
flow_results = self.flow.kickoff(inputs=flow_inputs)
flow_state = self.flow.state
results = RoleDefinitionSpecialistResult.create_for_type(self.type, self.type_version)
if flow_state.competencies:
results.competencies = flow_state.competencies
self.create_selection_specialist(arguments, flow_state.competencies)
self.create_selection_specialist(arguments, self.flow.state.competencies)
self.log_tuning(f"Traicie Role Definition Specialist execution ended", {"Results": results.model_dump()})

View File

@@ -55,7 +55,7 @@ class SpecialistExecutor(CrewAIBaseSpecialistExecutor):
"""
def __init__(self, tenant_id, specialist_id, session_id, task_id, **kwargs):
self.role_definition_crew = None
self.rag_crew = None
super().__init__(tenant_id, specialist_id, session_id, task_id)
@@ -407,8 +407,7 @@ class SpecialistExecutor(CrewAIBaseSpecialistExecutor):
if rag_output.rag_output.insufficient_info:
rag_output.rag_output.answer = insufficient_info_message
else:
rag_output = RAGOutput(answer=insufficient_info_message,
insufficient_info=True)
rag_output = RAGOutput(answer=insufficient_info_message, insufficient_info=True)
self.log_tuning(f"RAG Specialist execution ended", {"Results": rag_output.model_dump()})