- Introduction of dynamic Retrievers & Specialists
- Introduction of dynamic Processors - Introduction of caching system - Introduction of a better template manager - Adaptation of ModelVariables to support dynamic Processors / Retrievers / Specialists - Start adaptation of chat client
This commit is contained in:
23
common/langchain/outputs/base.py
Normal file
23
common/langchain/outputs/base.py
Normal file
@@ -0,0 +1,23 @@
|
||||
# Output Schema Management - common/langchain/outputs/base.py
|
||||
from typing import Dict, Type, Any
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class BaseSpecialistOutput(BaseModel):
|
||||
"""Base class for all specialist outputs"""
|
||||
pass
|
||||
|
||||
|
||||
class OutputRegistry:
|
||||
"""Registry for specialist output schemas"""
|
||||
_schemas: Dict[str, Type[BaseSpecialistOutput]] = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, specialist_type: str, schema_class: Type[BaseSpecialistOutput]):
|
||||
cls._schemas[specialist_type] = schema_class
|
||||
|
||||
@classmethod
|
||||
def get_schema(cls, specialist_type: str) -> Type[BaseSpecialistOutput]:
|
||||
if specialist_type not in cls._schemas:
|
||||
raise ValueError(f"No output schema registered for {specialist_type}")
|
||||
return cls._schemas[specialist_type]
|
||||
22
common/langchain/outputs/rag.py
Normal file
22
common/langchain/outputs/rag.py
Normal file
@@ -0,0 +1,22 @@
|
||||
# RAG Specialist Output - common/langchain/outputs/rag.py
|
||||
from typing import List
|
||||
from pydantic import Field
|
||||
from .base import BaseSpecialistOutput
|
||||
|
||||
|
||||
class RAGOutput(BaseSpecialistOutput):
|
||||
"""Output schema for RAG specialist"""
|
||||
"""Default docstring - to be replaced with actual prompt"""
|
||||
|
||||
answer: str = Field(
|
||||
...,
|
||||
description="The answer to the user question, based on the given sources",
|
||||
)
|
||||
citations: List[int] = Field(
|
||||
...,
|
||||
description="The integer IDs of the SPECIFIC sources that were used to generate the answer"
|
||||
)
|
||||
insufficient_info: bool = Field(
|
||||
False, # Default value is set to False
|
||||
description="A boolean indicating whether given sources were sufficient or not to generate the answer"
|
||||
)
|
||||
@@ -1,145 +0,0 @@
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from sqlalchemy import func, and_, or_, desc
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
from typing import Any, Dict
|
||||
from flask import current_app
|
||||
|
||||
from common.extensions import db
|
||||
from common.models.document import Document, DocumentVersion
|
||||
from common.utils.datetime_utils import get_date_in_timezone
|
||||
from common.utils.model_utils import ModelVariables
|
||||
|
||||
|
||||
class EveAIDefaultRagRetriever(BaseRetriever, BaseModel):
|
||||
_catalog_id: int = PrivateAttr()
|
||||
_model_variables: ModelVariables = PrivateAttr()
|
||||
_tenant_info: Dict[str, Any] = PrivateAttr()
|
||||
|
||||
def __init__(self, catalog_id: int, model_variables: ModelVariables, tenant_info: Dict[str, Any]):
|
||||
super().__init__()
|
||||
current_app.logger.debug(f'Model variables type: {type(model_variables)}')
|
||||
self._catalog_id = catalog_id
|
||||
self._model_variables = model_variables
|
||||
self._tenant_info = tenant_info
|
||||
|
||||
@property
|
||||
def catalog_id(self) -> int:
|
||||
return self._catalog_id
|
||||
|
||||
@property
|
||||
def model_variables(self) -> ModelVariables:
|
||||
return self._model_variables
|
||||
|
||||
@property
|
||||
def tenant_info(self) -> Dict[str, Any]:
|
||||
return self._tenant_info
|
||||
|
||||
def _get_relevant_documents(self, query: str):
|
||||
current_app.logger.debug(f'Retrieving relevant documents for query: {query}')
|
||||
query_embedding = self._get_query_embedding(query)
|
||||
current_app.logger.debug(f'Model Variables Private: {type(self._model_variables)}')
|
||||
current_app.logger.debug(f'Model Variables Property: {type(self.model_variables)}')
|
||||
db_class = self.model_variables['embedding_db_model']
|
||||
similarity_threshold = self.model_variables['similarity_threshold']
|
||||
k = self.model_variables['k']
|
||||
|
||||
if self.model_variables['rag_tuning']:
|
||||
try:
|
||||
current_date = get_date_in_timezone(self.tenant_info['timezone'])
|
||||
current_app.rag_tuning_logger.debug(f'Current date: {current_date}\n')
|
||||
|
||||
# Debug query to show similarity for all valid documents (without chunk text)
|
||||
debug_query = (
|
||||
db.session.query(
|
||||
Document.id.label('document_id'),
|
||||
DocumentVersion.id.label('version_id'),
|
||||
db_class.id.label('embedding_id'),
|
||||
(1 - db_class.embedding.cosine_distance(query_embedding)).label('similarity')
|
||||
)
|
||||
.join(DocumentVersion, db_class.doc_vers_id == DocumentVersion.id)
|
||||
.join(Document, DocumentVersion.doc_id == Document.id)
|
||||
.filter(
|
||||
or_(Document.valid_from.is_(None), func.date(Document.valid_from) <= current_date),
|
||||
or_(Document.valid_to.is_(None), func.date(Document.valid_to) >= current_date)
|
||||
)
|
||||
.order_by(desc('similarity'))
|
||||
)
|
||||
|
||||
debug_results = debug_query.all()
|
||||
|
||||
current_app.logger.debug("Debug: Similarity for all valid documents:")
|
||||
for row in debug_results:
|
||||
current_app.rag_tuning_logger.debug(f"Doc ID: {row.document_id}, "
|
||||
f"Version ID: {row.version_id}, "
|
||||
f"Embedding ID: {row.embedding_id}, "
|
||||
f"Similarity: {row.similarity}")
|
||||
current_app.rag_tuning_logger.debug(f'---------------------------------------\n')
|
||||
except SQLAlchemyError as e:
|
||||
current_app.logger.error(f'Error generating overview: {e}')
|
||||
db.session.rollback()
|
||||
|
||||
if self.model_variables['rag_tuning']:
|
||||
current_app.rag_tuning_logger.debug(f'Parameters for Retrieval of documents: \n')
|
||||
current_app.rag_tuning_logger.debug(f'Similarity Threshold: {similarity_threshold}\n')
|
||||
current_app.rag_tuning_logger.debug(f'K: {k}\n')
|
||||
current_app.rag_tuning_logger.debug(f'---------------------------------------\n')
|
||||
|
||||
try:
|
||||
current_date = get_date_in_timezone(self.tenant_info['timezone'])
|
||||
# Subquery to find the latest version of each document
|
||||
subquery = (
|
||||
db.session.query(
|
||||
DocumentVersion.doc_id,
|
||||
func.max(DocumentVersion.id).label('latest_version_id')
|
||||
)
|
||||
.group_by(DocumentVersion.doc_id)
|
||||
.subquery()
|
||||
)
|
||||
# Main query to filter embeddings
|
||||
query_obj = (
|
||||
db.session.query(db_class,
|
||||
(1 - db_class.embedding.cosine_distance(query_embedding)).label('similarity'))
|
||||
.join(DocumentVersion, db_class.doc_vers_id == DocumentVersion.id)
|
||||
.join(Document, DocumentVersion.doc_id == Document.id)
|
||||
.join(subquery, DocumentVersion.id == subquery.c.latest_version_id)
|
||||
.filter(
|
||||
or_(Document.valid_from.is_(None), func.date(Document.valid_from) <= current_date),
|
||||
or_(Document.valid_to.is_(None), func.date(Document.valid_to) >= current_date),
|
||||
(1 - db_class.embedding.cosine_distance(query_embedding)) > similarity_threshold,
|
||||
Document.catalog_id == self._catalog_id
|
||||
)
|
||||
.order_by(desc('similarity'))
|
||||
.limit(k)
|
||||
)
|
||||
|
||||
if self.model_variables['rag_tuning']:
|
||||
current_app.rag_tuning_logger.debug(f'Query executed for Retrieval of documents: \n')
|
||||
current_app.rag_tuning_logger.debug(f'{query_obj.statement}\n')
|
||||
current_app.rag_tuning_logger.debug(f'---------------------------------------\n')
|
||||
|
||||
res = query_obj.all()
|
||||
|
||||
if self.model_variables['rag_tuning']:
|
||||
current_app.rag_tuning_logger.debug(f'Retrieved {len(res)} relevant documents \n')
|
||||
current_app.rag_tuning_logger.debug(f'Data retrieved: \n')
|
||||
current_app.rag_tuning_logger.debug(f'{res}\n')
|
||||
current_app.rag_tuning_logger.debug(f'---------------------------------------\n')
|
||||
|
||||
result = []
|
||||
for doc in res:
|
||||
if self.model_variables['rag_tuning']:
|
||||
current_app.rag_tuning_logger.debug(f'Document ID: {doc[0].id} - Distance: {doc[1]}\n')
|
||||
current_app.rag_tuning_logger.debug(f'Chunk: \n {doc[0].chunk}\n\n')
|
||||
result.append(f'SOURCE: {doc[0].id}\n\n{doc[0].chunk}\n\n')
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
current_app.logger.error(f'Error retrieving relevant documents: {e}')
|
||||
db.session.rollback()
|
||||
return []
|
||||
return result
|
||||
|
||||
def _get_query_embedding(self, query: str):
|
||||
embedding_model = self.model_variables['embedding_model']
|
||||
query_embedding = embedding_model.embed_query(query)
|
||||
return query_embedding
|
||||
@@ -1,154 +0,0 @@
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from sqlalchemy import func, and_, or_, desc, cast, JSON
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
from typing import Any, Dict, List, Optional
|
||||
from flask import current_app
|
||||
from contextlib import contextmanager
|
||||
|
||||
from common.extensions import db
|
||||
from common.models.document import Document, DocumentVersion, Catalog
|
||||
from common.utils.datetime_utils import get_date_in_timezone
|
||||
from common.utils.model_utils import ModelVariables
|
||||
|
||||
|
||||
class EveAIDossierRetriever(BaseRetriever, BaseModel):
|
||||
_catalog_id: int = PrivateAttr()
|
||||
_model_variables: ModelVariables = PrivateAttr()
|
||||
_tenant_info: Dict[str, Any] = PrivateAttr()
|
||||
_active_filters: Optional[Dict[str, Any]] = PrivateAttr()
|
||||
|
||||
def __init__(self, catalog_id: int, model_variables: ModelVariables, tenant_info: Dict[str, Any]):
|
||||
super().__init__()
|
||||
self._catalog_id = catalog_id
|
||||
self._model_variables = model_variables
|
||||
self._tenant_info = tenant_info
|
||||
self._active_filters = None
|
||||
|
||||
@contextmanager
|
||||
def filtering(self, metadata_filters: Dict[str, Any]):
|
||||
"""Context manager for temporarily setting metadata filters"""
|
||||
previous_filters = self._active_filters
|
||||
self._active_filters = metadata_filters
|
||||
try:
|
||||
yield self
|
||||
finally:
|
||||
self._active_filters = previous_filters
|
||||
|
||||
def _build_metadata_filter_conditions(self, query):
|
||||
"""Build SQL conditions for metadata filtering"""
|
||||
if not self._active_filters:
|
||||
return query
|
||||
|
||||
conditions = []
|
||||
for field, value in self._active_filters.items():
|
||||
if value is None:
|
||||
continue
|
||||
|
||||
# Handle both single values and lists of values
|
||||
if isinstance(value, (list, tuple)):
|
||||
# Multiple values - create OR condition
|
||||
or_conditions = []
|
||||
for val in value:
|
||||
or_conditions.append(
|
||||
cast(DocumentVersion.user_metadata[field].astext, JSON) == str(val)
|
||||
)
|
||||
if or_conditions:
|
||||
conditions.append(or_(*or_conditions))
|
||||
else:
|
||||
# Single value - direct comparison
|
||||
conditions.append(
|
||||
cast(DocumentVersion.user_metadata[field].astext, JSON) == str(value)
|
||||
)
|
||||
|
||||
if conditions:
|
||||
query = query.filter(and_(*conditions))
|
||||
|
||||
return query
|
||||
|
||||
def _get_relevant_documents(self, query: str):
|
||||
current_app.logger.debug(f'Retrieving relevant documents for dossier query: {query}')
|
||||
if self._active_filters:
|
||||
current_app.logger.debug(f'Using metadata filters: {self._active_filters}')
|
||||
|
||||
query_embedding = self._get_query_embedding(query)
|
||||
db_class = self.model_variables['embedding_db_model']
|
||||
similarity_threshold = self.model_variables['similarity_threshold']
|
||||
k = self.model_variables['k']
|
||||
|
||||
try:
|
||||
current_date = get_date_in_timezone(self.tenant_info['timezone'])
|
||||
|
||||
# Subquery to find the latest version of each document
|
||||
subquery = (
|
||||
db.session.query(
|
||||
DocumentVersion.doc_id,
|
||||
func.max(DocumentVersion.id).label('latest_version_id')
|
||||
)
|
||||
.group_by(DocumentVersion.doc_id)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
# Build base query
|
||||
# Build base query
|
||||
query_obj = (
|
||||
db.session.query(db_class,
|
||||
(1 - db_class.embedding.cosine_distance(query_embedding)).label('similarity'))
|
||||
.join(DocumentVersion, db_class.doc_vers_id == DocumentVersion.id)
|
||||
.join(Document, DocumentVersion.doc_id == Document.id)
|
||||
.join(subquery, DocumentVersion.id == subquery.c.latest_version_id)
|
||||
.filter(
|
||||
or_(Document.valid_from.is_(None), func.date(Document.valid_from) <= current_date),
|
||||
or_(Document.valid_to.is_(None), func.date(Document.valid_to) >= current_date),
|
||||
(1 - db_class.embedding.cosine_distance(query_embedding)) > similarity_threshold,
|
||||
Document.catalog_id == self._catalog_id
|
||||
)
|
||||
)
|
||||
|
||||
# Apply metadata filters
|
||||
query_obj = self._build_metadata_filter_conditions(query_obj)
|
||||
|
||||
# Order and limit results
|
||||
query_obj = query_obj.order_by(desc('similarity')).limit(k)
|
||||
|
||||
# Debug logging for RAG tuning if enabled
|
||||
if self.model_variables['rag_tuning']:
|
||||
self._log_rag_tuning(query_obj, query_embedding)
|
||||
|
||||
res = query_obj.all()
|
||||
|
||||
result = []
|
||||
for doc in res:
|
||||
if self.model_variables['rag_tuning']:
|
||||
current_app.logger.debug(f'Document ID: {doc[0].id} - Distance: {doc[1]}\n')
|
||||
current_app.logger.debug(f'Chunk: \n {doc[0].chunk}\n\n')
|
||||
result.append(f'SOURCE: {doc[0].id}\n\n{doc[0].chunk}\n\n')
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
current_app.logger.error(f'Error retrieving relevant documents: {e}')
|
||||
db.session.rollback()
|
||||
return []
|
||||
|
||||
return result
|
||||
|
||||
def _log_rag_tuning(self, query_obj, query_embedding):
|
||||
"""Log debug information for RAG tuning"""
|
||||
current_app.rag_tuning_logger.debug("Debug: Query execution plan:")
|
||||
current_app.rag_tuning_logger.debug(f"{query_obj.statement}")
|
||||
if self._active_filters:
|
||||
current_app.rag_tuning_logger.debug("Debug: Active metadata filters:")
|
||||
current_app.rag_tuning_logger.debug(f"{self._active_filters}")
|
||||
|
||||
def _get_query_embedding(self, query: str):
|
||||
"""Get embedding for the query text"""
|
||||
embedding_model = self.model_variables['embedding_model']
|
||||
query_embedding = embedding_model.embed_query(query)
|
||||
return query_embedding
|
||||
|
||||
@property
|
||||
def model_variables(self) -> ModelVariables:
|
||||
return self._model_variables
|
||||
|
||||
@property
|
||||
def tenant_info(self) -> Dict[str, Any]:
|
||||
return self._tenant_info
|
||||
@@ -1,52 +0,0 @@
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from sqlalchemy import asc
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from pydantic import Field, BaseModel, PrivateAttr
|
||||
from typing import Any, Dict
|
||||
from flask import current_app
|
||||
|
||||
from common.extensions import db
|
||||
from common.models.interaction import ChatSession, Interaction
|
||||
from common.utils.model_utils import ModelVariables
|
||||
|
||||
|
||||
class EveAIHistoryRetriever(BaseRetriever, BaseModel):
|
||||
_model_variables: ModelVariables = PrivateAttr()
|
||||
_session_id: str = PrivateAttr()
|
||||
|
||||
def __init__(self, model_variables: ModelVariables, session_id: str):
|
||||
super().__init__()
|
||||
self._model_variables = model_variables
|
||||
self._session_id = session_id
|
||||
|
||||
@property
|
||||
def model_variables(self) -> ModelVariables:
|
||||
return self._model_variables
|
||||
|
||||
@property
|
||||
def session_id(self) -> str:
|
||||
return self._session_id
|
||||
|
||||
def _get_relevant_documents(self, query: str):
|
||||
current_app.logger.debug(f'Retrieving history of interactions for query: {query}')
|
||||
|
||||
try:
|
||||
query_obj = (
|
||||
db.session.query(Interaction)
|
||||
.join(ChatSession, Interaction.chat_session_id == ChatSession.id)
|
||||
.filter(ChatSession.session_id == self.session_id)
|
||||
.order_by(asc(Interaction.id))
|
||||
)
|
||||
|
||||
interactions = query_obj.all()
|
||||
|
||||
result = []
|
||||
for interaction in interactions:
|
||||
result.append(f'HUMAN:\n{interaction.detailed_question}\n\nAI: \n{interaction.answer}\n\n')
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
current_app.logger.error(f'Error retrieving history of interactions: {e}')
|
||||
db.session.rollback()
|
||||
return []
|
||||
|
||||
return result
|
||||
@@ -1,40 +0,0 @@
|
||||
from pydantic import BaseModel, PrivateAttr
|
||||
from typing import Dict, Any
|
||||
|
||||
from common.utils.model_utils import ModelVariables
|
||||
|
||||
|
||||
class EveAIRetriever(BaseModel):
|
||||
_catalog_id: int = PrivateAttr()
|
||||
_user_metadata: Dict[str, Any] = PrivateAttr()
|
||||
_system_metadata: Dict[str, Any] = PrivateAttr()
|
||||
_configuration: Dict[str, Any] = PrivateAttr()
|
||||
_tenant_info: Dict[str, Any] = PrivateAttr()
|
||||
_model_variables: ModelVariables = PrivateAttr()
|
||||
_tuning: bool = PrivateAttr()
|
||||
|
||||
def __init__(self, catalog_id: int, user_metadata: Dict[str, Any], system_metadata: Dict[str, Any],
|
||||
configuration: Dict[str, Any]):
|
||||
super().__init__()
|
||||
self._catalog_id = catalog_id
|
||||
self._user_metadata = user_metadata
|
||||
self._system_metadata = system_metadata
|
||||
self._configuration = configuration
|
||||
|
||||
@property
|
||||
def catalog_id(self):
|
||||
return self._catalog_id
|
||||
|
||||
@property
|
||||
def user_metadata(self):
|
||||
return self._user_metadata
|
||||
|
||||
@property
|
||||
def system_metadata(self):
|
||||
return self._system_metadata
|
||||
|
||||
@property
|
||||
def configuration(self):
|
||||
return self._configuration
|
||||
|
||||
# Any common methods that should be shared among retrievers can go here.
|
||||
154
common/langchain/templates/template_manager.py
Normal file
154
common/langchain/templates/template_manager.py
Normal file
@@ -0,0 +1,154 @@
|
||||
import os
|
||||
import yaml
|
||||
from typing import Dict, Optional, Any
|
||||
from packaging import version
|
||||
from dataclasses import dataclass
|
||||
from flask import current_app, Flask
|
||||
|
||||
from common.utils.os_utils import get_project_root
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptTemplate:
|
||||
"""Represents a versioned prompt template"""
|
||||
content: str
|
||||
version: str
|
||||
metadata: Dict[str, Any]
|
||||
|
||||
|
||||
class TemplateManager:
|
||||
"""Manages versioned prompt templates"""
|
||||
|
||||
def __init__(self):
|
||||
self.templates_dir = None
|
||||
self._templates = None
|
||||
self.app = None
|
||||
|
||||
def init_app(self, app: Flask) -> None:
|
||||
# Initialize template manager
|
||||
base_dir = "/app"
|
||||
self.templates_dir = os.path.join(base_dir, 'config', 'prompts')
|
||||
app.logger.debug(f'Loading templates from {self.templates_dir}')
|
||||
self.app = app
|
||||
self._templates = self._load_templates()
|
||||
# Log available templates for each supported model
|
||||
for llm in app.config['SUPPORTED_LLMS']:
|
||||
try:
|
||||
available_templates = self.list_templates(llm)
|
||||
app.logger.info(f"Loaded templates for {llm}: {available_templates}")
|
||||
except ValueError:
|
||||
app.logger.warning(f"No templates found for {llm}")
|
||||
|
||||
def _load_templates(self) -> Dict[str, Dict[str, Dict[str, PromptTemplate]]]:
|
||||
"""
|
||||
Load all template versions from the templates directory.
|
||||
Structure: {provider.model -> {template_name -> {version -> template}}}
|
||||
Directory structure:
|
||||
prompts/
|
||||
├── provider/
|
||||
│ └── model/
|
||||
│ └── template_name/
|
||||
│ └── version.yaml
|
||||
"""
|
||||
templates = {}
|
||||
|
||||
# Iterate through providers (anthropic, openai)
|
||||
for provider in os.listdir(self.templates_dir):
|
||||
provider_path = os.path.join(self.templates_dir, provider)
|
||||
if not os.path.isdir(provider_path):
|
||||
continue
|
||||
|
||||
# Iterate through models (claude-3, gpt-4o)
|
||||
for model in os.listdir(provider_path):
|
||||
model_path = os.path.join(provider_path, model)
|
||||
if not os.path.isdir(model_path):
|
||||
continue
|
||||
|
||||
provider_model = f"{provider}.{model}"
|
||||
templates[provider_model] = {}
|
||||
|
||||
# Iterate through template types (rag, summary, etc.)
|
||||
for template_name in os.listdir(model_path):
|
||||
template_path = os.path.join(model_path, template_name)
|
||||
if not os.path.isdir(template_path):
|
||||
continue
|
||||
|
||||
template_versions = {}
|
||||
# Load all version files for this template
|
||||
for version_file in os.listdir(template_path):
|
||||
if not version_file.endswith('.yaml'):
|
||||
continue
|
||||
|
||||
version_str = version_file[:-5] # Remove .yaml
|
||||
if not self._is_valid_version(version_str):
|
||||
current_app.logger.warning(
|
||||
f"Invalid version format for {template_name}: {version_str}")
|
||||
continue
|
||||
|
||||
try:
|
||||
with open(os.path.join(template_path, version_file)) as f:
|
||||
template_data = yaml.safe_load(f)
|
||||
# Verify required fields
|
||||
if not template_data.get('content'):
|
||||
raise ValueError("Template content is required")
|
||||
|
||||
template_versions[version_str] = PromptTemplate(
|
||||
content=template_data['content'],
|
||||
version=version_str,
|
||||
metadata=template_data.get('metadata', {})
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(
|
||||
f"Error loading template {template_name} version {version_str}: {e}")
|
||||
continue
|
||||
|
||||
if template_versions:
|
||||
templates[provider_model][template_name] = template_versions
|
||||
|
||||
return templates
|
||||
|
||||
def _is_valid_version(self, version_str: str) -> bool:
|
||||
"""Validate semantic versioning string"""
|
||||
try:
|
||||
version.parse(version_str)
|
||||
return True
|
||||
except version.InvalidVersion:
|
||||
return False
|
||||
|
||||
def get_template(self,
|
||||
provider_model: str,
|
||||
template_name: str,
|
||||
template_version: Optional[str] = None) -> PromptTemplate:
|
||||
"""
|
||||
Get a specific template version. If version not specified,
|
||||
returns the latest version.
|
||||
"""
|
||||
if provider_model not in self._templates:
|
||||
raise ValueError(f"Unknown provider.model: {provider_model}")
|
||||
|
||||
if template_name not in self._templates[provider_model]:
|
||||
raise ValueError(f"Unknown template: {template_name}")
|
||||
|
||||
versions = self._templates[provider_model][template_name]
|
||||
|
||||
if template_version:
|
||||
if template_version not in versions:
|
||||
raise ValueError(f"Template version {template_version} not found")
|
||||
return versions[template_version]
|
||||
|
||||
# Return latest version
|
||||
latest = max(versions.keys(), key=version.parse)
|
||||
return versions[latest]
|
||||
|
||||
def list_templates(self, provider_model: str) -> Dict[str, list]:
|
||||
"""
|
||||
List all available templates and their versions for a provider.model
|
||||
Returns: {template_name: [version1, version2, ...]}
|
||||
"""
|
||||
if provider_model not in self._templates:
|
||||
raise ValueError(f"Unknown provider.model: {provider_model}")
|
||||
|
||||
return {
|
||||
template_name: sorted(versions.keys(), key=version.parse)
|
||||
for template_name, versions in self._templates[provider_model].items()
|
||||
}
|
||||
@@ -1,27 +0,0 @@
|
||||
import time
|
||||
from common.utils.business_event_context import current_event
|
||||
|
||||
|
||||
def tracked_transcribe(client, *args, **kwargs):
|
||||
start_time = time.time()
|
||||
|
||||
# Extract the file and model from kwargs if present, otherwise use defaults
|
||||
file = kwargs.get('file')
|
||||
model = kwargs.get('model', 'whisper-1')
|
||||
duration = kwargs.pop('duration', 600)
|
||||
|
||||
result = client.audio.transcriptions.create(*args, **kwargs)
|
||||
end_time = time.time()
|
||||
|
||||
# Token usage for transcriptions is actually the duration in seconds we pass, as the whisper model is priced per second transcribed
|
||||
|
||||
metrics = {
|
||||
'total_tokens': duration,
|
||||
'prompt_tokens': 0, # For transcriptions, all tokens are considered "completion"
|
||||
'completion_tokens': duration,
|
||||
'time_elapsed': end_time - start_time,
|
||||
'interaction_type': 'ASR',
|
||||
}
|
||||
current_event.log_llm_metrics(metrics)
|
||||
|
||||
return result
|
||||
77
common/langchain/tracked_transcription.py
Normal file
77
common/langchain/tracked_transcription.py
Normal file
@@ -0,0 +1,77 @@
|
||||
# common/langchain/tracked_transcription.py
|
||||
from typing import Any, Optional, Dict
|
||||
import time
|
||||
from openai import OpenAI
|
||||
from common.utils.business_event_context import current_event
|
||||
|
||||
|
||||
class TrackedOpenAITranscription:
|
||||
"""Wrapper for OpenAI transcription with metric tracking"""
|
||||
|
||||
def __init__(self, api_key: str, **kwargs: Any):
|
||||
"""Initialize with OpenAI client settings"""
|
||||
self.client = OpenAI(api_key=api_key)
|
||||
self.model = kwargs.get('model', 'whisper-1')
|
||||
|
||||
def transcribe(self,
|
||||
file: Any,
|
||||
model: Optional[str] = None,
|
||||
language: Optional[str] = None,
|
||||
prompt: Optional[str] = None,
|
||||
response_format: Optional[str] = None,
|
||||
temperature: Optional[float] = None,
|
||||
duration: Optional[int] = None) -> str:
|
||||
"""
|
||||
Transcribe audio with metrics tracking
|
||||
|
||||
Args:
|
||||
file: Audio file to transcribe
|
||||
model: Model to use (defaults to whisper-1)
|
||||
language: Optional language of the audio
|
||||
prompt: Optional prompt to guide transcription
|
||||
response_format: Response format (json, text, etc)
|
||||
temperature: Sampling temperature
|
||||
duration: Duration of audio in seconds for metrics
|
||||
|
||||
Returns:
|
||||
Transcription text
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# Create transcription options
|
||||
options = {
|
||||
"file": file,
|
||||
"model": model or self.model,
|
||||
}
|
||||
if language:
|
||||
options["language"] = language
|
||||
if prompt:
|
||||
options["prompt"] = prompt
|
||||
if response_format:
|
||||
options["response_format"] = response_format
|
||||
if temperature:
|
||||
options["temperature"] = temperature
|
||||
|
||||
response = self.client.audio.transcriptions.create(**options)
|
||||
|
||||
# Calculate metrics
|
||||
end_time = time.time()
|
||||
|
||||
# Token usage for transcriptions is based on audio duration
|
||||
metrics = {
|
||||
'total_tokens': duration or 600, # Default to 10 minutes if duration not provided
|
||||
'prompt_tokens': 0, # For transcriptions, all tokens are completion
|
||||
'completion_tokens': duration or 600,
|
||||
'time_elapsed': end_time - start_time,
|
||||
'interaction_type': 'ASR',
|
||||
}
|
||||
current_event.log_llm_metrics(metrics)
|
||||
|
||||
# Return text from response
|
||||
if isinstance(response, str):
|
||||
return response
|
||||
return response.text
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"Transcription failed: {str(e)}")
|
||||
Reference in New Issue
Block a user