Removing DocumentLanguage, as both System Context and User Context are to be defined on DocumentVersion level.

Finetuning of embedding workers.
This commit is contained in:
Josako
2024-06-06 15:26:49 +02:00
parent 1a25313673
commit 27b6de8734
21 changed files with 301 additions and 295 deletions

View File

@@ -0,0 +1,50 @@
from langchain_core.retrievers import BaseRetriever
from sqlalchemy.exc import SQLAlchemyError
from pydantic import BaseModel, Field
from typing import Any, Dict
from common.extensions import db
from flask import current_app
from config.logging_config import LOGGING
class EveAIRetriever(BaseRetriever):
model_variables: Dict[str, Any] = Field(...)
def __init__(self, model_variables: Dict[str, Any]):
super().__init__()
current_app.logger.debug('Initializing EveAIRetriever')
self.model_variables = model_variables
current_app.logger.debug('EveAIRetriever initialized')
def _get_relevant_documents(self, query: str):
current_app.logger.debug(f'Retrieving relevant documents for query: {query}')
query_embedding = self._get_query_embedding(query)
db_class = self.model_variables['embedding_db_model']
similarity_threshold = self.model_variables['similarity_threshold']
k = self.model_variables['k']
try:
res = (
db.session.query(db_class,
db_class.embedding.cosine_distance(query_embedding)
.label('distance'))
.filter(db_class.embedding.cosine_distance(query_embedding) < similarity_threshold)
.order_by('distance')
.limit(k)
.all()
)
current_app.rag_tuning_logger.debug(f'Retrieved {len(res)} relevant documents')
current_app.rag_tuning_logger.debug(f'---------------------------------------')
for doc in res:
current_app.rag_tuning_logger.debug(f'Document ID: {doc[0].id} - Distance: {doc[1]}\n')
current_app.rag_tuning_logger.debug(f'Chunk: \n {doc[0].chunk}\n\n')
except SQLAlchemyError as e:
current_app.logger.error(f'Error retrieving relevant documents: {e}')
return []
return res
def _get_query_embedding(self, query: str):
embedding_model = self.model_variables['embedding_model']
query_embedding = embedding_model.embed_query(query)
return query_embedding

View File

@@ -17,50 +17,22 @@ class Document(db.Model):
updated_by = db.Column(db.Integer, db.ForeignKey(User.id)) updated_by = db.Column(db.Integer, db.ForeignKey(User.id))
# Relations # Relations
languages = db.relationship('DocumentLanguage', backref='document', lazy=True) versions = db.relationship('DocumentVersion', backref='document', lazy=True)
def __repr__(self): def __repr__(self):
return f"<Document {self.id}: {self.name}>" return f"<Document {self.id}: {self.name}>"
class DocumentLanguage(db.Model):
id = db.Column(db.Integer, primary_key=True)
document_id = db.Column(db.Integer, db.ForeignKey(Document.id), nullable=False)
language = db.Column(db.String(2), nullable=False)
user_context = db.Column(db.Text, nullable=True)
system_context = db.Column(db.Text, nullable=True)
latest_version_id = db.Column(db.Integer, db.ForeignKey('document_version.id'), nullable=True)
# Versioning Information
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now())
created_by = db.Column(db.Integer, db.ForeignKey(User.id), nullable=False)
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now(), onupdate=db.func.now())
updated_by = db.Column(db.Integer, db.ForeignKey(User.id))
# Relations
versions = db.relationship(
'DocumentVersion',
backref='document_language',
lazy='joined',
foreign_keys='DocumentVersion.doc_lang_id'
)
latest_version = db.relationship(
'DocumentVersion',
uselist=False,
foreign_keys=[latest_version_id]
)
def __repr__(self):
return f"<DocumentLanguage {self.document_id}.{self.language}>"
class DocumentVersion(db.Model): class DocumentVersion(db.Model):
id = db.Column(db.Integer, primary_key=True) id = db.Column(db.Integer, primary_key=True)
doc_lang_id = db.Column(db.Integer, db.ForeignKey(DocumentLanguage.id), nullable=False) doc_id = db.Column(db.Integer, db.ForeignKey(Document.id), nullable=False)
url = db.Column(db.String(200), nullable=True) url = db.Column(db.String(200), nullable=True)
file_location = db.Column(db.String(255), nullable=True) file_location = db.Column(db.String(255), nullable=True)
file_name = db.Column(db.String(200), nullable=True) file_name = db.Column(db.String(200), nullable=True)
file_type = db.Column(db.String(20), nullable=True) file_type = db.Column(db.String(20), nullable=True)
language = db.Column(db.String(2), nullable=False)
user_context = db.Column(db.Text, nullable=True)
system_context = db.Column(db.Text, nullable=True)
# Versioning Information # Versioning Information
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now()) created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now())
@@ -81,7 +53,7 @@ class DocumentVersion(db.Model):
return f"<DocumentVersion {self.document_language.document_id}.{self.document_language.language}>.{self.id}>" return f"<DocumentVersion {self.document_language.document_id}.{self.document_language.language}>.{self.id}>"
def calc_file_location(self): def calc_file_location(self):
return f"{self.document_language.document.tenant_id}/{self.document_language.document.id}/{self.document_language.language}" return f"{self.document.tenant_id}/{self.document.id}/{self.language}"
def calc_file_name(self): def calc_file_name(self):
return f"{self.id}.{self.file_type}" return f"{self.id}.{self.file_type}"

View File

@@ -48,6 +48,10 @@ class Tenant(db.Model):
allowed_monthly_interactions = db.Column(db.Integer, nullable=True) allowed_monthly_interactions = db.Column(db.Integer, nullable=True)
encrypted_chat_api_key = db.Column(db.String(500), nullable=True) encrypted_chat_api_key = db.Column(db.String(500), nullable=True)
# Tuning enablers
embed_tuning = db.Column(db.Boolean, nullable=True, default=False)
rag_tuning = db.Column(db.Boolean, nullable=True, default=False)
# Relations # Relations
users = db.relationship('User', backref='tenant') users = db.relationship('User', backref='tenant')
domains = db.relationship('TenantDomain', backref='tenant') domains = db.relationship('TenantDomain', backref='tenant')
@@ -133,7 +137,10 @@ class TenantDomain(db.Model):
id = db.Column(db.Integer, primary_key=True) id = db.Column(db.Integer, primary_key=True)
tenant_id = db.Column(db.Integer, db.ForeignKey('public.tenant.id'), nullable=False) tenant_id = db.Column(db.Integer, db.ForeignKey('public.tenant.id'), nullable=False)
domain = db.Column(db.String(255), unique=True, nullable=False) # Originally, domain was required to be unique.
# However, several tenants can run from the same domain (e.g. for demo purposes,
# but also internal and external chat clients.
domain = db.Column(db.String(255), nullable=False)
valid_to = db.Column(db.Date, nullable=True) valid_to = db.Column(db.Date, nullable=True)
# Versioning Information # Versioning Information

View File

@@ -2,6 +2,7 @@ from flask import current_app
from langchain.embeddings import OpenAIEmbeddings from langchain.embeddings import OpenAIEmbeddings
from langchain.chat_models import ChatOpenAI from langchain.chat_models import ChatOpenAI
from langchain.prompts import ChatPromptTemplate from langchain.prompts import ChatPromptTemplate
import ast
from common.models.document import EmbeddingSmallOpenAI from common.models.document import EmbeddingSmallOpenAI
@@ -35,6 +36,23 @@ def select_model_variables(tenant):
else: else:
model_variables['no_RAG_temperature'] = 0.5 model_variables['no_RAG_temperature'] = 0.5
# Set Tuning variables
if tenant.embed_tuning:
model_variables['embed_tuning'] = tenant.embed_tuning
else:
model_variables['embed_tuning'] = False
if tenant.rag_tuning:
model_variables['rag_tuning'] = tenant.rag_tuning
else:
model_variables['rag_tuning'] = False
# Set HTML Chunking Variables
model_variables['html_tags'] = tenant.html_tags
model_variables['html_end_tags'] = tenant.html_end_tags
model_variables['html_included_elements'] = tenant.html_included_elements
model_variables['html_excluded_elements'] = tenant.html_excluded_elements
# Set Embedding variables # Set Embedding variables
match embedding_provider: match embedding_provider:
case 'openai': case 'openai':

View File

@@ -58,6 +58,22 @@ LOGGING = {
'backupCount': 10, 'backupCount': 10,
'formatter': 'standard', 'formatter': 'standard',
}, },
'file_rag_tuning': {
'level': 'DEBUG',
'class': 'logging.handlers.RotatingFileHandler',
'filename': 'logs/rag_tuning.log',
'maxBytes': 1024*1024*5, # 5MB
'backupCount': 10,
'formatter': 'standard',
},
'file_embed_tuning': {
'level': 'DEBUG',
'class': 'logging.handlers.RotatingFileHandler',
'filename': 'logs/rag_tuning.log',
'maxBytes': 1024*1024*5, # 5MB
'backupCount': 10,
'formatter': 'standard',
},
'console': { 'console': {
'class': 'logging.StreamHandler', 'class': 'logging.StreamHandler',
'level': 'DEBUG', 'level': 'DEBUG',
@@ -71,27 +87,27 @@ LOGGING = {
}, },
'loggers': { 'loggers': {
'eveai_app': { # logger for the eveai_app 'eveai_app': { # logger for the eveai_app
'handlers': ['file_app', 'console'], 'handlers': ['file_app',],
'level': 'DEBUG', 'level': 'DEBUG',
'propagate': False 'propagate': False
}, },
'eveai_workers': { # logger for the eveai_workers 'eveai_workers': { # logger for the eveai_workers
'handlers': ['file_workers', 'console'], 'handlers': ['file_workers',],
'level': 'DEBUG', 'level': 'DEBUG',
'propagate': False 'propagate': False
}, },
'eveai_chat': { # logger for the eveai_chat 'eveai_chat': { # logger for the eveai_chat
'handlers': ['file_chat', 'console'], 'handlers': ['file_chat',],
'level': 'DEBUG', 'level': 'DEBUG',
'propagate': False 'propagate': False
}, },
'eveai_chat_workers': { # logger for the eveai_chat_workers 'eveai_chat_workers': { # logger for the eveai_chat_workers
'handlers': ['file_chat_workers', 'console'], 'handlers': ['file_chat_workers',],
'level': 'DEBUG', 'level': 'DEBUG',
'propagate': False 'propagate': False
}, },
'sqlalchemy.engine': { # logger for the sqlalchemy 'sqlalchemy.engine': { # logger for the sqlalchemy
'handlers': ['file_sqlalchemy', 'console'], 'handlers': ['file_sqlalchemy',],
'level': 'DEBUG', 'level': 'DEBUG',
'propagate': False 'propagate': False
}, },
@@ -105,10 +121,20 @@ LOGGING = {
'level': 'DEBUG', 'level': 'DEBUG',
'propagate': False 'propagate': False
}, },
'rag_tuning': { # logger for the rag_tuning
'handlers': ['file_rag_tuning', 'console'],
'level': 'DEBUG',
'propagate': False
},
'embed_tuning': { # logger for the embed_tuning
'handlers': ['file_embed_tuning', 'console'],
'level': 'DEBUG',
'propagate': False
},
'': { # root logger '': { # root logger
'handlers': ['console'], 'handlers': ['console'],
'level': 'WARNING', # Set higher level for root to minimize noise 'level': 'WARNING', # Set higher level for root to minimize noise
'propagate': False 'propagate': False
} },
} }
} }

View File

@@ -1,24 +0,0 @@
{% extends 'base.html' %}
{% from 'macros.html' import render_selectable_table, render_pagination %}
{% block title %}Document Languages{% endblock %}
{% block content_title %}Document Languages{% endblock %}
{% block content_description %}View Languages for {{ document }}{% endblock %}
{% block content_class %}<div class="col-xl-12 col-lg-5 col-md-7 mx-auto">{% endblock %}
{% block content %}
<div class="container">
<form method="POST" action="{{ url_for('document_bp.handle_document_language_selection') }}">
{{ render_selectable_table(headers=["Document Language ID", "Language", "User Context", "System Context"], rows=rows, selectable=True, id="documentsTable") }}
<div class="form-group mt-3">
<button type="submit" name="action" value="edit_document_language" class="btn btn-primary">Edit Document Language</button>
<button type="submit" name="action" value="document_versions" class="btn btn-secondary">Show Document Versions</button>
</div>
</form>
</div>
{% endblock %}
{% block content_footer %}
{{ render_pagination(pagination, 'document_bp.documents') }}
{% endblock %}

View File

@@ -12,7 +12,8 @@
<form method="POST" action="{{ url_for('document_bp.handle_document_version_selection') }}"> <form method="POST" action="{{ url_for('document_bp.handle_document_version_selection') }}">
{{ render_selectable_table(headers=["Document Version ID", "URL", "File Location", "File Name", "File Type", "Processing", "Processing Start", "Proceeing Finish"], rows=rows, selectable=True, id="versionsTable") }} {{ render_selectable_table(headers=["Document Version ID", "URL", "File Location", "File Name", "File Type", "Processing", "Processing Start", "Proceeing Finish"], rows=rows, selectable=True, id="versionsTable") }}
<div class="form-group mt-3"> <div class="form-group mt-3">
<button type="submit" name="action" value="process_document_version" class="btn btn-primary">Process Document Version</button> <button type="submit" name="action" value="edit_document_version" class="btn btn-primary">Edit Document Version</button>
<button type="submit" name="action" value="process_document_version" class="btn btn-secondary">Process Document Version</button>
</div> </div>
</form> </form>
</div> </div>

View File

@@ -13,7 +13,7 @@
{{ render_selectable_table(headers=["Document ID", "Name", "Valid From", "Valid To"], rows=rows, selectable=True, id="documentsTable") }} {{ render_selectable_table(headers=["Document ID", "Name", "Valid From", "Valid To"], rows=rows, selectable=True, id="documentsTable") }}
<div class="form-group mt-3"> <div class="form-group mt-3">
<button type="submit" name="action" value="edit_document" class="btn btn-primary">Edit Document</button> <button type="submit" name="action" value="edit_document" class="btn btn-primary">Edit Document</button>
<button type="submit" name="action" value="document_languages" class="btn btn-secondary">Show Document Languages</button> <button type="submit" name="action" value="document_versions" class="btn btn-secondary">Show Document Versions</button>
</div> </div>
</form> </form>
</div> </div>

View File

@@ -72,6 +72,7 @@
{'name': 'Tenant List', 'url': '/user/select_tenant', 'roles': ['Super User']}, {'name': 'Tenant List', 'url': '/user/select_tenant', 'roles': ['Super User']},
{'name': 'Tenant Registration', 'url': '/user/tenant', 'roles': ['Super User']}, {'name': 'Tenant Registration', 'url': '/user/tenant', 'roles': ['Super User']},
{'name': 'Tenant Overview', 'url': '/user/tenant_overview', 'roles': ['Super User', 'Tenant Admin']}, {'name': 'Tenant Overview', 'url': '/user/tenant_overview', 'roles': ['Super User', 'Tenant Admin']},
{'name': 'Edit Tenant', 'url': '/user/tenant/' ~ session['tenant'].get('id'), 'roles': ['Super User', 'Tenant Admin']},
{'name': 'Tenant Domains', 'url': '/user/view_tenant_domains', 'roles': ['Super User', 'Tenant Admin']}, {'name': 'Tenant Domains', 'url': '/user/view_tenant_domains', 'roles': ['Super User', 'Tenant Admin']},
{'name': 'Tenant Domain Registration', 'url': '/user/tenant_domain', 'roles': ['Super User', 'Tenant Admin']}, {'name': 'Tenant Domain Registration', 'url': '/user/tenant_domain', 'roles': ['Super User', 'Tenant Admin']},
{'name': 'User List', 'url': '/user/view_users', 'roles': ['Super User', 'Tenant Admin']}, {'name': 'User List', 'url': '/user/view_users', 'roles': ['Super User', 'Tenant Admin']},

View File

@@ -40,6 +40,11 @@
Embedding Search Embedding Search
</a> </a>
</li> </li>
<li class="nav-item">
<a class="nav-link mb-0 px-0 py-1" data-toggle="tab" href="#tuning-tab" role="tab" aria-controls="html-chunking" aria-selected="false">
Tuning
</a>
</li>
</ul> </ul>
</div> </div>
<div class="tab-content tab-space"> <div class="tab-content tab-space">
@@ -73,6 +78,13 @@
{{ render_included_field(field, disabled_fields=es_fields, include_fields=es_fields) }} {{ render_included_field(field, disabled_fields=es_fields, include_fields=es_fields) }}
{% endfor %} {% endfor %}
</div> </div>
<!-- Tuning Settings Tab -->
<div class="tab-pane fade" id="tuning-tab" role="tabpanel">
{% set tuning_fields = ['embed_tuning', 'rag_tuning', ] %}
{% for field in form %}
{{ render_included_field(field, disabled_fields=tuning_fields, include_fields=tuning_fields) }}
{% endfor %}
</div>
</div> </div>
</div> </div>
</div> </div>

View File

@@ -47,7 +47,7 @@ class EditDocumentForm(FlaskForm):
submit = SubmitField('Submit') submit = SubmitField('Submit')
class EditDocumentLanguageForm(FlaskForm): class EditDocumentVersionForm(FlaskForm):
language = StringField('Language') language = StringField('Language')
user_context = TextAreaField('User Context', validators=[Optional()]) user_context = TextAreaField('User Context', validators=[Optional()])
system_context = TextAreaField('System Context', validators=[Optional()]) system_context = TextAreaField('System Context', validators=[Optional()])

View File

@@ -13,9 +13,9 @@ from requests.exceptions import SSLError
from urllib.parse import urlparse from urllib.parse import urlparse
import io import io
from common.models.document import Document, DocumentLanguage, DocumentVersion from common.models.document import Document, DocumentVersion
from common.extensions import db from common.extensions import db
from .document_forms import AddDocumentForm, AddURLForm, EditDocumentForm, EditDocumentLanguageForm from .document_forms import AddDocumentForm, AddURLForm, EditDocumentForm, EditDocumentVersionForm
from common.utils.middleware import mw_before_request from common.utils.middleware import mw_before_request
from common.utils.celery_utils import current_celery from common.utils.celery_utils import current_celery
from common.utils.nginx_utils import prefixed_url_for from common.utils.nginx_utils import prefixed_url_for
@@ -59,7 +59,7 @@ def add_document():
filename = secure_filename(file.filename) filename = secure_filename(file.filename)
extension = filename.rsplit('.', 1)[1].lower() extension = filename.rsplit('.', 1)[1].lower()
new_doc, new_doc_lang, new_doc_vers = create_document_stack(form, file, filename, extension) new_doc, new_doc_vers = create_document_stack(form, file, filename, extension)
task = current_celery.send_task('create_embeddings', queue='embeddings', args=[ task = current_celery.send_task('create_embeddings', queue='embeddings', args=[
session['tenant']['id'], session['tenant']['id'],
@@ -100,7 +100,7 @@ def add_url():
filename += '.html' filename += '.html'
extension = 'html' extension = 'html'
new_doc, new_doc_lang, new_doc_vers = create_document_stack(form, file, filename, extension) new_doc, new_doc_vers = create_document_stack(form, file, filename, extension)
task = current_celery.send_task('create_embeddings', queue='embeddings', args=[ task = current_celery.send_task('create_embeddings', queue='embeddings', args=[
session['tenant']['id'], session['tenant']['id'],
@@ -146,8 +146,8 @@ def handle_document_selection():
match action: match action:
case 'edit_document': case 'edit_document':
return redirect(prefixed_url_for('document_bp.edit_document', document_id=doc_id)) return redirect(prefixed_url_for('document_bp.edit_document', document_id=doc_id))
case 'document_languages': case 'document_versions':
return redirect(prefixed_url_for('document_bp.document_languages', document_id=doc_id)) return redirect(prefixed_url_for('document_bp.document_versions', document_id=doc_id))
# Add more conditions for other actions # Add more conditions for other actions
return redirect(prefixed_url_for('document_bp.documents')) return redirect(prefixed_url_for('document_bp.documents'))
@@ -180,81 +180,46 @@ def edit_document(document_id):
return render_template('document/edit_document.html', form=form, document_id=document_id) return render_template('document/edit_document.html', form=form, document_id=document_id)
@document_bp.route('/document_languages/<int:document_id>', methods=['GET', 'POST']) @document_bp.route('/edit_document_version/<int:document_version_id>', methods=['GET', 'POST'])
@roles_accepted('Super User', 'Tenant Admin') @roles_accepted('Super User', 'Tenant Admin')
def document_languages(document_id): def edit_document_version(document_version_id):
doc = Document.query.get_or_404(document_id) doc_vers = DocumentVersion.query.get_or_404(document_version_id)
doc_desc = f'Document {doc.id}: {doc.name}' form = EditDocumentVersionForm(obj=doc_vers)
page = request.args.get('page', 1, type=int)
per_page = request.args.get('per_page', 10, type=int)
query = DocumentLanguage.query.filter_by(document_id=document_id).order_by(DocumentLanguage.language)
pagination = query.paginate(page=page, per_page=per_page, error_out=False)
doc_langs = pagination.items
rows = prepare_table_for_macro(doc_langs, [('id', ''), ('language', ''), ('user_context', ''),
('system_context', '')])
return render_template('document/document_languages.html', rows=rows, pagination=pagination, document=doc_desc)
@document_bp.route('/handle_document_language_selection', methods=['POST'])
@roles_accepted('Super User', 'Tenant Admin')
def handle_document_language_selection():
document_language_identification = request.form['selected_row']
doc_lang_id = ast.literal_eval(document_language_identification).get('value')
action = request.form['action']
match action:
case 'edit_document_language':
return redirect(prefixed_url_for('document_bp.edit_document_language', document_language_id=doc_lang_id))
case 'document_versions':
return redirect(prefixed_url_for('document_bp.document_versions', document_language_id=doc_lang_id))
# Add more conditions for other actions
return redirect(prefixed_url_for('document_bp.document_languages'))
@document_bp.route('/edit_document_language/<int:document_language_id>', methods=['GET', 'POST'])
@roles_accepted('Super User', 'Tenant Admin')
def edit_document_language(document_language_id):
doc_lang = DocumentLanguage.query.get_or_404(document_language_id)
form = EditDocumentLanguageForm(obj=doc_lang)
if form.validate_on_submit(): if form.validate_on_submit():
doc_lang.user_context = form.user_context.data doc_vers.user_context = form.user_context.data
update_logging_information(doc_lang, dt.now(tz.utc)) update_logging_information(doc_vers, dt.now(tz.utc))
try: try:
db.session.add(doc_lang) db.session.add(doc_vers)
db.session.commit() db.session.commit()
flash(f'Document Language {doc_lang.id} updated successfully', 'success') flash(f'Document Version {doc_vers.id} updated successfully', 'success')
except SQLAlchemyError as e: except SQLAlchemyError as e:
db.session.rollback() db.session.rollback()
flash(f'Error updating document language: {e}', 'danger') flash(f'Error updating document version: {e}', 'danger')
current_app.logger.error(f'Error updating document language {doc_lang.id} ' current_app.logger.error(f'Error updating document version {doc_vers.id} '
f'for tenant {session['tenant']['id']}: {e}') f'for tenant {session['tenant']['id']}: {e}')
else: else:
form_validation_failed(request, form) form_validation_failed(request, form)
return render_template('document/edit_document_language.html', form=form, document_langauge_id=document_language_id, return render_template('document/edit_document_version.html', form=form, document_version_id=document_version_id,
doc_details=f'Document {doc_lang.document.name}') doc_details=f'Document {doc_vers.document.name}')
@document_bp.route('/document_versions/<int:document_language_id>', methods=['GET', 'POST']) @document_bp.route('/document_versions/<int:document_id>', methods=['GET', 'POST'])
@roles_accepted('Super User', 'Tenant Admin') @roles_accepted('Super User', 'Tenant Admin')
def document_versions(document_language_id): def document_versions(document_id):
flash(f'Processing documents is a long running process. Please be careful retriggering processing!', 'danger') flash(f'Processing documents is a long running process. Please be careful retriggering processing!', 'danger')
doc_lang = DocumentLanguage.query.get_or_404(document_language_id) doc_vers = DocumentVersion.query.get_or_404(document_id)
doc_desc = f'Document {doc_lang.document.name}, Language {doc_lang.language}' doc_desc = f'Document {doc_vers.document.name}, Language {doc_vers.language}'
page = request.args.get('page', 1, type=int) page = request.args.get('page', 1, type=int)
per_page = request.args.get('per_page', 10, type=int) per_page = request.args.get('per_page', 10, type=int)
query = DocumentVersion.query.filter_by(doc_lang_id=document_language_id).order_by(desc(DocumentVersion.id)) query = (DocumentVersion.query.filter_by(doc_id=document_id)
.order_by(DocumentVersion.language)
.order_by(desc(DocumentVersion.id)))
pagination = query.paginate(page=page, per_page=per_page, error_out=False) pagination = query.paginate(page=page, per_page=per_page, error_out=False)
doc_langs = pagination.items doc_langs = pagination.items
@@ -276,6 +241,8 @@ def handle_document_version_selection():
action = request.form['action'] action = request.form['action']
match action: match action:
case 'edit_document_version':
return redirect(prefixed_url_for('document_bp.edit_document_version', document_version_id=doc_vers_id))
case 'process_document_version': case 'process_document_version':
process_version(doc_vers_id) process_version(doc_vers_id)
# Add more conditions for other actions # Add more conditions for other actions
@@ -315,17 +282,11 @@ def create_document_stack(form, file, filename, extension):
# Create the Document # Create the Document
new_doc = create_document(form, filename) new_doc = create_document(form, filename)
# Create the DocumentLanguage
new_doc_lang = create_language_for_document(new_doc, form.language.data, form.user_context.data)
# Create the DocumentVersion # Create the DocumentVersion
new_doc_vers = DocumentVersion() new_doc_vers = create_version_for_document(new_doc, form.language.data, form.user_context.data)
new_doc_vers.document_language = new_doc_lang
set_logging_information(new_doc_vers, dt.now(tz.utc))
try: try:
db.session.add(new_doc) db.session.add(new_doc)
db.session.add(new_doc_lang)
db.session.add(new_doc_vers) db.session.add(new_doc_vers)
db.session.commit() db.session.commit()
except SQLAlchemyError as e: except SQLAlchemyError as e:
@@ -338,30 +299,12 @@ def create_document_stack(form, file, filename, extension):
current_app.logger.error('Unknown error') current_app.logger.error('Unknown error')
raise raise
try:
new_doc_lang = db.session.merge(new_doc_lang)
new_doc_vers = db.session.merge(new_doc_vers)
new_doc_lang.latest_version_id = new_doc_vers.id
db.session.commit()
except SQLAlchemyError as e:
current_app.logger.error(f'Error adding document for tenant {session["tenant"]["id"]}: {e}')
flash('Error adding document.', 'error')
db.session.rollback()
error = e.args
raise
except Exception as e:
current_app.logger.error(f'Error adding document for tenant {session["tenant"]["id"]}: {e}')
flash('Error adding document.', 'error')
db.session.rollback()
error = e.args
raise
current_app.logger.info(f'Document added successfully for tenant {session["tenant"]["id"]}, ' current_app.logger.info(f'Document added successfully for tenant {session["tenant"]["id"]}, '
f'Document Version {new_doc.id}') f'Document Version {new_doc.id}')
upload_file_for_version(new_doc_vers, file, extension) upload_file_for_version(new_doc_vers, file, extension)
return new_doc, new_doc_lang, new_doc_vers return new_doc, new_doc_vers
def log_session_state(session, msg=""): def log_session_state(session, msg=""):
@@ -386,21 +329,21 @@ def create_document(form, filename):
return new_doc return new_doc
def create_language_for_document(document, language, user_context): def create_version_for_document(document, language, user_context):
new_doc_lang = DocumentLanguage() new_doc_vers = DocumentVersion()
if language == '': if language == '':
new_doc_lang.language = session['default_language'] new_doc_vers.language = session['default_language']
else: else:
new_doc_lang.language = language new_doc_vers.language = language
if user_context != '': if user_context != '':
new_doc_lang.user_context = user_context new_doc_vers.user_context = user_context
new_doc_lang.document = document new_doc_vers.document = document
set_logging_information(new_doc_lang, dt.now(tz.utc)) set_logging_information(new_doc_vers, dt.now(tz.utc))
return new_doc_lang return new_doc_vers
def upload_file_for_version(doc_vers, file, extension): def upload_file_for_version(doc_vers, file, extension):

View File

@@ -33,6 +33,9 @@ class TenantForm(FlaskForm):
es_similarity_threshold = FloatField('Similarity Threshold for Searching Embeddings (0.5)', es_similarity_threshold = FloatField('Similarity Threshold for Searching Embeddings (0.5)',
default=0.5, default=0.5,
validators=[NumberRange(min=0, max=1)]) validators=[NumberRange(min=0, max=1)])
# Tuning variables
embed_tuning = BooleanField('Enable Embedding Tuning', default=False)
rag_tuning = BooleanField('Enable RAG Tuning', default=False)
submit = SubmitField('Submit') submit = SubmitField('Submit')

View File

@@ -53,7 +53,9 @@ def tenant():
llm_model=form.llm_model.data, llm_model=form.llm_model.data,
license_start_date=form.license_start_date.data, license_start_date=form.license_start_date.data,
license_end_date=form.license_end_date.data, license_end_date=form.license_end_date.data,
allowed_monthly_interactions=form.allowed_monthly_interactions.data) allowed_monthly_interactions=form.allowed_monthly_interactions.data,
embed_tuning=form.embed_tuning.data,
rag_tuning=form.rag_tuning.data)
# Handle Embedding Variables # Handle Embedding Variables
new_tenant.html_tags = form.html_tags.data.split(',') if form.html_tags.data else [] new_tenant.html_tags = form.html_tags.data.split(',') if form.html_tags.data else []

View File

@@ -23,6 +23,8 @@ def create_app(config_file=None):
celery = make_celery(app.name, app.config) celery = make_celery(app.name, app.config)
init_celery(celery, app) init_celery(celery, app)
app.rag_tuning_logger = logging.getLogger('rag_tuning')
from eveai_chat_workers import tasks from eveai_chat_workers import tasks
print(tasks.tasks_ping()) print(tasks.tasks_ping())

View File

@@ -1,15 +1,12 @@
from datetime import datetime as dt, timezone as tz from datetime import datetime as dt, timezone as tz
from flask import current_app from flask import current_app
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableParallel
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
from celery import states from celery import states
from celery.exceptions import Ignore from celery.exceptions import Ignore
import os import os
# Unstructured commercial client imports
from unstructured_client import UnstructuredClient
from unstructured_client.models import shared
from unstructured_client.models.errors import SDKError
# OpenAI imports # OpenAI imports
from langchain_openai import OpenAIEmbeddings, ChatOpenAI from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate from langchain_core.prompts import ChatPromptTemplate
@@ -22,14 +19,14 @@ from common.models.document import DocumentVersion, EmbeddingMistral, EmbeddingS
from common.models.user import Tenant from common.models.user import Tenant
from common.extensions import db from common.extensions import db
from common.utils.celery_utils import current_celery from common.utils.celery_utils import current_celery
from common.utils.model_utils import select_model_variables
from bs4 import BeautifulSoup from common.langchain.EveAIRetriever import EveAIRetriever
@current_celery.task(name='ask_question', queue='llm_interactions') @current_celery.task(name='ask_question', queue='llm_interactions')
def ask_question(tenant_id, question): def ask_question(tenant_id, question):
current_app.logger.debug('In ask_question') current_app.logger.debug('In ask_question')
current_app.logger.debug(f'ask_question: Received question for tenant {tenant_id}: {question}. Processing...') current_app.logger.info(f'ask_question: Received question for tenant {tenant_id}: {question}. Processing...')
try: try:
# Retrieve the tenant # Retrieve the tenant
@@ -43,115 +40,17 @@ def ask_question(tenant_id, question):
# Select variables to work with depending on tenant model # Select variables to work with depending on tenant model
model_variables = select_model_variables(tenant) model_variables = select_model_variables(tenant)
# create embedding for the query current_app.logger.debug(f'ask_question: model_variables: {model_variables}')
embedded_question = create_embedding(model_variables, question)
retriever = EveAIRetriever(model_variables)
# Search the database for relevant embeddings # Search the database for relevant embeddings
relevant_embeddings = search_embeddings(model_variables, embedded_question) relevant_embeddings = retriever.invoke(question)
response = "" return 'No response yet, check back later.'
for embed in relevant_embeddings:
response += relevant_embeddings.chunk + '\n'
return response
except Exception as e: except Exception as e:
current_app.logger.error(f'ask_question: Error processing question: {e}') current_app.logger.error(f'ask_question: Error processing question: {e}')
raise Ignore raise
def select_model_variables(tenant):
embedding_provider = tenant.embedding_model.rsplit('.', 1)[0]
embedding_model = tenant.embedding_model.rsplit('.', 1)[1]
llm_provider = tenant.llm_model.rsplit('.', 1)[0]
llm_model = tenant.llm_model.rsplit('.', 1)[1]
# Set model variables
model_variables = {}
if tenant.es_k:
model_variables['k'] = tenant.es_k
else:
model_variables['k'] = 5
if tenant.es_similarity_threshold:
model_variables['similarity_threshold'] = tenant.es_similarity_threshold
else:
model_variables['similarity_threshold'] = 0.7
if tenant.chat_RAG_temperature:
model_variables['RAG_temperature'] = tenant.chat_RAG_temperature
else:
model_variables['RAG_temperature'] = 0.3
if tenant.chat_no_RAG_temperature:
model_variables['no_RAG_temperature'] = tenant.chat_no_RAG_temperature
else:
model_variables['no_RAG_temperature'] = 0.5
# Set Embedding variables
match embedding_provider:
case 'openai':
match embedding_model:
case 'text-embedding-3-small':
api_key = current_app.config.get('OPENAI_API_KEY')
model_variables['embedding'] = OpenAIEmbeddings(api_key=api_key,
model='text-embedding-3-small')
model_variables['embedding_db_model'] = EmbeddingSmallOpenAI
case _:
raise Exception(f'Error setting model variables for tenant {tenant.id} '
f'error: Invalid embedding model')
case _:
raise Exception(f'Error setting model variables for tenant {tenant.id} '
f'error: Invalid embedding provider')
# Set Chat model variables
match llm_provider:
case 'openai':
api_key = current_app.config.get('OPENAI_API_KEY')
model_variables['llm'] = ChatOpenAI(api_key=api_key,
model=llm_model,
temperature=model_variables['RAG_temperature'])
match llm_model:
case 'gpt-4-turbo' | 'gpt-4-o':
rag_template = current_app.config.get('GPT4_RAG_TEMPLATE')
case 'gpt-3-5-turbo':
rag_template = current_app.config.get('GPT3_5_RAG_TEMPLATE')
case _:
raise Exception(f'Error setting model variables for tenant {tenant.id} '
f'error: Invalid chat model')
model_variables['prompt'] = ChatPromptTemplate.from_template(rag_template)
case _:
raise Exception(f'Error setting model variables for tenant {tenant.id} '
f'error: Invalid chat provider')
return model_variables
def create_embedding(model_variables, question):
try:
embeddings = model_variables['embedding'].embed_documents(question)
except LangChainException as e:
raise Exception(f'Error creating embedding for question (LangChain): {e}')
return embeddings[0]
def search_embeddings(model_variables, embedded_query):
current_app.logger.debug(f'In search_embeddings searching for {embedded_query}')
db_class = model_variables['embedding_db_model']
try:
res = (
db.session.query(db_class, db_class.embedding.cosine_distance(embedded_query).label('distance'))
.filter(db_class.embedding.cosine_distance(embedded_query) < model_variables['similarity_threshold'])
.order_by("distance")
.limit(model_variables['k'])
.all()
)
except SQLAlchemyError as e:
raise Exception(f'Error searching embeddings (SQLAlchemy): {e}')
current_app.logger.debug(f'Results from embedding search: {res}')
return res
def tasks_ping(): def tasks_ping():

View File

@@ -16,6 +16,8 @@ def create_app(config_file=None):
app.config.from_object(config_file) app.config.from_object(config_file)
logging.config.dictConfig(LOGGING) logging.config.dictConfig(LOGGING)
app.embed_tuning_logger = logging.getLogger('embed_tuning')
register_extensions(app) register_extensions(app)
celery = make_celery(app.name, app.config) celery = make_celery(app.name, app.config)

View File

@@ -130,13 +130,11 @@ def process_pdf(tenant, model_variables, document_version):
raise raise
summary = summarize_chunk(tenant, model_variables, document_version, chunks[0]) summary = summarize_chunk(tenant, model_variables, document_version, chunks[0])
doc_lang = document_version.document_language document_version.system_context = f'Summary: {summary}\n'
doc_lang.system_context = f'Summary: {summary}\n'
enriched_chunks = enrich_chunks(tenant, document_version, chunks) enriched_chunks = enrich_chunks(tenant, document_version, chunks)
embeddings = embed_chunks(tenant, model_variables, document_version, enriched_chunks) embeddings = embed_chunks(tenant, model_variables, document_version, enriched_chunks)
try: try:
db.session.add(doc_lang)
db.session.add(document_version) db.session.add(document_version)
document_version.processing_finished_at = dt.now(tz.utc) document_version.processing_finished_at = dt.now(tz.utc)
document_version.processing = False document_version.processing = False
@@ -156,10 +154,10 @@ def process_pdf(tenant, model_variables, document_version):
def process_html(tenant, model_variables, document_version): def process_html(tenant, model_variables, document_version):
# The tags to be considered can be dependent on the tenant # The tags to be considered can be dependent on the tenant
html_tags = tenant.html_tags html_tags = model_variables['html_tags']
end_tags = tenant.html_end_tags html_end_tags = model_variables['html_end_tags']
included_elements = tenant.html_included_elements html_included_elements = model_variables['html_included_elements']
excluded_elements = tenant.html_excluded_elements html_excluded_elements = model_variables['html_excluded_elements']
file_path = os.path.join(current_app.config['UPLOAD_FOLDER'], file_path = os.path.join(current_app.config['UPLOAD_FOLDER'],
document_version.file_location, document_version.file_location,
@@ -175,23 +173,25 @@ def process_html(tenant, model_variables, document_version):
create_embeddings.update_state(state=states.FAILURE) create_embeddings.update_state(state=states.FAILURE)
raise raise
extracted_data, title = parse_html(html_content, html_tags, included_elements=included_elements, extracted_data, title = parse_html(html_content, html_tags, included_elements=html_included_elements,
excluded_elements=excluded_elements) excluded_elements=html_excluded_elements)
potential_chunks = create_potential_chunks(extracted_data, end_tags) potential_chunks = create_potential_chunks(extracted_data, html_end_tags)
chunks = combine_chunks(potential_chunks, chunks = combine_chunks(potential_chunks,
model_variables['min_chunk_size'], model_variables['min_chunk_size'],
model_variables['max_chunk_size'] model_variables['max_chunk_size']
) )
if len(chunks) > 0:
summary = summarize_chunk(tenant, model_variables, document_version, chunks[0]) summary = summarize_chunk(tenant, model_variables, document_version, chunks[0])
doc_lang = document_version.document_language document_version.system_context = (f'Title: {title}\n'
doc_lang.system_context = (f'Title: {title}\n'
f'Summary: {summary}\n') f'Summary: {summary}\n')
else:
document_version.system_context = (f'Title: {title}\n')
enriched_chunks = enrich_chunks(tenant, document_version, chunks) enriched_chunks = enrich_chunks(tenant, document_version, chunks)
embeddings = embed_chunks(tenant, model_variables, document_version, enriched_chunks) embeddings = embed_chunks(tenant, model_variables, document_version, enriched_chunks)
try: try:
db.session.add(doc_lang)
db.session.add(document_version) db.session.add(document_version)
document_version.processing_finished_at = dt.now(tz.utc) document_version.processing_finished_at = dt.now(tz.utc)
document_version.processing = False document_version.processing = False
@@ -210,12 +210,14 @@ def process_html(tenant, model_variables, document_version):
def enrich_chunks(tenant, document_version, chunks): def enrich_chunks(tenant, document_version, chunks):
current_app.logger.debug(f'Enriching chunks for tenant {tenant.id} ' current_app.logger.debug(f'Enriching chunks for tenant {tenant.id} '
f'on document version {document_version.id}') f'on document version {document_version.id}')
doc_lang = document_version.document_language
chunk_total_context = (f'Filename: {document_version.file_name}\n' chunk_total_context = (f'Filename: {document_version.file_name}\n'
f'{doc_lang.system_context}\n' f'User Context:{document_version.user_context}\n'
f'User Context:\n{doc_lang.user_context}') f'{document_version.system_context}\n\n')
enriched_chunks = [] enriched_chunks = []
initial_chunk = f'Filename: {document_version.file_name}\n User Context:\n{doc_lang.user_context}\n{chunks[0]}' initial_chunk = (f'Filename: {document_version.file_name}\n'
f'User Context:\n{document_version.user_context}\n\n'
f'{chunks[0]}')
enriched_chunks.append(initial_chunk) enriched_chunks.append(initial_chunk)
for chunk in chunks[1:]: for chunk in chunks[1:]:
enriched_chunk = f'{chunk_total_context}\n{chunk}' enriched_chunk = f'{chunk_total_context}\n{chunk}'
@@ -313,6 +315,12 @@ def parse_html(html_content, tags, included_elements=None, excluded_elements=Non
else: else:
elements_to_parse = [soup] # parse the entire document if no included_elements specified elements_to_parse = [soup] # parse the entire document if no included_elements specified
current_app.embed_tuning_logger.debug(f'Included Elements: {included_elements}')
current_app.embed_tuning_logger.debug(f'Included Elements: {len(included_elements)}')
current_app.embed_tuning_logger.debug(f'Excluded Elements: {excluded_elements}')
current_app.embed_tuning_logger.debug(f'Found {len(elements_to_parse)} elements to parse')
current_app.embed_tuning_logger.debug(f'{elements_to_parse}')
# Iterate through the found included elements # Iterate through the found included elements
for element in elements_to_parse: for element in elements_to_parse:
# Find all specified tags within each included element # Find all specified tags within each included element

View File

@@ -0,0 +1,58 @@
"""Removing DocumentLanguage (user_context and system_context correct on DocumentVersion)
Revision ID: 217938792642
Revises: 53fee8f13cdb
Create Date: 2024-06-06 14:16:03.288734
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = '217938792642'
down_revision = '53fee8f13cdb'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint('document_version_doc_lang_id_fkey', 'document_version', type_='foreignkey')
op.drop_table('document_language')
op.add_column('document_version', sa.Column('doc_id', sa.Integer(), nullable=False))
op.add_column('document_version', sa.Column('language', sa.String(length=2), nullable=False))
op.add_column('document_version', sa.Column('user_context', sa.Text(), nullable=True))
op.add_column('document_version', sa.Column('system_context', sa.Text(), nullable=True))
op.create_foreign_key(None, 'document_version', 'document', ['doc_id'], ['id'])
op.drop_column('document_version', 'doc_lang_id')
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('document_version', sa.Column('doc_lang_id', sa.INTEGER(), autoincrement=False, nullable=False))
op.drop_constraint(None, 'document_version', type_='foreignkey')
op.create_foreign_key('document_version_doc_lang_id_fkey', 'document_version', 'document_language', ['doc_lang_id'], ['id'])
op.drop_column('document_version', 'system_context')
op.drop_column('document_version', 'user_context')
op.drop_column('document_version', 'language')
op.drop_column('document_version', 'doc_id')
op.create_table('document_language',
sa.Column('id', sa.INTEGER(), autoincrement=True, nullable=False),
sa.Column('document_id', sa.INTEGER(), autoincrement=False, nullable=False),
sa.Column('language', sa.VARCHAR(length=2), autoincrement=False, nullable=False),
sa.Column('created_at', postgresql.TIMESTAMP(), server_default=sa.text('now()'), autoincrement=False, nullable=False),
sa.Column('created_by', sa.INTEGER(), autoincrement=False, nullable=False),
sa.Column('updated_at', postgresql.TIMESTAMP(), server_default=sa.text('now()'), autoincrement=False, nullable=False),
sa.Column('updated_by', sa.INTEGER(), autoincrement=False, nullable=True),
sa.Column('latest_version_id', sa.INTEGER(), autoincrement=False, nullable=True),
sa.Column('user_context', sa.TEXT(), autoincrement=False, nullable=True),
sa.Column('system_context', sa.TEXT(), autoincrement=False, nullable=True),
sa.ForeignKeyConstraint(['created_by'], ['public.user.id'], name='document_language_created_by_fkey'),
sa.ForeignKeyConstraint(['document_id'], ['document.id'], name='document_language_document_id_fkey'),
sa.ForeignKeyConstraint(['latest_version_id'], ['document_version.id'], name='document_language_latest_version_id_fkey'),
sa.ForeignKeyConstraint(['updated_by'], ['public.user.id'], name='document_language_updated_by_fkey'),
sa.PrimaryKeyConstraint('id', name='document_language_pkey')
)
# ### end Alembic commands ###

25
public/chat_ae.html Normal file
View File

@@ -0,0 +1,25 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Chat Client AE</title>
<link rel="stylesheet" href="/static/css/eveai-chat-style.css">
<script src="https://cdn.socket.io/4.0.1/socket.io.min.js"></script>
<script src="/static/js/eveai-sdk.js" defer></script>
<script src="/static/js/eveai-chat-widget.js" defer></script>
</head>
<body>
<div id="chat-container"></div>
<script>
document.addEventListener('DOMContentLoaded', function() {
const eveAI = new EveAI(
'39',
'EveAI-CHAT-6919-1265-9848-6655-9870',
'http://macstudio.ask-eve-ai-local.com'
);
eveAI.initializeChat('chat-container');
});
</script>
</body>
</html>

View File

@@ -13,3 +13,4 @@ beautifulsoup4~=4.12.3
google~=3.0.0 google~=3.0.0
redis~=5.0.4 redis~=5.0.4
itsdangerous~=2.2.0 itsdangerous~=2.2.0
pydantic~=2.7.1