- Introduction of dynamic Retrievers & Specialists
- Introduction of dynamic Processors - Introduction of caching system - Introduction of a better template manager - Adaptation of ModelVariables to support dynamic Processors / Retrievers / Specialists - Start adaptation of chat client
This commit is contained in:
@@ -3,11 +3,14 @@ import logging.config
|
||||
from flask import Flask
|
||||
import os
|
||||
|
||||
from common.langchain.templates.template_manager import TemplateManager
|
||||
from common.utils.celery_utils import make_celery, init_celery
|
||||
from common.extensions import db
|
||||
from common.extensions import db, template_manager, cache_manager
|
||||
from config.logging_config import LOGGING
|
||||
from config.config import get_config
|
||||
|
||||
from . import specialists, retrievers
|
||||
|
||||
|
||||
def create_app(config_file=None):
|
||||
app = Flask(__name__)
|
||||
@@ -24,14 +27,12 @@ def create_app(config_file=None):
|
||||
|
||||
logging.config.dictConfig(LOGGING)
|
||||
|
||||
app.logger.debug('Starting up eveai_chat_workers...')
|
||||
app.logger.info('Starting up eveai_chat_workers...')
|
||||
register_extensions(app)
|
||||
|
||||
celery = make_celery(app.name, app.config)
|
||||
init_celery(celery, app)
|
||||
|
||||
app.rag_tuning_logger = logging.getLogger('rag_tuning')
|
||||
|
||||
from eveai_chat_workers import tasks
|
||||
print(tasks.tasks_ping())
|
||||
|
||||
@@ -40,6 +41,9 @@ def create_app(config_file=None):
|
||||
|
||||
def register_extensions(app):
|
||||
db.init_app(app)
|
||||
cache_manager.init_app(app)
|
||||
template_manager.init_app(app)
|
||||
|
||||
|
||||
app, celery = create_app()
|
||||
|
||||
|
||||
193
eveai_chat_workers/chat_session_cache.py
Normal file
193
eveai_chat_workers/chat_session_cache.py
Normal file
@@ -0,0 +1,193 @@
|
||||
# common/utils/cache/chat_session_handler.py
|
||||
from typing import Dict, List, Any, Optional
|
||||
from datetime import datetime as dt, timezone as tz
|
||||
from dataclasses import dataclass
|
||||
|
||||
from flask import current_app
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.orm import joinedload
|
||||
|
||||
from common.extensions import db, cache_manager
|
||||
from common.models.interaction import ChatSession, Interaction
|
||||
from common.utils.cache.base import CacheHandler
|
||||
|
||||
|
||||
@dataclass
|
||||
class CachedInteraction:
|
||||
"""Lightweight representation of an interaction for history purposes"""
|
||||
specialist_arguments: Dict[str, Any] # Contains the original question and other arguments
|
||||
specialist_results: Dict[str, Any] # Contains detailed question, answer and other results
|
||||
|
||||
|
||||
@dataclass
|
||||
class CachedSession:
|
||||
"""Cached representation of a chat session with its interactions"""
|
||||
id: int
|
||||
session_id: str
|
||||
interactions: List[CachedInteraction]
|
||||
timezone: str
|
||||
|
||||
|
||||
class ChatSessionCacheHandler(CacheHandler[CachedSession]):
|
||||
"""Handles caching of chat sessions focused on interaction history"""
|
||||
handler_name = 'chat_session_cache'
|
||||
|
||||
def __init__(self, region):
|
||||
super().__init__(region, 'chat_session')
|
||||
self.configure_keys('session_id')
|
||||
|
||||
def get_cached_session(self, session_id: str, *, create_params: Optional[Dict[str, Any]] = None) -> CachedSession:
|
||||
"""
|
||||
Get or create a cached session with its interaction history.
|
||||
If not in cache, loads from database and caches it.
|
||||
|
||||
Args:
|
||||
session_id: The session identifier
|
||||
create_params: Optional parameters for session creation if it doesn't exist.
|
||||
Must include 'timezone' if provided.
|
||||
|
||||
|
||||
Returns:
|
||||
CachedSession with interaction history
|
||||
|
||||
"""
|
||||
|
||||
def creator_func(session_id: str) -> CachedSession:
|
||||
# Load session and interactions from database
|
||||
session = (
|
||||
ChatSession.query
|
||||
.options(joinedload(ChatSession.interactions))
|
||||
.filter_by(session_id=session_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not session:
|
||||
if not create_params:
|
||||
raise ValueError(f"Chat session {session_id} not found and no creation parameters provided")
|
||||
|
||||
if 'timezone' not in create_params:
|
||||
raise ValueError("timezone is required in create_params for new session creation")
|
||||
|
||||
# Create new session
|
||||
session = ChatSession(
|
||||
session_id=session_id,
|
||||
session_start=dt.now(tz.utc),
|
||||
timezone=create_params['timezone']
|
||||
)
|
||||
try:
|
||||
db.session.add(session)
|
||||
db.session.commit()
|
||||
except SQLAlchemyError as e:
|
||||
db.session.rollback()
|
||||
raise ValueError(f"Failed to create new session: {str(e)}")
|
||||
|
||||
# Convert to cached format
|
||||
cached_interactions = [
|
||||
CachedInteraction(
|
||||
specialist_arguments=interaction.specialist_arguments,
|
||||
specialist_results=interaction.specialist_results
|
||||
)
|
||||
for interaction in session.interactions
|
||||
if interaction.specialist_results is not None # Only include completed interactions
|
||||
]
|
||||
|
||||
return CachedSession(
|
||||
id=session.id,
|
||||
session_id=session_id,
|
||||
interactions=cached_interactions,
|
||||
timezone=session.timezone
|
||||
)
|
||||
|
||||
return self.get(creator_func, session_id=session_id)
|
||||
|
||||
def add_completed_interaction(self, session_id: str, interaction: Interaction) -> None:
|
||||
"""
|
||||
Add a completed interaction to the cached session history.
|
||||
Should only be called once the interaction has an answer.
|
||||
|
||||
Args:
|
||||
session_id: The session identifier
|
||||
interaction: The completed interaction to add
|
||||
|
||||
Note:
|
||||
Only adds the interaction if it has an answer
|
||||
"""
|
||||
if not interaction.specialist_results:
|
||||
return # Skip incomplete interactions
|
||||
|
||||
try:
|
||||
cached_session = self.get_cached_session(session_id)
|
||||
|
||||
# Add new interaction to cache
|
||||
cached_session.interactions.append(
|
||||
CachedInteraction(
|
||||
specialist_arguments=interaction.specialist_arguments,
|
||||
specialist_results=interaction.specialist_results,
|
||||
)
|
||||
)
|
||||
|
||||
# Force cache update
|
||||
self.invalidate(session_id=session_id)
|
||||
|
||||
except ValueError:
|
||||
# If session not in cache yet, load it fresh from DB
|
||||
self.get_cached_session(session_id)
|
||||
|
||||
def to_cache_data(self, instance: CachedSession) -> Dict[str, Any]:
|
||||
"""Convert CachedSession to cache data"""
|
||||
return {
|
||||
'id': instance.id,
|
||||
'session_id': instance.session_id,
|
||||
'timezone': instance.timezone,
|
||||
'interactions': [
|
||||
{
|
||||
'specialist_arguments': interaction.specialist_arguments,
|
||||
'specialist_results': interaction.specialist_results,
|
||||
}
|
||||
for interaction in instance.interactions
|
||||
],
|
||||
'last_updated': dt.now(tz=tz.utc).isoformat()
|
||||
}
|
||||
|
||||
def from_cache_data(self, data: Dict[str, Any], session_id: str, **kwargs) -> CachedSession:
|
||||
"""Create CachedSession from cache data"""
|
||||
interactions = [
|
||||
CachedInteraction(
|
||||
specialist_arguments=int_data['specialist_arguments'],
|
||||
specialist_results=int_data['specialist_results']
|
||||
)
|
||||
for int_data in data['interactions']
|
||||
]
|
||||
|
||||
return CachedSession(
|
||||
id=data['id'],
|
||||
session_id=data['session_id'],
|
||||
interactions=interactions,
|
||||
timezone=data['timezone']
|
||||
)
|
||||
|
||||
def should_cache(self, value: Dict[str, Any]) -> bool:
|
||||
"""Validate cache data"""
|
||||
required_fields = {'id','session_id', 'timezone', 'interactions'}
|
||||
return all(field in value for field in required_fields)
|
||||
|
||||
|
||||
# Register the handler with the cache manager
|
||||
cache_manager.register_handler(ChatSessionCacheHandler, 'eveai_chat_workers')
|
||||
|
||||
|
||||
# Helper function similar to get_model_variables
|
||||
def get_chat_history(session_id: str) -> CachedSession:
|
||||
"""
|
||||
Get cached chat history for a session, loading from database if needed
|
||||
|
||||
Args:
|
||||
session_id: Session ID to look up
|
||||
|
||||
Returns:
|
||||
CachedSession with interaction history
|
||||
|
||||
Raises:
|
||||
ValueError: If session doesn't exist
|
||||
"""
|
||||
return cache_manager.chat_session_cache.get_cached_session(session_id)
|
||||
5
eveai_chat_workers/retrievers/__init__.py
Normal file
5
eveai_chat_workers/retrievers/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
# Import all specialist implementations here to ensure registration
|
||||
from . import standard_rag
|
||||
|
||||
# List of all available specialist implementations
|
||||
__all__ = ['standard_rag']
|
||||
57
eveai_chat_workers/retrievers/base.py
Normal file
57
eveai_chat_workers/retrievers/base.py
Normal file
@@ -0,0 +1,57 @@
|
||||
from abc import ABC, abstractmethod, abstractproperty
|
||||
from typing import Dict, Any, List
|
||||
|
||||
from flask import current_app
|
||||
from eveai_chat_workers.retrievers.retriever_typing import RetrieverResult, RetrieverArguments
|
||||
from config.logging_config import TuningLogger
|
||||
|
||||
|
||||
class BaseRetriever(ABC):
|
||||
"""Base class for all retrievers"""
|
||||
|
||||
def __init__(self, tenant_id: int, retriever_id: int):
|
||||
self.tenant_id = tenant_id
|
||||
self.retriever_id = retriever_id
|
||||
self.tuning = False
|
||||
self.tuning_logger = None
|
||||
self._setup_tuning_logger()
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def type(self) -> str:
|
||||
"""The type of the retriever"""
|
||||
pass
|
||||
|
||||
def _setup_tuning_logger(self):
|
||||
try:
|
||||
self.tuning_logger = TuningLogger(
|
||||
'tuning',
|
||||
tenant_id=self.tenant_id,
|
||||
retriever_id=self.retriever_id,
|
||||
)
|
||||
# Verify logger is working with a test message
|
||||
if self.tuning:
|
||||
self.tuning_logger.log_tuning('retriever', "Tuning logger initialized")
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Failed to setup tuning logger: {str(e)}")
|
||||
raise
|
||||
|
||||
def _log_tuning(self, message: str, data: Dict[str, Any] = None) -> None:
|
||||
if self.tuning and self.tuning_logger:
|
||||
try:
|
||||
self.tuning_logger.log_tuning('retriever', message, data)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Processor: Error in tuning logging: {e}")
|
||||
|
||||
@abstractmethod
|
||||
def retrieve(self, arguments: RetrieverArguments) -> List[RetrieverResult]:
|
||||
"""
|
||||
Retrieve relevant documents based on provided arguments
|
||||
|
||||
Args:
|
||||
arguments: Dictionary of arguments for the retrieval operation
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: List of retrieved documents/content
|
||||
"""
|
||||
pass
|
||||
20
eveai_chat_workers/retrievers/registry.py
Normal file
20
eveai_chat_workers/retrievers/registry.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from typing import Dict, Type
|
||||
from .base import BaseRetriever
|
||||
|
||||
|
||||
class RetrieverRegistry:
|
||||
"""Registry for retriever types"""
|
||||
|
||||
_registry: Dict[str, Type[BaseRetriever]] = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, retriever_type: str, retriever_class: Type[BaseRetriever]):
|
||||
"""Register a new retriever type"""
|
||||
cls._registry[retriever_type] = retriever_class
|
||||
|
||||
@classmethod
|
||||
def get_retriever_class(cls, retriever_type: str) -> Type[BaseRetriever]:
|
||||
"""Get the retriever class for a given type"""
|
||||
if retriever_type not in cls._registry:
|
||||
raise ValueError(f"Unknown retriever type: {retriever_type}")
|
||||
return cls._registry[retriever_type]
|
||||
60
eveai_chat_workers/retrievers/retriever_typing.py
Normal file
60
eveai_chat_workers/retrievers/retriever_typing.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from typing import List, Dict, Any
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from common.utils.config_field_types import ArgumentDefinition, TaggingFields
|
||||
from config.retriever_types import RETRIEVER_TYPES
|
||||
|
||||
|
||||
class RetrieverMetadata(BaseModel):
|
||||
"""Metadata structure for retrieved documents"""
|
||||
document_id: int = Field(..., description="ID of the source document")
|
||||
version_id: int = Field(..., description="Version ID of the source document")
|
||||
document_name: str = Field(..., description="Name of the source document")
|
||||
user_metadata: Dict[str, Any] = Field(
|
||||
default_factory=dict, # This will use an empty dict if None is provided
|
||||
description="User-defined metadata"
|
||||
)
|
||||
|
||||
|
||||
class RetrieverResult(BaseModel):
|
||||
"""Standard result format for all retrievers"""
|
||||
id: int = Field(..., description="ID of the retrieved embedding")
|
||||
chunk: str = Field(..., description="Retrieved text chunk")
|
||||
similarity: float = Field(..., description="Similarity score (0-1)")
|
||||
metadata: RetrieverMetadata = Field(..., description="Associated metadata")
|
||||
|
||||
|
||||
class RetrieverArguments(BaseModel):
|
||||
"""
|
||||
Dynamic arguments for retrievers, allowing arbitrary fields but validating required ones
|
||||
based on RETRIEVER_TYPES configuration.
|
||||
"""
|
||||
type: str = Field(..., description="Type of retriever (e.g. STANDARD_RAG)")
|
||||
|
||||
# Allow any additional fields
|
||||
model_config = {
|
||||
"extra": "allow"
|
||||
}
|
||||
|
||||
@model_validator(mode='after')
|
||||
def validate_required_arguments(self) -> 'RetrieverArguments':
|
||||
"""Validate that all required arguments for this retriever type are present"""
|
||||
retriever_config = RETRIEVER_TYPES.get(self.type)
|
||||
if not retriever_config:
|
||||
raise ValueError(f"Unknown retriever type: {self.type}")
|
||||
|
||||
# Check required arguments from configuration
|
||||
for arg_name, arg_config in retriever_config['arguments'].items():
|
||||
if arg_config.get('required', False):
|
||||
if not hasattr(self, arg_name):
|
||||
raise ValueError(f"Missing required argument '{arg_name}' for {self.type}")
|
||||
|
||||
# Type validation
|
||||
value = getattr(self, arg_name)
|
||||
expected_type = arg_config['type']
|
||||
if expected_type == 'str' and not isinstance(value, str):
|
||||
raise ValueError(f"Argument '{arg_name}' must be a string")
|
||||
elif expected_type == 'int' and not isinstance(value, int):
|
||||
raise ValueError(f"Argument '{arg_name}' must be an integer")
|
||||
# Add other type validations as needed
|
||||
|
||||
return self
|
||||
140
eveai_chat_workers/retrievers/standard_rag.py
Normal file
140
eveai_chat_workers/retrievers/standard_rag.py
Normal file
@@ -0,0 +1,140 @@
|
||||
# retrievers/standard_rag.py
|
||||
from datetime import datetime as dt, timezone as tz
|
||||
from typing import Dict, Any, List
|
||||
from sqlalchemy import func, or_, desc
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from flask import current_app
|
||||
|
||||
from common.extensions import db
|
||||
from common.models.document import Document, DocumentVersion, Catalog, Retriever
|
||||
from common.models.user import Tenant
|
||||
from common.utils.datetime_utils import get_date_in_timezone
|
||||
from common.utils.model_utils import get_model_variables
|
||||
from .base import BaseRetriever
|
||||
|
||||
from .registry import RetrieverRegistry
|
||||
from .retriever_typing import RetrieverArguments, RetrieverResult, RetrieverMetadata
|
||||
|
||||
|
||||
class StandardRAGRetriever(BaseRetriever):
|
||||
"""Standard RAG retriever implementation"""
|
||||
|
||||
def __init__(self, tenant_id: int, retriever_id: int):
|
||||
super().__init__(tenant_id, retriever_id)
|
||||
|
||||
retriever = Retriever.query.get_or_404(retriever_id)
|
||||
self.catalog_id = retriever.catalog_id
|
||||
self.similarity_threshold = retriever.configuration.get('es_similarity_threshold', 0.3)
|
||||
self.k = retriever.configuration.get('es_k', 8)
|
||||
self.tuning = retriever.tuning
|
||||
self.model_variables = get_model_variables(self.tenant_id)
|
||||
|
||||
self._log_tuning("Standard RAG retriever initialized")
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return "STANDARD_RAG"
|
||||
|
||||
def retrieve(self, arguments: RetrieverArguments) -> List[RetrieverResult]:
|
||||
"""
|
||||
Retrieve documents based on query
|
||||
|
||||
Args:
|
||||
arguments: Validated RetrieverArguments containing at minimum:
|
||||
- query: str - The search query
|
||||
|
||||
Returns:
|
||||
List[RetrieverResult]: List of retrieved documents with similarity scores
|
||||
"""
|
||||
try:
|
||||
query = arguments.query
|
||||
|
||||
# Get query embedding
|
||||
query_embedding = self._get_query_embedding(query)
|
||||
|
||||
# Get the appropriate embedding database model
|
||||
db_class = self.model_variables.embedding_model_class
|
||||
|
||||
# Get current date for validity checks
|
||||
current_date = dt.now(tz=tz.utc).date()
|
||||
|
||||
# Create subquery for latest versions
|
||||
subquery = (
|
||||
db.session.query(
|
||||
DocumentVersion.doc_id,
|
||||
func.max(DocumentVersion.id).label('latest_version_id')
|
||||
)
|
||||
.group_by(DocumentVersion.doc_id)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
# Main query
|
||||
query_obj = (
|
||||
db.session.query(
|
||||
db_class,
|
||||
(1 - db_class.embedding.cosine_distance(query_embedding)).label('similarity')
|
||||
)
|
||||
.join(DocumentVersion, db_class.doc_vers_id == DocumentVersion.id)
|
||||
.join(Document, DocumentVersion.doc_id == Document.id)
|
||||
.join(subquery, DocumentVersion.id == subquery.c.latest_version_id)
|
||||
.filter(
|
||||
or_(Document.valid_from.is_(None), func.date(Document.valid_from) <= current_date),
|
||||
or_(Document.valid_to.is_(None), func.date(Document.valid_to) >= current_date),
|
||||
(1 - db_class.embedding.cosine_distance(query_embedding)) > self.similarity_threshold,
|
||||
Document.catalog_id == self.catalog_id
|
||||
)
|
||||
.order_by(desc('similarity'))
|
||||
.limit(self.k)
|
||||
)
|
||||
|
||||
results = query_obj.all()
|
||||
|
||||
# Transform results into standard format
|
||||
processed_results = []
|
||||
for doc, similarity in results:
|
||||
processed_results.append(
|
||||
RetrieverResult(
|
||||
id=doc.id,
|
||||
chunk=doc.chunk,
|
||||
similarity=float(similarity),
|
||||
metadata=RetrieverMetadata(
|
||||
document_id=doc.document_version.doc_id,
|
||||
version_id=doc.document_version.id,
|
||||
document_name=doc.document_version.document.name,
|
||||
user_metadata=doc.document_version.user_metadata or {},
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Log the retrieval
|
||||
if self.tuning:
|
||||
compiled_query = str(query_obj.statement.compile(
|
||||
compile_kwargs={"literal_binds": True} # This will include the actual values in the SQL
|
||||
))
|
||||
self._log_tuning('retrieve', {
|
||||
"arguments": arguments.model_dump(),
|
||||
"similarity_threshold": self.similarity_threshold,
|
||||
"k": self.k,
|
||||
"query": compiled_query,
|
||||
"Raw Results": str(results),
|
||||
"Processed Results": [r.model_dump() for r in processed_results],
|
||||
})
|
||||
|
||||
return processed_results
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
current_app.logger.error(f'Error in RAG retrieval: {e}')
|
||||
db.session.rollback()
|
||||
raise
|
||||
except Exception as e:
|
||||
current_app.logger.error(f'Unexpected error in RAG retrieval: {e}')
|
||||
raise
|
||||
|
||||
def _get_query_embedding(self, query: str):
|
||||
"""Get embedding for the query text"""
|
||||
embedding_model = self.model_variables.embedding_model
|
||||
return embedding_model.embed_query(query)
|
||||
|
||||
|
||||
# Register the retriever type
|
||||
RetrieverRegistry.register("STANDARD_RAG", StandardRAGRetriever)
|
||||
5
eveai_chat_workers/specialists/__init__.py
Normal file
5
eveai_chat_workers/specialists/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
# Import all specialist implementations here to ensure registration
|
||||
from . import rag_specialist
|
||||
|
||||
# List of all available specialist implementations
|
||||
__all__ = ['rag_specialist']
|
||||
50
eveai_chat_workers/specialists/base.py
Normal file
50
eveai_chat_workers/specialists/base.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Any
|
||||
from flask import current_app
|
||||
|
||||
from config.logging_config import TuningLogger
|
||||
from eveai_chat_workers.specialists.specialist_typing import SpecialistArguments, SpecialistResult
|
||||
|
||||
|
||||
class BaseSpecialist(ABC):
|
||||
"""Base class for all specialists"""
|
||||
|
||||
def __init__(self, tenant_id: int, specialist_id: int, session_id: str):
|
||||
self.tenant_id = tenant_id
|
||||
self.specialist_id = specialist_id
|
||||
self.session_id = session_id
|
||||
self.tuning = False
|
||||
self.tuning_logger = None
|
||||
self._setup_tuning_logger()
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def type(self) -> str:
|
||||
"""The type of the specialist"""
|
||||
pass
|
||||
|
||||
def _setup_tuning_logger(self):
|
||||
try:
|
||||
self.tuning_logger = TuningLogger(
|
||||
'tuning',
|
||||
tenant_id=self.tenant_id,
|
||||
specialist_id=self.specialist_id,
|
||||
)
|
||||
# Verify logger is working with a test message
|
||||
if self.tuning:
|
||||
self.tuning_logger.log_tuning('specialist', "Tuning logger initialized")
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Failed to setup tuning logger: {str(e)}")
|
||||
raise
|
||||
|
||||
def _log_tuning(self, message: str, data: Dict[str, Any] = None) -> None:
|
||||
if self.tuning and self.tuning_logger:
|
||||
try:
|
||||
self.tuning_logger.log_tuning('specialist', message, data)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Processor: Error in tuning logging: {e}")
|
||||
|
||||
@abstractmethod
|
||||
def execute(self, arguments: SpecialistArguments) -> SpecialistResult:
|
||||
"""Execute the specialist's logic"""
|
||||
pass
|
||||
289
eveai_chat_workers/specialists/rag_specialist.py
Normal file
289
eveai_chat_workers/specialists/rag_specialist.py
Normal file
@@ -0,0 +1,289 @@
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, List
|
||||
from flask import current_app
|
||||
from langchain_core.exceptions import LangChainException
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_core.runnables import RunnableParallel, RunnablePassthrough
|
||||
|
||||
from common.langchain.outputs.base import OutputRegistry
|
||||
from common.langchain.outputs.rag import RAGOutput
|
||||
from common.utils.business_event_context import current_event
|
||||
from .specialist_typing import SpecialistArguments, SpecialistResult
|
||||
from ..chat_session_cache import CachedInteraction, get_chat_history
|
||||
from ..retrievers.registry import RetrieverRegistry
|
||||
from ..retrievers.base import BaseRetriever
|
||||
from common.models.interaction import SpecialistRetriever, Specialist
|
||||
from common.utils.model_utils import get_model_variables, create_language_template, replace_variable_in_template
|
||||
from .base import BaseSpecialist
|
||||
from .registry import SpecialistRegistry
|
||||
from eveai_chat_workers.retrievers.retriever_typing import RetrieverArguments, RetrieverResult
|
||||
|
||||
|
||||
class RAGSpecialist(BaseSpecialist):
|
||||
"""
|
||||
Standard Q&A RAG Specialist implementation that combines retriever results
|
||||
with LLM processing to generate answers.
|
||||
"""
|
||||
def __init__(self, tenant_id: int, specialist_id: int, session_id: str):
|
||||
super().__init__(tenant_id, specialist_id, session_id)
|
||||
|
||||
# Check and load the specialist
|
||||
specialist = Specialist.query.get_or_404(specialist_id)
|
||||
# Set the specific configuration for the RAG Specialist
|
||||
self.specialist_context = specialist.configuration.get('specialist_context', '')
|
||||
self.temperature = specialist.configuration.get('temperature', 0.3)
|
||||
self.tuning = specialist.tuning
|
||||
|
||||
# Initialize retrievers
|
||||
self.retrievers = self._initialize_retrievers()
|
||||
|
||||
# Initialize model variables
|
||||
self.model_variables = get_model_variables(tenant_id)
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return "STANDARD_RAG"
|
||||
|
||||
def _initialize_retrievers(self) -> List[BaseRetriever]:
|
||||
"""Initialize all retrievers associated with this specialist"""
|
||||
retrievers = []
|
||||
|
||||
# Get retriever associations from database
|
||||
specialist_retrievers = (
|
||||
SpecialistRetriever.query
|
||||
.filter_by(specialist_id=self.specialist_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
self._log_tuning("_initialize_retrievers", {"Nr of retrievers": len(specialist_retrievers)})
|
||||
|
||||
for spec_retriever in specialist_retrievers:
|
||||
# Get retriever configuration from database
|
||||
retriever = spec_retriever.retriever
|
||||
retriever_class = RetrieverRegistry.get_retriever_class(retriever.type)
|
||||
self._log_tuning("_initialize_retrievers", {
|
||||
"Retriever id": spec_retriever.retriever_id,
|
||||
"Retriever Type": retriever.type,
|
||||
"Retriever Class": str(retriever_class),
|
||||
})
|
||||
|
||||
# Initialize retriever with its configuration
|
||||
retrievers.append(
|
||||
retriever_class(
|
||||
tenant_id=self.tenant_id,
|
||||
retriever_id=retriever.id,
|
||||
)
|
||||
)
|
||||
|
||||
return retrievers
|
||||
|
||||
@property
|
||||
def required_templates(self) -> List[str]:
|
||||
"""List of required templates for this specialist"""
|
||||
return ['rag', 'history']
|
||||
|
||||
# def _detail_question(question, language, model_variables, session_id):
|
||||
# retriever = EveAIHistoryRetriever(model_variables=model_variables, session_id=session_id)
|
||||
# llm = model_variables['llm']
|
||||
# template = model_variables['history_template']
|
||||
# language_template = create_language_template(template, language)
|
||||
# full_template = replace_variable_in_template(language_template, "{tenant_context}",
|
||||
# model_variables['rag_context'])
|
||||
# history_prompt = ChatPromptTemplate.from_template(full_template)
|
||||
# setup_and_retrieval = RunnableParallel({"history": retriever, "question": RunnablePassthrough()})
|
||||
# output_parser = StrOutputParser()
|
||||
#
|
||||
# chain = setup_and_retrieval | history_prompt | llm | output_parser
|
||||
#
|
||||
# try:
|
||||
# answer = chain.invoke(question)
|
||||
# return answer
|
||||
# except LangChainException as e:
|
||||
# current_app.logger.error(f'Error detailing question: {e}')
|
||||
# raise
|
||||
|
||||
def _detail_question(self, language: str, question: str) -> str:
|
||||
"""Detail question based on conversation history"""
|
||||
try:
|
||||
# Get cached session history
|
||||
cached_session = get_chat_history(self.session_id)
|
||||
|
||||
# Format history for the prompt
|
||||
formatted_history = "\n\n".join([
|
||||
f"HUMAN:\n{interaction.specialist_results.get('detailed_query')}\n\n"
|
||||
f"AI:\n{interaction.specialist_results.get('answer')}"
|
||||
for interaction in cached_session.interactions
|
||||
])
|
||||
|
||||
# Get LLM and template
|
||||
llm = self.model_variables.get_llm(temperature=0.3)
|
||||
template = self.model_variables.get_template('history')
|
||||
language_template = create_language_template(template, language)
|
||||
|
||||
# Create prompt
|
||||
history_prompt = ChatPromptTemplate.from_template(language_template)
|
||||
|
||||
# Create chain
|
||||
chain = (
|
||||
history_prompt |
|
||||
llm |
|
||||
StrOutputParser()
|
||||
)
|
||||
|
||||
# Execute chain
|
||||
detailed_question = chain.invoke({
|
||||
"history": formatted_history,
|
||||
"question": question
|
||||
})
|
||||
|
||||
if self.tuning:
|
||||
self._log_tuning("_detail_question", {
|
||||
"cached_session_id": cached_session.session_id,
|
||||
"cached_session.interactions": str(cached_session.interactions),
|
||||
"original_question": question,
|
||||
"history_used": formatted_history,
|
||||
"detailed_question": detailed_question,
|
||||
})
|
||||
|
||||
return detailed_question
|
||||
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error detailing question: {e}")
|
||||
return question # Fallback to original question
|
||||
|
||||
def execute(self, arguments: SpecialistArguments) -> SpecialistResult:
|
||||
"""
|
||||
Execute the RAG specialist to generate an answer
|
||||
"""
|
||||
start_time = datetime.now()
|
||||
|
||||
try:
|
||||
with current_event.create_span("Specialist Detail Question"):
|
||||
# Get required arguments
|
||||
language = arguments.language
|
||||
query = arguments.query
|
||||
detailed_question = self._detail_question(language, query)
|
||||
|
||||
# Log the start of retrieval process if tuning is enabled
|
||||
with current_event.create_span("Specialist Retrieval"):
|
||||
self._log_tuning("Starting context retrieval", {
|
||||
"num_retrievers": len(self.retrievers),
|
||||
"all arguments": arguments.model_dump(),
|
||||
})
|
||||
|
||||
# Get retriever-specific arguments
|
||||
retriever_arguments = arguments.retriever_arguments
|
||||
|
||||
# Collect context from all retrievers
|
||||
all_context = []
|
||||
for retriever in self.retrievers:
|
||||
# Get arguments for this specific retriever
|
||||
retriever_id = str(retriever.retriever_id)
|
||||
if retriever_id not in retriever_arguments:
|
||||
current_app.logger.error(f"Missing arguments for retriever {retriever_id}")
|
||||
continue
|
||||
|
||||
# Get the retriever's arguments and update the query
|
||||
current_retriever_args = retriever_arguments[retriever_id]
|
||||
if isinstance(retriever_arguments[retriever_id], RetrieverArguments):
|
||||
updated_args = current_retriever_args.model_dump()
|
||||
updated_args['query'] = detailed_question
|
||||
retriever_args = RetrieverArguments(**updated_args)
|
||||
else:
|
||||
# Create a new RetrieverArguments instance from the dictionary
|
||||
current_retriever_args['query'] = detailed_question
|
||||
retriever_args = RetrieverArguments(**current_retriever_args)
|
||||
|
||||
# Each retriever gets its own specific arguments
|
||||
retriever_result = retriever.retrieve(retriever_args)
|
||||
all_context.extend(retriever_result)
|
||||
|
||||
# Sort by similarity if available and get unique contexts
|
||||
all_context.sort(key=lambda x: x.similarity, reverse=True)
|
||||
unique_contexts = []
|
||||
seen_chunks = set()
|
||||
for ctx in all_context:
|
||||
if ctx.chunk not in seen_chunks:
|
||||
unique_contexts.append(ctx)
|
||||
seen_chunks.add(ctx.chunk)
|
||||
|
||||
self._log_tuning("Context retrieval completed", {
|
||||
"total_contexts": len(all_context),
|
||||
"unique_contexts": len(unique_contexts),
|
||||
"average_similarity": sum(ctx.similarity for ctx in unique_contexts) / len(
|
||||
unique_contexts) if unique_contexts else 0
|
||||
})
|
||||
|
||||
# Prepare context for LLM
|
||||
formatted_context = "\n\n".join([
|
||||
f"SOURCE: {ctx.metadata.document_id}\n\n{ctx.chunk}"
|
||||
for ctx in unique_contexts
|
||||
])
|
||||
|
||||
with current_event.create_span("Specialist RAG invocation"):
|
||||
try:
|
||||
# Get LLM with specified temperature
|
||||
llm = self.model_variables.get_llm(temperature=self.temperature)
|
||||
|
||||
# Get template
|
||||
template = self.model_variables.get_template('rag')
|
||||
language_template = create_language_template(template, language)
|
||||
full_template = replace_variable_in_template(
|
||||
language_template,
|
||||
"{tenant_context}",
|
||||
self.specialist_context
|
||||
)
|
||||
|
||||
if self.tuning:
|
||||
self._log_tuning("Template preparation completed", {
|
||||
"template": full_template,
|
||||
"context": formatted_context,
|
||||
"tenant_context": self.specialist_context,
|
||||
})
|
||||
|
||||
# Create prompt
|
||||
rag_prompt = ChatPromptTemplate.from_template(full_template)
|
||||
|
||||
# Setup chain components
|
||||
setup_and_retrieval = RunnableParallel({
|
||||
"context": lambda x: formatted_context,
|
||||
"question": lambda x: x
|
||||
})
|
||||
|
||||
# Get output schema for structured output
|
||||
output_schema = OutputRegistry.get_schema(self.type)
|
||||
structured_llm = llm.with_structured_output(output_schema)
|
||||
chain = setup_and_retrieval | rag_prompt | structured_llm
|
||||
|
||||
raw_result = chain.invoke(query)
|
||||
result = SpecialistResult.create_for_type(
|
||||
"STANDARD_RAG",
|
||||
detailed_query=detailed_question,
|
||||
answer=raw_result.answer,
|
||||
citations=[ctx.metadata.document_id for ctx in unique_contexts
|
||||
if ctx.id in raw_result.citations],
|
||||
insufficient_info=raw_result.insufficient_info
|
||||
)
|
||||
|
||||
if self.tuning:
|
||||
self._log_tuning("LLM chain execution completed", {
|
||||
"Result": result.model_dump()
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error in LLM processing: {e}")
|
||||
if self.tuning:
|
||||
self._log_tuning("LLM processing error", {"error": str(e)})
|
||||
raise
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
current_app.logger.error(f'Error in RAG specialist execution: {str(e)}')
|
||||
raise
|
||||
|
||||
|
||||
# Register the specialist type
|
||||
SpecialistRegistry.register("STANDARD_RAG", RAGSpecialist)
|
||||
OutputRegistry.register("STANDARD_RAG", RAGOutput)
|
||||
21
eveai_chat_workers/specialists/registry.py
Normal file
21
eveai_chat_workers/specialists/registry.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from typing import Dict, Type
|
||||
from .base import BaseSpecialist
|
||||
|
||||
|
||||
class SpecialistRegistry:
|
||||
"""Registry for specialist types"""
|
||||
|
||||
_registry: Dict[str, Type[BaseSpecialist]] = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, specialist_type: str, specialist_class: Type[BaseSpecialist]):
|
||||
"""Register a new specialist type"""
|
||||
cls._registry[specialist_type] = specialist_class
|
||||
|
||||
@classmethod
|
||||
def get_specialist_class(cls, specialist_type: str) -> Type[BaseSpecialist]:
|
||||
"""Get the specialist class for a given type"""
|
||||
if specialist_type not in cls._registry:
|
||||
raise ValueError(f"Unknown specialist type: {specialist_type}")
|
||||
return cls._registry[specialist_type]
|
||||
|
||||
144
eveai_chat_workers/specialists/specialist_typing.py
Normal file
144
eveai_chat_workers/specialists/specialist_typing.py
Normal file
@@ -0,0 +1,144 @@
|
||||
from typing import List, Dict, Any, Optional
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from config.specialist_types import SPECIALIST_TYPES
|
||||
from eveai_chat_workers.retrievers.retriever_typing import RetrieverArguments
|
||||
|
||||
|
||||
class SpecialistArguments(BaseModel):
|
||||
"""
|
||||
Dynamic arguments for specialists, allowing arbitrary fields but validating required ones
|
||||
based on SPECIALIST_TYPES configuration.
|
||||
"""
|
||||
type: str = Field(..., description="Type of specialist (e.g. STANDARD_RAG)")
|
||||
retriever_arguments: Dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Arguments for each retriever, keyed by retriever ID"
|
||||
)
|
||||
|
||||
# Allow any additional fields
|
||||
model_config = {
|
||||
"extra": "allow"
|
||||
}
|
||||
|
||||
@model_validator(mode='after')
|
||||
def validate_required_arguments(self) -> 'SpecialistArguments':
|
||||
"""Validate that all required arguments for this specialist type are present"""
|
||||
specialist_config = SPECIALIST_TYPES.get(self.type)
|
||||
if not specialist_config:
|
||||
raise ValueError(f"Unknown specialist type: {self.type}")
|
||||
|
||||
# Check required arguments from configuration
|
||||
for arg_name, arg_config in specialist_config['arguments'].items():
|
||||
if arg_config.get('required', False):
|
||||
if not hasattr(self, arg_name):
|
||||
raise ValueError(f"Missing required argument '{arg_name}' for {self.type}")
|
||||
|
||||
# Type validation
|
||||
value = getattr(self, arg_name)
|
||||
expected_type = arg_config['type']
|
||||
if expected_type == 'str' and not isinstance(value, str):
|
||||
raise ValueError(f"Argument '{arg_name}' must be a string")
|
||||
elif expected_type == 'int' and not isinstance(value, int):
|
||||
raise ValueError(f"Argument '{arg_name}' must be an integer")
|
||||
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def create(cls, type_name: str, specialist_args: Dict[str, Any],
|
||||
retriever_args: Dict[str, Dict[str, Any]]) -> 'SpecialistArguments':
|
||||
"""
|
||||
Factory method to create SpecialistArguments with validated retriever arguments
|
||||
|
||||
Args:
|
||||
type_name: The specialist type (e.g., 'STANDARD_RAG')
|
||||
specialist_args: Arguments specific to the specialist
|
||||
retriever_args: Dictionary of retriever arguments keyed by retriever ID
|
||||
|
||||
Returns:
|
||||
Validated SpecialistArguments instance
|
||||
"""
|
||||
# Convert raw retriever arguments to RetrieverArguments instances
|
||||
validated_retriever_args = {}
|
||||
for retriever_id, args in retriever_args.items():
|
||||
# Ensure type is included in retriever arguments
|
||||
if 'type' not in args:
|
||||
raise ValueError(f"Retriever arguments for {retriever_id} must include 'type'")
|
||||
|
||||
validated_retriever_args[retriever_id] = RetrieverArguments(**args)
|
||||
|
||||
# Combine everything into the specialist arguments
|
||||
return cls(
|
||||
type=type_name,
|
||||
**specialist_args,
|
||||
retriever_arguments=validated_retriever_args
|
||||
)
|
||||
|
||||
|
||||
class SpecialistResult(BaseModel):
|
||||
"""
|
||||
Dynamic results from specialists, validating required fields based on
|
||||
SPECIALIST_TYPES configuration.
|
||||
"""
|
||||
type: str = Field(..., description="Type of specialist (e.g. STANDARD_RAG)")
|
||||
|
||||
# Allow any additional fields
|
||||
model_config = {
|
||||
"extra": "allow"
|
||||
}
|
||||
|
||||
@model_validator(mode='after')
|
||||
def validate_required_results(self) -> 'SpecialistResult':
|
||||
"""Validate that all required result fields for this specialist type are present"""
|
||||
specialist_config = SPECIALIST_TYPES.get(self.type)
|
||||
if not specialist_config:
|
||||
raise ValueError(f"Unknown specialist type: {self.type}")
|
||||
|
||||
# Check required results from configuration
|
||||
required_results = specialist_config.get('results', {})
|
||||
for result_name, result_config in required_results.items():
|
||||
if result_config.get('required', False):
|
||||
if not hasattr(self, result_name):
|
||||
raise ValueError(f"Missing required result '{result_name}' for {self.type}")
|
||||
|
||||
# Type validation
|
||||
value = getattr(self, result_name)
|
||||
expected_type = result_config['type']
|
||||
|
||||
# Validate based on type annotation
|
||||
if expected_type == 'str' and not isinstance(value, str):
|
||||
raise ValueError(f"Result '{result_name}' must be a string")
|
||||
elif expected_type == 'bool' and not isinstance(value, bool):
|
||||
raise ValueError(f"Result '{result_name}' must be a boolean")
|
||||
elif expected_type == 'List[str]' and not (
|
||||
isinstance(value, list) and all(isinstance(x, str) for x in value)):
|
||||
raise ValueError(f"Result '{result_name}' must be a list of strings")
|
||||
# Add other type validations as needed
|
||||
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def create_for_type(cls, specialist_type: str, **results) -> 'SpecialistResult':
|
||||
"""
|
||||
Factory method to create a type-specific result
|
||||
|
||||
Args:
|
||||
specialist_type: The type of specialist (e.g., 'STANDARD_RAG')
|
||||
**results: The result values to include
|
||||
|
||||
Returns:
|
||||
Validated SpecialistResult instance
|
||||
|
||||
Example:
|
||||
For STANDARD_RAG:
|
||||
result = SpecialistResult.create_for_type(
|
||||
'STANDARD_RAG',
|
||||
answer="The answer text",
|
||||
citations=["doc1", "doc2"],
|
||||
insufficient_info=False
|
||||
)
|
||||
"""
|
||||
# Add the type to the results
|
||||
results['type'] = specialist_type
|
||||
|
||||
# Create and validate the result
|
||||
return cls(**results)
|
||||
@@ -1,24 +1,23 @@
|
||||
from datetime import datetime as dt, timezone as tz
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
from flask import current_app
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.runnables import RunnableParallel, RunnablePassthrough
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
# OpenAI imports
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_core.exceptions import LangChainException
|
||||
|
||||
from common.utils.config_field_types import TaggingFields
|
||||
from common.utils.database import Database
|
||||
from common.models.document import Embedding
|
||||
from common.models.document import Embedding, Catalog
|
||||
from common.models.user import Tenant
|
||||
from common.models.interaction import ChatSession, Interaction, InteractionEmbedding
|
||||
from common.extensions import db
|
||||
from common.models.interaction import ChatSession, Interaction, InteractionEmbedding, Specialist, SpecialistRetriever
|
||||
from common.extensions import db, cache_manager
|
||||
from common.utils.celery_utils import current_celery
|
||||
from common.utils.model_utils import select_model_variables, create_language_template, replace_variable_in_template
|
||||
from common.langchain.retrievers.eveai_default_rag_retriever import EveAIDefaultRagRetriever
|
||||
from common.langchain.retrievers.eveai_history_retriever import EveAIHistoryRetriever
|
||||
from common.utils.business_event import BusinessEvent
|
||||
from common.utils.business_event_context import current_event
|
||||
from config.specialist_types import SPECIALIST_TYPES
|
||||
from eveai_chat_workers.chat_session_cache import get_chat_history
|
||||
from eveai_chat_workers.specialists.registry import SpecialistRegistry
|
||||
from config.retriever_types import RETRIEVER_TYPES
|
||||
from eveai_chat_workers.specialists.specialist_typing import SpecialistArguments
|
||||
|
||||
|
||||
# Healthcheck task
|
||||
@@ -27,41 +26,207 @@ def ping():
|
||||
return 'pong'
|
||||
|
||||
|
||||
def detail_question(question, language, model_variables, session_id):
|
||||
current_app.logger.debug(f'Detail question: {question}')
|
||||
current_app.logger.debug(f'model_variables: {model_variables}')
|
||||
current_app.logger.debug(f'session_id: {session_id}')
|
||||
retriever = EveAIHistoryRetriever(model_variables=model_variables, session_id=session_id)
|
||||
llm = model_variables['llm']
|
||||
template = model_variables['history_template']
|
||||
language_template = create_language_template(template, language)
|
||||
full_template = replace_variable_in_template(language_template, "{tenant_context}", model_variables['rag_context'])
|
||||
history_prompt = ChatPromptTemplate.from_template(full_template)
|
||||
setup_and_retrieval = RunnableParallel({"history": retriever, "question": RunnablePassthrough()})
|
||||
output_parser = StrOutputParser()
|
||||
|
||||
chain = setup_and_retrieval | history_prompt | llm | output_parser
|
||||
|
||||
try:
|
||||
answer = chain.invoke(question)
|
||||
return answer
|
||||
except LangChainException as e:
|
||||
current_app.logger.error(f'Error detailing question: {e}')
|
||||
raise
|
||||
class ArgumentPreparationError(Exception):
|
||||
"""Custom exception for argument preparation errors"""
|
||||
pass
|
||||
|
||||
|
||||
@current_celery.task(name='ask_question', queue='llm_interactions')
|
||||
def ask_question(tenant_id, question, language, session_id, user_timezone, room):
|
||||
"""returns result structured as follows:
|
||||
result = {
|
||||
'answer': 'Your answer here',
|
||||
'citations': ['http://example.com/citation1', 'http://example.com/citation2'],
|
||||
'algorithm': 'algorithm_name',
|
||||
'interaction_id': 'interaction_id_value'
|
||||
}
|
||||
def validate_specialist_arguments(specialist_type: str, arguments: Dict[str, Any]) -> None:
|
||||
"""
|
||||
with BusinessEvent("Ask Question", tenant_id=tenant_id, chat_session_id=session_id):
|
||||
current_app.logger.info(f'ask_question: Received question for tenant {tenant_id}: {question}. Processing...')
|
||||
Validate specialist-specific arguments
|
||||
|
||||
Args:
|
||||
specialist_type: Type of specialist
|
||||
arguments: Arguments to validate (excluding retriever-specific arguments)
|
||||
|
||||
Raises:
|
||||
ArgumentPreparationError: If validation fails
|
||||
"""
|
||||
specialist_config = SPECIALIST_TYPES.get(specialist_type)
|
||||
if not specialist_config:
|
||||
raise ArgumentPreparationError(f"Unknown specialist type: {specialist_type}")
|
||||
|
||||
required_args = specialist_config.get('arguments', {})
|
||||
|
||||
# Check for required arguments
|
||||
for arg_name, arg_config in required_args.items():
|
||||
if arg_config.get('required', False) and arg_name not in arguments:
|
||||
raise ArgumentPreparationError(f"Missing required argument '{arg_name}' for specialist")
|
||||
|
||||
if arg_name in arguments:
|
||||
# Type checking
|
||||
expected_type = arg_config.get('type')
|
||||
if expected_type == 'str' and not isinstance(arguments[arg_name], str):
|
||||
raise ArgumentPreparationError(f"Argument '{arg_name}' must be a string")
|
||||
elif expected_type == 'int' and not isinstance(arguments[arg_name], int):
|
||||
raise ArgumentPreparationError(f"Argument '{arg_name}' must be an integer")
|
||||
|
||||
|
||||
def validate_retriever_arguments(retriever_type: str, arguments: Dict[str, Any],
|
||||
catalog_config: Optional[Dict[str, Any]] = None) -> None:
|
||||
"""
|
||||
Validate retriever-specific arguments
|
||||
|
||||
Args:
|
||||
retriever_type: Type of retriever
|
||||
arguments: Arguments to validate
|
||||
catalog_config: Optional catalog configuration for metadata validation
|
||||
|
||||
Raises:
|
||||
ArgumentPreparationError: If validation fails
|
||||
"""
|
||||
retriever_config = RETRIEVER_TYPES.get(retriever_type)
|
||||
if not retriever_config:
|
||||
raise ArgumentPreparationError(f"Unknown retriever type: {retriever_type}")
|
||||
|
||||
# Validate standard retriever arguments
|
||||
required_args = retriever_config.get('arguments', {})
|
||||
for arg_name, arg_config in required_args.items():
|
||||
if arg_config.get('required', False) and arg_name not in arguments:
|
||||
raise ArgumentPreparationError(f"Missing required argument '{arg_name}' for retriever")
|
||||
|
||||
# Only validate metadata filters if catalog configuration is provided
|
||||
if catalog_config and 'metadata_filters' in arguments:
|
||||
if 'tagging_fields' in catalog_config:
|
||||
tagging_fields = TaggingFields.from_dict(catalog_config['tagging_fields'])
|
||||
errors = tagging_fields.validate_argument_values(arguments['metadata_filters'])
|
||||
if errors:
|
||||
raise ArgumentPreparationError(f"Invalid metadata filters: {errors}")
|
||||
|
||||
|
||||
def is_retriever_id(key: str) -> bool:
|
||||
"""
|
||||
Check if a key represents a valid retriever ID.
|
||||
Valid formats: positive integers, including leading zeros
|
||||
|
||||
Args:
|
||||
key: String to check
|
||||
|
||||
Returns:
|
||||
bool: True if the key represents a valid retriever ID
|
||||
"""
|
||||
try:
|
||||
# Convert to int to handle leading zeros
|
||||
value = int(key)
|
||||
# Ensure it's a positive number
|
||||
return value > 0
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def prepare_arguments(specialist: Any, arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Prepare complete argument dictionary for specialist execution with inheritance
|
||||
|
||||
Args:
|
||||
specialist: Specialist model instance
|
||||
arguments: Dictionary containing:
|
||||
- Specialist arguments
|
||||
- Retriever-specific arguments keyed by retriever ID
|
||||
|
||||
Returns:
|
||||
Dict containing prepared arguments with inheritance applied
|
||||
|
||||
Raises:
|
||||
ArgumentPreparationError: If argument preparation or validation fails
|
||||
"""
|
||||
try:
|
||||
# Separate specialist arguments from retriever arguments
|
||||
retriever_args = {}
|
||||
specialist_args = {}
|
||||
|
||||
for key, value in arguments.items():
|
||||
if isinstance(key, str) and is_retriever_id(key): # Retriever ID
|
||||
retriever_args[key] = value
|
||||
else:
|
||||
specialist_args[key] = value
|
||||
|
||||
# Validate specialist arguments
|
||||
validate_specialist_arguments(specialist.type, specialist_args)
|
||||
|
||||
# Get all retrievers associated with this specialist
|
||||
specialist_retrievers = (
|
||||
SpecialistRetriever.query
|
||||
.filter_by(specialist_id=specialist.id)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Process each retriever
|
||||
prepared_retriever_args = {}
|
||||
for spec_retriever in specialist_retrievers:
|
||||
retriever = spec_retriever.retriever
|
||||
retriever_id = str(retriever.id)
|
||||
|
||||
# Get catalog configuration if it exists
|
||||
catalog_config = None
|
||||
if retriever.catalog_id:
|
||||
try:
|
||||
catalog = Catalog.query.get(retriever.catalog_id)
|
||||
if catalog:
|
||||
catalog_config = catalog.configuration
|
||||
except SQLAlchemyError:
|
||||
current_app.logger.warning(
|
||||
f"Could not fetch catalog {retriever.catalog_id} for retriever {retriever_id}"
|
||||
)
|
||||
|
||||
# Start with specialist arguments (inheritance)
|
||||
inherited_args = specialist_args.copy()
|
||||
|
||||
# Override with retriever-specific arguments if provided
|
||||
if retriever_id in retriever_args:
|
||||
inherited_args.update(retriever_args[retriever_id])
|
||||
|
||||
# Always include the retriever type
|
||||
inherited_args['type'] = retriever.type
|
||||
|
||||
# Validate the combined arguments
|
||||
validate_retriever_arguments(
|
||||
retriever.type,
|
||||
inherited_args,
|
||||
catalog_config
|
||||
)
|
||||
|
||||
prepared_retriever_args[retriever_id] = inherited_args
|
||||
|
||||
# Construct final argument structure
|
||||
final_arguments = {
|
||||
**specialist_args,
|
||||
'retriever_arguments': prepared_retriever_args
|
||||
}
|
||||
return final_arguments
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
current_app.logger.error(f'Database error during argument preparation: {e}')
|
||||
raise ArgumentPreparationError(f"Database error: {str(e)}")
|
||||
except Exception as e:
|
||||
current_app.logger.error(f'Error during argument preparation: {e}')
|
||||
raise ArgumentPreparationError(str(e))
|
||||
|
||||
|
||||
@current_celery.task(name='execute_specialist', queue='llm_interactions')
|
||||
def execute_specialist(tenant_id: int, specialist_id: int, arguments: Dict[str, Any],
|
||||
session_id: str, user_timezone: str, room: str) -> dict:
|
||||
"""
|
||||
Execute a specialist with given arguments
|
||||
|
||||
Args:
|
||||
tenant_id: ID of the tenant
|
||||
specialist_id: ID of the specialist to use
|
||||
arguments: Dictionary containing all required arguments for specialist and retrievers
|
||||
session_id: Chat session ID
|
||||
user_timezone: User's timezone
|
||||
room: Socket.IO room for the response
|
||||
|
||||
Returns:
|
||||
dict: {
|
||||
'result': Dict - Specialist execution result
|
||||
'interaction_id': int - Created interaction ID
|
||||
'room': str - Socket.IO room
|
||||
}
|
||||
"""
|
||||
with BusinessEvent("Execute Specialist", tenant_id=tenant_id, chat_session_id=session_id) as event:
|
||||
current_app.logger.info(
|
||||
f'execute_specialist: Processing request for tenant {tenant_id} using specialist {specialist_id}')
|
||||
|
||||
try:
|
||||
# Retrieve the tenant
|
||||
@@ -69,208 +234,85 @@ def ask_question(tenant_id, question, language, session_id, user_timezone, room)
|
||||
if not tenant:
|
||||
raise Exception(f'Tenant {tenant_id} not found.')
|
||||
|
||||
# Ensure we are working in the correct database schema
|
||||
# Switch to correct database schema
|
||||
Database(tenant_id).switch_schema()
|
||||
|
||||
# Ensure we have a session to story history
|
||||
chat_session = ChatSession.query.filter_by(session_id=session_id).first()
|
||||
if not chat_session:
|
||||
# Ensure we have a session
|
||||
cached_session = cache_manager.chat_session_cache.get_cached_session(
|
||||
session_id,
|
||||
create_params={'timezone': user_timezone}
|
||||
)
|
||||
|
||||
# Get specialist from database
|
||||
specialist = Specialist.query.get_or_404(specialist_id)
|
||||
|
||||
# Prepare complete arguments
|
||||
try:
|
||||
raw_arguments = prepare_arguments(specialist, arguments)
|
||||
# Convert the prepared arguments into a SpecialistArguments instance
|
||||
complete_arguments = SpecialistArguments.create(
|
||||
type_name=specialist.type,
|
||||
specialist_args={k: v for k, v in raw_arguments.items() if k != 'retriever_arguments'},
|
||||
retriever_args=raw_arguments.get('retriever_arguments', {})
|
||||
)
|
||||
except ValueError as e:
|
||||
current_app.logger.error(f'execute_specialist: Error preparing arguments: {e}')
|
||||
raise
|
||||
|
||||
# Create new interaction record
|
||||
new_interaction = Interaction()
|
||||
new_interaction.chat_session_id = cached_session.id
|
||||
new_interaction.timezone = user_timezone
|
||||
new_interaction.question_at = dt.now(tz.utc)
|
||||
new_interaction.specialist_id = specialist.id
|
||||
new_interaction.specialist_arguments = complete_arguments.model_dump(mode='json')
|
||||
|
||||
try:
|
||||
db.session.add(new_interaction)
|
||||
db.session.commit()
|
||||
event.update_attribute('interaction_id', new_interaction.id)
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
current_app.logger.error(f'execute_specialist: Error creating interaction: {e}')
|
||||
raise
|
||||
|
||||
with current_event.create_span("Specialist invocation"):
|
||||
# Initialize specialist instance
|
||||
specialist_class = SpecialistRegistry.get_specialist_class(specialist.type)
|
||||
specialist_instance = specialist_class(
|
||||
tenant_id=tenant_id,
|
||||
specialist_id=specialist_id,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Execute specialist
|
||||
result = specialist_instance.execute(complete_arguments)
|
||||
|
||||
# Update interaction record
|
||||
new_interaction.specialist_results = result.model_dump(mode='json') # Store complete result
|
||||
new_interaction.answer_at = dt.now(tz.utc)
|
||||
|
||||
try:
|
||||
chat_session = ChatSession()
|
||||
chat_session.session_id = session_id
|
||||
chat_session.session_start = dt.now(tz.utc)
|
||||
chat_session.timezone = user_timezone
|
||||
db.session.add(chat_session)
|
||||
db.session.add(new_interaction)
|
||||
db.session.commit()
|
||||
except SQLAlchemyError as e:
|
||||
current_app.logger.error(f'ask_question: Error initializing chat session in database: {e}')
|
||||
current_app.logger.error(f'execute_specialist: Error updating interaction: {e}')
|
||||
raise
|
||||
|
||||
with current_event.create_span("RAG Answer"):
|
||||
result, interaction = answer_using_tenant_rag(question, language, tenant, chat_session)
|
||||
result['algorithm'] = current_app.config['INTERACTION_ALGORITHMS']['RAG_TENANT']['name']
|
||||
result['interaction_id'] = interaction.id
|
||||
result['room'] = room # Include the room in the result
|
||||
# Now that we have a complete interaction with an answer, add it to the cache
|
||||
cache_manager.chat_session_cache.add_completed_interaction(session_id, new_interaction)
|
||||
|
||||
if result['insufficient_info']:
|
||||
if 'LLM' in tenant.fallback_algorithms:
|
||||
with current_event.create_span("Fallback Algorithm LLM"):
|
||||
result, interaction = answer_using_llm(question, language, tenant, chat_session)
|
||||
result['algorithm'] = current_app.config['INTERACTION_ALGORITHMS']['LLM']['name']
|
||||
result['interaction_id'] = interaction.id
|
||||
result['room'] = room # Include the room in the result
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
current_app.logger.error(f'ask_question: Error processing question: {e}')
|
||||
raise
|
||||
|
||||
|
||||
def answer_using_tenant_rag(question, language, tenant, chat_session):
|
||||
new_interaction = Interaction()
|
||||
new_interaction.question = question
|
||||
new_interaction.language = language
|
||||
new_interaction.timezone = chat_session.timezone
|
||||
new_interaction.appreciation = None
|
||||
new_interaction.chat_session_id = chat_session.id
|
||||
new_interaction.question_at = dt.now(tz.utc)
|
||||
new_interaction.algorithm_used = current_app.config['INTERACTION_ALGORITHMS']['RAG_TENANT']['name']
|
||||
|
||||
# Select variables to work with depending on tenant model
|
||||
model_variables = select_model_variables(tenant)
|
||||
tenant_info = tenant.to_dict()
|
||||
|
||||
# Langchain debugging if required
|
||||
# set_debug(True)
|
||||
|
||||
with current_event.create_span("Detail Question"):
|
||||
detailed_question = detail_question(question, language, model_variables, chat_session.session_id)
|
||||
if model_variables['rag_tuning']:
|
||||
current_app.rag_tuning_logger.debug(f'Detailed Question for tenant {tenant.id}:\n{question}.')
|
||||
current_app.rag_tuning_logger.debug(f'-------------------------------------------------------------------')
|
||||
new_interaction.detailed_question = detailed_question
|
||||
new_interaction.detailed_question_at = dt.now(tz.utc)
|
||||
|
||||
with current_event.create_span("Generate Answer using RAG"):
|
||||
retriever = EveAIDefaultRagRetriever(1, model_variables, tenant_info)
|
||||
llm = model_variables['llm']
|
||||
template = model_variables['rag_template']
|
||||
language_template = create_language_template(template, language)
|
||||
full_template = replace_variable_in_template(language_template, "{tenant_context}", model_variables['rag_context'])
|
||||
rag_prompt = ChatPromptTemplate.from_template(full_template)
|
||||
setup_and_retrieval = RunnableParallel({"context": retriever, "question": RunnablePassthrough()})
|
||||
if model_variables['rag_tuning']:
|
||||
current_app.rag_tuning_logger.debug(f'Full prompt for tenant {tenant.id}:\n{full_template}.')
|
||||
current_app.rag_tuning_logger.debug(f'-------------------------------------------------------------------')
|
||||
|
||||
new_interaction_embeddings = []
|
||||
if not model_variables['cited_answer_cls']: # The model doesn't support structured feedback
|
||||
output_parser = StrOutputParser()
|
||||
|
||||
chain = setup_and_retrieval | rag_prompt | llm | output_parser
|
||||
|
||||
# Invoke the chain with the actual question
|
||||
answer = chain.invoke(detailed_question)
|
||||
new_interaction.answer = answer
|
||||
result = {
|
||||
'answer': answer,
|
||||
'citations': [],
|
||||
'insufficient_info': False
|
||||
# Prepare response
|
||||
response = {
|
||||
'result': result.model_dump(),
|
||||
'interaction_id': new_interaction.id,
|
||||
'room': room
|
||||
}
|
||||
|
||||
else: # The model supports structured feedback
|
||||
structured_llm = llm.with_structured_output(model_variables['cited_answer_cls'])
|
||||
return response
|
||||
|
||||
chain = setup_and_retrieval | rag_prompt | structured_llm
|
||||
|
||||
result = chain.invoke(detailed_question).dict()
|
||||
current_app.logger.debug(f'ask_question: result answer: {result['answer']}')
|
||||
current_app.logger.debug(f'ask_question: result citations: {result["citations"]}')
|
||||
current_app.logger.debug(f'ask_question: insufficient information: {result["insufficient_info"]}')
|
||||
if model_variables['rag_tuning']:
|
||||
current_app.rag_tuning_logger.debug(f'ask_question: result answer: {result['answer']}')
|
||||
current_app.rag_tuning_logger.debug(f'ask_question: result citations: {result["citations"]}')
|
||||
current_app.rag_tuning_logger.debug(f'ask_question: insufficient information: {result["insufficient_info"]}')
|
||||
current_app.rag_tuning_logger.debug(f'-------------------------------------------------------------------')
|
||||
new_interaction.answer = result['answer']
|
||||
|
||||
# Filter out the existing Embedding IDs
|
||||
given_embedding_ids = [int(emb_id) for emb_id in result['citations']]
|
||||
embeddings = (
|
||||
db.session.query(Embedding)
|
||||
.filter(Embedding.id.in_(given_embedding_ids))
|
||||
.all()
|
||||
)
|
||||
existing_embedding_ids = [emb.id for emb in embeddings]
|
||||
urls = list(set(emb.document_version.url for emb in embeddings))
|
||||
if model_variables['rag_tuning']:
|
||||
current_app.rag_tuning_logger.debug(f'Referenced documents for answer for tenant {tenant.id}:\n')
|
||||
current_app.rag_tuning_logger.debug(f'{urls}')
|
||||
current_app.rag_tuning_logger.debug(f'-------------------------------------------------------------------')
|
||||
|
||||
for emb_id in existing_embedding_ids:
|
||||
new_interaction_embedding = InteractionEmbedding(embedding_id=emb_id)
|
||||
new_interaction_embedding.interaction = new_interaction
|
||||
new_interaction_embeddings.append(new_interaction_embedding)
|
||||
|
||||
result['citations'] = urls
|
||||
|
||||
# Disable langchain debugging if set above.
|
||||
# set_debug(False)
|
||||
|
||||
new_interaction.answer_at = dt.now(tz.utc)
|
||||
chat_session.session_end = dt.now(tz.utc)
|
||||
|
||||
try:
|
||||
db.session.add(chat_session)
|
||||
db.session.add(new_interaction)
|
||||
db.session.add_all(new_interaction_embeddings)
|
||||
db.session.commit()
|
||||
return result, new_interaction
|
||||
except SQLAlchemyError as e:
|
||||
current_app.logger.error(f'ask_question: Error saving interaction to database: {e}')
|
||||
raise
|
||||
|
||||
|
||||
def answer_using_llm(question, language, tenant, chat_session):
|
||||
new_interaction = Interaction()
|
||||
new_interaction.question = question
|
||||
new_interaction.language = language
|
||||
new_interaction.timezone = chat_session.timezone
|
||||
new_interaction.appreciation = None
|
||||
new_interaction.chat_session_id = chat_session.id
|
||||
new_interaction.question_at = dt.now(tz.utc)
|
||||
new_interaction.algorithm_used = current_app.config['INTERACTION_ALGORITHMS']['LLM']['name']
|
||||
|
||||
# Select variables to work with depending on tenant model
|
||||
model_variables = select_model_variables(tenant)
|
||||
tenant_info = tenant.to_dict()
|
||||
|
||||
# Langchain debugging if required
|
||||
# set_debug(True)
|
||||
|
||||
with current_event.create_span("Detail Question"):
|
||||
detailed_question = detail_question(question, language, model_variables, chat_session.session_id)
|
||||
current_app.logger.debug(f'Original question:\n {question}\n\nDetailed question: {detailed_question}')
|
||||
new_interaction.detailed_question = detailed_question
|
||||
new_interaction.detailed_question_at = dt.now(tz.utc)
|
||||
|
||||
with current_event.create_span("Detail Answer using LLM"):
|
||||
retriever = EveAIDefaultRagRetriever(1, model_variables, tenant_info)
|
||||
llm = model_variables['llm_no_rag']
|
||||
template = model_variables['encyclopedia_template']
|
||||
language_template = create_language_template(template, language)
|
||||
rag_prompt = ChatPromptTemplate.from_template(language_template)
|
||||
setup = RunnablePassthrough()
|
||||
output_parser = StrOutputParser()
|
||||
|
||||
new_interaction_embeddings = []
|
||||
|
||||
chain = setup | rag_prompt | llm | output_parser
|
||||
input_question = {"question": detailed_question}
|
||||
|
||||
# Invoke the chain with the actual question
|
||||
answer = chain.invoke(input_question)
|
||||
new_interaction.answer = answer
|
||||
result = {
|
||||
'answer': answer,
|
||||
'citations': [],
|
||||
'insufficient_info': False
|
||||
}
|
||||
|
||||
# Disable langchain debugging if set above.
|
||||
# set_debug(False)
|
||||
|
||||
new_interaction.answer_at = dt.now(tz.utc)
|
||||
chat_session.session_end = dt.now(tz.utc)
|
||||
|
||||
try:
|
||||
db.session.add(chat_session)
|
||||
db.session.add(new_interaction)
|
||||
db.session.commit()
|
||||
return result, new_interaction
|
||||
except SQLAlchemyError as e:
|
||||
current_app.logger.error(f'ask_question: Error saving interaction to database: {e}')
|
||||
except Exception as e:
|
||||
current_app.logger.error(f'execute_specialist: Error executing specialist: {e}')
|
||||
raise
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user