- Created a new eveai_chat plugin to support the new dynamic possibilities of the Specialists. Currently only supports standard Rag retrievers (i.e. no extra arguments).
This commit is contained in:
19
common/utils/cache/eveai_cache_manager.py
vendored
19
common/utils/cache/eveai_cache_manager.py
vendored
@@ -9,21 +9,28 @@ class EveAICacheManager:
|
||||
"""Cache manager with registration capabilities"""
|
||||
|
||||
def __init__(self):
|
||||
self.model_region = None
|
||||
self.eveai_chat_workers_region = None
|
||||
self.eveai_workers_region = None
|
||||
self._regions = {}
|
||||
self._handlers = {}
|
||||
|
||||
def init_app(self, app: Flask):
|
||||
"""Initialize cache regions"""
|
||||
from common.utils.cache.regions import create_cache_regions
|
||||
self.model_region, self.eveai_chat_workers_region, self.eveai_workers_region = create_cache_regions(app)
|
||||
self._regions = create_cache_regions(app)
|
||||
|
||||
# Store regions in instance
|
||||
for region_name, region in self._regions.items():
|
||||
setattr(self, f"{region_name}_region", region)
|
||||
|
||||
# Initialize all registered handlers with their regions
|
||||
for handler_class, region_name in self._handlers.items():
|
||||
region = getattr(self, f"{region_name}_region")
|
||||
region = self._regions[region_name]
|
||||
handler_instance = handler_class(region)
|
||||
setattr(self, handler_class.handler_name, handler_instance)
|
||||
handler_name = getattr(handler_class, 'handler_name', None)
|
||||
if handler_name:
|
||||
app.logger.debug(f"{handler_name} is registered")
|
||||
setattr(self, handler_name, handler_instance)
|
||||
|
||||
app.logger.info('Cache regions initialized: ' + ', '.join(self._regions.keys()))
|
||||
|
||||
def register_handler(self, handler_class: Type[CacheHandler], region: str):
|
||||
"""Register a cache handler class with its region"""
|
||||
|
||||
13
common/utils/cache/regions.py
vendored
13
common/utils/cache/regions.py
vendored
@@ -1,7 +1,6 @@
|
||||
# common/utils/cache/regions.py
|
||||
|
||||
from dogpile.cache import make_region
|
||||
from flask import current_app
|
||||
from urllib.parse import urlparse
|
||||
import os
|
||||
|
||||
@@ -36,27 +35,31 @@ def get_redis_config(app):
|
||||
def create_cache_regions(app):
|
||||
"""Initialize all cache regions with app config"""
|
||||
redis_config = get_redis_config(app)
|
||||
regions = {}
|
||||
|
||||
# Region for model-related caching (ModelVariables etc)
|
||||
model_region = make_region(name='model').configure(
|
||||
model_region = make_region(name='eveai_model').configure(
|
||||
'dogpile.cache.redis',
|
||||
arguments=redis_config,
|
||||
replace_existing_backend=True
|
||||
)
|
||||
regions['eveai_model'] = model_region
|
||||
|
||||
# Region for eveai_chat_workers components (Specialists, Retrievers, ...)
|
||||
eveai_chat_workers_region = make_region(name='chat_workers').configure(
|
||||
eveai_chat_workers_region = make_region(name='eveai_chat_workers').configure(
|
||||
'dogpile.cache.redis',
|
||||
arguments=redis_config, # arguments={**redis_config, 'db': 4}, # Different DB
|
||||
replace_existing_backend=True
|
||||
)
|
||||
regions['eveai_chat_workers'] = eveai_chat_workers_region
|
||||
|
||||
# Region for eveai_workers components (Processors, ...)
|
||||
eveai_workers_region = make_region(name='workers').configure(
|
||||
eveai_workers_region = make_region(name='eveai_workers').configure(
|
||||
'dogpile.cache.redis',
|
||||
arguments=redis_config, # Same config for now
|
||||
replace_existing_backend=True
|
||||
)
|
||||
regions['eveai_workers'] = eveai_workers_region
|
||||
|
||||
return model_region, eveai_chat_workers_region, eveai_workers_region
|
||||
return regions
|
||||
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from flask import request, current_app, session
|
||||
from flask_jwt_extended import decode_token, verify_jwt_in_request, get_jwt_identity
|
||||
|
||||
from common.models.user import Tenant, TenantDomain
|
||||
|
||||
|
||||
@@ -23,31 +25,45 @@ def cors_after_request(response, prefix):
|
||||
response.headers.add('Access-Control-Allow-Methods', '*')
|
||||
return response
|
||||
|
||||
# Handle OPTIONS preflight requests
|
||||
if request.method == 'OPTIONS':
|
||||
response.headers.add('Access-Control-Allow-Origin', '*')
|
||||
response.headers.add('Access-Control-Allow-Headers', 'Content-Type,Authorization,X-Tenant-ID')
|
||||
response.headers.add('Access-Control-Allow-Methods', 'GET,POST,PUT,DELETE,OPTIONS')
|
||||
response.headers.add('Access-Control-Allow-Credentials', 'true')
|
||||
return response
|
||||
|
||||
tenant_id = None
|
||||
allowed_origins = []
|
||||
|
||||
# Try to get tenant_id from JSON payload
|
||||
json_data = request.get_json(silent=True)
|
||||
|
||||
if json_data and 'tenant_id' in json_data:
|
||||
tenant_id = json_data['tenant_id']
|
||||
# Check Socket.IO connection
|
||||
if 'socket.io' in request.path:
|
||||
token = request.args.get('token')
|
||||
if token:
|
||||
try:
|
||||
decoded = decode_token(token)
|
||||
tenant_id = decoded['sub']
|
||||
except Exception as e:
|
||||
current_app.logger.error(f'Error decoding token: {e}')
|
||||
return response
|
||||
else:
|
||||
# Fallback to get tenant_id from query parameters or headers if JSON is not available
|
||||
tenant_id = request.args.get('tenant_id') or request.args.get('tenantId') or request.headers.get('X-Tenant-ID')
|
||||
# Regular API requests
|
||||
try:
|
||||
if verify_jwt_in_request(optional=True):
|
||||
tenant_id = get_jwt_identity()
|
||||
except Exception as e:
|
||||
current_app.logger.error(f'Error verifying JWT: {e}')
|
||||
return response
|
||||
|
||||
if tenant_id:
|
||||
origin = request.headers.get('Origin')
|
||||
allowed_origins = get_allowed_origins(tenant_id)
|
||||
else:
|
||||
current_app.logger.warning('tenant_id not found in request')
|
||||
|
||||
origin = request.headers.get('Origin')
|
||||
if origin in allowed_origins:
|
||||
response.headers.add('Access-Control-Allow-Origin', origin)
|
||||
response.headers.add('Access-Control-Allow-Headers', 'Content-Type,Authorization')
|
||||
response.headers.add('Access-Control-Allow-Methods', 'GET,POST,PUT,DELETE,OPTIONS')
|
||||
response.headers.add('Access-Control-Allow-Credentials', 'true')
|
||||
else:
|
||||
current_app.logger.warning(f'Origin {origin} not allowed')
|
||||
if origin in allowed_origins:
|
||||
response.headers.add('Access-Control-Allow-Origin', origin)
|
||||
response.headers.add('Access-Control-Allow-Headers', 'Content-Type,Authorization')
|
||||
response.headers.add('Access-Control-Allow-Methods', 'GET,POST,PUT,DELETE,OPTIONS')
|
||||
response.headers.add('Access-Control-Allow-Credentials', 'true')
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@@ -118,3 +118,10 @@ class EveAIInvalidDocumentVersion(EveAIException):
|
||||
# Construct the message dynamically
|
||||
message = f"Tenant with ID '{tenant_id}' has no document version with ID {document_version_id}."
|
||||
super().__init__(message, status_code, payload)
|
||||
|
||||
|
||||
class EveAISocketInputException(EveAIException):
|
||||
"""Raised when a socket call receives an invalid payload"""
|
||||
|
||||
def __init__(self, message, status_code=400, payload=None):
|
||||
super.__init__(message, status_code, payload)
|
||||
@@ -252,7 +252,7 @@ class ModelVariablesCacheHandler(CacheHandler[ModelVariables]):
|
||||
|
||||
|
||||
# Register the handler with the cache manager
|
||||
cache_manager.register_handler(ModelVariablesCacheHandler, 'model')
|
||||
cache_manager.register_handler(ModelVariablesCacheHandler, 'eveai_model')
|
||||
|
||||
|
||||
# Helper function to get cached model variables
|
||||
|
||||
60
common/utils/token_validation.py
Normal file
60
common/utils/token_validation.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
from flask_jwt_extended import decode_token, verify_jwt_in_request
|
||||
from flask import current_app
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenValidationResult:
|
||||
"""Clean, simple validation result"""
|
||||
is_valid: bool
|
||||
tenant_id: Optional[int] = None
|
||||
error_message: Optional[str] = None
|
||||
|
||||
|
||||
class TokenValidator:
|
||||
"""Simplified token validator focused on JWT validation"""
|
||||
|
||||
def validate_token(self, token: str) -> TokenValidationResult:
|
||||
"""
|
||||
Validate JWT token
|
||||
|
||||
Args:
|
||||
token: The JWT token to validate
|
||||
|
||||
Returns:
|
||||
TokenValidationResult with validation status and tenant_id if valid
|
||||
"""
|
||||
try:
|
||||
# Decode and validate token
|
||||
decoded_token = decode_token(token)
|
||||
|
||||
# Extract tenant_id from token subject
|
||||
tenant_id = decoded_token.get('sub')
|
||||
if not tenant_id:
|
||||
return TokenValidationResult(
|
||||
is_valid=False,
|
||||
error_message="Missing tenant ID in token"
|
||||
)
|
||||
|
||||
# Verify token timestamps
|
||||
now = datetime.utcnow().timestamp()
|
||||
if not (decoded_token.get('exp', 0) > now >= decoded_token.get('nbf', 0)):
|
||||
return TokenValidationResult(
|
||||
is_valid=False,
|
||||
error_message="Token expired or not yet valid"
|
||||
)
|
||||
|
||||
# Token is valid
|
||||
return TokenValidationResult(
|
||||
is_valid=True,
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Token validation error: {str(e)}")
|
||||
return TokenValidationResult(
|
||||
is_valid=False,
|
||||
error_message=str(e)
|
||||
)
|
||||
Reference in New Issue
Block a user