diff --git a/common/models/document.py b/common/models/document.py
index a904f46..89f6edb 100644
--- a/common/models/document.py
+++ b/common/models/document.py
@@ -10,7 +10,7 @@ class Catalog(db.Model):
id = db.Column(db.Integer, primary_key=True)
name = db.Column(db.String(50), nullable=False)
description = db.Column(db.Text, nullable=True)
- type = db.Column(db.String(50), nullable=False, default="DEFAULT_CATALOG")
+ type = db.Column(db.String(50), nullable=False, default="STANDARD_CATALOG")
# Embedding variables
html_tags = db.Column(ARRAY(sa.String(10)), nullable=True, default=['p', 'h1', 'h2', 'h3', 'h4', 'h5', 'h6', 'li'])
@@ -46,7 +46,7 @@ class Retriever(db.Model):
name = db.Column(db.String(50), nullable=False)
description = db.Column(db.Text, nullable=True)
catalog_id = db.Column(db.Integer, db.ForeignKey('catalog.id'), nullable=True)
- type = db.Column(db.String(50), nullable=False, default="DEFAULT_RAG")
+ type = db.Column(db.String(50), nullable=False, default="STANDARD_RAG")
tuning = db.Column(db.Boolean, nullable=True, default=False)
# Meta Data
diff --git a/common/models/interaction.py b/common/models/interaction.py
index 450eb96..c6f076c 100644
--- a/common/models/interaction.py
+++ b/common/models/interaction.py
@@ -50,12 +50,25 @@ class InteractionEmbedding(db.Model):
class Specialist(db.Model):
id = db.Column(db.Integer, primary_key=True)
- name = db.Column(db.String(20), nullable=False)
+ name = db.Column(db.String(50), nullable=False)
+ description = db.Column(db.Text, nullable=True)
+ type = db.Column(db.String(50), nullable=False, default="STANDARD_RAG")
tuning = db.Column(db.Boolean, nullable=True, default=False)
configuration = db.Column(JSONB, nullable=True)
arguments = db.Column(JSONB, nullable=True)
+ # Relationship to retrievers through the association table
+ retrievers = db.relationship('SpecialistRetriever', backref='specialist', lazy=True,
+ cascade="all, delete-orphan")
+
+ # Versioning Information
+ created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now())
+ created_by = db.Column(db.Integer, db.ForeignKey(User.id), nullable=True)
+ updated_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now(), onupdate=db.func.now())
+ updated_by = db.Column(db.Integer, db.ForeignKey(User.id))
+
class SpecialistRetriever(db.Model):
specialist_id = db.Column(db.Integer, db.ForeignKey(Specialist.id, ondelete='CASCADE'), primary_key=True)
retriever_id = db.Column(db.Integer, db.ForeignKey(Retriever.id, ondelete='CASCADE'), primary_key=True)
+
diff --git a/config/catalog_types.py b/config/catalog_types.py
index a5c2cc5..6155a60 100644
--- a/config/catalog_types.py
+++ b/config/catalog_types.py
@@ -1,7 +1,7 @@
# Catalog Types
CATALOG_TYPES = {
- "STANDARD": {
- "name": "Default Catalog",
+ "STANDARD_CATALOG": {
+ "name": "Standard Catalog",
"Description": "A Catalog with information in Evie's Library, to be considered as a whole",
"configuration": {}
},
diff --git a/config/retriever_types.py b/config/retriever_types.py
index 92673da..8a743c2 100644
--- a/config/retriever_types.py
+++ b/config/retriever_types.py
@@ -1,7 +1,7 @@
# Retriever Types
RETRIEVER_TYPES = {
"STANDARD_RAG": {
- "name": "Default RAG",
+ "name": "Standard RAG Retriever",
"description": "Retrieving all embeddings conform the query",
"configuration": {
"es_k": {
diff --git a/eveai_app/templates/document/retrievers.html b/eveai_app/templates/document/retrievers.html
index cbc44b0..3490552 100644
--- a/eveai_app/templates/document/retrievers.html
+++ b/eveai_app/templates/document/retrievers.html
@@ -4,13 +4,13 @@
{% block title %}Retrievers{% endblock %}
{% block content_title %}Retrievers{% endblock %}
-{% block content_description %}View Retrieers for Tenant{% endblock %}
+{% block content_description %}View Retrievers for Tenant{% endblock %}
{% block content_class %}
{% endblock %}
{% block content %}
+{% endblock %}
+
+{% block content_footer %}
+
+{% endblock %}
diff --git a/eveai_app/templates/interaction/specialist.html b/eveai_app/templates/interaction/specialist.html
new file mode 100644
index 0000000..a386bc6
--- /dev/null
+++ b/eveai_app/templates/interaction/specialist.html
@@ -0,0 +1,23 @@
+{% extends 'base.html' %}
+{% from "macros.html" import render_field %}
+
+{% block title %}Specialist Registration{% endblock %}
+
+{% block content_title %}Register Specialist{% endblock %}
+{% block content_description %}Define a new specialist{% endblock %}
+
+{% block content %}
+
+{% endblock %}
+
+{% block content_footer %}
+
+{% endblock %}
diff --git a/eveai_app/templates/interaction/specialists.html b/eveai_app/templates/interaction/specialists.html
new file mode 100644
index 0000000..574362f
--- /dev/null
+++ b/eveai_app/templates/interaction/specialists.html
@@ -0,0 +1,23 @@
+{% extends 'base.html' %}
+{% from 'macros.html' import render_selectable_table, render_pagination %}
+
+{% block title %}Retrievers{% endblock %}
+
+{% block content_title %}Specialists{% endblock %}
+{% block content_description %}View Specialists for Tenant{% endblock %}
+{% block content_class %}
{% endblock %}
+
+{% block content %}
+
+{% endblock %}
+
+{% block content_footer %}
+ {{ render_pagination(pagination, 'document_bp.retrievers') }}
+{% endblock %}
\ No newline at end of file
diff --git a/eveai_app/templates/navbar.html b/eveai_app/templates/navbar.html
index 16dbd1a..2d48b48 100644
--- a/eveai_app/templates/navbar.html
+++ b/eveai_app/templates/navbar.html
@@ -95,6 +95,8 @@
{% endif %}
{% if current_user.is_authenticated %}
{{ dropdown('Interactions', 'hub', [
+ {'name': 'Add Specialist', 'url': '/interaction/specialist', 'roles': ['Super User', 'Tenant Admin']},
+ {'name': 'All Specialists', 'url': '/interaction/specialists', 'roles': ['Super User', 'Tenant Admin']},
{'name': 'Chat Sessions', 'url': '/interaction/chat_sessions', 'roles': ['Super User', 'Tenant Admin']},
]) }}
{% endif %}
diff --git a/eveai_app/views/document_forms.py b/eveai_app/views/document_forms.py
index 021bd40..344ac13 100644
--- a/eveai_app/views/document_forms.py
+++ b/eveai_app/views/document_forms.py
@@ -8,7 +8,6 @@ import json
from wtforms_sqlalchemy.fields import QuerySelectField
-from common.extensions import db
from common.models.document import Catalog
from config.catalog_types import CATALOG_TYPES
@@ -42,7 +41,6 @@ class CatalogForm(FlaskForm):
# Metadata fields
user_metadata = TextAreaField('User Metadata', validators=[Optional(), validate_json])
system_metadata = TextAreaField('System Metadata', validators=[Optional(), validate_json])
- configuration = TextAreaField('Configuration', validators=[Optional(), validate_json])
# HTML Embedding Variables
html_tags = StringField('HTML Tags', validators=[DataRequired()],
diff --git a/eveai_app/views/document_views.py b/eveai_app/views/document_views.py
index 06de1ad..6323381 100644
--- a/eveai_app/views/document_views.py
+++ b/eveai_app/views/document_views.py
@@ -228,8 +228,6 @@ def edit_retriever(retriever_id):
configuration_config = RETRIEVER_TYPES[retriever.type]["configuration"]
form.add_dynamic_fields("configuration", configuration_config, retriever.configuration)
- if request.method == 'POST':
- current_app.logger.debug(f'Received POST request with {request.form}')
if form.validate_on_submit():
# Update basic fields
@@ -258,7 +256,6 @@ def edit_retriever(retriever_id):
else:
form_validation_failed(request, form)
- current_app.logger.debug(f"Rendering Template for {retriever_id}")
return render_template('document/edit_retriever.html', form=form, retriever_id=retriever_id)
diff --git a/eveai_app/views/interaction_forms.py b/eveai_app/views/interaction_forms.py
new file mode 100644
index 0000000..f75d2ea
--- /dev/null
+++ b/eveai_app/views/interaction_forms.py
@@ -0,0 +1,60 @@
+from flask import session, current_app, request
+from flask_wtf import FlaskForm
+from wtforms import (StringField, BooleanField, SubmitField, DateField, IntegerField, FloatField, SelectMultipleField,
+ SelectField, FieldList, FormField, TextAreaField, URLField)
+from wtforms.validators import DataRequired, Length, Optional, URL, ValidationError, NumberRange
+from flask_wtf.file import FileField, FileAllowed, FileRequired
+import json
+
+from wtforms_sqlalchemy.fields import QuerySelectField, QuerySelectMultipleField
+
+from common.models.document import Retriever
+
+from config.catalog_types import CATALOG_TYPES
+from config.specialist_types import SPECIALIST_TYPES
+from .dynamic_form_base import DynamicFormBase
+
+
+def get_retrievers():
+ return Retriever.query.all()
+
+
+class SpecialistForm(FlaskForm):
+ name = StringField('Name', validators=[DataRequired(), Length(max=50)])
+ description = TextAreaField('Description', validators=[DataRequired()])
+
+ retrievers = QuerySelectMultipleField(
+ 'Retrievers',
+ query_factory=get_retrievers,
+ get_label='name', # Assuming your Retriever model has a 'name' field
+ allow_blank=True,
+ description='Select one or more retrievers to associate with this specialist'
+ )
+
+ type = SelectField('Specialist Type', validators=[DataRequired()])
+
+ tuning = BooleanField('Enable Retrieval Tuning', default=False)
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ # Dynamically populate the 'type' field using the constructor
+ self.type.choices = [(key, value['name']) for key, value in SPECIALIST_TYPES.items()]
+
+
+class EditSpecialistForm(DynamicFormBase):
+ name = StringField('Name', validators=[DataRequired()])
+ description = TextAreaField('Description', validators=[DataRequired()])
+
+ retrievers = QuerySelectMultipleField(
+ 'Retrievers',
+ query_factory=get_retrievers,
+ get_label='name',
+ allow_blank=True,
+ description='Select one or more retrievers to associate with this specialist'
+ )
+
+ type = StringField('Specialist Type', validators=[DataRequired()], render_kw={'readonly': True})
+ tuning = BooleanField('Enable Retrieval Tuning', default=False)
+
+
+
diff --git a/eveai_app/views/interaction_views.py b/eveai_app/views/interaction_views.py
index ec6a1c0..415e368 100644
--- a/eveai_app/views/interaction_views.py
+++ b/eveai_app/views/interaction_views.py
@@ -15,14 +15,17 @@ from requests.exceptions import SSLError
from urllib.parse import urlparse
import io
-from common.models.document import Embedding, DocumentVersion
-from common.models.interaction import ChatSession, Interaction, InteractionEmbedding
+from common.models.document import Embedding, DocumentVersion, Retriever
+from common.models.interaction import ChatSession, Interaction, InteractionEmbedding, Specialist, SpecialistRetriever
from common.extensions import db
+from common.utils.document_utils import set_logging_information, update_logging_information
+from config.specialist_types import SPECIALIST_TYPES
from .document_forms import AddDocumentForm, AddURLForm, EditDocumentForm, EditDocumentVersionForm
from common.utils.middleware import mw_before_request
from common.utils.celery_utils import current_celery
from common.utils.nginx_utils import prefixed_url_for
from common.utils.view_assistants import form_validation_failed, prepare_table_for_macro
+from .interaction_forms import SpecialistForm, EditSpecialistForm
interaction_bp = Blueprint('interaction_bp', __name__, url_prefix='/interaction')
@@ -122,3 +125,150 @@ def show_chat_session(chat_session):
interactions = Interaction.query.filter_by(chat_session_id=chat_session.id).all()
return render_template('interaction/view_chat_session.html', chat_session=chat_session, interactions=interactions)
+
+@interaction_bp.route('/specialist', methods=['GET', 'POST'])
+@roles_accepted('Super User', 'Tenant Admin')
+def specialist():
+ form = SpecialistForm()
+
+ if form.validate_on_submit():
+ tenant_id = session.get('tenant').get('id')
+ try:
+ new_specialist = Specialist()
+
+ # Populate fields individually instead of using populate_obj (gives problem with QueryMultipleSelectField)
+ new_specialist.name = form.name.data
+ new_specialist.description = form.description.data
+ new_specialist.type = form.type.data
+ new_specialist.tuning = form.tuning.data
+
+ set_logging_information(new_specialist, dt.now(tz.utc))
+
+ db.session.add(new_specialist)
+ db.session.flush() # This assigns the ID to the specialist without committing the transaction
+
+ current_app.logger.debug(
+ f'New Specialist after flush - id: {new_specialist.id}, name: {new_specialist.name}')
+
+ # Create the retriever associations
+ selected_retrievers = form.retrievers.data
+ current_app.logger.debug(f'Selected Retrievers - {selected_retrievers}')
+ for retriever in selected_retrievers:
+ current_app.logger.debug(f'Creating association for Retriever - {retriever.id}')
+ specialist_retriever = SpecialistRetriever(
+ specialist_id=new_specialist.id,
+ retriever_id=retriever.id
+ )
+ db.session.add(specialist_retriever)
+
+ # Commit everything in one transaction
+ db.session.commit()
+ flash('Specialist successfully added!', 'success')
+ current_app.logger.info(f'Specialist {new_specialist.name} successfully added for tenant {tenant_id}!')
+
+ return redirect(prefixed_url_for('interaction_bp.edit_specialist', specialist_id=new_specialist.id))
+
+ except Exception as e:
+ db.session.rollback()
+ current_app.logger.error(f'Failed to add specialist. Error: {str(e)}', exc_info=True)
+ flash(f'Failed to add specialist. Error: {str(e)}', 'danger')
+ return render_template('interaction/specialist.html', form=form)
+
+ return render_template('interaction/specialists.html', form=form)
+
+
+@interaction_bp.route('/specialist/
', methods=['GET', 'POST'])
+@roles_accepted('Super User', 'Tenant Admin')
+def edit_specialist(specialist_id):
+ specialist = Specialist.query.get_or_404(specialist_id)
+ form = EditSpecialistForm(request.form, obj=specialist)
+
+ configuration_config = SPECIALIST_TYPES[specialist.type]["configuration"]
+ form.add_dynamic_fields("configuration", configuration_config, specialist.configuration)
+
+ if request.method == 'GET':
+ # Pre-populate the retrievers field with current associations
+ current_app.logger.debug(f'Specialist retrievers: {specialist.retrievers}')
+ current_app.logger.debug(f'Form Retrievers Data Before: {form.retrievers.data}')
+
+ # Get the actual Retriever objects for the associated retriever_ids
+ retriever_objects = Retriever.query.filter(
+ Retriever.id.in_([sr.retriever_id for sr in specialist.retrievers])
+ ).all()
+ form.retrievers.data = retriever_objects
+
+ current_app.logger.debug(f'Form Retrievers Data After: {form.retrievers.data}')
+
+ if form.validate_on_submit():
+ # Update the basic fields
+ form.populate_obj(specialist)
+ # Update the configuration dynamic fields
+ specialist.configuration = form.get_dynamic_data("configuration")
+
+ # Update retriever associations
+ current_retrievers = set(sr.retriever_id for sr in specialist.retrievers)
+ selected_retrievers = set(r.id for r in form.retrievers.data)
+
+ # Remove unselected retrievers
+ for sr in specialist.retrievers[:]:
+ if sr.retriever_id not in selected_retrievers:
+ db.session.delete(sr)
+
+ # Add new retrievers
+ for retriever_id in selected_retrievers - current_retrievers:
+ specialist_retriever = SpecialistRetriever(
+ specialist_id=specialist.id,
+ retriever_id=retriever_id
+ )
+ db.session.add(specialist_retriever)
+
+ # Update logging information
+ update_logging_information(specialist, dt.now(tz.utc))
+
+ try:
+ db.session.commit()
+ flash('Specialist updated successfully!', 'success')
+ current_app.logger.info(f'Specialist {specialist.id} updated successfully')
+ except SQLAlchemyError as e:
+ db.session.rollback()
+ flash(f'Failed to update specialist. Error: {str(e)}', 'danger')
+ current_app.logger.error(f'Failed to update specialist {specialist_id}. Error: {str(e)}')
+ return render_template('interaction/edit_specialist.html', form=form, specialist_id=specialist_id)
+
+ return redirect(prefixed_url_for('interaction_bp.specialists'))
+ else:
+ form_validation_failed(request, form)
+
+ return render_template('interaction/edit_specialist.html', form=form, specialist_id=specialist_id)
+
+
+@interaction_bp.route('/specialists', methods=['GET', 'POST'])
+@roles_accepted('Super User', 'Tenant Admin')
+def specialists():
+ page = request.args.get('page', 1, type=int)
+ per_page = request.args.get('per_page', 10, type=int)
+
+ query = Specialist.query.order_by(Specialist.id)
+
+ pagination = query.paginate(page=page, per_page=per_page)
+ the_specialists = pagination.items
+
+ # prepare table data
+ rows = prepare_table_for_macro(the_specialists,
+ [('id', ''), ('name', ''), ('type', '')])
+
+ # Render the catalogs in a template
+ return render_template('interaction/specialists.html', rows=rows, pagination=pagination)
+
+
+@interaction_bp.route('/handle_specialist_selection', methods=['POST'])
+@roles_accepted('Super User', 'Tenant Admin')
+def handle_specialist_selection():
+ specialist_identification = request.form.get('selected_row')
+ specialist_id = ast.literal_eval(specialist_identification).get('value')
+ action = request.form.get('action')
+
+ if action == "edit_specialist":
+ return redirect(prefixed_url_for('interaction_bp.edit_specialist', specialist_id=specialist_id))
+
+ return redirect(prefixed_url_for('interaction_bp.specialists'))
diff --git a/migrations/tenant/versions/03385f2000da_update_specialists.py b/migrations/tenant/versions/03385f2000da_update_specialists.py
new file mode 100644
index 0000000..eba8073
--- /dev/null
+++ b/migrations/tenant/versions/03385f2000da_update_specialists.py
@@ -0,0 +1,29 @@
+"""Update Retriever Arguments
+
+Revision ID: 03385f2000da
+Revises: e476c2013352
+Create Date: 2024-11-04 08:33:46.158354
+
+"""
+from alembic import op
+import sqlalchemy as sa
+import pgvector
+from sqlalchemy.dialects import postgresql
+
+# revision identifiers, used by Alembic.
+revision = '03385f2000da'
+down_revision = 'e476c2013352'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.add_column('retriever', sa.Column('arguments', postgresql.JSONB(astext_type=sa.Text()), nullable=True))
+ # ### end Alembic commands ###
+
+
+def downgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.drop_column('retriever', 'arguments')
+ # ### end Alembic commands ###
diff --git a/migrations/tenant/versions/0c347651837c_specialist_name_to_50_characters.py b/migrations/tenant/versions/0c347651837c_specialist_name_to_50_characters.py
new file mode 100644
index 0000000..010f26e
--- /dev/null
+++ b/migrations/tenant/versions/0c347651837c_specialist_name_to_50_characters.py
@@ -0,0 +1,35 @@
+"""Specialist name to 50 characters
+
+Revision ID: 0c347651837c
+Revises: 03385f2000da
+Create Date: 2024-11-04 09:25:17.476384
+
+"""
+from alembic import op
+import sqlalchemy as sa
+import pgvector
+
+
+# revision identifiers, used by Alembic.
+revision = '0c347651837c'
+down_revision = '03385f2000da'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.alter_column('specialist', 'name',
+ existing_type=sa.VARCHAR(length=20),
+ type_=sa.String(length=50),
+ existing_nullable=False)
+ # ### end Alembic commands ###
+
+
+def downgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.alter_column('specialist', 'name',
+ existing_type=sa.String(length=50),
+ type_=sa.VARCHAR(length=20),
+ existing_nullable=False)
+ # ### end Alembic commands ###
diff --git a/migrations/tenant/versions/e476c2013352_introducing_specialists.py b/migrations/tenant/versions/e476c2013352_introducing_specialists.py
new file mode 100644
index 0000000..b6fd40c
--- /dev/null
+++ b/migrations/tenant/versions/e476c2013352_introducing_specialists.py
@@ -0,0 +1,52 @@
+"""Introducing Specialists
+
+Revision ID: e476c2013352
+Revises: 331f8100eb87
+Create Date: 2024-11-04 08:08:19.737409
+
+"""
+from alembic import op
+import sqlalchemy as sa
+import pgvector
+from sqlalchemy.dialects import postgresql
+
+# revision identifiers, used by Alembic.
+revision = 'e476c2013352'
+down_revision = '331f8100eb87'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.create_table('specialist',
+ sa.Column('id', sa.Integer(), nullable=False),
+ sa.Column('name', sa.String(length=20), nullable=False),
+ sa.Column('description', sa.Text(), nullable=True),
+ sa.Column('type', sa.String(length=50), nullable=False),
+ sa.Column('tuning', sa.Boolean(), nullable=True),
+ sa.Column('configuration', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
+ sa.Column('arguments', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
+ sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=False),
+ sa.Column('created_by', sa.Integer(), nullable=True),
+ sa.Column('updated_at', sa.DateTime(), server_default=sa.text('now()'), nullable=False),
+ sa.Column('updated_by', sa.Integer(), nullable=True),
+ sa.ForeignKeyConstraint(['created_by'], ['public.user.id'], ),
+ sa.ForeignKeyConstraint(['updated_by'], ['public.user.id'], ),
+ sa.PrimaryKeyConstraint('id')
+ )
+ op.create_table('specialist_retriever',
+ sa.Column('specialist_id', sa.Integer(), nullable=False),
+ sa.Column('retriever_id', sa.Integer(), nullable=False),
+ sa.ForeignKeyConstraint(['retriever_id'], ['retriever.id'], ondelete='CASCADE'),
+ sa.ForeignKeyConstraint(['specialist_id'], ['specialist.id'], ondelete='CASCADE'),
+ sa.PrimaryKeyConstraint('specialist_id', 'retriever_id')
+ )
+ # ### end Alembic commands ###
+
+
+def downgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.drop_table('specialist_retriever')
+ op.drop_table('specialist')
+ # ### end Alembic commands ###