From 27b6de87343b087011c474ede7412cd32abba3b0 Mon Sep 17 00:00:00 2001 From: Josako Date: Thu, 6 Jun 2024 15:26:49 +0200 Subject: [PATCH] Removing DocumentLanguage, as both System Context and User Context are to be defined on DocumentVersion level. Finetuning of embedding workers. --- common/langchain/EveAIRetriever.py | 50 +++++++ common/models/document.py | 40 +----- common/models/user.py | 9 +- common/utils/model_utils.py | 18 +++ config/logging_config.py | 38 ++++- .../document/document_languages.html | 24 ---- .../templates/document/document_versions.html | 3 +- eveai_app/templates/document/documents.html | 2 +- eveai_app/templates/navbar.html | 1 + eveai_app/templates/user/tenant_overview.html | 12 ++ eveai_app/views/document_forms.py | 2 +- eveai_app/views/document_views.py | 131 +++++------------- eveai_app/views/user_forms.py | 3 + eveai_app/views/user_views.py | 4 +- eveai_chat_workers/__init__.py | 2 + eveai_chat_workers/tasks.py | 123 ++-------------- eveai_workers/__init__.py | 2 + eveai_workers/tasks.py | 46 +++--- ...removing_documentlanguage_user_context_.py | 58 ++++++++ public/chat_ae.html | 25 ++++ requirements.txt | 3 +- 21 files changed, 301 insertions(+), 295 deletions(-) create mode 100644 common/langchain/EveAIRetriever.py delete mode 100644 eveai_app/templates/document/document_languages.html create mode 100644 migrations/tenant/versions/217938792642_removing_documentlanguage_user_context_.py create mode 100644 public/chat_ae.html diff --git a/common/langchain/EveAIRetriever.py b/common/langchain/EveAIRetriever.py new file mode 100644 index 0000000..4232454 --- /dev/null +++ b/common/langchain/EveAIRetriever.py @@ -0,0 +1,50 @@ +from langchain_core.retrievers import BaseRetriever +from sqlalchemy.exc import SQLAlchemyError +from pydantic import BaseModel, Field +from typing import Any, Dict + +from common.extensions import db +from flask import current_app +from config.logging_config import LOGGING + + +class EveAIRetriever(BaseRetriever): + model_variables: Dict[str, Any] = Field(...) + + def __init__(self, model_variables: Dict[str, Any]): + super().__init__() + current_app.logger.debug('Initializing EveAIRetriever') + self.model_variables = model_variables + current_app.logger.debug('EveAIRetriever initialized') + + def _get_relevant_documents(self, query: str): + current_app.logger.debug(f'Retrieving relevant documents for query: {query}') + query_embedding = self._get_query_embedding(query) + db_class = self.model_variables['embedding_db_model'] + similarity_threshold = self.model_variables['similarity_threshold'] + k = self.model_variables['k'] + try: + res = ( + db.session.query(db_class, + db_class.embedding.cosine_distance(query_embedding) + .label('distance')) + .filter(db_class.embedding.cosine_distance(query_embedding) < similarity_threshold) + .order_by('distance') + .limit(k) + .all() + ) + current_app.rag_tuning_logger.debug(f'Retrieved {len(res)} relevant documents') + current_app.rag_tuning_logger.debug(f'---------------------------------------') + for doc in res: + current_app.rag_tuning_logger.debug(f'Document ID: {doc[0].id} - Distance: {doc[1]}\n') + current_app.rag_tuning_logger.debug(f'Chunk: \n {doc[0].chunk}\n\n') + + except SQLAlchemyError as e: + current_app.logger.error(f'Error retrieving relevant documents: {e}') + return [] + return res + + def _get_query_embedding(self, query: str): + embedding_model = self.model_variables['embedding_model'] + query_embedding = embedding_model.embed_query(query) + return query_embedding diff --git a/common/models/document.py b/common/models/document.py index 2ab70f2..4b06a49 100644 --- a/common/models/document.py +++ b/common/models/document.py @@ -17,50 +17,22 @@ class Document(db.Model): updated_by = db.Column(db.Integer, db.ForeignKey(User.id)) # Relations - languages = db.relationship('DocumentLanguage', backref='document', lazy=True) + versions = db.relationship('DocumentVersion', backref='document', lazy=True) def __repr__(self): return f"" -class DocumentLanguage(db.Model): - id = db.Column(db.Integer, primary_key=True) - document_id = db.Column(db.Integer, db.ForeignKey(Document.id), nullable=False) - language = db.Column(db.String(2), nullable=False) - user_context = db.Column(db.Text, nullable=True) - system_context = db.Column(db.Text, nullable=True) - latest_version_id = db.Column(db.Integer, db.ForeignKey('document_version.id'), nullable=True) - - # Versioning Information - created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now()) - created_by = db.Column(db.Integer, db.ForeignKey(User.id), nullable=False) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now(), onupdate=db.func.now()) - updated_by = db.Column(db.Integer, db.ForeignKey(User.id)) - - # Relations - versions = db.relationship( - 'DocumentVersion', - backref='document_language', - lazy='joined', - foreign_keys='DocumentVersion.doc_lang_id' - ) - latest_version = db.relationship( - 'DocumentVersion', - uselist=False, - foreign_keys=[latest_version_id] - ) - - def __repr__(self): - return f"" - - class DocumentVersion(db.Model): id = db.Column(db.Integer, primary_key=True) - doc_lang_id = db.Column(db.Integer, db.ForeignKey(DocumentLanguage.id), nullable=False) + doc_id = db.Column(db.Integer, db.ForeignKey(Document.id), nullable=False) url = db.Column(db.String(200), nullable=True) file_location = db.Column(db.String(255), nullable=True) file_name = db.Column(db.String(200), nullable=True) file_type = db.Column(db.String(20), nullable=True) + language = db.Column(db.String(2), nullable=False) + user_context = db.Column(db.Text, nullable=True) + system_context = db.Column(db.Text, nullable=True) # Versioning Information created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now()) @@ -81,7 +53,7 @@ class DocumentVersion(db.Model): return f".{self.id}>" def calc_file_location(self): - return f"{self.document_language.document.tenant_id}/{self.document_language.document.id}/{self.document_language.language}" + return f"{self.document.tenant_id}/{self.document.id}/{self.language}" def calc_file_name(self): return f"{self.id}.{self.file_type}" diff --git a/common/models/user.py b/common/models/user.py index 0b148f3..53ec4b3 100644 --- a/common/models/user.py +++ b/common/models/user.py @@ -48,6 +48,10 @@ class Tenant(db.Model): allowed_monthly_interactions = db.Column(db.Integer, nullable=True) encrypted_chat_api_key = db.Column(db.String(500), nullable=True) + # Tuning enablers + embed_tuning = db.Column(db.Boolean, nullable=True, default=False) + rag_tuning = db.Column(db.Boolean, nullable=True, default=False) + # Relations users = db.relationship('User', backref='tenant') domains = db.relationship('TenantDomain', backref='tenant') @@ -133,7 +137,10 @@ class TenantDomain(db.Model): id = db.Column(db.Integer, primary_key=True) tenant_id = db.Column(db.Integer, db.ForeignKey('public.tenant.id'), nullable=False) - domain = db.Column(db.String(255), unique=True, nullable=False) + # Originally, domain was required to be unique. + # However, several tenants can run from the same domain (e.g. for demo purposes, + # but also internal and external chat clients. + domain = db.Column(db.String(255), nullable=False) valid_to = db.Column(db.Date, nullable=True) # Versioning Information diff --git a/common/utils/model_utils.py b/common/utils/model_utils.py index 939e741..339e076 100644 --- a/common/utils/model_utils.py +++ b/common/utils/model_utils.py @@ -2,6 +2,7 @@ from flask import current_app from langchain.embeddings import OpenAIEmbeddings from langchain.chat_models import ChatOpenAI from langchain.prompts import ChatPromptTemplate +import ast from common.models.document import EmbeddingSmallOpenAI @@ -35,6 +36,23 @@ def select_model_variables(tenant): else: model_variables['no_RAG_temperature'] = 0.5 + # Set Tuning variables + if tenant.embed_tuning: + model_variables['embed_tuning'] = tenant.embed_tuning + else: + model_variables['embed_tuning'] = False + + if tenant.rag_tuning: + model_variables['rag_tuning'] = tenant.rag_tuning + else: + model_variables['rag_tuning'] = False + + # Set HTML Chunking Variables + model_variables['html_tags'] = tenant.html_tags + model_variables['html_end_tags'] = tenant.html_end_tags + model_variables['html_included_elements'] = tenant.html_included_elements + model_variables['html_excluded_elements'] = tenant.html_excluded_elements + # Set Embedding variables match embedding_provider: case 'openai': diff --git a/config/logging_config.py b/config/logging_config.py index c1a72dd..b5c420d 100644 --- a/config/logging_config.py +++ b/config/logging_config.py @@ -58,6 +58,22 @@ LOGGING = { 'backupCount': 10, 'formatter': 'standard', }, + 'file_rag_tuning': { + 'level': 'DEBUG', + 'class': 'logging.handlers.RotatingFileHandler', + 'filename': 'logs/rag_tuning.log', + 'maxBytes': 1024*1024*5, # 5MB + 'backupCount': 10, + 'formatter': 'standard', + }, + 'file_embed_tuning': { + 'level': 'DEBUG', + 'class': 'logging.handlers.RotatingFileHandler', + 'filename': 'logs/rag_tuning.log', + 'maxBytes': 1024*1024*5, # 5MB + 'backupCount': 10, + 'formatter': 'standard', + }, 'console': { 'class': 'logging.StreamHandler', 'level': 'DEBUG', @@ -71,27 +87,27 @@ LOGGING = { }, 'loggers': { 'eveai_app': { # logger for the eveai_app - 'handlers': ['file_app', 'console'], + 'handlers': ['file_app',], 'level': 'DEBUG', 'propagate': False }, 'eveai_workers': { # logger for the eveai_workers - 'handlers': ['file_workers', 'console'], + 'handlers': ['file_workers',], 'level': 'DEBUG', 'propagate': False }, 'eveai_chat': { # logger for the eveai_chat - 'handlers': ['file_chat', 'console'], + 'handlers': ['file_chat',], 'level': 'DEBUG', 'propagate': False }, 'eveai_chat_workers': { # logger for the eveai_chat_workers - 'handlers': ['file_chat_workers', 'console'], + 'handlers': ['file_chat_workers',], 'level': 'DEBUG', 'propagate': False }, 'sqlalchemy.engine': { # logger for the sqlalchemy - 'handlers': ['file_sqlalchemy', 'console'], + 'handlers': ['file_sqlalchemy',], 'level': 'DEBUG', 'propagate': False }, @@ -105,10 +121,20 @@ LOGGING = { 'level': 'DEBUG', 'propagate': False }, + 'rag_tuning': { # logger for the rag_tuning + 'handlers': ['file_rag_tuning', 'console'], + 'level': 'DEBUG', + 'propagate': False + }, + 'embed_tuning': { # logger for the embed_tuning + 'handlers': ['file_embed_tuning', 'console'], + 'level': 'DEBUG', + 'propagate': False + }, '': { # root logger 'handlers': ['console'], 'level': 'WARNING', # Set higher level for root to minimize noise 'propagate': False - } + }, } } \ No newline at end of file diff --git a/eveai_app/templates/document/document_languages.html b/eveai_app/templates/document/document_languages.html deleted file mode 100644 index 5e7ef69..0000000 --- a/eveai_app/templates/document/document_languages.html +++ /dev/null @@ -1,24 +0,0 @@ -{% extends 'base.html' %} -{% from 'macros.html' import render_selectable_table, render_pagination %} - -{% block title %}Document Languages{% endblock %} - -{% block content_title %}Document Languages{% endblock %} -{% block content_description %}View Languages for {{ document }}{% endblock %} -{% block content_class %}
{% endblock %} - -{% block content %} -
-
- {{ render_selectable_table(headers=["Document Language ID", "Language", "User Context", "System Context"], rows=rows, selectable=True, id="documentsTable") }} -
- - -
-
-
-{% endblock %} - -{% block content_footer %} - {{ render_pagination(pagination, 'document_bp.documents') }} -{% endblock %} \ No newline at end of file diff --git a/eveai_app/templates/document/document_versions.html b/eveai_app/templates/document/document_versions.html index e02adc1..7bd07f2 100644 --- a/eveai_app/templates/document/document_versions.html +++ b/eveai_app/templates/document/document_versions.html @@ -12,7 +12,8 @@
{{ render_selectable_table(headers=["Document Version ID", "URL", "File Location", "File Name", "File Type", "Processing", "Processing Start", "Proceeing Finish"], rows=rows, selectable=True, id="versionsTable") }}
- + +
diff --git a/eveai_app/templates/document/documents.html b/eveai_app/templates/document/documents.html index 483e023..cebc150 100644 --- a/eveai_app/templates/document/documents.html +++ b/eveai_app/templates/document/documents.html @@ -13,7 +13,7 @@ {{ render_selectable_table(headers=["Document ID", "Name", "Valid From", "Valid To"], rows=rows, selectable=True, id="documentsTable") }}
- +
diff --git a/eveai_app/templates/navbar.html b/eveai_app/templates/navbar.html index b745fd3..54287ee 100644 --- a/eveai_app/templates/navbar.html +++ b/eveai_app/templates/navbar.html @@ -72,6 +72,7 @@ {'name': 'Tenant List', 'url': '/user/select_tenant', 'roles': ['Super User']}, {'name': 'Tenant Registration', 'url': '/user/tenant', 'roles': ['Super User']}, {'name': 'Tenant Overview', 'url': '/user/tenant_overview', 'roles': ['Super User', 'Tenant Admin']}, + {'name': 'Edit Tenant', 'url': '/user/tenant/' ~ session['tenant'].get('id'), 'roles': ['Super User', 'Tenant Admin']}, {'name': 'Tenant Domains', 'url': '/user/view_tenant_domains', 'roles': ['Super User', 'Tenant Admin']}, {'name': 'Tenant Domain Registration', 'url': '/user/tenant_domain', 'roles': ['Super User', 'Tenant Admin']}, {'name': 'User List', 'url': '/user/view_users', 'roles': ['Super User', 'Tenant Admin']}, diff --git a/eveai_app/templates/user/tenant_overview.html b/eveai_app/templates/user/tenant_overview.html index 4b30e87..33d39e1 100644 --- a/eveai_app/templates/user/tenant_overview.html +++ b/eveai_app/templates/user/tenant_overview.html @@ -40,6 +40,11 @@ Embedding Search +
@@ -73,6 +78,13 @@ {{ render_included_field(field, disabled_fields=es_fields, include_fields=es_fields) }} {% endfor %}
+ +
+ {% set tuning_fields = ['embed_tuning', 'rag_tuning', ] %} + {% for field in form %} + {{ render_included_field(field, disabled_fields=tuning_fields, include_fields=tuning_fields) }} + {% endfor %} +
diff --git a/eveai_app/views/document_forms.py b/eveai_app/views/document_forms.py index 7b682b9..eb52e7a 100644 --- a/eveai_app/views/document_forms.py +++ b/eveai_app/views/document_forms.py @@ -47,7 +47,7 @@ class EditDocumentForm(FlaskForm): submit = SubmitField('Submit') -class EditDocumentLanguageForm(FlaskForm): +class EditDocumentVersionForm(FlaskForm): language = StringField('Language') user_context = TextAreaField('User Context', validators=[Optional()]) system_context = TextAreaField('System Context', validators=[Optional()]) diff --git a/eveai_app/views/document_views.py b/eveai_app/views/document_views.py index 79da665..b60defc 100644 --- a/eveai_app/views/document_views.py +++ b/eveai_app/views/document_views.py @@ -13,9 +13,9 @@ from requests.exceptions import SSLError from urllib.parse import urlparse import io -from common.models.document import Document, DocumentLanguage, DocumentVersion +from common.models.document import Document, DocumentVersion from common.extensions import db -from .document_forms import AddDocumentForm, AddURLForm, EditDocumentForm, EditDocumentLanguageForm +from .document_forms import AddDocumentForm, AddURLForm, EditDocumentForm, EditDocumentVersionForm from common.utils.middleware import mw_before_request from common.utils.celery_utils import current_celery from common.utils.nginx_utils import prefixed_url_for @@ -59,7 +59,7 @@ def add_document(): filename = secure_filename(file.filename) extension = filename.rsplit('.', 1)[1].lower() - new_doc, new_doc_lang, new_doc_vers = create_document_stack(form, file, filename, extension) + new_doc, new_doc_vers = create_document_stack(form, file, filename, extension) task = current_celery.send_task('create_embeddings', queue='embeddings', args=[ session['tenant']['id'], @@ -100,7 +100,7 @@ def add_url(): filename += '.html' extension = 'html' - new_doc, new_doc_lang, new_doc_vers = create_document_stack(form, file, filename, extension) + new_doc, new_doc_vers = create_document_stack(form, file, filename, extension) task = current_celery.send_task('create_embeddings', queue='embeddings', args=[ session['tenant']['id'], @@ -146,8 +146,8 @@ def handle_document_selection(): match action: case 'edit_document': return redirect(prefixed_url_for('document_bp.edit_document', document_id=doc_id)) - case 'document_languages': - return redirect(prefixed_url_for('document_bp.document_languages', document_id=doc_id)) + case 'document_versions': + return redirect(prefixed_url_for('document_bp.document_versions', document_id=doc_id)) # Add more conditions for other actions return redirect(prefixed_url_for('document_bp.documents')) @@ -180,81 +180,46 @@ def edit_document(document_id): return render_template('document/edit_document.html', form=form, document_id=document_id) -@document_bp.route('/document_languages/', methods=['GET', 'POST']) +@document_bp.route('/edit_document_version/', methods=['GET', 'POST']) @roles_accepted('Super User', 'Tenant Admin') -def document_languages(document_id): - doc = Document.query.get_or_404(document_id) - doc_desc = f'Document {doc.id}: {doc.name}' - page = request.args.get('page', 1, type=int) - per_page = request.args.get('per_page', 10, type=int) - - query = DocumentLanguage.query.filter_by(document_id=document_id).order_by(DocumentLanguage.language) - - pagination = query.paginate(page=page, per_page=per_page, error_out=False) - doc_langs = pagination.items - - rows = prepare_table_for_macro(doc_langs, [('id', ''), ('language', ''), ('user_context', ''), - ('system_context', '')]) - - return render_template('document/document_languages.html', rows=rows, pagination=pagination, document=doc_desc) - - -@document_bp.route('/handle_document_language_selection', methods=['POST']) -@roles_accepted('Super User', 'Tenant Admin') -def handle_document_language_selection(): - document_language_identification = request.form['selected_row'] - doc_lang_id = ast.literal_eval(document_language_identification).get('value') - - action = request.form['action'] - - match action: - case 'edit_document_language': - return redirect(prefixed_url_for('document_bp.edit_document_language', document_language_id=doc_lang_id)) - case 'document_versions': - return redirect(prefixed_url_for('document_bp.document_versions', document_language_id=doc_lang_id)) - - # Add more conditions for other actions - return redirect(prefixed_url_for('document_bp.document_languages')) - - -@document_bp.route('/edit_document_language/', methods=['GET', 'POST']) -@roles_accepted('Super User', 'Tenant Admin') -def edit_document_language(document_language_id): - doc_lang = DocumentLanguage.query.get_or_404(document_language_id) - form = EditDocumentLanguageForm(obj=doc_lang) +def edit_document_version(document_version_id): + doc_vers = DocumentVersion.query.get_or_404(document_version_id) + form = EditDocumentVersionForm(obj=doc_vers) if form.validate_on_submit(): - doc_lang.user_context = form.user_context.data + doc_vers.user_context = form.user_context.data - update_logging_information(doc_lang, dt.now(tz.utc)) + update_logging_information(doc_vers, dt.now(tz.utc)) try: - db.session.add(doc_lang) + db.session.add(doc_vers) db.session.commit() - flash(f'Document Language {doc_lang.id} updated successfully', 'success') + flash(f'Document Version {doc_vers.id} updated successfully', 'success') except SQLAlchemyError as e: db.session.rollback() - flash(f'Error updating document language: {e}', 'danger') - current_app.logger.error(f'Error updating document language {doc_lang.id} ' + flash(f'Error updating document version: {e}', 'danger') + current_app.logger.error(f'Error updating document version {doc_vers.id} ' f'for tenant {session['tenant']['id']}: {e}') else: form_validation_failed(request, form) - return render_template('document/edit_document_language.html', form=form, document_langauge_id=document_language_id, - doc_details=f'Document {doc_lang.document.name}') + return render_template('document/edit_document_version.html', form=form, document_version_id=document_version_id, + doc_details=f'Document {doc_vers.document.name}') -@document_bp.route('/document_versions/', methods=['GET', 'POST']) +@document_bp.route('/document_versions/', methods=['GET', 'POST']) @roles_accepted('Super User', 'Tenant Admin') -def document_versions(document_language_id): +def document_versions(document_id): flash(f'Processing documents is a long running process. Please be careful retriggering processing!', 'danger') - doc_lang = DocumentLanguage.query.get_or_404(document_language_id) - doc_desc = f'Document {doc_lang.document.name}, Language {doc_lang.language}' + doc_vers = DocumentVersion.query.get_or_404(document_id) + doc_desc = f'Document {doc_vers.document.name}, Language {doc_vers.language}' page = request.args.get('page', 1, type=int) per_page = request.args.get('per_page', 10, type=int) - query = DocumentVersion.query.filter_by(doc_lang_id=document_language_id).order_by(desc(DocumentVersion.id)) + query = (DocumentVersion.query.filter_by(doc_id=document_id) + .order_by(DocumentVersion.language) + .order_by(desc(DocumentVersion.id))) pagination = query.paginate(page=page, per_page=per_page, error_out=False) doc_langs = pagination.items @@ -276,6 +241,8 @@ def handle_document_version_selection(): action = request.form['action'] match action: + case 'edit_document_version': + return redirect(prefixed_url_for('document_bp.edit_document_version', document_version_id=doc_vers_id)) case 'process_document_version': process_version(doc_vers_id) # Add more conditions for other actions @@ -315,17 +282,11 @@ def create_document_stack(form, file, filename, extension): # Create the Document new_doc = create_document(form, filename) - # Create the DocumentLanguage - new_doc_lang = create_language_for_document(new_doc, form.language.data, form.user_context.data) - # Create the DocumentVersion - new_doc_vers = DocumentVersion() - new_doc_vers.document_language = new_doc_lang - set_logging_information(new_doc_vers, dt.now(tz.utc)) + new_doc_vers = create_version_for_document(new_doc, form.language.data, form.user_context.data) try: db.session.add(new_doc) - db.session.add(new_doc_lang) db.session.add(new_doc_vers) db.session.commit() except SQLAlchemyError as e: @@ -338,30 +299,12 @@ def create_document_stack(form, file, filename, extension): current_app.logger.error('Unknown error') raise - try: - new_doc_lang = db.session.merge(new_doc_lang) - new_doc_vers = db.session.merge(new_doc_vers) - new_doc_lang.latest_version_id = new_doc_vers.id - db.session.commit() - except SQLAlchemyError as e: - current_app.logger.error(f'Error adding document for tenant {session["tenant"]["id"]}: {e}') - flash('Error adding document.', 'error') - db.session.rollback() - error = e.args - raise - except Exception as e: - current_app.logger.error(f'Error adding document for tenant {session["tenant"]["id"]}: {e}') - flash('Error adding document.', 'error') - db.session.rollback() - error = e.args - raise - current_app.logger.info(f'Document added successfully for tenant {session["tenant"]["id"]}, ' f'Document Version {new_doc.id}') upload_file_for_version(new_doc_vers, file, extension) - return new_doc, new_doc_lang, new_doc_vers + return new_doc, new_doc_vers def log_session_state(session, msg=""): @@ -386,21 +329,21 @@ def create_document(form, filename): return new_doc -def create_language_for_document(document, language, user_context): - new_doc_lang = DocumentLanguage() +def create_version_for_document(document, language, user_context): + new_doc_vers = DocumentVersion() if language == '': - new_doc_lang.language = session['default_language'] + new_doc_vers.language = session['default_language'] else: - new_doc_lang.language = language + new_doc_vers.language = language if user_context != '': - new_doc_lang.user_context = user_context + new_doc_vers.user_context = user_context - new_doc_lang.document = document + new_doc_vers.document = document - set_logging_information(new_doc_lang, dt.now(tz.utc)) + set_logging_information(new_doc_vers, dt.now(tz.utc)) - return new_doc_lang + return new_doc_vers def upload_file_for_version(doc_vers, file, extension): diff --git a/eveai_app/views/user_forms.py b/eveai_app/views/user_forms.py index 64ef439..692d1a6 100644 --- a/eveai_app/views/user_forms.py +++ b/eveai_app/views/user_forms.py @@ -33,6 +33,9 @@ class TenantForm(FlaskForm): es_similarity_threshold = FloatField('Similarity Threshold for Searching Embeddings (0.5)', default=0.5, validators=[NumberRange(min=0, max=1)]) + # Tuning variables + embed_tuning = BooleanField('Enable Embedding Tuning', default=False) + rag_tuning = BooleanField('Enable RAG Tuning', default=False) submit = SubmitField('Submit') diff --git a/eveai_app/views/user_views.py b/eveai_app/views/user_views.py index 72b3737..0da0679 100644 --- a/eveai_app/views/user_views.py +++ b/eveai_app/views/user_views.py @@ -53,7 +53,9 @@ def tenant(): llm_model=form.llm_model.data, license_start_date=form.license_start_date.data, license_end_date=form.license_end_date.data, - allowed_monthly_interactions=form.allowed_monthly_interactions.data) + allowed_monthly_interactions=form.allowed_monthly_interactions.data, + embed_tuning=form.embed_tuning.data, + rag_tuning=form.rag_tuning.data) # Handle Embedding Variables new_tenant.html_tags = form.html_tags.data.split(',') if form.html_tags.data else [] diff --git a/eveai_chat_workers/__init__.py b/eveai_chat_workers/__init__.py index a213962..da932bb 100644 --- a/eveai_chat_workers/__init__.py +++ b/eveai_chat_workers/__init__.py @@ -23,6 +23,8 @@ def create_app(config_file=None): celery = make_celery(app.name, app.config) init_celery(celery, app) + app.rag_tuning_logger = logging.getLogger('rag_tuning') + from eveai_chat_workers import tasks print(tasks.tasks_ping()) diff --git a/eveai_chat_workers/tasks.py b/eveai_chat_workers/tasks.py index 6cbd8c4..379f969 100644 --- a/eveai_chat_workers/tasks.py +++ b/eveai_chat_workers/tasks.py @@ -1,15 +1,12 @@ from datetime import datetime as dt, timezone as tz from flask import current_app +from langchain_core.output_parsers import StrOutputParser +from langchain_core.runnables import RunnableParallel from sqlalchemy.exc import SQLAlchemyError from celery import states from celery.exceptions import Ignore import os -# Unstructured commercial client imports -from unstructured_client import UnstructuredClient -from unstructured_client.models import shared -from unstructured_client.models.errors import SDKError - # OpenAI imports from langchain_openai import OpenAIEmbeddings, ChatOpenAI from langchain_core.prompts import ChatPromptTemplate @@ -22,14 +19,14 @@ from common.models.document import DocumentVersion, EmbeddingMistral, EmbeddingS from common.models.user import Tenant from common.extensions import db from common.utils.celery_utils import current_celery - -from bs4 import BeautifulSoup +from common.utils.model_utils import select_model_variables +from common.langchain.EveAIRetriever import EveAIRetriever @current_celery.task(name='ask_question', queue='llm_interactions') def ask_question(tenant_id, question): current_app.logger.debug('In ask_question') - current_app.logger.debug(f'ask_question: Received question for tenant {tenant_id}: {question}. Processing...') + current_app.logger.info(f'ask_question: Received question for tenant {tenant_id}: {question}. Processing...') try: # Retrieve the tenant @@ -43,115 +40,17 @@ def ask_question(tenant_id, question): # Select variables to work with depending on tenant model model_variables = select_model_variables(tenant) - # create embedding for the query - embedded_question = create_embedding(model_variables, question) + current_app.logger.debug(f'ask_question: model_variables: {model_variables}') + + retriever = EveAIRetriever(model_variables) # Search the database for relevant embeddings - relevant_embeddings = search_embeddings(model_variables, embedded_question) + relevant_embeddings = retriever.invoke(question) - response = "" - for embed in relevant_embeddings: - response += relevant_embeddings.chunk + '\n' - - return response + return 'No response yet, check back later.' except Exception as e: current_app.logger.error(f'ask_question: Error processing question: {e}') - raise Ignore - - -def select_model_variables(tenant): - embedding_provider = tenant.embedding_model.rsplit('.', 1)[0] - embedding_model = tenant.embedding_model.rsplit('.', 1)[1] - - llm_provider = tenant.llm_model.rsplit('.', 1)[0] - llm_model = tenant.llm_model.rsplit('.', 1)[1] - - # Set model variables - model_variables = {} - if tenant.es_k: - model_variables['k'] = tenant.es_k - else: - model_variables['k'] = 5 - - if tenant.es_similarity_threshold: - model_variables['similarity_threshold'] = tenant.es_similarity_threshold - else: - model_variables['similarity_threshold'] = 0.7 - - if tenant.chat_RAG_temperature: - model_variables['RAG_temperature'] = tenant.chat_RAG_temperature - else: - model_variables['RAG_temperature'] = 0.3 - - if tenant.chat_no_RAG_temperature: - model_variables['no_RAG_temperature'] = tenant.chat_no_RAG_temperature - else: - model_variables['no_RAG_temperature'] = 0.5 - - # Set Embedding variables - match embedding_provider: - case 'openai': - match embedding_model: - case 'text-embedding-3-small': - api_key = current_app.config.get('OPENAI_API_KEY') - model_variables['embedding'] = OpenAIEmbeddings(api_key=api_key, - model='text-embedding-3-small') - model_variables['embedding_db_model'] = EmbeddingSmallOpenAI - case _: - raise Exception(f'Error setting model variables for tenant {tenant.id} ' - f'error: Invalid embedding model') - case _: - raise Exception(f'Error setting model variables for tenant {tenant.id} ' - f'error: Invalid embedding provider') - - # Set Chat model variables - match llm_provider: - case 'openai': - api_key = current_app.config.get('OPENAI_API_KEY') - model_variables['llm'] = ChatOpenAI(api_key=api_key, - model=llm_model, - temperature=model_variables['RAG_temperature']) - match llm_model: - case 'gpt-4-turbo' | 'gpt-4-o': - rag_template = current_app.config.get('GPT4_RAG_TEMPLATE') - case 'gpt-3-5-turbo': - rag_template = current_app.config.get('GPT3_5_RAG_TEMPLATE') - case _: - raise Exception(f'Error setting model variables for tenant {tenant.id} ' - f'error: Invalid chat model') - model_variables['prompt'] = ChatPromptTemplate.from_template(rag_template) - case _: - raise Exception(f'Error setting model variables for tenant {tenant.id} ' - f'error: Invalid chat provider') - - return model_variables - - -def create_embedding(model_variables, question): - try: - embeddings = model_variables['embedding'].embed_documents(question) - except LangChainException as e: - raise Exception(f'Error creating embedding for question (LangChain): {e}') - - return embeddings[0] - - -def search_embeddings(model_variables, embedded_query): - current_app.logger.debug(f'In search_embeddings searching for {embedded_query}') - db_class = model_variables['embedding_db_model'] - try: - res = ( - db.session.query(db_class, db_class.embedding.cosine_distance(embedded_query).label('distance')) - .filter(db_class.embedding.cosine_distance(embedded_query) < model_variables['similarity_threshold']) - .order_by("distance") - .limit(model_variables['k']) - .all() - ) - except SQLAlchemyError as e: - raise Exception(f'Error searching embeddings (SQLAlchemy): {e}') - - current_app.logger.debug(f'Results from embedding search: {res}') - return res + raise def tasks_ping(): diff --git a/eveai_workers/__init__.py b/eveai_workers/__init__.py index 932a81c..95ef7e9 100644 --- a/eveai_workers/__init__.py +++ b/eveai_workers/__init__.py @@ -16,6 +16,8 @@ def create_app(config_file=None): app.config.from_object(config_file) logging.config.dictConfig(LOGGING) + app.embed_tuning_logger = logging.getLogger('embed_tuning') + register_extensions(app) celery = make_celery(app.name, app.config) diff --git a/eveai_workers/tasks.py b/eveai_workers/tasks.py index f9d1412..a20f202 100644 --- a/eveai_workers/tasks.py +++ b/eveai_workers/tasks.py @@ -130,13 +130,11 @@ def process_pdf(tenant, model_variables, document_version): raise summary = summarize_chunk(tenant, model_variables, document_version, chunks[0]) - doc_lang = document_version.document_language - doc_lang.system_context = f'Summary: {summary}\n' + document_version.system_context = f'Summary: {summary}\n' enriched_chunks = enrich_chunks(tenant, document_version, chunks) embeddings = embed_chunks(tenant, model_variables, document_version, enriched_chunks) try: - db.session.add(doc_lang) db.session.add(document_version) document_version.processing_finished_at = dt.now(tz.utc) document_version.processing = False @@ -156,10 +154,10 @@ def process_pdf(tenant, model_variables, document_version): def process_html(tenant, model_variables, document_version): # The tags to be considered can be dependent on the tenant - html_tags = tenant.html_tags - end_tags = tenant.html_end_tags - included_elements = tenant.html_included_elements - excluded_elements = tenant.html_excluded_elements + html_tags = model_variables['html_tags'] + html_end_tags = model_variables['html_end_tags'] + html_included_elements = model_variables['html_included_elements'] + html_excluded_elements = model_variables['html_excluded_elements'] file_path = os.path.join(current_app.config['UPLOAD_FOLDER'], document_version.file_location, @@ -175,23 +173,25 @@ def process_html(tenant, model_variables, document_version): create_embeddings.update_state(state=states.FAILURE) raise - extracted_data, title = parse_html(html_content, html_tags, included_elements=included_elements, - excluded_elements=excluded_elements) - potential_chunks = create_potential_chunks(extracted_data, end_tags) + extracted_data, title = parse_html(html_content, html_tags, included_elements=html_included_elements, + excluded_elements=html_excluded_elements) + potential_chunks = create_potential_chunks(extracted_data, html_end_tags) chunks = combine_chunks(potential_chunks, model_variables['min_chunk_size'], model_variables['max_chunk_size'] ) - summary = summarize_chunk(tenant, model_variables, document_version, chunks[0]) - doc_lang = document_version.document_language - doc_lang.system_context = (f'Title: {title}\n' - f'Summary: {summary}\n') + + if len(chunks) > 0: + summary = summarize_chunk(tenant, model_variables, document_version, chunks[0]) + document_version.system_context = (f'Title: {title}\n' + f'Summary: {summary}\n') + else: + document_version.system_context = (f'Title: {title}\n') enriched_chunks = enrich_chunks(tenant, document_version, chunks) embeddings = embed_chunks(tenant, model_variables, document_version, enriched_chunks) try: - db.session.add(doc_lang) db.session.add(document_version) document_version.processing_finished_at = dt.now(tz.utc) document_version.processing = False @@ -210,12 +210,14 @@ def process_html(tenant, model_variables, document_version): def enrich_chunks(tenant, document_version, chunks): current_app.logger.debug(f'Enriching chunks for tenant {tenant.id} ' f'on document version {document_version.id}') - doc_lang = document_version.document_language chunk_total_context = (f'Filename: {document_version.file_name}\n' - f'{doc_lang.system_context}\n' - f'User Context:\n{doc_lang.user_context}') + f'User Context:{document_version.user_context}\n' + f'{document_version.system_context}\n\n') enriched_chunks = [] - initial_chunk = f'Filename: {document_version.file_name}\n User Context:\n{doc_lang.user_context}\n{chunks[0]}' + initial_chunk = (f'Filename: {document_version.file_name}\n' + f'User Context:\n{document_version.user_context}\n\n' + f'{chunks[0]}') + enriched_chunks.append(initial_chunk) for chunk in chunks[1:]: enriched_chunk = f'{chunk_total_context}\n{chunk}' @@ -313,6 +315,12 @@ def parse_html(html_content, tags, included_elements=None, excluded_elements=Non else: elements_to_parse = [soup] # parse the entire document if no included_elements specified + current_app.embed_tuning_logger.debug(f'Included Elements: {included_elements}') + current_app.embed_tuning_logger.debug(f'Included Elements: {len(included_elements)}') + current_app.embed_tuning_logger.debug(f'Excluded Elements: {excluded_elements}') + current_app.embed_tuning_logger.debug(f'Found {len(elements_to_parse)} elements to parse') + current_app.embed_tuning_logger.debug(f'{elements_to_parse}') + # Iterate through the found included elements for element in elements_to_parse: # Find all specified tags within each included element diff --git a/migrations/tenant/versions/217938792642_removing_documentlanguage_user_context_.py b/migrations/tenant/versions/217938792642_removing_documentlanguage_user_context_.py new file mode 100644 index 0000000..2661fe1 --- /dev/null +++ b/migrations/tenant/versions/217938792642_removing_documentlanguage_user_context_.py @@ -0,0 +1,58 @@ +"""Removing DocumentLanguage (user_context and system_context correct on DocumentVersion) + +Revision ID: 217938792642 +Revises: 53fee8f13cdb +Create Date: 2024-06-06 14:16:03.288734 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '217938792642' +down_revision = '53fee8f13cdb' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint('document_version_doc_lang_id_fkey', 'document_version', type_='foreignkey') + op.drop_table('document_language') + op.add_column('document_version', sa.Column('doc_id', sa.Integer(), nullable=False)) + op.add_column('document_version', sa.Column('language', sa.String(length=2), nullable=False)) + op.add_column('document_version', sa.Column('user_context', sa.Text(), nullable=True)) + op.add_column('document_version', sa.Column('system_context', sa.Text(), nullable=True)) + op.create_foreign_key(None, 'document_version', 'document', ['doc_id'], ['id']) + op.drop_column('document_version', 'doc_lang_id') + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('document_version', sa.Column('doc_lang_id', sa.INTEGER(), autoincrement=False, nullable=False)) + op.drop_constraint(None, 'document_version', type_='foreignkey') + op.create_foreign_key('document_version_doc_lang_id_fkey', 'document_version', 'document_language', ['doc_lang_id'], ['id']) + op.drop_column('document_version', 'system_context') + op.drop_column('document_version', 'user_context') + op.drop_column('document_version', 'language') + op.drop_column('document_version', 'doc_id') + op.create_table('document_language', + sa.Column('id', sa.INTEGER(), autoincrement=True, nullable=False), + sa.Column('document_id', sa.INTEGER(), autoincrement=False, nullable=False), + sa.Column('language', sa.VARCHAR(length=2), autoincrement=False, nullable=False), + sa.Column('created_at', postgresql.TIMESTAMP(), server_default=sa.text('now()'), autoincrement=False, nullable=False), + sa.Column('created_by', sa.INTEGER(), autoincrement=False, nullable=False), + sa.Column('updated_at', postgresql.TIMESTAMP(), server_default=sa.text('now()'), autoincrement=False, nullable=False), + sa.Column('updated_by', sa.INTEGER(), autoincrement=False, nullable=True), + sa.Column('latest_version_id', sa.INTEGER(), autoincrement=False, nullable=True), + sa.Column('user_context', sa.TEXT(), autoincrement=False, nullable=True), + sa.Column('system_context', sa.TEXT(), autoincrement=False, nullable=True), + sa.ForeignKeyConstraint(['created_by'], ['public.user.id'], name='document_language_created_by_fkey'), + sa.ForeignKeyConstraint(['document_id'], ['document.id'], name='document_language_document_id_fkey'), + sa.ForeignKeyConstraint(['latest_version_id'], ['document_version.id'], name='document_language_latest_version_id_fkey'), + sa.ForeignKeyConstraint(['updated_by'], ['public.user.id'], name='document_language_updated_by_fkey'), + sa.PrimaryKeyConstraint('id', name='document_language_pkey') + ) + # ### end Alembic commands ### diff --git a/public/chat_ae.html b/public/chat_ae.html new file mode 100644 index 0000000..8603b44 --- /dev/null +++ b/public/chat_ae.html @@ -0,0 +1,25 @@ + + + + + + Chat Client AE + + + + + + +
+ + + \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 8b2abc7..a5a0063 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,4 +12,5 @@ requests~=2.31.0 beautifulsoup4~=4.12.3 google~=3.0.0 redis~=5.0.4 -itsdangerous~=2.2.0 \ No newline at end of file +itsdangerous~=2.2.0 +pydantic~=2.7.1 \ No newline at end of file