Enable model variables & start working on RAG task

This commit is contained in:
Josako
2024-05-25 20:17:02 +02:00
parent e483d6cf90
commit ce91323dc9
15 changed files with 340 additions and 46 deletions

View File

@@ -2,6 +2,7 @@ from common.extensions import db
from flask_security import UserMixin, RoleMixin
from sqlalchemy.dialects.postgresql import ARRAY
import sqlalchemy as sa
from sqlalchemy import CheckConstraint
class Tenant(db.Model):
@@ -33,6 +34,14 @@ class Tenant(db.Model):
html_included_elements = db.Column(ARRAY(sa.String(50)), nullable=True)
html_excluded_elements = db.Column(ARRAY(sa.String(50)), nullable=True)
# Embedding search variables
es_k = db.Column(db.Integer, nullable=True, default=5)
es_similarity_threshold = db.Column(db.Float, nullable=True, default=0.7)
# Chat variables
chat_RAG_temperature = db.Column(db.Float, nullable=True, default=0.3)
chat_no_RAG_temperature = db.Column(db.Float, nullable=True, default=0.5)
# Licensing Information
license_start_date = db.Column(db.Date, nullable=True)
license_end_date = db.Column(db.Date, nullable=True)

View File

@@ -7,6 +7,8 @@ celery_app = Celery()
def init_celery(celery, app):
celery_app.main = app.name
app.logger.debug(f'CELERY_BROKER_URL: {app.config["CELERY_BROKER_URL"]}')
app.logger.debug(f'CELERY_RESULT_BACKEND: {app.config["CELERY_RESULT_BACKEND"]}')
celery_config = {
'broker_url': app.config.get('CELERY_BROKER_URL', 'redis://localhost:6379/0'),
'result_backend': app.config.get('CELERY_RESULT_BACKEND', 'redis://localhost:6379/0'),

View File

@@ -1,5 +1,6 @@
from os import environ, path
from datetime import timedelta
import redis
basedir = path.abspath(path.dirname(__file__))
@@ -42,7 +43,7 @@ class Config(object):
# supported LLMs
SUPPORTED_EMBEDDINGS = ['openai.text-embedding-3-small', 'mistral.mistral-embed']
SUPPORTED_LLMS = ['openai.gpt-4-turbo', 'openai.gpt-3.5-turbo', 'mistral.mistral-large-2402']
SUPPORTED_LLMS = ['openai.gpt-4o', 'openai.gpt-4-turbo', 'openai.gpt-3.5-turbo', 'mistral.mistral-large-2402']
# Celery settings
CELERY_TASK_SERIALIZER = 'json'
@@ -62,6 +63,20 @@ class Config(object):
GPT3_5_SUMMARY_TEMPLATE = """Write a concise summary of the text in the same language as the provided text.
Text is delimited between triple backquotes.
```{text}```"""
GPT4_RAG_TEMPLATE = """Answer the question based on the following context, both delimited between triple backquotes
in the same language as question.
If the question cannot be answered using the text, say "I don't know" in the same language as the question.
Context:
```{context}```
Question:
```{question}```"""
GPT3_5_RAG_TEMPLATE = """Answer the question based on the following context, both delimited between triple backquotes
in the same language as question.
If the question cannot be answered using the text, say "I don't know" in the same language as the question.
Context:
```{context}```
Question:
```{question}```"""
# SocketIO settings
# SOCKETIO_ASYNC_MODE = 'threading'
@@ -91,8 +106,13 @@ class DevConfig(Config):
UPLOAD_FOLDER = '/Volumes/OWC4M2_1/Development/eveAI/file_store'
# Celery settings
CELERY_BROKER_URL = 'redis://localhost:6379/0' # Default Redis configuration
# eveai_app Redis Settings
CELERY_BROKER_URL = 'redis://localhost:6379/0'
CELERY_RESULT_BACKEND = 'redis://localhost:6379/0'
# eveai_chat Redis Settings
CELERY_BROKER_URL_CHAT = 'redis://localhost:6379/3'
CELERY_RESULT_BACKEND_CHAT = 'redis://localhost:6379/3'
# OpenAI API Keys
OPENAI_API_KEY = 'sk-proj-8R0jWzwjL7PeoPyMhJTZT3BlbkFJLb6HfRB2Hr9cEVFWEhU7'
@@ -118,12 +138,8 @@ class DevConfig(Config):
JWT_SECRET_KEY = 'bsdMkmQ8ObfMD52yAFg4trrvjgjMhuIqg2fjDpD/JqvgY0ccCcmlsEnVFmR79WPiLKEA3i8a5zmejwLZKl4v9Q=='
# Session settings
SESSION_REDIS = {
'host': 'localhost', # Redis server hostname or IP address
'port': 6379, # Redis server port
'db': 2, # Redis database number (optional)
'password': None # Redis password (optional)
}
SESSION_REDIS = redis.from_url('redis://localhost:6379/2')
class ProdConfig(Config):
DEVELOPMENT = False

View File

@@ -26,6 +26,14 @@ LOGGING = {
'backupCount': 10,
'formatter': 'standard',
},
'file_chat_workers': {
'level': 'DEBUG',
'class': 'logging.handlers.RotatingFileHandler',
'filename': 'logs/eveai_chat_workers.log',
'maxBytes': 1024*1024*5, # 5MB
'backupCount': 10,
'formatter': 'standard',
},
'console': {
'class': 'logging.StreamHandler',
'level': 'DEBUG',
@@ -53,6 +61,11 @@ LOGGING = {
'level': 'DEBUG',
'propagate': False
},
'eveai_chat_workers': { # logger for the eveai_chat_workers
'handlers': ['file_chat_workers', 'console'],
'level': 'DEBUG',
'propagate': False
},
'': { # root logger
'handlers': ['console'],
'level': 'WARNING', # Set higher level for root to minimize noise

View File

@@ -35,6 +35,11 @@
HTML Chunking
</a>
</li>
<li class="nav-item">
<a class="nav-link mb-0 px-0 py-1" data-toggle="tab" href="#embedding-search-tab" role="tab" aria-controls="html-chunking" aria-selected="false">
Embedding Search
</a>
</li>
<li class="nav-item">
<a class="nav-link mb-0 px-0 py-1" data-toggle="tab" href="#domains-tab" role="tab" aria-controls="domains" aria-selected="false">
Domains
@@ -71,6 +76,14 @@
{{ render_included_field(field, disabled_fields=html_fields, include_fields=html_fields) }}
{% endfor %}
</div>
<!-- Embedding Search Settings Tab -->
<div class="tab-pane fade" id="embedding-search-tab" role="tabpanel">
{% set es_fields = ['es_k', 'es_similarity_threshold', ] %}
{% for field in form %}
{{ render_included_field(field, disabled_fields=es_fields, include_fields=es_fields) }}
{% endfor %}
</div>
<!-- Domains Tab -->
<div class="tab-pane fade" id="domains-tab" role="tabpanel">
<ul>

View File

@@ -1,7 +1,7 @@
from flask import current_app
from flask_wtf import FlaskForm
from wtforms import (StringField, PasswordField, BooleanField, SubmitField, EmailField, IntegerField, DateField,
SelectField, SelectMultipleField, FieldList, FormField)
SelectField, SelectMultipleField, FieldList, FormField, FloatField)
from wtforms.validators import DataRequired, Length, Email, NumberRange, Optional
from common.models.user import Role
@@ -26,6 +26,10 @@ class TenantForm(FlaskForm):
default='p, li')
html_included_elements = StringField('HTML Included Elements', validators=[Optional()])
html_excluded_elements = StringField('HTML Excluded Elements', validators=[Optional()])
# Embedding Search variables
es_k = IntegerField('Limit for Searching Embeddings (5)', validators=[NumberRange(min=0)])
es_similarity_threshold = FloatField('Similarity Threshold for Searching Embeddings (0.5)',
validators=[NumberRange(min=0, max=1)])
submit = SubmitField('Submit')

View File

@@ -7,6 +7,9 @@ from common.extensions import db, socketio, jwt, kms_client, cors, session
from config.logging_config import LOGGING
from eveai_chat.socket_handlers import chat_handler
from common.utils.cors_utils import create_cors_after_request
from common.utils.celery_utils import make_celery, init_celery
def create_app(config_file=None):
@@ -20,6 +23,9 @@ def create_app(config_file=None):
logging.config.dictConfig(LOGGING)
register_extensions(app)
app.celery = make_celery(app.name, app.config)
init_celery(app.celery, app)
# Register Blueprints
register_blueprints(app)
@@ -49,13 +55,6 @@ def register_extensions(app):
cors.init_app(app, resources={r"/chat/*": {"origins": "*"}})
app.after_request(create_cors_after_request('/chat'))
# Session setup
# redis_config = app.config['SESSION_REDIS']
# redis_client = Redis(host=redis_config['host'],
# port=redis_config['port'],
# db=redis_config['db'],
# password=redis_config['password']
# )
session.init_app(app)

View File

@@ -4,6 +4,7 @@ from flask import current_app, request
from common.extensions import socketio, kms_client
from common.models.user import Tenant
from common.utils.celery_utils import current_celery
@socketio.on('connect')
@@ -66,13 +67,17 @@ def handle_message(data):
if not current_api_key:
raise Exception("Missing api_key")
# Store interaction in the database
# Offload actual processing of question
task = current_celery.send_task('ask_question', queue='llm_interactions', args=[
current_tenant_id,
data['message'],
])
current_app.logger.debug(f'SocketIO: Message offloading for tenant {current_tenant_id}, '
f'Question: {task.id}')
response = {
'tenantId': data['tenantId'],
'message': f'This is a bot response. Responding to message {data['message']} '
f'from tenant {current_tenant_id}',
'messageId': 'bot-message-id',
'message': 'Processing question ...',
'taskId': task.id,
'algorithm': 'alg1'
}
current_app.logger.debug(f"SocketIO: Message handling sent bot response: {response}")
@@ -82,6 +87,30 @@ def handle_message(data):
disconnect()
@socketio.on('check_task_status')
def check_task_status(data):
task_id = data.get('task_id')
current_app.logger.debug(f'SocketIO: Check task status for task_id: {task_id}')
if not task_id:
emit('task_status', {'status': 'error', 'message': 'Missing task ID'})
return
task_result = current_celery.AsyncResult(task_id)
if task_result.state == 'PENDING':
current_app.logger.debug(f'SocketIO: Task {task_id} is pending')
emit('task_status', {'status': 'pending', 'taskId': task_id})
elif task_result.state != 'FAILURE':
current_app.logger.debug(f'SocketIO: Task {task_id} has finished. Status: {task_result.state}, '
f'Result: {task_result.result}')
emit('task_status', {
'status': task_result.state,
'result': task_result.result
})
else:
current_app.logger.error(f'SocketIO: Task {task_id} has failed. Error: {task_result.info}')
emit('task_status', {'status': 'failure', 'message': str(task_result.info)})
def validate_api_key(tenant_id, api_key):
tenant = Tenant.query.get_or_404(tenant_id)
decrypted_api_key = kms_client.decrypt_api_key(tenant.encrypted_chat_api_key)

View File

@@ -0,0 +1,36 @@
import logging
import logging.config
from flask import Flask
from common.utils.celery_utils import make_celery, init_celery
from common.extensions import db
from config.logging_config import LOGGING
def create_app(config_file=None):
app = Flask(__name__)
if config_file is None:
app.config.from_object('config.config.DevConfig')
else:
app.config.from_object(config_file)
logging.config.dictConfig(LOGGING)
app.logger.debug('Starting up eveai_chat_workers...')
register_extensions(app)
celery = make_celery(app.name, app.config)
init_celery(celery, app)
from eveai_chat_workers import tasks
print(tasks.tasks_ping())
return app, celery
def register_extensions(app):
db.init_app(app)
app, celery = create_app()

158
eveai_chat_workers/tasks.py Normal file
View File

@@ -0,0 +1,158 @@
from datetime import datetime as dt, timezone as tz
from flask import current_app
from sqlalchemy.exc import SQLAlchemyError
from celery import states
from celery.exceptions import Ignore
import os
# Unstructured commercial client imports
from unstructured_client import UnstructuredClient
from unstructured_client.models import shared
from unstructured_client.models.errors import SDKError
# OpenAI imports
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain.chains.summarize import load_summarize_chain
from langchain.text_splitter import CharacterTextSplitter
from langchain_core.exceptions import LangChainException
from common.utils.database import Database
from common.models.document import DocumentVersion, EmbeddingMistral, EmbeddingSmallOpenAI
from common.models.user import Tenant
from common.extensions import db
from common.utils.celery_utils import current_celery
from bs4 import BeautifulSoup
@current_celery.task(name='ask_question', queue='llm_interactions')
def ask_question(tenant_id, question):
current_app.logger.debug('In ask_question')
current_app.logger.debug(f'ask_question: Received question for tenant {tenant_id}: {question}. Processing...')
try:
# Retrieve the tenant
tenant = Tenant.query.get(tenant_id)
if not tenant:
raise Exception(f'Tenant {tenant_id} not found.')
# Ensure we are working in the correct database schema
Database(tenant_id).switch_schema()
# Select variables to work with depending on tenant model
model_variables = select_model_variables(tenant)
# create embedding for the query
embedded_question = create_embedding(model_variables, question)
# Search the database for relevant embeddings
relevant_embeddings = search_embeddings(model_variables, embedded_question)
response = ""
for embed in relevant_embeddings:
response += relevant_embeddings.chunk + '\n'
return response
except Exception as e:
current_app.logger.error(f'ask_question: Error processing question: {e}')
raise Ignore
def select_model_variables(tenant):
embedding_provider = tenant.embedding_model.rsplit('.', 1)[0]
embedding_model = tenant.embedding_model.rsplit('.', 1)[1]
llm_provider = tenant.llm_model.rsplit('.', 1)[0]
llm_model = tenant.llm_model.rsplit('.', 1)[1]
# Set model variables
model_variables = {}
if tenant.es_k:
model_variables['k'] = tenant.es_k
else:
model_variables['k'] = 5
if tenant.es_similarity_threshold:
model_variables['similarity_threshold'] = tenant.es_similarity_threshold
else:
model_variables['similarity_threshold'] = 0.7
if tenant.chat_RAG_temperature:
model_variables['RAG_temperature'] = tenant.chat_RAG_temperature
else:
model_variables['RAG_temperature'] = 0.3
if tenant.chat_no_RAG_temperature:
model_variables['no_RAG_temperature'] = tenant.chat_no_RAG_temperature
else:
model_variables['no_RAG_temperature'] = 0.5
# Set Embedding variables
match embedding_provider:
case 'openai':
match embedding_model:
case 'text-embedding-3-small':
api_key = current_app.config.get('OPENAI_API_KEY')
model_variables['embedding'] = OpenAIEmbeddings(api_key=api_key,
model='text-embedding-3-small')
model_variables['embedding_db_model'] = EmbeddingSmallOpenAI
case _:
raise Exception(f'Error setting model variables for tenant {tenant.id} '
f'error: Invalid embedding model')
case _:
raise Exception(f'Error setting model variables for tenant {tenant.id} '
f'error: Invalid embedding provider')
# Set Chat model variables
match llm_provider:
case 'openai':
api_key = current_app.config.get('OPENAI_API_KEY')
model_variables['llm'] = ChatOpenAI(api_key=api_key,
model=llm_model,
temperature=model_variables['RAG_temperature'])
match llm_model:
case 'gpt-4-turbo' | 'gpt-4-o':
rag_template = current_app.config.get('GPT4_RAG_TEMPLATE')
case 'gpt-3-5-turbo':
rag_template = current_app.config.get('GPT3_5_RAG_TEMPLATE')
case _:
raise Exception(f'Error setting model variables for tenant {tenant.id} '
f'error: Invalid chat model')
model_variables['prompt'] = ChatPromptTemplate.from_template(rag_template)
case _:
raise Exception(f'Error setting model variables for tenant {tenant.id} '
f'error: Invalid chat provider')
return model_variables
def create_embedding(model_variables, question):
try:
embeddings = model_variables['embedding'].embed_documents(question)
except LangChainException as e:
raise Exception(f'Error creating embedding for question (LangChain): {e}')
return embeddings[0]
def search_embeddings(model_variables, embedded_query):
current_app.logger.debug(f'In search_embeddings searching for {embedded_query}')
db_class = model_variables['embedding_db_model']
try:
res = (
db.session.query(db_class, db_class.embedding.cosine_distance(embedded_query).label('distance'))
.filter(db_class.embedding.cosine_distance(embedded_query) < model_variables['similarity_threshold'])
.order_by("distance")
.limit(model_variables['k'])
.all()
)
except SQLAlchemyError as e:
raise Exception(f'Error searching embeddings (SQLAlchemy): {e}')
current_app.logger.debug(f'Results from embedding search: {res}')
return res
def tasks_ping():
return 'pong'

View File

@@ -1,23 +0,0 @@
from .tasks import create_embeddings
from celery import Celery, Task
def init_celery(app):
class ContextTask(Task):
def __call__(self, *args, **kwargs):
with app.app_context():
return self.run(*args, **kwargs)
celery_app = Celery(app.import_name, task_cls=ContextTask)
celery_app.conf.broker_url = app.config.get('CELERY_BROKER_URL')
celery_app.conf.result_backend = app.config.get('CELERY_RESULT_BACKEND')
celery_app.conf.accept_content = app.config.get('CELERY_ACCEPT_CONTENT')
celery_app.conf.task_serializer = app.config.get('CELERY_TASK_SERIALIZER')
celery_app.conf.timezone = app.config.get('CELERY_TIMEZONE')
celery_app.conf.enable_utc = app.config.get('CELERY_ENABLE_UTC')
celery_app.set_default()
app.extensions['celery'] = celery_app

View File

@@ -0,0 +1,4 @@
from eveai_chat_workers import celery
if __name__ == '__main__':
celery.start()

View File

@@ -0,0 +1,13 @@
#!/usr/bin/env bash
cd "/Volumes/OWC4M2_1/Dropbox/Josako's Dev/Josako/EveAI/Development/eveAI/" || exit 1
source "/Volumes/OWC4M2_1/Dropbox/Josako's Dev/Josako/EveAI/Development/eveAI/.venv/bin/activate"
export PYTHONPATH="$PYTHONPATH:/Volumes/OWC4M2_1/Dropbox/Josako's Dev/Josako/EveAI/Development/eveAI/"
# Start a worker for the 'llm_interactions' queue with auto-scaling
celery -A eveai_chat_workers.celery worker --loglevel=info -Q llm_interactions --autoscale=2,8 --hostname=interactions_worker@%h &
# Wait for all background processes to finish
wait
deactivate

View File

@@ -6,8 +6,8 @@ source "/Volumes/OWC4M2_1/Dropbox/Josako's Dev/Josako/EveAI/Development/eveAI/.v
# Start a worker for the 'embeddings' queue with higher concurrency
celery -A eveai_workers.celery worker --loglevel=info -Q embeddings --autoscale=1,4 --hostname=embeddings_worker@%h &
# Start a worker for the 'llm_interactions' queue with auto-scaling
celery -A eveai_workers.celery worker --loglevel=info - Q llm_interactions --autoscale=2,8 --hostname=interactions_worker@%h &
# Start a worker for the 'llm_interactions' queue with auto-scaling - not necessary, in eveai_chat_workers
# celery -A eveai_workers.celery worker --loglevel=info - Q llm_interactions --autoscale=2,8 --hostname=interactions_worker@%h &
# Wait for all background processes to finish
wait

View File

@@ -114,9 +114,30 @@ class EveAIChatWidget extends HTMLElement {
this.socket.on('bot_response', (data) => {
if (data.tenantId === this.tenantId) {
console.log('Bot response received:', data)
console.log('Task ID received:', data.taskId)
this.addMessage(data.message, 'bot', data.messageId, data.algorithm);
this.checkTaskStatus(data.taskId)
}
});
this.socket.on('task_status', (data) => {
console.log('Task status received:', data)
console.log('Task ID received:', data.taskId)
if (data.status === 'SUCCESS') {
this.addMessage(data.result, 'bot');
} else if (data.status === 'FAILURE') {
this.addMessage('Failed to process message.', 'bot');
} else if (data.status === 'pending') {
console.log('Task is pending')
setTimeout(() => this.checkTaskStatus(data.taskId), 1000); // Poll every second
console.log('New check sent')
}
});
}
checkTaskStatus(taskId) {
this.socket.emit('check_task_status', { task_id: taskId });
}
getTemplate() {