Remove ModelVariables (model_utils) from application & optimize Tenant
This commit is contained in:
@@ -24,17 +24,13 @@ class Tenant(db.Model):
|
||||
name = db.Column(db.String(80), unique=True, nullable=False)
|
||||
website = db.Column(db.String(255), nullable=True)
|
||||
timezone = db.Column(db.String(50), nullable=True, default='UTC')
|
||||
rag_context = db.Column(db.Text, nullable=True)
|
||||
type = db.Column(db.String(20), nullable=True, server_default='Active')
|
||||
|
||||
# language information
|
||||
default_language = db.Column(db.String(2), nullable=True)
|
||||
allowed_languages = db.Column(ARRAY(sa.String(2)), nullable=True)
|
||||
|
||||
# LLM specific choices
|
||||
llm_model = db.Column(db.String(50), nullable=True)
|
||||
|
||||
# Entitlements
|
||||
# Entitlements
|
||||
currency = db.Column(db.String(20), nullable=True)
|
||||
storage_dirty = db.Column(db.Boolean, nullable=True, default=False)
|
||||
|
||||
@@ -62,11 +58,9 @@ class Tenant(db.Model):
|
||||
'name': self.name,
|
||||
'website': self.website,
|
||||
'timezone': self.timezone,
|
||||
'rag_context': self.rag_context,
|
||||
'type': self.type,
|
||||
'default_language': self.default_language,
|
||||
'allowed_languages': self.allowed_languages,
|
||||
'llm_model': self.llm_model,
|
||||
'currency': self.currency,
|
||||
}
|
||||
|
||||
|
||||
@@ -139,142 +139,26 @@ def process_pdf():
|
||||
full_model_name = 'mistral-ocr-latest'
|
||||
|
||||
|
||||
def get_template(template_name: str, version: Optional[str] = "1.0") -> tuple[
|
||||
def get_template(template_name: str, version: Optional[str] = "1.0", temperature: float = 0.3) -> tuple[
|
||||
Any, BaseChatModel | None | ChatOpenAI | ChatMistralAI]:
|
||||
"""
|
||||
Get a prompt template
|
||||
"""
|
||||
prompt = cache_manager.prompts_config_cache.get_config(template_name, version)
|
||||
if "llm_model" in prompt:
|
||||
llm = get_embedding_llm(full_model_name=prompt["llm_model"])
|
||||
llm = get_embedding_llm(full_model_name=prompt["llm_model"], temperature=temperature)
|
||||
else:
|
||||
llm = get_embedding_llm()
|
||||
llm = get_embedding_llm(temperature=temperature)
|
||||
|
||||
return prompt["content"], llm
|
||||
|
||||
|
||||
class ModelVariables:
|
||||
"""Manages model-related variables and configurations"""
|
||||
|
||||
def __init__(self, tenant_id: int, variables: Dict[str, Any] = None):
|
||||
"""
|
||||
Initialize ModelVariables with tenant and optional template manager
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant instance
|
||||
variables: Optional variables
|
||||
"""
|
||||
current_app.logger.info(f'Model variables initialized with tenant {tenant_id} and variables \n{variables}')
|
||||
self.tenant_id = tenant_id
|
||||
self._variables = variables if variables is not None else self._initialize_variables()
|
||||
current_app.logger.info(f'Model _variables initialized to {self._variables}')
|
||||
self._llm_instances = {}
|
||||
self.llm_metrics_handler = LLMMetricsHandler()
|
||||
self._transcription_model = None
|
||||
|
||||
def _initialize_variables(self) -> Dict[str, Any]:
|
||||
"""Initialize the variables dictionary"""
|
||||
variables = {}
|
||||
|
||||
tenant = Tenant.query.get(self.tenant_id)
|
||||
if not tenant:
|
||||
raise EveAITenantNotFound(self.tenant_id)
|
||||
|
||||
# Set model providers
|
||||
variables['llm_provider'], variables['llm_model'] = tenant.llm_model.split('.')
|
||||
variables['llm_full_model'] = tenant.llm_model
|
||||
|
||||
# Set model-specific configurations
|
||||
model_config = MODEL_CONFIG.get(variables['llm_provider'], {}).get(variables['llm_model'], {})
|
||||
variables.update(model_config)
|
||||
|
||||
# Additional configurations
|
||||
variables['annotation_chunk_length'] = current_app.config['ANNOTATION_TEXT_CHUNK_LENGTH'][tenant.llm_model]
|
||||
variables['max_compression_duration'] = current_app.config['MAX_COMPRESSION_DURATION']
|
||||
variables['max_transcription_duration'] = current_app.config['MAX_TRANSCRIPTION_DURATION']
|
||||
variables['compression_cpu_limit'] = current_app.config['COMPRESSION_CPU_LIMIT']
|
||||
variables['compression_process_delay'] = current_app.config['COMPRESSION_PROCESS_DELAY']
|
||||
|
||||
return variables
|
||||
|
||||
@property
|
||||
def annotation_chunk_length(self):
|
||||
return self._variables['annotation_chunk_length']
|
||||
|
||||
@property
|
||||
def max_compression_duration(self):
|
||||
return self._variables['max_compression_duration']
|
||||
|
||||
@property
|
||||
def max_transcription_duration(self):
|
||||
return self._variables['max_transcription_duration']
|
||||
|
||||
@property
|
||||
def compression_cpu_limit(self):
|
||||
return self._variables['compression_cpu_limit']
|
||||
|
||||
@property
|
||||
def compression_process_delay(self):
|
||||
return self._variables['compression_process_delay']
|
||||
|
||||
def get_llm(self, temperature: float = 0.3, **kwargs) -> Any:
|
||||
"""
|
||||
Get an LLM instance with specific configuration
|
||||
|
||||
Args:
|
||||
temperature: The temperature for the LLM
|
||||
**kwargs: Additional configuration parameters
|
||||
|
||||
Returns:
|
||||
An instance of the configured LLM
|
||||
"""
|
||||
cache_key = f"{temperature}_{hash(frozenset(kwargs.items()))}"
|
||||
|
||||
if cache_key not in self._llm_instances:
|
||||
provider = self._variables['llm_provider']
|
||||
model = self._variables['llm_model']
|
||||
|
||||
if provider == 'openai':
|
||||
self._llm_instances[cache_key] = ChatOpenAI(
|
||||
api_key=os.getenv('OPENAI_API_KEY'),
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
callbacks=[self.llm_metrics_handler],
|
||||
**kwargs
|
||||
)
|
||||
elif provider == 'anthropic':
|
||||
self._llm_instances[cache_key] = ChatAnthropic(
|
||||
api_key=os.getenv('ANTHROPIC_API_KEY'),
|
||||
model=current_app.config['ANTHROPIC_LLM_VERSIONS'][model],
|
||||
temperature=temperature,
|
||||
callbacks=[self.llm_metrics_handler],
|
||||
**kwargs
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported LLM provider: {provider}")
|
||||
|
||||
return self._llm_instances[cache_key]
|
||||
|
||||
@property
|
||||
def transcription_model(self) -> TrackedOpenAITranscription:
|
||||
"""Get the transcription model instance"""
|
||||
if self._transcription_model is None:
|
||||
api_key = os.getenv('OPENAI_API_KEY')
|
||||
self._transcription_model = TrackedOpenAITranscription(
|
||||
api_key=api_key,
|
||||
model='whisper-1'
|
||||
)
|
||||
return self._transcription_model
|
||||
|
||||
# Remove the old transcription-related methods since they're now handled by TrackedOpenAITranscription
|
||||
@property
|
||||
def transcription_client(self):
|
||||
raise DeprecationWarning("Use transcription_model instead")
|
||||
|
||||
def transcribe(self, *args, **kwargs):
|
||||
raise DeprecationWarning("Use transcription_model.transcribe() instead")
|
||||
|
||||
|
||||
# Helper function to get cached model variables
|
||||
def get_model_variables(tenant_id: int) -> ModelVariables:
|
||||
return ModelVariables(tenant_id=tenant_id)
|
||||
def get_transcription_model(model_name: str = "whisper-1") -> TrackedOpenAITranscription:
|
||||
"""
|
||||
Get a transcription model instance
|
||||
"""
|
||||
api_key = os.getenv('OPENAI_API_KEY')
|
||||
return TrackedOpenAITranscription(
|
||||
api_key=api_key,
|
||||
model=model_name
|
||||
)
|
||||
|
||||
@@ -81,11 +81,7 @@ class Config(object):
|
||||
ANTHROPIC_LLM_VERSIONS = {'claude-3-5-sonnet': 'claude-3-5-sonnet-20240620', }
|
||||
|
||||
# Annotation text chunk length
|
||||
ANNOTATION_TEXT_CHUNK_LENGTH = {
|
||||
'openai.gpt-4o': 10000,
|
||||
'openai.gpt-4o-mini': 10000,
|
||||
'anthropic.claude-3-5-sonnet': 8000
|
||||
}
|
||||
ANNOTATION_TEXT_CHUNK_LENGTH = 10000
|
||||
|
||||
# Environemnt Loaders
|
||||
OPENAI_API_KEY = environ.get('OPENAI_API_KEY')
|
||||
|
||||
@@ -10,7 +10,7 @@
|
||||
<form method="post">
|
||||
{{ form.hidden_tag() }}
|
||||
<!-- Main Tenant Information -->
|
||||
{% set main_fields = ['name', 'code', 'website', 'default_language', 'allowed_languages', 'rag_context', 'type'] %}
|
||||
{% set main_fields = ['name', 'code', 'website', 'default_language', 'allowed_languages', 'type'] %}
|
||||
{% for field in form %}
|
||||
{{ render_included_field(field, disabled_fields=main_fields, include_fields=main_fields) }}
|
||||
{% endfor %}
|
||||
@@ -20,11 +20,6 @@
|
||||
<div class="col-lg-12">
|
||||
<div class="nav-wrapper position-relative end-0">
|
||||
<ul class="nav nav-pills nav-fill p-1" role="tablist">
|
||||
<li class="nav-item" role="presentation">
|
||||
<a class="nav-link mb-0 px-0 py-1 active" data-toggle="tab" href="#model-info-tab" role="tab" aria-controls="model-info" aria-selected="true">
|
||||
Model Information
|
||||
</a>
|
||||
</li>
|
||||
<li class="nav-item">
|
||||
<a class="nav-link mb-0 px-0 py-1" data-toggle="tab" href="#license-info-tab" role="tab" aria-controls="license-info" aria-selected="false">
|
||||
License Information
|
||||
@@ -33,13 +28,6 @@
|
||||
</ul>
|
||||
</div>
|
||||
<div class="tab-content tab-space">
|
||||
<!-- Model Information Tab -->
|
||||
<div class="tab-pane fade show active" id="model-info-tab" role="tabpanel">
|
||||
{% set model_fields = ['llm_model'] %}
|
||||
{% for field in form %}
|
||||
{{ render_included_field(field, disabled_fields=model_fields, include_fields=model_fields) }}
|
||||
{% endfor %}
|
||||
</div>
|
||||
<!-- License Information Tab -->
|
||||
<div class="tab-pane fade" id="license-info-tab" role="tabpanel">
|
||||
{% set license_fields = ['currency', 'usage_email', ] %}
|
||||
|
||||
@@ -22,8 +22,6 @@ class TenantForm(FlaskForm):
|
||||
currency = SelectField('Currency', choices=[], validators=[DataRequired()])
|
||||
# Timezone
|
||||
timezone = SelectField('Timezone', choices=[], validators=[DataRequired()])
|
||||
# LLM fields
|
||||
llm_model = SelectField('Large Language Model', choices=[], validators=[DataRequired()])
|
||||
|
||||
# For Super Users only - Allow to assign the tenant to the partner
|
||||
assign_to_partner = BooleanField('Assign to Partner', default=False)
|
||||
|
||||
@@ -177,7 +177,6 @@ def register_chat_session_cache_handlers(cache_manager):
|
||||
cache_manager.register_handler(ChatSessionCacheHandler, 'eveai_chat_workers')
|
||||
|
||||
|
||||
# Helper function similar to get_model_variables
|
||||
def get_chat_history(session_id: str) -> CachedSession:
|
||||
"""
|
||||
Get cached chat history for a session, loading from database if needed
|
||||
|
||||
@@ -12,7 +12,8 @@ from common.utils.business_event_context import current_event
|
||||
from eveai_chat_workers.specialists.specialist_typing import SpecialistArguments, SpecialistResult
|
||||
from eveai_chat_workers.chat_session_cache import get_chat_history
|
||||
from common.models.interaction import Specialist
|
||||
from common.utils.model_utils import get_model_variables, create_language_template, replace_variable_in_template
|
||||
from common.utils.model_utils import create_language_template, replace_variable_in_template, \
|
||||
get_template
|
||||
from eveai_chat_workers.specialists.base_specialist import BaseSpecialistExecutor
|
||||
from eveai_chat_workers.retrievers.retriever_typing import RetrieverArguments
|
||||
|
||||
@@ -38,9 +39,6 @@ class SpecialistExecutor(BaseSpecialistExecutor):
|
||||
# Initialize retrievers
|
||||
self.retrievers = self._initialize_retrievers()
|
||||
|
||||
# Initialize model variables
|
||||
self.model_variables = get_model_variables(tenant_id)
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return "STANDARD_RAG_SPECIALIST"
|
||||
@@ -68,8 +66,7 @@ class SpecialistExecutor(BaseSpecialistExecutor):
|
||||
])
|
||||
|
||||
# Get LLM and template
|
||||
llm = self.model_variables.get_llm(temperature=0.3)
|
||||
template = self.model_variables.get_template('history')
|
||||
template, llm = get_template("history", temperature=0.3)
|
||||
language_template = create_language_template(template, language)
|
||||
|
||||
# Create prompt
|
||||
@@ -179,11 +176,7 @@ class SpecialistExecutor(BaseSpecialistExecutor):
|
||||
with current_event.create_span("Specialist RAG invocation"):
|
||||
try:
|
||||
self.update_progress(self.task_id, "EveAI Chain Start", {})
|
||||
# Get LLM with specified temperature
|
||||
llm = self.model_variables.get_llm(temperature=self.temperature)
|
||||
|
||||
# Get template
|
||||
template = self.model_variables.get_template('rag')
|
||||
template, llm = get_template("rag", self.temperature)
|
||||
language_template = create_language_template(template, language)
|
||||
full_template = replace_variable_in_template(
|
||||
language_template,
|
||||
|
||||
@@ -8,7 +8,7 @@ from langchain_core.prompts import ChatPromptTemplate
|
||||
|
||||
from common.models.interaction import Specialist
|
||||
from common.utils.business_event_context import current_event
|
||||
from common.utils.model_utils import get_model_variables, get_crewai_llm, create_language_template
|
||||
from common.utils.model_utils import get_crewai_llm, create_language_template, get_template
|
||||
from eveai_chat_workers.retrievers.retriever_typing import RetrieverArguments
|
||||
from eveai_chat_workers.specialists.crewai_base_classes import EveAICrewAIAgent, EveAICrewAITask
|
||||
from crewai.tools import BaseTool
|
||||
@@ -41,9 +41,6 @@ class CrewAIBaseSpecialistExecutor(BaseSpecialistExecutor):
|
||||
# Initialize retrievers
|
||||
self.retrievers = self._initialize_retrievers()
|
||||
|
||||
# Initialize model variables
|
||||
self.model_variables = get_model_variables(tenant_id)
|
||||
|
||||
# initialize the Flow
|
||||
self.flow = None
|
||||
|
||||
@@ -212,8 +209,7 @@ class CrewAIBaseSpecialistExecutor(BaseSpecialistExecutor):
|
||||
try:
|
||||
with current_event.create_span("Specialist Detail Question"):
|
||||
# Get LLM and template
|
||||
llm = self.model_variables.get_llm(temperature=0.3)
|
||||
template = cache_manager.prompts_config_cache.get_config('history').get('content', '')
|
||||
template, llm = get_template("history", temperature=0.3)
|
||||
language_template = create_language_template(template, language)
|
||||
|
||||
# Create prompt
|
||||
|
||||
@@ -8,20 +8,23 @@ import tempfile
|
||||
from common.extensions import minio_client
|
||||
import subprocess
|
||||
|
||||
from flask import current_app
|
||||
|
||||
from common.utils.model_utils import get_transcription_model
|
||||
from .processor_registry import ProcessorRegistry
|
||||
from .transcription_processor import TranscriptionBaseProcessor
|
||||
from common.utils.business_event_context import current_event
|
||||
|
||||
|
||||
class AudioProcessor(TranscriptionBaseProcessor):
|
||||
def __init__(self, tenant, model_variables, document_version, catalog, processor):
|
||||
super().__init__(tenant, model_variables, document_version, catalog, processor)
|
||||
self.transcription_model = model_variables.transcription_model
|
||||
def __init__(self, tenant, document_version, catalog, processor):
|
||||
super().__init__(tenant, document_version, catalog, processor)
|
||||
self.transcription_model = get_transcription_model()
|
||||
self.ffmpeg_path = 'ffmpeg'
|
||||
self.max_compression_duration = model_variables.max_compression_duration
|
||||
self.max_transcription_duration = model_variables.max_transcription_duration
|
||||
self.compression_cpu_limit = model_variables.compression_cpu_limit # CPU usage limit in percentage
|
||||
self.compression_process_delay = model_variables.compression_process_delay # Delay between processing chunks in seconds
|
||||
self.max_compression_duration = current_app.config['MAX_COMPRESSION_DURATION']
|
||||
self.max_transcription_duration = current_app.config['MAX_TRANSCRIPTION_DURATION']
|
||||
self.compression_cpu_limit = current_app.config['COMPRESSION_CPU_LIMIT'] # CPU usage limit in percentage
|
||||
self.compression_process_delay = current_app.config['COMPRESSION_PROCESS_DELAY'] # Delay between processing chunks in seconds
|
||||
self.file_type = document_version.file_type
|
||||
|
||||
def _get_transcription(self):
|
||||
@@ -154,7 +157,7 @@ class AudioProcessor(TranscriptionBaseProcessor):
|
||||
file_size = os.path.getsize(temp_audio.name)
|
||||
|
||||
with open(temp_audio.name, 'rb') as audio_file:
|
||||
transcription = self.model_variables.transcription_model.transcribe(
|
||||
transcription = self.transcription_model.transcribe(
|
||||
file=audio_file,
|
||||
language=self.document_version.language,
|
||||
response_format='verbose_json',
|
||||
|
||||
@@ -7,9 +7,8 @@ from config.logging_config import TuningLogger
|
||||
|
||||
|
||||
class BaseProcessor(ABC):
|
||||
def __init__(self, tenant, model_variables, document_version, catalog, processor):
|
||||
def __init__(self, tenant, document_version, catalog, processor):
|
||||
self.tenant = tenant
|
||||
self.model_variables = model_variables
|
||||
self.document_version = document_version
|
||||
self.catalog = catalog
|
||||
self.processor = processor
|
||||
|
||||
@@ -7,8 +7,8 @@ import re
|
||||
|
||||
|
||||
class DocxProcessor(BaseProcessor):
|
||||
def __init__(self, tenant, model_variables, document_version, catalog, processor):
|
||||
super().__init__(tenant, model_variables, document_version, catalog, processor)
|
||||
def __init__(self, tenant, document_version, catalog, processor):
|
||||
super().__init__(tenant, document_version, catalog, processor)
|
||||
self.config = processor.configuration
|
||||
self.extract_comments = self.config.get('extract_comments', False)
|
||||
self.extract_headers_footers = self.config.get('extract_headers_footers', False)
|
||||
|
||||
@@ -11,8 +11,8 @@ from common.utils.string_list_converter import StringListConverter as SLC
|
||||
|
||||
|
||||
class HTMLProcessor(BaseProcessor):
|
||||
def __init__(self, tenant, model_variables, document_version, catalog, processor):
|
||||
super().__init__(tenant, model_variables, document_version, catalog, processor)
|
||||
def __init__(self, tenant, document_version, catalog, processor):
|
||||
super().__init__(tenant, document_version, catalog, processor)
|
||||
cat_conf = catalog.configuration
|
||||
proc_conf = processor.configuration
|
||||
self.html_tags = SLC.string_to_list(proc_conf['html_tags'])
|
||||
|
||||
@@ -18,8 +18,8 @@ def _find_first_h1(markdown: str) -> str:
|
||||
|
||||
|
||||
class MarkdownProcessor(BaseProcessor):
|
||||
def __init__(self, tenant, model_variables, document_version, catalog, processor):
|
||||
super().__init__(tenant, model_variables, document_version, catalog, processor)
|
||||
def __init__(self, tenant, document_version, catalog, processor):
|
||||
super().__init__(tenant, document_version, catalog, processor)
|
||||
|
||||
self.chunk_size = catalog.max_chunk_size
|
||||
self.chunk_overlap = 0
|
||||
|
||||
@@ -16,8 +16,8 @@ from .processor_registry import ProcessorRegistry
|
||||
|
||||
|
||||
class PDFProcessor(BaseProcessor):
|
||||
def __init__(self, tenant, model_variables, document_version, catalog, processor):
|
||||
super().__init__(tenant, model_variables, document_version, catalog, processor)
|
||||
def __init__(self, tenant, document_version, catalog, processor):
|
||||
super().__init__(tenant, document_version, catalog, processor)
|
||||
|
||||
self.chunk_size = catalog.max_chunk_size
|
||||
self.chunk_overlap = 0
|
||||
|
||||
@@ -3,6 +3,7 @@ from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_core.runnables import RunnablePassthrough
|
||||
from flask import current_app
|
||||
|
||||
from common.utils.model_utils import create_language_template, get_embedding_llm, get_template
|
||||
from .base_processor import BaseProcessor
|
||||
@@ -10,9 +11,9 @@ from common.utils.business_event_context import current_event
|
||||
|
||||
|
||||
class TranscriptionBaseProcessor(BaseProcessor):
|
||||
def __init__(self, tenant, model_variables, document_version, catalog, processor):
|
||||
super().__init__(tenant, model_variables, document_version, catalog, processor)
|
||||
self.annotation_chunk_size = model_variables.annotation_chunk_length
|
||||
def __init__(self, tenant, document_version, catalog, processor):
|
||||
super().__init__(tenant, document_version, catalog, processor)
|
||||
self.annotation_chunk_size = current_app.config['ANNOTATION_TEXT_CHUNK_LENGTH']
|
||||
self.annotation_chunk_overlap = 0
|
||||
|
||||
def process(self):
|
||||
|
||||
@@ -17,8 +17,7 @@ from common.models.document import DocumentVersion, Embedding, Document, Process
|
||||
from common.models.user import Tenant
|
||||
from common.utils.celery_utils import current_celery
|
||||
from common.utils.database import Database
|
||||
from common.utils.model_utils import create_language_template, get_model_variables, get_embedding_model_and_class, \
|
||||
get_embedding_llm, get_template
|
||||
from common.utils.model_utils import create_language_template, get_embedding_model_and_class, get_template
|
||||
|
||||
from common.utils.business_event import BusinessEvent
|
||||
from common.utils.business_event_context import current_event
|
||||
@@ -58,9 +57,6 @@ def create_embeddings(tenant_id, document_version_id):
|
||||
catalog_id = doc.catalog_id
|
||||
catalog = Catalog.query.get_or_404(catalog_id)
|
||||
|
||||
# Select variables to work with depending on tenant and model
|
||||
model_variables = get_model_variables(tenant_id)
|
||||
|
||||
# Define processor related information
|
||||
processor_type, processor_class = ProcessorRegistry.get_processor_for_file_type(document_version.file_type)
|
||||
processor = get_processor_for_document(catalog_id, document_version.file_type, document_version.sub_file_type)
|
||||
@@ -102,7 +98,6 @@ def create_embeddings(tenant_id, document_version_id):
|
||||
with current_event.create_span(f"{processor_type} Processing"):
|
||||
document_processor = processor_class(
|
||||
tenant=tenant,
|
||||
model_variables=model_variables,
|
||||
document_version=document_version,
|
||||
catalog=catalog,
|
||||
processor=processor
|
||||
@@ -114,7 +109,7 @@ def create_embeddings(tenant_id, document_version_id):
|
||||
})
|
||||
|
||||
with current_event.create_span("Embedding"):
|
||||
embed_markdown(tenant, model_variables, document_version, catalog, document_processor, markdown, title)
|
||||
embed_markdown(tenant, document_version, catalog, document_processor, markdown, title)
|
||||
|
||||
current_event.log("Finished Embedding Creation Task")
|
||||
|
||||
@@ -142,7 +137,7 @@ def delete_embeddings_for_document_version(document_version):
|
||||
raise
|
||||
|
||||
|
||||
def embed_markdown(tenant, model_variables, document_version, catalog, processor, markdown, title):
|
||||
def embed_markdown(tenant, document_version, catalog, processor, markdown, title):
|
||||
# Create potential chunks
|
||||
potential_chunks = create_potential_chunks_for_markdown(tenant.id, document_version, processor, markdown,
|
||||
catalog.max_chunk_size)
|
||||
@@ -154,7 +149,7 @@ def embed_markdown(tenant, model_variables, document_version, catalog, processor
|
||||
|
||||
# Enrich chunks
|
||||
with current_event.create_span("Enrich Chunks"):
|
||||
enriched_chunks = enrich_chunks(tenant, model_variables, document_version, title, chunks)
|
||||
enriched_chunks = enrich_chunks(tenant, document_version, title, chunks)
|
||||
processor.log_tuning("Enriched Chunks: ", {'enriched_chunks': enriched_chunks})
|
||||
|
||||
# Create embeddings
|
||||
@@ -178,10 +173,10 @@ def embed_markdown(tenant, model_variables, document_version, catalog, processor
|
||||
f'on document version {document_version.id} :-)')
|
||||
|
||||
|
||||
def enrich_chunks(tenant, model_variables, document_version, title, chunks):
|
||||
def enrich_chunks(tenant, document_version, title, chunks):
|
||||
summary = ''
|
||||
if len(chunks) > 1:
|
||||
summary = summarize_chunk(tenant, model_variables, document_version, chunks[0])
|
||||
summary = summarize_chunk(tenant, document_version, chunks[0])
|
||||
|
||||
chunk_total_context = (f'Filename: {document_version.object_name}\n'
|
||||
f'User Context:\n{document_version.user_context}\n\n'
|
||||
@@ -209,7 +204,7 @@ def enrich_chunks(tenant, model_variables, document_version, title, chunks):
|
||||
return enriched_chunks
|
||||
|
||||
|
||||
def summarize_chunk(tenant, model_variables, document_version, chunk):
|
||||
def summarize_chunk(tenant, document_version, chunk):
|
||||
current_event.log("Starting Summarizing Chunk")
|
||||
|
||||
template, llm = get_template("summary")
|
||||
|
||||
@@ -0,0 +1,31 @@
|
||||
"""Remove rag_context and llm_model from Tenant
|
||||
|
||||
Revision ID: 7d3c6f48735c
|
||||
Revises: 4eae969dcac2
|
||||
Create Date: 2025-05-20 08:09:23.673137
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '7d3c6f48735c'
|
||||
down_revision = '4eae969dcac2'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('tenant', schema=None) as batch_op:
|
||||
batch_op.drop_column('rag_context')
|
||||
batch_op.drop_column('llm_model')
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
pass
|
||||
# ### end Alembic commands ###
|
||||
Reference in New Issue
Block a user