Refactoring part 1
Some changes for workers, but stopped due to refactoring
This commit is contained in:
23
eveai_workers/celery_utils.py
Normal file
23
eveai_workers/celery_utils.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from .tasks import create_embeddings
|
||||
from celery import Celery, Task
|
||||
|
||||
|
||||
def init_celery(app):
|
||||
class ContextTask(Task):
|
||||
def __call__(self, *args, **kwargs):
|
||||
with app.app_context():
|
||||
return self.run(*args, **kwargs)
|
||||
|
||||
celery_app = Celery(app.import_name, task_cls=ContextTask)
|
||||
|
||||
celery_app.conf.broker_url = app.config.get('CELERY_BROKER_URL')
|
||||
celery_app.conf.result_backend = app.config.get('CELERY_RESULT_BACKEND')
|
||||
celery_app.conf.accept_content = app.config.get('CELERY_ACCEPT_CONTENT')
|
||||
celery_app.conf.task_serializer = app.config.get('CELERY_TASK_SERIALIZER')
|
||||
celery_app.conf.timezone = app.config.get('CELERY_TIMEZONE')
|
||||
celery_app.conf.enable_utc = app.config.get('CELERY_ENABLE_UTC')
|
||||
|
||||
celery_app.set_default()
|
||||
|
||||
app.extensions['celery'] = celery_app
|
||||
|
||||
67
eveai_workers/tasks.py
Normal file
67
eveai_workers/tasks.py
Normal file
@@ -0,0 +1,67 @@
|
||||
from datetime import datetime as dt, timezone as tz
|
||||
from flask import current_app
|
||||
from langchain_community.document_loaders.unstructured import UnstructuredAPIFileLoader
|
||||
import os
|
||||
from celery import shared_task
|
||||
|
||||
from common.utils.database import Database
|
||||
from common.models import DocumentVersion, EmbeddingMistral, EmbeddingSmallOpenAI
|
||||
from eveai_app import db
|
||||
|
||||
|
||||
@shared_task(name='create_embeddings', queue='embeddings')
|
||||
def create_embeddings(tenant_id, document_version_id, default_embedding_model):
|
||||
current_app.logger.info(f'Creating embeddings for tenant {tenant_id} on document version {document_version_id} '
|
||||
f'with model {default_embedding_model}')
|
||||
|
||||
# Ensure we are working in the correct database schema
|
||||
Database(tenant_id).switch_schema()
|
||||
|
||||
# Retrieve document version to process
|
||||
document_version = DocumentVersion.query.get(document_version_id)
|
||||
if document_version is None:
|
||||
current_app.logger.error(f'Cannot create embeddings for tenant {tenant_id}. '
|
||||
f'Document version {document_version_id} not found')
|
||||
return
|
||||
db.session.add(document_version)
|
||||
|
||||
# start processing
|
||||
document_version.processing = True
|
||||
document_version.processing_started_at = dt.now(tz.utc)
|
||||
db.session.commit()
|
||||
|
||||
embedding_provider = default_embedding_model.rsplit('.', 1)[0]
|
||||
embedding_model = default_embedding_model.rsplit('.', 1)[1]
|
||||
# define embedding variables
|
||||
match (embedding_provider, embedding_model):
|
||||
case ('openai', 'text-embedding-3-small'):
|
||||
embedding_model = EmbeddingSmallOpenAI()
|
||||
case ('mistral', 'text-embedding-3-small'):
|
||||
embedding_model = EmbeddingMistral()
|
||||
|
||||
match document_version.file_type:
|
||||
case 'pdf':
|
||||
url = current_app.config.get('UNSTRUCTURED_FULL_URL')
|
||||
api_key = current_app.config.get('UNSTRUCTURED_API_KEY')
|
||||
file_path = os.path.join(current_app.config['UPLOAD_FOLDER'],
|
||||
document_version.file_location,
|
||||
document_version.file_path)
|
||||
with open(file_path, 'rb') as f:
|
||||
loader = UnstructuredAPIFileLoader(f,
|
||||
url=url,
|
||||
api_key=api_key,
|
||||
mode='elements',
|
||||
strategy='hi-res',
|
||||
include_page_breaks=True,
|
||||
unique_element_ids=True,
|
||||
chunking_strategy='by_title',
|
||||
max_characters=3000,
|
||||
)
|
||||
documents = loader.load()
|
||||
print(documents)
|
||||
|
||||
|
||||
@shared_task(name='ask_eve_ai', queue='llm_interactions')
|
||||
def ask_eve_ai(query):
|
||||
# Interaction logic with LLMs like GPT (Langchain API calls, etc.)
|
||||
pass
|
||||
Reference in New Issue
Block a user