diff --git a/common/langchain/eveai_default_rag_retriever.py b/common/langchain/eveai_default_rag_retriever.py new file mode 100644 index 0000000..8d03587 --- /dev/null +++ b/common/langchain/eveai_default_rag_retriever.py @@ -0,0 +1,145 @@ +from langchain_core.retrievers import BaseRetriever +from sqlalchemy import func, and_, or_, desc +from sqlalchemy.exc import SQLAlchemyError +from pydantic import BaseModel, Field, PrivateAttr +from typing import Any, Dict +from flask import current_app + +from common.extensions import db +from common.models.document import Document, DocumentVersion +from common.utils.datetime_utils import get_date_in_timezone +from common.utils.model_utils import ModelVariables + + +class EveAIDefaultRagRetriever(BaseRetriever, BaseModel): + _catalog_id: int = PrivateAttr() + _model_variables: ModelVariables = PrivateAttr() + _tenant_info: Dict[str, Any] = PrivateAttr() + + def __init__(self, catalog_id: int, model_variables: ModelVariables, tenant_info: Dict[str, Any]): + super().__init__() + current_app.logger.debug(f'Model variables type: {type(model_variables)}') + self._catalog_id = catalog_id + self._model_variables = model_variables + self._tenant_info = tenant_info + + @property + def catalog_id(self) -> int: + return self._catalog_id + + @property + def model_variables(self) -> ModelVariables: + return self._model_variables + + @property + def tenant_info(self) -> Dict[str, Any]: + return self._tenant_info + + def _get_relevant_documents(self, query: str): + current_app.logger.debug(f'Retrieving relevant documents for query: {query}') + query_embedding = self._get_query_embedding(query) + current_app.logger.debug(f'Model Variables Private: {type(self._model_variables)}') + current_app.logger.debug(f'Model Variables Property: {type(self.model_variables)}') + db_class = self.model_variables['embedding_db_model'] + similarity_threshold = self.model_variables['similarity_threshold'] + k = self.model_variables['k'] + + if self.model_variables['rag_tuning']: + try: + current_date = get_date_in_timezone(self.tenant_info['timezone']) + current_app.rag_tuning_logger.debug(f'Current date: {current_date}\n') + + # Debug query to show similarity for all valid documents (without chunk text) + debug_query = ( + db.session.query( + Document.id.label('document_id'), + DocumentVersion.id.label('version_id'), + db_class.id.label('embedding_id'), + (1 - db_class.embedding.cosine_distance(query_embedding)).label('similarity') + ) + .join(DocumentVersion, db_class.doc_vers_id == DocumentVersion.id) + .join(Document, DocumentVersion.doc_id == Document.id) + .filter( + or_(Document.valid_from.is_(None), func.date(Document.valid_from) <= current_date), + or_(Document.valid_to.is_(None), func.date(Document.valid_to) >= current_date) + ) + .order_by(desc('similarity')) + ) + + debug_results = debug_query.all() + + current_app.logger.debug("Debug: Similarity for all valid documents:") + for row in debug_results: + current_app.rag_tuning_logger.debug(f"Doc ID: {row.document_id}, " + f"Version ID: {row.version_id}, " + f"Embedding ID: {row.embedding_id}, " + f"Similarity: {row.similarity}") + current_app.rag_tuning_logger.debug(f'---------------------------------------\n') + except SQLAlchemyError as e: + current_app.logger.error(f'Error generating overview: {e}') + db.session.rollback() + + if self.model_variables['rag_tuning']: + current_app.rag_tuning_logger.debug(f'Parameters for Retrieval of documents: \n') + current_app.rag_tuning_logger.debug(f'Similarity Threshold: {similarity_threshold}\n') + current_app.rag_tuning_logger.debug(f'K: {k}\n') + current_app.rag_tuning_logger.debug(f'---------------------------------------\n') + + try: + current_date = get_date_in_timezone(self.tenant_info['timezone']) + # Subquery to find the latest version of each document + subquery = ( + db.session.query( + DocumentVersion.doc_id, + func.max(DocumentVersion.id).label('latest_version_id') + ) + .group_by(DocumentVersion.doc_id) + .subquery() + ) + # Main query to filter embeddings + query_obj = ( + db.session.query(db_class, + (1 - db_class.embedding.cosine_distance(query_embedding)).label('similarity')) + .join(DocumentVersion, db_class.doc_vers_id == DocumentVersion.id) + .join(Document, DocumentVersion.doc_id == Document.id) + .join(subquery, DocumentVersion.id == subquery.c.latest_version_id) + .filter( + or_(Document.valid_from.is_(None), func.date(Document.valid_from) <= current_date), + or_(Document.valid_to.is_(None), func.date(Document.valid_to) >= current_date), + (1 - db_class.embedding.cosine_distance(query_embedding)) > similarity_threshold, + Document.catalog_id == self._catalog_id + ) + .order_by(desc('similarity')) + .limit(k) + ) + + if self.model_variables['rag_tuning']: + current_app.rag_tuning_logger.debug(f'Query executed for Retrieval of documents: \n') + current_app.rag_tuning_logger.debug(f'{query_obj.statement}\n') + current_app.rag_tuning_logger.debug(f'---------------------------------------\n') + + res = query_obj.all() + + if self.model_variables['rag_tuning']: + current_app.rag_tuning_logger.debug(f'Retrieved {len(res)} relevant documents \n') + current_app.rag_tuning_logger.debug(f'Data retrieved: \n') + current_app.rag_tuning_logger.debug(f'{res}\n') + current_app.rag_tuning_logger.debug(f'---------------------------------------\n') + + result = [] + for doc in res: + if self.model_variables['rag_tuning']: + current_app.rag_tuning_logger.debug(f'Document ID: {doc[0].id} - Distance: {doc[1]}\n') + current_app.rag_tuning_logger.debug(f'Chunk: \n {doc[0].chunk}\n\n') + result.append(f'SOURCE: {doc[0].id}\n\n{doc[0].chunk}\n\n') + + except SQLAlchemyError as e: + current_app.logger.error(f'Error retrieving relevant documents: {e}') + db.session.rollback() + return [] + return result + + def _get_query_embedding(self, query: str): + embedding_model = self.model_variables['embedding_model'] + query_embedding = embedding_model.embed_query(query) + return query_embedding diff --git a/common/langchain/eveai_retriever.py b/common/langchain/eveai_retriever.py index 1e517f6..d394920 100644 --- a/common/langchain/eveai_retriever.py +++ b/common/langchain/eveai_retriever.py @@ -1,138 +1,39 @@ -from langchain_core.retrievers import BaseRetriever -from sqlalchemy import func, and_, or_, desc -from sqlalchemy.exc import SQLAlchemyError -from pydantic import BaseModel, Field, PrivateAttr -from typing import Any, Dict -from flask import current_app +from pydantic import BaseModel, PrivateAttr +from typing import Dict, Any -from common.extensions import db -from common.models.document import Document, DocumentVersion -from common.utils.datetime_utils import get_date_in_timezone from common.utils.model_utils import ModelVariables -class EveAIRetriever(BaseRetriever, BaseModel): - _model_variables: ModelVariables = PrivateAttr() +class EveAIRetriever(BaseModel): + _catalog_id: int = PrivateAttr() + _user_metadata: Dict[str, Any] = PrivateAttr() + _system_metadata: Dict[str, Any] = PrivateAttr() + _configuration: Dict[str, Any] = PrivateAttr() _tenant_info: Dict[str, Any] = PrivateAttr() + _model_variables: ModelVariables = PrivateAttr() - def __init__(self, model_variables: ModelVariables, tenant_info: Dict[str, Any]): + def __init__(self, catalog_id: int, user_metadata: Dict[str, Any], system_metadata: Dict[str, Any], + configuration: Dict[str, Any]): super().__init__() - current_app.logger.debug(f'Model variables type: {type(model_variables)}') - self._model_variables = model_variables - self._tenant_info = tenant_info + self._catalog_id = catalog_id + self._user_metadata = user_metadata + self._system_metadata = system_metadata + self._configuration = configuration @property - def model_variables(self) -> ModelVariables: - return self._model_variables + def catalog_id(self): + return self._catalog_id @property - def tenant_info(self) -> Dict[str, Any]: - return self._tenant_info + def user_metadata(self): + return self._user_metadata - def _get_relevant_documents(self, query: str): - current_app.logger.debug(f'Retrieving relevant documents for query: {query}') - query_embedding = self._get_query_embedding(query) - current_app.logger.debug(f'Model Variables Private: {type(self._model_variables)}') - current_app.logger.debug(f'Model Variables Property: {type(self.model_variables)}') - db_class = self.model_variables['embedding_db_model'] - similarity_threshold = self.model_variables['similarity_threshold'] - k = self.model_variables['k'] + @property + def system_metadata(self): + return self._system_metadata - if self.model_variables['rag_tuning']: - try: - current_date = get_date_in_timezone(self.tenant_info['timezone']) - current_app.rag_tuning_logger.debug(f'Current date: {current_date}\n') + @property + def configuration(self): + return self._configuration - # Debug query to show similarity for all valid documents (without chunk text) - debug_query = ( - db.session.query( - Document.id.label('document_id'), - DocumentVersion.id.label('version_id'), - db_class.id.label('embedding_id'), - (1 - db_class.embedding.cosine_distance(query_embedding)).label('similarity') - ) - .join(DocumentVersion, db_class.doc_vers_id == DocumentVersion.id) - .join(Document, DocumentVersion.doc_id == Document.id) - .filter( - or_(Document.valid_from.is_(None), func.date(Document.valid_from) <= current_date), - or_(Document.valid_to.is_(None), func.date(Document.valid_to) >= current_date) - ) - .order_by(desc('similarity')) - ) - - debug_results = debug_query.all() - - current_app.logger.debug("Debug: Similarity for all valid documents:") - for row in debug_results: - current_app.rag_tuning_logger.debug(f"Doc ID: {row.document_id}, " - f"Version ID: {row.version_id}, " - f"Embedding ID: {row.embedding_id}, " - f"Similarity: {row.similarity}") - current_app.rag_tuning_logger.debug(f'---------------------------------------\n') - except SQLAlchemyError as e: - current_app.logger.error(f'Error generating overview: {e}') - db.session.rollback() - - if self.model_variables['rag_tuning']: - current_app.rag_tuning_logger.debug(f'Parameters for Retrieval of documents: \n') - current_app.rag_tuning_logger.debug(f'Similarity Threshold: {similarity_threshold}\n') - current_app.rag_tuning_logger.debug(f'K: {k}\n') - current_app.rag_tuning_logger.debug(f'---------------------------------------\n') - - try: - current_date = get_date_in_timezone(self.tenant_info['timezone']) - # Subquery to find the latest version of each document - subquery = ( - db.session.query( - DocumentVersion.doc_id, - func.max(DocumentVersion.id).label('latest_version_id') - ) - .group_by(DocumentVersion.doc_id) - .subquery() - ) - # Main query to filter embeddings - query_obj = ( - db.session.query(db_class, - (1 - db_class.embedding.cosine_distance(query_embedding)).label('similarity')) - .join(DocumentVersion, db_class.doc_vers_id == DocumentVersion.id) - .join(Document, DocumentVersion.doc_id == Document.id) - .join(subquery, DocumentVersion.id == subquery.c.latest_version_id) - .filter( - or_(Document.valid_from.is_(None), func.date(Document.valid_from) <= current_date), - or_(Document.valid_to.is_(None), func.date(Document.valid_to) >= current_date), - (1 - db_class.embedding.cosine_distance(query_embedding)) > similarity_threshold - ) - .order_by(desc('similarity')) - .limit(k) - ) - - if self.model_variables['rag_tuning']: - current_app.rag_tuning_logger.debug(f'Query executed for Retrieval of documents: \n') - current_app.rag_tuning_logger.debug(f'{query_obj.statement}\n') - current_app.rag_tuning_logger.debug(f'---------------------------------------\n') - - res = query_obj.all() - - if self.model_variables['rag_tuning']: - current_app.rag_tuning_logger.debug(f'Retrieved {len(res)} relevant documents \n') - current_app.rag_tuning_logger.debug(f'Data retrieved: \n') - current_app.rag_tuning_logger.debug(f'{res}\n') - current_app.rag_tuning_logger.debug(f'---------------------------------------\n') - - result = [] - for doc in res: - if self.model_variables['rag_tuning']: - current_app.rag_tuning_logger.debug(f'Document ID: {doc[0].id} - Distance: {doc[1]}\n') - current_app.rag_tuning_logger.debug(f'Chunk: \n {doc[0].chunk}\n\n') - result.append(f'SOURCE: {doc[0].id}\n\n{doc[0].chunk}\n\n') - - except SQLAlchemyError as e: - current_app.logger.error(f'Error retrieving relevant documents: {e}') - db.session.rollback() - return [] - return result - - def _get_query_embedding(self, query: str): - embedding_model = self.model_variables['embedding_model'] - query_embedding = embedding_model.embed_query(query) - return query_embedding + # Any common methods that should be shared among retrievers can go here. diff --git a/common/models/document.py b/common/models/document.py index 7d08e67..3d9d4c9 100644 --- a/common/models/document.py +++ b/common/models/document.py @@ -8,8 +8,10 @@ import sqlalchemy as sa class Catalog(db.Model): id = db.Column(db.Integer, primary_key=True) + parent_id = db.Column(db.Integer, db.ForeignKey('catalog.id'), nullable=True) name = db.Column(db.String(50), nullable=False) description = db.Column(db.Text, nullable=True) + type = db.Column(db.String(50), nullable=False, default="DEFAULT_CATALOG") # Embedding variables html_tags = db.Column(ARRAY(sa.String(10)), nullable=True, default=['p', 'h1', 'h2', 'h3', 'h4', 'h5', 'h6', 'li']) @@ -21,17 +23,36 @@ class Catalog(db.Model): min_chunk_size = db.Column(db.Integer, nullable=True, default=2000) max_chunk_size = db.Column(db.Integer, nullable=True, default=3000) - # Embedding search variables ==> move to specialist? - es_k = db.Column(db.Integer, nullable=True, default=8) - es_similarity_threshold = db.Column(db.Float, nullable=True, default=0.4) - # Chat variables ==> Move to Specialist? chat_RAG_temperature = db.Column(db.Float, nullable=True, default=0.3) chat_no_RAG_temperature = db.Column(db.Float, nullable=True, default=0.5) # Tuning enablers embed_tuning = db.Column(db.Boolean, nullable=True, default=False) - rag_tuning = db.Column(db.Boolean, nullable=True, default=False) # Move to Specialist? + + # Meta Data + user_metadata = db.Column(JSONB, nullable=True) + system_metadata = db.Column(JSONB, nullable=True) + configuration = db.Column(JSONB, nullable=True) + + # Versioning Information + created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now()) + created_by = db.Column(db.Integer, db.ForeignKey(User.id), nullable=True) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now(), onupdate=db.func.now()) + updated_by = db.Column(db.Integer, db.ForeignKey(User.id)) + + +class Retriever(db.Model): + id = db.Column(db.Integer, primary_key=True) + name = db.Column(db.String(50), nullable=False) + description = db.Column(db.Text, nullable=True) + catalog_id = db.Column(db.Integer, db.ForeignKey('catalog.id'), nullable=True) + type = db.Column(db.String(50), nullable=False, default="DEFAULT_RAG") + + # Meta Data + user_metadata = db.Column(JSONB, nullable=True) + system_metadata = db.Column(JSONB, nullable=True) + configuration = db.Column(JSONB, nullable=True) # Versioning Information created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now()) diff --git a/config/catalog_types.py b/config/catalog_types.py new file mode 100644 index 0000000..1346bdf --- /dev/null +++ b/config/catalog_types.py @@ -0,0 +1,11 @@ +# Catalog Types +CATALOG_TYPES = { + "DEFAULT": { + "name": "Default Catalog", + "Description": "Default Catalog" + }, + "DOSSIER": { + "name": "Dossier Catalog", + "Description": "A Catalog in which several Dossiers can be stored" + }, +} diff --git a/config/retriever_types.py b/config/retriever_types.py new file mode 100644 index 0000000..84c2539 --- /dev/null +++ b/config/retriever_types.py @@ -0,0 +1,30 @@ +# Retriever Types +RETRIEVER_TYPES = { + "DEFAULT_RAG": { + "name": "Default RAG", + "description": "Retrieving all embeddings conform the query", + "configuration": { + "es_k": { + "name": "es_k", + "type": "int", + "description": "K-value to retrieve embeddings (max embeddings retrieved)", + "required": True, + "default": 8, + }, + "es_similarity_threshold": { + "name": "es_similarity_threshold", + "type": "float", + "description": "Similarity threshold for retrieving embeddings", + "required": True, + "default": 0.3, + }, + "rag_tuning": { + "name": "rag_tuning", + "type": "boolean", + "description": "Whether to do tuning logging or not.", + "required": False, + "default": False, + } + } + } +} diff --git a/eveai_app/templates/document/edit_catalog.html b/eveai_app/templates/document/edit_catalog.html index f205e4c..51a548b 100644 --- a/eveai_app/templates/document/edit_catalog.html +++ b/eveai_app/templates/document/edit_catalog.html @@ -16,7 +16,7 @@ When you change chunking of embedding information, you'll need to manually refre {% for field in form %} {{ render_field(field, disabled_fields, exclude_fields) }} {% endfor %} - + {% endblock %} diff --git a/eveai_app/templates/document/edit_retriever.html b/eveai_app/templates/document/edit_retriever.html new file mode 100644 index 0000000..f8af75c --- /dev/null +++ b/eveai_app/templates/document/edit_retriever.html @@ -0,0 +1,31 @@ +{% extends 'base.html' %} +{% from "macros.html" import render_field2, render_dynamic_fields %} + +{% block title %}Edit Retriever{% endblock %} + +{% block content_title %}Edit Retriever{% endblock %} +{% block content_description %}Edit a Retriever (for a Catalog){% endblock %} + +{% block content %} +
+ {{ form.hidden_tag() }} + {% set disabled_fields = ['type'] %} + {% set exclude_fields = [] %} + + {% for field in form.get_static_fields() %} + {{ render_field2(field, disabled_fields, exclude_fields) }} + {% endfor %} + + {% for collection_name, fields in form.get_dynamic_fields().items() %} +

{{ collection_name }}

+ {% for field in fields %} + {{ render_field2(field, disabled_fields, exclude_fields) }} + {% endfor %} + {% endfor %} + +
+{% endblock %} + +{% block content_footer %} + +{% endblock %} diff --git a/eveai_app/templates/document/retriever.html b/eveai_app/templates/document/retriever.html new file mode 100644 index 0000000..9a90349 --- /dev/null +++ b/eveai_app/templates/document/retriever.html @@ -0,0 +1,23 @@ +{% extends 'base.html' %} +{% from "macros.html" import render_field %} + +{% block title %}Retriever Registration{% endblock %} + +{% block content_title %}Register Retriever{% endblock %} +{% block content_description %}Define a new retriever (for a catalog){% endblock %} + +{% block content %} +
+ {{ form.hidden_tag() }} + {% set disabled_fields = [] %} + {% set exclude_fields = [] %} + {% for field in form %} + {{ render_field(field, disabled_fields, exclude_fields) }} + {% endfor %} + +
+{% endblock %} + +{% block content_footer %} + +{% endblock %} diff --git a/eveai_app/templates/document/retrievers.html b/eveai_app/templates/document/retrievers.html new file mode 100644 index 0000000..cbc44b0 --- /dev/null +++ b/eveai_app/templates/document/retrievers.html @@ -0,0 +1,23 @@ +{% extends 'base.html' %} +{% from 'macros.html' import render_selectable_table, render_pagination %} + +{% block title %}Retrievers{% endblock %} + +{% block content_title %}Retrievers{% endblock %} +{% block content_description %}View Retrieers for Tenant{% endblock %} +{% block content_class %}
{% endblock %} + +{% block content %} +
+
+ {{ render_selectable_table(headers=["Retriever ID", "Name", "Type", "Catalog ID"], rows=rows, selectable=True, id="retrieverssTable") }} +
+ +
+
+
+{% endblock %} + +{% block content_footer %} + {{ render_pagination(pagination, 'document_bp.retrievers') }} +{% endblock %} \ No newline at end of file diff --git a/eveai_app/templates/macros.html b/eveai_app/templates/macros.html index 94e3896..e75ff28 100644 --- a/eveai_app/templates/macros.html +++ b/eveai_app/templates/macros.html @@ -23,6 +23,43 @@ {% endif %} {% endmacro %} +{% macro render_field2(field, disabled_fields=[], exclude_fields=[], class='') %} + + + + {% set disabled = field.name in disabled_fields %} + {% set exclude_fields = exclude_fields + ['csrf_token', 'submit'] %} + {% if field.name not in exclude_fields %} + {% if field.type == 'BooleanField' %} +
+
+ {{ field(class="form-check-input " + class, disabled=disabled) }} + {{ field.label(class="form-check-label") }} +
+ {% if field.errors %} +
+ {% for error in field.errors %} + {{ error }} + {% endfor %} +
+ {% endif %} +
+ {% else %} +
+ {{ field.label(class="form-label") }} + {{ field(class="form-control " + class, disabled=disabled) }} + {% if field.errors %} +
+ {% for error in field.errors %} + {{ error }} + {% endfor %} +
+ {% endif %} +
+ {% endif %} + {% endif %} +{% endmacro %} + {% macro render_included_field(field, disabled_fields=[], include_fields=[]) %} {% set disabled = field.name in disabled_fields %} {% if field.name in include_fields %} diff --git a/eveai_app/templates/navbar.html b/eveai_app/templates/navbar.html index bb11f33..16dbd1a 100644 --- a/eveai_app/templates/navbar.html +++ b/eveai_app/templates/navbar.html @@ -83,6 +83,8 @@ {{ dropdown('Document Mgmt', 'note_stack', [ {'name': 'Add Catalog', 'url': '/document/catalog', 'roles': ['Super User', 'Tenant Admin']}, {'name': 'All Catalogs', 'url': '/document/catalogs', 'roles': ['Super User', 'Tenant Admin']}, + {'name': 'Add Retriever', 'url': '/document/retriever', 'roles': ['Super User', 'Tenant Admin']}, + {'name': 'All Retrievers', 'url': '/document/retrievers', 'roles': ['Super User', 'Tenant Admin']}, {'name': 'Add Document', 'url': '/document/add_document', 'roles': ['Super User', 'Tenant Admin']}, {'name': 'Add URL', 'url': '/document/add_url', 'roles': ['Super User', 'Tenant Admin']}, {'name': 'Add a list of URLs', 'url': '/document/add_urls', 'roles': ['Super User', 'Tenant Admin']}, diff --git a/eveai_app/views/document_forms.py b/eveai_app/views/document_forms.py index 87427e1..b035110 100644 --- a/eveai_app/views/document_forms.py +++ b/eveai_app/views/document_forms.py @@ -1,4 +1,4 @@ -from flask import session, current_app +from flask import session, current_app, request from flask_wtf import FlaskForm from wtforms import (StringField, BooleanField, SubmitField, DateField, IntegerField, FloatField, SelectMultipleField, SelectField, FieldList, FormField, TextAreaField, URLField) @@ -6,6 +6,14 @@ from wtforms.validators import DataRequired, Length, Optional, URL, ValidationEr from flask_wtf.file import FileField, FileAllowed, FileRequired import json +from wtforms_sqlalchemy.fields import QuerySelectField + +from common.models.document import Catalog + +from config.catalog_types import CATALOG_TYPES +from config.retriever_types import RETRIEVER_TYPES +from .dynamic_form_base import DynamicFormBase + def allowed_file(form, field): if field.data: @@ -26,6 +34,23 @@ def validate_json(form, field): class CatalogForm(FlaskForm): name = StringField('Name', validators=[DataRequired(), Length(max=50)]) description = TextAreaField('Description', validators=[Optional()]) + # Parent ID (Optional for root-level catalogs) + parent = QuerySelectField( + 'Parent Catalog', + query_factory=lambda: Catalog.query.all(), + allow_blank=True, + get_label='name', + validators=[Optional()], + ) + + # Select Field for Catalog Type (Uses the CATALOG_TYPES defined in config) + type = SelectField('Catalog Type', validators=[DataRequired()]) + + # Metadata fields + user_metadata = TextAreaField('User Metadata', validators=[Optional(), validate_json]) + system_metadata = TextAreaField('System Metadata', validators=[Optional(), validate_json]) + configuration = TextAreaField('Configuration', validators=[Optional(), validate_json]) + # HTML Embedding Variables html_tags = StringField('HTML Tags', validators=[DataRequired()], default='p, h1, h2, h3, h4, h5, h6, li, , tbody, tr, td') @@ -38,19 +63,65 @@ class CatalogForm(FlaskForm): default=2000) max_chunk_size = IntegerField('Maximum Chunk Size (3000)', validators=[NumberRange(min=0), Optional()], default=3000) - # Embedding Search variables - es_k = IntegerField('Limit for Searching Embeddings (5)', - default=5, - validators=[NumberRange(min=0)]) - es_similarity_threshold = FloatField('Similarity Threshold for Searching Embeddings (0.5)', - default=0.5, - validators=[NumberRange(min=0, max=1)]) # Chat Variables chat_RAG_temperature = FloatField('RAG Temperature', default=0.3, validators=[NumberRange(min=0, max=1)]) chat_no_RAG_temperature = FloatField('No RAG Temperature', default=0.5, validators=[NumberRange(min=0, max=1)]) # Tuning variables embed_tuning = BooleanField('Enable Embedding Tuning', default=False) - rag_tuning = BooleanField('Enable RAG Tuning', default=False) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # Dynamically populate the 'type' field using the constructor + self.type.choices = [(key, value['name']) for key, value in CATALOG_TYPES.items()] + + +class RetrieverForm(FlaskForm): + name = StringField('Name', validators=[DataRequired(), Length(max=50)]) + description = TextAreaField('Description', validators=[Optional()]) + # Catalog for the Retriever + catalog = QuerySelectField( + 'Catalog ID', + query_factory=lambda: Catalog.query.all(), + allow_blank=True, + get_label='name', + validators=[Optional()], + ) + # Select Field for Retriever Type (Uses the RETRIEVER_TYPES defined in config) + type = SelectField('Retriever Type', validators=[DataRequired()]) + + # Metadata fields + user_metadata = TextAreaField('User Metadata', validators=[Optional(), validate_json]) + system_metadata = TextAreaField('System Metadata', validators=[Optional(), validate_json]) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # Dynamically populate the 'type' field using the constructor + self.type.choices = [(key, value['name']) for key, value in RETRIEVER_TYPES.items()] + + +class EditRetrieverForm(DynamicFormBase): + name = StringField('Name', validators=[DataRequired(), Length(max=50)]) + description = TextAreaField('Description', validators=[Optional()]) + # Catalog for the Retriever + catalog = QuerySelectField( + 'Catalog ID', + query_factory=lambda: Catalog.query.all(), + allow_blank=True, + get_label='name', + validators=[Optional()], + ) + # Select Field for Retriever Type (Uses the RETRIEVER_TYPES defined in config) + type = SelectField('Retriever Type', validators=[DataRequired()], render_kw={'readonly': True}) + + # Metadata fields + user_metadata = TextAreaField('User Metadata', validators=[Optional(), validate_json]) + system_metadata = TextAreaField('System Metadata', validators=[Optional(), validate_json]) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # Set the retriever type choices (loaded from config) + self.type.choices = [(key, value['name']) for key, value in RETRIEVER_TYPES.items()] class AddDocumentForm(FlaskForm): diff --git a/eveai_app/views/document_views.py b/eveai_app/views/document_views.py index f8f2a4b..260163f 100644 --- a/eveai_app/views/document_views.py +++ b/eveai_app/views/document_views.py @@ -14,7 +14,7 @@ from urllib.parse import urlparse, unquote import io import json -from common.models.document import Document, DocumentVersion, Catalog +from common.models.document import Document, DocumentVersion, Catalog, Retriever from common.extensions import db, minio_client from common.utils.document_utils import validate_file_type, create_document_stack, start_embedding_task, process_url, \ process_multiple_urls, get_documents_list, edit_document, \ @@ -22,7 +22,7 @@ from common.utils.document_utils import validate_file_type, create_document_stac from common.utils.eveai_exceptions import EveAIInvalidLanguageException, EveAIUnsupportedFileType, \ EveAIDoubleURLException from .document_forms import AddDocumentForm, AddURLForm, EditDocumentForm, EditDocumentVersionForm, AddURLsForm, \ - CatalogForm + CatalogForm, RetrieverForm, EditRetrieverForm from common.utils.middleware import mw_before_request from common.utils.celery_utils import current_celery from common.utils.nginx_utils import prefixed_url_for @@ -30,6 +30,8 @@ from common.utils.view_assistants import form_validation_failed, prepare_table_f from .document_list_view import DocumentListView from .document_version_list_view import DocumentVersionListView +from config.retriever_types import RETRIEVER_TYPES + document_bp = Blueprint('document_bp', __name__, url_prefix='/document') @@ -65,6 +67,7 @@ def catalog(): tenant_id = session.get('tenant').get('id') new_catalog = Catalog() form.populate_obj(new_catalog) + new_catalog.parent_id = form.parent.data.get('id') # Handle Embedding Variables new_catalog.html_tags = [tag.strip() for tag in form.html_tags.data.split(',')] if form.html_tags.data else [] new_catalog.html_end_tags = [tag.strip() for tag in form.html_end_tags.data.split(',')] \ @@ -103,7 +106,7 @@ def catalogs(): the_catalogs = pagination.items # prepare table data - rows = prepare_table_for_macro(the_catalogs, [('id', ''), ('name', '')]) + rows = prepare_table_for_macro(the_catalogs, [('id', ''), ('name', ''), ('type', '')]) # Render the catalogs in a template return render_template('document/catalogs.html', rows=rows, pagination=pagination) @@ -173,6 +176,121 @@ def edit_catalog(catalog_id): return render_template('document/edit_catalog.html', form=form, catalog_id=catalog_id) +@document_bp.route('/retriever', methods=['GET', 'POST']) +@roles_accepted('Super User', 'Tenant Admin') +def retriever(): + form = RetrieverForm() + + if form.validate_on_submit(): + tenant_id = session.get('tenant').get('id') + new_retriever = Retriever() + form.populate_obj(new_retriever) + new_retriever.catalog_id = form.catalog.data.id + + set_logging_information(new_retriever, dt.now(tz.utc)) + + try: + db.session.add(new_retriever) + db.session.commit() + flash('Retriever successfully added!', 'success') + current_app.logger.info(f'Catalog {new_retriever.name} successfully added for tenant {tenant_id}!') + except SQLAlchemyError as e: + db.session.rollback() + flash(f'Failed to add retriever. Error: {e}', 'danger') + current_app.logger.error(f'Failed to add retriever {new_retriever.name}' + f'for tenant {tenant_id}. Error: {str(e)}') + + # Enable step 2 of creation of retriever - add configuration of the retriever (dependent on type) + return redirect(prefixed_url_for('document_bp.retriever', retriever_id=new_retriever.id)) + + return render_template('document/retriever.html', form=form) + + +@document_bp.route('/retriever/', methods=['GET', 'POST']) +@roles_accepted('Super User', 'Tenant Admin') +def edit_retriever(retriever_id): + """Edit an existing retriever configuration.""" + current_app.logger.debug(f"Editing Retriever {retriever_id}") + + # Get the retriever or return 404 + retriever = Retriever.query.get_or_404(retriever_id) + + if retriever.catalog_id: + # If catalog_id is just an ID, fetch the Catalog object + retriever.catalog = Catalog.query.get(retriever.catalog_id) + else: + retriever.catalog = None + + # Create form instance with the retriever + form = EditRetrieverForm(request.form, obj=retriever) + + configuration_config = RETRIEVER_TYPES[retriever.type]["configuration"] + current_app.logger.debug(f"Configuration {configuration_config}") + form.add_dynamic_fields("configuration", configuration_config, retriever.configuration) + + if form.validate_on_submit(): + # Update basic fields + form.populate_obj(retriever) + retriever.configuration = form.get_dynamic_data('configuration') + + # Update catalog relationship + retriever.catalog_id = form.catalog.data.id if form.catalog.data else None + + # Update logging information + update_logging_information(retriever, dt.now(tz.utc)) + + # Save changes to database + try: + db.session.add(retriever) + db.session.commit() + flash('Retriever updated successfully!', 'success') + current_app.logger.info(f'Retriever {retriever.id} updated successfully') + except SQLAlchemyError as e: + db.session.rollback() + flash(f'Failed to update retriever. Error: {str(e)}', 'danger') + current_app.logger.error(f'Failed to update retriever {retriever_id}. Error: {str(e)}') + return render_template('document/edit_retriever.html', form=form, retriever_id=retriever_id) + + return redirect(prefixed_url_for('document_bp.retrievers')) + else: + form_validation_failed(request, form) + + current_app.logger.debug(f"Rendering Template for {retriever_id}") + return render_template('document/edit_retriever.html', form=form, retriever_id=retriever_id) + + +@document_bp.route('/retrievers', methods=['GET', 'POST']) +@roles_accepted('Super User', 'Tenant Admin') +def retrievers(): + page = request.args.get('page', 1, type=int) + per_page = request.args.get('per_page', 10, type=int) + + query = Retriever.query.order_by(Retriever.id) + + pagination = query.paginate(page=page, per_page=per_page) + the_retrievers = pagination.items + + # prepare table data + rows = prepare_table_for_macro(the_retrievers, + [('id', ''), ('name', ''), ('type', ''), ('catalog_id', '')]) + + # Render the catalogs in a template + return render_template('document/retrievers.html', rows=rows, pagination=pagination) + + +@document_bp.route('/handle_retriever_selection', methods=['POST']) +@roles_accepted('Super User', 'Tenant Admin') +def handle_retriever_selection(): + retriever_identification = request.form.get('selected_row') + retriever_id = ast.literal_eval(retriever_identification).get('value') + action = request.form['action'] + + if action == 'edit_retriever': + return redirect(prefixed_url_for('document_bp.edit_retriever', retriever_id=retriever_id)) + + return redirect(prefixed_url_for('document_bp.retrievers')) + + @document_bp.route('/add_document', methods=['GET', 'POST']) @roles_accepted('Super User', 'Tenant Admin') def add_document(): diff --git a/eveai_app/views/dynamic_form_base.py b/eveai_app/views/dynamic_form_base.py new file mode 100644 index 0000000..46ba100 --- /dev/null +++ b/eveai_app/views/dynamic_form_base.py @@ -0,0 +1,92 @@ +from flask_wtf import FlaskForm +from wtforms import IntegerField, FloatField, BooleanField, StringField, validators +from flask import current_app + +class DynamicFormBase(FlaskForm): + def __init__(self, formdata=None, *args, **kwargs): + super(DynamicFormBase, self).__init__(*args, **kwargs) + # Maps collection names to lists of field names + self.dynamic_fields = {} + # Store formdata for later use + self.formdata = formdata + + def add_dynamic_fields(self, collection_name, config, initial_data=None): + """Add dynamic fields to the form based on the configuration.""" + self.dynamic_fields[collection_name] = [] + for field_name, field_def in config.items(): + current_app.logger.debug(f"{field_name}: {field_def}") + # Prefix the field name with the collection name + full_field_name = f"{collection_name}_{field_name}" + field_type = field_def.get('type') + description = field_def.get('description', '') + required = field_def.get('required', False) + default = field_def.get('default') + + # Determine validators + field_validators = [validators.InputRequired()] if required else [validators.Optional()] + + # Map the field type to WTForms field classes + field_class = { + 'int': IntegerField, + 'float': FloatField, + 'boolean': BooleanField, + 'string': StringField, + }.get(field_type, StringField) + + # Create the field instance + unbound_field = field_class( + label=description, + validators=field_validators, + default=default + ) + + # Bind the field to the form + bound_field = unbound_field.bind(form=self, name=full_field_name) + + # Process the field with formdata + if self.formdata and full_field_name in self.formdata: + # If formdata is available and contains the field + bound_field.process(self.formdata) + elif initial_data and field_name in initial_data: + # Use initial data if provided + bound_field.process(formdata=None, data=initial_data[field_name]) + else: + # Use default value + bound_field.process(formdata=None, data=default) + + # Set collection name attribute for identification + # bound_field.collection_name = collection_name + + # Add the field to the form + setattr(self, full_field_name, bound_field) + self._fields[full_field_name] = bound_field + self.dynamic_fields[collection_name].append(full_field_name) + + def get_static_fields(self): + """Return a list of static field instances.""" + # Get names of dynamic fields + dynamic_field_names = set() + for field_list in self.dynamic_fields.values(): + dynamic_field_names.update(field_list) + + # Return all fields that are not dynamic + return [field for name, field in self._fields.items() if name not in dynamic_field_names] + + def get_dynamic_fields(self): + """Return a dictionary of dynamic fields per collection.""" + result = {} + for collection_name, field_names in self.dynamic_fields.items(): + result[collection_name] = [getattr(self, name) for name in field_names] + return result + + def get_dynamic_data(self, collection_name): + """Retrieve the data from dynamic fields of a specific collection.""" + data = {} + if collection_name not in self.dynamic_fields: + return data + prefix_length = len(collection_name) + 1 # +1 for the underscore + for full_field_name in self.dynamic_fields[collection_name]: + original_field_name = full_field_name[prefix_length:] + field = getattr(self, full_field_name) + data[original_field_name] = field.data + return data diff --git a/eveai_chat_workers/tasks.py b/eveai_chat_workers/tasks.py index 90f6641..9c21e8b 100644 --- a/eveai_chat_workers/tasks.py +++ b/eveai_chat_workers/tasks.py @@ -22,7 +22,7 @@ from common.models.interaction import ChatSession, Interaction, InteractionEmbed from common.extensions import db from common.utils.celery_utils import current_celery from common.utils.model_utils import select_model_variables, create_language_template, replace_variable_in_template -from common.langchain.eveai_retriever import EveAIRetriever +from common.langchain.eveai_default_rag_retriever import EveAIDefaultRagRetriever from common.langchain.eveai_history_retriever import EveAIHistoryRetriever from common.utils.business_event import BusinessEvent from common.utils.business_event_context import current_event @@ -139,7 +139,7 @@ def answer_using_tenant_rag(question, language, tenant, chat_session): new_interaction.detailed_question_at = dt.now(tz.utc) with current_event.create_span("Generate Answer using RAG"): - retriever = EveAIRetriever(model_variables, tenant_info) + retriever = EveAIDefaultRagRetriever(model_variables, tenant_info) llm = model_variables['llm'] template = model_variables['rag_template'] language_template = create_language_template(template, language) @@ -243,7 +243,7 @@ def answer_using_llm(question, language, tenant, chat_session): new_interaction.detailed_question_at = dt.now(tz.utc) with current_event.create_span("Detail Answer using LLM"): - retriever = EveAIRetriever(model_variables, tenant_info) + retriever = EveAIDefaultRagRetriever(model_variables, tenant_info) llm = model_variables['llm_no_rag'] template = model_variables['encyclopedia_template'] language_template = create_language_template(template, language) diff --git a/integrations/Wordpress/eveai-chat-widget.zip b/integrations/Wordpress/eveai-chat-widget.zip deleted file mode 100644 index 709bc0d..0000000 Binary files a/integrations/Wordpress/eveai-chat-widget.zip and /dev/null differ diff --git a/migrations/tenant/versions/3717364e6429_add_retriever_model.py b/migrations/tenant/versions/3717364e6429_add_retriever_model.py new file mode 100644 index 0000000..74cd51d --- /dev/null +++ b/migrations/tenant/versions/3717364e6429_add_retriever_model.py @@ -0,0 +1,56 @@ +"""Add Retriever Model + +Revision ID: 3717364e6429 +Revises: 7b7b566e667f +Create Date: 2024-10-21 14:22:30.258679 + +""" +from alembic import op +import sqlalchemy as sa +import pgvector +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '3717364e6429' +down_revision = '7b7b566e667f' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('retriever', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('name', sa.String(length=50), nullable=False), + sa.Column('description', sa.Text(), nullable=True), + sa.Column('catalog_id', sa.Integer(), nullable=True), + sa.Column('type', sa.String(length=50), nullable=False), + sa.Column('user_metadata', postgresql.JSONB(astext_type=sa.Text()), nullable=True), + sa.Column('system_metadata', postgresql.JSONB(astext_type=sa.Text()), nullable=True), + sa.Column('configuration', postgresql.JSONB(astext_type=sa.Text()), nullable=True), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=False), + sa.Column('created_by', sa.Integer(), nullable=True), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('now()'), nullable=False), + sa.Column('updated_by', sa.Integer(), nullable=True), + sa.ForeignKeyConstraint(['catalog_id'], ['catalog.id'], ), + sa.ForeignKeyConstraint(['created_by'], ['public.user.id'], ), + sa.ForeignKeyConstraint(['updated_by'], ['public.user.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.drop_column('catalog', 'es_similarity_threshold') + op.drop_column('catalog', 'es_k') + op.drop_column('catalog', 'rag_tuning') + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('catalog', sa.Column('rag_tuning', sa.BOOLEAN(), autoincrement=False, nullable=True)) + op.add_column('catalog', sa.Column('es_k', sa.INTEGER(), autoincrement=False, nullable=True)) + op.add_column('catalog', sa.Column('es_similarity_threshold', sa.DOUBLE_PRECISION(precision=53), autoincrement=False, nullable=True)) + op.drop_constraint(None, 'catalog', type_='foreignkey') + op.drop_constraint(None, 'catalog', type_='foreignkey') + op.create_foreign_key('catalog_updated_by_fkey', 'catalog', 'user', ['updated_by'], ['id']) + op.create_foreign_key('catalog_created_by_fkey', 'catalog', 'user', ['created_by'], ['id']) + op.drop_table('retriever') + # ### end Alembic commands ### diff --git a/migrations/tenant/versions/7b7b566e667f_extensions_for_more_catalog_types_in_.py b/migrations/tenant/versions/7b7b566e667f_extensions_for_more_catalog_types_in_.py new file mode 100644 index 0000000..b348b40 --- /dev/null +++ b/migrations/tenant/versions/7b7b566e667f_extensions_for_more_catalog_types_in_.py @@ -0,0 +1,46 @@ +"""Extensions for more Catalog Types in Catalog Model + +Revision ID: 7b7b566e667f +Revises: 28984b05d396 +Create Date: 2024-10-21 07:39:52.260054 + +""" +from alembic import op +import sqlalchemy as sa +import pgvector +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '7b7b566e667f' +down_revision = '28984b05d396' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('catalog', sa.Column('parent_id', sa.Integer(), nullable=True)) + + op.add_column('catalog', sa.Column('type', sa.String(length=50), nullable=False, server_default='DEFAULT')) + # Update all existing rows to have the 'type' set to 'DEFAULT' + op.execute("UPDATE catalog SET type='DEFAULT' WHERE type IS NULL") + # Remove the server default (optional but recommended) + op.alter_column('catalog', 'type', server_default=None) + + op.add_column('catalog', sa.Column('user_metadata', postgresql.JSONB(astext_type=sa.Text()), nullable=True)) + op.add_column('catalog', sa.Column('system_metadata', postgresql.JSONB(astext_type=sa.Text()), nullable=True)) + op.add_column('catalog', sa.Column('configuration', postgresql.JSONB(astext_type=sa.Text()), nullable=True)) + op.create_foreign_key(None, 'catalog', 'catalog', ['parent_id'], ['id']) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint(None, 'catalog', type_='foreignkey') + op.drop_column('catalog', 'configuration') + op.drop_column('catalog', 'system_metadata') + op.drop_column('catalog', 'user_metadata') + op.drop_column('catalog', 'type') + op.drop_column('catalog', 'parent_id') + # ### end Alembic commands ### diff --git a/requirements.txt b/requirements.txt index 2f57f6f..be8e3db 100644 --- a/requirements.txt +++ b/requirements.txt @@ -82,3 +82,4 @@ prometheus-client~=0.20.0 flower~=2.0.1 psutil~=6.0.0 celery-redbeat~=2.2.0 +WTForms-SQLAlchemy~=0.4.1