- Introduction of dynamic Retrievers & Specialists

- Introduction of dynamic Processors
- Introduction of caching system
- Introduction of a better template manager
- Adaptation of ModelVariables to support dynamic Processors / Retrievers / Specialists
- Start adaptation of chat client
This commit is contained in:
Josako
2024-11-15 10:00:53 +01:00
parent 55a8a95f79
commit 1807435339
101 changed files with 4181 additions and 1764 deletions

View File

@@ -0,0 +1,23 @@
# Output Schema Management - common/langchain/outputs/base.py
from typing import Dict, Type, Any
from pydantic import BaseModel
class BaseSpecialistOutput(BaseModel):
"""Base class for all specialist outputs"""
pass
class OutputRegistry:
"""Registry for specialist output schemas"""
_schemas: Dict[str, Type[BaseSpecialistOutput]] = {}
@classmethod
def register(cls, specialist_type: str, schema_class: Type[BaseSpecialistOutput]):
cls._schemas[specialist_type] = schema_class
@classmethod
def get_schema(cls, specialist_type: str) -> Type[BaseSpecialistOutput]:
if specialist_type not in cls._schemas:
raise ValueError(f"No output schema registered for {specialist_type}")
return cls._schemas[specialist_type]

View File

@@ -0,0 +1,22 @@
# RAG Specialist Output - common/langchain/outputs/rag.py
from typing import List
from pydantic import Field
from .base import BaseSpecialistOutput
class RAGOutput(BaseSpecialistOutput):
"""Output schema for RAG specialist"""
"""Default docstring - to be replaced with actual prompt"""
answer: str = Field(
...,
description="The answer to the user question, based on the given sources",
)
citations: List[int] = Field(
...,
description="The integer IDs of the SPECIFIC sources that were used to generate the answer"
)
insufficient_info: bool = Field(
False, # Default value is set to False
description="A boolean indicating whether given sources were sufficient or not to generate the answer"
)

View File

@@ -1,145 +0,0 @@
from langchain_core.retrievers import BaseRetriever
from sqlalchemy import func, and_, or_, desc
from sqlalchemy.exc import SQLAlchemyError
from pydantic import BaseModel, Field, PrivateAttr
from typing import Any, Dict
from flask import current_app
from common.extensions import db
from common.models.document import Document, DocumentVersion
from common.utils.datetime_utils import get_date_in_timezone
from common.utils.model_utils import ModelVariables
class EveAIDefaultRagRetriever(BaseRetriever, BaseModel):
_catalog_id: int = PrivateAttr()
_model_variables: ModelVariables = PrivateAttr()
_tenant_info: Dict[str, Any] = PrivateAttr()
def __init__(self, catalog_id: int, model_variables: ModelVariables, tenant_info: Dict[str, Any]):
super().__init__()
current_app.logger.debug(f'Model variables type: {type(model_variables)}')
self._catalog_id = catalog_id
self._model_variables = model_variables
self._tenant_info = tenant_info
@property
def catalog_id(self) -> int:
return self._catalog_id
@property
def model_variables(self) -> ModelVariables:
return self._model_variables
@property
def tenant_info(self) -> Dict[str, Any]:
return self._tenant_info
def _get_relevant_documents(self, query: str):
current_app.logger.debug(f'Retrieving relevant documents for query: {query}')
query_embedding = self._get_query_embedding(query)
current_app.logger.debug(f'Model Variables Private: {type(self._model_variables)}')
current_app.logger.debug(f'Model Variables Property: {type(self.model_variables)}')
db_class = self.model_variables['embedding_db_model']
similarity_threshold = self.model_variables['similarity_threshold']
k = self.model_variables['k']
if self.model_variables['rag_tuning']:
try:
current_date = get_date_in_timezone(self.tenant_info['timezone'])
current_app.rag_tuning_logger.debug(f'Current date: {current_date}\n')
# Debug query to show similarity for all valid documents (without chunk text)
debug_query = (
db.session.query(
Document.id.label('document_id'),
DocumentVersion.id.label('version_id'),
db_class.id.label('embedding_id'),
(1 - db_class.embedding.cosine_distance(query_embedding)).label('similarity')
)
.join(DocumentVersion, db_class.doc_vers_id == DocumentVersion.id)
.join(Document, DocumentVersion.doc_id == Document.id)
.filter(
or_(Document.valid_from.is_(None), func.date(Document.valid_from) <= current_date),
or_(Document.valid_to.is_(None), func.date(Document.valid_to) >= current_date)
)
.order_by(desc('similarity'))
)
debug_results = debug_query.all()
current_app.logger.debug("Debug: Similarity for all valid documents:")
for row in debug_results:
current_app.rag_tuning_logger.debug(f"Doc ID: {row.document_id}, "
f"Version ID: {row.version_id}, "
f"Embedding ID: {row.embedding_id}, "
f"Similarity: {row.similarity}")
current_app.rag_tuning_logger.debug(f'---------------------------------------\n')
except SQLAlchemyError as e:
current_app.logger.error(f'Error generating overview: {e}')
db.session.rollback()
if self.model_variables['rag_tuning']:
current_app.rag_tuning_logger.debug(f'Parameters for Retrieval of documents: \n')
current_app.rag_tuning_logger.debug(f'Similarity Threshold: {similarity_threshold}\n')
current_app.rag_tuning_logger.debug(f'K: {k}\n')
current_app.rag_tuning_logger.debug(f'---------------------------------------\n')
try:
current_date = get_date_in_timezone(self.tenant_info['timezone'])
# Subquery to find the latest version of each document
subquery = (
db.session.query(
DocumentVersion.doc_id,
func.max(DocumentVersion.id).label('latest_version_id')
)
.group_by(DocumentVersion.doc_id)
.subquery()
)
# Main query to filter embeddings
query_obj = (
db.session.query(db_class,
(1 - db_class.embedding.cosine_distance(query_embedding)).label('similarity'))
.join(DocumentVersion, db_class.doc_vers_id == DocumentVersion.id)
.join(Document, DocumentVersion.doc_id == Document.id)
.join(subquery, DocumentVersion.id == subquery.c.latest_version_id)
.filter(
or_(Document.valid_from.is_(None), func.date(Document.valid_from) <= current_date),
or_(Document.valid_to.is_(None), func.date(Document.valid_to) >= current_date),
(1 - db_class.embedding.cosine_distance(query_embedding)) > similarity_threshold,
Document.catalog_id == self._catalog_id
)
.order_by(desc('similarity'))
.limit(k)
)
if self.model_variables['rag_tuning']:
current_app.rag_tuning_logger.debug(f'Query executed for Retrieval of documents: \n')
current_app.rag_tuning_logger.debug(f'{query_obj.statement}\n')
current_app.rag_tuning_logger.debug(f'---------------------------------------\n')
res = query_obj.all()
if self.model_variables['rag_tuning']:
current_app.rag_tuning_logger.debug(f'Retrieved {len(res)} relevant documents \n')
current_app.rag_tuning_logger.debug(f'Data retrieved: \n')
current_app.rag_tuning_logger.debug(f'{res}\n')
current_app.rag_tuning_logger.debug(f'---------------------------------------\n')
result = []
for doc in res:
if self.model_variables['rag_tuning']:
current_app.rag_tuning_logger.debug(f'Document ID: {doc[0].id} - Distance: {doc[1]}\n')
current_app.rag_tuning_logger.debug(f'Chunk: \n {doc[0].chunk}\n\n')
result.append(f'SOURCE: {doc[0].id}\n\n{doc[0].chunk}\n\n')
except SQLAlchemyError as e:
current_app.logger.error(f'Error retrieving relevant documents: {e}')
db.session.rollback()
return []
return result
def _get_query_embedding(self, query: str):
embedding_model = self.model_variables['embedding_model']
query_embedding = embedding_model.embed_query(query)
return query_embedding

View File

@@ -1,154 +0,0 @@
from langchain_core.retrievers import BaseRetriever
from sqlalchemy import func, and_, or_, desc, cast, JSON
from sqlalchemy.exc import SQLAlchemyError
from pydantic import BaseModel, Field, PrivateAttr
from typing import Any, Dict, List, Optional
from flask import current_app
from contextlib import contextmanager
from common.extensions import db
from common.models.document import Document, DocumentVersion, Catalog
from common.utils.datetime_utils import get_date_in_timezone
from common.utils.model_utils import ModelVariables
class EveAIDossierRetriever(BaseRetriever, BaseModel):
_catalog_id: int = PrivateAttr()
_model_variables: ModelVariables = PrivateAttr()
_tenant_info: Dict[str, Any] = PrivateAttr()
_active_filters: Optional[Dict[str, Any]] = PrivateAttr()
def __init__(self, catalog_id: int, model_variables: ModelVariables, tenant_info: Dict[str, Any]):
super().__init__()
self._catalog_id = catalog_id
self._model_variables = model_variables
self._tenant_info = tenant_info
self._active_filters = None
@contextmanager
def filtering(self, metadata_filters: Dict[str, Any]):
"""Context manager for temporarily setting metadata filters"""
previous_filters = self._active_filters
self._active_filters = metadata_filters
try:
yield self
finally:
self._active_filters = previous_filters
def _build_metadata_filter_conditions(self, query):
"""Build SQL conditions for metadata filtering"""
if not self._active_filters:
return query
conditions = []
for field, value in self._active_filters.items():
if value is None:
continue
# Handle both single values and lists of values
if isinstance(value, (list, tuple)):
# Multiple values - create OR condition
or_conditions = []
for val in value:
or_conditions.append(
cast(DocumentVersion.user_metadata[field].astext, JSON) == str(val)
)
if or_conditions:
conditions.append(or_(*or_conditions))
else:
# Single value - direct comparison
conditions.append(
cast(DocumentVersion.user_metadata[field].astext, JSON) == str(value)
)
if conditions:
query = query.filter(and_(*conditions))
return query
def _get_relevant_documents(self, query: str):
current_app.logger.debug(f'Retrieving relevant documents for dossier query: {query}')
if self._active_filters:
current_app.logger.debug(f'Using metadata filters: {self._active_filters}')
query_embedding = self._get_query_embedding(query)
db_class = self.model_variables['embedding_db_model']
similarity_threshold = self.model_variables['similarity_threshold']
k = self.model_variables['k']
try:
current_date = get_date_in_timezone(self.tenant_info['timezone'])
# Subquery to find the latest version of each document
subquery = (
db.session.query(
DocumentVersion.doc_id,
func.max(DocumentVersion.id).label('latest_version_id')
)
.group_by(DocumentVersion.doc_id)
.subquery()
)
# Build base query
# Build base query
query_obj = (
db.session.query(db_class,
(1 - db_class.embedding.cosine_distance(query_embedding)).label('similarity'))
.join(DocumentVersion, db_class.doc_vers_id == DocumentVersion.id)
.join(Document, DocumentVersion.doc_id == Document.id)
.join(subquery, DocumentVersion.id == subquery.c.latest_version_id)
.filter(
or_(Document.valid_from.is_(None), func.date(Document.valid_from) <= current_date),
or_(Document.valid_to.is_(None), func.date(Document.valid_to) >= current_date),
(1 - db_class.embedding.cosine_distance(query_embedding)) > similarity_threshold,
Document.catalog_id == self._catalog_id
)
)
# Apply metadata filters
query_obj = self._build_metadata_filter_conditions(query_obj)
# Order and limit results
query_obj = query_obj.order_by(desc('similarity')).limit(k)
# Debug logging for RAG tuning if enabled
if self.model_variables['rag_tuning']:
self._log_rag_tuning(query_obj, query_embedding)
res = query_obj.all()
result = []
for doc in res:
if self.model_variables['rag_tuning']:
current_app.logger.debug(f'Document ID: {doc[0].id} - Distance: {doc[1]}\n')
current_app.logger.debug(f'Chunk: \n {doc[0].chunk}\n\n')
result.append(f'SOURCE: {doc[0].id}\n\n{doc[0].chunk}\n\n')
except SQLAlchemyError as e:
current_app.logger.error(f'Error retrieving relevant documents: {e}')
db.session.rollback()
return []
return result
def _log_rag_tuning(self, query_obj, query_embedding):
"""Log debug information for RAG tuning"""
current_app.rag_tuning_logger.debug("Debug: Query execution plan:")
current_app.rag_tuning_logger.debug(f"{query_obj.statement}")
if self._active_filters:
current_app.rag_tuning_logger.debug("Debug: Active metadata filters:")
current_app.rag_tuning_logger.debug(f"{self._active_filters}")
def _get_query_embedding(self, query: str):
"""Get embedding for the query text"""
embedding_model = self.model_variables['embedding_model']
query_embedding = embedding_model.embed_query(query)
return query_embedding
@property
def model_variables(self) -> ModelVariables:
return self._model_variables
@property
def tenant_info(self) -> Dict[str, Any]:
return self._tenant_info

View File

@@ -1,52 +0,0 @@
from langchain_core.retrievers import BaseRetriever
from sqlalchemy import asc
from sqlalchemy.exc import SQLAlchemyError
from pydantic import Field, BaseModel, PrivateAttr
from typing import Any, Dict
from flask import current_app
from common.extensions import db
from common.models.interaction import ChatSession, Interaction
from common.utils.model_utils import ModelVariables
class EveAIHistoryRetriever(BaseRetriever, BaseModel):
_model_variables: ModelVariables = PrivateAttr()
_session_id: str = PrivateAttr()
def __init__(self, model_variables: ModelVariables, session_id: str):
super().__init__()
self._model_variables = model_variables
self._session_id = session_id
@property
def model_variables(self) -> ModelVariables:
return self._model_variables
@property
def session_id(self) -> str:
return self._session_id
def _get_relevant_documents(self, query: str):
current_app.logger.debug(f'Retrieving history of interactions for query: {query}')
try:
query_obj = (
db.session.query(Interaction)
.join(ChatSession, Interaction.chat_session_id == ChatSession.id)
.filter(ChatSession.session_id == self.session_id)
.order_by(asc(Interaction.id))
)
interactions = query_obj.all()
result = []
for interaction in interactions:
result.append(f'HUMAN:\n{interaction.detailed_question}\n\nAI: \n{interaction.answer}\n\n')
except SQLAlchemyError as e:
current_app.logger.error(f'Error retrieving history of interactions: {e}')
db.session.rollback()
return []
return result

View File

@@ -1,40 +0,0 @@
from pydantic import BaseModel, PrivateAttr
from typing import Dict, Any
from common.utils.model_utils import ModelVariables
class EveAIRetriever(BaseModel):
_catalog_id: int = PrivateAttr()
_user_metadata: Dict[str, Any] = PrivateAttr()
_system_metadata: Dict[str, Any] = PrivateAttr()
_configuration: Dict[str, Any] = PrivateAttr()
_tenant_info: Dict[str, Any] = PrivateAttr()
_model_variables: ModelVariables = PrivateAttr()
_tuning: bool = PrivateAttr()
def __init__(self, catalog_id: int, user_metadata: Dict[str, Any], system_metadata: Dict[str, Any],
configuration: Dict[str, Any]):
super().__init__()
self._catalog_id = catalog_id
self._user_metadata = user_metadata
self._system_metadata = system_metadata
self._configuration = configuration
@property
def catalog_id(self):
return self._catalog_id
@property
def user_metadata(self):
return self._user_metadata
@property
def system_metadata(self):
return self._system_metadata
@property
def configuration(self):
return self._configuration
# Any common methods that should be shared among retrievers can go here.

View File

@@ -0,0 +1,154 @@
import os
import yaml
from typing import Dict, Optional, Any
from packaging import version
from dataclasses import dataclass
from flask import current_app, Flask
from common.utils.os_utils import get_project_root
@dataclass
class PromptTemplate:
"""Represents a versioned prompt template"""
content: str
version: str
metadata: Dict[str, Any]
class TemplateManager:
"""Manages versioned prompt templates"""
def __init__(self):
self.templates_dir = None
self._templates = None
self.app = None
def init_app(self, app: Flask) -> None:
# Initialize template manager
base_dir = "/app"
self.templates_dir = os.path.join(base_dir, 'config', 'prompts')
app.logger.debug(f'Loading templates from {self.templates_dir}')
self.app = app
self._templates = self._load_templates()
# Log available templates for each supported model
for llm in app.config['SUPPORTED_LLMS']:
try:
available_templates = self.list_templates(llm)
app.logger.info(f"Loaded templates for {llm}: {available_templates}")
except ValueError:
app.logger.warning(f"No templates found for {llm}")
def _load_templates(self) -> Dict[str, Dict[str, Dict[str, PromptTemplate]]]:
"""
Load all template versions from the templates directory.
Structure: {provider.model -> {template_name -> {version -> template}}}
Directory structure:
prompts/
├── provider/
│ └── model/
│ └── template_name/
│ └── version.yaml
"""
templates = {}
# Iterate through providers (anthropic, openai)
for provider in os.listdir(self.templates_dir):
provider_path = os.path.join(self.templates_dir, provider)
if not os.path.isdir(provider_path):
continue
# Iterate through models (claude-3, gpt-4o)
for model in os.listdir(provider_path):
model_path = os.path.join(provider_path, model)
if not os.path.isdir(model_path):
continue
provider_model = f"{provider}.{model}"
templates[provider_model] = {}
# Iterate through template types (rag, summary, etc.)
for template_name in os.listdir(model_path):
template_path = os.path.join(model_path, template_name)
if not os.path.isdir(template_path):
continue
template_versions = {}
# Load all version files for this template
for version_file in os.listdir(template_path):
if not version_file.endswith('.yaml'):
continue
version_str = version_file[:-5] # Remove .yaml
if not self._is_valid_version(version_str):
current_app.logger.warning(
f"Invalid version format for {template_name}: {version_str}")
continue
try:
with open(os.path.join(template_path, version_file)) as f:
template_data = yaml.safe_load(f)
# Verify required fields
if not template_data.get('content'):
raise ValueError("Template content is required")
template_versions[version_str] = PromptTemplate(
content=template_data['content'],
version=version_str,
metadata=template_data.get('metadata', {})
)
except Exception as e:
current_app.logger.error(
f"Error loading template {template_name} version {version_str}: {e}")
continue
if template_versions:
templates[provider_model][template_name] = template_versions
return templates
def _is_valid_version(self, version_str: str) -> bool:
"""Validate semantic versioning string"""
try:
version.parse(version_str)
return True
except version.InvalidVersion:
return False
def get_template(self,
provider_model: str,
template_name: str,
template_version: Optional[str] = None) -> PromptTemplate:
"""
Get a specific template version. If version not specified,
returns the latest version.
"""
if provider_model not in self._templates:
raise ValueError(f"Unknown provider.model: {provider_model}")
if template_name not in self._templates[provider_model]:
raise ValueError(f"Unknown template: {template_name}")
versions = self._templates[provider_model][template_name]
if template_version:
if template_version not in versions:
raise ValueError(f"Template version {template_version} not found")
return versions[template_version]
# Return latest version
latest = max(versions.keys(), key=version.parse)
return versions[latest]
def list_templates(self, provider_model: str) -> Dict[str, list]:
"""
List all available templates and their versions for a provider.model
Returns: {template_name: [version1, version2, ...]}
"""
if provider_model not in self._templates:
raise ValueError(f"Unknown provider.model: {provider_model}")
return {
template_name: sorted(versions.keys(), key=version.parse)
for template_name, versions in self._templates[provider_model].items()
}

View File

@@ -1,27 +0,0 @@
import time
from common.utils.business_event_context import current_event
def tracked_transcribe(client, *args, **kwargs):
start_time = time.time()
# Extract the file and model from kwargs if present, otherwise use defaults
file = kwargs.get('file')
model = kwargs.get('model', 'whisper-1')
duration = kwargs.pop('duration', 600)
result = client.audio.transcriptions.create(*args, **kwargs)
end_time = time.time()
# Token usage for transcriptions is actually the duration in seconds we pass, as the whisper model is priced per second transcribed
metrics = {
'total_tokens': duration,
'prompt_tokens': 0, # For transcriptions, all tokens are considered "completion"
'completion_tokens': duration,
'time_elapsed': end_time - start_time,
'interaction_type': 'ASR',
}
current_event.log_llm_metrics(metrics)
return result

View File

@@ -0,0 +1,77 @@
# common/langchain/tracked_transcription.py
from typing import Any, Optional, Dict
import time
from openai import OpenAI
from common.utils.business_event_context import current_event
class TrackedOpenAITranscription:
"""Wrapper for OpenAI transcription with metric tracking"""
def __init__(self, api_key: str, **kwargs: Any):
"""Initialize with OpenAI client settings"""
self.client = OpenAI(api_key=api_key)
self.model = kwargs.get('model', 'whisper-1')
def transcribe(self,
file: Any,
model: Optional[str] = None,
language: Optional[str] = None,
prompt: Optional[str] = None,
response_format: Optional[str] = None,
temperature: Optional[float] = None,
duration: Optional[int] = None) -> str:
"""
Transcribe audio with metrics tracking
Args:
file: Audio file to transcribe
model: Model to use (defaults to whisper-1)
language: Optional language of the audio
prompt: Optional prompt to guide transcription
response_format: Response format (json, text, etc)
temperature: Sampling temperature
duration: Duration of audio in seconds for metrics
Returns:
Transcription text
"""
start_time = time.time()
try:
# Create transcription options
options = {
"file": file,
"model": model or self.model,
}
if language:
options["language"] = language
if prompt:
options["prompt"] = prompt
if response_format:
options["response_format"] = response_format
if temperature:
options["temperature"] = temperature
response = self.client.audio.transcriptions.create(**options)
# Calculate metrics
end_time = time.time()
# Token usage for transcriptions is based on audio duration
metrics = {
'total_tokens': duration or 600, # Default to 10 minutes if duration not provided
'prompt_tokens': 0, # For transcriptions, all tokens are completion
'completion_tokens': duration or 600,
'time_elapsed': end_time - start_time,
'interaction_type': 'ASR',
}
current_event.log_llm_metrics(metrics)
# Return text from response
if isinstance(response, str):
return response
return response.text
except Exception as e:
raise Exception(f"Transcription failed: {str(e)}")