Improving chat functionality significantly throughout the application.

This commit is contained in:
Josako
2024-06-12 11:07:18 +02:00
parent 27b6de8734
commit be311c440b
22 changed files with 604 additions and 127 deletions

View File

@@ -1,29 +1,26 @@
from datetime import datetime as dt, timezone as tz
from flask import current_app
from sqlalchemy.exc import SQLAlchemyError
from celery import states
from celery.exceptions import Ignore
import os
from datetime import datetime as dt, timezone as tz
from bs4 import BeautifulSoup
import html
from celery import states
from flask import current_app
# OpenAI imports
from langchain.chains.summarize import load_summarize_chain
from langchain.text_splitter import CharacterTextSplitter
from langchain_core.exceptions import LangChainException
from langchain_core.prompts import ChatPromptTemplate
from sqlalchemy.exc import SQLAlchemyError
# Unstructured commercial client imports
from unstructured_client import UnstructuredClient
from unstructured_client.models import shared
from unstructured_client.models.errors import SDKError
# OpenAI imports
from langchain_core.prompts import ChatPromptTemplate
from langchain.chains.summarize import load_summarize_chain
from langchain.text_splitter import CharacterTextSplitter
from langchain_core.exceptions import LangChainException
from common.utils.database import Database
from common.models.document import DocumentVersion
from common.models.user import Tenant
from common.extensions import db
from common.models.document import DocumentVersion, Embedding
from common.models.user import Tenant
from common.utils.celery_utils import current_celery
from common.utils.model_utils import select_model_variables
from bs4 import BeautifulSoup
from common.utils.database import Database
from common.utils.model_utils import select_model_variables, create_language_template
@current_celery.task(name='create_embeddings', queue='embeddings')
@@ -65,6 +62,8 @@ def create_embeddings(tenant_id, document_version_id):
# start processing
document_version.processing = True
document_version.processing_started_at = dt.now(tz.utc)
document_version.processing_finished_at = None
document_version.processing_error = None
db.session.commit()
except SQLAlchemyError as e:
@@ -73,6 +72,8 @@ def create_embeddings(tenant_id, document_version_id):
f'for tenant {tenant_id}')
raise
delete_embeddings_for_document_version(document_version)
try:
match document_version.file_type:
case 'pdf':
@@ -152,6 +153,18 @@ def process_pdf(tenant, model_variables, document_version):
f'on document version {document_version.id} :-)')
def delete_embeddings_for_document_version(document_version):
embeddings_to_delete = db.session.query(Embedding).filter_by(doc_vers_id=document_version.id).all()
for embedding in embeddings_to_delete:
db.session.delete(embedding)
try:
db.session.commit()
current_app.logger.info(f'Deleted embeddings for document version {document_version.id}')
except SQLAlchemyError as e:
current_app.logger.error(f'Unable to delete embeddings for document version {document_version.id}')
raise
def process_html(tenant, model_variables, document_version):
# The tags to be considered can be dependent on the tenant
html_tags = model_variables['html_tags']
@@ -176,12 +189,15 @@ def process_html(tenant, model_variables, document_version):
extracted_data, title = parse_html(html_content, html_tags, included_elements=html_included_elements,
excluded_elements=html_excluded_elements)
potential_chunks = create_potential_chunks(extracted_data, html_end_tags)
current_app.embed_tuning_logger.debug(f'Nr of potential chunks: {len(potential_chunks)}')
chunks = combine_chunks(potential_chunks,
model_variables['min_chunk_size'],
model_variables['max_chunk_size']
)
current_app.logger.debug(f'Nr of chunks: {len(chunks)}')
if len(chunks) > 0:
if len(chunks) > 1:
summary = summarize_chunk(tenant, model_variables, document_version, chunks[0])
document_version.system_context = (f'Title: {title}\n'
f'Summary: {summary}\n')
@@ -210,6 +226,7 @@ def process_html(tenant, model_variables, document_version):
def enrich_chunks(tenant, document_version, chunks):
current_app.logger.debug(f'Enriching chunks for tenant {tenant.id} '
f'on document version {document_version.id}')
current_app.logger.debug(f'Nr of chunks: {len(chunks)}')
chunk_total_context = (f'Filename: {document_version.file_name}\n'
f'User Context:{document_version.user_context}\n'
f'{document_version.system_context}\n\n')
@@ -233,8 +250,10 @@ def summarize_chunk(tenant, model_variables, document_version, chunk):
current_app.logger.debug(f'Summarizing chunk for tenant {tenant.id} '
f'on document version {document_version.id}')
llm = model_variables['llm']
prompt = model_variables['summary_prompt']
chain = load_summarize_chain(llm, chain_type='stuff', prompt=prompt)
template = model_variables['summary_template']
language_template = create_language_template(template, document_version.language)
current_app.logger.debug(f'Language prompt: {language_template}')
chain = load_summarize_chain(llm, chain_type='stuff', prompt=ChatPromptTemplate.from_template(language_template))
doc_creator = CharacterTextSplitter(chunk_size=model_variables['max_chunk_size'] * 2, chunk_overlap=0)
text_to_summarize = doc_creator.create_documents(chunk)
@@ -319,7 +338,6 @@ def parse_html(html_content, tags, included_elements=None, excluded_elements=Non
current_app.embed_tuning_logger.debug(f'Included Elements: {len(included_elements)}')
current_app.embed_tuning_logger.debug(f'Excluded Elements: {excluded_elements}')
current_app.embed_tuning_logger.debug(f'Found {len(elements_to_parse)} elements to parse')
current_app.embed_tuning_logger.debug(f'{elements_to_parse}')
# Iterate through the found included elements
for element in elements_to_parse:
@@ -327,7 +345,8 @@ def parse_html(html_content, tags, included_elements=None, excluded_elements=Non
for sub_element in element.find_all(tags):
if excluded_elements and sub_element.find_parent(excluded_elements):
continue # Skip this sub_element if it's within any of the excluded_elements
extracted_content.append((sub_element.name, sub_element.get_text(strip=True)))
sub_content = html.unescape(sub_element.get_text(strip=False))
extracted_content.append((sub_element.name, sub_content))
title = soup.find('title').get_text(strip=True)
@@ -362,11 +381,14 @@ def combine_chunks(potential_chunks, min_chars, max_chars):
current_length = 0
for chunk in potential_chunks:
current_app.embed_tuning_logger.debug(f'chunk: {chunk}')
chunk_content = ''.join(text for _, text in chunk)
current_app.embed_tuning_logger.debug(f'chunk_content: {chunk_content}')
chunk_length = len(chunk_content)
if current_length + chunk_length > max_chars:
if current_length >= min_chars:
current_app.embed_tuning_logger.debug(f'Adding chunk to actual_chunks: {current_chunk}')
actual_chunks.append(current_chunk)
current_chunk = chunk_content
current_length = chunk_length
@@ -378,8 +400,11 @@ def combine_chunks(potential_chunks, min_chars, max_chars):
current_chunk += chunk_content
current_length += chunk_length
current_app.embed_tuning_logger.debug(f'Remaining Chunk: {current_chunk}')
current_app.embed_tuning_logger.debug(f'Remaining Length: {current_length}')
# Handle the last chunk
if current_chunk and current_length >= min_chars:
if current_chunk and current_length >= 0:
actual_chunks.append(current_chunk)
return actual_chunks