224 lines
8.4 KiB
Python
224 lines
8.4 KiB
Python
import json
|
|
import re
|
|
from typing import Dict, Any, Optional
|
|
from datetime import datetime as dt, timezone as tz
|
|
|
|
import xxhash
|
|
from flask import current_app
|
|
from langchain_core.output_parsers import StrOutputParser
|
|
from langchain_core.prompts import ChatPromptTemplate
|
|
from langchain_core.runnables import RunnablePassthrough
|
|
from sqlalchemy.inspection import inspect
|
|
|
|
from common.langchain.persistent_llm_metrics_handler import PersistentLLMMetricsHandler
|
|
from common.utils.business_event_context import current_event
|
|
from common.utils.cache.base import CacheHandler, T
|
|
from common.extensions import db
|
|
|
|
from common.models.user import TranslationCache
|
|
from flask_security import current_user
|
|
|
|
from common.utils.model_utils import get_template
|
|
|
|
|
|
class TranslationCacheHandler(CacheHandler[TranslationCache]):
|
|
"""Handles caching of translations with fallback to database and external translation service"""
|
|
handler_name = 'translation_cache'
|
|
|
|
def __init__(self, region):
|
|
super().__init__(region, 'translation')
|
|
self.configure_keys('hash_key')
|
|
|
|
def _to_cache_data(self, instance: TranslationCache) -> Dict[str, Any]:
|
|
"""Convert TranslationCache instance to cache data using SQLAlchemy inspection"""
|
|
if not instance:
|
|
return {}
|
|
|
|
mapper = inspect(TranslationCache)
|
|
data = {}
|
|
|
|
for column in mapper.columns:
|
|
value = getattr(instance, column.name)
|
|
|
|
# Handle date serialization
|
|
if isinstance(value, dt):
|
|
data[column.name] = value.isoformat()
|
|
else:
|
|
data[column.name] = value
|
|
|
|
return data
|
|
|
|
def _from_cache_data(self, data: Dict[str, Any], **kwargs) -> TranslationCache:
|
|
if not data:
|
|
return None
|
|
|
|
# Create a new TranslationCache instance
|
|
translation = TranslationCache()
|
|
mapper = inspect(TranslationCache)
|
|
|
|
# Set all attributes dynamically
|
|
for column in mapper.columns:
|
|
if column.name in data:
|
|
value = data[column.name]
|
|
|
|
# Handle date deserialization
|
|
if column.name.endswith('_date') and value:
|
|
if isinstance(value, str):
|
|
value = dt.fromisoformat(value).date()
|
|
|
|
setattr(translation, column.name, value)
|
|
|
|
metrics = {
|
|
'total_tokens': translation.prompt_tokens + translation.completion_tokens,
|
|
'prompt_tokens': translation.prompt_tokens,
|
|
'completion_tokens': translation.completion_tokens,
|
|
'time_elapsed': 0,
|
|
'interaction_type': 'TRANSLATION-CACHE'
|
|
}
|
|
current_event.log_llm_metrics(metrics)
|
|
|
|
return translation
|
|
|
|
def _should_cache(self, value) -> bool:
|
|
"""Validate if the translation should be cached"""
|
|
if value is None:
|
|
return False
|
|
|
|
# Handle both TranslationCache objects and serialized data (dict)
|
|
if isinstance(value, TranslationCache):
|
|
return value.cache_key is not None
|
|
elif isinstance(value, dict):
|
|
return value.get('cache_key') is not None
|
|
|
|
return False
|
|
|
|
def get_translation(self, text: str, target_lang: str, source_lang: str = None, context: str = None) -> Optional[
|
|
TranslationCache]:
|
|
"""
|
|
Get the translation for a text in a specific language
|
|
|
|
Args:
|
|
text: The text to be translated
|
|
target_lang: The target language for the translation
|
|
source_lang: The source language of the text to be translated
|
|
context: Optional context for the translation
|
|
|
|
Returns:
|
|
TranslationCache instance if found, None otherwise
|
|
"""
|
|
if not context:
|
|
context = 'No context provided.'
|
|
|
|
def creator_func(hash_key: str) -> Optional[TranslationCache]:
|
|
# Check if translation already exists in database
|
|
existing_translation = db.session.query(TranslationCache).filter_by(cache_key=hash_key).first()
|
|
|
|
if existing_translation:
|
|
# Update last used timestamp
|
|
existing_translation.last_used_at = dt.now(tz=tz.utc)
|
|
metrics = {
|
|
'total_tokens': existing_translation.prompt_tokens + existing_translation.completion_tokens,
|
|
'prompt_tokens': existing_translation.prompt_tokens,
|
|
'completion_tokens': existing_translation.completion_tokens,
|
|
'time_elapsed': 0,
|
|
'interaction_type': 'TRANSLATION-DB'
|
|
}
|
|
current_event.log_llm_metrics(metrics)
|
|
db.session.commit()
|
|
return existing_translation
|
|
|
|
# Translation not found in DB, need to create it
|
|
# Get the translation and metrics
|
|
translated_text, metrics = self.translate_text(
|
|
text_to_translate=text,
|
|
target_lang=target_lang,
|
|
source_lang=source_lang,
|
|
context=context
|
|
)
|
|
|
|
# Create new translation cache record
|
|
new_translation = TranslationCache(
|
|
cache_key=hash_key,
|
|
source_text=text,
|
|
translated_text=translated_text,
|
|
source_language=source_lang,
|
|
target_language=target_lang,
|
|
context=context,
|
|
prompt_tokens=metrics.get('prompt_tokens', 0),
|
|
completion_tokens=metrics.get('completion_tokens', 0),
|
|
created_at=dt.now(tz=tz.utc),
|
|
created_by=getattr(current_user, 'id', None) if 'current_user' in globals() else None,
|
|
updated_at=dt.now(tz=tz.utc),
|
|
updated_by=getattr(current_user, 'id', None) if 'current_user' in globals() else None,
|
|
last_used_at=dt.now(tz=tz.utc)
|
|
)
|
|
|
|
# Save to database
|
|
db.session.add(new_translation)
|
|
db.session.commit()
|
|
|
|
return new_translation
|
|
|
|
# Generate the hash key using your existing method
|
|
hash_key = self._generate_cache_key(text, target_lang, source_lang, context)
|
|
|
|
# Pass the hash_key to the get method
|
|
return self.get(creator_func, hash_key=hash_key)
|
|
|
|
def invalidate_tenant_translations(self, tenant_id: int):
|
|
"""Invalidate cached translations for specific tenant"""
|
|
self.invalidate(tenant_id=tenant_id)
|
|
|
|
def _generate_cache_key(self, text: str, target_lang: str, source_lang: str = None, context: str = None) -> str:
|
|
"""Generate cache key for a translation"""
|
|
cache_data = {
|
|
"text": text.strip(),
|
|
"target_lang": target_lang.lower(),
|
|
"source_lang": source_lang.lower() if source_lang else None,
|
|
"context": context.strip() if context else None,
|
|
}
|
|
|
|
cache_string = json.dumps(cache_data, sort_keys=True, ensure_ascii=False)
|
|
return xxhash.xxh64(cache_string.encode('utf-8')).hexdigest()
|
|
|
|
def translate_text(self, text_to_translate: str, target_lang: str, source_lang: str = None, context: str = None) \
|
|
-> tuple[str, dict[str, int | float]]:
|
|
target_language = current_app.config['SUPPORTED_LANGUAGE_ISO639_1_LOOKUP'][target_lang]
|
|
prompt_params = {
|
|
"text_to_translate": text_to_translate,
|
|
"target_language": target_language,
|
|
}
|
|
if context:
|
|
template, llm = get_template("translation_with_context")
|
|
prompt_params["context"] = context
|
|
else:
|
|
template, llm = get_template("translation_without_context")
|
|
|
|
# Add a metrics handler to capture usage
|
|
|
|
metrics_handler = PersistentLLMMetricsHandler()
|
|
existing_callbacks = llm.callbacks
|
|
llm.callbacks = existing_callbacks + [metrics_handler]
|
|
|
|
translation_prompt = ChatPromptTemplate.from_template(template)
|
|
|
|
setup = RunnablePassthrough()
|
|
|
|
chain = (setup | translation_prompt | llm | StrOutputParser())
|
|
|
|
translation = chain.invoke(prompt_params)
|
|
|
|
# Remove double square brackets from translation
|
|
translation = re.sub(r'\[\[(.*?)\]\]', r'\1', translation)
|
|
|
|
metrics = metrics_handler.get_metrics()
|
|
|
|
return translation, metrics
|
|
|
|
def register_translation_cache_handlers(cache_manager) -> None:
|
|
"""Register translation cache handlers with cache manager"""
|
|
cache_manager.register_handler(
|
|
TranslationCacheHandler,
|
|
'eveai_model' # Use existing eveai_model region
|
|
)
|