- Tuning moved to Retriever iso in the configuration, as this is an attribute that should be available for all types of Retrievers
219 lines
11 KiB
Python
219 lines
11 KiB
Python
from flask import session, current_app, request
|
|
from flask_wtf import FlaskForm
|
|
from wtforms import (StringField, BooleanField, SubmitField, DateField, IntegerField, FloatField, SelectMultipleField,
|
|
SelectField, FieldList, FormField, TextAreaField, URLField)
|
|
from wtforms.validators import DataRequired, Length, Optional, URL, ValidationError, NumberRange
|
|
from flask_wtf.file import FileField, FileAllowed, FileRequired
|
|
import json
|
|
|
|
from wtforms_sqlalchemy.fields import QuerySelectField
|
|
|
|
from common.extensions import db
|
|
from common.models.document import Catalog
|
|
|
|
from config.catalog_types import CATALOG_TYPES
|
|
from config.retriever_types import RETRIEVER_TYPES
|
|
from .dynamic_form_base import DynamicFormBase
|
|
|
|
|
|
def allowed_file(form, field):
|
|
if field.data:
|
|
filename = field.data.filename
|
|
allowed_extensions = current_app.config.get('SUPPORTED_FILE_TYPES', [])
|
|
if not ('.' in filename and filename.rsplit('.', 1)[1].lower() in allowed_extensions):
|
|
raise ValidationError('Unsupported file type.')
|
|
|
|
|
|
def validate_json(form, field):
|
|
if field.data:
|
|
try:
|
|
json.loads(field.data)
|
|
except json.JSONDecodeError:
|
|
raise ValidationError('Invalid JSON format')
|
|
|
|
|
|
class CatalogForm(FlaskForm):
|
|
name = StringField('Name', validators=[DataRequired(), Length(max=50)])
|
|
description = TextAreaField('Description', validators=[Optional()])
|
|
|
|
# Select Field for Catalog Type (Uses the CATALOG_TYPES defined in config)
|
|
type = SelectField('Catalog Type', validators=[DataRequired()])
|
|
|
|
# Metadata fields
|
|
user_metadata = TextAreaField('User Metadata', validators=[Optional(), validate_json])
|
|
system_metadata = TextAreaField('System Metadata', validators=[Optional(), validate_json])
|
|
configuration = TextAreaField('Configuration', validators=[Optional(), validate_json])
|
|
|
|
# HTML Embedding Variables
|
|
html_tags = StringField('HTML Tags', validators=[DataRequired()],
|
|
default='p, h1, h2, h3, h4, h5, h6, li, , tbody, tr, td')
|
|
html_end_tags = StringField('HTML End Tags', validators=[DataRequired()],
|
|
default='p, li')
|
|
html_included_elements = StringField('HTML Included Elements', validators=[Optional()], default='article, main')
|
|
html_excluded_elements = StringField('HTML Excluded Elements', validators=[Optional()],
|
|
default='header, footer, nav, script')
|
|
html_excluded_classes = StringField('HTML Excluded Classes', validators=[Optional()])
|
|
min_chunk_size = IntegerField('Minimum Chunk Size (2000)', validators=[NumberRange(min=0), Optional()],
|
|
default=2000)
|
|
max_chunk_size = IntegerField('Maximum Chunk Size (3000)', validators=[NumberRange(min=0), Optional()],
|
|
default=3000)
|
|
# Chat Variables
|
|
chat_RAG_temperature = FloatField('RAG Temperature', default=0.3, validators=[NumberRange(min=0, max=1)])
|
|
chat_no_RAG_temperature = FloatField('No RAG Temperature', default=0.5, validators=[NumberRange(min=0, max=1)])
|
|
# Tuning variables
|
|
embed_tuning = BooleanField('Enable Embedding Tuning', default=False)
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
# Dynamically populate the 'type' field using the constructor
|
|
self.type.choices = [(key, value['name']) for key, value in CATALOG_TYPES.items()]
|
|
|
|
|
|
class EditCatalogForm(DynamicFormBase):
|
|
name = StringField('Name', validators=[DataRequired(), Length(max=50)])
|
|
description = TextAreaField('Description', validators=[Optional()])
|
|
|
|
# Select Field for Catalog Type (Uses the CATALOG_TYPES defined in config)
|
|
type = StringField('Catalog Type', validators=[DataRequired()], render_kw={'readonly': True})
|
|
|
|
# Metadata fields
|
|
user_metadata = TextAreaField('User Metadata', validators=[Optional(), validate_json])
|
|
system_metadata = TextAreaField('System Metadata', validators=[Optional(), validate_json],)
|
|
|
|
# HTML Embedding Variables
|
|
html_tags = StringField('HTML Tags', validators=[DataRequired()],
|
|
default='p, h1, h2, h3, h4, h5, h6, li, , tbody, tr, td')
|
|
html_end_tags = StringField('HTML End Tags', validators=[DataRequired()],
|
|
default='p, li')
|
|
html_included_elements = StringField('HTML Included Elements', validators=[Optional()], default='article, main')
|
|
html_excluded_elements = StringField('HTML Excluded Elements', validators=[Optional()],
|
|
default='header, footer, nav, script')
|
|
html_excluded_classes = StringField('HTML Excluded Classes', validators=[Optional()])
|
|
min_chunk_size = IntegerField('Minimum Chunk Size (2000)', validators=[NumberRange(min=0), Optional()],
|
|
default=2000)
|
|
max_chunk_size = IntegerField('Maximum Chunk Size (3000)', validators=[NumberRange(min=0), Optional()],
|
|
default=3000)
|
|
# Chat Variables
|
|
chat_RAG_temperature = FloatField('RAG Temperature', default=0.3, validators=[NumberRange(min=0, max=1)])
|
|
chat_no_RAG_temperature = FloatField('No RAG Temperature', default=0.5, validators=[NumberRange(min=0, max=1)])
|
|
# Tuning variables
|
|
embed_tuning = BooleanField('Enable Embedding Tuning', default=False)
|
|
|
|
|
|
class RetrieverForm(FlaskForm):
|
|
name = StringField('Name', validators=[DataRequired(), Length(max=50)])
|
|
description = TextAreaField('Description', validators=[Optional()])
|
|
# Catalog for the Retriever
|
|
catalog = QuerySelectField(
|
|
'Catalog ID',
|
|
query_factory=lambda: Catalog.query.all(),
|
|
allow_blank=True,
|
|
get_label='name',
|
|
validators=[Optional()],
|
|
)
|
|
# Select Field for Retriever Type (Uses the RETRIEVER_TYPES defined in config)
|
|
type = SelectField('Retriever Type', validators=[DataRequired()])
|
|
tuning = BooleanField('Enable Tuning', default=False)
|
|
|
|
# Metadata fields
|
|
user_metadata = TextAreaField('User Metadata', validators=[Optional(), validate_json])
|
|
system_metadata = TextAreaField('System Metadata', validators=[Optional(), validate_json])
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
# Dynamically populate the 'type' field using the constructor
|
|
self.type.choices = [(key, value['name']) for key, value in RETRIEVER_TYPES.items()]
|
|
|
|
|
|
class EditRetrieverForm(DynamicFormBase):
|
|
name = StringField('Name', validators=[DataRequired(), Length(max=50)])
|
|
description = TextAreaField('Description', validators=[Optional()])
|
|
# Catalog for the Retriever
|
|
catalog = QuerySelectField(
|
|
'Catalog ID',
|
|
query_factory=lambda: Catalog.query.all(),
|
|
allow_blank=True,
|
|
get_label='name',
|
|
validators=[Optional()],
|
|
)
|
|
# Select Field for Retriever Type (Uses the RETRIEVER_TYPES defined in config)
|
|
type = SelectField('Retriever Type', validators=[DataRequired()], render_kw={'readonly': True})
|
|
tuning = BooleanField('Enable Tuning', default=False)
|
|
|
|
# Metadata fields
|
|
user_metadata = TextAreaField('User Metadata', validators=[Optional(), validate_json])
|
|
system_metadata = TextAreaField('System Metadata', validators=[Optional(), validate_json])
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
|
|
# Set the retriever type choices (loaded from config)
|
|
self.type.choices = [(key, value['name']) for key, value in RETRIEVER_TYPES.items()]
|
|
|
|
|
|
class AddDocumentForm(DynamicFormBase):
|
|
file = FileField('File', validators=[FileRequired(), allowed_file])
|
|
name = StringField('Name', validators=[Length(max=100)])
|
|
language = SelectField('Language', choices=[], validators=[Optional()])
|
|
user_context = TextAreaField('User Context', validators=[Optional()])
|
|
valid_from = DateField('Valid from', id='form-control datepicker', validators=[Optional()])
|
|
user_metadata = TextAreaField('User Metadata', validators=[Optional(), validate_json])
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.language.choices = [(language, language) for language in
|
|
session.get('tenant').get('allowed_languages')]
|
|
if not self.language.data:
|
|
self.language.data = session.get('tenant').get('default_language')
|
|
|
|
|
|
class AddURLForm(DynamicFormBase):
|
|
url = URLField('URL', validators=[DataRequired(), URL()])
|
|
name = StringField('Name', validators=[Length(max=100)])
|
|
language = SelectField('Language', choices=[], validators=[Optional()])
|
|
user_context = TextAreaField('User Context', validators=[Optional()])
|
|
valid_from = DateField('Valid from', id='form-control datepicker', validators=[Optional()])
|
|
user_metadata = TextAreaField('User Metadata', validators=[Optional(), validate_json])
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.language.choices = [(language, language) for language in
|
|
session.get('tenant').get('allowed_languages')]
|
|
if not self.language.data:
|
|
self.language.data = session.get('tenant').get('default_language')
|
|
|
|
|
|
class AddURLsForm(FlaskForm):
|
|
urls = TextAreaField('URL(s) (one per line)', validators=[DataRequired()])
|
|
name = StringField('Name Prefix', validators=[Length(max=100)])
|
|
language = SelectField('Language', choices=[], validators=[Optional()])
|
|
user_context = TextAreaField('User Context', validators=[Optional()])
|
|
valid_from = DateField('Valid from', id='form-control datepicker', validators=[Optional()])
|
|
|
|
submit = SubmitField('Submit')
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.language.choices = [(language, language) for language in
|
|
session.get('tenant').get('allowed_languages')]
|
|
if not self.language.data:
|
|
self.language.data = session.get('tenant').get('default_language')
|
|
|
|
|
|
class EditDocumentForm(FlaskForm):
|
|
name = StringField('Name', validators=[Length(max=100)])
|
|
valid_from = DateField('Valid from', id='form-control datepicker', validators=[Optional()])
|
|
valid_to = DateField('Valid to', id='form-control datepicker', validators=[Optional()])
|
|
|
|
submit = SubmitField('Submit')
|
|
|
|
|
|
class EditDocumentVersionForm(DynamicFormBase):
|
|
language = StringField('Language')
|
|
user_context = TextAreaField('User Context', validators=[Optional()])
|
|
system_context = TextAreaField('System Context', validators=[Optional()])
|
|
user_metadata = TextAreaField('User Metadata', validators=[Optional(), validate_json])
|
|
system_metadata = TextAreaField('System Metadata', validators=[Optional(), validate_json])
|
|
|
|
submit = SubmitField('Submit')
|