diff --git a/common/models/user.py b/common/models/user.py
index 9eb5ffa..9b2ac43 100644
--- a/common/models/user.py
+++ b/common/models/user.py
@@ -2,6 +2,7 @@ from common.extensions import db
from flask_security import UserMixin, RoleMixin
from sqlalchemy.dialects.postgresql import ARRAY
import sqlalchemy as sa
+from sqlalchemy import CheckConstraint
class Tenant(db.Model):
@@ -33,6 +34,14 @@ class Tenant(db.Model):
html_included_elements = db.Column(ARRAY(sa.String(50)), nullable=True)
html_excluded_elements = db.Column(ARRAY(sa.String(50)), nullable=True)
+ # Embedding search variables
+ es_k = db.Column(db.Integer, nullable=True, default=5)
+ es_similarity_threshold = db.Column(db.Float, nullable=True, default=0.7)
+
+ # Chat variables
+ chat_RAG_temperature = db.Column(db.Float, nullable=True, default=0.3)
+ chat_no_RAG_temperature = db.Column(db.Float, nullable=True, default=0.5)
+
# Licensing Information
license_start_date = db.Column(db.Date, nullable=True)
license_end_date = db.Column(db.Date, nullable=True)
diff --git a/common/utils/celery_utils.py b/common/utils/celery_utils.py
index bd3e9dc..8224fbb 100644
--- a/common/utils/celery_utils.py
+++ b/common/utils/celery_utils.py
@@ -7,6 +7,8 @@ celery_app = Celery()
def init_celery(celery, app):
celery_app.main = app.name
+ app.logger.debug(f'CELERY_BROKER_URL: {app.config["CELERY_BROKER_URL"]}')
+ app.logger.debug(f'CELERY_RESULT_BACKEND: {app.config["CELERY_RESULT_BACKEND"]}')
celery_config = {
'broker_url': app.config.get('CELERY_BROKER_URL', 'redis://localhost:6379/0'),
'result_backend': app.config.get('CELERY_RESULT_BACKEND', 'redis://localhost:6379/0'),
diff --git a/config/config.py b/config/config.py
index 7fff6f5..4db4cad 100644
--- a/config/config.py
+++ b/config/config.py
@@ -1,5 +1,6 @@
from os import environ, path
from datetime import timedelta
+import redis
basedir = path.abspath(path.dirname(__file__))
@@ -42,7 +43,7 @@ class Config(object):
# supported LLMs
SUPPORTED_EMBEDDINGS = ['openai.text-embedding-3-small', 'mistral.mistral-embed']
- SUPPORTED_LLMS = ['openai.gpt-4-turbo', 'openai.gpt-3.5-turbo', 'mistral.mistral-large-2402']
+ SUPPORTED_LLMS = ['openai.gpt-4o', 'openai.gpt-4-turbo', 'openai.gpt-3.5-turbo', 'mistral.mistral-large-2402']
# Celery settings
CELERY_TASK_SERIALIZER = 'json'
@@ -62,6 +63,20 @@ class Config(object):
GPT3_5_SUMMARY_TEMPLATE = """Write a concise summary of the text in the same language as the provided text.
Text is delimited between triple backquotes.
```{text}```"""
+ GPT4_RAG_TEMPLATE = """Answer the question based on the following context, both delimited between triple backquotes
+ in the same language as question.
+ If the question cannot be answered using the text, say "I don't know" in the same language as the question.
+ Context:
+ ```{context}```
+ Question:
+ ```{question}```"""
+ GPT3_5_RAG_TEMPLATE = """Answer the question based on the following context, both delimited between triple backquotes
+ in the same language as question.
+ If the question cannot be answered using the text, say "I don't know" in the same language as the question.
+ Context:
+ ```{context}```
+ Question:
+ ```{question}```"""
# SocketIO settings
# SOCKETIO_ASYNC_MODE = 'threading'
@@ -91,8 +106,13 @@ class DevConfig(Config):
UPLOAD_FOLDER = '/Volumes/OWC4M2_1/Development/eveAI/file_store'
# Celery settings
- CELERY_BROKER_URL = 'redis://localhost:6379/0' # Default Redis configuration
+ # eveai_app Redis Settings
+ CELERY_BROKER_URL = 'redis://localhost:6379/0'
CELERY_RESULT_BACKEND = 'redis://localhost:6379/0'
+ # eveai_chat Redis Settings
+ CELERY_BROKER_URL_CHAT = 'redis://localhost:6379/3'
+ CELERY_RESULT_BACKEND_CHAT = 'redis://localhost:6379/3'
+
# OpenAI API Keys
OPENAI_API_KEY = 'sk-proj-8R0jWzwjL7PeoPyMhJTZT3BlbkFJLb6HfRB2Hr9cEVFWEhU7'
@@ -118,12 +138,8 @@ class DevConfig(Config):
JWT_SECRET_KEY = 'bsdMkmQ8ObfMD52yAFg4trrvjgjMhuIqg2fjDpD/JqvgY0ccCcmlsEnVFmR79WPiLKEA3i8a5zmejwLZKl4v9Q=='
# Session settings
- SESSION_REDIS = {
- 'host': 'localhost', # Redis server hostname or IP address
- 'port': 6379, # Redis server port
- 'db': 2, # Redis database number (optional)
- 'password': None # Redis password (optional)
- }
+ SESSION_REDIS = redis.from_url('redis://localhost:6379/2')
+
class ProdConfig(Config):
DEVELOPMENT = False
diff --git a/config/logging_config.py b/config/logging_config.py
index 9156dfa..0214108 100644
--- a/config/logging_config.py
+++ b/config/logging_config.py
@@ -26,6 +26,14 @@ LOGGING = {
'backupCount': 10,
'formatter': 'standard',
},
+ 'file_chat_workers': {
+ 'level': 'DEBUG',
+ 'class': 'logging.handlers.RotatingFileHandler',
+ 'filename': 'logs/eveai_chat_workers.log',
+ 'maxBytes': 1024*1024*5, # 5MB
+ 'backupCount': 10,
+ 'formatter': 'standard',
+ },
'console': {
'class': 'logging.StreamHandler',
'level': 'DEBUG',
@@ -53,6 +61,11 @@ LOGGING = {
'level': 'DEBUG',
'propagate': False
},
+ 'eveai_chat_workers': { # logger for the eveai_chat_workers
+ 'handlers': ['file_chat_workers', 'console'],
+ 'level': 'DEBUG',
+ 'propagate': False
+ },
'': { # root logger
'handlers': ['console'],
'level': 'WARNING', # Set higher level for root to minimize noise
diff --git a/eveai_app/templates/user/tenant_overview.html b/eveai_app/templates/user/tenant_overview.html
index 51f4e36..6fe7380 100644
--- a/eveai_app/templates/user/tenant_overview.html
+++ b/eveai_app/templates/user/tenant_overview.html
@@ -35,6 +35,11 @@
HTML Chunking
+
+
+ Embedding Search
+
+
Domains
@@ -71,6 +76,14 @@
{{ render_included_field(field, disabled_fields=html_fields, include_fields=html_fields) }}
{% endfor %}
+
+
+ {% set es_fields = ['es_k', 'es_similarity_threshold', ] %}
+ {% for field in form %}
+ {{ render_included_field(field, disabled_fields=es_fields, include_fields=es_fields) }}
+ {% endfor %}
+
+
diff --git a/eveai_app/views/user_forms.py b/eveai_app/views/user_forms.py
index 55d2918..e6d226f 100644
--- a/eveai_app/views/user_forms.py
+++ b/eveai_app/views/user_forms.py
@@ -1,7 +1,7 @@
from flask import current_app
from flask_wtf import FlaskForm
from wtforms import (StringField, PasswordField, BooleanField, SubmitField, EmailField, IntegerField, DateField,
- SelectField, SelectMultipleField, FieldList, FormField)
+ SelectField, SelectMultipleField, FieldList, FormField, FloatField)
from wtforms.validators import DataRequired, Length, Email, NumberRange, Optional
from common.models.user import Role
@@ -26,6 +26,10 @@ class TenantForm(FlaskForm):
default='p, li')
html_included_elements = StringField('HTML Included Elements', validators=[Optional()])
html_excluded_elements = StringField('HTML Excluded Elements', validators=[Optional()])
+ # Embedding Search variables
+ es_k = IntegerField('Limit for Searching Embeddings (5)', validators=[NumberRange(min=0)])
+ es_similarity_threshold = FloatField('Similarity Threshold for Searching Embeddings (0.5)',
+ validators=[NumberRange(min=0, max=1)])
submit = SubmitField('Submit')
diff --git a/eveai_chat/__init__.py b/eveai_chat/__init__.py
index f39aad7..ca7fc97 100644
--- a/eveai_chat/__init__.py
+++ b/eveai_chat/__init__.py
@@ -7,6 +7,9 @@ from common.extensions import db, socketio, jwt, kms_client, cors, session
from config.logging_config import LOGGING
from eveai_chat.socket_handlers import chat_handler
from common.utils.cors_utils import create_cors_after_request
+from common.utils.celery_utils import make_celery, init_celery
+
+
def create_app(config_file=None):
@@ -20,6 +23,9 @@ def create_app(config_file=None):
logging.config.dictConfig(LOGGING)
register_extensions(app)
+ app.celery = make_celery(app.name, app.config)
+ init_celery(app.celery, app)
+
# Register Blueprints
register_blueprints(app)
@@ -49,13 +55,6 @@ def register_extensions(app):
cors.init_app(app, resources={r"/chat/*": {"origins": "*"}})
app.after_request(create_cors_after_request('/chat'))
- # Session setup
- # redis_config = app.config['SESSION_REDIS']
- # redis_client = Redis(host=redis_config['host'],
- # port=redis_config['port'],
- # db=redis_config['db'],
- # password=redis_config['password']
- # )
session.init_app(app)
diff --git a/eveai_chat/socket_handlers/chat_handler.py b/eveai_chat/socket_handlers/chat_handler.py
index 91b1c7b..b4ac376 100644
--- a/eveai_chat/socket_handlers/chat_handler.py
+++ b/eveai_chat/socket_handlers/chat_handler.py
@@ -4,6 +4,7 @@ from flask import current_app, request
from common.extensions import socketio, kms_client
from common.models.user import Tenant
+from common.utils.celery_utils import current_celery
@socketio.on('connect')
@@ -66,13 +67,17 @@ def handle_message(data):
if not current_api_key:
raise Exception("Missing api_key")
- # Store interaction in the database
-
+ # Offload actual processing of question
+ task = current_celery.send_task('ask_question', queue='llm_interactions', args=[
+ current_tenant_id,
+ data['message'],
+ ])
+ current_app.logger.debug(f'SocketIO: Message offloading for tenant {current_tenant_id}, '
+ f'Question: {task.id}')
response = {
'tenantId': data['tenantId'],
- 'message': f'This is a bot response. Responding to message {data['message']} '
- f'from tenant {current_tenant_id}',
- 'messageId': 'bot-message-id',
+ 'message': 'Processing question ...',
+ 'taskId': task.id,
'algorithm': 'alg1'
}
current_app.logger.debug(f"SocketIO: Message handling sent bot response: {response}")
@@ -82,6 +87,30 @@ def handle_message(data):
disconnect()
+@socketio.on('check_task_status')
+def check_task_status(data):
+ task_id = data.get('task_id')
+ current_app.logger.debug(f'SocketIO: Check task status for task_id: {task_id}')
+ if not task_id:
+ emit('task_status', {'status': 'error', 'message': 'Missing task ID'})
+ return
+
+ task_result = current_celery.AsyncResult(task_id)
+ if task_result.state == 'PENDING':
+ current_app.logger.debug(f'SocketIO: Task {task_id} is pending')
+ emit('task_status', {'status': 'pending', 'taskId': task_id})
+ elif task_result.state != 'FAILURE':
+ current_app.logger.debug(f'SocketIO: Task {task_id} has finished. Status: {task_result.state}, '
+ f'Result: {task_result.result}')
+ emit('task_status', {
+ 'status': task_result.state,
+ 'result': task_result.result
+ })
+ else:
+ current_app.logger.error(f'SocketIO: Task {task_id} has failed. Error: {task_result.info}')
+ emit('task_status', {'status': 'failure', 'message': str(task_result.info)})
+
+
def validate_api_key(tenant_id, api_key):
tenant = Tenant.query.get_or_404(tenant_id)
decrypted_api_key = kms_client.decrypt_api_key(tenant.encrypted_chat_api_key)
diff --git a/eveai_chat_workers/__init__.py b/eveai_chat_workers/__init__.py
new file mode 100644
index 0000000..a213962
--- /dev/null
+++ b/eveai_chat_workers/__init__.py
@@ -0,0 +1,36 @@
+import logging
+import logging.config
+from flask import Flask
+
+from common.utils.celery_utils import make_celery, init_celery
+from common.extensions import db
+from config.logging_config import LOGGING
+
+
+def create_app(config_file=None):
+ app = Flask(__name__)
+
+ if config_file is None:
+ app.config.from_object('config.config.DevConfig')
+ else:
+ app.config.from_object(config_file)
+
+ logging.config.dictConfig(LOGGING)
+
+ app.logger.debug('Starting up eveai_chat_workers...')
+ register_extensions(app)
+
+ celery = make_celery(app.name, app.config)
+ init_celery(celery, app)
+
+ from eveai_chat_workers import tasks
+ print(tasks.tasks_ping())
+
+ return app, celery
+
+
+def register_extensions(app):
+ db.init_app(app)
+
+
+app, celery = create_app()
diff --git a/eveai_chat_workers/tasks.py b/eveai_chat_workers/tasks.py
new file mode 100644
index 0000000..6cbd8c4
--- /dev/null
+++ b/eveai_chat_workers/tasks.py
@@ -0,0 +1,158 @@
+from datetime import datetime as dt, timezone as tz
+from flask import current_app
+from sqlalchemy.exc import SQLAlchemyError
+from celery import states
+from celery.exceptions import Ignore
+import os
+
+# Unstructured commercial client imports
+from unstructured_client import UnstructuredClient
+from unstructured_client.models import shared
+from unstructured_client.models.errors import SDKError
+
+# OpenAI imports
+from langchain_openai import OpenAIEmbeddings, ChatOpenAI
+from langchain_core.prompts import ChatPromptTemplate
+from langchain.chains.summarize import load_summarize_chain
+from langchain.text_splitter import CharacterTextSplitter
+from langchain_core.exceptions import LangChainException
+
+from common.utils.database import Database
+from common.models.document import DocumentVersion, EmbeddingMistral, EmbeddingSmallOpenAI
+from common.models.user import Tenant
+from common.extensions import db
+from common.utils.celery_utils import current_celery
+
+from bs4 import BeautifulSoup
+
+
+@current_celery.task(name='ask_question', queue='llm_interactions')
+def ask_question(tenant_id, question):
+ current_app.logger.debug('In ask_question')
+ current_app.logger.debug(f'ask_question: Received question for tenant {tenant_id}: {question}. Processing...')
+
+ try:
+ # Retrieve the tenant
+ tenant = Tenant.query.get(tenant_id)
+ if not tenant:
+ raise Exception(f'Tenant {tenant_id} not found.')
+
+ # Ensure we are working in the correct database schema
+ Database(tenant_id).switch_schema()
+
+ # Select variables to work with depending on tenant model
+ model_variables = select_model_variables(tenant)
+
+ # create embedding for the query
+ embedded_question = create_embedding(model_variables, question)
+
+ # Search the database for relevant embeddings
+ relevant_embeddings = search_embeddings(model_variables, embedded_question)
+
+ response = ""
+ for embed in relevant_embeddings:
+ response += relevant_embeddings.chunk + '\n'
+
+ return response
+ except Exception as e:
+ current_app.logger.error(f'ask_question: Error processing question: {e}')
+ raise Ignore
+
+
+def select_model_variables(tenant):
+ embedding_provider = tenant.embedding_model.rsplit('.', 1)[0]
+ embedding_model = tenant.embedding_model.rsplit('.', 1)[1]
+
+ llm_provider = tenant.llm_model.rsplit('.', 1)[0]
+ llm_model = tenant.llm_model.rsplit('.', 1)[1]
+
+ # Set model variables
+ model_variables = {}
+ if tenant.es_k:
+ model_variables['k'] = tenant.es_k
+ else:
+ model_variables['k'] = 5
+
+ if tenant.es_similarity_threshold:
+ model_variables['similarity_threshold'] = tenant.es_similarity_threshold
+ else:
+ model_variables['similarity_threshold'] = 0.7
+
+ if tenant.chat_RAG_temperature:
+ model_variables['RAG_temperature'] = tenant.chat_RAG_temperature
+ else:
+ model_variables['RAG_temperature'] = 0.3
+
+ if tenant.chat_no_RAG_temperature:
+ model_variables['no_RAG_temperature'] = tenant.chat_no_RAG_temperature
+ else:
+ model_variables['no_RAG_temperature'] = 0.5
+
+ # Set Embedding variables
+ match embedding_provider:
+ case 'openai':
+ match embedding_model:
+ case 'text-embedding-3-small':
+ api_key = current_app.config.get('OPENAI_API_KEY')
+ model_variables['embedding'] = OpenAIEmbeddings(api_key=api_key,
+ model='text-embedding-3-small')
+ model_variables['embedding_db_model'] = EmbeddingSmallOpenAI
+ case _:
+ raise Exception(f'Error setting model variables for tenant {tenant.id} '
+ f'error: Invalid embedding model')
+ case _:
+ raise Exception(f'Error setting model variables for tenant {tenant.id} '
+ f'error: Invalid embedding provider')
+
+ # Set Chat model variables
+ match llm_provider:
+ case 'openai':
+ api_key = current_app.config.get('OPENAI_API_KEY')
+ model_variables['llm'] = ChatOpenAI(api_key=api_key,
+ model=llm_model,
+ temperature=model_variables['RAG_temperature'])
+ match llm_model:
+ case 'gpt-4-turbo' | 'gpt-4-o':
+ rag_template = current_app.config.get('GPT4_RAG_TEMPLATE')
+ case 'gpt-3-5-turbo':
+ rag_template = current_app.config.get('GPT3_5_RAG_TEMPLATE')
+ case _:
+ raise Exception(f'Error setting model variables for tenant {tenant.id} '
+ f'error: Invalid chat model')
+ model_variables['prompt'] = ChatPromptTemplate.from_template(rag_template)
+ case _:
+ raise Exception(f'Error setting model variables for tenant {tenant.id} '
+ f'error: Invalid chat provider')
+
+ return model_variables
+
+
+def create_embedding(model_variables, question):
+ try:
+ embeddings = model_variables['embedding'].embed_documents(question)
+ except LangChainException as e:
+ raise Exception(f'Error creating embedding for question (LangChain): {e}')
+
+ return embeddings[0]
+
+
+def search_embeddings(model_variables, embedded_query):
+ current_app.logger.debug(f'In search_embeddings searching for {embedded_query}')
+ db_class = model_variables['embedding_db_model']
+ try:
+ res = (
+ db.session.query(db_class, db_class.embedding.cosine_distance(embedded_query).label('distance'))
+ .filter(db_class.embedding.cosine_distance(embedded_query) < model_variables['similarity_threshold'])
+ .order_by("distance")
+ .limit(model_variables['k'])
+ .all()
+ )
+ except SQLAlchemyError as e:
+ raise Exception(f'Error searching embeddings (SQLAlchemy): {e}')
+
+ current_app.logger.debug(f'Results from embedding search: {res}')
+ return res
+
+
+def tasks_ping():
+ return 'pong'
diff --git a/eveai_workers/celery_utils.py b/eveai_workers/celery_utils.py
deleted file mode 100644
index fd436d0..0000000
--- a/eveai_workers/celery_utils.py
+++ /dev/null
@@ -1,23 +0,0 @@
-from .tasks import create_embeddings
-from celery import Celery, Task
-
-
-def init_celery(app):
- class ContextTask(Task):
- def __call__(self, *args, **kwargs):
- with app.app_context():
- return self.run(*args, **kwargs)
-
- celery_app = Celery(app.import_name, task_cls=ContextTask)
-
- celery_app.conf.broker_url = app.config.get('CELERY_BROKER_URL')
- celery_app.conf.result_backend = app.config.get('CELERY_RESULT_BACKEND')
- celery_app.conf.accept_content = app.config.get('CELERY_ACCEPT_CONTENT')
- celery_app.conf.task_serializer = app.config.get('CELERY_TASK_SERIALIZER')
- celery_app.conf.timezone = app.config.get('CELERY_TIMEZONE')
- celery_app.conf.enable_utc = app.config.get('CELERY_ENABLE_UTC')
-
- celery_app.set_default()
-
- app.extensions['celery'] = celery_app
-
diff --git a/scripts/run_eveai_chat_workers.py b/scripts/run_eveai_chat_workers.py
new file mode 100644
index 0000000..282500e
--- /dev/null
+++ b/scripts/run_eveai_chat_workers.py
@@ -0,0 +1,4 @@
+from eveai_chat_workers import celery
+
+if __name__ == '__main__':
+ celery.start()
diff --git a/scripts/start_eveai_chat_workers.sh b/scripts/start_eveai_chat_workers.sh
new file mode 100755
index 0000000..09faa7f
--- /dev/null
+++ b/scripts/start_eveai_chat_workers.sh
@@ -0,0 +1,13 @@
+#!/usr/bin/env bash
+
+cd "/Volumes/OWC4M2_1/Dropbox/Josako's Dev/Josako/EveAI/Development/eveAI/" || exit 1
+source "/Volumes/OWC4M2_1/Dropbox/Josako's Dev/Josako/EveAI/Development/eveAI/.venv/bin/activate"
+export PYTHONPATH="$PYTHONPATH:/Volumes/OWC4M2_1/Dropbox/Josako's Dev/Josako/EveAI/Development/eveAI/"
+
+# Start a worker for the 'llm_interactions' queue with auto-scaling
+celery -A eveai_chat_workers.celery worker --loglevel=info -Q llm_interactions --autoscale=2,8 --hostname=interactions_worker@%h &
+
+# Wait for all background processes to finish
+wait
+
+deactivate
\ No newline at end of file
diff --git a/scripts/start_eveai_workers.sh b/scripts/start_eveai_workers.sh
index 26bbba3..f674943 100755
--- a/scripts/start_eveai_workers.sh
+++ b/scripts/start_eveai_workers.sh
@@ -6,8 +6,8 @@ source "/Volumes/OWC4M2_1/Dropbox/Josako's Dev/Josako/EveAI/Development/eveAI/.v
# Start a worker for the 'embeddings' queue with higher concurrency
celery -A eveai_workers.celery worker --loglevel=info -Q embeddings --autoscale=1,4 --hostname=embeddings_worker@%h &
-# Start a worker for the 'llm_interactions' queue with auto-scaling
-celery -A eveai_workers.celery worker --loglevel=info - Q llm_interactions --autoscale=2,8 --hostname=interactions_worker@%h &
+# Start a worker for the 'llm_interactions' queue with auto-scaling - not necessary, in eveai_chat_workers
+# celery -A eveai_workers.celery worker --loglevel=info - Q llm_interactions --autoscale=2,8 --hostname=interactions_worker@%h &
# Wait for all background processes to finish
wait
diff --git a/static/js/eveai-chat-widget.js b/static/js/eveai-chat-widget.js
index 73ceb7f..5789a82 100644
--- a/static/js/eveai-chat-widget.js
+++ b/static/js/eveai-chat-widget.js
@@ -114,9 +114,30 @@ class EveAIChatWidget extends HTMLElement {
this.socket.on('bot_response', (data) => {
if (data.tenantId === this.tenantId) {
+ console.log('Bot response received:', data)
+ console.log('Task ID received:', data.taskId)
this.addMessage(data.message, 'bot', data.messageId, data.algorithm);
+ this.checkTaskStatus(data.taskId)
}
});
+
+ this.socket.on('task_status', (data) => {
+ console.log('Task status received:', data)
+ console.log('Task ID received:', data.taskId)
+ if (data.status === 'SUCCESS') {
+ this.addMessage(data.result, 'bot');
+ } else if (data.status === 'FAILURE') {
+ this.addMessage('Failed to process message.', 'bot');
+ } else if (data.status === 'pending') {
+ console.log('Task is pending')
+ setTimeout(() => this.checkTaskStatus(data.taskId), 1000); // Poll every second
+ console.log('New check sent')
+ }
+ });
+ }
+
+ checkTaskStatus(taskId) {
+ this.socket.emit('check_task_status', { task_id: taskId });
}
getTemplate() {