- Introduction of dynamic Retrievers & Specialists
- Introduction of dynamic Processors - Introduction of caching system - Introduction of a better template manager - Adaptation of ModelVariables to support dynamic Processors / Retrievers / Specialists - Start adaptation of chat client
This commit is contained in:
154
common/langchain/templates/template_manager.py
Normal file
154
common/langchain/templates/template_manager.py
Normal file
@@ -0,0 +1,154 @@
|
||||
import os
|
||||
import yaml
|
||||
from typing import Dict, Optional, Any
|
||||
from packaging import version
|
||||
from dataclasses import dataclass
|
||||
from flask import current_app, Flask
|
||||
|
||||
from common.utils.os_utils import get_project_root
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptTemplate:
|
||||
"""Represents a versioned prompt template"""
|
||||
content: str
|
||||
version: str
|
||||
metadata: Dict[str, Any]
|
||||
|
||||
|
||||
class TemplateManager:
|
||||
"""Manages versioned prompt templates"""
|
||||
|
||||
def __init__(self):
|
||||
self.templates_dir = None
|
||||
self._templates = None
|
||||
self.app = None
|
||||
|
||||
def init_app(self, app: Flask) -> None:
|
||||
# Initialize template manager
|
||||
base_dir = "/app"
|
||||
self.templates_dir = os.path.join(base_dir, 'config', 'prompts')
|
||||
app.logger.debug(f'Loading templates from {self.templates_dir}')
|
||||
self.app = app
|
||||
self._templates = self._load_templates()
|
||||
# Log available templates for each supported model
|
||||
for llm in app.config['SUPPORTED_LLMS']:
|
||||
try:
|
||||
available_templates = self.list_templates(llm)
|
||||
app.logger.info(f"Loaded templates for {llm}: {available_templates}")
|
||||
except ValueError:
|
||||
app.logger.warning(f"No templates found for {llm}")
|
||||
|
||||
def _load_templates(self) -> Dict[str, Dict[str, Dict[str, PromptTemplate]]]:
|
||||
"""
|
||||
Load all template versions from the templates directory.
|
||||
Structure: {provider.model -> {template_name -> {version -> template}}}
|
||||
Directory structure:
|
||||
prompts/
|
||||
├── provider/
|
||||
│ └── model/
|
||||
│ └── template_name/
|
||||
│ └── version.yaml
|
||||
"""
|
||||
templates = {}
|
||||
|
||||
# Iterate through providers (anthropic, openai)
|
||||
for provider in os.listdir(self.templates_dir):
|
||||
provider_path = os.path.join(self.templates_dir, provider)
|
||||
if not os.path.isdir(provider_path):
|
||||
continue
|
||||
|
||||
# Iterate through models (claude-3, gpt-4o)
|
||||
for model in os.listdir(provider_path):
|
||||
model_path = os.path.join(provider_path, model)
|
||||
if not os.path.isdir(model_path):
|
||||
continue
|
||||
|
||||
provider_model = f"{provider}.{model}"
|
||||
templates[provider_model] = {}
|
||||
|
||||
# Iterate through template types (rag, summary, etc.)
|
||||
for template_name in os.listdir(model_path):
|
||||
template_path = os.path.join(model_path, template_name)
|
||||
if not os.path.isdir(template_path):
|
||||
continue
|
||||
|
||||
template_versions = {}
|
||||
# Load all version files for this template
|
||||
for version_file in os.listdir(template_path):
|
||||
if not version_file.endswith('.yaml'):
|
||||
continue
|
||||
|
||||
version_str = version_file[:-5] # Remove .yaml
|
||||
if not self._is_valid_version(version_str):
|
||||
current_app.logger.warning(
|
||||
f"Invalid version format for {template_name}: {version_str}")
|
||||
continue
|
||||
|
||||
try:
|
||||
with open(os.path.join(template_path, version_file)) as f:
|
||||
template_data = yaml.safe_load(f)
|
||||
# Verify required fields
|
||||
if not template_data.get('content'):
|
||||
raise ValueError("Template content is required")
|
||||
|
||||
template_versions[version_str] = PromptTemplate(
|
||||
content=template_data['content'],
|
||||
version=version_str,
|
||||
metadata=template_data.get('metadata', {})
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(
|
||||
f"Error loading template {template_name} version {version_str}: {e}")
|
||||
continue
|
||||
|
||||
if template_versions:
|
||||
templates[provider_model][template_name] = template_versions
|
||||
|
||||
return templates
|
||||
|
||||
def _is_valid_version(self, version_str: str) -> bool:
|
||||
"""Validate semantic versioning string"""
|
||||
try:
|
||||
version.parse(version_str)
|
||||
return True
|
||||
except version.InvalidVersion:
|
||||
return False
|
||||
|
||||
def get_template(self,
|
||||
provider_model: str,
|
||||
template_name: str,
|
||||
template_version: Optional[str] = None) -> PromptTemplate:
|
||||
"""
|
||||
Get a specific template version. If version not specified,
|
||||
returns the latest version.
|
||||
"""
|
||||
if provider_model not in self._templates:
|
||||
raise ValueError(f"Unknown provider.model: {provider_model}")
|
||||
|
||||
if template_name not in self._templates[provider_model]:
|
||||
raise ValueError(f"Unknown template: {template_name}")
|
||||
|
||||
versions = self._templates[provider_model][template_name]
|
||||
|
||||
if template_version:
|
||||
if template_version not in versions:
|
||||
raise ValueError(f"Template version {template_version} not found")
|
||||
return versions[template_version]
|
||||
|
||||
# Return latest version
|
||||
latest = max(versions.keys(), key=version.parse)
|
||||
return versions[latest]
|
||||
|
||||
def list_templates(self, provider_model: str) -> Dict[str, list]:
|
||||
"""
|
||||
List all available templates and their versions for a provider.model
|
||||
Returns: {template_name: [version1, version2, ...]}
|
||||
"""
|
||||
if provider_model not in self._templates:
|
||||
raise ValueError(f"Unknown provider.model: {provider_model}")
|
||||
|
||||
return {
|
||||
template_name: sorted(versions.keys(), key=version.parse)
|
||||
for template_name, versions in self._templates[provider_model].items()
|
||||
}
|
||||
Reference in New Issue
Block a user