Remove ModelVariables (model_utils) from application & optimize Tenant

This commit is contained in:
Josako
2025-05-20 10:17:08 +02:00
parent 70de4c0328
commit d789e431ca
17 changed files with 83 additions and 206 deletions

View File

@@ -24,16 +24,12 @@ class Tenant(db.Model):
name = db.Column(db.String(80), unique=True, nullable=False) name = db.Column(db.String(80), unique=True, nullable=False)
website = db.Column(db.String(255), nullable=True) website = db.Column(db.String(255), nullable=True)
timezone = db.Column(db.String(50), nullable=True, default='UTC') timezone = db.Column(db.String(50), nullable=True, default='UTC')
rag_context = db.Column(db.Text, nullable=True)
type = db.Column(db.String(20), nullable=True, server_default='Active') type = db.Column(db.String(20), nullable=True, server_default='Active')
# language information # language information
default_language = db.Column(db.String(2), nullable=True) default_language = db.Column(db.String(2), nullable=True)
allowed_languages = db.Column(ARRAY(sa.String(2)), nullable=True) allowed_languages = db.Column(ARRAY(sa.String(2)), nullable=True)
# LLM specific choices
llm_model = db.Column(db.String(50), nullable=True)
# Entitlements # Entitlements
currency = db.Column(db.String(20), nullable=True) currency = db.Column(db.String(20), nullable=True)
storage_dirty = db.Column(db.Boolean, nullable=True, default=False) storage_dirty = db.Column(db.Boolean, nullable=True, default=False)
@@ -62,11 +58,9 @@ class Tenant(db.Model):
'name': self.name, 'name': self.name,
'website': self.website, 'website': self.website,
'timezone': self.timezone, 'timezone': self.timezone,
'rag_context': self.rag_context,
'type': self.type, 'type': self.type,
'default_language': self.default_language, 'default_language': self.default_language,
'allowed_languages': self.allowed_languages, 'allowed_languages': self.allowed_languages,
'llm_model': self.llm_model,
'currency': self.currency, 'currency': self.currency,
} }

View File

@@ -139,142 +139,26 @@ def process_pdf():
full_model_name = 'mistral-ocr-latest' full_model_name = 'mistral-ocr-latest'
def get_template(template_name: str, version: Optional[str] = "1.0") -> tuple[ def get_template(template_name: str, version: Optional[str] = "1.0", temperature: float = 0.3) -> tuple[
Any, BaseChatModel | None | ChatOpenAI | ChatMistralAI]: Any, BaseChatModel | None | ChatOpenAI | ChatMistralAI]:
""" """
Get a prompt template Get a prompt template
""" """
prompt = cache_manager.prompts_config_cache.get_config(template_name, version) prompt = cache_manager.prompts_config_cache.get_config(template_name, version)
if "llm_model" in prompt: if "llm_model" in prompt:
llm = get_embedding_llm(full_model_name=prompt["llm_model"]) llm = get_embedding_llm(full_model_name=prompt["llm_model"], temperature=temperature)
else: else:
llm = get_embedding_llm() llm = get_embedding_llm(temperature=temperature)
return prompt["content"], llm return prompt["content"], llm
class ModelVariables: def get_transcription_model(model_name: str = "whisper-1") -> TrackedOpenAITranscription:
"""Manages model-related variables and configurations"""
def __init__(self, tenant_id: int, variables: Dict[str, Any] = None):
""" """
Initialize ModelVariables with tenant and optional template manager Get a transcription model instance
Args:
tenant_id: Tenant instance
variables: Optional variables
""" """
current_app.logger.info(f'Model variables initialized with tenant {tenant_id} and variables \n{variables}')
self.tenant_id = tenant_id
self._variables = variables if variables is not None else self._initialize_variables()
current_app.logger.info(f'Model _variables initialized to {self._variables}')
self._llm_instances = {}
self.llm_metrics_handler = LLMMetricsHandler()
self._transcription_model = None
def _initialize_variables(self) -> Dict[str, Any]:
"""Initialize the variables dictionary"""
variables = {}
tenant = Tenant.query.get(self.tenant_id)
if not tenant:
raise EveAITenantNotFound(self.tenant_id)
# Set model providers
variables['llm_provider'], variables['llm_model'] = tenant.llm_model.split('.')
variables['llm_full_model'] = tenant.llm_model
# Set model-specific configurations
model_config = MODEL_CONFIG.get(variables['llm_provider'], {}).get(variables['llm_model'], {})
variables.update(model_config)
# Additional configurations
variables['annotation_chunk_length'] = current_app.config['ANNOTATION_TEXT_CHUNK_LENGTH'][tenant.llm_model]
variables['max_compression_duration'] = current_app.config['MAX_COMPRESSION_DURATION']
variables['max_transcription_duration'] = current_app.config['MAX_TRANSCRIPTION_DURATION']
variables['compression_cpu_limit'] = current_app.config['COMPRESSION_CPU_LIMIT']
variables['compression_process_delay'] = current_app.config['COMPRESSION_PROCESS_DELAY']
return variables
@property
def annotation_chunk_length(self):
return self._variables['annotation_chunk_length']
@property
def max_compression_duration(self):
return self._variables['max_compression_duration']
@property
def max_transcription_duration(self):
return self._variables['max_transcription_duration']
@property
def compression_cpu_limit(self):
return self._variables['compression_cpu_limit']
@property
def compression_process_delay(self):
return self._variables['compression_process_delay']
def get_llm(self, temperature: float = 0.3, **kwargs) -> Any:
"""
Get an LLM instance with specific configuration
Args:
temperature: The temperature for the LLM
**kwargs: Additional configuration parameters
Returns:
An instance of the configured LLM
"""
cache_key = f"{temperature}_{hash(frozenset(kwargs.items()))}"
if cache_key not in self._llm_instances:
provider = self._variables['llm_provider']
model = self._variables['llm_model']
if provider == 'openai':
self._llm_instances[cache_key] = ChatOpenAI(
api_key=os.getenv('OPENAI_API_KEY'),
model=model,
temperature=temperature,
callbacks=[self.llm_metrics_handler],
**kwargs
)
elif provider == 'anthropic':
self._llm_instances[cache_key] = ChatAnthropic(
api_key=os.getenv('ANTHROPIC_API_KEY'),
model=current_app.config['ANTHROPIC_LLM_VERSIONS'][model],
temperature=temperature,
callbacks=[self.llm_metrics_handler],
**kwargs
)
else:
raise ValueError(f"Unsupported LLM provider: {provider}")
return self._llm_instances[cache_key]
@property
def transcription_model(self) -> TrackedOpenAITranscription:
"""Get the transcription model instance"""
if self._transcription_model is None:
api_key = os.getenv('OPENAI_API_KEY') api_key = os.getenv('OPENAI_API_KEY')
self._transcription_model = TrackedOpenAITranscription( return TrackedOpenAITranscription(
api_key=api_key, api_key=api_key,
model='whisper-1' model=model_name
) )
return self._transcription_model
# Remove the old transcription-related methods since they're now handled by TrackedOpenAITranscription
@property
def transcription_client(self):
raise DeprecationWarning("Use transcription_model instead")
def transcribe(self, *args, **kwargs):
raise DeprecationWarning("Use transcription_model.transcribe() instead")
# Helper function to get cached model variables
def get_model_variables(tenant_id: int) -> ModelVariables:
return ModelVariables(tenant_id=tenant_id)

View File

@@ -81,11 +81,7 @@ class Config(object):
ANTHROPIC_LLM_VERSIONS = {'claude-3-5-sonnet': 'claude-3-5-sonnet-20240620', } ANTHROPIC_LLM_VERSIONS = {'claude-3-5-sonnet': 'claude-3-5-sonnet-20240620', }
# Annotation text chunk length # Annotation text chunk length
ANNOTATION_TEXT_CHUNK_LENGTH = { ANNOTATION_TEXT_CHUNK_LENGTH = 10000
'openai.gpt-4o': 10000,
'openai.gpt-4o-mini': 10000,
'anthropic.claude-3-5-sonnet': 8000
}
# Environemnt Loaders # Environemnt Loaders
OPENAI_API_KEY = environ.get('OPENAI_API_KEY') OPENAI_API_KEY = environ.get('OPENAI_API_KEY')

View File

@@ -10,7 +10,7 @@
<form method="post"> <form method="post">
{{ form.hidden_tag() }} {{ form.hidden_tag() }}
<!-- Main Tenant Information --> <!-- Main Tenant Information -->
{% set main_fields = ['name', 'code', 'website', 'default_language', 'allowed_languages', 'rag_context', 'type'] %} {% set main_fields = ['name', 'code', 'website', 'default_language', 'allowed_languages', 'type'] %}
{% for field in form %} {% for field in form %}
{{ render_included_field(field, disabled_fields=main_fields, include_fields=main_fields) }} {{ render_included_field(field, disabled_fields=main_fields, include_fields=main_fields) }}
{% endfor %} {% endfor %}
@@ -20,11 +20,6 @@
<div class="col-lg-12"> <div class="col-lg-12">
<div class="nav-wrapper position-relative end-0"> <div class="nav-wrapper position-relative end-0">
<ul class="nav nav-pills nav-fill p-1" role="tablist"> <ul class="nav nav-pills nav-fill p-1" role="tablist">
<li class="nav-item" role="presentation">
<a class="nav-link mb-0 px-0 py-1 active" data-toggle="tab" href="#model-info-tab" role="tab" aria-controls="model-info" aria-selected="true">
Model Information
</a>
</li>
<li class="nav-item"> <li class="nav-item">
<a class="nav-link mb-0 px-0 py-1" data-toggle="tab" href="#license-info-tab" role="tab" aria-controls="license-info" aria-selected="false"> <a class="nav-link mb-0 px-0 py-1" data-toggle="tab" href="#license-info-tab" role="tab" aria-controls="license-info" aria-selected="false">
License Information License Information
@@ -33,13 +28,6 @@
</ul> </ul>
</div> </div>
<div class="tab-content tab-space"> <div class="tab-content tab-space">
<!-- Model Information Tab -->
<div class="tab-pane fade show active" id="model-info-tab" role="tabpanel">
{% set model_fields = ['llm_model'] %}
{% for field in form %}
{{ render_included_field(field, disabled_fields=model_fields, include_fields=model_fields) }}
{% endfor %}
</div>
<!-- License Information Tab --> <!-- License Information Tab -->
<div class="tab-pane fade" id="license-info-tab" role="tabpanel"> <div class="tab-pane fade" id="license-info-tab" role="tabpanel">
{% set license_fields = ['currency', 'usage_email', ] %} {% set license_fields = ['currency', 'usage_email', ] %}

View File

@@ -22,8 +22,6 @@ class TenantForm(FlaskForm):
currency = SelectField('Currency', choices=[], validators=[DataRequired()]) currency = SelectField('Currency', choices=[], validators=[DataRequired()])
# Timezone # Timezone
timezone = SelectField('Timezone', choices=[], validators=[DataRequired()]) timezone = SelectField('Timezone', choices=[], validators=[DataRequired()])
# LLM fields
llm_model = SelectField('Large Language Model', choices=[], validators=[DataRequired()])
# For Super Users only - Allow to assign the tenant to the partner # For Super Users only - Allow to assign the tenant to the partner
assign_to_partner = BooleanField('Assign to Partner', default=False) assign_to_partner = BooleanField('Assign to Partner', default=False)

View File

@@ -177,7 +177,6 @@ def register_chat_session_cache_handlers(cache_manager):
cache_manager.register_handler(ChatSessionCacheHandler, 'eveai_chat_workers') cache_manager.register_handler(ChatSessionCacheHandler, 'eveai_chat_workers')
# Helper function similar to get_model_variables
def get_chat_history(session_id: str) -> CachedSession: def get_chat_history(session_id: str) -> CachedSession:
""" """
Get cached chat history for a session, loading from database if needed Get cached chat history for a session, loading from database if needed

View File

@@ -12,7 +12,8 @@ from common.utils.business_event_context import current_event
from eveai_chat_workers.specialists.specialist_typing import SpecialistArguments, SpecialistResult from eveai_chat_workers.specialists.specialist_typing import SpecialistArguments, SpecialistResult
from eveai_chat_workers.chat_session_cache import get_chat_history from eveai_chat_workers.chat_session_cache import get_chat_history
from common.models.interaction import Specialist from common.models.interaction import Specialist
from common.utils.model_utils import get_model_variables, create_language_template, replace_variable_in_template from common.utils.model_utils import create_language_template, replace_variable_in_template, \
get_template
from eveai_chat_workers.specialists.base_specialist import BaseSpecialistExecutor from eveai_chat_workers.specialists.base_specialist import BaseSpecialistExecutor
from eveai_chat_workers.retrievers.retriever_typing import RetrieverArguments from eveai_chat_workers.retrievers.retriever_typing import RetrieverArguments
@@ -38,9 +39,6 @@ class SpecialistExecutor(BaseSpecialistExecutor):
# Initialize retrievers # Initialize retrievers
self.retrievers = self._initialize_retrievers() self.retrievers = self._initialize_retrievers()
# Initialize model variables
self.model_variables = get_model_variables(tenant_id)
@property @property
def type(self) -> str: def type(self) -> str:
return "STANDARD_RAG_SPECIALIST" return "STANDARD_RAG_SPECIALIST"
@@ -68,8 +66,7 @@ class SpecialistExecutor(BaseSpecialistExecutor):
]) ])
# Get LLM and template # Get LLM and template
llm = self.model_variables.get_llm(temperature=0.3) template, llm = get_template("history", temperature=0.3)
template = self.model_variables.get_template('history')
language_template = create_language_template(template, language) language_template = create_language_template(template, language)
# Create prompt # Create prompt
@@ -179,11 +176,7 @@ class SpecialistExecutor(BaseSpecialistExecutor):
with current_event.create_span("Specialist RAG invocation"): with current_event.create_span("Specialist RAG invocation"):
try: try:
self.update_progress(self.task_id, "EveAI Chain Start", {}) self.update_progress(self.task_id, "EveAI Chain Start", {})
# Get LLM with specified temperature template, llm = get_template("rag", self.temperature)
llm = self.model_variables.get_llm(temperature=self.temperature)
# Get template
template = self.model_variables.get_template('rag')
language_template = create_language_template(template, language) language_template = create_language_template(template, language)
full_template = replace_variable_in_template( full_template = replace_variable_in_template(
language_template, language_template,

View File

@@ -8,7 +8,7 @@ from langchain_core.prompts import ChatPromptTemplate
from common.models.interaction import Specialist from common.models.interaction import Specialist
from common.utils.business_event_context import current_event from common.utils.business_event_context import current_event
from common.utils.model_utils import get_model_variables, get_crewai_llm, create_language_template from common.utils.model_utils import get_crewai_llm, create_language_template, get_template
from eveai_chat_workers.retrievers.retriever_typing import RetrieverArguments from eveai_chat_workers.retrievers.retriever_typing import RetrieverArguments
from eveai_chat_workers.specialists.crewai_base_classes import EveAICrewAIAgent, EveAICrewAITask from eveai_chat_workers.specialists.crewai_base_classes import EveAICrewAIAgent, EveAICrewAITask
from crewai.tools import BaseTool from crewai.tools import BaseTool
@@ -41,9 +41,6 @@ class CrewAIBaseSpecialistExecutor(BaseSpecialistExecutor):
# Initialize retrievers # Initialize retrievers
self.retrievers = self._initialize_retrievers() self.retrievers = self._initialize_retrievers()
# Initialize model variables
self.model_variables = get_model_variables(tenant_id)
# initialize the Flow # initialize the Flow
self.flow = None self.flow = None
@@ -212,8 +209,7 @@ class CrewAIBaseSpecialistExecutor(BaseSpecialistExecutor):
try: try:
with current_event.create_span("Specialist Detail Question"): with current_event.create_span("Specialist Detail Question"):
# Get LLM and template # Get LLM and template
llm = self.model_variables.get_llm(temperature=0.3) template, llm = get_template("history", temperature=0.3)
template = cache_manager.prompts_config_cache.get_config('history').get('content', '')
language_template = create_language_template(template, language) language_template = create_language_template(template, language)
# Create prompt # Create prompt

View File

@@ -8,20 +8,23 @@ import tempfile
from common.extensions import minio_client from common.extensions import minio_client
import subprocess import subprocess
from flask import current_app
from common.utils.model_utils import get_transcription_model
from .processor_registry import ProcessorRegistry from .processor_registry import ProcessorRegistry
from .transcription_processor import TranscriptionBaseProcessor from .transcription_processor import TranscriptionBaseProcessor
from common.utils.business_event_context import current_event from common.utils.business_event_context import current_event
class AudioProcessor(TranscriptionBaseProcessor): class AudioProcessor(TranscriptionBaseProcessor):
def __init__(self, tenant, model_variables, document_version, catalog, processor): def __init__(self, tenant, document_version, catalog, processor):
super().__init__(tenant, model_variables, document_version, catalog, processor) super().__init__(tenant, document_version, catalog, processor)
self.transcription_model = model_variables.transcription_model self.transcription_model = get_transcription_model()
self.ffmpeg_path = 'ffmpeg' self.ffmpeg_path = 'ffmpeg'
self.max_compression_duration = model_variables.max_compression_duration self.max_compression_duration = current_app.config['MAX_COMPRESSION_DURATION']
self.max_transcription_duration = model_variables.max_transcription_duration self.max_transcription_duration = current_app.config['MAX_TRANSCRIPTION_DURATION']
self.compression_cpu_limit = model_variables.compression_cpu_limit # CPU usage limit in percentage self.compression_cpu_limit = current_app.config['COMPRESSION_CPU_LIMIT'] # CPU usage limit in percentage
self.compression_process_delay = model_variables.compression_process_delay # Delay between processing chunks in seconds self.compression_process_delay = current_app.config['COMPRESSION_PROCESS_DELAY'] # Delay between processing chunks in seconds
self.file_type = document_version.file_type self.file_type = document_version.file_type
def _get_transcription(self): def _get_transcription(self):
@@ -154,7 +157,7 @@ class AudioProcessor(TranscriptionBaseProcessor):
file_size = os.path.getsize(temp_audio.name) file_size = os.path.getsize(temp_audio.name)
with open(temp_audio.name, 'rb') as audio_file: with open(temp_audio.name, 'rb') as audio_file:
transcription = self.model_variables.transcription_model.transcribe( transcription = self.transcription_model.transcribe(
file=audio_file, file=audio_file,
language=self.document_version.language, language=self.document_version.language,
response_format='verbose_json', response_format='verbose_json',

View File

@@ -7,9 +7,8 @@ from config.logging_config import TuningLogger
class BaseProcessor(ABC): class BaseProcessor(ABC):
def __init__(self, tenant, model_variables, document_version, catalog, processor): def __init__(self, tenant, document_version, catalog, processor):
self.tenant = tenant self.tenant = tenant
self.model_variables = model_variables
self.document_version = document_version self.document_version = document_version
self.catalog = catalog self.catalog = catalog
self.processor = processor self.processor = processor

View File

@@ -7,8 +7,8 @@ import re
class DocxProcessor(BaseProcessor): class DocxProcessor(BaseProcessor):
def __init__(self, tenant, model_variables, document_version, catalog, processor): def __init__(self, tenant, document_version, catalog, processor):
super().__init__(tenant, model_variables, document_version, catalog, processor) super().__init__(tenant, document_version, catalog, processor)
self.config = processor.configuration self.config = processor.configuration
self.extract_comments = self.config.get('extract_comments', False) self.extract_comments = self.config.get('extract_comments', False)
self.extract_headers_footers = self.config.get('extract_headers_footers', False) self.extract_headers_footers = self.config.get('extract_headers_footers', False)

View File

@@ -11,8 +11,8 @@ from common.utils.string_list_converter import StringListConverter as SLC
class HTMLProcessor(BaseProcessor): class HTMLProcessor(BaseProcessor):
def __init__(self, tenant, model_variables, document_version, catalog, processor): def __init__(self, tenant, document_version, catalog, processor):
super().__init__(tenant, model_variables, document_version, catalog, processor) super().__init__(tenant, document_version, catalog, processor)
cat_conf = catalog.configuration cat_conf = catalog.configuration
proc_conf = processor.configuration proc_conf = processor.configuration
self.html_tags = SLC.string_to_list(proc_conf['html_tags']) self.html_tags = SLC.string_to_list(proc_conf['html_tags'])

View File

@@ -18,8 +18,8 @@ def _find_first_h1(markdown: str) -> str:
class MarkdownProcessor(BaseProcessor): class MarkdownProcessor(BaseProcessor):
def __init__(self, tenant, model_variables, document_version, catalog, processor): def __init__(self, tenant, document_version, catalog, processor):
super().__init__(tenant, model_variables, document_version, catalog, processor) super().__init__(tenant, document_version, catalog, processor)
self.chunk_size = catalog.max_chunk_size self.chunk_size = catalog.max_chunk_size
self.chunk_overlap = 0 self.chunk_overlap = 0

View File

@@ -16,8 +16,8 @@ from .processor_registry import ProcessorRegistry
class PDFProcessor(BaseProcessor): class PDFProcessor(BaseProcessor):
def __init__(self, tenant, model_variables, document_version, catalog, processor): def __init__(self, tenant, document_version, catalog, processor):
super().__init__(tenant, model_variables, document_version, catalog, processor) super().__init__(tenant, document_version, catalog, processor)
self.chunk_size = catalog.max_chunk_size self.chunk_size = catalog.max_chunk_size
self.chunk_overlap = 0 self.chunk_overlap = 0

View File

@@ -3,6 +3,7 @@ from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_core.output_parsers import StrOutputParser from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough from langchain_core.runnables import RunnablePassthrough
from flask import current_app
from common.utils.model_utils import create_language_template, get_embedding_llm, get_template from common.utils.model_utils import create_language_template, get_embedding_llm, get_template
from .base_processor import BaseProcessor from .base_processor import BaseProcessor
@@ -10,9 +11,9 @@ from common.utils.business_event_context import current_event
class TranscriptionBaseProcessor(BaseProcessor): class TranscriptionBaseProcessor(BaseProcessor):
def __init__(self, tenant, model_variables, document_version, catalog, processor): def __init__(self, tenant, document_version, catalog, processor):
super().__init__(tenant, model_variables, document_version, catalog, processor) super().__init__(tenant, document_version, catalog, processor)
self.annotation_chunk_size = model_variables.annotation_chunk_length self.annotation_chunk_size = current_app.config['ANNOTATION_TEXT_CHUNK_LENGTH']
self.annotation_chunk_overlap = 0 self.annotation_chunk_overlap = 0
def process(self): def process(self):

View File

@@ -17,8 +17,7 @@ from common.models.document import DocumentVersion, Embedding, Document, Process
from common.models.user import Tenant from common.models.user import Tenant
from common.utils.celery_utils import current_celery from common.utils.celery_utils import current_celery
from common.utils.database import Database from common.utils.database import Database
from common.utils.model_utils import create_language_template, get_model_variables, get_embedding_model_and_class, \ from common.utils.model_utils import create_language_template, get_embedding_model_and_class, get_template
get_embedding_llm, get_template
from common.utils.business_event import BusinessEvent from common.utils.business_event import BusinessEvent
from common.utils.business_event_context import current_event from common.utils.business_event_context import current_event
@@ -58,9 +57,6 @@ def create_embeddings(tenant_id, document_version_id):
catalog_id = doc.catalog_id catalog_id = doc.catalog_id
catalog = Catalog.query.get_or_404(catalog_id) catalog = Catalog.query.get_or_404(catalog_id)
# Select variables to work with depending on tenant and model
model_variables = get_model_variables(tenant_id)
# Define processor related information # Define processor related information
processor_type, processor_class = ProcessorRegistry.get_processor_for_file_type(document_version.file_type) processor_type, processor_class = ProcessorRegistry.get_processor_for_file_type(document_version.file_type)
processor = get_processor_for_document(catalog_id, document_version.file_type, document_version.sub_file_type) processor = get_processor_for_document(catalog_id, document_version.file_type, document_version.sub_file_type)
@@ -102,7 +98,6 @@ def create_embeddings(tenant_id, document_version_id):
with current_event.create_span(f"{processor_type} Processing"): with current_event.create_span(f"{processor_type} Processing"):
document_processor = processor_class( document_processor = processor_class(
tenant=tenant, tenant=tenant,
model_variables=model_variables,
document_version=document_version, document_version=document_version,
catalog=catalog, catalog=catalog,
processor=processor processor=processor
@@ -114,7 +109,7 @@ def create_embeddings(tenant_id, document_version_id):
}) })
with current_event.create_span("Embedding"): with current_event.create_span("Embedding"):
embed_markdown(tenant, model_variables, document_version, catalog, document_processor, markdown, title) embed_markdown(tenant, document_version, catalog, document_processor, markdown, title)
current_event.log("Finished Embedding Creation Task") current_event.log("Finished Embedding Creation Task")
@@ -142,7 +137,7 @@ def delete_embeddings_for_document_version(document_version):
raise raise
def embed_markdown(tenant, model_variables, document_version, catalog, processor, markdown, title): def embed_markdown(tenant, document_version, catalog, processor, markdown, title):
# Create potential chunks # Create potential chunks
potential_chunks = create_potential_chunks_for_markdown(tenant.id, document_version, processor, markdown, potential_chunks = create_potential_chunks_for_markdown(tenant.id, document_version, processor, markdown,
catalog.max_chunk_size) catalog.max_chunk_size)
@@ -154,7 +149,7 @@ def embed_markdown(tenant, model_variables, document_version, catalog, processor
# Enrich chunks # Enrich chunks
with current_event.create_span("Enrich Chunks"): with current_event.create_span("Enrich Chunks"):
enriched_chunks = enrich_chunks(tenant, model_variables, document_version, title, chunks) enriched_chunks = enrich_chunks(tenant, document_version, title, chunks)
processor.log_tuning("Enriched Chunks: ", {'enriched_chunks': enriched_chunks}) processor.log_tuning("Enriched Chunks: ", {'enriched_chunks': enriched_chunks})
# Create embeddings # Create embeddings
@@ -178,10 +173,10 @@ def embed_markdown(tenant, model_variables, document_version, catalog, processor
f'on document version {document_version.id} :-)') f'on document version {document_version.id} :-)')
def enrich_chunks(tenant, model_variables, document_version, title, chunks): def enrich_chunks(tenant, document_version, title, chunks):
summary = '' summary = ''
if len(chunks) > 1: if len(chunks) > 1:
summary = summarize_chunk(tenant, model_variables, document_version, chunks[0]) summary = summarize_chunk(tenant, document_version, chunks[0])
chunk_total_context = (f'Filename: {document_version.object_name}\n' chunk_total_context = (f'Filename: {document_version.object_name}\n'
f'User Context:\n{document_version.user_context}\n\n' f'User Context:\n{document_version.user_context}\n\n'
@@ -209,7 +204,7 @@ def enrich_chunks(tenant, model_variables, document_version, title, chunks):
return enriched_chunks return enriched_chunks
def summarize_chunk(tenant, model_variables, document_version, chunk): def summarize_chunk(tenant, document_version, chunk):
current_event.log("Starting Summarizing Chunk") current_event.log("Starting Summarizing Chunk")
template, llm = get_template("summary") template, llm = get_template("summary")

View File

@@ -0,0 +1,31 @@
"""Remove rag_context and llm_model from Tenant
Revision ID: 7d3c6f48735c
Revises: 4eae969dcac2
Create Date: 2025-05-20 08:09:23.673137
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '7d3c6f48735c'
down_revision = '4eae969dcac2'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('tenant', schema=None) as batch_op:
batch_op.drop_column('rag_context')
batch_op.drop_column('llm_model')
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
pass
# ### end Alembic commands ###