diff --git a/common/langchain/eveai_default_rag_retriever.py b/common/langchain/eveai_default_rag_retriever.py
new file mode 100644
index 0000000..8d03587
--- /dev/null
+++ b/common/langchain/eveai_default_rag_retriever.py
@@ -0,0 +1,145 @@
+from langchain_core.retrievers import BaseRetriever
+from sqlalchemy import func, and_, or_, desc
+from sqlalchemy.exc import SQLAlchemyError
+from pydantic import BaseModel, Field, PrivateAttr
+from typing import Any, Dict
+from flask import current_app
+
+from common.extensions import db
+from common.models.document import Document, DocumentVersion
+from common.utils.datetime_utils import get_date_in_timezone
+from common.utils.model_utils import ModelVariables
+
+
+class EveAIDefaultRagRetriever(BaseRetriever, BaseModel):
+ _catalog_id: int = PrivateAttr()
+ _model_variables: ModelVariables = PrivateAttr()
+ _tenant_info: Dict[str, Any] = PrivateAttr()
+
+ def __init__(self, catalog_id: int, model_variables: ModelVariables, tenant_info: Dict[str, Any]):
+ super().__init__()
+ current_app.logger.debug(f'Model variables type: {type(model_variables)}')
+ self._catalog_id = catalog_id
+ self._model_variables = model_variables
+ self._tenant_info = tenant_info
+
+ @property
+ def catalog_id(self) -> int:
+ return self._catalog_id
+
+ @property
+ def model_variables(self) -> ModelVariables:
+ return self._model_variables
+
+ @property
+ def tenant_info(self) -> Dict[str, Any]:
+ return self._tenant_info
+
+ def _get_relevant_documents(self, query: str):
+ current_app.logger.debug(f'Retrieving relevant documents for query: {query}')
+ query_embedding = self._get_query_embedding(query)
+ current_app.logger.debug(f'Model Variables Private: {type(self._model_variables)}')
+ current_app.logger.debug(f'Model Variables Property: {type(self.model_variables)}')
+ db_class = self.model_variables['embedding_db_model']
+ similarity_threshold = self.model_variables['similarity_threshold']
+ k = self.model_variables['k']
+
+ if self.model_variables['rag_tuning']:
+ try:
+ current_date = get_date_in_timezone(self.tenant_info['timezone'])
+ current_app.rag_tuning_logger.debug(f'Current date: {current_date}\n')
+
+ # Debug query to show similarity for all valid documents (without chunk text)
+ debug_query = (
+ db.session.query(
+ Document.id.label('document_id'),
+ DocumentVersion.id.label('version_id'),
+ db_class.id.label('embedding_id'),
+ (1 - db_class.embedding.cosine_distance(query_embedding)).label('similarity')
+ )
+ .join(DocumentVersion, db_class.doc_vers_id == DocumentVersion.id)
+ .join(Document, DocumentVersion.doc_id == Document.id)
+ .filter(
+ or_(Document.valid_from.is_(None), func.date(Document.valid_from) <= current_date),
+ or_(Document.valid_to.is_(None), func.date(Document.valid_to) >= current_date)
+ )
+ .order_by(desc('similarity'))
+ )
+
+ debug_results = debug_query.all()
+
+ current_app.logger.debug("Debug: Similarity for all valid documents:")
+ for row in debug_results:
+ current_app.rag_tuning_logger.debug(f"Doc ID: {row.document_id}, "
+ f"Version ID: {row.version_id}, "
+ f"Embedding ID: {row.embedding_id}, "
+ f"Similarity: {row.similarity}")
+ current_app.rag_tuning_logger.debug(f'---------------------------------------\n')
+ except SQLAlchemyError as e:
+ current_app.logger.error(f'Error generating overview: {e}')
+ db.session.rollback()
+
+ if self.model_variables['rag_tuning']:
+ current_app.rag_tuning_logger.debug(f'Parameters for Retrieval of documents: \n')
+ current_app.rag_tuning_logger.debug(f'Similarity Threshold: {similarity_threshold}\n')
+ current_app.rag_tuning_logger.debug(f'K: {k}\n')
+ current_app.rag_tuning_logger.debug(f'---------------------------------------\n')
+
+ try:
+ current_date = get_date_in_timezone(self.tenant_info['timezone'])
+ # Subquery to find the latest version of each document
+ subquery = (
+ db.session.query(
+ DocumentVersion.doc_id,
+ func.max(DocumentVersion.id).label('latest_version_id')
+ )
+ .group_by(DocumentVersion.doc_id)
+ .subquery()
+ )
+ # Main query to filter embeddings
+ query_obj = (
+ db.session.query(db_class,
+ (1 - db_class.embedding.cosine_distance(query_embedding)).label('similarity'))
+ .join(DocumentVersion, db_class.doc_vers_id == DocumentVersion.id)
+ .join(Document, DocumentVersion.doc_id == Document.id)
+ .join(subquery, DocumentVersion.id == subquery.c.latest_version_id)
+ .filter(
+ or_(Document.valid_from.is_(None), func.date(Document.valid_from) <= current_date),
+ or_(Document.valid_to.is_(None), func.date(Document.valid_to) >= current_date),
+ (1 - db_class.embedding.cosine_distance(query_embedding)) > similarity_threshold,
+ Document.catalog_id == self._catalog_id
+ )
+ .order_by(desc('similarity'))
+ .limit(k)
+ )
+
+ if self.model_variables['rag_tuning']:
+ current_app.rag_tuning_logger.debug(f'Query executed for Retrieval of documents: \n')
+ current_app.rag_tuning_logger.debug(f'{query_obj.statement}\n')
+ current_app.rag_tuning_logger.debug(f'---------------------------------------\n')
+
+ res = query_obj.all()
+
+ if self.model_variables['rag_tuning']:
+ current_app.rag_tuning_logger.debug(f'Retrieved {len(res)} relevant documents \n')
+ current_app.rag_tuning_logger.debug(f'Data retrieved: \n')
+ current_app.rag_tuning_logger.debug(f'{res}\n')
+ current_app.rag_tuning_logger.debug(f'---------------------------------------\n')
+
+ result = []
+ for doc in res:
+ if self.model_variables['rag_tuning']:
+ current_app.rag_tuning_logger.debug(f'Document ID: {doc[0].id} - Distance: {doc[1]}\n')
+ current_app.rag_tuning_logger.debug(f'Chunk: \n {doc[0].chunk}\n\n')
+ result.append(f'SOURCE: {doc[0].id}\n\n{doc[0].chunk}\n\n')
+
+ except SQLAlchemyError as e:
+ current_app.logger.error(f'Error retrieving relevant documents: {e}')
+ db.session.rollback()
+ return []
+ return result
+
+ def _get_query_embedding(self, query: str):
+ embedding_model = self.model_variables['embedding_model']
+ query_embedding = embedding_model.embed_query(query)
+ return query_embedding
diff --git a/common/langchain/eveai_retriever.py b/common/langchain/eveai_retriever.py
index 1e517f6..d394920 100644
--- a/common/langchain/eveai_retriever.py
+++ b/common/langchain/eveai_retriever.py
@@ -1,138 +1,39 @@
-from langchain_core.retrievers import BaseRetriever
-from sqlalchemy import func, and_, or_, desc
-from sqlalchemy.exc import SQLAlchemyError
-from pydantic import BaseModel, Field, PrivateAttr
-from typing import Any, Dict
-from flask import current_app
+from pydantic import BaseModel, PrivateAttr
+from typing import Dict, Any
-from common.extensions import db
-from common.models.document import Document, DocumentVersion
-from common.utils.datetime_utils import get_date_in_timezone
from common.utils.model_utils import ModelVariables
-class EveAIRetriever(BaseRetriever, BaseModel):
- _model_variables: ModelVariables = PrivateAttr()
+class EveAIRetriever(BaseModel):
+ _catalog_id: int = PrivateAttr()
+ _user_metadata: Dict[str, Any] = PrivateAttr()
+ _system_metadata: Dict[str, Any] = PrivateAttr()
+ _configuration: Dict[str, Any] = PrivateAttr()
_tenant_info: Dict[str, Any] = PrivateAttr()
+ _model_variables: ModelVariables = PrivateAttr()
- def __init__(self, model_variables: ModelVariables, tenant_info: Dict[str, Any]):
+ def __init__(self, catalog_id: int, user_metadata: Dict[str, Any], system_metadata: Dict[str, Any],
+ configuration: Dict[str, Any]):
super().__init__()
- current_app.logger.debug(f'Model variables type: {type(model_variables)}')
- self._model_variables = model_variables
- self._tenant_info = tenant_info
+ self._catalog_id = catalog_id
+ self._user_metadata = user_metadata
+ self._system_metadata = system_metadata
+ self._configuration = configuration
@property
- def model_variables(self) -> ModelVariables:
- return self._model_variables
+ def catalog_id(self):
+ return self._catalog_id
@property
- def tenant_info(self) -> Dict[str, Any]:
- return self._tenant_info
+ def user_metadata(self):
+ return self._user_metadata
- def _get_relevant_documents(self, query: str):
- current_app.logger.debug(f'Retrieving relevant documents for query: {query}')
- query_embedding = self._get_query_embedding(query)
- current_app.logger.debug(f'Model Variables Private: {type(self._model_variables)}')
- current_app.logger.debug(f'Model Variables Property: {type(self.model_variables)}')
- db_class = self.model_variables['embedding_db_model']
- similarity_threshold = self.model_variables['similarity_threshold']
- k = self.model_variables['k']
+ @property
+ def system_metadata(self):
+ return self._system_metadata
- if self.model_variables['rag_tuning']:
- try:
- current_date = get_date_in_timezone(self.tenant_info['timezone'])
- current_app.rag_tuning_logger.debug(f'Current date: {current_date}\n')
+ @property
+ def configuration(self):
+ return self._configuration
- # Debug query to show similarity for all valid documents (without chunk text)
- debug_query = (
- db.session.query(
- Document.id.label('document_id'),
- DocumentVersion.id.label('version_id'),
- db_class.id.label('embedding_id'),
- (1 - db_class.embedding.cosine_distance(query_embedding)).label('similarity')
- )
- .join(DocumentVersion, db_class.doc_vers_id == DocumentVersion.id)
- .join(Document, DocumentVersion.doc_id == Document.id)
- .filter(
- or_(Document.valid_from.is_(None), func.date(Document.valid_from) <= current_date),
- or_(Document.valid_to.is_(None), func.date(Document.valid_to) >= current_date)
- )
- .order_by(desc('similarity'))
- )
-
- debug_results = debug_query.all()
-
- current_app.logger.debug("Debug: Similarity for all valid documents:")
- for row in debug_results:
- current_app.rag_tuning_logger.debug(f"Doc ID: {row.document_id}, "
- f"Version ID: {row.version_id}, "
- f"Embedding ID: {row.embedding_id}, "
- f"Similarity: {row.similarity}")
- current_app.rag_tuning_logger.debug(f'---------------------------------------\n')
- except SQLAlchemyError as e:
- current_app.logger.error(f'Error generating overview: {e}')
- db.session.rollback()
-
- if self.model_variables['rag_tuning']:
- current_app.rag_tuning_logger.debug(f'Parameters for Retrieval of documents: \n')
- current_app.rag_tuning_logger.debug(f'Similarity Threshold: {similarity_threshold}\n')
- current_app.rag_tuning_logger.debug(f'K: {k}\n')
- current_app.rag_tuning_logger.debug(f'---------------------------------------\n')
-
- try:
- current_date = get_date_in_timezone(self.tenant_info['timezone'])
- # Subquery to find the latest version of each document
- subquery = (
- db.session.query(
- DocumentVersion.doc_id,
- func.max(DocumentVersion.id).label('latest_version_id')
- )
- .group_by(DocumentVersion.doc_id)
- .subquery()
- )
- # Main query to filter embeddings
- query_obj = (
- db.session.query(db_class,
- (1 - db_class.embedding.cosine_distance(query_embedding)).label('similarity'))
- .join(DocumentVersion, db_class.doc_vers_id == DocumentVersion.id)
- .join(Document, DocumentVersion.doc_id == Document.id)
- .join(subquery, DocumentVersion.id == subquery.c.latest_version_id)
- .filter(
- or_(Document.valid_from.is_(None), func.date(Document.valid_from) <= current_date),
- or_(Document.valid_to.is_(None), func.date(Document.valid_to) >= current_date),
- (1 - db_class.embedding.cosine_distance(query_embedding)) > similarity_threshold
- )
- .order_by(desc('similarity'))
- .limit(k)
- )
-
- if self.model_variables['rag_tuning']:
- current_app.rag_tuning_logger.debug(f'Query executed for Retrieval of documents: \n')
- current_app.rag_tuning_logger.debug(f'{query_obj.statement}\n')
- current_app.rag_tuning_logger.debug(f'---------------------------------------\n')
-
- res = query_obj.all()
-
- if self.model_variables['rag_tuning']:
- current_app.rag_tuning_logger.debug(f'Retrieved {len(res)} relevant documents \n')
- current_app.rag_tuning_logger.debug(f'Data retrieved: \n')
- current_app.rag_tuning_logger.debug(f'{res}\n')
- current_app.rag_tuning_logger.debug(f'---------------------------------------\n')
-
- result = []
- for doc in res:
- if self.model_variables['rag_tuning']:
- current_app.rag_tuning_logger.debug(f'Document ID: {doc[0].id} - Distance: {doc[1]}\n')
- current_app.rag_tuning_logger.debug(f'Chunk: \n {doc[0].chunk}\n\n')
- result.append(f'SOURCE: {doc[0].id}\n\n{doc[0].chunk}\n\n')
-
- except SQLAlchemyError as e:
- current_app.logger.error(f'Error retrieving relevant documents: {e}')
- db.session.rollback()
- return []
- return result
-
- def _get_query_embedding(self, query: str):
- embedding_model = self.model_variables['embedding_model']
- query_embedding = embedding_model.embed_query(query)
- return query_embedding
+ # Any common methods that should be shared among retrievers can go here.
diff --git a/common/models/document.py b/common/models/document.py
index 7d08e67..3d9d4c9 100644
--- a/common/models/document.py
+++ b/common/models/document.py
@@ -8,8 +8,10 @@ import sqlalchemy as sa
class Catalog(db.Model):
id = db.Column(db.Integer, primary_key=True)
+ parent_id = db.Column(db.Integer, db.ForeignKey('catalog.id'), nullable=True)
name = db.Column(db.String(50), nullable=False)
description = db.Column(db.Text, nullable=True)
+ type = db.Column(db.String(50), nullable=False, default="DEFAULT_CATALOG")
# Embedding variables
html_tags = db.Column(ARRAY(sa.String(10)), nullable=True, default=['p', 'h1', 'h2', 'h3', 'h4', 'h5', 'h6', 'li'])
@@ -21,17 +23,36 @@ class Catalog(db.Model):
min_chunk_size = db.Column(db.Integer, nullable=True, default=2000)
max_chunk_size = db.Column(db.Integer, nullable=True, default=3000)
- # Embedding search variables ==> move to specialist?
- es_k = db.Column(db.Integer, nullable=True, default=8)
- es_similarity_threshold = db.Column(db.Float, nullable=True, default=0.4)
-
# Chat variables ==> Move to Specialist?
chat_RAG_temperature = db.Column(db.Float, nullable=True, default=0.3)
chat_no_RAG_temperature = db.Column(db.Float, nullable=True, default=0.5)
# Tuning enablers
embed_tuning = db.Column(db.Boolean, nullable=True, default=False)
- rag_tuning = db.Column(db.Boolean, nullable=True, default=False) # Move to Specialist?
+
+ # Meta Data
+ user_metadata = db.Column(JSONB, nullable=True)
+ system_metadata = db.Column(JSONB, nullable=True)
+ configuration = db.Column(JSONB, nullable=True)
+
+ # Versioning Information
+ created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now())
+ created_by = db.Column(db.Integer, db.ForeignKey(User.id), nullable=True)
+ updated_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now(), onupdate=db.func.now())
+ updated_by = db.Column(db.Integer, db.ForeignKey(User.id))
+
+
+class Retriever(db.Model):
+ id = db.Column(db.Integer, primary_key=True)
+ name = db.Column(db.String(50), nullable=False)
+ description = db.Column(db.Text, nullable=True)
+ catalog_id = db.Column(db.Integer, db.ForeignKey('catalog.id'), nullable=True)
+ type = db.Column(db.String(50), nullable=False, default="DEFAULT_RAG")
+
+ # Meta Data
+ user_metadata = db.Column(JSONB, nullable=True)
+ system_metadata = db.Column(JSONB, nullable=True)
+ configuration = db.Column(JSONB, nullable=True)
# Versioning Information
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now())
diff --git a/config/catalog_types.py b/config/catalog_types.py
new file mode 100644
index 0000000..1346bdf
--- /dev/null
+++ b/config/catalog_types.py
@@ -0,0 +1,11 @@
+# Catalog Types
+CATALOG_TYPES = {
+ "DEFAULT": {
+ "name": "Default Catalog",
+ "Description": "Default Catalog"
+ },
+ "DOSSIER": {
+ "name": "Dossier Catalog",
+ "Description": "A Catalog in which several Dossiers can be stored"
+ },
+}
diff --git a/config/retriever_types.py b/config/retriever_types.py
new file mode 100644
index 0000000..84c2539
--- /dev/null
+++ b/config/retriever_types.py
@@ -0,0 +1,30 @@
+# Retriever Types
+RETRIEVER_TYPES = {
+ "DEFAULT_RAG": {
+ "name": "Default RAG",
+ "description": "Retrieving all embeddings conform the query",
+ "configuration": {
+ "es_k": {
+ "name": "es_k",
+ "type": "int",
+ "description": "K-value to retrieve embeddings (max embeddings retrieved)",
+ "required": True,
+ "default": 8,
+ },
+ "es_similarity_threshold": {
+ "name": "es_similarity_threshold",
+ "type": "float",
+ "description": "Similarity threshold for retrieving embeddings",
+ "required": True,
+ "default": 0.3,
+ },
+ "rag_tuning": {
+ "name": "rag_tuning",
+ "type": "boolean",
+ "description": "Whether to do tuning logging or not.",
+ "required": False,
+ "default": False,
+ }
+ }
+ }
+}
diff --git a/eveai_app/templates/document/edit_catalog.html b/eveai_app/templates/document/edit_catalog.html
index f205e4c..51a548b 100644
--- a/eveai_app/templates/document/edit_catalog.html
+++ b/eveai_app/templates/document/edit_catalog.html
@@ -16,7 +16,7 @@ When you change chunking of embedding information, you'll need to manually refre
{% for field in form %}
{{ render_field(field, disabled_fields, exclude_fields) }}
{% endfor %}
-
+
{% endblock %}
diff --git a/eveai_app/templates/document/edit_retriever.html b/eveai_app/templates/document/edit_retriever.html
new file mode 100644
index 0000000..f8af75c
--- /dev/null
+++ b/eveai_app/templates/document/edit_retriever.html
@@ -0,0 +1,31 @@
+{% extends 'base.html' %}
+{% from "macros.html" import render_field2, render_dynamic_fields %}
+
+{% block title %}Edit Retriever{% endblock %}
+
+{% block content_title %}Edit Retriever{% endblock %}
+{% block content_description %}Edit a Retriever (for a Catalog){% endblock %}
+
+{% block content %}
+
+{% endblock %}
+
+{% block content_footer %}
+
+{% endblock %}
diff --git a/eveai_app/templates/document/retriever.html b/eveai_app/templates/document/retriever.html
new file mode 100644
index 0000000..9a90349
--- /dev/null
+++ b/eveai_app/templates/document/retriever.html
@@ -0,0 +1,23 @@
+{% extends 'base.html' %}
+{% from "macros.html" import render_field %}
+
+{% block title %}Retriever Registration{% endblock %}
+
+{% block content_title %}Register Retriever{% endblock %}
+{% block content_description %}Define a new retriever (for a catalog){% endblock %}
+
+{% block content %}
+
+{% endblock %}
+
+{% block content_footer %}
+
+{% endblock %}
diff --git a/eveai_app/templates/document/retrievers.html b/eveai_app/templates/document/retrievers.html
new file mode 100644
index 0000000..cbc44b0
--- /dev/null
+++ b/eveai_app/templates/document/retrievers.html
@@ -0,0 +1,23 @@
+{% extends 'base.html' %}
+{% from 'macros.html' import render_selectable_table, render_pagination %}
+
+{% block title %}Retrievers{% endblock %}
+
+{% block content_title %}Retrievers{% endblock %}
+{% block content_description %}View Retrieers for Tenant{% endblock %}
+{% block content_class %}{% endblock %}
+
+{% block content %}
+
+{% endblock %}
+
+{% block content_footer %}
+ {{ render_pagination(pagination, 'document_bp.retrievers') }}
+{% endblock %}
\ No newline at end of file
diff --git a/eveai_app/templates/macros.html b/eveai_app/templates/macros.html
index 94e3896..e75ff28 100644
--- a/eveai_app/templates/macros.html
+++ b/eveai_app/templates/macros.html
@@ -23,6 +23,43 @@
{% endif %}
{% endmacro %}
+{% macro render_field2(field, disabled_fields=[], exclude_fields=[], class='') %}
+
+
+
+ {% set disabled = field.name in disabled_fields %}
+ {% set exclude_fields = exclude_fields + ['csrf_token', 'submit'] %}
+ {% if field.name not in exclude_fields %}
+ {% if field.type == 'BooleanField' %}
+
+ {% else %}
+
+ {% endif %}
+ {% endif %}
+{% endmacro %}
+
{% macro render_included_field(field, disabled_fields=[], include_fields=[]) %}
{% set disabled = field.name in disabled_fields %}
{% if field.name in include_fields %}
diff --git a/eveai_app/templates/navbar.html b/eveai_app/templates/navbar.html
index bb11f33..16dbd1a 100644
--- a/eveai_app/templates/navbar.html
+++ b/eveai_app/templates/navbar.html
@@ -83,6 +83,8 @@
{{ dropdown('Document Mgmt', 'note_stack', [
{'name': 'Add Catalog', 'url': '/document/catalog', 'roles': ['Super User', 'Tenant Admin']},
{'name': 'All Catalogs', 'url': '/document/catalogs', 'roles': ['Super User', 'Tenant Admin']},
+ {'name': 'Add Retriever', 'url': '/document/retriever', 'roles': ['Super User', 'Tenant Admin']},
+ {'name': 'All Retrievers', 'url': '/document/retrievers', 'roles': ['Super User', 'Tenant Admin']},
{'name': 'Add Document', 'url': '/document/add_document', 'roles': ['Super User', 'Tenant Admin']},
{'name': 'Add URL', 'url': '/document/add_url', 'roles': ['Super User', 'Tenant Admin']},
{'name': 'Add a list of URLs', 'url': '/document/add_urls', 'roles': ['Super User', 'Tenant Admin']},
diff --git a/eveai_app/views/document_forms.py b/eveai_app/views/document_forms.py
index 87427e1..b035110 100644
--- a/eveai_app/views/document_forms.py
+++ b/eveai_app/views/document_forms.py
@@ -1,4 +1,4 @@
-from flask import session, current_app
+from flask import session, current_app, request
from flask_wtf import FlaskForm
from wtforms import (StringField, BooleanField, SubmitField, DateField, IntegerField, FloatField, SelectMultipleField,
SelectField, FieldList, FormField, TextAreaField, URLField)
@@ -6,6 +6,14 @@ from wtforms.validators import DataRequired, Length, Optional, URL, ValidationEr
from flask_wtf.file import FileField, FileAllowed, FileRequired
import json
+from wtforms_sqlalchemy.fields import QuerySelectField
+
+from common.models.document import Catalog
+
+from config.catalog_types import CATALOG_TYPES
+from config.retriever_types import RETRIEVER_TYPES
+from .dynamic_form_base import DynamicFormBase
+
def allowed_file(form, field):
if field.data:
@@ -26,6 +34,23 @@ def validate_json(form, field):
class CatalogForm(FlaskForm):
name = StringField('Name', validators=[DataRequired(), Length(max=50)])
description = TextAreaField('Description', validators=[Optional()])
+ # Parent ID (Optional for root-level catalogs)
+ parent = QuerySelectField(
+ 'Parent Catalog',
+ query_factory=lambda: Catalog.query.all(),
+ allow_blank=True,
+ get_label='name',
+ validators=[Optional()],
+ )
+
+ # Select Field for Catalog Type (Uses the CATALOG_TYPES defined in config)
+ type = SelectField('Catalog Type', validators=[DataRequired()])
+
+ # Metadata fields
+ user_metadata = TextAreaField('User Metadata', validators=[Optional(), validate_json])
+ system_metadata = TextAreaField('System Metadata', validators=[Optional(), validate_json])
+ configuration = TextAreaField('Configuration', validators=[Optional(), validate_json])
+
# HTML Embedding Variables
html_tags = StringField('HTML Tags', validators=[DataRequired()],
default='p, h1, h2, h3, h4, h5, h6, li, , tbody, tr, td')
@@ -38,19 +63,65 @@ class CatalogForm(FlaskForm):
default=2000)
max_chunk_size = IntegerField('Maximum Chunk Size (3000)', validators=[NumberRange(min=0), Optional()],
default=3000)
- # Embedding Search variables
- es_k = IntegerField('Limit for Searching Embeddings (5)',
- default=5,
- validators=[NumberRange(min=0)])
- es_similarity_threshold = FloatField('Similarity Threshold for Searching Embeddings (0.5)',
- default=0.5,
- validators=[NumberRange(min=0, max=1)])
# Chat Variables
chat_RAG_temperature = FloatField('RAG Temperature', default=0.3, validators=[NumberRange(min=0, max=1)])
chat_no_RAG_temperature = FloatField('No RAG Temperature', default=0.5, validators=[NumberRange(min=0, max=1)])
# Tuning variables
embed_tuning = BooleanField('Enable Embedding Tuning', default=False)
- rag_tuning = BooleanField('Enable RAG Tuning', default=False)
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ # Dynamically populate the 'type' field using the constructor
+ self.type.choices = [(key, value['name']) for key, value in CATALOG_TYPES.items()]
+
+
+class RetrieverForm(FlaskForm):
+ name = StringField('Name', validators=[DataRequired(), Length(max=50)])
+ description = TextAreaField('Description', validators=[Optional()])
+ # Catalog for the Retriever
+ catalog = QuerySelectField(
+ 'Catalog ID',
+ query_factory=lambda: Catalog.query.all(),
+ allow_blank=True,
+ get_label='name',
+ validators=[Optional()],
+ )
+ # Select Field for Retriever Type (Uses the RETRIEVER_TYPES defined in config)
+ type = SelectField('Retriever Type', validators=[DataRequired()])
+
+ # Metadata fields
+ user_metadata = TextAreaField('User Metadata', validators=[Optional(), validate_json])
+ system_metadata = TextAreaField('System Metadata', validators=[Optional(), validate_json])
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ # Dynamically populate the 'type' field using the constructor
+ self.type.choices = [(key, value['name']) for key, value in RETRIEVER_TYPES.items()]
+
+
+class EditRetrieverForm(DynamicFormBase):
+ name = StringField('Name', validators=[DataRequired(), Length(max=50)])
+ description = TextAreaField('Description', validators=[Optional()])
+ # Catalog for the Retriever
+ catalog = QuerySelectField(
+ 'Catalog ID',
+ query_factory=lambda: Catalog.query.all(),
+ allow_blank=True,
+ get_label='name',
+ validators=[Optional()],
+ )
+ # Select Field for Retriever Type (Uses the RETRIEVER_TYPES defined in config)
+ type = SelectField('Retriever Type', validators=[DataRequired()], render_kw={'readonly': True})
+
+ # Metadata fields
+ user_metadata = TextAreaField('User Metadata', validators=[Optional(), validate_json])
+ system_metadata = TextAreaField('System Metadata', validators=[Optional(), validate_json])
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ # Set the retriever type choices (loaded from config)
+ self.type.choices = [(key, value['name']) for key, value in RETRIEVER_TYPES.items()]
class AddDocumentForm(FlaskForm):
diff --git a/eveai_app/views/document_views.py b/eveai_app/views/document_views.py
index f8f2a4b..260163f 100644
--- a/eveai_app/views/document_views.py
+++ b/eveai_app/views/document_views.py
@@ -14,7 +14,7 @@ from urllib.parse import urlparse, unquote
import io
import json
-from common.models.document import Document, DocumentVersion, Catalog
+from common.models.document import Document, DocumentVersion, Catalog, Retriever
from common.extensions import db, minio_client
from common.utils.document_utils import validate_file_type, create_document_stack, start_embedding_task, process_url, \
process_multiple_urls, get_documents_list, edit_document, \
@@ -22,7 +22,7 @@ from common.utils.document_utils import validate_file_type, create_document_stac
from common.utils.eveai_exceptions import EveAIInvalidLanguageException, EveAIUnsupportedFileType, \
EveAIDoubleURLException
from .document_forms import AddDocumentForm, AddURLForm, EditDocumentForm, EditDocumentVersionForm, AddURLsForm, \
- CatalogForm
+ CatalogForm, RetrieverForm, EditRetrieverForm
from common.utils.middleware import mw_before_request
from common.utils.celery_utils import current_celery
from common.utils.nginx_utils import prefixed_url_for
@@ -30,6 +30,8 @@ from common.utils.view_assistants import form_validation_failed, prepare_table_f
from .document_list_view import DocumentListView
from .document_version_list_view import DocumentVersionListView
+from config.retriever_types import RETRIEVER_TYPES
+
document_bp = Blueprint('document_bp', __name__, url_prefix='/document')
@@ -65,6 +67,7 @@ def catalog():
tenant_id = session.get('tenant').get('id')
new_catalog = Catalog()
form.populate_obj(new_catalog)
+ new_catalog.parent_id = form.parent.data.get('id')
# Handle Embedding Variables
new_catalog.html_tags = [tag.strip() for tag in form.html_tags.data.split(',')] if form.html_tags.data else []
new_catalog.html_end_tags = [tag.strip() for tag in form.html_end_tags.data.split(',')] \
@@ -103,7 +106,7 @@ def catalogs():
the_catalogs = pagination.items
# prepare table data
- rows = prepare_table_for_macro(the_catalogs, [('id', ''), ('name', '')])
+ rows = prepare_table_for_macro(the_catalogs, [('id', ''), ('name', ''), ('type', '')])
# Render the catalogs in a template
return render_template('document/catalogs.html', rows=rows, pagination=pagination)
@@ -173,6 +176,121 @@ def edit_catalog(catalog_id):
return render_template('document/edit_catalog.html', form=form, catalog_id=catalog_id)
+@document_bp.route('/retriever', methods=['GET', 'POST'])
+@roles_accepted('Super User', 'Tenant Admin')
+def retriever():
+ form = RetrieverForm()
+
+ if form.validate_on_submit():
+ tenant_id = session.get('tenant').get('id')
+ new_retriever = Retriever()
+ form.populate_obj(new_retriever)
+ new_retriever.catalog_id = form.catalog.data.id
+
+ set_logging_information(new_retriever, dt.now(tz.utc))
+
+ try:
+ db.session.add(new_retriever)
+ db.session.commit()
+ flash('Retriever successfully added!', 'success')
+ current_app.logger.info(f'Catalog {new_retriever.name} successfully added for tenant {tenant_id}!')
+ except SQLAlchemyError as e:
+ db.session.rollback()
+ flash(f'Failed to add retriever. Error: {e}', 'danger')
+ current_app.logger.error(f'Failed to add retriever {new_retriever.name}'
+ f'for tenant {tenant_id}. Error: {str(e)}')
+
+ # Enable step 2 of creation of retriever - add configuration of the retriever (dependent on type)
+ return redirect(prefixed_url_for('document_bp.retriever', retriever_id=new_retriever.id))
+
+ return render_template('document/retriever.html', form=form)
+
+
+@document_bp.route('/retriever/', methods=['GET', 'POST'])
+@roles_accepted('Super User', 'Tenant Admin')
+def edit_retriever(retriever_id):
+ """Edit an existing retriever configuration."""
+ current_app.logger.debug(f"Editing Retriever {retriever_id}")
+
+ # Get the retriever or return 404
+ retriever = Retriever.query.get_or_404(retriever_id)
+
+ if retriever.catalog_id:
+ # If catalog_id is just an ID, fetch the Catalog object
+ retriever.catalog = Catalog.query.get(retriever.catalog_id)
+ else:
+ retriever.catalog = None
+
+ # Create form instance with the retriever
+ form = EditRetrieverForm(request.form, obj=retriever)
+
+ configuration_config = RETRIEVER_TYPES[retriever.type]["configuration"]
+ current_app.logger.debug(f"Configuration {configuration_config}")
+ form.add_dynamic_fields("configuration", configuration_config, retriever.configuration)
+
+ if form.validate_on_submit():
+ # Update basic fields
+ form.populate_obj(retriever)
+ retriever.configuration = form.get_dynamic_data('configuration')
+
+ # Update catalog relationship
+ retriever.catalog_id = form.catalog.data.id if form.catalog.data else None
+
+ # Update logging information
+ update_logging_information(retriever, dt.now(tz.utc))
+
+ # Save changes to database
+ try:
+ db.session.add(retriever)
+ db.session.commit()
+ flash('Retriever updated successfully!', 'success')
+ current_app.logger.info(f'Retriever {retriever.id} updated successfully')
+ except SQLAlchemyError as e:
+ db.session.rollback()
+ flash(f'Failed to update retriever. Error: {str(e)}', 'danger')
+ current_app.logger.error(f'Failed to update retriever {retriever_id}. Error: {str(e)}')
+ return render_template('document/edit_retriever.html', form=form, retriever_id=retriever_id)
+
+ return redirect(prefixed_url_for('document_bp.retrievers'))
+ else:
+ form_validation_failed(request, form)
+
+ current_app.logger.debug(f"Rendering Template for {retriever_id}")
+ return render_template('document/edit_retriever.html', form=form, retriever_id=retriever_id)
+
+
+@document_bp.route('/retrievers', methods=['GET', 'POST'])
+@roles_accepted('Super User', 'Tenant Admin')
+def retrievers():
+ page = request.args.get('page', 1, type=int)
+ per_page = request.args.get('per_page', 10, type=int)
+
+ query = Retriever.query.order_by(Retriever.id)
+
+ pagination = query.paginate(page=page, per_page=per_page)
+ the_retrievers = pagination.items
+
+ # prepare table data
+ rows = prepare_table_for_macro(the_retrievers,
+ [('id', ''), ('name', ''), ('type', ''), ('catalog_id', '')])
+
+ # Render the catalogs in a template
+ return render_template('document/retrievers.html', rows=rows, pagination=pagination)
+
+
+@document_bp.route('/handle_retriever_selection', methods=['POST'])
+@roles_accepted('Super User', 'Tenant Admin')
+def handle_retriever_selection():
+ retriever_identification = request.form.get('selected_row')
+ retriever_id = ast.literal_eval(retriever_identification).get('value')
+ action = request.form['action']
+
+ if action == 'edit_retriever':
+ return redirect(prefixed_url_for('document_bp.edit_retriever', retriever_id=retriever_id))
+
+ return redirect(prefixed_url_for('document_bp.retrievers'))
+
+
@document_bp.route('/add_document', methods=['GET', 'POST'])
@roles_accepted('Super User', 'Tenant Admin')
def add_document():
diff --git a/eveai_app/views/dynamic_form_base.py b/eveai_app/views/dynamic_form_base.py
new file mode 100644
index 0000000..46ba100
--- /dev/null
+++ b/eveai_app/views/dynamic_form_base.py
@@ -0,0 +1,92 @@
+from flask_wtf import FlaskForm
+from wtforms import IntegerField, FloatField, BooleanField, StringField, validators
+from flask import current_app
+
+class DynamicFormBase(FlaskForm):
+ def __init__(self, formdata=None, *args, **kwargs):
+ super(DynamicFormBase, self).__init__(*args, **kwargs)
+ # Maps collection names to lists of field names
+ self.dynamic_fields = {}
+ # Store formdata for later use
+ self.formdata = formdata
+
+ def add_dynamic_fields(self, collection_name, config, initial_data=None):
+ """Add dynamic fields to the form based on the configuration."""
+ self.dynamic_fields[collection_name] = []
+ for field_name, field_def in config.items():
+ current_app.logger.debug(f"{field_name}: {field_def}")
+ # Prefix the field name with the collection name
+ full_field_name = f"{collection_name}_{field_name}"
+ field_type = field_def.get('type')
+ description = field_def.get('description', '')
+ required = field_def.get('required', False)
+ default = field_def.get('default')
+
+ # Determine validators
+ field_validators = [validators.InputRequired()] if required else [validators.Optional()]
+
+ # Map the field type to WTForms field classes
+ field_class = {
+ 'int': IntegerField,
+ 'float': FloatField,
+ 'boolean': BooleanField,
+ 'string': StringField,
+ }.get(field_type, StringField)
+
+ # Create the field instance
+ unbound_field = field_class(
+ label=description,
+ validators=field_validators,
+ default=default
+ )
+
+ # Bind the field to the form
+ bound_field = unbound_field.bind(form=self, name=full_field_name)
+
+ # Process the field with formdata
+ if self.formdata and full_field_name in self.formdata:
+ # If formdata is available and contains the field
+ bound_field.process(self.formdata)
+ elif initial_data and field_name in initial_data:
+ # Use initial data if provided
+ bound_field.process(formdata=None, data=initial_data[field_name])
+ else:
+ # Use default value
+ bound_field.process(formdata=None, data=default)
+
+ # Set collection name attribute for identification
+ # bound_field.collection_name = collection_name
+
+ # Add the field to the form
+ setattr(self, full_field_name, bound_field)
+ self._fields[full_field_name] = bound_field
+ self.dynamic_fields[collection_name].append(full_field_name)
+
+ def get_static_fields(self):
+ """Return a list of static field instances."""
+ # Get names of dynamic fields
+ dynamic_field_names = set()
+ for field_list in self.dynamic_fields.values():
+ dynamic_field_names.update(field_list)
+
+ # Return all fields that are not dynamic
+ return [field for name, field in self._fields.items() if name not in dynamic_field_names]
+
+ def get_dynamic_fields(self):
+ """Return a dictionary of dynamic fields per collection."""
+ result = {}
+ for collection_name, field_names in self.dynamic_fields.items():
+ result[collection_name] = [getattr(self, name) for name in field_names]
+ return result
+
+ def get_dynamic_data(self, collection_name):
+ """Retrieve the data from dynamic fields of a specific collection."""
+ data = {}
+ if collection_name not in self.dynamic_fields:
+ return data
+ prefix_length = len(collection_name) + 1 # +1 for the underscore
+ for full_field_name in self.dynamic_fields[collection_name]:
+ original_field_name = full_field_name[prefix_length:]
+ field = getattr(self, full_field_name)
+ data[original_field_name] = field.data
+ return data
diff --git a/eveai_chat_workers/tasks.py b/eveai_chat_workers/tasks.py
index 90f6641..9c21e8b 100644
--- a/eveai_chat_workers/tasks.py
+++ b/eveai_chat_workers/tasks.py
@@ -22,7 +22,7 @@ from common.models.interaction import ChatSession, Interaction, InteractionEmbed
from common.extensions import db
from common.utils.celery_utils import current_celery
from common.utils.model_utils import select_model_variables, create_language_template, replace_variable_in_template
-from common.langchain.eveai_retriever import EveAIRetriever
+from common.langchain.eveai_default_rag_retriever import EveAIDefaultRagRetriever
from common.langchain.eveai_history_retriever import EveAIHistoryRetriever
from common.utils.business_event import BusinessEvent
from common.utils.business_event_context import current_event
@@ -139,7 +139,7 @@ def answer_using_tenant_rag(question, language, tenant, chat_session):
new_interaction.detailed_question_at = dt.now(tz.utc)
with current_event.create_span("Generate Answer using RAG"):
- retriever = EveAIRetriever(model_variables, tenant_info)
+ retriever = EveAIDefaultRagRetriever(model_variables, tenant_info)
llm = model_variables['llm']
template = model_variables['rag_template']
language_template = create_language_template(template, language)
@@ -243,7 +243,7 @@ def answer_using_llm(question, language, tenant, chat_session):
new_interaction.detailed_question_at = dt.now(tz.utc)
with current_event.create_span("Detail Answer using LLM"):
- retriever = EveAIRetriever(model_variables, tenant_info)
+ retriever = EveAIDefaultRagRetriever(model_variables, tenant_info)
llm = model_variables['llm_no_rag']
template = model_variables['encyclopedia_template']
language_template = create_language_template(template, language)
diff --git a/integrations/Wordpress/eveai-chat-widget.zip b/integrations/Wordpress/eveai-chat-widget.zip
deleted file mode 100644
index 709bc0d..0000000
Binary files a/integrations/Wordpress/eveai-chat-widget.zip and /dev/null differ
diff --git a/migrations/tenant/versions/3717364e6429_add_retriever_model.py b/migrations/tenant/versions/3717364e6429_add_retriever_model.py
new file mode 100644
index 0000000..74cd51d
--- /dev/null
+++ b/migrations/tenant/versions/3717364e6429_add_retriever_model.py
@@ -0,0 +1,56 @@
+"""Add Retriever Model
+
+Revision ID: 3717364e6429
+Revises: 7b7b566e667f
+Create Date: 2024-10-21 14:22:30.258679
+
+"""
+from alembic import op
+import sqlalchemy as sa
+import pgvector
+from sqlalchemy.dialects import postgresql
+
+# revision identifiers, used by Alembic.
+revision = '3717364e6429'
+down_revision = '7b7b566e667f'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.create_table('retriever',
+ sa.Column('id', sa.Integer(), nullable=False),
+ sa.Column('name', sa.String(length=50), nullable=False),
+ sa.Column('description', sa.Text(), nullable=True),
+ sa.Column('catalog_id', sa.Integer(), nullable=True),
+ sa.Column('type', sa.String(length=50), nullable=False),
+ sa.Column('user_metadata', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
+ sa.Column('system_metadata', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
+ sa.Column('configuration', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
+ sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=False),
+ sa.Column('created_by', sa.Integer(), nullable=True),
+ sa.Column('updated_at', sa.DateTime(), server_default=sa.text('now()'), nullable=False),
+ sa.Column('updated_by', sa.Integer(), nullable=True),
+ sa.ForeignKeyConstraint(['catalog_id'], ['catalog.id'], ),
+ sa.ForeignKeyConstraint(['created_by'], ['public.user.id'], ),
+ sa.ForeignKeyConstraint(['updated_by'], ['public.user.id'], ),
+ sa.PrimaryKeyConstraint('id')
+ )
+ op.drop_column('catalog', 'es_similarity_threshold')
+ op.drop_column('catalog', 'es_k')
+ op.drop_column('catalog', 'rag_tuning')
+ # ### end Alembic commands ###
+
+
+def downgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.add_column('catalog', sa.Column('rag_tuning', sa.BOOLEAN(), autoincrement=False, nullable=True))
+ op.add_column('catalog', sa.Column('es_k', sa.INTEGER(), autoincrement=False, nullable=True))
+ op.add_column('catalog', sa.Column('es_similarity_threshold', sa.DOUBLE_PRECISION(precision=53), autoincrement=False, nullable=True))
+ op.drop_constraint(None, 'catalog', type_='foreignkey')
+ op.drop_constraint(None, 'catalog', type_='foreignkey')
+ op.create_foreign_key('catalog_updated_by_fkey', 'catalog', 'user', ['updated_by'], ['id'])
+ op.create_foreign_key('catalog_created_by_fkey', 'catalog', 'user', ['created_by'], ['id'])
+ op.drop_table('retriever')
+ # ### end Alembic commands ###
diff --git a/migrations/tenant/versions/7b7b566e667f_extensions_for_more_catalog_types_in_.py b/migrations/tenant/versions/7b7b566e667f_extensions_for_more_catalog_types_in_.py
new file mode 100644
index 0000000..b348b40
--- /dev/null
+++ b/migrations/tenant/versions/7b7b566e667f_extensions_for_more_catalog_types_in_.py
@@ -0,0 +1,46 @@
+"""Extensions for more Catalog Types in Catalog Model
+
+Revision ID: 7b7b566e667f
+Revises: 28984b05d396
+Create Date: 2024-10-21 07:39:52.260054
+
+"""
+from alembic import op
+import sqlalchemy as sa
+import pgvector
+from sqlalchemy.dialects import postgresql
+
+# revision identifiers, used by Alembic.
+revision = '7b7b566e667f'
+down_revision = '28984b05d396'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.add_column('catalog', sa.Column('parent_id', sa.Integer(), nullable=True))
+
+ op.add_column('catalog', sa.Column('type', sa.String(length=50), nullable=False, server_default='DEFAULT'))
+ # Update all existing rows to have the 'type' set to 'DEFAULT'
+ op.execute("UPDATE catalog SET type='DEFAULT' WHERE type IS NULL")
+ # Remove the server default (optional but recommended)
+ op.alter_column('catalog', 'type', server_default=None)
+
+ op.add_column('catalog', sa.Column('user_metadata', postgresql.JSONB(astext_type=sa.Text()), nullable=True))
+ op.add_column('catalog', sa.Column('system_metadata', postgresql.JSONB(astext_type=sa.Text()), nullable=True))
+ op.add_column('catalog', sa.Column('configuration', postgresql.JSONB(astext_type=sa.Text()), nullable=True))
+ op.create_foreign_key(None, 'catalog', 'catalog', ['parent_id'], ['id'])
+
+ # ### end Alembic commands ###
+
+
+def downgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.drop_constraint(None, 'catalog', type_='foreignkey')
+ op.drop_column('catalog', 'configuration')
+ op.drop_column('catalog', 'system_metadata')
+ op.drop_column('catalog', 'user_metadata')
+ op.drop_column('catalog', 'type')
+ op.drop_column('catalog', 'parent_id')
+ # ### end Alembic commands ###
diff --git a/requirements.txt b/requirements.txt
index 2f57f6f..be8e3db 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -82,3 +82,4 @@ prometheus-client~=0.20.0
flower~=2.0.1
psutil~=6.0.0
celery-redbeat~=2.2.0
+WTForms-SQLAlchemy~=0.4.1