import os import yaml from typing import Dict, Optional, Any from packaging import version from dataclasses import dataclass from flask import current_app, Flask from common.utils.os_utils import get_project_root @dataclass class PromptTemplate: """Represents a versioned prompt template""" content: str version: str metadata: Dict[str, Any] class TemplateManager: """Manages versioned prompt templates""" def __init__(self): self.templates_dir = None self._templates = None self.app = None def init_app(self, app: Flask) -> None: # Initialize template manager base_dir = "/app" self.templates_dir = os.path.join(base_dir, 'config', 'prompts') self.app = app self._templates = self._load_templates() # Log available templates for each supported model for llm in app.config['SUPPORTED_LLMS']: try: available_templates = self.list_templates(llm) app.logger.info(f"Loaded templates for {llm}: {available_templates}") except ValueError: app.logger.warning(f"No templates found for {llm}") def _load_templates(self) -> Dict[str, Dict[str, Dict[str, PromptTemplate]]]: """ Load all template versions from the templates directory. Structure: {provider.model -> {template_name -> {version -> template}}} Directory structure: prompts/ ├── provider/ │ └── model/ │ └── template_name/ │ └── version.yaml """ templates = {} # Iterate through providers (anthropic, openai) for provider in os.listdir(self.templates_dir): provider_path = os.path.join(self.templates_dir, provider) if not os.path.isdir(provider_path): continue # Iterate through models (claude-3, gpt-4o) for model in os.listdir(provider_path): model_path = os.path.join(provider_path, model) if not os.path.isdir(model_path): continue provider_model = f"{provider}.{model}" templates[provider_model] = {} # Iterate through template types (rag, summary, etc.) for template_name in os.listdir(model_path): template_path = os.path.join(model_path, template_name) if not os.path.isdir(template_path): continue template_versions = {} # Load all version files for this template for version_file in os.listdir(template_path): if not version_file.endswith('.yaml'): continue version_str = version_file[:-5] # Remove .yaml if not self._is_valid_version(version_str): current_app.logger.warning( f"Invalid version format for {template_name}: {version_str}") continue try: with open(os.path.join(template_path, version_file)) as f: template_data = yaml.safe_load(f) # Verify required fields if not template_data.get('content'): raise ValueError("Template content is required") template_versions[version_str] = PromptTemplate( content=template_data['content'], version=version_str, metadata=template_data.get('metadata', {}) ) except Exception as e: current_app.logger.error( f"Error loading template {template_name} version {version_str}: {e}") continue if template_versions: templates[provider_model][template_name] = template_versions return templates def _is_valid_version(self, version_str: str) -> bool: """Validate semantic versioning string""" try: version.parse(version_str) return True except version.InvalidVersion: return False def get_template(self, provider_model: str, template_name: str, template_version: Optional[str] = None) -> PromptTemplate: """ Get a specific template version. If version not specified, returns the latest version. """ if provider_model not in self._templates: raise ValueError(f"Unknown provider.model: {provider_model}") if template_name not in self._templates[provider_model]: raise ValueError(f"Unknown template: {template_name}") versions = self._templates[provider_model][template_name] if template_version: if template_version not in versions: raise ValueError(f"Template version {template_version} not found") return versions[template_version] # Return latest version latest = max(versions.keys(), key=version.parse) return versions[latest] def list_templates(self, provider_model: str) -> Dict[str, list]: """ List all available templates and their versions for a provider.model Returns: {template_name: [version1, version2, ...]} """ if provider_model not in self._templates: raise ValueError(f"Unknown provider.model: {provider_model}") return { template_name: sorted(versions.keys(), key=version.parse) for template_name, versions in self._templates[provider_model].items() }