- Revisiting RAG_SPECIALIST

- Adapt Catalogs & Retrievers to use specific types, removing tagging_fields
- Adding CrewAI Implementation Guide
This commit is contained in:
Josako
2025-07-08 15:54:16 +02:00
parent 33b5742d2f
commit 509ee95d81
32 changed files with 997 additions and 825 deletions

View File

@@ -11,6 +11,7 @@ class Catalog(db.Model):
name = db.Column(db.String(50), nullable=False, unique=True)
description = db.Column(db.Text, nullable=True)
type = db.Column(db.String(50), nullable=False, default="STANDARD_CATALOG")
type_version = db.Column(db.String(20), nullable=True, default="1.0.0")
min_chunk_size = db.Column(db.Integer, nullable=True, default=1500)
max_chunk_size = db.Column(db.Integer, nullable=True, default=2500)
@@ -26,6 +27,20 @@ class Catalog(db.Model):
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now(), onupdate=db.func.now())
updated_by = db.Column(db.Integer, db.ForeignKey(User.id))
def to_dict(self):
return {
'id': self.id,
'name': self.name,
'description': self.description,
'type': self.type,
'type_version': self.type_version,
'min_chunk_size': self.min_chunk_size,
'max_chunk_size': self.max_chunk_size,
'user_metadata': self.user_metadata,
'system_metadata': self.system_metadata,
'configuration': self.configuration,
}
class Processor(db.Model):
id = db.Column(db.Integer, primary_key=True)

View File

@@ -47,101 +47,101 @@ class TenantServices:
current_app.logger.error(f"Error associating tenant {tenant_id} with partner: {str(e)}")
raise e
@staticmethod
def get_available_types_for_tenant(tenant_id: int, config_type: str) -> Dict[str, Dict[str, str]]:
"""
Get available configuration types for a tenant based on partner relationships
@staticmethod
def get_available_types_for_tenant(tenant_id: int, config_type: str) -> Dict[str, Dict[str, str]]:
"""
Get available configuration types for a tenant based on partner relationships
Args:
tenant_id: The tenant ID
config_type: The configuration type ('specialists', 'agents', 'tasks', etc.)
Args:
tenant_id: The tenant ID
config_type: The configuration type ('specialists', 'agents', 'tasks', etc.)
Returns:
Dictionary of available types for the tenant
"""
# Get the appropriate cache handler based on config_type
cache_handler = None
if config_type == 'specialists':
cache_handler = cache_manager.specialists_types_cache
elif config_type == 'agents':
cache_handler = cache_manager.agents_types_cache
elif config_type == 'tasks':
cache_handler = cache_manager.tasks_types_cache
elif config_type == 'tools':
cache_handler = cache_manager.tools_types_cache
else:
raise ValueError(f"Unsupported config type: {config_type}")
Returns:
Dictionary of available types for the tenant
"""
# Get the appropriate cache handler based on config_type
cache_handler = None
if config_type == 'specialists':
cache_handler = cache_manager.specialists_types_cache
elif config_type == 'agents':
cache_handler = cache_manager.agents_types_cache
elif config_type == 'tasks':
cache_handler = cache_manager.tasks_types_cache
elif config_type == 'tools':
cache_handler = cache_manager.tools_types_cache
elif config_type == 'catalogs':
cache_handler = cache_manager.catalogs_types_cache
elif config_type == 'retrievers':
cache_handler = cache_manager.retrievers_types_cache
else:
raise ValueError(f"Unsupported config type: {config_type}")
# Get all types with their metadata (including partner info)
all_types = cache_handler.get_types()
# Get all types with their metadata (including partner info)
all_types = cache_handler.get_types()
# Filter to include:
# 1. Types with no partner (global)
# 2. Types with partners that have a SPECIALIST_SERVICE relationship with this tenant
available_partners = TenantServices.get_tenant_partner_names(tenant_id)
# Filter to include:
# 1. Types with no partner (global)
# 2. Types with partners that have a SPECIALIST_SERVICE relationship with this tenant
available_partners = TenantServices.get_tenant_partner_specialist_denominators(tenant_id)
available_types = {
type_id: info for type_id, info in all_types.items()
if info.get('partner') is None or info.get('partner') in available_partners
}
available_types = {
type_id: info for type_id, info in all_types.items()
if info.get('partner') is None or info.get('partner') in available_partners
}
return available_types
return available_types
@staticmethod
def get_tenant_partner_names(tenant_id: int) -> List[str]:
"""
Get names of partners that have a SPECIALIST_SERVICE relationship with this tenant
@staticmethod
def get_tenant_partner_specialist_denominators(tenant_id: int) -> List[str]:
"""
Get names of partners that have a SPECIALIST_SERVICE relationship with this tenant, that can be used for
filtering configurations.
Args:
tenant_id: The tenant ID
Args:
tenant_id: The tenant ID
Returns:
List of partner names (tenant names)
"""
# Find all PartnerTenant relationships for this tenant
partner_names = []
try:
# Get all partner services of type SPECIALIST_SERVICE
specialist_services = (
Returns:
List of partner names (tenant names)
"""
# Find all PartnerTenant relationships for this tenant
partner_service_denominators = []
try:
# Get all partner services of type SPECIALIST_SERVICE
specialist_services = (
PartnerService.query
.filter_by(type='SPECIALIST_SERVICE')
.all()
)
if not specialist_services:
return []
# Find tenant relationships with these services
partner_tenants = (
PartnerTenant.query
.filter_by(tenant_id=tenant_id)
.filter(PartnerTenant.partner_service_id.in_([svc.id for svc in specialist_services]))
.all()
)
# Get the partner names (their tenant names)
for pt in partner_tenants:
partner_service = (
PartnerService.query
.filter_by(type='SPECIALIST_SERVICE')
.all()
.filter_by(id=pt.partner_service_id)
.first()
)
if not specialist_services:
return []
if partner_service:
partner_service_denominators.append(partner_service.configuration.get("specialist_denominator", ""))
# Find tenant relationships with these services
partner_tenants = (
PartnerTenant.query
.filter_by(tenant_id=tenant_id)
.filter(PartnerTenant.partner_service_id.in_([svc.id for svc in specialist_services]))
.all()
)
except SQLAlchemyError as e:
current_app.logger.error(f"Database error retrieving partner names: {str(e)}")
# Get the partner names (their tenant names)
for pt in partner_tenants:
partner_service = (
PartnerService.query
.filter_by(id=pt.partner_service_id)
.first()
)
return partner_service_denominators
if partner_service:
partner = Partner.query.get(partner_service.partner_id)
if partner:
# Get the tenant associated with this partner
partner_tenant = Tenant.query.get(partner.tenant_id)
if partner_tenant:
partner_names.append(partner_tenant.name)
except SQLAlchemyError as e:
current_app.logger.error(f"Database error retrieving partner names: {str(e)}")
return partner_names
@staticmethod
def can_use_specialist_type(tenant_id: int, specialist_type: str) -> bool:
@staticmethod
def can_use_specialist_type(tenant_id: int, specialist_type: str) -> bool:
"""
Check if a tenant can use a specific specialist type
@@ -166,7 +166,7 @@ class TenantServices:
# If it's a partner-specific specialist, check if tenant has access
partner_name = specialist_def.get('partner')
available_partners = TenantServices.get_tenant_partner_names(tenant_id)
available_partners = TenantServices.get_tenant_partner_specialist_denominators(tenant_id)
return partner_name in available_partners

View File

@@ -332,24 +332,22 @@ class BaseConfigTypesCacheHandler(CacheHandler[Dict[str, Any]]):
"""
return isinstance(value, dict) # Cache all dictionaries
def _load_type_definitions(self) -> Dict[str, Dict[str, str]]:
def _load_type_definitions(self) -> Dict[str, Dict[str, Any]]:
"""Load type definitions from the corresponding type_defs module"""
if not self._types_module:
raise ValueError("_types_module must be set by subclass")
type_definitions = {
type_id: {
'name': info['name'],
'description': info['description'],
'partner': info.get('partner') # Include partner info if available
}
for type_id, info in self._types_module.items()
}
type_definitions = {}
for type_id, info in self._types_module.items():
# Kopieer alle velden uit de type definitie
type_definitions[type_id] = {}
for key, value in info.items():
type_definitions[type_id][key] = value
return type_definitions
def get_types(self) -> Dict[str, Dict[str, str]]:
"""Get dictionary of available types with name and description"""
def get_types(self) -> Dict[str, Dict[str, Any]]:
"""Get dictionary of available types with all defined properties"""
result = self.get(
lambda type_name: self._load_type_definitions(),
type_name=f'{self.config_type}_types',

View File

@@ -0,0 +1,19 @@
version: "1.0.0"
name: "Role Definition Catalog"
description: "A Catalog containing information specific to a specific role"
configuration:
tagging_fields:
role_identification:
type: "string"
required: true
description: "A unique identification for the role"
document_type:
type: "enum"
required: true
description: "Type of document"
allowed_values: [ "Intake", "Vacancy Text", "Additional Information" ]
document_version_configurations: ["tagging_fields"]
metadata:
author: "Josako"
date_added: "2025-07-07"
description: "A Catalog containing information specific to a specific role"

View File

@@ -3,7 +3,7 @@ name: "Standard RAG Retriever"
configuration:
es_k:
name: "es_k"
type: "int"
type: "integer"
description: "K-value to retrieve embeddings (max embeddings retrieved)"
required: true
default: 8
@@ -13,12 +13,7 @@ configuration:
description: "Similarity threshold for retrieving embeddings"
required: true
default: 0.3
arguments:
query:
name: "query"
type: "str"
description: "Query to retrieve embeddings"
required: True
arguments: {}
metadata:
author: "Josako"
date_added: "2025-01-24"

View File

@@ -0,0 +1,26 @@
version: "1.0.0"
name: "Retrieves role information for a specific role"
configuration:
es_k:
name: "es_k"
type: "integer"
description: "K-value to retrieve embeddings (max embeddings retrieved)"
required: true
default: 8
es_similarity_threshold:
name: "es_similarity_threshold"
type: "float"
description: "Similarity threshold for retrieving embeddings"
required: true
default: 0.3
arguments:
role_identification:
name: "Role Identification"
type: "string"
description: "The role information needs to be retrieved for"
required: true
metadata:
author: "Josako"
date_added: "2025-07-07"
changes: "Initial version"
description: "Retrieves role information for a specific role"

View File

@@ -1,9 +1,9 @@
version: "1.0.0"
name: "DOSSIER Retriever"
name: "Retrieves vacancy text for a specific role"
configuration:
es_k:
name: "es_k"
type: "int"
type: "integer"
description: "K-value to retrieve embeddings (max embeddings retrieved)"
required: true
default: 8
@@ -13,24 +13,19 @@ configuration:
description: "Similarity threshold for retrieving embeddings"
required: true
default: 0.3
tagging_fields_filter:
name: "Tagging Fields Filter"
type: "tagging_fields_filter"
description: "Filter JSON to retrieve a subset of documents"
required: true
dynamic_arguments:
name: "Dynamic Arguments"
type: "dynamic_arguments"
description: "dynamic arguments used in the filter"
required: false
arguments:
query:
name: "query"
type: "str"
type: "string"
description: "Query to retrieve embeddings"
required: True
required: true
role_identification:
name: "Role Identification"
type: "string"
description: "The role information needs to be retrieved for"
required: true
metadata:
author: "Josako"
date_added: "2025-03-11"
date_added: "2025-07-07"
changes: "Initial version"
description: "Retrieving all embeddings conform the query and the tagging fields filter"
description: "Retrieves vacancy text for a specific role"

View File

@@ -19,11 +19,6 @@ arguments:
type: "str"
description: "Language code to be used for receiving questions and giving answers"
required: true
query:
name: "query"
type: "str"
description: "Query or response to process"
required: true
results:
rag_output:
answer:

View File

@@ -0,0 +1,49 @@
version: "1.0.0"
name: "RAG Specialist"
framework: "crewai"
chat: true
configuration:
name:
name: "name"
type: "str"
description: "The name the specialist is called upon."
required: true
welcome_message:
name: "Welcome Message"
type: "string"
description: "Welcome Message to be given to the end user"
required: false
arguments:
language:
name: "Language"
type: "str"
description: "Language code to be used for receiving questions and giving answers"
required: true
results:
rag_output:
answer:
name: "answer"
type: "str"
description: "Answer to the query"
required: true
citations:
name: "citations"
type: "List[str]"
description: "List of citations"
required: false
insufficient_info:
name: "insufficient_info"
type: "bool"
description: "Whether or not the query is insufficient info"
required: true
agents:
- type: "RAG_AGENT"
version: "1.1"
tasks:
- type: "RAG_TASK"
version: "1.1"
metadata:
author: "Josako"
date_added: "2025-01-08"
changes: "Initial version"
description: "A Specialist that performs Q&A activities"

View File

@@ -4,8 +4,9 @@ CATALOG_TYPES = {
"name": "Standard Catalog",
"description": "A Catalog with information in Evie's Library, to be considered as a whole",
},
"DOSSIER_CATALOG": {
"name": "Dossier Catalog",
"description": "A Catalog with information in Evie's Library in which several Dossiers can be stored",
"TRAICIE_ROLE_DEFINITION_CATALOG": {
"name": "Role Definition Catalog",
"description": "A Catalog with information about roles, to be considered as a whole",
"partner": "traicie"
},
}

View File

@@ -4,8 +4,16 @@ RETRIEVER_TYPES = {
"name": "Standard RAG Retriever",
"description": "Retrieving all embeddings from the catalog conform the query",
},
"DOSSIER_RETRIEVER": {
"name": "Retriever for managing DOSSIER catalogs",
"description": "Retrieving filtered embeddings from the catalog conform the query",
"TRAICIE_ROLE_DEFINITION_BY_ROLE_IDENTIFICATION": {
"name": "Traicie Role Definition Retriever by Role Identification",
"description": "Retrieves relevant role information for a given role",
"partner": "traicie",
"valid_catalog_types": ["TRAICIE_ROLE_DEFINITION_CATALOG"]
},
"TRAICIE_ROLE_DEFINITION_VACANCY_TEXT": {
"name": "Traicie Role Definition Vacancy Text Retriever",
"description": "Retrieves vacancy text for a given role",
"partner": "traicie",
"valid_catalog_types": ["TRAICIE_ROLE_DEFINITION_CATALOG"]
}
}

View File

@@ -1,7 +1,7 @@
# Specialist Types
SPECIALIST_TYPES = {
"STANDARD_RAG_SPECIALIST": {
"name": "Q&A RAG Specialist",
"name": "Standard RAG Specialist",
"description": "Standard Q&A through RAG Specialist",
},
"RAG_SPECIALIST": {

View File

@@ -2,31 +2,91 @@
## Name Sensitivity
A lot of the required functionality to implement specialists has been automated. This automation is based on naming
conventions. So ... names of variables, attributes, ... needs to be precise, or you'll get problems.
conventions. So ... names of variables, attributes, ... needs to be precise, or you'll get into problems.
## Base Class: CrewAIBaseSpecialistExecutor
Inherit your SpecialistExecutor class from the base class CrewAIBaseSpecialistExecutor
The base class for defining new CrewAI based Specialists is CrewAIBaseSpecialistExecutor. This class implements a lot of
functionality out of the box, making the full development process easier to manage:
### Conventions
- tasks are referenced by using the lower case name of the configured task
- agents idem dito
- Before the specialist execution
- Retrieval of context (RAG Retrieval)
- Build up of historic context (memory)
- Initialisation of Agents, Tasks and tools (defined in the specialist configuration)
- Initialisation of specialist state, based on historic context
- During specialist execution
- Updates to the history (ChatSession and Interaction)
- Formatting of the results
### Specialise the __init__ method
It enables the following functionality:
- Define the crews you want to use in your specialist implementation
- Do other initialisations you require
- Call super
- Logging when requested (_log_tuning)
- Sending progress updates (ept)
- ...
### Type and Typeversion properties
### Naming Conventions
- tasks are referenced by using the lower case name of the configured task. Their names should always end on "_task"
- agents idem dito, but their names should end on "_agent"
- tools idem dito, but their names will end on "_tools"
## Implementation
### Step 1 - Code location
The implementation of a specialist should be placed in the specialists folder. If the specialist is a global specialist
(i.e. it can be used by all tenants), it will be in the globals folder. If it is created for a specific partner, we will
place it in the folder for that partner (lower_case)
The naming of the implementation is dependent on the version of the specialist we are creating. If we implement a
version 1.0, the implementation will be called "1_0.py". This implementation is also used for specialists with different
patch versions (e.g. 1.0.1, 1.0.4, ...)
### Step 2: type and type_version properties
- Adapt the type and type_version properties to define the correct specialist. This refers to the actual specialist configuration!
### Implement _config_task_agents
### Step 3: Specialist Setup
This method links the tasks to the agents that will perform LLM interactions. You can call the method _add_task_agent
to link both.
#### Specialising init
### Implement _config_pydantic_outputs
Use specialisisation of the init method to define the crews you require, and if needed, additional initialisation code.
This method links the tasks to their pydantic outputs, and their name in the state class.
#### configuring the specialist
- _config_task_agents
- Link each task to the agent that will perform the task
- use _add_task_agent for each of the tasks
- _config_pytdantic_outputs
- Link each task to a specific output (Pydantic)
- use _add_pydantic_output for each of the tasks
- _config_state_result_relations
- Zorg dat er een automatische overdracht kan zijn van state naar result
- use _add_state_result_relation om zo'n relatie toe te voegen
- when you give the attributes in the state and result the same names, this becomes quite obvious and easy to maintain
### Step 4: Implement specialist execution
This is the entry method invoked to actually implement specialist logic.
#### Available data
- arguments: the arguments passed in the specialist invocation
- formatted_context: the documents retrieved for the specialist
- citations: the citations found in the retrieval process
- self._formatted_history: a build-up of the history of the conversation
- self._cached_session: the complete cached session
- self.flow.state: the current flow state
#### Implementation guidelines
- allways use self.flow.state to update elements required to have available in consecutive calls, or elements you want to persist in the results
- use the phase (defined in the state) to distinguish between phases in the specialist execution
- Use "SelectionResult.create_for_type(self.type, self.type_version)" to return results.
### Step 5: Define Implementation Classes
- Define the Pydantic Output Classes to ensure structured outputs
- Define an Input class, containing all potential inputs required for flows and crews to perform their activities
- Define a FlowState (derived from EveAIFlowState) to maintain state throughout specialist execution
- Define a Result (derived from SpecialistResult) to define all information that needs to be stored in the session
- Define the Flow (derived from EveAICrewAIFlow) with the FlowState class

View File

@@ -0,0 +1,118 @@
# Tagging Fields Hulpprogramma's
## Overzicht
Dit document beschrijft de functionaliteit voor "tagging fields" in het systeem. Deze code biedt een flexibele manier om metadata velden te definiëren, valideren en documenteren voor gebruik in verschillende onderdelen van de applicatie.
## Kernconcept
Tagging fields zijn configureerbare metadata velden die gebruikt worden om verschillende objecten in het systeem te categoriseren, filteren en organiseren. De code biedt een uitgebreide set hulpprogramma's om:
1. Velddefinities te maken en valideren
2. Waarden te controleren op basis van veldbeperkingen
3. Documentatie te genereren in verschillende formaten
## Hoofdklassen
### TaggingField
De `TaggingField` klasse vertegenwoordigt een enkel metadataveld met:
- **type**: Het gegevenstype (`string`, `integer`, `float`, `date`, `enum`, `color`)
- **required**: Geeft aan of het veld verplicht is
- **description**: Beschrijving van het veld
- **allowed_values**: Mogelijke waarden voor enum typen
- **min_value** / **max_value**: Begrenzingen voor numerieke waarden
Interne validatie zorgt ervoor dat:
- Alleen geldige veldtypen worden gebruikt
- Enum velden altijd toegestane waarden hebben
- Numerieke velden geldige grenswaarden hebben
### TaggingFields
De `TaggingFields` klasse beheert een verzameling `TaggingField` objecten, gegroepeerd op naam. Het biedt:
- Conversie van/naar dictionaries met `from_dict()` en `to_dict()`
- Centrale opslag voor alle metadataveld configuraties
### Constraints Klassen
Er zijn verschillende klassen die beperkingen definiëren voor elk type veld:
- **NumericConstraint**: Voor integer/float waarden (min/max grenzen)
- **StringConstraint**: Voor tekstwaarden (lengte, patronen, verboden patronen)
- **DateConstraint**: Voor datumwaarden (min/max datums, formaten)
- **EnumConstraint**: Voor opsommingstypen (toegestane waarden, hoofdlettergevoeligheid)
Elke constraint klasse bevat een `validate()` methode om te controleren of een waarde voldoet aan de beperkingen.
### ArgumentDefinition en ArgumentDefinitions
Deze klassen breiden het concept uit naar functionele argumenten:
- **ArgumentDefinition**: Definieert een enkel argument met type en beperkingen
- **ArgumentDefinitions**: Beheert een verzameling argumenten en biedt validatie
## Documentatie Generatie
De module bevat functies om automatisch documentatie te genereren voor tagging fields:
- `generate_field_documentation()`: Hoofdfunctie die documentatie genereert
- Ondersteunt meerdere uitvoerformaten: Markdown, JSON, YAML
- Biedt twee versieniveaus: basic en extended
```python
# Voorbeeld van gebruik
documentatie = generate_field_documentation(
tagging_fields=fields_config,
format="markdown",
version="extended"
)
```
## Hulpfuncties voor Patroonverwerking
De module bevat ook hulpprogramma's voor het verwerken van regex patronen:
- `patterns_to_json()`: Zet tekstgebied-inhoud om naar JSON patronenlijst
- `json_to_patterns()`: Zet JSON terug naar tekstgebied formaat
- `json_to_pattern_list()`: Converteert JSON naar een Python lijst
## Validatie in Formulieren
In combinatie met webformulieren wordt de `validate_tagging_fields()` functie gebruikt om formulierinvoer te valideren tegen het tagging fields schema. Dit zorgt ervoor dat alleen geldige configuraties worden opgeslagen.
## Voorbeeld Implementatie
```python
# Voorbeeld van tagging fields configuratie
fields_config = {
"priority": {
"type": "enum",
"required": True,
"description": "Prioriteit van het item",
"allowed_values": ["laag", "medium", "hoog"]
},
"deadline": {
"type": "date",
"required": False,
"description": "Deadline voor voltooiing"
},
"tags": {
"type": "string",
"required": False,
"description": "Zoektags voor het item"
}
}
# Maak een TaggingFields object
fields = TaggingFields.from_dict(fields_config)
# Genereer documentatie
docs = generate_field_documentation(fields_config, format="markdown")
```
## Conclusie
De tagging fields functionaliteit biedt een krachtige en flexibele manier om metadata te beheren in de toepassing. Door gebruik te maken van sterke typering en validatie, zorgt het systeem ervoor dat gegevens consistent en betrouwbaar blijven, terwijl het toch aanpasbaar is aan verschillende gebruikssituaties.

View File

@@ -11,9 +11,9 @@ from wtforms_sqlalchemy.fields import QuerySelectField
from common.extensions import cache_manager
from common.models.document import Catalog
from common.services.user import TenantServices
from common.utils.form_assistants import validate_json
from config.type_defs.catalog_types import CATALOG_TYPES
from config.type_defs.processor_types import PROCESSOR_TYPES
from .dynamic_form_base import DynamicFormBase
@@ -45,7 +45,11 @@ class CatalogForm(FlaskForm):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Dynamically populate the 'type' field using the constructor
self.type.choices = [(key, value['name']) for key, value in CATALOG_TYPES.items()]
types_dict = cache_manager.catalogs_types_cache.get_types()
# Dynamically populate the 'type' field using the constructor
tenant_id = session.get('tenant').get('id')
choices = TenantServices.get_available_types_for_tenant(tenant_id, "catalogs")
self.type.choices = [(key, value['name']) for key, value in choices.items()]
class EditCatalogForm(DynamicFormBase):
@@ -55,6 +59,7 @@ class EditCatalogForm(DynamicFormBase):
# Select Field for Catalog Type (Uses the CATALOG_TYPES defined in config)
type = StringField('Catalog Type', validators=[DataRequired()], render_kw={'readonly': True})
type_version = StringField('Catalog Type Version', validators=[DataRequired()], render_kw={'readonly': True})
# Selection fields for processing & creating embeddings
min_chunk_size = IntegerField('Minimum Chunk Size (2000)', validators=[NumberRange(min=0), Optional()],
@@ -130,9 +135,21 @@ class RetrieverForm(FlaskForm):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
types_dict = cache_manager.retrievers_types_cache.get_types()
tenant_id = session.get('tenant').get('id')
choices = TenantServices.get_available_types_for_tenant(tenant_id, "retrievers")
current_app.logger.debug(f"Potential choices: {choices}")
# Dynamically populate the 'type' field using the constructor
self.type.choices = [(key, value['name']) for key, value in types_dict.items()]
type_choices = []
for key, value in choices.items():
valid_catalog_types = value.get('valid_catalog_types', None)
if valid_catalog_types:
catalog_type = session.get('catalog').get('type')
current_app.logger.debug(f"Check {catalog_type} in {valid_catalog_types}")
if catalog_type in valid_catalog_types:
type_choices.append((key, value['name']))
else: # Retriever type is valid for all catalog types
type_choices.append((key, value['name']))
self.type.choices = type_choices
class EditRetrieverForm(DynamicFormBase):
@@ -141,6 +158,7 @@ class EditRetrieverForm(DynamicFormBase):
# Select Field for Retriever Type (Uses the RETRIEVER_TYPES defined in config)
type = StringField('Processor Type', validators=[DataRequired()], render_kw={'readonly': True})
type_version = StringField('Retriever Type Version', validators=[DataRequired()], render_kw={'readonly': True})
tuning = BooleanField('Enable Tuning', default=False)
# Metadata fields

View File

@@ -50,7 +50,7 @@ def before_request():
current_app.logger.error(f'Error switching schema in Document Blueprint: {e}')
raise
# Catalog Management ------------------------------------------------------------------------------
@document_bp.route('/catalog', methods=['GET', 'POST'])
@roles_accepted('Super User', 'Partner Admin', 'Tenant Admin')
def catalog():
@@ -62,6 +62,11 @@ def catalog():
form.populate_obj(new_catalog)
set_logging_information(new_catalog, dt.now(tz.utc))
# Set Type information, including the configuration for backward compatibility
new_catalog.type_version = cache_manager.catalogs_version_tree_cache.get_latest_version(new_catalog.type)
new_catalog.configuration = (cache_manager.catalogs_config_cache
.get_config(new_catalog.type, new_catalog.type_version).get("configuration", {}))
try:
db.session.add(new_catalog)
db.session.commit()
@@ -110,6 +115,7 @@ def handle_catalog_selection():
current_app.logger.info(f'Setting session catalog to {catalog.name}')
session['catalog_id'] = catalog_id
session['catalog_name'] = catalog.name
session['catalog'] = catalog.to_dict()
elif action == 'edit_catalog':
return redirect(prefixed_url_for('document_bp.edit_catalog', catalog_id=catalog_id))
@@ -124,15 +130,14 @@ def edit_catalog(catalog_id):
form = EditCatalogForm(request.form, obj=catalog)
full_config = cache_manager.catalogs_config_cache.get_config(catalog.type)
form.add_dynamic_fields("configuration", full_config, catalog.configuration)
if request.method == 'POST' and form.validate_on_submit():
form.populate_obj(catalog)
catalog.configuration = form.get_dynamic_data('configuration')
update_logging_information(catalog, dt.now(tz.utc))
try:
db.session.add(catalog)
db.session.commit()
if session.get('catalog_id') == catalog_id:
session['catalog'] = catalog.to_dict()
flash('Catalog successfully updated successfully!', 'success')
current_app.logger.info(f'Catalog {catalog.name} successfully updated for tenant {tenant_id}')
except SQLAlchemyError as e:
@@ -147,6 +152,7 @@ def edit_catalog(catalog_id):
return render_template('document/edit_catalog.html', form=form, catalog_id=catalog_id)
# Processor Management ----------------------------------------------------------------------------
@document_bp.route('/processor', methods=['GET', 'POST'])
@roles_accepted('Super User', 'Partner Admin', 'Tenant Admin')
def processor():
@@ -264,6 +270,7 @@ def handle_processor_selection():
return redirect(prefixed_url_for('document_bp.processors'))
# Retriever Management ----------------------------------------------------------------------------
@document_bp.route('/retriever', methods=['GET', 'POST'])
@roles_accepted('Super User', 'Partner Admin', 'Tenant Admin')
def retriever():
@@ -306,8 +313,7 @@ def edit_retriever(retriever_id):
form = EditRetrieverForm(request.form, obj=retriever)
retriever_config = cache_manager.retrievers_config_cache.get_config(retriever.type, retriever.type_version)
configuration_config = retriever_config.get("configuration")
form.add_dynamic_fields("configuration", configuration_config, retriever.configuration)
form.add_dynamic_fields("configuration", retriever_config, retriever.configuration)
if form.validate_on_submit():
# Update basic fields
@@ -375,6 +381,7 @@ def handle_retriever_selection():
return redirect(prefixed_url_for('document_bp.retrievers'))
# Document Management -----------------------------------------------------------------------------
@document_bp.route('/add_document', methods=['GET', 'POST'])
@roles_accepted('Super User', 'Partner Admin', 'Tenant Admin')
def add_document():
@@ -730,6 +737,53 @@ def handle_full_document_selection():
return redirect(prefixed_url_for('document_bp.full_documents'))
@document_bp.route('/document_versions_list', methods=['GET'])
@roles_accepted('Super User', 'Partner Admin', 'Tenant Admin')
def document_versions_list():
view = DocumentVersionListView(DocumentVersion, 'document/document_versions_list_view.html', per_page=20)
return view.get()
@document_bp.route('/view_document_version_markdown/<int:document_version_id>', methods=['GET'])
@roles_accepted('Super User', 'Partner Admin', 'Tenant Admin')
def view_document_version_markdown(document_version_id):
current_app.logger.debug(f'Viewing document version markdown {document_version_id}')
# Retrieve document version
document_version = DocumentVersion.query.get_or_404(document_version_id)
# retrieve tenant information
tenant_id = session.get('tenant').get('id')
try:
# Generate markdown filename
markdown_filename = f"{document_version.id}.md"
markdown_object_name = minio_client.generate_object_name(document_version.doc_id, document_version.language,
document_version.id, markdown_filename)
current_app.logger.debug(f'Markdown object name: {markdown_object_name}')
# Download actual markdown file
file_data = minio_client.download_document_file(
tenant_id,
document_version.bucket_name,
markdown_object_name,
)
# Decodeer de binaire data naar UTF-8 tekst
markdown_content = file_data.decode('utf-8')
current_app.logger.debug(f'Markdown content: {markdown_content}')
# Render de template met de markdown inhoud
return render_template(
'document/view_document_version_markdown.html',
document_version=document_version,
markdown_content=markdown_content
)
except Exception as e:
current_app.logger.error(f"Error retrieving markdown for document version {document_version_id}: {str(e)}")
flash(f"Error retrieving processed document: {str(e)}", "danger")
return redirect(prefixed_url_for('document_bp.document_versions'))
# Utilities ---------------------------------------------------------------------------------------
@document_bp.route('/library_operations', methods=['GET', 'POST'])
@roles_accepted('Super User', 'Partner Admin', 'Tenant Admin')
def library_operations():
@@ -834,52 +888,6 @@ def create_default_rag_library():
return redirect(prefixed_url_for('document_bp.library_operations'))
@document_bp.route('/document_versions_list', methods=['GET'])
@roles_accepted('Super User', 'Partner Admin', 'Tenant Admin')
def document_versions_list():
view = DocumentVersionListView(DocumentVersion, 'document/document_versions_list_view.html', per_page=20)
return view.get()
@document_bp.route('/view_document_version_markdown/<int:document_version_id>', methods=['GET'])
@roles_accepted('Super User', 'Partner Admin', 'Tenant Admin')
def view_document_version_markdown(document_version_id):
current_app.logger.debug(f'Viewing document version markdown {document_version_id}')
# Retrieve document version
document_version = DocumentVersion.query.get_or_404(document_version_id)
# retrieve tenant information
tenant_id = session.get('tenant').get('id')
try:
# Generate markdown filename
markdown_filename = f"{document_version.id}.md"
markdown_object_name = minio_client.generate_object_name(document_version.doc_id, document_version.language,
document_version.id, markdown_filename)
current_app.logger.debug(f'Markdown object name: {markdown_object_name}')
# Download actual markdown file
file_data = minio_client.download_document_file(
tenant_id,
document_version.bucket_name,
markdown_object_name,
)
# Decodeer de binaire data naar UTF-8 tekst
markdown_content = file_data.decode('utf-8')
current_app.logger.debug(f'Markdown content: {markdown_content}')
# Render de template met de markdown inhoud
return render_template(
'document/view_document_version_markdown.html',
document_version=document_version,
markdown_content=markdown_content
)
except Exception as e:
current_app.logger.error(f"Error retrieving markdown for document version {document_version_id}: {str(e)}")
flash(f"Error retrieving processed document: {str(e)}", "danger")
return redirect(prefixed_url_for('document_bp.document_versions'))
def refresh_all_documents():
for doc in Document.query.all():
refresh_document(doc.id)

View File

@@ -40,6 +40,7 @@ class EditPartnerServiceForm(DynamicFormBase):
name = StringField('Name', validators=[DataRequired(), Length(max=50)])
description = TextAreaField('Description', validators=[Optional()])
type = StringField('Partner Service Type', validators=[DataRequired()], render_kw={'readonly': True})
type_version = StringField('Partner Service Type Version', validators=[DataRequired()], render_kw={'readonly': True})
active = BooleanField('Active', validators=[Optional()], default=True)
user_metadata = TextAreaField('User Metadata', validators=[Optional(), validate_json])

View File

@@ -45,6 +45,7 @@ def handle_trigger_action():
return redirect(prefixed_url_for('partner_bp.trigger_actions'))
# Partner Management ------------------------------------------------------------------------------
@partner_bp.route('/partner/<int:partner_id>', methods=['GET', 'POST'])
@roles_accepted('Super User')
def edit_partner(partner_id):
@@ -124,6 +125,7 @@ def handle_partner_selection():
return redirect(prefixed_url_for('partner_bp.partners'))
# Partner Servide Management ----------------------------------------------------------------------
@partner_bp.route('/partner_service', methods=['GET', 'POST'])
@roles_accepted('Super User')
def partner_service():
@@ -160,20 +162,12 @@ def partner_service():
@roles_accepted('Super User')
def edit_partner_service(partner_service_id):
partner_service = PartnerService.query.get_or_404(partner_service_id)
partner = session.get('partner', None)
partner_id = session['partner']['id']
current_app.logger.debug(f"Request Type: {request.method}")
form = EditPartnerServiceForm(obj=partner_service)
form = EditPartnerServiceForm(request.form, obj=partner_service)
partner_service_config = cache_manager.partner_services_config_cache.get_config(partner_service.type,
partner_service.type_version)
configuration_config = partner_service_config.get('configuration')
current_app.logger.debug(f"Configuration config for {partner_service.type} {partner_service.type_version}: "
f"{configuration_config}")
form.add_dynamic_fields("configuration", partner_service_config, partner_service.configuration)
permissions_config = partner_service_config.get('permissions')
current_app.logger.debug(f"Permissions config for {partner_service.type} {partner_service.type_version}: "
f"{permissions_config}")
form.add_dynamic_fields("permissions", partner_service_config, partner_service.permissions)
if request.method == 'POST':
@@ -188,9 +182,6 @@ def edit_partner_service(partner_service_id):
current_app.logger.debug(f"Partner Service configuration: {partner_service.configuration}")
current_app.logger.debug(f"Partner Service permissions: {partner_service.permissions}")
# update partner relationship
partner_service.partner_id = partner_id
update_logging_information(partner_service, dt.now(tz.utc))
try:
@@ -258,6 +249,7 @@ def handle_partner_service_selection():
return redirect(prefixed_url_for('partner_bp.partner_services'))
# Utility Functions
def register_partner_from_tenant(tenant_id):
# check if there is already a partner defined for the tenant
partner = Partner.query.filter_by(tenant_id=tenant_id).first()

View File

@@ -36,6 +36,7 @@ def log_after_request(response):
return response
# Tenant Management -------------------------------------------------------------------------------
@user_bp.route('/tenant', methods=['GET', 'POST'])
@roles_accepted('Super User', 'Partner Admin')
def tenant():
@@ -136,6 +137,114 @@ def edit_tenant(tenant_id):
return render_template('user/tenant.html', form=form, tenant_id=tenant_id)
@user_bp.route('/select_tenant', methods=['GET', 'POST'])
@roles_accepted('Super User', 'Partner Admin') # Allow both roles
def select_tenant():
filter_form = TenantSelectionForm(request.form)
page = request.args.get('page', 1, type=int)
per_page = request.args.get('per_page', 10, type=int)
# Start with a base query
query = Tenant.query
current_app.logger.debug("We proberen het scherm op te bouwen")
current_app.logger.debug(f"Session: {session}")
# Apply different filters based on user role
if current_user.has_roles('Partner Admin') and 'partner' in session:
current_app.logger.debug("We zitten in partner mode")
# Get the partner's management service
management_service = next((service for service in session['partner']['services']
if service.get('type') == 'MANAGEMENT_SERVICE'), None)
if management_service:
# Get the partner's own tenant
partner_tenant_id = session['partner']['tenant_id']
# Get tenants managed by this partner through PartnerTenant relationships
managed_tenant_ids = db.session.query(PartnerTenant.tenant_id).filter_by(
partner_service_id=management_service['id']
).all()
# Convert list of tuples to flat list
managed_tenant_ids = [tenant_id for (tenant_id,) in managed_tenant_ids]
# Include partner's own tenant in the list
allowed_tenant_ids = [partner_tenant_id] + managed_tenant_ids
# Filter query to only show allowed tenants
query = query.filter(Tenant.id.in_(allowed_tenant_ids))
current_app.logger.debug("We zitten na partner service selectie")
# Apply form filters (for both Super User and Partner Admin)
if filter_form.validate_on_submit():
if filter_form.types.data:
query = query.filter(Tenant.type.in_(filter_form.types.data))
if filter_form.search.data:
search = f"%{filter_form.search.data}%"
query = query.filter(Tenant.name.ilike(search))
# Finalize query
query = query.order_by(Tenant.name)
pagination = query.paginate(page=page, per_page=per_page, error_out=False)
tenants = pagination.items
rows = prepare_table_for_macro(tenants, [('id', ''), ('name', ''), ('website', ''), ('type', '')])
return render_template('user/select_tenant.html', rows=rows, pagination=pagination, filter_form=filter_form)
@user_bp.route('/handle_tenant_selection', methods=['POST'])
@roles_accepted('Super User', 'Partner Admin')
def handle_tenant_selection():
action = request.form['action']
if action == 'create_tenant':
return redirect(prefixed_url_for('user_bp.tenant'))
tenant_identification = request.form['selected_row']
tenant_id = ast.literal_eval(tenant_identification).get('value')
if not UserServices.can_user_edit_tenant(tenant_id):
current_app.logger.info(f"User not authenticated to edit tenant {tenant_id}.")
flash(f"You are not authenticated to manage tenant {tenant_id}", 'danger')
return redirect(prefixed_url_for('user_bp.select_tenant'))
the_tenant = Tenant.query.get(tenant_id)
# set tenant information in the session
session['tenant'] = the_tenant.to_dict()
# remove catalog-related items from the session
session.pop('catalog_id', None)
session.pop('catalog_name', None)
match action:
case 'edit_tenant':
return redirect(prefixed_url_for('user_bp.edit_tenant', tenant_id=tenant_id))
case 'select_tenant':
return redirect(prefixed_url_for('user_bp.tenant_overview'))
# Add more conditions for other actions
return redirect(prefixed_url_for('select_tenant'))
@user_bp.route('/tenant_overview', methods=['GET'])
@roles_accepted('Super User', 'Partner Admin', 'Tenant Admin')
def tenant_overview():
tenant_id = session['tenant']['id']
tenant = Tenant.query.get_or_404(tenant_id)
form = EditTenantForm(obj=tenant)
# Zet de waarde van default_tenant_make_id
if tenant.default_tenant_make_id:
form.default_tenant_make_id.data = str(tenant.default_tenant_make_id)
# Haal de naam van de default make op als deze bestaat
default_make_name = None
if tenant.default_tenant_make:
default_make_name = tenant.default_tenant_make.name
return render_template('user/tenant_overview.html', form=form, default_make_name=default_make_name)
# User Management ---------------------------------------------------------------------------------
@user_bp.route('/user', methods=['GET', 'POST'])
@roles_accepted('Super User', 'Tenant Admin', 'Partner Admin')
def user():
@@ -235,94 +344,6 @@ def edit_user(user_id):
return render_template('user/edit_user.html', form=form, user_id=user_id)
@user_bp.route('/select_tenant', methods=['GET', 'POST'])
@roles_accepted('Super User', 'Partner Admin') # Allow both roles
def select_tenant():
filter_form = TenantSelectionForm(request.form)
page = request.args.get('page', 1, type=int)
per_page = request.args.get('per_page', 10, type=int)
# Start with a base query
query = Tenant.query
current_app.logger.debug("We proberen het scherm op te bouwen")
current_app.logger.debug(f"Session: {session}")
# Apply different filters based on user role
if current_user.has_roles('Partner Admin') and 'partner' in session:
current_app.logger.debug("We zitten in partner mode")
# Get the partner's management service
management_service = next((service for service in session['partner']['services']
if service.get('type') == 'MANAGEMENT_SERVICE'), None)
if management_service:
# Get the partner's own tenant
partner_tenant_id = session['partner']['tenant_id']
# Get tenants managed by this partner through PartnerTenant relationships
managed_tenant_ids = db.session.query(PartnerTenant.tenant_id).filter_by(
partner_service_id=management_service['id']
).all()
# Convert list of tuples to flat list
managed_tenant_ids = [tenant_id for (tenant_id,) in managed_tenant_ids]
# Include partner's own tenant in the list
allowed_tenant_ids = [partner_tenant_id] + managed_tenant_ids
# Filter query to only show allowed tenants
query = query.filter(Tenant.id.in_(allowed_tenant_ids))
current_app.logger.debug("We zitten na partner service selectie")
# Apply form filters (for both Super User and Partner Admin)
if filter_form.validate_on_submit():
if filter_form.types.data:
query = query.filter(Tenant.type.in_(filter_form.types.data))
if filter_form.search.data:
search = f"%{filter_form.search.data}%"
query = query.filter(Tenant.name.ilike(search))
# Finalize query
query = query.order_by(Tenant.name)
pagination = query.paginate(page=page, per_page=per_page, error_out=False)
tenants = pagination.items
rows = prepare_table_for_macro(tenants, [('id', ''), ('name', ''), ('website', ''), ('type', '')])
return render_template('user/select_tenant.html', rows=rows, pagination=pagination, filter_form=filter_form)
@user_bp.route('/handle_tenant_selection', methods=['POST'])
@roles_accepted('Super User', 'Partner Admin')
def handle_tenant_selection():
action = request.form['action']
if action == 'create_tenant':
return redirect(prefixed_url_for('user_bp.tenant'))
tenant_identification = request.form['selected_row']
tenant_id = ast.literal_eval(tenant_identification).get('value')
if not UserServices.can_user_edit_tenant(tenant_id):
current_app.logger.info(f"User not authenticated to edit tenant {tenant_id}.")
flash(f"You are not authenticated to manage tenant {tenant_id}", 'danger')
return redirect(prefixed_url_for('user_bp.select_tenant'))
the_tenant = Tenant.query.get(tenant_id)
# set tenant information in the session
session['tenant'] = the_tenant.to_dict()
# remove catalog-related items from the session
session.pop('catalog_id', None)
session.pop('catalog_name', None)
match action:
case 'edit_tenant':
return redirect(prefixed_url_for('user_bp.edit_tenant', tenant_id=tenant_id))
case 'select_tenant':
return redirect(prefixed_url_for('user_bp.tenant_overview'))
# Add more conditions for other actions
return redirect(prefixed_url_for('select_tenant'))
@user_bp.route('/view_users')
@roles_accepted('Super User', 'Partner Admin', 'Tenant Admin')
def view_users():
@@ -369,6 +390,7 @@ def handle_user_action():
return redirect(prefixed_url_for('user_bp.view_users'))
# Tenant Domain Management ------------------------------------------------------------------------
@user_bp.route('/view_tenant_domains')
@roles_accepted('Super User', 'Partner Admin', 'Tenant Admin')
def view_tenant_domains():
@@ -461,25 +483,7 @@ def edit_tenant_domain(tenant_domain_id):
return render_template('user/edit_tenant_domain.html', form=form, tenant_domain_id=tenant_domain_id)
@user_bp.route('/tenant_overview', methods=['GET'])
@roles_accepted('Super User', 'Partner Admin', 'Tenant Admin')
def tenant_overview():
tenant_id = session['tenant']['id']
tenant = Tenant.query.get_or_404(tenant_id)
form = EditTenantForm(obj=tenant)
# Zet de waarde van default_tenant_make_id
if tenant.default_tenant_make_id:
form.default_tenant_make_id.data = str(tenant.default_tenant_make_id)
# Haal de naam van de default make op als deze bestaat
default_make_name = None
if tenant.default_tenant_make:
default_make_name = tenant.default_tenant_make.name
return render_template('user/tenant_overview.html', form=form, default_make_name=default_make_name)
# Tenant Project Management -----------------------------------------------------------------------
@user_bp.route('/tenant_project', methods=['GET', 'POST'])
@roles_accepted('Super User', 'Partner Admin', 'Tenant Admin')
def tenant_project():
@@ -639,6 +643,7 @@ def delete_tenant_project(tenant_project_id):
return redirect(prefixed_url_for('user_bp.tenant_projects'))
# Tenant Make Management --------------------------------------------------------------------------
@user_bp.route('/tenant_make', methods=['GET', 'POST'])
@roles_accepted('Super User', 'Partner Admin', 'Tenant Admin')
def tenant_make():

View File

@@ -1,6 +0,0 @@
# Import all specialist implementations here to ensure registration
from . import standard_rag
from . import dossier_retriever
# List of all available specialist implementations
__all__ = ['standard_rag', 'dossier_retriever']

View File

@@ -1,3 +1,5 @@
import importlib
import json
from abc import ABC, abstractmethod, abstractproperty
from typing import Dict, Any, List, Optional, Tuple
@@ -5,7 +7,7 @@ from flask import current_app
from sqlalchemy import func, or_, desc
from sqlalchemy.exc import SQLAlchemyError
from common.extensions import db
from common.extensions import db, cache_manager
from common.models.document import Document, DocumentVersion, Catalog, Retriever
from common.utils.model_utils import get_embedding_model_and_class
from eveai_chat_workers.retrievers.retriever_typing import RetrieverResult, RetrieverArguments, RetrieverMetadata
@@ -19,9 +21,12 @@ class BaseRetriever(ABC):
self.tenant_id = tenant_id
self.retriever_id = retriever_id
self.retriever = Retriever.query.get_or_404(retriever_id)
self.tuning = False
self.catalog_id = self.retriever.catalog_id
self.tuning = self.retriever.tuning
self.tuning_logger = None
self._setup_tuning_logger()
self.embedding_model, self.embedding_model_class = (
get_embedding_model_and_class(tenant_id=tenant_id, catalog_id=self.catalog_id))
@property
@abstractmethod
@@ -29,6 +34,11 @@ class BaseRetriever(ABC):
"""The type of the retriever"""
raise NotImplementedError
@abstractmethod
def type_version(self) -> str:
"""The type version of the retriever"""
raise NotImplementedError
def _setup_tuning_logger(self):
try:
self.tuning_logger = TuningLogger(
@@ -43,6 +53,32 @@ class BaseRetriever(ABC):
current_app.logger.error(f"Failed to setup tuning logger: {str(e)}")
raise
def _parse_metadata(self, metadata: Any) -> Dict[str, Any]:
"""
Parse metadata ensuring it's a dictionary
Args:
metadata: Input metadata which could be string, dict, or None
Returns:
Dict[str, Any]: Parsed metadata as dictionary
"""
if metadata is None:
return {}
if isinstance(metadata, dict):
return metadata
if isinstance(metadata, str):
try:
return json.loads(metadata)
except json.JSONDecodeError:
current_app.logger.warning(f"Failed to parse metadata JSON string: {metadata}")
return {}
current_app.logger.warning(f"Unexpected metadata type: {type(metadata)}")
return {}
def log_tuning(self, message: str, data: Dict[str, Any] = None) -> None:
if self.tuning and self.tuning_logger:
try:
@@ -50,31 +86,6 @@ class BaseRetriever(ABC):
except Exception as e:
current_app.logger.error(f"Processor: Error in tuning logging: {e}")
def setup_standard_retrieval_params(self) -> Tuple[Any, Any, Any, float, int]:
"""
Set up common parameters needed for standard retrieval functionality
Returns:
Tuple containing:
- embedding_model: The model to use for embeddings
- embedding_model_class: The class for storing embeddings
- catalog_id: ID of the catalog
- similarity_threshold: Threshold for similarity matching
- k: Maximum number of results to return
"""
catalog_id = self.retriever.catalog_id
catalog = Catalog.query.get_or_404(catalog_id)
embedding_model = "mistral.mistral-embed"
embedding_model, embedding_model_class = get_embedding_model_and_class(
self.tenant_id, catalog_id, embedding_model
)
similarity_threshold = self.retriever.configuration.get('es_similarity_threshold', 0.3)
k = self.retriever.configuration.get('es_k', 8)
return embedding_model, embedding_model_class, catalog_id, similarity_threshold, k
@abstractmethod
def retrieve(self, arguments: RetrieverArguments) -> List[RetrieverResult]:
"""
@@ -87,3 +98,16 @@ class BaseRetriever(ABC):
List[Dict[str, Any]]: List of retrieved documents/content
"""
raise NotImplementedError
def get_retriever_class(retriever_type: str, type_version: str):
major_minor = '_'.join(type_version.split('.')[:2])
retriever_config = cache_manager.retrievers_config_cache.get_config(retriever_type, type_version)
partner = retriever_config.get("partner", None)
if partner:
module_path = f"eveai_chat_workers.retrievers.{partner}.{retriever_type}.{major_minor}"
else:
module_path = f"eveai_chat_workers.retrievers.globals.{retriever_type}.{major_minor}"
current_app.logger.debug(f"Importing retriever class from {module_path}")
module = importlib.import_module(module_path)
return module.RetrieverExecutor

View File

@@ -1,374 +0,0 @@
"""
DossierRetriever implementation that adds metadata filtering to retrieval
"""
import json
from datetime import datetime as dt, date, timezone as tz
from typing import Dict, Any, List, Optional, Union, Tuple
from sqlalchemy import func, or_, desc, and_, text, cast, JSON, String, Integer, Float, Boolean, DateTime
from sqlalchemy.sql import expression
from sqlalchemy.exc import SQLAlchemyError
from flask import current_app
from common.extensions import db
from common.models.document import Document, DocumentVersion, Catalog, Retriever
from common.utils.model_utils import get_embedding_model_and_class
from .base import BaseRetriever
from .registry import RetrieverRegistry
from .retriever_typing import RetrieverArguments, RetrieverResult, RetrieverMetadata
class DossierRetriever(BaseRetriever):
"""
Dossier Retriever implementation that adds metadata filtering
to standard retrieval functionality
"""
def __init__(self, tenant_id: int, retriever_id: int):
super().__init__(tenant_id, retriever_id)
# Set up standard retrieval parameters
self.embedding_model, self.embedding_model_class, self.catalog_id, self.similarity_threshold, self.k = self.setup_standard_retrieval_params()
# Dossier-specific configuration
self.tagging_fields_filter = self.retriever.configuration.get('tagging_fields_filter', {})
self.dynamic_arguments = self.retriever.configuration.get('dynamic_arguments', {})
self.log_tuning("Dossier retriever initialized", {
"tagging_fields_filter": self.tagging_fields_filter,
"dynamic_arguments": self.dynamic_arguments,
"similarity_threshold": self.similarity_threshold,
"k": self.k
})
@property
def type(self) -> str:
return "DOSSIER_RETRIEVER"
def _parse_metadata(self, metadata: Any) -> Dict[str, Any]:
"""
Parse metadata ensuring it's a dictionary
Args:
metadata: Input metadata which could be string, dict, or None
Returns:
Dict[str, Any]: Parsed metadata as dictionary
"""
if metadata is None:
return {}
if isinstance(metadata, dict):
return metadata
if isinstance(metadata, str):
try:
return json.loads(metadata)
except json.JSONDecodeError:
current_app.logger.warning(f"Failed to parse metadata JSON string: {metadata}")
return {}
current_app.logger.warning(f"Unexpected metadata type: {type(metadata)}")
return {}
def _apply_metadata_filter(self, query_obj, arguments: RetrieverArguments):
"""
Apply metadata filters to the query based on tagging_fields_filter configuration
Args:
query_obj: SQLAlchemy query object
arguments: Retriever arguments (for variable substitution)
Returns:
Modified SQLAlchemy query object
"""
if not self.tagging_fields_filter:
return query_obj
# Get dynamic argument values
dynamic_values = {}
for arg_name, arg_config in self.dynamic_arguments.items():
if hasattr(arguments, arg_name):
dynamic_values[arg_name] = getattr(arguments, arg_name)
# Build the filter
filter_condition = self._build_filter_condition(self.tagging_fields_filter, dynamic_values)
if filter_condition is not None:
query_obj = query_obj.filter(filter_condition)
self.log_tuning("Applied metadata filter", {
"filter_sql": str(filter_condition),
"dynamic_values": dynamic_values
})
return query_obj
def _build_filter_condition(self, filter_def: Dict[str, Any], dynamic_values: Dict[str, Any]) -> Optional[
expression.BinaryExpression]:
"""
Recursively build SQLAlchemy filter condition from filter definition
Args:
filter_def: Filter definition (logical group or field condition)
dynamic_values: Values for dynamic variable substitution
Returns:
SQLAlchemy expression or None if invalid
"""
# Handle logical groups (AND, OR, NOT)
if 'logical' in filter_def:
logical_op = filter_def['logical'].lower()
subconditions = [
self._build_filter_condition(cond, dynamic_values)
for cond in filter_def.get('conditions', [])
]
# Remove None values
subconditions = [c for c in subconditions if c is not None]
if not subconditions:
return None
if logical_op == 'and':
return and_(*subconditions)
elif logical_op == 'or':
return or_(*subconditions)
elif logical_op == 'not':
if len(subconditions) == 1:
return ~subconditions[0]
else:
# NOT should have exactly one condition
current_app.logger.warning(f"NOT operator requires exactly one condition, got {len(subconditions)}")
return None
else:
current_app.logger.warning(f"Unknown logical operator: {logical_op}")
return None
# Handle field conditions
elif 'field' in filter_def and 'operator' in filter_def and 'value' in filter_def:
field_name = filter_def['field']
operator = filter_def['operator'].lower()
value = self._resolve_value(filter_def['value'], dynamic_values, filter_def.get('default'))
# Skip if we couldn't resolve the value
if value is None and operator not in ['is_null', 'is_not_null']:
return None
# Create the field expression to match JSON data
field_expr = cast(DocumentVersion.catalog_properties['tagging_fields'][field_name], String)
# Apply the appropriate operator
return self._apply_operator(field_expr, operator, value, filter_def)
else:
current_app.logger.warning(f"Invalid filter definition: {filter_def}")
return None
def _resolve_value(self, value_def: Any, dynamic_values: Dict[str, Any], default: Any = None) -> Any:
"""
Resolve a value definition, handling variables and defaults
Args:
value_def: Value definition (could be literal, variable reference, or list)
dynamic_values: Values for dynamic variable substitution
default: Default value if variable not found
Returns:
Resolved value
"""
# Handle lists (recursively resolve each item)
if isinstance(value_def, list):
return [self._resolve_value(item, dynamic_values) for item in value_def]
# Handle variable references (strings starting with $)
if isinstance(value_def, str) and value_def.startswith('$'):
var_name = value_def[1:] # Remove $ prefix
if var_name in dynamic_values:
return dynamic_values[var_name]
else:
# Use default if provided
return default
# Return literal values as-is
return value_def
def _apply_operator(self, field_expr, operator: str, value: Any, filter_def: Dict[str, Any]) -> Optional[
expression.BinaryExpression]:
"""
Apply the specified operator to create a filter condition
Args:
field_expr: SQLAlchemy field expression
operator: Operator to apply
value: Value to compare against
filter_def: Original filter definition (for additional options)
Returns:
SQLAlchemy expression
"""
try:
# String operators
if operator == 'eq':
return field_expr == str(value)
elif operator == 'neq':
return field_expr != str(value)
elif operator == 'contains':
return field_expr.contains(str(value))
elif operator == 'not_contains':
return ~field_expr.contains(str(value))
elif operator == 'starts_with':
return field_expr.startswith(str(value))
elif operator == 'ends_with':
return field_expr.endswith(str(value))
elif operator == 'in':
return field_expr.in_([str(v) for v in value])
elif operator == 'not_in':
return ~field_expr.in_([str(v) for v in value])
elif operator == 'regex' or operator == 'not_regex':
# PostgreSQL regex using ~ or !~ operator
case_insensitive = filter_def.get('case_insensitive', False)
regex_op = '~*' if case_insensitive else '~'
if operator == 'not_regex':
regex_op = '!~*' if case_insensitive else '!~'
return text(
f"{field_expr.compile(compile_kwargs={'literal_binds': True})} {regex_op} :regex_value").bindparams(
regex_value=str(value))
# Numeric/Date operators
elif operator == 'gt':
return cast(field_expr, Float) > float(value)
elif operator == 'gte':
return cast(field_expr, Float) >= float(value)
elif operator == 'lt':
return cast(field_expr, Float) < float(value)
elif operator == 'lte':
return cast(field_expr, Float) <= float(value)
elif operator == 'between':
if len(value) == 2:
return cast(field_expr, Float).between(float(value[0]), float(value[1]))
else:
current_app.logger.warning(f"BETWEEN operator requires exactly two values, got {len(value)}")
return None
elif operator == 'not_between':
if len(value) == 2:
return ~cast(field_expr, Float).between(float(value[0]), float(value[1]))
else:
current_app.logger.warning(f"NOT_BETWEEN operator requires exactly two values, got {len(value)}")
return None
# Null checking
elif operator == 'is_null':
return field_expr.is_(None)
elif operator == 'is_not_null':
return field_expr.isnot(None)
else:
current_app.logger.warning(f"Unknown operator: {operator}")
return None
except (ValueError, TypeError) as e:
current_app.logger.error(f"Error applying operator {operator} with value {value}: {e}")
return None
def retrieve(self, arguments: RetrieverArguments) -> List[RetrieverResult]:
"""
Retrieve documents based on query with added metadata filtering
Args:
arguments: Validated RetrieverArguments containing at minimum:
- query: str - The search query
Returns:
List[RetrieverResult]: List of retrieved documents with similarity scores
"""
try:
query = arguments.query
# Get query embedding
query_embedding = self.embedding_model.embed_query(query)
# Get the appropriate embedding database model
db_class = self.embedding_model_class
# Get current date for validity checks
current_date = dt.now(tz=tz.utc).date()
# Create subquery for latest versions
subquery = (
db.session.query(
DocumentVersion.doc_id,
func.max(DocumentVersion.id).label('latest_version_id')
)
.group_by(DocumentVersion.doc_id)
.subquery()
)
# Main query
query_obj = (
db.session.query(
db_class,
(1 - db_class.embedding.cosine_distance(query_embedding)).label('similarity')
)
.join(DocumentVersion, db_class.doc_vers_id == DocumentVersion.id)
.join(Document, DocumentVersion.doc_id == Document.id)
.join(subquery, DocumentVersion.id == subquery.c.latest_version_id)
.filter(
or_(Document.valid_from.is_(None), func.date(Document.valid_from) <= current_date),
or_(Document.valid_to.is_(None), func.date(Document.valid_to) >= current_date),
(1 - db_class.embedding.cosine_distance(query_embedding)) > self.similarity_threshold,
Document.catalog_id == self.catalog_id
)
)
# Apply metadata filtering
query_obj = self._apply_metadata_filter(query_obj, arguments)
# Apply ordering and limit
query_obj = query_obj.order_by(desc('similarity')).limit(self.k)
# Execute query
results = query_obj.all()
# Transform results into standard format
processed_results = []
for doc, similarity in results:
# Parse user_metadata to ensure it's a dictionary
user_metadata = self._parse_metadata(doc.document_version.user_metadata)
processed_results.append(
RetrieverResult(
id=doc.id,
chunk=doc.chunk,
similarity=float(similarity),
metadata=RetrieverMetadata(
document_id=doc.document_version.doc_id,
version_id=doc.document_version.id,
document_name=doc.document_version.document.name,
user_metadata=user_metadata,
)
)
)
# Log the retrieval
if self.tuning:
compiled_query = str(query_obj.statement.compile(
compile_kwargs={"literal_binds": True} # This will include the actual values in the SQL
))
self.log_tuning('retrieve', {
"arguments": arguments.model_dump(),
"similarity_threshold": self.similarity_threshold,
"k": self.k,
"query": compiled_query,
"results_count": len(results),
"processed_results_count": len(processed_results),
})
return processed_results
except SQLAlchemyError as e:
current_app.logger.error(f'Error in Dossier retrieval: {e}')
db.session.rollback()
raise
except Exception as e:
current_app.logger.error(f'Unexpected error in Dossier retrieval: {e}')
raise
# Register the retriever type
RetrieverRegistry.register("DOSSIER_RETRIEVER", DossierRetriever)

View File

@@ -11,54 +11,25 @@ from common.models.document import Document, DocumentVersion, Catalog, Retriever
from common.models.user import Tenant
from common.utils.datetime_utils import get_date_in_timezone
from common.utils.model_utils import get_embedding_model_and_class
from .base import BaseRetriever
from eveai_chat_workers.retrievers.base_retriever import BaseRetriever
from .registry import RetrieverRegistry
from .retriever_typing import RetrieverArguments, RetrieverResult, RetrieverMetadata
from eveai_chat_workers.retrievers.retriever_typing import RetrieverArguments, RetrieverResult, RetrieverMetadata
class StandardRAGRetriever(BaseRetriever):
class RetrieverExecutor(BaseRetriever):
"""Standard RAG retriever implementation"""
def __init__(self, tenant_id: int, retriever_id: int):
super().__init__(tenant_id, retriever_id)
# Set up standard retrieval parameters
self.embedding_model, self.embedding_model_class, self.catalog_id, self.similarity_threshold, self.k = self.setup_standard_retrieval_params()
self.log_tuning("Standard RAG retriever initialized", {
"similarity_threshold": self.similarity_threshold,
"k": self.k
})
@property
def type(self) -> str:
return "STANDARD_RAG"
def _parse_metadata(self, metadata: Any) -> Dict[str, Any]:
"""
Parse metadata ensuring it's a dictionary
Args:
metadata: Input metadata which could be string, dict, or None
Returns:
Dict[str, Any]: Parsed metadata as dictionary
"""
if metadata is None:
return {}
if isinstance(metadata, dict):
return metadata
if isinstance(metadata, str):
try:
return json.loads(metadata)
except json.JSONDecodeError:
current_app.logger.warning(f"Failed to parse metadata JSON string: {metadata}")
return {}
current_app.logger.warning(f"Unexpected metadata type: {type(metadata)}")
return {}
@property
def type_version(self) -> str:
return "1.0"
def retrieve(self, arguments: RetrieverArguments) -> List[RetrieverResult]:
"""
@@ -72,10 +43,10 @@ class StandardRAGRetriever(BaseRetriever):
List[RetrieverResult]: List of retrieved documents with similarity scores
"""
try:
query = arguments.query
question = arguments.question
# Get query embedding
query_embedding = self.embedding_model.embed_query(query)
query_embedding = self.embedding_model.embed_query(question)
# Get the appropriate embedding database model
db_class = self.embedding_model_class
@@ -93,6 +64,9 @@ class StandardRAGRetriever(BaseRetriever):
.subquery()
)
similarity_threshold = self.retriever.configuration.get('es_similarity_threshold', 0.3)
k = self.retriever.configuration.get('es_k', 8)
# Main query
query_obj = (
db.session.query(
@@ -106,11 +80,11 @@ class StandardRAGRetriever(BaseRetriever):
.filter(
or_(Document.valid_from.is_(None), func.date(Document.valid_from) <= current_date),
or_(Document.valid_to.is_(None), func.date(Document.valid_to) >= current_date),
(1 - db_class.embedding.cosine_distance(query_embedding)) > self.similarity_threshold,
(1 - db_class.embedding.cosine_distance(query_embedding)) > similarity_threshold,
Document.catalog_id == self.catalog_id
)
.order_by(desc('similarity'))
.limit(self.k)
.limit(k)
)
results = query_obj.all()
@@ -160,5 +134,3 @@ class StandardRAGRetriever(BaseRetriever):
raise
# Register the retriever type
RetrieverRegistry.register("STANDARD_RAG", StandardRAGRetriever)

View File

@@ -1,20 +0,0 @@
from typing import Dict, Type
from .base import BaseRetriever
class RetrieverRegistry:
"""Registry for retriever types"""
_registry: Dict[str, Type[BaseRetriever]] = {}
@classmethod
def register(cls, retriever_type: str, retriever_class: Type[BaseRetriever]):
"""Register a new retriever type"""
cls._registry[retriever_type] = retriever_class
@classmethod
def get_retriever_class(cls, retriever_type: str) -> Type[BaseRetriever]:
"""Get the retriever class for a given type"""
if retriever_type not in cls._registry:
raise ValueError(f"Unknown retriever type: {retriever_type}")
return cls._registry[retriever_type]

View File

@@ -34,6 +34,8 @@ class RetrieverArguments(BaseModel):
type: str = Field(..., description="Type of retriever (e.g. STANDARD_RAG)")
type_version: str = Field(..., description="Version of retriever type (e.g. 1.0)")
question: str = Field(..., description="Question to retrieve answers for")
# Allow any additional fields
model_config = {
"extra": "allow"

View File

@@ -8,8 +8,7 @@ from common.models.interaction import SpecialistRetriever, Specialist
from common.models.user import Tenant
from common.utils.execution_progress import ExecutionProgressTracker
from config.logging_config import TuningLogger
from eveai_chat_workers.retrievers.base import BaseRetriever
from eveai_chat_workers.retrievers.registry import RetrieverRegistry
from eveai_chat_workers.retrievers.base_retriever import BaseRetriever, get_retriever_class
from eveai_chat_workers.specialists.specialist_typing import SpecialistArguments, SpecialistResult
@@ -56,20 +55,16 @@ class BaseSpecialistExecutor(ABC):
for spec_retriever in specialist_retrievers:
# Get retriever configuration from database
retriever = spec_retriever.retriever
retriever_class = RetrieverRegistry.get_retriever_class(retriever.type)
retriever_executor_class = get_retriever_class(retriever.type, retriever.type_version)
self.log_tuning("_initialize_retrievers", {
"Retriever id": spec_retriever.retriever_id,
"Retriever Type": retriever.type,
"Retriever Class": str(retriever_class),
"Retriever Version": retriever.type_version,
})
retriever_executor = retriever_executor_class(self.tenant_id, spec_retriever.retriever_id)
# Initialize retriever with its configuration
retrievers.append(
retriever_class(
tenant_id=self.tenant_id,
retriever_id=retriever.id,
)
)
retrievers.append(retriever_executor)
return retrievers
@@ -144,7 +139,7 @@ def get_specialist_class(specialist_type: str, type_version: str):
if partner:
module_path = f"eveai_chat_workers.specialists.{partner}.{specialist_type}.{major_minor}"
else:
module_path = f"eveai_chat_workers.specialists.{specialist_type}.{major_minor}"
module_path = f"eveai_chat_workers.specialists.globals.{specialist_type}.{major_minor}"
current_app.logger.debug(f"Importing specialist class from {module_path}")
module = importlib.import_module(module_path)
return module.SpecialistExecutor

View File

@@ -308,12 +308,12 @@ class CrewAIBaseSpecialistExecutor(BaseSpecialistExecutor):
current_retriever_args = retriever_arguments[retriever_id]
if isinstance(retriever_arguments[retriever_id], RetrieverArguments):
updated_args = current_retriever_args.model_dump()
updated_args['query'] = arguments.query
updated_args['question'] = arguments.question
updated_args['language'] = arguments.language
retriever_args = RetrieverArguments(**updated_args)
else:
# Create a new RetrieverArguments instance from the dictionary
current_retriever_args['query'] = arguments.query
current_retriever_args['query'] = arguments.question
retriever_args = RetrieverArguments(**current_retriever_args)
# Each retriever gets its own specific arguments

View File

@@ -6,6 +6,7 @@ from crewai.flow.flow import start, listen, and_
from flask import current_app
from pydantic import BaseModel, Field
from common.services.utils.translation_services import TranslationServices
from common.utils.business_event_context import current_event
from eveai_chat_workers.retrievers.retriever_typing import RetrieverArguments
from eveai_chat_workers.specialists.crewai_base_specialist import CrewAIBaseSpecialistExecutor
@@ -13,6 +14,10 @@ from eveai_chat_workers.specialists.specialist_typing import SpecialistResult, S
from eveai_chat_workers.outputs.globals.rag.rag_v1_0 import RAGOutput
from eveai_chat_workers.specialists.crewai_base_classes import EveAICrewAICrew, EveAICrewAIFlow, EveAIFlowState
INSUFFICIENT_INFORMATION_MESSAGE = (
"We do not have the necessary information to provide you with the requested answers. "
"Please accept our apologies. Don't hesitate to ask other questions, and I'll do my best to answer them.")
class SpecialistExecutor(CrewAIBaseSpecialistExecutor):
"""
@@ -40,6 +45,9 @@ class SpecialistExecutor(CrewAIBaseSpecialistExecutor):
def _config_pydantic_outputs(self):
self._add_pydantic_output("rag_task", RAGOutput, "rag_output")
def _config_state_result_relations(self):
self._add_state_result_relation("rag_output")
def _instantiate_specialist(self):
verbose = self.tuning
@@ -61,40 +69,84 @@ class SpecialistExecutor(CrewAIBaseSpecialistExecutor):
def execute(self, arguments: SpecialistArguments, formatted_context, citations) -> SpecialistResult:
self.log_tuning("RAG Specialist execution started", {})
flow_inputs = {
"language": arguments.language,
"query": arguments.query,
"context": formatted_context,
"history": self.formatted_history,
"name": self.specialist.configuration.get('name', ''),
"company": self.specialist.configuration.get('company', ''),
}
# crew_results = self.rag_crew.kickoff(inputs=flow_inputs)
# current_app.logger.debug(f"Test Crew Output received: {crew_results}")
flow_results = self.flow.kickoff(inputs=flow_inputs)
current_app.logger.debug(f"Arguments: {arguments.model_dump()}")
current_app.logger.debug(f"Formatted Context: {formatted_context}")
current_app.logger.debug(f"Formatted History: {self._formatted_history}")
current_app.logger.debug(f"Cached Chat Session: {self._cached_session}")
flow_state = self.flow.state
if not self._cached_session.interactions:
specialist_phase = "initial"
else:
specialist_phase = self._cached_session.interactions[-1].specialist_results.get('phase', 'initial')
results = RAGSpecialistResult.create_for_type(self.type, self.type_version)
update_data = {}
if flow_state.rag_output: # Fallback
update_data["rag_output"] = flow_state.rag_output
results = None
current_app.logger.debug(f"Specialist Phase: {specialist_phase}")
results = results.model_copy(update=update_data)
match specialist_phase:
case "initial":
results = self.execute_initial_state(arguments, formatted_context, citations)
case "rag":
results = self.execute_rag_state(arguments, formatted_context, citations)
self.log_tuning(f"RAG Specialist execution ended", {"Results": results.model_dump()})
return results
def execute_initial_state(self, arguments: SpecialistArguments, formatted_context, citations) -> SpecialistResult:
self.log_tuning("RAG Specialist initial_state execution started", {})
welcome_message = self.specialist.configuration.get('welcome_message', 'Welcome! You can start asking questions')
welcome_message = TranslationServices.translate(self.tenant_id, welcome_message, arguments.language)
self.flow.state.answer = welcome_message
self.flow.state.phase = "rag"
results = RAGSpecialistResult.create_for_type(self.type, self.type_version)
return results
def execute_rag_state(self, arguments: SpecialistArguments, formatted_context, citations) -> SpecialistResult:
self.log_tuning("RAG Specialist rag_state execution started", {})
insufficient_info_message = TranslationServices.translate(self.tenant_id,
INSUFFICIENT_INFORMATION_MESSAGE,
arguments.language)
if formatted_context:
flow_inputs = {
"language": arguments.language,
"question": arguments.question,
"context": formatted_context,
"history": self.formatted_history,
"name": self.specialist.configuration.get('name', ''),
"welcome_message": self.specialist.configuration.get('welcome_message', '')
}
flow_results = self.flow.kickoff(inputs=flow_inputs)
if flow_results.rag_output.insufficient_info:
flow_results.rag_output.answer = insufficient_info_message
rag_output = flow_results.rag_output
else:
rag_output = RAGOutput(answer=insufficient_info_message, insufficient_info=True)
self.flow.state.rag_output = rag_output
self.flow.state.answer = rag_output.answer
self.flow.state.phase = "rag"
results = RAGSpecialistResult.create_for_type(self.type, self.type_version)
class RAGSpecialistInput(BaseModel):
language: Optional[str] = Field(None, alias="language")
query: Optional[str] = Field(None, alias="query")
question: Optional[str] = Field(None, alias="question")
context: Optional[str] = Field(None, alias="context")
citations: Optional[List[int]] = Field(None, alias="citations")
history: Optional[str] = Field(None, alias="history")
name: Optional[str] = Field(None, alias="name")
company: Optional[str] = Field(None, alias="company")
welcome_message: Optional[str] = Field(None, alias="welcome_message")
class RAGSpecialistResult(SpecialistResult):

View File

@@ -0,0 +1,197 @@
import json
from os import wait
from typing import Optional, List
from crewai.flow.flow import start, listen, and_
from flask import current_app
from pydantic import BaseModel, Field
from common.services.utils.translation_services import TranslationServices
from common.utils.business_event_context import current_event
from eveai_chat_workers.retrievers.retriever_typing import RetrieverArguments
from eveai_chat_workers.specialists.crewai_base_specialist import CrewAIBaseSpecialistExecutor
from eveai_chat_workers.specialists.specialist_typing import SpecialistResult, SpecialistArguments
from eveai_chat_workers.outputs.globals.rag.rag_v1_0 import RAGOutput
from eveai_chat_workers.specialists.crewai_base_classes import EveAICrewAICrew, EveAICrewAIFlow, EveAIFlowState
INSUFFICIENT_INFORMATION_MESSAGE = (
"We do not have the necessary information to provide you with the requested answers. "
"Please accept our apologies. Don't hesitate to ask other questions, and I'll do my best to answer them.")
class SpecialistExecutor(CrewAIBaseSpecialistExecutor):
"""
type: RAG_SPECIALIST
type_version: 1.0
RAG Specialist Executor class
"""
def __init__(self, tenant_id, specialist_id, session_id, task_id, **kwargs):
self.rag_crew = None
super().__init__(tenant_id, specialist_id, session_id, task_id)
@property
def type(self) -> str:
return "RAG_SPECIALIST"
@property
def type_version(self) -> str:
return "1.1"
def _config_task_agents(self):
self._add_task_agent("rag_task", "rag_agent")
def _config_pydantic_outputs(self):
self._add_pydantic_output("rag_task", RAGOutput, "rag_output")
def _config_state_result_relations(self):
self._add_state_result_relation("rag_output")
def _instantiate_specialist(self):
verbose = self.tuning
rag_agents = [self.rag_agent]
rag_tasks = [self.rag_task]
self.rag_crew = EveAICrewAICrew(
self,
"Rag Crew",
agents=rag_agents,
tasks=rag_tasks,
verbose=verbose,
)
self.flow = RAGFlow(
self,
self.rag_crew,
)
def execute(self, arguments: SpecialistArguments, formatted_context, citations) -> SpecialistResult:
self.log_tuning("RAG Specialist execution started", {})
current_app.logger.debug(f"Arguments: {arguments.model_dump()}")
current_app.logger.debug(f"Formatted Context: {formatted_context}")
current_app.logger.debug(f"Formatted History: {self._formatted_history}")
current_app.logger.debug(f"Cached Chat Session: {self._cached_session}")
if not self._cached_session.interactions:
specialist_phase = "initial"
else:
specialist_phase = self._cached_session.interactions[-1].specialist_results.get('phase', 'initial')
results = None
current_app.logger.debug(f"Specialist Phase: {specialist_phase}")
match specialist_phase:
case "initial":
results = self.execute_initial_state(arguments, formatted_context, citations)
case "rag":
results = self.execute_rag_state(arguments, formatted_context, citations)
self.log_tuning(f"RAG Specialist execution ended", {"Results": results.model_dump()})
return results
def execute_initial_state(self, arguments: SpecialistArguments, formatted_context, citations) -> SpecialistResult:
self.log_tuning("RAG Specialist initial_state execution started", {})
welcome_message = self.specialist.configuration.get('welcome_message', 'Welcome! You can start asking questions')
welcome_message = TranslationServices.translate(self.tenant_id, welcome_message, arguments.language)
self.flow.state.answer = welcome_message
self.flow.state.phase = "rag"
results = RAGSpecialistResult.create_for_type(self.type, self.type_version)
return results
def execute_rag_state(self, arguments: SpecialistArguments, formatted_context, citations) -> SpecialistResult:
self.log_tuning("RAG Specialist rag_state execution started", {})
insufficient_info_message = TranslationServices.translate(self.tenant_id,
INSUFFICIENT_INFORMATION_MESSAGE,
arguments.language)
if formatted_context:
flow_inputs = {
"language": arguments.language,
"question": arguments.question,
"context": formatted_context,
"history": self.formatted_history,
"name": self.specialist.configuration.get('name', ''),
"welcome_message": self.specialist.configuration.get('welcome_message', '')
}
flow_results = self.flow.kickoff(inputs=flow_inputs)
if flow_results.rag_output.insufficient_info:
flow_results.rag_output.answer = insufficient_info_message
rag_output = flow_results.rag_output
else:
rag_output = RAGOutput(answer=insufficient_info_message, insufficient_info=True)
self.flow.state.rag_output = rag_output
self.flow.state.answer = rag_output.answer
self.flow.state.phase = "rag"
results = RAGSpecialistResult.create_for_type(self.type, self.type_version)
class RAGSpecialistInput(BaseModel):
language: Optional[str] = Field(None, alias="language")
question: Optional[str] = Field(None, alias="question")
context: Optional[str] = Field(None, alias="context")
history: Optional[str] = Field(None, alias="history")
name: Optional[str] = Field(None, alias="name")
welcome_message: Optional[str] = Field(None, alias="welcome_message")
class RAGSpecialistResult(SpecialistResult):
rag_output: Optional[RAGOutput] = Field(None, alias="Rag Output")
class RAGFlowState(EveAIFlowState):
"""Flow state for RAG specialist that automatically updates from task outputs"""
input: Optional[RAGSpecialistInput] = None
rag_output: Optional[RAGOutput] = None
class RAGFlow(EveAICrewAIFlow[RAGFlowState]):
def __init__(self,
specialist_executor: CrewAIBaseSpecialistExecutor,
rag_crew: EveAICrewAICrew,
**kwargs):
super().__init__(specialist_executor, "RAG Specialist Flow", **kwargs)
self.specialist_executor = specialist_executor
self.rag_crew = rag_crew
self.exception_raised = False
@start()
def process_inputs(self):
return ""
@listen(process_inputs)
async def execute_rag(self):
inputs = self.state.input.model_dump()
try:
crew_output = await self.rag_crew.kickoff_async(inputs=inputs)
self.specialist_executor.log_tuning("RAG Crew Output", crew_output.model_dump())
output_pydantic = crew_output.pydantic
if not output_pydantic:
raw_json = json.loads(crew_output.raw)
output_pydantic = RAGOutput.model_validate(raw_json)
self.state.rag_output = output_pydantic
return crew_output
except Exception as e:
current_app.logger.error(f"CREW rag_crew Kickoff Error: {str(e)}")
self.exception_raised = True
raise e
async def kickoff_async(self, inputs=None):
current_app.logger.debug(f"Async kickoff {self.name}")
current_app.logger.debug(f"Inputs: {inputs}")
self.state.input = RAGSpecialistInput.model_validate(inputs)
result = await super().kickoff_async(inputs)
return self.state

View File

@@ -51,6 +51,9 @@ class SpecialistExecutor(CrewAIBaseSpecialistExecutor):
def _config_pydantic_outputs(self):
self._add_pydantic_output("traicie_get_competencies_task", Competencies, "competencies")
def _config_state_result_relations(self):
self._add_state_result_relation("competencies")
def _instantiate_specialist(self):
verbose = self.tuning
@@ -83,13 +86,9 @@ class SpecialistExecutor(CrewAIBaseSpecialistExecutor):
flow_results = self.flow.kickoff(inputs=flow_inputs)
flow_state = self.flow.state
results = RoleDefinitionSpecialistResult.create_for_type(self.type, self.type_version)
if flow_state.competencies:
results.competencies = flow_state.competencies
self.create_selection_specialist(arguments, flow_state.competencies)
self.create_selection_specialist(arguments, self.flow.state.competencies)
self.log_tuning(f"Traicie Role Definition Specialist execution ended", {"Results": results.model_dump()})

View File

@@ -55,7 +55,7 @@ class SpecialistExecutor(CrewAIBaseSpecialistExecutor):
"""
def __init__(self, tenant_id, specialist_id, session_id, task_id, **kwargs):
self.role_definition_crew = None
self.rag_crew = None
super().__init__(tenant_id, specialist_id, session_id, task_id)
@@ -407,8 +407,7 @@ class SpecialistExecutor(CrewAIBaseSpecialistExecutor):
if rag_output.rag_output.insufficient_info:
rag_output.rag_output.answer = insufficient_info_message
else:
rag_output = RAGOutput(answer=insufficient_info_message,
insufficient_info=True)
rag_output = RAGOutput(answer=insufficient_info_message, insufficient_info=True)
self.log_tuning(f"RAG Specialist execution ended", {"Results": rag_output.model_dump()})

View File

@@ -0,0 +1,29 @@
"""Add type_version to Catalog Model
Revision ID: 26e8f0d8c143
Revises: af3d56001771
Create Date: 2025-07-07 12:15:40.046144
"""
from alembic import op
import sqlalchemy as sa
import pgvector
# revision identifiers, used by Alembic.
revision = '26e8f0d8c143'
down_revision = 'af3d56001771'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('catalog', sa.Column('type_version', sa.String(length=20), nullable=True))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('catalog', 'type_version')
# ### end Alembic commands ###