API key working, CORS working, SocketIO working (but no JWT), Chat client v1, Session implemented (server side)

This commit is contained in:
Josako
2024-05-22 21:32:09 +02:00
parent 883988dbab
commit 364da812ba
21 changed files with 763 additions and 69 deletions

View File

@@ -6,6 +6,8 @@ from flask_mailman import Mail
from flask_login import LoginManager
from flask_cors import CORS
from flask_socketio import SocketIO
from flask_jwt_extended import JWTManager
from flask_session import Session
from .utils.key_encryption import JosKMSClient
@@ -18,4 +20,7 @@ mail = Mail()
login_manager = LoginManager()
cors = CORS()
socketio = SocketIO()
kms_client = JosKMSClient()
jwt = JWTManager()
session = Session()
kms_client = JosKMSClient.from_service_account_json('config/gc_sa_eveai.json')

View File

@@ -1,16 +1,16 @@
from ..extensions import db
from .user import User, Tenant
from .document import Embedding
class ChatSession(db.Model):
id = db.Column(db.Integer, primary_key=True)
user_id = db.Column(db.Integer, db.ForeignKey('public.user.id'), nullable=True)
user_id = db.Column(db.Integer, db.ForeignKey(User.id), nullable=True)
session_start = db.Column(db.DateTime, nullable=False)
session_end = db.Column(db.DateTime, nullable=True)
# Relations
chat_interactions = db.relationship('Interaction', backref='chat_session', lazy=True)
user = db.relationship('User', backref='chat_sessions', lazy=True)
interactions = db.relationship('Interaction', backref='chat_session', lazy=True)
def __repr__(self):
return f"<ChatSession {self.id} by {self.user_id}>"
@@ -18,7 +18,7 @@ class ChatSession(db.Model):
class Interaction(db.Model):
id = db.Column(db.Integer, primary_key=True)
chat_session_id = db.Column(db.Integer, db.ForeignKey('public.chat_session.id'), nullable=False)
chat_session_id = db.Column(db.Integer, db.ForeignKey(ChatSession.id), nullable=False)
question = db.Column(db.Text, nullable=False)
answer = db.Column(db.Text, nullable=True)
language = db.Column(db.String(2), nullable=False)
@@ -33,5 +33,5 @@ class Interaction(db.Model):
class InteractionEmbedding(db.Model):
interaction_id = db.Column(db.Integer, db.ForeignKey('interaction.id', ondelete='CASCADE'), primary_key=True)
embedding_id = db.Column(db.Integer, db.ForeignKey('embedding.id', ondelete='CASCADE'), primary_key=True)
interaction_id = db.Column(db.Integer, db.ForeignKey(Interaction.id, ondelete='CASCADE'), primary_key=True)
embedding_id = db.Column(db.Integer, db.ForeignKey(Embedding.id, ondelete='CASCADE'), primary_key=True)

View File

@@ -0,0 +1,75 @@
from flask import request, current_app, session
from common.models.user import Tenant, TenantDomain
def get_allowed_origins(tenant_id):
session_key = f"allowed_origins_{tenant_id}"
if session_key in session:
current_app.logger.debug(f"Fetching allowed origins for tenant {tenant_id} from session")
return session[session_key]
current_app.logger.debug(f"Fetching allowed origins for tenant {tenant_id} from database")
tenant_domains = TenantDomain.query.filter_by(tenant_id=int(tenant_id)).all()
allowed_origins = [domain.domain for domain in tenant_domains]
# Cache the result in the session
session[session_key] = allowed_origins
return allowed_origins
def cors_after_request(response, prefix):
current_app.logger.debug(f'CORS after request: {request.path}, prefix: {prefix}')
current_app.logger.debug(f'request.headers: {request.headers}')
current_app.logger.debug(f'request.args: {request.args}')
current_app.logger.debug(f'request is json?: {request.is_json}')
tenant_id = None
allowed_origins = []
# Try to get tenant_id from JSON payload
json_data = request.get_json(silent=True)
current_app.logger.debug(f'request.get_json(silent=True): {json_data}')
if json_data and 'tenant_id' in json_data:
tenant_id = json_data['tenant_id']
else:
# Fallback to get tenant_id from query parameters or headers if JSON is not available
tenant_id = request.args.get('tenant_id') or request.args.get('tenantId') or request.headers.get('X-Tenant-ID')
current_app.logger.debug(f'Identified tenant_id: {tenant_id}')
if tenant_id:
allowed_origins = get_allowed_origins(tenant_id)
current_app.logger.debug(f'Allowed origins for tenant {tenant_id}: {allowed_origins}')
else:
current_app.logger.warning('tenant_id not found in request')
origin = request.headers.get('Origin')
current_app.logger.debug(f'Origin: {origin}')
if origin in allowed_origins:
response.headers.add('Access-Control-Allow-Origin', origin)
response.headers.add('Access-Control-Allow-Headers', 'Content-Type,Authorization')
response.headers.add('Access-Control-Allow-Methods', 'GET,POST,PUT,DELETE,OPTIONS')
response.headers.add('Access-Control-Allow-Credentials', 'true')
current_app.logger.debug(f'CORS headers set for origin: {origin}')
else:
current_app.logger.warning(f'Origin {origin} not allowed')
return response
def create_cors_after_request(prefix):
def wrapped_cors_after_request(response):
return cors_after_request(response, prefix)
return wrapped_cors_after_request
def create_multiple_cors_after_requests(prefixes):
def wrapped_cors_after_requests(response):
for prefix, cors_function in prefixes:
response = cors_function(response)
return response
return wrapped_cors_after_requests

View File

@@ -1,9 +1,11 @@
from google.cloud import kms
from google.cloud import kms_v1
from base64 import b64encode, b64decode
from Crypto.Cipher import AES
from Crypto.Random import get_random_bytes
import random
import time
from flask import Flask
import os
def generate_api_key(prefix="EveAI-Chat"):
@@ -11,7 +13,7 @@ def generate_api_key(prefix="EveAI-Chat"):
return f"{prefix}-{'-'.join(parts)}"
class JosKMSClient(kms.KeyManagementServiceClient):
class JosKMSClient(kms_v1.KeyManagementServiceClient):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.key_name = None
@@ -26,18 +28,36 @@ class JosKMSClient(kms.KeyManagementServiceClient):
self.key_ring = app.config.get('GC_KEY_RING')
self.crypto_key = app.config.get('GC_CRYPTO_KEY')
self.key_name = self.crypto_key_path(self.project_id, self.location, self.key_ring, self.crypto_key)
app.logger.info(f'Project ID: {self.project_id}')
app.logger.info(f'Location: {self.location}')
app.logger.info(f'Key Ring: {self.key_ring}')
app.logger.info(f'Crypto Key: {self.crypto_key}')
app.logger.info(f'Key Name: {self.key_name}')
app.logger.info(f'Service Account Key Path: {os.getenv('GOOGLE_APPLICATION_CREDENTIALS')}')
os.environ["GOOGLE_CLOUD_PROJECT"] = self.project_id
def encrypt_api_key(self, api_key):
"""Encrypts the API key using the latest version of the KEK."""
dek = get_random_bytes(32) # AES 256-bit key
cipher = AES.new(dek, AES.MODE_GCM)
ciphertext, tag = cipher.encrypt_and_digest(api_key.encode())
# print(f'Dek: {dek}')
# Encrypt the DEK using the latest version of the Google Cloud KMS key
encrypt_response = self.encrypt(
request={'name': self.key_name, 'plaintext': dek}
)
encrypted_dek = encrypt_response.ciphertext
# print(f"Encrypted DEK: {encrypted_dek}")
#
# # Check
# decrypt_response = self.decrypt(
# request={'name': self.key_name, 'ciphertext': encrypted_dek}
# )
# decrypted_dek = decrypt_response.plaintext
# print(f"Decrypted DEK: {decrypted_dek}")
# Store the version of the key used
key_version = encrypt_response.name
@@ -53,17 +73,35 @@ class JosKMSClient(kms.KeyManagementServiceClient):
def decrypt_api_key(self, encrypted_data):
"""Decrypts the API key using the specified key version."""
key_version = encrypted_data['key_version']
encrypted_dek = b64decode(encrypted_data['encrypted_dek'])
nonce = b64decode(encrypted_data['nonce'])
tag = b64decode(encrypted_data['tag'])
ciphertext = b64decode(encrypted_data['ciphertext'])
key_name = self.key_name
encrypted_dek = b64decode(encrypted_data['encrypted_dek'].encode('utf-8'))
nonce = b64decode(encrypted_data['nonce'].encode('utf-8'))
tag = b64decode(encrypted_data['tag'].encode('utf-8'))
ciphertext = b64decode(encrypted_data['ciphertext'].encode('utf-8'))
# Decrypt the DEK using the specified version of the Google Cloud KMS key
decrypt_response = self.decrypt(
request={'name': key_version, 'ciphertext': encrypted_dek}
)
dek = decrypt_response.plaintext
try:
decrypt_response = self.decrypt(
request={'name': key_name, 'ciphertext': encrypted_dek}
)
dek = decrypt_response.plaintext
except Exception as e:
print(f"Failed to decrypt DEK: {e}")
return None
cipher = AES.new(dek, AES.MODE_GCM, nonce=nonce)
api_key = cipher.decrypt_and_verify(ciphertext, tag)
return api_key.decode()
def check_kms_access_and_latency(self):
# key_name = self.crypto_key_path(self.project_id, self.location, self.key_ring, self.crypto_key)
#
# start_time = time.time()
# try:
# response = self.get_crypto_key(name=key_name)
# end_time = time.time()
# print(f"Response Time: {end_time - start_time} seconds")
# print("Access to KMS is successful.")
# except Exception as e:
# print(f"Failed to access KMS: {e}")
pass