Remove ModelVariables (model_utils) from application & optimize Tenant
This commit is contained in:
@@ -24,17 +24,13 @@ class Tenant(db.Model):
|
|||||||
name = db.Column(db.String(80), unique=True, nullable=False)
|
name = db.Column(db.String(80), unique=True, nullable=False)
|
||||||
website = db.Column(db.String(255), nullable=True)
|
website = db.Column(db.String(255), nullable=True)
|
||||||
timezone = db.Column(db.String(50), nullable=True, default='UTC')
|
timezone = db.Column(db.String(50), nullable=True, default='UTC')
|
||||||
rag_context = db.Column(db.Text, nullable=True)
|
|
||||||
type = db.Column(db.String(20), nullable=True, server_default='Active')
|
type = db.Column(db.String(20), nullable=True, server_default='Active')
|
||||||
|
|
||||||
# language information
|
# language information
|
||||||
default_language = db.Column(db.String(2), nullable=True)
|
default_language = db.Column(db.String(2), nullable=True)
|
||||||
allowed_languages = db.Column(ARRAY(sa.String(2)), nullable=True)
|
allowed_languages = db.Column(ARRAY(sa.String(2)), nullable=True)
|
||||||
|
|
||||||
# LLM specific choices
|
# Entitlements
|
||||||
llm_model = db.Column(db.String(50), nullable=True)
|
|
||||||
|
|
||||||
# Entitlements
|
|
||||||
currency = db.Column(db.String(20), nullable=True)
|
currency = db.Column(db.String(20), nullable=True)
|
||||||
storage_dirty = db.Column(db.Boolean, nullable=True, default=False)
|
storage_dirty = db.Column(db.Boolean, nullable=True, default=False)
|
||||||
|
|
||||||
@@ -62,11 +58,9 @@ class Tenant(db.Model):
|
|||||||
'name': self.name,
|
'name': self.name,
|
||||||
'website': self.website,
|
'website': self.website,
|
||||||
'timezone': self.timezone,
|
'timezone': self.timezone,
|
||||||
'rag_context': self.rag_context,
|
|
||||||
'type': self.type,
|
'type': self.type,
|
||||||
'default_language': self.default_language,
|
'default_language': self.default_language,
|
||||||
'allowed_languages': self.allowed_languages,
|
'allowed_languages': self.allowed_languages,
|
||||||
'llm_model': self.llm_model,
|
|
||||||
'currency': self.currency,
|
'currency': self.currency,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -139,142 +139,26 @@ def process_pdf():
|
|||||||
full_model_name = 'mistral-ocr-latest'
|
full_model_name = 'mistral-ocr-latest'
|
||||||
|
|
||||||
|
|
||||||
def get_template(template_name: str, version: Optional[str] = "1.0") -> tuple[
|
def get_template(template_name: str, version: Optional[str] = "1.0", temperature: float = 0.3) -> tuple[
|
||||||
Any, BaseChatModel | None | ChatOpenAI | ChatMistralAI]:
|
Any, BaseChatModel | None | ChatOpenAI | ChatMistralAI]:
|
||||||
"""
|
"""
|
||||||
Get a prompt template
|
Get a prompt template
|
||||||
"""
|
"""
|
||||||
prompt = cache_manager.prompts_config_cache.get_config(template_name, version)
|
prompt = cache_manager.prompts_config_cache.get_config(template_name, version)
|
||||||
if "llm_model" in prompt:
|
if "llm_model" in prompt:
|
||||||
llm = get_embedding_llm(full_model_name=prompt["llm_model"])
|
llm = get_embedding_llm(full_model_name=prompt["llm_model"], temperature=temperature)
|
||||||
else:
|
else:
|
||||||
llm = get_embedding_llm()
|
llm = get_embedding_llm(temperature=temperature)
|
||||||
|
|
||||||
return prompt["content"], llm
|
return prompt["content"], llm
|
||||||
|
|
||||||
|
|
||||||
class ModelVariables:
|
def get_transcription_model(model_name: str = "whisper-1") -> TrackedOpenAITranscription:
|
||||||
"""Manages model-related variables and configurations"""
|
"""
|
||||||
|
Get a transcription model instance
|
||||||
def __init__(self, tenant_id: int, variables: Dict[str, Any] = None):
|
"""
|
||||||
"""
|
api_key = os.getenv('OPENAI_API_KEY')
|
||||||
Initialize ModelVariables with tenant and optional template manager
|
return TrackedOpenAITranscription(
|
||||||
|
api_key=api_key,
|
||||||
Args:
|
model=model_name
|
||||||
tenant_id: Tenant instance
|
)
|
||||||
variables: Optional variables
|
|
||||||
"""
|
|
||||||
current_app.logger.info(f'Model variables initialized with tenant {tenant_id} and variables \n{variables}')
|
|
||||||
self.tenant_id = tenant_id
|
|
||||||
self._variables = variables if variables is not None else self._initialize_variables()
|
|
||||||
current_app.logger.info(f'Model _variables initialized to {self._variables}')
|
|
||||||
self._llm_instances = {}
|
|
||||||
self.llm_metrics_handler = LLMMetricsHandler()
|
|
||||||
self._transcription_model = None
|
|
||||||
|
|
||||||
def _initialize_variables(self) -> Dict[str, Any]:
|
|
||||||
"""Initialize the variables dictionary"""
|
|
||||||
variables = {}
|
|
||||||
|
|
||||||
tenant = Tenant.query.get(self.tenant_id)
|
|
||||||
if not tenant:
|
|
||||||
raise EveAITenantNotFound(self.tenant_id)
|
|
||||||
|
|
||||||
# Set model providers
|
|
||||||
variables['llm_provider'], variables['llm_model'] = tenant.llm_model.split('.')
|
|
||||||
variables['llm_full_model'] = tenant.llm_model
|
|
||||||
|
|
||||||
# Set model-specific configurations
|
|
||||||
model_config = MODEL_CONFIG.get(variables['llm_provider'], {}).get(variables['llm_model'], {})
|
|
||||||
variables.update(model_config)
|
|
||||||
|
|
||||||
# Additional configurations
|
|
||||||
variables['annotation_chunk_length'] = current_app.config['ANNOTATION_TEXT_CHUNK_LENGTH'][tenant.llm_model]
|
|
||||||
variables['max_compression_duration'] = current_app.config['MAX_COMPRESSION_DURATION']
|
|
||||||
variables['max_transcription_duration'] = current_app.config['MAX_TRANSCRIPTION_DURATION']
|
|
||||||
variables['compression_cpu_limit'] = current_app.config['COMPRESSION_CPU_LIMIT']
|
|
||||||
variables['compression_process_delay'] = current_app.config['COMPRESSION_PROCESS_DELAY']
|
|
||||||
|
|
||||||
return variables
|
|
||||||
|
|
||||||
@property
|
|
||||||
def annotation_chunk_length(self):
|
|
||||||
return self._variables['annotation_chunk_length']
|
|
||||||
|
|
||||||
@property
|
|
||||||
def max_compression_duration(self):
|
|
||||||
return self._variables['max_compression_duration']
|
|
||||||
|
|
||||||
@property
|
|
||||||
def max_transcription_duration(self):
|
|
||||||
return self._variables['max_transcription_duration']
|
|
||||||
|
|
||||||
@property
|
|
||||||
def compression_cpu_limit(self):
|
|
||||||
return self._variables['compression_cpu_limit']
|
|
||||||
|
|
||||||
@property
|
|
||||||
def compression_process_delay(self):
|
|
||||||
return self._variables['compression_process_delay']
|
|
||||||
|
|
||||||
def get_llm(self, temperature: float = 0.3, **kwargs) -> Any:
|
|
||||||
"""
|
|
||||||
Get an LLM instance with specific configuration
|
|
||||||
|
|
||||||
Args:
|
|
||||||
temperature: The temperature for the LLM
|
|
||||||
**kwargs: Additional configuration parameters
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
An instance of the configured LLM
|
|
||||||
"""
|
|
||||||
cache_key = f"{temperature}_{hash(frozenset(kwargs.items()))}"
|
|
||||||
|
|
||||||
if cache_key not in self._llm_instances:
|
|
||||||
provider = self._variables['llm_provider']
|
|
||||||
model = self._variables['llm_model']
|
|
||||||
|
|
||||||
if provider == 'openai':
|
|
||||||
self._llm_instances[cache_key] = ChatOpenAI(
|
|
||||||
api_key=os.getenv('OPENAI_API_KEY'),
|
|
||||||
model=model,
|
|
||||||
temperature=temperature,
|
|
||||||
callbacks=[self.llm_metrics_handler],
|
|
||||||
**kwargs
|
|
||||||
)
|
|
||||||
elif provider == 'anthropic':
|
|
||||||
self._llm_instances[cache_key] = ChatAnthropic(
|
|
||||||
api_key=os.getenv('ANTHROPIC_API_KEY'),
|
|
||||||
model=current_app.config['ANTHROPIC_LLM_VERSIONS'][model],
|
|
||||||
temperature=temperature,
|
|
||||||
callbacks=[self.llm_metrics_handler],
|
|
||||||
**kwargs
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported LLM provider: {provider}")
|
|
||||||
|
|
||||||
return self._llm_instances[cache_key]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def transcription_model(self) -> TrackedOpenAITranscription:
|
|
||||||
"""Get the transcription model instance"""
|
|
||||||
if self._transcription_model is None:
|
|
||||||
api_key = os.getenv('OPENAI_API_KEY')
|
|
||||||
self._transcription_model = TrackedOpenAITranscription(
|
|
||||||
api_key=api_key,
|
|
||||||
model='whisper-1'
|
|
||||||
)
|
|
||||||
return self._transcription_model
|
|
||||||
|
|
||||||
# Remove the old transcription-related methods since they're now handled by TrackedOpenAITranscription
|
|
||||||
@property
|
|
||||||
def transcription_client(self):
|
|
||||||
raise DeprecationWarning("Use transcription_model instead")
|
|
||||||
|
|
||||||
def transcribe(self, *args, **kwargs):
|
|
||||||
raise DeprecationWarning("Use transcription_model.transcribe() instead")
|
|
||||||
|
|
||||||
|
|
||||||
# Helper function to get cached model variables
|
|
||||||
def get_model_variables(tenant_id: int) -> ModelVariables:
|
|
||||||
return ModelVariables(tenant_id=tenant_id)
|
|
||||||
|
|||||||
@@ -81,11 +81,7 @@ class Config(object):
|
|||||||
ANTHROPIC_LLM_VERSIONS = {'claude-3-5-sonnet': 'claude-3-5-sonnet-20240620', }
|
ANTHROPIC_LLM_VERSIONS = {'claude-3-5-sonnet': 'claude-3-5-sonnet-20240620', }
|
||||||
|
|
||||||
# Annotation text chunk length
|
# Annotation text chunk length
|
||||||
ANNOTATION_TEXT_CHUNK_LENGTH = {
|
ANNOTATION_TEXT_CHUNK_LENGTH = 10000
|
||||||
'openai.gpt-4o': 10000,
|
|
||||||
'openai.gpt-4o-mini': 10000,
|
|
||||||
'anthropic.claude-3-5-sonnet': 8000
|
|
||||||
}
|
|
||||||
|
|
||||||
# Environemnt Loaders
|
# Environemnt Loaders
|
||||||
OPENAI_API_KEY = environ.get('OPENAI_API_KEY')
|
OPENAI_API_KEY = environ.get('OPENAI_API_KEY')
|
||||||
|
|||||||
@@ -10,7 +10,7 @@
|
|||||||
<form method="post">
|
<form method="post">
|
||||||
{{ form.hidden_tag() }}
|
{{ form.hidden_tag() }}
|
||||||
<!-- Main Tenant Information -->
|
<!-- Main Tenant Information -->
|
||||||
{% set main_fields = ['name', 'code', 'website', 'default_language', 'allowed_languages', 'rag_context', 'type'] %}
|
{% set main_fields = ['name', 'code', 'website', 'default_language', 'allowed_languages', 'type'] %}
|
||||||
{% for field in form %}
|
{% for field in form %}
|
||||||
{{ render_included_field(field, disabled_fields=main_fields, include_fields=main_fields) }}
|
{{ render_included_field(field, disabled_fields=main_fields, include_fields=main_fields) }}
|
||||||
{% endfor %}
|
{% endfor %}
|
||||||
@@ -20,11 +20,6 @@
|
|||||||
<div class="col-lg-12">
|
<div class="col-lg-12">
|
||||||
<div class="nav-wrapper position-relative end-0">
|
<div class="nav-wrapper position-relative end-0">
|
||||||
<ul class="nav nav-pills nav-fill p-1" role="tablist">
|
<ul class="nav nav-pills nav-fill p-1" role="tablist">
|
||||||
<li class="nav-item" role="presentation">
|
|
||||||
<a class="nav-link mb-0 px-0 py-1 active" data-toggle="tab" href="#model-info-tab" role="tab" aria-controls="model-info" aria-selected="true">
|
|
||||||
Model Information
|
|
||||||
</a>
|
|
||||||
</li>
|
|
||||||
<li class="nav-item">
|
<li class="nav-item">
|
||||||
<a class="nav-link mb-0 px-0 py-1" data-toggle="tab" href="#license-info-tab" role="tab" aria-controls="license-info" aria-selected="false">
|
<a class="nav-link mb-0 px-0 py-1" data-toggle="tab" href="#license-info-tab" role="tab" aria-controls="license-info" aria-selected="false">
|
||||||
License Information
|
License Information
|
||||||
@@ -33,13 +28,6 @@
|
|||||||
</ul>
|
</ul>
|
||||||
</div>
|
</div>
|
||||||
<div class="tab-content tab-space">
|
<div class="tab-content tab-space">
|
||||||
<!-- Model Information Tab -->
|
|
||||||
<div class="tab-pane fade show active" id="model-info-tab" role="tabpanel">
|
|
||||||
{% set model_fields = ['llm_model'] %}
|
|
||||||
{% for field in form %}
|
|
||||||
{{ render_included_field(field, disabled_fields=model_fields, include_fields=model_fields) }}
|
|
||||||
{% endfor %}
|
|
||||||
</div>
|
|
||||||
<!-- License Information Tab -->
|
<!-- License Information Tab -->
|
||||||
<div class="tab-pane fade" id="license-info-tab" role="tabpanel">
|
<div class="tab-pane fade" id="license-info-tab" role="tabpanel">
|
||||||
{% set license_fields = ['currency', 'usage_email', ] %}
|
{% set license_fields = ['currency', 'usage_email', ] %}
|
||||||
|
|||||||
@@ -22,8 +22,6 @@ class TenantForm(FlaskForm):
|
|||||||
currency = SelectField('Currency', choices=[], validators=[DataRequired()])
|
currency = SelectField('Currency', choices=[], validators=[DataRequired()])
|
||||||
# Timezone
|
# Timezone
|
||||||
timezone = SelectField('Timezone', choices=[], validators=[DataRequired()])
|
timezone = SelectField('Timezone', choices=[], validators=[DataRequired()])
|
||||||
# LLM fields
|
|
||||||
llm_model = SelectField('Large Language Model', choices=[], validators=[DataRequired()])
|
|
||||||
|
|
||||||
# For Super Users only - Allow to assign the tenant to the partner
|
# For Super Users only - Allow to assign the tenant to the partner
|
||||||
assign_to_partner = BooleanField('Assign to Partner', default=False)
|
assign_to_partner = BooleanField('Assign to Partner', default=False)
|
||||||
|
|||||||
@@ -177,7 +177,6 @@ def register_chat_session_cache_handlers(cache_manager):
|
|||||||
cache_manager.register_handler(ChatSessionCacheHandler, 'eveai_chat_workers')
|
cache_manager.register_handler(ChatSessionCacheHandler, 'eveai_chat_workers')
|
||||||
|
|
||||||
|
|
||||||
# Helper function similar to get_model_variables
|
|
||||||
def get_chat_history(session_id: str) -> CachedSession:
|
def get_chat_history(session_id: str) -> CachedSession:
|
||||||
"""
|
"""
|
||||||
Get cached chat history for a session, loading from database if needed
|
Get cached chat history for a session, loading from database if needed
|
||||||
|
|||||||
@@ -12,7 +12,8 @@ from common.utils.business_event_context import current_event
|
|||||||
from eveai_chat_workers.specialists.specialist_typing import SpecialistArguments, SpecialistResult
|
from eveai_chat_workers.specialists.specialist_typing import SpecialistArguments, SpecialistResult
|
||||||
from eveai_chat_workers.chat_session_cache import get_chat_history
|
from eveai_chat_workers.chat_session_cache import get_chat_history
|
||||||
from common.models.interaction import Specialist
|
from common.models.interaction import Specialist
|
||||||
from common.utils.model_utils import get_model_variables, create_language_template, replace_variable_in_template
|
from common.utils.model_utils import create_language_template, replace_variable_in_template, \
|
||||||
|
get_template
|
||||||
from eveai_chat_workers.specialists.base_specialist import BaseSpecialistExecutor
|
from eveai_chat_workers.specialists.base_specialist import BaseSpecialistExecutor
|
||||||
from eveai_chat_workers.retrievers.retriever_typing import RetrieverArguments
|
from eveai_chat_workers.retrievers.retriever_typing import RetrieverArguments
|
||||||
|
|
||||||
@@ -38,9 +39,6 @@ class SpecialistExecutor(BaseSpecialistExecutor):
|
|||||||
# Initialize retrievers
|
# Initialize retrievers
|
||||||
self.retrievers = self._initialize_retrievers()
|
self.retrievers = self._initialize_retrievers()
|
||||||
|
|
||||||
# Initialize model variables
|
|
||||||
self.model_variables = get_model_variables(tenant_id)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def type(self) -> str:
|
def type(self) -> str:
|
||||||
return "STANDARD_RAG_SPECIALIST"
|
return "STANDARD_RAG_SPECIALIST"
|
||||||
@@ -68,8 +66,7 @@ class SpecialistExecutor(BaseSpecialistExecutor):
|
|||||||
])
|
])
|
||||||
|
|
||||||
# Get LLM and template
|
# Get LLM and template
|
||||||
llm = self.model_variables.get_llm(temperature=0.3)
|
template, llm = get_template("history", temperature=0.3)
|
||||||
template = self.model_variables.get_template('history')
|
|
||||||
language_template = create_language_template(template, language)
|
language_template = create_language_template(template, language)
|
||||||
|
|
||||||
# Create prompt
|
# Create prompt
|
||||||
@@ -179,11 +176,7 @@ class SpecialistExecutor(BaseSpecialistExecutor):
|
|||||||
with current_event.create_span("Specialist RAG invocation"):
|
with current_event.create_span("Specialist RAG invocation"):
|
||||||
try:
|
try:
|
||||||
self.update_progress(self.task_id, "EveAI Chain Start", {})
|
self.update_progress(self.task_id, "EveAI Chain Start", {})
|
||||||
# Get LLM with specified temperature
|
template, llm = get_template("rag", self.temperature)
|
||||||
llm = self.model_variables.get_llm(temperature=self.temperature)
|
|
||||||
|
|
||||||
# Get template
|
|
||||||
template = self.model_variables.get_template('rag')
|
|
||||||
language_template = create_language_template(template, language)
|
language_template = create_language_template(template, language)
|
||||||
full_template = replace_variable_in_template(
|
full_template = replace_variable_in_template(
|
||||||
language_template,
|
language_template,
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from langchain_core.prompts import ChatPromptTemplate
|
|||||||
|
|
||||||
from common.models.interaction import Specialist
|
from common.models.interaction import Specialist
|
||||||
from common.utils.business_event_context import current_event
|
from common.utils.business_event_context import current_event
|
||||||
from common.utils.model_utils import get_model_variables, get_crewai_llm, create_language_template
|
from common.utils.model_utils import get_crewai_llm, create_language_template, get_template
|
||||||
from eveai_chat_workers.retrievers.retriever_typing import RetrieverArguments
|
from eveai_chat_workers.retrievers.retriever_typing import RetrieverArguments
|
||||||
from eveai_chat_workers.specialists.crewai_base_classes import EveAICrewAIAgent, EveAICrewAITask
|
from eveai_chat_workers.specialists.crewai_base_classes import EveAICrewAIAgent, EveAICrewAITask
|
||||||
from crewai.tools import BaseTool
|
from crewai.tools import BaseTool
|
||||||
@@ -41,9 +41,6 @@ class CrewAIBaseSpecialistExecutor(BaseSpecialistExecutor):
|
|||||||
# Initialize retrievers
|
# Initialize retrievers
|
||||||
self.retrievers = self._initialize_retrievers()
|
self.retrievers = self._initialize_retrievers()
|
||||||
|
|
||||||
# Initialize model variables
|
|
||||||
self.model_variables = get_model_variables(tenant_id)
|
|
||||||
|
|
||||||
# initialize the Flow
|
# initialize the Flow
|
||||||
self.flow = None
|
self.flow = None
|
||||||
|
|
||||||
@@ -212,8 +209,7 @@ class CrewAIBaseSpecialistExecutor(BaseSpecialistExecutor):
|
|||||||
try:
|
try:
|
||||||
with current_event.create_span("Specialist Detail Question"):
|
with current_event.create_span("Specialist Detail Question"):
|
||||||
# Get LLM and template
|
# Get LLM and template
|
||||||
llm = self.model_variables.get_llm(temperature=0.3)
|
template, llm = get_template("history", temperature=0.3)
|
||||||
template = cache_manager.prompts_config_cache.get_config('history').get('content', '')
|
|
||||||
language_template = create_language_template(template, language)
|
language_template = create_language_template(template, language)
|
||||||
|
|
||||||
# Create prompt
|
# Create prompt
|
||||||
|
|||||||
@@ -8,20 +8,23 @@ import tempfile
|
|||||||
from common.extensions import minio_client
|
from common.extensions import minio_client
|
||||||
import subprocess
|
import subprocess
|
||||||
|
|
||||||
|
from flask import current_app
|
||||||
|
|
||||||
|
from common.utils.model_utils import get_transcription_model
|
||||||
from .processor_registry import ProcessorRegistry
|
from .processor_registry import ProcessorRegistry
|
||||||
from .transcription_processor import TranscriptionBaseProcessor
|
from .transcription_processor import TranscriptionBaseProcessor
|
||||||
from common.utils.business_event_context import current_event
|
from common.utils.business_event_context import current_event
|
||||||
|
|
||||||
|
|
||||||
class AudioProcessor(TranscriptionBaseProcessor):
|
class AudioProcessor(TranscriptionBaseProcessor):
|
||||||
def __init__(self, tenant, model_variables, document_version, catalog, processor):
|
def __init__(self, tenant, document_version, catalog, processor):
|
||||||
super().__init__(tenant, model_variables, document_version, catalog, processor)
|
super().__init__(tenant, document_version, catalog, processor)
|
||||||
self.transcription_model = model_variables.transcription_model
|
self.transcription_model = get_transcription_model()
|
||||||
self.ffmpeg_path = 'ffmpeg'
|
self.ffmpeg_path = 'ffmpeg'
|
||||||
self.max_compression_duration = model_variables.max_compression_duration
|
self.max_compression_duration = current_app.config['MAX_COMPRESSION_DURATION']
|
||||||
self.max_transcription_duration = model_variables.max_transcription_duration
|
self.max_transcription_duration = current_app.config['MAX_TRANSCRIPTION_DURATION']
|
||||||
self.compression_cpu_limit = model_variables.compression_cpu_limit # CPU usage limit in percentage
|
self.compression_cpu_limit = current_app.config['COMPRESSION_CPU_LIMIT'] # CPU usage limit in percentage
|
||||||
self.compression_process_delay = model_variables.compression_process_delay # Delay between processing chunks in seconds
|
self.compression_process_delay = current_app.config['COMPRESSION_PROCESS_DELAY'] # Delay between processing chunks in seconds
|
||||||
self.file_type = document_version.file_type
|
self.file_type = document_version.file_type
|
||||||
|
|
||||||
def _get_transcription(self):
|
def _get_transcription(self):
|
||||||
@@ -154,7 +157,7 @@ class AudioProcessor(TranscriptionBaseProcessor):
|
|||||||
file_size = os.path.getsize(temp_audio.name)
|
file_size = os.path.getsize(temp_audio.name)
|
||||||
|
|
||||||
with open(temp_audio.name, 'rb') as audio_file:
|
with open(temp_audio.name, 'rb') as audio_file:
|
||||||
transcription = self.model_variables.transcription_model.transcribe(
|
transcription = self.transcription_model.transcribe(
|
||||||
file=audio_file,
|
file=audio_file,
|
||||||
language=self.document_version.language,
|
language=self.document_version.language,
|
||||||
response_format='verbose_json',
|
response_format='verbose_json',
|
||||||
|
|||||||
@@ -7,9 +7,8 @@ from config.logging_config import TuningLogger
|
|||||||
|
|
||||||
|
|
||||||
class BaseProcessor(ABC):
|
class BaseProcessor(ABC):
|
||||||
def __init__(self, tenant, model_variables, document_version, catalog, processor):
|
def __init__(self, tenant, document_version, catalog, processor):
|
||||||
self.tenant = tenant
|
self.tenant = tenant
|
||||||
self.model_variables = model_variables
|
|
||||||
self.document_version = document_version
|
self.document_version = document_version
|
||||||
self.catalog = catalog
|
self.catalog = catalog
|
||||||
self.processor = processor
|
self.processor = processor
|
||||||
|
|||||||
@@ -7,8 +7,8 @@ import re
|
|||||||
|
|
||||||
|
|
||||||
class DocxProcessor(BaseProcessor):
|
class DocxProcessor(BaseProcessor):
|
||||||
def __init__(self, tenant, model_variables, document_version, catalog, processor):
|
def __init__(self, tenant, document_version, catalog, processor):
|
||||||
super().__init__(tenant, model_variables, document_version, catalog, processor)
|
super().__init__(tenant, document_version, catalog, processor)
|
||||||
self.config = processor.configuration
|
self.config = processor.configuration
|
||||||
self.extract_comments = self.config.get('extract_comments', False)
|
self.extract_comments = self.config.get('extract_comments', False)
|
||||||
self.extract_headers_footers = self.config.get('extract_headers_footers', False)
|
self.extract_headers_footers = self.config.get('extract_headers_footers', False)
|
||||||
|
|||||||
@@ -11,8 +11,8 @@ from common.utils.string_list_converter import StringListConverter as SLC
|
|||||||
|
|
||||||
|
|
||||||
class HTMLProcessor(BaseProcessor):
|
class HTMLProcessor(BaseProcessor):
|
||||||
def __init__(self, tenant, model_variables, document_version, catalog, processor):
|
def __init__(self, tenant, document_version, catalog, processor):
|
||||||
super().__init__(tenant, model_variables, document_version, catalog, processor)
|
super().__init__(tenant, document_version, catalog, processor)
|
||||||
cat_conf = catalog.configuration
|
cat_conf = catalog.configuration
|
||||||
proc_conf = processor.configuration
|
proc_conf = processor.configuration
|
||||||
self.html_tags = SLC.string_to_list(proc_conf['html_tags'])
|
self.html_tags = SLC.string_to_list(proc_conf['html_tags'])
|
||||||
|
|||||||
@@ -18,8 +18,8 @@ def _find_first_h1(markdown: str) -> str:
|
|||||||
|
|
||||||
|
|
||||||
class MarkdownProcessor(BaseProcessor):
|
class MarkdownProcessor(BaseProcessor):
|
||||||
def __init__(self, tenant, model_variables, document_version, catalog, processor):
|
def __init__(self, tenant, document_version, catalog, processor):
|
||||||
super().__init__(tenant, model_variables, document_version, catalog, processor)
|
super().__init__(tenant, document_version, catalog, processor)
|
||||||
|
|
||||||
self.chunk_size = catalog.max_chunk_size
|
self.chunk_size = catalog.max_chunk_size
|
||||||
self.chunk_overlap = 0
|
self.chunk_overlap = 0
|
||||||
|
|||||||
@@ -16,8 +16,8 @@ from .processor_registry import ProcessorRegistry
|
|||||||
|
|
||||||
|
|
||||||
class PDFProcessor(BaseProcessor):
|
class PDFProcessor(BaseProcessor):
|
||||||
def __init__(self, tenant, model_variables, document_version, catalog, processor):
|
def __init__(self, tenant, document_version, catalog, processor):
|
||||||
super().__init__(tenant, model_variables, document_version, catalog, processor)
|
super().__init__(tenant, document_version, catalog, processor)
|
||||||
|
|
||||||
self.chunk_size = catalog.max_chunk_size
|
self.chunk_size = catalog.max_chunk_size
|
||||||
self.chunk_overlap = 0
|
self.chunk_overlap = 0
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ from langchain_text_splitters import RecursiveCharacterTextSplitter
|
|||||||
from langchain_core.output_parsers import StrOutputParser
|
from langchain_core.output_parsers import StrOutputParser
|
||||||
from langchain_core.prompts import ChatPromptTemplate
|
from langchain_core.prompts import ChatPromptTemplate
|
||||||
from langchain_core.runnables import RunnablePassthrough
|
from langchain_core.runnables import RunnablePassthrough
|
||||||
|
from flask import current_app
|
||||||
|
|
||||||
from common.utils.model_utils import create_language_template, get_embedding_llm, get_template
|
from common.utils.model_utils import create_language_template, get_embedding_llm, get_template
|
||||||
from .base_processor import BaseProcessor
|
from .base_processor import BaseProcessor
|
||||||
@@ -10,9 +11,9 @@ from common.utils.business_event_context import current_event
|
|||||||
|
|
||||||
|
|
||||||
class TranscriptionBaseProcessor(BaseProcessor):
|
class TranscriptionBaseProcessor(BaseProcessor):
|
||||||
def __init__(self, tenant, model_variables, document_version, catalog, processor):
|
def __init__(self, tenant, document_version, catalog, processor):
|
||||||
super().__init__(tenant, model_variables, document_version, catalog, processor)
|
super().__init__(tenant, document_version, catalog, processor)
|
||||||
self.annotation_chunk_size = model_variables.annotation_chunk_length
|
self.annotation_chunk_size = current_app.config['ANNOTATION_TEXT_CHUNK_LENGTH']
|
||||||
self.annotation_chunk_overlap = 0
|
self.annotation_chunk_overlap = 0
|
||||||
|
|
||||||
def process(self):
|
def process(self):
|
||||||
|
|||||||
@@ -17,8 +17,7 @@ from common.models.document import DocumentVersion, Embedding, Document, Process
|
|||||||
from common.models.user import Tenant
|
from common.models.user import Tenant
|
||||||
from common.utils.celery_utils import current_celery
|
from common.utils.celery_utils import current_celery
|
||||||
from common.utils.database import Database
|
from common.utils.database import Database
|
||||||
from common.utils.model_utils import create_language_template, get_model_variables, get_embedding_model_and_class, \
|
from common.utils.model_utils import create_language_template, get_embedding_model_and_class, get_template
|
||||||
get_embedding_llm, get_template
|
|
||||||
|
|
||||||
from common.utils.business_event import BusinessEvent
|
from common.utils.business_event import BusinessEvent
|
||||||
from common.utils.business_event_context import current_event
|
from common.utils.business_event_context import current_event
|
||||||
@@ -58,9 +57,6 @@ def create_embeddings(tenant_id, document_version_id):
|
|||||||
catalog_id = doc.catalog_id
|
catalog_id = doc.catalog_id
|
||||||
catalog = Catalog.query.get_or_404(catalog_id)
|
catalog = Catalog.query.get_or_404(catalog_id)
|
||||||
|
|
||||||
# Select variables to work with depending on tenant and model
|
|
||||||
model_variables = get_model_variables(tenant_id)
|
|
||||||
|
|
||||||
# Define processor related information
|
# Define processor related information
|
||||||
processor_type, processor_class = ProcessorRegistry.get_processor_for_file_type(document_version.file_type)
|
processor_type, processor_class = ProcessorRegistry.get_processor_for_file_type(document_version.file_type)
|
||||||
processor = get_processor_for_document(catalog_id, document_version.file_type, document_version.sub_file_type)
|
processor = get_processor_for_document(catalog_id, document_version.file_type, document_version.sub_file_type)
|
||||||
@@ -102,7 +98,6 @@ def create_embeddings(tenant_id, document_version_id):
|
|||||||
with current_event.create_span(f"{processor_type} Processing"):
|
with current_event.create_span(f"{processor_type} Processing"):
|
||||||
document_processor = processor_class(
|
document_processor = processor_class(
|
||||||
tenant=tenant,
|
tenant=tenant,
|
||||||
model_variables=model_variables,
|
|
||||||
document_version=document_version,
|
document_version=document_version,
|
||||||
catalog=catalog,
|
catalog=catalog,
|
||||||
processor=processor
|
processor=processor
|
||||||
@@ -114,7 +109,7 @@ def create_embeddings(tenant_id, document_version_id):
|
|||||||
})
|
})
|
||||||
|
|
||||||
with current_event.create_span("Embedding"):
|
with current_event.create_span("Embedding"):
|
||||||
embed_markdown(tenant, model_variables, document_version, catalog, document_processor, markdown, title)
|
embed_markdown(tenant, document_version, catalog, document_processor, markdown, title)
|
||||||
|
|
||||||
current_event.log("Finished Embedding Creation Task")
|
current_event.log("Finished Embedding Creation Task")
|
||||||
|
|
||||||
@@ -142,7 +137,7 @@ def delete_embeddings_for_document_version(document_version):
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
def embed_markdown(tenant, model_variables, document_version, catalog, processor, markdown, title):
|
def embed_markdown(tenant, document_version, catalog, processor, markdown, title):
|
||||||
# Create potential chunks
|
# Create potential chunks
|
||||||
potential_chunks = create_potential_chunks_for_markdown(tenant.id, document_version, processor, markdown,
|
potential_chunks = create_potential_chunks_for_markdown(tenant.id, document_version, processor, markdown,
|
||||||
catalog.max_chunk_size)
|
catalog.max_chunk_size)
|
||||||
@@ -154,7 +149,7 @@ def embed_markdown(tenant, model_variables, document_version, catalog, processor
|
|||||||
|
|
||||||
# Enrich chunks
|
# Enrich chunks
|
||||||
with current_event.create_span("Enrich Chunks"):
|
with current_event.create_span("Enrich Chunks"):
|
||||||
enriched_chunks = enrich_chunks(tenant, model_variables, document_version, title, chunks)
|
enriched_chunks = enrich_chunks(tenant, document_version, title, chunks)
|
||||||
processor.log_tuning("Enriched Chunks: ", {'enriched_chunks': enriched_chunks})
|
processor.log_tuning("Enriched Chunks: ", {'enriched_chunks': enriched_chunks})
|
||||||
|
|
||||||
# Create embeddings
|
# Create embeddings
|
||||||
@@ -178,10 +173,10 @@ def embed_markdown(tenant, model_variables, document_version, catalog, processor
|
|||||||
f'on document version {document_version.id} :-)')
|
f'on document version {document_version.id} :-)')
|
||||||
|
|
||||||
|
|
||||||
def enrich_chunks(tenant, model_variables, document_version, title, chunks):
|
def enrich_chunks(tenant, document_version, title, chunks):
|
||||||
summary = ''
|
summary = ''
|
||||||
if len(chunks) > 1:
|
if len(chunks) > 1:
|
||||||
summary = summarize_chunk(tenant, model_variables, document_version, chunks[0])
|
summary = summarize_chunk(tenant, document_version, chunks[0])
|
||||||
|
|
||||||
chunk_total_context = (f'Filename: {document_version.object_name}\n'
|
chunk_total_context = (f'Filename: {document_version.object_name}\n'
|
||||||
f'User Context:\n{document_version.user_context}\n\n'
|
f'User Context:\n{document_version.user_context}\n\n'
|
||||||
@@ -209,7 +204,7 @@ def enrich_chunks(tenant, model_variables, document_version, title, chunks):
|
|||||||
return enriched_chunks
|
return enriched_chunks
|
||||||
|
|
||||||
|
|
||||||
def summarize_chunk(tenant, model_variables, document_version, chunk):
|
def summarize_chunk(tenant, document_version, chunk):
|
||||||
current_event.log("Starting Summarizing Chunk")
|
current_event.log("Starting Summarizing Chunk")
|
||||||
|
|
||||||
template, llm = get_template("summary")
|
template, llm = get_template("summary")
|
||||||
|
|||||||
@@ -0,0 +1,31 @@
|
|||||||
|
"""Remove rag_context and llm_model from Tenant
|
||||||
|
|
||||||
|
Revision ID: 7d3c6f48735c
|
||||||
|
Revises: 4eae969dcac2
|
||||||
|
Create Date: 2025-05-20 08:09:23.673137
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = '7d3c6f48735c'
|
||||||
|
down_revision = '4eae969dcac2'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
with op.batch_alter_table('tenant', schema=None) as batch_op:
|
||||||
|
batch_op.drop_column('rag_context')
|
||||||
|
batch_op.drop_column('llm_model')
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
pass
|
||||||
|
# ### end Alembic commands ###
|
||||||
Reference in New Issue
Block a user