- Started addition of Assets (to e.g. handle document templates).

- To be continued (Models, first views are ready)
This commit is contained in:
Josako
2025-03-17 17:40:42 +01:00
parent a6402524ce
commit cf2201a1f7
13 changed files with 778 additions and 39 deletions

View File

@@ -0,0 +1,62 @@
from datetime import datetime as dt, timezone as tz
from flask import current_app
from sqlalchemy.exc import SQLAlchemyError
from common.extensions import cache_manager, minio_client, db
from common.models.interaction import EveAIAsset, EveAIAssetVersion
from common.utils.document_utils import mark_tenant_storage_dirty
from common.utils.model_logging_utils import set_logging_information
def create_asset_stack(api_input, tenant_id):
type_version = cache_manager.assets_version_tree_cache.get_latest_version(api_input['type'])
api_input['type_version'] = type_version
new_asset = create_asset(api_input, tenant_id)
new_asset_version = create_version_for_asset(new_asset, tenant_id)
db.session.add(new_asset)
db.session.add(new_asset_version)
try:
db.session.commit()
except SQLAlchemyError as e:
current_app.logger.error(f"Could not add asset for tenant {tenant_id}: {str(e)}")
db.session.rollback()
raise e
return new_asset, new_asset_version
def create_asset(api_input, tenant_id):
new_asset = EveAIAsset()
new_asset.name = api_input['name']
new_asset.description = api_input['description']
new_asset.type = api_input['type']
new_asset.type_version = api_input['type_version']
if api_input['valid_from'] and api_input['valid_from'] != '':
new_asset.valid_from = api_input['valid_from']
else:
new_asset.valid_from = dt.now(tz.utc)
new_asset.valid_to = api_input['valid_to']
set_logging_information(new_asset, dt.now(tz.utc))
return new_asset
def create_version_for_asset(asset, tenant_id):
new_asset_version = EveAIAssetVersion()
new_asset_version.asset = asset
new_asset_version.bucket_name = minio_client.create_tenant_bucket(tenant_id)
set_logging_information(new_asset_version, dt.now(tz.utc))
return new_asset_version
def add_asset_version_file(asset_version, field_name, file, tenant_id):
object_name, file_size = minio_client.upload_file(asset_version.bucket_name, asset_version.id, field_name,
file.content_type)
mark_tenant_storage_dirty(tenant_id)
return object_name

View File

@@ -1,4 +1,5 @@
import os
import time
import uuid
from contextlib import contextmanager
from datetime import datetime
@@ -82,15 +83,21 @@ class BusinessEvent:
self.span_name = span_name
self.parent_span_id = parent_span_id
# Track start time for the span
span_start_time = time.time()
self.log(f"Start")
try:
yield
finally:
# Calculate total time for this span
span_total_time = time.time() - span_start_time
if self.llm_metrics['call_count'] > 0:
self.log_final_metrics()
self.reset_llm_metrics()
self.log(f"End")
self.log(f"End", extra_fields={'span_duration': span_total_time})
# Restore the previous span info
if self.spans:
self.span_id, self.span_name, self.parent_span_id = self.spans.pop()
@@ -99,7 +106,7 @@ class BusinessEvent:
self.span_name = None
self.parent_span_id = None
def log(self, message: str, level: str = 'info'):
def log(self, message: str, level: str = 'info', extra_fields: Dict[str, Any] = None):
log_data = {
'timestamp': dt.now(tz=tz.utc),
'event_type': self.event_type,
@@ -115,6 +122,15 @@ class BusinessEvent:
'environment': self.environment,
'message': message,
}
# Add any extra fields
if extra_fields:
for key, value in extra_fields.items():
# For span/trace duration, use the llm_metrics_total_time field
if key == 'span_duration' or key == 'trace_duration':
log_data['llm_metrics_total_time'] = value
else:
log_data[key] = value
self._log_buffer.append(log_data)
def log_llm_metrics(self, metrics: dict, level: str = 'info'):
@@ -226,13 +242,17 @@ class BusinessEvent:
self._log_buffer = []
def __enter__(self):
self.trace_start_time = time.time()
self.log(f'Starting Trace for {self.event_type}')
return BusinessEventContext(self).__enter__()
def __exit__(self, exc_type, exc_val, exc_tb):
trace_total_time = time.time() - self.trace_start_time
if self.llm_metrics['call_count'] > 0:
self.log_final_metrics()
self.reset_llm_metrics()
self.log(f'Ending Trace for {self.event_type}')
self.log(f'Ending Trace for {self.event_type}', extra_fields={'trace_duration': trace_total_time})
self._flush_log_buffer()
return BusinessEventContext(self).__exit__(exc_type, exc_val, exc_tb)

View File

@@ -33,6 +33,9 @@ class MinioClient:
def generate_object_name(self, document_id, language, version_id, filename):
return f"{document_id}/{language}/{version_id}/{filename}"
def generate_asset_name(self, asset_version_id, file_name, content_type):
return f"assets/{asset_version_id}/{file_name}.{content_type}"
def upload_document_file(self, tenant_id, document_id, language, version_id, filename, file_data):
bucket_name = self.generate_bucket_name(tenant_id)
object_name = self.generate_object_name(document_id, language, version_id, filename)
@@ -54,6 +57,26 @@ class MinioClient:
except S3Error as err:
raise Exception(f"Error occurred while uploading file: {err}")
def upload_asset_file(self, bucket_name, asset_version_id, file_name, file_type, file_data):
object_name = self.generate_asset_name(asset_version_id, file_name, file_type)
try:
if isinstance(file_data, FileStorage):
file_data = file_data.read()
elif isinstance(file_data, io.BytesIO):
file_data = file_data.getvalue()
elif isinstance(file_data, str):
file_data = file_data.encode('utf-8')
elif not isinstance(file_data, bytes):
raise TypeError('Unsupported file type. Expected FileStorage, BytesIO, str, or bytes.')
self.client.put_object(
bucket_name, object_name, io.BytesIO(file_data), len(file_data)
)
return object_name, len(file_data)
except S3Error as err:
raise Exception(f"Error occurred while uploading asset: {err}")
def download_document_file(self, tenant_id, bucket_name, object_name):
try:
response = self.client.get_object(bucket_name, object_name)

View File

@@ -0,0 +1,18 @@
version: "1.0.0"
name: "Document Template"
configuration:
argument_definition:
name: "variable_defition"
type: "file"
description: "Yaml file defining the arguments in the Document Template."
required: True
content_markdown:
name: "content_markdown"
type: "str"
description: "Actual template file in markdown format."
required: True
metadata:
author: "Josako"
date_added: "2025-03-12"
description: "Asset that defines a template in markdown a specialist can process"
changes: "Initial version"

View File

@@ -0,0 +1,7 @@
# Agent Types
AGENT_TYPES = {
"DOCUMENT_TEMPLATE": {
"name": "Document Template",
"description": "Asset that defines a template in markdown a specialist can process",
},
}

View File

@@ -1,5 +1,6 @@
from flask_wtf import FlaskForm
from wtforms import IntegerField, FloatField, BooleanField, StringField, TextAreaField, validators, ValidationError
from wtforms import (IntegerField, FloatField, BooleanField, StringField, TextAreaField, FileField,
validators, ValidationError)
from flask import current_app
import json
@@ -264,6 +265,7 @@ class DynamicFormBase(FlaskForm):
'string': StringField,
'text': TextAreaField,
'date': DateField,
'file': FileField,
}.get(field_type, StringField)
field_kwargs = {}

View File

@@ -1,5 +1,6 @@
from flask_wtf import FlaskForm
from wtforms import (StringField, BooleanField, SelectField, TextAreaField)
from wtforms.fields.datetime import DateField
from wtforms.validators import DataRequired, Length, Optional
from wtforms_sqlalchemy.fields import QuerySelectMultipleField
@@ -94,3 +95,32 @@ class EditEveAITaskForm(BaseEditComponentForm):
class EditEveAIToolForm(BaseEditComponentForm):
pass
class AddEveAIAssetForm(FlaskForm):
name = StringField('Name', validators=[DataRequired(), Length(max=50)])
description = TextAreaField('Description', validators=[Optional()])
type = SelectField('Type', validators=[DataRequired()])
valid_from = DateField('Valid From', id='form-control datepicker', validators=[Optional()])
valid_to = DateField('Valid To', id='form-control datepicker', validators=[Optional()])
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
types_dict = cache_manager.assets_types_cache.get_types()
self.type.choices = [(key, value['name']) for key, value in types_dict.items()]
class EditEveAIAssetForm(FlaskForm):
name = StringField('Name', validators=[DataRequired(), Length(max=50)])
description = TextAreaField('Description', validators=[Optional()])
type = SelectField('Type', validators=[DataRequired()], render_kw={'readonly': True})
type_version = StringField('Type Version', validators=[DataRequired()], render_kw={'readonly': True})
valid_from = DateField('Valid From', id='form-control datepicker', validators=[Optional()])
valid_to = DateField('Valid To', id='form-control datepicker', validators=[Optional()])
class EditEveAIAssetVersionForm(DynamicFormBase):
asset_name = StringField('Asset Name', validators=[DataRequired()], render_kw={'readonly': True})
asset_type = StringField('Asset Type', validators=[DataRequired()], render_kw={'readonly': True})
asset_type_version = StringField('Asset Type Version', validators=[DataRequired()], render_kw={'readonly': True})
bucket_name = StringField('Bucket Name', validators=[DataRequired()], render_kw={'readonly': True})

View File

@@ -6,12 +6,15 @@ from flask_security import roles_accepted
from langchain.agents import Agent
from sqlalchemy import desc
from sqlalchemy.exc import SQLAlchemyError
from werkzeug.datastructures import FileStorage
from werkzeug.utils import secure_filename
from common.models.document import Embedding, DocumentVersion, Retriever
from common.models.interaction import (ChatSession, Interaction, InteractionEmbedding, Specialist, SpecialistRetriever,
EveAIAgent, EveAITask, EveAITool)
EveAIAgent, EveAITask, EveAITool, EveAIAssetVersion)
from common.extensions import db, cache_manager
from common.utils.asset_utils import create_asset_stack, add_asset_version_file
from common.utils.model_logging_utils import set_logging_information, update_logging_information
from common.utils.middleware import mw_before_request
@@ -22,7 +25,7 @@ from common.utils.specialist_utils import initialize_specialist
from config.type_defs.specialist_types import SPECIALIST_TYPES
from .interaction_forms import (SpecialistForm, EditSpecialistForm, EditEveAIAgentForm, EditEveAITaskForm,
EditEveAIToolForm)
EditEveAIToolForm, AddEveAIAssetForm, EditEveAIAssetVersionForm)
interaction_bp = Blueprint('interaction_bp', __name__, url_prefix='/interaction')
@@ -37,6 +40,7 @@ def log_after_request(response):
return response
# Routes for Chat Session Management --------------------------------------------------------------
@interaction_bp.before_request
def before_request():
try:
@@ -129,6 +133,7 @@ def show_chat_session(chat_session):
return render_template('interaction/view_chat_session.html', chat_session=chat_session, interactions=interactions)
# Routes for Specialist Management ----------------------------------------------------------------
@interaction_bp.route('/specialist', methods=['GET', 'POST'])
@roles_accepted('Super User', 'Tenant Admin')
def specialist():
@@ -142,7 +147,8 @@ def specialist():
# Populate fields individually instead of using populate_obj (gives problem with QueryMultipleSelectField)
new_specialist.name = form.name.data
new_specialist.type = form.type.data
new_specialist.type_version = cache_manager.specialists_version_tree_cache.get_latest_version(new_specialist.type)
new_specialist.type_version = cache_manager.specialists_version_tree_cache.get_latest_version(
new_specialist.type)
new_specialist.tuning = form.tuning.data
set_logging_information(new_specialist, dt.now(tz.utc))
@@ -252,7 +258,7 @@ def edit_specialist(specialist_id):
task_rows=task_rows,
tool_rows=tool_rows,
prefixed_url_for=prefixed_url_for,
svg_path=svg_path,)
svg_path=svg_path, )
else:
form_validation_failed(request, form)
@@ -263,7 +269,7 @@ def edit_specialist(specialist_id):
task_rows=task_rows,
tool_rows=tool_rows,
prefixed_url_for=prefixed_url_for,
svg_path=svg_path,)
svg_path=svg_path, )
@interaction_bp.route('/specialists', methods=['GET', 'POST'])
@@ -298,7 +304,7 @@ def handle_specialist_selection():
return redirect(prefixed_url_for('interaction_bp.specialists'))
# Routes for Agent management
# Routes for Agent management ---------------------------------------------------------------------
@interaction_bp.route('/agent/<int:agent_id>/edit', methods=['GET'])
@roles_accepted('Super User', 'Tenant Admin')
def edit_agent(agent_id):
@@ -338,7 +344,7 @@ def save_agent(agent_id):
return jsonify({'success': False, 'message': 'Validation failed'})
# Routes for Task management
# Routes for Task management ----------------------------------------------------------------------
@interaction_bp.route('/task/<int:task_id>/edit', methods=['GET'])
@roles_accepted('Super User', 'Tenant Admin')
def edit_task(task_id):
@@ -374,7 +380,7 @@ def save_task(task_id):
return jsonify({'success': False, 'message': 'Validation failed'})
# Routes for Tool management
# Routes for Tool management ----------------------------------------------------------------------
@interaction_bp.route('/tool/<int:tool_id>/edit', methods=['GET'])
@roles_accepted('Super User', 'Tenant Admin')
def edit_tool(tool_id):
@@ -410,7 +416,7 @@ def save_tool(tool_id):
return jsonify({'success': False, 'message': 'Validation failed'})
# Component selection handlers
# Component selection handlers --------------------------------------------------------------------
@interaction_bp.route('/handle_agent_selection', methods=['POST'])
@roles_accepted('Super User', 'Tenant Admin')
def handle_agent_selection():
@@ -448,3 +454,92 @@ def handle_tool_selection():
return redirect(prefixed_url_for('interaction_bp.edit_tool', tool_id=tool_id))
return redirect(prefixed_url_for('interaction_bp.edit_specialist'))
# Routes for Asset management ---------------------------------------------------------------------
@interaction_bp.route('/add_asset', methods=['GET', 'POST'])
@roles_accepted('Super User', 'Tenant Admin')
def add_asset():
form = AddEveAIAssetForm(request.form)
tenant_id = session.get('tenant').get('id')
if form.validate_on_submit():
try:
current_app.logger.info(f"Adding asset for tenant {tenant_id}")
api_input = {
'name': form.name.data,
'description': form.description.data,
'type': form.type.data,
'valid_from': form.valid_from.data,
'valid_to': form.valid_to.data,
}
new_asset, new_asset_version = create_asset_stack(api_input, tenant_id)
return redirect(prefixed_url_for('interaction_bp.edit_asset_version',
asset_version_id=new_asset_version.id))
except Exception as e:
current_app.logger.error(f'Failed to add asset for tenant {tenant_id}: {str(e)}')
flash('An error occurred while adding asset', 'error')
return render_template('interaction/add_asset.html')
@interaction_bp.route('/edit_asset_version/<int:asset_version_id>', methods=['GET', 'POST'])
@roles_accepted('Super User', 'Tenant Admin')
def edit_asset_version(asset_version_id):
asset_version = EveAIAssetVersion.query.get_or_404(asset_version_id)
form = EditEveAIAssetVersionForm(asset_version)
asset_config = cache_manager.assets_config_cache.get_config(asset_version.asset.type,
asset_version.asset.type_version)
configuration_config = asset_config.get('configuration')
form.add_dynamic_fields("configuration", configuration_config, asset_version.configuration)
if form.validate_on_submit():
# Update the configuration dynamic fields
configuration = form.get_dynamic_data("configuration")
processed_configuration = {}
tenant_id = session.get('tenant').get('id')
# if files are returned, we will store the file_names in the configuration, and add the file to the appropriate
# bucket, in the appropriate location
for field_name, field_value in configuration.items():
# Handle file field - check if the value is a FileStorage instance
if isinstance(field_value, FileStorage) and field_value.filename:
try:
# Upload file and retrieve object_name for the file
object_name = add_asset_version_file(asset_version, field_name, field_value, tenant_id)
# Store object reference in configuration instead of file content
processed_configuration[field_name] = object_name
except Exception as e:
current_app.logger.error(f"Failed to upload file for asset version {asset_version.id}: {str(e)}")
flash(f"Failed to upload file '{field_value.filename}': {str(e)}", "danger")
return render_template('interaction/edit_asset_version.html', form=form,
asset_version=asset_version)
# Handle normal fields
else:
processed_configuration[field_name] = field_value
# Update the asset version with processed configuration
asset_version.configuration = processed_configuration
# Update logging information
update_logging_information(asset_version, dt.now(tz.utc))
try:
db.session.commit()
flash('Asset uploaded successfully!', 'success')
current_app.logger.info(f'Asset Version {asset_version.id} updated successfully')
return redirect(prefixed_url_for('interaction_bp.assets'))
except SQLAlchemyError as e:
db.session.rollback()
flash(f'Failed to upload asset. Error: {str(e)}', 'danger')
current_app.logger.error(f'Failed to update asset version {asset_version.id}. Error: {str(e)}')
return render_template('interaction/edit_asset_version.html', form=form)
else:
form_validation_failed(request, form)
return render_template('interaction/edit_asset_version.html', form=form)

View File

@@ -1,5 +1,6 @@
# Import all specialist implementations here to ensure registration
from . import standard_rag
from . import dossier_retriever
# List of all available specialist implementations
__all__ = ['standard_rag']
__all__ = ['standard_rag', 'dossier_retriever']

View File

@@ -1,8 +1,14 @@
from abc import ABC, abstractmethod, abstractproperty
from typing import Dict, Any, List
from typing import Dict, Any, List, Optional, Tuple
from flask import current_app
from eveai_chat_workers.retrievers.retriever_typing import RetrieverResult, RetrieverArguments
from sqlalchemy import func, or_, desc
from sqlalchemy.exc import SQLAlchemyError
from common.extensions import db
from common.models.document import Document, DocumentVersion, Catalog, Retriever
from common.utils.model_utils import get_embedding_model_and_class
from eveai_chat_workers.retrievers.retriever_typing import RetrieverResult, RetrieverArguments, RetrieverMetadata
from config.logging_config import TuningLogger
@@ -12,6 +18,7 @@ class BaseRetriever(ABC):
def __init__(self, tenant_id: int, retriever_id: int):
self.tenant_id = tenant_id
self.retriever_id = retriever_id
self.retriever = Retriever.query.get_or_404(retriever_id)
self.tuning = False
self.tuning_logger = None
self._setup_tuning_logger()
@@ -43,6 +50,31 @@ class BaseRetriever(ABC):
except Exception as e:
current_app.logger.error(f"Processor: Error in tuning logging: {e}")
def setup_standard_retrieval_params(self) -> Tuple[Any, Any, Any, float, int]:
"""
Set up common parameters needed for standard retrieval functionality
Returns:
Tuple containing:
- embedding_model: The model to use for embeddings
- embedding_model_class: The class for storing embeddings
- catalog_id: ID of the catalog
- similarity_threshold: Threshold for similarity matching
- k: Maximum number of results to return
"""
catalog_id = self.retriever.catalog_id
catalog = Catalog.query.get_or_404(catalog_id)
embedding_model = "mistral.mistral-embed"
embedding_model, embedding_model_class = get_embedding_model_and_class(
self.tenant_id, catalog_id, embedding_model
)
similarity_threshold = self.retriever.configuration.get('es_similarity_threshold', 0.3)
k = self.retriever.configuration.get('es_k', 8)
return embedding_model, embedding_model_class, catalog_id, similarity_threshold, k
@abstractmethod
def retrieve(self, arguments: RetrieverArguments) -> List[RetrieverResult]:
"""

View File

@@ -0,0 +1,374 @@
"""
DossierRetriever implementation that adds metadata filtering to retrieval
"""
import json
from datetime import datetime as dt, date, timezone as tz
from typing import Dict, Any, List, Optional, Union, Tuple
from sqlalchemy import func, or_, desc, and_, text, cast, JSON, String, Integer, Float, Boolean, DateTime
from sqlalchemy.sql import expression
from sqlalchemy.exc import SQLAlchemyError
from flask import current_app
from common.extensions import db
from common.models.document import Document, DocumentVersion, Catalog, Retriever
from common.utils.model_utils import get_embedding_model_and_class
from .base import BaseRetriever
from .registry import RetrieverRegistry
from .retriever_typing import RetrieverArguments, RetrieverResult, RetrieverMetadata
class DossierRetriever(BaseRetriever):
"""
Dossier Retriever implementation that adds metadata filtering
to standard retrieval functionality
"""
def __init__(self, tenant_id: int, retriever_id: int):
super().__init__(tenant_id, retriever_id)
# Set up standard retrieval parameters
self.embedding_model, self.embedding_model_class, self.catalog_id, self.similarity_threshold, self.k = self.setup_standard_retrieval_params()
# Dossier-specific configuration
self.tagging_fields_filter = self.retriever.configuration.get('tagging_fields_filter', {})
self.dynamic_arguments = self.retriever.configuration.get('dynamic_arguments', {})
self.log_tuning("Dossier retriever initialized", {
"tagging_fields_filter": self.tagging_fields_filter,
"dynamic_arguments": self.dynamic_arguments,
"similarity_threshold": self.similarity_threshold,
"k": self.k
})
@property
def type(self) -> str:
return "DOSSIER_RETRIEVER"
def _parse_metadata(self, metadata: Any) -> Dict[str, Any]:
"""
Parse metadata ensuring it's a dictionary
Args:
metadata: Input metadata which could be string, dict, or None
Returns:
Dict[str, Any]: Parsed metadata as dictionary
"""
if metadata is None:
return {}
if isinstance(metadata, dict):
return metadata
if isinstance(metadata, str):
try:
return json.loads(metadata)
except json.JSONDecodeError:
current_app.logger.warning(f"Failed to parse metadata JSON string: {metadata}")
return {}
current_app.logger.warning(f"Unexpected metadata type: {type(metadata)}")
return {}
def _apply_metadata_filter(self, query_obj, arguments: RetrieverArguments):
"""
Apply metadata filters to the query based on tagging_fields_filter configuration
Args:
query_obj: SQLAlchemy query object
arguments: Retriever arguments (for variable substitution)
Returns:
Modified SQLAlchemy query object
"""
if not self.tagging_fields_filter:
return query_obj
# Get dynamic argument values
dynamic_values = {}
for arg_name, arg_config in self.dynamic_arguments.items():
if hasattr(arguments, arg_name):
dynamic_values[arg_name] = getattr(arguments, arg_name)
# Build the filter
filter_condition = self._build_filter_condition(self.tagging_fields_filter, dynamic_values)
if filter_condition is not None:
query_obj = query_obj.filter(filter_condition)
self.log_tuning("Applied metadata filter", {
"filter_sql": str(filter_condition),
"dynamic_values": dynamic_values
})
return query_obj
def _build_filter_condition(self, filter_def: Dict[str, Any], dynamic_values: Dict[str, Any]) -> Optional[
expression.BinaryExpression]:
"""
Recursively build SQLAlchemy filter condition from filter definition
Args:
filter_def: Filter definition (logical group or field condition)
dynamic_values: Values for dynamic variable substitution
Returns:
SQLAlchemy expression or None if invalid
"""
# Handle logical groups (AND, OR, NOT)
if 'logical' in filter_def:
logical_op = filter_def['logical'].lower()
subconditions = [
self._build_filter_condition(cond, dynamic_values)
for cond in filter_def.get('conditions', [])
]
# Remove None values
subconditions = [c for c in subconditions if c is not None]
if not subconditions:
return None
if logical_op == 'and':
return and_(*subconditions)
elif logical_op == 'or':
return or_(*subconditions)
elif logical_op == 'not':
if len(subconditions) == 1:
return ~subconditions[0]
else:
# NOT should have exactly one condition
current_app.logger.warning(f"NOT operator requires exactly one condition, got {len(subconditions)}")
return None
else:
current_app.logger.warning(f"Unknown logical operator: {logical_op}")
return None
# Handle field conditions
elif 'field' in filter_def and 'operator' in filter_def and 'value' in filter_def:
field_name = filter_def['field']
operator = filter_def['operator'].lower()
value = self._resolve_value(filter_def['value'], dynamic_values, filter_def.get('default'))
# Skip if we couldn't resolve the value
if value is None and operator not in ['is_null', 'is_not_null']:
return None
# Create the field expression to match JSON data
field_expr = cast(DocumentVersion.catalog_properties['tagging_fields'][field_name], String)
# Apply the appropriate operator
return self._apply_operator(field_expr, operator, value, filter_def)
else:
current_app.logger.warning(f"Invalid filter definition: {filter_def}")
return None
def _resolve_value(self, value_def: Any, dynamic_values: Dict[str, Any], default: Any = None) -> Any:
"""
Resolve a value definition, handling variables and defaults
Args:
value_def: Value definition (could be literal, variable reference, or list)
dynamic_values: Values for dynamic variable substitution
default: Default value if variable not found
Returns:
Resolved value
"""
# Handle lists (recursively resolve each item)
if isinstance(value_def, list):
return [self._resolve_value(item, dynamic_values) for item in value_def]
# Handle variable references (strings starting with $)
if isinstance(value_def, str) and value_def.startswith('$'):
var_name = value_def[1:] # Remove $ prefix
if var_name in dynamic_values:
return dynamic_values[var_name]
else:
# Use default if provided
return default
# Return literal values as-is
return value_def
def _apply_operator(self, field_expr, operator: str, value: Any, filter_def: Dict[str, Any]) -> Optional[
expression.BinaryExpression]:
"""
Apply the specified operator to create a filter condition
Args:
field_expr: SQLAlchemy field expression
operator: Operator to apply
value: Value to compare against
filter_def: Original filter definition (for additional options)
Returns:
SQLAlchemy expression
"""
try:
# String operators
if operator == 'eq':
return field_expr == str(value)
elif operator == 'neq':
return field_expr != str(value)
elif operator == 'contains':
return field_expr.contains(str(value))
elif operator == 'not_contains':
return ~field_expr.contains(str(value))
elif operator == 'starts_with':
return field_expr.startswith(str(value))
elif operator == 'ends_with':
return field_expr.endswith(str(value))
elif operator == 'in':
return field_expr.in_([str(v) for v in value])
elif operator == 'not_in':
return ~field_expr.in_([str(v) for v in value])
elif operator == 'regex' or operator == 'not_regex':
# PostgreSQL regex using ~ or !~ operator
case_insensitive = filter_def.get('case_insensitive', False)
regex_op = '~*' if case_insensitive else '~'
if operator == 'not_regex':
regex_op = '!~*' if case_insensitive else '!~'
return text(
f"{field_expr.compile(compile_kwargs={'literal_binds': True})} {regex_op} :regex_value").bindparams(
regex_value=str(value))
# Numeric/Date operators
elif operator == 'gt':
return cast(field_expr, Float) > float(value)
elif operator == 'gte':
return cast(field_expr, Float) >= float(value)
elif operator == 'lt':
return cast(field_expr, Float) < float(value)
elif operator == 'lte':
return cast(field_expr, Float) <= float(value)
elif operator == 'between':
if len(value) == 2:
return cast(field_expr, Float).between(float(value[0]), float(value[1]))
else:
current_app.logger.warning(f"BETWEEN operator requires exactly two values, got {len(value)}")
return None
elif operator == 'not_between':
if len(value) == 2:
return ~cast(field_expr, Float).between(float(value[0]), float(value[1]))
else:
current_app.logger.warning(f"NOT_BETWEEN operator requires exactly two values, got {len(value)}")
return None
# Null checking
elif operator == 'is_null':
return field_expr.is_(None)
elif operator == 'is_not_null':
return field_expr.isnot(None)
else:
current_app.logger.warning(f"Unknown operator: {operator}")
return None
except (ValueError, TypeError) as e:
current_app.logger.error(f"Error applying operator {operator} with value {value}: {e}")
return None
def retrieve(self, arguments: RetrieverArguments) -> List[RetrieverResult]:
"""
Retrieve documents based on query with added metadata filtering
Args:
arguments: Validated RetrieverArguments containing at minimum:
- query: str - The search query
Returns:
List[RetrieverResult]: List of retrieved documents with similarity scores
"""
try:
query = arguments.query
# Get query embedding
query_embedding = self.embedding_model.embed_query(query)
# Get the appropriate embedding database model
db_class = self.embedding_model_class
# Get current date for validity checks
current_date = dt.now(tz=tz.utc).date()
# Create subquery for latest versions
subquery = (
db.session.query(
DocumentVersion.doc_id,
func.max(DocumentVersion.id).label('latest_version_id')
)
.group_by(DocumentVersion.doc_id)
.subquery()
)
# Main query
query_obj = (
db.session.query(
db_class,
(1 - db_class.embedding.cosine_distance(query_embedding)).label('similarity')
)
.join(DocumentVersion, db_class.doc_vers_id == DocumentVersion.id)
.join(Document, DocumentVersion.doc_id == Document.id)
.join(subquery, DocumentVersion.id == subquery.c.latest_version_id)
.filter(
or_(Document.valid_from.is_(None), func.date(Document.valid_from) <= current_date),
or_(Document.valid_to.is_(None), func.date(Document.valid_to) >= current_date),
(1 - db_class.embedding.cosine_distance(query_embedding)) > self.similarity_threshold,
Document.catalog_id == self.catalog_id
)
)
# Apply metadata filtering
query_obj = self._apply_metadata_filter(query_obj, arguments)
# Apply ordering and limit
query_obj = query_obj.order_by(desc('similarity')).limit(self.k)
# Execute query
results = query_obj.all()
# Transform results into standard format
processed_results = []
for doc, similarity in results:
# Parse user_metadata to ensure it's a dictionary
user_metadata = self._parse_metadata(doc.document_version.user_metadata)
processed_results.append(
RetrieverResult(
id=doc.id,
chunk=doc.chunk,
similarity=float(similarity),
metadata=RetrieverMetadata(
document_id=doc.document_version.doc_id,
version_id=doc.document_version.id,
document_name=doc.document_version.document.name,
user_metadata=user_metadata,
)
)
)
# Log the retrieval
if self.tuning:
compiled_query = str(query_obj.statement.compile(
compile_kwargs={"literal_binds": True} # This will include the actual values in the SQL
))
self.log_tuning('retrieve', {
"arguments": arguments.model_dump(),
"similarity_threshold": self.similarity_threshold,
"k": self.k,
"query": compiled_query,
"results_count": len(results),
"processed_results_count": len(processed_results),
})
return processed_results
except SQLAlchemyError as e:
current_app.logger.error(f'Error in Dossier retrieval: {e}')
db.session.rollback()
raise
except Exception as e:
current_app.logger.error(f'Unexpected error in Dossier retrieval: {e}')
raise
# Register the retriever type
RetrieverRegistry.register("DOSSIER_RETRIEVER", DossierRetriever)

View File

@@ -23,20 +23,12 @@ class StandardRAGRetriever(BaseRetriever):
def __init__(self, tenant_id: int, retriever_id: int):
super().__init__(tenant_id, retriever_id)
retriever = Retriever.query.get_or_404(retriever_id)
self.catalog_id = retriever.catalog_id
self.tenant_id = tenant_id
catalog = Catalog.query.get_or_404(self.catalog_id)
embedding_model = "mistral.mistral-embed"
self.embedding_model, self.embedding_model_class = get_embedding_model_and_class(self.tenant_id,
self.catalog_id,
embedding_model)
self.similarity_threshold = retriever.configuration.get('es_similarity_threshold', 0.3)
self.k = retriever.configuration.get('es_k', 8)
self.tuning = retriever.tuning
self.log_tuning("Standard RAG retriever initialized")
# Set up standard retrieval parameters
self.embedding_model, self.embedding_model_class, self.catalog_id, self.similarity_threshold, self.k = self.setup_standard_retrieval_params()
self.log_tuning("Standard RAG retriever initialized", {
"similarity_threshold": self.similarity_threshold,
"k": self.k
})
@property
def type(self) -> str:

View File

@@ -0,0 +1,83 @@
"""Add EveAIAsset m& associated models
Revision ID: 4a9f7a6285cc
Revises: ed981a641be9
Create Date: 2025-03-12 09:43:05.390895
"""
from alembic import op
import sqlalchemy as sa
import pgvector
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = '4a9f7a6285cc'
down_revision = 'ed981a641be9'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('eve_ai_asset',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('name', sa.String(length=50), nullable=False),
sa.Column('description', sa.Text(), nullable=True),
sa.Column('type', sa.String(length=50), nullable=False),
sa.Column('type_version', sa.String(length=20), nullable=True),
sa.Column('valid_from', sa.DateTime(), nullable=True),
sa.Column('valid_to', sa.DateTime(), nullable=True),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=False),
sa.Column('created_by', sa.Integer(), nullable=True),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('now()'), nullable=False),
sa.Column('updated_by', sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(['created_by'], ['public.user.id'], ),
sa.ForeignKeyConstraint(['updated_by'], ['public.user.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_table('eve_ai_asset_version',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('asset_id', sa.Integer(), nullable=False),
sa.Column('bucket_name', sa.String(length=255), nullable=True),
sa.Column('configuration', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.Column('arguments', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=False),
sa.Column('created_by', sa.Integer(), nullable=True),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('now()'), nullable=False),
sa.Column('updated_by', sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(['asset_id'], ['eve_ai_asset.id'], ),
sa.ForeignKeyConstraint(['created_by'], ['public.user.id'], ),
sa.ForeignKeyConstraint(['updated_by'], ['public.user.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_table('eve_ai_asset_instruction',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('asset_version_id', sa.Integer(), nullable=False),
sa.Column('name', sa.String(length=255), nullable=False),
sa.Column('content', sa.Text(), nullable=True),
sa.ForeignKeyConstraint(['asset_version_id'], ['eve_ai_asset_version.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_table('eve_ai_processed_asset',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('asset_version_id', sa.Integer(), nullable=False),
sa.Column('specialist_id', sa.Integer(), nullable=True),
sa.Column('chat_session_id', sa.Integer(), nullable=True),
sa.Column('bucket_name', sa.String(length=255), nullable=True),
sa.Column('object_name', sa.String(length=255), nullable=True),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True),
sa.ForeignKeyConstraint(['asset_version_id'], ['eve_ai_asset_version.id'], ),
sa.ForeignKeyConstraint(['chat_session_id'], ['chat_session.id'], ),
sa.ForeignKeyConstraint(['specialist_id'], ['specialist.id'], ),
sa.PrimaryKeyConstraint('id')
)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('eve_ai_processed_asset')
op.drop_table('eve_ai_asset_instruction')
op.drop_table('eve_ai_asset_version')
op.drop_table('eve_ai_asset')
# ### end Alembic commands ###