Files
eveAI/eveai_workers/tasks.py
2024-07-04 08:11:31 +02:00

593 lines
26 KiB
Python

import os
from datetime import datetime as dt, timezone as tz
import gevent
from bs4 import BeautifulSoup
import html
from celery import states
from flask import current_app
# OpenAI imports
from langchain.chains.summarize import load_summarize_chain
from langchain.text_splitter import CharacterTextSplitter, MarkdownHeaderTextSplitter
from langchain_core.exceptions import LangChainException
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from sqlalchemy.exc import SQLAlchemyError
# Unstructured commercial client imports
from unstructured_client import UnstructuredClient
from unstructured_client.models import shared
from unstructured_client.models.errors import SDKError
from pytube import YouTube
from common.extensions import db
from common.models.document import DocumentVersion, Embedding
from common.models.user import Tenant
from common.utils.celery_utils import current_celery
from common.utils.database import Database
from common.utils.model_utils import select_model_variables, create_language_template
@current_celery.task(name='create_embeddings', queue='embeddings')
def create_embeddings(tenant_id, document_version_id):
# Setup Remote Debugging only if PYCHARM_DEBUG=True
if current_app.config['PYCHARM_DEBUG']:
import pydevd_pycharm
pydevd_pycharm.settrace('localhost', port=50170, stdoutToServer=True, stderrToServer=True)
current_app.logger.info(f'Creating embeddings for tenant {tenant_id} on document version {document_version_id}.')
try:
# Retrieve Tenant for which we are processing
tenant = Tenant.query.get(tenant_id)
if tenant is None:
raise Exception(f'Tenant {tenant_id} not found')
# Ensure we are working in the correct database schema
Database(tenant_id).switch_schema()
# Select variables to work with depending on tenant and model
model_variables = select_model_variables(tenant)
# Retrieve document version to process
document_version = DocumentVersion.query.get(document_version_id)
if document_version is None:
raise Exception(f'Document version {document_version_id} not found')
except Exception as e:
current_app.logger.error(f'Create Embeddings request received '
f'for non existing document version {document_version_id} '
f'for tenant {tenant_id}, '
f'error: {e}')
raise
try:
db.session.add(document_version)
# start processing
document_version.processing = True
document_version.processing_started_at = dt.now(tz.utc)
document_version.processing_finished_at = None
document_version.processing_error = None
db.session.commit()
except SQLAlchemyError as e:
current_app.logger.error(f'Unable to save Embedding status information '
f'in document version {document_version_id} '
f'for tenant {tenant_id}')
raise
delete_embeddings_for_document_version(document_version)
try:
match document_version.file_type:
case 'pdf':
process_pdf(tenant, model_variables, document_version)
case 'html':
process_html(tenant, model_variables, document_version)
case 'youtube':
process_youtube(tenant, model_variables, document_version)
case _:
raise Exception(f'No functionality defined for file type {document_version.file_type} '
f'for tenant {tenant_id} '
f'while creating embeddings for document version {document_version_id}')
except Exception as e:
current_app.logger.error(f'Error creating embeddings for tenant {tenant_id} '
f'on document version {document_version_id} '
f'error: {e}')
document_version.processing = False
document_version.processing_finished_at = dt.now(tz.utc)
document_version.processing_error = str(e)[:255]
db.session.commit()
create_embeddings.update_state(state=states.FAILURE)
raise
def process_pdf(tenant, model_variables, document_version):
file_path = os.path.join(current_app.config['UPLOAD_FOLDER'],
document_version.file_location,
document_version.file_name)
if os.path.exists(file_path):
with open(file_path, 'rb') as f:
files = shared.Files(content=f.read(), file_name=document_version.file_name)
req = shared.PartitionParameters(
files=files,
strategy='hi_res',
hi_res_model_name='yolox',
coordinates=True,
extract_image_block_types=['Image', 'Table'],
chunking_strategy='by_title',
combine_under_n_chars=model_variables['min_chunk_size'],
max_characters=model_variables['max_chunk_size'],
)
else:
current_app.logger.error(f'The physical file for document version {document_version.id} '
f'for tenant {tenant.id} '
f'at {file_path} does not exist')
create_embeddings.update_state(state=states.FAILURE)
raise
try:
chunks = partition_doc_unstructured(tenant, document_version, req)
except Exception as e:
current_app.logger.error(f'Unable to create Embeddings for tenant {tenant.id} '
f'while processing PDF on document version {document_version.id} '
f'error: {e}')
create_embeddings.update_state(state=states.FAILURE)
raise
summary = summarize_chunk(tenant, model_variables, document_version, chunks[0])
document_version.system_context = f'Summary: {summary}\n'
enriched_chunks = enrich_chunks(tenant, document_version, chunks)
embeddings = embed_chunks(tenant, model_variables, document_version, enriched_chunks)
try:
db.session.add(document_version)
document_version.processing_finished_at = dt.now(tz.utc)
document_version.processing = False
db.session.add_all(embeddings)
db.session.commit()
except SQLAlchemyError as e:
current_app.logger.error(f'Error saving embedding information for tenant {tenant.id} '
f'on PDF, document version {document_version.id}'
f'error: {e}')
db.session.rollback()
create_embeddings.update_state(state=states.FAILURE)
raise
current_app.logger.info(f'Embeddings created successfully for tenant {tenant.id} '
f'on document version {document_version.id} :-)')
def delete_embeddings_for_document_version(document_version):
embeddings_to_delete = db.session.query(Embedding).filter_by(doc_vers_id=document_version.id).all()
for embedding in embeddings_to_delete:
db.session.delete(embedding)
try:
db.session.commit()
current_app.logger.info(f'Deleted embeddings for document version {document_version.id}')
except SQLAlchemyError as e:
current_app.logger.error(f'Unable to delete embeddings for document version {document_version.id}')
raise
def process_html(tenant, model_variables, document_version):
# The tags to be considered can be dependent on the tenant
html_tags = model_variables['html_tags']
html_end_tags = model_variables['html_end_tags']
html_included_elements = model_variables['html_included_elements']
html_excluded_elements = model_variables['html_excluded_elements']
file_path = os.path.join(current_app.config['UPLOAD_FOLDER'],
document_version.file_location,
document_version.file_name)
if os.path.exists(file_path):
with open(file_path, 'rb') as f:
html_content = f.read()
else:
current_app.logger.error(f'The physical file for document version {document_version.id} '
f'for tenant {tenant.id} '
f'at {file_path} does not exist')
create_embeddings.update_state(state=states.FAILURE)
raise
extracted_data, title = parse_html(html_content, html_tags, included_elements=html_included_elements,
excluded_elements=html_excluded_elements)
potential_chunks = create_potential_chunks(extracted_data, html_end_tags)
current_app.embed_tuning_logger.debug(f'Nr of potential chunks: {len(potential_chunks)}')
chunks = combine_chunks(potential_chunks,
model_variables['min_chunk_size'],
model_variables['max_chunk_size']
)
current_app.logger.debug(f'Nr of chunks: {len(chunks)}')
if len(chunks) > 1:
summary = summarize_chunk(tenant, model_variables, document_version, chunks[0])
document_version.system_context = (f'Title: {title}\n'
f'Summary: {summary}\n')
else:
document_version.system_context = (f'Title: {title}\n')
enriched_chunks = enrich_chunks(tenant, document_version, chunks)
embeddings = embed_chunks(tenant, model_variables, document_version, enriched_chunks)
try:
db.session.add(document_version)
document_version.processing_finished_at = dt.now(tz.utc)
document_version.processing = False
db.session.add_all(embeddings)
db.session.commit()
except SQLAlchemyError as e:
current_app.logger.error(f'Error saving embedding information for tenant {tenant.id} '
f'on HTML, document version {document_version.id}'
f'error: {e}')
raise
current_app.logger.info(f'Embeddings created successfully for tenant {tenant.id} '
f'on document version {document_version.id} :-)')
def enrich_chunks(tenant, document_version, chunks):
current_app.logger.debug(f'Enriching chunks for tenant {tenant.id} '
f'on document version {document_version.id}')
current_app.logger.debug(f'Nr of chunks: {len(chunks)}')
chunk_total_context = (f'Filename: {document_version.file_name}\n'
f'User Context:{document_version.user_context}\n'
f'{document_version.system_context}\n\n')
enriched_chunks = []
initial_chunk = (f'Filename: {document_version.file_name}\n'
f'User Context:\n{document_version.user_context}\n\n'
f'{chunks[0]}')
enriched_chunks.append(initial_chunk)
for chunk in chunks[1:]:
enriched_chunk = f'{chunk_total_context}\n{chunk}'
enriched_chunks.append(enriched_chunk)
current_app.logger.debug(f'Finished enriching chunks for tenant {tenant.id} '
f'on document version {document_version.id}')
return enriched_chunks
def summarize_chunk(tenant, model_variables, document_version, chunk):
current_app.logger.debug(f'Summarizing chunk for tenant {tenant.id} '
f'on document version {document_version.id}')
llm = model_variables['llm']
template = model_variables['summary_template']
language_template = create_language_template(template, document_version.language)
current_app.logger.debug(f'Language prompt: {language_template}')
chain = load_summarize_chain(llm, chain_type='stuff', prompt=ChatPromptTemplate.from_template(language_template))
doc_creator = CharacterTextSplitter(chunk_size=model_variables['max_chunk_size'] * 2, chunk_overlap=0)
text_to_summarize = doc_creator.create_documents(chunk)
try:
summary = chain.run(text_to_summarize)
current_app.logger.debug(f'Finished summarizing chunk for tenant {tenant.id} '
f'on document version {document_version.id}.')
return summary
except LangChainException as e:
current_app.logger.error(f'Error creating summary for chunk enrichment for tenant {tenant.id} '
f'on document version {document_version.id} '
f'error: {e}')
raise
def partition_doc_unstructured(tenant, document_version, unstructured_request):
current_app.logger.debug(f'Partitioning document version {document_version.id} for tenant {tenant.id}')
# Initiate the connection to unstructured.io
url = current_app.config.get('UNSTRUCTURED_FULL_URL')
api_key = current_app.config.get('UNSTRUCTURED_API_KEY')
unstructured_client = UnstructuredClient(server_url=url, api_key_auth=api_key)
try:
res = unstructured_client.general.partition(unstructured_request)
chunks = []
for el in res.elements:
match el['type']:
case 'CompositeElement':
chunks.append(el['text'])
case 'Image':
pass
case 'Table':
chunks.append(el['metadata']['text_as_html'])
current_app.logger.debug(f'Finished partioning document version {document_version.id} for tenant {tenant.id}')
return chunks
except SDKError as e:
current_app.logger.error(f'Error creating embeddings for tenant {tenant.id} '
f'on document version {document_version.id} while chuncking'
f'error: {e}')
raise
def embed_chunks(tenant, model_variables, document_version, chunks):
current_app.logger.debug(f'Embedding chunks for tenant {tenant.id} '
f'on document version {document_version.id}')
embedding_model = model_variables['embedding_model']
try:
embeddings = embedding_model.embed_documents(chunks)
current_app.logger.debug(f'Finished embedding chunks for tenant {tenant.id} '
f'on document version {document_version.id}')
except LangChainException as e:
current_app.logger.error(f'Error creating embeddings for tenant {tenant.id} '
f'on document version {document_version.id} while calling OpenAI API'
f'error: {e}')
raise
# Add embeddings to the database
new_embeddings = []
for chunk, embedding in zip(chunks, embeddings):
new_embedding = model_variables['embedding_db_model']()
new_embedding.document_version = document_version
new_embedding.active = True
new_embedding.chunk = chunk
new_embedding.embedding = embedding
new_embeddings.append(new_embedding)
return new_embeddings
def parse_html(html_content, tags, included_elements=None, excluded_elements=None):
soup = BeautifulSoup(html_content, 'html.parser')
extracted_content = []
if included_elements:
elements_to_parse = soup.find_all(included_elements)
else:
elements_to_parse = [soup] # parse the entire document if no included_elements specified
current_app.embed_tuning_logger.debug(f'Included Elements: {included_elements}')
current_app.embed_tuning_logger.debug(f'Included Elements: {len(included_elements)}')
current_app.embed_tuning_logger.debug(f'Excluded Elements: {excluded_elements}')
current_app.embed_tuning_logger.debug(f'Found {len(elements_to_parse)} elements to parse')
# Iterate through the found included elements
for element in elements_to_parse:
# Find all specified tags within each included element
for sub_element in element.find_all(tags):
if excluded_elements and sub_element.find_parent(excluded_elements):
continue # Skip this sub_element if it's within any of the excluded_elements
sub_content = html.unescape(sub_element.get_text(strip=False))
extracted_content.append((sub_element.name, sub_content))
title = soup.find('title').get_text(strip=True)
return extracted_content, title
def create_potential_chunks(extracted_data, end_tags):
potential_chunks = []
current_chunk = []
for tag, text in extracted_data:
formatted_text = f"- {text}" if tag == 'li' else f"{text}\n"
if current_chunk and tag in end_tags and current_chunk[-1][0] in end_tags:
# Consecutive li and p elements stay together
current_chunk.append((tag, formatted_text))
else:
# End the current chunk if the last element was an end tag
if current_chunk and current_chunk[-1][0] in end_tags:
potential_chunks.append(current_chunk)
current_chunk = []
current_chunk.append((tag, formatted_text))
# Add the last chunk
if current_chunk:
potential_chunks.append(current_chunk)
return potential_chunks
def combine_chunks(potential_chunks, min_chars, max_chars):
actual_chunks = []
current_chunk = ""
current_length = 0
for chunk in potential_chunks:
current_app.embed_tuning_logger.debug(f'chunk: {chunk}')
chunk_content = ''.join(text for _, text in chunk)
current_app.embed_tuning_logger.debug(f'chunk_content: {chunk_content}')
chunk_length = len(chunk_content)
if current_length + chunk_length > max_chars:
if current_length >= min_chars:
current_app.embed_tuning_logger.debug(f'Adding chunk to actual_chunks: {current_chunk}')
actual_chunks.append(current_chunk)
current_chunk = chunk_content
current_length = chunk_length
else:
# If the combined chunk is still less than max_chars, keep adding
current_chunk += chunk_content
current_length += chunk_length
else:
current_chunk += chunk_content
current_length += chunk_length
current_app.embed_tuning_logger.debug(f'Remaining Chunk: {current_chunk}')
current_app.embed_tuning_logger.debug(f'Remaining Length: {current_length}')
# Handle the last chunk
if current_chunk and current_length >= 0:
actual_chunks.append(current_chunk)
return actual_chunks
def process_youtube(tenant, model_variables, document_version):
base_path = os.path.join(current_app.config['UPLOAD_FOLDER'],
document_version.file_location)
# clean old files if necessary
of, title, description, author = download_youtube(document_version.url, base_path, 'downloaded.mp4', tenant)
document_version.system_context = f'Title: {title}\nDescription: {description}\nAuthor: {author}'
compress_audio(base_path, 'downloaded.mp4', 'compressed.mp3', tenant)
transcribe_audio(base_path, 'compressed.mp3', 'transcription.txt', document_version.language, tenant, model_variables)
annotate_transcription(base_path, 'transcription.txt', 'transcription.md', tenant, model_variables)
potential_chunks = create_potential_chunks_for_markdown(base_path, 'transcription.md', tenant)
actual_chunks = combine_chunks_for_markdown(potential_chunks, model_variables['min_chunk_size'],
model_variables['max_chunk_size'])
enriched_chunks = enrich_chunks(tenant, document_version, actual_chunks)
embeddings = embed_chunks(tenant, model_variables, document_version, enriched_chunks)
try:
db.session.add(document_version)
document_version.processing_finished_at = dt.now(tz.utc)
document_version.processing = False
db.session.add_all(embeddings)
db.session.commit()
except SQLAlchemyError as e:
current_app.logger.error(f'Error saving embedding information for tenant {tenant.id} '
f'on Youtube document version {document_version.id}'
f'error: {e}')
raise
current_app.logger.info(f'Embeddings created successfully for tenant {tenant.id} '
f'on Youtube document version {document_version.id} :-)')
def download_youtube(url, file_location, file_name, tenant):
try:
current_app.logger.info(f'Downloading YouTube video: {url} on location {file_location} for tenant: {tenant.id}')
yt = YouTube(url)
stream = yt.streams.get_audio_only()
output_file = stream.download(output_path=file_location, filename=file_name)
current_app.logger.info(f'Downloaded YouTube video: {url} on location {file_location} for tenant: {tenant.id}')
return output_file, yt.title, yt.description, yt.author
except Exception as e:
current_app.logger.error(f'Error downloading YouTube video: {url} on location {file_location} for '
f'tenant: {tenant.id} with error: {e}')
raise
def compress_audio(file_location, input_file, output_file, tenant):
try:
current_app.logger.info(f'Compressing audio on {file_location} for tenant: {tenant.id}')
result = os.popen(f'scripts/compress.sh -d {file_location} -i {input_file} -o {output_file}')
output_file_path = os.path.join(file_location, output_file)
count = 0
while not os.path.exists(output_file_path) and count < 10:
gevent.sleep(1)
current_app.logger.debug(f'Waiting for {output_file_path} to be created... Count: {count}')
count += 1
current_app.logger.info(f'Compressed audio for {file_location} for tenant: {tenant.id}')
return result
except Exception as e:
current_app.logger.error(f'Error compressing audio on {file_location} for tenant: {tenant.id} with error: {e}')
raise
def transcribe_audio(file_location, input_file, output_file, language, tenant, model_variables):
try:
current_app.logger.info(f'Transcribing audio on {file_location} for tenant: {tenant.id}')
client = model_variables['transcription_client']
model = model_variables['transcription_model']
input_file_path = os.path.join(file_location, input_file)
output_file_path = os.path.join(file_location, output_file)
count = 0
while not os.path.exists(input_file_path) and count < 10:
gevent.sleep(1)
current_app.logger.debug(f'Waiting for {input_file_path} to exist... Count: {count}')
count += 1
with open(input_file_path, 'rb') as audio_file:
transcription = client.audio.transcriptions.create(
file=audio_file,
model=model,
language=language,
response_format='verbose_json',
)
with open(output_file_path, 'w') as transcript_file:
transcript_file.write(transcription.text)
current_app.logger.info(f'Transcribed audio for {file_location} for tenant: {tenant.id}')
except Exception as e:
current_app.logger.error(f'Error transcribing audio for {file_location} for tenant: {tenant.id}, '
f'with error: {e}')
raise
def annotate_transcription(file_location, input_file, output_file, tenant, model_variables):
try:
current_app.logger.debug(f'Annotating transcription on {file_location} for tenant {tenant.id}')
llm = model_variables['llm']
template = model_variables['transcript_template']
transcript_prompt = ChatPromptTemplate.from_template(template)
setup = RunnablePassthrough()
output_parser = StrOutputParser()
transcript = ''
with open(os.path.join(file_location, input_file), 'r') as f:
transcript = f.read()
chain = setup | transcript_prompt | llm | output_parser
input_transcript = {"transcript": transcript}
annotated_transcript = chain.invoke(input_transcript)
with open(os.path.join(file_location, output_file), 'w') as f:
f.write(annotated_transcript)
current_app.logger.info(f'Annotated transcription for {file_location} for tenant {tenant.id}')
except Exception as e:
current_app.logger.error(f'Error annotating transcription for {file_location} for tenant {tenant.id}, '
f'with error: {e}')
raise
def create_potential_chunks_for_markdown(base_path, input_file, tenant):
current_app.logger.info(f'Creating potential chunks for {base_path} for tenant {tenant.id}')
markdown = ''
with open(os.path.join(base_path, input_file), 'r') as f:
markdown = f.read()
headers_to_split_on = [
("#", "Header 1"),
("##", "Header 2"),
# ("###", "Header 3"),
]
markdown_splitter = MarkdownHeaderTextSplitter(headers_to_split_on, strip_headers=False)
md_header_splits = markdown_splitter.split_text(markdown)
potential_chunks = [doc.page_content for doc in md_header_splits]
return potential_chunks
def combine_chunks_for_markdown(potential_chunks, min_chars, max_chars):
actual_chunks = []
current_chunk = ""
current_length = 0
for chunk in potential_chunks:
chunk_length = len(chunk)
if current_length + chunk_length > max_chars:
if current_length >= min_chars:
actual_chunks.append(current_chunk)
current_chunk = chunk
current_length = chunk_length
else:
# If the combined chunk is still less than max_chars, keep adding
current_chunk += f'\n{chunk}'
current_length += chunk_length
else:
current_chunk += f'\n{chunk}'
current_length += chunk_length
# Handle the last chunk
if current_chunk and current_length >= 0:
actual_chunks.append(current_chunk)
return actual_chunks
pass