- Add Catalog Functionality

This commit is contained in:
Josako
2024-10-15 18:14:57 +02:00
parent 3316a8bc47
commit 3e644f1652
16 changed files with 303 additions and 84 deletions

View File

@@ -21,17 +21,17 @@ class Catalog(db.Model):
min_chunk_size = db.Column(db.Integer, nullable=True, default=2000)
max_chunk_size = db.Column(db.Integer, nullable=True, default=3000)
# Embedding search variables
es_k = db.Column(db.Integer, nullable=True, default=5)
es_similarity_threshold = db.Column(db.Float, nullable=True, default=0.7)
# Embedding search variables ==> move to specialist?
es_k = db.Column(db.Integer, nullable=True, default=8)
es_similarity_threshold = db.Column(db.Float, nullable=True, default=0.4)
# Chat variables
# Chat variables ==> Move to Specialist?
chat_RAG_temperature = db.Column(db.Float, nullable=True, default=0.3)
chat_no_RAG_temperature = db.Column(db.Float, nullable=True, default=0.5)
# Tuning enablers
embed_tuning = db.Column(db.Boolean, nullable=True, default=False)
rag_tuning = db.Column(db.Boolean, nullable=True, default=False)
rag_tuning = db.Column(db.Boolean, nullable=True, default=False) # Move to Specialist?
# Versioning Information
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now())
@@ -42,6 +42,7 @@ class Catalog(db.Model):
class Document(db.Model):
id = db.Column(db.Integer, primary_key=True)
# tenant_id = db.Column(db.Integer, db.ForeignKey(Tenant.id), nullable=False)
catalog_id = db.Column(db.Integer, db.ForeignKey(Catalog.id), nullable=True)
name = db.Column(db.String(100), nullable=False)
valid_from = db.Column(db.DateTime, nullable=True)

View File

@@ -22,7 +22,7 @@ def create_document_stack(api_input, file, filename, extension, tenant_id):
db.session.add(new_doc)
# Create the DocumentVersion
new_doc_vers = create_version_for_document(new_doc,
new_doc_vers = create_version_for_document(new_doc, tenant_id,
api_input.get('url', ''),
api_input.get('language', 'en'),
api_input.get('user_context', ''),
@@ -63,7 +63,7 @@ def create_document(form, filename, catalog_id):
return new_doc
def create_version_for_document(document, url, language, user_context, user_metadata):
def create_version_for_document(document, tenant_id, url, language, user_context, user_metadata):
new_doc_vers = DocumentVersion()
if url != '':
new_doc_vers.url = url
@@ -83,7 +83,7 @@ def create_version_for_document(document, url, language, user_context, user_meta
set_logging_information(new_doc_vers, dt.now(tz.utc))
mark_tenant_storage_dirty(document.tenant_id)
mark_tenant_storage_dirty(tenant_id)
return new_doc_vers
@@ -287,7 +287,7 @@ def edit_document_version(version_id, user_context):
return None, str(e)
def refresh_document_with_info(doc_id, api_input):
def refresh_document_with_info(doc_id, tenant_id, api_input):
doc = Document.query.get_or_404(doc_id)
old_doc_vers = DocumentVersion.query.filter_by(doc_id=doc_id).order_by(desc(DocumentVersion.id)).first()
@@ -295,11 +295,11 @@ def refresh_document_with_info(doc_id, api_input):
return None, "This document has no URL. Only documents with a URL can be refreshed."
new_doc_vers = create_version_for_document(
doc,
doc, tenant_id,
old_doc_vers.url,
api_input.get('language', old_doc_vers.language),
api_input.get('user_context', old_doc_vers.user_context),
api_input.get('user_metadata', old_doc_vers.user_metadata)
api_input.get('user_metadata', old_doc_vers.user_metadata),
)
set_logging_information(new_doc_vers, dt.now(tz.utc))
@@ -329,7 +329,7 @@ def refresh_document_with_info(doc_id, api_input):
# Update the existing refresh_document function to use the new refresh_document_with_info
def refresh_document(doc_id):
def refresh_document(doc_id, tenant_id):
current_app.logger.info(f'Refreshing document {doc_id}')
doc = Document.query.get_or_404(doc_id)
old_doc_vers = DocumentVersion.query.filter_by(doc_id=doc_id).order_by(desc(DocumentVersion.id)).first()
@@ -340,7 +340,7 @@ def refresh_document(doc_id):
'user_metadata': old_doc_vers.user_metadata
}
return refresh_document_with_info(doc_id, api_input)
return refresh_document_with_info(doc_id, tenant_id, api_input)
# Function triggered when a document_version is created or updated

View File

@@ -14,7 +14,7 @@ from portkey_ai.langchain.portkey_langchain_callback_handler import LangchainCal
from common.langchain.llm_metrics_handler import LLMMetricsHandler
from common.langchain.tracked_openai_embeddings import TrackedOpenAIEmbeddings
from common.langchain.tracked_transcribe import tracked_transcribe
from common.models.document import EmbeddingSmallOpenAI, EmbeddingLargeOpenAI
from common.models.document import EmbeddingSmallOpenAI, EmbeddingLargeOpenAI, Catalog
from common.models.user import Tenant
from config.model_config import MODEL_CONFIG
from common.utils.business_event_context import current_event
@@ -42,8 +42,9 @@ def set_language_prompt_template(cls, language_prompt):
class ModelVariables(MutableMapping):
def __init__(self, tenant: Tenant):
def __init__(self, tenant: Tenant, catalog_id=None):
self.tenant = tenant
self.catalog_id = catalog_id
self._variables = self._initialize_variables()
self._embedding_model = None
self._llm = None
@@ -57,25 +58,32 @@ class ModelVariables(MutableMapping):
def _initialize_variables(self):
variables = {}
# We initialize the variables that are available knowing the tenant. For the other, we will apply 'lazy loading'
variables['k'] = self.tenant.es_k or 5
variables['similarity_threshold'] = self.tenant.es_similarity_threshold or 0.7
variables['RAG_temperature'] = self.tenant.chat_RAG_temperature or 0.3
variables['no_RAG_temperature'] = self.tenant.chat_no_RAG_temperature or 0.5
variables['embed_tuning'] = self.tenant.embed_tuning or False
variables['rag_tuning'] = self.tenant.rag_tuning or False
# Get the Catalog if catalog_id is passed
if self.catalog_id:
catalog = Catalog.query.get_or_404(self.catalog_id)
# We initialize the variables that are available knowing the tenant.
variables['embed_tuning'] = catalog.embed_tuning or False
# Set HTML Chunking Variables
variables['html_tags'] = catalog.html_tags
variables['html_end_tags'] = catalog.html_end_tags
variables['html_included_elements'] = catalog.html_included_elements
variables['html_excluded_elements'] = catalog.html_excluded_elements
variables['html_excluded_classes'] = catalog.html_excluded_classes
# Set Chunk Size variables
variables['min_chunk_size'] = catalog.min_chunk_size
variables['max_chunk_size'] = catalog.max_chunk_size
# Set the RAG Context (will have to change once specialists are defined
variables['rag_context'] = self.tenant.rag_context or " "
# Set HTML Chunking Variables
variables['html_tags'] = self.tenant.html_tags
variables['html_end_tags'] = self.tenant.html_end_tags
variables['html_included_elements'] = self.tenant.html_included_elements
variables['html_excluded_elements'] = self.tenant.html_excluded_elements
variables['html_excluded_classes'] = self.tenant.html_excluded_classes
# Set Chunk Size variables
variables['min_chunk_size'] = self.tenant.min_chunk_size
variables['max_chunk_size'] = self.tenant.max_chunk_size
# Temporary setting until we have Specialists
variables['rag_tuning'] = False
variables['RAG_temperature'] = 0.3
variables['no_RAG_temperature'] = 0.5
variables['k'] = 8
variables['similarity_threshold'] = 0.4
# Set model providers
variables['embedding_provider'], variables['embedding_model'] = self.tenant.embedding_model.rsplit('.', 1)
@@ -195,7 +203,12 @@ class ModelVariables(MutableMapping):
return self.transcription_client
elif key in self._variables.get('prompt_templates', []):
return self.get_prompt_template(key)
return self._variables.get(key)
else:
value = self._variables.get(key)
if value is not None:
return value
else:
raise KeyError(f'Variable {key} does not exist in ModelVariables')
def __setitem__(self, key: str, value: Any) -> None:
self._variables[key] = value
@@ -225,8 +238,8 @@ class ModelVariables(MutableMapping):
return self._variables.values()
def select_model_variables(tenant):
model_variables = ModelVariables(tenant=tenant)
def select_model_variables(tenant, catalog_id=None):
model_variables = ModelVariables(tenant=tenant, catalog_id=catalog_id)
return model_variables