- Allowing for multiple types of Catalogs

- Introduction of retrievers
- Ensuring processing information is collected from Catalog iso Tenant
- Introduction of a generic Form class to enable dynamic fields based on a configuration
- Realisation of Retriever functionality to support dynamic fields
This commit is contained in:
Josako
2024-10-25 14:11:47 +02:00
parent 30fec27488
commit aa358df28e
19 changed files with 753 additions and 145 deletions

View File

@@ -0,0 +1,145 @@
from langchain_core.retrievers import BaseRetriever
from sqlalchemy import func, and_, or_, desc
from sqlalchemy.exc import SQLAlchemyError
from pydantic import BaseModel, Field, PrivateAttr
from typing import Any, Dict
from flask import current_app
from common.extensions import db
from common.models.document import Document, DocumentVersion
from common.utils.datetime_utils import get_date_in_timezone
from common.utils.model_utils import ModelVariables
class EveAIDefaultRagRetriever(BaseRetriever, BaseModel):
_catalog_id: int = PrivateAttr()
_model_variables: ModelVariables = PrivateAttr()
_tenant_info: Dict[str, Any] = PrivateAttr()
def __init__(self, catalog_id: int, model_variables: ModelVariables, tenant_info: Dict[str, Any]):
super().__init__()
current_app.logger.debug(f'Model variables type: {type(model_variables)}')
self._catalog_id = catalog_id
self._model_variables = model_variables
self._tenant_info = tenant_info
@property
def catalog_id(self) -> int:
return self._catalog_id
@property
def model_variables(self) -> ModelVariables:
return self._model_variables
@property
def tenant_info(self) -> Dict[str, Any]:
return self._tenant_info
def _get_relevant_documents(self, query: str):
current_app.logger.debug(f'Retrieving relevant documents for query: {query}')
query_embedding = self._get_query_embedding(query)
current_app.logger.debug(f'Model Variables Private: {type(self._model_variables)}')
current_app.logger.debug(f'Model Variables Property: {type(self.model_variables)}')
db_class = self.model_variables['embedding_db_model']
similarity_threshold = self.model_variables['similarity_threshold']
k = self.model_variables['k']
if self.model_variables['rag_tuning']:
try:
current_date = get_date_in_timezone(self.tenant_info['timezone'])
current_app.rag_tuning_logger.debug(f'Current date: {current_date}\n')
# Debug query to show similarity for all valid documents (without chunk text)
debug_query = (
db.session.query(
Document.id.label('document_id'),
DocumentVersion.id.label('version_id'),
db_class.id.label('embedding_id'),
(1 - db_class.embedding.cosine_distance(query_embedding)).label('similarity')
)
.join(DocumentVersion, db_class.doc_vers_id == DocumentVersion.id)
.join(Document, DocumentVersion.doc_id == Document.id)
.filter(
or_(Document.valid_from.is_(None), func.date(Document.valid_from) <= current_date),
or_(Document.valid_to.is_(None), func.date(Document.valid_to) >= current_date)
)
.order_by(desc('similarity'))
)
debug_results = debug_query.all()
current_app.logger.debug("Debug: Similarity for all valid documents:")
for row in debug_results:
current_app.rag_tuning_logger.debug(f"Doc ID: {row.document_id}, "
f"Version ID: {row.version_id}, "
f"Embedding ID: {row.embedding_id}, "
f"Similarity: {row.similarity}")
current_app.rag_tuning_logger.debug(f'---------------------------------------\n')
except SQLAlchemyError as e:
current_app.logger.error(f'Error generating overview: {e}')
db.session.rollback()
if self.model_variables['rag_tuning']:
current_app.rag_tuning_logger.debug(f'Parameters for Retrieval of documents: \n')
current_app.rag_tuning_logger.debug(f'Similarity Threshold: {similarity_threshold}\n')
current_app.rag_tuning_logger.debug(f'K: {k}\n')
current_app.rag_tuning_logger.debug(f'---------------------------------------\n')
try:
current_date = get_date_in_timezone(self.tenant_info['timezone'])
# Subquery to find the latest version of each document
subquery = (
db.session.query(
DocumentVersion.doc_id,
func.max(DocumentVersion.id).label('latest_version_id')
)
.group_by(DocumentVersion.doc_id)
.subquery()
)
# Main query to filter embeddings
query_obj = (
db.session.query(db_class,
(1 - db_class.embedding.cosine_distance(query_embedding)).label('similarity'))
.join(DocumentVersion, db_class.doc_vers_id == DocumentVersion.id)
.join(Document, DocumentVersion.doc_id == Document.id)
.join(subquery, DocumentVersion.id == subquery.c.latest_version_id)
.filter(
or_(Document.valid_from.is_(None), func.date(Document.valid_from) <= current_date),
or_(Document.valid_to.is_(None), func.date(Document.valid_to) >= current_date),
(1 - db_class.embedding.cosine_distance(query_embedding)) > similarity_threshold,
Document.catalog_id == self._catalog_id
)
.order_by(desc('similarity'))
.limit(k)
)
if self.model_variables['rag_tuning']:
current_app.rag_tuning_logger.debug(f'Query executed for Retrieval of documents: \n')
current_app.rag_tuning_logger.debug(f'{query_obj.statement}\n')
current_app.rag_tuning_logger.debug(f'---------------------------------------\n')
res = query_obj.all()
if self.model_variables['rag_tuning']:
current_app.rag_tuning_logger.debug(f'Retrieved {len(res)} relevant documents \n')
current_app.rag_tuning_logger.debug(f'Data retrieved: \n')
current_app.rag_tuning_logger.debug(f'{res}\n')
current_app.rag_tuning_logger.debug(f'---------------------------------------\n')
result = []
for doc in res:
if self.model_variables['rag_tuning']:
current_app.rag_tuning_logger.debug(f'Document ID: {doc[0].id} - Distance: {doc[1]}\n')
current_app.rag_tuning_logger.debug(f'Chunk: \n {doc[0].chunk}\n\n')
result.append(f'SOURCE: {doc[0].id}\n\n{doc[0].chunk}\n\n')
except SQLAlchemyError as e:
current_app.logger.error(f'Error retrieving relevant documents: {e}')
db.session.rollback()
return []
return result
def _get_query_embedding(self, query: str):
embedding_model = self.model_variables['embedding_model']
query_embedding = embedding_model.embed_query(query)
return query_embedding

View File

@@ -1,138 +1,39 @@
from langchain_core.retrievers import BaseRetriever
from sqlalchemy import func, and_, or_, desc
from sqlalchemy.exc import SQLAlchemyError
from pydantic import BaseModel, Field, PrivateAttr
from typing import Any, Dict
from flask import current_app
from pydantic import BaseModel, PrivateAttr
from typing import Dict, Any
from common.extensions import db
from common.models.document import Document, DocumentVersion
from common.utils.datetime_utils import get_date_in_timezone
from common.utils.model_utils import ModelVariables
class EveAIRetriever(BaseRetriever, BaseModel):
_model_variables: ModelVariables = PrivateAttr()
class EveAIRetriever(BaseModel):
_catalog_id: int = PrivateAttr()
_user_metadata: Dict[str, Any] = PrivateAttr()
_system_metadata: Dict[str, Any] = PrivateAttr()
_configuration: Dict[str, Any] = PrivateAttr()
_tenant_info: Dict[str, Any] = PrivateAttr()
_model_variables: ModelVariables = PrivateAttr()
def __init__(self, model_variables: ModelVariables, tenant_info: Dict[str, Any]):
def __init__(self, catalog_id: int, user_metadata: Dict[str, Any], system_metadata: Dict[str, Any],
configuration: Dict[str, Any]):
super().__init__()
current_app.logger.debug(f'Model variables type: {type(model_variables)}')
self._model_variables = model_variables
self._tenant_info = tenant_info
self._catalog_id = catalog_id
self._user_metadata = user_metadata
self._system_metadata = system_metadata
self._configuration = configuration
@property
def model_variables(self) -> ModelVariables:
return self._model_variables
def catalog_id(self):
return self._catalog_id
@property
def tenant_info(self) -> Dict[str, Any]:
return self._tenant_info
def user_metadata(self):
return self._user_metadata
def _get_relevant_documents(self, query: str):
current_app.logger.debug(f'Retrieving relevant documents for query: {query}')
query_embedding = self._get_query_embedding(query)
current_app.logger.debug(f'Model Variables Private: {type(self._model_variables)}')
current_app.logger.debug(f'Model Variables Property: {type(self.model_variables)}')
db_class = self.model_variables['embedding_db_model']
similarity_threshold = self.model_variables['similarity_threshold']
k = self.model_variables['k']
@property
def system_metadata(self):
return self._system_metadata
if self.model_variables['rag_tuning']:
try:
current_date = get_date_in_timezone(self.tenant_info['timezone'])
current_app.rag_tuning_logger.debug(f'Current date: {current_date}\n')
@property
def configuration(self):
return self._configuration
# Debug query to show similarity for all valid documents (without chunk text)
debug_query = (
db.session.query(
Document.id.label('document_id'),
DocumentVersion.id.label('version_id'),
db_class.id.label('embedding_id'),
(1 - db_class.embedding.cosine_distance(query_embedding)).label('similarity')
)
.join(DocumentVersion, db_class.doc_vers_id == DocumentVersion.id)
.join(Document, DocumentVersion.doc_id == Document.id)
.filter(
or_(Document.valid_from.is_(None), func.date(Document.valid_from) <= current_date),
or_(Document.valid_to.is_(None), func.date(Document.valid_to) >= current_date)
)
.order_by(desc('similarity'))
)
debug_results = debug_query.all()
current_app.logger.debug("Debug: Similarity for all valid documents:")
for row in debug_results:
current_app.rag_tuning_logger.debug(f"Doc ID: {row.document_id}, "
f"Version ID: {row.version_id}, "
f"Embedding ID: {row.embedding_id}, "
f"Similarity: {row.similarity}")
current_app.rag_tuning_logger.debug(f'---------------------------------------\n')
except SQLAlchemyError as e:
current_app.logger.error(f'Error generating overview: {e}')
db.session.rollback()
if self.model_variables['rag_tuning']:
current_app.rag_tuning_logger.debug(f'Parameters for Retrieval of documents: \n')
current_app.rag_tuning_logger.debug(f'Similarity Threshold: {similarity_threshold}\n')
current_app.rag_tuning_logger.debug(f'K: {k}\n')
current_app.rag_tuning_logger.debug(f'---------------------------------------\n')
try:
current_date = get_date_in_timezone(self.tenant_info['timezone'])
# Subquery to find the latest version of each document
subquery = (
db.session.query(
DocumentVersion.doc_id,
func.max(DocumentVersion.id).label('latest_version_id')
)
.group_by(DocumentVersion.doc_id)
.subquery()
)
# Main query to filter embeddings
query_obj = (
db.session.query(db_class,
(1 - db_class.embedding.cosine_distance(query_embedding)).label('similarity'))
.join(DocumentVersion, db_class.doc_vers_id == DocumentVersion.id)
.join(Document, DocumentVersion.doc_id == Document.id)
.join(subquery, DocumentVersion.id == subquery.c.latest_version_id)
.filter(
or_(Document.valid_from.is_(None), func.date(Document.valid_from) <= current_date),
or_(Document.valid_to.is_(None), func.date(Document.valid_to) >= current_date),
(1 - db_class.embedding.cosine_distance(query_embedding)) > similarity_threshold
)
.order_by(desc('similarity'))
.limit(k)
)
if self.model_variables['rag_tuning']:
current_app.rag_tuning_logger.debug(f'Query executed for Retrieval of documents: \n')
current_app.rag_tuning_logger.debug(f'{query_obj.statement}\n')
current_app.rag_tuning_logger.debug(f'---------------------------------------\n')
res = query_obj.all()
if self.model_variables['rag_tuning']:
current_app.rag_tuning_logger.debug(f'Retrieved {len(res)} relevant documents \n')
current_app.rag_tuning_logger.debug(f'Data retrieved: \n')
current_app.rag_tuning_logger.debug(f'{res}\n')
current_app.rag_tuning_logger.debug(f'---------------------------------------\n')
result = []
for doc in res:
if self.model_variables['rag_tuning']:
current_app.rag_tuning_logger.debug(f'Document ID: {doc[0].id} - Distance: {doc[1]}\n')
current_app.rag_tuning_logger.debug(f'Chunk: \n {doc[0].chunk}\n\n')
result.append(f'SOURCE: {doc[0].id}\n\n{doc[0].chunk}\n\n')
except SQLAlchemyError as e:
current_app.logger.error(f'Error retrieving relevant documents: {e}')
db.session.rollback()
return []
return result
def _get_query_embedding(self, query: str):
embedding_model = self.model_variables['embedding_model']
query_embedding = embedding_model.embed_query(query)
return query_embedding
# Any common methods that should be shared among retrievers can go here.

View File

@@ -8,8 +8,10 @@ import sqlalchemy as sa
class Catalog(db.Model):
id = db.Column(db.Integer, primary_key=True)
parent_id = db.Column(db.Integer, db.ForeignKey('catalog.id'), nullable=True)
name = db.Column(db.String(50), nullable=False)
description = db.Column(db.Text, nullable=True)
type = db.Column(db.String(50), nullable=False, default="DEFAULT_CATALOG")
# Embedding variables
html_tags = db.Column(ARRAY(sa.String(10)), nullable=True, default=['p', 'h1', 'h2', 'h3', 'h4', 'h5', 'h6', 'li'])
@@ -21,17 +23,36 @@ class Catalog(db.Model):
min_chunk_size = db.Column(db.Integer, nullable=True, default=2000)
max_chunk_size = db.Column(db.Integer, nullable=True, default=3000)
# Embedding search variables ==> move to specialist?
es_k = db.Column(db.Integer, nullable=True, default=8)
es_similarity_threshold = db.Column(db.Float, nullable=True, default=0.4)
# Chat variables ==> Move to Specialist?
chat_RAG_temperature = db.Column(db.Float, nullable=True, default=0.3)
chat_no_RAG_temperature = db.Column(db.Float, nullable=True, default=0.5)
# Tuning enablers
embed_tuning = db.Column(db.Boolean, nullable=True, default=False)
rag_tuning = db.Column(db.Boolean, nullable=True, default=False) # Move to Specialist?
# Meta Data
user_metadata = db.Column(JSONB, nullable=True)
system_metadata = db.Column(JSONB, nullable=True)
configuration = db.Column(JSONB, nullable=True)
# Versioning Information
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now())
created_by = db.Column(db.Integer, db.ForeignKey(User.id), nullable=True)
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now(), onupdate=db.func.now())
updated_by = db.Column(db.Integer, db.ForeignKey(User.id))
class Retriever(db.Model):
id = db.Column(db.Integer, primary_key=True)
name = db.Column(db.String(50), nullable=False)
description = db.Column(db.Text, nullable=True)
catalog_id = db.Column(db.Integer, db.ForeignKey('catalog.id'), nullable=True)
type = db.Column(db.String(50), nullable=False, default="DEFAULT_RAG")
# Meta Data
user_metadata = db.Column(JSONB, nullable=True)
system_metadata = db.Column(JSONB, nullable=True)
configuration = db.Column(JSONB, nullable=True)
# Versioning Information
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now())

11
config/catalog_types.py Normal file
View File

@@ -0,0 +1,11 @@
# Catalog Types
CATALOG_TYPES = {
"DEFAULT": {
"name": "Default Catalog",
"Description": "Default Catalog"
},
"DOSSIER": {
"name": "Dossier Catalog",
"Description": "A Catalog in which several Dossiers can be stored"
},
}

30
config/retriever_types.py Normal file
View File

@@ -0,0 +1,30 @@
# Retriever Types
RETRIEVER_TYPES = {
"DEFAULT_RAG": {
"name": "Default RAG",
"description": "Retrieving all embeddings conform the query",
"configuration": {
"es_k": {
"name": "es_k",
"type": "int",
"description": "K-value to retrieve embeddings (max embeddings retrieved)",
"required": True,
"default": 8,
},
"es_similarity_threshold": {
"name": "es_similarity_threshold",
"type": "float",
"description": "Similarity threshold for retrieving embeddings",
"required": True,
"default": 0.3,
},
"rag_tuning": {
"name": "rag_tuning",
"type": "boolean",
"description": "Whether to do tuning logging or not.",
"required": False,
"default": False,
}
}
}
}

View File

@@ -16,7 +16,7 @@ When you change chunking of embedding information, you'll need to manually refre
{% for field in form %}
{{ render_field(field, disabled_fields, exclude_fields) }}
{% endfor %}
<button type="submit" class="btn btn-primary">Register Catalog</button>
<button type="submit" class="btn btn-primary">Save Catalog</button>
</form>
{% endblock %}

View File

@@ -0,0 +1,31 @@
{% extends 'base.html' %}
{% from "macros.html" import render_field2, render_dynamic_fields %}
{% block title %}Edit Retriever{% endblock %}
{% block content_title %}Edit Retriever{% endblock %}
{% block content_description %}Edit a Retriever (for a Catalog){% endblock %}
{% block content %}
<form method="post">
{{ form.hidden_tag() }}
{% set disabled_fields = ['type'] %}
{% set exclude_fields = [] %}
<!-- Render Static Fields -->
{% for field in form.get_static_fields() %}
{{ render_field2(field, disabled_fields, exclude_fields) }}
{% endfor %}
<!-- Render Dynamic Fields -->
{% for collection_name, fields in form.get_dynamic_fields().items() %}
<h4 class="mt-4">{{ collection_name }}</h4>
{% for field in fields %}
{{ render_field2(field, disabled_fields, exclude_fields) }}
{% endfor %}
{% endfor %}
<button type="submit" class="btn btn-primary">Save Retriever</button>
</form>
{% endblock %}
{% block content_footer %}
{% endblock %}

View File

@@ -0,0 +1,23 @@
{% extends 'base.html' %}
{% from "macros.html" import render_field %}
{% block title %}Retriever Registration{% endblock %}
{% block content_title %}Register Retriever{% endblock %}
{% block content_description %}Define a new retriever (for a catalog){% endblock %}
{% block content %}
<form method="post">
{{ form.hidden_tag() }}
{% set disabled_fields = [] %}
{% set exclude_fields = [] %}
{% for field in form %}
{{ render_field(field, disabled_fields, exclude_fields) }}
{% endfor %}
<button type="submit" class="btn btn-primary">Register Retriever</button>
</form>
{% endblock %}
{% block content_footer %}
{% endblock %}

View File

@@ -0,0 +1,23 @@
{% extends 'base.html' %}
{% from 'macros.html' import render_selectable_table, render_pagination %}
{% block title %}Retrievers{% endblock %}
{% block content_title %}Retrievers{% endblock %}
{% block content_description %}View Retrieers for Tenant{% endblock %}
{% block content_class %}<div class="col-xl-12 col-lg-5 col-md-7 mx-auto"></div>{% endblock %}
{% block content %}
<div class="container">
<form method="POST" action="{{ url_for('document_bp.handle_retriever_selection') }}">
{{ render_selectable_table(headers=["Retriever ID", "Name", "Type", "Catalog ID"], rows=rows, selectable=True, id="retrieverssTable") }}
<div class="form-group mt-3">
<button type="submit" name="action" value="edit_retriever" class="btn btn-primary">Edit Retriever</button>
</div>
</form>
</div>
{% endblock %}
{% block content_footer %}
{{ render_pagination(pagination, 'document_bp.retrievers') }}
{% endblock %}

View File

@@ -23,6 +23,43 @@
{% endif %}
{% endmacro %}
{% macro render_field2(field, disabled_fields=[], exclude_fields=[], class='') %}
<!-- Debug info -->
<!-- Field name: {{ field.name }}, Field type: {{ field.__class__.__name__ }} -->
{% set disabled = field.name in disabled_fields %}
{% set exclude_fields = exclude_fields + ['csrf_token', 'submit'] %}
{% if field.name not in exclude_fields %}
{% if field.type == 'BooleanField' %}
<div class="form-group">
<div class="form-check form-switch">
{{ field(class="form-check-input " + class, disabled=disabled) }}
{{ field.label(class="form-check-label") }}
</div>
{% if field.errors %}
<div class="invalid-feedback d-block">
{% for error in field.errors %}
{{ error }}
{% endfor %}
</div>
{% endif %}
</div>
{% else %}
<div class="form-group">
{{ field.label(class="form-label") }}
{{ field(class="form-control " + class, disabled=disabled) }}
{% if field.errors %}
<div class="invalid-feedback d-block">
{% for error in field.errors %}
{{ error }}
{% endfor %}
</div>
{% endif %}
</div>
{% endif %}
{% endif %}
{% endmacro %}
{% macro render_included_field(field, disabled_fields=[], include_fields=[]) %}
{% set disabled = field.name in disabled_fields %}
{% if field.name in include_fields %}

View File

@@ -83,6 +83,8 @@
{{ dropdown('Document Mgmt', 'note_stack', [
{'name': 'Add Catalog', 'url': '/document/catalog', 'roles': ['Super User', 'Tenant Admin']},
{'name': 'All Catalogs', 'url': '/document/catalogs', 'roles': ['Super User', 'Tenant Admin']},
{'name': 'Add Retriever', 'url': '/document/retriever', 'roles': ['Super User', 'Tenant Admin']},
{'name': 'All Retrievers', 'url': '/document/retrievers', 'roles': ['Super User', 'Tenant Admin']},
{'name': 'Add Document', 'url': '/document/add_document', 'roles': ['Super User', 'Tenant Admin']},
{'name': 'Add URL', 'url': '/document/add_url', 'roles': ['Super User', 'Tenant Admin']},
{'name': 'Add a list of URLs', 'url': '/document/add_urls', 'roles': ['Super User', 'Tenant Admin']},

View File

@@ -1,4 +1,4 @@
from flask import session, current_app
from flask import session, current_app, request
from flask_wtf import FlaskForm
from wtforms import (StringField, BooleanField, SubmitField, DateField, IntegerField, FloatField, SelectMultipleField,
SelectField, FieldList, FormField, TextAreaField, URLField)
@@ -6,6 +6,14 @@ from wtforms.validators import DataRequired, Length, Optional, URL, ValidationEr
from flask_wtf.file import FileField, FileAllowed, FileRequired
import json
from wtforms_sqlalchemy.fields import QuerySelectField
from common.models.document import Catalog
from config.catalog_types import CATALOG_TYPES
from config.retriever_types import RETRIEVER_TYPES
from .dynamic_form_base import DynamicFormBase
def allowed_file(form, field):
if field.data:
@@ -26,6 +34,23 @@ def validate_json(form, field):
class CatalogForm(FlaskForm):
name = StringField('Name', validators=[DataRequired(), Length(max=50)])
description = TextAreaField('Description', validators=[Optional()])
# Parent ID (Optional for root-level catalogs)
parent = QuerySelectField(
'Parent Catalog',
query_factory=lambda: Catalog.query.all(),
allow_blank=True,
get_label='name',
validators=[Optional()],
)
# Select Field for Catalog Type (Uses the CATALOG_TYPES defined in config)
type = SelectField('Catalog Type', validators=[DataRequired()])
# Metadata fields
user_metadata = TextAreaField('User Metadata', validators=[Optional(), validate_json])
system_metadata = TextAreaField('System Metadata', validators=[Optional(), validate_json])
configuration = TextAreaField('Configuration', validators=[Optional(), validate_json])
# HTML Embedding Variables
html_tags = StringField('HTML Tags', validators=[DataRequired()],
default='p, h1, h2, h3, h4, h5, h6, li, , tbody, tr, td')
@@ -38,19 +63,65 @@ class CatalogForm(FlaskForm):
default=2000)
max_chunk_size = IntegerField('Maximum Chunk Size (3000)', validators=[NumberRange(min=0), Optional()],
default=3000)
# Embedding Search variables
es_k = IntegerField('Limit for Searching Embeddings (5)',
default=5,
validators=[NumberRange(min=0)])
es_similarity_threshold = FloatField('Similarity Threshold for Searching Embeddings (0.5)',
default=0.5,
validators=[NumberRange(min=0, max=1)])
# Chat Variables
chat_RAG_temperature = FloatField('RAG Temperature', default=0.3, validators=[NumberRange(min=0, max=1)])
chat_no_RAG_temperature = FloatField('No RAG Temperature', default=0.5, validators=[NumberRange(min=0, max=1)])
# Tuning variables
embed_tuning = BooleanField('Enable Embedding Tuning', default=False)
rag_tuning = BooleanField('Enable RAG Tuning', default=False)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Dynamically populate the 'type' field using the constructor
self.type.choices = [(key, value['name']) for key, value in CATALOG_TYPES.items()]
class RetrieverForm(FlaskForm):
name = StringField('Name', validators=[DataRequired(), Length(max=50)])
description = TextAreaField('Description', validators=[Optional()])
# Catalog for the Retriever
catalog = QuerySelectField(
'Catalog ID',
query_factory=lambda: Catalog.query.all(),
allow_blank=True,
get_label='name',
validators=[Optional()],
)
# Select Field for Retriever Type (Uses the RETRIEVER_TYPES defined in config)
type = SelectField('Retriever Type', validators=[DataRequired()])
# Metadata fields
user_metadata = TextAreaField('User Metadata', validators=[Optional(), validate_json])
system_metadata = TextAreaField('System Metadata', validators=[Optional(), validate_json])
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Dynamically populate the 'type' field using the constructor
self.type.choices = [(key, value['name']) for key, value in RETRIEVER_TYPES.items()]
class EditRetrieverForm(DynamicFormBase):
name = StringField('Name', validators=[DataRequired(), Length(max=50)])
description = TextAreaField('Description', validators=[Optional()])
# Catalog for the Retriever
catalog = QuerySelectField(
'Catalog ID',
query_factory=lambda: Catalog.query.all(),
allow_blank=True,
get_label='name',
validators=[Optional()],
)
# Select Field for Retriever Type (Uses the RETRIEVER_TYPES defined in config)
type = SelectField('Retriever Type', validators=[DataRequired()], render_kw={'readonly': True})
# Metadata fields
user_metadata = TextAreaField('User Metadata', validators=[Optional(), validate_json])
system_metadata = TextAreaField('System Metadata', validators=[Optional(), validate_json])
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Set the retriever type choices (loaded from config)
self.type.choices = [(key, value['name']) for key, value in RETRIEVER_TYPES.items()]
class AddDocumentForm(FlaskForm):

View File

@@ -14,7 +14,7 @@ from urllib.parse import urlparse, unquote
import io
import json
from common.models.document import Document, DocumentVersion, Catalog
from common.models.document import Document, DocumentVersion, Catalog, Retriever
from common.extensions import db, minio_client
from common.utils.document_utils import validate_file_type, create_document_stack, start_embedding_task, process_url, \
process_multiple_urls, get_documents_list, edit_document, \
@@ -22,7 +22,7 @@ from common.utils.document_utils import validate_file_type, create_document_stac
from common.utils.eveai_exceptions import EveAIInvalidLanguageException, EveAIUnsupportedFileType, \
EveAIDoubleURLException
from .document_forms import AddDocumentForm, AddURLForm, EditDocumentForm, EditDocumentVersionForm, AddURLsForm, \
CatalogForm
CatalogForm, RetrieverForm, EditRetrieverForm
from common.utils.middleware import mw_before_request
from common.utils.celery_utils import current_celery
from common.utils.nginx_utils import prefixed_url_for
@@ -30,6 +30,8 @@ from common.utils.view_assistants import form_validation_failed, prepare_table_f
from .document_list_view import DocumentListView
from .document_version_list_view import DocumentVersionListView
from config.retriever_types import RETRIEVER_TYPES
document_bp = Blueprint('document_bp', __name__, url_prefix='/document')
@@ -65,6 +67,7 @@ def catalog():
tenant_id = session.get('tenant').get('id')
new_catalog = Catalog()
form.populate_obj(new_catalog)
new_catalog.parent_id = form.parent.data.get('id')
# Handle Embedding Variables
new_catalog.html_tags = [tag.strip() for tag in form.html_tags.data.split(',')] if form.html_tags.data else []
new_catalog.html_end_tags = [tag.strip() for tag in form.html_end_tags.data.split(',')] \
@@ -103,7 +106,7 @@ def catalogs():
the_catalogs = pagination.items
# prepare table data
rows = prepare_table_for_macro(the_catalogs, [('id', ''), ('name', '')])
rows = prepare_table_for_macro(the_catalogs, [('id', ''), ('name', ''), ('type', '')])
# Render the catalogs in a template
return render_template('document/catalogs.html', rows=rows, pagination=pagination)
@@ -173,6 +176,121 @@ def edit_catalog(catalog_id):
return render_template('document/edit_catalog.html', form=form, catalog_id=catalog_id)
@document_bp.route('/retriever', methods=['GET', 'POST'])
@roles_accepted('Super User', 'Tenant Admin')
def retriever():
form = RetrieverForm()
if form.validate_on_submit():
tenant_id = session.get('tenant').get('id')
new_retriever = Retriever()
form.populate_obj(new_retriever)
new_retriever.catalog_id = form.catalog.data.id
set_logging_information(new_retriever, dt.now(tz.utc))
try:
db.session.add(new_retriever)
db.session.commit()
flash('Retriever successfully added!', 'success')
current_app.logger.info(f'Catalog {new_retriever.name} successfully added for tenant {tenant_id}!')
except SQLAlchemyError as e:
db.session.rollback()
flash(f'Failed to add retriever. Error: {e}', 'danger')
current_app.logger.error(f'Failed to add retriever {new_retriever.name}'
f'for tenant {tenant_id}. Error: {str(e)}')
# Enable step 2 of creation of retriever - add configuration of the retriever (dependent on type)
return redirect(prefixed_url_for('document_bp.retriever', retriever_id=new_retriever.id))
return render_template('document/retriever.html', form=form)
@document_bp.route('/retriever/<int:retriever_id>', methods=['GET', 'POST'])
@roles_accepted('Super User', 'Tenant Admin')
def edit_retriever(retriever_id):
"""Edit an existing retriever configuration."""
current_app.logger.debug(f"Editing Retriever {retriever_id}")
# Get the retriever or return 404
retriever = Retriever.query.get_or_404(retriever_id)
if retriever.catalog_id:
# If catalog_id is just an ID, fetch the Catalog object
retriever.catalog = Catalog.query.get(retriever.catalog_id)
else:
retriever.catalog = None
# Create form instance with the retriever
form = EditRetrieverForm(request.form, obj=retriever)
configuration_config = RETRIEVER_TYPES[retriever.type]["configuration"]
current_app.logger.debug(f"Configuration {configuration_config}")
form.add_dynamic_fields("configuration", configuration_config, retriever.configuration)
if form.validate_on_submit():
# Update basic fields
form.populate_obj(retriever)
retriever.configuration = form.get_dynamic_data('configuration')
# Update catalog relationship
retriever.catalog_id = form.catalog.data.id if form.catalog.data else None
# Update logging information
update_logging_information(retriever, dt.now(tz.utc))
# Save changes to database
try:
db.session.add(retriever)
db.session.commit()
flash('Retriever updated successfully!', 'success')
current_app.logger.info(f'Retriever {retriever.id} updated successfully')
except SQLAlchemyError as e:
db.session.rollback()
flash(f'Failed to update retriever. Error: {str(e)}', 'danger')
current_app.logger.error(f'Failed to update retriever {retriever_id}. Error: {str(e)}')
return render_template('document/edit_retriever.html', form=form, retriever_id=retriever_id)
return redirect(prefixed_url_for('document_bp.retrievers'))
else:
form_validation_failed(request, form)
current_app.logger.debug(f"Rendering Template for {retriever_id}")
return render_template('document/edit_retriever.html', form=form, retriever_id=retriever_id)
@document_bp.route('/retrievers', methods=['GET', 'POST'])
@roles_accepted('Super User', 'Tenant Admin')
def retrievers():
page = request.args.get('page', 1, type=int)
per_page = request.args.get('per_page', 10, type=int)
query = Retriever.query.order_by(Retriever.id)
pagination = query.paginate(page=page, per_page=per_page)
the_retrievers = pagination.items
# prepare table data
rows = prepare_table_for_macro(the_retrievers,
[('id', ''), ('name', ''), ('type', ''), ('catalog_id', '')])
# Render the catalogs in a template
return render_template('document/retrievers.html', rows=rows, pagination=pagination)
@document_bp.route('/handle_retriever_selection', methods=['POST'])
@roles_accepted('Super User', 'Tenant Admin')
def handle_retriever_selection():
retriever_identification = request.form.get('selected_row')
retriever_id = ast.literal_eval(retriever_identification).get('value')
action = request.form['action']
if action == 'edit_retriever':
return redirect(prefixed_url_for('document_bp.edit_retriever', retriever_id=retriever_id))
return redirect(prefixed_url_for('document_bp.retrievers'))
@document_bp.route('/add_document', methods=['GET', 'POST'])
@roles_accepted('Super User', 'Tenant Admin')
def add_document():

View File

@@ -0,0 +1,92 @@
from flask_wtf import FlaskForm
from wtforms import IntegerField, FloatField, BooleanField, StringField, validators
from flask import current_app
class DynamicFormBase(FlaskForm):
def __init__(self, formdata=None, *args, **kwargs):
super(DynamicFormBase, self).__init__(*args, **kwargs)
# Maps collection names to lists of field names
self.dynamic_fields = {}
# Store formdata for later use
self.formdata = formdata
def add_dynamic_fields(self, collection_name, config, initial_data=None):
"""Add dynamic fields to the form based on the configuration."""
self.dynamic_fields[collection_name] = []
for field_name, field_def in config.items():
current_app.logger.debug(f"{field_name}: {field_def}")
# Prefix the field name with the collection name
full_field_name = f"{collection_name}_{field_name}"
field_type = field_def.get('type')
description = field_def.get('description', '')
required = field_def.get('required', False)
default = field_def.get('default')
# Determine validators
field_validators = [validators.InputRequired()] if required else [validators.Optional()]
# Map the field type to WTForms field classes
field_class = {
'int': IntegerField,
'float': FloatField,
'boolean': BooleanField,
'string': StringField,
}.get(field_type, StringField)
# Create the field instance
unbound_field = field_class(
label=description,
validators=field_validators,
default=default
)
# Bind the field to the form
bound_field = unbound_field.bind(form=self, name=full_field_name)
# Process the field with formdata
if self.formdata and full_field_name in self.formdata:
# If formdata is available and contains the field
bound_field.process(self.formdata)
elif initial_data and field_name in initial_data:
# Use initial data if provided
bound_field.process(formdata=None, data=initial_data[field_name])
else:
# Use default value
bound_field.process(formdata=None, data=default)
# Set collection name attribute for identification
# bound_field.collection_name = collection_name
# Add the field to the form
setattr(self, full_field_name, bound_field)
self._fields[full_field_name] = bound_field
self.dynamic_fields[collection_name].append(full_field_name)
def get_static_fields(self):
"""Return a list of static field instances."""
# Get names of dynamic fields
dynamic_field_names = set()
for field_list in self.dynamic_fields.values():
dynamic_field_names.update(field_list)
# Return all fields that are not dynamic
return [field for name, field in self._fields.items() if name not in dynamic_field_names]
def get_dynamic_fields(self):
"""Return a dictionary of dynamic fields per collection."""
result = {}
for collection_name, field_names in self.dynamic_fields.items():
result[collection_name] = [getattr(self, name) for name in field_names]
return result
def get_dynamic_data(self, collection_name):
"""Retrieve the data from dynamic fields of a specific collection."""
data = {}
if collection_name not in self.dynamic_fields:
return data
prefix_length = len(collection_name) + 1 # +1 for the underscore
for full_field_name in self.dynamic_fields[collection_name]:
original_field_name = full_field_name[prefix_length:]
field = getattr(self, full_field_name)
data[original_field_name] = field.data
return data

View File

@@ -22,7 +22,7 @@ from common.models.interaction import ChatSession, Interaction, InteractionEmbed
from common.extensions import db
from common.utils.celery_utils import current_celery
from common.utils.model_utils import select_model_variables, create_language_template, replace_variable_in_template
from common.langchain.eveai_retriever import EveAIRetriever
from common.langchain.eveai_default_rag_retriever import EveAIDefaultRagRetriever
from common.langchain.eveai_history_retriever import EveAIHistoryRetriever
from common.utils.business_event import BusinessEvent
from common.utils.business_event_context import current_event
@@ -139,7 +139,7 @@ def answer_using_tenant_rag(question, language, tenant, chat_session):
new_interaction.detailed_question_at = dt.now(tz.utc)
with current_event.create_span("Generate Answer using RAG"):
retriever = EveAIRetriever(model_variables, tenant_info)
retriever = EveAIDefaultRagRetriever(model_variables, tenant_info)
llm = model_variables['llm']
template = model_variables['rag_template']
language_template = create_language_template(template, language)
@@ -243,7 +243,7 @@ def answer_using_llm(question, language, tenant, chat_session):
new_interaction.detailed_question_at = dt.now(tz.utc)
with current_event.create_span("Detail Answer using LLM"):
retriever = EveAIRetriever(model_variables, tenant_info)
retriever = EveAIDefaultRagRetriever(model_variables, tenant_info)
llm = model_variables['llm_no_rag']
template = model_variables['encyclopedia_template']
language_template = create_language_template(template, language)

View File

@@ -0,0 +1,56 @@
"""Add Retriever Model
Revision ID: 3717364e6429
Revises: 7b7b566e667f
Create Date: 2024-10-21 14:22:30.258679
"""
from alembic import op
import sqlalchemy as sa
import pgvector
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = '3717364e6429'
down_revision = '7b7b566e667f'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('retriever',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('name', sa.String(length=50), nullable=False),
sa.Column('description', sa.Text(), nullable=True),
sa.Column('catalog_id', sa.Integer(), nullable=True),
sa.Column('type', sa.String(length=50), nullable=False),
sa.Column('user_metadata', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.Column('system_metadata', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.Column('configuration', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=False),
sa.Column('created_by', sa.Integer(), nullable=True),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('now()'), nullable=False),
sa.Column('updated_by', sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(['catalog_id'], ['catalog.id'], ),
sa.ForeignKeyConstraint(['created_by'], ['public.user.id'], ),
sa.ForeignKeyConstraint(['updated_by'], ['public.user.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.drop_column('catalog', 'es_similarity_threshold')
op.drop_column('catalog', 'es_k')
op.drop_column('catalog', 'rag_tuning')
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('catalog', sa.Column('rag_tuning', sa.BOOLEAN(), autoincrement=False, nullable=True))
op.add_column('catalog', sa.Column('es_k', sa.INTEGER(), autoincrement=False, nullable=True))
op.add_column('catalog', sa.Column('es_similarity_threshold', sa.DOUBLE_PRECISION(precision=53), autoincrement=False, nullable=True))
op.drop_constraint(None, 'catalog', type_='foreignkey')
op.drop_constraint(None, 'catalog', type_='foreignkey')
op.create_foreign_key('catalog_updated_by_fkey', 'catalog', 'user', ['updated_by'], ['id'])
op.create_foreign_key('catalog_created_by_fkey', 'catalog', 'user', ['created_by'], ['id'])
op.drop_table('retriever')
# ### end Alembic commands ###

View File

@@ -0,0 +1,46 @@
"""Extensions for more Catalog Types in Catalog Model
Revision ID: 7b7b566e667f
Revises: 28984b05d396
Create Date: 2024-10-21 07:39:52.260054
"""
from alembic import op
import sqlalchemy as sa
import pgvector
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = '7b7b566e667f'
down_revision = '28984b05d396'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('catalog', sa.Column('parent_id', sa.Integer(), nullable=True))
op.add_column('catalog', sa.Column('type', sa.String(length=50), nullable=False, server_default='DEFAULT'))
# Update all existing rows to have the 'type' set to 'DEFAULT'
op.execute("UPDATE catalog SET type='DEFAULT' WHERE type IS NULL")
# Remove the server default (optional but recommended)
op.alter_column('catalog', 'type', server_default=None)
op.add_column('catalog', sa.Column('user_metadata', postgresql.JSONB(astext_type=sa.Text()), nullable=True))
op.add_column('catalog', sa.Column('system_metadata', postgresql.JSONB(astext_type=sa.Text()), nullable=True))
op.add_column('catalog', sa.Column('configuration', postgresql.JSONB(astext_type=sa.Text()), nullable=True))
op.create_foreign_key(None, 'catalog', 'catalog', ['parent_id'], ['id'])
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint(None, 'catalog', type_='foreignkey')
op.drop_column('catalog', 'configuration')
op.drop_column('catalog', 'system_metadata')
op.drop_column('catalog', 'user_metadata')
op.drop_column('catalog', 'type')
op.drop_column('catalog', 'parent_id')
# ### end Alembic commands ###

View File

@@ -82,3 +82,4 @@ prometheus-client~=0.20.0
flower~=2.0.1
psutil~=6.0.0
celery-redbeat~=2.2.0
WTForms-SQLAlchemy~=0.4.1