import io import os from datetime import datetime as dt, timezone as tz from celery import states from flask import current_app # OpenAI imports from langchain.text_splitter import MarkdownHeaderTextSplitter from langchain_core.exceptions import LangChainException from langchain_core.output_parsers import StrOutputParser from langchain_core.prompts import ChatPromptTemplate from langchain_core.runnables import RunnablePassthrough from sqlalchemy import or_ from sqlalchemy.exc import SQLAlchemyError from common.extensions import db, minio_client, template_manager from common.models.document import DocumentVersion, Embedding, Document, Processor, Catalog from common.models.user import Tenant from common.utils.celery_utils import current_celery from common.utils.database import Database from common.utils.model_utils import create_language_template, get_model_variables from common.utils.business_event import BusinessEvent from common.utils.business_event_context import current_event from config.processor_types import PROCESSOR_TYPES from eveai_workers.processors.processor_registry import ProcessorRegistry # Healthcheck task @current_celery.task(name='ping', queue='embeddings') def ping(): return 'pong' @current_celery.task(name='create_embeddings', queue='embeddings') def create_embeddings(tenant_id, document_version_id): try: # Retrieve Tenant for which we are processing tenant = Tenant.query.get(tenant_id) if tenant is None: raise Exception(f'Tenant {tenant_id} not found') # Ensure we are working in the correct database schema Database(tenant_id).switch_schema() # Retrieve document version to process document_version = DocumentVersion.query.get(document_version_id) if document_version is None: raise Exception(f'Document version {document_version_id} not found') # Retrieve the Catalog ID doc = Document.query.get_or_404(document_version.doc_id) catalog_id = doc.catalog_id catalog = Catalog.query.get_or_404(catalog_id) # Select variables to work with depending on tenant and model model_variables = get_model_variables(tenant_id) # Define processor related information processor_type, processor_class = ProcessorRegistry.get_processor_for_file_type(document_version.file_type) processor = get_processor_for_document(catalog_id, document_version.file_type, document_version.sub_file_type) except Exception as e: current_app.logger.error(f'Create Embeddings request received ' f'for badly configured document version {document_version_id} ' f'for tenant {tenant_id}, ' f'error: {e}') raise # BusinessEvent creates a context, which is why we need to use it with a with block with BusinessEvent('Create Embeddings', tenant_id, document_version_id=document_version_id, document_version_file_size=document_version.file_size): current_app.logger.info(f'Creating embeddings for tenant {tenant_id} on document version {document_version_id}') try: db.session.add(document_version) # start processing document_version.processing = True document_version.processing_started_at = dt.now(tz.utc) document_version.processing_finished_at = None document_version.processing_error = None db.session.commit() except SQLAlchemyError as e: current_app.logger.error(f'Unable to save Embedding status information ' f'in document version {document_version_id} ' f'for tenant {tenant_id}') raise delete_embeddings_for_document_version(document_version) try: with current_event.create_span(f"{processor_type} Processing"): document_processor = processor_class( tenant=tenant, model_variables=model_variables, document_version=document_version, catalog=catalog, processor=processor ) markdown, title = document_processor.process() with current_event.create_span("Embedding"): embed_markdown(tenant, model_variables, document_version, catalog, markdown, title) current_event.log("Finished Embedding Creation Task") except Exception as e: current_app.logger.error(f'Error creating embeddings for tenant {tenant_id} ' f'on document version {document_version_id} ' f'error: {e}') document_version.processing = False document_version.processing_finished_at = dt.now(tz.utc) document_version.processing_error = str(e)[:255] db.session.commit() create_embeddings.update_state(state=states.FAILURE) raise def delete_embeddings_for_document_version(document_version): embeddings_to_delete = db.session.query(Embedding).filter_by(doc_vers_id=document_version.id).all() for embedding in embeddings_to_delete: db.session.delete(embedding) try: db.session.commit() current_app.logger.info(f'Deleted embeddings for document version {document_version.id}') except SQLAlchemyError as e: current_app.logger.error(f'Unable to delete embeddings for document version {document_version.id}') raise def embed_markdown(tenant, model_variables, document_version, catalog, markdown, title): # Create potential chunks potential_chunks = create_potential_chunks_for_markdown(tenant.id, document_version, f"{document_version.id}.md") # Combine chunks for embedding chunks = combine_chunks_for_markdown(potential_chunks, catalog.min_chunk_size, catalog.max_chunk_size) # Enrich chunks with current_event.create_span("Enrich Chunks"): enriched_chunks = enrich_chunks(tenant, model_variables, document_version, title, chunks) # Create embeddings with current_event.create_span("Create Embeddings"): embeddings = embed_chunks(tenant, model_variables, document_version, enriched_chunks) # Update document version and save embeddings try: db.session.add(document_version) document_version.processing_finished_at = dt.now(tz.utc) document_version.processing = False db.session.add_all(embeddings) db.session.commit() except SQLAlchemyError as e: current_app.logger.error(f'Error saving embedding information for tenant {tenant.id} ' f'on HTML, document version {document_version.id}' f'error: {e}') raise current_app.logger.info(f'Embeddings created successfully for tenant {tenant.id} ' f'on document version {document_version.id} :-)') def enrich_chunks(tenant, model_variables, document_version, title, chunks): summary = '' if len(chunks) > 1: summary = summarize_chunk(tenant, model_variables, document_version, chunks[0]) chunk_total_context = (f'Filename: {document_version.object_name}\n' f'User Context:\n{document_version.user_context}\n\n' f'User Metadata:\n{document_version.user_metadata}\n\n' f'Title: {title}\n' f'Summary:\n{summary}\n' f'System Context:\n{document_version.system_context}\n\n' f'System Metadata:\n{document_version.system_metadata}\n\n' ) enriched_chunks = [] initial_chunk = (f'Filename: {document_version.object_name}\n' f'User Context:\n{document_version.user_context}\n\n' f'User Metadata:\n{document_version.user_metadata}\n\n' f'Title: {title}\n' f'System Context:\n{document_version.system_context}\n\n' f'System Metadata:\n{document_version.system_metadata}\n\n' f'{chunks[0]}' ) enriched_chunks.append(initial_chunk) for chunk in chunks[1:]: enriched_chunk = f'{chunk_total_context}\n{chunk}' enriched_chunks.append(enriched_chunk) return enriched_chunks def summarize_chunk(tenant, model_variables, document_version, chunk): current_event.log("Starting Summarizing Chunk") llm = model_variables.get_llm() template = model_variables.get_template("summary") language_template = create_language_template(template, document_version.language) summary_prompt = ChatPromptTemplate.from_template(language_template) setup = RunnablePassthrough() output_parser = StrOutputParser() chain = setup | summary_prompt | llm | output_parser try: summary = chain.invoke({"text": chunk}) current_event.log("Finished Summarizing Chunk") return summary except LangChainException as e: current_app.logger.error(f'Error creating summary for chunk enrichment for tenant {tenant.id} ' f'on document version {document_version.id} ' f'error: {e}') raise def embed_chunks(tenant, model_variables, document_version, chunks): embedding_model = model_variables.embedding_model try: embeddings = embedding_model.embed_documents(chunks) except LangChainException as e: current_app.logger.error(f'Error creating embeddings for tenant {tenant.id} ' f'on document version {document_version.id} while calling OpenAI API' f'error: {e}') raise # Add embeddings to the database new_embeddings = [] for chunk, embedding in zip(chunks, embeddings): new_embedding = model_variables.embedding_model_class() new_embedding.document_version = document_version new_embedding.active = True new_embedding.chunk = chunk new_embedding.embedding = embedding new_embeddings.append(new_embedding) return new_embeddings def create_potential_chunks_for_markdown(tenant_id, document_version, input_file): try: current_app.logger.info(f'Creating potential chunks for tenant {tenant_id}') markdown_on = document_version.object_name.rsplit('.', 1)[0] + '.md' # Download the markdown file from MinIO markdown_data = minio_client.download_document_file(tenant_id, document_version.bucket_name, markdown_on, ) markdown = markdown_data.decode('utf-8') headers_to_split_on = [ ("#", "Header 1"), ("##", "Header 2"), ] markdown_splitter = MarkdownHeaderTextSplitter(headers_to_split_on, strip_headers=False) md_header_splits = markdown_splitter.split_text(markdown) potential_chunks = [doc.page_content for doc in md_header_splits] return potential_chunks except Exception as e: current_app.logger.error(f'Error creating potential chunks for tenant {tenant_id}, with error: {e}') raise def combine_chunks_for_markdown(potential_chunks, min_chars, max_chars): actual_chunks = [] current_chunk = "" current_length = 0 for chunk in potential_chunks: chunk_length = len(chunk) if current_length + chunk_length > max_chars: if current_length >= min_chars: actual_chunks.append(current_chunk) current_chunk = chunk current_length = chunk_length else: # If the combined chunk is still less than max_chars, keep adding current_chunk += f'\n{chunk}' current_length += chunk_length else: current_chunk += f'\n{chunk}' current_length += chunk_length # Handle the last chunk if current_chunk and current_length >= 0: actual_chunks.append(current_chunk) return actual_chunks def get_processor_for_document(catalog_id: int, file_type: str, sub_file_type: str = None) -> Processor: """ Get the appropriate processor for a document based on catalog_id, file_type and optional sub_file_type. Args: catalog_id: ID of the catalog file_type: Type of file (e.g., 'pdf', 'html') sub_file_type: Optional sub-type for specialized processing Returns: Processor instance Raises: ValueError: If no matching processor is found """ try: # Start with base query for catalog query = Processor.query.filter_by(catalog_id=catalog_id) # Find processor type that handles this file type matching_processor_type = None for proc_type, config in PROCESSOR_TYPES.items(): supported_types = config['file_types'] if isinstance(supported_types, str): supported_types = [t.strip() for t in supported_types.split(',')] if file_type in supported_types: matching_processor_type = proc_type break if not matching_processor_type: raise ValueError(f"No processor type found for file type: {file_type}") # Add processor type condition query = query.filter_by(type=matching_processor_type) # If sub_file_type is provided, add that condition if sub_file_type: query = query.filter_by(sub_file_type=sub_file_type) else: # If no sub_file_type, prefer processors without sub_file_type specification query = query.filter(or_(Processor.sub_file_type.is_(None), Processor.sub_file_type == '')) # Get the first matching processor processor = query.first() if not processor: if sub_file_type: raise ValueError( f"No processor found for catalog {catalog_id} of type {matching_processor_type}, " f"file type {file_type}, sub-type {sub_file_type}" ) else: raise ValueError( f"No processor found for catalog {catalog_id}, " f"file type {file_type}" ) return processor except Exception as e: current_app.logger.error(f"Error finding processor: {str(e)}") raise