- Started addition of Assets (to e.g. handle document templates).
- To be continued (Models, first views are ready)
This commit is contained in:
62
common/utils/asset_utils.py
Normal file
62
common/utils/asset_utils.py
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
from datetime import datetime as dt, timezone as tz
|
||||||
|
|
||||||
|
from flask import current_app
|
||||||
|
from sqlalchemy.exc import SQLAlchemyError
|
||||||
|
|
||||||
|
from common.extensions import cache_manager, minio_client, db
|
||||||
|
from common.models.interaction import EveAIAsset, EveAIAssetVersion
|
||||||
|
from common.utils.document_utils import mark_tenant_storage_dirty
|
||||||
|
from common.utils.model_logging_utils import set_logging_information
|
||||||
|
|
||||||
|
|
||||||
|
def create_asset_stack(api_input, tenant_id):
|
||||||
|
type_version = cache_manager.assets_version_tree_cache.get_latest_version(api_input['type'])
|
||||||
|
api_input['type_version'] = type_version
|
||||||
|
new_asset = create_asset(api_input, tenant_id)
|
||||||
|
new_asset_version = create_version_for_asset(new_asset, tenant_id)
|
||||||
|
db.session.add(new_asset)
|
||||||
|
db.session.add(new_asset_version)
|
||||||
|
|
||||||
|
try:
|
||||||
|
db.session.commit()
|
||||||
|
except SQLAlchemyError as e:
|
||||||
|
current_app.logger.error(f"Could not add asset for tenant {tenant_id}: {str(e)}")
|
||||||
|
db.session.rollback()
|
||||||
|
raise e
|
||||||
|
|
||||||
|
return new_asset, new_asset_version
|
||||||
|
|
||||||
|
|
||||||
|
def create_asset(api_input, tenant_id):
|
||||||
|
new_asset = EveAIAsset()
|
||||||
|
new_asset.name = api_input['name']
|
||||||
|
new_asset.description = api_input['description']
|
||||||
|
new_asset.type = api_input['type']
|
||||||
|
new_asset.type_version = api_input['type_version']
|
||||||
|
if api_input['valid_from'] and api_input['valid_from'] != '':
|
||||||
|
new_asset.valid_from = api_input['valid_from']
|
||||||
|
else:
|
||||||
|
new_asset.valid_from = dt.now(tz.utc)
|
||||||
|
new_asset.valid_to = api_input['valid_to']
|
||||||
|
set_logging_information(new_asset, dt.now(tz.utc))
|
||||||
|
|
||||||
|
return new_asset
|
||||||
|
|
||||||
|
|
||||||
|
def create_version_for_asset(asset, tenant_id):
|
||||||
|
new_asset_version = EveAIAssetVersion()
|
||||||
|
new_asset_version.asset = asset
|
||||||
|
new_asset_version.bucket_name = minio_client.create_tenant_bucket(tenant_id)
|
||||||
|
set_logging_information(new_asset_version, dt.now(tz.utc))
|
||||||
|
|
||||||
|
return new_asset_version
|
||||||
|
|
||||||
|
|
||||||
|
def add_asset_version_file(asset_version, field_name, file, tenant_id):
|
||||||
|
object_name, file_size = minio_client.upload_file(asset_version.bucket_name, asset_version.id, field_name,
|
||||||
|
file.content_type)
|
||||||
|
mark_tenant_storage_dirty(tenant_id)
|
||||||
|
return object_name
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -1,4 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
@@ -82,15 +83,21 @@ class BusinessEvent:
|
|||||||
self.span_name = span_name
|
self.span_name = span_name
|
||||||
self.parent_span_id = parent_span_id
|
self.parent_span_id = parent_span_id
|
||||||
|
|
||||||
|
# Track start time for the span
|
||||||
|
span_start_time = time.time()
|
||||||
|
|
||||||
self.log(f"Start")
|
self.log(f"Start")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
|
# Calculate total time for this span
|
||||||
|
span_total_time = time.time() - span_start_time
|
||||||
|
|
||||||
if self.llm_metrics['call_count'] > 0:
|
if self.llm_metrics['call_count'] > 0:
|
||||||
self.log_final_metrics()
|
self.log_final_metrics()
|
||||||
self.reset_llm_metrics()
|
self.reset_llm_metrics()
|
||||||
self.log(f"End")
|
self.log(f"End", extra_fields={'span_duration': span_total_time})
|
||||||
# Restore the previous span info
|
# Restore the previous span info
|
||||||
if self.spans:
|
if self.spans:
|
||||||
self.span_id, self.span_name, self.parent_span_id = self.spans.pop()
|
self.span_id, self.span_name, self.parent_span_id = self.spans.pop()
|
||||||
@@ -99,7 +106,7 @@ class BusinessEvent:
|
|||||||
self.span_name = None
|
self.span_name = None
|
||||||
self.parent_span_id = None
|
self.parent_span_id = None
|
||||||
|
|
||||||
def log(self, message: str, level: str = 'info'):
|
def log(self, message: str, level: str = 'info', extra_fields: Dict[str, Any] = None):
|
||||||
log_data = {
|
log_data = {
|
||||||
'timestamp': dt.now(tz=tz.utc),
|
'timestamp': dt.now(tz=tz.utc),
|
||||||
'event_type': self.event_type,
|
'event_type': self.event_type,
|
||||||
@@ -115,6 +122,15 @@ class BusinessEvent:
|
|||||||
'environment': self.environment,
|
'environment': self.environment,
|
||||||
'message': message,
|
'message': message,
|
||||||
}
|
}
|
||||||
|
# Add any extra fields
|
||||||
|
if extra_fields:
|
||||||
|
for key, value in extra_fields.items():
|
||||||
|
# For span/trace duration, use the llm_metrics_total_time field
|
||||||
|
if key == 'span_duration' or key == 'trace_duration':
|
||||||
|
log_data['llm_metrics_total_time'] = value
|
||||||
|
else:
|
||||||
|
log_data[key] = value
|
||||||
|
|
||||||
self._log_buffer.append(log_data)
|
self._log_buffer.append(log_data)
|
||||||
|
|
||||||
def log_llm_metrics(self, metrics: dict, level: str = 'info'):
|
def log_llm_metrics(self, metrics: dict, level: str = 'info'):
|
||||||
@@ -226,13 +242,17 @@ class BusinessEvent:
|
|||||||
self._log_buffer = []
|
self._log_buffer = []
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
|
self.trace_start_time = time.time()
|
||||||
self.log(f'Starting Trace for {self.event_type}')
|
self.log(f'Starting Trace for {self.event_type}')
|
||||||
return BusinessEventContext(self).__enter__()
|
return BusinessEventContext(self).__enter__()
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
trace_total_time = time.time() - self.trace_start_time
|
||||||
|
|
||||||
if self.llm_metrics['call_count'] > 0:
|
if self.llm_metrics['call_count'] > 0:
|
||||||
self.log_final_metrics()
|
self.log_final_metrics()
|
||||||
self.reset_llm_metrics()
|
self.reset_llm_metrics()
|
||||||
self.log(f'Ending Trace for {self.event_type}')
|
|
||||||
|
self.log(f'Ending Trace for {self.event_type}', extra_fields={'trace_duration': trace_total_time})
|
||||||
self._flush_log_buffer()
|
self._flush_log_buffer()
|
||||||
return BusinessEventContext(self).__exit__(exc_type, exc_val, exc_tb)
|
return BusinessEventContext(self).__exit__(exc_type, exc_val, exc_tb)
|
||||||
|
|||||||
@@ -33,6 +33,9 @@ class MinioClient:
|
|||||||
def generate_object_name(self, document_id, language, version_id, filename):
|
def generate_object_name(self, document_id, language, version_id, filename):
|
||||||
return f"{document_id}/{language}/{version_id}/{filename}"
|
return f"{document_id}/{language}/{version_id}/{filename}"
|
||||||
|
|
||||||
|
def generate_asset_name(self, asset_version_id, file_name, content_type):
|
||||||
|
return f"assets/{asset_version_id}/{file_name}.{content_type}"
|
||||||
|
|
||||||
def upload_document_file(self, tenant_id, document_id, language, version_id, filename, file_data):
|
def upload_document_file(self, tenant_id, document_id, language, version_id, filename, file_data):
|
||||||
bucket_name = self.generate_bucket_name(tenant_id)
|
bucket_name = self.generate_bucket_name(tenant_id)
|
||||||
object_name = self.generate_object_name(document_id, language, version_id, filename)
|
object_name = self.generate_object_name(document_id, language, version_id, filename)
|
||||||
@@ -54,6 +57,26 @@ class MinioClient:
|
|||||||
except S3Error as err:
|
except S3Error as err:
|
||||||
raise Exception(f"Error occurred while uploading file: {err}")
|
raise Exception(f"Error occurred while uploading file: {err}")
|
||||||
|
|
||||||
|
def upload_asset_file(self, bucket_name, asset_version_id, file_name, file_type, file_data):
|
||||||
|
object_name = self.generate_asset_name(asset_version_id, file_name, file_type)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if isinstance(file_data, FileStorage):
|
||||||
|
file_data = file_data.read()
|
||||||
|
elif isinstance(file_data, io.BytesIO):
|
||||||
|
file_data = file_data.getvalue()
|
||||||
|
elif isinstance(file_data, str):
|
||||||
|
file_data = file_data.encode('utf-8')
|
||||||
|
elif not isinstance(file_data, bytes):
|
||||||
|
raise TypeError('Unsupported file type. Expected FileStorage, BytesIO, str, or bytes.')
|
||||||
|
|
||||||
|
self.client.put_object(
|
||||||
|
bucket_name, object_name, io.BytesIO(file_data), len(file_data)
|
||||||
|
)
|
||||||
|
return object_name, len(file_data)
|
||||||
|
except S3Error as err:
|
||||||
|
raise Exception(f"Error occurred while uploading asset: {err}")
|
||||||
|
|
||||||
def download_document_file(self, tenant_id, bucket_name, object_name):
|
def download_document_file(self, tenant_id, bucket_name, object_name):
|
||||||
try:
|
try:
|
||||||
response = self.client.get_object(bucket_name, object_name)
|
response = self.client.get_object(bucket_name, object_name)
|
||||||
|
|||||||
18
config/assets/DOCUMENT_TEMPLATE/1.0.0.yaml
Normal file
18
config/assets/DOCUMENT_TEMPLATE/1.0.0.yaml
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
version: "1.0.0"
|
||||||
|
name: "Document Template"
|
||||||
|
configuration:
|
||||||
|
argument_definition:
|
||||||
|
name: "variable_defition"
|
||||||
|
type: "file"
|
||||||
|
description: "Yaml file defining the arguments in the Document Template."
|
||||||
|
required: True
|
||||||
|
content_markdown:
|
||||||
|
name: "content_markdown"
|
||||||
|
type: "str"
|
||||||
|
description: "Actual template file in markdown format."
|
||||||
|
required: True
|
||||||
|
metadata:
|
||||||
|
author: "Josako"
|
||||||
|
date_added: "2025-03-12"
|
||||||
|
description: "Asset that defines a template in markdown a specialist can process"
|
||||||
|
changes: "Initial version"
|
||||||
7
config/type_defs/asset_types.py
Normal file
7
config/type_defs/asset_types.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
# Agent Types
|
||||||
|
AGENT_TYPES = {
|
||||||
|
"DOCUMENT_TEMPLATE": {
|
||||||
|
"name": "Document Template",
|
||||||
|
"description": "Asset that defines a template in markdown a specialist can process",
|
||||||
|
},
|
||||||
|
}
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
from flask_wtf import FlaskForm
|
from flask_wtf import FlaskForm
|
||||||
from wtforms import IntegerField, FloatField, BooleanField, StringField, TextAreaField, validators, ValidationError
|
from wtforms import (IntegerField, FloatField, BooleanField, StringField, TextAreaField, FileField,
|
||||||
|
validators, ValidationError)
|
||||||
from flask import current_app
|
from flask import current_app
|
||||||
import json
|
import json
|
||||||
|
|
||||||
@@ -264,6 +265,7 @@ class DynamicFormBase(FlaskForm):
|
|||||||
'string': StringField,
|
'string': StringField,
|
||||||
'text': TextAreaField,
|
'text': TextAreaField,
|
||||||
'date': DateField,
|
'date': DateField,
|
||||||
|
'file': FileField,
|
||||||
}.get(field_type, StringField)
|
}.get(field_type, StringField)
|
||||||
field_kwargs = {}
|
field_kwargs = {}
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
from flask_wtf import FlaskForm
|
from flask_wtf import FlaskForm
|
||||||
from wtforms import (StringField, BooleanField, SelectField, TextAreaField)
|
from wtforms import (StringField, BooleanField, SelectField, TextAreaField)
|
||||||
|
from wtforms.fields.datetime import DateField
|
||||||
from wtforms.validators import DataRequired, Length, Optional
|
from wtforms.validators import DataRequired, Length, Optional
|
||||||
|
|
||||||
from wtforms_sqlalchemy.fields import QuerySelectMultipleField
|
from wtforms_sqlalchemy.fields import QuerySelectMultipleField
|
||||||
@@ -94,3 +95,32 @@ class EditEveAITaskForm(BaseEditComponentForm):
|
|||||||
class EditEveAIToolForm(BaseEditComponentForm):
|
class EditEveAIToolForm(BaseEditComponentForm):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class AddEveAIAssetForm(FlaskForm):
|
||||||
|
name = StringField('Name', validators=[DataRequired(), Length(max=50)])
|
||||||
|
description = TextAreaField('Description', validators=[Optional()])
|
||||||
|
type = SelectField('Type', validators=[DataRequired()])
|
||||||
|
valid_from = DateField('Valid From', id='form-control datepicker', validators=[Optional()])
|
||||||
|
valid_to = DateField('Valid To', id='form-control datepicker', validators=[Optional()])
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
types_dict = cache_manager.assets_types_cache.get_types()
|
||||||
|
self.type.choices = [(key, value['name']) for key, value in types_dict.items()]
|
||||||
|
|
||||||
|
|
||||||
|
class EditEveAIAssetForm(FlaskForm):
|
||||||
|
name = StringField('Name', validators=[DataRequired(), Length(max=50)])
|
||||||
|
description = TextAreaField('Description', validators=[Optional()])
|
||||||
|
type = SelectField('Type', validators=[DataRequired()], render_kw={'readonly': True})
|
||||||
|
type_version = StringField('Type Version', validators=[DataRequired()], render_kw={'readonly': True})
|
||||||
|
valid_from = DateField('Valid From', id='form-control datepicker', validators=[Optional()])
|
||||||
|
valid_to = DateField('Valid To', id='form-control datepicker', validators=[Optional()])
|
||||||
|
|
||||||
|
|
||||||
|
class EditEveAIAssetVersionForm(DynamicFormBase):
|
||||||
|
asset_name = StringField('Asset Name', validators=[DataRequired()], render_kw={'readonly': True})
|
||||||
|
asset_type = StringField('Asset Type', validators=[DataRequired()], render_kw={'readonly': True})
|
||||||
|
asset_type_version = StringField('Asset Type Version', validators=[DataRequired()], render_kw={'readonly': True})
|
||||||
|
bucket_name = StringField('Bucket Name', validators=[DataRequired()], render_kw={'readonly': True})
|
||||||
|
|
||||||
|
|||||||
@@ -6,12 +6,15 @@ from flask_security import roles_accepted
|
|||||||
from langchain.agents import Agent
|
from langchain.agents import Agent
|
||||||
from sqlalchemy import desc
|
from sqlalchemy import desc
|
||||||
from sqlalchemy.exc import SQLAlchemyError
|
from sqlalchemy.exc import SQLAlchemyError
|
||||||
|
from werkzeug.datastructures import FileStorage
|
||||||
|
from werkzeug.utils import secure_filename
|
||||||
|
|
||||||
from common.models.document import Embedding, DocumentVersion, Retriever
|
from common.models.document import Embedding, DocumentVersion, Retriever
|
||||||
from common.models.interaction import (ChatSession, Interaction, InteractionEmbedding, Specialist, SpecialistRetriever,
|
from common.models.interaction import (ChatSession, Interaction, InteractionEmbedding, Specialist, SpecialistRetriever,
|
||||||
EveAIAgent, EveAITask, EveAITool)
|
EveAIAgent, EveAITask, EveAITool, EveAIAssetVersion)
|
||||||
|
|
||||||
from common.extensions import db, cache_manager
|
from common.extensions import db, cache_manager
|
||||||
|
from common.utils.asset_utils import create_asset_stack, add_asset_version_file
|
||||||
from common.utils.model_logging_utils import set_logging_information, update_logging_information
|
from common.utils.model_logging_utils import set_logging_information, update_logging_information
|
||||||
|
|
||||||
from common.utils.middleware import mw_before_request
|
from common.utils.middleware import mw_before_request
|
||||||
@@ -22,7 +25,7 @@ from common.utils.specialist_utils import initialize_specialist
|
|||||||
from config.type_defs.specialist_types import SPECIALIST_TYPES
|
from config.type_defs.specialist_types import SPECIALIST_TYPES
|
||||||
|
|
||||||
from .interaction_forms import (SpecialistForm, EditSpecialistForm, EditEveAIAgentForm, EditEveAITaskForm,
|
from .interaction_forms import (SpecialistForm, EditSpecialistForm, EditEveAIAgentForm, EditEveAITaskForm,
|
||||||
EditEveAIToolForm)
|
EditEveAIToolForm, AddEveAIAssetForm, EditEveAIAssetVersionForm)
|
||||||
|
|
||||||
interaction_bp = Blueprint('interaction_bp', __name__, url_prefix='/interaction')
|
interaction_bp = Blueprint('interaction_bp', __name__, url_prefix='/interaction')
|
||||||
|
|
||||||
@@ -37,6 +40,7 @@ def log_after_request(response):
|
|||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
# Routes for Chat Session Management --------------------------------------------------------------
|
||||||
@interaction_bp.before_request
|
@interaction_bp.before_request
|
||||||
def before_request():
|
def before_request():
|
||||||
try:
|
try:
|
||||||
@@ -129,6 +133,7 @@ def show_chat_session(chat_session):
|
|||||||
return render_template('interaction/view_chat_session.html', chat_session=chat_session, interactions=interactions)
|
return render_template('interaction/view_chat_session.html', chat_session=chat_session, interactions=interactions)
|
||||||
|
|
||||||
|
|
||||||
|
# Routes for Specialist Management ----------------------------------------------------------------
|
||||||
@interaction_bp.route('/specialist', methods=['GET', 'POST'])
|
@interaction_bp.route('/specialist', methods=['GET', 'POST'])
|
||||||
@roles_accepted('Super User', 'Tenant Admin')
|
@roles_accepted('Super User', 'Tenant Admin')
|
||||||
def specialist():
|
def specialist():
|
||||||
@@ -142,7 +147,8 @@ def specialist():
|
|||||||
# Populate fields individually instead of using populate_obj (gives problem with QueryMultipleSelectField)
|
# Populate fields individually instead of using populate_obj (gives problem with QueryMultipleSelectField)
|
||||||
new_specialist.name = form.name.data
|
new_specialist.name = form.name.data
|
||||||
new_specialist.type = form.type.data
|
new_specialist.type = form.type.data
|
||||||
new_specialist.type_version = cache_manager.specialists_version_tree_cache.get_latest_version(new_specialist.type)
|
new_specialist.type_version = cache_manager.specialists_version_tree_cache.get_latest_version(
|
||||||
|
new_specialist.type)
|
||||||
new_specialist.tuning = form.tuning.data
|
new_specialist.tuning = form.tuning.data
|
||||||
|
|
||||||
set_logging_information(new_specialist, dt.now(tz.utc))
|
set_logging_information(new_specialist, dt.now(tz.utc))
|
||||||
@@ -252,7 +258,7 @@ def edit_specialist(specialist_id):
|
|||||||
task_rows=task_rows,
|
task_rows=task_rows,
|
||||||
tool_rows=tool_rows,
|
tool_rows=tool_rows,
|
||||||
prefixed_url_for=prefixed_url_for,
|
prefixed_url_for=prefixed_url_for,
|
||||||
svg_path=svg_path,)
|
svg_path=svg_path, )
|
||||||
else:
|
else:
|
||||||
form_validation_failed(request, form)
|
form_validation_failed(request, form)
|
||||||
|
|
||||||
@@ -263,7 +269,7 @@ def edit_specialist(specialist_id):
|
|||||||
task_rows=task_rows,
|
task_rows=task_rows,
|
||||||
tool_rows=tool_rows,
|
tool_rows=tool_rows,
|
||||||
prefixed_url_for=prefixed_url_for,
|
prefixed_url_for=prefixed_url_for,
|
||||||
svg_path=svg_path,)
|
svg_path=svg_path, )
|
||||||
|
|
||||||
|
|
||||||
@interaction_bp.route('/specialists', methods=['GET', 'POST'])
|
@interaction_bp.route('/specialists', methods=['GET', 'POST'])
|
||||||
@@ -298,7 +304,7 @@ def handle_specialist_selection():
|
|||||||
return redirect(prefixed_url_for('interaction_bp.specialists'))
|
return redirect(prefixed_url_for('interaction_bp.specialists'))
|
||||||
|
|
||||||
|
|
||||||
# Routes for Agent management
|
# Routes for Agent management ---------------------------------------------------------------------
|
||||||
@interaction_bp.route('/agent/<int:agent_id>/edit', methods=['GET'])
|
@interaction_bp.route('/agent/<int:agent_id>/edit', methods=['GET'])
|
||||||
@roles_accepted('Super User', 'Tenant Admin')
|
@roles_accepted('Super User', 'Tenant Admin')
|
||||||
def edit_agent(agent_id):
|
def edit_agent(agent_id):
|
||||||
@@ -338,7 +344,7 @@ def save_agent(agent_id):
|
|||||||
return jsonify({'success': False, 'message': 'Validation failed'})
|
return jsonify({'success': False, 'message': 'Validation failed'})
|
||||||
|
|
||||||
|
|
||||||
# Routes for Task management
|
# Routes for Task management ----------------------------------------------------------------------
|
||||||
@interaction_bp.route('/task/<int:task_id>/edit', methods=['GET'])
|
@interaction_bp.route('/task/<int:task_id>/edit', methods=['GET'])
|
||||||
@roles_accepted('Super User', 'Tenant Admin')
|
@roles_accepted('Super User', 'Tenant Admin')
|
||||||
def edit_task(task_id):
|
def edit_task(task_id):
|
||||||
@@ -374,7 +380,7 @@ def save_task(task_id):
|
|||||||
return jsonify({'success': False, 'message': 'Validation failed'})
|
return jsonify({'success': False, 'message': 'Validation failed'})
|
||||||
|
|
||||||
|
|
||||||
# Routes for Tool management
|
# Routes for Tool management ----------------------------------------------------------------------
|
||||||
@interaction_bp.route('/tool/<int:tool_id>/edit', methods=['GET'])
|
@interaction_bp.route('/tool/<int:tool_id>/edit', methods=['GET'])
|
||||||
@roles_accepted('Super User', 'Tenant Admin')
|
@roles_accepted('Super User', 'Tenant Admin')
|
||||||
def edit_tool(tool_id):
|
def edit_tool(tool_id):
|
||||||
@@ -410,7 +416,7 @@ def save_tool(tool_id):
|
|||||||
return jsonify({'success': False, 'message': 'Validation failed'})
|
return jsonify({'success': False, 'message': 'Validation failed'})
|
||||||
|
|
||||||
|
|
||||||
# Component selection handlers
|
# Component selection handlers --------------------------------------------------------------------
|
||||||
@interaction_bp.route('/handle_agent_selection', methods=['POST'])
|
@interaction_bp.route('/handle_agent_selection', methods=['POST'])
|
||||||
@roles_accepted('Super User', 'Tenant Admin')
|
@roles_accepted('Super User', 'Tenant Admin')
|
||||||
def handle_agent_selection():
|
def handle_agent_selection():
|
||||||
@@ -448,3 +454,92 @@ def handle_tool_selection():
|
|||||||
return redirect(prefixed_url_for('interaction_bp.edit_tool', tool_id=tool_id))
|
return redirect(prefixed_url_for('interaction_bp.edit_tool', tool_id=tool_id))
|
||||||
|
|
||||||
return redirect(prefixed_url_for('interaction_bp.edit_specialist'))
|
return redirect(prefixed_url_for('interaction_bp.edit_specialist'))
|
||||||
|
|
||||||
|
|
||||||
|
# Routes for Asset management ---------------------------------------------------------------------
|
||||||
|
@interaction_bp.route('/add_asset', methods=['GET', 'POST'])
|
||||||
|
@roles_accepted('Super User', 'Tenant Admin')
|
||||||
|
def add_asset():
|
||||||
|
form = AddEveAIAssetForm(request.form)
|
||||||
|
tenant_id = session.get('tenant').get('id')
|
||||||
|
|
||||||
|
if form.validate_on_submit():
|
||||||
|
try:
|
||||||
|
current_app.logger.info(f"Adding asset for tenant {tenant_id}")
|
||||||
|
|
||||||
|
api_input = {
|
||||||
|
'name': form.name.data,
|
||||||
|
'description': form.description.data,
|
||||||
|
'type': form.type.data,
|
||||||
|
'valid_from': form.valid_from.data,
|
||||||
|
'valid_to': form.valid_to.data,
|
||||||
|
}
|
||||||
|
new_asset, new_asset_version = create_asset_stack(api_input, tenant_id)
|
||||||
|
|
||||||
|
return redirect(prefixed_url_for('interaction_bp.edit_asset_version',
|
||||||
|
asset_version_id=new_asset_version.id))
|
||||||
|
except Exception as e:
|
||||||
|
current_app.logger.error(f'Failed to add asset for tenant {tenant_id}: {str(e)}')
|
||||||
|
flash('An error occurred while adding asset', 'error')
|
||||||
|
|
||||||
|
return render_template('interaction/add_asset.html')
|
||||||
|
|
||||||
|
|
||||||
|
@interaction_bp.route('/edit_asset_version/<int:asset_version_id>', methods=['GET', 'POST'])
|
||||||
|
@roles_accepted('Super User', 'Tenant Admin')
|
||||||
|
def edit_asset_version(asset_version_id):
|
||||||
|
asset_version = EveAIAssetVersion.query.get_or_404(asset_version_id)
|
||||||
|
form = EditEveAIAssetVersionForm(asset_version)
|
||||||
|
asset_config = cache_manager.assets_config_cache.get_config(asset_version.asset.type,
|
||||||
|
asset_version.asset.type_version)
|
||||||
|
configuration_config = asset_config.get('configuration')
|
||||||
|
form.add_dynamic_fields("configuration", configuration_config, asset_version.configuration)
|
||||||
|
|
||||||
|
if form.validate_on_submit():
|
||||||
|
# Update the configuration dynamic fields
|
||||||
|
configuration = form.get_dynamic_data("configuration")
|
||||||
|
processed_configuration = {}
|
||||||
|
tenant_id = session.get('tenant').get('id')
|
||||||
|
# if files are returned, we will store the file_names in the configuration, and add the file to the appropriate
|
||||||
|
# bucket, in the appropriate location
|
||||||
|
for field_name, field_value in configuration.items():
|
||||||
|
# Handle file field - check if the value is a FileStorage instance
|
||||||
|
if isinstance(field_value, FileStorage) and field_value.filename:
|
||||||
|
try:
|
||||||
|
# Upload file and retrieve object_name for the file
|
||||||
|
object_name = add_asset_version_file(asset_version, field_name, field_value, tenant_id)
|
||||||
|
|
||||||
|
# Store object reference in configuration instead of file content
|
||||||
|
processed_configuration[field_name] = object_name
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
current_app.logger.error(f"Failed to upload file for asset version {asset_version.id}: {str(e)}")
|
||||||
|
flash(f"Failed to upload file '{field_value.filename}': {str(e)}", "danger")
|
||||||
|
return render_template('interaction/edit_asset_version.html', form=form,
|
||||||
|
asset_version=asset_version)
|
||||||
|
# Handle normal fields
|
||||||
|
else:
|
||||||
|
processed_configuration[field_name] = field_value
|
||||||
|
|
||||||
|
# Update the asset version with processed configuration
|
||||||
|
asset_version.configuration = processed_configuration
|
||||||
|
|
||||||
|
# Update logging information
|
||||||
|
update_logging_information(asset_version, dt.now(tz.utc))
|
||||||
|
|
||||||
|
try:
|
||||||
|
db.session.commit()
|
||||||
|
flash('Asset uploaded successfully!', 'success')
|
||||||
|
current_app.logger.info(f'Asset Version {asset_version.id} updated successfully')
|
||||||
|
return redirect(prefixed_url_for('interaction_bp.assets'))
|
||||||
|
except SQLAlchemyError as e:
|
||||||
|
db.session.rollback()
|
||||||
|
flash(f'Failed to upload asset. Error: {str(e)}', 'danger')
|
||||||
|
current_app.logger.error(f'Failed to update asset version {asset_version.id}. Error: {str(e)}')
|
||||||
|
return render_template('interaction/edit_asset_version.html', form=form)
|
||||||
|
else:
|
||||||
|
form_validation_failed(request, form)
|
||||||
|
|
||||||
|
return render_template('interaction/edit_asset_version.html', form=form)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
# Import all specialist implementations here to ensure registration
|
# Import all specialist implementations here to ensure registration
|
||||||
from . import standard_rag
|
from . import standard_rag
|
||||||
|
from . import dossier_retriever
|
||||||
|
|
||||||
# List of all available specialist implementations
|
# List of all available specialist implementations
|
||||||
__all__ = ['standard_rag']
|
__all__ = ['standard_rag', 'dossier_retriever']
|
||||||
|
|||||||
@@ -1,8 +1,14 @@
|
|||||||
from abc import ABC, abstractmethod, abstractproperty
|
from abc import ABC, abstractmethod, abstractproperty
|
||||||
from typing import Dict, Any, List
|
from typing import Dict, Any, List, Optional, Tuple
|
||||||
|
|
||||||
from flask import current_app
|
from flask import current_app
|
||||||
from eveai_chat_workers.retrievers.retriever_typing import RetrieverResult, RetrieverArguments
|
from sqlalchemy import func, or_, desc
|
||||||
|
from sqlalchemy.exc import SQLAlchemyError
|
||||||
|
|
||||||
|
from common.extensions import db
|
||||||
|
from common.models.document import Document, DocumentVersion, Catalog, Retriever
|
||||||
|
from common.utils.model_utils import get_embedding_model_and_class
|
||||||
|
from eveai_chat_workers.retrievers.retriever_typing import RetrieverResult, RetrieverArguments, RetrieverMetadata
|
||||||
from config.logging_config import TuningLogger
|
from config.logging_config import TuningLogger
|
||||||
|
|
||||||
|
|
||||||
@@ -12,6 +18,7 @@ class BaseRetriever(ABC):
|
|||||||
def __init__(self, tenant_id: int, retriever_id: int):
|
def __init__(self, tenant_id: int, retriever_id: int):
|
||||||
self.tenant_id = tenant_id
|
self.tenant_id = tenant_id
|
||||||
self.retriever_id = retriever_id
|
self.retriever_id = retriever_id
|
||||||
|
self.retriever = Retriever.query.get_or_404(retriever_id)
|
||||||
self.tuning = False
|
self.tuning = False
|
||||||
self.tuning_logger = None
|
self.tuning_logger = None
|
||||||
self._setup_tuning_logger()
|
self._setup_tuning_logger()
|
||||||
@@ -43,6 +50,31 @@ class BaseRetriever(ABC):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
current_app.logger.error(f"Processor: Error in tuning logging: {e}")
|
current_app.logger.error(f"Processor: Error in tuning logging: {e}")
|
||||||
|
|
||||||
|
def setup_standard_retrieval_params(self) -> Tuple[Any, Any, Any, float, int]:
|
||||||
|
"""
|
||||||
|
Set up common parameters needed for standard retrieval functionality
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple containing:
|
||||||
|
- embedding_model: The model to use for embeddings
|
||||||
|
- embedding_model_class: The class for storing embeddings
|
||||||
|
- catalog_id: ID of the catalog
|
||||||
|
- similarity_threshold: Threshold for similarity matching
|
||||||
|
- k: Maximum number of results to return
|
||||||
|
"""
|
||||||
|
catalog_id = self.retriever.catalog_id
|
||||||
|
catalog = Catalog.query.get_or_404(catalog_id)
|
||||||
|
embedding_model = "mistral.mistral-embed"
|
||||||
|
|
||||||
|
embedding_model, embedding_model_class = get_embedding_model_and_class(
|
||||||
|
self.tenant_id, catalog_id, embedding_model
|
||||||
|
)
|
||||||
|
|
||||||
|
similarity_threshold = self.retriever.configuration.get('es_similarity_threshold', 0.3)
|
||||||
|
k = self.retriever.configuration.get('es_k', 8)
|
||||||
|
|
||||||
|
return embedding_model, embedding_model_class, catalog_id, similarity_threshold, k
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def retrieve(self, arguments: RetrieverArguments) -> List[RetrieverResult]:
|
def retrieve(self, arguments: RetrieverArguments) -> List[RetrieverResult]:
|
||||||
"""
|
"""
|
||||||
|
|||||||
374
eveai_chat_workers/retrievers/dossier_retriever.py
Normal file
374
eveai_chat_workers/retrievers/dossier_retriever.py
Normal file
@@ -0,0 +1,374 @@
|
|||||||
|
"""
|
||||||
|
DossierRetriever implementation that adds metadata filtering to retrieval
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
from datetime import datetime as dt, date, timezone as tz
|
||||||
|
from typing import Dict, Any, List, Optional, Union, Tuple
|
||||||
|
from sqlalchemy import func, or_, desc, and_, text, cast, JSON, String, Integer, Float, Boolean, DateTime
|
||||||
|
from sqlalchemy.sql import expression
|
||||||
|
from sqlalchemy.exc import SQLAlchemyError
|
||||||
|
from flask import current_app
|
||||||
|
|
||||||
|
from common.extensions import db
|
||||||
|
from common.models.document import Document, DocumentVersion, Catalog, Retriever
|
||||||
|
from common.utils.model_utils import get_embedding_model_and_class
|
||||||
|
from .base import BaseRetriever
|
||||||
|
from .registry import RetrieverRegistry
|
||||||
|
from .retriever_typing import RetrieverArguments, RetrieverResult, RetrieverMetadata
|
||||||
|
|
||||||
|
|
||||||
|
class DossierRetriever(BaseRetriever):
|
||||||
|
"""
|
||||||
|
Dossier Retriever implementation that adds metadata filtering
|
||||||
|
to standard retrieval functionality
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, tenant_id: int, retriever_id: int):
|
||||||
|
super().__init__(tenant_id, retriever_id)
|
||||||
|
|
||||||
|
# Set up standard retrieval parameters
|
||||||
|
self.embedding_model, self.embedding_model_class, self.catalog_id, self.similarity_threshold, self.k = self.setup_standard_retrieval_params()
|
||||||
|
|
||||||
|
# Dossier-specific configuration
|
||||||
|
self.tagging_fields_filter = self.retriever.configuration.get('tagging_fields_filter', {})
|
||||||
|
self.dynamic_arguments = self.retriever.configuration.get('dynamic_arguments', {})
|
||||||
|
|
||||||
|
self.log_tuning("Dossier retriever initialized", {
|
||||||
|
"tagging_fields_filter": self.tagging_fields_filter,
|
||||||
|
"dynamic_arguments": self.dynamic_arguments,
|
||||||
|
"similarity_threshold": self.similarity_threshold,
|
||||||
|
"k": self.k
|
||||||
|
})
|
||||||
|
|
||||||
|
@property
|
||||||
|
def type(self) -> str:
|
||||||
|
return "DOSSIER_RETRIEVER"
|
||||||
|
|
||||||
|
def _parse_metadata(self, metadata: Any) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Parse metadata ensuring it's a dictionary
|
||||||
|
|
||||||
|
Args:
|
||||||
|
metadata: Input metadata which could be string, dict, or None
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, Any]: Parsed metadata as dictionary
|
||||||
|
"""
|
||||||
|
if metadata is None:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
if isinstance(metadata, dict):
|
||||||
|
return metadata
|
||||||
|
|
||||||
|
if isinstance(metadata, str):
|
||||||
|
try:
|
||||||
|
return json.loads(metadata)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
current_app.logger.warning(f"Failed to parse metadata JSON string: {metadata}")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
current_app.logger.warning(f"Unexpected metadata type: {type(metadata)}")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def _apply_metadata_filter(self, query_obj, arguments: RetrieverArguments):
|
||||||
|
"""
|
||||||
|
Apply metadata filters to the query based on tagging_fields_filter configuration
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query_obj: SQLAlchemy query object
|
||||||
|
arguments: Retriever arguments (for variable substitution)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Modified SQLAlchemy query object
|
||||||
|
"""
|
||||||
|
if not self.tagging_fields_filter:
|
||||||
|
return query_obj
|
||||||
|
|
||||||
|
# Get dynamic argument values
|
||||||
|
dynamic_values = {}
|
||||||
|
for arg_name, arg_config in self.dynamic_arguments.items():
|
||||||
|
if hasattr(arguments, arg_name):
|
||||||
|
dynamic_values[arg_name] = getattr(arguments, arg_name)
|
||||||
|
|
||||||
|
# Build the filter
|
||||||
|
filter_condition = self._build_filter_condition(self.tagging_fields_filter, dynamic_values)
|
||||||
|
if filter_condition is not None:
|
||||||
|
query_obj = query_obj.filter(filter_condition)
|
||||||
|
self.log_tuning("Applied metadata filter", {
|
||||||
|
"filter_sql": str(filter_condition),
|
||||||
|
"dynamic_values": dynamic_values
|
||||||
|
})
|
||||||
|
|
||||||
|
return query_obj
|
||||||
|
|
||||||
|
def _build_filter_condition(self, filter_def: Dict[str, Any], dynamic_values: Dict[str, Any]) -> Optional[
|
||||||
|
expression.BinaryExpression]:
|
||||||
|
"""
|
||||||
|
Recursively build SQLAlchemy filter condition from filter definition
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filter_def: Filter definition (logical group or field condition)
|
||||||
|
dynamic_values: Values for dynamic variable substitution
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SQLAlchemy expression or None if invalid
|
||||||
|
"""
|
||||||
|
# Handle logical groups (AND, OR, NOT)
|
||||||
|
if 'logical' in filter_def:
|
||||||
|
logical_op = filter_def['logical'].lower()
|
||||||
|
subconditions = [
|
||||||
|
self._build_filter_condition(cond, dynamic_values)
|
||||||
|
for cond in filter_def.get('conditions', [])
|
||||||
|
]
|
||||||
|
# Remove None values
|
||||||
|
subconditions = [c for c in subconditions if c is not None]
|
||||||
|
|
||||||
|
if not subconditions:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if logical_op == 'and':
|
||||||
|
return and_(*subconditions)
|
||||||
|
elif logical_op == 'or':
|
||||||
|
return or_(*subconditions)
|
||||||
|
elif logical_op == 'not':
|
||||||
|
if len(subconditions) == 1:
|
||||||
|
return ~subconditions[0]
|
||||||
|
else:
|
||||||
|
# NOT should have exactly one condition
|
||||||
|
current_app.logger.warning(f"NOT operator requires exactly one condition, got {len(subconditions)}")
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
current_app.logger.warning(f"Unknown logical operator: {logical_op}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Handle field conditions
|
||||||
|
elif 'field' in filter_def and 'operator' in filter_def and 'value' in filter_def:
|
||||||
|
field_name = filter_def['field']
|
||||||
|
operator = filter_def['operator'].lower()
|
||||||
|
value = self._resolve_value(filter_def['value'], dynamic_values, filter_def.get('default'))
|
||||||
|
|
||||||
|
# Skip if we couldn't resolve the value
|
||||||
|
if value is None and operator not in ['is_null', 'is_not_null']:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Create the field expression to match JSON data
|
||||||
|
field_expr = cast(DocumentVersion.catalog_properties['tagging_fields'][field_name], String)
|
||||||
|
|
||||||
|
# Apply the appropriate operator
|
||||||
|
return self._apply_operator(field_expr, operator, value, filter_def)
|
||||||
|
|
||||||
|
else:
|
||||||
|
current_app.logger.warning(f"Invalid filter definition: {filter_def}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _resolve_value(self, value_def: Any, dynamic_values: Dict[str, Any], default: Any = None) -> Any:
|
||||||
|
"""
|
||||||
|
Resolve a value definition, handling variables and defaults
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value_def: Value definition (could be literal, variable reference, or list)
|
||||||
|
dynamic_values: Values for dynamic variable substitution
|
||||||
|
default: Default value if variable not found
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Resolved value
|
||||||
|
"""
|
||||||
|
# Handle lists (recursively resolve each item)
|
||||||
|
if isinstance(value_def, list):
|
||||||
|
return [self._resolve_value(item, dynamic_values) for item in value_def]
|
||||||
|
|
||||||
|
# Handle variable references (strings starting with $)
|
||||||
|
if isinstance(value_def, str) and value_def.startswith('$'):
|
||||||
|
var_name = value_def[1:] # Remove $ prefix
|
||||||
|
if var_name in dynamic_values:
|
||||||
|
return dynamic_values[var_name]
|
||||||
|
else:
|
||||||
|
# Use default if provided
|
||||||
|
return default
|
||||||
|
|
||||||
|
# Return literal values as-is
|
||||||
|
return value_def
|
||||||
|
|
||||||
|
def _apply_operator(self, field_expr, operator: str, value: Any, filter_def: Dict[str, Any]) -> Optional[
|
||||||
|
expression.BinaryExpression]:
|
||||||
|
"""
|
||||||
|
Apply the specified operator to create a filter condition
|
||||||
|
|
||||||
|
Args:
|
||||||
|
field_expr: SQLAlchemy field expression
|
||||||
|
operator: Operator to apply
|
||||||
|
value: Value to compare against
|
||||||
|
filter_def: Original filter definition (for additional options)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SQLAlchemy expression
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# String operators
|
||||||
|
if operator == 'eq':
|
||||||
|
return field_expr == str(value)
|
||||||
|
elif operator == 'neq':
|
||||||
|
return field_expr != str(value)
|
||||||
|
elif operator == 'contains':
|
||||||
|
return field_expr.contains(str(value))
|
||||||
|
elif operator == 'not_contains':
|
||||||
|
return ~field_expr.contains(str(value))
|
||||||
|
elif operator == 'starts_with':
|
||||||
|
return field_expr.startswith(str(value))
|
||||||
|
elif operator == 'ends_with':
|
||||||
|
return field_expr.endswith(str(value))
|
||||||
|
elif operator == 'in':
|
||||||
|
return field_expr.in_([str(v) for v in value])
|
||||||
|
elif operator == 'not_in':
|
||||||
|
return ~field_expr.in_([str(v) for v in value])
|
||||||
|
elif operator == 'regex' or operator == 'not_regex':
|
||||||
|
# PostgreSQL regex using ~ or !~ operator
|
||||||
|
case_insensitive = filter_def.get('case_insensitive', False)
|
||||||
|
regex_op = '~*' if case_insensitive else '~'
|
||||||
|
if operator == 'not_regex':
|
||||||
|
regex_op = '!~*' if case_insensitive else '!~'
|
||||||
|
return text(
|
||||||
|
f"{field_expr.compile(compile_kwargs={'literal_binds': True})} {regex_op} :regex_value").bindparams(
|
||||||
|
regex_value=str(value))
|
||||||
|
|
||||||
|
# Numeric/Date operators
|
||||||
|
elif operator == 'gt':
|
||||||
|
return cast(field_expr, Float) > float(value)
|
||||||
|
elif operator == 'gte':
|
||||||
|
return cast(field_expr, Float) >= float(value)
|
||||||
|
elif operator == 'lt':
|
||||||
|
return cast(field_expr, Float) < float(value)
|
||||||
|
elif operator == 'lte':
|
||||||
|
return cast(field_expr, Float) <= float(value)
|
||||||
|
elif operator == 'between':
|
||||||
|
if len(value) == 2:
|
||||||
|
return cast(field_expr, Float).between(float(value[0]), float(value[1]))
|
||||||
|
else:
|
||||||
|
current_app.logger.warning(f"BETWEEN operator requires exactly two values, got {len(value)}")
|
||||||
|
return None
|
||||||
|
elif operator == 'not_between':
|
||||||
|
if len(value) == 2:
|
||||||
|
return ~cast(field_expr, Float).between(float(value[0]), float(value[1]))
|
||||||
|
else:
|
||||||
|
current_app.logger.warning(f"NOT_BETWEEN operator requires exactly two values, got {len(value)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Null checking
|
||||||
|
elif operator == 'is_null':
|
||||||
|
return field_expr.is_(None)
|
||||||
|
elif operator == 'is_not_null':
|
||||||
|
return field_expr.isnot(None)
|
||||||
|
|
||||||
|
else:
|
||||||
|
current_app.logger.warning(f"Unknown operator: {operator}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
except (ValueError, TypeError) as e:
|
||||||
|
current_app.logger.error(f"Error applying operator {operator} with value {value}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def retrieve(self, arguments: RetrieverArguments) -> List[RetrieverResult]:
|
||||||
|
"""
|
||||||
|
Retrieve documents based on query with added metadata filtering
|
||||||
|
|
||||||
|
Args:
|
||||||
|
arguments: Validated RetrieverArguments containing at minimum:
|
||||||
|
- query: str - The search query
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[RetrieverResult]: List of retrieved documents with similarity scores
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
query = arguments.query
|
||||||
|
|
||||||
|
# Get query embedding
|
||||||
|
query_embedding = self.embedding_model.embed_query(query)
|
||||||
|
|
||||||
|
# Get the appropriate embedding database model
|
||||||
|
db_class = self.embedding_model_class
|
||||||
|
|
||||||
|
# Get current date for validity checks
|
||||||
|
current_date = dt.now(tz=tz.utc).date()
|
||||||
|
|
||||||
|
# Create subquery for latest versions
|
||||||
|
subquery = (
|
||||||
|
db.session.query(
|
||||||
|
DocumentVersion.doc_id,
|
||||||
|
func.max(DocumentVersion.id).label('latest_version_id')
|
||||||
|
)
|
||||||
|
.group_by(DocumentVersion.doc_id)
|
||||||
|
.subquery()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Main query
|
||||||
|
query_obj = (
|
||||||
|
db.session.query(
|
||||||
|
db_class,
|
||||||
|
(1 - db_class.embedding.cosine_distance(query_embedding)).label('similarity')
|
||||||
|
)
|
||||||
|
.join(DocumentVersion, db_class.doc_vers_id == DocumentVersion.id)
|
||||||
|
.join(Document, DocumentVersion.doc_id == Document.id)
|
||||||
|
.join(subquery, DocumentVersion.id == subquery.c.latest_version_id)
|
||||||
|
.filter(
|
||||||
|
or_(Document.valid_from.is_(None), func.date(Document.valid_from) <= current_date),
|
||||||
|
or_(Document.valid_to.is_(None), func.date(Document.valid_to) >= current_date),
|
||||||
|
(1 - db_class.embedding.cosine_distance(query_embedding)) > self.similarity_threshold,
|
||||||
|
Document.catalog_id == self.catalog_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply metadata filtering
|
||||||
|
query_obj = self._apply_metadata_filter(query_obj, arguments)
|
||||||
|
|
||||||
|
# Apply ordering and limit
|
||||||
|
query_obj = query_obj.order_by(desc('similarity')).limit(self.k)
|
||||||
|
|
||||||
|
# Execute query
|
||||||
|
results = query_obj.all()
|
||||||
|
|
||||||
|
# Transform results into standard format
|
||||||
|
processed_results = []
|
||||||
|
for doc, similarity in results:
|
||||||
|
# Parse user_metadata to ensure it's a dictionary
|
||||||
|
user_metadata = self._parse_metadata(doc.document_version.user_metadata)
|
||||||
|
processed_results.append(
|
||||||
|
RetrieverResult(
|
||||||
|
id=doc.id,
|
||||||
|
chunk=doc.chunk,
|
||||||
|
similarity=float(similarity),
|
||||||
|
metadata=RetrieverMetadata(
|
||||||
|
document_id=doc.document_version.doc_id,
|
||||||
|
version_id=doc.document_version.id,
|
||||||
|
document_name=doc.document_version.document.name,
|
||||||
|
user_metadata=user_metadata,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Log the retrieval
|
||||||
|
if self.tuning:
|
||||||
|
compiled_query = str(query_obj.statement.compile(
|
||||||
|
compile_kwargs={"literal_binds": True} # This will include the actual values in the SQL
|
||||||
|
))
|
||||||
|
self.log_tuning('retrieve', {
|
||||||
|
"arguments": arguments.model_dump(),
|
||||||
|
"similarity_threshold": self.similarity_threshold,
|
||||||
|
"k": self.k,
|
||||||
|
"query": compiled_query,
|
||||||
|
"results_count": len(results),
|
||||||
|
"processed_results_count": len(processed_results),
|
||||||
|
})
|
||||||
|
|
||||||
|
return processed_results
|
||||||
|
|
||||||
|
except SQLAlchemyError as e:
|
||||||
|
current_app.logger.error(f'Error in Dossier retrieval: {e}')
|
||||||
|
db.session.rollback()
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
current_app.logger.error(f'Unexpected error in Dossier retrieval: {e}')
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
# Register the retriever type
|
||||||
|
RetrieverRegistry.register("DOSSIER_RETRIEVER", DossierRetriever)
|
||||||
@@ -23,20 +23,12 @@ class StandardRAGRetriever(BaseRetriever):
|
|||||||
def __init__(self, tenant_id: int, retriever_id: int):
|
def __init__(self, tenant_id: int, retriever_id: int):
|
||||||
super().__init__(tenant_id, retriever_id)
|
super().__init__(tenant_id, retriever_id)
|
||||||
|
|
||||||
retriever = Retriever.query.get_or_404(retriever_id)
|
# Set up standard retrieval parameters
|
||||||
self.catalog_id = retriever.catalog_id
|
self.embedding_model, self.embedding_model_class, self.catalog_id, self.similarity_threshold, self.k = self.setup_standard_retrieval_params()
|
||||||
self.tenant_id = tenant_id
|
self.log_tuning("Standard RAG retriever initialized", {
|
||||||
catalog = Catalog.query.get_or_404(self.catalog_id)
|
"similarity_threshold": self.similarity_threshold,
|
||||||
embedding_model = "mistral.mistral-embed"
|
"k": self.k
|
||||||
|
})
|
||||||
self.embedding_model, self.embedding_model_class = get_embedding_model_and_class(self.tenant_id,
|
|
||||||
self.catalog_id,
|
|
||||||
embedding_model)
|
|
||||||
self.similarity_threshold = retriever.configuration.get('es_similarity_threshold', 0.3)
|
|
||||||
self.k = retriever.configuration.get('es_k', 8)
|
|
||||||
self.tuning = retriever.tuning
|
|
||||||
|
|
||||||
self.log_tuning("Standard RAG retriever initialized")
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def type(self) -> str:
|
def type(self) -> str:
|
||||||
|
|||||||
@@ -0,0 +1,83 @@
|
|||||||
|
"""Add EveAIAsset m& associated models
|
||||||
|
|
||||||
|
Revision ID: 4a9f7a6285cc
|
||||||
|
Revises: ed981a641be9
|
||||||
|
Create Date: 2025-03-12 09:43:05.390895
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
import pgvector
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = '4a9f7a6285cc'
|
||||||
|
down_revision = 'ed981a641be9'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.create_table('eve_ai_asset',
|
||||||
|
sa.Column('id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('name', sa.String(length=50), nullable=False),
|
||||||
|
sa.Column('description', sa.Text(), nullable=True),
|
||||||
|
sa.Column('type', sa.String(length=50), nullable=False),
|
||||||
|
sa.Column('type_version', sa.String(length=20), nullable=True),
|
||||||
|
sa.Column('valid_from', sa.DateTime(), nullable=True),
|
||||||
|
sa.Column('valid_to', sa.DateTime(), nullable=True),
|
||||||
|
sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=False),
|
||||||
|
sa.Column('created_by', sa.Integer(), nullable=True),
|
||||||
|
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('now()'), nullable=False),
|
||||||
|
sa.Column('updated_by', sa.Integer(), nullable=True),
|
||||||
|
sa.ForeignKeyConstraint(['created_by'], ['public.user.id'], ),
|
||||||
|
sa.ForeignKeyConstraint(['updated_by'], ['public.user.id'], ),
|
||||||
|
sa.PrimaryKeyConstraint('id')
|
||||||
|
)
|
||||||
|
op.create_table('eve_ai_asset_version',
|
||||||
|
sa.Column('id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('asset_id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('bucket_name', sa.String(length=255), nullable=True),
|
||||||
|
sa.Column('configuration', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||||
|
sa.Column('arguments', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||||
|
sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=False),
|
||||||
|
sa.Column('created_by', sa.Integer(), nullable=True),
|
||||||
|
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('now()'), nullable=False),
|
||||||
|
sa.Column('updated_by', sa.Integer(), nullable=True),
|
||||||
|
sa.ForeignKeyConstraint(['asset_id'], ['eve_ai_asset.id'], ),
|
||||||
|
sa.ForeignKeyConstraint(['created_by'], ['public.user.id'], ),
|
||||||
|
sa.ForeignKeyConstraint(['updated_by'], ['public.user.id'], ),
|
||||||
|
sa.PrimaryKeyConstraint('id')
|
||||||
|
)
|
||||||
|
op.create_table('eve_ai_asset_instruction',
|
||||||
|
sa.Column('id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('asset_version_id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('name', sa.String(length=255), nullable=False),
|
||||||
|
sa.Column('content', sa.Text(), nullable=True),
|
||||||
|
sa.ForeignKeyConstraint(['asset_version_id'], ['eve_ai_asset_version.id'], ),
|
||||||
|
sa.PrimaryKeyConstraint('id')
|
||||||
|
)
|
||||||
|
op.create_table('eve_ai_processed_asset',
|
||||||
|
sa.Column('id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('asset_version_id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('specialist_id', sa.Integer(), nullable=True),
|
||||||
|
sa.Column('chat_session_id', sa.Integer(), nullable=True),
|
||||||
|
sa.Column('bucket_name', sa.String(length=255), nullable=True),
|
||||||
|
sa.Column('object_name', sa.String(length=255), nullable=True),
|
||||||
|
sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True),
|
||||||
|
sa.ForeignKeyConstraint(['asset_version_id'], ['eve_ai_asset_version.id'], ),
|
||||||
|
sa.ForeignKeyConstraint(['chat_session_id'], ['chat_session.id'], ),
|
||||||
|
sa.ForeignKeyConstraint(['specialist_id'], ['specialist.id'], ),
|
||||||
|
sa.PrimaryKeyConstraint('id')
|
||||||
|
)
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.drop_table('eve_ai_processed_asset')
|
||||||
|
op.drop_table('eve_ai_asset_instruction')
|
||||||
|
op.drop_table('eve_ai_asset_version')
|
||||||
|
op.drop_table('eve_ai_asset')
|
||||||
|
# ### end Alembic commands ###
|
||||||
Reference in New Issue
Block a user