import json import re from typing import Dict, Any, Optional from datetime import datetime as dt, timezone as tz import xxhash from flask import current_app from langchain_core.output_parsers import StrOutputParser from langchain_core.prompts import ChatPromptTemplate from langchain_core.runnables import RunnablePassthrough from sqlalchemy.inspection import inspect from common.langchain.persistent_llm_metrics_handler import PersistentLLMMetricsHandler from common.utils.business_event_context import current_event from common.utils.cache.base import CacheHandler, T from common.extensions import db from common.models.user import TranslationCache from flask_security import current_user from common.utils.model_utils import get_template class TranslationCacheHandler(CacheHandler[TranslationCache]): """Handles caching of translations with fallback to database and external translation service""" handler_name = 'translation_cache' def __init__(self, region): super().__init__(region, 'translation') self.configure_keys('hash_key') def _to_cache_data(self, instance: TranslationCache) -> Dict[str, Any]: """Convert TranslationCache instance to cache data using SQLAlchemy inspection""" if not instance: return {} mapper = inspect(TranslationCache) data = {} for column in mapper.columns: value = getattr(instance, column.name) # Handle date serialization if isinstance(value, dt): data[column.name] = value.isoformat() else: data[column.name] = value return data def _from_cache_data(self, data: Dict[str, Any], **kwargs) -> TranslationCache: if not data: return None # Create a new TranslationCache instance translation = TranslationCache() mapper = inspect(TranslationCache) # Set all attributes dynamically for column in mapper.columns: if column.name in data: value = data[column.name] # Handle date deserialization if column.name.endswith('_date') and value: if isinstance(value, str): value = dt.fromisoformat(value).date() setattr(translation, column.name, value) metrics = { 'total_tokens': translation.prompt_tokens + translation.completion_tokens, 'prompt_tokens': translation.prompt_tokens, 'completion_tokens': translation.completion_tokens, 'time_elapsed': 0, 'interaction_type': 'LLM' } current_event.log_llm_metrics(metrics) return translation def _should_cache(self, value) -> bool: """Validate if the translation should be cached""" if value is None: return False # Handle both TranslationCache objects and serialized data (dict) if isinstance(value, TranslationCache): return value.cache_key is not None elif isinstance(value, dict): return value.get('cache_key') is not None return False def get_translation(self, text: str, target_lang: str, source_lang: str = None, context: str = None) -> Optional[ TranslationCache]: """ Get the translation for a text in a specific language Args: text: The text to be translated target_lang: The target language for the translation source_lang: The source language of the text to be translated context: Optional context for the translation Returns: TranslationCache instance if found, None otherwise """ if not context: context = 'No context provided.' def creator_func(hash_key: str) -> Optional[TranslationCache]: # Check if translation already exists in database existing_translation = db.session.query(TranslationCache).filter_by(cache_key=hash_key).first() if existing_translation: # Update last used timestamp existing_translation.last_used_at = dt.now(tz=tz.utc) metrics = { 'total_tokens': existing_translation.prompt_tokens + existing_translation.completion_tokens, 'prompt_tokens': existing_translation.prompt_tokens, 'completion_tokens': existing_translation.completion_tokens, 'time_elapsed': 0, 'interaction_type': 'TRANSLATION' } current_event.log_llm_metrics(metrics) db.session.commit() return existing_translation # Translation not found in DB, need to create it # Get the translation and metrics translated_text, metrics = self.translate_text( text_to_translate=text, target_lang=target_lang, source_lang=source_lang, context=context ) # Create new translation cache record new_translation = TranslationCache( cache_key=hash_key, source_text=text, translated_text=translated_text, source_language=source_lang, target_language=target_lang, context=context, prompt_tokens=metrics.get('prompt_tokens', 0), completion_tokens=metrics.get('completion_tokens', 0), created_at=dt.now(tz=tz.utc), created_by=getattr(current_user, 'id', None) if 'current_user' in globals() else None, updated_at=dt.now(tz=tz.utc), updated_by=getattr(current_user, 'id', None) if 'current_user' in globals() else None, last_used_at=dt.now(tz=tz.utc) ) # Save to database db.session.add(new_translation) db.session.commit() return new_translation # Generate the hash key using your existing method hash_key = self._generate_cache_key(text, target_lang, source_lang, context) # Pass the hash_key to the get method return self.get(creator_func, hash_key=hash_key) def invalidate_tenant_translations(self, tenant_id: int): """Invalidate cached translations for specific tenant""" self.invalidate(tenant_id=tenant_id) def _generate_cache_key(self, text: str, target_lang: str, source_lang: str = None, context: str = None) -> str: """Generate cache key for a translation""" cache_data = { "text": text.strip(), "target_lang": target_lang.lower(), "source_lang": source_lang.lower() if source_lang else None, "context": context.strip() if context else None, } cache_string = json.dumps(cache_data, sort_keys=True, ensure_ascii=False) return xxhash.xxh64(cache_string.encode('utf-8')).hexdigest() def translate_text(self, text_to_translate: str, target_lang: str, source_lang: str = None, context: str = None) \ -> tuple[str, dict[str, int | float]]: target_language = current_app.config['SUPPORTED_LANGUAGE_ISO639_1_LOOKUP'][target_lang] prompt_params = { "text_to_translate": text_to_translate, "target_language": target_language, } if context: template, llm = get_template("translation_with_context") prompt_params["context"] = context else: template, llm = get_template("translation_without_context") # Add a metrics handler to capture usage metrics_handler = PersistentLLMMetricsHandler() existing_callbacks = llm.callbacks llm.callbacks = existing_callbacks + [metrics_handler] translation_prompt = ChatPromptTemplate.from_template(template) setup = RunnablePassthrough() chain = (setup | translation_prompt | llm | StrOutputParser()) translation = chain.invoke(prompt_params) # Remove double square brackets from translation translation = re.sub(r'\[\[(.*?)\]\]', r'\1', translation) metrics = metrics_handler.get_metrics() return translation, metrics def register_translation_cache_handlers(cache_manager) -> None: """Register translation cache handlers with cache manager""" cache_manager.register_handler( TranslationCacheHandler, 'eveai_model' # Use existing eveai_model region )