- Allowing for multiple types of Catalogs
- Introduction of retrievers - Ensuring processing information is collected from Catalog iso Tenant - Introduction of a generic Form class to enable dynamic fields based on a configuration - Realisation of Retriever functionality to support dynamic fields
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from flask import session, current_app
|
||||
from flask import session, current_app, request
|
||||
from flask_wtf import FlaskForm
|
||||
from wtforms import (StringField, BooleanField, SubmitField, DateField, IntegerField, FloatField, SelectMultipleField,
|
||||
SelectField, FieldList, FormField, TextAreaField, URLField)
|
||||
@@ -6,6 +6,14 @@ from wtforms.validators import DataRequired, Length, Optional, URL, ValidationEr
|
||||
from flask_wtf.file import FileField, FileAllowed, FileRequired
|
||||
import json
|
||||
|
||||
from wtforms_sqlalchemy.fields import QuerySelectField
|
||||
|
||||
from common.models.document import Catalog
|
||||
|
||||
from config.catalog_types import CATALOG_TYPES
|
||||
from config.retriever_types import RETRIEVER_TYPES
|
||||
from .dynamic_form_base import DynamicFormBase
|
||||
|
||||
|
||||
def allowed_file(form, field):
|
||||
if field.data:
|
||||
@@ -26,6 +34,23 @@ def validate_json(form, field):
|
||||
class CatalogForm(FlaskForm):
|
||||
name = StringField('Name', validators=[DataRequired(), Length(max=50)])
|
||||
description = TextAreaField('Description', validators=[Optional()])
|
||||
# Parent ID (Optional for root-level catalogs)
|
||||
parent = QuerySelectField(
|
||||
'Parent Catalog',
|
||||
query_factory=lambda: Catalog.query.all(),
|
||||
allow_blank=True,
|
||||
get_label='name',
|
||||
validators=[Optional()],
|
||||
)
|
||||
|
||||
# Select Field for Catalog Type (Uses the CATALOG_TYPES defined in config)
|
||||
type = SelectField('Catalog Type', validators=[DataRequired()])
|
||||
|
||||
# Metadata fields
|
||||
user_metadata = TextAreaField('User Metadata', validators=[Optional(), validate_json])
|
||||
system_metadata = TextAreaField('System Metadata', validators=[Optional(), validate_json])
|
||||
configuration = TextAreaField('Configuration', validators=[Optional(), validate_json])
|
||||
|
||||
# HTML Embedding Variables
|
||||
html_tags = StringField('HTML Tags', validators=[DataRequired()],
|
||||
default='p, h1, h2, h3, h4, h5, h6, li, , tbody, tr, td')
|
||||
@@ -38,19 +63,65 @@ class CatalogForm(FlaskForm):
|
||||
default=2000)
|
||||
max_chunk_size = IntegerField('Maximum Chunk Size (3000)', validators=[NumberRange(min=0), Optional()],
|
||||
default=3000)
|
||||
# Embedding Search variables
|
||||
es_k = IntegerField('Limit for Searching Embeddings (5)',
|
||||
default=5,
|
||||
validators=[NumberRange(min=0)])
|
||||
es_similarity_threshold = FloatField('Similarity Threshold for Searching Embeddings (0.5)',
|
||||
default=0.5,
|
||||
validators=[NumberRange(min=0, max=1)])
|
||||
# Chat Variables
|
||||
chat_RAG_temperature = FloatField('RAG Temperature', default=0.3, validators=[NumberRange(min=0, max=1)])
|
||||
chat_no_RAG_temperature = FloatField('No RAG Temperature', default=0.5, validators=[NumberRange(min=0, max=1)])
|
||||
# Tuning variables
|
||||
embed_tuning = BooleanField('Enable Embedding Tuning', default=False)
|
||||
rag_tuning = BooleanField('Enable RAG Tuning', default=False)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
# Dynamically populate the 'type' field using the constructor
|
||||
self.type.choices = [(key, value['name']) for key, value in CATALOG_TYPES.items()]
|
||||
|
||||
|
||||
class RetrieverForm(FlaskForm):
|
||||
name = StringField('Name', validators=[DataRequired(), Length(max=50)])
|
||||
description = TextAreaField('Description', validators=[Optional()])
|
||||
# Catalog for the Retriever
|
||||
catalog = QuerySelectField(
|
||||
'Catalog ID',
|
||||
query_factory=lambda: Catalog.query.all(),
|
||||
allow_blank=True,
|
||||
get_label='name',
|
||||
validators=[Optional()],
|
||||
)
|
||||
# Select Field for Retriever Type (Uses the RETRIEVER_TYPES defined in config)
|
||||
type = SelectField('Retriever Type', validators=[DataRequired()])
|
||||
|
||||
# Metadata fields
|
||||
user_metadata = TextAreaField('User Metadata', validators=[Optional(), validate_json])
|
||||
system_metadata = TextAreaField('System Metadata', validators=[Optional(), validate_json])
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
# Dynamically populate the 'type' field using the constructor
|
||||
self.type.choices = [(key, value['name']) for key, value in RETRIEVER_TYPES.items()]
|
||||
|
||||
|
||||
class EditRetrieverForm(DynamicFormBase):
|
||||
name = StringField('Name', validators=[DataRequired(), Length(max=50)])
|
||||
description = TextAreaField('Description', validators=[Optional()])
|
||||
# Catalog for the Retriever
|
||||
catalog = QuerySelectField(
|
||||
'Catalog ID',
|
||||
query_factory=lambda: Catalog.query.all(),
|
||||
allow_blank=True,
|
||||
get_label='name',
|
||||
validators=[Optional()],
|
||||
)
|
||||
# Select Field for Retriever Type (Uses the RETRIEVER_TYPES defined in config)
|
||||
type = SelectField('Retriever Type', validators=[DataRequired()], render_kw={'readonly': True})
|
||||
|
||||
# Metadata fields
|
||||
user_metadata = TextAreaField('User Metadata', validators=[Optional(), validate_json])
|
||||
system_metadata = TextAreaField('System Metadata', validators=[Optional(), validate_json])
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# Set the retriever type choices (loaded from config)
|
||||
self.type.choices = [(key, value['name']) for key, value in RETRIEVER_TYPES.items()]
|
||||
|
||||
|
||||
class AddDocumentForm(FlaskForm):
|
||||
|
||||
@@ -14,7 +14,7 @@ from urllib.parse import urlparse, unquote
|
||||
import io
|
||||
import json
|
||||
|
||||
from common.models.document import Document, DocumentVersion, Catalog
|
||||
from common.models.document import Document, DocumentVersion, Catalog, Retriever
|
||||
from common.extensions import db, minio_client
|
||||
from common.utils.document_utils import validate_file_type, create_document_stack, start_embedding_task, process_url, \
|
||||
process_multiple_urls, get_documents_list, edit_document, \
|
||||
@@ -22,7 +22,7 @@ from common.utils.document_utils import validate_file_type, create_document_stac
|
||||
from common.utils.eveai_exceptions import EveAIInvalidLanguageException, EveAIUnsupportedFileType, \
|
||||
EveAIDoubleURLException
|
||||
from .document_forms import AddDocumentForm, AddURLForm, EditDocumentForm, EditDocumentVersionForm, AddURLsForm, \
|
||||
CatalogForm
|
||||
CatalogForm, RetrieverForm, EditRetrieverForm
|
||||
from common.utils.middleware import mw_before_request
|
||||
from common.utils.celery_utils import current_celery
|
||||
from common.utils.nginx_utils import prefixed_url_for
|
||||
@@ -30,6 +30,8 @@ from common.utils.view_assistants import form_validation_failed, prepare_table_f
|
||||
from .document_list_view import DocumentListView
|
||||
from .document_version_list_view import DocumentVersionListView
|
||||
|
||||
from config.retriever_types import RETRIEVER_TYPES
|
||||
|
||||
document_bp = Blueprint('document_bp', __name__, url_prefix='/document')
|
||||
|
||||
|
||||
@@ -65,6 +67,7 @@ def catalog():
|
||||
tenant_id = session.get('tenant').get('id')
|
||||
new_catalog = Catalog()
|
||||
form.populate_obj(new_catalog)
|
||||
new_catalog.parent_id = form.parent.data.get('id')
|
||||
# Handle Embedding Variables
|
||||
new_catalog.html_tags = [tag.strip() for tag in form.html_tags.data.split(',')] if form.html_tags.data else []
|
||||
new_catalog.html_end_tags = [tag.strip() for tag in form.html_end_tags.data.split(',')] \
|
||||
@@ -103,7 +106,7 @@ def catalogs():
|
||||
the_catalogs = pagination.items
|
||||
|
||||
# prepare table data
|
||||
rows = prepare_table_for_macro(the_catalogs, [('id', ''), ('name', '')])
|
||||
rows = prepare_table_for_macro(the_catalogs, [('id', ''), ('name', ''), ('type', '')])
|
||||
|
||||
# Render the catalogs in a template
|
||||
return render_template('document/catalogs.html', rows=rows, pagination=pagination)
|
||||
@@ -173,6 +176,121 @@ def edit_catalog(catalog_id):
|
||||
return render_template('document/edit_catalog.html', form=form, catalog_id=catalog_id)
|
||||
|
||||
|
||||
@document_bp.route('/retriever', methods=['GET', 'POST'])
|
||||
@roles_accepted('Super User', 'Tenant Admin')
|
||||
def retriever():
|
||||
form = RetrieverForm()
|
||||
|
||||
if form.validate_on_submit():
|
||||
tenant_id = session.get('tenant').get('id')
|
||||
new_retriever = Retriever()
|
||||
form.populate_obj(new_retriever)
|
||||
new_retriever.catalog_id = form.catalog.data.id
|
||||
|
||||
set_logging_information(new_retriever, dt.now(tz.utc))
|
||||
|
||||
try:
|
||||
db.session.add(new_retriever)
|
||||
db.session.commit()
|
||||
flash('Retriever successfully added!', 'success')
|
||||
current_app.logger.info(f'Catalog {new_retriever.name} successfully added for tenant {tenant_id}!')
|
||||
except SQLAlchemyError as e:
|
||||
db.session.rollback()
|
||||
flash(f'Failed to add retriever. Error: {e}', 'danger')
|
||||
current_app.logger.error(f'Failed to add retriever {new_retriever.name}'
|
||||
f'for tenant {tenant_id}. Error: {str(e)}')
|
||||
|
||||
# Enable step 2 of creation of retriever - add configuration of the retriever (dependent on type)
|
||||
return redirect(prefixed_url_for('document_bp.retriever', retriever_id=new_retriever.id))
|
||||
|
||||
return render_template('document/retriever.html', form=form)
|
||||
|
||||
|
||||
@document_bp.route('/retriever/<int:retriever_id>', methods=['GET', 'POST'])
|
||||
@roles_accepted('Super User', 'Tenant Admin')
|
||||
def edit_retriever(retriever_id):
|
||||
"""Edit an existing retriever configuration."""
|
||||
current_app.logger.debug(f"Editing Retriever {retriever_id}")
|
||||
|
||||
# Get the retriever or return 404
|
||||
retriever = Retriever.query.get_or_404(retriever_id)
|
||||
|
||||
if retriever.catalog_id:
|
||||
# If catalog_id is just an ID, fetch the Catalog object
|
||||
retriever.catalog = Catalog.query.get(retriever.catalog_id)
|
||||
else:
|
||||
retriever.catalog = None
|
||||
|
||||
# Create form instance with the retriever
|
||||
form = EditRetrieverForm(request.form, obj=retriever)
|
||||
|
||||
configuration_config = RETRIEVER_TYPES[retriever.type]["configuration"]
|
||||
current_app.logger.debug(f"Configuration {configuration_config}")
|
||||
form.add_dynamic_fields("configuration", configuration_config, retriever.configuration)
|
||||
|
||||
if form.validate_on_submit():
|
||||
# Update basic fields
|
||||
form.populate_obj(retriever)
|
||||
retriever.configuration = form.get_dynamic_data('configuration')
|
||||
|
||||
# Update catalog relationship
|
||||
retriever.catalog_id = form.catalog.data.id if form.catalog.data else None
|
||||
|
||||
# Update logging information
|
||||
update_logging_information(retriever, dt.now(tz.utc))
|
||||
|
||||
# Save changes to database
|
||||
try:
|
||||
db.session.add(retriever)
|
||||
db.session.commit()
|
||||
flash('Retriever updated successfully!', 'success')
|
||||
current_app.logger.info(f'Retriever {retriever.id} updated successfully')
|
||||
except SQLAlchemyError as e:
|
||||
db.session.rollback()
|
||||
flash(f'Failed to update retriever. Error: {str(e)}', 'danger')
|
||||
current_app.logger.error(f'Failed to update retriever {retriever_id}. Error: {str(e)}')
|
||||
return render_template('document/edit_retriever.html', form=form, retriever_id=retriever_id)
|
||||
|
||||
return redirect(prefixed_url_for('document_bp.retrievers'))
|
||||
else:
|
||||
form_validation_failed(request, form)
|
||||
|
||||
current_app.logger.debug(f"Rendering Template for {retriever_id}")
|
||||
return render_template('document/edit_retriever.html', form=form, retriever_id=retriever_id)
|
||||
|
||||
|
||||
@document_bp.route('/retrievers', methods=['GET', 'POST'])
|
||||
@roles_accepted('Super User', 'Tenant Admin')
|
||||
def retrievers():
|
||||
page = request.args.get('page', 1, type=int)
|
||||
per_page = request.args.get('per_page', 10, type=int)
|
||||
|
||||
query = Retriever.query.order_by(Retriever.id)
|
||||
|
||||
pagination = query.paginate(page=page, per_page=per_page)
|
||||
the_retrievers = pagination.items
|
||||
|
||||
# prepare table data
|
||||
rows = prepare_table_for_macro(the_retrievers,
|
||||
[('id', ''), ('name', ''), ('type', ''), ('catalog_id', '')])
|
||||
|
||||
# Render the catalogs in a template
|
||||
return render_template('document/retrievers.html', rows=rows, pagination=pagination)
|
||||
|
||||
|
||||
@document_bp.route('/handle_retriever_selection', methods=['POST'])
|
||||
@roles_accepted('Super User', 'Tenant Admin')
|
||||
def handle_retriever_selection():
|
||||
retriever_identification = request.form.get('selected_row')
|
||||
retriever_id = ast.literal_eval(retriever_identification).get('value')
|
||||
action = request.form['action']
|
||||
|
||||
if action == 'edit_retriever':
|
||||
return redirect(prefixed_url_for('document_bp.edit_retriever', retriever_id=retriever_id))
|
||||
|
||||
return redirect(prefixed_url_for('document_bp.retrievers'))
|
||||
|
||||
|
||||
@document_bp.route('/add_document', methods=['GET', 'POST'])
|
||||
@roles_accepted('Super User', 'Tenant Admin')
|
||||
def add_document():
|
||||
|
||||
92
eveai_app/views/dynamic_form_base.py
Normal file
92
eveai_app/views/dynamic_form_base.py
Normal file
@@ -0,0 +1,92 @@
|
||||
from flask_wtf import FlaskForm
|
||||
from wtforms import IntegerField, FloatField, BooleanField, StringField, validators
|
||||
from flask import current_app
|
||||
|
||||
class DynamicFormBase(FlaskForm):
|
||||
def __init__(self, formdata=None, *args, **kwargs):
|
||||
super(DynamicFormBase, self).__init__(*args, **kwargs)
|
||||
# Maps collection names to lists of field names
|
||||
self.dynamic_fields = {}
|
||||
# Store formdata for later use
|
||||
self.formdata = formdata
|
||||
|
||||
def add_dynamic_fields(self, collection_name, config, initial_data=None):
|
||||
"""Add dynamic fields to the form based on the configuration."""
|
||||
self.dynamic_fields[collection_name] = []
|
||||
for field_name, field_def in config.items():
|
||||
current_app.logger.debug(f"{field_name}: {field_def}")
|
||||
# Prefix the field name with the collection name
|
||||
full_field_name = f"{collection_name}_{field_name}"
|
||||
field_type = field_def.get('type')
|
||||
description = field_def.get('description', '')
|
||||
required = field_def.get('required', False)
|
||||
default = field_def.get('default')
|
||||
|
||||
# Determine validators
|
||||
field_validators = [validators.InputRequired()] if required else [validators.Optional()]
|
||||
|
||||
# Map the field type to WTForms field classes
|
||||
field_class = {
|
||||
'int': IntegerField,
|
||||
'float': FloatField,
|
||||
'boolean': BooleanField,
|
||||
'string': StringField,
|
||||
}.get(field_type, StringField)
|
||||
|
||||
# Create the field instance
|
||||
unbound_field = field_class(
|
||||
label=description,
|
||||
validators=field_validators,
|
||||
default=default
|
||||
)
|
||||
|
||||
# Bind the field to the form
|
||||
bound_field = unbound_field.bind(form=self, name=full_field_name)
|
||||
|
||||
# Process the field with formdata
|
||||
if self.formdata and full_field_name in self.formdata:
|
||||
# If formdata is available and contains the field
|
||||
bound_field.process(self.formdata)
|
||||
elif initial_data and field_name in initial_data:
|
||||
# Use initial data if provided
|
||||
bound_field.process(formdata=None, data=initial_data[field_name])
|
||||
else:
|
||||
# Use default value
|
||||
bound_field.process(formdata=None, data=default)
|
||||
|
||||
# Set collection name attribute for identification
|
||||
# bound_field.collection_name = collection_name
|
||||
|
||||
# Add the field to the form
|
||||
setattr(self, full_field_name, bound_field)
|
||||
self._fields[full_field_name] = bound_field
|
||||
self.dynamic_fields[collection_name].append(full_field_name)
|
||||
|
||||
def get_static_fields(self):
|
||||
"""Return a list of static field instances."""
|
||||
# Get names of dynamic fields
|
||||
dynamic_field_names = set()
|
||||
for field_list in self.dynamic_fields.values():
|
||||
dynamic_field_names.update(field_list)
|
||||
|
||||
# Return all fields that are not dynamic
|
||||
return [field for name, field in self._fields.items() if name not in dynamic_field_names]
|
||||
|
||||
def get_dynamic_fields(self):
|
||||
"""Return a dictionary of dynamic fields per collection."""
|
||||
result = {}
|
||||
for collection_name, field_names in self.dynamic_fields.items():
|
||||
result[collection_name] = [getattr(self, name) for name in field_names]
|
||||
return result
|
||||
|
||||
def get_dynamic_data(self, collection_name):
|
||||
"""Retrieve the data from dynamic fields of a specific collection."""
|
||||
data = {}
|
||||
if collection_name not in self.dynamic_fields:
|
||||
return data
|
||||
prefix_length = len(collection_name) + 1 # +1 for the underscore
|
||||
for full_field_name in self.dynamic_fields[collection_name]:
|
||||
original_field_name = full_field_name[prefix_length:]
|
||||
field = getattr(self, full_field_name)
|
||||
data[original_field_name] = field.data
|
||||
return data
|
||||
Reference in New Issue
Block a user