154 lines
5.9 KiB
Python
154 lines
5.9 KiB
Python
import os
|
|
import yaml
|
|
from typing import Dict, Optional, Any
|
|
from packaging import version
|
|
from dataclasses import dataclass
|
|
from flask import current_app, Flask
|
|
|
|
from common.utils.os_utils import get_project_root
|
|
|
|
|
|
@dataclass
|
|
class PromptTemplate:
|
|
"""Represents a versioned prompt template"""
|
|
content: str
|
|
version: str
|
|
metadata: Dict[str, Any]
|
|
|
|
|
|
class TemplateManager:
|
|
"""Manages versioned prompt templates"""
|
|
|
|
def __init__(self):
|
|
self.templates_dir = None
|
|
self._templates = None
|
|
self.app = None
|
|
|
|
def init_app(self, app: Flask) -> None:
|
|
# Initialize template manager
|
|
base_dir = "/app"
|
|
self.templates_dir = os.path.join(base_dir, 'config', 'prompts')
|
|
self.app = app
|
|
self._templates = self._load_templates()
|
|
# Log available templates for each supported model
|
|
for llm in app.config['SUPPORTED_LLMS']:
|
|
try:
|
|
available_templates = self.list_templates(llm)
|
|
app.logger.info(f"Loaded templates for {llm}: {available_templates}")
|
|
except ValueError:
|
|
app.logger.warning(f"No templates found for {llm}")
|
|
|
|
def _load_templates(self) -> Dict[str, Dict[str, Dict[str, PromptTemplate]]]:
|
|
"""
|
|
Load all template versions from the templates directory.
|
|
Structure: {provider.model -> {template_name -> {version -> template}}}
|
|
Directory structure:
|
|
prompts/
|
|
├── provider/
|
|
│ └── model/
|
|
│ └── template_name/
|
|
│ └── version.yaml
|
|
"""
|
|
templates = {}
|
|
|
|
# Iterate through providers (anthropic, openai)
|
|
for provider in os.listdir(self.templates_dir):
|
|
provider_path = os.path.join(self.templates_dir, provider)
|
|
if not os.path.isdir(provider_path):
|
|
continue
|
|
|
|
# Iterate through models (claude-3, gpt-4o)
|
|
for model in os.listdir(provider_path):
|
|
model_path = os.path.join(provider_path, model)
|
|
if not os.path.isdir(model_path):
|
|
continue
|
|
|
|
provider_model = f"{provider}.{model}"
|
|
templates[provider_model] = {}
|
|
|
|
# Iterate through template types (rag, summary, etc.)
|
|
for template_name in os.listdir(model_path):
|
|
template_path = os.path.join(model_path, template_name)
|
|
if not os.path.isdir(template_path):
|
|
continue
|
|
|
|
template_versions = {}
|
|
# Load all version files for this template
|
|
for version_file in os.listdir(template_path):
|
|
if not version_file.endswith('.yaml'):
|
|
continue
|
|
|
|
version_str = version_file[:-5] # Remove .yaml
|
|
if not self._is_valid_version(version_str):
|
|
current_app.logger.warning(
|
|
f"Invalid version format for {template_name}: {version_str}")
|
|
continue
|
|
|
|
try:
|
|
with open(os.path.join(template_path, version_file)) as f:
|
|
template_data = yaml.safe_load(f)
|
|
# Verify required fields
|
|
if not template_data.get('content'):
|
|
raise ValueError("Template content is required")
|
|
|
|
template_versions[version_str] = PromptTemplate(
|
|
content=template_data['content'],
|
|
version=version_str,
|
|
metadata=template_data.get('metadata', {})
|
|
)
|
|
except Exception as e:
|
|
current_app.logger.error(
|
|
f"Error loading template {template_name} version {version_str}: {e}")
|
|
continue
|
|
|
|
if template_versions:
|
|
templates[provider_model][template_name] = template_versions
|
|
|
|
return templates
|
|
|
|
def _is_valid_version(self, version_str: str) -> bool:
|
|
"""Validate semantic versioning string"""
|
|
try:
|
|
version.parse(version_str)
|
|
return True
|
|
except version.InvalidVersion:
|
|
return False
|
|
|
|
def get_template(self,
|
|
provider_model: str,
|
|
template_name: str,
|
|
template_version: Optional[str] = None) -> PromptTemplate:
|
|
"""
|
|
Get a specific template version. If version not specified,
|
|
returns the latest version.
|
|
"""
|
|
if provider_model not in self._templates:
|
|
raise ValueError(f"Unknown provider.model: {provider_model}")
|
|
|
|
if template_name not in self._templates[provider_model]:
|
|
raise ValueError(f"Unknown template: {template_name}")
|
|
|
|
versions = self._templates[provider_model][template_name]
|
|
|
|
if template_version:
|
|
if template_version not in versions:
|
|
raise ValueError(f"Template version {template_version} not found")
|
|
return versions[template_version]
|
|
|
|
# Return latest version
|
|
latest = max(versions.keys(), key=version.parse)
|
|
return versions[latest]
|
|
|
|
def list_templates(self, provider_model: str) -> Dict[str, list]:
|
|
"""
|
|
List all available templates and their versions for a provider.model
|
|
Returns: {template_name: [version1, version2, ...]}
|
|
"""
|
|
if provider_model not in self._templates:
|
|
raise ValueError(f"Unknown provider.model: {provider_model}")
|
|
|
|
return {
|
|
template_name: sorted(versions.keys(), key=version.parse)
|
|
for template_name, versions in self._templates[provider_model].items()
|
|
}
|