- Created a new eveai_chat plugin to support the new dynamic possibilities of the Specialists. Currently only supports standard Rag retrievers (i.e. no extra arguments).

This commit is contained in:
Josako
2024-11-26 13:35:29 +01:00
parent 7702a6dfcc
commit 07d89d204f
42 changed files with 1771 additions and 989 deletions

View File

@@ -9,21 +9,28 @@ class EveAICacheManager:
"""Cache manager with registration capabilities"""
def __init__(self):
self.model_region = None
self.eveai_chat_workers_region = None
self.eveai_workers_region = None
self._regions = {}
self._handlers = {}
def init_app(self, app: Flask):
"""Initialize cache regions"""
from common.utils.cache.regions import create_cache_regions
self.model_region, self.eveai_chat_workers_region, self.eveai_workers_region = create_cache_regions(app)
self._regions = create_cache_regions(app)
# Store regions in instance
for region_name, region in self._regions.items():
setattr(self, f"{region_name}_region", region)
# Initialize all registered handlers with their regions
for handler_class, region_name in self._handlers.items():
region = getattr(self, f"{region_name}_region")
region = self._regions[region_name]
handler_instance = handler_class(region)
setattr(self, handler_class.handler_name, handler_instance)
handler_name = getattr(handler_class, 'handler_name', None)
if handler_name:
app.logger.debug(f"{handler_name} is registered")
setattr(self, handler_name, handler_instance)
app.logger.info('Cache regions initialized: ' + ', '.join(self._regions.keys()))
def register_handler(self, handler_class: Type[CacheHandler], region: str):
"""Register a cache handler class with its region"""

View File

@@ -1,7 +1,6 @@
# common/utils/cache/regions.py
from dogpile.cache import make_region
from flask import current_app
from urllib.parse import urlparse
import os
@@ -36,27 +35,31 @@ def get_redis_config(app):
def create_cache_regions(app):
"""Initialize all cache regions with app config"""
redis_config = get_redis_config(app)
regions = {}
# Region for model-related caching (ModelVariables etc)
model_region = make_region(name='model').configure(
model_region = make_region(name='eveai_model').configure(
'dogpile.cache.redis',
arguments=redis_config,
replace_existing_backend=True
)
regions['eveai_model'] = model_region
# Region for eveai_chat_workers components (Specialists, Retrievers, ...)
eveai_chat_workers_region = make_region(name='chat_workers').configure(
eveai_chat_workers_region = make_region(name='eveai_chat_workers').configure(
'dogpile.cache.redis',
arguments=redis_config, # arguments={**redis_config, 'db': 4}, # Different DB
replace_existing_backend=True
)
regions['eveai_chat_workers'] = eveai_chat_workers_region
# Region for eveai_workers components (Processors, ...)
eveai_workers_region = make_region(name='workers').configure(
eveai_workers_region = make_region(name='eveai_workers').configure(
'dogpile.cache.redis',
arguments=redis_config, # Same config for now
replace_existing_backend=True
)
regions['eveai_workers'] = eveai_workers_region
return model_region, eveai_chat_workers_region, eveai_workers_region
return regions

View File

@@ -1,4 +1,6 @@
from flask import request, current_app, session
from flask_jwt_extended import decode_token, verify_jwt_in_request, get_jwt_identity
from common.models.user import Tenant, TenantDomain
@@ -23,31 +25,45 @@ def cors_after_request(response, prefix):
response.headers.add('Access-Control-Allow-Methods', '*')
return response
# Handle OPTIONS preflight requests
if request.method == 'OPTIONS':
response.headers.add('Access-Control-Allow-Origin', '*')
response.headers.add('Access-Control-Allow-Headers', 'Content-Type,Authorization,X-Tenant-ID')
response.headers.add('Access-Control-Allow-Methods', 'GET,POST,PUT,DELETE,OPTIONS')
response.headers.add('Access-Control-Allow-Credentials', 'true')
return response
tenant_id = None
allowed_origins = []
# Try to get tenant_id from JSON payload
json_data = request.get_json(silent=True)
if json_data and 'tenant_id' in json_data:
tenant_id = json_data['tenant_id']
# Check Socket.IO connection
if 'socket.io' in request.path:
token = request.args.get('token')
if token:
try:
decoded = decode_token(token)
tenant_id = decoded['sub']
except Exception as e:
current_app.logger.error(f'Error decoding token: {e}')
return response
else:
# Fallback to get tenant_id from query parameters or headers if JSON is not available
tenant_id = request.args.get('tenant_id') or request.args.get('tenantId') or request.headers.get('X-Tenant-ID')
# Regular API requests
try:
if verify_jwt_in_request(optional=True):
tenant_id = get_jwt_identity()
except Exception as e:
current_app.logger.error(f'Error verifying JWT: {e}')
return response
if tenant_id:
origin = request.headers.get('Origin')
allowed_origins = get_allowed_origins(tenant_id)
else:
current_app.logger.warning('tenant_id not found in request')
origin = request.headers.get('Origin')
if origin in allowed_origins:
response.headers.add('Access-Control-Allow-Origin', origin)
response.headers.add('Access-Control-Allow-Headers', 'Content-Type,Authorization')
response.headers.add('Access-Control-Allow-Methods', 'GET,POST,PUT,DELETE,OPTIONS')
response.headers.add('Access-Control-Allow-Credentials', 'true')
else:
current_app.logger.warning(f'Origin {origin} not allowed')
if origin in allowed_origins:
response.headers.add('Access-Control-Allow-Origin', origin)
response.headers.add('Access-Control-Allow-Headers', 'Content-Type,Authorization')
response.headers.add('Access-Control-Allow-Methods', 'GET,POST,PUT,DELETE,OPTIONS')
response.headers.add('Access-Control-Allow-Credentials', 'true')
return response

View File

@@ -118,3 +118,10 @@ class EveAIInvalidDocumentVersion(EveAIException):
# Construct the message dynamically
message = f"Tenant with ID '{tenant_id}' has no document version with ID {document_version_id}."
super().__init__(message, status_code, payload)
class EveAISocketInputException(EveAIException):
"""Raised when a socket call receives an invalid payload"""
def __init__(self, message, status_code=400, payload=None):
super.__init__(message, status_code, payload)

View File

@@ -252,7 +252,7 @@ class ModelVariablesCacheHandler(CacheHandler[ModelVariables]):
# Register the handler with the cache manager
cache_manager.register_handler(ModelVariablesCacheHandler, 'model')
cache_manager.register_handler(ModelVariablesCacheHandler, 'eveai_model')
# Helper function to get cached model variables

View File

@@ -0,0 +1,60 @@
from dataclasses import dataclass
from typing import Optional
from datetime import datetime
from flask_jwt_extended import decode_token, verify_jwt_in_request
from flask import current_app
@dataclass
class TokenValidationResult:
"""Clean, simple validation result"""
is_valid: bool
tenant_id: Optional[int] = None
error_message: Optional[str] = None
class TokenValidator:
"""Simplified token validator focused on JWT validation"""
def validate_token(self, token: str) -> TokenValidationResult:
"""
Validate JWT token
Args:
token: The JWT token to validate
Returns:
TokenValidationResult with validation status and tenant_id if valid
"""
try:
# Decode and validate token
decoded_token = decode_token(token)
# Extract tenant_id from token subject
tenant_id = decoded_token.get('sub')
if not tenant_id:
return TokenValidationResult(
is_valid=False,
error_message="Missing tenant ID in token"
)
# Verify token timestamps
now = datetime.utcnow().timestamp()
if not (decoded_token.get('exp', 0) > now >= decoded_token.get('nbf', 0)):
return TokenValidationResult(
is_valid=False,
error_message="Token expired or not yet valid"
)
# Token is valid
return TokenValidationResult(
is_valid=True,
tenant_id=tenant_id
)
except Exception as e:
current_app.logger.error(f"Token validation error: {str(e)}")
return TokenValidationResult(
is_valid=False,
error_message=str(e)
)