13 Commits

Author SHA1 Message Date
Josako
9e14824249 - Furter refinement of the API, adding functionality for refreshing documents and returning Token expiration time when retrieving token
- Implementation of a first version of a Wordpress plugin
- Adding api service to nginx.conf
2024-09-11 16:31:13 +02:00
Josako
76cb825660 - Full API application, streamlined, de-duplication of document handling code into document_utils.py
- Added meta-data fields to DocumentVersion
- Docker container to support API
2024-09-09 16:11:42 +02:00
Josako
341ba47d1c - Bugfixing 2024-09-05 14:31:54 +02:00
Josako
1fa33c029b - Correcting mistakes in tenant schema migrations 2024-09-03 11:50:25 +02:00
Josako
bcf7d439f3 - Old migration files that were not added to GIT 2024-09-03 11:49:46 +02:00
Josako
b9acf4d2ae - Add CHANGELOG.md 2024-09-02 14:04:44 +02:00
Josako
ae7bf3dbae - Correct default language when adding Documents and URLs 2024-09-02 14:04:22 +02:00
Josako
914c265afe - Improvements on document uploads (accept other files than html-files when entering a URL)
- Introduction of API-functionality (to be continued). Deduplication of document and url uploads between views and api.
- Improvements on document processing - introduction of processor classes to streamline document inputs
- Removed pure Youtube functionality, as Youtube retrieval of documents continuously changes. But added upload of srt, mp3, ogg and mp4
2024-09-02 12:37:44 +02:00
Josako
a158655247 - Add API Key Registration to tenant 2024-08-29 10:42:39 +02:00
Josako
bc350af247 - Allow the chat-widget to connect to multiple servers (e.g. development and production)
- Created a full session overview
2024-08-28 10:11:31 +02:00
Josako
6062b7646c - Allow multiple instances of Evie on 1 website. Shortcode is now parametrized. 2024-08-27 10:31:33 +02:00
Josako
122d1a18df - Allow for more complex and longer PDFs to be uploaded to Evie. First implmentation of a processor for specific file types.
- Allow URLs to contain other information than just HTML information. It can alose refer to e.g. PDF-files.
2024-08-27 07:05:56 +02:00
Josako
2ca006d82c Added excluded element classes to HTML parsing to allow for more complex document parsing
Added chunking to conversion of HTML to markdown in case of large files
2024-08-22 16:41:13 +02:00
87 changed files with 3777 additions and 7714 deletions

29
.gitignore vendored
View File

@@ -12,3 +12,32 @@ docker/tenant_files/
**/.DS_Store **/.DS_Store
__pycache__ __pycache__
**/__pycache__ **/__pycache__
/.idea
*.pyc
*.pyc
common/.DS_Store
common/__pycache__/__init__.cpython-312.pyc
common/__pycache__/extensions.cpython-312.pyc
common/models/__pycache__/__init__.cpython-312.pyc
common/models/__pycache__/document.cpython-312.pyc
common/models/__pycache__/interaction.cpython-312.pyc
common/models/__pycache__/user.cpython-312.pyc
common/utils/.DS_Store
common/utils/__pycache__/__init__.cpython-312.pyc
common/utils/__pycache__/celery_utils.cpython-312.pyc
common/utils/__pycache__/nginx_utils.cpython-312.pyc
common/utils/__pycache__/security.cpython-312.pyc
common/utils/__pycache__/simple_encryption.cpython-312.pyc
common/utils/__pycache__/template_filters.cpython-312.pyc
config/.DS_Store
config/__pycache__/__init__.cpython-312.pyc
config/__pycache__/config.cpython-312.pyc
config/__pycache__/logging_config.cpython-312.pyc
eveai_app/.DS_Store
eveai_app/__pycache__/__init__.cpython-312.pyc
eveai_app/__pycache__/errors.cpython-312.pyc
eveai_chat/.DS_Store
migrations/.DS_Store
migrations/public/.DS_Store
scripts/.DS_Store
scripts/__pycache__/run_eveai_app.cpython-312.pyc

87
CHANGELOG.md Normal file
View File

@@ -0,0 +1,87 @@
# Changelog
All notable changes to EveAI will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [Unreleased]
### Added
- For new features.
### Changed
- For changes in existing functionality.
### Deprecated
- For soon-to-be removed features.
### Removed
- For now removed features.
### Fixed
- Set default language when registering Documents or URLs.
### Security
- In case of vulnerabilities.
-
## [1.0.6-alfa] - 2024-09-03
### Fixed
- Problems with tenant scheme migrations - may have to be revisited
- Correction of default language settings when uploading docs or URLs
- Addition of a CHANGELOG.md file
## [1.0.5-alfa] - 2024-09-02
### Added
- Allow chatwidget to connect to multiple servers (e.g. development and production)
- Start implementation of API
- Add API-key functionality to tenants
- Deduplication of API and Document view code
- Allow URL addition to accept all types of files, not just HTML
- Allow new file types upload: srt, mp3, ogg, mp4
- Improve processing of different file types using Processor classes
### Removed
- Removed direct upload of Youtube URLs, due to continuous changes in Youtube website
## [1.0.4-alfa] - 2024-08-27
Skipped
## [1.0.3-alfa] - 2024-08-27
### Added
- Refinement of HTML processing - allow for excluded classes and elements.
- Allow for multiple instances of Evie on 1 website (pure + Wordpress plugin)
### Changed
- PDF Processing extracted in new PDF Processor class.
- Allow for longer and more complex PDFs to be uploaded.
## [1.0.2-alfa] - 2024-08-22
### Fixed
- Bugfix for ResetPasswordForm in config.py
## [1.0.1-alfa] - 2024-08-21
### Added
- Full Document Version Overview
### Changed
- Improvements to user creation and registration, renewal of passwords, ...
## [1.0.0-alfa] - 2024-08-16
### Added
- Initial release of the project.
### Changed
- None
### Fixed
- None
[Unreleased]: https://github.com/username/repo/compare/v1.0.0...HEAD
[1.0.0]: https://github.com/username/repo/releases/tag/v1.0.0

View File

@@ -9,6 +9,7 @@ from flask_socketio import SocketIO
from flask_jwt_extended import JWTManager from flask_jwt_extended import JWTManager
from flask_session import Session from flask_session import Session
from flask_wtf import CSRFProtect from flask_wtf import CSRFProtect
from flask_restx import Api
from .utils.nginx_utils import prefixed_url_for from .utils.nginx_utils import prefixed_url_for
from .utils.simple_encryption import SimpleEncryption from .utils.simple_encryption import SimpleEncryption
@@ -27,8 +28,6 @@ cors = CORS()
socketio = SocketIO() socketio = SocketIO()
jwt = JWTManager() jwt = JWTManager()
session = Session() session = Session()
api_rest = Api()
# kms_client = JosKMSClient.from_service_account_json('config/gc_sa_eveai.json')
simple_encryption = SimpleEncryption() simple_encryption = SimpleEncryption()
minio_client = MinioClient() minio_client = MinioClient()

2
common/models/README.txt Normal file
View File

@@ -0,0 +1,2 @@
If models are added to the public schema (i.e. in the user domain), ensure to add their corresponding tables to the
env.py, get_public_table_names, for tenant migrations!

View File

@@ -1,6 +1,7 @@
from common.extensions import db from common.extensions import db
from .user import User, Tenant from .user import User, Tenant
from pgvector.sqlalchemy import Vector from pgvector.sqlalchemy import Vector
from sqlalchemy.dialects.postgresql import JSONB
class Document(db.Model): class Document(db.Model):
@@ -12,7 +13,7 @@ class Document(db.Model):
# Versioning Information # Versioning Information
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now()) created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now())
created_by = db.Column(db.Integer, db.ForeignKey(User.id), nullable=False) created_by = db.Column(db.Integer, db.ForeignKey(User.id), nullable=True)
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now(), onupdate=db.func.now()) updated_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now(), onupdate=db.func.now())
updated_by = db.Column(db.Integer, db.ForeignKey(User.id)) updated_by = db.Column(db.Integer, db.ForeignKey(User.id))
@@ -33,6 +34,8 @@ class DocumentVersion(db.Model):
language = db.Column(db.String(2), nullable=False) language = db.Column(db.String(2), nullable=False)
user_context = db.Column(db.Text, nullable=True) user_context = db.Column(db.Text, nullable=True)
system_context = db.Column(db.Text, nullable=True) system_context = db.Column(db.Text, nullable=True)
user_metadata = db.Column(JSONB, nullable=True)
system_metadata = db.Column(JSONB, nullable=True)
# Versioning Information # Versioning Information
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now()) created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now())

View File

@@ -35,10 +35,11 @@ class Tenant(db.Model):
html_end_tags = db.Column(ARRAY(sa.String(10)), nullable=True, default=['p', 'li']) html_end_tags = db.Column(ARRAY(sa.String(10)), nullable=True, default=['p', 'li'])
html_included_elements = db.Column(ARRAY(sa.String(50)), nullable=True) html_included_elements = db.Column(ARRAY(sa.String(50)), nullable=True)
html_excluded_elements = db.Column(ARRAY(sa.String(50)), nullable=True) html_excluded_elements = db.Column(ARRAY(sa.String(50)), nullable=True)
html_excluded_classes = db.Column(ARRAY(sa.String(200)), nullable=True)
min_chunk_size = db.Column(db.Integer, nullable=True, default=2000) min_chunk_size = db.Column(db.Integer, nullable=True, default=2000)
max_chunk_size = db.Column(db.Integer, nullable=True, default=3000) max_chunk_size = db.Column(db.Integer, nullable=True, default=3000)
# Embedding search variables # Embedding search variables
es_k = db.Column(db.Integer, nullable=True, default=5) es_k = db.Column(db.Integer, nullable=True, default=5)
es_similarity_threshold = db.Column(db.Float, nullable=True, default=0.7) es_similarity_threshold = db.Column(db.Float, nullable=True, default=0.7)
@@ -53,6 +54,8 @@ class Tenant(db.Model):
license_end_date = db.Column(db.Date, nullable=True) license_end_date = db.Column(db.Date, nullable=True)
allowed_monthly_interactions = db.Column(db.Integer, nullable=True) allowed_monthly_interactions = db.Column(db.Integer, nullable=True)
encrypted_chat_api_key = db.Column(db.String(500), nullable=True) encrypted_chat_api_key = db.Column(db.String(500), nullable=True)
encrypted_api_key = db.Column(db.String(500), nullable=True)
# Tuning enablers # Tuning enablers
embed_tuning = db.Column(db.Boolean, nullable=True, default=False) embed_tuning = db.Column(db.Boolean, nullable=True, default=False)
@@ -80,6 +83,7 @@ class Tenant(db.Model):
'html_end_tags': self.html_end_tags, 'html_end_tags': self.html_end_tags,
'html_included_elements': self.html_included_elements, 'html_included_elements': self.html_included_elements,
'html_excluded_elements': self.html_excluded_elements, 'html_excluded_elements': self.html_excluded_elements,
'html_excluded_classes': self.html_excluded_classes,
'min_chunk_size': self.min_chunk_size, 'min_chunk_size': self.min_chunk_size,
'max_chunk_size': self.max_chunk_size, 'max_chunk_size': self.max_chunk_size,
'es_k': self.es_k, 'es_k': self.es_k,

View File

@@ -0,0 +1,338 @@
from datetime import datetime as dt, timezone as tz
from sqlalchemy import desc
from sqlalchemy.exc import SQLAlchemyError
from werkzeug.utils import secure_filename
from common.models.document import Document, DocumentVersion
from common.extensions import db, minio_client
from common.utils.celery_utils import current_celery
from flask import current_app
from flask_security import current_user
import requests
from urllib.parse import urlparse, unquote
import os
from .eveai_exceptions import EveAIInvalidLanguageException, EveAIDoubleURLException, EveAIUnsupportedFileType
def create_document_stack(api_input, file, filename, extension, tenant_id):
# Create the Document
new_doc = create_document(api_input, filename, tenant_id)
db.session.add(new_doc)
# Create the DocumentVersion
new_doc_vers = create_version_for_document(new_doc,
api_input.get('url', ''),
api_input.get('language', 'en'),
api_input.get('user_context', ''),
api_input.get('user_metadata'),
)
db.session.add(new_doc_vers)
try:
db.session.commit()
except SQLAlchemyError as e:
current_app.logger.error(f'Error adding document for tenant {tenant_id}: {e}')
db.session.rollback()
raise
current_app.logger.info(f'Document added successfully for tenant {tenant_id}, '
f'Document Version {new_doc.id}')
# Upload file to storage
upload_file_for_version(new_doc_vers, file, extension, tenant_id)
return new_doc, new_doc_vers
def create_document(form, filename, tenant_id):
new_doc = Document()
if form['name'] == '':
new_doc.name = filename.rsplit('.', 1)[0]
else:
new_doc.name = form['name']
if form['valid_from'] and form['valid_from'] != '':
new_doc.valid_from = form['valid_from']
else:
new_doc.valid_from = dt.now(tz.utc)
new_doc.tenant_id = tenant_id
set_logging_information(new_doc, dt.now(tz.utc))
return new_doc
def create_version_for_document(document, url, language, user_context, user_metadata):
new_doc_vers = DocumentVersion()
if url != '':
new_doc_vers.url = url
if language == '':
raise EveAIInvalidLanguageException('Language is required for document creation!')
else:
new_doc_vers.language = language
if user_context != '':
new_doc_vers.user_context = user_context
if user_metadata != '' and user_metadata is not None:
new_doc_vers.user_metadata = user_metadata
new_doc_vers.document = document
set_logging_information(new_doc_vers, dt.now(tz.utc))
return new_doc_vers
def upload_file_for_version(doc_vers, file, extension, tenant_id):
doc_vers.file_type = extension
doc_vers.file_name = doc_vers.calc_file_name()
doc_vers.file_location = doc_vers.calc_file_location()
# Normally, the tenant bucket should exist. But let's be on the safe side if a migration took place.
minio_client.create_tenant_bucket(tenant_id)
try:
minio_client.upload_document_file(
tenant_id,
doc_vers.doc_id,
doc_vers.language,
doc_vers.id,
doc_vers.file_name,
file
)
db.session.commit()
current_app.logger.info(f'Successfully saved document to MinIO for tenant {tenant_id} for '
f'document version {doc_vers.id} while uploading file.')
except Exception as e:
db.session.rollback()
current_app.logger.error(
f'Error saving document to MinIO for tenant {tenant_id}: {e}')
raise
def set_logging_information(obj, timestamp):
obj.created_at = timestamp
obj.updated_at = timestamp
user_id = get_current_user_id()
if user_id:
obj.created_by = user_id
obj.updated_by = user_id
def update_logging_information(obj, timestamp):
obj.updated_at = timestamp
user_id = get_current_user_id()
if user_id:
obj.updated_by = user_id
def get_current_user_id():
try:
if current_user and current_user.is_authenticated:
return current_user.id
else:
return None
except Exception:
# This will catch any errors if current_user is not available (e.g., in API context)
return None
def get_extension_from_content_type(content_type):
content_type_map = {
'text/html': 'html',
'application/pdf': 'pdf',
'text/plain': 'txt',
'application/msword': 'doc',
'application/vnd.openxmlformats-officedocument.wordprocessingml.document': 'docx',
# Add more mappings as needed
}
return content_type_map.get(content_type, 'html') # Default to 'html' if unknown
def process_url(url, tenant_id):
response = requests.head(url, allow_redirects=True)
content_type = response.headers.get('Content-Type', '').split(';')[0]
# Determine file extension based on Content-Type
extension = get_extension_from_content_type(content_type)
# Generate filename
parsed_url = urlparse(url)
path = unquote(parsed_url.path)
filename = os.path.basename(path)
if not filename or '.' not in filename:
# Use the last part of the path or a default name
filename = path.strip('/').split('/')[-1] or 'document'
filename = secure_filename(f"{filename}.{extension}")
else:
filename = secure_filename(filename)
# Check if a document with this URL already exists
existing_doc = DocumentVersion.query.filter_by(url=url).first()
if existing_doc:
raise EveAIDoubleURLException
# Download the content
response = requests.get(url)
response.raise_for_status()
file_content = response.content
return file_content, filename, extension
def process_multiple_urls(urls, tenant_id, api_input):
results = []
for url in urls:
try:
file_content, filename, extension = process_url(url, tenant_id)
url_input = api_input.copy()
url_input.update({
'url': url,
'name': f"{api_input['name']}-{filename}" if api_input['name'] else filename
})
new_doc, new_doc_vers = create_document_stack(url_input, file_content, filename, extension, tenant_id)
task_id = start_embedding_task(tenant_id, new_doc_vers.id)
results.append({
'url': url,
'document_id': new_doc.id,
'document_version_id': new_doc_vers.id,
'task_id': task_id,
'status': 'success'
})
except Exception as e:
current_app.logger.error(f"Error processing URL {url}: {str(e)}")
results.append({
'url': url,
'status': 'error',
'message': str(e)
})
return results
def start_embedding_task(tenant_id, doc_vers_id):
task = current_celery.send_task('create_embeddings', queue='embeddings', args=[
tenant_id,
doc_vers_id,
])
current_app.logger.info(f'Embedding creation started for tenant {tenant_id}, '
f'Document Version {doc_vers_id}. '
f'Embedding creation task: {task.id}')
return task.id
def validate_file_type(extension):
current_app.logger.debug(f'Validating file type {extension}')
current_app.logger.debug(f'Supported file types: {current_app.config["SUPPORTED_FILE_TYPES"]}')
if extension not in current_app.config['SUPPORTED_FILE_TYPES']:
raise EveAIUnsupportedFileType(f"Filetype {extension} is currently not supported. "
f"Supported filetypes: {', '.join(current_app.config['SUPPORTED_FILE_TYPES'])}")
def get_filename_from_url(url):
parsed_url = urlparse(url)
path_parts = parsed_url.path.split('/')
filename = path_parts[-1]
if filename == '':
filename = 'index'
if not filename.endswith('.html'):
filename += '.html'
return filename
def get_documents_list(page, per_page):
query = Document.query.order_by(desc(Document.created_at))
pagination = query.paginate(page=page, per_page=per_page, error_out=False)
return pagination
def edit_document(document_id, name, valid_from, valid_to):
doc = Document.query.get_or_404(document_id)
doc.name = name
doc.valid_from = valid_from
doc.valid_to = valid_to
update_logging_information(doc, dt.now(tz.utc))
try:
db.session.add(doc)
db.session.commit()
return doc, None
except SQLAlchemyError as e:
db.session.rollback()
return None, str(e)
def edit_document_version(version_id, user_context):
doc_vers = DocumentVersion.query.get_or_404(version_id)
doc_vers.user_context = user_context
update_logging_information(doc_vers, dt.now(tz.utc))
try:
db.session.add(doc_vers)
db.session.commit()
return doc_vers, None
except SQLAlchemyError as e:
db.session.rollback()
return None, str(e)
def refresh_document_with_info(doc_id, api_input):
doc = Document.query.get_or_404(doc_id)
old_doc_vers = DocumentVersion.query.filter_by(doc_id=doc_id).order_by(desc(DocumentVersion.id)).first()
if not old_doc_vers.url:
return None, "This document has no URL. Only documents with a URL can be refreshed."
new_doc_vers = create_version_for_document(
doc,
old_doc_vers.url,
api_input.get('language', old_doc_vers.language),
api_input.get('user_context', old_doc_vers.user_context),
api_input.get('user_metadata', old_doc_vers.user_metadata)
)
set_logging_information(new_doc_vers, dt.now(tz.utc))
try:
db.session.add(new_doc_vers)
db.session.commit()
except SQLAlchemyError as e:
db.session.rollback()
return None, str(e)
response = requests.head(old_doc_vers.url, allow_redirects=True)
content_type = response.headers.get('Content-Type', '').split(';')[0]
extension = get_extension_from_content_type(content_type)
response = requests.get(old_doc_vers.url)
response.raise_for_status()
file_content = response.content
upload_file_for_version(new_doc_vers, file_content, extension, doc.tenant_id)
task = current_celery.send_task('create_embeddings', queue='embeddings', args=[
doc.tenant_id,
new_doc_vers.id,
])
return new_doc_vers, task.id
# Update the existing refresh_document function to use the new refresh_document_with_info
def refresh_document(doc_id):
doc = Document.query.get_or_404(doc_id)
old_doc_vers = DocumentVersion.query.filter_by(doc_id=doc_id).order_by(desc(DocumentVersion.id)).first()
api_input = {
'language': old_doc_vers.language,
'user_context': old_doc_vers.user_context,
'user_metadata': old_doc_vers.user_metadata
}
return refresh_document_with_info(doc_id, api_input)

View File

@@ -0,0 +1,36 @@
class EveAIException(Exception):
"""Base exception class for EveAI API"""
def __init__(self, message, status_code=400, payload=None):
super().__init__()
self.message = message
self.status_code = status_code
self.payload = payload
def to_dict(self):
rv = dict(self.payload or ())
rv['message'] = self.message
return rv
class EveAIInvalidLanguageException(EveAIException):
"""Raised when an invalid language is provided"""
def __init__(self, message="Langage is required", status_code=400, payload=None):
super().__init__(message, status_code, payload)
class EveAIDoubleURLException(EveAIException):
"""Raised when an existing url is provided"""
def __init__(self, message="URL already exists", status_code=400, payload=None):
super().__init__(message, status_code, payload)
class EveAIUnsupportedFileType(EveAIException):
"""Raised when an invalid file type is provided"""
def __init__(self, message="Filetype is not supported", status_code=400, payload=None):
super().__init__(message, status_code, payload)

View File

@@ -86,6 +86,7 @@ def select_model_variables(tenant):
model_variables['html_end_tags'] = tenant.html_end_tags model_variables['html_end_tags'] = tenant.html_end_tags
model_variables['html_included_elements'] = tenant.html_included_elements model_variables['html_included_elements'] = tenant.html_included_elements
model_variables['html_excluded_elements'] = tenant.html_excluded_elements model_variables['html_excluded_elements'] = tenant.html_excluded_elements
model_variables['html_excluded_classes'] = tenant.html_excluded_classes
# Set Chunk Size variables # Set Chunk Size variables
model_variables['min_chunk_size'] = tenant.min_chunk_size model_variables['min_chunk_size'] = tenant.min_chunk_size
@@ -144,8 +145,12 @@ def select_model_variables(tenant):
default_headers=portkey_headers) default_headers=portkey_headers)
tool_calling_supported = False tool_calling_supported = False
match llm_model: match llm_model:
case 'gpt-4-turbo' | 'gpt-4o' | 'gpt-4o-mini': case 'gpt-4o' | 'gpt-4o-mini':
tool_calling_supported = True tool_calling_supported = True
PDF_chunk_size = 10000
PDF_chunk_overlap = 200
PDF_min_chunk_size = 8000
PDF_max_chunk_size = 12000
case _: case _:
raise Exception(f'Error setting model variables for tenant {tenant.id} ' raise Exception(f'Error setting model variables for tenant {tenant.id} '
f'error: Invalid chat model') f'error: Invalid chat model')
@@ -160,10 +165,19 @@ def select_model_variables(tenant):
model=llm_model_ext, model=llm_model_ext,
temperature=model_variables['RAG_temperature']) temperature=model_variables['RAG_temperature'])
tool_calling_supported = True tool_calling_supported = True
PDF_chunk_size = 10000
PDF_chunk_overlap = 200
PDF_min_chunk_size = 8000
PDF_max_chunk_size = 12000
case _: case _:
raise Exception(f'Error setting model variables for tenant {tenant.id} ' raise Exception(f'Error setting model variables for tenant {tenant.id} '
f'error: Invalid chat provider') f'error: Invalid chat provider')
model_variables['PDF_chunk_size'] = PDF_chunk_size
model_variables['PDF_chunk_overlap'] = PDF_chunk_overlap
model_variables['PDF_min_chunk_size'] = PDF_min_chunk_size
model_variables['PDF_max_chunk_size'] = PDF_max_chunk_size
if tool_calling_supported: if tool_calling_supported:
model_variables['cited_answer_cls'] = CitedAnswer model_variables['cited_answer_cls'] = CitedAnswer

View File

@@ -53,7 +53,7 @@ class Config(object):
WTF_CSRF_CHECK_DEFAULT = False WTF_CSRF_CHECK_DEFAULT = False
# file upload settings # file upload settings
MAX_CONTENT_LENGTH = 16 * 1024 * 1024 MAX_CONTENT_LENGTH = 50 * 1024 * 1024
UPLOAD_EXTENSIONS = ['.txt', '.pdf', '.png', '.jpg', '.jpeg', '.gif'] UPLOAD_EXTENSIONS = ['.txt', '.pdf', '.png', '.jpg', '.jpeg', '.gif']
# supported languages # supported languages
@@ -107,6 +107,7 @@ class Config(object):
# JWT settings # JWT settings
JWT_SECRET_KEY = environ.get('JWT_SECRET_KEY') JWT_SECRET_KEY = environ.get('JWT_SECRET_KEY')
JWT_ACCESS_TOKEN_EXPIRES = timedelta(hours=1) # Set token expiry to 1 hour
# API Encryption # API Encryption
API_ENCRYPTION_KEY = environ.get('API_ENCRYPTION_KEY') API_ENCRYPTION_KEY = environ.get('API_ENCRYPTION_KEY')
@@ -136,6 +137,10 @@ class Config(object):
MAIL_PASSWORD = environ.get('MAIL_PASSWORD') MAIL_PASSWORD = environ.get('MAIL_PASSWORD')
MAIL_DEFAULT_SENDER = ('eveAI Admin', MAIL_USERNAME) MAIL_DEFAULT_SENDER = ('eveAI Admin', MAIL_USERNAME)
SUPPORTED_FILE_TYPES = ['pdf', 'html', 'md', 'txt', 'mp3', 'mp4', 'ogg', 'srt']
class DevConfig(Config): class DevConfig(Config):
DEVELOPMENT = True DEVELOPMENT = True

View File

@@ -60,6 +60,14 @@ LOGGING = {
'backupCount': 10, 'backupCount': 10,
'formatter': 'standard', 'formatter': 'standard',
}, },
'file_api': {
'level': 'DEBUG',
'class': 'logging.handlers.RotatingFileHandler',
'filename': 'logs/eveai_api.log',
'maxBytes': 1024 * 1024 * 5, # 5MB
'backupCount': 10,
'formatter': 'standard',
},
'file_sqlalchemy': { 'file_sqlalchemy': {
'level': 'DEBUG', 'level': 'DEBUG',
'class': 'logging.handlers.RotatingFileHandler', 'class': 'logging.handlers.RotatingFileHandler',
@@ -146,6 +154,11 @@ LOGGING = {
'level': 'DEBUG', 'level': 'DEBUG',
'propagate': False 'propagate': False
}, },
'eveai_api': { # logger for the eveai_chat_workers
'handlers': ['file_api', 'graylog', ] if env == 'production' else ['file_api', ],
'level': 'DEBUG',
'propagate': False
},
'sqlalchemy.engine': { # logger for the sqlalchemy 'sqlalchemy.engine': { # logger for the sqlalchemy
'handlers': ['file_sqlalchemy', 'graylog', ] if env == 'production' else ['file_sqlalchemy', ], 'handlers': ['file_sqlalchemy', 'graylog', ] if env == 'production' else ['file_sqlalchemy', ],
'level': 'DEBUG', 'level': 'DEBUG',

View File

@@ -15,11 +15,12 @@ html_parse: |
pdf_parse: | pdf_parse: |
You are a top administrative aid specialized in transforming given PDF-files into markdown formatted files. The generated files will be used to generate embeddings in a RAG-system. You are a top administrative aid specialized in transforming given PDF-files into markdown formatted files. The generated files will be used to generate embeddings in a RAG-system.
The content you get is already processed (some markdown already generated), but needs to be corrected. For large files, you may receive only portions of the full file. Consider this when processing the content.
# Best practices are: # Best practices are:
- Respect wordings and language(s) used in the PDF. - Respect wordings and language(s) used in the provided content.
- The following items need to be considered: headings, paragraphs, listed items (numbered or not) and tables. Images can be neglected. - The following items need to be considered: headings, paragraphs, listed items (numbered or not) and tables. Images can be neglected.
- When headings are numbered, show the numbering and define the header level. - When headings are numbered, show the numbering and define the header level. You may have to correct current header levels, as preprocessing is known to make errors.
- A new item is started when a <return> is found before a full line is reached. In order to know the number of characters in a line, please check the document and the context within the document (e.g. an image could limit the number of characters temporarily). - A new item is started when a <return> is found before a full line is reached. In order to know the number of characters in a line, please check the document and the context within the document (e.g. an image could limit the number of characters temporarily).
- Paragraphs are to be stripped of newlines so they become easily readable. - Paragraphs are to be stripped of newlines so they become easily readable.
- Be careful of encoding of the text. Everything needs to be human readable. - Be careful of encoding of the text. Everything needs to be human readable.

View File

@@ -57,6 +57,9 @@ services:
- ../nginx/sites-enabled:/etc/nginx/sites-enabled - ../nginx/sites-enabled:/etc/nginx/sites-enabled
- ../nginx/static:/etc/nginx/static - ../nginx/static:/etc/nginx/static
- ../nginx/public:/etc/nginx/public - ../nginx/public:/etc/nginx/public
- ../integrations/Wordpress/eveai-chat-widget/css/eveai-chat-style.css:/etc/nginx/static/css/eveai-chat-style.css
- ../integrations/Wordpress/eveai-chat-widget/js/eveai-chat-widget.js:/etc/nginx/static/js/eveai-chat-widget.js
- ../integrations/Wordpress/eveai-chat-widget/js/eveai-sdk.js:/etc/nginx/static/js/eveai-sdk.js
- ./logs/nginx:/var/log/nginx - ./logs/nginx:/var/log/nginx
depends_on: depends_on:
- eveai_app - eveai_app
@@ -209,6 +212,43 @@ services:
networks: networks:
- eveai-network - eveai-network
eveai_api:
image: josakola/eveai_api:latest
build:
context: ..
dockerfile: ./docker/eveai_api/Dockerfile
platforms:
- linux/amd64
- linux/arm64
ports:
- 5003:5003
environment:
<<: *common-variables
COMPONENT_NAME: eveai_api
volumes:
- ../eveai_api:/app/eveai_api
- ../common:/app/common
- ../config:/app/config
- ../scripts:/app/scripts
- ../patched_packages:/app/patched_packages
- eveai_logs:/app/logs
depends_on:
db:
condition: service_healthy
redis:
condition: service_healthy
minio:
condition: service_healthy
healthcheck:
test: [ "CMD", "curl", "-f", "http://localhost:5003/health" ]
interval: 10s
timeout: 5s
retries: 5
# entrypoint: ["scripts/entrypoint.sh"]
# command: ["scripts/start_eveai_api.sh"]
networks:
- eveai-network
db: db:
hostname: db hostname: db
image: ankane/pgvector image: ankane/pgvector

View File

@@ -142,6 +142,24 @@ services:
networks: networks:
- eveai-network - eveai-network
eveai_api:
platform: linux/amd64
image: josakola/eveai_api:latest
ports:
- 5003:5003
environment:
<<: *common-variables
COMPONENT_NAME: eveai_api
volumes:
- eveai_logs:/app/logs
healthcheck:
test: [ "CMD", "curl", "-f", "http://localhost:5001/health" ]
interval: 10s
timeout: 5s
retries: 5
networks:
- eveai-network
volumes: volumes:
eveai_logs: eveai_logs:
# miniAre theo_data: # miniAre theo_data:

View File

@@ -0,0 +1,69 @@
ARG PYTHON_VERSION=3.12.3
FROM python:${PYTHON_VERSION}-slim as base
# Prevents Python from writing pyc files.
ENV PYTHONDONTWRITEBYTECODE=1
# Keeps Python from buffering stdout and stderr to avoid situations where
# the application crashes without emitting any logs due to buffering.
ENV PYTHONUNBUFFERED=1
# Create directory for patched packages and set permissions
RUN mkdir -p /app/patched_packages && \
chmod 777 /app/patched_packages
# Ensure patches are applied to the application.
ENV PYTHONPATH=/app/patched_packages:$PYTHONPATH
WORKDIR /app
# Create a non-privileged user that the app will run under.
# See https://docs.docker.com/go/dockerfile-user-best-practices/
ARG UID=10001
RUN adduser \
--disabled-password \
--gecos "" \
--home "/nonexistent" \
--shell "/bin/bash" \
--no-create-home \
--uid "${UID}" \
appuser
# Install necessary packages and build tools
RUN apt-get update && apt-get install -y \
build-essential \
gcc \
postgresql-client \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*
# Create logs directory and set permissions
RUN mkdir -p /app/logs && chown -R appuser:appuser /app/logs
# Download dependencies as a separate step to take advantage of Docker's caching.
# Leverage a cache mount to /root/.cache/pip to speed up subsequent builds.
# Leverage a bind mount to requirements.txt to avoid having to copy them into
# into this layer.
COPY requirements.txt /app/
RUN python -m pip install -r /app/requirements.txt
# Copy the source code into the container.
COPY eveai_api /app/eveai_api
COPY common /app/common
COPY config /app/config
COPY scripts /app/scripts
COPY patched_packages /app/patched_packages
# Set permissions for entrypoint script
RUN chmod 777 /app/scripts/entrypoint.sh
# Set ownership of the application directory to the non-privileged user
RUN chown -R appuser:appuser /app
# Expose the port that the application listens on.
EXPOSE 5003
# Set entrypoint and command
ENTRYPOINT ["/app/scripts/entrypoint.sh"]
CMD ["/app/scripts/start_eveai_api.sh"]

View File

@@ -45,7 +45,7 @@ RUN mkdir -p /app/logs && chown -R appuser:appuser /app/logs
# Leverage a bind mount to requirements.txt to avoid having to copy them into # Leverage a bind mount to requirements.txt to avoid having to copy them into
# into this layer. # into this layer.
COPY ../../requirements.txt /app/ COPY requirements.txt /app/
RUN python -m pip install -r requirements.txt RUN python -m pip install -r requirements.txt
# Copy the source code into the container. # Copy the source code into the container.

View File

@@ -10,6 +10,9 @@ COPY ../../nginx/mime.types /etc/nginx/mime.types
# Copy static & public files # Copy static & public files
RUN mkdir -p /etc/nginx/static /etc/nginx/public RUN mkdir -p /etc/nginx/static /etc/nginx/public
COPY ../../nginx/static /etc/nginx/static COPY ../../nginx/static /etc/nginx/static
COPY ../../integrations/Wordpress/eveai-chat-widget/css/eveai-chat-style.css /etc/nginx/static/css/
COPY ../../integrations/Wordpress/eveai-chat-widget/js/eveai-chat-widget.js /etc/nginx/static/js/
COPY ../../integrations/Wordpress/eveai-chat-widget/js/eveai-sdk.js /etc/nginx/static/js
COPY ../../nginx/public /etc/nginx/public COPY ../../nginx/public /etc/nginx/public
# Copy site-specific configurations # Copy site-specific configurations

View File

@@ -1,4 +1,104 @@
# from flask import Blueprint, request from flask import Flask, jsonify, request
# from flask_jwt_extended import get_jwt_identity, verify_jwt_in_request
# public_api_bp = Blueprint("public", __name__, url_prefix="/api/v1") from common.extensions import db, api_rest, jwt, minio_client, simple_encryption
# tenant_api_bp = Blueprint("tenant", __name__, url_prefix="/api/v1/tenant") import os
import logging.config
from common.utils.database import Database
from config.logging_config import LOGGING
from .api.document_api import document_ns
from .api.auth import auth_ns
from config.config import get_config
from common.utils.celery_utils import make_celery, init_celery
from common.utils.eveai_exceptions import EveAIException
def create_app(config_file=None):
app = Flask(__name__)
environment = os.getenv('FLASK_ENV', 'development')
match environment:
case 'development':
app.config.from_object(get_config('dev'))
case 'production':
app.config.from_object(get_config('prod'))
case _:
app.config.from_object(get_config('dev'))
app.config['SESSION_KEY_PREFIX'] = 'eveai_api_'
app.celery = make_celery(app.name, app.config)
init_celery(app.celery, app)
logging.config.dictConfig(LOGGING)
logger = logging.getLogger(__name__)
logger.info("eveai_api starting up")
# Register Necessary Extensions
register_extensions(app)
# register Blueprints
register_namespaces(api_rest)
# Error handler for the API
@app.errorhandler(EveAIException)
def handle_eveai_exception(error):
return {'message': str(error)}, error.status_code
@app.before_request
def before_request():
app.logger.debug(f'Before request: {request.method} {request.path}')
app.logger.debug(f'Request URL: {request.url}')
app.logger.debug(f'Request headers: {dict(request.headers)}')
# Log request arguments
app.logger.debug(f'Request args: {request.args}')
# Log form data if it's a POST request
if request.method == 'POST':
app.logger.debug(f'Form data: {request.form}')
# Log JSON data if the content type is application/json
if request.is_json:
app.logger.debug(f'JSON data: {request.json}')
# Log raw data for other content types
if request.data:
app.logger.debug(f'Raw data: {request.data}')
# Check if this is a request to the token endpoint
if request.path == '/api/v1/auth/token' and request.method == 'POST':
app.logger.debug('Token request detected, skipping JWT verification')
return
try:
verify_jwt_in_request(optional=True)
tenant_id = get_jwt_identity()
app.logger.debug(f'Tenant ID from JWT: {tenant_id}')
if tenant_id:
Database(tenant_id).switch_schema()
app.logger.debug(f'Switched to schema for tenant {tenant_id}')
else:
app.logger.debug('No tenant ID found in JWT')
except Exception as e:
app.logger.error(f'Error in before_request: {str(e)}')
# Don't raise the exception here, let the request continue
# The appropriate error handling will be done in the specific endpoints
return app
def register_extensions(app):
db.init_app(app)
api_rest.init_app(app, title='EveAI API', version='1.0', description='EveAI API')
jwt.init_app(app)
minio_client.init_app(app)
simple_encryption.init_app(app)
def register_namespaces(app):
api_rest.add_namespace(document_ns, path='/api/v1/documents')
api_rest.add_namespace(auth_ns, path='/api/v1/auth')

75
eveai_api/api/auth.py Normal file
View File

@@ -0,0 +1,75 @@
from datetime import timedelta
from flask_restx import Namespace, Resource, fields
from flask_jwt_extended import create_access_token
from common.models.user import Tenant
from common.extensions import simple_encryption
from flask import current_app, request
auth_ns = Namespace('auth', description='Authentication related operations')
token_model = auth_ns.model('Token', {
'tenant_id': fields.Integer(required=True, description='Tenant ID'),
'api_key': fields.String(required=True, description='API Key')
})
token_response = auth_ns.model('TokenResponse', {
'access_token': fields.String(description='JWT access token'),
'expires_in': fields.Integer(description='Token expiration time in seconds')
})
@auth_ns.route('/token')
class Token(Resource):
@auth_ns.expect(token_model)
@auth_ns.response(200, 'Success', token_response)
@auth_ns.response(400, 'Validation Error')
@auth_ns.response(401, 'Unauthorized')
@auth_ns.response(404, 'Tenant Not Found')
def post(self):
"""
Get JWT token
"""
current_app.logger.debug(f"Token endpoint called with data: {request.json}")
try:
tenant_id = auth_ns.payload['tenant_id']
api_key = auth_ns.payload['api_key']
except KeyError as e:
current_app.logger.error(f"Missing required field: {e}")
return {'message': f"Missing required field: {e}"}, 400
current_app.logger.debug(f"Querying database for tenant: {tenant_id}")
tenant = Tenant.query.get(tenant_id)
if not tenant:
current_app.logger.error(f"Tenant not found: {tenant_id}")
return {'message': "Tenant not found"}, 404
current_app.logger.debug(f"Tenant found: {tenant.id}")
try:
current_app.logger.debug("Attempting to decrypt API key")
decrypted_api_key = simple_encryption.decrypt_api_key(tenant.encrypted_api_key)
except Exception as e:
current_app.logger.error(f"Error decrypting API key: {e}")
return {'message': "Internal server error"}, 500
if api_key != decrypted_api_key:
current_app.logger.error(f"Invalid API key for tenant: {tenant_id}")
return {'message': "Invalid API key"}, 401
# Get the JWT_ACCESS_TOKEN_EXPIRES setting from the app config
expires_delta = current_app.config.get('JWT_ACCESS_TOKEN_EXPIRES', timedelta(minutes=15))
try:
current_app.logger.debug(f"Creating access token for tenant: {tenant_id}")
access_token = create_access_token(identity=tenant_id, expires_delta=expires_delta)
current_app.logger.debug("Access token created successfully")
return {
'access_token': access_token,
'expires_in': expires_delta.total_seconds()
}, 200
except Exception as e:
current_app.logger.error(f"Error creating access token: {e}")
return {'message': "Internal server error"}, 500

View File

@@ -0,0 +1,313 @@
import json
from datetime import datetime
import pytz
from flask import current_app, request
from flask_restx import Namespace, Resource, fields, reqparse
from flask_jwt_extended import jwt_required, get_jwt_identity
from werkzeug.datastructures import FileStorage
from werkzeug.utils import secure_filename
from common.utils.document_utils import (
create_document_stack, process_url, start_embedding_task,
validate_file_type, EveAIInvalidLanguageException, EveAIDoubleURLException, EveAIUnsupportedFileType,
process_multiple_urls, get_documents_list, edit_document, refresh_document, edit_document_version,
refresh_document_with_info
)
def validate_date(date_str):
try:
return datetime.fromisoformat(date_str).replace(tzinfo=pytz.UTC)
except ValueError:
raise ValueError("Invalid date format. Use ISO format (YYYY-MM-DDTHH:MM:SS).")
def validate_json(json_str):
try:
return json.loads(json_str)
except json.JSONDecodeError:
raise ValueError("Invalid JSON format for user_metadata.")
document_ns = Namespace('documents', description='Document related operations')
# Define models for request parsing and response serialization
upload_parser = reqparse.RequestParser()
upload_parser.add_argument('file', location='files', type=FileStorage, required=True, help='The file to upload')
upload_parser.add_argument('name', location='form', type=str, required=False, help='Name of the document')
upload_parser.add_argument('language', location='form', type=str, required=True, help='Language of the document')
upload_parser.add_argument('user_context', location='form', type=str, required=False,
help='User context for the document')
upload_parser.add_argument('valid_from', location='form', type=validate_date, required=False,
help='Valid from date for the document (ISO format)')
upload_parser.add_argument('user_metadata', location='form', type=validate_json, required=False,
help='User metadata for the document (JSON format)')
add_document_response = document_ns.model('AddDocumentResponse', {
'message': fields.String(description='Status message'),
'document_id': fields.Integer(description='ID of the created document'),
'document_version_id': fields.Integer(description='ID of the created document version'),
'task_id': fields.String(description='ID of the embedding task')
})
@document_ns.route('/add_document')
class AddDocument(Resource):
@jwt_required()
@document_ns.expect(upload_parser)
@document_ns.response(201, 'Document added successfully', add_document_response)
@document_ns.response(400, 'Validation Error')
@document_ns.response(500, 'Internal Server Error')
def post(self):
"""
Add a new document
"""
tenant_id = get_jwt_identity()
current_app.logger.info(f'Adding document for tenant {tenant_id}')
try:
args = upload_parser.parse_args()
file = args['file']
filename = secure_filename(file.filename)
extension = filename.rsplit('.', 1)[1].lower()
validate_file_type(extension)
api_input = {
'name': args.get('name') or filename,
'language': args.get('language'),
'user_context': args.get('user_context'),
'valid_from': args.get('valid_from'),
'user_metadata': args.get('user_metadata'),
}
new_doc, new_doc_vers = create_document_stack(api_input, file, filename, extension, tenant_id)
task_id = start_embedding_task(tenant_id, new_doc_vers.id)
return {
'message': f'Processing on document {new_doc.name}, version {new_doc_vers.id} started. Task ID: {task_id}.',
'document_id': new_doc.id,
'document_version_id': new_doc_vers.id,
'task_id': task_id
}, 201
except (EveAIInvalidLanguageException, EveAIUnsupportedFileType) as e:
current_app.logger.error(f'Error adding document: {str(e)}')
document_ns.abort(400, str(e))
except Exception as e:
current_app.logger.error(f'Error adding document: {str(e)}')
document_ns.abort(500, 'Error adding document')
# Models for AddURL
add_url_model = document_ns.model('AddURL', {
'url': fields.String(required=True, description='URL of the document to add'),
'name': fields.String(required=False, description='Name of the document'),
'language': fields.String(required=True, description='Language of the document'),
'user_context': fields.String(required=False, description='User context for the document'),
'valid_from': fields.String(required=False, description='Valid from date for the document'),
'user_metadata': fields.String(required=False, description='User metadata for the document'),
'system_metadata': fields.String(required=False, description='System metadata for the document')
})
add_url_response = document_ns.model('AddURLResponse', {
'message': fields.String(description='Status message'),
'document_id': fields.Integer(description='ID of the created document'),
'document_version_id': fields.Integer(description='ID of the created document version'),
'task_id': fields.String(description='ID of the embedding task')
})
@document_ns.route('/add_url')
class AddURL(Resource):
@jwt_required()
@document_ns.expect(add_url_model)
@document_ns.response(201, 'Document added successfully', add_url_response)
@document_ns.response(400, 'Validation Error')
@document_ns.response(500, 'Internal Server Error')
def post(self):
"""
Add a new document from URL
"""
tenant_id = get_jwt_identity()
current_app.logger.info(f'Adding document from URL for tenant {tenant_id}')
try:
args = document_ns.payload
file_content, filename, extension = process_url(args['url'], tenant_id)
api_input = {
'url': args['url'],
'name': args.get('name') or filename,
'language': args['language'],
'user_context': args.get('user_context'),
'valid_from': args.get('valid_from'),
'user_metadata': args.get('user_metadata'),
}
new_doc, new_doc_vers = create_document_stack(api_input, file_content, filename, extension, tenant_id)
task_id = start_embedding_task(tenant_id, new_doc_vers.id)
return {
'message': f'Processing on document {new_doc.name}, version {new_doc_vers.id} started. Task ID: {task_id}.',
'document_id': new_doc.id,
'document_version_id': new_doc_vers.id,
'task_id': task_id
}, 201
except EveAIDoubleURLException:
document_ns.abort(400, f'A document with URL {args["url"]} already exists.')
except (EveAIInvalidLanguageException, EveAIUnsupportedFileType) as e:
document_ns.abort(400, str(e))
except Exception as e:
current_app.logger.error(f'Error adding document from URL: {str(e)}')
document_ns.abort(500, 'Error adding document from URL')
document_list_model = document_ns.model('DocumentList', {
'id': fields.Integer(description='Document ID'),
'name': fields.String(description='Document name'),
'valid_from': fields.DateTime(description='Valid from date'),
'valid_to': fields.DateTime(description='Valid to date'),
})
@document_ns.route('/list')
class DocumentList(Resource):
@jwt_required()
@document_ns.doc('list_documents')
@document_ns.marshal_list_with(document_list_model, envelope='documents')
def get(self):
"""List all documents"""
page = request.args.get('page', 1, type=int)
per_page = request.args.get('per_page', 10, type=int)
pagination = get_documents_list(page, per_page)
return pagination.items, 200
edit_document_model = document_ns.model('EditDocument', {
'name': fields.String(required=True, description='New name for the document'),
'valid_from': fields.DateTime(required=False, description='New valid from date'),
'valid_to': fields.DateTime(required=False, description='New valid to date'),
})
@document_ns.route('/<int:document_id>')
class DocumentResource(Resource):
@jwt_required()
@document_ns.doc('edit_document')
@document_ns.expect(edit_document_model)
@document_ns.response(200, 'Document updated successfully')
def put(self, document_id):
"""Edit a document"""
data = request.json
updated_doc, error = edit_document(document_id, data['name'], data.get('valid_from'), data.get('valid_to'))
if updated_doc:
return {'message': f'Document {updated_doc.id} updated successfully'}, 200
else:
return {'message': f'Error updating document: {error}'}, 400
@jwt_required()
@document_ns.doc('refresh_document')
@document_ns.response(200, 'Document refreshed successfully')
def post(self, document_id):
"""Refresh a document"""
new_version, result = refresh_document(document_id)
if new_version:
return {'message': f'Document refreshed. New version: {new_version.id}. Task ID: {result}'}, 200
else:
return {'message': f'Error refreshing document: {result}'}, 400
edit_document_version_model = document_ns.model('EditDocumentVersion', {
'user_context': fields.String(required=True, description='New user context for the document version'),
})
@document_ns.route('/version/<int:version_id>')
class DocumentVersionResource(Resource):
@jwt_required()
@document_ns.doc('edit_document_version')
@document_ns.expect(edit_document_version_model)
@document_ns.response(200, 'Document version updated successfully')
def put(self, version_id):
"""Edit a document version"""
data = request.json
updated_version, error = edit_document_version(version_id, data['user_context'])
if updated_version:
return {'message': f'Document Version {updated_version.id} updated successfully'}, 200
else:
return {'message': f'Error updating document version: {error}'}, 400
# Define the model for the request body of refresh_with_info
refresh_document_model = document_ns.model('RefreshDocument', {
'name': fields.String(required=False, description='New name for the document'),
'language': fields.String(required=False, description='Language of the document'),
'user_context': fields.String(required=False, description='User context for the document'),
'user_metadata': fields.Raw(required=False, description='User metadata for the document')
})
@document_ns.route('/<int:document_id>/refresh')
class RefreshDocument(Resource):
@jwt_required()
@document_ns.response(200, 'Document refreshed successfully')
@document_ns.response(404, 'Document not found')
def post(self, document_id):
"""
Refresh a document without additional information
"""
tenant_id = get_jwt_identity()
current_app.logger.info(f'Refreshing document {document_id} for tenant {tenant_id}')
try:
new_version, result = refresh_document(document_id)
if new_version:
return {
'message': f'Document refreshed successfully. New version: {new_version.id}. Task ID: {result}',
'document_id': document_id,
'document_version_id': new_version.id,
'task_id': result
}, 200
else:
return {'message': f'Error refreshing document: {result}'}, 400
except Exception as e:
current_app.logger.error(f'Error refreshing document: {str(e)}')
return {'message': 'Internal server error'}, 500
@document_ns.route('/<int:document_id>/refresh_with_info')
class RefreshDocumentWithInfo(Resource):
@jwt_required()
@document_ns.expect(refresh_document_model)
@document_ns.response(200, 'Document refreshed successfully')
@document_ns.response(400, 'Validation Error')
@document_ns.response(404, 'Document not found')
def post(self, document_id):
"""
Refresh a document with new information
"""
tenant_id = get_jwt_identity()
current_app.logger.info(f'Refreshing document {document_id} with info for tenant {tenant_id}')
try:
api_input = request.json
new_version, result = refresh_document_with_info(document_id, api_input)
if new_version:
return {
'message': f'Document refreshed successfully with new info. New version: {new_version.id}. Task ID: {result}',
'document_id': document_id,
'document_version_id': new_version.id,
'task_id': result
}, 200
else:
return {'message': f'Error refreshing document with info: {result}'}, 400
except Exception as e:
current_app.logger.error(f'Error refreshing document with info: {str(e)}')
return {'message': 'Internal server error'}, 500

View File

@@ -1,7 +0,0 @@
from flask import request
from flask.views import MethodView
class RegisterAPI(MethodView):
def post(self):
username = request.json['username']

View File

@@ -27,7 +27,6 @@ def create_app(config_file=None):
app.wsgi_app = ProxyFix(app.wsgi_app, x_for=1, x_proto=1, x_host=1, x_port=1) app.wsgi_app = ProxyFix(app.wsgi_app, x_for=1, x_proto=1, x_host=1, x_port=1)
environment = os.getenv('FLASK_ENV', 'development') environment = os.getenv('FLASK_ENV', 'development')
print(environment)
match environment: match environment:
case 'development': case 'development':
@@ -49,8 +48,6 @@ def create_app(config_file=None):
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.info("eveai_app starting up") logger.info("eveai_app starting up")
logger.debug("start config")
logger.debug(app.config)
# Register extensions # Register extensions
@@ -95,14 +92,11 @@ def create_app(config_file=None):
} }
return jsonify(response), 500 return jsonify(response), 500
@app.before_request # @app.before_request
def before_request(): # def before_request():
# app.logger.debug(f"Before request - Session ID: {session.sid}") # # app.logger.debug(f"Before request - Session ID: {session.sid}")
app.logger.debug(f"Before request - Session data: {session}") # app.logger.debug(f"Before request - Session data: {session}")
app.logger.debug(f"Before request - Request headers: {request.headers}") # app.logger.debug(f"Before request - Request headers: {request.headers}")
# Register API
register_api(app)
# Register template filters # Register template filters
register_filters(app) register_filters(app)
@@ -138,9 +132,3 @@ def register_blueprints(app):
app.register_blueprint(security_bp) app.register_blueprint(security_bp)
from .views.interaction_views import interaction_bp from .views.interaction_views import interaction_bp
app.register_blueprint(interaction_bp) app.register_blueprint(interaction_bp)
def register_api(app):
pass
# from . import api
# app.register_blueprint(api.bp, url_prefix='/api')

View File

@@ -1,24 +0,0 @@
{% extends 'base.html' %}
{% from "macros.html" import render_field %}
{% block title %}Add Youtube Document{% endblock %}
{% block content_title %}Add Youtube Document{% endblock %}
{% block content_description %}Add a youtube url and the corresponding document to EveAI. In some cases, url's cannot be loaded directly. Download the html and add it as a document in that case.{% endblock %}
{% block content %}
<form method="post">
{{ form.hidden_tag() }}
{% set disabled_fields = [] %}
{% set exclude_fields = [] %}
{% for field in form %}
{{ render_field(field, disabled_fields, exclude_fields) }}
{% endfor %}
<button type="submit" class="btn btn-primary">Add Youtube Document</button>
</form>
{% endblock %}
{% block content_footer %}
{% endblock %}

View File

@@ -8,7 +8,7 @@
{% block content %} {% block content %}
<form method="post"> <form method="post">
{{ form.hidden_tag() }} {{ form.hidden_tag() }}
{% set disabled_fields = ['language', 'system_context'] %} {% set disabled_fields = ['language', 'system_context', 'system_metadata'] %}
{% set exclude_fields = [] %} {% set exclude_fields = [] %}
{% for field in form %} {% for field in form %}
{{ render_field(field, disabled_fields, exclude_fields) }} {{ render_field(field, disabled_fields, exclude_fields) }}

View File

@@ -1,126 +1,80 @@
{% extends "base.html" %} {% extends "base.html" %}
{% from "macros.html" import render_field %}
{% block title %}Session Overview{% endblock %}
{% block content_title %}Session Overview{% endblock %}
{% block content_description %}An overview of the chat session.{% endblock %}
{% block content %} {% block content %}
<div class="container mt-5"> <div class="container mt-5">
<h2>Chat Session Details</h2> <h2>Chat Session Details</h2>
<!-- Session Information -->
<div class="card mb-4"> <div class="card mb-4">
<div class="card-header"> <div class="card-header">
<h5>Session Information</h5> <h5>Session Information</h5>
<!-- Timezone Toggle Buttons -->
<div class="btn-group" role="group">
<button type="button" class="btn btn-primary" id="toggle-interaction-timezone">Interaction Timezone</button>
<button type="button" class="btn btn-secondary" id="toggle-admin-timezone">Admin Timezone</button>
</div>
</div> </div>
<div class="card-body"> <div class="card-body">
<dl class="row"> <p><strong>Session ID:</strong> {{ chat_session.session_id }}</p>
<dt class="col-sm-3">Session ID:</dt> <p><strong>User:</strong> {{ chat_session.user.user_name if chat_session.user else 'Anonymous' }}</p>
<dd class="col-sm-9">{{ chat_session.session_id }}</dd> <p><strong>Start:</strong> {{ chat_session.session_start | to_local_time(chat_session.timezone) }}</p>
<p><strong>End:</strong> {{ chat_session.session_end | to_local_time(chat_session.timezone) if chat_session.session_end else 'Ongoing' }}</p>
<dt class="col-sm-3">Session Start:</dt>
<dd class="col-sm-9">
<span class="timezone interaction-timezone">{{ chat_session.session_start | to_local_time(chat_session.timezone) }}</span>
<span class="timezone admin-timezone d-none">{{ chat_session.session_start | to_local_time(session['admin_user_timezone']) }}</span>
</dd>
<dt class="col-sm-3">Session End:</dt>
<dd class="col-sm-9">
{% if chat_session.session_end %}
<span class="timezone interaction-timezone">{{ chat_session.session_end | to_local_time(chat_session.timezone) }}</span>
<span class="timezone admin-timezone d-none">{{ chat_session.session_end | to_local_time(session['admin_user_timezone']) }}</span>
{% else %}
Ongoing
{% endif %}
</dd>
</dl>
</div> </div>
</div> </div>
<!-- Interactions List --> <h3>Interactions</h3>
<div class="card mb-4"> <div class="accordion" id="interactionsAccordion">
<div class="card-header"> {% for interaction in interactions %}
<h5>Interactions</h5> <div class="accordion-item">
</div> <h2 class="accordion-header" id="heading{{ loop.index }}">
<div class="card-body"> <button class="accordion-button collapsed" type="button" data-bs-toggle="collapse"
{% for interaction in interactions %} data-bs-target="#collapse{{ loop.index }}" aria-expanded="false"
<div class="interaction mb-3"> aria-controls="collapse{{ loop.index }}">
<div class="card"> <div class="d-flex justify-content-between align-items-center w-100">
<div class="card-header d-flex justify-content-between"> <span class="interaction-question">{{ interaction.question | truncate(50) }}</span>
<span>Question:</span> <span class="interaction-icons">
<span class="text-muted"> <i class="material-icons algorithm-icon {{ interaction.algorithm_used | lower }}">fingerprint</i>
<span class="timezone interaction-timezone">{{ interaction.question_at | to_local_time(interaction.timezone) }}</span> <i class="material-icons thumb-icon {% if interaction.appreciation == 100 %}filled{% else %}outlined{% endif %}">thumb_up</i>
<span class="timezone admin-timezone d-none">{{ interaction.question_at | to_local_time(session['admin_user_timezone']) }}</span> <i class="material-icons thumb-icon {% if interaction.appreciation == 0 %}filled{% else %}outlined{% endif %}">thumb_down</i>
- </span>
<span class="timezone interaction-timezone">{{ interaction.answer_at | to_local_time(interaction.timezone) }}</span>
<span class="timezone admin-timezone d-none">{{ interaction.answer_at | to_local_time(session['admin_user_timezone']) }}</span>
({{ interaction.question_at | time_difference(interaction.answer_at) }})
</span>
</div>
<div class="card-body">
<p><strong>Question:</strong> {{ interaction.question }}</p>
<p><strong>Answer:</strong> {{ interaction.answer }}</p>
<p>
<strong>Algorithm Used:</strong>
<i class="material-icons {{ 'fingerprint-rag-' ~ interaction.algorithm_used.lower() }}">
fingerprint
</i> {{ interaction.algorithm_used }}
</p>
<p>
<strong>Appreciation:</strong>
<i class="material-icons thumb-icon {{ 'thumb_up' if interaction.appreciation == 1 else 'thumb_down' }}">
{{ 'thumb_up' if interaction.appreciation == 1 else 'thumb_down' }}
</i>
</p>
<p><strong>Embeddings:</strong>
{% if interaction.embeddings %}
{% for embedding in interaction.embeddings %}
<a href="{{ url_for('interaction_bp.view_embedding', embedding_id=embedding.embedding_id) }}" class="badge badge-info">
{{ embedding.embedding_id }}
</a>
{% endfor %}
{% else %}
None
{% endif %}
</p>
</div>
</div> </div>
</button>
</h2>
<div id="collapse{{ loop.index }}" class="accordion-collapse collapse" aria-labelledby="heading{{ loop.index }}"
data-bs-parent="#interactionsAccordion">
<div class="accordion-body">
<h6>Detailed Question:</h6>
<p>{{ interaction.detailed_question }}</p>
<h6>Answer:</h6>
<div class="markdown-content">{{ interaction.answer | safe }}</div>
{% if embeddings_dict.get(interaction.id) %}
<h6>Related Documents:</h6>
<ul>
{% for embedding in embeddings_dict[interaction.id] %}
<li>
{% if embedding.url %}
<a href="{{ embedding.url }}" target="_blank">{{ embedding.url }}</a>
{% else %}
{{ embedding.file_name }}
{% endif %}
</li>
{% endfor %}
</ul>
{% endif %}
</div> </div>
{% endfor %} </div>
</div> </div>
{% endfor %}
</div> </div>
</div> </div>
{% endblock %} {% endblock %}
{% block scripts %} {% block scripts %}
<script src="https://cdn.jsdelivr.net/npm/marked/marked.min.js"></script>
<script> <script>
document.addEventListener('DOMContentLoaded', function() { document.addEventListener('DOMContentLoaded', function() {
// Elements to toggle var markdownElements = document.querySelectorAll('.markdown-content');
const interactionTimes = document.querySelectorAll('.interaction-timezone'); markdownElements.forEach(function(el) {
const adminTimes = document.querySelectorAll('.admin-timezone'); el.innerHTML = marked.parse(el.textContent);
// Buttons
const interactionButton = document.getElementById('toggle-interaction-timezone');
const adminButton = document.getElementById('toggle-admin-timezone');
// Toggle to Interaction Timezone
interactionButton.addEventListener('click', function() {
interactionTimes.forEach(el => el.classList.remove('d-none'));
adminTimes.forEach(el => el.classList.add('d-none'));
interactionButton.classList.add('btn-primary');
interactionButton.classList.remove('btn-secondary');
adminButton.classList.add('btn-secondary');
adminButton.classList.remove('btn-primary');
});
// Toggle to Admin Timezone
adminButton.addEventListener('click', function() {
interactionTimes.forEach(el => el.classList.add('d-none'));
adminTimes.forEach(el => el.classList.remove('d-none'));
interactionButton.classList.add('btn-secondary');
interactionButton.classList.remove('btn-primary');
adminButton.classList.add('btn-primary');
adminButton.classList.remove('btn-secondary');
}); });
}); });
</script> </script>

View File

@@ -84,7 +84,6 @@
{'name': 'Add Document', 'url': '/document/add_document', 'roles': ['Super User', 'Tenant Admin']}, {'name': 'Add Document', 'url': '/document/add_document', 'roles': ['Super User', 'Tenant Admin']},
{'name': 'Add URL', 'url': '/document/add_url', 'roles': ['Super User', 'Tenant Admin']}, {'name': 'Add URL', 'url': '/document/add_url', 'roles': ['Super User', 'Tenant Admin']},
{'name': 'Add a list of URLs', 'url': '/document/add_urls', 'roles': ['Super User', 'Tenant Admin']}, {'name': 'Add a list of URLs', 'url': '/document/add_urls', 'roles': ['Super User', 'Tenant Admin']},
{'name': 'Add Youtube Document' , 'url': '/document/add_youtube', 'roles': ['Super User', 'Tenant Admin']},
{'name': 'All Documents', 'url': '/document/documents', 'roles': ['Super User', 'Tenant Admin']}, {'name': 'All Documents', 'url': '/document/documents', 'roles': ['Super User', 'Tenant Admin']},
{'name': 'All Document Versions', 'url': '/document/document_versions_list', 'roles': ['Super User', 'Tenant Admin']}, {'name': 'All Document Versions', 'url': '/document/document_versions_list', 'roles': ['Super User', 'Tenant Admin']},
{'name': 'Library Operations', 'url': '/document/library_operations', 'roles': ['Super User', 'Tenant Admin']}, {'name': 'Library Operations', 'url': '/document/library_operations', 'roles': ['Super User', 'Tenant Admin']},

View File

@@ -62,18 +62,25 @@
{{ render_included_field(field, disabled_fields=license_fields, include_fields=license_fields) }} {{ render_included_field(field, disabled_fields=license_fields, include_fields=license_fields) }}
{% endfor %} {% endfor %}
<!-- Register API Key Button --> <!-- Register API Key Button -->
<button type="button" class="btn btn-primary" onclick="generateNewChatApiKey()">Register Chat API Key</button>
<button type="button" class="btn btn-primary" onclick="generateNewApiKey()">Register API Key</button> <button type="button" class="btn btn-primary" onclick="generateNewApiKey()">Register API Key</button>
<!-- API Key Display Field --> <!-- API Key Display Field -->
<div id="chat-api-key-field" style="display:none;">
<label for="chat-api-key">Chat API Key:</label>
<input type="text" id="chat-api-key" class="form-control" readonly>
<button type="button" id="copy-chat-button" class="btn btn-primary">Copy to Clipboard</button>
<p id="copy-chat-message" style="display:none;color:green;">Chat API key copied to clipboard</p>
</div>
<div id="api-key-field" style="display:none;"> <div id="api-key-field" style="display:none;">
<label for="api-key">API Key:</label> <label for="api-key">API Key:</label>
<input type="text" id="api-key" class="form-control" readonly> <input type="text" id="api-key" class="form-control" readonly>
<button type="button" id="copy-button" class="btn btn-primary">Copy to Clipboard</button> <button type="button" id="copy-api-button" class="btn btn-primary">Copy to Clipboard</button>
<p id="copy-message" style="display:none;color:green;">API key copied to clipboard</p> <p id="copy-message" style="display:none;color:green;">API key copied to clipboard</p>
</div> </div>
</div> </div>
<!-- Chunking Settings Tab --> <!-- Chunking Settings Tab -->
<div class="tab-pane fade" id="chunking-tab" role="tabpanel"> <div class="tab-pane fade" id="chunking-tab" role="tabpanel">
{% set html_fields = ['html_tags', 'html_end_tags', 'html_included_elements', 'html_excluded_elements', 'min_chunk_size', 'max_chunk_size'] %} {% set html_fields = ['html_tags', 'html_end_tags', 'html_included_elements', 'html_excluded_elements', 'html_excluded_classes', 'min_chunk_size', 'max_chunk_size'] %}
{% for field in form %} {% for field in form %}
{{ render_included_field(field, disabled_fields=html_fields, include_fields=html_fields) }} {{ render_included_field(field, disabled_fields=html_fields, include_fields=html_fields) }}
{% endfor %} {% endfor %}
@@ -105,14 +112,25 @@
{% block scripts %} {% block scripts %}
<script> <script>
// Function to generate a new Chat API Key
function generateNewChatApiKey() {
generateApiKey('/admin/user/generate_chat_api_key', '#chat-api-key', '#chat-api-key-field');
}
// Function to generate a new general API Key
function generateNewApiKey() { function generateNewApiKey() {
generateApiKey('/admin/user/generate_api_api_key', '#api-key', '#api-key-field');
}
// Reusable function to handle API key generation
function generateApiKey(url, inputSelector, fieldSelector) {
$.ajax({ $.ajax({
url: '/user/generate_chat_api_key', url: url,
type: 'POST', type: 'POST',
contentType: 'application/json', contentType: 'application/json',
success: function(response) { success: function(response) {
$('#api-key').val(response.api_key); $(inputSelector).val(response.api_key);
$('#api-key-field').show(); $(fieldSelector).show();
}, },
error: function(error) { error: function(error) {
alert('Error generating new API key: ' + error.responseText); alert('Error generating new API key: ' + error.responseText);
@@ -120,25 +138,27 @@
}); });
} }
function copyToClipboard(selector) { // Function to copy text to clipboard
function copyToClipboard(selector, messageSelector) {
const element = document.querySelector(selector); const element = document.querySelector(selector);
if (element) { if (element) {
const text = element.value; const text = element.value;
if (navigator.clipboard && navigator.clipboard.writeText) { if (navigator.clipboard && navigator.clipboard.writeText) {
navigator.clipboard.writeText(text).then(function() { navigator.clipboard.writeText(text).then(function() {
showCopyMessage(); showCopyMessage(messageSelector);
}).catch(function(error) { }).catch(function(error) {
alert('Failed to copy text: ' + error); alert('Failed to copy text: ' + error);
}); });
} else { } else {
fallbackCopyToClipboard(text); fallbackCopyToClipboard(text, messageSelector);
} }
} else { } else {
console.error('Element not found for selector:', selector); console.error('Element not found for selector:', selector);
} }
} }
function fallbackCopyToClipboard(text) { // Fallback method for copying text to clipboard
function fallbackCopyToClipboard(text, messageSelector) {
const textArea = document.createElement('textarea'); const textArea = document.createElement('textarea');
textArea.value = text; textArea.value = text;
document.body.appendChild(textArea); document.body.appendChild(textArea);
@@ -146,15 +166,16 @@
textArea.select(); textArea.select();
try { try {
document.execCommand('copy'); document.execCommand('copy');
showCopyMessage(); showCopyMessage(messageSelector);
} catch (err) { } catch (err) {
alert('Fallback: Oops, unable to copy', err); alert('Fallback: Oops, unable to copy', err);
} }
document.body.removeChild(textArea); document.body.removeChild(textArea);
} }
function showCopyMessage() { // Function to show copy confirmation message
const message = document.getElementById('copy-message'); function showCopyMessage(messageSelector) {
const message = document.querySelector(messageSelector);
if (message) { if (message) {
message.style.display = 'block'; message.style.display = 'block';
setTimeout(function() { setTimeout(function() {
@@ -163,8 +184,13 @@
} }
} }
document.getElementById('copy-button').addEventListener('click', function() { // Event listeners for copy buttons
copyToClipboard('#api-key'); document.getElementById('copy-chat-button').addEventListener('click', function() {
copyToClipboard('#chat-api-key', '#copy-chat-message');
});
document.getElementById('copy-api-button').addEventListener('click', function() {
copyToClipboard('#api-key', '#copy-message');
}); });
</script> </script>
<script> <script>

View File

@@ -1,18 +1,36 @@
from flask import session from flask import session, current_app
from flask_wtf import FlaskForm from flask_wtf import FlaskForm
from wtforms import (StringField, BooleanField, SubmitField, DateField, from wtforms import (StringField, BooleanField, SubmitField, DateField,
SelectField, FieldList, FormField, TextAreaField, URLField) SelectField, FieldList, FormField, TextAreaField, URLField)
from wtforms.validators import DataRequired, Length, Optional, URL from wtforms.validators import DataRequired, Length, Optional, URL, ValidationError
from flask_wtf.file import FileField, FileAllowed, FileRequired from flask_wtf.file import FileField, FileAllowed, FileRequired
import json
def allowed_file(form, field):
if field.data:
filename = field.data.filename
allowed_extensions = current_app.config.get('SUPPORTED_FILE_TYPES', [])
if not ('.' in filename and filename.rsplit('.', 1)[1].lower() in allowed_extensions):
raise ValidationError('Unsupported file type.')
def validate_json(form, field):
if field.data:
try:
json.loads(field.data)
except json.JSONDecodeError:
raise ValidationError('Invalid JSON format')
class AddDocumentForm(FlaskForm): class AddDocumentForm(FlaskForm):
file = FileField('File', validators=[FileAllowed(['pdf', 'txt', 'html']), file = FileField('File', validators=[FileRequired(), allowed_file])
FileRequired()])
name = StringField('Name', validators=[Length(max=100)]) name = StringField('Name', validators=[Length(max=100)])
language = SelectField('Language', choices=[], validators=[Optional()]) language = SelectField('Language', choices=[], validators=[Optional()])
user_context = TextAreaField('User Context', validators=[Optional()]) user_context = TextAreaField('User Context', validators=[Optional()])
valid_from = DateField('Valid from', id='form-control datepicker', validators=[Optional()]) valid_from = DateField('Valid from', id='form-control datepicker', validators=[Optional()])
user_metadata = TextAreaField('User Metadata', validators=[Optional(), validate_json])
system_metadata = TextAreaField('System Metadata', validators=[Optional(), validate_json])
submit = SubmitField('Submit') submit = SubmitField('Submit')
@@ -20,6 +38,7 @@ class AddDocumentForm(FlaskForm):
super().__init__() super().__init__()
self.language.choices = [(language, language) for language in self.language.choices = [(language, language) for language in
session.get('tenant').get('allowed_languages')] session.get('tenant').get('allowed_languages')]
self.language.data = session.get('tenant').get('default_language')
class AddURLForm(FlaskForm): class AddURLForm(FlaskForm):
@@ -28,6 +47,8 @@ class AddURLForm(FlaskForm):
language = SelectField('Language', choices=[], validators=[Optional()]) language = SelectField('Language', choices=[], validators=[Optional()])
user_context = TextAreaField('User Context', validators=[Optional()]) user_context = TextAreaField('User Context', validators=[Optional()])
valid_from = DateField('Valid from', id='form-control datepicker', validators=[Optional()]) valid_from = DateField('Valid from', id='form-control datepicker', validators=[Optional()])
user_metadata = TextAreaField('User Metadata', validators=[Optional(), validate_json])
system_metadata = TextAreaField('System Metadata', validators=[Optional(), validate_json])
submit = SubmitField('Submit') submit = SubmitField('Submit')
@@ -35,6 +56,7 @@ class AddURLForm(FlaskForm):
super().__init__() super().__init__()
self.language.choices = [(language, language) for language in self.language.choices = [(language, language) for language in
session.get('tenant').get('allowed_languages')] session.get('tenant').get('allowed_languages')]
self.language.data = session.get('tenant').get('default_language')
class AddURLsForm(FlaskForm): class AddURLsForm(FlaskForm):
@@ -50,21 +72,7 @@ class AddURLsForm(FlaskForm):
super().__init__() super().__init__()
self.language.choices = [(language, language) for language in self.language.choices = [(language, language) for language in
session.get('tenant').get('allowed_languages')] session.get('tenant').get('allowed_languages')]
self.language.data = session.get('tenant').get('default_language')
class AddYoutubeForm(FlaskForm):
url = URLField('Youtube URL', validators=[DataRequired(), URL()])
name = StringField('Name', validators=[Length(max=100)])
language = SelectField('Language', choices=[], validators=[Optional()])
user_context = TextAreaField('User Context', validators=[Optional()])
valid_from = DateField('Valid from', id='form-control datepicker', validators=[Optional()])
submit = SubmitField('Submit')
def __init__(self):
super().__init__()
self.language.choices = [(language, language) for language in
session.get('tenant').get('allowed_languages')]
class EditDocumentForm(FlaskForm): class EditDocumentForm(FlaskForm):
@@ -79,8 +87,7 @@ class EditDocumentVersionForm(FlaskForm):
language = StringField('Language') language = StringField('Language')
user_context = TextAreaField('User Context', validators=[Optional()]) user_context = TextAreaField('User Context', validators=[Optional()])
system_context = TextAreaField('System Context', validators=[Optional()]) system_context = TextAreaField('System Context', validators=[Optional()])
user_metadata = TextAreaField('User Metadata', validators=[Optional(), validate_json])
system_metadata = TextAreaField('System Metadata', validators=[Optional(), validate_json])
submit = SubmitField('Submit') submit = SubmitField('Submit')

View File

@@ -1,25 +1,25 @@
import ast import ast
import os
from datetime import datetime as dt, timezone as tz from datetime import datetime as dt, timezone as tz
import chardet
from flask import request, redirect, flash, render_template, Blueprint, session, current_app from flask import request, redirect, flash, render_template, Blueprint, session, current_app
from flask_security import roles_accepted, current_user from flask_security import roles_accepted, current_user
from sqlalchemy import desc from sqlalchemy import desc
from sqlalchemy.orm import joinedload
from werkzeug.datastructures import FileStorage
from werkzeug.utils import secure_filename from werkzeug.utils import secure_filename
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
import requests import requests
from requests.exceptions import SSLError from requests.exceptions import SSLError
from urllib.parse import urlparse from urllib.parse import urlparse, unquote
import io import io
from minio.error import S3Error import json
from common.models.document import Document, DocumentVersion from common.models.document import Document, DocumentVersion
from common.extensions import db, minio_client from common.extensions import db, minio_client
from .document_forms import AddDocumentForm, AddURLForm, EditDocumentForm, EditDocumentVersionForm, AddYoutubeForm, \ from common.utils.document_utils import validate_file_type, create_document_stack, start_embedding_task, process_url, \
AddURLsForm process_multiple_urls, get_documents_list, edit_document, \
edit_document_version, refresh_document
from common.utils.eveai_exceptions import EveAIInvalidLanguageException, EveAIUnsupportedFileType, \
EveAIDoubleURLException
from .document_forms import AddDocumentForm, AddURLForm, EditDocumentForm, EditDocumentVersionForm, AddURLsForm
from common.utils.middleware import mw_before_request from common.utils.middleware import mw_before_request
from common.utils.celery_utils import current_celery from common.utils.celery_utils import current_celery
from common.utils.nginx_utils import prefixed_url_for from common.utils.nginx_utils import prefixed_url_for
@@ -56,30 +56,40 @@ def before_request():
@roles_accepted('Super User', 'Tenant Admin') @roles_accepted('Super User', 'Tenant Admin')
def add_document(): def add_document():
form = AddDocumentForm() form = AddDocumentForm()
current_app.logger.debug('Adding document')
# If the form is submitted
if form.validate_on_submit(): if form.validate_on_submit():
current_app.logger.info(f'Adding document for tenant {session["tenant"]["id"]}') try:
file = form.file.data current_app.logger.debug('Validating file type')
filename = secure_filename(file.filename) tenant_id = session['tenant']['id']
extension = filename.rsplit('.', 1)[1].lower() file = form.file.data
form_dict = form_to_dict(form) filename = secure_filename(file.filename)
extension = filename.rsplit('.', 1)[1].lower()
new_doc, new_doc_vers = create_document_stack(form_dict, file, filename, extension) validate_file_type(extension)
task = current_celery.send_task('create_embeddings', queue='embeddings', args=[ api_input = {
session['tenant']['id'], 'name': form.name.data,
new_doc_vers.id, 'language': form.language.data,
]) 'user_context': form.user_context.data,
current_app.logger.info(f'Embedding creation started for tenant {session["tenant"]["id"]}, ' 'valid_from': form.valid_from.data,
f'Document Version {new_doc_vers.id}. ' 'user_metadata': json.loads(form.user_metadata.data) if form.user_metadata.data else None,
f'Embedding creation task: {task.id}') 'system_metadata': json.loads(form.system_metadata.data) if form.system_metadata.data else None
flash(f'Processing on document {new_doc.name}, version {new_doc_vers.id} started. Task ID: {task.id}.',
'success')
return redirect(prefixed_url_for('document_bp.documents')) }
else:
form_validation_failed(request, form) new_doc, new_doc_vers = create_document_stack(api_input, file, filename, extension, tenant_id)
task_id = start_embedding_task(tenant_id, new_doc_vers.id)
flash(f'Processing on document {new_doc.name}, version {new_doc_vers.id} started. Task ID: {task_id}.',
'success')
return redirect(prefixed_url_for('document_bp.documents'))
except (EveAIInvalidLanguageException, EveAIUnsupportedFileType) as e:
flash(str(e), 'error')
except Exception as e:
current_app.logger.error(f'Error adding document: {str(e)}')
flash('An error occurred while adding the document.', 'error')
return render_template('document/add_document.html', form=form) return render_template('document/add_document.html', form=form)
@@ -89,45 +99,37 @@ def add_document():
def add_url(): def add_url():
form = AddURLForm() form = AddURLForm()
# If the form is submitted
if form.validate_on_submit(): if form.validate_on_submit():
current_app.logger.info(f'Adding url for tenant {session["tenant"]["id"]}') try:
url = form.url.data tenant_id = session['tenant']['id']
url = form.url.data
doc_vers = DocumentVersion.query.filter_by(url=url).all() file_content, filename, extension = process_url(url, tenant_id)
if doc_vers:
current_app.logger.info(f'A document with url {url} already exists. No new document created.') api_input = {
flash(f'A document with url {url} already exists. No new document created.', 'info') 'name': form.name.data or filename,
'url': url,
'language': form.language.data,
'user_context': form.user_context.data,
'valid_from': form.valid_from.data,
'user_metadata': json.loads(form.user_metadata.data) if form.user_metadata.data else None,
'system_metadata': json.loads(form.system_metadata.data) if form.system_metadata.data else None
}
new_doc, new_doc_vers = create_document_stack(api_input, file_content, filename, extension, tenant_id)
task_id = start_embedding_task(tenant_id, new_doc_vers.id)
flash(f'Processing on document {new_doc.name}, version {new_doc_vers.id} started. Task ID: {task_id}.',
'success')
return redirect(prefixed_url_for('document_bp.documents')) return redirect(prefixed_url_for('document_bp.documents'))
# Only when no document with URL exists
html = fetch_html(url)
file = io.BytesIO(html)
parsed_url = urlparse(url) except EveAIDoubleURLException:
path_parts = parsed_url.path.split('/') flash(f'A document with url {url} already exists. No new document created.', 'info')
filename = path_parts[-1] except (EveAIInvalidLanguageException, EveAIUnsupportedFileType) as e:
if filename == '': flash(str(e), 'error')
filename = 'index' except Exception as e:
if not filename.endswith('.html'): current_app.logger.error(f'Error adding document: {str(e)}')
filename += '.html' flash('An error occurred while adding the document.', 'error')
extension = 'html'
form_dict = form_to_dict(form)
new_doc, new_doc_vers = create_document_stack(form_dict, file, filename, extension)
task = current_celery.send_task('create_embeddings', queue='embeddings', args=[
session['tenant']['id'],
new_doc_vers.id,
])
current_app.logger.info(f'Embedding creation started for tenant {session["tenant"]["id"]}, '
f'Document Version {new_doc_vers.id}. '
f'Embedding creation task: {task.id}')
flash(f'Processing on document {new_doc.name}, version {new_doc_vers.id} started. Task ID: {task.id}.',
'success')
return redirect(prefixed_url_for('document_bp.documents'))
else:
form_validation_failed(request, form)
return render_template('document/add_url.html', form=form) return render_template('document/add_url.html', form=form)
@@ -138,100 +140,36 @@ def add_urls():
form = AddURLsForm() form = AddURLsForm()
if form.validate_on_submit(): if form.validate_on_submit():
urls = form.urls.data.split('\n') try:
urls = [url.strip() for url in urls if url.strip()] tenant_id = session['tenant']['id']
urls = form.urls.data.split('\n')
urls = [url.strip() for url in urls if url.strip()]
for i, url in enumerate(urls): api_input = {
try: 'name': form.name.data,
doc_vers = DocumentVersion.query.filter_by(url=url).all() 'language': form.language.data,
if doc_vers: 'user_context': form.user_context.data,
current_app.logger.info(f'A document with url {url} already exists. No new document created.') 'valid_from': form.valid_from.data
flash(f'A document with url {url} already exists. No new document created.', 'info') }
continue
html = fetch_html(url) results = process_multiple_urls(urls, tenant_id, api_input)
file = io.BytesIO(html)
parsed_url = urlparse(url) for result in results:
path_parts = parsed_url.path.split('/') if result['status'] == 'success':
filename = path_parts[-1] if path_parts[-1] else 'index' flash(
if not filename.endswith('.html'): f"Processed URL: {result['url']} - Document ID: {result['document_id']}, Version ID: {result['document_version_id']}",
filename += '.html' 'success')
else:
flash(f"Error processing URL: {result['url']} - {result['message']}", 'error')
# Use the name prefix if provided, otherwise use the filename return redirect(prefixed_url_for('document_bp.documents'))
doc_name = f"{form.name.data}-{filename}" if form.name.data else filename
new_doc, new_doc_vers = create_document_stack({ except Exception as e:
'name': doc_name, current_app.logger.error(f'Error adding multiple URLs: {str(e)}')
'url': url, flash('An error occurred while adding the URLs.', 'error')
'language': form.language.data,
'user_context': form.user_context.data,
'valid_from': form.valid_from.data
}, file, filename, 'html')
task = current_celery.send_task('create_embeddings', queue='embeddings', args=[
session['tenant']['id'],
new_doc_vers.id,
])
current_app.logger.info(f'Embedding creation started for tenant {session["tenant"]["id"]}, '
f'Document Version {new_doc_vers.id}. '
f'Embedding creation task: {task.id}')
flash(f'Processing on document {new_doc.name}, version {new_doc_vers.id} started. Task ID: {task.id}.',
'success')
except Exception as e:
current_app.logger.error(f"Error processing URL {url}: {str(e)}")
flash(f'Error processing URL {url}: {str(e)}', 'danger')
return redirect(prefixed_url_for('document_bp.documents'))
else:
form_validation_failed(request, form)
return render_template('document/add_urls.html', form=form) return render_template('document/add_urls.html', form=form)
@document_bp.route('/add_youtube', methods=['GET', 'POST'])
@roles_accepted('Super User', 'Tenant Admin')
def add_youtube():
form = AddYoutubeForm()
if form.validate_on_submit():
current_app.logger.info(f'Adding Youtube document for tenant {session["tenant"]["id"]}')
url = form.url.data
current_app.logger.debug(f'Value of language field: {form.language.data}')
doc_vers = DocumentVersion.query.filter_by(url=url).all()
if doc_vers:
current_app.logger.info(f'A document with url {url} already exists. No new document created.')
flash(f'A document with url {url} already exists. No new document created.', 'info')
return redirect(prefixed_url_for('document_bp.documents'))
# As downloading a Youtube document can take quite some time, we offload this downloading to the worker
# We just pass a simple file to get things conform
file = "Youtube placeholder file"
filename = 'placeholder.youtube'
extension = 'youtube'
form_dict = form_to_dict(form)
current_app.logger.debug(f'Form data: {form_dict}')
new_doc, new_doc_vers = create_document_stack(form_dict, file, filename, extension)
task = current_celery.send_task('create_embeddings', queue='embeddings', args=[
session['tenant']['id'],
new_doc_vers.id,
])
current_app.logger.info(f'Processing and Embedding on Youtube document started for tenant '
f'{session["tenant"]["id"]}, '
f'Document Version {new_doc_vers.id}. '
f'Processing and Embedding Youtube task: {task.id}')
flash(f'Processing on Youtube document {new_doc.name}, version {new_doc_vers.id} started. Task ID: {task.id}.',
'success')
return redirect(prefixed_url_for('document_bp.documents'))
else:
form_validation_failed(request, form)
return render_template('document/add_youtube.html', form=form)
@document_bp.route('/documents', methods=['GET', 'POST']) @document_bp.route('/documents', methods=['GET', 'POST'])
@roles_accepted('Super User', 'Tenant Admin') @roles_accepted('Super User', 'Tenant Admin')
@@ -239,9 +177,7 @@ def documents():
page = request.args.get('page', 1, type=int) page = request.args.get('page', 1, type=int)
per_page = request.args.get('per_page', 10, type=int) per_page = request.args.get('per_page', 10, type=int)
query = Document.query.order_by(desc(Document.created_at)) pagination = get_documents_list(page, per_page)
pagination = query.paginate(page=page, per_page=per_page, error_out=False)
docs = pagination.items docs = pagination.items
rows = prepare_table_for_macro(docs, [('id', ''), ('name', ''), ('valid_from', ''), ('valid_to', '')]) rows = prepare_table_for_macro(docs, [('id', ''), ('name', ''), ('valid_from', ''), ('valid_to', '')])
@@ -259,11 +195,11 @@ def handle_document_selection():
match action: match action:
case 'edit_document': case 'edit_document':
return redirect(prefixed_url_for('document_bp.edit_document', document_id=doc_id)) return redirect(prefixed_url_for('document_bp.edit_document_view', document_id=doc_id))
case 'document_versions': case 'document_versions':
return redirect(prefixed_url_for('document_bp.document_versions', document_id=doc_id)) return redirect(prefixed_url_for('document_bp.document_versions', document_id=doc_id))
case 'refresh_document': case 'refresh_document':
refresh_document(doc_id) refresh_document_view(doc_id)
return redirect(prefixed_url_for('document_bp.document_versions', document_id=doc_id)) return redirect(prefixed_url_for('document_bp.document_versions', document_id=doc_id))
case 're_embed_latest_versions': case 're_embed_latest_versions':
re_embed_latest_versions() re_embed_latest_versions()
@@ -274,25 +210,22 @@ def handle_document_selection():
@document_bp.route('/edit_document/<int:document_id>', methods=['GET', 'POST']) @document_bp.route('/edit_document/<int:document_id>', methods=['GET', 'POST'])
@roles_accepted('Super User', 'Tenant Admin') @roles_accepted('Super User', 'Tenant Admin')
def edit_document(document_id): def edit_document_view(document_id):
doc = Document.query.get_or_404(document_id) doc = Document.query.get_or_404(document_id)
form = EditDocumentForm(obj=doc) form = EditDocumentForm(obj=doc)
if form.validate_on_submit(): if form.validate_on_submit():
doc.name = form.name.data updated_doc, error = edit_document(
doc.valid_from = form.valid_from.data document_id,
doc.valid_to = form.valid_to.data form.name.data,
form.valid_from.data,
update_logging_information(doc, dt.now(tz.utc)) form.valid_to.data
)
try: if updated_doc:
db.session.add(doc) flash(f'Document {updated_doc.id} updated successfully', 'success')
db.session.commit() return redirect(prefixed_url_for('document_bp.documents'))
flash(f'Document {doc.id} updated successfully', 'success') else:
except SQLAlchemyError as e: flash(f'Error updating document: {error}', 'danger')
db.session.rollback()
flash(f'Error updating document: {e}', 'danger')
current_app.logger.error(f'Error updating document: {e}')
else: else:
form_validation_failed(request, form) form_validation_failed(request, form)
@@ -301,24 +234,20 @@ def edit_document(document_id):
@document_bp.route('/edit_document_version/<int:document_version_id>', methods=['GET', 'POST']) @document_bp.route('/edit_document_version/<int:document_version_id>', methods=['GET', 'POST'])
@roles_accepted('Super User', 'Tenant Admin') @roles_accepted('Super User', 'Tenant Admin')
def edit_document_version(document_version_id): def edit_document_version_view(document_version_id):
doc_vers = DocumentVersion.query.get_or_404(document_version_id) doc_vers = DocumentVersion.query.get_or_404(document_version_id)
form = EditDocumentVersionForm(obj=doc_vers) form = EditDocumentVersionForm(obj=doc_vers)
if form.validate_on_submit(): if form.validate_on_submit():
doc_vers.user_context = form.user_context.data updated_version, error = edit_document_version(
document_version_id,
update_logging_information(doc_vers, dt.now(tz.utc)) form.user_context.data
)
try: if updated_version:
db.session.add(doc_vers) flash(f'Document Version {updated_version.id} updated successfully', 'success')
db.session.commit() return redirect(prefixed_url_for('document_bp.document_versions', document_id=updated_version.doc_id))
flash(f'Document Version {doc_vers.id} updated successfully', 'success') else:
except SQLAlchemyError as e: flash(f'Error updating document version: {error}', 'danger')
db.session.rollback()
flash(f'Error updating document version: {e}', 'danger')
current_app.logger.error(f'Error updating document version {doc_vers.id} '
f'for tenant {session['tenant']['id']}: {e}')
else: else:
form_validation_failed(request, form) form_validation_failed(request, form)
@@ -329,8 +258,8 @@ def edit_document_version(document_version_id):
@document_bp.route('/document_versions/<int:document_id>', methods=['GET', 'POST']) @document_bp.route('/document_versions/<int:document_id>', methods=['GET', 'POST'])
@roles_accepted('Super User', 'Tenant Admin') @roles_accepted('Super User', 'Tenant Admin')
def document_versions(document_id): def document_versions(document_id):
doc_vers = DocumentVersion.query.get_or_404(document_id) doc = Document.query.get_or_404(document_id)
doc_desc = f'Document {doc_vers.document.name}, Language {doc_vers.language}' doc_desc = f'Document {doc.name}'
page = request.args.get('page', 1, type=int) page = request.args.get('page', 1, type=int)
per_page = request.args.get('per_page', 10, type=int) per_page = request.args.get('per_page', 10, type=int)
@@ -358,9 +287,11 @@ def handle_document_version_selection():
action = request.form['action'] action = request.form['action']
current_app.logger.debug(f'Triggered Document Version Action: {action}')
match action: match action:
case 'edit_document_version': case 'edit_document_version':
return redirect(prefixed_url_for('document_bp.edit_document_version', document_version_id=doc_vers_id)) return redirect(prefixed_url_for('document_bp.edit_document_version_view', document_version_id=doc_vers_id))
case 'process_document_version': case 'process_document_version':
process_version(doc_vers_id) process_version(doc_vers_id)
# Add more conditions for other actions # Add more conditions for other actions
@@ -403,55 +334,13 @@ def refresh_all_documents():
refresh_document(doc.id) refresh_document(doc.id)
def refresh_document(doc_id): def refresh_document_view(document_id):
doc = Document.query.get_or_404(doc_id) new_version, result = refresh_document(document_id)
doc_vers = DocumentVersion.query.filter_by(doc_id=doc_id).order_by(desc(DocumentVersion.id)).first() if new_version:
if not doc_vers.url: flash(f'Document refreshed. New version: {new_version.id}. Task ID: {result}', 'success')
current_app.logger.info(f'Document {doc_id} has no URL, skipping refresh') else:
flash(f'This document has no URL. I can only refresh documents with a URL. skipping refresh', 'alert') flash(f'Error refreshing document: {result}', 'danger')
return return redirect(prefixed_url_for('document_bp.documents'))
new_doc_vers = create_version_for_document(doc, doc_vers.url, doc_vers.language, doc_vers.user_context)
try:
db.session.add(new_doc_vers)
db.session.commit()
except SQLAlchemyError as e:
current_app.logger.error(f'Error refreshing document {doc_id} for tenant {session["tenant"]["id"]}: {e}')
flash('Error refreshing document.', 'alert')
db.session.rollback()
error = e.args
raise
except Exception as e:
current_app.logger.error('Unknown error')
raise
html = fetch_html(new_doc_vers.url)
file = io.BytesIO(html)
parsed_url = urlparse(new_doc_vers.url)
path_parts = parsed_url.path.split('/')
filename = path_parts[-1]
if filename == '':
filename = 'index'
if not filename.endswith('.html'):
filename += '.html'
extension = 'html'
current_app.logger.info(f'Document added successfully for tenant {session["tenant"]["id"]}, '
f'Document Version {new_doc_vers.id}')
upload_file_for_version(new_doc_vers, file, extension)
task = current_celery.send_task('create_embeddings', queue='embeddings', args=[
session['tenant']['id'],
new_doc_vers.id,
])
current_app.logger.info(f'Embedding creation started for tenant {session["tenant"]["id"]}, '
f'Document Version {new_doc_vers.id}. '
f'Embedding creation task: {task.id}')
flash(f'Processing on document {doc.name}, version {new_doc_vers.id} started. Task ID: {task.id}.',
'success')
def re_embed_latest_versions(): def re_embed_latest_versions():
@@ -489,116 +378,11 @@ def update_logging_information(obj, timestamp):
obj.updated_by = current_user.id obj.updated_by = current_user.id
def create_document_stack(form, file, filename, extension):
# Create the Document
new_doc = create_document(form, filename)
# Create the DocumentVersion
new_doc_vers = create_version_for_document(new_doc,
form.get('url', ''),
form.get('language', 'en'),
form.get('user_context', '')
)
try:
db.session.add(new_doc)
db.session.add(new_doc_vers)
db.session.commit()
except SQLAlchemyError as e:
current_app.logger.error(f'Error adding document for tenant {session["tenant"]["id"]}: {e}')
flash('Error adding document.', 'alert')
db.session.rollback()
error = e.args
raise
except Exception as e:
current_app.logger.error('Unknown error')
raise
current_app.logger.info(f'Document added successfully for tenant {session["tenant"]["id"]}, '
f'Document Version {new_doc.id}')
upload_file_for_version(new_doc_vers, file, extension)
return new_doc, new_doc_vers
def log_session_state(session, msg=""): def log_session_state(session, msg=""):
current_app.logger.debug(f"{msg} - Session dirty: {session.dirty}") current_app.logger.debug(f"{msg} - Session dirty: {session.dirty}")
current_app.logger.debug(f"{msg} - Session new: {session.new}") current_app.logger.debug(f"{msg} - Session new: {session.new}")
def create_document(form, filename):
new_doc = Document()
if form['name'] == '':
new_doc.name = filename.rsplit('.', 1)[0]
else:
new_doc.name = form['name']
if form['valid_from'] and form['valid_from'] != '':
new_doc.valid_from = form['valid_from']
else:
new_doc.valid_from = dt.now(tz.utc)
new_doc.tenant_id = session['tenant']['id']
set_logging_information(new_doc, dt.now(tz.utc))
return new_doc
def create_version_for_document(document, url, language, user_context):
new_doc_vers = DocumentVersion()
if url != '':
new_doc_vers.url = url
if language == '':
new_doc_vers.language = session['default_language']
else:
new_doc_vers.language = language
if user_context != '':
new_doc_vers.user_context = user_context
new_doc_vers.document = document
set_logging_information(new_doc_vers, dt.now(tz.utc))
return new_doc_vers
def upload_file_for_version(doc_vers, file, extension):
doc_vers.file_type = extension
doc_vers.file_name = doc_vers.calc_file_name()
doc_vers.file_location = doc_vers.calc_file_location()
# Normally, the tenant bucket should exist. But let's be on the safe side if a migration took place.
tenant_id = session['tenant']['id']
minio_client.create_tenant_bucket(tenant_id)
try:
minio_client.upload_document_file(
tenant_id,
doc_vers.doc_id,
doc_vers.language,
doc_vers.id,
doc_vers.file_name,
file
)
db.session.commit()
current_app.logger.info(f'Successfully saved document to MinIO for tenant {tenant_id} for '
f'document version {doc_vers.id} while uploading file.')
except S3Error as e:
db.session.rollback()
flash('Error saving document to MinIO.', 'error')
current_app.logger.error(
f'Error saving document to MinIO for tenant {tenant_id}: {e}')
raise
except SQLAlchemyError as e:
db.session.rollback()
flash('Error saving document metadata.', 'error')
current_app.logger.error(
f'Error saving document metadata for tenant {tenant_id}: {e}')
raise
def fetch_html(url): def fetch_html(url):
# Fetches HTML content from a URL # Fetches HTML content from a URL
try: try:

View File

@@ -15,7 +15,8 @@ from requests.exceptions import SSLError
from urllib.parse import urlparse from urllib.parse import urlparse
import io import io
from common.models.interaction import ChatSession, Interaction from common.models.document import Embedding, DocumentVersion
from common.models.interaction import ChatSession, Interaction, InteractionEmbedding
from common.extensions import db from common.extensions import db
from .document_forms import AddDocumentForm, AddURLForm, EditDocumentForm, EditDocumentVersionForm from .document_forms import AddDocumentForm, AddURLForm, EditDocumentForm, EditDocumentVersionForm
from common.utils.middleware import mw_before_request from common.utils.middleware import mw_before_request
@@ -80,11 +81,34 @@ def handle_chat_session_selection():
return redirect(prefixed_url_for('interaction_bp.chat_sessions')) return redirect(prefixed_url_for('interaction_bp.chat_sessions'))
@interaction_bp.route('/view_chat_session/<chat_session_id>', methods=['GET']) @interaction_bp.route('/view_chat_session/<int:chat_session_id>', methods=['GET'])
@roles_accepted('Super User', 'Tenant Admin') @roles_accepted('Super User', 'Tenant Admin')
def view_chat_session(chat_session_id): def view_chat_session(chat_session_id):
chat_session = ChatSession.query.get_or_404(chat_session_id) chat_session = ChatSession.query.get_or_404(chat_session_id)
show_chat_session(chat_session) interactions = (Interaction.query
.filter_by(chat_session_id=chat_session.id)
.order_by(Interaction.question_at)
.all())
# Fetch all related embeddings for the interactions in this session
embedding_query = (db.session.query(InteractionEmbedding.interaction_id,
DocumentVersion.url,
DocumentVersion.file_name)
.join(Embedding, InteractionEmbedding.embedding_id == Embedding.id)
.join(DocumentVersion, Embedding.doc_vers_id == DocumentVersion.id)
.filter(InteractionEmbedding.interaction_id.in_([i.id for i in interactions])))
# Create a dictionary to store embeddings for each interaction
embeddings_dict = {}
for interaction_id, url, file_name in embedding_query:
if interaction_id not in embeddings_dict:
embeddings_dict[interaction_id] = []
embeddings_dict[interaction_id].append({'url': url, 'file_name': file_name})
return render_template('interaction/view_chat_session.html',
chat_session=chat_session,
interactions=interactions,
embeddings_dict=embeddings_dict)
@interaction_bp.route('/view_chat_session_by_session_id/<session_id>', methods=['GET']) @interaction_bp.route('/view_chat_session_by_session_id/<session_id>', methods=['GET'])

View File

@@ -32,6 +32,7 @@ class TenantForm(FlaskForm):
default='p, li') default='p, li')
html_included_elements = StringField('HTML Included Elements', validators=[Optional()]) html_included_elements = StringField('HTML Included Elements', validators=[Optional()])
html_excluded_elements = StringField('HTML Excluded Elements', validators=[Optional()]) html_excluded_elements = StringField('HTML Excluded Elements', validators=[Optional()])
html_excluded_classes = StringField('HTML Excluded Classes', validators=[Optional()])
min_chunk_size = IntegerField('Minimum Chunk Size (2000)', validators=[NumberRange(min=0), Optional()], default=2000) min_chunk_size = IntegerField('Minimum Chunk Size (2000)', validators=[NumberRange(min=0), Optional()], default=2000)
max_chunk_size = IntegerField('Maximum Chunk Size (3000)', validators=[NumberRange(min=0), Optional()], default=3000) max_chunk_size = IntegerField('Maximum Chunk Size (3000)', validators=[NumberRange(min=0), Optional()], default=3000)
# Embedding Search variables # Embedding Search variables

View File

@@ -68,6 +68,8 @@ def tenant():
if form.html_included_elements.data else [] if form.html_included_elements.data else []
new_tenant.html_excluded_elements = [tag.strip() for tag in form.html_excluded_elements.data.split(',')] \ new_tenant.html_excluded_elements = [tag.strip() for tag in form.html_excluded_elements.data.split(',')] \
if form.html_excluded_elements.data else [] if form.html_excluded_elements.data else []
new_tenant.html_excluded_classes = [cls.strip() for cls in form.html_excluded_classes.data.split(',')] \
if form.html_excluded_classes.data else []
current_app.logger.debug(f'html_tags: {new_tenant.html_tags},' current_app.logger.debug(f'html_tags: {new_tenant.html_tags},'
f'html_end_tags: {new_tenant.html_end_tags},' f'html_end_tags: {new_tenant.html_end_tags},'
@@ -123,6 +125,8 @@ def edit_tenant(tenant_id):
form.html_included_elements.data = ', '.join(tenant.html_included_elements) form.html_included_elements.data = ', '.join(tenant.html_included_elements)
if tenant.html_excluded_elements: if tenant.html_excluded_elements:
form.html_excluded_elements.data = ', '.join(tenant.html_excluded_elements) form.html_excluded_elements.data = ', '.join(tenant.html_excluded_elements)
if tenant.html_excluded_classes:
form.html_excluded_classes.data = ', '.join(tenant.html_excluded_classes)
if form.validate_on_submit(): if form.validate_on_submit():
# Populate the tenant with form data # Populate the tenant with form data
@@ -134,6 +138,8 @@ def edit_tenant(tenant_id):
elem.strip()] elem.strip()]
tenant.html_excluded_elements = [elem.strip() for elem in form.html_excluded_elements.data.split(',') if tenant.html_excluded_elements = [elem.strip() for elem in form.html_excluded_elements.data.split(',') if
elem.strip()] elem.strip()]
tenant.html_excluded_classes = [elem.strip() for elem in form.html_excluded_classes.data.split(',') if
elem.strip()]
db.session.commit() db.session.commit()
flash('Tenant updated successfully.', 'success') flash('Tenant updated successfully.', 'success')
@@ -429,6 +435,36 @@ def generate_chat_api_key():
tenant.encrypted_chat_api_key = simple_encryption.encrypt_api_key(new_api_key) tenant.encrypted_chat_api_key = simple_encryption.encrypt_api_key(new_api_key)
update_logging_information(tenant, dt.now(tz.utc)) update_logging_information(tenant, dt.now(tz.utc))
try:
db.session.add(tenant)
db.session.commit()
except SQLAlchemyError as e:
db.session.rollback()
current_app.logger.error(f'Unable to store chat api key for tenant {tenant.id}. Error: {str(e)}')
return jsonify({'api_key': new_api_key}), 200
@user_bp.route('/check_api_api_key', methods=['POST'])
@roles_accepted('Super User', 'Tenant Admin')
def check_api_api_key():
tenant_id = session['tenant']['id']
tenant = Tenant.query.get_or_404(tenant_id)
if tenant.encrypted_api_key:
return jsonify({'api_key_exists': True})
return jsonify({'api_key_exists': False})
@user_bp.route('/generate_api_api_key', methods=['POST'])
@roles_accepted('Super User', 'Tenant Admin')
def generate_api_api_key():
tenant = Tenant.query.get_or_404(session['tenant']['id'])
new_api_key = generate_api_key(prefix="EveAI-API")
tenant.encrypted_api_key = simple_encryption.encrypt_api_key(new_api_key)
update_logging_information(tenant, dt.now(tz.utc))
try: try:
db.session.add(tenant) db.session.add(tenant)
db.session.commit() db.session.commit()

View File

@@ -0,0 +1,187 @@
import io
import os
from pydub import AudioSegment
import tempfile
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from common.extensions import minio_client
from common.utils.model_utils import create_language_template
from .processor import Processor
import subprocess
class AudioProcessor(Processor):
def __init__(self, tenant, model_variables, document_version):
super().__init__(tenant, model_variables, document_version)
self.transcription_client = model_variables['transcription_client']
self.transcription_model = model_variables['transcription_model']
self.ffmpeg_path = 'ffmpeg'
def process(self):
self._log("Starting Audio processing")
try:
file_data = minio_client.download_document_file(
self.tenant.id,
self.document_version.doc_id,
self.document_version.language,
self.document_version.id,
self.document_version.file_name
)
compressed_audio = self._compress_audio(file_data)
transcription = self._transcribe_audio(compressed_audio)
markdown, title = self._generate_markdown_from_transcription(transcription)
self._save_markdown(markdown)
self._log("Finished processing Audio")
return markdown, title
except Exception as e:
self._log(f"Error processing Audio: {str(e)}", level='error')
raise
def _compress_audio(self, audio_data):
self._log("Compressing audio")
with tempfile.NamedTemporaryFile(delete=False, suffix=f'.{self.document_version.file_type}') as temp_input:
temp_input.write(audio_data)
temp_input.flush()
# Use a unique filename for the output to avoid conflicts
output_filename = f'compressed_{os.urandom(8).hex()}.mp3'
output_path = os.path.join(tempfile.gettempdir(), output_filename)
try:
result = subprocess.run(
[self.ffmpeg_path, '-y', '-i', temp_input.name, '-b:a', '64k', '-f', 'mp3', output_path],
capture_output=True,
text=True,
check=True
)
with open(output_path, 'rb') as f:
compressed_data = f.read()
# Save compressed audio to MinIO
compressed_filename = f"{self.document_version.id}_compressed.mp3"
minio_client.upload_document_file(
self.tenant.id,
self.document_version.doc_id,
self.document_version.language,
self.document_version.id,
compressed_filename,
compressed_data
)
self._log(f"Saved compressed audio to MinIO: {compressed_filename}")
return compressed_data
except subprocess.CalledProcessError as e:
error_message = f"Compression failed: {e.stderr}"
self._log(error_message, level='error')
raise Exception(error_message)
finally:
# Clean up temporary files
os.unlink(temp_input.name)
if os.path.exists(output_path):
os.unlink(output_path)
def _transcribe_audio(self, audio_data):
self._log("Starting audio transcription")
audio = AudioSegment.from_file(io.BytesIO(audio_data), format="mp3")
segment_length = 10 * 60 * 1000 # 10 minutes in milliseconds
transcriptions = []
for i, chunk in enumerate(audio[::segment_length]):
self._log(f'Processing chunk {i + 1} of {len(audio) // segment_length + 1}')
with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as temp_audio:
chunk.export(temp_audio.name, format="mp3")
temp_audio.flush()
try:
file_size = os.path.getsize(temp_audio.name)
self._log(f"Temporary audio file size: {file_size} bytes")
with open(temp_audio.name, 'rb') as audio_file:
file_start = audio_file.read(100)
self._log(f"First 100 bytes of audio file: {file_start}")
audio_file.seek(0) # Reset file pointer to the beginning
self._log("Calling transcription API")
transcription = self.transcription_client.audio.transcriptions.create(
file=audio_file,
model=self.transcription_model,
language=self.document_version.language,
response_format='verbose_json',
)
self._log("Transcription API call completed")
if transcription:
# Handle the transcription result based on its type
if isinstance(transcription, str):
self._log(f"Transcription result (string): {transcription[:100]}...")
transcriptions.append(transcription)
elif hasattr(transcription, 'text'):
self._log(
f"Transcription result (object with 'text' attribute): {transcription.text[:100]}...")
transcriptions.append(transcription.text)
else:
self._log(f"Transcription result (unknown type): {str(transcription)[:100]}...")
transcriptions.append(str(transcription))
else:
self._log("Warning: Received empty transcription", level='warning')
except Exception as e:
self._log(f"Error during transcription: {str(e)}", level='error')
finally:
os.unlink(temp_audio.name)
full_transcription = " ".join(filter(None, transcriptions))
if not full_transcription:
self._log("Warning: No transcription was generated", level='warning')
full_transcription = "No transcription available."
# Save transcription to MinIO
transcription_filename = f"{self.document_version.id}_transcription.txt"
minio_client.upload_document_file(
self.tenant.id,
self.document_version.doc_id,
self.document_version.language,
self.document_version.id,
transcription_filename,
full_transcription.encode('utf-8')
)
self._log(f"Saved transcription to MinIO: {transcription_filename}")
return full_transcription
def _generate_markdown_from_transcription(self, transcription):
self._log("Generating markdown from transcription")
llm = self.model_variables['llm']
template = self.model_variables['transcript_template']
language_template = create_language_template(template, self.document_version.language)
transcript_prompt = ChatPromptTemplate.from_template(language_template)
setup = RunnablePassthrough()
output_parser = StrOutputParser()
chain = setup | transcript_prompt | llm | output_parser
input_transcript = {'transcript': transcription}
markdown = chain.invoke(input_transcript)
# Extract title from the markdown
title = self._extract_title_from_markdown(markdown)
return markdown, title
def _extract_title_from_markdown(self, markdown):
# Simple extraction of the first header as the title
lines = markdown.split('\n')
for line in lines:
if line.startswith('# '):
return line[2:].strip()
return "Untitled Audio Transcription"

View File

@@ -0,0 +1,143 @@
from bs4 import BeautifulSoup
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from common.extensions import db, minio_client
from common.utils.model_utils import create_language_template
from .processor import Processor
class HTMLProcessor(Processor):
def __init__(self, tenant, model_variables, document_version):
super().__init__(tenant, model_variables, document_version)
self.html_tags = model_variables['html_tags']
self.html_end_tags = model_variables['html_end_tags']
self.html_included_elements = model_variables['html_included_elements']
self.html_excluded_elements = model_variables['html_excluded_elements']
def process(self):
self._log("Starting HTML processing")
try:
file_data = minio_client.download_document_file(
self.tenant.id,
self.document_version.doc_id,
self.document_version.language,
self.document_version.id,
self.document_version.file_name
)
html_content = file_data.decode('utf-8')
extracted_html, title = self._parse_html(html_content)
markdown = self._generate_markdown_from_html(extracted_html)
self._save_markdown(markdown)
self._log("Finished processing HTML")
return markdown, title
except Exception as e:
self._log(f"Error processing HTML: {str(e)}", level='error')
raise
def _parse_html(self, html_content):
self._log(f'Parsing HTML for tenant {self.tenant.id}')
soup = BeautifulSoup(html_content, 'html.parser')
extracted_html = ''
excluded_classes = self._parse_excluded_classes(self.tenant.html_excluded_classes)
if self.html_included_elements:
elements_to_parse = soup.find_all(self.html_included_elements)
else:
elements_to_parse = [soup]
for element in elements_to_parse:
for sub_element in element.find_all(self.html_tags):
if self._should_exclude_element(sub_element, excluded_classes):
continue
extracted_html += self._extract_element_content(sub_element)
title = soup.find('title').get_text(strip=True) if soup.find('title') else ''
self._log(f'Finished parsing HTML for tenant {self.tenant.id}')
return extracted_html, title
def _generate_markdown_from_html(self, html_content):
self._log(f'Generating markdown from HTML for tenant {self.tenant.id}')
llm = self.model_variables['llm']
template = self.model_variables['html_parse_template']
parse_prompt = ChatPromptTemplate.from_template(template)
setup = RunnablePassthrough()
output_parser = StrOutputParser()
chain = setup | parse_prompt | llm | output_parser
soup = BeautifulSoup(html_content, 'lxml')
chunks = self._split_content(soup)
markdown_chunks = []
for chunk in chunks:
if self.embed_tuning:
self._log(f'Processing chunk: \n{chunk}\n')
input_html = {"html": chunk}
markdown_chunk = chain.invoke(input_html)
markdown_chunks.append(markdown_chunk)
if self.embed_tuning:
self._log(f'Processed markdown chunk: \n{markdown_chunk}\n')
markdown = "\n\n".join(markdown_chunks)
self._log(f'Finished generating markdown from HTML for tenant {self.tenant.id}')
return markdown
def _split_content(self, soup, max_size=20000):
chunks = []
current_chunk = []
current_size = 0
for element in soup.find_all(['h1', 'h2', 'h3', 'h4', 'h5', 'h6', 'p', 'div', 'span', 'table']):
element_html = str(element)
element_size = len(element_html)
if current_size + element_size > max_size and current_chunk:
chunks.append(''.join(map(str, current_chunk)))
current_chunk = []
current_size = 0
current_chunk.append(element)
current_size += element_size
if element.name in ['h1', 'h2', 'h3'] and current_size > max_size:
chunks.append(''.join(map(str, current_chunk)))
current_chunk = []
current_size = 0
if current_chunk:
chunks.append(''.join(map(str, current_chunk)))
return chunks
def _parse_excluded_classes(self, excluded_classes):
parsed = {}
if excluded_classes:
for rule in excluded_classes:
element, cls = rule.split('.', 1)
parsed.setdefault(element, set()).add(cls)
return parsed
def _should_exclude_element(self, element, excluded_classes):
if self.html_excluded_elements and element.find_parent(self.html_excluded_elements):
return True
return self._is_element_excluded_by_class(element, excluded_classes)
def _is_element_excluded_by_class(self, element, excluded_classes):
for parent in element.parents:
if self._element_matches_exclusion(parent, excluded_classes):
return True
return self._element_matches_exclusion(element, excluded_classes)
def _element_matches_exclusion(self, element, excluded_classes):
if '*' in excluded_classes and any(cls in excluded_classes['*'] for cls in element.get('class', [])):
return True
return element.name in excluded_classes and \
any(cls in excluded_classes[element.name] for cls in element.get('class', []))
def _extract_element_content(self, element):
content = ' '.join(child.strip() for child in element.stripped_strings)
return f'<{element.name}>{content}</{element.name}>\n'

View File

@@ -0,0 +1,239 @@
import io
import pdfplumber
from flask import current_app
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
import re
from langchain_core.runnables import RunnablePassthrough
from common.extensions import minio_client
from common.utils.model_utils import create_language_template
from .processor import Processor
class PDFProcessor(Processor):
def __init__(self, tenant, model_variables, document_version):
super().__init__(tenant, model_variables, document_version)
# PDF-specific initialization
self.chunk_size = model_variables['PDF_chunk_size']
self.chunk_overlap = model_variables['PDF_chunk_overlap']
self.min_chunk_size = model_variables['PDF_min_chunk_size']
self.max_chunk_size = model_variables['PDF_max_chunk_size']
def process(self):
self._log("Starting PDF processing")
try:
file_data = minio_client.download_document_file(
self.tenant.id,
self.document_version.doc_id,
self.document_version.language,
self.document_version.id,
self.document_version.file_name
)
extracted_content = self._extract_content(file_data)
structured_content, title = self._structure_content(extracted_content)
llm_chunks = self._split_content_for_llm(structured_content)
markdown = self._process_chunks_with_llm(llm_chunks)
self._save_markdown(markdown)
self._log("Finished processing PDF")
return markdown, title
except Exception as e:
self._log(f"Error processing PDF: {str(e)}", level='error')
raise
def _extract_content(self, file_data):
extracted_content = []
with pdfplumber.open(io.BytesIO(file_data)) as pdf:
figure_counter = 1
for page_num, page in enumerate(pdf.pages):
self._log(f"Extracting content from page {page_num + 1}")
page_content = {
'text': page.extract_text(),
'figures': self._extract_figures(page, page_num, figure_counter),
'tables': self._extract_tables(page)
}
if self.embed_tuning:
self._log(f'Extracted PDF Content for page {page_num + 1}')
self._log(f"{page_content }")
figure_counter += len(page_content['figures'])
extracted_content.append(page_content)
# if self.embed_tuning:
# current_app.embed_tuning_logger.debug(f'Extracted PDF Content')
# current_app.embed_tuning_logger.debug(f'---------------------')
# current_app.embed_tuning_logger.debug(f'Page: {page_content}')
# current_app.embed_tuning_logger.debug(f'End of Extracted PDF Content')
# current_app.embed_tuning_logger.debug(f'----------------------------')
return extracted_content
def _extract_figures(self, page, page_num, figure_counter):
figures = []
# Omit figure processing for now!
# for img in page.images:
# try:
# # Try to get the bbox, use full page dimensions if not available
# bbox = img.get('bbox', (0, 0, page.width, page.height))
#
# figure = {
# 'figure_number': figure_counter,
# 'filename': f"figure_{page_num + 1}_{figure_counter}.png",
# 'caption': self._find_figure_caption(page, bbox)
# }
#
# # Extract the figure as an image
# figure_image = page.within_bbox(bbox).to_image()
#
# # Save the figure using MinIO
# with io.BytesIO() as output:
# figure_image.save(output, format='PNG')
# output.seek(0)
# minio_client.upload_document_file(
# self.tenant.id,
# self.document_version.doc_id,
# self.document_version.language,
# self.document_version.id,
# figure['filename'],
# output.getvalue()
# )
#
# figures.append(figure)
# figure_counter += 1
# except Exception as e:
# self._log(f"Error processing figure on page {page_num + 1}: {str(e)}", level='error')
return figures
def _find_figure_caption(self, page, bbox):
try:
# Look for text below the figure
caption_bbox = (bbox[0], bbox[3], bbox[2], min(bbox[3] + 50, page.height))
caption_text = page.crop(caption_bbox).extract_text()
if caption_text and caption_text.lower().startswith('figure'):
return caption_text
except Exception as e:
self._log(f"Error finding figure caption: {str(e)}", level='error')
return None
def _extract_tables(self, page):
tables = []
try:
for table in page.extract_tables():
if table:
markdown_table = self._table_to_markdown(table)
if markdown_table: # Only add non-empty tables
tables.append(markdown_table)
except Exception as e:
self._log(f"Error extracting tables from page: {str(e)}", level='error')
return tables
def _table_to_markdown(self, table):
if not table or not table[0]: # Check if table is empty or first row is empty
return "" # Return empty string for empty tables
def clean_cell(cell):
if cell is None:
return "" # Convert None to empty string
return str(cell).replace("|", "\\|") # Escape pipe characters and convert to string
header = [clean_cell(cell) for cell in table[0]]
markdown = "| " + " | ".join(header) + " |\n"
markdown += "| " + " | ".join(["---"] * len(header)) + " |\n"
for row in table[1:]:
cleaned_row = [clean_cell(cell) for cell in row]
markdown += "| " + " | ".join(cleaned_row) + " |\n"
return markdown
def _structure_content(self, extracted_content):
structured_content = ""
title = "Untitled Document"
current_heading_level = 0
heading_pattern = re.compile(r'^(\d+(\.\d+)*\.?\s*)?(.+)$')
def identify_heading(text):
match = heading_pattern.match(text.strip())
if match:
numbering, _, content = match.groups()
if numbering:
level = numbering.count('.') + 1
return level, f"{numbering}{content}"
else:
return 1, content # Assume it's a top-level heading if no numbering
return 0, text # Not a heading
for page in extracted_content:
# Assume the title is on the first page
if page == extracted_content[0]:
lines = page.get('text', '').split('\n')
if lines:
title = lines[0].strip() # Use the first non-empty line as the title
# Process text
paragraphs = page['text'].split('\n\n')
for para in paragraphs:
lines = para.strip().split('\n')
if len(lines) == 1: # Potential heading
level, text = identify_heading(lines[0])
if level > 0:
heading_marks = '#' * level
structured_content += f"\n\n{heading_marks} {text}\n\n"
if level == 1 and not title:
title = text # Use the first top-level heading as the title if not set
else:
structured_content += f"{para}\n\n" # Treat as normal paragraph
else:
structured_content += f"{para}\n\n" # Multi-line paragraph
# Process figures
for figure in page.get('figures', []):
structured_content += f"\n\n![Figure {figure['figure_number']}]({figure['filename']})\n\n"
if figure['caption']:
structured_content += f"*Figure {figure['figure_number']}: {figure['caption']}*\n\n"
# Add tables
if 'tables' in page:
for table in page['tables']:
structured_content += f"\n{table}\n"
if self.embed_tuning:
self._save_intermediate(structured_content, "structured_content.md")
return structured_content, title
def _split_content_for_llm(self, content):
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=self.chunk_size,
chunk_overlap=self.chunk_overlap,
length_function=len,
separators=["\n\n", "\n", " ", ""]
)
return text_splitter.split_text(content)
def _process_chunks_with_llm(self, chunks):
llm = self.model_variables['llm']
template = self.model_variables['pdf_parse_template']
pdf_prompt = ChatPromptTemplate.from_template(template)
setup = RunnablePassthrough()
output_parser = StrOutputParser()
chain = setup | pdf_prompt | llm | output_parser
markdown_chunks = []
for chunk in chunks:
input = {"pdf_content": chunk}
result = chain.invoke(input)
# Remove Markdown code block delimiters if present
result = result.strip()
if result.startswith("```markdown"):
result = result[len("```markdown"):].strip()
if result.endswith("```"):
result = result[:-3].strip()
markdown_chunks.append(result)
return "\n\n".join(markdown_chunks)

View File

@@ -0,0 +1,42 @@
from abc import ABC, abstractmethod
from flask import current_app
from common.extensions import minio_client
class Processor(ABC):
def __init__(self, tenant, model_variables, document_version):
self.tenant = tenant
self.model_variables = model_variables
self.document_version = document_version
self.embed_tuning = model_variables['embed_tuning']
@abstractmethod
def process(self):
pass
def _save_markdown(self, markdown):
markdown_filename = f"{self.document_version.id}.md"
minio_client.upload_document_file(
self.tenant.id,
self.document_version.doc_id,
self.document_version.language,
self.document_version.id,
markdown_filename,
markdown.encode('utf-8')
)
def _log(self, message, level='debug'):
logger = current_app.logger
log_method = getattr(logger, level)
log_method(
f"{self.__class__.__name__} - Tenant {self.tenant.id}, Document {self.document_version.id}: {message}")
def _save_intermediate(self, content, filename):
minio_client.upload_document_file(
self.tenant.id,
self.document_version.doc_id,
self.document_version.language,
self.document_version.id,
filename,
content.encode('utf-8')
)

View File

@@ -0,0 +1,80 @@
import re
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from common.extensions import minio_client
from common.utils.model_utils import create_language_template
from .processor import Processor
class SRTProcessor(Processor):
def __init__(self, tenant, model_variables, document_version):
super().__init__(tenant, model_variables, document_version)
def process(self):
self._log("Starting SRT processing")
try:
file_data = minio_client.download_document_file(
self.tenant.id,
self.document_version.doc_id,
self.document_version.language,
self.document_version.id,
self.document_version.file_name
)
srt_content = file_data.decode('utf-8')
cleaned_transcription = self._clean_srt(srt_content)
markdown, title = self._generate_markdown_from_transcription(cleaned_transcription)
self._save_markdown(markdown)
self._log("Finished processing SRT")
return markdown, title
except Exception as e:
self._log(f"Error processing SRT: {str(e)}", level='error')
raise
def _clean_srt(self, srt_content):
# Remove timecodes and subtitle numbers
cleaned_lines = []
for line in srt_content.split('\n'):
# Skip empty lines, subtitle numbers, and timecodes
if line.strip() and not line.strip().isdigit() and not re.match(
r'\d{2}:\d{2}:\d{2},\d{3} --> \d{2}:\d{2}:\d{2},\d{3}', line):
cleaned_lines.append(line.strip())
# Join the cleaned lines
cleaned_text = ' '.join(cleaned_lines)
# Remove any extra spaces
cleaned_text = re.sub(r'\s+', ' ', cleaned_text).strip()
return cleaned_text
def _generate_markdown_from_transcription(self, transcription):
self._log("Generating markdown from transcription")
llm = self.model_variables['llm']
template = self.model_variables['transcript_template']
language_template = create_language_template(template, self.document_version.language)
transcript_prompt = ChatPromptTemplate.from_template(language_template)
setup = RunnablePassthrough()
output_parser = StrOutputParser()
chain = setup | transcript_prompt | llm | output_parser
input_transcript = {'transcript': transcription}
markdown = chain.invoke(input_transcript)
# Extract title from the markdown
title = self._extract_title_from_markdown(markdown)
return markdown, title
def _extract_title_from_markdown(self, markdown):
# Simple extraction of the first header as the title
lines = markdown.split('\n')
for line in lines:
if line.startswith('# '):
return line[2:].strip()
return "Untitled SRT Transcription"

View File

@@ -1,26 +1,16 @@
import io import io
import os import os
from datetime import datetime as dt, timezone as tz from datetime import datetime as dt, timezone as tz
import subprocess
import gevent
from bs4 import BeautifulSoup
import html
from celery import states from celery import states
from flask import current_app from flask import current_app
# OpenAI imports # OpenAI imports
from langchain.chains.summarize import load_summarize_chain from langchain.text_splitter import MarkdownHeaderTextSplitter
from langchain.text_splitter import CharacterTextSplitter, MarkdownHeaderTextSplitter
from langchain_core.exceptions import LangChainException from langchain_core.exceptions import LangChainException
from langchain_core.output_parsers import StrOutputParser from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough from langchain_core.runnables import RunnablePassthrough
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
from pytube import YouTube
import PyPDF2
from pydub import AudioSegment
import tempfile
from common.extensions import db, minio_client from common.extensions import db, minio_client
from common.models.document import DocumentVersion, Embedding from common.models.document import DocumentVersion, Embedding
@@ -29,6 +19,10 @@ from common.utils.celery_utils import current_celery
from common.utils.database import Database from common.utils.database import Database
from common.utils.model_utils import select_model_variables, create_language_template from common.utils.model_utils import select_model_variables, create_language_template
from common.utils.os_utils import safe_remove, sync_folder from common.utils.os_utils import safe_remove, sync_folder
from eveai_workers.Processors.audio_processor import AudioProcessor
from eveai_workers.Processors.html_processor import HTMLProcessor
from eveai_workers.Processors.pdf_processor import PDFProcessor
from eveai_workers.Processors.srt_processor import SRTProcessor
@current_celery.task(name='create_embeddings', queue='embeddings') @current_celery.task(name='create_embeddings', queue='embeddings')
@@ -84,8 +78,10 @@ def create_embeddings(tenant_id, document_version_id):
process_pdf(tenant, model_variables, document_version) process_pdf(tenant, model_variables, document_version)
case 'html': case 'html':
process_html(tenant, model_variables, document_version) process_html(tenant, model_variables, document_version)
case 'youtube': case 'srt':
process_youtube(tenant, model_variables, document_version) process_srt(tenant, model_variables, document_version)
case 'mp4' | 'mp3' | 'ogg':
process_audio(tenant, model_variables, document_version)
case _: case _:
raise Exception(f'No functionality defined for file type {document_version.file_type} ' raise Exception(f'No functionality defined for file type {document_version.file_type} '
f'for tenant {tenant_id} ' f'for tenant {tenant_id} '
@@ -103,49 +99,6 @@ def create_embeddings(tenant_id, document_version_id):
raise raise
def process_pdf(tenant, model_variables, document_version):
file_data = minio_client.download_document_file(tenant.id, document_version.doc_id, document_version.language,
document_version.id, document_version.file_name)
pdf_text = ''
pdf_reader = PyPDF2.PdfReader(io.BytesIO(file_data))
for page in pdf_reader.pages:
pdf_text += page.extract_text()
markdown = generate_markdown_from_pdf(tenant, model_variables, document_version, pdf_text)
markdown_file_name = f'{document_version.id}.md'
minio_client.upload_document_file(tenant.id, document_version.doc_id, document_version.language, document_version.id,
markdown_file_name, markdown.encode())
potential_chunks = create_potential_chunks_for_markdown(tenant.id, document_version, markdown_file_name)
chunks = combine_chunks_for_markdown(potential_chunks, model_variables['min_chunk_size'],
model_variables['max_chunk_size'])
if len(chunks) > 1:
summary = summarize_chunk(tenant, model_variables, document_version, chunks[0])
document_version.system_context = f'Summary: {summary}\n'
else:
document_version.system_context = ''
enriched_chunks = enrich_chunks(tenant, document_version, chunks)
embeddings = embed_chunks(tenant, model_variables, document_version, enriched_chunks)
try:
db.session.add(document_version)
document_version.processing_finished_at = dt.now(tz.utc)
document_version.processing = False
db.session.add_all(embeddings)
db.session.commit()
except SQLAlchemyError as e:
current_app.logger.error(f'Error saving embedding information for tenant {tenant.id} '
f'on HTML, document version {document_version.id}'
f'error: {e}')
raise
current_app.logger.info(f'Embeddings created successfully for tenant {tenant.id} '
f'on document version {document_version.id} :-)')
def delete_embeddings_for_document_version(document_version): def delete_embeddings_for_document_version(document_version):
embeddings_to_delete = db.session.query(Embedding).filter_by(doc_vers_id=document_version.id).all() embeddings_to_delete = db.session.query(Embedding).filter_by(doc_vers_id=document_version.id).all()
for embedding in embeddings_to_delete: for embedding in embeddings_to_delete:
@@ -158,43 +111,53 @@ def delete_embeddings_for_document_version(document_version):
raise raise
def process_pdf(tenant, model_variables, document_version):
processor = PDFProcessor(tenant, model_variables, document_version)
markdown, title = processor.process()
# Process markdown and embed
embed_markdown(tenant, model_variables, document_version, markdown, title)
def process_html(tenant, model_variables, document_version): def process_html(tenant, model_variables, document_version):
file_data = minio_client.download_document_file(tenant.id, document_version.doc_id, document_version.language, processor = HTMLProcessor(tenant, model_variables, document_version)
document_version.id, document_version.file_name) markdown, title = processor.process()
html_content = file_data.decode('utf-8')
# The tags to be considered can be dependent on the tenant # Process markdown and embed
html_tags = model_variables['html_tags'] embed_markdown(tenant, model_variables, document_version, markdown, title)
html_end_tags = model_variables['html_end_tags']
html_included_elements = model_variables['html_included_elements']
html_excluded_elements = model_variables['html_excluded_elements']
extracted_html, title = parse_html(tenant, html_content, html_tags, included_elements=html_included_elements,
excluded_elements=html_excluded_elements)
extracted_file_name = f'{document_version.id}-extracted.html' def process_audio(tenant, model_variables, document_version):
minio_client.upload_document_file(tenant.id, document_version.doc_id, document_version.language, document_version.id, processor = AudioProcessor(tenant, model_variables, document_version)
extracted_file_name, extracted_html.encode()) markdown, title = processor.process()
markdown = generate_markdown_from_html(tenant, model_variables, document_version, extracted_html) # Process markdown and embed
markdown_file_name = f'{document_version.id}.md' embed_markdown(tenant, model_variables, document_version, markdown, title)
minio_client.upload_document_file(tenant.id, document_version.doc_id, document_version.language, document_version.id,
markdown_file_name, markdown.encode())
potential_chunks = create_potential_chunks_for_markdown(tenant.id, document_version, markdown_file_name)
def process_srt(tenant, model_variables, document_version):
processor = SRTProcessor(tenant, model_variables, document_version)
markdown, title = processor.process()
# Process markdown and embed
embed_markdown(tenant, model_variables, document_version, markdown, title)
def embed_markdown(tenant, model_variables, document_version, markdown, title):
# Create potential chunks
potential_chunks = create_potential_chunks_for_markdown(tenant.id, document_version, f"{document_version.id}.md")
# Combine chunks for embedding
chunks = combine_chunks_for_markdown(potential_chunks, model_variables['min_chunk_size'], chunks = combine_chunks_for_markdown(potential_chunks, model_variables['min_chunk_size'],
model_variables['max_chunk_size']) model_variables['max_chunk_size'])
if len(chunks) > 1: # Enrich chunks
summary = summarize_chunk(tenant, model_variables, document_version, chunks[0]) enriched_chunks = enrich_chunks(tenant, model_variables, document_version, title, chunks)
document_version.system_context = (f'Title: {title}\n'
f'Summary: {summary}\n')
else:
document_version.system_context = (f'Title: {title}\n')
enriched_chunks = enrich_chunks(tenant, document_version, title, chunks) # Create embeddings
embeddings = embed_chunks(tenant, model_variables, document_version, enriched_chunks) embeddings = embed_chunks(tenant, model_variables, document_version, enriched_chunks)
# Update document version and save embeddings
try: try:
db.session.add(document_version) db.session.add(document_version)
document_version.processing_finished_at = dt.now(tz.utc) document_version.processing_finished_at = dt.now(tz.utc)
@@ -211,12 +174,18 @@ def process_html(tenant, model_variables, document_version):
f'on document version {document_version.id} :-)') f'on document version {document_version.id} :-)')
def enrich_chunks(tenant, document_version, title, chunks): def enrich_chunks(tenant, model_variables, document_version, title, chunks):
current_app.logger.debug(f'Enriching chunks for tenant {tenant.id} ' current_app.logger.debug(f'Enriching chunks for tenant {tenant.id} '
f'on document version {document_version.id}') f'on document version {document_version.id}')
current_app.logger.debug(f'Nr of chunks: {len(chunks)}')
summary = ''
if len(chunks) > 1:
summary = summarize_chunk(tenant, model_variables, document_version, chunks[0])
chunk_total_context = (f'Filename: {document_version.file_name}\n' chunk_total_context = (f'Filename: {document_version.file_name}\n'
f'User Context:\n{document_version.user_context}\n\n' f'User Context:\n{document_version.user_context}\n\n'
f'Title: {title}\n'
f'{summary}\n'
f'{document_version.system_context}\n\n') f'{document_version.system_context}\n\n')
enriched_chunks = [] enriched_chunks = []
initial_chunk = (f'Filename: {document_version.file_name}\n' initial_chunk = (f'Filename: {document_version.file_name}\n'
@@ -235,40 +204,6 @@ def enrich_chunks(tenant, document_version, title, chunks):
return enriched_chunks return enriched_chunks
def generate_markdown_from_html(tenant, model_variables, document_version, html_content):
current_app.logger.debug(f'Generating Markdown from HTML for tenant {tenant.id} '
f'on document version {document_version.id}')
llm = model_variables['llm']
template = model_variables['html_parse_template']
parse_prompt = ChatPromptTemplate.from_template(template)
setup = RunnablePassthrough()
output_parser = StrOutputParser()
chain = setup | parse_prompt | llm | output_parser
input_html = {"html": html_content}
markdown = chain.invoke(input_html)
return markdown
def generate_markdown_from_pdf(tenant, model_variables, document_version, pdf_content):
current_app.logger.debug(f'Generating Markdown from PDF for tenant {tenant.id} '
f'on document version {document_version.id}')
llm = model_variables['llm']
template = model_variables['pdf_parse_template']
parse_prompt = ChatPromptTemplate.from_template(template)
setup = RunnablePassthrough()
output_parser = StrOutputParser()
chain = setup | parse_prompt | llm | output_parser
input_pdf = {"pdf_content": pdf_content}
markdown = chain.invoke(input_pdf)
return markdown
def summarize_chunk(tenant, model_variables, document_version, chunk): def summarize_chunk(tenant, model_variables, document_version, chunk):
current_app.logger.debug(f'Summarizing chunk for tenant {tenant.id} ' current_app.logger.debug(f'Summarizing chunk for tenant {tenant.id} '
f'on document version {document_version.id}') f'on document version {document_version.id}')
@@ -323,274 +258,252 @@ def embed_chunks(tenant, model_variables, document_version, chunks):
return new_embeddings return new_embeddings
def parse_html(tenant, html_content, tags, included_elements=None, excluded_elements=None): def log_parsing_info(tenant, tags, included_elements, excluded_elements, excluded_classes, elements_to_parse):
soup = BeautifulSoup(html_content, 'html.parser')
extracted_html = ''
if included_elements:
elements_to_parse = soup.find_all(included_elements)
else:
elements_to_parse = [soup] # parse the entire document if no included_elements specified
if tenant.embed_tuning: if tenant.embed_tuning:
current_app.embed_tuning_logger.debug(f'Tags to parse: {tags}') current_app.embed_tuning_logger.debug(f'Tags to parse: {tags}')
current_app.embed_tuning_logger.debug(f'Included Elements: {included_elements}') current_app.embed_tuning_logger.debug(f'Included Elements: {included_elements}')
current_app.embed_tuning_logger.debug(f'Included Elements: {len(included_elements)}')
current_app.embed_tuning_logger.debug(f'Excluded Elements: {excluded_elements}') current_app.embed_tuning_logger.debug(f'Excluded Elements: {excluded_elements}')
current_app.embed_tuning_logger.debug(f'Excluded Classes: {excluded_classes}')
current_app.embed_tuning_logger.debug(f'Found {len(elements_to_parse)} elements to parse') current_app.embed_tuning_logger.debug(f'Found {len(elements_to_parse)} elements to parse')
current_app.embed_tuning_logger.debug(f'First element to parse: {elements_to_parse[0]}') current_app.embed_tuning_logger.debug(f'First element to parse: {elements_to_parse[0]}')
# Iterate through the found included elements
for element in elements_to_parse:
# Find all specified tags within each included element
for sub_element in element.find_all(tags):
if tenant.embed_tuning:
current_app.embed_tuning_logger.debug(f'Found element: {sub_element.name}')
if excluded_elements and sub_element.find_parent(excluded_elements):
continue # Skip this sub_element if it's within any of the excluded_elements
extracted_html += f'<{sub_element.name}>{sub_element.get_text(strip=True)}</{sub_element.name}>\n'
title = soup.find('title').get_text(strip=True) # def process_youtube(tenant, model_variables, document_version):
# download_file_name = f'{document_version.id}.mp4'
return extracted_html, title # compressed_file_name = f'{document_version.id}.mp3'
# transcription_file_name = f'{document_version.id}.txt'
# markdown_file_name = f'{document_version.id}.md'
def process_youtube(tenant, model_variables, document_version): #
base_path = os.path.join(current_app.config['UPLOAD_FOLDER'], # # Remove existing files (in case of a re-processing of the file
document_version.file_location) # minio_client.delete_document_file(tenant.id, document_version.doc_id, document_version.language,
download_file_name = f'{document_version.id}.mp4' # document_version.id, download_file_name)
compressed_file_name = f'{document_version.id}.mp3' # minio_client.delete_document_file(tenant.id, document_version.doc_id, document_version.language,
transcription_file_name = f'{document_version.id}.txt' # document_version.id, compressed_file_name)
markdown_file_name = f'{document_version.id}.md' # minio_client.delete_document_file(tenant.id, document_version.doc_id, document_version.language,
# document_version.id, transcription_file_name)
# Remove existing files (in case of a re-processing of the file # minio_client.delete_document_file(tenant.id, document_version.doc_id, document_version.language,
minio_client.delete_document_file(tenant.id, document_version.doc_id, document_version.language, # document_version.id, markdown_file_name)
document_version.id, download_file_name) #
minio_client.delete_document_file(tenant.id, document_version.doc_id, document_version.language, # of, title, description, author = download_youtube(document_version.url, tenant.id, document_version,
document_version.id, compressed_file_name) # download_file_name)
minio_client.delete_document_file(tenant.id, document_version.doc_id, document_version.language, # document_version.system_context = f'Title: {title}\nDescription: {description}\nAuthor: {author}'
document_version.id, transcription_file_name) # compress_audio(tenant.id, document_version, download_file_name, compressed_file_name)
minio_client.delete_document_file(tenant.id, document_version.doc_id, document_version.language, # transcribe_audio(tenant.id, document_version, compressed_file_name, transcription_file_name, model_variables)
document_version.id, markdown_file_name) # annotate_transcription(tenant, document_version, transcription_file_name, markdown_file_name, model_variables)
#
of, title, description, author = download_youtube(document_version.url, tenant.id, document_version, # potential_chunks = create_potential_chunks_for_markdown(tenant.id, document_version, markdown_file_name)
download_file_name) # actual_chunks = combine_chunks_for_markdown(potential_chunks, model_variables['min_chunk_size'],
document_version.system_context = f'Title: {title}\nDescription: {description}\nAuthor: {author}' # model_variables['max_chunk_size'])
compress_audio(tenant.id, document_version, download_file_name, compressed_file_name) #
transcribe_audio(tenant.id, document_version, compressed_file_name, transcription_file_name, model_variables) # enriched_chunks = enrich_chunks(tenant, document_version, actual_chunks)
annotate_transcription(tenant, document_version, transcription_file_name, markdown_file_name, model_variables) # embeddings = embed_chunks(tenant, model_variables, document_version, enriched_chunks)
#
potential_chunks = create_potential_chunks_for_markdown(tenant.id, document_version, markdown_file_name) # try:
actual_chunks = combine_chunks_for_markdown(potential_chunks, model_variables['min_chunk_size'], # db.session.add(document_version)
model_variables['max_chunk_size']) # document_version.processing_finished_at = dt.now(tz.utc)
# document_version.processing = False
enriched_chunks = enrich_chunks(tenant, document_version, actual_chunks) # db.session.add_all(embeddings)
embeddings = embed_chunks(tenant, model_variables, document_version, enriched_chunks) # db.session.commit()
# except SQLAlchemyError as e:
try: # current_app.logger.error(f'Error saving embedding information for tenant {tenant.id} '
db.session.add(document_version) # f'on Youtube document version {document_version.id}'
document_version.processing_finished_at = dt.now(tz.utc) # f'error: {e}')
document_version.processing = False # raise
db.session.add_all(embeddings) #
db.session.commit() # current_app.logger.info(f'Embeddings created successfully for tenant {tenant.id} '
except SQLAlchemyError as e: # f'on Youtube document version {document_version.id} :-)')
current_app.logger.error(f'Error saving embedding information for tenant {tenant.id} ' #
f'on Youtube document version {document_version.id}' #
f'error: {e}') # def download_youtube(url, tenant_id, document_version, file_name):
raise # try:
# current_app.logger.info(f'Downloading YouTube video: {url} for tenant: {tenant_id}')
current_app.logger.info(f'Embeddings created successfully for tenant {tenant.id} ' # yt = YouTube(url)
f'on Youtube document version {document_version.id} :-)') # stream = yt.streams.get_audio_only()
#
# with tempfile.NamedTemporaryFile(delete=False) as temp_file:
def download_youtube(url, tenant_id, document_version, file_name): # stream.download(output_path=temp_file.name)
try: # with open(temp_file.name, 'rb') as f:
current_app.logger.info(f'Downloading YouTube video: {url} for tenant: {tenant_id}') # file_data = f.read()
yt = YouTube(url) #
stream = yt.streams.get_audio_only() # minio_client.upload_document_file(tenant_id, document_version.doc_id, document_version.language,
# document_version.id,
with tempfile.NamedTemporaryFile(delete=False) as temp_file: # file_name, file_data)
stream.download(output_path=temp_file.name) #
with open(temp_file.name, 'rb') as f: # current_app.logger.info(f'Downloaded YouTube video: {url} for tenant: {tenant_id}')
file_data = f.read() # return file_name, yt.title, yt.description, yt.author
# except Exception as e:
minio_client.upload_document_file(tenant_id, document_version.doc_id, document_version.language, document_version.id, # current_app.logger.error(f'Error downloading YouTube video: {url} for tenant: {tenant_id} with error: {e}')
file_name, file_data) # raise
#
current_app.logger.info(f'Downloaded YouTube video: {url} for tenant: {tenant_id}') #
return file_name, yt.title, yt.description, yt.author # def compress_audio(tenant_id, document_version, input_file, output_file):
except Exception as e: # try:
current_app.logger.error(f'Error downloading YouTube video: {url} for tenant: {tenant_id} with error: {e}') # current_app.logger.info(f'Compressing audio for tenant: {tenant_id}')
raise #
# input_data = minio_client.download_document_file(tenant_id, document_version.doc_id, document_version.language,
# document_version.id, input_file)
def compress_audio(tenant_id, document_version, input_file, output_file): #
try: # with tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as temp_input:
current_app.logger.info(f'Compressing audio for tenant: {tenant_id}') # temp_input.write(input_data)
# temp_input.flush()
input_data = minio_client.download_document_file(tenant_id, document_version.doc_id, document_version.language, #
document_version.id, input_file) # with tempfile.NamedTemporaryFile(delete=False, suffix='.mp3') as temp_output:
# result = subprocess.run(
with tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as temp_input: # ['ffmpeg', '-i', temp_input.name, '-b:a', '64k', '-f', 'mp3', temp_output.name],
temp_input.write(input_data) # capture_output=True,
temp_input.flush() # text=True
# )
with tempfile.NamedTemporaryFile(delete=False, suffix='.mp3') as temp_output: #
result = subprocess.run( # if result.returncode != 0:
['ffmpeg', '-i', temp_input.name, '-b:a', '64k', '-f', 'mp3', temp_output.name], # raise Exception(f"Compression failed: {result.stderr}")
capture_output=True, #
text=True # with open(temp_output.name, 'rb') as f:
) # compressed_data = f.read()
#
if result.returncode != 0: # minio_client.upload_document_file(tenant_id, document_version.doc_id, document_version.language,
raise Exception(f"Compression failed: {result.stderr}") # document_version.id,
# output_file, compressed_data)
with open(temp_output.name, 'rb') as f: #
compressed_data = f.read() # current_app.logger.info(f'Compressed audio for tenant: {tenant_id}')
# except Exception as e:
minio_client.upload_document_file(tenant_id, document_version.doc_id, document_version.language, document_version.id, # current_app.logger.error(f'Error compressing audio for tenant: {tenant_id} with error: {e}')
output_file, compressed_data) # raise
#
current_app.logger.info(f'Compressed audio for tenant: {tenant_id}') #
except Exception as e: # def transcribe_audio(tenant_id, document_version, input_file, output_file, model_variables):
current_app.logger.error(f'Error compressing audio for tenant: {tenant_id} with error: {e}') # try:
raise # current_app.logger.info(f'Transcribing audio for tenant: {tenant_id}')
# client = model_variables['transcription_client']
# model = model_variables['transcription_model']
def transcribe_audio(tenant_id, document_version, input_file, output_file, model_variables): #
try: # # Download the audio file from MinIO
current_app.logger.info(f'Transcribing audio for tenant: {tenant_id}') # audio_data = minio_client.download_document_file(tenant_id, document_version.doc_id, document_version.language,
client = model_variables['transcription_client'] # document_version.id, input_file)
model = model_variables['transcription_model'] #
# # Load the audio data into pydub
# Download the audio file from MinIO # audio = AudioSegment.from_mp3(io.BytesIO(audio_data))
audio_data = minio_client.download_document_file(tenant_id, document_version.doc_id, document_version.language, #
document_version.id, input_file) # # Define segment length (e.g., 10 minutes)
# segment_length = 10 * 60 * 1000 # 10 minutes in milliseconds
# Load the audio data into pydub #
audio = AudioSegment.from_mp3(io.BytesIO(audio_data)) # transcriptions = []
#
# Define segment length (e.g., 10 minutes) # # Split audio into segments and transcribe each
segment_length = 10 * 60 * 1000 # 10 minutes in milliseconds # for i, chunk in enumerate(audio[::segment_length]):
# current_app.logger.debug(f'Transcribing chunk {i + 1} of {len(audio) // segment_length + 1}')
transcriptions = [] #
# with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as temp_audio:
# Split audio into segments and transcribe each # chunk.export(temp_audio.name, format="mp3")
for i, chunk in enumerate(audio[::segment_length]): #
current_app.logger.debug(f'Transcribing chunk {i + 1} of {len(audio) // segment_length + 1}') # with open(temp_audio.name, 'rb') as audio_segment:
# transcription = client.audio.transcriptions.create(
with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as temp_audio: # file=audio_segment,
chunk.export(temp_audio.name, format="mp3") # model=model,
# language=document_version.language,
with open(temp_audio.name, 'rb') as audio_segment: # response_format='verbose_json',
transcription = client.audio.transcriptions.create( # )
file=audio_segment, #
model=model, # transcriptions.append(transcription.text)
language=document_version.language, #
response_format='verbose_json', # os.unlink(temp_audio.name) # Delete the temporary file
) #
# # Combine all transcriptions
transcriptions.append(transcription.text) # full_transcription = " ".join(transcriptions)
#
os.unlink(temp_audio.name) # Delete the temporary file # # Upload the full transcription to MinIO
# minio_client.upload_document_file(
# Combine all transcriptions # tenant_id,
full_transcription = " ".join(transcriptions) # document_version.doc_id,
# document_version.language,
# Upload the full transcription to MinIO # document_version.id,
minio_client.upload_document_file( # output_file,
tenant_id, # full_transcription.encode('utf-8')
document_version.doc_id, # )
document_version.language, #
document_version.id, # current_app.logger.info(f'Transcribed audio for tenant: {tenant_id}')
output_file, # except Exception as e:
full_transcription.encode('utf-8') # current_app.logger.error(f'Error transcribing audio for tenant: {tenant_id}, with error: {e}')
) # raise
#
current_app.logger.info(f'Transcribed audio for tenant: {tenant_id}') #
except Exception as e: # def annotate_transcription(tenant, document_version, input_file, output_file, model_variables):
current_app.logger.error(f'Error transcribing audio for tenant: {tenant_id}, with error: {e}') # try:
raise # current_app.logger.debug(f'Annotating transcription for tenant {tenant.id}')
#
# char_splitter = CharacterTextSplitter(separator='.',
def annotate_transcription(tenant, document_version, input_file, output_file, model_variables): # chunk_size=model_variables['annotation_chunk_length'],
try: # chunk_overlap=0)
current_app.logger.debug(f'Annotating transcription for tenant {tenant.id}') #
# headers_to_split_on = [
char_splitter = CharacterTextSplitter(separator='.', # ("#", "Header 1"),
chunk_size=model_variables['annotation_chunk_length'], # ("##", "Header 2"),
chunk_overlap=0) # ]
# markdown_splitter = MarkdownHeaderTextSplitter(headers_to_split_on, strip_headers=False)
headers_to_split_on = [ #
("#", "Header 1"), # llm = model_variables['llm']
("##", "Header 2"), # template = model_variables['transcript_template']
] # language_template = create_language_template(template, document_version.language)
markdown_splitter = MarkdownHeaderTextSplitter(headers_to_split_on, strip_headers=False) # transcript_prompt = ChatPromptTemplate.from_template(language_template)
# setup = RunnablePassthrough()
llm = model_variables['llm'] # output_parser = StrOutputParser()
template = model_variables['transcript_template'] #
language_template = create_language_template(template, document_version.language) # # Download the transcription file from MinIO
transcript_prompt = ChatPromptTemplate.from_template(language_template) # transcript_data = minio_client.download_document_file(tenant.id, document_version.doc_id,
setup = RunnablePassthrough() # document_version.language, document_version.id,
output_parser = StrOutputParser() # input_file)
# transcript = transcript_data.decode('utf-8')
# Download the transcription file from MinIO #
transcript_data = minio_client.download_document_file(tenant.id, document_version.doc_id, # chain = setup | transcript_prompt | llm | output_parser
document_version.language, document_version.id, #
input_file) # chunks = char_splitter.split_text(transcript)
transcript = transcript_data.decode('utf-8') # all_markdown_chunks = []
# last_markdown_chunk = ''
chain = setup | transcript_prompt | llm | output_parser # for chunk in chunks:
# current_app.logger.debug(f'Annotating next chunk of {len(chunks)} for tenant {tenant.id}')
chunks = char_splitter.split_text(transcript) # full_input = last_markdown_chunk + '\n' + chunk
all_markdown_chunks = [] # if tenant.embed_tuning:
last_markdown_chunk = '' # current_app.embed_tuning_logger.debug(f'Annotating chunk: \n '
for chunk in chunks: # f'------------------\n'
current_app.logger.debug(f'Annotating next chunk of {len(chunks)} for tenant {tenant.id}') # f'{full_input}\n'
full_input = last_markdown_chunk + '\n' + chunk # f'------------------\n')
if tenant.embed_tuning: # input_transcript = {'transcript': full_input}
current_app.embed_tuning_logger.debug(f'Annotating chunk: \n ' # markdown = chain.invoke(input_transcript)
f'------------------\n' # # GPT-4o returns some kind of content description: ```markdown <text> ```
f'{full_input}\n' # if markdown.startswith("```markdown"):
f'------------------\n') # markdown = "\n".join(markdown.strip().split("\n")[1:-1])
input_transcript = {'transcript': full_input} # if tenant.embed_tuning:
markdown = chain.invoke(input_transcript) # current_app.embed_tuning_logger.debug(f'Markdown Received: \n '
# GPT-4o returns some kind of content description: ```markdown <text> ``` # f'------------------\n'
if markdown.startswith("```markdown"): # f'{markdown}\n'
markdown = "\n".join(markdown.strip().split("\n")[1:-1]) # f'------------------\n')
if tenant.embed_tuning: # md_header_splits = markdown_splitter.split_text(markdown)
current_app.embed_tuning_logger.debug(f'Markdown Received: \n ' # markdown_chunks = [doc.page_content for doc in md_header_splits]
f'------------------\n' # # claude-3.5-sonnet returns introductory text
f'{markdown}\n' # if not markdown_chunks[0].startswith('#'):
f'------------------\n') # markdown_chunks.pop(0)
md_header_splits = markdown_splitter.split_text(markdown) # last_markdown_chunk = markdown_chunks[-1]
markdown_chunks = [doc.page_content for doc in md_header_splits] # last_markdown_chunk = "\n".join(markdown.strip().split("\n")[1:])
# claude-3.5-sonnet returns introductory text # markdown_chunks.pop()
if not markdown_chunks[0].startswith('#'): # all_markdown_chunks += markdown_chunks
markdown_chunks.pop(0) #
last_markdown_chunk = markdown_chunks[-1] # all_markdown_chunks += [last_markdown_chunk]
last_markdown_chunk = "\n".join(markdown.strip().split("\n")[1:]) #
markdown_chunks.pop() # annotated_transcript = '\n'.join(all_markdown_chunks)
all_markdown_chunks += markdown_chunks #
# # Upload the annotated transcript to MinIO
all_markdown_chunks += [last_markdown_chunk] # minio_client.upload_document_file(
# tenant.id,
annotated_transcript = '\n'.join(all_markdown_chunks) # document_version.doc_id,
# document_version.language,
# Upload the annotated transcript to MinIO # document_version.id,
minio_client.upload_document_file( # output_file,
tenant.id, # annotated_transcript.encode('utf-8')
document_version.doc_id, # )
document_version.language, #
document_version.id, # current_app.logger.info(f'Annotated transcription for tenant {tenant.id}')
output_file, # except Exception as e:
annotated_transcript.encode('utf-8') # current_app.logger.error(f'Error annotating transcription for tenant {tenant.id}, with error: {e}')
) # raise
current_app.logger.info(f'Annotated transcription for tenant {tenant.id}')
except Exception as e:
current_app.logger.error(f'Error annotating transcription for tenant {tenant.id}, with error: {e}')
raise
def create_potential_chunks_for_markdown(tenant_id, document_version, input_file): def create_potential_chunks_for_markdown(tenant_id, document_version, input_file):
@@ -648,7 +561,3 @@ def combine_chunks_for_markdown(potential_chunks, min_chars, max_chars):
actual_chunks.append(current_chunk) actual_chunks.append(current_chunk)
return actual_chunks return actual_chunks
pass

View File

@@ -3,7 +3,7 @@
Plugin Name: EveAI Chat Widget Plugin Name: EveAI Chat Widget
Plugin URI: https://askeveai.com/ Plugin URI: https://askeveai.com/
Description: Integrates the EveAI chat interface into your WordPress site. Description: Integrates the EveAI chat interface into your WordPress site.
Version: 1.3.21 Version: 1.5.0
Author: Josako, Pieter Laroy Author: Josako, Pieter Laroy
Author URI: https://askeveai.com/about/ Author URI: https://askeveai.com/about/
*/ */
@@ -18,15 +18,30 @@ function eveai_chat_enqueue_scripts() {
wp_enqueue_style('eveai-chat-style', plugin_dir_url(__FILE__) . 'css/eveai-chat-style.css'); wp_enqueue_style('eveai-chat-style', plugin_dir_url(__FILE__) . 'css/eveai-chat-style.css');
} }
add_action('wp_enqueue_scripts', 'eveai_chat_enqueue_scripts'); add_action('wp_enqueue_scripts', 'eveai_chat_enqueue_scripts');
add_action('admin_enqueue_scripts', 'eveai_chat_enqueue_scripts');
// Shortcode function // Shortcode function
function eveai_chat_shortcode($atts) { function eveai_chat_shortcode($atts) {
$options = get_option('eveai_chat_options'); // Default values
$tenant_id = esc_js($options['tenant_id']); $defaults = array(
$api_key = esc_js($options['api_key']); 'tenant_id' => '',
$domain = esc_js($options['domain']); 'api_key' => '',
$language = esc_js($options['language']); 'domain' => '',
$supported_languages = esc_js($options['supported_languages']); 'language' => 'en',
'supported_languages' => 'en,fr,de,es',
'server_url' => 'https://evie.askeveai.com'
);
// Merge provided attributes with defaults
$atts = shortcode_atts($defaults, $atts, 'eveai_chat');
// Sanitize inputs
$tenant_id = sanitize_text_field($atts['tenant_id']);
$api_key = sanitize_text_field($atts['api_key']);
$domain = esc_url_raw($atts['domain']);
$language = sanitize_text_field($atts['language']);
$supported_languages = sanitize_text_field($atts['supported_languages']);
$server_url = esc_url_raw($atts['server_url']);
// Generate a unique ID for this instance of the chat widget // Generate a unique ID for this instance of the chat widget
$chat_id = 'chat-container-' . uniqid(); $chat_id = 'chat-container-' . uniqid();
@@ -39,7 +54,8 @@ function eveai_chat_shortcode($atts) {
'$api_key', '$api_key',
'$domain', '$domain',
'$language', '$language',
'$supported_languages' '$supported_languages',
'$server_url'
); );
eveAI.initializeChat('$chat_id'); eveAI.initializeChat('$chat_id');
}); });
@@ -49,80 +65,3 @@ function eveai_chat_shortcode($atts) {
} }
add_shortcode('eveai_chat', 'eveai_chat_shortcode'); add_shortcode('eveai_chat', 'eveai_chat_shortcode');
// Add admin menu
function eveai_chat_admin_menu() {
add_options_page('EveAI Chat Settings', 'EveAI Chat', 'manage_options', 'eveai-chat-settings', 'eveai_chat_settings_page');
}
add_action('admin_menu', 'eveai_chat_admin_menu');
// Settings page
function eveai_chat_settings_page() {
?>
<div class="wrap">
<h1>EveAI Chat Settings</h1>
<form method="post" action="options.php">
<?php
settings_fields('eveai_chat_options');
do_settings_sections('eveai-chat-settings');
submit_button();
?>
</form>
</div>
<?php
}
// Register settings
function eveai_chat_register_settings() {
register_setting('eveai_chat_options', 'eveai_chat_options', 'eveai_chat_options_validate');
add_settings_section('eveai_chat_main', 'Main Settings', 'eveai_chat_section_text', 'eveai-chat-settings');
add_settings_field('eveai_chat_tenant_id', 'Tenant ID', 'eveai_chat_tenant_id_input', 'eveai-chat-settings', 'eveai_chat_main');
add_settings_field('eveai_chat_api_key', 'API Key', 'eveai_chat_api_key_input', 'eveai-chat-settings', 'eveai_chat_main');
add_settings_field('eveai_chat_domain', 'Domain', 'eveai_chat_domain_input', 'eveai-chat-settings', 'eveai_chat_main');
add_settings_field('eveai_chat_language', 'Default Language', 'eveai_chat_language_input', 'eveai-chat-settings', 'eveai_chat_main');
add_settings_field('eveai_chat_supported_languages', 'Supported Languages', 'eveai_chat_supported_languages_input', 'eveai-chat-settings', 'eveai_chat_main');
}
add_action('admin_init', 'eveai_chat_register_settings');
function eveai_chat_section_text() {
echo '<p>Enter your EveAI Chat configuration details below:</p>';
}
function eveai_chat_tenant_id_input() {
$options = get_option('eveai_chat_options');
echo "<input id='eveai_chat_tenant_id' name='eveai_chat_options[tenant_id]' type='text' value='" . esc_attr($options['tenant_id']) . "' />";
}
function eveai_chat_api_key_input() {
$options = get_option('eveai_chat_options');
echo "<input id='eveai_chat_api_key' name='eveai_chat_options[api_key]' type='password' value='" . esc_attr($options['api_key']) . "' />";
}
function eveai_chat_domain_input() {
$options = get_option('eveai_chat_options');
echo "<input id='eveai_chat_domain' name='eveai_chat_options[domain]' type='text' value='" . esc_attr($options['domain']) . "' />";
}
function eveai_chat_language_input() {
$options = get_option('eveai_chat_options');
echo "<input id='eveai_chat_language' name='eveai_chat_options[language]' type='text' value='" . esc_attr($options['language']) . "' />";
}
function eveai_chat_supported_languages_input() {
$options = get_option('eveai_chat_options');
$supported_languages = isset($options['supported_languages']) ? $options['supported_languages'] : 'en,fr,de,es';
echo "<input id='eveai_chat_supported_languages' name='eveai_chat_options[supported_languages]' type='text' value='" . esc_attr($supported_languages) . "' />";
echo "<p class='description'>Enter comma-separated language codes (e.g., en,fr,de,es)</p>";
}
function eveai_chat_options_validate($input) {
$new_input = array();
$new_input['tenant_id'] = sanitize_text_field($input['tenant_id']);
$new_input['api_key'] = sanitize_text_field($input['api_key']);
$new_input['domain'] = esc_url_raw($input['domain']);
$new_input['language'] = sanitize_text_field($input['language']);
$new_input['supported_languages'] = sanitize_text_field($input['supported_languages']);
return $new_input;
}

View File

@@ -1,6 +1,6 @@
class EveAIChatWidget extends HTMLElement { class EveAIChatWidget extends HTMLElement {
static get observedAttributes() { static get observedAttributes() {
return ['tenant-id', 'api-key', 'domain', 'language', 'languages']; return ['tenant-id', 'api-key', 'domain', 'language', 'languages', 'server-url'];
} }
constructor() { constructor() {
@@ -87,6 +87,7 @@ class EveAIChatWidget extends HTMLElement {
this.language = this.getAttribute('language'); this.language = this.getAttribute('language');
const languageAttr = this.getAttribute('languages'); const languageAttr = this.getAttribute('languages');
this.languages = languageAttr ? languageAttr.split(',') : []; this.languages = languageAttr ? languageAttr.split(',') : [];
this.serverUrl = this.getAttribute('server-url');
this.currentLanguage = this.language; this.currentLanguage = this.language;
console.log('Updated attributes:', { console.log('Updated attributes:', {
tenantId: this.tenantId, tenantId: this.tenantId,
@@ -94,7 +95,8 @@ class EveAIChatWidget extends HTMLElement {
domain: this.domain, domain: this.domain,
language: this.language, language: this.language,
currentLanguage: this.currentLanguage, currentLanguage: this.currentLanguage,
languages: this.languages languages: this.languages,
serverUrl: this.serverUrl
}); });
} }
@@ -104,14 +106,16 @@ class EveAIChatWidget extends HTMLElement {
const domain = this.getAttribute('domain'); const domain = this.getAttribute('domain');
const language = this.getAttribute('language'); const language = this.getAttribute('language');
const languages = this.getAttribute('languages'); const languages = this.getAttribute('languages');
const serverUrl = this.getAttribute('server-url');
console.log('Checking if all attributes are set:', { console.log('Checking if all attributes are set:', {
tenantId, tenantId,
apiKey, apiKey,
domain, domain,
language, language,
languages languages,
serverUrl
}); });
return tenantId && apiKey && domain && language && languages; return tenantId && apiKey && domain && language && languages && serverUrl;
} }
createLanguageDropdown() { createLanguageDropdown() {
@@ -142,7 +146,7 @@ class EveAIChatWidget extends HTMLElement {
console.log(`Initializing socket connection to Evie`); console.log(`Initializing socket connection to Evie`);
// Ensure apiKey is passed in the query parameters // Ensure apiKey is passed in the query parameters
this.socket = io('https://evie.askeveai.com', { this.socket = io(this.serverUrl, {
path: '/chat/socket.io/', path: '/chat/socket.io/',
transports: ['websocket', 'polling'], transports: ['websocket', 'polling'],
query: { query: {
@@ -161,24 +165,32 @@ class EveAIChatWidget extends HTMLElement {
this.socket.on('connect', (data) => { this.socket.on('connect', (data) => {
console.log('Socket connected OK'); console.log('Socket connected OK');
console.log('Connect event data:', data);
console.log('Connect event this:', this);
this.setStatusMessage('Connected to EveAI.'); this.setStatusMessage('Connected to EveAI.');
this.updateConnectionStatus(true); this.updateConnectionStatus(true);
this.startHeartbeat(); this.startHeartbeat();
if (data.room) { if (data && data.room) {
this.room = data.room; this.room = data.room;
console.log(`Joined room: ${this.room}`); console.log(`Joined room: ${this.room}`);
} else {
console.log('Room information not received on connect');
} }
}); });
this.socket.on('authenticated', (data) => { this.socket.on('authenticated', (data) => {
console.log('Authenticated event received: ', data); console.log('Authenticated event received');
console.log('Authentication event data:', data);
console.log('Authentication event this:', this);
this.setStatusMessage('Authenticated.'); this.setStatusMessage('Authenticated.');
if (data.token) { if (data && data.token) {
this.jwtToken = data.token; // Store the JWT token received from the server this.jwtToken = data.token;
} }
if (data.room) { if (data && data.room) {
this.room = data.room; this.room = data.room;
console.log(`Confirmed room: ${this.room}`); console.log(`Confirmed room: ${this.room}`);
} else {
console.log('Room information not received on authentication');
} }
}); });

View File

@@ -1,13 +1,14 @@
// static/js/eveai-sdk.js // static/js/eveai-sdk.js
class EveAI { class EveAI {
constructor(tenantId, apiKey, domain, language, languages) { constructor(tenantId, apiKey, domain, language, languages, serverUrl) {
this.tenantId = tenantId; this.tenantId = tenantId;
this.apiKey = apiKey; this.apiKey = apiKey;
this.domain = domain; this.domain = domain;
this.language = language; this.language = language;
this.languages = languages; this.languages = languages;
this.serverUrl = serverUrl;
console.log('EveAI constructor:', { tenantId, apiKey, domain, language, languages }); console.log('EveAI constructor:', { tenantId, apiKey, domain, language, languages, serverUrl });
} }
initializeChat(containerId) { initializeChat(containerId) {
@@ -21,6 +22,7 @@ class EveAI {
chatWidget.setAttribute('domain', this.domain); chatWidget.setAttribute('domain', this.domain);
chatWidget.setAttribute('language', this.language); chatWidget.setAttribute('language', this.language);
chatWidget.setAttribute('languages', this.languages); chatWidget.setAttribute('languages', this.languages);
chatWidget.setAttribute('server-url', this.serverUrl);
}); });
} else { } else {
console.error('Container not found'); console.error('Container not found');

View File

@@ -3,7 +3,7 @@ Contributors: Josako
Tags: chat, ai Tags: chat, ai
Requires at least: 5.0 Requires at least: 5.0
Tested up to: 5.9 Tested up to: 5.9
Stable tag: 1.3.0 Stable tag: 1.5.0
License: GPLv2 or later License: GPLv2 or later
License URI: http://www.gnu.org/licenses/gpl-2.0.html License URI: http://www.gnu.org/licenses/gpl-2.0.html
@@ -17,7 +17,18 @@ This plugin allows you to easily add the EveAI chat widget to your WordPress sit
1. Upload the `eveai-chat-widget` folder to the `/wp-content/plugins/` directory 1. Upload the `eveai-chat-widget` folder to the `/wp-content/plugins/` directory
2. Activate the plugin through the 'Plugins' menu in WordPress 2. Activate the plugin through the 'Plugins' menu in WordPress
3. Go to Settings > EveAI Chat to configure your chat widget parameters 3. Add EveAI Chat Widget to your page or post using the instructions below.
== Usage ==
To add an EveAI Chat Widget to your page or post, use the following shortcode:
[eveai_chat tenant_id="YOUR_TENANT_ID" api_key="YOUR_API_KEY" domain="YOUR_DOMAIN" language="LANGUAGE_CODE" supported_languages="COMMA_SEPARATED_LANGUAGE_CODES" server_url="Server URL for Evie"]
Example:
[eveai_chat tenant_id="123456" api_key="your_api_key_here" domain="https://your-domain.com" language="en" supported_languages="en,fr,de,es" server_url="https://evie.askeveai.com"]
You can add multiple chat widgets with different configurations by using the shortcode multiple times with different parameters.
== Frequently Asked Questions == == Frequently Asked Questions ==
@@ -27,6 +38,16 @@ Contact your EveAI service provider to obtain your Tenant ID, API Key, and Domai
== Changelog == == Changelog ==
= 1.5.0 =
* Allow for multiple servers to serve Evie
= 1.4.1 - 1.4...=
* Bug fixes
= 1.4.0 =
* Allow for multiple instances of Evie on the same website
* Parametrization of the shortcode
= 1.3.3 - = = 1.3.3 - =
* ensure all attributes (also height and supportedLanguages) are set before initializing the socket * ensure all attributes (also height and supportedLanguages) are set before initializing the socket
* Bugfixing * Bugfixing

View File

@@ -0,0 +1,83 @@
# EveAI Sync WordPress Plugin
## Description
EveAI Sync is a WordPress plugin that synchronizes your WordPress content (posts and pages) with the EveAI platform. It allows for seamless integration between your WordPress site and EveAI, ensuring that your content is always up-to-date on both platforms.
## Features
- Automatic synchronization of posts and pages with EveAI
- Support for excluding specific categories or individual posts/pages from syncing
- Bulk synchronization of existing content
- Custom metadata synchronization
- Easy-to-use admin interface for configuration
## Installation
1. Download the plugin zip file.
2. Log in to your WordPress admin panel.
3. Go to Plugins > Add New.
4. Click on the "Upload Plugin" button.
5. Select the downloaded zip file and click "Install Now".
6. After installation, click "Activate Plugin".
## Configuration
1. Go to Settings > EveAI Sync in your WordPress admin panel.
2. Enter your EveAI API URL, Tenant ID, and API Key.
3. Configure any additional settings as needed.
4. Click "Save Changes".
## Usage
- New posts and pages will automatically sync to EveAI when published.
- Existing content can be synced using the "Bulk Sync" option in the settings.
- To exclude a post or page from syncing, use the "Exclude from EveAI sync" checkbox in the post editor.
- Categories can be excluded from syncing in the plugin settings.
## Action Scheduler Dependency
This plugin uses Action Scheduler for efficient background processing of synchronization tasks. Action Scheduler is typically included with WooCommerce, but the plugin can also function without it.
### With Action Scheduler
If Action Scheduler is available (either through WooCommerce or included with this plugin), EveAI Sync will use it for more reliable and efficient scheduling of synchronization tasks.
### Without Action Scheduler
If Action Scheduler is not available, the plugin will automatically fall back to using WordPress cron for scheduling tasks. This fallback ensures that the plugin remains functional, although with potentially less precise timing for background tasks.
No additional configuration is needed; the plugin will automatically detect the presence or absence of Action Scheduler and adjust its behavior accordingly.
## Versions
### 1.0.x - Bugfixing Releases
### 1.0.0 - Initial Release
## Frequently Asked Questions
**Q: How often does the plugin sync content?**
A: The plugin syncs content immediately when a post or page is published or updated. For bulk syncs or when Action Scheduler is not available, the timing may vary based on WordPress cron execution.
**Q: Can I sync only certain types of content?**
A: By default, the plugin syncs all posts and pages. You can exclude specific categories or individual posts/pages from syncing.
**Q: What happens if the sync fails?**
A: The plugin will log any sync failures and attempt to retry. You can view sync status in the plugin's admin interface.
**Q: Do I need to install Action Scheduler separately?**
A: No, the plugin will work with or without Action Scheduler. If you have WooCommerce installed, Action Scheduler will be available automatically.
## Support
For support, please open an issue on our GitHub repository or contact our support team at support@eveai.com.
## Contributing
We welcome contributions to the EveAI Sync plugin. Please feel free to submit pull requests or open issues on our GitHub repository.
## License
This plugin is licensed under the GPL v2 or later.

View File

@@ -0,0 +1,70 @@
.eveai-admin-wrap {
max-width: 800px;
margin: 20px auto;
}
.eveai-admin-header {
background-color: #fff;
padding: 20px;
border: 1px solid #ccc;
border-radius: 5px;
margin-bottom: 20px;
}
.eveai-admin-header h1 {
margin: 0;
color: #23282d;
}
.eveai-admin-content {
background-color: #fff;
padding: 20px;
border: 1px solid #ccc;
border-radius: 5px;
}
.eveai-form-group {
margin-bottom: 15px;
}
.eveai-form-group label {
display: block;
margin-bottom: 5px;
font-weight: bold;
}
.eveai-form-group input[type="text"],
.eveai-form-group input[type="password"] {
width: 100%;
padding: 8px;
border: 1px solid #ddd;
border-radius: 4px;
}
.eveai-button {
background-color: #0085ba;
border-color: #0073aa #006799 #006799;
color: #fff;
text-decoration: none;
text-shadow: 0 -1px 1px #006799, 1px 0 1px #006799, 0 1px 1px #006799, -1px 0 1px #006799;
display: inline-block;
padding: 8px 12px;
border-radius: 3px;
cursor: pointer;
}
.eveai-button:hover {
background-color: #008ec2;
}
.eveai-category-list {
margin-top: 20px;
}
.eveai-category-item {
margin-bottom: 10px;
}
.eveai-category-item label {
font-weight: normal;
}

View File

@@ -0,0 +1,49 @@
jQuery(document).ready(function($) {
// Handle bulk sync button click
$('#eveai-bulk-sync').on('click', function(e) {
e.preventDefault();
if (confirm('Are you sure you want to start a bulk sync? This may take a while.')) {
$.ajax({
url: ajaxurl,
type: 'POST',
data: {
action: 'eveai_bulk_sync',
nonce: eveai_admin.nonce
},
success: function(response) {
alert(response.data.message);
},
error: function() {
alert('An error occurred. Please try again.');
}
});
}
});
// Handle category exclusion checkboxes
$('.eveai-category-exclude').on('change', function() {
var categoryId = $(this).val();
var isExcluded = $(this).is(':checked');
$.ajax({
url: ajaxurl,
type: 'POST',
data: {
action: 'eveai_toggle_category_exclusion',
category_id: categoryId,
is_excluded: isExcluded ? 1 : 0,
nonce: eveai_admin.nonce
},
success: function(response) {
if (response.success) {
console.log('Category exclusion updated');
} else {
alert('Failed to update category exclusion');
}
},
error: function() {
alert('An error occurred. Please try again.');
}
});
});
});

View File

@@ -0,0 +1,64 @@
<?php
/**
* Plugin Name: EveAI Sync
* Plugin URI: https://askeveai.com/
* Description: Synchronizes WordPress content with EveAI API.
* Version: 1.0.16
* Author: Josako, Pieter Laroy
* Author URI: https://askeveai.com/about/
* License: GPL v2 or later
* License URI: https://www.gnu.org/licenses/gpl-2.0.html
* Text Domain: eveai-sync
* Domain Path: /languages
*/
if (!defined('ABSPATH')) {
exit; // Exit if accessed directly
}
// Define plugin constants
define('EVEAI_SYNC_VERSION', '1.0.0');
define('EVEAI_SYNC_PLUGIN_DIR', plugin_dir_path(__FILE__));
define('EVEAI_SYNC_PLUGIN_URL', plugin_dir_url(__FILE__));
// Include the main plugin class
require_once EVEAI_SYNC_PLUGIN_DIR . 'includes/class-eveai-sync.php';
// Initialize the plugin
function eveai_sync_init() {
$plugin = new EveAI_Sync();
$plugin->init();
}
add_action('plugins_loaded', 'eveai_sync_init');
// Set up activation and deactivation hooks
register_activation_hook(__FILE__, 'eveai_sync_activation');
register_deactivation_hook(__FILE__, 'eveai_sync_deactivation');
function eveai_sync_activation() {
// Other activation tasks...
}
function eveai_sync_deactivation() {
// Other deactivation tasks...
}
// Clean up meta when a post is permanently deleted
function eveai_delete_post_meta($post_id) {
delete_post_meta($post_id, '_eveai_document_id');
delete_post_meta($post_id, '_eveai_document_version_id');
}
add_action('before_delete_post', 'eveai_delete_post_meta');
// Display sync info in post
function eveai_display_sync_info($post) {
$document_id = get_post_meta($post->ID, '_eveai_document_id', true);
$document_version_id = get_post_meta($post->ID, '_eveai_document_version_id', true);
echo '<div class="misc-pub-section">';
echo '<h4>EveAI Sync Info:</h4>';
echo 'Document ID: ' . ($document_id ? esc_html($document_id) : 'Not set') . '<br>';
echo 'Document Version ID: ' . ($document_version_id ? esc_html($document_version_id) : 'Not set');
echo '</div>';
}
add_action('post_submitbox_misc_actions', 'eveai_display_sync_info');

View File

@@ -0,0 +1,156 @@
<?php
class EveAI_Admin {
private $api;
public function __construct($api) {
$this->api = $api;
}
public function register_settings() {
register_setting('eveai_settings', 'eveai_api_url');
register_setting('eveai_settings', 'eveai_tenant_id');
register_setting('eveai_settings', 'eveai_api_key');
register_setting('eveai_settings', 'eveai_default_language');
register_setting('eveai_settings', 'eveai_excluded_categories');
register_setting('eveai_settings', 'eveai_excluded_categories');
register_setting('eveai_settings', 'eveai_access_token');
register_setting('eveai_settings', 'eveai_token_expiry');
}
public function add_admin_menu() {
add_options_page(
'EveAI Sync Settings',
'EveAI Sync',
'manage_options',
'eveai-sync',
array($this, 'render_settings_page')
);
}
public function render_settings_page() {
?>
<div class="wrap">
<h1><?php echo esc_html(get_admin_page_title()); ?></h1>
<form action="options.php" method="post">
<?php
settings_fields('eveai_settings');
do_settings_sections('eveai-sync');
?>
<table class="form-table">
<tr valign="top">
<th scope="row">API URL</th>
<td><input type="text" name="eveai_api_url" value="<?php echo esc_attr(get_option('eveai_api_url')); ?>" style="width: 100%;" /></td>
</tr>
<tr valign="top">
<th scope="row">Tenant ID</th>
<td><input type="text" name="eveai_tenant_id" value="<?php echo esc_attr(get_option('eveai_tenant_id')); ?>" style="width: 100%;" /></td>
</tr>
<tr valign="top">
<th scope="row">API Key</th>
<td><input type="text" name="eveai_api_key" value="<?php echo esc_attr(get_option('eveai_api_key')); ?>" style="width: 100%;" /></td>
</tr>
<tr valign="top">
<th scope="row">Default Language</th>
<td><input type="text" name="eveai_default_language" value="<?php echo esc_attr(get_option('eveai_default_language', 'en')); ?>" style="width: 100%;" /></td>
</tr>
<tr valign="top">
<th scope="row">Excluded Categories</th>
<td>
<input type="text" name="eveai_excluded_categories" value="<?php echo esc_attr(get_option('eveai_excluded_categories')); ?>" style="width: 100%;" />
<p class="description">Enter a comma-separated list of category names to exclude from syncing.</p>
</td>
</tr>
</table>
<?php submit_button('Save Settings'); ?>
</form>
<h2>Bulk Sync</h2>
<p>Click the button below to start a bulk sync of all posts and pages to EveAI.</p>
<form method="post" action="">
<?php wp_nonce_field('eveai_bulk_sync', 'eveai_bulk_sync_nonce'); ?>
<input type="submit" name="eveai_bulk_sync" class="button button-primary" value="Start Bulk Sync">
</form>
<div id="eveai-sync-results" style="margin-top: 20px;"></div>
</div>
<script>
jQuery(document).ready(function($) {
$('form').on('submit', function(e) {
if ($(this).find('input[name="eveai_bulk_sync"]').length) {
e.preventDefault();
var $results = $('#eveai-sync-results');
$results.html('<p>Starting bulk sync...</p>');
$.ajax({
url: ajaxurl,
type: 'POST',
data: {
action: 'eveai_bulk_sync',
nonce: '<?php echo wp_create_nonce('eveai_bulk_sync_ajax'); ?>'
},
success: function(response) {
if (response.success) {
var resultsHtml = '<h3>Sync Results:</h3><ul>';
response.data.forEach(function(item) {
resultsHtml += '<li>' + item.title + ' (' + item.type + '): ' + item.status + '</li>';
});
resultsHtml += '</ul>';
$results.html(resultsHtml);
} else {
$results.html('<p>Error: ' + response.data + '</p>');
}
},
error: function() {
$results.html('<p>An error occurred. Please try again.</p>');
}
});
}
});
});
</script>
<?php
}
public function handle_bulk_sync_ajax() {
check_ajax_referer('eveai_bulk_sync_ajax', 'nonce');
if (!current_user_can('manage_options')) {
wp_send_json_error('Insufficient permissions');
return;
}
$post_handler = new EveAI_Post_Handler($this->api);
$bulk_sync = new EveAI_Bulk_Sync($this->api, $post_handler);
$results = $bulk_sync->init_bulk_sync();
wp_send_json_success($results);
}
public function add_sync_meta_box() {
add_meta_box(
'eveai_sync_meta_box',
'EveAI Sync',
array($this, 'render_sync_meta_box'),
array('post', 'page'),
'side',
'default'
);
}
public function render_sync_meta_box($post) {
$excluded = get_post_meta($post->ID, '_eveai_exclude_sync', true);
wp_nonce_field('eveai_sync_meta_box', 'eveai_sync_meta_box_nonce');
?>
<label>
<input type="checkbox" name="eveai_exclude_sync" value="1" <?php checked($excluded, '1'); ?>>
Exclude from EveAI sync
</label>
<?php
}
public function handle_bulk_sync() {
$post_handler = new EveAI_Post_Handler($this->api);
$bulk_sync = new EveAI_Bulk_Sync($this->api, $post_handler);
$bulk_sync->init_bulk_sync();
add_settings_error('eveai_messages', 'eveai_message', 'Bulk sync initiated successfully.', 'updated');
}
}

View File

@@ -0,0 +1,135 @@
<?php
class EveAI_API {
private $api_url;
private $tenant_id;
private $api_key;
private $access_token;
private $token_expiry;
public function __construct() {
$this->api_url = get_option('eveai_api_url');
$this->tenant_id = get_option('eveai_tenant_id');
$this->api_key = get_option('eveai_api_key');
$this->access_token = get_option('eveai_access_token');
$this->token_expiry = get_option('eveai_token_expiry', 0);
}
private function ensure_valid_token() {
if (empty($this->access_token) || time() > $this->token_expiry) {
$this->get_new_token();
}
}
private function get_new_token() {
$response = wp_remote_post($this->api_url . '/api/v1/auth/token', [
'body' => json_encode([
'tenant_id' => $this->tenant_id,
'api_key' => $this->api_key,
]),
'headers' => [
'Content-Type' => 'application/json',
],
]);
if (is_wp_error($response)) {
throw new Exception('Failed to get token: ' . $response->get_error_message());
}
$body = wp_remote_retrieve_body($response);
// Check if the body is already an array (decoded JSON)
if (!is_array($body)) {
$body = json_decode($body, true);
}
if (empty($body['access_token'])) {
throw new Exception('Invalid token response');
}
$this->access_token = $body['access_token'];
// Use the expiration time from the API response, or default to 1 hour if not provided
$expires_in = isset($body['expires_in']) ? $body['expires_in'] : 3600;
$this->token_expiry = time() + $expires_in - 10; // Subtract 10 seconds to be safe
update_option('eveai_access_token', $this->access_token);
update_option('eveai_token_expiry', $this->token_expiry);
}
private function make_request($method, $endpoint, $data = null) {
$this->ensure_valid_token();
error_log('EveAI API Request: ' . $method . ' ' . $this->api_url . $endpoint);
$url = $this->api_url . $endpoint;
$args = array(
'method' => $method,
'headers' => array(
'Content-Type' => 'application/json',
'Authorization' => 'Bearer ' . $this->access_token,
)
);
if ($data !== null) {
$args['body'] = json_encode($data);
}
$response = wp_remote_request($url, $args);
if (is_wp_error($response)) {
$error_message = $response->get_error_message();
error_log('EveAI API Error: ' . $error_message);
throw new Exception('API request failed: ' . $error_message);
}
$body = wp_remote_retrieve_body($response);
$status_code = wp_remote_retrieve_response_code($response);
error_log('EveAI API Response: ' . print_r($body, true));
error_log('EveAI API Status Code: ' . $status_code);
// Check if the body is already an array (decoded JSON)
if (!is_array($body)) {
$body = json_decode($body, true);
}
if ($status_code == 401) {
// Token might have expired, try to get a new one and retry the request
error_log('Token expired, trying to get a new one...');
$this->get_new_token();
return $this->make_request($method, $endpoint, $data);
}
if ($status_code >= 400) {
$error_message = isset($body['message']) ? $body['message'] : 'Unknown error';
error_log('EveAI API Error: ' . $error_message);
throw new Exception('API error: ' . $error_message);
}
return $body;
}
public function add_url($data) {
return $this->make_request('POST', '/api/v1/documents/add_url', $data);
}
public function update_document($document_id, $data) {
return $this->make_request('PUT', "/api/v1/documents/{$document_id}", $data);
}
public function invalidate_document($document_id) {
$data = array(
'valid_to' => gmdate('Y-m-d\TH:i:s\Z') // Current UTC time in ISO 8601 format
);
return $this->make_request('PUT', "/api/v1/documents/{$document_id}", $data);
}
public function refresh_document($document_id) {
return $this->make_request('POST', "/api/v1/documents/{$document_id}/refresh");
}
public function refresh_document_with_info($document_id, $data) {
return $this->make_request('POST', "/api/v1/documents/{$document_id}/refresh_with_info", $data);
}
}

View File

@@ -0,0 +1,37 @@
<?php
class EveAI_Bulk_Sync {
private $api;
private $post_handler;
public function __construct($api, $post_handler) {
$this->api = $api;
$this->post_handler = $post_handler;
}
public function init_bulk_sync() {
$posts = get_posts(array(
'post_type' => array('post', 'page'),
'post_status' => 'publish',
'posts_per_page' => -1,
));
$sync_results = array();
foreach ($posts as $post) {
$evie_id = get_post_meta($post->ID, '_eveai_document_id', true);
$evie_version_id = get_post_meta($post->ID, '_eveai_document_version_id', true);
$is_update = ($evie_id && $evie_version_id);
$result = $this->post_handler->sync_post($post->ID, $is_update);
$sync_results[] = array(
'id' => $post->ID,
'title' => $post->post_title,
'type' => $post->post_type,
'status' => $result ? 'success' : 'failed'
);
}
return $sync_results;
}
}

View File

@@ -0,0 +1,214 @@
<?php
class EveAI_Post_Handler {
private $api;
public function __construct($api) {
$this->api = $api;
}
public function handle_post_save($post_id, $post, $update) {
// Verify if this is not an auto save routine.
if (defined('DOING_AUTOSAVE') && DOING_AUTOSAVE) return;
// Check if this is a revision
if (wp_is_post_revision($post_id)) return;
// Check if post type is one we want to sync
if (!in_array($post->post_type, ['post', 'page'])) return;
// Check if post status is published
if ($post->post_status != 'publish') return;
// Verify nonce
if (!isset($_POST['eveai_sync_meta_box_nonce']) || !wp_verify_nonce($_POST['eveai_sync_meta_box_nonce'], 'eveai_sync_meta_box')) {
return;
}
// Check permissions
if ('page' == $_POST['post_type']) {
if (!current_user_can('edit_page', $post_id)) return;
} else {
if (!current_user_can('edit_post', $post_id)) return;
}
// Check if we should sync this post
if (!$this->should_sync_post($post_id)) return;
// Check if this is a REST API request
if (defined('REST_REQUEST') && REST_REQUEST) {
error_log("EveAI: REST API request detected for post $post_id");
}
error_log('Handling post' . $post_id . 'save event with update: ' . $update);
// Check if we've already synced this post in this request
if (get_post_meta($post_id, '_eveai_syncing', true)) {
return;
}
// Set a flag to indicate we're syncing
update_post_meta($post_id, '_eveai_syncing', true);
$this->sync_post($post_id, $update);
// Remove the flag after syncing
delete_post_meta($post_id, '_eveai_syncing');
}
public function sync_post($post_id, $is_update) {
$evie_id = get_post_meta($post_id, '_eveai_document_id', true);
try {
if ($evie_id && $is_update) {
$old_data = $this->get_old_post_data($post_id);
$new_data = $this->prepare_post_data($post_id);
if ($this->has_metadata_changed($old_data, $new_data)) {
$result = $this->refresh_document_with_info($evie_id, $new_data);
} else {
$result = $this->refresh_document($evie_id);
}
} else {
$data = $this->prepare_post_data($post_id);
$result = $this->api->add_url($data);
}
if (isset($result['document_id']) && isset($result['document_version_id'])) {
update_post_meta($post_id, '_eveai_document_id', $result['document_id']);
update_post_meta($post_id, '_eveai_document_version_id', $result['document_version_id']);
// Add debugging
error_log("EveAI: Set document_id {$result['document_id']} and document_version_id {$result['document_version_id']} for post {$post_id}");
}
return true;
} catch (Exception $e) {
error_log('EveAI Sync Error: ' . $e->getMessage());
// Optionally, you can add an admin notice here
add_action('admin_notices', function() use ($e) {
echo '<div class="notice notice-error is-dismissible">';
echo '<p>EveAI Sync Error: ' . esc_html($e->getMessage()) . '</p>';
echo '</div>';
});
return false;
}
return false;
}
private function get_old_post_data($post_id) {
$post = get_post($post_id);
return array(
'name' => $post->post_title,
'system_metadata' => json_encode([
'post_id' => $post_id,
'type' => $post->post_type,
'author' => get_the_author_meta('display_name', $post->post_author),
'categories' => $post->post_type === 'post' ? wp_get_post_categories($post_id, array('fields' => 'names')) : [],
'tags' => $post->post_type === 'post' ? wp_get_post_tags($post_id, array('fields' => 'names')) : [],
]),
);
}
private function has_metadata_changed($old_data, $new_data) {
return $old_data['name'] !== $new_data['name'] ||
$old_data['user_metadata'] !== $new_data['user_metadata'];
}
private function refresh_document_with_info($evie_id, $data) {
try {
return $this->api->refresh_document_with_info($evie_id, $data);
} catch (Exception $e) {
error_log('EveAI refresh with info error: ' . $e->getMessage());
return false;
}
}
private function refresh_document($evie_id) {
try {
return $this->api->refresh_document($evie_id);
} catch (Exception $e) {
error_log('EveAI refresh error: ' . $e->getMessage());
add_action('admin_notices', function() use ($e) {
echo '<div class="notice notice-error is-dismissible">';
echo '<p>EveAI Sync Error: ' . esc_html($e->getMessage()) . '</p>';
echo '</div>';
});
return false;
}
}
public function handle_post_delete($post_id) {
if ($evie_id) {
try {
$this->api->invalidate_document($evie_id);
} catch (Exception $e) {
error_log('EveAI invalidate error: ' . $e->getMessage());
add_action('admin_notices', function() use ($e) {
echo '<div class="notice notice-error is-dismissible">';
echo '<p>EveAI Sync Error: ' . esc_html($e->getMessage()) . '</p>';
echo '</div>';
});
}
}
}
public function process_sync_queue() {
$queue = get_option('eveai_sync_queue', array());
foreach ($queue as $key => $item) {
$this->sync_post($item['post_id'], $item['is_update']);
unset($queue[$key]);
}
update_option('eveai_sync_queue', $queue);
}
private function should_sync_post($post_id) {
if (get_post_meta($post_id, '_eveai_exclude_sync', true)) {
return false;
}
$post_type = get_post_type($post_id);
if ($post_type === 'page') {
// Pages are always synced unless individually excluded
return true;
}
if ($post_type === 'post') {
$excluded_categories_string = get_option('eveai_excluded_categories', '');
$excluded_categories = array_map('trim', explode(',', $excluded_categories_string));
$post_categories = wp_get_post_categories($post_id, array('fields' => 'names'));
$post_tags = wp_get_post_tags($post_id, array('fields' => 'names'));
// Check if any of the post's categories or tags are not in the excluded list
$all_terms = array_merge($post_categories, $post_tags);
foreach ($all_terms as $term) {
if (!in_array($term, $excluded_categories)) {
return true;
}
}
}
return false;
}
private function prepare_post_data($post_id) {
$post = get_post($post_id);
$data = array(
'url' => get_permalink($post_id),
'name' => $post->post_title,
'language' => get_option('eveai_default_language', 'en'),
'valid_from' => get_gmt_from_date($post->post_date, 'Y-m-d\TH:i:s\Z'),
'user_metadata' => json_encode([
'post_id' => $post_id,
'type' => $post->post_type,
'author' => get_the_author_meta('display_name', $post->post_author),
'categories' => $post->post_type === 'post' ? wp_get_post_categories($post_id, array('fields' => 'names')) : [],
'tags' => $post->post_type === 'post' ? wp_get_post_tags($post_id, array('fields' => 'names')) : [],
]),
);
return $data;
}
}

View File

@@ -0,0 +1,33 @@
<?php
class EveAI_Sync {
private $api;
private $post_handler;
private $admin;
public function init() {
$this->load_dependencies();
$this->setup_actions();
}
private function load_dependencies() {
require_once EVEAI_SYNC_PLUGIN_DIR . 'includes/class-eveai-api.php';
require_once EVEAI_SYNC_PLUGIN_DIR . 'includes/class-eveai-post-handler.php';
require_once EVEAI_SYNC_PLUGIN_DIR . 'includes/class-eveai-admin.php';
require_once EVEAI_SYNC_PLUGIN_DIR . 'includes/class-eveai-bulk-sync.php';
$this->api = new EveAI_API();
$this->post_handler = new EveAI_Post_Handler($this->api);
$this->admin = new EveAI_Admin($this->api);
}
private function setup_actions() {
add_action('save_post', array($this->post_handler, 'handle_post_save'), 10, 3);
add_action('before_delete_post', array($this->post_handler, 'handle_post_delete'));
add_action('admin_init', array($this->admin, 'register_settings'));
add_action('admin_menu', array($this->admin, 'add_admin_menu'));
add_action('add_meta_boxes', array($this->admin, 'add_sync_meta_box'));
add_action('eveai_sync_post', array($this->post_handler, 'sync_post'), 10, 2);
add_action('wp_ajax_eveai_bulk_sync', array($this->admin, 'handle_bulk_sync_ajax'));
}
}

View File

@@ -0,0 +1,68 @@
"""Include excluded_classes for Tenant
Revision ID: 110be45f6e44
Revises: 229774547fed
Create Date: 2024-08-22 07:37:40.591220
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = '110be45f6e44'
down_revision = '229774547fed'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('roles_users', schema=None) as batch_op:
batch_op.drop_constraint('roles_users_user_id_fkey', type_='foreignkey')
batch_op.drop_constraint('roles_users_role_id_fkey', type_='foreignkey')
batch_op.create_foreign_key(None, 'role', ['role_id'], ['id'], referent_schema='public', ondelete='CASCADE')
batch_op.create_foreign_key(None, 'user', ['user_id'], ['id'], referent_schema='public', ondelete='CASCADE')
with op.batch_alter_table('tenant', schema=None) as batch_op:
batch_op.add_column(sa.Column('html_excluded_classes', postgresql.ARRAY(sa.String(length=200)), nullable=True))
with op.batch_alter_table('tenant_domain', schema=None) as batch_op:
batch_op.drop_constraint('tenant_domain_updated_by_fkey', type_='foreignkey')
batch_op.drop_constraint('tenant_domain_tenant_id_fkey', type_='foreignkey')
batch_op.drop_constraint('tenant_domain_created_by_fkey', type_='foreignkey')
batch_op.create_foreign_key(None, 'user', ['updated_by'], ['id'], referent_schema='public')
batch_op.create_foreign_key(None, 'tenant', ['tenant_id'], ['id'], referent_schema='public')
batch_op.create_foreign_key(None, 'user', ['created_by'], ['id'], referent_schema='public')
with op.batch_alter_table('user', schema=None) as batch_op:
batch_op.drop_constraint('user_tenant_id_fkey', type_='foreignkey')
batch_op.create_foreign_key(None, 'tenant', ['tenant_id'], ['id'], referent_schema='public')
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('user', schema=None) as batch_op:
batch_op.drop_constraint(None, type_='foreignkey')
batch_op.create_foreign_key('user_tenant_id_fkey', 'tenant', ['tenant_id'], ['id'])
with op.batch_alter_table('tenant_domain', schema=None) as batch_op:
batch_op.drop_constraint(None, type_='foreignkey')
batch_op.drop_constraint(None, type_='foreignkey')
batch_op.drop_constraint(None, type_='foreignkey')
batch_op.create_foreign_key('tenant_domain_created_by_fkey', 'user', ['created_by'], ['id'])
batch_op.create_foreign_key('tenant_domain_tenant_id_fkey', 'tenant', ['tenant_id'], ['id'])
batch_op.create_foreign_key('tenant_domain_updated_by_fkey', 'user', ['updated_by'], ['id'])
with op.batch_alter_table('tenant', schema=None) as batch_op:
batch_op.drop_column('html_excluded_classes')
with op.batch_alter_table('roles_users', schema=None) as batch_op:
batch_op.drop_constraint(None, type_='foreignkey')
batch_op.drop_constraint(None, type_='foreignkey')
batch_op.create_foreign_key('roles_users_role_id_fkey', 'role', ['role_id'], ['id'], ondelete='CASCADE')
batch_op.create_foreign_key('roles_users_user_id_fkey', 'user', ['user_id'], ['id'], ondelete='CASCADE')
# ### end Alembic commands ###

View File

@@ -0,0 +1,68 @@
"""Add API Key to tenant - next to Chat API key
Revision ID: ce6f5b62bbfb
Revises: 110be45f6e44
Create Date: 2024-08-29 07:43:20.662983
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'ce6f5b62bbfb'
down_revision = '110be45f6e44'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('roles_users', schema=None) as batch_op:
batch_op.drop_constraint('roles_users_role_id_fkey', type_='foreignkey')
batch_op.drop_constraint('roles_users_user_id_fkey', type_='foreignkey')
batch_op.create_foreign_key(None, 'user', ['user_id'], ['id'], referent_schema='public', ondelete='CASCADE')
batch_op.create_foreign_key(None, 'role', ['role_id'], ['id'], referent_schema='public', ondelete='CASCADE')
with op.batch_alter_table('tenant', schema=None) as batch_op:
batch_op.add_column(sa.Column('encrypted_api_key', sa.String(length=500), nullable=True))
with op.batch_alter_table('tenant_domain', schema=None) as batch_op:
batch_op.drop_constraint('tenant_domain_tenant_id_fkey', type_='foreignkey')
batch_op.drop_constraint('tenant_domain_updated_by_fkey', type_='foreignkey')
batch_op.drop_constraint('tenant_domain_created_by_fkey', type_='foreignkey')
batch_op.create_foreign_key(None, 'user', ['updated_by'], ['id'], referent_schema='public')
batch_op.create_foreign_key(None, 'user', ['created_by'], ['id'], referent_schema='public')
batch_op.create_foreign_key(None, 'tenant', ['tenant_id'], ['id'], referent_schema='public')
with op.batch_alter_table('user', schema=None) as batch_op:
batch_op.drop_constraint('user_tenant_id_fkey', type_='foreignkey')
batch_op.create_foreign_key(None, 'tenant', ['tenant_id'], ['id'], referent_schema='public')
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('user', schema=None) as batch_op:
batch_op.drop_constraint(None, type_='foreignkey')
batch_op.create_foreign_key('user_tenant_id_fkey', 'tenant', ['tenant_id'], ['id'])
with op.batch_alter_table('tenant_domain', schema=None) as batch_op:
batch_op.drop_constraint(None, type_='foreignkey')
batch_op.drop_constraint(None, type_='foreignkey')
batch_op.drop_constraint(None, type_='foreignkey')
batch_op.create_foreign_key('tenant_domain_created_by_fkey', 'user', ['created_by'], ['id'])
batch_op.create_foreign_key('tenant_domain_updated_by_fkey', 'user', ['updated_by'], ['id'])
batch_op.create_foreign_key('tenant_domain_tenant_id_fkey', 'tenant', ['tenant_id'], ['id'])
with op.batch_alter_table('tenant', schema=None) as batch_op:
batch_op.drop_column('encrypted_api_key')
with op.batch_alter_table('roles_users', schema=None) as batch_op:
batch_op.drop_constraint(None, type_='foreignkey')
batch_op.drop_constraint(None, type_='foreignkey')
batch_op.create_foreign_key('roles_users_user_id_fkey', 'user', ['user_id'], ['id'], ondelete='CASCADE')
batch_op.create_foreign_key('roles_users_role_id_fkey', 'role', ['role_id'], ['id'], ondelete='CASCADE')
# ### end Alembic commands ###

View File

@@ -1,3 +1,4 @@
import inspect
import logging import logging
from logging.config import fileConfig from logging.config import fileConfig
@@ -5,12 +6,12 @@ from flask import current_app
from alembic import context from alembic import context
from sqlalchemy import NullPool, engine_from_config, text from sqlalchemy import NullPool, engine_from_config, text
import sqlalchemy as sa
from common.models.user import Tenant from sqlalchemy.sql import schema
import pgvector import pgvector
from pgvector.sqlalchemy import Vector from pgvector.sqlalchemy import Vector
from common.models import document, interaction, user
# this is the Alembic Config object, which provides # this is the Alembic Config object, which provides
# access to the values within the .ini file in use. # access to the values within the .ini file in use.
@@ -46,17 +47,34 @@ def get_engine_url():
config.set_main_option('sqlalchemy.url', get_engine_url()) config.set_main_option('sqlalchemy.url', get_engine_url())
target_db = current_app.extensions['migrate'].db target_db = current_app.extensions['migrate'].db
# other values from the config, defined by the needs of env.py,
# can be acquired: def get_public_table_names():
# my_important_option = config.get_main_option("my_important_option") # TODO: This function should include the necessary functionality to automatically retrieve table names
# ... etc. return ['role', 'roles_users', 'tenant', 'user', 'tenant_domain']
PUBLIC_TABLES = get_public_table_names()
logger.info(f"Public tables: {PUBLIC_TABLES}")
def include_object(object, name, type_, reflected, compare_to):
if type_ == "table" and name in PUBLIC_TABLES:
logger.info(f"Excluding public table: {name}")
return False
logger.info(f"Including object: {type_} {name}")
return True
# List of Tenants # List of Tenants
tenants_ = Tenant.query.all() def get_tenant_ids():
if tenants_: from common.models.user import Tenant
tenants = [tenant.id for tenant in tenants_]
else: tenants_ = Tenant.query.all()
tenants = [] if tenants_:
tenants = [tenant.id for tenant in tenants_]
else:
tenants = []
return tenants
def get_metadata(): def get_metadata():
@@ -79,7 +97,10 @@ def run_migrations_offline():
""" """
url = config.get_main_option("sqlalchemy.url") url = config.get_main_option("sqlalchemy.url")
context.configure( context.configure(
url=url, target_metadata=get_metadata(), literal_binds=True url=url,
target_metadata=get_metadata(),
# literal_binds=True,
include_object=include_object,
) )
with context.begin_transaction(): with context.begin_transaction():
@@ -99,18 +120,8 @@ def run_migrations_online():
poolclass=NullPool, poolclass=NullPool,
) )
def process_revision_directives(context, revision, directives):
if config.cmd_opts.autogenerate:
script = directives[0]
if script.upgrade_ops is not None:
# Add import for pgvector and set search path
script.upgrade_ops.ops.insert(0, sa.schema.ExecuteSQLOp(
"SET search_path TO CURRENT_SCHEMA(), public; IMPORT pgvector",
execution_options=None
))
with connectable.connect() as connection: with connectable.connect() as connection:
print(tenants) tenants = get_tenant_ids()
for tenant in tenants: for tenant in tenants:
logger.info(f"Migrating tenant: {tenant}") logger.info(f"Migrating tenant: {tenant}")
# set search path on the connection, which ensures that # set search path on the connection, which ensures that
@@ -127,7 +138,8 @@ def run_migrations_online():
context.configure( context.configure(
connection=connection, connection=connection,
target_metadata=get_metadata(), target_metadata=get_metadata(),
process_revision_directives=process_revision_directives, # literal_binds=True,
include_object=include_object,
) )
with context.begin_transaction(): with context.begin_transaction():

View File

@@ -0,0 +1,27 @@
"""Ensure logging information in document domain does not require user
Revision ID: 43eac8a7a00b
Revises: 5d5437d81041
Create Date: 2024-09-03 09:36:06.541938
"""
from alembic import op
import sqlalchemy as sa
import pgvector
# revision identifiers, used by Alembic.
revision = '43eac8a7a00b'
down_revision = '5d5437d81041'
branch_labels = None
depends_on = None
def upgrade():
# Manual upgrade commands
op.execute('ALTER TABLE document ALTER COLUMN created_by DROP NOT NULL')
def downgrade():
# Manual downgrade commands
op.execute('ALTER TABLE document ALTER COLUMN created_by SET NOT NULL')

View File

@@ -0,0 +1,33 @@
"""Adding metadata to DocumentVersion
Revision ID: 711a09a77680
Revises: 43eac8a7a00b
Create Date: 2024-09-06 14:37:39.647455
"""
from alembic import op
import sqlalchemy as sa
import pgvector
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = '711a09a77680'
down_revision = '43eac8a7a00b'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('document_version', sa.Column('user_metadata', postgresql.JSONB(astext_type=sa.Text()), nullable=True))
op.add_column('document_version', sa.Column('system_metadata', postgresql.JSONB(astext_type=sa.Text()), nullable=True))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('document_version', 'system_metadata')
op.drop_column('document_version', 'user_metadata')
# ### end Alembic commands ###

View File

@@ -32,7 +32,7 @@ http {
#keepalive_timeout 0; #keepalive_timeout 0;
keepalive_timeout 65; keepalive_timeout 65;
client_max_body_size 16M; client_max_body_size 50M;
#gzip on; #gzip on;
@@ -137,6 +137,27 @@ http {
} }
location /api/ {
proxy_pass http://eveai_api:5003;
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
proxy_set_header X-Forwarded-Prefix /api;
proxy_http_version 1.1;
proxy_set_header Upgrade $http_upgrade;
proxy_set_header Connection "upgrade";
proxy_buffering off;
proxy_buffer_size 16k;
proxy_buffers 4 32k;
proxy_busy_buffers_size 64k;
proxy_connect_timeout 60s;
proxy_send_timeout 60s;
proxy_read_timeout 60s;
send_timeout 60s;
}
location /flower/ { location /flower/ {
proxy_pass http://127.0.0.1:5555/; proxy_pass http://127.0.0.1:5555/;
proxy_set_header Host $host; proxy_set_header Host $host;

View File

@@ -0,0 +1,30 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Chat Client Evie</title>
<link href="https://fonts.googleapis.com/icon?family=Material+Icons" rel="stylesheet">
<script src="https://cdn.socket.io/4.0.1/socket.io.min.js"></script>
<script src="https://cdn.jsdelivr.net/npm/marked/marked.min.js"></script>
<script src="/static/js/eveai-sdk.js" defer></script>
<script src="/static/js/eveai-chat-widget.js" defer></script>
<link rel="stylesheet" href="/static/css/eveai-chat-style.css">
</head>
<body>
<div id="chat-container"></div>
<script>
document.addEventListener('DOMContentLoaded', function() {
const eveAI = new EveAI(
'2',
'EveAI-CHAT-9079-8604-7441-6496-7604',
'http://macstudio.ask-eve-ai-local.com',
'en',
'en,fr,nl',
'http://macstudio.ask-eve-ai-local.com:8080/'
);
eveAI.initializeChat('chat-container');
});
</script>
</body>
</html>

View File

@@ -386,6 +386,7 @@ input[type="radio"] {
.btn-danger:hover { .btn-danger:hover {
background-color: darken(var(--bs-danger), 10%) !important; /* Darken the background on hover */ background-color: darken(var(--bs-danger), 10%) !important; /* Darken the background on hover */
border-color: darken(var(--bs-danger), 10%) !important; /* Darken the border on hover */ border-color: darken(var(--bs-danger), 10%) !important; /* Darken the border on hover */
color: var(--bs-white) !important; /* Ensure the text remains white and readable */
} }
/* Success Alert Styling */ /* Success Alert Styling */
@@ -420,3 +421,90 @@ input[type="radio"] {
box-shadow: none; box-shadow: none;
} }
/* Custom styles for chat session view */
.accordion-button:not(.collapsed) {
background-color: var(--bs-primary);
color: var(--bs-white);
}
.accordion-button:focus {
box-shadow: 0 0 0 0.25rem rgba(118, 89, 154, 0.25);
}
.interaction-question {
font-size: 1rem; /* Normal text size */
}
.interaction-icons {
display: flex;
align-items: center;
}
.interaction-icons .material-icons {
font-size: 24px;
margin-left: 8px;
}
.thumb-icon.filled {
color: var(--bs-success);
}
.thumb-icon.outlined {
color: var(--thumb-icon-outlined);
}
/* Algorithm icon colors */
.algorithm-icon.rag_tenant {
color: var(--algorithm-color-rag-tenant);
}
.algorithm-icon.rag_wikipedia {
color: var(--algorithm-color-rag-wikipedia);
}
.algorithm-icon.rag_google {
color: var(--algorithm-color-rag-google);
}
.algorithm-icon.llm {
color: var(--algorithm-color-llm);
}
.accordion-body {
background-color: var(--bs-light);
}
/* Markdown content styles */
.markdown-content {
font-size: 1rem;
line-height: 1.5;
}
.markdown-content p {
margin-bottom: 1rem;
}
.markdown-content h1, .markdown-content h2, .markdown-content h3,
.markdown-content h4, .markdown-content h5, .markdown-content h6 {
margin-top: 1.5rem;
margin-bottom: 1rem;
}
.markdown-content ul, .markdown-content ol {
margin-bottom: 1rem;
padding-left: 2rem;
}
.markdown-content code {
background-color: #f8f9fa;
padding: 0.2em 0.4em;
border-radius: 3px;
}
.markdown-content pre {
background-color: #f8f9fa;
padding: 1rem;
border-radius: 5px;
overflow-x: auto;
}

View File

@@ -1,294 +0,0 @@
/* eveai_chat.css */
:root {
--user-message-bg: #292929; /* Default user message background color */
--bot-message-bg: #1e1e1e; /* Default bot message background color */
--chat-bg: #1e1e1e; /* Default chat background color */
--status-line-color: #e9e9e9; /* Color for the status line text */
--status-line-bg: #1e1e1e; /* Background color for the status line */
--status-line-height: 30px; /* Fixed height for the status line */
--algorithm-color-rag-tenant: #0f0; /* Green for RAG_TENANT */
--algorithm-color-rag-wikipedia: #00f; /* Blue for RAG_WIKIPEDIA */
--algorithm-color-rag-google: #ff0; /* Yellow for RAG_GOOGLE */
--algorithm-color-llm: #800080; /* Purple for RAG_LLM */
/*--font-family: 'Arial, sans-serif'; !* Default font family *!*/
--font-family: 'Segoe UI, Roboto, Cantarell, Noto Sans, Apple Color Emoji, Segoe UI Emoji, Segoe UI Symbol';
--font-color: #e9e9e9; /* Default font color */
--user-message-font-color: #e9e9e9; /* User message font color */
--bot-message-font-color: #e9e9e9; /* Bot message font color */
--input-bg: #292929; /* Input background color */
--input-border: #ccc; /* Input border color */
--input-text-color: #e9e9e9; /* Input text color */
--button-color: #007bff; /* Button text color */
/* Variables for hyperlink backgrounds */
--link-bg: #1e1e1e; /* Default background color for hyperlinks */
--link-hover-bg: #1e1e1e; /* Background color on hover for hyperlinks */
--link-color: #dec981; /* Default text color for hyperlinks */
--link-hover-color: #D68F53; /* Text color on hover for hyperlinks */
/* New scrollbar variables */
--scrollbar-bg: #292929; /* Background color for the scrollbar track */
--scrollbar-thumb: #4b4b4b; /* Color for the scrollbar thumb */
--scrollbar-thumb-hover: #dec981; /* Color for the thumb on hover */
--scrollbar-thumb-active: #D68F53; /* Color for the thumb when active (dragged) */
/* Thumb colors */
--thumb-icon-outlined: #4b4b4b;
--thumb-icon-filled: #e9e9e9;
/* Connection Status colors */
--status-connected-color: #28a745; /* Green color for connected status */
--status-disconnected-color: #ffc107; /* Orange color for disconnected status */
}
/* Connection status styles */
.connection-status-icon {
vertical-align: middle;
font-size: 24px;
margin-right: 8px;
}
.status-connected {
color: var(--status-connected-color);
}
.status-disconnected {
color: var(--status-disconnected-color);
}
/* Custom scrollbar styles */
.messages-area::-webkit-scrollbar {
width: 12px; /* Width of the scrollbar */
}
.messages-area::-webkit-scrollbar-track {
background: var(--scrollbar-bg); /* Background color for the track */
}
.messages-area::-webkit-scrollbar-thumb {
background-color: var(--scrollbar-thumb); /* Color of the thumb */
border-radius: 10px; /* Rounded corners for the thumb */
border: 3px solid var(--scrollbar-bg); /* Space around the thumb */
}
.messages-area::-webkit-scrollbar-thumb:hover {
background-color: var(--scrollbar-thumb-hover); /* Color when hovering over the thumb */
}
.messages-area::-webkit-scrollbar-thumb:active {
background-color: var(--scrollbar-thumb-active); /* Color when active (dragging) */
}
/* For Firefox */
.messages-area {
scrollbar-width: thin; /* Make scrollbar thinner */
scrollbar-color: var(--scrollbar-thumb) var(--scrollbar-bg); /* Thumb and track colors */
}
/* General Styles */
.chat-container {
display: flex;
flex-direction: column;
height: 75vh;
/*max-height: 100vh;*/
max-width: 600px;
margin: auto;
border: 1px solid #ccc;
border-radius: 8px;
overflow: hidden;
background-color: var(--chat-bg);
font-family: var(--font-family); /* Apply the default font family */
color: var(--font-color); /* Apply the default font color */
}
.disclaimer {
font-size: 0.7em;
text-align: right;
padding: 5px 20px 5px 5px;
margin-bottom: 5px;
}
.messages-area {
flex: 1;
overflow-y: auto;
padding: 10px;
background-color: var(--bot-message-bg);
}
.message {
max-width: 90%;
margin-bottom: 10px;
padding: 10px;
border-radius: 15px;
font-size: 1rem;
}
.message.user {
margin-left: auto;
background-color: var(--user-message-bg);
color: var(--user-message-font-color); /* Apply user message font color */
}
.message.bot {
background-color: var(--bot-message-bg);
color: var(--bot-message-font-color); /* Apply bot message font color */
}
.message-icons {
display: flex;
align-items: center;
}
/* Scoped styles for thumb icons */
.thumb-icon.outlined {
color: var(--thumb-icon-outlined); /* Color for outlined state */
}
.thumb-icon.filled {
color: var(--thumb-icon-filled); /* Color for filled state */
}
/* Default styles for material icons */
.material-icons {
font-size: 24px;
vertical-align: middle;
cursor: pointer;
}
.question-area {
display: flex;
flex-direction: row;
align-items: center;
background-color: var(--user-message-bg);
padding: 10px;
}
.language-select-container {
width: 100%;
margin-bottom: 10px; /* Spacing between the dropdown and the textarea */
}
.language-select {
width: 100%;
margin-bottom: 5px; /* Space between the dropdown and the send button */
padding: 8px;
border-radius: 5px;
border: 1px solid var(--input-border);
background-color: var(--input-bg);
color: var(--input-text-color);
font-size: 1rem;
}
.question-area textarea {
flex: 1;
border: none;
padding: 10px;
border-radius: 15px;
background-color: var(--input-bg);
border: 1px solid var(--input-border);
color: var(--input-text-color);
font-family: var(--font-family); /* Apply the default font family */
font-size: 1rem;
resize: vertical;
min-height: 60px;
max-height: 150px;
overflow-y: auto;
margin-right: 10px; /* Space between textarea and right-side container */
}
.right-side {
display: flex;
flex-direction: column;
align-items: center;
}
.question-area button {
background: none;
border: none;
cursor: pointer;
color: var(--button-color);
}
/* Styles for the send icon */
.send-icon {
font-size: 24px;
color: var(--button-color);
cursor: pointer;
}
.send-icon.disabled {
color: grey; /* Color for the disabled state */
cursor: not-allowed; /* Change cursor to indicate disabled state */
}
/* New CSS for the status-line */
.status-line {
height: var(--status-line-height); /* Fixed height for the status line */
padding: 5px 10px;
background-color: var(--status-line-bg); /* Background color */
color: var(--status-line-color); /* Text color */
font-size: 0.9rem; /* Slightly smaller font size */
text-align: center; /* Centered text */
border-top: 1px solid #ccc; /* Subtle top border */
display: flex;
align-items: center;
justify-content: flex-start;
}
/* Algorithm-specific colors for fingerprint icon */
.fingerprint-rag-tenant {
color: var(--algorithm-color-rag-tenant);
}
.fingerprint-rag-wikipedia {
color: var(--algorithm-color-rag-wikipedia);
}
.fingerprint-rag-google {
color: var(--algorithm-color-rag-google);
}
.fingerprint-llm {
color: var(--algorithm-color-llm);
}
/* Styling for citation links */
.citations a {
background-color: var(--link-bg); /* Apply default background color */
color: var(--link-color); /* Apply default link color */
padding: 2px 4px; /* Add padding for better appearance */
border-radius: 3px; /* Add slight rounding for a modern look */
text-decoration: none; /* Remove default underline */
transition: background-color 0.3s, color 0.3s; /* Smooth transition for hover effects */
}
.citations a:hover {
background-color: var(--link-hover-bg); /* Background color on hover */
color: var(--link-hover-color); /* Text color on hover */
}
/* Media queries for responsiveness */
@media (max-width: 768px) {
.chat-container {
max-width: 90%; /* Reduce max width on smaller screens */
}
}
@media (max-width: 480px) {
.chat-container {
max-width: 95%; /* Further reduce max width on very small screens */
}
.question-area input {
font-size: 0.9rem; /* Adjust input font size for smaller screens */
}
.status-line {
font-size: 0.8rem; /* Adjust status line font size for smaller screens */
}
}

View File

@@ -1,470 +0,0 @@
class EveAIChatWidget extends HTMLElement {
static get observedAttributes() {
return ['tenant-id', 'api-key', 'domain', 'language', 'languages'];
}
constructor() {
super();
this.socket = null; // Initialize socket to null
this.attributesSet = false; // Flag to check if all attributes are set
this.jwtToken = null; // Initialize jwtToken to null
this.userTimezone = Intl.DateTimeFormat().resolvedOptions().timeZone; // Detect user's timezone
this.heartbeatInterval = null;
this.idleTime = 0; // in milliseconds
this.maxConnectionIdleTime = 1 * 60 * 60 * 1000; // 1 hours in milliseconds
this.languages = []
this.room = null;
console.log('EveAIChatWidget constructor called');
}
connectedCallback() {
console.log('connectedCallback called');
this.innerHTML = this.getTemplate();
this.messagesArea = this.querySelector('.messages-area');
this.questionInput = this.querySelector('.question-area textarea');
this.sendButton = this.querySelector('.send-icon');
this.languageSelect = this.querySelector('.language-select');
this.statusLine = this.querySelector('.status-line');
this.statusMessage = this.querySelector('.status-message');
this.connectionStatusIcon = this.querySelector('.connection-status-icon');
this.sendButton.addEventListener('click', () => this.handleSendMessage());
this.questionInput.addEventListener('keydown', (event) => {
if (event.key === 'Enter' && !event.shiftKey) {
event.preventDefault(); // Prevent adding a new line
this.handleSendMessage();
}
});
if (this.areAllAttributesSet() && !this.socket) {
console.log('Attributes already set in connectedCallback, initializing socket');
this.initializeSocket();
}
}
populateLanguageDropdown() {
// Clear existing options
this.languageSelect.innerHTML = '';
console.log(`languages for options: ${this.languages}`)
// Populate with new options
this.languages.forEach(lang => {
const option = document.createElement('option');
option.value = lang;
option.textContent = lang.toUpperCase();
if (lang === this.currentLanguage) {
option.selected = true;
}
console.log(`Adding option for language: ${lang}`)
this.languageSelect.appendChild(option);
});
// Add event listener for language change
this.languageSelect.addEventListener('change', (e) => {
this.currentLanguage = e.target.value;
// You might want to emit an event or update the backend about the language change
});
}
attributeChangedCallback(name, oldValue, newValue) {
console.log(`attributeChangedCallback called: ${name} changed from ${oldValue} to ${newValue}`);
this.updateAttributes();
if (this.areAllAttributesSet() && !this.socket) {
console.log('All attributes set in attributeChangedCallback, initializing socket');
this.attributesSet = true;
console.log('All attributes are set, populating language dropdown');
this.populateLanguageDropdown();
console.log('All attributes are set, initializing socket')
this.initializeSocket();
}
}
updateAttributes() {
this.tenantId = this.getAttribute('tenant-id');
this.apiKey = this.getAttribute('api-key');
this.domain = this.getAttribute('domain');
this.language = this.getAttribute('language');
const languageAttr = this.getAttribute('languages');
this.languages = languageAttr ? languageAttr.split(',') : [];
this.currentLanguage = this.language;
console.log('Updated attributes:', {
tenantId: this.tenantId,
apiKey: this.apiKey,
domain: this.domain,
language: this.language,
currentLanguage: this.currentLanguage,
languages: this.languages
});
}
areAllAttributesSet() {
const tenantId = this.getAttribute('tenant-id');
const apiKey = this.getAttribute('api-key');
const domain = this.getAttribute('domain');
const language = this.getAttribute('language');
const languages = this.getAttribute('languages');
console.log('Checking if all attributes are set:', {
tenantId,
apiKey,
domain,
language,
languages
});
return tenantId && apiKey && domain && language && languages;
}
createLanguageDropdown() {
const select = document.createElement('select');
select.id = 'languageSelect';
this.languages.forEach(lang => {
const option = document.createElement('option');
option.value = lang;
option.textContent = lang.toUpperCase();
if (lang === this.currentLanguage) {
option.selected = true;
}
select.appendChild(option);
});
select.addEventListener('change', (e) => {
this.currentLanguage = e.target.value;
// You might want to emit an event or update the backend about the language change
});
return select;
}
initializeSocket() {
if (this.socket) {
console.log('Socket already initialized');
return;
}
console.log(`Initializing socket connection to Evie`);
// Ensure apiKey is passed in the query parameters
this.socket = io('https://evie.askeveai.com', {
path: '/chat/socket.io/',
transports: ['websocket', 'polling'],
query: {
tenantId: this.tenantId,
apiKey: this.apiKey // Ensure apiKey is included here
},
auth: {
token: 'Bearer ' + this.apiKey // Ensure token is included here
},
reconnectionAttempts: Infinity, // Infinite reconnection attempts
reconnectionDelay: 5000, // Delay between reconnections
timeout: 20000 // Connection timeout
});
console.log(`Finished initializing socket connection to Evie`);
this.socket.on('connect', (data) => {
console.log('Socket connected OK');
this.setStatusMessage('Connected to EveAI.');
this.updateConnectionStatus(true);
this.startHeartbeat();
if (data.room) {
this.room = data.room;
console.log(`Joined room: ${this.room}`);
}
});
this.socket.on('authenticated', (data) => {
console.log('Authenticated event received: ', data);
this.setStatusMessage('Authenticated.');
if (data.token) {
this.jwtToken = data.token; // Store the JWT token received from the server
}
if (data.room) {
this.room = data.room;
console.log(`Confirmed room: ${this.room}`);
}
});
this.socket.on('connect_error', (err) => {
console.error('Socket connection error:', err);
this.setStatusMessage('EveAI Chat Widget needs further configuration by site administrator.');
this.updateConnectionStatus(false);
});
this.socket.on('connect_timeout', () => {
console.error('Socket connection timeout');
this.setStatusMessage('EveAI Chat Widget needs further configuration by site administrator.');
this.updateConnectionStatus(false);
});
this.socket.on('disconnect', (reason) => {
console.log('Socket disconnected: ', reason);
if (reason === 'io server disconnect') {
// Server disconnected the socket
this.socket.connect(); // Attempt to reconnect
}
this.setStatusMessage('Disconnected from EveAI. Please refresh the page for further interaction.');
this.updateConnectionStatus(false);
this.stopHeartbeat();
});
this.socket.on('reconnect_attempt', () => {
console.log('Attempting to reconnect to the server...');
this.setStatusMessage('Attempting to reconnect...');
});
this.socket.on('reconnect', () => {
console.log('Successfully reconnected to the server');
this.setStatusMessage('Reconnected to EveAI.');
this.updateConnectionStatus(true);
this.startHeartbeat();
});
this.socket.on('bot_response', (data) => {
if (data.tenantId === this.tenantId) {
console.log('Initial response received:', data);
console.log('Task ID received:', data.taskId);
this.checkTaskStatus(data.taskId);
this.setStatusMessage('Processing...');
}
});
this.socket.on('task_status', (data) => {
console.log('Task status received:', data.status);
console.log('Task ID received:', data.taskId);
console.log('Citations type:', typeof data.citations, 'Citations:', data.citations);
if (data.status === 'pending') {
this.updateProgress();
setTimeout(() => this.checkTaskStatus(data.taskId), 1000); // Poll every second
} else if (data.status === 'success') {
this.addBotMessage(data.answer, data.interaction_id, data.algorithm, data.citations);
this.clearProgress(); // Clear progress indicator when done
} else {
this.setStatusMessage('Failed to process message.');
}
});
}
setStatusMessage(message) {
this.statusMessage.textContent = message;
}
updateConnectionStatus(isConnected) {
if (isConnected) {
this.connectionStatusIcon.textContent = 'link';
this.connectionStatusIcon.classList.remove('status-disconnected');
this.connectionStatusIcon.classList.add('status-connected');
} else {
this.connectionStatusIcon.textContent = 'link_off';
this.connectionStatusIcon.classList.remove('status-connected');
this.connectionStatusIcon.classList.add('status-disconnected');
}
}
startHeartbeat() {
this.stopHeartbeat(); // Clear any existing interval
this.heartbeatInterval = setInterval(() => {
if (this.socket && this.socket.connected) {
this.socket.emit('heartbeat');
this.idleTime += 30000;
if (this.idleTime >= this.maxConnectionIdleTime) {
this.socket.disconnect();
this.setStatusMessage('Disconnected due to inactivity.');
this.updateConnectionStatus(false);
this.stopHeartbeat();
}
}
}, 30000); // Send a heartbeat every 30 seconds
}
stopHeartbeat() {
if (this.heartbeatInterval) {
clearInterval(this.heartbeatInterval);
this.heartbeatInterval = null;
}
}
updateProgress() {
if (!this.statusMessage.textContent) {
this.statusMessage.textContent = 'Processing...';
} else {
this.statusMessage.textContent += '.'; // Append a dot
}
}
clearProgress() {
this.statusMessage.textContent = '';
this.toggleSendButton(false); // Re-enable and revert send button to outlined version
}
checkTaskStatus(taskId) {
this.updateProgress();
this.socket.emit('check_task_status', { task_id: taskId });
}
getTemplate() {
return `
<div class="chat-container">
<div class="messages-area"></div>
<div class="disclaimer">Evie can make mistakes. Please double-check responses.</div>
<div class="question-area">
<textarea placeholder="Type your message here..." rows="3"></textarea>
<div class="right-side">
<select class="language-select"></select>
<i class="material-icons send-icon outlined">send</i>
</div>
</div>
<div class="status-line">
<i class="material-icons connection-status-icon">link_off</i>
<span class="status-message"></span>
</div>
</div>
`;
}
addUserMessage(text) {
const message = document.createElement('div');
message.classList.add('message', 'user');
message.innerHTML = `<p>${text}</p>`;
this.messagesArea.appendChild(message);
this.messagesArea.scrollTop = this.messagesArea.scrollHeight;
}
handleFeedback(feedback, interactionId) {
// Send feedback to the backend
console.log('handleFeedback called');
if (!this.socket) {
console.error('Socket is not initialized');
return;
}
if (!this.jwtToken) {
console.error('JWT token is not available');
return;
}
console.log('Sending message to backend');
console.log(`Feedback for ${interactionId}: ${feedback}`);
this.socket.emit('feedback', { tenantId: this.tenantId, token: this.jwtToken, feedback, interactionId });
this.setStatusMessage('Feedback sent.');
}
addBotMessage(text, interactionId, algorithm = 'default', citations = []) {
const message = document.createElement('div');
message.classList.add('message', 'bot');
let content = marked.parse(text);
let citationsHtml = citations.map(url => `<a href="${url}" target="_blank">${url}</a>`).join('<br>');
let algorithmClass;
switch (algorithm) {
case 'RAG_TENANT':
algorithmClass = 'fingerprint-rag-tenant';
break;
case 'RAG_WIKIPEDIA':
algorithmClass = 'fingerprint-rag-wikipedia';
break;
case 'RAG_GOOGLE':
algorithmClass = 'fingerprint-rag-google';
break;
case 'LLM':
algorithmClass = 'fingerprint-llm';
break;
default:
algorithmClass = '';
}
message.innerHTML = `
<p>${content}</p>
${citationsHtml ? `<p class="citations">${citationsHtml}</p>` : ''}
<div class="message-icons">
<i class="material-icons ${algorithmClass}">fingerprint</i>
<i class="material-icons thumb-icon outlined" data-feedback="up" data-interaction-id="${interactionId}">thumb_up_off_alt</i>
<i class="material-icons thumb-icon outlined" data-feedback="down" data-interaction-id="${interactionId}">thumb_down_off_alt</i>
</div>
`;
this.messagesArea.appendChild(message);
// Add event listeners for feedback buttons
const thumbsUp = message.querySelector('i[data-feedback="up"]');
const thumbsDown = message.querySelector('i[data-feedback="down"]');
thumbsUp.addEventListener('click', () => this.toggleFeedback(thumbsUp, thumbsDown, 'up', interactionId));
thumbsDown.addEventListener('click', () => this.toggleFeedback(thumbsUp, thumbsDown, 'down', interactionId));
this.messagesArea.scrollTop = this.messagesArea.scrollHeight;
}
toggleFeedback(thumbsUp, thumbsDown, feedback, interactionId) {
console.log('feedback called');
this.idleTime = 0; // Reset idle time
if (feedback === 'up') {
thumbsUp.textContent = 'thumb_up'; // Change to filled icon
thumbsUp.classList.remove('outlined');
thumbsUp.classList.add('filled');
thumbsDown.textContent = 'thumb_down_off_alt'; // Keep the other icon outlined
thumbsDown.classList.add('outlined');
thumbsDown.classList.remove('filled');
} else {
thumbsDown.textContent = 'thumb_down'; // Change to filled icon
thumbsDown.classList.remove('outlined');
thumbsDown.classList.add('filled');
thumbsUp.textContent = 'thumb_up_off_alt'; // Keep the other icon outlined
thumbsUp.classList.add('outlined');
thumbsUp.classList.remove('filled');
}
// Send feedback to the backend
this.handleFeedback(feedback, interactionId);
}
handleSendMessage() {
console.log('handleSendMessage called');
this.idleTime = 0; // Reset idle time
const message = this.questionInput.value.trim();
if (message) {
this.addUserMessage(message);
this.questionInput.value = '';
this.sendMessageToBackend(message);
this.toggleSendButton(true); // Disable and change send button to filled version
}
}
sendMessageToBackend(message) {
console.log('sendMessageToBackend called');
if (!this.socket) {
console.error('Socket is not initialized');
return;
}
if (!this.jwtToken) {
console.error('JWT token is not available');
return;
}
const selectedLanguage = this.languageSelect.value;
console.log('Sending message to backend');
this.socket.emit('user_message', {
tenantId: this.tenantId,
token: this.jwtToken,
message,
language: selectedLanguage,
timezone: this.userTimezone
});
this.setStatusMessage('Processing started ...')
}
toggleSendButton(isProcessing) {
if (isProcessing) {
this.sendButton.textContent = 'send'; // Filled send icon
this.sendButton.classList.remove('outlined');
this.sendButton.classList.add('filled');
this.sendButton.classList.add('disabled'); // Add disabled class for styling
this.sendButton.style.pointerEvents = 'none'; // Disable click events
} else {
this.sendButton.textContent = 'send'; // Outlined send icon
this.sendButton.classList.add('outlined');
this.sendButton.classList.remove('filled');
this.sendButton.classList.remove('disabled'); // Remove disabled class
this.sendButton.style.pointerEvents = 'auto'; // Re-enable click events
}
}
}
customElements.define('eveai-chat-widget', EveAIChatWidget);

View File

@@ -1,29 +0,0 @@
// static/js/eveai-sdk.js
class EveAI {
constructor(tenantId, apiKey, domain, language, languages) {
this.tenantId = tenantId;
this.apiKey = apiKey;
this.domain = domain;
this.language = language;
this.languages = languages;
console.log('EveAI constructor:', { tenantId, apiKey, domain, language, languages });
}
initializeChat(containerId) {
const container = document.getElementById(containerId);
if (container) {
container.innerHTML = '<eveai-chat-widget></eveai-chat-widget>';
customElements.whenDefined('eveai-chat-widget').then(() => {
const chatWidget = container.querySelector('eveai-chat-widget');
chatWidget.setAttribute('tenant-id', this.tenantId);
chatWidget.setAttribute('api-key', this.apiKey);
chatWidget.setAttribute('domain', this.domain);
chatWidget.setAttribute('language', this.language);
chatWidget.setAttribute('languages', this.languages);
});
} else {
console.error('Container not found');
}
}
}

View File

@@ -1,19 +0,0 @@
# flake8: noqa: F401
# noreorder
"""
Pytube: a very serious Python library for downloading YouTube Videos.
"""
__title__ = "pytube"
__author__ = "Ronnie Ghose, Taylor Fox Dahlin, Nick Ficano"
__license__ = "The Unlicense (Unlicense)"
__js__ = None
__js_url__ = None
from pytube.version import __version__
from pytube.streams import Stream
from pytube.captions import Caption
from pytube.query import CaptionQuery, StreamQuery
from pytube.__main__ import YouTube
from pytube.contrib.playlist import Playlist
from pytube.contrib.channel import Channel
from pytube.contrib.search import Search

View File

@@ -1,479 +0,0 @@
"""
This module implements the core developer interface for pytube.
The problem domain of the :class:`YouTube <YouTube> class focuses almost
exclusively on the developer interface. Pytube offloads the heavy lifting to
smaller peripheral modules and functions.
"""
import logging
from typing import Any, Callable, Dict, List, Optional
import pytube
import pytube.exceptions as exceptions
from pytube import extract, request
from pytube import Stream, StreamQuery
from pytube.helpers import install_proxy
from pytube.innertube import InnerTube
from pytube.metadata import YouTubeMetadata
from pytube.monostate import Monostate
logger = logging.getLogger(__name__)
class YouTube:
"""Core developer interface for pytube."""
def __init__(
self,
url: str,
on_progress_callback: Optional[Callable[[Any, bytes, int], None]] = None,
on_complete_callback: Optional[Callable[[Any, Optional[str]], None]] = None,
proxies: Dict[str, str] = None,
use_oauth: bool = False,
allow_oauth_cache: bool = True
):
"""Construct a :class:`YouTube <YouTube>`.
:param str url:
A valid YouTube watch URL.
:param func on_progress_callback:
(Optional) User defined callback function for stream download
progress events.
:param func on_complete_callback:
(Optional) User defined callback function for stream download
complete events.
:param dict proxies:
(Optional) A dict mapping protocol to proxy address which will be used by pytube.
:param bool use_oauth:
(Optional) Prompt the user to authenticate to YouTube.
If allow_oauth_cache is set to True, the user should only be prompted once.
:param bool allow_oauth_cache:
(Optional) Cache OAuth tokens locally on the machine. Defaults to True.
These tokens are only generated if use_oauth is set to True as well.
"""
self._js: Optional[str] = None # js fetched by js_url
self._js_url: Optional[str] = None # the url to the js, parsed from watch html
self._vid_info: Optional[Dict] = None # content fetched from innertube/player
self._watch_html: Optional[str] = None # the html of /watch?v=<video_id>
self._embed_html: Optional[str] = None
self._player_config_args: Optional[Dict] = None # inline js in the html containing
self._age_restricted: Optional[bool] = None
self._fmt_streams: Optional[List[Stream]] = None
self._initial_data = None
self._metadata: Optional[YouTubeMetadata] = None
# video_id part of /watch?v=<video_id>
self.video_id = extract.video_id(url)
self.watch_url = f"https://youtube.com/watch?v={self.video_id}"
self.embed_url = f"https://www.youtube.com/embed/{self.video_id}"
# Shared between all instances of `Stream` (Borg pattern).
self.stream_monostate = Monostate(
on_progress=on_progress_callback, on_complete=on_complete_callback
)
if proxies:
install_proxy(proxies)
self._author = None
self._title = None
self._publish_date = None
self.use_oauth = use_oauth
self.allow_oauth_cache = allow_oauth_cache
def __repr__(self):
return f'<pytube.__main__.YouTube object: videoId={self.video_id}>'
def __eq__(self, o: object) -> bool:
# Compare types and urls, if they're same return true, else return false.
return type(o) == type(self) and o.watch_url == self.watch_url
@property
def watch_html(self):
if self._watch_html:
return self._watch_html
self._watch_html = request.get(url=self.watch_url)
return self._watch_html
@property
def embed_html(self):
if self._embed_html:
return self._embed_html
self._embed_html = request.get(url=self.embed_url)
return self._embed_html
@property
def age_restricted(self):
if self._age_restricted:
return self._age_restricted
self._age_restricted = extract.is_age_restricted(self.watch_html)
return self._age_restricted
@property
def js_url(self):
if self._js_url:
return self._js_url
if self.age_restricted:
self._js_url = extract.js_url(self.embed_html)
else:
self._js_url = extract.js_url(self.watch_html)
return self._js_url
@property
def js(self):
if self._js:
return self._js
# If the js_url doesn't match the cached url, fetch the new js and update
# the cache; otherwise, load the cache.
if pytube.__js_url__ != self.js_url:
self._js = request.get(self.js_url)
pytube.__js__ = self._js
pytube.__js_url__ = self.js_url
else:
self._js = pytube.__js__
return self._js
@property
def initial_data(self):
if self._initial_data:
return self._initial_data
self._initial_data = extract.initial_data(self.watch_html)
return self._initial_data
@property
def streaming_data(self):
"""Return streamingData from video info."""
if 'streamingData' in self.vid_info:
return self.vid_info['streamingData']
else:
self.bypass_age_gate()
return self.vid_info['streamingData']
@property
def fmt_streams(self):
"""Returns a list of streams if they have been initialized.
If the streams have not been initialized, finds all relevant
streams and initializes them.
"""
self.check_availability()
if self._fmt_streams:
return self._fmt_streams
self._fmt_streams = []
stream_manifest = extract.apply_descrambler(self.streaming_data)
# If the cached js doesn't work, try fetching a new js file
# https://github.com/pytube/pytube/issues/1054
try:
extract.apply_signature(stream_manifest, self.vid_info, self.js)
except exceptions.ExtractError:
# To force an update to the js file, we clear the cache and retry
self._js = None
self._js_url = None
pytube.__js__ = None
pytube.__js_url__ = None
extract.apply_signature(stream_manifest, self.vid_info, self.js)
# build instances of :class:`Stream <Stream>`
# Initialize stream objects
for stream in stream_manifest:
video = Stream(
stream=stream,
monostate=self.stream_monostate,
)
self._fmt_streams.append(video)
self.stream_monostate.title = self.title
self.stream_monostate.duration = self.length
return self._fmt_streams
def check_availability(self):
"""Check whether the video is available.
Raises different exceptions based on why the video is unavailable,
otherwise does nothing.
"""
status, messages = extract.playability_status(self.watch_html)
for reason in messages:
if status == 'UNPLAYABLE':
if reason == (
'Join this channel to get access to members-only content '
'like this video, and other exclusive perks.'
):
raise exceptions.MembersOnly(video_id=self.video_id)
elif reason == 'This live stream recording is not available.':
raise exceptions.RecordingUnavailable(video_id=self.video_id)
else:
raise exceptions.VideoUnavailable(video_id=self.video_id)
elif status == 'LOGIN_REQUIRED':
if reason == (
'This is a private video. '
'Please sign in to verify that you may see it.'
):
raise exceptions.VideoPrivate(video_id=self.video_id)
elif status == 'ERROR':
if reason == 'Video unavailable':
raise exceptions.VideoUnavailable(video_id=self.video_id)
elif status == 'LIVE_STREAM':
raise exceptions.LiveStreamError(video_id=self.video_id)
@property
def vid_info(self):
"""Parse the raw vid info and return the parsed result.
:rtype: Dict[Any, Any]
"""
if self._vid_info:
return self._vid_info
innertube = InnerTube(use_oauth=self.use_oauth, allow_cache=self.allow_oauth_cache)
innertube_response = innertube.player(self.video_id)
self._vid_info = innertube_response
return self._vid_info
def bypass_age_gate(self):
"""Attempt to update the vid_info by bypassing the age gate."""
innertube = InnerTube(
client='ANDROID_EMBED',
use_oauth=self.use_oauth,
allow_cache=self.allow_oauth_cache
)
innertube_response = innertube.player(self.video_id)
playability_status = innertube_response['playabilityStatus'].get('status', None)
# If we still can't access the video, raise an exception
# (tier 3 age restriction)
if playability_status == 'UNPLAYABLE':
raise exceptions.AgeRestrictedError(self.video_id)
self._vid_info = innertube_response
@property
def caption_tracks(self) -> List[pytube.Caption]:
"""Get a list of :class:`Caption <Caption>`.
:rtype: List[Caption]
"""
raw_tracks = (
self.vid_info.get("captions", {})
.get("playerCaptionsTracklistRenderer", {})
.get("captionTracks", [])
)
return [pytube.Caption(track) for track in raw_tracks]
@property
def captions(self) -> pytube.CaptionQuery:
"""Interface to query caption tracks.
:rtype: :class:`CaptionQuery <CaptionQuery>`.
"""
return pytube.CaptionQuery(self.caption_tracks)
@property
def streams(self) -> StreamQuery:
"""Interface to query both adaptive (DASH) and progressive streams.
:rtype: :class:`StreamQuery <StreamQuery>`.
"""
self.check_availability()
return StreamQuery(self.fmt_streams)
@property
def thumbnail_url(self) -> str:
"""Get the thumbnail url image.
:rtype: str
"""
thumbnail_details = (
self.vid_info.get("videoDetails", {})
.get("thumbnail", {})
.get("thumbnails")
)
if thumbnail_details:
thumbnail_details = thumbnail_details[-1] # last item has max size
return thumbnail_details["url"]
return f"https://img.youtube.com/vi/{self.video_id}/maxresdefault.jpg"
@property
def publish_date(self):
"""Get the publish date.
:rtype: datetime
"""
if self._publish_date:
return self._publish_date
self._publish_date = extract.publish_date(self.watch_html)
return self._publish_date
@publish_date.setter
def publish_date(self, value):
"""Sets the publish date."""
self._publish_date = value
@property
def title(self) -> str:
"""Get the video title.
:rtype: str
"""
if self._title:
return self._title
try:
self._title = self.vid_info['videoDetails']['title']
except KeyError:
# Check_availability will raise the correct exception in most cases
# if it doesn't, ask for a report.
self.check_availability()
raise exceptions.PytubeError(
(
f'Exception while accessing title of {self.watch_url}. '
'Please file a bug report at https://github.com/pytube/pytube'
)
)
return self._title
@title.setter
def title(self, value):
"""Sets the title value."""
self._title = value
@property
def description(self) -> str:
"""Get the video description.
:rtype: str
"""
return self.vid_info.get("videoDetails", {}).get("shortDescription")
@property
def rating(self) -> float:
"""Get the video average rating.
:rtype: float
"""
return self.vid_info.get("videoDetails", {}).get("averageRating")
@property
def length(self) -> int:
"""Get the video length in seconds.
:rtype: int
"""
return int(self.vid_info.get('videoDetails', {}).get('lengthSeconds'))
@property
def views(self) -> int:
"""Get the number of the times the video has been viewed.
:rtype: int
"""
return int(self.vid_info.get("videoDetails", {}).get("viewCount"))
@property
def author(self) -> str:
"""Get the video author.
:rtype: str
"""
if self._author:
return self._author
self._author = self.vid_info.get("videoDetails", {}).get(
"author", "unknown"
)
return self._author
@author.setter
def author(self, value):
"""Set the video author."""
self._author = value
@property
def keywords(self) -> List[str]:
"""Get the video keywords.
:rtype: List[str]
"""
return self.vid_info.get('videoDetails', {}).get('keywords', [])
@property
def channel_id(self) -> str:
"""Get the video poster's channel id.
:rtype: str
"""
return self.vid_info.get('videoDetails', {}).get('channelId', None)
@property
def channel_url(self) -> str:
"""Construct the channel url for the video's poster from the channel id.
:rtype: str
"""
return f'https://www.youtube.com/channel/{self.channel_id}'
@property
def metadata(self) -> Optional[YouTubeMetadata]:
"""Get the metadata for the video.
:rtype: YouTubeMetadata
"""
if self._metadata:
return self._metadata
else:
self._metadata = extract.metadata(self.initial_data)
return self._metadata
def register_on_progress_callback(self, func: Callable[[Any, bytes, int], None]):
"""Register a download progress callback function post initialization.
:param callable func:
A callback function that takes ``stream``, ``chunk``,
and ``bytes_remaining`` as parameters.
:rtype: None
"""
self.stream_monostate.on_progress = func
def register_on_complete_callback(self, func: Callable[[Any, Optional[str]], None]):
"""Register a download complete callback function post initialization.
:param callable func:
A callback function that takes ``stream`` and ``file_path``.
:rtype: None
"""
self.stream_monostate.on_complete = func
@staticmethod
def from_id(video_id: str) -> "YouTube":
"""Construct a :class:`YouTube <YouTube>` object from a video id.
:param str video_id:
The video id of the YouTube video.
:rtype: :class:`YouTube <YouTube>`
"""
return YouTube(f"https://www.youtube.com/watch?v={video_id}")

View File

@@ -1,164 +0,0 @@
import math
import os
import time
import json
import xml.etree.ElementTree as ElementTree
from html import unescape
from typing import Dict, Optional
from pytube import request
from pytube.helpers import safe_filename, target_directory
class Caption:
"""Container for caption tracks."""
def __init__(self, caption_track: Dict):
"""Construct a :class:`Caption <Caption>`.
:param dict caption_track:
Caption track data extracted from ``watch_html``.
"""
self.url = caption_track.get("baseUrl")
# Certain videos have runs instead of simpleText
# this handles that edge case
name_dict = caption_track['name']
if 'simpleText' in name_dict:
self.name = name_dict['simpleText']
else:
for el in name_dict['runs']:
if 'text' in el:
self.name = el['text']
# Use "vssId" instead of "languageCode", fix issue #779
self.code = caption_track["vssId"]
# Remove preceding '.' for backwards compatibility, e.g.:
# English -> vssId: .en, languageCode: en
# English (auto-generated) -> vssId: a.en, languageCode: en
self.code = self.code.strip('.')
@property
def xml_captions(self) -> str:
"""Download the xml caption tracks."""
return request.get(self.url)
@property
def json_captions(self) -> dict:
"""Download and parse the json caption tracks."""
json_captions_url = self.url.replace('fmt=srv3','fmt=json3')
text = request.get(json_captions_url)
parsed = json.loads(text)
assert parsed['wireMagic'] == 'pb3', 'Unexpected captions format'
return parsed
def generate_srt_captions(self) -> str:
"""Generate "SubRip Subtitle" captions.
Takes the xml captions from :meth:`~pytube.Caption.xml_captions` and
recompiles them into the "SubRip Subtitle" format.
"""
return self.xml_caption_to_srt(self.xml_captions)
@staticmethod
def float_to_srt_time_format(d: float) -> str:
"""Convert decimal durations into proper srt format.
:rtype: str
:returns:
SubRip Subtitle (str) formatted time duration.
float_to_srt_time_format(3.89) -> '00:00:03,890'
"""
fraction, whole = math.modf(d)
time_fmt = time.strftime("%H:%M:%S,", time.gmtime(whole))
ms = f"{fraction:.3f}".replace("0.", "")
return time_fmt + ms
def xml_caption_to_srt(self, xml_captions: str) -> str:
"""Convert xml caption tracks to "SubRip Subtitle (srt)".
:param str xml_captions:
XML formatted caption tracks.
"""
segments = []
root = ElementTree.fromstring(xml_captions)
for i, child in enumerate(list(root)):
text = child.text or ""
caption = unescape(text.replace("\n", " ").replace(" ", " "),)
try:
duration = float(child.attrib["dur"])
except KeyError:
duration = 0.0
start = float(child.attrib["start"])
end = start + duration
sequence_number = i + 1 # convert from 0-indexed to 1.
line = "{seq}\n{start} --> {end}\n{text}\n".format(
seq=sequence_number,
start=self.float_to_srt_time_format(start),
end=self.float_to_srt_time_format(end),
text=caption,
)
segments.append(line)
return "\n".join(segments).strip()
def download(
self,
title: str,
srt: bool = True,
output_path: Optional[str] = None,
filename_prefix: Optional[str] = None,
) -> str:
"""Write the media stream to disk.
:param title:
Output filename (stem only) for writing media file.
If one is not specified, the default filename is used.
:type title: str
:param srt:
Set to True to download srt, false to download xml. Defaults to True.
:type srt bool
:param output_path:
(optional) Output path for writing media file. If one is not
specified, defaults to the current working directory.
:type output_path: str or None
:param filename_prefix:
(optional) A string that will be prepended to the filename.
For example a number in a playlist or the name of a series.
If one is not specified, nothing will be prepended
This is separate from filename so you can use the default
filename but still add a prefix.
:type filename_prefix: str or None
:rtype: str
"""
if title.endswith(".srt") or title.endswith(".xml"):
filename = ".".join(title.split(".")[:-1])
else:
filename = title
if filename_prefix:
filename = f"{safe_filename(filename_prefix)}{filename}"
filename = safe_filename(filename)
filename += f" ({self.code})"
if srt:
filename += ".srt"
else:
filename += ".xml"
file_path = os.path.join(target_directory(output_path), filename)
with open(file_path, "w", encoding="utf-8") as file_handle:
if srt:
file_handle.write(self.generate_srt_captions())
else:
file_handle.write(self.xml_captions)
return file_path
def __repr__(self):
"""Printable object representation."""
return '<Caption lang="{s.name}" code="{s.code}">'.format(s=self)

View File

@@ -1,698 +0,0 @@
"""
This module contains all logic necessary to decipher the signature.
YouTube's strategy to restrict downloading videos is to send a ciphered version
of the signature to the client, along with the decryption algorithm obfuscated
in JavaScript. For the clients to play the videos, JavaScript must take the
ciphered version, cycle it through a series of "transform functions," and then
signs the media URL with the output.
This module is responsible for (1) finding and extracting those "transform
functions" (2) maps them to Python equivalents and (3) taking the ciphered
signature and decoding it.
"""
import logging
import re
from itertools import chain
from typing import Any, Callable, Dict, List, Optional, Tuple
from pytube.exceptions import ExtractError, RegexMatchError
from pytube.helpers import cache, regex_search
from pytube.parser import find_object_from_startpoint, throttling_array_split
logger = logging.getLogger(__name__)
class Cipher:
def __init__(self, js: str):
self.transform_plan: List[str] = get_transform_plan(js)
var_regex = re.compile(r"^\w+\W")
var_match = var_regex.search(self.transform_plan[0])
if not var_match:
raise RegexMatchError(
caller="__init__", pattern=var_regex.pattern
)
var = var_match.group(0)[:-1]
self.transform_map = get_transform_map(js, var)
self.js_func_patterns = [
r"\w+\.(\w+)\(\w,(\d+)\)",
r"\w+\[(\"\w+\")\]\(\w,(\d+)\)"
]
self.throttling_plan = get_throttling_plan(js)
self.throttling_array = get_throttling_function_array(js)
self.calculated_n = None
def calculate_n(self, initial_n: list):
"""Converts n to the correct value to prevent throttling."""
if self.calculated_n:
return self.calculated_n
# First, update all instances of 'b' with the list(initial_n)
for i in range(len(self.throttling_array)):
if self.throttling_array[i] == 'b':
self.throttling_array[i] = initial_n
for step in self.throttling_plan:
curr_func = self.throttling_array[int(step[0])]
if not callable(curr_func):
logger.debug(f'{curr_func} is not callable.')
logger.debug(f'Throttling array:\n{self.throttling_array}\n')
raise ExtractError(f'{curr_func} is not callable.')
first_arg = self.throttling_array[int(step[1])]
if len(step) == 2:
curr_func(first_arg)
elif len(step) == 3:
second_arg = self.throttling_array[int(step[2])]
curr_func(first_arg, second_arg)
self.calculated_n = ''.join(initial_n)
return self.calculated_n
def get_signature(self, ciphered_signature: str) -> str:
"""Decipher the signature.
Taking the ciphered signature, applies the transform functions.
:param str ciphered_signature:
The ciphered signature sent in the ``player_config``.
:rtype: str
:returns:
Decrypted signature required to download the media content.
"""
signature = list(ciphered_signature)
for js_func in self.transform_plan:
name, argument = self.parse_function(js_func) # type: ignore
signature = self.transform_map[name](signature, argument)
logger.debug(
"applied transform function\n"
"output: %s\n"
"js_function: %s\n"
"argument: %d\n"
"function: %s",
"".join(signature),
name,
argument,
self.transform_map[name],
)
return "".join(signature)
@cache
def parse_function(self, js_func: str) -> Tuple[str, int]:
"""Parse the Javascript transform function.
Break a JavaScript transform function down into a two element ``tuple``
containing the function name and some integer-based argument.
:param str js_func:
The JavaScript version of the transform function.
:rtype: tuple
:returns:
two element tuple containing the function name and an argument.
**Example**:
parse_function('DE.AJ(a,15)')
('AJ', 15)
"""
logger.debug("parsing transform function")
for pattern in self.js_func_patterns:
regex = re.compile(pattern)
parse_match = regex.search(js_func)
if parse_match:
fn_name, fn_arg = parse_match.groups()
return fn_name, int(fn_arg)
raise RegexMatchError(
caller="parse_function", pattern="js_func_patterns"
)
def get_initial_function_name(js: str) -> str:
"""Extract the name of the function responsible for computing the signature.
:param str js:
The contents of the base.js asset file.
:rtype: str
:returns:
Function name from regex match
"""
function_patterns = [
r"\b[cs]\s*&&\s*[adf]\.set\([^,]+\s*,\s*encodeURIComponent\s*\(\s*(?P<sig>[a-zA-Z0-9$]+)\(", # noqa: E501
r"\b[a-zA-Z0-9]+\s*&&\s*[a-zA-Z0-9]+\.set\([^,]+\s*,\s*encodeURIComponent\s*\(\s*(?P<sig>[a-zA-Z0-9$]+)\(", # noqa: E501
r'(?:\b|[^a-zA-Z0-9$])(?P<sig>[a-zA-Z0-9$]{2})\s*=\s*function\(\s*a\s*\)\s*{\s*a\s*=\s*a\.split\(\s*""\s*\)', # noqa: E501
r'(?P<sig>[a-zA-Z0-9$]+)\s*=\s*function\(\s*a\s*\)\s*{\s*a\s*=\s*a\.split\(\s*""\s*\)', # noqa: E501
r'(["\'])signature\1\s*,\s*(?P<sig>[a-zA-Z0-9$]+)\(',
r"\.sig\|\|(?P<sig>[a-zA-Z0-9$]+)\(",
r"yt\.akamaized\.net/\)\s*\|\|\s*.*?\s*[cs]\s*&&\s*[adf]\.set\([^,]+\s*,\s*(?:encodeURIComponent\s*\()?\s*(?P<sig>[a-zA-Z0-9$]+)\(", # noqa: E501
r"\b[cs]\s*&&\s*[adf]\.set\([^,]+\s*,\s*(?P<sig>[a-zA-Z0-9$]+)\(", # noqa: E501
r"\b[a-zA-Z0-9]+\s*&&\s*[a-zA-Z0-9]+\.set\([^,]+\s*,\s*(?P<sig>[a-zA-Z0-9$]+)\(", # noqa: E501
r"\bc\s*&&\s*a\.set\([^,]+\s*,\s*\([^)]*\)\s*\(\s*(?P<sig>[a-zA-Z0-9$]+)\(", # noqa: E501
r"\bc\s*&&\s*[a-zA-Z0-9]+\.set\([^,]+\s*,\s*\([^)]*\)\s*\(\s*(?P<sig>[a-zA-Z0-9$]+)\(", # noqa: E501
r"\bc\s*&&\s*[a-zA-Z0-9]+\.set\([^,]+\s*,\s*\([^)]*\)\s*\(\s*(?P<sig>[a-zA-Z0-9$]+)\(", # noqa: E501
]
logger.debug("finding initial function name")
for pattern in function_patterns:
regex = re.compile(pattern)
function_match = regex.search(js)
if function_match:
logger.debug("finished regex search, matched: %s", pattern)
return function_match.group(1)
raise RegexMatchError(
caller="get_initial_function_name", pattern="multiple"
)
def get_transform_plan(js: str) -> List[str]:
"""Extract the "transform plan".
The "transform plan" is the functions that the ciphered signature is
cycled through to obtain the actual signature.
:param str js:
The contents of the base.js asset file.
**Example**:
['DE.AJ(a,15)',
'DE.VR(a,3)',
'DE.AJ(a,51)',
'DE.VR(a,3)',
'DE.kT(a,51)',
'DE.kT(a,8)',
'DE.VR(a,3)',
'DE.kT(a,21)']
"""
name = re.escape(get_initial_function_name(js))
pattern = r"%s=function\(\w\){[a-z=\.\(\"\)]*;(.*);(?:.+)}" % name
logger.debug("getting transform plan")
return regex_search(pattern, js, group=1).split(";")
def get_transform_object(js: str, var: str) -> List[str]:
"""Extract the "transform object".
The "transform object" contains the function definitions referenced in the
"transform plan". The ``var`` argument is the obfuscated variable name
which contains these functions, for example, given the function call
``DE.AJ(a,15)`` returned by the transform plan, "DE" would be the var.
:param str js:
The contents of the base.js asset file.
:param str var:
The obfuscated variable name that stores an object with all functions
that descrambles the signature.
**Example**:
>>> get_transform_object(js, 'DE')
['AJ:function(a){a.reverse()}',
'VR:function(a,b){a.splice(0,b)}',
'kT:function(a,b){var c=a[0];a[0]=a[b%a.length];a[b]=c}']
"""
pattern = r"var %s={(.*?)};" % re.escape(var)
logger.debug("getting transform object")
regex = re.compile(pattern, flags=re.DOTALL)
transform_match = regex.search(js)
if not transform_match:
raise RegexMatchError(caller="get_transform_object", pattern=pattern)
return transform_match.group(1).replace("\n", " ").split(", ")
def get_transform_map(js: str, var: str) -> Dict:
"""Build a transform function lookup.
Build a lookup table of obfuscated JavaScript function names to the
Python equivalents.
:param str js:
The contents of the base.js asset file.
:param str var:
The obfuscated variable name that stores an object with all functions
that descrambles the signature.
"""
transform_object = get_transform_object(js, var)
mapper = {}
for obj in transform_object:
# AJ:function(a){a.reverse()} => AJ, function(a){a.reverse()}
name, function = obj.split(":", 1)
fn = map_functions(function)
mapper[name] = fn
return mapper
def get_throttling_function_name(js: str) -> str:
"""Extract the name of the function that computes the throttling parameter.
:param str js:
The contents of the base.js asset file.
:rtype: str
:returns:
The name of the function used to compute the throttling parameter.
"""
function_patterns = [
# https://github.com/ytdl-org/youtube-dl/issues/29326#issuecomment-865985377
# https://github.com/yt-dlp/yt-dlp/commit/48416bc4a8f1d5ff07d5977659cb8ece7640dcd8
# var Bpa = [iha];
# ...
# a.C && (b = a.get("n")) && (b = Bpa[0](b), a.set("n", b),
# Bpa.length || iha("")) }};
# In the above case, `iha` is the relevant function name
r'a\.[a-zA-Z]\s*&&\s*\([a-z]\s*=\s*a\.get\("n"\)\)\s*&&\s*'
r'\([a-z]\s*=\s*([a-zA-Z0-9$]+)(\[\d+\])?\([a-z]\)',
r'\([a-z]\s*=\s*([a-zA-Z0-9$]+)(\[\d+\])\([a-z]\)',
]
logger.debug('Finding throttling function name')
for pattern in function_patterns:
regex = re.compile(pattern)
function_match = regex.search(js)
if function_match:
logger.debug("finished regex search, matched: %s", pattern)
if len(function_match.groups()) == 1:
return function_match.group(1)
idx = function_match.group(2)
if idx:
idx = idx.strip("[]")
array = re.search(
r'var {nfunc}\s*=\s*(\[.+?\]);'.format(
nfunc=re.escape(function_match.group(1))),
js
)
if array:
array = array.group(1).strip("[]").split(",")
array = [x.strip() for x in array]
return array[int(idx)]
raise RegexMatchError(
caller="get_throttling_function_name", pattern="multiple"
)
def get_throttling_function_code(js: str) -> str:
"""Extract the raw code for the throttling function.
:param str js:
The contents of the base.js asset file.
:rtype: str
:returns:
The name of the function used to compute the throttling parameter.
"""
# Begin by extracting the correct function name
name = re.escape(get_throttling_function_name(js))
# Identify where the function is defined
pattern_start = r"%s=function\(\w\)" % name
regex = re.compile(pattern_start)
match = regex.search(js)
# Extract the code within curly braces for the function itself, and merge any split lines
code_lines_list = find_object_from_startpoint(js, match.span()[1]).split('\n')
joined_lines = "".join(code_lines_list)
# Prepend function definition (e.g. `Dea=function(a)`)
return match.group(0) + joined_lines
def get_throttling_function_array(js: str) -> List[Any]:
"""Extract the "c" array.
:param str js:
The contents of the base.js asset file.
:returns:
The array of various integers, arrays, and functions.
"""
raw_code = get_throttling_function_code(js)
array_start = r",c=\["
array_regex = re.compile(array_start)
match = array_regex.search(raw_code)
array_raw = find_object_from_startpoint(raw_code, match.span()[1] - 1)
str_array = throttling_array_split(array_raw)
converted_array = []
for el in str_array:
try:
converted_array.append(int(el))
continue
except ValueError:
# Not an integer value.
pass
if el == 'null':
converted_array.append(None)
continue
if el.startswith('"') and el.endswith('"'):
# Convert e.g. '"abcdef"' to string without quotation marks, 'abcdef'
converted_array.append(el[1:-1])
continue
if el.startswith('function'):
mapper = (
(r"{for\(\w=\(\w%\w\.length\+\w\.length\)%\w\.length;\w--;\)\w\.unshift\(\w.pop\(\)\)}", throttling_unshift), # noqa:E501
(r"{\w\.reverse\(\)}", throttling_reverse),
(r"{\w\.push\(\w\)}", throttling_push),
(r";var\s\w=\w\[0\];\w\[0\]=\w\[\w\];\w\[\w\]=\w}", throttling_swap),
(r"case\s\d+", throttling_cipher_function),
(r"\w\.splice\(0,1,\w\.splice\(\w,1,\w\[0\]\)\[0\]\)", throttling_nested_splice), # noqa:E501
(r";\w\.splice\(\w,1\)}", js_splice),
(r"\w\.splice\(-\w\)\.reverse\(\)\.forEach\(function\(\w\){\w\.unshift\(\w\)}\)", throttling_prepend), # noqa:E501
(r"for\(var \w=\w\.length;\w;\)\w\.push\(\w\.splice\(--\w,1\)\[0\]\)}", throttling_reverse), # noqa:E501
)
found = False
for pattern, fn in mapper:
if re.search(pattern, el):
converted_array.append(fn)
found = True
if found:
continue
converted_array.append(el)
# Replace null elements with array itself
for i in range(len(converted_array)):
if converted_array[i] is None:
converted_array[i] = converted_array
return converted_array
def get_throttling_plan(js: str):
"""Extract the "throttling plan".
The "throttling plan" is a list of tuples used for calling functions
in the c array. The first element of the tuple is the index of the
function to call, and any remaining elements of the tuple are arguments
to pass to that function.
:param str js:
The contents of the base.js asset file.
:returns:
The full function code for computing the throttlign parameter.
"""
raw_code = get_throttling_function_code(js)
transform_start = r"try{"
plan_regex = re.compile(transform_start)
match = plan_regex.search(raw_code)
transform_plan_raw = find_object_from_startpoint(raw_code, match.span()[1] - 1)
# Steps are either c[x](c[y]) or c[x](c[y],c[z])
step_start = r"c\[(\d+)\]\(c\[(\d+)\](,c(\[(\d+)\]))?\)"
step_regex = re.compile(step_start)
matches = step_regex.findall(transform_plan_raw)
transform_steps = []
for match in matches:
if match[4] != '':
transform_steps.append((match[0],match[1],match[4]))
else:
transform_steps.append((match[0],match[1]))
return transform_steps
def reverse(arr: List, _: Optional[Any]):
"""Reverse elements in a list.
This function is equivalent to:
.. code-block:: javascript
function(a, b) { a.reverse() }
This method takes an unused ``b`` variable as their transform functions
universally sent two arguments.
**Example**:
>>> reverse([1, 2, 3, 4])
[4, 3, 2, 1]
"""
return arr[::-1]
def splice(arr: List, b: int):
"""Add/remove items to/from a list.
This function is equivalent to:
.. code-block:: javascript
function(a, b) { a.splice(0, b) }
**Example**:
>>> splice([1, 2, 3, 4], 2)
[1, 2]
"""
return arr[b:]
def swap(arr: List, b: int):
"""Swap positions at b modulus the list length.
This function is equivalent to:
.. code-block:: javascript
function(a, b) { var c=a[0];a[0]=a[b%a.length];a[b]=c }
**Example**:
>>> swap([1, 2, 3, 4], 2)
[3, 2, 1, 4]
"""
r = b % len(arr)
return list(chain([arr[r]], arr[1:r], [arr[0]], arr[r + 1 :]))
def throttling_reverse(arr: list):
"""Reverses the input list.
Needs to do an in-place reversal so that the passed list gets changed.
To accomplish this, we create a reversed copy, and then change each
indvidual element.
"""
reverse_copy = arr.copy()[::-1]
for i in range(len(reverse_copy)):
arr[i] = reverse_copy[i]
def throttling_push(d: list, e: Any):
"""Pushes an element onto a list."""
d.append(e)
def throttling_mod_func(d: list, e: int):
"""Perform the modular function from the throttling array functions.
In the javascript, the modular operation is as follows:
e = (e % d.length + d.length) % d.length
We simply translate this to python here.
"""
return (e % len(d) + len(d)) % len(d)
def throttling_unshift(d: list, e: int):
"""Rotates the elements of the list to the right.
In the javascript, the operation is as follows:
for(e=(e%d.length+d.length)%d.length;e--;)d.unshift(d.pop())
"""
e = throttling_mod_func(d, e)
new_arr = d[-e:] + d[:-e]
d.clear()
for el in new_arr:
d.append(el)
def throttling_cipher_function(d: list, e: str):
"""This ciphers d with e to generate a new list.
In the javascript, the operation is as follows:
var h = [A-Za-z0-9-_], f = 96; // simplified from switch-case loop
d.forEach(
function(l,m,n){
this.push(
n[m]=h[
(h.indexOf(l)-h.indexOf(this[m])+m-32+f--)%h.length
]
)
},
e.split("")
)
"""
h = list('ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_')
f = 96
# by naming it "this" we can more closely reflect the js
this = list(e)
# This is so we don't run into weirdness with enumerate while
# we change the input list
copied_list = d.copy()
for m, l in enumerate(copied_list):
bracket_val = (h.index(l) - h.index(this[m]) + m - 32 + f) % len(h)
this.append(
h[bracket_val]
)
d[m] = h[bracket_val]
f -= 1
def throttling_nested_splice(d: list, e: int):
"""Nested splice function in throttling js.
In the javascript, the operation is as follows:
function(d,e){
e=(e%d.length+d.length)%d.length;
d.splice(
0,
1,
d.splice(
e,
1,
d[0]
)[0]
)
}
While testing, all this seemed to do is swap element 0 and e,
but the actual process is preserved in case there was an edge
case that was not considered.
"""
e = throttling_mod_func(d, e)
inner_splice = js_splice(
d,
e,
1,
d[0]
)
js_splice(
d,
0,
1,
inner_splice[0]
)
def throttling_prepend(d: list, e: int):
"""
In the javascript, the operation is as follows:
function(d,e){
e=(e%d.length+d.length)%d.length;
d.splice(-e).reverse().forEach(
function(f){
d.unshift(f)
}
)
}
Effectively, this moves the last e elements of d to the beginning.
"""
start_len = len(d)
# First, calculate e
e = throttling_mod_func(d, e)
# Then do the prepending
new_arr = d[-e:] + d[:-e]
# And update the input list
d.clear()
for el in new_arr:
d.append(el)
end_len = len(d)
assert start_len == end_len
def throttling_swap(d: list, e: int):
"""Swap positions of the 0'th and e'th elements in-place."""
e = throttling_mod_func(d, e)
f = d[0]
d[0] = d[e]
d[e] = f
def js_splice(arr: list, start: int, delete_count=None, *items):
"""Implementation of javascript's splice function.
:param list arr:
Array to splice
:param int start:
Index at which to start changing the array
:param int delete_count:
Number of elements to delete from the array
:param *items:
Items to add to the array
Reference: https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Array/splice # noqa:E501
"""
# Special conditions for start value
try:
if start > len(arr):
start = len(arr)
# If start is negative, count backwards from end
if start < 0:
start = len(arr) - start
except TypeError:
# Non-integer start values are treated as 0 in js
start = 0
# Special condition when delete_count is greater than remaining elements
if not delete_count or delete_count >= len(arr) - start:
delete_count = len(arr) - start # noqa: N806
deleted_elements = arr[start:start + delete_count]
# Splice appropriately.
new_arr = arr[:start] + list(items) + arr[start + delete_count:]
# Replace contents of input array
arr.clear()
for el in new_arr:
arr.append(el)
return deleted_elements
def map_functions(js_func: str) -> Callable:
"""For a given JavaScript transform function, return the Python equivalent.
:param str js_func:
The JavaScript version of the transform function.
"""
mapper = (
# function(a){a.reverse()}
(r"{\w\.reverse\(\)}", reverse),
# function(a,b){a.splice(0,b)}
(r"{\w\.splice\(0,\w\)}", splice),
# function(a,b){var c=a[0];a[0]=a[b%a.length];a[b]=c}
(r"{var\s\w=\w\[0\];\w\[0\]=\w\[\w\%\w.length\];\w\[\w\]=\w}", swap),
# function(a,b){var c=a[0];a[0]=a[b%a.length];a[b%a.length]=c}
(
r"{var\s\w=\w\[0\];\w\[0\]=\w\[\w\%\w.length\];\w\[\w\%\w.length\]=\w}",
swap,
),
)
for pattern, fn in mapper:
if re.search(pattern, js_func):
return fn
raise RegexMatchError(caller="map_functions", pattern="multiple")

View File

@@ -1,560 +0,0 @@
#!/usr/bin/env python3
"""A simple command line application to download youtube videos."""
import argparse
import gzip
import json
import logging
import os
import shutil
import sys
import datetime as dt
import subprocess # nosec
from typing import List, Optional
import pytube.exceptions as exceptions
from pytube import __version__
from pytube import CaptionQuery, Playlist, Stream, YouTube
from pytube.helpers import safe_filename, setup_logger
logger = logging.getLogger(__name__)
def main():
"""Command line application to download youtube videos."""
# noinspection PyTypeChecker
parser = argparse.ArgumentParser(description=main.__doc__)
args = _parse_args(parser)
if args.verbose:
log_filename = None
if args.logfile:
log_filename = args.logfile
setup_logger(logging.DEBUG, log_filename=log_filename)
logger.debug(f'Pytube version: {__version__}')
if not args.url or "youtu" not in args.url:
parser.print_help()
sys.exit(1)
if "/playlist" in args.url:
print("Loading playlist...")
playlist = Playlist(args.url)
if not args.target:
args.target = safe_filename(playlist.title)
for youtube_video in playlist.videos:
try:
_perform_args_on_youtube(youtube_video, args)
except exceptions.PytubeError as e:
print(f"There was an error with video: {youtube_video}")
print(e)
else:
print("Loading video...")
youtube = YouTube(args.url)
_perform_args_on_youtube(youtube, args)
def _perform_args_on_youtube(
youtube: YouTube, args: argparse.Namespace
) -> None:
if len(sys.argv) == 2 : # no arguments parsed
download_highest_resolution_progressive(
youtube=youtube, resolution="highest", target=args.target
)
if args.list_captions:
_print_available_captions(youtube.captions)
if args.list:
display_streams(youtube)
if args.build_playback_report:
build_playback_report(youtube)
if args.itag:
download_by_itag(youtube=youtube, itag=args.itag, target=args.target)
if args.caption_code:
download_caption(
youtube=youtube, lang_code=args.caption_code, target=args.target
)
if args.resolution:
download_by_resolution(
youtube=youtube, resolution=args.resolution, target=args.target
)
if args.audio:
download_audio(
youtube=youtube, filetype=args.audio, target=args.target
)
if args.ffmpeg:
ffmpeg_process(
youtube=youtube, resolution=args.ffmpeg, target=args.target
)
def _parse_args(
parser: argparse.ArgumentParser, args: Optional[List] = None
) -> argparse.Namespace:
parser.add_argument(
"url", help="The YouTube /watch or /playlist url", nargs="?"
)
parser.add_argument(
"--version", action="version", version="%(prog)s " + __version__,
)
parser.add_argument(
"--itag", type=int, help="The itag for the desired stream",
)
parser.add_argument(
"-r",
"--resolution",
type=str,
help="The resolution for the desired stream",
)
parser.add_argument(
"-l",
"--list",
action="store_true",
help=(
"The list option causes pytube cli to return a list of streams "
"available to download"
),
)
parser.add_argument(
"-v",
"--verbose",
action="store_true",
dest="verbose",
help="Set logger output to verbose output.",
)
parser.add_argument(
"--logfile",
action="store",
help="logging debug and error messages into a log file",
)
parser.add_argument(
"--build-playback-report",
action="store_true",
help="Save the html and js to disk",
)
parser.add_argument(
"-c",
"--caption-code",
type=str,
help=(
"Download srt captions for given language code. "
"Prints available language codes if no argument given"
),
)
parser.add_argument(
'-lc',
'--list-captions',
action='store_true',
help=(
"List available caption codes for a video"
)
)
parser.add_argument(
"-t",
"--target",
help=(
"The output directory for the downloaded stream. "
"Default is current working directory"
),
)
parser.add_argument(
"-a",
"--audio",
const="mp4",
nargs="?",
help=(
"Download the audio for a given URL at the highest bitrate available. "
"Defaults to mp4 format if none is specified"
),
)
parser.add_argument(
"-f",
"--ffmpeg",
const="best",
nargs="?",
help=(
"Downloads the audio and video stream for resolution provided. "
"If no resolution is provided, downloads the best resolution. "
"Runs the command line program ffmpeg to combine the audio and video"
),
)
return parser.parse_args(args)
def build_playback_report(youtube: YouTube) -> None:
"""Serialize the request data to json for offline debugging.
:param YouTube youtube:
A YouTube object.
"""
ts = int(dt.datetime.utcnow().timestamp())
fp = os.path.join(os.getcwd(), f"yt-video-{youtube.video_id}-{ts}.json.gz")
js = youtube.js
watch_html = youtube.watch_html
vid_info = youtube.vid_info
with gzip.open(fp, "wb") as fh:
fh.write(
json.dumps(
{
"url": youtube.watch_url,
"js": js,
"watch_html": watch_html,
"video_info": vid_info,
}
).encode("utf8"),
)
def display_progress_bar(
bytes_received: int, filesize: int, ch: str = "", scale: float = 0.55
) -> None:
"""Display a simple, pretty progress bar.
Example:
~~~~~~~~
PSY - GANGNAM STYLE(강남스타일) MV.mp4
↳ |███████████████████████████████████████| 100.0%
:param int bytes_received:
The delta between the total file size (bytes) and bytes already
written to disk.
:param int filesize:
File size of the media stream in bytes.
:param str ch:
Character to use for presenting progress segment.
:param float scale:
Scale multiplier to reduce progress bar size.
"""
columns = shutil.get_terminal_size().columns
max_width = int(columns * scale)
filled = int(round(max_width * bytes_received / float(filesize)))
remaining = max_width - filled
progress_bar = ch * filled + " " * remaining
percent = round(100.0 * bytes_received / float(filesize), 1)
text = f" ↳ |{progress_bar}| {percent}%\r"
sys.stdout.write(text)
sys.stdout.flush()
# noinspection PyUnusedLocal
def on_progress(
stream: Stream, chunk: bytes, bytes_remaining: int
) -> None: # pylint: disable=W0613
filesize = stream.filesize
bytes_received = filesize - bytes_remaining
display_progress_bar(bytes_received, filesize)
def _download(
stream: Stream,
target: Optional[str] = None,
filename: Optional[str] = None,
) -> None:
filesize_megabytes = stream.filesize // 1048576
print(f"{filename or stream.default_filename} | {filesize_megabytes} MB")
file_path = stream.get_file_path(filename=filename, output_path=target)
if stream.exists_at_path(file_path):
print(f"Already downloaded at:\n{file_path}")
return
stream.download(output_path=target, filename=filename)
sys.stdout.write("\n")
def _unique_name(base: str, subtype: str, media_type: str, target: str) -> str:
"""
Given a base name, the file format, and the target directory, will generate
a filename unique for that directory and file format.
:param str base:
The given base-name.
:param str subtype:
The filetype of the video which will be downloaded.
:param str media_type:
The media_type of the file, ie. "audio" or "video"
:param Path target:
Target directory for download.
"""
counter = 0
while True:
file_name = f"{base}_{media_type}_{counter}"
file_path = os.path.join(target, f"{file_name}.{subtype}")
if not os.path.exists(file_path):
return file_name
counter += 1
def ffmpeg_process(
youtube: YouTube, resolution: str, target: Optional[str] = None
) -> None:
"""
Decides the correct video stream to download, then calls _ffmpeg_downloader.
:param YouTube youtube:
A valid YouTube object.
:param str resolution:
YouTube video resolution.
:param str target:
Target directory for download
"""
youtube.register_on_progress_callback(on_progress)
target = target or os.getcwd()
if resolution == "best":
highest_quality_stream = (
youtube.streams.filter(progressive=False)
.order_by("resolution")
.last()
)
mp4_stream = (
youtube.streams.filter(progressive=False, subtype="mp4")
.order_by("resolution")
.last()
)
if highest_quality_stream.resolution == mp4_stream.resolution:
video_stream = mp4_stream
else:
video_stream = highest_quality_stream
else:
video_stream = youtube.streams.filter(
progressive=False, resolution=resolution, subtype="mp4"
).first()
if not video_stream:
video_stream = youtube.streams.filter(
progressive=False, resolution=resolution
).first()
if video_stream is None:
print(f"Could not find a stream with resolution: {resolution}")
print("Try one of these:")
display_streams(youtube)
sys.exit()
audio_stream = youtube.streams.get_audio_only(video_stream.subtype)
if not audio_stream:
audio_stream = (
youtube.streams.filter(only_audio=True).order_by("abr").last()
)
if not audio_stream:
print("Could not find an audio only stream")
sys.exit()
_ffmpeg_downloader(
audio_stream=audio_stream, video_stream=video_stream, target=target
)
def _ffmpeg_downloader(
audio_stream: Stream, video_stream: Stream, target: str
) -> None:
"""
Given a YouTube Stream object, finds the correct audio stream, downloads them both
giving them a unique name, them uses ffmpeg to create a new file with the audio
and video from the previously downloaded files. Then deletes the original adaptive
streams, leaving the combination.
:param Stream audio_stream:
A valid Stream object representing the audio to download
:param Stream video_stream:
A valid Stream object representing the video to download
:param Path target:
A valid Path object
"""
video_unique_name = _unique_name(
safe_filename(video_stream.title),
video_stream.subtype,
"video",
target=target,
)
audio_unique_name = _unique_name(
safe_filename(video_stream.title),
audio_stream.subtype,
"audio",
target=target,
)
_download(stream=video_stream, target=target, filename=video_unique_name)
print("Loading audio...")
_download(stream=audio_stream, target=target, filename=audio_unique_name)
video_path = os.path.join(
target, f"{video_unique_name}.{video_stream.subtype}"
)
audio_path = os.path.join(
target, f"{audio_unique_name}.{audio_stream.subtype}"
)
final_path = os.path.join(
target, f"{safe_filename(video_stream.title)}.{video_stream.subtype}"
)
subprocess.run( # nosec
[
"ffmpeg",
"-i",
video_path,
"-i",
audio_path,
"-codec",
"copy",
final_path,
]
)
os.unlink(video_path)
os.unlink(audio_path)
def download_by_itag(
youtube: YouTube, itag: int, target: Optional[str] = None
) -> None:
"""Start downloading a YouTube video.
:param YouTube youtube:
A valid YouTube object.
:param int itag:
YouTube format identifier code.
:param str target:
Target directory for download
"""
stream = youtube.streams.get_by_itag(itag)
if stream is None:
print(f"Could not find a stream with itag: {itag}")
print("Try one of these:")
display_streams(youtube)
sys.exit()
youtube.register_on_progress_callback(on_progress)
try:
_download(stream, target=target)
except KeyboardInterrupt:
sys.exit()
def download_by_resolution(
youtube: YouTube, resolution: str, target: Optional[str] = None
) -> None:
"""Start downloading a YouTube video.
:param YouTube youtube:
A valid YouTube object.
:param str resolution:
YouTube video resolution.
:param str target:
Target directory for download
"""
# TODO(nficano): allow dash itags to be selected
stream = youtube.streams.get_by_resolution(resolution)
if stream is None:
print(f"Could not find a stream with resolution: {resolution}")
print("Try one of these:")
display_streams(youtube)
sys.exit()
youtube.register_on_progress_callback(on_progress)
try:
_download(stream, target=target)
except KeyboardInterrupt:
sys.exit()
def download_highest_resolution_progressive(
youtube: YouTube, resolution: str, target: Optional[str] = None
) -> None:
"""Start downloading the highest resolution progressive stream.
:param YouTube youtube:
A valid YouTube object.
:param str resolution:
YouTube video resolution.
:param str target:
Target directory for download
"""
youtube.register_on_progress_callback(on_progress)
try:
stream = youtube.streams.get_highest_resolution()
except exceptions.VideoUnavailable as err:
print(f"No video streams available: {err}")
else:
try:
_download(stream, target=target)
except KeyboardInterrupt:
sys.exit()
def display_streams(youtube: YouTube) -> None:
"""Probe YouTube video and lists its available formats.
:param YouTube youtube:
A valid YouTube watch URL.
"""
for stream in youtube.streams:
print(stream)
def _print_available_captions(captions: CaptionQuery) -> None:
print(
f"Available caption codes are: {', '.join(c.code for c in captions)}"
)
def download_caption(
youtube: YouTube, lang_code: Optional[str], target: Optional[str] = None
) -> None:
"""Download a caption for the YouTube video.
:param YouTube youtube:
A valid YouTube object.
:param str lang_code:
Language code desired for caption file.
Prints available codes if the value is None
or the desired code is not available.
:param str target:
Target directory for download
"""
try:
caption = youtube.captions[lang_code]
downloaded_path = caption.download(
title=youtube.title, output_path=target
)
print(f"Saved caption file to: {downloaded_path}")
except KeyError:
print(f"Unable to find caption with code: {lang_code}")
_print_available_captions(youtube.captions)
def download_audio(
youtube: YouTube, filetype: str, target: Optional[str] = None
) -> None:
"""
Given a filetype, downloads the highest quality available audio stream for a
YouTube video.
:param YouTube youtube:
A valid YouTube object.
:param str filetype:
Desired file format to download.
:param str target:
Target directory for download
"""
audio = (
youtube.streams.filter(only_audio=True, subtype=filetype)
.order_by("abr")
.last()
)
if audio is None:
print("No audio only stream found. Try one of these:")
display_streams(youtube)
sys.exit()
youtube.register_on_progress_callback(on_progress)
try:
_download(audio, target=target)
except KeyboardInterrupt:
sys.exit()
if __name__ == "__main__":
main()

View File

@@ -1,201 +0,0 @@
# -*- coding: utf-8 -*-
"""Module for interacting with a user's youtube channel."""
import json
import logging
from typing import Dict, List, Optional, Tuple
from pytube import extract, Playlist, request
from pytube.helpers import uniqueify
logger = logging.getLogger(__name__)
class Channel(Playlist):
def __init__(self, url: str, proxies: Optional[Dict[str, str]] = None):
"""Construct a :class:`Channel <Channel>`.
:param str url:
A valid YouTube channel URL.
:param proxies:
(Optional) A dictionary of proxies to use for web requests.
"""
super().__init__(url, proxies)
self.channel_uri = extract.channel_name(url)
self.channel_url = (
f"https://www.youtube.com{self.channel_uri}"
)
self.videos_url = self.channel_url + '/videos'
self.playlists_url = self.channel_url + '/playlists'
self.community_url = self.channel_url + '/community'
self.featured_channels_url = self.channel_url + '/channels'
self.about_url = self.channel_url + '/about'
# Possible future additions
self._playlists_html = None
self._community_html = None
self._featured_channels_html = None
self._about_html = None
@property
def channel_name(self):
"""Get the name of the YouTube channel.
:rtype: str
"""
return self.initial_data['metadata']['channelMetadataRenderer']['title']
@property
def channel_id(self):
"""Get the ID of the YouTube channel.
This will return the underlying ID, not the vanity URL.
:rtype: str
"""
return self.initial_data['metadata']['channelMetadataRenderer']['externalId']
@property
def vanity_url(self):
"""Get the vanity URL of the YouTube channel.
Returns None if it doesn't exist.
:rtype: str
"""
return self.initial_data['metadata']['channelMetadataRenderer'].get('vanityChannelUrl', None) # noqa:E501
@property
def html(self):
"""Get the html for the /videos page.
:rtype: str
"""
if self._html:
return self._html
self._html = request.get(self.videos_url)
return self._html
@property
def playlists_html(self):
"""Get the html for the /playlists page.
Currently unused for any functionality.
:rtype: str
"""
if self._playlists_html:
return self._playlists_html
else:
self._playlists_html = request.get(self.playlists_url)
return self._playlists_html
@property
def community_html(self):
"""Get the html for the /community page.
Currently unused for any functionality.
:rtype: str
"""
if self._community_html:
return self._community_html
else:
self._community_html = request.get(self.community_url)
return self._community_html
@property
def featured_channels_html(self):
"""Get the html for the /channels page.
Currently unused for any functionality.
:rtype: str
"""
if self._featured_channels_html:
return self._featured_channels_html
else:
self._featured_channels_html = request.get(self.featured_channels_url)
return self._featured_channels_html
@property
def about_html(self):
"""Get the html for the /about page.
Currently unused for any functionality.
:rtype: str
"""
if self._about_html:
return self._about_html
else:
self._about_html = request.get(self.about_url)
return self._about_html
@staticmethod
def _extract_videos(raw_json: str) -> Tuple[List[str], Optional[str]]:
"""Extracts videos from a raw json page
:param str raw_json: Input json extracted from the page or the last
server response
:rtype: Tuple[List[str], Optional[str]]
:returns: Tuple containing a list of up to 100 video watch ids and
a continuation token, if more videos are available
"""
initial_data = json.loads(raw_json)
# this is the json tree structure, if the json was extracted from
# html
try:
videos = initial_data["contents"][
"twoColumnBrowseResultsRenderer"][
"tabs"][1]["tabRenderer"]["content"][
"sectionListRenderer"]["contents"][0][
"itemSectionRenderer"]["contents"][0][
"gridRenderer"]["items"]
except (KeyError, IndexError, TypeError):
try:
# this is the json tree structure, if the json was directly sent
# by the server in a continuation response
important_content = initial_data[1]['response']['onResponseReceivedActions'][
0
]['appendContinuationItemsAction']['continuationItems']
videos = important_content
except (KeyError, IndexError, TypeError):
try:
# this is the json tree structure, if the json was directly sent
# by the server in a continuation response
# no longer a list and no longer has the "response" key
important_content = initial_data['onResponseReceivedActions'][0][
'appendContinuationItemsAction']['continuationItems']
videos = important_content
except (KeyError, IndexError, TypeError) as p:
logger.info(p)
return [], None
try:
continuation = videos[-1]['continuationItemRenderer'][
'continuationEndpoint'
]['continuationCommand']['token']
videos = videos[:-1]
except (KeyError, IndexError):
# if there is an error, no continuation is available
continuation = None
# remove duplicates
return (
uniqueify(
list(
# only extract the video ids from the video data
map(
lambda x: (
f"/watch?v="
f"{x['gridVideoRenderer']['videoId']}"
),
videos
)
),
),
continuation,
)

View File

@@ -1,419 +0,0 @@
"""Module to download a complete playlist from a youtube channel."""
import json
import logging
from collections.abc import Sequence
from datetime import date, datetime
from typing import Dict, Iterable, List, Optional, Tuple, Union
from pytube import extract, request, YouTube
from pytube.helpers import cache, DeferredGeneratorList, install_proxy, uniqueify
logger = logging.getLogger(__name__)
class Playlist(Sequence):
"""Load a YouTube playlist with URL"""
def __init__(self, url: str, proxies: Optional[Dict[str, str]] = None):
if proxies:
install_proxy(proxies)
self._input_url = url
# These need to be initialized as None for the properties.
self._html = None
self._ytcfg = None
self._initial_data = None
self._sidebar_info = None
self._playlist_id = None
@property
def playlist_id(self):
"""Get the playlist id.
:rtype: str
"""
if self._playlist_id:
return self._playlist_id
self._playlist_id = extract.playlist_id(self._input_url)
return self._playlist_id
@property
def playlist_url(self):
"""Get the base playlist url.
:rtype: str
"""
return f"https://www.youtube.com/playlist?list={self.playlist_id}"
@property
def html(self):
"""Get the playlist page html.
:rtype: str
"""
if self._html:
return self._html
self._html = request.get(self.playlist_url)
return self._html
@property
def ytcfg(self):
"""Extract the ytcfg from the playlist page html.
:rtype: dict
"""
if self._ytcfg:
return self._ytcfg
self._ytcfg = extract.get_ytcfg(self.html)
return self._ytcfg
@property
def initial_data(self):
"""Extract the initial data from the playlist page html.
:rtype: dict
"""
if self._initial_data:
return self._initial_data
else:
self._initial_data = extract.initial_data(self.html)
return self._initial_data
@property
def sidebar_info(self):
"""Extract the sidebar info from the playlist page html.
:rtype: dict
"""
if self._sidebar_info:
return self._sidebar_info
else:
self._sidebar_info = self.initial_data['sidebar'][
'playlistSidebarRenderer']['items']
return self._sidebar_info
@property
def yt_api_key(self):
"""Extract the INNERTUBE_API_KEY from the playlist ytcfg.
:rtype: str
"""
return self.ytcfg['INNERTUBE_API_KEY']
def _paginate(
self, until_watch_id: Optional[str] = None
) -> Iterable[List[str]]:
"""Parse the video links from the page source, yields the /watch?v=
part from video link
:param until_watch_id Optional[str]: YouTube Video watch id until
which the playlist should be read.
:rtype: Iterable[List[str]]
:returns: Iterable of lists of YouTube watch ids
"""
videos_urls, continuation = self._extract_videos(
json.dumps(extract.initial_data(self.html))
)
if until_watch_id:
try:
trim_index = videos_urls.index(f"/watch?v={until_watch_id}")
yield videos_urls[:trim_index]
return
except ValueError:
pass
yield videos_urls
# Extraction from a playlist only returns 100 videos at a time
# if self._extract_videos returns a continuation there are more
# than 100 songs inside a playlist, so we need to add further requests
# to gather all of them
if continuation:
load_more_url, headers, data = self._build_continuation_url(continuation)
else:
load_more_url, headers, data = None, None, None
while load_more_url and headers and data: # there is an url found
logger.debug("load more url: %s", load_more_url)
# requesting the next page of videos with the url generated from the
# previous page, needs to be a post
req = request.post(load_more_url, extra_headers=headers, data=data)
# extract up to 100 songs from the page loaded
# returns another continuation if more videos are available
videos_urls, continuation = self._extract_videos(req)
if until_watch_id:
try:
trim_index = videos_urls.index(f"/watch?v={until_watch_id}")
yield videos_urls[:trim_index]
return
except ValueError:
pass
yield videos_urls
if continuation:
load_more_url, headers, data = self._build_continuation_url(
continuation
)
else:
load_more_url, headers, data = None, None, None
def _build_continuation_url(self, continuation: str) -> Tuple[str, dict, dict]:
"""Helper method to build the url and headers required to request
the next page of videos
:param str continuation: Continuation extracted from the json response
of the last page
:rtype: Tuple[str, dict, dict]
:returns: Tuple of an url and required headers for the next http
request
"""
return (
(
# was changed to this format (and post requests)
# between 2021.03.02 and 2021.03.03
"https://www.youtube.com/youtubei/v1/browse?key="
f"{self.yt_api_key}"
),
{
"X-YouTube-Client-Name": "1",
"X-YouTube-Client-Version": "2.20200720.00.02",
},
# extra data required for post request
{
"continuation": continuation,
"context": {
"client": {
"clientName": "WEB",
"clientVersion": "2.20200720.00.02"
}
}
}
)
@staticmethod
def _extract_videos(raw_json: str) -> Tuple[List[str], Optional[str]]:
"""Extracts videos from a raw json page
:param str raw_json: Input json extracted from the page or the last
server response
:rtype: Tuple[List[str], Optional[str]]
:returns: Tuple containing a list of up to 100 video watch ids and
a continuation token, if more videos are available
"""
initial_data = json.loads(raw_json)
try:
# this is the json tree structure, if the json was extracted from
# html
section_contents = initial_data["contents"][
"twoColumnBrowseResultsRenderer"][
"tabs"][0]["tabRenderer"]["content"][
"sectionListRenderer"]["contents"]
try:
# Playlist without submenus
important_content = section_contents[
0]["itemSectionRenderer"][
"contents"][0]["playlistVideoListRenderer"]
except (KeyError, IndexError, TypeError):
# Playlist with submenus
important_content = section_contents[
1]["itemSectionRenderer"][
"contents"][0]["playlistVideoListRenderer"]
videos = important_content["contents"]
except (KeyError, IndexError, TypeError):
try:
# this is the json tree structure, if the json was directly sent
# by the server in a continuation response
# no longer a list and no longer has the "response" key
important_content = initial_data['onResponseReceivedActions'][0][
'appendContinuationItemsAction']['continuationItems']
videos = important_content
except (KeyError, IndexError, TypeError) as p:
logger.info(p)
return [], None
try:
continuation = videos[-1]['continuationItemRenderer'][
'continuationEndpoint'
]['continuationCommand']['token']
videos = videos[:-1]
except (KeyError, IndexError):
# if there is an error, no continuation is available
continuation = None
# remove duplicates
return (
uniqueify(
list(
# only extract the video ids from the video data
map(
lambda x: (
f"/watch?v="
f"{x['playlistVideoRenderer']['videoId']}"
),
videos
)
),
),
continuation,
)
def trimmed(self, video_id: str) -> Iterable[str]:
"""Retrieve a list of YouTube video URLs trimmed at the given video ID
i.e. if the playlist has video IDs 1,2,3,4 calling trimmed(3) returns
[1,2]
:type video_id: str
video ID to trim the returned list of playlist URLs at
:rtype: List[str]
:returns:
List of video URLs from the playlist trimmed at the given ID
"""
for page in self._paginate(until_watch_id=video_id):
yield from (self._video_url(watch_path) for watch_path in page)
def url_generator(self):
"""Generator that yields video URLs.
:Yields: Video URLs
"""
for page in self._paginate():
for video in page:
yield self._video_url(video)
@property # type: ignore
@cache
def video_urls(self) -> DeferredGeneratorList:
"""Complete links of all the videos in playlist
:rtype: List[str]
:returns: List of video URLs
"""
return DeferredGeneratorList(self.url_generator())
def videos_generator(self):
for url in self.video_urls:
yield YouTube(url)
@property
def videos(self) -> Iterable[YouTube]:
"""Yields YouTube objects of videos in this playlist
:rtype: List[YouTube]
:returns: List of YouTube
"""
return DeferredGeneratorList(self.videos_generator())
def __getitem__(self, i: Union[slice, int]) -> Union[str, List[str]]:
return self.video_urls[i]
def __len__(self) -> int:
return len(self.video_urls)
def __repr__(self) -> str:
return f"{repr(self.video_urls)}"
@property
@cache
def last_updated(self) -> Optional[date]:
"""Extract the date that the playlist was last updated.
For some playlists, this will be a specific date, which is returned as a datetime
object. For other playlists, this is an estimate such as "1 week ago". Due to the
fact that this value is returned as a string, pytube does a best-effort parsing
where possible, and returns the raw string where it is not possible.
:return: Date of last playlist update where possible, else the string provided
:rtype: datetime.date
"""
last_updated_text = self.sidebar_info[0]['playlistSidebarPrimaryInfoRenderer'][
'stats'][2]['runs'][1]['text']
try:
date_components = last_updated_text.split()
month = date_components[0]
day = date_components[1].strip(',')
year = date_components[2]
return datetime.strptime(
f"{month} {day:0>2} {year}", "%b %d %Y"
).date()
except (IndexError, KeyError):
return last_updated_text
@property
@cache
def title(self) -> Optional[str]:
"""Extract playlist title
:return: playlist title (name)
:rtype: Optional[str]
"""
return self.sidebar_info[0]['playlistSidebarPrimaryInfoRenderer'][
'title']['runs'][0]['text']
@property
def description(self) -> str:
return self.sidebar_info[0]['playlistSidebarPrimaryInfoRenderer'][
'description']['simpleText']
@property
def length(self):
"""Extract the number of videos in the playlist.
:return: Playlist video count
:rtype: int
"""
count_text = self.sidebar_info[0]['playlistSidebarPrimaryInfoRenderer'][
'stats'][0]['runs'][0]['text']
count_text = count_text.replace(',','')
return int(count_text)
@property
def views(self):
"""Extract view count for playlist.
:return: Playlist view count
:rtype: int
"""
# "1,234,567 views"
views_text = self.sidebar_info[0]['playlistSidebarPrimaryInfoRenderer'][
'stats'][1]['simpleText']
# "1,234,567"
count_text = views_text.split()[0]
# "1234567"
count_text = count_text.replace(',', '')
return int(count_text)
@property
def owner(self):
"""Extract the owner of the playlist.
:return: Playlist owner name.
:rtype: str
"""
return self.sidebar_info[1]['playlistSidebarSecondaryInfoRenderer'][
'videoOwner']['videoOwnerRenderer']['title']['runs'][0]['text']
@property
def owner_id(self):
"""Extract the channel_id of the owner of the playlist.
:return: Playlist owner's channel ID.
:rtype: str
"""
return self.sidebar_info[1]['playlistSidebarSecondaryInfoRenderer'][
'videoOwner']['videoOwnerRenderer']['title']['runs'][0][
'navigationEndpoint']['browseEndpoint']['browseId']
@property
def owner_url(self):
"""Create the channel url of the owner of the playlist.
:return: Playlist owner's channel url.
:rtype: str
"""
return f'https://www.youtube.com/channel/{self.owner_id}'
@staticmethod
def _video_url(watch_path: str):
return f"https://www.youtube.com{watch_path}"

View File

@@ -1,225 +0,0 @@
"""Module for interacting with YouTube search."""
# Native python imports
import logging
# Local imports
from pytube import YouTube
from pytube.innertube import InnerTube
logger = logging.getLogger(__name__)
class Search:
def __init__(self, query):
"""Initialize Search object.
:param str query:
Search query provided by the user.
"""
self.query = query
self._innertube_client = InnerTube(client='WEB')
# The first search, without a continuation, is structured differently
# and contains completion suggestions, so we must store this separately
self._initial_results = None
self._results = None
self._completion_suggestions = None
# Used for keeping track of query continuations so that new results
# are always returned when get_next_results() is called
self._current_continuation = None
@property
def completion_suggestions(self):
"""Return query autocompletion suggestions for the query.
:rtype: list
:returns:
A list of autocomplete suggestions provided by YouTube for the query.
"""
if self._completion_suggestions:
return self._completion_suggestions
if self.results:
self._completion_suggestions = self._initial_results['refinements']
return self._completion_suggestions
@property
def results(self):
"""Return search results.
On first call, will generate and return the first set of results.
Additional results can be generated using ``.get_next_results()``.
:rtype: list
:returns:
A list of YouTube objects.
"""
if self._results:
return self._results
videos, continuation = self.fetch_and_parse()
self._results = videos
self._current_continuation = continuation
return self._results
def get_next_results(self):
"""Use the stored continuation string to fetch the next set of results.
This method does not return the results, but instead updates the results property.
"""
if self._current_continuation:
videos, continuation = self.fetch_and_parse(self._current_continuation)
self._results.extend(videos)
self._current_continuation = continuation
else:
raise IndexError
def fetch_and_parse(self, continuation=None):
"""Fetch from the innertube API and parse the results.
:param str continuation:
Continuation string for fetching results.
:rtype: tuple
:returns:
A tuple of a list of YouTube objects and a continuation string.
"""
# Begin by executing the query and identifying the relevant sections
# of the results
raw_results = self.fetch_query(continuation)
# Initial result is handled by try block, continuations by except block
try:
sections = raw_results['contents']['twoColumnSearchResultsRenderer'][
'primaryContents']['sectionListRenderer']['contents']
except KeyError:
sections = raw_results['onResponseReceivedCommands'][0][
'appendContinuationItemsAction']['continuationItems']
item_renderer = None
continuation_renderer = None
for s in sections:
if 'itemSectionRenderer' in s:
item_renderer = s['itemSectionRenderer']
if 'continuationItemRenderer' in s:
continuation_renderer = s['continuationItemRenderer']
# If the continuationItemRenderer doesn't exist, assume no further results
if continuation_renderer:
next_continuation = continuation_renderer['continuationEndpoint'][
'continuationCommand']['token']
else:
next_continuation = None
# If the itemSectionRenderer doesn't exist, assume no results.
if item_renderer:
videos = []
raw_video_list = item_renderer['contents']
for video_details in raw_video_list:
# Skip over ads
if video_details.get('searchPyvRenderer', {}).get('ads', None):
continue
# Skip "recommended" type videos e.g. "people also watched" and "popular X"
# that break up the search results
if 'shelfRenderer' in video_details:
continue
# Skip auto-generated "mix" playlist results
if 'radioRenderer' in video_details:
continue
# Skip playlist results
if 'playlistRenderer' in video_details:
continue
# Skip channel results
if 'channelRenderer' in video_details:
continue
# Skip 'people also searched for' results
if 'horizontalCardListRenderer' in video_details:
continue
# Can't seem to reproduce, probably related to typo fix suggestions
if 'didYouMeanRenderer' in video_details:
continue
# Seems to be the renderer used for the image shown on a no results page
if 'backgroundPromoRenderer' in video_details:
continue
if 'videoRenderer' not in video_details:
logger.warn('Unexpected renderer encountered.')
logger.warn(f'Renderer name: {video_details.keys()}')
logger.warn(f'Search term: {self.query}')
logger.warn(
'Please open an issue at '
'https://github.com/pytube/pytube/issues '
'and provide this log output.'
)
continue
# Extract relevant video information from the details.
# Some of this can be used to pre-populate attributes of the
# YouTube object.
vid_renderer = video_details['videoRenderer']
vid_id = vid_renderer['videoId']
vid_url = f'https://www.youtube.com/watch?v={vid_id}'
vid_title = vid_renderer['title']['runs'][0]['text']
vid_channel_name = vid_renderer['ownerText']['runs'][0]['text']
vid_channel_uri = vid_renderer['ownerText']['runs'][0][
'navigationEndpoint']['commandMetadata']['webCommandMetadata']['url']
# Livestreams have "runs", non-livestreams have "simpleText",
# and scheduled releases do not have 'viewCountText'
if 'viewCountText' in vid_renderer:
if 'runs' in vid_renderer['viewCountText']:
vid_view_count_text = vid_renderer['viewCountText']['runs'][0]['text']
else:
vid_view_count_text = vid_renderer['viewCountText']['simpleText']
# Strip ' views' text, then remove commas
stripped_text = vid_view_count_text.split()[0].replace(',','')
if stripped_text == 'No':
vid_view_count = 0
else:
vid_view_count = int(stripped_text)
else:
vid_view_count = 0
if 'lengthText' in vid_renderer:
vid_length = vid_renderer['lengthText']['simpleText']
else:
vid_length = None
vid_metadata = {
'id': vid_id,
'url': vid_url,
'title': vid_title,
'channel_name': vid_channel_name,
'channel_url': vid_channel_uri,
'view_count': vid_view_count,
'length': vid_length
}
# Construct YouTube object from metadata and append to results
vid = YouTube(vid_metadata['url'])
vid.author = vid_metadata['channel_name']
vid.title = vid_metadata['title']
videos.append(vid)
else:
videos = None
return videos, next_continuation
def fetch_query(self, continuation=None):
"""Fetch raw results from the innertube API.
:param str continuation:
Continuation string for fetching results.
:rtype: dict
:returns:
The raw json object returned by the innertube API.
"""
query_results = self._innertube_client.search(self.query, continuation)
if not self._initial_results:
self._initial_results = query_results
return query_results # noqa:R504

View File

@@ -1,145 +0,0 @@
"""Library specific exception definitions."""
from typing import Pattern, Union
class PytubeError(Exception):
"""Base pytube exception that all others inherit.
This is done to not pollute the built-in exceptions, which *could* result
in unintended errors being unexpectedly and incorrectly handled within
implementers code.
"""
class MaxRetriesExceeded(PytubeError):
"""Maximum number of retries exceeded."""
class HTMLParseError(PytubeError):
"""HTML could not be parsed"""
class ExtractError(PytubeError):
"""Data extraction based exception."""
class RegexMatchError(ExtractError):
"""Regex pattern did not return any matches."""
def __init__(self, caller: str, pattern: Union[str, Pattern]):
"""
:param str caller:
Calling function
:param str pattern:
Pattern that failed to match
"""
super().__init__(f"{caller}: could not find match for {pattern}")
self.caller = caller
self.pattern = pattern
class VideoUnavailable(PytubeError):
"""Base video unavailable error."""
def __init__(self, video_id: str):
"""
:param str video_id:
A YouTube video identifier.
"""
self.video_id = video_id
super().__init__(self.error_string)
@property
def error_string(self):
return f'{self.video_id} is unavailable'
class AgeRestrictedError(VideoUnavailable):
"""Video is age restricted, and cannot be accessed without OAuth."""
def __init__(self, video_id: str):
"""
:param str video_id:
A YouTube video identifier.
"""
self.video_id = video_id
super().__init__(self.video_id)
@property
def error_string(self):
return f"{self.video_id} is age restricted, and can't be accessed without logging in."
class LiveStreamError(VideoUnavailable):
"""Video is a live stream."""
def __init__(self, video_id: str):
"""
:param str video_id:
A YouTube video identifier.
"""
self.video_id = video_id
super().__init__(self.video_id)
@property
def error_string(self):
return f'{self.video_id} is streaming live and cannot be loaded'
class VideoPrivate(VideoUnavailable):
def __init__(self, video_id: str):
"""
:param str video_id:
A YouTube video identifier.
"""
self.video_id = video_id
super().__init__(self.video_id)
@property
def error_string(self):
return f'{self.video_id} is a private video'
class RecordingUnavailable(VideoUnavailable):
def __init__(self, video_id: str):
"""
:param str video_id:
A YouTube video identifier.
"""
self.video_id = video_id
super().__init__(self.video_id)
@property
def error_string(self):
return f'{self.video_id} does not have a live stream recording available'
class MembersOnly(VideoUnavailable):
"""Video is members-only.
YouTube has special videos that are only viewable to users who have
subscribed to a content creator.
ref: https://support.google.com/youtube/answer/7544492?hl=en
"""
def __init__(self, video_id: str):
"""
:param str video_id:
A YouTube video identifier.
"""
self.video_id = video_id
super().__init__(self.video_id)
@property
def error_string(self):
return f'{self.video_id} is a members-only video'
class VideoRegionBlocked(VideoUnavailable):
def __init__(self, video_id: str):
"""
:param str video_id:
A YouTube video identifier.
"""
self.video_id = video_id
super().__init__(self.video_id)
@property
def error_string(self):
return f'{self.video_id} is not available in your region'

View File

@@ -1,579 +0,0 @@
"""This module contains all non-cipher related data extraction logic."""
import logging
import urllib.parse
import re
from collections import OrderedDict
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple
from urllib.parse import parse_qs, quote, urlencode, urlparse
from pytube.cipher import Cipher
from pytube.exceptions import HTMLParseError, LiveStreamError, RegexMatchError
from pytube.helpers import regex_search
from pytube.metadata import YouTubeMetadata
from pytube.parser import parse_for_object, parse_for_all_objects
logger = logging.getLogger(__name__)
def publish_date(watch_html: str):
"""Extract publish date
:param str watch_html:
The html contents of the watch page.
:rtype: str
:returns:
Publish date of the video.
"""
try:
result = regex_search(
r"(?<=itemprop=\"datePublished\" content=\")\d{4}-\d{2}-\d{2}",
watch_html, group=0
)
except RegexMatchError:
return None
return datetime.strptime(result, '%Y-%m-%d')
def recording_available(watch_html):
"""Check if live stream recording is available.
:param str watch_html:
The html contents of the watch page.
:rtype: bool
:returns:
Whether or not the content is private.
"""
unavailable_strings = [
'This live stream recording is not available.'
]
for string in unavailable_strings:
if string in watch_html:
return False
return True
def is_private(watch_html):
"""Check if content is private.
:param str watch_html:
The html contents of the watch page.
:rtype: bool
:returns:
Whether or not the content is private.
"""
private_strings = [
"This is a private video. Please sign in to verify that you may see it.",
"\"simpleText\":\"Private video\"",
"This video is private."
]
for string in private_strings:
if string in watch_html:
return True
return False
def is_age_restricted(watch_html: str) -> bool:
"""Check if content is age restricted.
:param str watch_html:
The html contents of the watch page.
:rtype: bool
:returns:
Whether or not the content is age restricted.
"""
try:
regex_search(r"og:restrictions:age", watch_html, group=0)
except RegexMatchError:
return False
return True
def playability_status(watch_html: str) -> (str, str):
"""Return the playability status and status explanation of a video.
For example, a video may have a status of LOGIN_REQUIRED, and an explanation
of "This is a private video. Please sign in to verify that you may see it."
This explanation is what gets incorporated into the media player overlay.
:param str watch_html:
The html contents of the watch page.
:rtype: bool
:returns:
Playability status and reason of the video.
"""
player_response = initial_player_response(watch_html)
status_dict = player_response.get('playabilityStatus', {})
if 'liveStreamability' in status_dict:
return 'LIVE_STREAM', 'Video is a live stream.'
if 'status' in status_dict:
if 'reason' in status_dict:
return status_dict['status'], [status_dict['reason']]
if 'messages' in status_dict:
return status_dict['status'], status_dict['messages']
return None, [None]
def video_id(url: str) -> str:
"""Extract the ``video_id`` from a YouTube url.
This function supports the following patterns:
- :samp:`https://youtube.com/watch?v={video_id}`
- :samp:`https://youtube.com/embed/{video_id}`
- :samp:`https://youtu.be/{video_id}`
:param str url:
A YouTube url containing a video id.
:rtype: str
:returns:
YouTube video id.
"""
return regex_search(r"(?:v=|\/)([0-9A-Za-z_-]{11}).*", url, group=1)
def playlist_id(url: str) -> str:
"""Extract the ``playlist_id`` from a YouTube url.
This function supports the following patterns:
- :samp:`https://youtube.com/playlist?list={playlist_id}`
- :samp:`https://youtube.com/watch?v={video_id}&list={playlist_id}`
:param str url:
A YouTube url containing a playlist id.
:rtype: str
:returns:
YouTube playlist id.
"""
parsed = urllib.parse.urlparse(url)
return parse_qs(parsed.query)['list'][0]
def channel_name(url: str) -> str:
"""Extract the ``channel_name`` or ``channel_id`` from a YouTube url.
This function supports the following patterns:
- :samp:`https://youtube.com/c/{channel_name}/*`
- :samp:`https://youtube.com/channel/{channel_id}/*
- :samp:`https://youtube.com/u/{channel_name}/*`
- :samp:`https://youtube.com/user/{channel_id}/*
:param str url:
A YouTube url containing a channel name.
:rtype: str
:returns:
YouTube channel name.
"""
patterns = [
r"(?:\/(c)\/([%\d\w_\-]+)(\/.*)?)",
r"(?:\/(channel)\/([%\w\d_\-]+)(\/.*)?)",
r"(?:\/(u)\/([%\d\w_\-]+)(\/.*)?)",
r"(?:\/(user)\/([%\w\d_\-]+)(\/.*)?)"
]
for pattern in patterns:
regex = re.compile(pattern)
function_match = regex.search(url)
if function_match:
logger.debug("finished regex search, matched: %s", pattern)
uri_style = function_match.group(1)
uri_identifier = function_match.group(2)
return f'/{uri_style}/{uri_identifier}'
raise RegexMatchError(
caller="channel_name", pattern="patterns"
)
def video_info_url(video_id: str, watch_url: str) -> str:
"""Construct the video_info url.
:param str video_id:
A YouTube video identifier.
:param str watch_url:
A YouTube watch url.
:rtype: str
:returns:
:samp:`https://youtube.com/get_video_info` with necessary GET
parameters.
"""
params = OrderedDict(
[
("video_id", video_id),
("ps", "default"),
("eurl", quote(watch_url)),
("hl", "en_US"),
("html5", "1"),
("c", "TVHTML5"),
("cver", "7.20201028"),
]
)
return _video_info_url(params)
def video_info_url_age_restricted(video_id: str, embed_html: str) -> str:
"""Construct the video_info url.
:param str video_id:
A YouTube video identifier.
:param str embed_html:
The html contents of the embed page (for age restricted videos).
:rtype: str
:returns:
:samp:`https://youtube.com/get_video_info` with necessary GET
parameters.
"""
try:
sts = regex_search(r'"sts"\s*:\s*(\d+)', embed_html, group=1)
except RegexMatchError:
sts = ""
# Here we use ``OrderedDict`` so that the output is consistent between
# Python 2.7+.
eurl = f"https://youtube.googleapis.com/v/{video_id}"
params = OrderedDict(
[
("video_id", video_id),
("eurl", eurl),
("sts", sts),
("html5", "1"),
("c", "TVHTML5"),
("cver", "7.20201028"),
]
)
return _video_info_url(params)
def _video_info_url(params: OrderedDict) -> str:
return "https://www.youtube.com/get_video_info?" + urlencode(params)
def js_url(html: str) -> str:
"""Get the base JavaScript url.
Construct the base JavaScript url, which contains the decipher
"transforms".
:param str html:
The html contents of the watch page.
"""
try:
base_js = get_ytplayer_config(html)['assets']['js']
except (KeyError, RegexMatchError):
base_js = get_ytplayer_js(html)
return "https://youtube.com" + base_js
def mime_type_codec(mime_type_codec: str) -> Tuple[str, List[str]]:
"""Parse the type data.
Breaks up the data in the ``type`` key of the manifest, which contains the
mime type and codecs serialized together, and splits them into separate
elements.
**Example**:
mime_type_codec('audio/webm; codecs="opus"') -> ('audio/webm', ['opus'])
:param str mime_type_codec:
String containing mime type and codecs.
:rtype: tuple
:returns:
The mime type and a list of codecs.
"""
pattern = r"(\w+\/\w+)\;\scodecs=\"([a-zA-Z-0-9.,\s]*)\""
regex = re.compile(pattern)
results = regex.search(mime_type_codec)
if not results:
raise RegexMatchError(caller="mime_type_codec", pattern=pattern)
mime_type, codecs = results.groups()
return mime_type, [c.strip() for c in codecs.split(",")]
def get_ytplayer_js(html: str) -> Any:
"""Get the YouTube player base JavaScript path.
:param str html
The html contents of the watch page.
:rtype: str
:returns:
Path to YouTube's base.js file.
"""
js_url_patterns = [
r"(/s/player/[\w\d]+/[\w\d_/.]+/base\.js)"
]
for pattern in js_url_patterns:
regex = re.compile(pattern)
function_match = regex.search(html)
if function_match:
logger.debug("finished regex search, matched: %s", pattern)
yt_player_js = function_match.group(1)
return yt_player_js
raise RegexMatchError(
caller="get_ytplayer_js", pattern="js_url_patterns"
)
def get_ytplayer_config(html: str) -> Any:
"""Get the YouTube player configuration data from the watch html.
Extract the ``ytplayer_config``, which is json data embedded within the
watch html and serves as the primary source of obtaining the stream
manifest data.
:param str html:
The html contents of the watch page.
:rtype: str
:returns:
Substring of the html containing the encoded manifest data.
"""
logger.debug("finding initial function name")
config_patterns = [
r"ytplayer\.config\s*=\s*",
r"ytInitialPlayerResponse\s*=\s*"
]
for pattern in config_patterns:
# Try each pattern consecutively if they don't find a match
try:
return parse_for_object(html, pattern)
except HTMLParseError as e:
logger.debug(f'Pattern failed: {pattern}')
logger.debug(e)
continue
# setConfig() needs to be handled a little differently.
# We want to parse the entire argument to setConfig()
# and use then load that as json to find PLAYER_CONFIG
# inside of it.
setconfig_patterns = [
r"yt\.setConfig\(.*['\"]PLAYER_CONFIG['\"]:\s*"
]
for pattern in setconfig_patterns:
# Try each pattern consecutively if they don't find a match
try:
return parse_for_object(html, pattern)
except HTMLParseError:
continue
raise RegexMatchError(
caller="get_ytplayer_config", pattern="config_patterns, setconfig_patterns"
)
def get_ytcfg(html: str) -> str:
"""Get the entirety of the ytcfg object.
This is built over multiple pieces, so we have to find all matches and
combine the dicts together.
:param str html:
The html contents of the watch page.
:rtype: str
:returns:
Substring of the html containing the encoded manifest data.
"""
ytcfg = {}
ytcfg_patterns = [
r"ytcfg\s=\s",
r"ytcfg\.set\("
]
for pattern in ytcfg_patterns:
# Try each pattern consecutively and try to build a cohesive object
try:
found_objects = parse_for_all_objects(html, pattern)
for obj in found_objects:
ytcfg.update(obj)
except HTMLParseError:
continue
if len(ytcfg) > 0:
return ytcfg
raise RegexMatchError(
caller="get_ytcfg", pattern="ytcfg_pattenrs"
)
def apply_signature(stream_manifest: Dict, vid_info: Dict, js: str) -> None:
"""Apply the decrypted signature to the stream manifest.
:param dict stream_manifest:
Details of the media streams available.
:param str js:
The contents of the base.js asset file.
"""
cipher = Cipher(js=js)
for i, stream in enumerate(stream_manifest):
try:
url: str = stream["url"]
except KeyError:
live_stream = (
vid_info.get("playabilityStatus", {},)
.get("liveStreamability")
)
if live_stream:
raise LiveStreamError("UNKNOWN")
# 403 Forbidden fix.
if "signature" in url or (
"s" not in stream and ("&sig=" in url or "&lsig=" in url)
):
# For certain videos, YouTube will just provide them pre-signed, in
# which case there's no real magic to download them and we can skip
# the whole signature descrambling entirely.
logger.debug("signature found, skip decipher")
continue
signature = cipher.get_signature(ciphered_signature=stream["s"])
logger.debug(
"finished descrambling signature for itag=%s", stream["itag"]
)
parsed_url = urlparse(url)
# Convert query params off url to dict
query_params = parse_qs(urlparse(url).query)
query_params = {
k: v[0] for k,v in query_params.items()
}
query_params['sig'] = signature
if 'ratebypass' not in query_params.keys():
# Cipher n to get the updated value
initial_n = list(query_params['n'])
new_n = cipher.calculate_n(initial_n)
query_params['n'] = new_n
url = f'{parsed_url.scheme}://{parsed_url.netloc}{parsed_url.path}?{urlencode(query_params)}' # noqa:E501
# 403 forbidden fix
stream_manifest[i]["url"] = url
def apply_descrambler(stream_data: Dict) -> None:
"""Apply various in-place transforms to YouTube's media stream data.
Creates a ``list`` of dictionaries by string splitting on commas, then
taking each list item, parsing it as a query string, converting it to a
``dict`` and unquoting the value.
:param dict stream_data:
Dictionary containing query string encoded values.
**Example**:
>>> d = {'foo': 'bar=1&var=test,em=5&t=url%20encoded'}
>>> apply_descrambler(d, 'foo')
>>> print(d)
{'foo': [{'bar': '1', 'var': 'test'}, {'em': '5', 't': 'url encoded'}]}
"""
if 'url' in stream_data:
return None
# Merge formats and adaptiveFormats into a single list
formats = []
if 'formats' in stream_data.keys():
formats.extend(stream_data['formats'])
if 'adaptiveFormats' in stream_data.keys():
formats.extend(stream_data['adaptiveFormats'])
# Extract url and s from signatureCiphers as necessary
for data in formats:
if 'url' not in data:
if 'signatureCipher' in data:
cipher_url = parse_qs(data['signatureCipher'])
data['url'] = cipher_url['url'][0]
data['s'] = cipher_url['s'][0]
data['is_otf'] = data.get('type') == 'FORMAT_STREAM_TYPE_OTF'
logger.debug("applying descrambler")
return formats
def initial_data(watch_html: str) -> str:
"""Extract the ytInitialData json from the watch_html page.
This mostly contains metadata necessary for rendering the page on-load,
such as video information, copyright notices, etc.
@param watch_html: Html of the watch page
@return:
"""
patterns = [
r"window\[['\"]ytInitialData['\"]]\s*=\s*",
r"ytInitialData\s*=\s*"
]
for pattern in patterns:
try:
return parse_for_object(watch_html, pattern)
except HTMLParseError:
pass
raise RegexMatchError(caller='initial_data', pattern='initial_data_pattern')
def initial_player_response(watch_html: str) -> str:
"""Extract the ytInitialPlayerResponse json from the watch_html page.
This mostly contains metadata necessary for rendering the page on-load,
such as video information, copyright notices, etc.
@param watch_html: Html of the watch page
@return:
"""
patterns = [
r"window\[['\"]ytInitialPlayerResponse['\"]]\s*=\s*",
r"ytInitialPlayerResponse\s*=\s*"
]
for pattern in patterns:
try:
return parse_for_object(watch_html, pattern)
except HTMLParseError:
pass
raise RegexMatchError(
caller='initial_player_response',
pattern='initial_player_response_pattern'
)
def metadata(initial_data) -> Optional[YouTubeMetadata]:
"""Get the informational metadata for the video.
e.g.:
[
{
'Song': '강남스타일(Gangnam Style)',
'Artist': 'PSY',
'Album': 'PSY SIX RULES Pt.1',
'Licensed to YouTube by': 'YG Entertainment Inc. [...]'
}
]
:rtype: YouTubeMetadata
"""
try:
metadata_rows: List = initial_data["contents"]["twoColumnWatchNextResults"][
"results"]["results"]["contents"][1]["videoSecondaryInfoRenderer"][
"metadataRowContainer"]["metadataRowContainerRenderer"]["rows"]
except (KeyError, IndexError):
# If there's an exception accessing this data, it probably doesn't exist.
return YouTubeMetadata([])
# Rows appear to only have "metadataRowRenderer" or "metadataRowHeaderRenderer"
# and we only care about the former, so we filter the others
metadata_rows = filter(
lambda x: "metadataRowRenderer" in x.keys(),
metadata_rows
)
# We then access the metadataRowRenderer key in each element
# and build a metadata object from this new list
metadata_rows = [x["metadataRowRenderer"] for x in metadata_rows]
return YouTubeMetadata(metadata_rows)

View File

@@ -1,335 +0,0 @@
"""Various helper functions implemented by pytube."""
import functools
import gzip
import json
import logging
import os
import re
import warnings
from typing import Any, Callable, Dict, List, Optional, TypeVar
from urllib import request
from pytube.exceptions import RegexMatchError
logger = logging.getLogger(__name__)
class DeferredGeneratorList:
"""A wrapper class for deferring list generation.
Pytube has some continuation generators that create web calls, which means
that any time a full list is requested, all of those web calls must be
made at once, which could lead to slowdowns. This will allow individual
elements to be queried, so that slowdowns only happen as necessary. For
example, you can iterate over elements in the list without accessing them
all simultaneously. This should allow for speed improvements for playlist
and channel interactions.
"""
def __init__(self, generator):
"""Construct a :class:`DeferredGeneratorList <DeferredGeneratorList>`.
:param generator generator:
The deferrable generator to create a wrapper for.
:param func func:
(Optional) A function to call on the generator items to produce the list.
"""
self.gen = generator
self._elements = []
def __eq__(self, other):
"""We want to mimic list behavior for comparison."""
return list(self) == other
def __getitem__(self, key) -> Any:
"""Only generate items as they're asked for."""
# We only allow querying with indexes.
if not isinstance(key, (int, slice)):
raise TypeError('Key must be either a slice or int.')
# Convert int keys to slice
key_slice = key
if isinstance(key, int):
key_slice = slice(key, key + 1, 1)
# Generate all elements up to the final item
while len(self._elements) < key_slice.stop:
try:
next_item = next(self.gen)
except StopIteration:
# If we can't find enough elements for the slice, raise an IndexError
raise IndexError
else:
self._elements.append(next_item)
return self._elements[key]
def __iter__(self):
"""Custom iterator for dynamically generated list."""
iter_index = 0
while True:
try:
curr_item = self[iter_index]
except IndexError:
return
else:
yield curr_item
iter_index += 1
def __next__(self) -> Any:
"""Fetch next element in iterator."""
try:
curr_element = self[self.iter_index]
except IndexError:
raise StopIteration
self.iter_index += 1
return curr_element # noqa:R504
def __len__(self) -> int:
"""Return length of list of all items."""
self.generate_all()
return len(self._elements)
def __repr__(self) -> str:
"""String representation of all items."""
self.generate_all()
return str(self._elements)
def __reversed__(self):
self.generate_all()
return self._elements[::-1]
def generate_all(self):
"""Generate all items."""
while True:
try:
next_item = next(self.gen)
except StopIteration:
break
else:
self._elements.append(next_item)
def regex_search(pattern: str, string: str, group: int) -> str:
"""Shortcut method to search a string for a given pattern.
:param str pattern:
A regular expression pattern.
:param str string:
A target string to search.
:param int group:
Index of group to return.
:rtype:
str or tuple
:returns:
Substring pattern matches.
"""
regex = re.compile(pattern)
results = regex.search(string)
if not results:
raise RegexMatchError(caller="regex_search", pattern=pattern)
logger.debug("matched regex search: %s", pattern)
return results.group(group)
def safe_filename(s: str, max_length: int = 255) -> str:
"""Sanitize a string making it safe to use as a filename.
This function was based off the limitations outlined here:
https://en.wikipedia.org/wiki/Filename.
:param str s:
A string to make safe for use as a file name.
:param int max_length:
The maximum filename character length.
:rtype: str
:returns:
A sanitized string.
"""
# Characters in range 0-31 (0x00-0x1F) are not allowed in ntfs filenames.
ntfs_characters = [chr(i) for i in range(0, 31)]
characters = [
r'"',
r"\#",
r"\$",
r"\%",
r"'",
r"\*",
r"\,",
r"\.",
r"\/",
r"\:",
r'"',
r"\;",
r"\<",
r"\>",
r"\?",
r"\\",
r"\^",
r"\|",
r"\~",
r"\\\\",
]
pattern = "|".join(ntfs_characters + characters)
regex = re.compile(pattern, re.UNICODE)
filename = regex.sub("", s)
return filename[:max_length].rsplit(" ", 0)[0]
def setup_logger(level: int = logging.ERROR, log_filename: Optional[str] = None) -> None:
"""Create a configured instance of logger.
:param int level:
Describe the severity level of the logs to handle.
"""
fmt = "[%(asctime)s] %(levelname)s in %(module)s: %(message)s"
date_fmt = "%H:%M:%S"
formatter = logging.Formatter(fmt, datefmt=date_fmt)
# https://github.com/pytube/pytube/issues/163
logger = logging.getLogger("pytube")
logger.setLevel(level)
stream_handler = logging.StreamHandler()
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)
if log_filename is not None:
file_handler = logging.FileHandler(log_filename)
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
GenericType = TypeVar("GenericType")
def cache(func: Callable[..., GenericType]) -> GenericType:
""" mypy compatible annotation wrapper for lru_cache"""
return functools.lru_cache()(func) # type: ignore
def deprecated(reason: str) -> Callable:
"""
This is a decorator which can be used to mark functions
as deprecated. It will result in a warning being emitted
when the function is used.
"""
def decorator(func1):
message = "Call to deprecated function {name} ({reason})."
@functools.wraps(func1)
def new_func1(*args, **kwargs):
warnings.simplefilter("always", DeprecationWarning)
warnings.warn(
message.format(name=func1.__name__, reason=reason),
category=DeprecationWarning,
stacklevel=2,
)
warnings.simplefilter("default", DeprecationWarning)
return func1(*args, **kwargs)
return new_func1
return decorator
def target_directory(output_path: Optional[str] = None) -> str:
"""
Function for determining target directory of a download.
Returns an absolute path (if relative one given) or the current
path (if none given). Makes directory if it does not exist.
:type output_path: str
:rtype: str
:returns:
An absolute directory path as a string.
"""
if output_path:
if not os.path.isabs(output_path):
output_path = os.path.join(os.getcwd(), output_path)
else:
output_path = os.getcwd()
os.makedirs(output_path, exist_ok=True)
return output_path
def install_proxy(proxy_handler: Dict[str, str]) -> None:
proxy_support = request.ProxyHandler(proxy_handler)
opener = request.build_opener(proxy_support)
request.install_opener(opener)
def uniqueify(duped_list: List) -> List:
"""Remove duplicate items from a list, while maintaining list order.
:param List duped_list
List to remove duplicates from
:return List result
De-duplicated list
"""
seen: Dict[Any, bool] = {}
result = []
for item in duped_list:
if item in seen:
continue
seen[item] = True
result.append(item)
return result
def generate_all_html_json_mocks():
"""Regenerate the video mock json files for all current test videos.
This should automatically output to the test/mocks directory.
"""
test_vid_ids = [
'2lAe1cqCOXo',
'5YceQ8YqYMc',
'irauhITDrsE',
'm8uHb5jIGN8',
'QRS8MkLhQmM',
'WXxV9g7lsFE'
]
for vid_id in test_vid_ids:
create_mock_html_json(vid_id)
def create_mock_html_json(vid_id) -> Dict[str, Any]:
"""Generate a json.gz file with sample html responses.
:param str vid_id
YouTube video id
:return dict data
Dict used to generate the json.gz file
"""
from pytube import YouTube
gzip_filename = 'yt-video-%s-html.json.gz' % vid_id
# Get the pytube directory in order to navigate to /tests/mocks
pytube_dir_path = os.path.abspath(
os.path.join(
os.path.dirname(__file__),
os.path.pardir
)
)
pytube_mocks_path = os.path.join(pytube_dir_path, 'tests', 'mocks')
gzip_filepath = os.path.join(pytube_mocks_path, gzip_filename)
yt = YouTube(f'https://www.youtube.com/watch?v={vid_id}')
html_data = {
'url': yt.watch_url,
'js': yt.js,
'embed_html': yt.embed_html,
'watch_html': yt.watch_html,
'vid_info': yt.vid_info
}
logger.info(f'Outputing json.gz file to {gzip_filepath}')
with gzip.open(gzip_filepath, 'wb') as f:
f.write(json.dumps(html_data).encode('utf-8'))
return html_data

View File

@@ -1,507 +0,0 @@
"""This module is designed to interact with the innertube API.
This module is NOT intended to be used directly by end users, as each of the
interfaces returns raw results. These should instead be parsed to extract
the useful information for the end user.
"""
# Native python imports
import json
import os
import pathlib
import time
from urllib import parse
# Local imports
from pytube import request
# YouTube on TV client secrets
_client_id = '861556708454-d6dlm3lh05idd8npek18k6be8ba3oc68.apps.googleusercontent.com'
_client_secret = 'SboVhoG9s0rNafixCSGGKXAT'
# Extracted API keys -- unclear what these are linked to.
_api_keys = [
'AIzaSyAO_FJ2SlqU8Q4STEHLGCilw_Y9_11qcW8',
'AIzaSyCtkvNIR1HCEwzsqK6JuE6KqpyjusIRI30',
'AIzaSyA8eiZmM1FaDVjRy-df2KTyQ_vz_yYM39w',
'AIzaSyC8UYZpvA2eknNex0Pjid0_eTLJoDu6los',
'AIzaSyCjc_pVEDi4qsv5MtC2dMXzpIaDoRFLsxw',
'AIzaSyDHQ9ipnphqTzDqZsbtd8_Ru4_kiKVQe2k'
]
_default_clients = {
'WEB': {
'context': {
'client': {
'clientName': 'WEB',
'clientVersion': '2.20200720.00.02'
}
},
'header': {
'User-Agent': 'Mozilla/5.0'
},
'api_key': 'AIzaSyAO_FJ2SlqU8Q4STEHLGCilw_Y9_11qcW8'
},
'ANDROID': {
'context': {
'client': {
'clientName': 'ANDROID',
'clientVersion': '17.31.35',
'androidSdkVersion': 30
}
},
'header': {
'User-Agent': 'com.google.android.youtube/',
},
'api_key': 'AIzaSyAO_FJ2SlqU8Q4STEHLGCilw_Y9_11qcW8'
},
'IOS': {
'context': {
'client': {
'clientName': 'IOS',
'clientVersion': '17.33.2',
'deviceModel': 'iPhone14,3'
}
},
'header': {
'User-Agent': 'com.google.ios.youtube/'
},
'api_key': 'AIzaSyAO_FJ2SlqU8Q4STEHLGCilw_Y9_11qcW8'
},
'WEB_EMBED': {
'context': {
'client': {
'clientName': 'WEB_EMBEDDED_PLAYER',
'clientVersion': '2.20210721.00.00',
'clientScreen': 'EMBED'
}
},
'header': {
'User-Agent': 'Mozilla/5.0'
},
'api_key': 'AIzaSyAO_FJ2SlqU8Q4STEHLGCilw_Y9_11qcW8'
},
'ANDROID_EMBED': {
'context': {
'client': {
'clientName': 'ANDROID_EMBEDDED_PLAYER',
'clientVersion': '17.31.35',
'clientScreen': 'EMBED',
'androidSdkVersion': 30,
}
},
'header': {
'User-Agent': 'com.google.android.youtube/'
},
'api_key': 'AIzaSyAO_FJ2SlqU8Q4STEHLGCilw_Y9_11qcW8'
},
'IOS_EMBED': {
'context': {
'client': {
'clientName': 'IOS_MESSAGES_EXTENSION',
'clientVersion': '17.33.2',
'deviceModel': 'iPhone14,3'
}
},
'header': {
'User-Agent': 'com.google.ios.youtube/'
},
'api_key': 'AIzaSyAO_FJ2SlqU8Q4STEHLGCilw_Y9_11qcW8'
},
'WEB_MUSIC': {
'context': {
'client': {
'clientName': 'WEB_REMIX',
'clientVersion': '1.20220727.01.00',
}
},
'header': {
'User-Agent': 'Mozilla/5.0'
},
'api_key': 'AIzaSyAO_FJ2SlqU8Q4STEHLGCilw_Y9_11qcW8'
},
'ANDROID_MUSIC': {
'context': {
'client': {
'clientName': 'ANDROID_MUSIC',
'clientVersion': '5.16.51',
'androidSdkVersion': 30
}
},
'header': {
'User-Agent': 'com.google.android.apps.youtube.music/'
},
'api_key': 'AIzaSyAO_FJ2SlqU8Q4STEHLGCilw_Y9_11qcW8'
},
'IOS_MUSIC': {
'context': {
'client': {
'clientName': 'IOS_MUSIC',
'clientVersion': '5.21',
'deviceModel': 'iPhone14,3'
}
},
'header': {
'User-Agent': 'com.google.ios.youtubemusic/'
},
'api_key': 'AIzaSyAO_FJ2SlqU8Q4STEHLGCilw_Y9_11qcW8'
},
'WEB_CREATOR': {
'context': {
'client': {
'clientName': 'WEB_CREATOR',
'clientVersion': '1.20220726.00.00',
}
},
'header': {
'User-Agent': 'Mozilla/5.0'
},
'api_key': 'AIzaSyAO_FJ2SlqU8Q4STEHLGCilw_Y9_11qcW8'
},
'ANDROID_CREATOR': {
'context': {
'client': {
'clientName': 'ANDROID_CREATOR',
'clientVersion': '22.30.100',
'androidSdkVersion': 30,
}
},
'header': {
'User-Agent': 'com.google.android.apps.youtube.creator/',
},
'api_key': 'AIzaSyAO_FJ2SlqU8Q4STEHLGCilw_Y9_11qcW8'
},
'IOS_CREATOR': {
'context': {
'client': {
'clientName': 'IOS_CREATOR',
'clientVersion': '22.33.101',
'deviceModel': 'iPhone14,3',
}
},
'header': {
'User-Agent': 'com.google.ios.ytcreator/'
},
'api_key': 'AIzaSyAO_FJ2SlqU8Q4STEHLGCilw_Y9_11qcW8'
},
'MWEB': {
'context': {
'client': {
'clientName': 'MWEB',
'clientVersion': '2.20220801.00.00',
}
},
'header': {
'User-Agent': 'Mozilla/5.0'
},
'api_key': 'AIzaSyAO_FJ2SlqU8Q4STEHLGCilw_Y9_11qcW8'
},
'TV_EMBED': {
'context': {
'client': {
'clientName': 'TVHTML5_SIMPLY_EMBEDDED_PLAYER',
'clientVersion': '2.0',
}
},
'header': {
'User-Agent': 'Mozilla/5.0'
},
'api_key': 'AIzaSyAO_FJ2SlqU8Q4STEHLGCilw_Y9_11qcW8'
},
}
_token_timeout = 1800
_cache_dir = pathlib.Path(__file__).parent.resolve() / '__cache__'
_token_file = os.path.join(_cache_dir, 'tokens.json')
class InnerTube:
"""Object for interacting with the innertube API."""
def __init__(self, client='ANDROID_MUSIC', use_oauth=False, allow_cache=True):
"""Initialize an InnerTube object.
:param str client:
Client to use for the object.
Default to web because it returns the most playback types.
:param bool use_oauth:
Whether or not to authenticate to YouTube.
:param bool allow_cache:
Allows caching of oauth tokens on the machine.
"""
self.context = _default_clients[client]['context']
self.header = _default_clients[client]['header']
self.api_key = _default_clients[client]['api_key']
self.access_token = None
self.refresh_token = None
self.use_oauth = use_oauth
self.allow_cache = allow_cache
# Stored as epoch time
self.expires = None
# Try to load from file if specified
if self.use_oauth and self.allow_cache:
# Try to load from file if possible
if os.path.exists(_token_file):
with open(_token_file) as f:
data = json.load(f)
self.access_token = data['access_token']
self.refresh_token = data['refresh_token']
self.expires = data['expires']
self.refresh_bearer_token()
def cache_tokens(self):
"""Cache tokens to file if allowed."""
if not self.allow_cache:
return
data = {
'access_token': self.access_token,
'refresh_token': self.refresh_token,
'expires': self.expires
}
if not os.path.exists(_cache_dir):
os.mkdir(_cache_dir)
with open(_token_file, 'w') as f:
json.dump(data, f)
def refresh_bearer_token(self, force=False):
"""Refreshes the OAuth token if necessary.
:param bool force:
Force-refresh the bearer token.
"""
if not self.use_oauth:
return
# Skip refresh if it's not necessary and not forced
if self.expires > time.time() and not force:
return
# Subtracting 30 seconds is arbitrary to avoid potential time discrepencies
start_time = int(time.time() - 30)
data = {
'client_id': _client_id,
'client_secret': _client_secret,
'grant_type': 'refresh_token',
'refresh_token': self.refresh_token
}
response = request._execute_request(
'https://oauth2.googleapis.com/token',
'POST',
headers={
'Content-Type': 'application/json'
},
data=data
)
response_data = json.loads(response.read())
self.access_token = response_data['access_token']
self.expires = start_time + response_data['expires_in']
self.cache_tokens()
def fetch_bearer_token(self):
"""Fetch an OAuth token."""
# Subtracting 30 seconds is arbitrary to avoid potential time discrepencies
start_time = int(time.time() - 30)
data = {
'client_id': _client_id,
'scope': 'https://www.googleapis.com/auth/youtube'
}
response = request._execute_request(
'https://oauth2.googleapis.com/device/code',
'POST',
headers={
'Content-Type': 'application/json'
},
data=data
)
response_data = json.loads(response.read())
verification_url = response_data['verification_url']
user_code = response_data['user_code']
print(f'Please open {verification_url} and input code {user_code}')
input('Press enter when you have completed this step.')
data = {
'client_id': _client_id,
'client_secret': _client_secret,
'device_code': response_data['device_code'],
'grant_type': 'urn:ietf:params:oauth:grant-type:device_code'
}
response = request._execute_request(
'https://oauth2.googleapis.com/token',
'POST',
headers={
'Content-Type': 'application/json'
},
data=data
)
response_data = json.loads(response.read())
self.access_token = response_data['access_token']
self.refresh_token = response_data['refresh_token']
self.expires = start_time + response_data['expires_in']
self.cache_tokens()
@property
def base_url(self):
"""Return the base url endpoint for the innertube API."""
return 'https://www.youtube.com/youtubei/v1'
@property
def base_data(self):
"""Return the base json data to transmit to the innertube API."""
return {
'context': self.context
}
@property
def base_params(self):
"""Return the base query parameters to transmit to the innertube API."""
return {
'key': self.api_key,
'contentCheckOk': True,
'racyCheckOk': True
}
def _call_api(self, endpoint, query, data):
"""Make a request to a given endpoint with the provided query parameters and data."""
# Remove the API key if oauth is being used.
if self.use_oauth:
del query['key']
endpoint_url = f'{endpoint}?{parse.urlencode(query)}'
headers = {
'Content-Type': 'application/json',
}
# Add the bearer token if applicable
if self.use_oauth:
if self.access_token:
self.refresh_bearer_token()
headers['Authorization'] = f'Bearer {self.access_token}'
else:
self.fetch_bearer_token()
headers['Authorization'] = f'Bearer {self.access_token}'
headers.update(self.header)
response = request._execute_request(
endpoint_url,
'POST',
headers=headers,
data=data
)
return json.loads(response.read())
def browse(self):
"""Make a request to the browse endpoint.
TODO: Figure out how we can use this
"""
# endpoint = f'{self.base_url}/browse' # noqa:E800
...
# return self._call_api(endpoint, query, self.base_data) # noqa:E800
def config(self):
"""Make a request to the config endpoint.
TODO: Figure out how we can use this
"""
# endpoint = f'{self.base_url}/config' # noqa:E800
...
# return self._call_api(endpoint, query, self.base_data) # noqa:E800
def guide(self):
"""Make a request to the guide endpoint.
TODO: Figure out how we can use this
"""
# endpoint = f'{self.base_url}/guide' # noqa:E800
...
# return self._call_api(endpoint, query, self.base_data) # noqa:E800
def next(self):
"""Make a request to the next endpoint.
TODO: Figure out how we can use this
"""
# endpoint = f'{self.base_url}/next' # noqa:E800
...
# return self._call_api(endpoint, query, self.base_data) # noqa:E800
def player(self, video_id):
"""Make a request to the player endpoint.
:param str video_id:
The video id to get player info for.
:rtype: dict
:returns:
Raw player info results.
"""
endpoint = f'{self.base_url}/player'
query = {
'videoId': video_id,
}
query.update(self.base_params)
return self._call_api(endpoint, query, self.base_data)
def search(self, search_query, continuation=None):
"""Make a request to the search endpoint.
:param str search_query:
The query to search.
:rtype: dict
:returns:
Raw search query results.
"""
endpoint = f'{self.base_url}/search'
query = {
'query': search_query
}
query.update(self.base_params)
data = {}
if continuation:
data['continuation'] = continuation
data.update(self.base_data)
return self._call_api(endpoint, query, data)
def verify_age(self, video_id):
"""Make a request to the age_verify endpoint.
Notable examples of the types of video this verification step is for:
* https://www.youtube.com/watch?v=QLdAhwSBZ3w
* https://www.youtube.com/watch?v=hc0ZDaAZQT0
:param str video_id:
The video id to get player info for.
:rtype: dict
:returns:
Returns information that includes a URL for bypassing certain restrictions.
"""
endpoint = f'{self.base_url}/verify_age'
data = {
'nextEndpoint': {
'urlEndpoint': {
'url': f'/watch?v={video_id}'
}
},
'setControvercy': True
}
data.update(self.base_data)
result = self._call_api(endpoint, self.base_params, data)
return result
def get_transcript(self, video_id):
"""Make a request to the get_transcript endpoint.
This is likely related to captioning for videos, but is currently untested.
"""
endpoint = f'{self.base_url}/get_transcript'
query = {
'videoId': video_id,
}
query.update(self.base_params)
result = self._call_api(endpoint, query, self.base_data)
return result

View File

@@ -1,153 +0,0 @@
"""This module contains a lookup table of YouTube's itag values."""
from typing import Dict
PROGRESSIVE_VIDEO = {
5: ("240p", "64kbps"),
6: ("270p", "64kbps"),
13: ("144p", None),
17: ("144p", "24kbps"),
18: ("360p", "96kbps"),
22: ("720p", "192kbps"),
34: ("360p", "128kbps"),
35: ("480p", "128kbps"),
36: ("240p", None),
37: ("1080p", "192kbps"),
38: ("3072p", "192kbps"),
43: ("360p", "128kbps"),
44: ("480p", "128kbps"),
45: ("720p", "192kbps"),
46: ("1080p", "192kbps"),
59: ("480p", "128kbps"),
78: ("480p", "128kbps"),
82: ("360p", "128kbps"),
83: ("480p", "128kbps"),
84: ("720p", "192kbps"),
85: ("1080p", "192kbps"),
91: ("144p", "48kbps"),
92: ("240p", "48kbps"),
93: ("360p", "128kbps"),
94: ("480p", "128kbps"),
95: ("720p", "256kbps"),
96: ("1080p", "256kbps"),
100: ("360p", "128kbps"),
101: ("480p", "192kbps"),
102: ("720p", "192kbps"),
132: ("240p", "48kbps"),
151: ("720p", "24kbps"),
300: ("720p", "128kbps"),
301: ("1080p", "128kbps"),
}
DASH_VIDEO = {
# DASH Video
133: ("240p", None), # MP4
134: ("360p", None), # MP4
135: ("480p", None), # MP4
136: ("720p", None), # MP4
137: ("1080p", None), # MP4
138: ("2160p", None), # MP4
160: ("144p", None), # MP4
167: ("360p", None), # WEBM
168: ("480p", None), # WEBM
169: ("720p", None), # WEBM
170: ("1080p", None), # WEBM
212: ("480p", None), # MP4
218: ("480p", None), # WEBM
219: ("480p", None), # WEBM
242: ("240p", None), # WEBM
243: ("360p", None), # WEBM
244: ("480p", None), # WEBM
245: ("480p", None), # WEBM
246: ("480p", None), # WEBM
247: ("720p", None), # WEBM
248: ("1080p", None), # WEBM
264: ("1440p", None), # MP4
266: ("2160p", None), # MP4
271: ("1440p", None), # WEBM
272: ("4320p", None), # WEBM
278: ("144p", None), # WEBM
298: ("720p", None), # MP4
299: ("1080p", None), # MP4
302: ("720p", None), # WEBM
303: ("1080p", None), # WEBM
308: ("1440p", None), # WEBM
313: ("2160p", None), # WEBM
315: ("2160p", None), # WEBM
330: ("144p", None), # WEBM
331: ("240p", None), # WEBM
332: ("360p", None), # WEBM
333: ("480p", None), # WEBM
334: ("720p", None), # WEBM
335: ("1080p", None), # WEBM
336: ("1440p", None), # WEBM
337: ("2160p", None), # WEBM
394: ("144p", None), # MP4
395: ("240p", None), # MP4
396: ("360p", None), # MP4
397: ("480p", None), # MP4
398: ("720p", None), # MP4
399: ("1080p", None), # MP4
400: ("1440p", None), # MP4
401: ("2160p", None), # MP4
402: ("4320p", None), # MP4
571: ("4320p", None), # MP4
694: ("144p", None), # MP4
695: ("240p", None), # MP4
696: ("360p", None), # MP4
697: ("480p", None), # MP4
698: ("720p", None), # MP4
699: ("1080p", None), # MP4
700: ("1440p", None), # MP4
701: ("2160p", None), # MP4
702: ("4320p", None), # MP4
}
DASH_AUDIO = {
# DASH Audio
139: (None, "48kbps"), # MP4
140: (None, "128kbps"), # MP4
141: (None, "256kbps"), # MP4
171: (None, "128kbps"), # WEBM
172: (None, "256kbps"), # WEBM
249: (None, "50kbps"), # WEBM
250: (None, "70kbps"), # WEBM
251: (None, "160kbps"), # WEBM
256: (None, "192kbps"), # MP4
258: (None, "384kbps"), # MP4
325: (None, None), # MP4
328: (None, None), # MP4
}
ITAGS = {
**PROGRESSIVE_VIDEO,
**DASH_VIDEO,
**DASH_AUDIO,
}
HDR = [330, 331, 332, 333, 334, 335, 336, 337]
_3D = [82, 83, 84, 85, 100, 101, 102]
LIVE = [91, 92, 93, 94, 95, 96, 132, 151]
def get_format_profile(itag: int) -> Dict:
"""Get additional format information for a given itag.
:param str itag:
YouTube format identifier code.
"""
itag = int(itag)
if itag in ITAGS:
res, bitrate = ITAGS[itag]
else:
res, bitrate = None, None
return {
"resolution": res,
"abr": bitrate,
"is_live": itag in LIVE,
"is_3d": itag in _3D,
"is_hdr": itag in HDR,
"is_dash": (
itag in DASH_AUDIO
or itag in DASH_VIDEO
),
}

View File

@@ -1,48 +0,0 @@
"""This module contains the YouTubeMetadata class."""
import json
from typing import Dict, List, Optional
class YouTubeMetadata:
def __init__(self, metadata: List):
self._raw_metadata: List = metadata
self._metadata = [{}]
for el in metadata:
# We only add metadata to the dict if it has a simpleText title.
if 'title' in el and 'simpleText' in el['title']:
metadata_title = el['title']['simpleText']
else:
continue
contents = el['contents'][0]
if 'simpleText' in contents:
self._metadata[-1][metadata_title] = contents['simpleText']
elif 'runs' in contents:
self._metadata[-1][metadata_title] = contents['runs'][0]['text']
# Upon reaching a dividing line, create a new grouping
if el.get('hasDividerLine', False):
self._metadata.append({})
# If we happen to create an empty dict at the end, drop it
if self._metadata[-1] == {}:
self._metadata = self._metadata[:-1]
def __getitem__(self, key):
return self._metadata[key]
def __iter__(self):
for el in self._metadata:
yield el
def __str__(self):
return json.dumps(self._metadata)
@property
def raw_metadata(self) -> Optional[Dict]:
return self._raw_metadata
@property
def metadata(self):
return self._metadata

View File

@@ -1,15 +0,0 @@
from typing import Any, Callable, Optional
class Monostate:
def __init__(
self,
on_progress: Optional[Callable[[Any, bytes, int], None]],
on_complete: Optional[Callable[[Any, Optional[str]], None]],
title: Optional[str] = None,
duration: Optional[int] = None,
):
self.on_progress = on_progress
self.on_complete = on_complete
self.title = title
self.duration = duration

View File

@@ -1,185 +0,0 @@
import ast
import json
import re
from pytube.exceptions import HTMLParseError
def parse_for_all_objects(html, preceding_regex):
"""Parses input html to find all matches for the input starting point.
:param str html:
HTML to be parsed for an object.
:param str preceding_regex:
Regex to find the string preceding the object.
:rtype list:
:returns:
A list of dicts created from parsing the objects.
"""
result = []
regex = re.compile(preceding_regex)
match_iter = regex.finditer(html)
for match in match_iter:
if match:
start_index = match.end()
try:
obj = parse_for_object_from_startpoint(html, start_index)
except HTMLParseError:
# Some of the instances might fail because set is technically
# a method of the ytcfg object. We'll skip these since they
# don't seem relevant at the moment.
continue
else:
result.append(obj)
if len(result) == 0:
raise HTMLParseError(f'No matches for regex {preceding_regex}')
return result
def parse_for_object(html, preceding_regex):
"""Parses input html to find the end of a JavaScript object.
:param str html:
HTML to be parsed for an object.
:param str preceding_regex:
Regex to find the string preceding the object.
:rtype dict:
:returns:
A dict created from parsing the object.
"""
regex = re.compile(preceding_regex)
result = regex.search(html)
if not result:
raise HTMLParseError(f'No matches for regex {preceding_regex}')
start_index = result.end()
return parse_for_object_from_startpoint(html, start_index)
def find_object_from_startpoint(html, start_point):
"""Parses input html to find the end of a JavaScript object.
:param str html:
HTML to be parsed for an object.
:param int start_point:
Index of where the object starts.
:rtype dict:
:returns:
A dict created from parsing the object.
"""
html = html[start_point:]
if html[0] not in ['{','[']:
raise HTMLParseError(f'Invalid start point. Start of HTML:\n{html[:20]}')
# First letter MUST be a open brace, so we put that in the stack,
# and skip the first character.
last_char = '{'
curr_char = None
stack = [html[0]]
i = 1
context_closers = {
'{': '}',
'[': ']',
'"': '"',
'/': '/' # javascript regex
}
while i < len(html):
if len(stack) == 0:
break
if curr_char not in [' ', '\n']:
last_char = curr_char
curr_char = html[i]
curr_context = stack[-1]
# If we've reached a context closer, we can remove an element off the stack
if curr_char == context_closers[curr_context]:
stack.pop()
i += 1
continue
# Strings and regex expressions require special context handling because they can contain
# context openers *and* closers
if curr_context in ['"', '/']:
# If there's a backslash in a string or regex expression, we skip a character
if curr_char == '\\':
i += 2
continue
else:
# Non-string contexts are when we need to look for context openers.
if curr_char in context_closers.keys():
# Slash starts a regular expression depending on context
if not (curr_char == '/' and last_char not in ['(', ',', '=', ':', '[', '!', '&', '|', '?', '{', '}', ';']):
stack.append(curr_char)
i += 1
full_obj = html[:i]
return full_obj # noqa: R504
def parse_for_object_from_startpoint(html, start_point):
"""JSONifies an object parsed from HTML.
:param str html:
HTML to be parsed for an object.
:param int start_point:
Index of where the object starts.
:rtype dict:
:returns:
A dict created from parsing the object.
"""
full_obj = find_object_from_startpoint(html, start_point)
try:
return json.loads(full_obj)
except json.decoder.JSONDecodeError:
try:
return ast.literal_eval(full_obj)
except (ValueError, SyntaxError):
raise HTMLParseError('Could not parse object.')
def throttling_array_split(js_array):
"""Parses the throttling array into a python list of strings.
Expects input to begin with `[` and close with `]`.
:param str js_array:
The javascript array, as a string.
:rtype: list:
:returns:
A list of strings representing splits on `,` in the throttling array.
"""
results = []
curr_substring = js_array[1:]
comma_regex = re.compile(r",")
func_regex = re.compile(r"function\([^)]*\)")
while len(curr_substring) > 0:
if curr_substring.startswith('function'):
# Handle functions separately. These can contain commas
match = func_regex.search(curr_substring)
match_start, match_end = match.span()
function_text = find_object_from_startpoint(curr_substring, match.span()[1])
full_function_def = curr_substring[:match_end + len(function_text)]
results.append(full_function_def)
curr_substring = curr_substring[len(full_function_def) + 1:]
else:
match = comma_regex.search(curr_substring)
# Try-catch to capture end of array
try:
match_start, match_end = match.span()
except AttributeError:
match_start = len(curr_substring) - 1
match_end = match_start + 1
curr_el = curr_substring[:match_start]
results.append(curr_el)
curr_substring = curr_substring[match_end:]
return results

View File

@@ -1,424 +0,0 @@
"""This module provides a query interface for media streams and captions."""
from collections.abc import Mapping, Sequence
from typing import Callable, List, Optional, Union
from pytube import Caption, Stream
from pytube.helpers import deprecated
class StreamQuery(Sequence):
"""Interface for querying the available media streams."""
def __init__(self, fmt_streams):
"""Construct a :class:`StreamQuery <StreamQuery>`.
param list fmt_streams:
list of :class:`Stream <Stream>` instances.
"""
self.fmt_streams = fmt_streams
self.itag_index = {int(s.itag): s for s in fmt_streams}
def filter(
self,
fps=None,
res=None,
resolution=None,
mime_type=None,
type=None,
subtype=None,
file_extension=None,
abr=None,
bitrate=None,
video_codec=None,
audio_codec=None,
only_audio=None,
only_video=None,
progressive=None,
adaptive=None,
is_dash=None,
custom_filter_functions=None,
):
"""Apply the given filtering criterion.
:param fps:
(optional) The frames per second.
:type fps:
int or None
:param resolution:
(optional) Alias to ``res``.
:type res:
str or None
:param res:
(optional) The video resolution.
:type resolution:
str or None
:param mime_type:
(optional) Two-part identifier for file formats and format contents
composed of a "type", a "subtype".
:type mime_type:
str or None
:param type:
(optional) Type part of the ``mime_type`` (e.g.: audio, video).
:type type:
str or None
:param subtype:
(optional) Sub-type part of the ``mime_type`` (e.g.: mp4, mov).
:type subtype:
str or None
:param file_extension:
(optional) Alias to ``sub_type``.
:type file_extension:
str or None
:param abr:
(optional) Average bitrate (ABR) refers to the average amount of
data transferred per unit of time (e.g.: 64kbps, 192kbps).
:type abr:
str or None
:param bitrate:
(optional) Alias to ``abr``.
:type bitrate:
str or None
:param video_codec:
(optional) Video compression format.
:type video_codec:
str or None
:param audio_codec:
(optional) Audio compression format.
:type audio_codec:
str or None
:param bool progressive:
Excludes adaptive streams (one file contains both audio and video
tracks).
:param bool adaptive:
Excludes progressive streams (audio and video are on separate
tracks).
:param bool is_dash:
Include/exclude dash streams.
:param bool only_audio:
Excludes streams with video tracks.
:param bool only_video:
Excludes streams with audio tracks.
:param custom_filter_functions:
(optional) Interface for defining complex filters without
subclassing.
:type custom_filter_functions:
list or None
"""
filters = []
if res or resolution:
if isinstance(res, str) or isinstance(resolution, str):
filters.append(lambda s: s.resolution == (res or resolution))
elif isinstance(res, list) or isinstance(resolution, list):
filters.append(lambda s: s.resolution in (res or resolution))
if fps:
filters.append(lambda s: s.fps == fps)
if mime_type:
filters.append(lambda s: s.mime_type == mime_type)
if type:
filters.append(lambda s: s.type == type)
if subtype or file_extension:
filters.append(lambda s: s.subtype == (subtype or file_extension))
if abr or bitrate:
filters.append(lambda s: s.abr == (abr or bitrate))
if video_codec:
filters.append(lambda s: s.video_codec == video_codec)
if audio_codec:
filters.append(lambda s: s.audio_codec == audio_codec)
if only_audio:
filters.append(
lambda s: (
s.includes_audio_track and not s.includes_video_track
),
)
if only_video:
filters.append(
lambda s: (
s.includes_video_track and not s.includes_audio_track
),
)
if progressive:
filters.append(lambda s: s.is_progressive)
if adaptive:
filters.append(lambda s: s.is_adaptive)
if custom_filter_functions:
filters.extend(custom_filter_functions)
if is_dash is not None:
filters.append(lambda s: s.is_dash == is_dash)
return self._filter(filters)
def _filter(self, filters: List[Callable]) -> "StreamQuery":
fmt_streams = self.fmt_streams
for filter_lambda in filters:
fmt_streams = filter(filter_lambda, fmt_streams)
return StreamQuery(list(fmt_streams))
def order_by(self, attribute_name: str) -> "StreamQuery":
"""Apply a sort order. Filters out stream the do not have the attribute.
:param str attribute_name:
The name of the attribute to sort by.
"""
has_attribute = [
s
for s in self.fmt_streams
if getattr(s, attribute_name) is not None
]
# Check that the attributes have string values.
if has_attribute and isinstance(
getattr(has_attribute[0], attribute_name), str
):
# Try to return a StreamQuery sorted by the integer representations
# of the values.
try:
return StreamQuery(
sorted(
has_attribute,
key=lambda s: int(
"".join(
filter(str.isdigit, getattr(s, attribute_name))
)
), # type: ignore # noqa: E501
)
)
except ValueError:
pass
return StreamQuery(
sorted(has_attribute, key=lambda s: getattr(s, attribute_name))
)
def desc(self) -> "StreamQuery":
"""Sort streams in descending order.
:rtype: :class:`StreamQuery <StreamQuery>`
"""
return StreamQuery(self.fmt_streams[::-1])
def asc(self) -> "StreamQuery":
"""Sort streams in ascending order.
:rtype: :class:`StreamQuery <StreamQuery>`
"""
return self
def get_by_itag(self, itag: int) -> Optional[Stream]:
"""Get the corresponding :class:`Stream <Stream>` for a given itag.
:param int itag:
YouTube format identifier code.
:rtype: :class:`Stream <Stream>` or None
:returns:
The :class:`Stream <Stream>` matching the given itag or None if
not found.
"""
return self.itag_index.get(int(itag))
def get_by_resolution(self, resolution: str) -> Optional[Stream]:
"""Get the corresponding :class:`Stream <Stream>` for a given resolution.
Stream must be a progressive mp4.
:param str resolution:
Video resolution i.e. "720p", "480p", "360p", "240p", "144p"
:rtype: :class:`Stream <Stream>` or None
:returns:
The :class:`Stream <Stream>` matching the given itag or None if
not found.
"""
return self.filter(
progressive=True, subtype="mp4", resolution=resolution
).first()
def get_lowest_resolution(self) -> Optional[Stream]:
"""Get lowest resolution stream that is a progressive mp4.
:rtype: :class:`Stream <Stream>` or None
:returns:
The :class:`Stream <Stream>` matching the given itag or None if
not found.
"""
return (
self.filter(progressive=True, subtype="mp4")
.order_by("resolution")
.first()
)
def get_highest_resolution(self) -> Optional[Stream]:
"""Get highest resolution stream that is a progressive video.
:rtype: :class:`Stream <Stream>` or None
:returns:
The :class:`Stream <Stream>` matching the given itag or None if
not found.
"""
return self.filter(progressive=True).order_by("resolution").last()
def get_audio_only(self, subtype: str = "mp4") -> Optional[Stream]:
"""Get highest bitrate audio stream for given codec (defaults to mp4)
:param str subtype:
Audio subtype, defaults to mp4
:rtype: :class:`Stream <Stream>` or None
:returns:
The :class:`Stream <Stream>` matching the given itag or None if
not found.
"""
return (
self.filter(only_audio=True, subtype=subtype)
.order_by("abr")
.last()
)
def otf(self, is_otf: bool = False) -> "StreamQuery":
"""Filter stream by OTF, useful if some streams have 404 URLs
:param bool is_otf: Set to False to retrieve only non-OTF streams
:rtype: :class:`StreamQuery <StreamQuery>`
:returns: A StreamQuery object with otf filtered streams
"""
return self._filter([lambda s: s.is_otf == is_otf])
def first(self) -> Optional[Stream]:
"""Get the first :class:`Stream <Stream>` in the results.
:rtype: :class:`Stream <Stream>` or None
:returns:
the first result of this query or None if the result doesn't
contain any streams.
"""
try:
return self.fmt_streams[0]
except IndexError:
return None
def last(self):
"""Get the last :class:`Stream <Stream>` in the results.
:rtype: :class:`Stream <Stream>` or None
:returns:
Return the last result of this query or None if the result
doesn't contain any streams.
"""
try:
return self.fmt_streams[-1]
except IndexError:
pass
@deprecated("Get the size of this list directly using len()")
def count(self, value: Optional[str] = None) -> int: # pragma: no cover
"""Get the count of items in the list.
:rtype: int
"""
if value:
return self.fmt_streams.count(value)
return len(self)
@deprecated("This object can be treated as a list, all() is useless")
def all(self) -> List[Stream]: # pragma: no cover
"""Get all the results represented by this query as a list.
:rtype: list
"""
return self.fmt_streams
def __getitem__(self, i: Union[slice, int]):
return self.fmt_streams[i]
def __len__(self) -> int:
return len(self.fmt_streams)
def __repr__(self) -> str:
return f"{self.fmt_streams}"
class CaptionQuery(Mapping):
"""Interface for querying the available captions."""
def __init__(self, captions: List[Caption]):
"""Construct a :class:`Caption <Caption>`.
param list captions:
list of :class:`Caption <Caption>` instances.
"""
self.lang_code_index = {c.code: c for c in captions}
@deprecated(
"This object can be treated as a dictionary, i.e. captions['en']"
)
def get_by_language_code(
self, lang_code: str
) -> Optional[Caption]: # pragma: no cover
"""Get the :class:`Caption <Caption>` for a given ``lang_code``.
:param str lang_code:
The code that identifies the caption language.
:rtype: :class:`Caption <Caption>` or None
:returns:
The :class:`Caption <Caption>` matching the given ``lang_code`` or
None if it does not exist.
"""
return self.lang_code_index.get(lang_code)
@deprecated("This object can be treated as a dictionary")
def all(self) -> List[Caption]: # pragma: no cover
"""Get all the results represented by this query as a list.
:rtype: list
"""
return list(self.lang_code_index.values())
def __getitem__(self, i: str):
return self.lang_code_index[i]
def __len__(self) -> int:
return len(self.lang_code_index)
def __iter__(self):
return iter(self.lang_code_index.values())
def __repr__(self) -> str:
return f"{self.lang_code_index}"

View File

@@ -1,269 +0,0 @@
"""Implements a simple wrapper around urlopen."""
import http.client
import json
import logging
import re
import socket
from functools import lru_cache
from urllib import parse
from urllib.error import URLError
from urllib.request import Request, urlopen
from pytube.exceptions import RegexMatchError, MaxRetriesExceeded
from pytube.helpers import regex_search
logger = logging.getLogger(__name__)
default_range_size = 9437184 # 9MB
def _execute_request(
url,
method=None,
headers=None,
data=None,
timeout=socket._GLOBAL_DEFAULT_TIMEOUT
):
base_headers = {"User-Agent": "Mozilla/5.0", "accept-language": "en-US,en"}
if headers:
base_headers.update(headers)
if data:
# encode data for request
if not isinstance(data, bytes):
data = bytes(json.dumps(data), encoding="utf-8")
if url.lower().startswith("http"):
request = Request(url, headers=base_headers, method=method, data=data)
else:
raise ValueError("Invalid URL")
return urlopen(request, timeout=timeout) # nosec
def get(url, extra_headers=None, timeout=socket._GLOBAL_DEFAULT_TIMEOUT):
"""Send an http GET request.
:param str url:
The URL to perform the GET request for.
:param dict extra_headers:
Extra headers to add to the request
:rtype: str
:returns:
UTF-8 encoded string of response
"""
if extra_headers is None:
extra_headers = {}
response = _execute_request(url, headers=extra_headers, timeout=timeout)
return response.read().decode("utf-8")
def post(url, extra_headers=None, data=None, timeout=socket._GLOBAL_DEFAULT_TIMEOUT):
"""Send an http POST request.
:param str url:
The URL to perform the POST request for.
:param dict extra_headers:
Extra headers to add to the request
:param dict data:
The data to send on the POST request
:rtype: str
:returns:
UTF-8 encoded string of response
"""
# could technically be implemented in get,
# but to avoid confusion implemented like this
if extra_headers is None:
extra_headers = {}
if data is None:
data = {}
# required because the youtube servers are strict on content type
# raises HTTPError [400]: Bad Request otherwise
extra_headers.update({"Content-Type": "application/json"})
response = _execute_request(
url,
headers=extra_headers,
data=data,
timeout=timeout
)
return response.read().decode("utf-8")
def seq_stream(
url,
timeout=socket._GLOBAL_DEFAULT_TIMEOUT,
max_retries=0
):
"""Read the response in sequence.
:param str url: The URL to perform the GET request for.
:rtype: Iterable[bytes]
"""
# YouTube expects a request sequence number as part of the parameters.
split_url = parse.urlsplit(url)
base_url = '%s://%s/%s?' % (split_url.scheme, split_url.netloc, split_url.path)
querys = dict(parse.parse_qsl(split_url.query))
# The 0th sequential request provides the file headers, which tell us
# information about how the file is segmented.
querys['sq'] = 0
url = base_url + parse.urlencode(querys)
segment_data = b''
for chunk in stream(url, timeout=timeout, max_retries=max_retries):
yield chunk
segment_data += chunk
# We can then parse the header to find the number of segments
stream_info = segment_data.split(b'\r\n')
segment_count_pattern = re.compile(b'Segment-Count: (\\d+)')
for line in stream_info:
match = segment_count_pattern.search(line)
if match:
segment_count = int(match.group(1).decode('utf-8'))
# We request these segments sequentially to build the file.
seq_num = 1
while seq_num <= segment_count:
# Create sequential request URL
querys['sq'] = seq_num
url = base_url + parse.urlencode(querys)
yield from stream(url, timeout=timeout, max_retries=max_retries)
seq_num += 1
return # pylint: disable=R1711
def stream(
url,
timeout=socket._GLOBAL_DEFAULT_TIMEOUT,
max_retries=0
):
"""Read the response in chunks.
:param str url: The URL to perform the GET request for.
:rtype: Iterable[bytes]
"""
file_size: int = default_range_size # fake filesize to start
downloaded = 0
while downloaded < file_size:
stop_pos = min(downloaded + default_range_size, file_size) - 1
range_header = f"bytes={downloaded}-{stop_pos}"
tries = 0
# Attempt to make the request multiple times as necessary.
while True:
# If the max retries is exceeded, raise an exception
if tries >= 1 + max_retries:
raise MaxRetriesExceeded()
# Try to execute the request, ignoring socket timeouts
try:
response = _execute_request(
url + f"&range={downloaded}-{stop_pos}",
method="GET",
timeout=timeout
)
except URLError as e:
# We only want to skip over timeout errors, and
# raise any other URLError exceptions
if isinstance(e.reason, socket.timeout):
pass
else:
raise
except http.client.IncompleteRead:
# Allow retries on IncompleteRead errors for unreliable connections
pass
else:
# On a successful request, break from loop
break
tries += 1
if file_size == default_range_size:
try:
resp = _execute_request(
url + f"&range={0}-{99999999999}",
method="GET",
timeout=timeout
)
content_range = resp.info()["Content-Length"]
file_size = int(content_range)
except (KeyError, IndexError, ValueError) as e:
logger.error(e)
while True:
chunk = response.read()
if not chunk:
break
downloaded += len(chunk)
yield chunk
return # pylint: disable=R1711
@lru_cache()
def filesize(url):
"""Fetch size in bytes of file at given URL
:param str url: The URL to get the size of
:returns: int: size in bytes of remote file
"""
return int(head(url)["content-length"])
@lru_cache()
def seq_filesize(url):
"""Fetch size in bytes of file at given URL from sequential requests
:param str url: The URL to get the size of
:returns: int: size in bytes of remote file
"""
total_filesize = 0
# YouTube expects a request sequence number as part of the parameters.
split_url = parse.urlsplit(url)
base_url = '%s://%s/%s?' % (split_url.scheme, split_url.netloc, split_url.path)
querys = dict(parse.parse_qsl(split_url.query))
# The 0th sequential request provides the file headers, which tell us
# information about how the file is segmented.
querys['sq'] = 0
url = base_url + parse.urlencode(querys)
response = _execute_request(
url, method="GET"
)
response_value = response.read()
# The file header must be added to the total filesize
total_filesize += len(response_value)
# We can then parse the header to find the number of segments
segment_count = 0
stream_info = response_value.split(b'\r\n')
segment_regex = b'Segment-Count: (\\d+)'
for line in stream_info:
# One of the lines should contain the segment count, but we don't know
# which, so we need to iterate through the lines to find it
try:
segment_count = int(regex_search(segment_regex, line, 1))
except RegexMatchError:
pass
if segment_count == 0:
raise RegexMatchError('seq_filesize', segment_regex)
# We make HEAD requests to the segments sequentially to find the total filesize.
seq_num = 1
while seq_num <= segment_count:
# Create sequential request URL
querys['sq'] = seq_num
url = base_url + parse.urlencode(querys)
total_filesize += int(head(url)['content-length'])
seq_num += 1
return total_filesize
def head(url):
"""Fetch headers returned http GET request.
:param str url:
The URL to perform the GET request for.
:rtype: dict
:returns:
dictionary of lowercase headers
"""
response_headers = _execute_request(url, method="HEAD").info()
return {k.lower(): v for k, v in response_headers.items()}

View File

@@ -1,436 +0,0 @@
"""
This module contains a container for stream manifest data.
A container object for the media stream (video only / audio only / video+audio
combined). This was referred to as ``Video`` in the legacy pytube version, but
has been renamed to accommodate DASH (which serves the audio and video
separately).
"""
import logging
import os
from math import ceil
from datetime import datetime
from typing import BinaryIO, Dict, Optional, Tuple
from urllib.error import HTTPError
from urllib.parse import parse_qs
from pytube import extract, request
from pytube.helpers import safe_filename, target_directory
from pytube.itags import get_format_profile
from pytube.monostate import Monostate
logger = logging.getLogger(__name__)
class Stream:
"""Container for stream manifest data."""
def __init__(
self, stream: Dict, monostate: Monostate
):
"""Construct a :class:`Stream <Stream>`.
:param dict stream:
The unscrambled data extracted from YouTube.
:param dict monostate:
Dictionary of data shared across all instances of
:class:`Stream <Stream>`.
"""
# A dictionary shared between all instances of :class:`Stream <Stream>`
# (Borg pattern).
self._monostate = monostate
self.url = stream["url"] # signed download url
self.itag = int(
stream["itag"]
) # stream format id (youtube nomenclature)
# set type and codec info
# 'video/webm; codecs="vp8, vorbis"' -> 'video/webm', ['vp8', 'vorbis']
self.mime_type, self.codecs = extract.mime_type_codec(stream["mimeType"])
# 'video/webm' -> 'video', 'webm'
self.type, self.subtype = self.mime_type.split("/")
# ['vp8', 'vorbis'] -> video_codec: vp8, audio_codec: vorbis. DASH
# streams return NoneType for audio/video depending.
self.video_codec, self.audio_codec = self.parse_codecs()
self.is_otf: bool = stream["is_otf"]
self.bitrate: Optional[int] = stream["bitrate"]
# filesize in bytes
self._filesize: Optional[int] = int(stream.get('contentLength', 0))
# filesize in kilobytes
self._filesize_kb: Optional[float] = float(ceil(float(stream.get('contentLength', 0)) / 1024 * 1000) / 1000)
# filesize in megabytes
self._filesize_mb: Optional[float] = float(ceil(float(stream.get('contentLength', 0)) / 1024 / 1024 * 1000) / 1000)
# filesize in gigabytes(fingers crossed we don't need terabytes going forward though)
self._filesize_gb: Optional[float] = float(ceil(float(stream.get('contentLength', 0)) / 1024 / 1024 / 1024 * 1000) / 1000)
# Additional information about the stream format, such as resolution,
# frame rate, and whether the stream is live (HLS) or 3D.
itag_profile = get_format_profile(self.itag)
self.is_dash = itag_profile["is_dash"]
self.abr = itag_profile["abr"] # average bitrate (audio streams only)
if 'fps' in stream:
self.fps = stream['fps'] # Video streams only
self.resolution = itag_profile[
"resolution"
] # resolution (e.g.: "480p")
self.is_3d = itag_profile["is_3d"]
self.is_hdr = itag_profile["is_hdr"]
self.is_live = itag_profile["is_live"]
@property
def is_adaptive(self) -> bool:
"""Whether the stream is DASH.
:rtype: bool
"""
# if codecs has two elements (e.g.: ['vp8', 'vorbis']): 2 % 2 = 0
# if codecs has one element (e.g.: ['vp8']) 1 % 2 = 1
return bool(len(self.codecs) % 2)
@property
def is_progressive(self) -> bool:
"""Whether the stream is progressive.
:rtype: bool
"""
return not self.is_adaptive
@property
def includes_audio_track(self) -> bool:
"""Whether the stream only contains audio.
:rtype: bool
"""
return self.is_progressive or self.type == "audio"
@property
def includes_video_track(self) -> bool:
"""Whether the stream only contains video.
:rtype: bool
"""
return self.is_progressive or self.type == "video"
def parse_codecs(self) -> Tuple[Optional[str], Optional[str]]:
"""Get the video/audio codecs from list of codecs.
Parse a variable length sized list of codecs and returns a
constant two element tuple, with the video codec as the first element
and audio as the second. Returns None if one is not available
(adaptive only).
:rtype: tuple
:returns:
A two element tuple with audio and video codecs.
"""
video = None
audio = None
if not self.is_adaptive:
video, audio = self.codecs
elif self.includes_video_track:
video = self.codecs[0]
elif self.includes_audio_track:
audio = self.codecs[0]
return video, audio
@property
def filesize(self) -> int:
"""File size of the media stream in bytes.
:rtype: int
:returns:
Filesize (in bytes) of the stream.
"""
if self._filesize == 0:
try:
self._filesize = request.filesize(self.url)
except HTTPError as e:
if e.code != 404:
raise
self._filesize = request.seq_filesize(self.url)
return self._filesize
@property
def filesize_kb(self) -> float:
"""File size of the media stream in kilobytes.
:rtype: float
:returns:
Rounded filesize (in kilobytes) of the stream.
"""
if self._filesize_kb == 0:
try:
self._filesize_kb = float(ceil(request.filesize(self.url)/1024 * 1000) / 1000)
except HTTPError as e:
if e.code != 404:
raise
self._filesize_kb = float(ceil(request.seq_filesize(self.url)/1024 * 1000) / 1000)
return self._filesize_kb
@property
def filesize_mb(self) -> float:
"""File size of the media stream in megabytes.
:rtype: float
:returns:
Rounded filesize (in megabytes) of the stream.
"""
if self._filesize_mb == 0:
try:
self._filesize_mb = float(ceil(request.filesize(self.url)/1024/1024 * 1000) / 1000)
except HTTPError as e:
if e.code != 404:
raise
self._filesize_mb = float(ceil(request.seq_filesize(self.url)/1024/1024 * 1000) / 1000)
return self._filesize_mb
@property
def filesize_gb(self) -> float:
"""File size of the media stream in gigabytes.
:rtype: float
:returns:
Rounded filesize (in gigabytes) of the stream.
"""
if self._filesize_gb == 0:
try:
self._filesize_gb = float(ceil(request.filesize(self.url)/1024/1024/1024 * 1000) / 1000)
except HTTPError as e:
if e.code != 404:
raise
self._filesize_gb = float(ceil(request.seq_filesize(self.url)/1024/1024/1024 * 1000) / 1000)
return self._filesize_gb
@property
def title(self) -> str:
"""Get title of video
:rtype: str
:returns:
Youtube video title
"""
return self._monostate.title or "Unknown YouTube Video Title"
@property
def filesize_approx(self) -> int:
"""Get approximate filesize of the video
Falls back to HTTP call if there is not sufficient information to approximate
:rtype: int
:returns: size of video in bytes
"""
if self._monostate.duration and self.bitrate:
bits_in_byte = 8
return int(
(self._monostate.duration * self.bitrate) / bits_in_byte
)
return self.filesize
@property
def expiration(self) -> datetime:
expire = parse_qs(self.url.split("?")[1])["expire"][0]
return datetime.utcfromtimestamp(int(expire))
@property
def default_filename(self) -> str:
"""Generate filename based on the video title.
:rtype: str
:returns:
An os file system compatible filename.
"""
filename = safe_filename(self.title)
return f"{filename}.{self.subtype}"
def download(
self,
output_path: Optional[str] = None,
filename: Optional[str] = None,
filename_prefix: Optional[str] = None,
skip_existing: bool = True,
timeout: Optional[int] = None,
max_retries: Optional[int] = 0
) -> str:
"""Write the media stream to disk.
:param output_path:
(optional) Output path for writing media file. If one is not
specified, defaults to the current working directory.
:type output_path: str or None
:param filename:
(optional) Output filename (stem only) for writing media file.
If one is not specified, the default filename is used.
:type filename: str or None
:param filename_prefix:
(optional) A string that will be prepended to the filename.
For example a number in a playlist or the name of a series.
If one is not specified, nothing will be prepended
This is separate from filename so you can use the default
filename but still add a prefix.
:type filename_prefix: str or None
:param skip_existing:
(optional) Skip existing files, defaults to True
:type skip_existing: bool
:param timeout:
(optional) Request timeout length in seconds. Uses system default.
:type timeout: int
:param max_retries:
(optional) Number of retries to attempt after socket timeout. Defaults to 0.
:type max_retries: int
:returns:
Path to the saved video
:rtype: str
"""
file_path = self.get_file_path(
filename=filename,
output_path=output_path,
filename_prefix=filename_prefix,
)
if skip_existing and self.exists_at_path(file_path):
logger.debug(f'file {file_path} already exists, skipping')
self.on_complete(file_path)
return file_path
bytes_remaining = self.filesize
logger.debug(f'downloading ({self.filesize} total bytes) file to {file_path}')
with open(file_path, "wb") as fh:
try:
for chunk in request.stream(
self.url,
timeout=timeout,
max_retries=max_retries
):
# reduce the (bytes) remainder by the length of the chunk.
bytes_remaining -= len(chunk)
# send to the on_progress callback.
self.on_progress(chunk, fh, bytes_remaining)
except HTTPError as e:
if e.code != 404:
raise
# Some adaptive streams need to be requested with sequence numbers
for chunk in request.seq_stream(
self.url,
timeout=timeout,
max_retries=max_retries
):
# reduce the (bytes) remainder by the length of the chunk.
bytes_remaining -= len(chunk)
# send to the on_progress callback.
self.on_progress(chunk, fh, bytes_remaining)
self.on_complete(file_path)
return file_path
def get_file_path(
self,
filename: Optional[str] = None,
output_path: Optional[str] = None,
filename_prefix: Optional[str] = None,
) -> str:
if not filename:
filename = self.default_filename
if filename_prefix:
filename = f"{filename_prefix}{filename}"
return os.path.join(target_directory(output_path), filename)
def exists_at_path(self, file_path: str) -> bool:
return (
os.path.isfile(file_path)
and os.path.getsize(file_path) == self.filesize
)
def stream_to_buffer(self, buffer: BinaryIO) -> None:
"""Write the media stream to buffer
:rtype: io.BytesIO buffer
"""
bytes_remaining = self.filesize
logger.info(
"downloading (%s total bytes) file to buffer", self.filesize,
)
for chunk in request.stream(self.url):
# reduce the (bytes) remainder by the length of the chunk.
bytes_remaining -= len(chunk)
# send to the on_progress callback.
self.on_progress(chunk, buffer, bytes_remaining)
self.on_complete(None)
def on_progress(
self, chunk: bytes, file_handler: BinaryIO, bytes_remaining: int
):
"""On progress callback function.
This function writes the binary data to the file, then checks if an
additional callback is defined in the monostate. This is exposed to
allow things like displaying a progress bar.
:param bytes chunk:
Segment of media file binary data, not yet written to disk.
:param file_handler:
The file handle where the media is being written to.
:type file_handler:
:py:class:`io.BufferedWriter`
:param int bytes_remaining:
The delta between the total file size in bytes and amount already
downloaded.
:rtype: None
"""
file_handler.write(chunk)
logger.debug("download remaining: %s", bytes_remaining)
if self._monostate.on_progress:
self._monostate.on_progress(self, chunk, bytes_remaining)
def on_complete(self, file_path: Optional[str]):
"""On download complete handler function.
:param file_path:
The file handle where the media is being written to.
:type file_path: str
:rtype: None
"""
logger.debug("download finished")
on_complete = self._monostate.on_complete
if on_complete:
logger.debug("calling on_complete callback %s", on_complete)
on_complete(self, file_path)
def __repr__(self) -> str:
"""Printable object representation.
:rtype: str
:returns:
A string representation of a :class:`Stream <Stream>` object.
"""
parts = ['itag="{s.itag}"', 'mime_type="{s.mime_type}"']
if self.includes_video_track:
parts.extend(['res="{s.resolution}"', 'fps="{s.fps}fps"'])
if not self.is_adaptive:
parts.extend(
['vcodec="{s.video_codec}"', 'acodec="{s.audio_codec}"',]
)
else:
parts.extend(['vcodec="{s.video_codec}"'])
else:
parts.extend(['abr="{s.abr}"', 'acodec="{s.audio_codec}"'])
parts.extend(['progressive="{s.is_progressive}"', 'type="{s.type}"'])
return f"<Stream: {' '.join(parts).format(s=self)}>"

View File

@@ -1,4 +0,0 @@
__version__ = "15.0.0"
if __name__ == "__main__":
print(__version__)

View File

@@ -9,7 +9,7 @@ cors~=1.0.1
Flask~=3.0.3 Flask~=3.0.3
Flask-BabelEx~=0.9.4 Flask-BabelEx~=0.9.4
Flask-Bootstrap~=3.3.7.1 Flask-Bootstrap~=3.3.7.1
Flask-Cors~=4.0.1 Flask-Cors~=5.0.0
Flask-JWT-Extended~=4.6.0 Flask-JWT-Extended~=4.6.0
Flask-Login~=0.6.3 Flask-Login~=0.6.3
flask-mailman~=1.1.1 flask-mailman~=1.1.1
@@ -43,7 +43,6 @@ pgvector~=0.2.5
pycryptodome~=3.20.0 pycryptodome~=3.20.0
pydantic~=2.7.4 pydantic~=2.7.4
PyJWT~=2.8.0 PyJWT~=2.8.0
pypdf~=4.2.0
PySocks~=1.7.1 PySocks~=1.7.1
python-dateutil~=2.9.0.post0 python-dateutil~=2.9.0.post0
python-engineio~=4.9.1 python-engineio~=4.9.1
@@ -61,16 +60,18 @@ urllib3~=2.2.2
WTForms~=3.1.2 WTForms~=3.1.2
wtforms-html5~=0.6.1 wtforms-html5~=0.6.1
zxcvbn~=4.4.28 zxcvbn~=4.4.28
pytube~=15.0.0
PyPDF2~=3.0.1
groq~=0.9.0 groq~=0.9.0
pydub~=0.25.1 pydub~=0.25.1
argparse~=1.4.0 argparse~=1.4.0
portkey_ai~=1.7.0 portkey_ai~=1.8.2
minio~=7.2.7 minio~=7.2.7
Werkzeug~=3.0.3 Werkzeug~=3.0.3
itsdangerous~=2.2.0 itsdangerous~=2.2.0
cryptography~=43.0.0 cryptography~=43.0.0
graypy~=2.1.0 graypy~=2.1.0
lxml~=5.3.0
pillow~=10.4.0
pdfplumber~=0.11.4
PyPDF2~=3.0.1
flask-restx~=1.3.0

6
scripts/run_eveai_api.py Normal file
View File

@@ -0,0 +1,6 @@
from eveai_api import create_app
app = create_app()
if __name__ == '__main__':
app.run(debug=True)

15
scripts/start_eveai_api.sh Executable file
View File

@@ -0,0 +1,15 @@
#!/bin/bash
cd "/app" || exit 1
export PROJECT_DIR="/app"
export PYTHONPATH="$PROJECT_DIR/patched_packages:$PYTHONPATH:$PROJECT_DIR" # Include the app directory in the Python path & patched packages
# Set FLASK_APP environment variables
export FLASK_APP=${PROJECT_DIR}/scripts/run_eveai_app.py # Adjust the path to your Flask app entry point
# Ensure we can write the logs
chown -R appuser:appuser /app/logs
# Start Flask app
gunicorn -w 4 -k gevent -b 0.0.0.0:5003 scripts.run_eveai_api:app