- Allowing for multiple types of Catalogs

- Introduction of retrievers
- Ensuring processing information is collected from Catalog iso Tenant
- Introduction of a generic Form class to enable dynamic fields based on a configuration
- Realisation of Retriever functionality to support dynamic fields
This commit is contained in:
Josako
2024-10-25 14:11:47 +02:00
parent 30fec27488
commit aa358df28e
19 changed files with 753 additions and 145 deletions

View File

@@ -0,0 +1,145 @@
from langchain_core.retrievers import BaseRetriever
from sqlalchemy import func, and_, or_, desc
from sqlalchemy.exc import SQLAlchemyError
from pydantic import BaseModel, Field, PrivateAttr
from typing import Any, Dict
from flask import current_app
from common.extensions import db
from common.models.document import Document, DocumentVersion
from common.utils.datetime_utils import get_date_in_timezone
from common.utils.model_utils import ModelVariables
class EveAIDefaultRagRetriever(BaseRetriever, BaseModel):
_catalog_id: int = PrivateAttr()
_model_variables: ModelVariables = PrivateAttr()
_tenant_info: Dict[str, Any] = PrivateAttr()
def __init__(self, catalog_id: int, model_variables: ModelVariables, tenant_info: Dict[str, Any]):
super().__init__()
current_app.logger.debug(f'Model variables type: {type(model_variables)}')
self._catalog_id = catalog_id
self._model_variables = model_variables
self._tenant_info = tenant_info
@property
def catalog_id(self) -> int:
return self._catalog_id
@property
def model_variables(self) -> ModelVariables:
return self._model_variables
@property
def tenant_info(self) -> Dict[str, Any]:
return self._tenant_info
def _get_relevant_documents(self, query: str):
current_app.logger.debug(f'Retrieving relevant documents for query: {query}')
query_embedding = self._get_query_embedding(query)
current_app.logger.debug(f'Model Variables Private: {type(self._model_variables)}')
current_app.logger.debug(f'Model Variables Property: {type(self.model_variables)}')
db_class = self.model_variables['embedding_db_model']
similarity_threshold = self.model_variables['similarity_threshold']
k = self.model_variables['k']
if self.model_variables['rag_tuning']:
try:
current_date = get_date_in_timezone(self.tenant_info['timezone'])
current_app.rag_tuning_logger.debug(f'Current date: {current_date}\n')
# Debug query to show similarity for all valid documents (without chunk text)
debug_query = (
db.session.query(
Document.id.label('document_id'),
DocumentVersion.id.label('version_id'),
db_class.id.label('embedding_id'),
(1 - db_class.embedding.cosine_distance(query_embedding)).label('similarity')
)
.join(DocumentVersion, db_class.doc_vers_id == DocumentVersion.id)
.join(Document, DocumentVersion.doc_id == Document.id)
.filter(
or_(Document.valid_from.is_(None), func.date(Document.valid_from) <= current_date),
or_(Document.valid_to.is_(None), func.date(Document.valid_to) >= current_date)
)
.order_by(desc('similarity'))
)
debug_results = debug_query.all()
current_app.logger.debug("Debug: Similarity for all valid documents:")
for row in debug_results:
current_app.rag_tuning_logger.debug(f"Doc ID: {row.document_id}, "
f"Version ID: {row.version_id}, "
f"Embedding ID: {row.embedding_id}, "
f"Similarity: {row.similarity}")
current_app.rag_tuning_logger.debug(f'---------------------------------------\n')
except SQLAlchemyError as e:
current_app.logger.error(f'Error generating overview: {e}')
db.session.rollback()
if self.model_variables['rag_tuning']:
current_app.rag_tuning_logger.debug(f'Parameters for Retrieval of documents: \n')
current_app.rag_tuning_logger.debug(f'Similarity Threshold: {similarity_threshold}\n')
current_app.rag_tuning_logger.debug(f'K: {k}\n')
current_app.rag_tuning_logger.debug(f'---------------------------------------\n')
try:
current_date = get_date_in_timezone(self.tenant_info['timezone'])
# Subquery to find the latest version of each document
subquery = (
db.session.query(
DocumentVersion.doc_id,
func.max(DocumentVersion.id).label('latest_version_id')
)
.group_by(DocumentVersion.doc_id)
.subquery()
)
# Main query to filter embeddings
query_obj = (
db.session.query(db_class,
(1 - db_class.embedding.cosine_distance(query_embedding)).label('similarity'))
.join(DocumentVersion, db_class.doc_vers_id == DocumentVersion.id)
.join(Document, DocumentVersion.doc_id == Document.id)
.join(subquery, DocumentVersion.id == subquery.c.latest_version_id)
.filter(
or_(Document.valid_from.is_(None), func.date(Document.valid_from) <= current_date),
or_(Document.valid_to.is_(None), func.date(Document.valid_to) >= current_date),
(1 - db_class.embedding.cosine_distance(query_embedding)) > similarity_threshold,
Document.catalog_id == self._catalog_id
)
.order_by(desc('similarity'))
.limit(k)
)
if self.model_variables['rag_tuning']:
current_app.rag_tuning_logger.debug(f'Query executed for Retrieval of documents: \n')
current_app.rag_tuning_logger.debug(f'{query_obj.statement}\n')
current_app.rag_tuning_logger.debug(f'---------------------------------------\n')
res = query_obj.all()
if self.model_variables['rag_tuning']:
current_app.rag_tuning_logger.debug(f'Retrieved {len(res)} relevant documents \n')
current_app.rag_tuning_logger.debug(f'Data retrieved: \n')
current_app.rag_tuning_logger.debug(f'{res}\n')
current_app.rag_tuning_logger.debug(f'---------------------------------------\n')
result = []
for doc in res:
if self.model_variables['rag_tuning']:
current_app.rag_tuning_logger.debug(f'Document ID: {doc[0].id} - Distance: {doc[1]}\n')
current_app.rag_tuning_logger.debug(f'Chunk: \n {doc[0].chunk}\n\n')
result.append(f'SOURCE: {doc[0].id}\n\n{doc[0].chunk}\n\n')
except SQLAlchemyError as e:
current_app.logger.error(f'Error retrieving relevant documents: {e}')
db.session.rollback()
return []
return result
def _get_query_embedding(self, query: str):
embedding_model = self.model_variables['embedding_model']
query_embedding = embedding_model.embed_query(query)
return query_embedding

View File

@@ -1,138 +1,39 @@
from langchain_core.retrievers import BaseRetriever from pydantic import BaseModel, PrivateAttr
from sqlalchemy import func, and_, or_, desc from typing import Dict, Any
from sqlalchemy.exc import SQLAlchemyError
from pydantic import BaseModel, Field, PrivateAttr
from typing import Any, Dict
from flask import current_app
from common.extensions import db
from common.models.document import Document, DocumentVersion
from common.utils.datetime_utils import get_date_in_timezone
from common.utils.model_utils import ModelVariables from common.utils.model_utils import ModelVariables
class EveAIRetriever(BaseRetriever, BaseModel): class EveAIRetriever(BaseModel):
_model_variables: ModelVariables = PrivateAttr() _catalog_id: int = PrivateAttr()
_user_metadata: Dict[str, Any] = PrivateAttr()
_system_metadata: Dict[str, Any] = PrivateAttr()
_configuration: Dict[str, Any] = PrivateAttr()
_tenant_info: Dict[str, Any] = PrivateAttr() _tenant_info: Dict[str, Any] = PrivateAttr()
_model_variables: ModelVariables = PrivateAttr()
def __init__(self, model_variables: ModelVariables, tenant_info: Dict[str, Any]): def __init__(self, catalog_id: int, user_metadata: Dict[str, Any], system_metadata: Dict[str, Any],
configuration: Dict[str, Any]):
super().__init__() super().__init__()
current_app.logger.debug(f'Model variables type: {type(model_variables)}') self._catalog_id = catalog_id
self._model_variables = model_variables self._user_metadata = user_metadata
self._tenant_info = tenant_info self._system_metadata = system_metadata
self._configuration = configuration
@property @property
def model_variables(self) -> ModelVariables: def catalog_id(self):
return self._model_variables return self._catalog_id
@property @property
def tenant_info(self) -> Dict[str, Any]: def user_metadata(self):
return self._tenant_info return self._user_metadata
def _get_relevant_documents(self, query: str): @property
current_app.logger.debug(f'Retrieving relevant documents for query: {query}') def system_metadata(self):
query_embedding = self._get_query_embedding(query) return self._system_metadata
current_app.logger.debug(f'Model Variables Private: {type(self._model_variables)}')
current_app.logger.debug(f'Model Variables Property: {type(self.model_variables)}')
db_class = self.model_variables['embedding_db_model']
similarity_threshold = self.model_variables['similarity_threshold']
k = self.model_variables['k']
if self.model_variables['rag_tuning']: @property
try: def configuration(self):
current_date = get_date_in_timezone(self.tenant_info['timezone']) return self._configuration
current_app.rag_tuning_logger.debug(f'Current date: {current_date}\n')
# Debug query to show similarity for all valid documents (without chunk text) # Any common methods that should be shared among retrievers can go here.
debug_query = (
db.session.query(
Document.id.label('document_id'),
DocumentVersion.id.label('version_id'),
db_class.id.label('embedding_id'),
(1 - db_class.embedding.cosine_distance(query_embedding)).label('similarity')
)
.join(DocumentVersion, db_class.doc_vers_id == DocumentVersion.id)
.join(Document, DocumentVersion.doc_id == Document.id)
.filter(
or_(Document.valid_from.is_(None), func.date(Document.valid_from) <= current_date),
or_(Document.valid_to.is_(None), func.date(Document.valid_to) >= current_date)
)
.order_by(desc('similarity'))
)
debug_results = debug_query.all()
current_app.logger.debug("Debug: Similarity for all valid documents:")
for row in debug_results:
current_app.rag_tuning_logger.debug(f"Doc ID: {row.document_id}, "
f"Version ID: {row.version_id}, "
f"Embedding ID: {row.embedding_id}, "
f"Similarity: {row.similarity}")
current_app.rag_tuning_logger.debug(f'---------------------------------------\n')
except SQLAlchemyError as e:
current_app.logger.error(f'Error generating overview: {e}')
db.session.rollback()
if self.model_variables['rag_tuning']:
current_app.rag_tuning_logger.debug(f'Parameters for Retrieval of documents: \n')
current_app.rag_tuning_logger.debug(f'Similarity Threshold: {similarity_threshold}\n')
current_app.rag_tuning_logger.debug(f'K: {k}\n')
current_app.rag_tuning_logger.debug(f'---------------------------------------\n')
try:
current_date = get_date_in_timezone(self.tenant_info['timezone'])
# Subquery to find the latest version of each document
subquery = (
db.session.query(
DocumentVersion.doc_id,
func.max(DocumentVersion.id).label('latest_version_id')
)
.group_by(DocumentVersion.doc_id)
.subquery()
)
# Main query to filter embeddings
query_obj = (
db.session.query(db_class,
(1 - db_class.embedding.cosine_distance(query_embedding)).label('similarity'))
.join(DocumentVersion, db_class.doc_vers_id == DocumentVersion.id)
.join(Document, DocumentVersion.doc_id == Document.id)
.join(subquery, DocumentVersion.id == subquery.c.latest_version_id)
.filter(
or_(Document.valid_from.is_(None), func.date(Document.valid_from) <= current_date),
or_(Document.valid_to.is_(None), func.date(Document.valid_to) >= current_date),
(1 - db_class.embedding.cosine_distance(query_embedding)) > similarity_threshold
)
.order_by(desc('similarity'))
.limit(k)
)
if self.model_variables['rag_tuning']:
current_app.rag_tuning_logger.debug(f'Query executed for Retrieval of documents: \n')
current_app.rag_tuning_logger.debug(f'{query_obj.statement}\n')
current_app.rag_tuning_logger.debug(f'---------------------------------------\n')
res = query_obj.all()
if self.model_variables['rag_tuning']:
current_app.rag_tuning_logger.debug(f'Retrieved {len(res)} relevant documents \n')
current_app.rag_tuning_logger.debug(f'Data retrieved: \n')
current_app.rag_tuning_logger.debug(f'{res}\n')
current_app.rag_tuning_logger.debug(f'---------------------------------------\n')
result = []
for doc in res:
if self.model_variables['rag_tuning']:
current_app.rag_tuning_logger.debug(f'Document ID: {doc[0].id} - Distance: {doc[1]}\n')
current_app.rag_tuning_logger.debug(f'Chunk: \n {doc[0].chunk}\n\n')
result.append(f'SOURCE: {doc[0].id}\n\n{doc[0].chunk}\n\n')
except SQLAlchemyError as e:
current_app.logger.error(f'Error retrieving relevant documents: {e}')
db.session.rollback()
return []
return result
def _get_query_embedding(self, query: str):
embedding_model = self.model_variables['embedding_model']
query_embedding = embedding_model.embed_query(query)
return query_embedding

View File

@@ -8,8 +8,10 @@ import sqlalchemy as sa
class Catalog(db.Model): class Catalog(db.Model):
id = db.Column(db.Integer, primary_key=True) id = db.Column(db.Integer, primary_key=True)
parent_id = db.Column(db.Integer, db.ForeignKey('catalog.id'), nullable=True)
name = db.Column(db.String(50), nullable=False) name = db.Column(db.String(50), nullable=False)
description = db.Column(db.Text, nullable=True) description = db.Column(db.Text, nullable=True)
type = db.Column(db.String(50), nullable=False, default="DEFAULT_CATALOG")
# Embedding variables # Embedding variables
html_tags = db.Column(ARRAY(sa.String(10)), nullable=True, default=['p', 'h1', 'h2', 'h3', 'h4', 'h5', 'h6', 'li']) html_tags = db.Column(ARRAY(sa.String(10)), nullable=True, default=['p', 'h1', 'h2', 'h3', 'h4', 'h5', 'h6', 'li'])
@@ -21,17 +23,36 @@ class Catalog(db.Model):
min_chunk_size = db.Column(db.Integer, nullable=True, default=2000) min_chunk_size = db.Column(db.Integer, nullable=True, default=2000)
max_chunk_size = db.Column(db.Integer, nullable=True, default=3000) max_chunk_size = db.Column(db.Integer, nullable=True, default=3000)
# Embedding search variables ==> move to specialist?
es_k = db.Column(db.Integer, nullable=True, default=8)
es_similarity_threshold = db.Column(db.Float, nullable=True, default=0.4)
# Chat variables ==> Move to Specialist? # Chat variables ==> Move to Specialist?
chat_RAG_temperature = db.Column(db.Float, nullable=True, default=0.3) chat_RAG_temperature = db.Column(db.Float, nullable=True, default=0.3)
chat_no_RAG_temperature = db.Column(db.Float, nullable=True, default=0.5) chat_no_RAG_temperature = db.Column(db.Float, nullable=True, default=0.5)
# Tuning enablers # Tuning enablers
embed_tuning = db.Column(db.Boolean, nullable=True, default=False) embed_tuning = db.Column(db.Boolean, nullable=True, default=False)
rag_tuning = db.Column(db.Boolean, nullable=True, default=False) # Move to Specialist?
# Meta Data
user_metadata = db.Column(JSONB, nullable=True)
system_metadata = db.Column(JSONB, nullable=True)
configuration = db.Column(JSONB, nullable=True)
# Versioning Information
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now())
created_by = db.Column(db.Integer, db.ForeignKey(User.id), nullable=True)
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now(), onupdate=db.func.now())
updated_by = db.Column(db.Integer, db.ForeignKey(User.id))
class Retriever(db.Model):
id = db.Column(db.Integer, primary_key=True)
name = db.Column(db.String(50), nullable=False)
description = db.Column(db.Text, nullable=True)
catalog_id = db.Column(db.Integer, db.ForeignKey('catalog.id'), nullable=True)
type = db.Column(db.String(50), nullable=False, default="DEFAULT_RAG")
# Meta Data
user_metadata = db.Column(JSONB, nullable=True)
system_metadata = db.Column(JSONB, nullable=True)
configuration = db.Column(JSONB, nullable=True)
# Versioning Information # Versioning Information
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now()) created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now())

11
config/catalog_types.py Normal file
View File

@@ -0,0 +1,11 @@
# Catalog Types
CATALOG_TYPES = {
"DEFAULT": {
"name": "Default Catalog",
"Description": "Default Catalog"
},
"DOSSIER": {
"name": "Dossier Catalog",
"Description": "A Catalog in which several Dossiers can be stored"
},
}

30
config/retriever_types.py Normal file
View File

@@ -0,0 +1,30 @@
# Retriever Types
RETRIEVER_TYPES = {
"DEFAULT_RAG": {
"name": "Default RAG",
"description": "Retrieving all embeddings conform the query",
"configuration": {
"es_k": {
"name": "es_k",
"type": "int",
"description": "K-value to retrieve embeddings (max embeddings retrieved)",
"required": True,
"default": 8,
},
"es_similarity_threshold": {
"name": "es_similarity_threshold",
"type": "float",
"description": "Similarity threshold for retrieving embeddings",
"required": True,
"default": 0.3,
},
"rag_tuning": {
"name": "rag_tuning",
"type": "boolean",
"description": "Whether to do tuning logging or not.",
"required": False,
"default": False,
}
}
}
}

View File

@@ -16,7 +16,7 @@ When you change chunking of embedding information, you'll need to manually refre
{% for field in form %} {% for field in form %}
{{ render_field(field, disabled_fields, exclude_fields) }} {{ render_field(field, disabled_fields, exclude_fields) }}
{% endfor %} {% endfor %}
<button type="submit" class="btn btn-primary">Register Catalog</button> <button type="submit" class="btn btn-primary">Save Catalog</button>
</form> </form>
{% endblock %} {% endblock %}

View File

@@ -0,0 +1,31 @@
{% extends 'base.html' %}
{% from "macros.html" import render_field2, render_dynamic_fields %}
{% block title %}Edit Retriever{% endblock %}
{% block content_title %}Edit Retriever{% endblock %}
{% block content_description %}Edit a Retriever (for a Catalog){% endblock %}
{% block content %}
<form method="post">
{{ form.hidden_tag() }}
{% set disabled_fields = ['type'] %}
{% set exclude_fields = [] %}
<!-- Render Static Fields -->
{% for field in form.get_static_fields() %}
{{ render_field2(field, disabled_fields, exclude_fields) }}
{% endfor %}
<!-- Render Dynamic Fields -->
{% for collection_name, fields in form.get_dynamic_fields().items() %}
<h4 class="mt-4">{{ collection_name }}</h4>
{% for field in fields %}
{{ render_field2(field, disabled_fields, exclude_fields) }}
{% endfor %}
{% endfor %}
<button type="submit" class="btn btn-primary">Save Retriever</button>
</form>
{% endblock %}
{% block content_footer %}
{% endblock %}

View File

@@ -0,0 +1,23 @@
{% extends 'base.html' %}
{% from "macros.html" import render_field %}
{% block title %}Retriever Registration{% endblock %}
{% block content_title %}Register Retriever{% endblock %}
{% block content_description %}Define a new retriever (for a catalog){% endblock %}
{% block content %}
<form method="post">
{{ form.hidden_tag() }}
{% set disabled_fields = [] %}
{% set exclude_fields = [] %}
{% for field in form %}
{{ render_field(field, disabled_fields, exclude_fields) }}
{% endfor %}
<button type="submit" class="btn btn-primary">Register Retriever</button>
</form>
{% endblock %}
{% block content_footer %}
{% endblock %}

View File

@@ -0,0 +1,23 @@
{% extends 'base.html' %}
{% from 'macros.html' import render_selectable_table, render_pagination %}
{% block title %}Retrievers{% endblock %}
{% block content_title %}Retrievers{% endblock %}
{% block content_description %}View Retrieers for Tenant{% endblock %}
{% block content_class %}<div class="col-xl-12 col-lg-5 col-md-7 mx-auto"></div>{% endblock %}
{% block content %}
<div class="container">
<form method="POST" action="{{ url_for('document_bp.handle_retriever_selection') }}">
{{ render_selectable_table(headers=["Retriever ID", "Name", "Type", "Catalog ID"], rows=rows, selectable=True, id="retrieverssTable") }}
<div class="form-group mt-3">
<button type="submit" name="action" value="edit_retriever" class="btn btn-primary">Edit Retriever</button>
</div>
</form>
</div>
{% endblock %}
{% block content_footer %}
{{ render_pagination(pagination, 'document_bp.retrievers') }}
{% endblock %}

View File

@@ -23,6 +23,43 @@
{% endif %} {% endif %}
{% endmacro %} {% endmacro %}
{% macro render_field2(field, disabled_fields=[], exclude_fields=[], class='') %}
<!-- Debug info -->
<!-- Field name: {{ field.name }}, Field type: {{ field.__class__.__name__ }} -->
{% set disabled = field.name in disabled_fields %}
{% set exclude_fields = exclude_fields + ['csrf_token', 'submit'] %}
{% if field.name not in exclude_fields %}
{% if field.type == 'BooleanField' %}
<div class="form-group">
<div class="form-check form-switch">
{{ field(class="form-check-input " + class, disabled=disabled) }}
{{ field.label(class="form-check-label") }}
</div>
{% if field.errors %}
<div class="invalid-feedback d-block">
{% for error in field.errors %}
{{ error }}
{% endfor %}
</div>
{% endif %}
</div>
{% else %}
<div class="form-group">
{{ field.label(class="form-label") }}
{{ field(class="form-control " + class, disabled=disabled) }}
{% if field.errors %}
<div class="invalid-feedback d-block">
{% for error in field.errors %}
{{ error }}
{% endfor %}
</div>
{% endif %}
</div>
{% endif %}
{% endif %}
{% endmacro %}
{% macro render_included_field(field, disabled_fields=[], include_fields=[]) %} {% macro render_included_field(field, disabled_fields=[], include_fields=[]) %}
{% set disabled = field.name in disabled_fields %} {% set disabled = field.name in disabled_fields %}
{% if field.name in include_fields %} {% if field.name in include_fields %}

View File

@@ -83,6 +83,8 @@
{{ dropdown('Document Mgmt', 'note_stack', [ {{ dropdown('Document Mgmt', 'note_stack', [
{'name': 'Add Catalog', 'url': '/document/catalog', 'roles': ['Super User', 'Tenant Admin']}, {'name': 'Add Catalog', 'url': '/document/catalog', 'roles': ['Super User', 'Tenant Admin']},
{'name': 'All Catalogs', 'url': '/document/catalogs', 'roles': ['Super User', 'Tenant Admin']}, {'name': 'All Catalogs', 'url': '/document/catalogs', 'roles': ['Super User', 'Tenant Admin']},
{'name': 'Add Retriever', 'url': '/document/retriever', 'roles': ['Super User', 'Tenant Admin']},
{'name': 'All Retrievers', 'url': '/document/retrievers', 'roles': ['Super User', 'Tenant Admin']},
{'name': 'Add Document', 'url': '/document/add_document', 'roles': ['Super User', 'Tenant Admin']}, {'name': 'Add Document', 'url': '/document/add_document', 'roles': ['Super User', 'Tenant Admin']},
{'name': 'Add URL', 'url': '/document/add_url', 'roles': ['Super User', 'Tenant Admin']}, {'name': 'Add URL', 'url': '/document/add_url', 'roles': ['Super User', 'Tenant Admin']},
{'name': 'Add a list of URLs', 'url': '/document/add_urls', 'roles': ['Super User', 'Tenant Admin']}, {'name': 'Add a list of URLs', 'url': '/document/add_urls', 'roles': ['Super User', 'Tenant Admin']},

View File

@@ -1,4 +1,4 @@
from flask import session, current_app from flask import session, current_app, request
from flask_wtf import FlaskForm from flask_wtf import FlaskForm
from wtforms import (StringField, BooleanField, SubmitField, DateField, IntegerField, FloatField, SelectMultipleField, from wtforms import (StringField, BooleanField, SubmitField, DateField, IntegerField, FloatField, SelectMultipleField,
SelectField, FieldList, FormField, TextAreaField, URLField) SelectField, FieldList, FormField, TextAreaField, URLField)
@@ -6,6 +6,14 @@ from wtforms.validators import DataRequired, Length, Optional, URL, ValidationEr
from flask_wtf.file import FileField, FileAllowed, FileRequired from flask_wtf.file import FileField, FileAllowed, FileRequired
import json import json
from wtforms_sqlalchemy.fields import QuerySelectField
from common.models.document import Catalog
from config.catalog_types import CATALOG_TYPES
from config.retriever_types import RETRIEVER_TYPES
from .dynamic_form_base import DynamicFormBase
def allowed_file(form, field): def allowed_file(form, field):
if field.data: if field.data:
@@ -26,6 +34,23 @@ def validate_json(form, field):
class CatalogForm(FlaskForm): class CatalogForm(FlaskForm):
name = StringField('Name', validators=[DataRequired(), Length(max=50)]) name = StringField('Name', validators=[DataRequired(), Length(max=50)])
description = TextAreaField('Description', validators=[Optional()]) description = TextAreaField('Description', validators=[Optional()])
# Parent ID (Optional for root-level catalogs)
parent = QuerySelectField(
'Parent Catalog',
query_factory=lambda: Catalog.query.all(),
allow_blank=True,
get_label='name',
validators=[Optional()],
)
# Select Field for Catalog Type (Uses the CATALOG_TYPES defined in config)
type = SelectField('Catalog Type', validators=[DataRequired()])
# Metadata fields
user_metadata = TextAreaField('User Metadata', validators=[Optional(), validate_json])
system_metadata = TextAreaField('System Metadata', validators=[Optional(), validate_json])
configuration = TextAreaField('Configuration', validators=[Optional(), validate_json])
# HTML Embedding Variables # HTML Embedding Variables
html_tags = StringField('HTML Tags', validators=[DataRequired()], html_tags = StringField('HTML Tags', validators=[DataRequired()],
default='p, h1, h2, h3, h4, h5, h6, li, , tbody, tr, td') default='p, h1, h2, h3, h4, h5, h6, li, , tbody, tr, td')
@@ -38,19 +63,65 @@ class CatalogForm(FlaskForm):
default=2000) default=2000)
max_chunk_size = IntegerField('Maximum Chunk Size (3000)', validators=[NumberRange(min=0), Optional()], max_chunk_size = IntegerField('Maximum Chunk Size (3000)', validators=[NumberRange(min=0), Optional()],
default=3000) default=3000)
# Embedding Search variables
es_k = IntegerField('Limit for Searching Embeddings (5)',
default=5,
validators=[NumberRange(min=0)])
es_similarity_threshold = FloatField('Similarity Threshold for Searching Embeddings (0.5)',
default=0.5,
validators=[NumberRange(min=0, max=1)])
# Chat Variables # Chat Variables
chat_RAG_temperature = FloatField('RAG Temperature', default=0.3, validators=[NumberRange(min=0, max=1)]) chat_RAG_temperature = FloatField('RAG Temperature', default=0.3, validators=[NumberRange(min=0, max=1)])
chat_no_RAG_temperature = FloatField('No RAG Temperature', default=0.5, validators=[NumberRange(min=0, max=1)]) chat_no_RAG_temperature = FloatField('No RAG Temperature', default=0.5, validators=[NumberRange(min=0, max=1)])
# Tuning variables # Tuning variables
embed_tuning = BooleanField('Enable Embedding Tuning', default=False) embed_tuning = BooleanField('Enable Embedding Tuning', default=False)
rag_tuning = BooleanField('Enable RAG Tuning', default=False)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Dynamically populate the 'type' field using the constructor
self.type.choices = [(key, value['name']) for key, value in CATALOG_TYPES.items()]
class RetrieverForm(FlaskForm):
name = StringField('Name', validators=[DataRequired(), Length(max=50)])
description = TextAreaField('Description', validators=[Optional()])
# Catalog for the Retriever
catalog = QuerySelectField(
'Catalog ID',
query_factory=lambda: Catalog.query.all(),
allow_blank=True,
get_label='name',
validators=[Optional()],
)
# Select Field for Retriever Type (Uses the RETRIEVER_TYPES defined in config)
type = SelectField('Retriever Type', validators=[DataRequired()])
# Metadata fields
user_metadata = TextAreaField('User Metadata', validators=[Optional(), validate_json])
system_metadata = TextAreaField('System Metadata', validators=[Optional(), validate_json])
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Dynamically populate the 'type' field using the constructor
self.type.choices = [(key, value['name']) for key, value in RETRIEVER_TYPES.items()]
class EditRetrieverForm(DynamicFormBase):
name = StringField('Name', validators=[DataRequired(), Length(max=50)])
description = TextAreaField('Description', validators=[Optional()])
# Catalog for the Retriever
catalog = QuerySelectField(
'Catalog ID',
query_factory=lambda: Catalog.query.all(),
allow_blank=True,
get_label='name',
validators=[Optional()],
)
# Select Field for Retriever Type (Uses the RETRIEVER_TYPES defined in config)
type = SelectField('Retriever Type', validators=[DataRequired()], render_kw={'readonly': True})
# Metadata fields
user_metadata = TextAreaField('User Metadata', validators=[Optional(), validate_json])
system_metadata = TextAreaField('System Metadata', validators=[Optional(), validate_json])
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Set the retriever type choices (loaded from config)
self.type.choices = [(key, value['name']) for key, value in RETRIEVER_TYPES.items()]
class AddDocumentForm(FlaskForm): class AddDocumentForm(FlaskForm):

View File

@@ -14,7 +14,7 @@ from urllib.parse import urlparse, unquote
import io import io
import json import json
from common.models.document import Document, DocumentVersion, Catalog from common.models.document import Document, DocumentVersion, Catalog, Retriever
from common.extensions import db, minio_client from common.extensions import db, minio_client
from common.utils.document_utils import validate_file_type, create_document_stack, start_embedding_task, process_url, \ from common.utils.document_utils import validate_file_type, create_document_stack, start_embedding_task, process_url, \
process_multiple_urls, get_documents_list, edit_document, \ process_multiple_urls, get_documents_list, edit_document, \
@@ -22,7 +22,7 @@ from common.utils.document_utils import validate_file_type, create_document_stac
from common.utils.eveai_exceptions import EveAIInvalidLanguageException, EveAIUnsupportedFileType, \ from common.utils.eveai_exceptions import EveAIInvalidLanguageException, EveAIUnsupportedFileType, \
EveAIDoubleURLException EveAIDoubleURLException
from .document_forms import AddDocumentForm, AddURLForm, EditDocumentForm, EditDocumentVersionForm, AddURLsForm, \ from .document_forms import AddDocumentForm, AddURLForm, EditDocumentForm, EditDocumentVersionForm, AddURLsForm, \
CatalogForm CatalogForm, RetrieverForm, EditRetrieverForm
from common.utils.middleware import mw_before_request from common.utils.middleware import mw_before_request
from common.utils.celery_utils import current_celery from common.utils.celery_utils import current_celery
from common.utils.nginx_utils import prefixed_url_for from common.utils.nginx_utils import prefixed_url_for
@@ -30,6 +30,8 @@ from common.utils.view_assistants import form_validation_failed, prepare_table_f
from .document_list_view import DocumentListView from .document_list_view import DocumentListView
from .document_version_list_view import DocumentVersionListView from .document_version_list_view import DocumentVersionListView
from config.retriever_types import RETRIEVER_TYPES
document_bp = Blueprint('document_bp', __name__, url_prefix='/document') document_bp = Blueprint('document_bp', __name__, url_prefix='/document')
@@ -65,6 +67,7 @@ def catalog():
tenant_id = session.get('tenant').get('id') tenant_id = session.get('tenant').get('id')
new_catalog = Catalog() new_catalog = Catalog()
form.populate_obj(new_catalog) form.populate_obj(new_catalog)
new_catalog.parent_id = form.parent.data.get('id')
# Handle Embedding Variables # Handle Embedding Variables
new_catalog.html_tags = [tag.strip() for tag in form.html_tags.data.split(',')] if form.html_tags.data else [] new_catalog.html_tags = [tag.strip() for tag in form.html_tags.data.split(',')] if form.html_tags.data else []
new_catalog.html_end_tags = [tag.strip() for tag in form.html_end_tags.data.split(',')] \ new_catalog.html_end_tags = [tag.strip() for tag in form.html_end_tags.data.split(',')] \
@@ -103,7 +106,7 @@ def catalogs():
the_catalogs = pagination.items the_catalogs = pagination.items
# prepare table data # prepare table data
rows = prepare_table_for_macro(the_catalogs, [('id', ''), ('name', '')]) rows = prepare_table_for_macro(the_catalogs, [('id', ''), ('name', ''), ('type', '')])
# Render the catalogs in a template # Render the catalogs in a template
return render_template('document/catalogs.html', rows=rows, pagination=pagination) return render_template('document/catalogs.html', rows=rows, pagination=pagination)
@@ -173,6 +176,121 @@ def edit_catalog(catalog_id):
return render_template('document/edit_catalog.html', form=form, catalog_id=catalog_id) return render_template('document/edit_catalog.html', form=form, catalog_id=catalog_id)
@document_bp.route('/retriever', methods=['GET', 'POST'])
@roles_accepted('Super User', 'Tenant Admin')
def retriever():
form = RetrieverForm()
if form.validate_on_submit():
tenant_id = session.get('tenant').get('id')
new_retriever = Retriever()
form.populate_obj(new_retriever)
new_retriever.catalog_id = form.catalog.data.id
set_logging_information(new_retriever, dt.now(tz.utc))
try:
db.session.add(new_retriever)
db.session.commit()
flash('Retriever successfully added!', 'success')
current_app.logger.info(f'Catalog {new_retriever.name} successfully added for tenant {tenant_id}!')
except SQLAlchemyError as e:
db.session.rollback()
flash(f'Failed to add retriever. Error: {e}', 'danger')
current_app.logger.error(f'Failed to add retriever {new_retriever.name}'
f'for tenant {tenant_id}. Error: {str(e)}')
# Enable step 2 of creation of retriever - add configuration of the retriever (dependent on type)
return redirect(prefixed_url_for('document_bp.retriever', retriever_id=new_retriever.id))
return render_template('document/retriever.html', form=form)
@document_bp.route('/retriever/<int:retriever_id>', methods=['GET', 'POST'])
@roles_accepted('Super User', 'Tenant Admin')
def edit_retriever(retriever_id):
"""Edit an existing retriever configuration."""
current_app.logger.debug(f"Editing Retriever {retriever_id}")
# Get the retriever or return 404
retriever = Retriever.query.get_or_404(retriever_id)
if retriever.catalog_id:
# If catalog_id is just an ID, fetch the Catalog object
retriever.catalog = Catalog.query.get(retriever.catalog_id)
else:
retriever.catalog = None
# Create form instance with the retriever
form = EditRetrieverForm(request.form, obj=retriever)
configuration_config = RETRIEVER_TYPES[retriever.type]["configuration"]
current_app.logger.debug(f"Configuration {configuration_config}")
form.add_dynamic_fields("configuration", configuration_config, retriever.configuration)
if form.validate_on_submit():
# Update basic fields
form.populate_obj(retriever)
retriever.configuration = form.get_dynamic_data('configuration')
# Update catalog relationship
retriever.catalog_id = form.catalog.data.id if form.catalog.data else None
# Update logging information
update_logging_information(retriever, dt.now(tz.utc))
# Save changes to database
try:
db.session.add(retriever)
db.session.commit()
flash('Retriever updated successfully!', 'success')
current_app.logger.info(f'Retriever {retriever.id} updated successfully')
except SQLAlchemyError as e:
db.session.rollback()
flash(f'Failed to update retriever. Error: {str(e)}', 'danger')
current_app.logger.error(f'Failed to update retriever {retriever_id}. Error: {str(e)}')
return render_template('document/edit_retriever.html', form=form, retriever_id=retriever_id)
return redirect(prefixed_url_for('document_bp.retrievers'))
else:
form_validation_failed(request, form)
current_app.logger.debug(f"Rendering Template for {retriever_id}")
return render_template('document/edit_retriever.html', form=form, retriever_id=retriever_id)
@document_bp.route('/retrievers', methods=['GET', 'POST'])
@roles_accepted('Super User', 'Tenant Admin')
def retrievers():
page = request.args.get('page', 1, type=int)
per_page = request.args.get('per_page', 10, type=int)
query = Retriever.query.order_by(Retriever.id)
pagination = query.paginate(page=page, per_page=per_page)
the_retrievers = pagination.items
# prepare table data
rows = prepare_table_for_macro(the_retrievers,
[('id', ''), ('name', ''), ('type', ''), ('catalog_id', '')])
# Render the catalogs in a template
return render_template('document/retrievers.html', rows=rows, pagination=pagination)
@document_bp.route('/handle_retriever_selection', methods=['POST'])
@roles_accepted('Super User', 'Tenant Admin')
def handle_retriever_selection():
retriever_identification = request.form.get('selected_row')
retriever_id = ast.literal_eval(retriever_identification).get('value')
action = request.form['action']
if action == 'edit_retriever':
return redirect(prefixed_url_for('document_bp.edit_retriever', retriever_id=retriever_id))
return redirect(prefixed_url_for('document_bp.retrievers'))
@document_bp.route('/add_document', methods=['GET', 'POST']) @document_bp.route('/add_document', methods=['GET', 'POST'])
@roles_accepted('Super User', 'Tenant Admin') @roles_accepted('Super User', 'Tenant Admin')
def add_document(): def add_document():

View File

@@ -0,0 +1,92 @@
from flask_wtf import FlaskForm
from wtforms import IntegerField, FloatField, BooleanField, StringField, validators
from flask import current_app
class DynamicFormBase(FlaskForm):
def __init__(self, formdata=None, *args, **kwargs):
super(DynamicFormBase, self).__init__(*args, **kwargs)
# Maps collection names to lists of field names
self.dynamic_fields = {}
# Store formdata for later use
self.formdata = formdata
def add_dynamic_fields(self, collection_name, config, initial_data=None):
"""Add dynamic fields to the form based on the configuration."""
self.dynamic_fields[collection_name] = []
for field_name, field_def in config.items():
current_app.logger.debug(f"{field_name}: {field_def}")
# Prefix the field name with the collection name
full_field_name = f"{collection_name}_{field_name}"
field_type = field_def.get('type')
description = field_def.get('description', '')
required = field_def.get('required', False)
default = field_def.get('default')
# Determine validators
field_validators = [validators.InputRequired()] if required else [validators.Optional()]
# Map the field type to WTForms field classes
field_class = {
'int': IntegerField,
'float': FloatField,
'boolean': BooleanField,
'string': StringField,
}.get(field_type, StringField)
# Create the field instance
unbound_field = field_class(
label=description,
validators=field_validators,
default=default
)
# Bind the field to the form
bound_field = unbound_field.bind(form=self, name=full_field_name)
# Process the field with formdata
if self.formdata and full_field_name in self.formdata:
# If formdata is available and contains the field
bound_field.process(self.formdata)
elif initial_data and field_name in initial_data:
# Use initial data if provided
bound_field.process(formdata=None, data=initial_data[field_name])
else:
# Use default value
bound_field.process(formdata=None, data=default)
# Set collection name attribute for identification
# bound_field.collection_name = collection_name
# Add the field to the form
setattr(self, full_field_name, bound_field)
self._fields[full_field_name] = bound_field
self.dynamic_fields[collection_name].append(full_field_name)
def get_static_fields(self):
"""Return a list of static field instances."""
# Get names of dynamic fields
dynamic_field_names = set()
for field_list in self.dynamic_fields.values():
dynamic_field_names.update(field_list)
# Return all fields that are not dynamic
return [field for name, field in self._fields.items() if name not in dynamic_field_names]
def get_dynamic_fields(self):
"""Return a dictionary of dynamic fields per collection."""
result = {}
for collection_name, field_names in self.dynamic_fields.items():
result[collection_name] = [getattr(self, name) for name in field_names]
return result
def get_dynamic_data(self, collection_name):
"""Retrieve the data from dynamic fields of a specific collection."""
data = {}
if collection_name not in self.dynamic_fields:
return data
prefix_length = len(collection_name) + 1 # +1 for the underscore
for full_field_name in self.dynamic_fields[collection_name]:
original_field_name = full_field_name[prefix_length:]
field = getattr(self, full_field_name)
data[original_field_name] = field.data
return data

View File

@@ -22,7 +22,7 @@ from common.models.interaction import ChatSession, Interaction, InteractionEmbed
from common.extensions import db from common.extensions import db
from common.utils.celery_utils import current_celery from common.utils.celery_utils import current_celery
from common.utils.model_utils import select_model_variables, create_language_template, replace_variable_in_template from common.utils.model_utils import select_model_variables, create_language_template, replace_variable_in_template
from common.langchain.eveai_retriever import EveAIRetriever from common.langchain.eveai_default_rag_retriever import EveAIDefaultRagRetriever
from common.langchain.eveai_history_retriever import EveAIHistoryRetriever from common.langchain.eveai_history_retriever import EveAIHistoryRetriever
from common.utils.business_event import BusinessEvent from common.utils.business_event import BusinessEvent
from common.utils.business_event_context import current_event from common.utils.business_event_context import current_event
@@ -139,7 +139,7 @@ def answer_using_tenant_rag(question, language, tenant, chat_session):
new_interaction.detailed_question_at = dt.now(tz.utc) new_interaction.detailed_question_at = dt.now(tz.utc)
with current_event.create_span("Generate Answer using RAG"): with current_event.create_span("Generate Answer using RAG"):
retriever = EveAIRetriever(model_variables, tenant_info) retriever = EveAIDefaultRagRetriever(model_variables, tenant_info)
llm = model_variables['llm'] llm = model_variables['llm']
template = model_variables['rag_template'] template = model_variables['rag_template']
language_template = create_language_template(template, language) language_template = create_language_template(template, language)
@@ -243,7 +243,7 @@ def answer_using_llm(question, language, tenant, chat_session):
new_interaction.detailed_question_at = dt.now(tz.utc) new_interaction.detailed_question_at = dt.now(tz.utc)
with current_event.create_span("Detail Answer using LLM"): with current_event.create_span("Detail Answer using LLM"):
retriever = EveAIRetriever(model_variables, tenant_info) retriever = EveAIDefaultRagRetriever(model_variables, tenant_info)
llm = model_variables['llm_no_rag'] llm = model_variables['llm_no_rag']
template = model_variables['encyclopedia_template'] template = model_variables['encyclopedia_template']
language_template = create_language_template(template, language) language_template = create_language_template(template, language)

View File

@@ -0,0 +1,56 @@
"""Add Retriever Model
Revision ID: 3717364e6429
Revises: 7b7b566e667f
Create Date: 2024-10-21 14:22:30.258679
"""
from alembic import op
import sqlalchemy as sa
import pgvector
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = '3717364e6429'
down_revision = '7b7b566e667f'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('retriever',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('name', sa.String(length=50), nullable=False),
sa.Column('description', sa.Text(), nullable=True),
sa.Column('catalog_id', sa.Integer(), nullable=True),
sa.Column('type', sa.String(length=50), nullable=False),
sa.Column('user_metadata', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.Column('system_metadata', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.Column('configuration', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=False),
sa.Column('created_by', sa.Integer(), nullable=True),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('now()'), nullable=False),
sa.Column('updated_by', sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(['catalog_id'], ['catalog.id'], ),
sa.ForeignKeyConstraint(['created_by'], ['public.user.id'], ),
sa.ForeignKeyConstraint(['updated_by'], ['public.user.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.drop_column('catalog', 'es_similarity_threshold')
op.drop_column('catalog', 'es_k')
op.drop_column('catalog', 'rag_tuning')
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('catalog', sa.Column('rag_tuning', sa.BOOLEAN(), autoincrement=False, nullable=True))
op.add_column('catalog', sa.Column('es_k', sa.INTEGER(), autoincrement=False, nullable=True))
op.add_column('catalog', sa.Column('es_similarity_threshold', sa.DOUBLE_PRECISION(precision=53), autoincrement=False, nullable=True))
op.drop_constraint(None, 'catalog', type_='foreignkey')
op.drop_constraint(None, 'catalog', type_='foreignkey')
op.create_foreign_key('catalog_updated_by_fkey', 'catalog', 'user', ['updated_by'], ['id'])
op.create_foreign_key('catalog_created_by_fkey', 'catalog', 'user', ['created_by'], ['id'])
op.drop_table('retriever')
# ### end Alembic commands ###

View File

@@ -0,0 +1,46 @@
"""Extensions for more Catalog Types in Catalog Model
Revision ID: 7b7b566e667f
Revises: 28984b05d396
Create Date: 2024-10-21 07:39:52.260054
"""
from alembic import op
import sqlalchemy as sa
import pgvector
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = '7b7b566e667f'
down_revision = '28984b05d396'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('catalog', sa.Column('parent_id', sa.Integer(), nullable=True))
op.add_column('catalog', sa.Column('type', sa.String(length=50), nullable=False, server_default='DEFAULT'))
# Update all existing rows to have the 'type' set to 'DEFAULT'
op.execute("UPDATE catalog SET type='DEFAULT' WHERE type IS NULL")
# Remove the server default (optional but recommended)
op.alter_column('catalog', 'type', server_default=None)
op.add_column('catalog', sa.Column('user_metadata', postgresql.JSONB(astext_type=sa.Text()), nullable=True))
op.add_column('catalog', sa.Column('system_metadata', postgresql.JSONB(astext_type=sa.Text()), nullable=True))
op.add_column('catalog', sa.Column('configuration', postgresql.JSONB(astext_type=sa.Text()), nullable=True))
op.create_foreign_key(None, 'catalog', 'catalog', ['parent_id'], ['id'])
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint(None, 'catalog', type_='foreignkey')
op.drop_column('catalog', 'configuration')
op.drop_column('catalog', 'system_metadata')
op.drop_column('catalog', 'user_metadata')
op.drop_column('catalog', 'type')
op.drop_column('catalog', 'parent_id')
# ### end Alembic commands ###

View File

@@ -82,3 +82,4 @@ prometheus-client~=0.20.0
flower~=2.0.1 flower~=2.0.1
psutil~=6.0.0 psutil~=6.0.0
celery-redbeat~=2.2.0 celery-redbeat~=2.2.0
WTForms-SQLAlchemy~=0.4.1