- Improvements to enable deployment in the cloud, mainly changing file access to Minio

- Improvements on RAG logging, and some debugging in that area
This commit is contained in:
Josako
2024-08-01 17:35:54 +02:00
parent 88ca04136d
commit 64cf8df3a9
19 changed files with 617 additions and 206 deletions

View File

@@ -11,6 +11,7 @@ from flask_session import Session
from flask_wtf import CSRFProtect
from .utils.key_encryption import JosKMSClient
from .utils.minio_utils import MinioClient
# Create extensions
db = SQLAlchemy()
@@ -26,3 +27,4 @@ jwt = JWTManager()
session = Session()
kms_client = JosKMSClient.from_service_account_json('config/gc_sa_eveai.json')
minio_client = MinioClient()

View File

@@ -1,5 +1,5 @@
from langchain_core.retrievers import BaseRetriever
from sqlalchemy import func, and_, or_
from sqlalchemy import func, and_, or_, desc
from sqlalchemy.exc import SQLAlchemyError
from pydantic import BaseModel, Field
from typing import Any, Dict
@@ -20,12 +20,56 @@ class EveAIRetriever(BaseRetriever):
self.tenant_info = tenant_info
def _get_relevant_documents(self, query: str):
current_app.logger.debug(f'Retrieving relevant documents for query: {query}')
query_embedding = self._get_query_embedding(query)
db_class = self.model_variables['embedding_db_model']
similarity_threshold = self.model_variables['similarity_threshold']
k = self.model_variables['k']
if self.tenant_info['rag_tuning']:
try:
current_date = get_date_in_timezone(self.tenant_info['timezone'])
current_app.rag_tuning_logger.debug(f'Current date: {current_date}\n')
# Debug query to show similarity for all valid documents (without chunk text)
debug_query = (
db.session.query(
Document.id.label('document_id'),
DocumentVersion.id.label('version_id'),
db_class.id.label('embedding_id'),
(1 - db_class.embedding.cosine_distance(query_embedding)).label('similarity')
)
.join(DocumentVersion, db_class.doc_vers_id == DocumentVersion.id)
.join(Document, DocumentVersion.doc_id == Document.id)
.filter(
or_(Document.valid_from.is_(None), func.date(Document.valid_from) <= current_date),
or_(Document.valid_to.is_(None), func.date(Document.valid_to) >= current_date)
)
.order_by(desc('similarity'))
)
debug_results = debug_query.all()
current_app.logger.debug("Debug: Similarity for all valid documents:")
for row in debug_results:
current_app.rag_tuning_logger.debug(f"Doc ID: {row.document_id}, "
f"Version ID: {row.version_id}, "
f"Embedding ID: {row.embedding_id}, "
f"Similarity: {row.similarity}")
current_app.rag_tuning_logger.debug(f'---------------------------------------\n')
except SQLAlchemyError as e:
current_app.logger.error(f'Error generating overview: {e}')
db.session.rollback()
if self.tenant_info['rag_tuning']:
current_app.rag_tuning_logger.debug(f'Parameters for Retrieval of documents: \n')
current_app.rag_tuning_logger.debug(f'Similarity Threshold: {similarity_threshold}\n')
current_app.rag_tuning_logger.debug(f'K: {k}\n')
current_app.rag_tuning_logger.debug(f'---------------------------------------\n')
try:
current_date = get_date_in_timezone(self.tenant_info['timezone'])
# Subquery to find the latest version of each document
@@ -40,24 +84,31 @@ class EveAIRetriever(BaseRetriever):
# Main query to filter embeddings
query_obj = (
db.session.query(db_class,
db_class.embedding.cosine_distance(query_embedding).label('distance'))
(1 - db_class.embedding.cosine_distance(query_embedding)).label('similarity'))
.join(DocumentVersion, db_class.doc_vers_id == DocumentVersion.id)
.join(Document, DocumentVersion.doc_id == Document.id)
.join(subquery, DocumentVersion.id == subquery.c.latest_version_id)
.filter(
or_(Document.valid_from.is_(None), Document.valid_from <= current_date),
or_(Document.valid_to.is_(None), Document.valid_to >= current_date),
db_class.embedding.cosine_distance(query_embedding) < similarity_threshold
or_(Document.valid_from.is_(None), func.date(Document.valid_from) <= current_date),
or_(Document.valid_to.is_(None), func.date(Document.valid_to) >= current_date),
(1 - db_class.embedding.cosine_distance(query_embedding)) > similarity_threshold
)
.order_by('distance')
.order_by(desc('similarity'))
.limit(k)
)
if self.tenant_info['rag_tuning']:
current_app.rag_tuning_logger.debug(f'Query executed for Retrieval of documents: \n')
current_app.rag_tuning_logger.debug(f'{query_obj.statement}\n')
current_app.rag_tuning_logger.debug(f'---------------------------------------\n')
res = query_obj.all()
if self.tenant_info['rag_tuning']:
current_app.rag_tuning_logger.debug(f'Retrieved {len(res)} relevant documents')
current_app.rag_tuning_logger.debug(f'---------------------------------------')
current_app.rag_tuning_logger.debug(f'Retrieved {len(res)} relevant documents \n')
current_app.rag_tuning_logger.debug(f'Data retrieved: \n')
current_app.rag_tuning_logger.debug(f'{res}\n')
current_app.rag_tuning_logger.debug(f'---------------------------------------\n')
result = []
for doc in res:

View File

@@ -82,7 +82,6 @@ class Tenant(db.Model):
'html_excluded_elements': self.html_excluded_elements,
'min_chunk_size': self.min_chunk_size,
'max_chunk_size': self.max_chunk_size,
'es_k'
'es_k': self.es_k,
'es_similarity_threshold': self.es_similarity_threshold,
'chat_RAG_temperature': self.chat_RAG_temperature,

View File

@@ -0,0 +1,86 @@
from minio import Minio
from minio.error import S3Error
from flask import Flask
import io
from werkzeug.datastructures import FileStorage
class MinioClient:
def __init__(self):
self.client = None
def init_app(self, app: Flask):
self.client = Minio(
app.config['MINIO_ENDPOINT'],
access_key=app.config['MINIO_ACCESS_KEY'],
secret_key=app.config['MINIO_SECRET_KEY'],
secure=app.config.get('MINIO_USE_HTTPS', False)
)
app.logger.info(f"MinIO client initialized with endpoint: {app.config['MINIO_ENDPOINT']}")
def generate_bucket_name(self, tenant_id):
return f"tenant-{tenant_id}-bucket"
def create_tenant_bucket(self, tenant_id):
bucket_name = self.generate_bucket_name(tenant_id)
try:
if not self.client.bucket_exists(bucket_name):
self.client.make_bucket(bucket_name)
return bucket_name
return bucket_name
except S3Error as err:
raise Exception(f"Error occurred while creating bucket: {err}")
def generate_object_name(self, document_id, language, version_id, filename):
return f"{document_id}/{language}/{version_id}/{filename}"
def upload_document_file(self, tenant_id, document_id, language, version_id, filename, file_data):
bucket_name = self.generate_bucket_name(tenant_id)
object_name = self.generate_object_name(document_id, language, version_id, filename)
try:
if isinstance(file_data, FileStorage):
file_data = file_data.read()
elif isinstance(file_data, io.BytesIO):
file_data = file_data.getvalue()
elif isinstance(file_data, str):
file_data = file_data.encode('utf-8')
elif not isinstance(file_data, bytes):
raise TypeError('Unsupported file type. Expected FileStorage, BytesIO, str, or bytes.')
self.client.put_object(
bucket_name, object_name, io.BytesIO(file_data), len(file_data)
)
return True
except S3Error as err:
raise Exception(f"Error occurred while uploading file: {err}")
def download_document_file(self, tenant_id, document_id, language, version_id, filename):
bucket_name = self.generate_bucket_name(tenant_id)
object_name = self.generate_object_name(document_id, language, version_id, filename)
try:
response = self.client.get_object(bucket_name, object_name)
return response.read()
except S3Error as err:
raise Exception(f"Error occurred while downloading file: {err}")
def list_document_files(self, tenant_id, document_id, language=None, version_id=None):
bucket_name = self.generate_bucket_name(tenant_id)
prefix = f"{document_id}/"
if language:
prefix += f"{language}/"
if version_id:
prefix += f"{version_id}/"
try:
objects = self.client.list_objects(bucket_name, prefix=prefix, recursive=True)
return [obj.object_name for obj in objects]
except S3Error as err:
raise Exception(f"Error occurred while listing files: {err}")
def delete_document_file(self, tenant_id, document_id, language, version_id, filename):
bucket_name = self.generate_bucket_name(tenant_id)
object_name = self.generate_object_name(document_id, language, version_id, filename)
try:
self.client.remove_object(bucket_name, object_name)
return True
except S3Error as err:
raise Exception(f"Error occurred while deleting file: {err}")

View File

@@ -141,7 +141,7 @@ def select_model_variables(tenant):
default_headers=portkey_headers)
tool_calling_supported = False
match llm_model:
case 'gpt-4-turbo' | 'gpt-4o':
case 'gpt-4-turbo' | 'gpt-4o' | 'gpt-4o-mini':
tool_calling_supported = True
case _:
raise Exception(f'Error setting model variables for tenant {tenant.id} '