import re from datetime import datetime as dt, timezone as tz from celery import states from flask import current_app # OpenAI imports from langchain.text_splitter import MarkdownHeaderTextSplitter from langchain_core.exceptions import LangChainException from langchain_core.output_parsers import StrOutputParser from langchain_core.prompts import ChatPromptTemplate from langchain_core.runnables import RunnablePassthrough from sqlalchemy import or_ from sqlalchemy.exc import SQLAlchemyError from common.extensions import db from common.models.document import DocumentVersion, Embedding, Document, Processor, Catalog from common.models.user import Tenant from common.utils.celery_utils import current_celery from common.utils.database import Database from common.utils.model_utils import create_language_template, get_model_variables, get_embedding_model_and_class, \ get_embedding_llm from common.utils.business_event import BusinessEvent from common.utils.business_event_context import current_event from config.type_defs.processor_types import PROCESSOR_TYPES from eveai_workers.processors.processor_registry import ProcessorRegistry from common.utils.eveai_exceptions import EveAIInvalidEmbeddingModel from common.utils.config_field_types import json_to_pattern_list # Healthcheck task @current_celery.task(name='ping', queue='embeddings') def ping(): return 'pong' @current_celery.task(name='create_embeddings', queue='embeddings') def create_embeddings(tenant_id, document_version_id): document_version = None try: # Retrieve Tenant for which we are processing tenant = Tenant.query.get(tenant_id) if tenant is None: raise Exception(f'Tenant {tenant_id} not found') # Ensure we are working in the correct database schema Database(tenant_id).switch_schema() # Retrieve document version to process document_version = DocumentVersion.query.get(document_version_id) if document_version is None: raise Exception(f'Document version {document_version_id} not found') # Retrieve the Catalog ID doc = Document.query.get_or_404(document_version.doc_id) catalog_id = doc.catalog_id catalog = Catalog.query.get_or_404(catalog_id) # Select variables to work with depending on tenant and model model_variables = get_model_variables(tenant_id) # Define processor related information processor_type, processor_class = ProcessorRegistry.get_processor_for_file_type(document_version.file_type) processor = get_processor_for_document(catalog_id, document_version.file_type, document_version.sub_file_type) except Exception as e: current_app.logger.error(f'Create Embeddings request received ' f'for badly configured document version {document_version_id} ' f'for tenant {tenant_id}, ' f'error: {e}') if document_version: document_version.processing_error = str(e) raise # BusinessEvent creates a context, which is why we need to use it with a with block with BusinessEvent('Create Embeddings', tenant_id, document_version_id=document_version_id, document_version_file_size=document_version.file_size): current_app.logger.info(f'Creating embeddings for tenant {tenant_id} on document version {document_version_id}') try: db.session.add(document_version) # start processing document_version.processing = True document_version.processing_started_at = dt.now(tz.utc) document_version.processing_finished_at = None document_version.processing_error = None db.session.commit() except SQLAlchemyError as e: current_app.logger.error(f'Unable to save Embedding status information ' f'in document version {document_version_id} ' f'for tenant {tenant_id}') raise delete_embeddings_for_document_version(document_version) try: with current_event.create_span(f"{processor_type} Processing"): document_processor = processor_class( tenant=tenant, model_variables=model_variables, document_version=document_version, catalog=catalog, processor=processor ) markdown, title = document_processor.process() document_processor.log_tuning("Processor returned: ", { 'markdown': markdown, 'title': title }) with current_event.create_span("Embedding"): embed_markdown(tenant, model_variables, document_version, catalog, document_processor, markdown, title) current_event.log("Finished Embedding Creation Task") except Exception as e: current_app.logger.error(f'Error creating embeddings for tenant {tenant_id} ' f'on document version {document_version_id} ' f'error: {e}') document_version.processing = False document_version.processing_finished_at = dt.now(tz.utc) document_version.processing_error = str(e)[:255] db.session.commit() create_embeddings.update_state(state=states.FAILURE) raise def delete_embeddings_for_document_version(document_version): embeddings_to_delete = db.session.query(Embedding).filter_by(doc_vers_id=document_version.id).all() for embedding in embeddings_to_delete: db.session.delete(embedding) try: db.session.commit() current_app.logger.info(f'Deleted embeddings for document version {document_version.id}') except SQLAlchemyError as e: current_app.logger.error(f'Unable to delete embeddings for document version {document_version.id}') raise def embed_markdown(tenant, model_variables, document_version, catalog, processor, markdown, title): # Create potential chunks potential_chunks = create_potential_chunks_for_markdown(tenant.id, document_version, processor, markdown) processor.log_tuning("Potential Chunks: ", {'potential chunks': potential_chunks}) # Combine chunks for embedding chunks = combine_chunks_for_markdown(potential_chunks, catalog.min_chunk_size, catalog.max_chunk_size, processor) processor.log_tuning("Chunks: ", {'chunks': chunks}) # Enrich chunks with current_event.create_span("Enrich Chunks"): enriched_chunks = enrich_chunks(tenant, model_variables, document_version, title, chunks) processor.log_tuning("Enriched Chunks: ", {'enriched_chunks': enriched_chunks}) # Create embeddings with current_event.create_span("Create Embeddings"): embeddings = embed_chunks(tenant, catalog, document_version, enriched_chunks) # Update document version and save embeddings try: db.session.add(document_version) document_version.processing_finished_at = dt.now(tz.utc) document_version.processing = False db.session.add_all(embeddings) db.session.commit() except SQLAlchemyError as e: current_app.logger.error(f'Error saving embedding information for tenant {tenant.id} ' f'on HTML, document version {document_version.id}' f'error: {e}') raise current_app.logger.info(f'Embeddings created successfully for tenant {tenant.id} ' f'on document version {document_version.id} :-)') def enrich_chunks(tenant, model_variables, document_version, title, chunks): summary = '' if len(chunks) > 1: summary = summarize_chunk(tenant, model_variables, document_version, chunks[0]) chunk_total_context = (f'Filename: {document_version.object_name}\n' f'User Context:\n{document_version.user_context}\n\n' f'User Metadata:\n{document_version.user_metadata}\n\n' f'Title: {title}\n' f'Summary:\n{summary}\n' f'System Context:\n{document_version.system_context}\n\n' f'System Metadata:\n{document_version.system_metadata}\n\n' ) enriched_chunks = [] initial_chunk = (f'Filename: {document_version.object_name}\n' f'User Context:\n{document_version.user_context}\n\n' f'User Metadata:\n{document_version.user_metadata}\n\n' f'Title: {title}\n' f'System Context:\n{document_version.system_context}\n\n' f'System Metadata:\n{document_version.system_metadata}\n\n' f'{chunks[0]}' ) enriched_chunks.append(initial_chunk) for chunk in chunks[1:]: enriched_chunk = f'{chunk_total_context}\n{chunk}' enriched_chunks.append(enriched_chunk) return enriched_chunks def summarize_chunk(tenant, model_variables, document_version, chunk): current_event.log("Starting Summarizing Chunk") llm = get_embedding_llm() template = model_variables.get_template("summary") language_template = create_language_template(template, document_version.language) summary_prompt = ChatPromptTemplate.from_template(language_template) setup = RunnablePassthrough() output_parser = StrOutputParser() chain = setup | summary_prompt | llm | output_parser try: summary = chain.invoke({"text": chunk}) current_event.log("Finished Summarizing Chunk") return summary except LangChainException as e: current_app.logger.error(f'Error creating summary for chunk enrichment for tenant {tenant.id} ' f'on document version {document_version.id} ' f'error: {e}') raise def embed_chunks(tenant, catalog, document_version, chunks): embedding_model, embedding_model_class = get_embedding_model_and_class(tenant.id, catalog.id) # Actually embed try: embeddings = embedding_model.embed_documents(chunks) except LangChainException as e: current_app.logger.error(f'Error creating embeddings for tenant {tenant.id} ' f'on document version {document_version.id} while calling OpenAI API' f'error: {e}') raise # Add embeddings to the database new_embeddings = [] for chunk, embedding in zip(chunks, embeddings): new_embedding = embedding_model_class() new_embedding.document_version = document_version new_embedding.active = True new_embedding.chunk = chunk new_embedding.embedding = embedding new_embeddings.append(new_embedding) return new_embeddings def create_potential_chunks_for_markdown(tenant_id, document_version, processor, markdown): try: current_app.logger.info(f'Creating potential chunks for tenant {tenant_id}') heading_level = processor.configuration.get('chunking_heading_level', 2) headers_to_split_on = [ (f"{'#' * i}", f"Header {i}") for i in range(1, min(heading_level + 1, 7)) ] processor.log_tuning('Headers to split on', {'header list: ': headers_to_split_on}) markdown_splitter = MarkdownHeaderTextSplitter(headers_to_split_on, strip_headers=False) md_header_splits = markdown_splitter.split_text(markdown) potential_chunks = [doc.page_content for doc in md_header_splits] return potential_chunks except Exception as e: current_app.logger.error(f'Error creating potential chunks for tenant {tenant_id}, with error: {e}') raise def combine_chunks_for_markdown(potential_chunks, min_chars, max_chars, processor): actual_chunks = [] current_chunk = "" current_length = 0 def matches_chunking_pattern(text, patterns): if not patterns: return False # Get the first line of the text first_line = text.split('\n', 1)[0].strip() # Check if it's a header at appropriate level header_match = re.match(r'^(#{1,6})\s+(.+)$', first_line) if not header_match: return False # Get the heading level (number of #s) header_level = len(header_match.group(1)) # Get the header text header_text = header_match.group(2) # Check if header matches any pattern for pattern in patterns: try: processor.log_tuning('Pattern check: ', { 'pattern: ': pattern, 'text': header_text }) if re.search(pattern, header_text, re.IGNORECASE): return True except Exception as e: current_app.logger.warning(f"Invalid regex pattern '{pattern}': {str(e)}") continue return False chunking_patterns = json_to_pattern_list(processor.configuration.get('chunking_patterns', "")) processor.log_tuning(f'Chunking Patterns Extraction: ', { 'Full Configuration': processor.configuration, 'Chunking Patterns': chunking_patterns, }) for chunk in potential_chunks: chunk_length = len(chunk) # Force new chunk if pattern matches if chunking_patterns and matches_chunking_pattern(chunk, chunking_patterns): if current_chunk and current_length >= min_chars: actual_chunks.append(current_chunk) current_chunk = chunk current_length = chunk_length continue if current_length + chunk_length > max_chars: if current_length >= min_chars: actual_chunks.append(current_chunk) current_chunk = chunk current_length = chunk_length else: # If the combined chunk is still less than max_chars, keep adding current_chunk += f'\n{chunk}' current_length += chunk_length else: current_chunk += f'\n{chunk}' current_length += chunk_length # Handle the last chunk if current_chunk and current_length >= 0: actual_chunks.append(current_chunk) return actual_chunks def get_processor_for_document(catalog_id: int, file_type: str, sub_file_type: str = None) -> Processor: """ Get the appropriate processor for a document based on catalog_id, file_type and optional sub_file_type. Args: catalog_id: ID of the catalog file_type: Type of file (e.g., 'pdf', 'html') sub_file_type: Optional sub-type for specialized processing Returns: Processor instance Raises: ValueError: If no matching processor is found """ try: # Start with base query for catalog query = Processor.query.filter_by(catalog_id=catalog_id) # Find processor type that handles this file type matching_processor_type = None for proc_type, config in PROCESSOR_TYPES.items(): supported_types = config['file_types'] if isinstance(supported_types, str): supported_types = [t.strip() for t in supported_types.split(',')] if file_type in supported_types: matching_processor_type = proc_type break if not matching_processor_type: raise ValueError(f"No processor type found for file type: {file_type}") # Add processor type condition query = query.filter_by(type=matching_processor_type) # If sub_file_type is provided, add that condition if sub_file_type: query = query.filter_by(sub_file_type=sub_file_type) else: # If no sub_file_type, prefer processors without sub_file_type specification query = query.filter(or_(Processor.sub_file_type.is_(None), Processor.sub_file_type == '')) # Get the first matching processor processor = query.first() if not processor: if sub_file_type: raise ValueError( f"No processor found for catalog {catalog_id} of type {matching_processor_type}, " f"file type {file_type}, sub-type {sub_file_type}" ) else: raise ValueError( f"No processor found for catalog {catalog_id}, " f"file type {file_type}" ) return processor except Exception as e: current_app.logger.error(f"Error finding processor: {str(e)}") raise