- Allowing for multiple types of Catalogs
- Introduction of retrievers - Ensuring processing information is collected from Catalog iso Tenant - Introduction of a generic Form class to enable dynamic fields based on a configuration - Realisation of Retriever functionality to support dynamic fields
This commit is contained in:
145
common/langchain/eveai_default_rag_retriever.py
Normal file
145
common/langchain/eveai_default_rag_retriever.py
Normal file
@@ -0,0 +1,145 @@
|
|||||||
|
from langchain_core.retrievers import BaseRetriever
|
||||||
|
from sqlalchemy import func, and_, or_, desc
|
||||||
|
from sqlalchemy.exc import SQLAlchemyError
|
||||||
|
from pydantic import BaseModel, Field, PrivateAttr
|
||||||
|
from typing import Any, Dict
|
||||||
|
from flask import current_app
|
||||||
|
|
||||||
|
from common.extensions import db
|
||||||
|
from common.models.document import Document, DocumentVersion
|
||||||
|
from common.utils.datetime_utils import get_date_in_timezone
|
||||||
|
from common.utils.model_utils import ModelVariables
|
||||||
|
|
||||||
|
|
||||||
|
class EveAIDefaultRagRetriever(BaseRetriever, BaseModel):
|
||||||
|
_catalog_id: int = PrivateAttr()
|
||||||
|
_model_variables: ModelVariables = PrivateAttr()
|
||||||
|
_tenant_info: Dict[str, Any] = PrivateAttr()
|
||||||
|
|
||||||
|
def __init__(self, catalog_id: int, model_variables: ModelVariables, tenant_info: Dict[str, Any]):
|
||||||
|
super().__init__()
|
||||||
|
current_app.logger.debug(f'Model variables type: {type(model_variables)}')
|
||||||
|
self._catalog_id = catalog_id
|
||||||
|
self._model_variables = model_variables
|
||||||
|
self._tenant_info = tenant_info
|
||||||
|
|
||||||
|
@property
|
||||||
|
def catalog_id(self) -> int:
|
||||||
|
return self._catalog_id
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_variables(self) -> ModelVariables:
|
||||||
|
return self._model_variables
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tenant_info(self) -> Dict[str, Any]:
|
||||||
|
return self._tenant_info
|
||||||
|
|
||||||
|
def _get_relevant_documents(self, query: str):
|
||||||
|
current_app.logger.debug(f'Retrieving relevant documents for query: {query}')
|
||||||
|
query_embedding = self._get_query_embedding(query)
|
||||||
|
current_app.logger.debug(f'Model Variables Private: {type(self._model_variables)}')
|
||||||
|
current_app.logger.debug(f'Model Variables Property: {type(self.model_variables)}')
|
||||||
|
db_class = self.model_variables['embedding_db_model']
|
||||||
|
similarity_threshold = self.model_variables['similarity_threshold']
|
||||||
|
k = self.model_variables['k']
|
||||||
|
|
||||||
|
if self.model_variables['rag_tuning']:
|
||||||
|
try:
|
||||||
|
current_date = get_date_in_timezone(self.tenant_info['timezone'])
|
||||||
|
current_app.rag_tuning_logger.debug(f'Current date: {current_date}\n')
|
||||||
|
|
||||||
|
# Debug query to show similarity for all valid documents (without chunk text)
|
||||||
|
debug_query = (
|
||||||
|
db.session.query(
|
||||||
|
Document.id.label('document_id'),
|
||||||
|
DocumentVersion.id.label('version_id'),
|
||||||
|
db_class.id.label('embedding_id'),
|
||||||
|
(1 - db_class.embedding.cosine_distance(query_embedding)).label('similarity')
|
||||||
|
)
|
||||||
|
.join(DocumentVersion, db_class.doc_vers_id == DocumentVersion.id)
|
||||||
|
.join(Document, DocumentVersion.doc_id == Document.id)
|
||||||
|
.filter(
|
||||||
|
or_(Document.valid_from.is_(None), func.date(Document.valid_from) <= current_date),
|
||||||
|
or_(Document.valid_to.is_(None), func.date(Document.valid_to) >= current_date)
|
||||||
|
)
|
||||||
|
.order_by(desc('similarity'))
|
||||||
|
)
|
||||||
|
|
||||||
|
debug_results = debug_query.all()
|
||||||
|
|
||||||
|
current_app.logger.debug("Debug: Similarity for all valid documents:")
|
||||||
|
for row in debug_results:
|
||||||
|
current_app.rag_tuning_logger.debug(f"Doc ID: {row.document_id}, "
|
||||||
|
f"Version ID: {row.version_id}, "
|
||||||
|
f"Embedding ID: {row.embedding_id}, "
|
||||||
|
f"Similarity: {row.similarity}")
|
||||||
|
current_app.rag_tuning_logger.debug(f'---------------------------------------\n')
|
||||||
|
except SQLAlchemyError as e:
|
||||||
|
current_app.logger.error(f'Error generating overview: {e}')
|
||||||
|
db.session.rollback()
|
||||||
|
|
||||||
|
if self.model_variables['rag_tuning']:
|
||||||
|
current_app.rag_tuning_logger.debug(f'Parameters for Retrieval of documents: \n')
|
||||||
|
current_app.rag_tuning_logger.debug(f'Similarity Threshold: {similarity_threshold}\n')
|
||||||
|
current_app.rag_tuning_logger.debug(f'K: {k}\n')
|
||||||
|
current_app.rag_tuning_logger.debug(f'---------------------------------------\n')
|
||||||
|
|
||||||
|
try:
|
||||||
|
current_date = get_date_in_timezone(self.tenant_info['timezone'])
|
||||||
|
# Subquery to find the latest version of each document
|
||||||
|
subquery = (
|
||||||
|
db.session.query(
|
||||||
|
DocumentVersion.doc_id,
|
||||||
|
func.max(DocumentVersion.id).label('latest_version_id')
|
||||||
|
)
|
||||||
|
.group_by(DocumentVersion.doc_id)
|
||||||
|
.subquery()
|
||||||
|
)
|
||||||
|
# Main query to filter embeddings
|
||||||
|
query_obj = (
|
||||||
|
db.session.query(db_class,
|
||||||
|
(1 - db_class.embedding.cosine_distance(query_embedding)).label('similarity'))
|
||||||
|
.join(DocumentVersion, db_class.doc_vers_id == DocumentVersion.id)
|
||||||
|
.join(Document, DocumentVersion.doc_id == Document.id)
|
||||||
|
.join(subquery, DocumentVersion.id == subquery.c.latest_version_id)
|
||||||
|
.filter(
|
||||||
|
or_(Document.valid_from.is_(None), func.date(Document.valid_from) <= current_date),
|
||||||
|
or_(Document.valid_to.is_(None), func.date(Document.valid_to) >= current_date),
|
||||||
|
(1 - db_class.embedding.cosine_distance(query_embedding)) > similarity_threshold,
|
||||||
|
Document.catalog_id == self._catalog_id
|
||||||
|
)
|
||||||
|
.order_by(desc('similarity'))
|
||||||
|
.limit(k)
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.model_variables['rag_tuning']:
|
||||||
|
current_app.rag_tuning_logger.debug(f'Query executed for Retrieval of documents: \n')
|
||||||
|
current_app.rag_tuning_logger.debug(f'{query_obj.statement}\n')
|
||||||
|
current_app.rag_tuning_logger.debug(f'---------------------------------------\n')
|
||||||
|
|
||||||
|
res = query_obj.all()
|
||||||
|
|
||||||
|
if self.model_variables['rag_tuning']:
|
||||||
|
current_app.rag_tuning_logger.debug(f'Retrieved {len(res)} relevant documents \n')
|
||||||
|
current_app.rag_tuning_logger.debug(f'Data retrieved: \n')
|
||||||
|
current_app.rag_tuning_logger.debug(f'{res}\n')
|
||||||
|
current_app.rag_tuning_logger.debug(f'---------------------------------------\n')
|
||||||
|
|
||||||
|
result = []
|
||||||
|
for doc in res:
|
||||||
|
if self.model_variables['rag_tuning']:
|
||||||
|
current_app.rag_tuning_logger.debug(f'Document ID: {doc[0].id} - Distance: {doc[1]}\n')
|
||||||
|
current_app.rag_tuning_logger.debug(f'Chunk: \n {doc[0].chunk}\n\n')
|
||||||
|
result.append(f'SOURCE: {doc[0].id}\n\n{doc[0].chunk}\n\n')
|
||||||
|
|
||||||
|
except SQLAlchemyError as e:
|
||||||
|
current_app.logger.error(f'Error retrieving relevant documents: {e}')
|
||||||
|
db.session.rollback()
|
||||||
|
return []
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _get_query_embedding(self, query: str):
|
||||||
|
embedding_model = self.model_variables['embedding_model']
|
||||||
|
query_embedding = embedding_model.embed_query(query)
|
||||||
|
return query_embedding
|
||||||
@@ -1,138 +1,39 @@
|
|||||||
from langchain_core.retrievers import BaseRetriever
|
from pydantic import BaseModel, PrivateAttr
|
||||||
from sqlalchemy import func, and_, or_, desc
|
from typing import Dict, Any
|
||||||
from sqlalchemy.exc import SQLAlchemyError
|
|
||||||
from pydantic import BaseModel, Field, PrivateAttr
|
|
||||||
from typing import Any, Dict
|
|
||||||
from flask import current_app
|
|
||||||
|
|
||||||
from common.extensions import db
|
|
||||||
from common.models.document import Document, DocumentVersion
|
|
||||||
from common.utils.datetime_utils import get_date_in_timezone
|
|
||||||
from common.utils.model_utils import ModelVariables
|
from common.utils.model_utils import ModelVariables
|
||||||
|
|
||||||
|
|
||||||
class EveAIRetriever(BaseRetriever, BaseModel):
|
class EveAIRetriever(BaseModel):
|
||||||
_model_variables: ModelVariables = PrivateAttr()
|
_catalog_id: int = PrivateAttr()
|
||||||
|
_user_metadata: Dict[str, Any] = PrivateAttr()
|
||||||
|
_system_metadata: Dict[str, Any] = PrivateAttr()
|
||||||
|
_configuration: Dict[str, Any] = PrivateAttr()
|
||||||
_tenant_info: Dict[str, Any] = PrivateAttr()
|
_tenant_info: Dict[str, Any] = PrivateAttr()
|
||||||
|
_model_variables: ModelVariables = PrivateAttr()
|
||||||
|
|
||||||
def __init__(self, model_variables: ModelVariables, tenant_info: Dict[str, Any]):
|
def __init__(self, catalog_id: int, user_metadata: Dict[str, Any], system_metadata: Dict[str, Any],
|
||||||
|
configuration: Dict[str, Any]):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
current_app.logger.debug(f'Model variables type: {type(model_variables)}')
|
self._catalog_id = catalog_id
|
||||||
self._model_variables = model_variables
|
self._user_metadata = user_metadata
|
||||||
self._tenant_info = tenant_info
|
self._system_metadata = system_metadata
|
||||||
|
self._configuration = configuration
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def model_variables(self) -> ModelVariables:
|
def catalog_id(self):
|
||||||
return self._model_variables
|
return self._catalog_id
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def tenant_info(self) -> Dict[str, Any]:
|
def user_metadata(self):
|
||||||
return self._tenant_info
|
return self._user_metadata
|
||||||
|
|
||||||
def _get_relevant_documents(self, query: str):
|
@property
|
||||||
current_app.logger.debug(f'Retrieving relevant documents for query: {query}')
|
def system_metadata(self):
|
||||||
query_embedding = self._get_query_embedding(query)
|
return self._system_metadata
|
||||||
current_app.logger.debug(f'Model Variables Private: {type(self._model_variables)}')
|
|
||||||
current_app.logger.debug(f'Model Variables Property: {type(self.model_variables)}')
|
|
||||||
db_class = self.model_variables['embedding_db_model']
|
|
||||||
similarity_threshold = self.model_variables['similarity_threshold']
|
|
||||||
k = self.model_variables['k']
|
|
||||||
|
|
||||||
if self.model_variables['rag_tuning']:
|
@property
|
||||||
try:
|
def configuration(self):
|
||||||
current_date = get_date_in_timezone(self.tenant_info['timezone'])
|
return self._configuration
|
||||||
current_app.rag_tuning_logger.debug(f'Current date: {current_date}\n')
|
|
||||||
|
|
||||||
# Debug query to show similarity for all valid documents (without chunk text)
|
# Any common methods that should be shared among retrievers can go here.
|
||||||
debug_query = (
|
|
||||||
db.session.query(
|
|
||||||
Document.id.label('document_id'),
|
|
||||||
DocumentVersion.id.label('version_id'),
|
|
||||||
db_class.id.label('embedding_id'),
|
|
||||||
(1 - db_class.embedding.cosine_distance(query_embedding)).label('similarity')
|
|
||||||
)
|
|
||||||
.join(DocumentVersion, db_class.doc_vers_id == DocumentVersion.id)
|
|
||||||
.join(Document, DocumentVersion.doc_id == Document.id)
|
|
||||||
.filter(
|
|
||||||
or_(Document.valid_from.is_(None), func.date(Document.valid_from) <= current_date),
|
|
||||||
or_(Document.valid_to.is_(None), func.date(Document.valid_to) >= current_date)
|
|
||||||
)
|
|
||||||
.order_by(desc('similarity'))
|
|
||||||
)
|
|
||||||
|
|
||||||
debug_results = debug_query.all()
|
|
||||||
|
|
||||||
current_app.logger.debug("Debug: Similarity for all valid documents:")
|
|
||||||
for row in debug_results:
|
|
||||||
current_app.rag_tuning_logger.debug(f"Doc ID: {row.document_id}, "
|
|
||||||
f"Version ID: {row.version_id}, "
|
|
||||||
f"Embedding ID: {row.embedding_id}, "
|
|
||||||
f"Similarity: {row.similarity}")
|
|
||||||
current_app.rag_tuning_logger.debug(f'---------------------------------------\n')
|
|
||||||
except SQLAlchemyError as e:
|
|
||||||
current_app.logger.error(f'Error generating overview: {e}')
|
|
||||||
db.session.rollback()
|
|
||||||
|
|
||||||
if self.model_variables['rag_tuning']:
|
|
||||||
current_app.rag_tuning_logger.debug(f'Parameters for Retrieval of documents: \n')
|
|
||||||
current_app.rag_tuning_logger.debug(f'Similarity Threshold: {similarity_threshold}\n')
|
|
||||||
current_app.rag_tuning_logger.debug(f'K: {k}\n')
|
|
||||||
current_app.rag_tuning_logger.debug(f'---------------------------------------\n')
|
|
||||||
|
|
||||||
try:
|
|
||||||
current_date = get_date_in_timezone(self.tenant_info['timezone'])
|
|
||||||
# Subquery to find the latest version of each document
|
|
||||||
subquery = (
|
|
||||||
db.session.query(
|
|
||||||
DocumentVersion.doc_id,
|
|
||||||
func.max(DocumentVersion.id).label('latest_version_id')
|
|
||||||
)
|
|
||||||
.group_by(DocumentVersion.doc_id)
|
|
||||||
.subquery()
|
|
||||||
)
|
|
||||||
# Main query to filter embeddings
|
|
||||||
query_obj = (
|
|
||||||
db.session.query(db_class,
|
|
||||||
(1 - db_class.embedding.cosine_distance(query_embedding)).label('similarity'))
|
|
||||||
.join(DocumentVersion, db_class.doc_vers_id == DocumentVersion.id)
|
|
||||||
.join(Document, DocumentVersion.doc_id == Document.id)
|
|
||||||
.join(subquery, DocumentVersion.id == subquery.c.latest_version_id)
|
|
||||||
.filter(
|
|
||||||
or_(Document.valid_from.is_(None), func.date(Document.valid_from) <= current_date),
|
|
||||||
or_(Document.valid_to.is_(None), func.date(Document.valid_to) >= current_date),
|
|
||||||
(1 - db_class.embedding.cosine_distance(query_embedding)) > similarity_threshold
|
|
||||||
)
|
|
||||||
.order_by(desc('similarity'))
|
|
||||||
.limit(k)
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.model_variables['rag_tuning']:
|
|
||||||
current_app.rag_tuning_logger.debug(f'Query executed for Retrieval of documents: \n')
|
|
||||||
current_app.rag_tuning_logger.debug(f'{query_obj.statement}\n')
|
|
||||||
current_app.rag_tuning_logger.debug(f'---------------------------------------\n')
|
|
||||||
|
|
||||||
res = query_obj.all()
|
|
||||||
|
|
||||||
if self.model_variables['rag_tuning']:
|
|
||||||
current_app.rag_tuning_logger.debug(f'Retrieved {len(res)} relevant documents \n')
|
|
||||||
current_app.rag_tuning_logger.debug(f'Data retrieved: \n')
|
|
||||||
current_app.rag_tuning_logger.debug(f'{res}\n')
|
|
||||||
current_app.rag_tuning_logger.debug(f'---------------------------------------\n')
|
|
||||||
|
|
||||||
result = []
|
|
||||||
for doc in res:
|
|
||||||
if self.model_variables['rag_tuning']:
|
|
||||||
current_app.rag_tuning_logger.debug(f'Document ID: {doc[0].id} - Distance: {doc[1]}\n')
|
|
||||||
current_app.rag_tuning_logger.debug(f'Chunk: \n {doc[0].chunk}\n\n')
|
|
||||||
result.append(f'SOURCE: {doc[0].id}\n\n{doc[0].chunk}\n\n')
|
|
||||||
|
|
||||||
except SQLAlchemyError as e:
|
|
||||||
current_app.logger.error(f'Error retrieving relevant documents: {e}')
|
|
||||||
db.session.rollback()
|
|
||||||
return []
|
|
||||||
return result
|
|
||||||
|
|
||||||
def _get_query_embedding(self, query: str):
|
|
||||||
embedding_model = self.model_variables['embedding_model']
|
|
||||||
query_embedding = embedding_model.embed_query(query)
|
|
||||||
return query_embedding
|
|
||||||
|
|||||||
@@ -8,8 +8,10 @@ import sqlalchemy as sa
|
|||||||
|
|
||||||
class Catalog(db.Model):
|
class Catalog(db.Model):
|
||||||
id = db.Column(db.Integer, primary_key=True)
|
id = db.Column(db.Integer, primary_key=True)
|
||||||
|
parent_id = db.Column(db.Integer, db.ForeignKey('catalog.id'), nullable=True)
|
||||||
name = db.Column(db.String(50), nullable=False)
|
name = db.Column(db.String(50), nullable=False)
|
||||||
description = db.Column(db.Text, nullable=True)
|
description = db.Column(db.Text, nullable=True)
|
||||||
|
type = db.Column(db.String(50), nullable=False, default="DEFAULT_CATALOG")
|
||||||
|
|
||||||
# Embedding variables
|
# Embedding variables
|
||||||
html_tags = db.Column(ARRAY(sa.String(10)), nullable=True, default=['p', 'h1', 'h2', 'h3', 'h4', 'h5', 'h6', 'li'])
|
html_tags = db.Column(ARRAY(sa.String(10)), nullable=True, default=['p', 'h1', 'h2', 'h3', 'h4', 'h5', 'h6', 'li'])
|
||||||
@@ -21,17 +23,36 @@ class Catalog(db.Model):
|
|||||||
min_chunk_size = db.Column(db.Integer, nullable=True, default=2000)
|
min_chunk_size = db.Column(db.Integer, nullable=True, default=2000)
|
||||||
max_chunk_size = db.Column(db.Integer, nullable=True, default=3000)
|
max_chunk_size = db.Column(db.Integer, nullable=True, default=3000)
|
||||||
|
|
||||||
# Embedding search variables ==> move to specialist?
|
|
||||||
es_k = db.Column(db.Integer, nullable=True, default=8)
|
|
||||||
es_similarity_threshold = db.Column(db.Float, nullable=True, default=0.4)
|
|
||||||
|
|
||||||
# Chat variables ==> Move to Specialist?
|
# Chat variables ==> Move to Specialist?
|
||||||
chat_RAG_temperature = db.Column(db.Float, nullable=True, default=0.3)
|
chat_RAG_temperature = db.Column(db.Float, nullable=True, default=0.3)
|
||||||
chat_no_RAG_temperature = db.Column(db.Float, nullable=True, default=0.5)
|
chat_no_RAG_temperature = db.Column(db.Float, nullable=True, default=0.5)
|
||||||
|
|
||||||
# Tuning enablers
|
# Tuning enablers
|
||||||
embed_tuning = db.Column(db.Boolean, nullable=True, default=False)
|
embed_tuning = db.Column(db.Boolean, nullable=True, default=False)
|
||||||
rag_tuning = db.Column(db.Boolean, nullable=True, default=False) # Move to Specialist?
|
|
||||||
|
# Meta Data
|
||||||
|
user_metadata = db.Column(JSONB, nullable=True)
|
||||||
|
system_metadata = db.Column(JSONB, nullable=True)
|
||||||
|
configuration = db.Column(JSONB, nullable=True)
|
||||||
|
|
||||||
|
# Versioning Information
|
||||||
|
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now())
|
||||||
|
created_by = db.Column(db.Integer, db.ForeignKey(User.id), nullable=True)
|
||||||
|
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now(), onupdate=db.func.now())
|
||||||
|
updated_by = db.Column(db.Integer, db.ForeignKey(User.id))
|
||||||
|
|
||||||
|
|
||||||
|
class Retriever(db.Model):
|
||||||
|
id = db.Column(db.Integer, primary_key=True)
|
||||||
|
name = db.Column(db.String(50), nullable=False)
|
||||||
|
description = db.Column(db.Text, nullable=True)
|
||||||
|
catalog_id = db.Column(db.Integer, db.ForeignKey('catalog.id'), nullable=True)
|
||||||
|
type = db.Column(db.String(50), nullable=False, default="DEFAULT_RAG")
|
||||||
|
|
||||||
|
# Meta Data
|
||||||
|
user_metadata = db.Column(JSONB, nullable=True)
|
||||||
|
system_metadata = db.Column(JSONB, nullable=True)
|
||||||
|
configuration = db.Column(JSONB, nullable=True)
|
||||||
|
|
||||||
# Versioning Information
|
# Versioning Information
|
||||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now())
|
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now())
|
||||||
|
|||||||
11
config/catalog_types.py
Normal file
11
config/catalog_types.py
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
# Catalog Types
|
||||||
|
CATALOG_TYPES = {
|
||||||
|
"DEFAULT": {
|
||||||
|
"name": "Default Catalog",
|
||||||
|
"Description": "Default Catalog"
|
||||||
|
},
|
||||||
|
"DOSSIER": {
|
||||||
|
"name": "Dossier Catalog",
|
||||||
|
"Description": "A Catalog in which several Dossiers can be stored"
|
||||||
|
},
|
||||||
|
}
|
||||||
30
config/retriever_types.py
Normal file
30
config/retriever_types.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
# Retriever Types
|
||||||
|
RETRIEVER_TYPES = {
|
||||||
|
"DEFAULT_RAG": {
|
||||||
|
"name": "Default RAG",
|
||||||
|
"description": "Retrieving all embeddings conform the query",
|
||||||
|
"configuration": {
|
||||||
|
"es_k": {
|
||||||
|
"name": "es_k",
|
||||||
|
"type": "int",
|
||||||
|
"description": "K-value to retrieve embeddings (max embeddings retrieved)",
|
||||||
|
"required": True,
|
||||||
|
"default": 8,
|
||||||
|
},
|
||||||
|
"es_similarity_threshold": {
|
||||||
|
"name": "es_similarity_threshold",
|
||||||
|
"type": "float",
|
||||||
|
"description": "Similarity threshold for retrieving embeddings",
|
||||||
|
"required": True,
|
||||||
|
"default": 0.3,
|
||||||
|
},
|
||||||
|
"rag_tuning": {
|
||||||
|
"name": "rag_tuning",
|
||||||
|
"type": "boolean",
|
||||||
|
"description": "Whether to do tuning logging or not.",
|
||||||
|
"required": False,
|
||||||
|
"default": False,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -16,7 +16,7 @@ When you change chunking of embedding information, you'll need to manually refre
|
|||||||
{% for field in form %}
|
{% for field in form %}
|
||||||
{{ render_field(field, disabled_fields, exclude_fields) }}
|
{{ render_field(field, disabled_fields, exclude_fields) }}
|
||||||
{% endfor %}
|
{% endfor %}
|
||||||
<button type="submit" class="btn btn-primary">Register Catalog</button>
|
<button type="submit" class="btn btn-primary">Save Catalog</button>
|
||||||
</form>
|
</form>
|
||||||
{% endblock %}
|
{% endblock %}
|
||||||
|
|
||||||
|
|||||||
31
eveai_app/templates/document/edit_retriever.html
Normal file
31
eveai_app/templates/document/edit_retriever.html
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
{% extends 'base.html' %}
|
||||||
|
{% from "macros.html" import render_field2, render_dynamic_fields %}
|
||||||
|
|
||||||
|
{% block title %}Edit Retriever{% endblock %}
|
||||||
|
|
||||||
|
{% block content_title %}Edit Retriever{% endblock %}
|
||||||
|
{% block content_description %}Edit a Retriever (for a Catalog){% endblock %}
|
||||||
|
|
||||||
|
{% block content %}
|
||||||
|
<form method="post">
|
||||||
|
{{ form.hidden_tag() }}
|
||||||
|
{% set disabled_fields = ['type'] %}
|
||||||
|
{% set exclude_fields = [] %}
|
||||||
|
<!-- Render Static Fields -->
|
||||||
|
{% for field in form.get_static_fields() %}
|
||||||
|
{{ render_field2(field, disabled_fields, exclude_fields) }}
|
||||||
|
{% endfor %}
|
||||||
|
<!-- Render Dynamic Fields -->
|
||||||
|
{% for collection_name, fields in form.get_dynamic_fields().items() %}
|
||||||
|
<h4 class="mt-4">{{ collection_name }}</h4>
|
||||||
|
{% for field in fields %}
|
||||||
|
{{ render_field2(field, disabled_fields, exclude_fields) }}
|
||||||
|
{% endfor %}
|
||||||
|
{% endfor %}
|
||||||
|
<button type="submit" class="btn btn-primary">Save Retriever</button>
|
||||||
|
</form>
|
||||||
|
{% endblock %}
|
||||||
|
|
||||||
|
{% block content_footer %}
|
||||||
|
|
||||||
|
{% endblock %}
|
||||||
23
eveai_app/templates/document/retriever.html
Normal file
23
eveai_app/templates/document/retriever.html
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
{% extends 'base.html' %}
|
||||||
|
{% from "macros.html" import render_field %}
|
||||||
|
|
||||||
|
{% block title %}Retriever Registration{% endblock %}
|
||||||
|
|
||||||
|
{% block content_title %}Register Retriever{% endblock %}
|
||||||
|
{% block content_description %}Define a new retriever (for a catalog){% endblock %}
|
||||||
|
|
||||||
|
{% block content %}
|
||||||
|
<form method="post">
|
||||||
|
{{ form.hidden_tag() }}
|
||||||
|
{% set disabled_fields = [] %}
|
||||||
|
{% set exclude_fields = [] %}
|
||||||
|
{% for field in form %}
|
||||||
|
{{ render_field(field, disabled_fields, exclude_fields) }}
|
||||||
|
{% endfor %}
|
||||||
|
<button type="submit" class="btn btn-primary">Register Retriever</button>
|
||||||
|
</form>
|
||||||
|
{% endblock %}
|
||||||
|
|
||||||
|
{% block content_footer %}
|
||||||
|
|
||||||
|
{% endblock %}
|
||||||
23
eveai_app/templates/document/retrievers.html
Normal file
23
eveai_app/templates/document/retrievers.html
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
{% extends 'base.html' %}
|
||||||
|
{% from 'macros.html' import render_selectable_table, render_pagination %}
|
||||||
|
|
||||||
|
{% block title %}Retrievers{% endblock %}
|
||||||
|
|
||||||
|
{% block content_title %}Retrievers{% endblock %}
|
||||||
|
{% block content_description %}View Retrieers for Tenant{% endblock %}
|
||||||
|
{% block content_class %}<div class="col-xl-12 col-lg-5 col-md-7 mx-auto"></div>{% endblock %}
|
||||||
|
|
||||||
|
{% block content %}
|
||||||
|
<div class="container">
|
||||||
|
<form method="POST" action="{{ url_for('document_bp.handle_retriever_selection') }}">
|
||||||
|
{{ render_selectable_table(headers=["Retriever ID", "Name", "Type", "Catalog ID"], rows=rows, selectable=True, id="retrieverssTable") }}
|
||||||
|
<div class="form-group mt-3">
|
||||||
|
<button type="submit" name="action" value="edit_retriever" class="btn btn-primary">Edit Retriever</button>
|
||||||
|
</div>
|
||||||
|
</form>
|
||||||
|
</div>
|
||||||
|
{% endblock %}
|
||||||
|
|
||||||
|
{% block content_footer %}
|
||||||
|
{{ render_pagination(pagination, 'document_bp.retrievers') }}
|
||||||
|
{% endblock %}
|
||||||
@@ -23,6 +23,43 @@
|
|||||||
{% endif %}
|
{% endif %}
|
||||||
{% endmacro %}
|
{% endmacro %}
|
||||||
|
|
||||||
|
{% macro render_field2(field, disabled_fields=[], exclude_fields=[], class='') %}
|
||||||
|
<!-- Debug info -->
|
||||||
|
<!-- Field name: {{ field.name }}, Field type: {{ field.__class__.__name__ }} -->
|
||||||
|
|
||||||
|
{% set disabled = field.name in disabled_fields %}
|
||||||
|
{% set exclude_fields = exclude_fields + ['csrf_token', 'submit'] %}
|
||||||
|
{% if field.name not in exclude_fields %}
|
||||||
|
{% if field.type == 'BooleanField' %}
|
||||||
|
<div class="form-group">
|
||||||
|
<div class="form-check form-switch">
|
||||||
|
{{ field(class="form-check-input " + class, disabled=disabled) }}
|
||||||
|
{{ field.label(class="form-check-label") }}
|
||||||
|
</div>
|
||||||
|
{% if field.errors %}
|
||||||
|
<div class="invalid-feedback d-block">
|
||||||
|
{% for error in field.errors %}
|
||||||
|
{{ error }}
|
||||||
|
{% endfor %}
|
||||||
|
</div>
|
||||||
|
{% endif %}
|
||||||
|
</div>
|
||||||
|
{% else %}
|
||||||
|
<div class="form-group">
|
||||||
|
{{ field.label(class="form-label") }}
|
||||||
|
{{ field(class="form-control " + class, disabled=disabled) }}
|
||||||
|
{% if field.errors %}
|
||||||
|
<div class="invalid-feedback d-block">
|
||||||
|
{% for error in field.errors %}
|
||||||
|
{{ error }}
|
||||||
|
{% endfor %}
|
||||||
|
</div>
|
||||||
|
{% endif %}
|
||||||
|
</div>
|
||||||
|
{% endif %}
|
||||||
|
{% endif %}
|
||||||
|
{% endmacro %}
|
||||||
|
|
||||||
{% macro render_included_field(field, disabled_fields=[], include_fields=[]) %}
|
{% macro render_included_field(field, disabled_fields=[], include_fields=[]) %}
|
||||||
{% set disabled = field.name in disabled_fields %}
|
{% set disabled = field.name in disabled_fields %}
|
||||||
{% if field.name in include_fields %}
|
{% if field.name in include_fields %}
|
||||||
|
|||||||
@@ -83,6 +83,8 @@
|
|||||||
{{ dropdown('Document Mgmt', 'note_stack', [
|
{{ dropdown('Document Mgmt', 'note_stack', [
|
||||||
{'name': 'Add Catalog', 'url': '/document/catalog', 'roles': ['Super User', 'Tenant Admin']},
|
{'name': 'Add Catalog', 'url': '/document/catalog', 'roles': ['Super User', 'Tenant Admin']},
|
||||||
{'name': 'All Catalogs', 'url': '/document/catalogs', 'roles': ['Super User', 'Tenant Admin']},
|
{'name': 'All Catalogs', 'url': '/document/catalogs', 'roles': ['Super User', 'Tenant Admin']},
|
||||||
|
{'name': 'Add Retriever', 'url': '/document/retriever', 'roles': ['Super User', 'Tenant Admin']},
|
||||||
|
{'name': 'All Retrievers', 'url': '/document/retrievers', 'roles': ['Super User', 'Tenant Admin']},
|
||||||
{'name': 'Add Document', 'url': '/document/add_document', 'roles': ['Super User', 'Tenant Admin']},
|
{'name': 'Add Document', 'url': '/document/add_document', 'roles': ['Super User', 'Tenant Admin']},
|
||||||
{'name': 'Add URL', 'url': '/document/add_url', 'roles': ['Super User', 'Tenant Admin']},
|
{'name': 'Add URL', 'url': '/document/add_url', 'roles': ['Super User', 'Tenant Admin']},
|
||||||
{'name': 'Add a list of URLs', 'url': '/document/add_urls', 'roles': ['Super User', 'Tenant Admin']},
|
{'name': 'Add a list of URLs', 'url': '/document/add_urls', 'roles': ['Super User', 'Tenant Admin']},
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from flask import session, current_app
|
from flask import session, current_app, request
|
||||||
from flask_wtf import FlaskForm
|
from flask_wtf import FlaskForm
|
||||||
from wtforms import (StringField, BooleanField, SubmitField, DateField, IntegerField, FloatField, SelectMultipleField,
|
from wtforms import (StringField, BooleanField, SubmitField, DateField, IntegerField, FloatField, SelectMultipleField,
|
||||||
SelectField, FieldList, FormField, TextAreaField, URLField)
|
SelectField, FieldList, FormField, TextAreaField, URLField)
|
||||||
@@ -6,6 +6,14 @@ from wtforms.validators import DataRequired, Length, Optional, URL, ValidationEr
|
|||||||
from flask_wtf.file import FileField, FileAllowed, FileRequired
|
from flask_wtf.file import FileField, FileAllowed, FileRequired
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
from wtforms_sqlalchemy.fields import QuerySelectField
|
||||||
|
|
||||||
|
from common.models.document import Catalog
|
||||||
|
|
||||||
|
from config.catalog_types import CATALOG_TYPES
|
||||||
|
from config.retriever_types import RETRIEVER_TYPES
|
||||||
|
from .dynamic_form_base import DynamicFormBase
|
||||||
|
|
||||||
|
|
||||||
def allowed_file(form, field):
|
def allowed_file(form, field):
|
||||||
if field.data:
|
if field.data:
|
||||||
@@ -26,6 +34,23 @@ def validate_json(form, field):
|
|||||||
class CatalogForm(FlaskForm):
|
class CatalogForm(FlaskForm):
|
||||||
name = StringField('Name', validators=[DataRequired(), Length(max=50)])
|
name = StringField('Name', validators=[DataRequired(), Length(max=50)])
|
||||||
description = TextAreaField('Description', validators=[Optional()])
|
description = TextAreaField('Description', validators=[Optional()])
|
||||||
|
# Parent ID (Optional for root-level catalogs)
|
||||||
|
parent = QuerySelectField(
|
||||||
|
'Parent Catalog',
|
||||||
|
query_factory=lambda: Catalog.query.all(),
|
||||||
|
allow_blank=True,
|
||||||
|
get_label='name',
|
||||||
|
validators=[Optional()],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Select Field for Catalog Type (Uses the CATALOG_TYPES defined in config)
|
||||||
|
type = SelectField('Catalog Type', validators=[DataRequired()])
|
||||||
|
|
||||||
|
# Metadata fields
|
||||||
|
user_metadata = TextAreaField('User Metadata', validators=[Optional(), validate_json])
|
||||||
|
system_metadata = TextAreaField('System Metadata', validators=[Optional(), validate_json])
|
||||||
|
configuration = TextAreaField('Configuration', validators=[Optional(), validate_json])
|
||||||
|
|
||||||
# HTML Embedding Variables
|
# HTML Embedding Variables
|
||||||
html_tags = StringField('HTML Tags', validators=[DataRequired()],
|
html_tags = StringField('HTML Tags', validators=[DataRequired()],
|
||||||
default='p, h1, h2, h3, h4, h5, h6, li, , tbody, tr, td')
|
default='p, h1, h2, h3, h4, h5, h6, li, , tbody, tr, td')
|
||||||
@@ -38,19 +63,65 @@ class CatalogForm(FlaskForm):
|
|||||||
default=2000)
|
default=2000)
|
||||||
max_chunk_size = IntegerField('Maximum Chunk Size (3000)', validators=[NumberRange(min=0), Optional()],
|
max_chunk_size = IntegerField('Maximum Chunk Size (3000)', validators=[NumberRange(min=0), Optional()],
|
||||||
default=3000)
|
default=3000)
|
||||||
# Embedding Search variables
|
|
||||||
es_k = IntegerField('Limit for Searching Embeddings (5)',
|
|
||||||
default=5,
|
|
||||||
validators=[NumberRange(min=0)])
|
|
||||||
es_similarity_threshold = FloatField('Similarity Threshold for Searching Embeddings (0.5)',
|
|
||||||
default=0.5,
|
|
||||||
validators=[NumberRange(min=0, max=1)])
|
|
||||||
# Chat Variables
|
# Chat Variables
|
||||||
chat_RAG_temperature = FloatField('RAG Temperature', default=0.3, validators=[NumberRange(min=0, max=1)])
|
chat_RAG_temperature = FloatField('RAG Temperature', default=0.3, validators=[NumberRange(min=0, max=1)])
|
||||||
chat_no_RAG_temperature = FloatField('No RAG Temperature', default=0.5, validators=[NumberRange(min=0, max=1)])
|
chat_no_RAG_temperature = FloatField('No RAG Temperature', default=0.5, validators=[NumberRange(min=0, max=1)])
|
||||||
# Tuning variables
|
# Tuning variables
|
||||||
embed_tuning = BooleanField('Enable Embedding Tuning', default=False)
|
embed_tuning = BooleanField('Enable Embedding Tuning', default=False)
|
||||||
rag_tuning = BooleanField('Enable RAG Tuning', default=False)
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
# Dynamically populate the 'type' field using the constructor
|
||||||
|
self.type.choices = [(key, value['name']) for key, value in CATALOG_TYPES.items()]
|
||||||
|
|
||||||
|
|
||||||
|
class RetrieverForm(FlaskForm):
|
||||||
|
name = StringField('Name', validators=[DataRequired(), Length(max=50)])
|
||||||
|
description = TextAreaField('Description', validators=[Optional()])
|
||||||
|
# Catalog for the Retriever
|
||||||
|
catalog = QuerySelectField(
|
||||||
|
'Catalog ID',
|
||||||
|
query_factory=lambda: Catalog.query.all(),
|
||||||
|
allow_blank=True,
|
||||||
|
get_label='name',
|
||||||
|
validators=[Optional()],
|
||||||
|
)
|
||||||
|
# Select Field for Retriever Type (Uses the RETRIEVER_TYPES defined in config)
|
||||||
|
type = SelectField('Retriever Type', validators=[DataRequired()])
|
||||||
|
|
||||||
|
# Metadata fields
|
||||||
|
user_metadata = TextAreaField('User Metadata', validators=[Optional(), validate_json])
|
||||||
|
system_metadata = TextAreaField('System Metadata', validators=[Optional(), validate_json])
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
# Dynamically populate the 'type' field using the constructor
|
||||||
|
self.type.choices = [(key, value['name']) for key, value in RETRIEVER_TYPES.items()]
|
||||||
|
|
||||||
|
|
||||||
|
class EditRetrieverForm(DynamicFormBase):
|
||||||
|
name = StringField('Name', validators=[DataRequired(), Length(max=50)])
|
||||||
|
description = TextAreaField('Description', validators=[Optional()])
|
||||||
|
# Catalog for the Retriever
|
||||||
|
catalog = QuerySelectField(
|
||||||
|
'Catalog ID',
|
||||||
|
query_factory=lambda: Catalog.query.all(),
|
||||||
|
allow_blank=True,
|
||||||
|
get_label='name',
|
||||||
|
validators=[Optional()],
|
||||||
|
)
|
||||||
|
# Select Field for Retriever Type (Uses the RETRIEVER_TYPES defined in config)
|
||||||
|
type = SelectField('Retriever Type', validators=[DataRequired()], render_kw={'readonly': True})
|
||||||
|
|
||||||
|
# Metadata fields
|
||||||
|
user_metadata = TextAreaField('User Metadata', validators=[Optional(), validate_json])
|
||||||
|
system_metadata = TextAreaField('System Metadata', validators=[Optional(), validate_json])
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
# Set the retriever type choices (loaded from config)
|
||||||
|
self.type.choices = [(key, value['name']) for key, value in RETRIEVER_TYPES.items()]
|
||||||
|
|
||||||
|
|
||||||
class AddDocumentForm(FlaskForm):
|
class AddDocumentForm(FlaskForm):
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ from urllib.parse import urlparse, unquote
|
|||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from common.models.document import Document, DocumentVersion, Catalog
|
from common.models.document import Document, DocumentVersion, Catalog, Retriever
|
||||||
from common.extensions import db, minio_client
|
from common.extensions import db, minio_client
|
||||||
from common.utils.document_utils import validate_file_type, create_document_stack, start_embedding_task, process_url, \
|
from common.utils.document_utils import validate_file_type, create_document_stack, start_embedding_task, process_url, \
|
||||||
process_multiple_urls, get_documents_list, edit_document, \
|
process_multiple_urls, get_documents_list, edit_document, \
|
||||||
@@ -22,7 +22,7 @@ from common.utils.document_utils import validate_file_type, create_document_stac
|
|||||||
from common.utils.eveai_exceptions import EveAIInvalidLanguageException, EveAIUnsupportedFileType, \
|
from common.utils.eveai_exceptions import EveAIInvalidLanguageException, EveAIUnsupportedFileType, \
|
||||||
EveAIDoubleURLException
|
EveAIDoubleURLException
|
||||||
from .document_forms import AddDocumentForm, AddURLForm, EditDocumentForm, EditDocumentVersionForm, AddURLsForm, \
|
from .document_forms import AddDocumentForm, AddURLForm, EditDocumentForm, EditDocumentVersionForm, AddURLsForm, \
|
||||||
CatalogForm
|
CatalogForm, RetrieverForm, EditRetrieverForm
|
||||||
from common.utils.middleware import mw_before_request
|
from common.utils.middleware import mw_before_request
|
||||||
from common.utils.celery_utils import current_celery
|
from common.utils.celery_utils import current_celery
|
||||||
from common.utils.nginx_utils import prefixed_url_for
|
from common.utils.nginx_utils import prefixed_url_for
|
||||||
@@ -30,6 +30,8 @@ from common.utils.view_assistants import form_validation_failed, prepare_table_f
|
|||||||
from .document_list_view import DocumentListView
|
from .document_list_view import DocumentListView
|
||||||
from .document_version_list_view import DocumentVersionListView
|
from .document_version_list_view import DocumentVersionListView
|
||||||
|
|
||||||
|
from config.retriever_types import RETRIEVER_TYPES
|
||||||
|
|
||||||
document_bp = Blueprint('document_bp', __name__, url_prefix='/document')
|
document_bp = Blueprint('document_bp', __name__, url_prefix='/document')
|
||||||
|
|
||||||
|
|
||||||
@@ -65,6 +67,7 @@ def catalog():
|
|||||||
tenant_id = session.get('tenant').get('id')
|
tenant_id = session.get('tenant').get('id')
|
||||||
new_catalog = Catalog()
|
new_catalog = Catalog()
|
||||||
form.populate_obj(new_catalog)
|
form.populate_obj(new_catalog)
|
||||||
|
new_catalog.parent_id = form.parent.data.get('id')
|
||||||
# Handle Embedding Variables
|
# Handle Embedding Variables
|
||||||
new_catalog.html_tags = [tag.strip() for tag in form.html_tags.data.split(',')] if form.html_tags.data else []
|
new_catalog.html_tags = [tag.strip() for tag in form.html_tags.data.split(',')] if form.html_tags.data else []
|
||||||
new_catalog.html_end_tags = [tag.strip() for tag in form.html_end_tags.data.split(',')] \
|
new_catalog.html_end_tags = [tag.strip() for tag in form.html_end_tags.data.split(',')] \
|
||||||
@@ -103,7 +106,7 @@ def catalogs():
|
|||||||
the_catalogs = pagination.items
|
the_catalogs = pagination.items
|
||||||
|
|
||||||
# prepare table data
|
# prepare table data
|
||||||
rows = prepare_table_for_macro(the_catalogs, [('id', ''), ('name', '')])
|
rows = prepare_table_for_macro(the_catalogs, [('id', ''), ('name', ''), ('type', '')])
|
||||||
|
|
||||||
# Render the catalogs in a template
|
# Render the catalogs in a template
|
||||||
return render_template('document/catalogs.html', rows=rows, pagination=pagination)
|
return render_template('document/catalogs.html', rows=rows, pagination=pagination)
|
||||||
@@ -173,6 +176,121 @@ def edit_catalog(catalog_id):
|
|||||||
return render_template('document/edit_catalog.html', form=form, catalog_id=catalog_id)
|
return render_template('document/edit_catalog.html', form=form, catalog_id=catalog_id)
|
||||||
|
|
||||||
|
|
||||||
|
@document_bp.route('/retriever', methods=['GET', 'POST'])
|
||||||
|
@roles_accepted('Super User', 'Tenant Admin')
|
||||||
|
def retriever():
|
||||||
|
form = RetrieverForm()
|
||||||
|
|
||||||
|
if form.validate_on_submit():
|
||||||
|
tenant_id = session.get('tenant').get('id')
|
||||||
|
new_retriever = Retriever()
|
||||||
|
form.populate_obj(new_retriever)
|
||||||
|
new_retriever.catalog_id = form.catalog.data.id
|
||||||
|
|
||||||
|
set_logging_information(new_retriever, dt.now(tz.utc))
|
||||||
|
|
||||||
|
try:
|
||||||
|
db.session.add(new_retriever)
|
||||||
|
db.session.commit()
|
||||||
|
flash('Retriever successfully added!', 'success')
|
||||||
|
current_app.logger.info(f'Catalog {new_retriever.name} successfully added for tenant {tenant_id}!')
|
||||||
|
except SQLAlchemyError as e:
|
||||||
|
db.session.rollback()
|
||||||
|
flash(f'Failed to add retriever. Error: {e}', 'danger')
|
||||||
|
current_app.logger.error(f'Failed to add retriever {new_retriever.name}'
|
||||||
|
f'for tenant {tenant_id}. Error: {str(e)}')
|
||||||
|
|
||||||
|
# Enable step 2 of creation of retriever - add configuration of the retriever (dependent on type)
|
||||||
|
return redirect(prefixed_url_for('document_bp.retriever', retriever_id=new_retriever.id))
|
||||||
|
|
||||||
|
return render_template('document/retriever.html', form=form)
|
||||||
|
|
||||||
|
|
||||||
|
@document_bp.route('/retriever/<int:retriever_id>', methods=['GET', 'POST'])
|
||||||
|
@roles_accepted('Super User', 'Tenant Admin')
|
||||||
|
def edit_retriever(retriever_id):
|
||||||
|
"""Edit an existing retriever configuration."""
|
||||||
|
current_app.logger.debug(f"Editing Retriever {retriever_id}")
|
||||||
|
|
||||||
|
# Get the retriever or return 404
|
||||||
|
retriever = Retriever.query.get_or_404(retriever_id)
|
||||||
|
|
||||||
|
if retriever.catalog_id:
|
||||||
|
# If catalog_id is just an ID, fetch the Catalog object
|
||||||
|
retriever.catalog = Catalog.query.get(retriever.catalog_id)
|
||||||
|
else:
|
||||||
|
retriever.catalog = None
|
||||||
|
|
||||||
|
# Create form instance with the retriever
|
||||||
|
form = EditRetrieverForm(request.form, obj=retriever)
|
||||||
|
|
||||||
|
configuration_config = RETRIEVER_TYPES[retriever.type]["configuration"]
|
||||||
|
current_app.logger.debug(f"Configuration {configuration_config}")
|
||||||
|
form.add_dynamic_fields("configuration", configuration_config, retriever.configuration)
|
||||||
|
|
||||||
|
if form.validate_on_submit():
|
||||||
|
# Update basic fields
|
||||||
|
form.populate_obj(retriever)
|
||||||
|
retriever.configuration = form.get_dynamic_data('configuration')
|
||||||
|
|
||||||
|
# Update catalog relationship
|
||||||
|
retriever.catalog_id = form.catalog.data.id if form.catalog.data else None
|
||||||
|
|
||||||
|
# Update logging information
|
||||||
|
update_logging_information(retriever, dt.now(tz.utc))
|
||||||
|
|
||||||
|
# Save changes to database
|
||||||
|
try:
|
||||||
|
db.session.add(retriever)
|
||||||
|
db.session.commit()
|
||||||
|
flash('Retriever updated successfully!', 'success')
|
||||||
|
current_app.logger.info(f'Retriever {retriever.id} updated successfully')
|
||||||
|
except SQLAlchemyError as e:
|
||||||
|
db.session.rollback()
|
||||||
|
flash(f'Failed to update retriever. Error: {str(e)}', 'danger')
|
||||||
|
current_app.logger.error(f'Failed to update retriever {retriever_id}. Error: {str(e)}')
|
||||||
|
return render_template('document/edit_retriever.html', form=form, retriever_id=retriever_id)
|
||||||
|
|
||||||
|
return redirect(prefixed_url_for('document_bp.retrievers'))
|
||||||
|
else:
|
||||||
|
form_validation_failed(request, form)
|
||||||
|
|
||||||
|
current_app.logger.debug(f"Rendering Template for {retriever_id}")
|
||||||
|
return render_template('document/edit_retriever.html', form=form, retriever_id=retriever_id)
|
||||||
|
|
||||||
|
|
||||||
|
@document_bp.route('/retrievers', methods=['GET', 'POST'])
|
||||||
|
@roles_accepted('Super User', 'Tenant Admin')
|
||||||
|
def retrievers():
|
||||||
|
page = request.args.get('page', 1, type=int)
|
||||||
|
per_page = request.args.get('per_page', 10, type=int)
|
||||||
|
|
||||||
|
query = Retriever.query.order_by(Retriever.id)
|
||||||
|
|
||||||
|
pagination = query.paginate(page=page, per_page=per_page)
|
||||||
|
the_retrievers = pagination.items
|
||||||
|
|
||||||
|
# prepare table data
|
||||||
|
rows = prepare_table_for_macro(the_retrievers,
|
||||||
|
[('id', ''), ('name', ''), ('type', ''), ('catalog_id', '')])
|
||||||
|
|
||||||
|
# Render the catalogs in a template
|
||||||
|
return render_template('document/retrievers.html', rows=rows, pagination=pagination)
|
||||||
|
|
||||||
|
|
||||||
|
@document_bp.route('/handle_retriever_selection', methods=['POST'])
|
||||||
|
@roles_accepted('Super User', 'Tenant Admin')
|
||||||
|
def handle_retriever_selection():
|
||||||
|
retriever_identification = request.form.get('selected_row')
|
||||||
|
retriever_id = ast.literal_eval(retriever_identification).get('value')
|
||||||
|
action = request.form['action']
|
||||||
|
|
||||||
|
if action == 'edit_retriever':
|
||||||
|
return redirect(prefixed_url_for('document_bp.edit_retriever', retriever_id=retriever_id))
|
||||||
|
|
||||||
|
return redirect(prefixed_url_for('document_bp.retrievers'))
|
||||||
|
|
||||||
|
|
||||||
@document_bp.route('/add_document', methods=['GET', 'POST'])
|
@document_bp.route('/add_document', methods=['GET', 'POST'])
|
||||||
@roles_accepted('Super User', 'Tenant Admin')
|
@roles_accepted('Super User', 'Tenant Admin')
|
||||||
def add_document():
|
def add_document():
|
||||||
|
|||||||
92
eveai_app/views/dynamic_form_base.py
Normal file
92
eveai_app/views/dynamic_form_base.py
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
from flask_wtf import FlaskForm
|
||||||
|
from wtforms import IntegerField, FloatField, BooleanField, StringField, validators
|
||||||
|
from flask import current_app
|
||||||
|
|
||||||
|
class DynamicFormBase(FlaskForm):
|
||||||
|
def __init__(self, formdata=None, *args, **kwargs):
|
||||||
|
super(DynamicFormBase, self).__init__(*args, **kwargs)
|
||||||
|
# Maps collection names to lists of field names
|
||||||
|
self.dynamic_fields = {}
|
||||||
|
# Store formdata for later use
|
||||||
|
self.formdata = formdata
|
||||||
|
|
||||||
|
def add_dynamic_fields(self, collection_name, config, initial_data=None):
|
||||||
|
"""Add dynamic fields to the form based on the configuration."""
|
||||||
|
self.dynamic_fields[collection_name] = []
|
||||||
|
for field_name, field_def in config.items():
|
||||||
|
current_app.logger.debug(f"{field_name}: {field_def}")
|
||||||
|
# Prefix the field name with the collection name
|
||||||
|
full_field_name = f"{collection_name}_{field_name}"
|
||||||
|
field_type = field_def.get('type')
|
||||||
|
description = field_def.get('description', '')
|
||||||
|
required = field_def.get('required', False)
|
||||||
|
default = field_def.get('default')
|
||||||
|
|
||||||
|
# Determine validators
|
||||||
|
field_validators = [validators.InputRequired()] if required else [validators.Optional()]
|
||||||
|
|
||||||
|
# Map the field type to WTForms field classes
|
||||||
|
field_class = {
|
||||||
|
'int': IntegerField,
|
||||||
|
'float': FloatField,
|
||||||
|
'boolean': BooleanField,
|
||||||
|
'string': StringField,
|
||||||
|
}.get(field_type, StringField)
|
||||||
|
|
||||||
|
# Create the field instance
|
||||||
|
unbound_field = field_class(
|
||||||
|
label=description,
|
||||||
|
validators=field_validators,
|
||||||
|
default=default
|
||||||
|
)
|
||||||
|
|
||||||
|
# Bind the field to the form
|
||||||
|
bound_field = unbound_field.bind(form=self, name=full_field_name)
|
||||||
|
|
||||||
|
# Process the field with formdata
|
||||||
|
if self.formdata and full_field_name in self.formdata:
|
||||||
|
# If formdata is available and contains the field
|
||||||
|
bound_field.process(self.formdata)
|
||||||
|
elif initial_data and field_name in initial_data:
|
||||||
|
# Use initial data if provided
|
||||||
|
bound_field.process(formdata=None, data=initial_data[field_name])
|
||||||
|
else:
|
||||||
|
# Use default value
|
||||||
|
bound_field.process(formdata=None, data=default)
|
||||||
|
|
||||||
|
# Set collection name attribute for identification
|
||||||
|
# bound_field.collection_name = collection_name
|
||||||
|
|
||||||
|
# Add the field to the form
|
||||||
|
setattr(self, full_field_name, bound_field)
|
||||||
|
self._fields[full_field_name] = bound_field
|
||||||
|
self.dynamic_fields[collection_name].append(full_field_name)
|
||||||
|
|
||||||
|
def get_static_fields(self):
|
||||||
|
"""Return a list of static field instances."""
|
||||||
|
# Get names of dynamic fields
|
||||||
|
dynamic_field_names = set()
|
||||||
|
for field_list in self.dynamic_fields.values():
|
||||||
|
dynamic_field_names.update(field_list)
|
||||||
|
|
||||||
|
# Return all fields that are not dynamic
|
||||||
|
return [field for name, field in self._fields.items() if name not in dynamic_field_names]
|
||||||
|
|
||||||
|
def get_dynamic_fields(self):
|
||||||
|
"""Return a dictionary of dynamic fields per collection."""
|
||||||
|
result = {}
|
||||||
|
for collection_name, field_names in self.dynamic_fields.items():
|
||||||
|
result[collection_name] = [getattr(self, name) for name in field_names]
|
||||||
|
return result
|
||||||
|
|
||||||
|
def get_dynamic_data(self, collection_name):
|
||||||
|
"""Retrieve the data from dynamic fields of a specific collection."""
|
||||||
|
data = {}
|
||||||
|
if collection_name not in self.dynamic_fields:
|
||||||
|
return data
|
||||||
|
prefix_length = len(collection_name) + 1 # +1 for the underscore
|
||||||
|
for full_field_name in self.dynamic_fields[collection_name]:
|
||||||
|
original_field_name = full_field_name[prefix_length:]
|
||||||
|
field = getattr(self, full_field_name)
|
||||||
|
data[original_field_name] = field.data
|
||||||
|
return data
|
||||||
@@ -22,7 +22,7 @@ from common.models.interaction import ChatSession, Interaction, InteractionEmbed
|
|||||||
from common.extensions import db
|
from common.extensions import db
|
||||||
from common.utils.celery_utils import current_celery
|
from common.utils.celery_utils import current_celery
|
||||||
from common.utils.model_utils import select_model_variables, create_language_template, replace_variable_in_template
|
from common.utils.model_utils import select_model_variables, create_language_template, replace_variable_in_template
|
||||||
from common.langchain.eveai_retriever import EveAIRetriever
|
from common.langchain.eveai_default_rag_retriever import EveAIDefaultRagRetriever
|
||||||
from common.langchain.eveai_history_retriever import EveAIHistoryRetriever
|
from common.langchain.eveai_history_retriever import EveAIHistoryRetriever
|
||||||
from common.utils.business_event import BusinessEvent
|
from common.utils.business_event import BusinessEvent
|
||||||
from common.utils.business_event_context import current_event
|
from common.utils.business_event_context import current_event
|
||||||
@@ -139,7 +139,7 @@ def answer_using_tenant_rag(question, language, tenant, chat_session):
|
|||||||
new_interaction.detailed_question_at = dt.now(tz.utc)
|
new_interaction.detailed_question_at = dt.now(tz.utc)
|
||||||
|
|
||||||
with current_event.create_span("Generate Answer using RAG"):
|
with current_event.create_span("Generate Answer using RAG"):
|
||||||
retriever = EveAIRetriever(model_variables, tenant_info)
|
retriever = EveAIDefaultRagRetriever(model_variables, tenant_info)
|
||||||
llm = model_variables['llm']
|
llm = model_variables['llm']
|
||||||
template = model_variables['rag_template']
|
template = model_variables['rag_template']
|
||||||
language_template = create_language_template(template, language)
|
language_template = create_language_template(template, language)
|
||||||
@@ -243,7 +243,7 @@ def answer_using_llm(question, language, tenant, chat_session):
|
|||||||
new_interaction.detailed_question_at = dt.now(tz.utc)
|
new_interaction.detailed_question_at = dt.now(tz.utc)
|
||||||
|
|
||||||
with current_event.create_span("Detail Answer using LLM"):
|
with current_event.create_span("Detail Answer using LLM"):
|
||||||
retriever = EveAIRetriever(model_variables, tenant_info)
|
retriever = EveAIDefaultRagRetriever(model_variables, tenant_info)
|
||||||
llm = model_variables['llm_no_rag']
|
llm = model_variables['llm_no_rag']
|
||||||
template = model_variables['encyclopedia_template']
|
template = model_variables['encyclopedia_template']
|
||||||
language_template = create_language_template(template, language)
|
language_template = create_language_template(template, language)
|
||||||
|
|||||||
Binary file not shown.
@@ -0,0 +1,56 @@
|
|||||||
|
"""Add Retriever Model
|
||||||
|
|
||||||
|
Revision ID: 3717364e6429
|
||||||
|
Revises: 7b7b566e667f
|
||||||
|
Create Date: 2024-10-21 14:22:30.258679
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
import pgvector
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = '3717364e6429'
|
||||||
|
down_revision = '7b7b566e667f'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.create_table('retriever',
|
||||||
|
sa.Column('id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('name', sa.String(length=50), nullable=False),
|
||||||
|
sa.Column('description', sa.Text(), nullable=True),
|
||||||
|
sa.Column('catalog_id', sa.Integer(), nullable=True),
|
||||||
|
sa.Column('type', sa.String(length=50), nullable=False),
|
||||||
|
sa.Column('user_metadata', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||||
|
sa.Column('system_metadata', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||||
|
sa.Column('configuration', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||||
|
sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=False),
|
||||||
|
sa.Column('created_by', sa.Integer(), nullable=True),
|
||||||
|
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('now()'), nullable=False),
|
||||||
|
sa.Column('updated_by', sa.Integer(), nullable=True),
|
||||||
|
sa.ForeignKeyConstraint(['catalog_id'], ['catalog.id'], ),
|
||||||
|
sa.ForeignKeyConstraint(['created_by'], ['public.user.id'], ),
|
||||||
|
sa.ForeignKeyConstraint(['updated_by'], ['public.user.id'], ),
|
||||||
|
sa.PrimaryKeyConstraint('id')
|
||||||
|
)
|
||||||
|
op.drop_column('catalog', 'es_similarity_threshold')
|
||||||
|
op.drop_column('catalog', 'es_k')
|
||||||
|
op.drop_column('catalog', 'rag_tuning')
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.add_column('catalog', sa.Column('rag_tuning', sa.BOOLEAN(), autoincrement=False, nullable=True))
|
||||||
|
op.add_column('catalog', sa.Column('es_k', sa.INTEGER(), autoincrement=False, nullable=True))
|
||||||
|
op.add_column('catalog', sa.Column('es_similarity_threshold', sa.DOUBLE_PRECISION(precision=53), autoincrement=False, nullable=True))
|
||||||
|
op.drop_constraint(None, 'catalog', type_='foreignkey')
|
||||||
|
op.drop_constraint(None, 'catalog', type_='foreignkey')
|
||||||
|
op.create_foreign_key('catalog_updated_by_fkey', 'catalog', 'user', ['updated_by'], ['id'])
|
||||||
|
op.create_foreign_key('catalog_created_by_fkey', 'catalog', 'user', ['created_by'], ['id'])
|
||||||
|
op.drop_table('retriever')
|
||||||
|
# ### end Alembic commands ###
|
||||||
@@ -0,0 +1,46 @@
|
|||||||
|
"""Extensions for more Catalog Types in Catalog Model
|
||||||
|
|
||||||
|
Revision ID: 7b7b566e667f
|
||||||
|
Revises: 28984b05d396
|
||||||
|
Create Date: 2024-10-21 07:39:52.260054
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
import pgvector
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = '7b7b566e667f'
|
||||||
|
down_revision = '28984b05d396'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.add_column('catalog', sa.Column('parent_id', sa.Integer(), nullable=True))
|
||||||
|
|
||||||
|
op.add_column('catalog', sa.Column('type', sa.String(length=50), nullable=False, server_default='DEFAULT'))
|
||||||
|
# Update all existing rows to have the 'type' set to 'DEFAULT'
|
||||||
|
op.execute("UPDATE catalog SET type='DEFAULT' WHERE type IS NULL")
|
||||||
|
# Remove the server default (optional but recommended)
|
||||||
|
op.alter_column('catalog', 'type', server_default=None)
|
||||||
|
|
||||||
|
op.add_column('catalog', sa.Column('user_metadata', postgresql.JSONB(astext_type=sa.Text()), nullable=True))
|
||||||
|
op.add_column('catalog', sa.Column('system_metadata', postgresql.JSONB(astext_type=sa.Text()), nullable=True))
|
||||||
|
op.add_column('catalog', sa.Column('configuration', postgresql.JSONB(astext_type=sa.Text()), nullable=True))
|
||||||
|
op.create_foreign_key(None, 'catalog', 'catalog', ['parent_id'], ['id'])
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.drop_constraint(None, 'catalog', type_='foreignkey')
|
||||||
|
op.drop_column('catalog', 'configuration')
|
||||||
|
op.drop_column('catalog', 'system_metadata')
|
||||||
|
op.drop_column('catalog', 'user_metadata')
|
||||||
|
op.drop_column('catalog', 'type')
|
||||||
|
op.drop_column('catalog', 'parent_id')
|
||||||
|
# ### end Alembic commands ###
|
||||||
@@ -82,3 +82,4 @@ prometheus-client~=0.20.0
|
|||||||
flower~=2.0.1
|
flower~=2.0.1
|
||||||
psutil~=6.0.0
|
psutil~=6.0.0
|
||||||
celery-redbeat~=2.2.0
|
celery-redbeat~=2.2.0
|
||||||
|
WTForms-SQLAlchemy~=0.4.1
|
||||||
|
|||||||
Reference in New Issue
Block a user