diff --git a/common/utils/asset_utils.py b/common/utils/asset_utils.py new file mode 100644 index 0000000..5eb3815 --- /dev/null +++ b/common/utils/asset_utils.py @@ -0,0 +1,62 @@ +from datetime import datetime as dt, timezone as tz + +from flask import current_app +from sqlalchemy.exc import SQLAlchemyError + +from common.extensions import cache_manager, minio_client, db +from common.models.interaction import EveAIAsset, EveAIAssetVersion +from common.utils.document_utils import mark_tenant_storage_dirty +from common.utils.model_logging_utils import set_logging_information + + +def create_asset_stack(api_input, tenant_id): + type_version = cache_manager.assets_version_tree_cache.get_latest_version(api_input['type']) + api_input['type_version'] = type_version + new_asset = create_asset(api_input, tenant_id) + new_asset_version = create_version_for_asset(new_asset, tenant_id) + db.session.add(new_asset) + db.session.add(new_asset_version) + + try: + db.session.commit() + except SQLAlchemyError as e: + current_app.logger.error(f"Could not add asset for tenant {tenant_id}: {str(e)}") + db.session.rollback() + raise e + + return new_asset, new_asset_version + + +def create_asset(api_input, tenant_id): + new_asset = EveAIAsset() + new_asset.name = api_input['name'] + new_asset.description = api_input['description'] + new_asset.type = api_input['type'] + new_asset.type_version = api_input['type_version'] + if api_input['valid_from'] and api_input['valid_from'] != '': + new_asset.valid_from = api_input['valid_from'] + else: + new_asset.valid_from = dt.now(tz.utc) + new_asset.valid_to = api_input['valid_to'] + set_logging_information(new_asset, dt.now(tz.utc)) + + return new_asset + + +def create_version_for_asset(asset, tenant_id): + new_asset_version = EveAIAssetVersion() + new_asset_version.asset = asset + new_asset_version.bucket_name = minio_client.create_tenant_bucket(tenant_id) + set_logging_information(new_asset_version, dt.now(tz.utc)) + + return new_asset_version + + +def add_asset_version_file(asset_version, field_name, file, tenant_id): + object_name, file_size = minio_client.upload_file(asset_version.bucket_name, asset_version.id, field_name, + file.content_type) + mark_tenant_storage_dirty(tenant_id) + return object_name + + + diff --git a/common/utils/business_event.py b/common/utils/business_event.py index e308399..317493f 100644 --- a/common/utils/business_event.py +++ b/common/utils/business_event.py @@ -1,4 +1,5 @@ import os +import time import uuid from contextlib import contextmanager from datetime import datetime @@ -82,15 +83,21 @@ class BusinessEvent: self.span_name = span_name self.parent_span_id = parent_span_id + # Track start time for the span + span_start_time = time.time() + self.log(f"Start") try: yield finally: + # Calculate total time for this span + span_total_time = time.time() - span_start_time + if self.llm_metrics['call_count'] > 0: self.log_final_metrics() self.reset_llm_metrics() - self.log(f"End") + self.log(f"End", extra_fields={'span_duration': span_total_time}) # Restore the previous span info if self.spans: self.span_id, self.span_name, self.parent_span_id = self.spans.pop() @@ -99,7 +106,7 @@ class BusinessEvent: self.span_name = None self.parent_span_id = None - def log(self, message: str, level: str = 'info'): + def log(self, message: str, level: str = 'info', extra_fields: Dict[str, Any] = None): log_data = { 'timestamp': dt.now(tz=tz.utc), 'event_type': self.event_type, @@ -115,6 +122,15 @@ class BusinessEvent: 'environment': self.environment, 'message': message, } + # Add any extra fields + if extra_fields: + for key, value in extra_fields.items(): + # For span/trace duration, use the llm_metrics_total_time field + if key == 'span_duration' or key == 'trace_duration': + log_data['llm_metrics_total_time'] = value + else: + log_data[key] = value + self._log_buffer.append(log_data) def log_llm_metrics(self, metrics: dict, level: str = 'info'): @@ -226,13 +242,17 @@ class BusinessEvent: self._log_buffer = [] def __enter__(self): + self.trace_start_time = time.time() self.log(f'Starting Trace for {self.event_type}') return BusinessEventContext(self).__enter__() def __exit__(self, exc_type, exc_val, exc_tb): + trace_total_time = time.time() - self.trace_start_time + if self.llm_metrics['call_count'] > 0: self.log_final_metrics() self.reset_llm_metrics() - self.log(f'Ending Trace for {self.event_type}') + + self.log(f'Ending Trace for {self.event_type}', extra_fields={'trace_duration': trace_total_time}) self._flush_log_buffer() return BusinessEventContext(self).__exit__(exc_type, exc_val, exc_tb) diff --git a/common/utils/minio_utils.py b/common/utils/minio_utils.py index 8c750cf..9cf372c 100644 --- a/common/utils/minio_utils.py +++ b/common/utils/minio_utils.py @@ -33,6 +33,9 @@ class MinioClient: def generate_object_name(self, document_id, language, version_id, filename): return f"{document_id}/{language}/{version_id}/{filename}" + def generate_asset_name(self, asset_version_id, file_name, content_type): + return f"assets/{asset_version_id}/{file_name}.{content_type}" + def upload_document_file(self, tenant_id, document_id, language, version_id, filename, file_data): bucket_name = self.generate_bucket_name(tenant_id) object_name = self.generate_object_name(document_id, language, version_id, filename) @@ -54,6 +57,26 @@ class MinioClient: except S3Error as err: raise Exception(f"Error occurred while uploading file: {err}") + def upload_asset_file(self, bucket_name, asset_version_id, file_name, file_type, file_data): + object_name = self.generate_asset_name(asset_version_id, file_name, file_type) + + try: + if isinstance(file_data, FileStorage): + file_data = file_data.read() + elif isinstance(file_data, io.BytesIO): + file_data = file_data.getvalue() + elif isinstance(file_data, str): + file_data = file_data.encode('utf-8') + elif not isinstance(file_data, bytes): + raise TypeError('Unsupported file type. Expected FileStorage, BytesIO, str, or bytes.') + + self.client.put_object( + bucket_name, object_name, io.BytesIO(file_data), len(file_data) + ) + return object_name, len(file_data) + except S3Error as err: + raise Exception(f"Error occurred while uploading asset: {err}") + def download_document_file(self, tenant_id, bucket_name, object_name): try: response = self.client.get_object(bucket_name, object_name) diff --git a/config/assets/DOCUMENT_TEMPLATE/1.0.0.yaml b/config/assets/DOCUMENT_TEMPLATE/1.0.0.yaml new file mode 100644 index 0000000..af05d09 --- /dev/null +++ b/config/assets/DOCUMENT_TEMPLATE/1.0.0.yaml @@ -0,0 +1,18 @@ +version: "1.0.0" +name: "Document Template" +configuration: + argument_definition: + name: "variable_defition" + type: "file" + description: "Yaml file defining the arguments in the Document Template." + required: True + content_markdown: + name: "content_markdown" + type: "str" + description: "Actual template file in markdown format." + required: True +metadata: + author: "Josako" + date_added: "2025-03-12" + description: "Asset that defines a template in markdown a specialist can process" + changes: "Initial version" diff --git a/config/type_defs/asset_types.py b/config/type_defs/asset_types.py new file mode 100644 index 0000000..41c769d --- /dev/null +++ b/config/type_defs/asset_types.py @@ -0,0 +1,7 @@ +# Agent Types +AGENT_TYPES = { + "DOCUMENT_TEMPLATE": { + "name": "Document Template", + "description": "Asset that defines a template in markdown a specialist can process", + }, +} diff --git a/eveai_app/views/dynamic_form_base.py b/eveai_app/views/dynamic_form_base.py index 656bd2a..e1078bd 100644 --- a/eveai_app/views/dynamic_form_base.py +++ b/eveai_app/views/dynamic_form_base.py @@ -1,5 +1,6 @@ from flask_wtf import FlaskForm -from wtforms import IntegerField, FloatField, BooleanField, StringField, TextAreaField, validators, ValidationError +from wtforms import (IntegerField, FloatField, BooleanField, StringField, TextAreaField, FileField, + validators, ValidationError) from flask import current_app import json @@ -264,6 +265,7 @@ class DynamicFormBase(FlaskForm): 'string': StringField, 'text': TextAreaField, 'date': DateField, + 'file': FileField, }.get(field_type, StringField) field_kwargs = {} diff --git a/eveai_app/views/interaction_forms.py b/eveai_app/views/interaction_forms.py index 9452b4e..e28b7c7 100644 --- a/eveai_app/views/interaction_forms.py +++ b/eveai_app/views/interaction_forms.py @@ -1,5 +1,6 @@ from flask_wtf import FlaskForm from wtforms import (StringField, BooleanField, SelectField, TextAreaField) +from wtforms.fields.datetime import DateField from wtforms.validators import DataRequired, Length, Optional from wtforms_sqlalchemy.fields import QuerySelectMultipleField @@ -94,3 +95,32 @@ class EditEveAITaskForm(BaseEditComponentForm): class EditEveAIToolForm(BaseEditComponentForm): pass + +class AddEveAIAssetForm(FlaskForm): + name = StringField('Name', validators=[DataRequired(), Length(max=50)]) + description = TextAreaField('Description', validators=[Optional()]) + type = SelectField('Type', validators=[DataRequired()]) + valid_from = DateField('Valid From', id='form-control datepicker', validators=[Optional()]) + valid_to = DateField('Valid To', id='form-control datepicker', validators=[Optional()]) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + types_dict = cache_manager.assets_types_cache.get_types() + self.type.choices = [(key, value['name']) for key, value in types_dict.items()] + + +class EditEveAIAssetForm(FlaskForm): + name = StringField('Name', validators=[DataRequired(), Length(max=50)]) + description = TextAreaField('Description', validators=[Optional()]) + type = SelectField('Type', validators=[DataRequired()], render_kw={'readonly': True}) + type_version = StringField('Type Version', validators=[DataRequired()], render_kw={'readonly': True}) + valid_from = DateField('Valid From', id='form-control datepicker', validators=[Optional()]) + valid_to = DateField('Valid To', id='form-control datepicker', validators=[Optional()]) + + +class EditEveAIAssetVersionForm(DynamicFormBase): + asset_name = StringField('Asset Name', validators=[DataRequired()], render_kw={'readonly': True}) + asset_type = StringField('Asset Type', validators=[DataRequired()], render_kw={'readonly': True}) + asset_type_version = StringField('Asset Type Version', validators=[DataRequired()], render_kw={'readonly': True}) + bucket_name = StringField('Bucket Name', validators=[DataRequired()], render_kw={'readonly': True}) + diff --git a/eveai_app/views/interaction_views.py b/eveai_app/views/interaction_views.py index 1d7c463..76089db 100644 --- a/eveai_app/views/interaction_views.py +++ b/eveai_app/views/interaction_views.py @@ -6,12 +6,15 @@ from flask_security import roles_accepted from langchain.agents import Agent from sqlalchemy import desc from sqlalchemy.exc import SQLAlchemyError +from werkzeug.datastructures import FileStorage +from werkzeug.utils import secure_filename from common.models.document import Embedding, DocumentVersion, Retriever from common.models.interaction import (ChatSession, Interaction, InteractionEmbedding, Specialist, SpecialistRetriever, - EveAIAgent, EveAITask, EveAITool) + EveAIAgent, EveAITask, EveAITool, EveAIAssetVersion) from common.extensions import db, cache_manager +from common.utils.asset_utils import create_asset_stack, add_asset_version_file from common.utils.model_logging_utils import set_logging_information, update_logging_information from common.utils.middleware import mw_before_request @@ -22,7 +25,7 @@ from common.utils.specialist_utils import initialize_specialist from config.type_defs.specialist_types import SPECIALIST_TYPES from .interaction_forms import (SpecialistForm, EditSpecialistForm, EditEveAIAgentForm, EditEveAITaskForm, - EditEveAIToolForm) + EditEveAIToolForm, AddEveAIAssetForm, EditEveAIAssetVersionForm) interaction_bp = Blueprint('interaction_bp', __name__, url_prefix='/interaction') @@ -37,6 +40,7 @@ def log_after_request(response): return response +# Routes for Chat Session Management -------------------------------------------------------------- @interaction_bp.before_request def before_request(): try: @@ -88,13 +92,13 @@ def view_chat_session(chat_session_id): .filter_by(chat_session_id=chat_session.id) .join(Specialist, Interaction.specialist_id == Specialist.id, isouter=True) .add_columns( - Interaction.id, - Interaction.question_at, - Interaction.specialist_arguments, - Interaction.specialist_results, - Specialist.name.label('specialist_name'), - Specialist.type.label('specialist_type') - ).order_by(Interaction.question_at).all()) + Interaction.id, + Interaction.question_at, + Interaction.specialist_arguments, + Interaction.specialist_results, + Specialist.name.label('specialist_name'), + Specialist.type.label('specialist_type') + ).order_by(Interaction.question_at).all()) # Fetch all related embeddings for the interactions in this session embedding_query = (db.session.query(InteractionEmbedding.interaction_id, @@ -129,6 +133,7 @@ def show_chat_session(chat_session): return render_template('interaction/view_chat_session.html', chat_session=chat_session, interactions=interactions) +# Routes for Specialist Management ---------------------------------------------------------------- @interaction_bp.route('/specialist', methods=['GET', 'POST']) @roles_accepted('Super User', 'Tenant Admin') def specialist(): @@ -142,7 +147,8 @@ def specialist(): # Populate fields individually instead of using populate_obj (gives problem with QueryMultipleSelectField) new_specialist.name = form.name.data new_specialist.type = form.type.data - new_specialist.type_version = cache_manager.specialists_version_tree_cache.get_latest_version(new_specialist.type) + new_specialist.type_version = cache_manager.specialists_version_tree_cache.get_latest_version( + new_specialist.type) new_specialist.tuning = form.tuning.data set_logging_information(new_specialist, dt.now(tz.utc)) @@ -252,7 +258,7 @@ def edit_specialist(specialist_id): task_rows=task_rows, tool_rows=tool_rows, prefixed_url_for=prefixed_url_for, - svg_path=svg_path,) + svg_path=svg_path, ) else: form_validation_failed(request, form) @@ -263,7 +269,7 @@ def edit_specialist(specialist_id): task_rows=task_rows, tool_rows=tool_rows, prefixed_url_for=prefixed_url_for, - svg_path=svg_path,) + svg_path=svg_path, ) @interaction_bp.route('/specialists', methods=['GET', 'POST']) @@ -298,7 +304,7 @@ def handle_specialist_selection(): return redirect(prefixed_url_for('interaction_bp.specialists')) -# Routes for Agent management +# Routes for Agent management --------------------------------------------------------------------- @interaction_bp.route('/agent//edit', methods=['GET']) @roles_accepted('Super User', 'Tenant Admin') def edit_agent(agent_id): @@ -338,7 +344,7 @@ def save_agent(agent_id): return jsonify({'success': False, 'message': 'Validation failed'}) -# Routes for Task management +# Routes for Task management ---------------------------------------------------------------------- @interaction_bp.route('/task//edit', methods=['GET']) @roles_accepted('Super User', 'Tenant Admin') def edit_task(task_id): @@ -374,7 +380,7 @@ def save_task(task_id): return jsonify({'success': False, 'message': 'Validation failed'}) -# Routes for Tool management +# Routes for Tool management ---------------------------------------------------------------------- @interaction_bp.route('/tool//edit', methods=['GET']) @roles_accepted('Super User', 'Tenant Admin') def edit_tool(tool_id): @@ -410,7 +416,7 @@ def save_tool(tool_id): return jsonify({'success': False, 'message': 'Validation failed'}) -# Component selection handlers +# Component selection handlers -------------------------------------------------------------------- @interaction_bp.route('/handle_agent_selection', methods=['POST']) @roles_accepted('Super User', 'Tenant Admin') def handle_agent_selection(): @@ -447,4 +453,93 @@ def handle_tool_selection(): if action == "edit_tool": return redirect(prefixed_url_for('interaction_bp.edit_tool', tool_id=tool_id)) - return redirect(prefixed_url_for('interaction_bp.edit_specialist')) \ No newline at end of file + return redirect(prefixed_url_for('interaction_bp.edit_specialist')) + + +# Routes for Asset management --------------------------------------------------------------------- +@interaction_bp.route('/add_asset', methods=['GET', 'POST']) +@roles_accepted('Super User', 'Tenant Admin') +def add_asset(): + form = AddEveAIAssetForm(request.form) + tenant_id = session.get('tenant').get('id') + + if form.validate_on_submit(): + try: + current_app.logger.info(f"Adding asset for tenant {tenant_id}") + + api_input = { + 'name': form.name.data, + 'description': form.description.data, + 'type': form.type.data, + 'valid_from': form.valid_from.data, + 'valid_to': form.valid_to.data, + } + new_asset, new_asset_version = create_asset_stack(api_input, tenant_id) + + return redirect(prefixed_url_for('interaction_bp.edit_asset_version', + asset_version_id=new_asset_version.id)) + except Exception as e: + current_app.logger.error(f'Failed to add asset for tenant {tenant_id}: {str(e)}') + flash('An error occurred while adding asset', 'error') + + return render_template('interaction/add_asset.html') + + +@interaction_bp.route('/edit_asset_version/', methods=['GET', 'POST']) +@roles_accepted('Super User', 'Tenant Admin') +def edit_asset_version(asset_version_id): + asset_version = EveAIAssetVersion.query.get_or_404(asset_version_id) + form = EditEveAIAssetVersionForm(asset_version) + asset_config = cache_manager.assets_config_cache.get_config(asset_version.asset.type, + asset_version.asset.type_version) + configuration_config = asset_config.get('configuration') + form.add_dynamic_fields("configuration", configuration_config, asset_version.configuration) + + if form.validate_on_submit(): + # Update the configuration dynamic fields + configuration = form.get_dynamic_data("configuration") + processed_configuration = {} + tenant_id = session.get('tenant').get('id') + # if files are returned, we will store the file_names in the configuration, and add the file to the appropriate + # bucket, in the appropriate location + for field_name, field_value in configuration.items(): + # Handle file field - check if the value is a FileStorage instance + if isinstance(field_value, FileStorage) and field_value.filename: + try: + # Upload file and retrieve object_name for the file + object_name = add_asset_version_file(asset_version, field_name, field_value, tenant_id) + + # Store object reference in configuration instead of file content + processed_configuration[field_name] = object_name + + except Exception as e: + current_app.logger.error(f"Failed to upload file for asset version {asset_version.id}: {str(e)}") + flash(f"Failed to upload file '{field_value.filename}': {str(e)}", "danger") + return render_template('interaction/edit_asset_version.html', form=form, + asset_version=asset_version) + # Handle normal fields + else: + processed_configuration[field_name] = field_value + + # Update the asset version with processed configuration + asset_version.configuration = processed_configuration + + # Update logging information + update_logging_information(asset_version, dt.now(tz.utc)) + + try: + db.session.commit() + flash('Asset uploaded successfully!', 'success') + current_app.logger.info(f'Asset Version {asset_version.id} updated successfully') + return redirect(prefixed_url_for('interaction_bp.assets')) + except SQLAlchemyError as e: + db.session.rollback() + flash(f'Failed to upload asset. Error: {str(e)}', 'danger') + current_app.logger.error(f'Failed to update asset version {asset_version.id}. Error: {str(e)}') + return render_template('interaction/edit_asset_version.html', form=form) + else: + form_validation_failed(request, form) + + return render_template('interaction/edit_asset_version.html', form=form) + + diff --git a/eveai_chat_workers/retrievers/__init__.py b/eveai_chat_workers/retrievers/__init__.py index 37bd40e..cefd720 100644 --- a/eveai_chat_workers/retrievers/__init__.py +++ b/eveai_chat_workers/retrievers/__init__.py @@ -1,5 +1,6 @@ # Import all specialist implementations here to ensure registration from . import standard_rag +from . import dossier_retriever # List of all available specialist implementations -__all__ = ['standard_rag'] \ No newline at end of file +__all__ = ['standard_rag', 'dossier_retriever'] diff --git a/eveai_chat_workers/retrievers/base.py b/eveai_chat_workers/retrievers/base.py index 8df51bf..bbf32e1 100644 --- a/eveai_chat_workers/retrievers/base.py +++ b/eveai_chat_workers/retrievers/base.py @@ -1,8 +1,14 @@ from abc import ABC, abstractmethod, abstractproperty -from typing import Dict, Any, List +from typing import Dict, Any, List, Optional, Tuple from flask import current_app -from eveai_chat_workers.retrievers.retriever_typing import RetrieverResult, RetrieverArguments +from sqlalchemy import func, or_, desc +from sqlalchemy.exc import SQLAlchemyError + +from common.extensions import db +from common.models.document import Document, DocumentVersion, Catalog, Retriever +from common.utils.model_utils import get_embedding_model_and_class +from eveai_chat_workers.retrievers.retriever_typing import RetrieverResult, RetrieverArguments, RetrieverMetadata from config.logging_config import TuningLogger @@ -12,6 +18,7 @@ class BaseRetriever(ABC): def __init__(self, tenant_id: int, retriever_id: int): self.tenant_id = tenant_id self.retriever_id = retriever_id + self.retriever = Retriever.query.get_or_404(retriever_id) self.tuning = False self.tuning_logger = None self._setup_tuning_logger() @@ -43,6 +50,31 @@ class BaseRetriever(ABC): except Exception as e: current_app.logger.error(f"Processor: Error in tuning logging: {e}") + def setup_standard_retrieval_params(self) -> Tuple[Any, Any, Any, float, int]: + """ + Set up common parameters needed for standard retrieval functionality + + Returns: + Tuple containing: + - embedding_model: The model to use for embeddings + - embedding_model_class: The class for storing embeddings + - catalog_id: ID of the catalog + - similarity_threshold: Threshold for similarity matching + - k: Maximum number of results to return + """ + catalog_id = self.retriever.catalog_id + catalog = Catalog.query.get_or_404(catalog_id) + embedding_model = "mistral.mistral-embed" + + embedding_model, embedding_model_class = get_embedding_model_and_class( + self.tenant_id, catalog_id, embedding_model + ) + + similarity_threshold = self.retriever.configuration.get('es_similarity_threshold', 0.3) + k = self.retriever.configuration.get('es_k', 8) + + return embedding_model, embedding_model_class, catalog_id, similarity_threshold, k + @abstractmethod def retrieve(self, arguments: RetrieverArguments) -> List[RetrieverResult]: """ diff --git a/eveai_chat_workers/retrievers/dossier_retriever.py b/eveai_chat_workers/retrievers/dossier_retriever.py new file mode 100644 index 0000000..1e46339 --- /dev/null +++ b/eveai_chat_workers/retrievers/dossier_retriever.py @@ -0,0 +1,374 @@ +""" +DossierRetriever implementation that adds metadata filtering to retrieval +""" +import json +from datetime import datetime as dt, date, timezone as tz +from typing import Dict, Any, List, Optional, Union, Tuple +from sqlalchemy import func, or_, desc, and_, text, cast, JSON, String, Integer, Float, Boolean, DateTime +from sqlalchemy.sql import expression +from sqlalchemy.exc import SQLAlchemyError +from flask import current_app + +from common.extensions import db +from common.models.document import Document, DocumentVersion, Catalog, Retriever +from common.utils.model_utils import get_embedding_model_and_class +from .base import BaseRetriever +from .registry import RetrieverRegistry +from .retriever_typing import RetrieverArguments, RetrieverResult, RetrieverMetadata + + +class DossierRetriever(BaseRetriever): + """ + Dossier Retriever implementation that adds metadata filtering + to standard retrieval functionality + """ + + def __init__(self, tenant_id: int, retriever_id: int): + super().__init__(tenant_id, retriever_id) + + # Set up standard retrieval parameters + self.embedding_model, self.embedding_model_class, self.catalog_id, self.similarity_threshold, self.k = self.setup_standard_retrieval_params() + + # Dossier-specific configuration + self.tagging_fields_filter = self.retriever.configuration.get('tagging_fields_filter', {}) + self.dynamic_arguments = self.retriever.configuration.get('dynamic_arguments', {}) + + self.log_tuning("Dossier retriever initialized", { + "tagging_fields_filter": self.tagging_fields_filter, + "dynamic_arguments": self.dynamic_arguments, + "similarity_threshold": self.similarity_threshold, + "k": self.k + }) + + @property + def type(self) -> str: + return "DOSSIER_RETRIEVER" + + def _parse_metadata(self, metadata: Any) -> Dict[str, Any]: + """ + Parse metadata ensuring it's a dictionary + + Args: + metadata: Input metadata which could be string, dict, or None + + Returns: + Dict[str, Any]: Parsed metadata as dictionary + """ + if metadata is None: + return {} + + if isinstance(metadata, dict): + return metadata + + if isinstance(metadata, str): + try: + return json.loads(metadata) + except json.JSONDecodeError: + current_app.logger.warning(f"Failed to parse metadata JSON string: {metadata}") + return {} + + current_app.logger.warning(f"Unexpected metadata type: {type(metadata)}") + return {} + + def _apply_metadata_filter(self, query_obj, arguments: RetrieverArguments): + """ + Apply metadata filters to the query based on tagging_fields_filter configuration + + Args: + query_obj: SQLAlchemy query object + arguments: Retriever arguments (for variable substitution) + + Returns: + Modified SQLAlchemy query object + """ + if not self.tagging_fields_filter: + return query_obj + + # Get dynamic argument values + dynamic_values = {} + for arg_name, arg_config in self.dynamic_arguments.items(): + if hasattr(arguments, arg_name): + dynamic_values[arg_name] = getattr(arguments, arg_name) + + # Build the filter + filter_condition = self._build_filter_condition(self.tagging_fields_filter, dynamic_values) + if filter_condition is not None: + query_obj = query_obj.filter(filter_condition) + self.log_tuning("Applied metadata filter", { + "filter_sql": str(filter_condition), + "dynamic_values": dynamic_values + }) + + return query_obj + + def _build_filter_condition(self, filter_def: Dict[str, Any], dynamic_values: Dict[str, Any]) -> Optional[ + expression.BinaryExpression]: + """ + Recursively build SQLAlchemy filter condition from filter definition + + Args: + filter_def: Filter definition (logical group or field condition) + dynamic_values: Values for dynamic variable substitution + + Returns: + SQLAlchemy expression or None if invalid + """ + # Handle logical groups (AND, OR, NOT) + if 'logical' in filter_def: + logical_op = filter_def['logical'].lower() + subconditions = [ + self._build_filter_condition(cond, dynamic_values) + for cond in filter_def.get('conditions', []) + ] + # Remove None values + subconditions = [c for c in subconditions if c is not None] + + if not subconditions: + return None + + if logical_op == 'and': + return and_(*subconditions) + elif logical_op == 'or': + return or_(*subconditions) + elif logical_op == 'not': + if len(subconditions) == 1: + return ~subconditions[0] + else: + # NOT should have exactly one condition + current_app.logger.warning(f"NOT operator requires exactly one condition, got {len(subconditions)}") + return None + else: + current_app.logger.warning(f"Unknown logical operator: {logical_op}") + return None + + # Handle field conditions + elif 'field' in filter_def and 'operator' in filter_def and 'value' in filter_def: + field_name = filter_def['field'] + operator = filter_def['operator'].lower() + value = self._resolve_value(filter_def['value'], dynamic_values, filter_def.get('default')) + + # Skip if we couldn't resolve the value + if value is None and operator not in ['is_null', 'is_not_null']: + return None + + # Create the field expression to match JSON data + field_expr = cast(DocumentVersion.catalog_properties['tagging_fields'][field_name], String) + + # Apply the appropriate operator + return self._apply_operator(field_expr, operator, value, filter_def) + + else: + current_app.logger.warning(f"Invalid filter definition: {filter_def}") + return None + + def _resolve_value(self, value_def: Any, dynamic_values: Dict[str, Any], default: Any = None) -> Any: + """ + Resolve a value definition, handling variables and defaults + + Args: + value_def: Value definition (could be literal, variable reference, or list) + dynamic_values: Values for dynamic variable substitution + default: Default value if variable not found + + Returns: + Resolved value + """ + # Handle lists (recursively resolve each item) + if isinstance(value_def, list): + return [self._resolve_value(item, dynamic_values) for item in value_def] + + # Handle variable references (strings starting with $) + if isinstance(value_def, str) and value_def.startswith('$'): + var_name = value_def[1:] # Remove $ prefix + if var_name in dynamic_values: + return dynamic_values[var_name] + else: + # Use default if provided + return default + + # Return literal values as-is + return value_def + + def _apply_operator(self, field_expr, operator: str, value: Any, filter_def: Dict[str, Any]) -> Optional[ + expression.BinaryExpression]: + """ + Apply the specified operator to create a filter condition + + Args: + field_expr: SQLAlchemy field expression + operator: Operator to apply + value: Value to compare against + filter_def: Original filter definition (for additional options) + + Returns: + SQLAlchemy expression + """ + try: + # String operators + if operator == 'eq': + return field_expr == str(value) + elif operator == 'neq': + return field_expr != str(value) + elif operator == 'contains': + return field_expr.contains(str(value)) + elif operator == 'not_contains': + return ~field_expr.contains(str(value)) + elif operator == 'starts_with': + return field_expr.startswith(str(value)) + elif operator == 'ends_with': + return field_expr.endswith(str(value)) + elif operator == 'in': + return field_expr.in_([str(v) for v in value]) + elif operator == 'not_in': + return ~field_expr.in_([str(v) for v in value]) + elif operator == 'regex' or operator == 'not_regex': + # PostgreSQL regex using ~ or !~ operator + case_insensitive = filter_def.get('case_insensitive', False) + regex_op = '~*' if case_insensitive else '~' + if operator == 'not_regex': + regex_op = '!~*' if case_insensitive else '!~' + return text( + f"{field_expr.compile(compile_kwargs={'literal_binds': True})} {regex_op} :regex_value").bindparams( + regex_value=str(value)) + + # Numeric/Date operators + elif operator == 'gt': + return cast(field_expr, Float) > float(value) + elif operator == 'gte': + return cast(field_expr, Float) >= float(value) + elif operator == 'lt': + return cast(field_expr, Float) < float(value) + elif operator == 'lte': + return cast(field_expr, Float) <= float(value) + elif operator == 'between': + if len(value) == 2: + return cast(field_expr, Float).between(float(value[0]), float(value[1])) + else: + current_app.logger.warning(f"BETWEEN operator requires exactly two values, got {len(value)}") + return None + elif operator == 'not_between': + if len(value) == 2: + return ~cast(field_expr, Float).between(float(value[0]), float(value[1])) + else: + current_app.logger.warning(f"NOT_BETWEEN operator requires exactly two values, got {len(value)}") + return None + + # Null checking + elif operator == 'is_null': + return field_expr.is_(None) + elif operator == 'is_not_null': + return field_expr.isnot(None) + + else: + current_app.logger.warning(f"Unknown operator: {operator}") + return None + + except (ValueError, TypeError) as e: + current_app.logger.error(f"Error applying operator {operator} with value {value}: {e}") + return None + + def retrieve(self, arguments: RetrieverArguments) -> List[RetrieverResult]: + """ + Retrieve documents based on query with added metadata filtering + + Args: + arguments: Validated RetrieverArguments containing at minimum: + - query: str - The search query + + Returns: + List[RetrieverResult]: List of retrieved documents with similarity scores + """ + try: + query = arguments.query + + # Get query embedding + query_embedding = self.embedding_model.embed_query(query) + + # Get the appropriate embedding database model + db_class = self.embedding_model_class + + # Get current date for validity checks + current_date = dt.now(tz=tz.utc).date() + + # Create subquery for latest versions + subquery = ( + db.session.query( + DocumentVersion.doc_id, + func.max(DocumentVersion.id).label('latest_version_id') + ) + .group_by(DocumentVersion.doc_id) + .subquery() + ) + + # Main query + query_obj = ( + db.session.query( + db_class, + (1 - db_class.embedding.cosine_distance(query_embedding)).label('similarity') + ) + .join(DocumentVersion, db_class.doc_vers_id == DocumentVersion.id) + .join(Document, DocumentVersion.doc_id == Document.id) + .join(subquery, DocumentVersion.id == subquery.c.latest_version_id) + .filter( + or_(Document.valid_from.is_(None), func.date(Document.valid_from) <= current_date), + or_(Document.valid_to.is_(None), func.date(Document.valid_to) >= current_date), + (1 - db_class.embedding.cosine_distance(query_embedding)) > self.similarity_threshold, + Document.catalog_id == self.catalog_id + ) + ) + + # Apply metadata filtering + query_obj = self._apply_metadata_filter(query_obj, arguments) + + # Apply ordering and limit + query_obj = query_obj.order_by(desc('similarity')).limit(self.k) + + # Execute query + results = query_obj.all() + + # Transform results into standard format + processed_results = [] + for doc, similarity in results: + # Parse user_metadata to ensure it's a dictionary + user_metadata = self._parse_metadata(doc.document_version.user_metadata) + processed_results.append( + RetrieverResult( + id=doc.id, + chunk=doc.chunk, + similarity=float(similarity), + metadata=RetrieverMetadata( + document_id=doc.document_version.doc_id, + version_id=doc.document_version.id, + document_name=doc.document_version.document.name, + user_metadata=user_metadata, + ) + ) + ) + + # Log the retrieval + if self.tuning: + compiled_query = str(query_obj.statement.compile( + compile_kwargs={"literal_binds": True} # This will include the actual values in the SQL + )) + self.log_tuning('retrieve', { + "arguments": arguments.model_dump(), + "similarity_threshold": self.similarity_threshold, + "k": self.k, + "query": compiled_query, + "results_count": len(results), + "processed_results_count": len(processed_results), + }) + + return processed_results + + except SQLAlchemyError as e: + current_app.logger.error(f'Error in Dossier retrieval: {e}') + db.session.rollback() + raise + except Exception as e: + current_app.logger.error(f'Unexpected error in Dossier retrieval: {e}') + raise + + +# Register the retriever type +RetrieverRegistry.register("DOSSIER_RETRIEVER", DossierRetriever) diff --git a/eveai_chat_workers/retrievers/standard_rag.py b/eveai_chat_workers/retrievers/standard_rag.py index ac975d9..61f66a6 100644 --- a/eveai_chat_workers/retrievers/standard_rag.py +++ b/eveai_chat_workers/retrievers/standard_rag.py @@ -23,20 +23,12 @@ class StandardRAGRetriever(BaseRetriever): def __init__(self, tenant_id: int, retriever_id: int): super().__init__(tenant_id, retriever_id) - retriever = Retriever.query.get_or_404(retriever_id) - self.catalog_id = retriever.catalog_id - self.tenant_id = tenant_id - catalog = Catalog.query.get_or_404(self.catalog_id) - embedding_model = "mistral.mistral-embed" - - self.embedding_model, self.embedding_model_class = get_embedding_model_and_class(self.tenant_id, - self.catalog_id, - embedding_model) - self.similarity_threshold = retriever.configuration.get('es_similarity_threshold', 0.3) - self.k = retriever.configuration.get('es_k', 8) - self.tuning = retriever.tuning - - self.log_tuning("Standard RAG retriever initialized") + # Set up standard retrieval parameters + self.embedding_model, self.embedding_model_class, self.catalog_id, self.similarity_threshold, self.k = self.setup_standard_retrieval_params() + self.log_tuning("Standard RAG retriever initialized", { + "similarity_threshold": self.similarity_threshold, + "k": self.k + }) @property def type(self) -> str: @@ -167,4 +159,4 @@ class StandardRAGRetriever(BaseRetriever): # Register the retriever type -RetrieverRegistry.register("STANDARD_RAG", StandardRAGRetriever) +RetrieverRegistry.register("STANDARD_RAG", StandardRAGRetriever) \ No newline at end of file diff --git a/migrations/tenant/versions/4a9f7a6285cc_add_eveaiasset_m_associated_models.py b/migrations/tenant/versions/4a9f7a6285cc_add_eveaiasset_m_associated_models.py new file mode 100644 index 0000000..b6146f6 --- /dev/null +++ b/migrations/tenant/versions/4a9f7a6285cc_add_eveaiasset_m_associated_models.py @@ -0,0 +1,83 @@ +"""Add EveAIAsset m& associated models + +Revision ID: 4a9f7a6285cc +Revises: ed981a641be9 +Create Date: 2025-03-12 09:43:05.390895 + +""" +from alembic import op +import sqlalchemy as sa +import pgvector +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '4a9f7a6285cc' +down_revision = 'ed981a641be9' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('eve_ai_asset', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('name', sa.String(length=50), nullable=False), + sa.Column('description', sa.Text(), nullable=True), + sa.Column('type', sa.String(length=50), nullable=False), + sa.Column('type_version', sa.String(length=20), nullable=True), + sa.Column('valid_from', sa.DateTime(), nullable=True), + sa.Column('valid_to', sa.DateTime(), nullable=True), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=False), + sa.Column('created_by', sa.Integer(), nullable=True), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('now()'), nullable=False), + sa.Column('updated_by', sa.Integer(), nullable=True), + sa.ForeignKeyConstraint(['created_by'], ['public.user.id'], ), + sa.ForeignKeyConstraint(['updated_by'], ['public.user.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('eve_ai_asset_version', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('asset_id', sa.Integer(), nullable=False), + sa.Column('bucket_name', sa.String(length=255), nullable=True), + sa.Column('configuration', postgresql.JSONB(astext_type=sa.Text()), nullable=True), + sa.Column('arguments', postgresql.JSONB(astext_type=sa.Text()), nullable=True), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=False), + sa.Column('created_by', sa.Integer(), nullable=True), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('now()'), nullable=False), + sa.Column('updated_by', sa.Integer(), nullable=True), + sa.ForeignKeyConstraint(['asset_id'], ['eve_ai_asset.id'], ), + sa.ForeignKeyConstraint(['created_by'], ['public.user.id'], ), + sa.ForeignKeyConstraint(['updated_by'], ['public.user.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('eve_ai_asset_instruction', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('asset_version_id', sa.Integer(), nullable=False), + sa.Column('name', sa.String(length=255), nullable=False), + sa.Column('content', sa.Text(), nullable=True), + sa.ForeignKeyConstraint(['asset_version_id'], ['eve_ai_asset_version.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('eve_ai_processed_asset', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('asset_version_id', sa.Integer(), nullable=False), + sa.Column('specialist_id', sa.Integer(), nullable=True), + sa.Column('chat_session_id', sa.Integer(), nullable=True), + sa.Column('bucket_name', sa.String(length=255), nullable=True), + sa.Column('object_name', sa.String(length=255), nullable=True), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True), + sa.ForeignKeyConstraint(['asset_version_id'], ['eve_ai_asset_version.id'], ), + sa.ForeignKeyConstraint(['chat_session_id'], ['chat_session.id'], ), + sa.ForeignKeyConstraint(['specialist_id'], ['specialist.id'], ), + sa.PrimaryKeyConstraint('id') + ) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('eve_ai_processed_asset') + op.drop_table('eve_ai_asset_instruction') + op.drop_table('eve_ai_asset_version') + op.drop_table('eve_ai_asset') + # ### end Alembic commands ###