Improving chat functionality significantly throughout the application.
This commit is contained in:
@@ -1,29 +1,26 @@
|
||||
from datetime import datetime as dt, timezone as tz
|
||||
from flask import current_app
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from celery import states
|
||||
from celery.exceptions import Ignore
|
||||
import os
|
||||
|
||||
from datetime import datetime as dt, timezone as tz
|
||||
from bs4 import BeautifulSoup
|
||||
import html
|
||||
from celery import states
|
||||
from flask import current_app
|
||||
# OpenAI imports
|
||||
from langchain.chains.summarize import load_summarize_chain
|
||||
from langchain.text_splitter import CharacterTextSplitter
|
||||
from langchain_core.exceptions import LangChainException
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
# Unstructured commercial client imports
|
||||
from unstructured_client import UnstructuredClient
|
||||
from unstructured_client.models import shared
|
||||
from unstructured_client.models.errors import SDKError
|
||||
|
||||
# OpenAI imports
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain.chains.summarize import load_summarize_chain
|
||||
from langchain.text_splitter import CharacterTextSplitter
|
||||
from langchain_core.exceptions import LangChainException
|
||||
|
||||
from common.utils.database import Database
|
||||
from common.models.document import DocumentVersion
|
||||
from common.models.user import Tenant
|
||||
from common.extensions import db
|
||||
from common.models.document import DocumentVersion, Embedding
|
||||
from common.models.user import Tenant
|
||||
from common.utils.celery_utils import current_celery
|
||||
from common.utils.model_utils import select_model_variables
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
from common.utils.database import Database
|
||||
from common.utils.model_utils import select_model_variables, create_language_template
|
||||
|
||||
|
||||
@current_celery.task(name='create_embeddings', queue='embeddings')
|
||||
@@ -65,6 +62,8 @@ def create_embeddings(tenant_id, document_version_id):
|
||||
# start processing
|
||||
document_version.processing = True
|
||||
document_version.processing_started_at = dt.now(tz.utc)
|
||||
document_version.processing_finished_at = None
|
||||
document_version.processing_error = None
|
||||
|
||||
db.session.commit()
|
||||
except SQLAlchemyError as e:
|
||||
@@ -73,6 +72,8 @@ def create_embeddings(tenant_id, document_version_id):
|
||||
f'for tenant {tenant_id}')
|
||||
raise
|
||||
|
||||
delete_embeddings_for_document_version(document_version)
|
||||
|
||||
try:
|
||||
match document_version.file_type:
|
||||
case 'pdf':
|
||||
@@ -152,6 +153,18 @@ def process_pdf(tenant, model_variables, document_version):
|
||||
f'on document version {document_version.id} :-)')
|
||||
|
||||
|
||||
def delete_embeddings_for_document_version(document_version):
|
||||
embeddings_to_delete = db.session.query(Embedding).filter_by(doc_vers_id=document_version.id).all()
|
||||
for embedding in embeddings_to_delete:
|
||||
db.session.delete(embedding)
|
||||
try:
|
||||
db.session.commit()
|
||||
current_app.logger.info(f'Deleted embeddings for document version {document_version.id}')
|
||||
except SQLAlchemyError as e:
|
||||
current_app.logger.error(f'Unable to delete embeddings for document version {document_version.id}')
|
||||
raise
|
||||
|
||||
|
||||
def process_html(tenant, model_variables, document_version):
|
||||
# The tags to be considered can be dependent on the tenant
|
||||
html_tags = model_variables['html_tags']
|
||||
@@ -176,12 +189,15 @@ def process_html(tenant, model_variables, document_version):
|
||||
extracted_data, title = parse_html(html_content, html_tags, included_elements=html_included_elements,
|
||||
excluded_elements=html_excluded_elements)
|
||||
potential_chunks = create_potential_chunks(extracted_data, html_end_tags)
|
||||
current_app.embed_tuning_logger.debug(f'Nr of potential chunks: {len(potential_chunks)}')
|
||||
|
||||
chunks = combine_chunks(potential_chunks,
|
||||
model_variables['min_chunk_size'],
|
||||
model_variables['max_chunk_size']
|
||||
)
|
||||
current_app.logger.debug(f'Nr of chunks: {len(chunks)}')
|
||||
|
||||
if len(chunks) > 0:
|
||||
if len(chunks) > 1:
|
||||
summary = summarize_chunk(tenant, model_variables, document_version, chunks[0])
|
||||
document_version.system_context = (f'Title: {title}\n'
|
||||
f'Summary: {summary}\n')
|
||||
@@ -210,6 +226,7 @@ def process_html(tenant, model_variables, document_version):
|
||||
def enrich_chunks(tenant, document_version, chunks):
|
||||
current_app.logger.debug(f'Enriching chunks for tenant {tenant.id} '
|
||||
f'on document version {document_version.id}')
|
||||
current_app.logger.debug(f'Nr of chunks: {len(chunks)}')
|
||||
chunk_total_context = (f'Filename: {document_version.file_name}\n'
|
||||
f'User Context:{document_version.user_context}\n'
|
||||
f'{document_version.system_context}\n\n')
|
||||
@@ -233,8 +250,10 @@ def summarize_chunk(tenant, model_variables, document_version, chunk):
|
||||
current_app.logger.debug(f'Summarizing chunk for tenant {tenant.id} '
|
||||
f'on document version {document_version.id}')
|
||||
llm = model_variables['llm']
|
||||
prompt = model_variables['summary_prompt']
|
||||
chain = load_summarize_chain(llm, chain_type='stuff', prompt=prompt)
|
||||
template = model_variables['summary_template']
|
||||
language_template = create_language_template(template, document_version.language)
|
||||
current_app.logger.debug(f'Language prompt: {language_template}')
|
||||
chain = load_summarize_chain(llm, chain_type='stuff', prompt=ChatPromptTemplate.from_template(language_template))
|
||||
|
||||
doc_creator = CharacterTextSplitter(chunk_size=model_variables['max_chunk_size'] * 2, chunk_overlap=0)
|
||||
text_to_summarize = doc_creator.create_documents(chunk)
|
||||
@@ -319,7 +338,6 @@ def parse_html(html_content, tags, included_elements=None, excluded_elements=Non
|
||||
current_app.embed_tuning_logger.debug(f'Included Elements: {len(included_elements)}')
|
||||
current_app.embed_tuning_logger.debug(f'Excluded Elements: {excluded_elements}')
|
||||
current_app.embed_tuning_logger.debug(f'Found {len(elements_to_parse)} elements to parse')
|
||||
current_app.embed_tuning_logger.debug(f'{elements_to_parse}')
|
||||
|
||||
# Iterate through the found included elements
|
||||
for element in elements_to_parse:
|
||||
@@ -327,7 +345,8 @@ def parse_html(html_content, tags, included_elements=None, excluded_elements=Non
|
||||
for sub_element in element.find_all(tags):
|
||||
if excluded_elements and sub_element.find_parent(excluded_elements):
|
||||
continue # Skip this sub_element if it's within any of the excluded_elements
|
||||
extracted_content.append((sub_element.name, sub_element.get_text(strip=True)))
|
||||
sub_content = html.unescape(sub_element.get_text(strip=False))
|
||||
extracted_content.append((sub_element.name, sub_content))
|
||||
|
||||
title = soup.find('title').get_text(strip=True)
|
||||
|
||||
@@ -362,11 +381,14 @@ def combine_chunks(potential_chunks, min_chars, max_chars):
|
||||
current_length = 0
|
||||
|
||||
for chunk in potential_chunks:
|
||||
current_app.embed_tuning_logger.debug(f'chunk: {chunk}')
|
||||
chunk_content = ''.join(text for _, text in chunk)
|
||||
current_app.embed_tuning_logger.debug(f'chunk_content: {chunk_content}')
|
||||
chunk_length = len(chunk_content)
|
||||
|
||||
if current_length + chunk_length > max_chars:
|
||||
if current_length >= min_chars:
|
||||
current_app.embed_tuning_logger.debug(f'Adding chunk to actual_chunks: {current_chunk}')
|
||||
actual_chunks.append(current_chunk)
|
||||
current_chunk = chunk_content
|
||||
current_length = chunk_length
|
||||
@@ -378,8 +400,11 @@ def combine_chunks(potential_chunks, min_chars, max_chars):
|
||||
current_chunk += chunk_content
|
||||
current_length += chunk_length
|
||||
|
||||
current_app.embed_tuning_logger.debug(f'Remaining Chunk: {current_chunk}')
|
||||
current_app.embed_tuning_logger.debug(f'Remaining Length: {current_length}')
|
||||
|
||||
# Handle the last chunk
|
||||
if current_chunk and current_length >= min_chars:
|
||||
if current_chunk and current_length >= 0:
|
||||
actual_chunks.append(current_chunk)
|
||||
|
||||
return actual_chunks
|
||||
|
||||
Reference in New Issue
Block a user