Extra commit for files in 'common'
- Add functionality to add a default dictionary for configuration fields - Correct entitlement processing - Remove get_template functionality from ModelVariables, define it directly with LLM model definition in configuration file.
This commit is contained in:
@@ -10,7 +10,6 @@ from flask_wtf import CSRFProtect
|
|||||||
from flask_restx import Api
|
from flask_restx import Api
|
||||||
from prometheus_flask_exporter import PrometheusMetrics
|
from prometheus_flask_exporter import PrometheusMetrics
|
||||||
|
|
||||||
from .langchain.templates.template_manager import TemplateManager
|
|
||||||
from .utils.cache.eveai_cache_manager import EveAICacheManager
|
from .utils.cache.eveai_cache_manager import EveAICacheManager
|
||||||
from .utils.simple_encryption import SimpleEncryption
|
from .utils.simple_encryption import SimpleEncryption
|
||||||
from .utils.minio_utils import MinioClient
|
from .utils.minio_utils import MinioClient
|
||||||
@@ -30,6 +29,5 @@ api_rest = Api()
|
|||||||
simple_encryption = SimpleEncryption()
|
simple_encryption = SimpleEncryption()
|
||||||
minio_client = MinioClient()
|
minio_client = MinioClient()
|
||||||
metrics = PrometheusMetrics.for_app_factory()
|
metrics = PrometheusMetrics.for_app_factory()
|
||||||
template_manager = TemplateManager()
|
|
||||||
cache_manager = EveAICacheManager()
|
cache_manager = EveAICacheManager()
|
||||||
|
|
||||||
|
|||||||
@@ -1,153 +0,0 @@
|
|||||||
import os
|
|
||||||
import yaml
|
|
||||||
from typing import Dict, Optional, Any
|
|
||||||
from packaging import version
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from flask import current_app, Flask
|
|
||||||
|
|
||||||
from common.utils.os_utils import get_project_root
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class PromptTemplate:
|
|
||||||
"""Represents a versioned prompt template"""
|
|
||||||
content: str
|
|
||||||
version: str
|
|
||||||
metadata: Dict[str, Any]
|
|
||||||
|
|
||||||
|
|
||||||
class TemplateManager:
|
|
||||||
"""Manages versioned prompt templates"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.templates_dir = None
|
|
||||||
self._templates = None
|
|
||||||
self.app = None
|
|
||||||
|
|
||||||
def init_app(self, app: Flask) -> None:
|
|
||||||
# Initialize template manager
|
|
||||||
base_dir = "/app"
|
|
||||||
self.templates_dir = os.path.join(base_dir, 'config', 'prompts')
|
|
||||||
self.app = app
|
|
||||||
self._templates = self._load_templates()
|
|
||||||
# Log available templates for each supported model
|
|
||||||
for llm in app.config['SUPPORTED_LLMS']:
|
|
||||||
try:
|
|
||||||
available_templates = self.list_templates(llm)
|
|
||||||
app.logger.info(f"Loaded templates for {llm}: {available_templates}")
|
|
||||||
except ValueError:
|
|
||||||
app.logger.warning(f"No templates found for {llm}")
|
|
||||||
|
|
||||||
def _load_templates(self) -> Dict[str, Dict[str, Dict[str, PromptTemplate]]]:
|
|
||||||
"""
|
|
||||||
Load all template versions from the templates directory.
|
|
||||||
Structure: {provider.model -> {template_name -> {version -> template}}}
|
|
||||||
Directory structure:
|
|
||||||
prompts/
|
|
||||||
├── provider/
|
|
||||||
│ └── model/
|
|
||||||
│ └── template_name/
|
|
||||||
│ └── version.yaml
|
|
||||||
"""
|
|
||||||
templates = {}
|
|
||||||
|
|
||||||
# Iterate through providers (anthropic, openai)
|
|
||||||
for provider in os.listdir(self.templates_dir):
|
|
||||||
provider_path = os.path.join(self.templates_dir, provider)
|
|
||||||
if not os.path.isdir(provider_path):
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Iterate through models (claude-3, gpt-4o)
|
|
||||||
for model in os.listdir(provider_path):
|
|
||||||
model_path = os.path.join(provider_path, model)
|
|
||||||
if not os.path.isdir(model_path):
|
|
||||||
continue
|
|
||||||
|
|
||||||
provider_model = f"{provider}.{model}"
|
|
||||||
templates[provider_model] = {}
|
|
||||||
|
|
||||||
# Iterate through template types (rag, summary, etc.)
|
|
||||||
for template_name in os.listdir(model_path):
|
|
||||||
template_path = os.path.join(model_path, template_name)
|
|
||||||
if not os.path.isdir(template_path):
|
|
||||||
continue
|
|
||||||
|
|
||||||
template_versions = {}
|
|
||||||
# Load all version files for this template
|
|
||||||
for version_file in os.listdir(template_path):
|
|
||||||
if not version_file.endswith('.yaml'):
|
|
||||||
continue
|
|
||||||
|
|
||||||
version_str = version_file[:-5] # Remove .yaml
|
|
||||||
if not self._is_valid_version(version_str):
|
|
||||||
current_app.logger.warning(
|
|
||||||
f"Invalid version format for {template_name}: {version_str}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
|
||||||
with open(os.path.join(template_path, version_file)) as f:
|
|
||||||
template_data = yaml.safe_load(f)
|
|
||||||
# Verify required fields
|
|
||||||
if not template_data.get('content'):
|
|
||||||
raise ValueError("Template content is required")
|
|
||||||
|
|
||||||
template_versions[version_str] = PromptTemplate(
|
|
||||||
content=template_data['content'],
|
|
||||||
version=version_str,
|
|
||||||
metadata=template_data.get('metadata', {})
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
current_app.logger.error(
|
|
||||||
f"Error loading template {template_name} version {version_str}: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
if template_versions:
|
|
||||||
templates[provider_model][template_name] = template_versions
|
|
||||||
|
|
||||||
return templates
|
|
||||||
|
|
||||||
def _is_valid_version(self, version_str: str) -> bool:
|
|
||||||
"""Validate semantic versioning string"""
|
|
||||||
try:
|
|
||||||
version.parse(version_str)
|
|
||||||
return True
|
|
||||||
except version.InvalidVersion:
|
|
||||||
return False
|
|
||||||
|
|
||||||
def get_template(self,
|
|
||||||
provider_model: str,
|
|
||||||
template_name: str,
|
|
||||||
template_version: Optional[str] = None) -> PromptTemplate:
|
|
||||||
"""
|
|
||||||
Get a specific template version. If version not specified,
|
|
||||||
returns the latest version.
|
|
||||||
"""
|
|
||||||
if provider_model not in self._templates:
|
|
||||||
raise ValueError(f"Unknown provider.model: {provider_model}")
|
|
||||||
|
|
||||||
if template_name not in self._templates[provider_model]:
|
|
||||||
raise ValueError(f"Unknown template: {template_name}")
|
|
||||||
|
|
||||||
versions = self._templates[provider_model][template_name]
|
|
||||||
|
|
||||||
if template_version:
|
|
||||||
if template_version not in versions:
|
|
||||||
raise ValueError(f"Template version {template_version} not found")
|
|
||||||
return versions[template_version]
|
|
||||||
|
|
||||||
# Return latest version
|
|
||||||
latest = max(versions.keys(), key=version.parse)
|
|
||||||
return versions[latest]
|
|
||||||
|
|
||||||
def list_templates(self, provider_model: str) -> Dict[str, list]:
|
|
||||||
"""
|
|
||||||
List all available templates and their versions for a provider.model
|
|
||||||
Returns: {template_name: [version1, version2, ...]}
|
|
||||||
"""
|
|
||||||
if provider_model not in self._templates:
|
|
||||||
raise ValueError(f"Unknown provider.model: {provider_model}")
|
|
||||||
|
|
||||||
return {
|
|
||||||
template_name: sorted(versions.keys(), key=version.parse)
|
|
||||||
for template_name, versions in self._templates[provider_model].items()
|
|
||||||
}
|
|
||||||
@@ -3,6 +3,7 @@ from sqlalchemy.sql.expression import text
|
|||||||
from common.extensions import db
|
from common.extensions import db
|
||||||
from datetime import datetime as dt, timezone as tz
|
from datetime import datetime as dt, timezone as tz
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from sqlalchemy import event
|
||||||
from sqlalchemy.dialects.postgresql import JSONB
|
from sqlalchemy.dialects.postgresql import JSONB
|
||||||
from sqlalchemy.ext.hybrid import hybrid_property
|
from sqlalchemy.ext.hybrid import hybrid_property
|
||||||
from dateutil.relativedelta import relativedelta
|
from dateutil.relativedelta import relativedelta
|
||||||
@@ -50,6 +51,7 @@ class License(db.Model):
|
|||||||
tenant_id = db.Column(db.Integer, db.ForeignKey('public.tenant.id'), nullable=False)
|
tenant_id = db.Column(db.Integer, db.ForeignKey('public.tenant.id'), nullable=False)
|
||||||
tier_id = db.Column(db.Integer, db.ForeignKey('public.license_tier.id'),nullable=False) # 'small', 'medium', 'custom'
|
tier_id = db.Column(db.Integer, db.ForeignKey('public.license_tier.id'),nullable=False) # 'small', 'medium', 'custom'
|
||||||
start_date = db.Column(db.Date, nullable=False)
|
start_date = db.Column(db.Date, nullable=False)
|
||||||
|
end_date = db.Column(db.Date, nullable=True)
|
||||||
nr_of_periods = db.Column(db.Integer, nullable=False)
|
nr_of_periods = db.Column(db.Integer, nullable=False)
|
||||||
currency = db.Column(db.String(20), nullable=False)
|
currency = db.Column(db.String(20), nullable=False)
|
||||||
yearly_payment = db.Column(db.Boolean, nullable=False, default=False)
|
yearly_payment = db.Column(db.Boolean, nullable=False, default=False)
|
||||||
@@ -81,54 +83,25 @@ class License(db.Model):
|
|||||||
order_by='LicensePeriod.period_number',
|
order_by='LicensePeriod.period_number',
|
||||||
cascade='all, delete-orphan')
|
cascade='all, delete-orphan')
|
||||||
|
|
||||||
@hybrid_property
|
def calculate_end_date(start_date, nr_of_periods):
|
||||||
def end_date(self):
|
"""Utility functie om einddatum te berekenen"""
|
||||||
"""
|
if start_date and nr_of_periods:
|
||||||
Berekent de einddatum van de licentie op basis van start_date en nr_of_periods.
|
return start_date + relativedelta(months=nr_of_periods) - relativedelta(days=1)
|
||||||
Elke periode is 1 maand, dus einddatum = startdatum + nr_of_periods maanden - 1 dag
|
return None
|
||||||
"""
|
|
||||||
if self.start_date and self.nr_of_periods:
|
|
||||||
return self.start_date + relativedelta(months=self.nr_of_periods) - relativedelta(days=1)
|
|
||||||
return None
|
|
||||||
|
|
||||||
@end_date.expression
|
# Luister naar start_date wijzigingen
|
||||||
def end_date(cls):
|
@event.listens_for(License.start_date, 'set')
|
||||||
"""
|
def set_start_date(target, value, oldvalue, initiator):
|
||||||
SQL expressie versie van de end_date property voor gebruik in queries
|
"""Bijwerken van end_date wanneer start_date wordt aangepast"""
|
||||||
"""
|
if value and target.nr_of_periods:
|
||||||
return db.func.date_add(
|
target.end_date = calculate_end_date(value, target.nr_of_periods)
|
||||||
db.func.date_add(
|
|
||||||
cls.start_date,
|
|
||||||
db.text(f'INTERVAL cls.nr_of_periods MONTH')
|
|
||||||
),
|
|
||||||
db.text('INTERVAL -1 DAY')
|
|
||||||
)
|
|
||||||
|
|
||||||
def update_configuration(self, **changes):
|
# Luister naar nr_of_periods wijzigingen
|
||||||
"""
|
@event.listens_for(License.nr_of_periods, 'set')
|
||||||
Update license configuration
|
def set_nr_of_periods(target, value, oldvalue, initiator):
|
||||||
These changes will only apply to future periods, not existing ones
|
"""Bijwerken van end_date wanneer nr_of_periods wordt aangepast"""
|
||||||
|
if value and target.start_date:
|
||||||
Args:
|
target.end_date = calculate_end_date(target.start_date, value)
|
||||||
**changes: Dictionary of changes to apply to the license
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
None
|
|
||||||
"""
|
|
||||||
allowed_fields = [
|
|
||||||
'tier_id', 'currency', 'basic_fee', 'max_storage_mb',
|
|
||||||
'additional_storage_price', 'additional_storage_bucket',
|
|
||||||
'included_embedding_mb', 'additional_embedding_price', 'additional_embedding_bucket',
|
|
||||||
'included_interaction_tokens', 'additional_interaction_token_price',
|
|
||||||
'additional_interaction_bucket', 'overage_embedding', 'overage_interaction',
|
|
||||||
'additional_storage_allowed', 'additional_embedding_allowed',
|
|
||||||
'additional_interaction_allowed'
|
|
||||||
]
|
|
||||||
|
|
||||||
# Apply only allowed changes
|
|
||||||
for key, value in changes.items():
|
|
||||||
if key in allowed_fields:
|
|
||||||
setattr(self, key, value)
|
|
||||||
|
|
||||||
|
|
||||||
class LicenseTier(db.Model):
|
class LicenseTier(db.Model):
|
||||||
@@ -209,22 +182,22 @@ class LicensePeriod(db.Model):
|
|||||||
period_end = db.Column(db.Date, nullable=False)
|
period_end = db.Column(db.Date, nullable=False)
|
||||||
|
|
||||||
# License configuration snapshot - copied from license when period is created
|
# License configuration snapshot - copied from license when period is created
|
||||||
currency = db.Column(db.String(20), nullable=False)
|
currency = db.Column(db.String(20), nullable=True)
|
||||||
basic_fee = db.Column(db.Float, nullable=False)
|
basic_fee = db.Column(db.Float, nullable=True)
|
||||||
max_storage_mb = db.Column(db.Integer, nullable=False)
|
max_storage_mb = db.Column(db.Integer, nullable=True)
|
||||||
additional_storage_price = db.Column(db.Float, nullable=False)
|
additional_storage_price = db.Column(db.Float, nullable=True)
|
||||||
additional_storage_bucket = db.Column(db.Integer, nullable=False)
|
additional_storage_bucket = db.Column(db.Integer, nullable=True)
|
||||||
included_embedding_mb = db.Column(db.Integer, nullable=False)
|
included_embedding_mb = db.Column(db.Integer, nullable=True)
|
||||||
additional_embedding_price = db.Column(db.Numeric(10, 4), nullable=False)
|
additional_embedding_price = db.Column(db.Numeric(10, 4), nullable=True)
|
||||||
additional_embedding_bucket = db.Column(db.Integer, nullable=False)
|
additional_embedding_bucket = db.Column(db.Integer, nullable=True)
|
||||||
included_interaction_tokens = db.Column(db.Integer, nullable=False)
|
included_interaction_tokens = db.Column(db.Integer, nullable=True)
|
||||||
additional_interaction_token_price = db.Column(db.Numeric(10, 4), nullable=False)
|
additional_interaction_token_price = db.Column(db.Numeric(10, 4), nullable=True)
|
||||||
additional_interaction_bucket = db.Column(db.Integer, nullable=False)
|
additional_interaction_bucket = db.Column(db.Integer, nullable=True)
|
||||||
|
|
||||||
# Allowance flags - can be changed from False to True within a period
|
# Allowance flags - can be changed from False to True within a period
|
||||||
additional_storage_allowed = db.Column(db.Boolean, nullable=False, default=False)
|
additional_storage_allowed = db.Column(db.Boolean, nullable=True, default=False)
|
||||||
additional_embedding_allowed = db.Column(db.Boolean, nullable=False, default=False)
|
additional_embedding_allowed = db.Column(db.Boolean, nullable=True, default=False)
|
||||||
additional_interaction_allowed = db.Column(db.Boolean, nullable=False, default=False)
|
additional_interaction_allowed = db.Column(db.Boolean, nullable=True, default=False)
|
||||||
|
|
||||||
# Status tracking
|
# Status tracking
|
||||||
status = db.Column(db.Enum(PeriodStatus), nullable=False, default=PeriodStatus.UPCOMING)
|
status = db.Column(db.Enum(PeriodStatus), nullable=False, default=PeriodStatus.UPCOMING)
|
||||||
|
|||||||
@@ -24,45 +24,65 @@ class LicensePeriodServices:
|
|||||||
Raises:
|
Raises:
|
||||||
EveAIException: and derived classes
|
EveAIException: and derived classes
|
||||||
"""
|
"""
|
||||||
current_date = dt.now(tz.utc).date()
|
try:
|
||||||
license_period = (db.session.query(LicensePeriod)
|
current_app.logger.debug(f"Finding current license period for tenant {tenant_id}")
|
||||||
.filter_by(tenant_id=tenant_id)
|
current_date = dt.now(tz.utc).date()
|
||||||
.filter(and_(LicensePeriod.period_start_date <= current_date,
|
license_period = (db.session.query(LicensePeriod)
|
||||||
LicensePeriod.period_end_date >= current_date))
|
.filter_by(tenant_id=tenant_id)
|
||||||
.first())
|
.filter(and_(LicensePeriod.period_start <= current_date,
|
||||||
if not license_period:
|
LicensePeriod.period_end >= current_date))
|
||||||
license_period = LicensePeriodServices._create_next_license_period_for_usage(tenant_id)
|
.first())
|
||||||
if license_period:
|
current_app.logger.debug(f"End searching for license period for tenant {tenant_id} ")
|
||||||
match license_period.status:
|
if not license_period:
|
||||||
case PeriodStatus.UPCOMING:
|
current_app.logger.debug(f"No license period found for tenant {tenant_id} on date {current_date}")
|
||||||
LicensePeriodServices._complete_last_license_period()
|
license_period = LicensePeriodServices._create_next_license_period_for_usage(tenant_id)
|
||||||
LicensePeriodServices._activate_license_period(license_period)
|
current_app.logger.debug(f"Created license period {license_period.id} for tenant {tenant_id}")
|
||||||
if not license_period.license_usage:
|
if license_period:
|
||||||
new_license_usage = LicenseUsage()
|
current_app.logger.debug(f"Found license period {license_period.id} for tenant {tenant_id} "
|
||||||
new_license_usage.license_period = license_period
|
f"with status {license_period.status}")
|
||||||
try:
|
match license_period.status:
|
||||||
db.session.add(new_license_usage)
|
case PeriodStatus.UPCOMING:
|
||||||
db.session.commit()
|
current_app.logger.debug(f"In upcoming state")
|
||||||
except SQLAlchemyError as e:
|
LicensePeriodServices._complete_last_license_period(tenant_id=tenant_id)
|
||||||
db.session.rollback()
|
current_app.logger.debug(f"Completed last license period for tenant {tenant_id}")
|
||||||
current_app.logger.error(
|
LicensePeriodServices._activate_license_period(license_period=license_period)
|
||||||
f"Error creating new license usage for license period {license_period.id}: {str(e)}")
|
current_app.logger.debug(f"Activated license period {license_period.id} for tenant {tenant_id}")
|
||||||
raise e
|
if not license_period.license_usage:
|
||||||
if license_period.status == PeriodStatus.ACTIVE:
|
new_license_usage = LicenseUsage(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
)
|
||||||
|
new_license_usage.license_period = license_period
|
||||||
|
try:
|
||||||
|
db.session.add(new_license_usage)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
except SQLAlchemyError as e:
|
||||||
|
db.session.rollback()
|
||||||
|
current_app.logger.error(
|
||||||
|
f"Error creating new license usage for license period "
|
||||||
|
f"{license_period.id}: {str(e)}")
|
||||||
|
raise e
|
||||||
|
if license_period.status == PeriodStatus.ACTIVE:
|
||||||
|
return license_period
|
||||||
|
else:
|
||||||
|
# Status is PENDING, so no prepaid payment received. There is no license period we can use.
|
||||||
|
# We allow for a delay of 5 days before raising an exception.
|
||||||
|
current_date = dt.now(tz.utc).date()
|
||||||
|
delta = abs(current_date - license_period.period_start_date)
|
||||||
|
if delta > timedelta(days=current_app.config.get('ENTITLEMENTS_MAX_PENDING_DAYS', 5)):
|
||||||
|
raise EveAIPendingLicensePeriod()
|
||||||
|
case PeriodStatus.ACTIVE:
|
||||||
return license_period
|
return license_period
|
||||||
else:
|
case PeriodStatus.PENDING:
|
||||||
# Status is PENDING, so no prepaid payment received. There is no license period we can use.
|
return license_period
|
||||||
# We allow for a delay of 5 days before raising an exception.
|
else:
|
||||||
current_date = dt.now(tz.utc).date()
|
raise EveAILicensePeriodsExceeded(license_id=None)
|
||||||
delta = abs(current_date - license_period.period_start_date)
|
except SQLAlchemyError as e:
|
||||||
if delta > timedelta(days=current_app.config.get('ENTITLEMENTS_MAX_PENDING_DAYS', 5)):
|
db.session.rollback()
|
||||||
raise EveAIPendingLicensePeriod()
|
current_app.logger.error(f"Error finding current license period for tenant {tenant_id}: {str(e)}")
|
||||||
case PeriodStatus.ACTIVE:
|
raise e
|
||||||
return license_period
|
except Exception as e:
|
||||||
case PeriodStatus.PENDING:
|
raise e
|
||||||
return license_period
|
|
||||||
else:
|
|
||||||
raise EveAILicensePeriodsExceeded(license_id=None)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _create_next_license_period_for_usage(tenant_id) -> LicensePeriod:
|
def _create_next_license_period_for_usage(tenant_id) -> LicensePeriod:
|
||||||
@@ -87,13 +107,17 @@ class LicensePeriodServices:
|
|||||||
if not the_license:
|
if not the_license:
|
||||||
current_app.logger.error(f"No active license found for tenant {tenant_id} on date {current_date}")
|
current_app.logger.error(f"No active license found for tenant {tenant_id} on date {current_date}")
|
||||||
raise EveAINoActiveLicense(tenant_id=tenant_id)
|
raise EveAINoActiveLicense(tenant_id=tenant_id)
|
||||||
|
else:
|
||||||
|
current_app.logger.debug(f"Found active license {the_license.id} for tenant {tenant_id} "
|
||||||
|
f"on date {current_date}")
|
||||||
|
|
||||||
next_period_number = 1
|
next_period_number = 1
|
||||||
if the_license.periods:
|
if the_license.periods:
|
||||||
# If there are existing periods, get the next sequential number
|
# If there are existing periods, get the next sequential number
|
||||||
next_period_number = max(p.period_number for p in the_license.periods) + 1
|
next_period_number = max(p.period_number for p in the_license.periods) + 1
|
||||||
|
current_app.logger.debug(f"Next period number for tenant {tenant_id} is {next_period_number}")
|
||||||
|
|
||||||
if next_period_number > the_license.max_periods:
|
if next_period_number > the_license.nr_of_periods:
|
||||||
raise EveAILicensePeriodsExceeded(license_id=the_license.id)
|
raise EveAILicensePeriodsExceeded(license_id=the_license.id)
|
||||||
|
|
||||||
new_license_period = LicensePeriod(
|
new_license_period = LicensePeriod(
|
||||||
@@ -103,18 +127,16 @@ class LicensePeriodServices:
|
|||||||
period_start=the_license.start_date + relativedelta(months=next_period_number-1),
|
period_start=the_license.start_date + relativedelta(months=next_period_number-1),
|
||||||
period_end=the_license.end_date + relativedelta(months=next_period_number, days=-1),
|
period_end=the_license.end_date + relativedelta(months=next_period_number, days=-1),
|
||||||
status=PeriodStatus.UPCOMING,
|
status=PeriodStatus.UPCOMING,
|
||||||
|
upcoming_at=dt.now(tz.utc),
|
||||||
)
|
)
|
||||||
set_logging_information(new_license_period, dt.now(tz.utc))
|
set_logging_information(new_license_period, dt.now(tz.utc))
|
||||||
new_license_usage = LicenseUsage(
|
|
||||||
license_period=new_license_period,
|
|
||||||
tenant_id=tenant_id,
|
|
||||||
)
|
|
||||||
set_logging_information(new_license_usage, dt.now(tz.utc))
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
current_app.logger.debug(f"Creating next license period for tenant {tenant_id} ")
|
||||||
db.session.add(new_license_period)
|
db.session.add(new_license_period)
|
||||||
db.session.add(new_license_usage)
|
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
current_app.logger.info(f"Created next license period for tenant {tenant_id} "
|
||||||
|
f"with id {new_license_period.id}")
|
||||||
return new_license_period
|
return new_license_period
|
||||||
except SQLAlchemyError as e:
|
except SQLAlchemyError as e:
|
||||||
db.session.rollback()
|
db.session.rollback()
|
||||||
@@ -136,15 +158,17 @@ class LicensePeriodServices:
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: If neither license_period_id nor license_period is provided
|
ValueError: If neither license_period_id nor license_period is provided
|
||||||
"""
|
"""
|
||||||
|
current_app.logger.debug(f"Activating license period")
|
||||||
if license_period is None and license_period_id is None:
|
if license_period is None and license_period_id is None:
|
||||||
raise ValueError("Either license_period_id or license_period must be provided")
|
raise ValueError("Either license_period_id or license_period must be provided")
|
||||||
|
|
||||||
# Get a license period object if only ID was provided
|
# Get a license period object if only ID was provided
|
||||||
if license_period is None:
|
if license_period is None:
|
||||||
|
current_app.logger.debug(f"Getting license period {license_period_id} to activate")
|
||||||
license_period = LicensePeriod.query.get_or_404(license_period_id)
|
license_period = LicensePeriod.query.get_or_404(license_period_id)
|
||||||
|
|
||||||
if license_period.upcoming_at is not None:
|
if license_period.pending_at is not None:
|
||||||
license_period.pending_at.upcoming_at = dt.now(tz.utc)
|
license_period.pending_at = dt.now(tz.utc)
|
||||||
license_period.status = PeriodStatus.PENDING
|
license_period.status = PeriodStatus.PENDING
|
||||||
if license_period.prepaid_payment:
|
if license_period.prepaid_payment:
|
||||||
# There is a payment received for the given period
|
# There is a payment received for the given period
|
||||||
@@ -173,7 +197,7 @@ class LicensePeriodServices:
|
|||||||
if not license_period.license_usage:
|
if not license_period.license_usage:
|
||||||
license_period.license_usage = LicenseUsage(
|
license_period.license_usage = LicenseUsage(
|
||||||
tenant_id=license_period.tenant_id,
|
tenant_id=license_period.tenant_id,
|
||||||
license_period=license_period,
|
license_period_id=license_period.id,
|
||||||
)
|
)
|
||||||
|
|
||||||
license_period.license_usage.recalculate_storage()
|
license_period.license_usage.recalculate_storage()
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ from common.eveai_model.tracked_mistral_embeddings import TrackedMistralAIEmbedd
|
|||||||
from common.langchain.tracked_transcription import TrackedOpenAITranscription
|
from common.langchain.tracked_transcription import TrackedOpenAITranscription
|
||||||
from common.models.user import Tenant
|
from common.models.user import Tenant
|
||||||
from config.model_config import MODEL_CONFIG
|
from config.model_config import MODEL_CONFIG
|
||||||
from common.extensions import template_manager
|
from common.extensions import cache_manager
|
||||||
from common.models.document import EmbeddingMistral
|
from common.models.document import EmbeddingMistral
|
||||||
from common.utils.eveai_exceptions import EveAITenantNotFound, EveAIInvalidEmbeddingModel
|
from common.utils.eveai_exceptions import EveAITenantNotFound, EveAIInvalidEmbeddingModel
|
||||||
from crewai import LLM
|
from crewai import LLM
|
||||||
@@ -139,6 +139,19 @@ def process_pdf():
|
|||||||
full_model_name = 'mistral-ocr-latest'
|
full_model_name = 'mistral-ocr-latest'
|
||||||
|
|
||||||
|
|
||||||
|
def get_template(template_name: str, version: Optional[str] = "1.0") -> tuple[
|
||||||
|
Any, BaseChatModel | None | ChatOpenAI | ChatMistralAI]:
|
||||||
|
"""
|
||||||
|
Get a prompt template
|
||||||
|
"""
|
||||||
|
prompt = cache_manager.prompts_config_cache.get_config(template_name, version)
|
||||||
|
if "llm_model" in prompt:
|
||||||
|
llm = get_embedding_llm(full_model_name=prompt["llm_model"])
|
||||||
|
else:
|
||||||
|
llm = get_embedding_llm()
|
||||||
|
|
||||||
|
return prompt["content"], llm
|
||||||
|
|
||||||
|
|
||||||
class ModelVariables:
|
class ModelVariables:
|
||||||
"""Manages model-related variables and configurations"""
|
"""Manages model-related variables and configurations"""
|
||||||
@@ -261,31 +274,6 @@ class ModelVariables:
|
|||||||
def transcribe(self, *args, **kwargs):
|
def transcribe(self, *args, **kwargs):
|
||||||
raise DeprecationWarning("Use transcription_model.transcribe() instead")
|
raise DeprecationWarning("Use transcription_model.transcribe() instead")
|
||||||
|
|
||||||
def get_template(self, template_name: str, version: Optional[str] = None) -> str:
|
|
||||||
"""
|
|
||||||
Get a template for the tenant's configured LLM
|
|
||||||
|
|
||||||
Args:
|
|
||||||
template_name: Name of the template to retrieve
|
|
||||||
version: Optional specific version to retrieve
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The template content
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
template = template_manager.get_template(
|
|
||||||
self._variables['llm_full_model'],
|
|
||||||
template_name,
|
|
||||||
version
|
|
||||||
)
|
|
||||||
return template.content
|
|
||||||
except Exception as e:
|
|
||||||
current_app.logger.error(f"Error getting template {template_name}: {str(e)}")
|
|
||||||
# Fall back to old template loading if template_manager fails
|
|
||||||
if template_name in self._variables.get('templates', {}):
|
|
||||||
return self._variables['templates'][template_name]
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
# Helper function to get cached model variables
|
# Helper function to get cached model variables
|
||||||
def get_model_variables(tenant_id: int) -> ModelVariables:
|
def get_model_variables(tenant_id: int) -> ModelVariables:
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
import pytz
|
import pytz
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from common.utils.nginx_utils import prefixed_url_for as puf
|
||||||
|
|
||||||
|
|
||||||
def to_local_time(utc_dt, timezone_str):
|
def to_local_time(utc_dt, timezone_str):
|
||||||
@@ -42,6 +43,10 @@ def status_color(status_name):
|
|||||||
return colors.get(status_name, 'secondary')
|
return colors.get(status_name, 'secondary')
|
||||||
|
|
||||||
|
|
||||||
|
def prefixed_url_for(endpoint):
|
||||||
|
return puf(endpoint)
|
||||||
|
|
||||||
|
|
||||||
def register_filters(app):
|
def register_filters(app):
|
||||||
"""
|
"""
|
||||||
Registers custom filters with the Flask app.
|
Registers custom filters with the Flask app.
|
||||||
@@ -49,3 +54,6 @@ def register_filters(app):
|
|||||||
app.jinja_env.filters['to_local_time'] = to_local_time
|
app.jinja_env.filters['to_local_time'] = to_local_time
|
||||||
app.jinja_env.filters['time_difference'] = time_difference
|
app.jinja_env.filters['time_difference'] = time_difference
|
||||||
app.jinja_env.filters['status_color'] = status_color
|
app.jinja_env.filters['status_color'] = status_color
|
||||||
|
app.jinja_env.filters['prefixed_url_for'] = prefixed_url_for
|
||||||
|
|
||||||
|
app.jinja_env.globals['prefixed_url_for'] = prefixed_url_for
|
||||||
|
|||||||
Reference in New Issue
Block a user