Improving chat functionality significantly throughout the application.

This commit is contained in:
Josako
2024-06-12 11:07:18 +02:00
parent 27b6de8734
commit be311c440b
22 changed files with 604 additions and 127 deletions

View File

@@ -1,11 +1,13 @@
from langchain_core.retrievers import BaseRetriever from langchain_core.retrievers import BaseRetriever
from sqlalchemy import func, and_, or_
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing import Any, Dict from typing import Any, Dict
from flask import current_app
from datetime import date
from common.extensions import db from common.extensions import db
from flask import current_app from common.models.document import Document, DocumentVersion, Embedding
from config.logging_config import LOGGING
class EveAIRetriever(BaseRetriever): class EveAIRetriever(BaseRetriever):
@@ -23,26 +25,53 @@ class EveAIRetriever(BaseRetriever):
db_class = self.model_variables['embedding_db_model'] db_class = self.model_variables['embedding_db_model']
similarity_threshold = self.model_variables['similarity_threshold'] similarity_threshold = self.model_variables['similarity_threshold']
k = self.model_variables['k'] k = self.model_variables['k']
try: try:
res = ( current_date = date.today()
# Subquery to find the latest version of each document
subquery = (
db.session.query(
DocumentVersion.doc_id,
func.max(DocumentVersion.id).label('latest_version_id')
)
.group_by(DocumentVersion.doc_id)
.subquery()
)
# Main query to filter embeddings
query_obj = (
db.session.query(db_class, db.session.query(db_class,
db_class.embedding.cosine_distance(query_embedding) db_class.embedding.cosine_distance(query_embedding).label('distance'))
.label('distance')) .join(DocumentVersion, db_class.doc_vers_id == DocumentVersion.id)
.filter(db_class.embedding.cosine_distance(query_embedding) < similarity_threshold) .join(Document, DocumentVersion.doc_id == Document.id)
.join(subquery, DocumentVersion.id == subquery.c.latest_version_id)
.filter(
or_(Document.valid_from.is_(None), Document.valid_from <= current_date),
or_(Document.valid_to.is_(None), Document.valid_to >= current_date),
db_class.embedding.cosine_distance(query_embedding) < similarity_threshold
)
.order_by('distance') .order_by('distance')
.limit(k) .limit(k)
.all()
) )
current_app.rag_tuning_logger.debug(f'Retrieved {len(res)} relevant documents')
current_app.rag_tuning_logger.debug(f'---------------------------------------') # Print the generated SQL statement for debugging
current_app.logger.debug("SQL Statement:\n")
current_app.logger.debug(query_obj.statement.compile(compile_kwargs={"literal_binds": True}))
res = query_obj.all()
# current_app.rag_tuning_logger.debug(f'Retrieved {len(res)} relevant documents')
# current_app.rag_tuning_logger.debug(f'---------------------------------------')
result = []
for doc in res: for doc in res:
current_app.rag_tuning_logger.debug(f'Document ID: {doc[0].id} - Distance: {doc[1]}\n') # current_app.rag_tuning_logger.debug(f'Document ID: {doc[0].id} - Distance: {doc[1]}\n')
current_app.rag_tuning_logger.debug(f'Chunk: \n {doc[0].chunk}\n\n') # current_app.rag_tuning_logger.debug(f'Chunk: \n {doc[0].chunk}\n\n')
result.append(f'SOURCE: {doc[0].id}\n\n{doc[0].chunk}\n\n')
except SQLAlchemyError as e: except SQLAlchemyError as e:
current_app.logger.error(f'Error retrieving relevant documents: {e}') current_app.logger.error(f'Error retrieving relevant documents: {e}')
db.session.rollback()
return [] return []
return res return result
def _get_query_embedding(self, query: str): def _get_query_embedding(self, query: str):
embedding_model = self.model_variables['embedding_model'] embedding_model = self.model_variables['embedding_model']

View File

@@ -6,6 +6,7 @@ from .document import Embedding
class ChatSession(db.Model): class ChatSession(db.Model):
id = db.Column(db.Integer, primary_key=True) id = db.Column(db.Integer, primary_key=True)
user_id = db.Column(db.Integer, db.ForeignKey(User.id), nullable=True) user_id = db.Column(db.Integer, db.ForeignKey(User.id), nullable=True)
session_id = db.Column(db.String(36), nullable=True)
session_start = db.Column(db.DateTime, nullable=False) session_start = db.Column(db.DateTime, nullable=False)
session_end = db.Column(db.DateTime, nullable=True) session_end = db.Column(db.DateTime, nullable=True)
@@ -21,6 +22,7 @@ class Interaction(db.Model):
chat_session_id = db.Column(db.Integer, db.ForeignKey(ChatSession.id), nullable=False) chat_session_id = db.Column(db.Integer, db.ForeignKey(ChatSession.id), nullable=False)
question = db.Column(db.Text, nullable=False) question = db.Column(db.Text, nullable=False)
answer = db.Column(db.Text, nullable=True) answer = db.Column(db.Text, nullable=True)
algorithm_used = db.Column(db.String(20), nullable=True)
language = db.Column(db.String(2), nullable=False) language = db.Column(db.String(2), nullable=False)
appreciation = db.Column(db.Integer, nullable=True, default=100) appreciation = db.Column(db.Integer, nullable=True, default=100)
@@ -28,6 +30,9 @@ class Interaction(db.Model):
question_at = db.Column(db.DateTime, nullable=False) question_at = db.Column(db.DateTime, nullable=False)
answer_at = db.Column(db.DateTime, nullable=True) answer_at = db.Column(db.DateTime, nullable=True)
# Relations
embeddings = db.relationship('InteractionEmbedding', backref='interaction', lazy=True)
def __repr__(self): def __repr__(self):
return f"<Interaction {self.id}>" return f"<Interaction {self.id}>"

View File

@@ -19,6 +19,7 @@ class Tenant(db.Model):
id = db.Column(db.Integer, primary_key=True) id = db.Column(db.Integer, primary_key=True)
name = db.Column(db.String(80), unique=True, nullable=False) name = db.Column(db.String(80), unique=True, nullable=False)
website = db.Column(db.String(255), nullable=True) website = db.Column(db.String(255), nullable=True)
timezone = db.Column(db.String(50), nullable=True, default='UTC')
# language information # language information
default_language = db.Column(db.String(2), nullable=True) default_language = db.Column(db.String(2), nullable=True)
@@ -70,7 +71,9 @@ class Tenant(db.Model):
'llm_model': self.llm_model, 'llm_model': self.llm_model,
'license_start_date': self.license_start_date, 'license_start_date': self.license_start_date,
'license_end_date': self.license_end_date, 'license_end_date': self.license_end_date,
'allowed_monthly_interactions': self.allowed_monthly_interactions 'allowed_monthly_interactions': self.allowed_monthly_interactions,
'embed_tuning': self.embed_tuning,
'rag_tuning': self.rag_tuning,
} }

View File

@@ -1,12 +1,32 @@
import langcodes
from flask import current_app from flask import current_app
from langchain.embeddings import OpenAIEmbeddings from langchain_community.embeddings import OpenAIEmbeddings
from langchain.chat_models import ChatOpenAI from langchain_openai import ChatOpenAI
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain.prompts import ChatPromptTemplate from langchain.prompts import ChatPromptTemplate
import ast import ast
from typing import List
from common.models.document import EmbeddingSmallOpenAI from common.models.document import EmbeddingSmallOpenAI
class CitedAnswer(BaseModel):
"""Default docstring - to be replaced with actual prompt"""
answer: str = Field(
...,
description="The answer to the user question, based on the given sources",
)
citations: List[int] = Field(
...,
description="The integer IDs of the SPECIFIC sources that were used to generate the answer"
)
def set_language_prompt_template(cls, language_prompt):
cls.__doc__ = language_prompt
def select_model_variables(tenant): def select_model_variables(tenant):
embedding_provider = tenant.embedding_model.rsplit('.', 1)[0] embedding_provider = tenant.embedding_model.rsplit('.', 1)[0]
embedding_model = tenant.embedding_model.rsplit('.', 1)[1] embedding_model = tenant.embedding_model.rsplit('.', 1)[1]
@@ -60,7 +80,7 @@ def select_model_variables(tenant):
case 'text-embedding-3-small': case 'text-embedding-3-small':
api_key = current_app.config.get('OPENAI_API_KEY') api_key = current_app.config.get('OPENAI_API_KEY')
model_variables['embedding_model'] = OpenAIEmbeddings(api_key=api_key, model_variables['embedding_model'] = OpenAIEmbeddings(api_key=api_key,
model='text-embedding-3-small') model='text-embedding-3-small')
model_variables['embedding_db_model'] = EmbeddingSmallOpenAI model_variables['embedding_db_model'] = EmbeddingSmallOpenAI
model_variables['min_chunk_size'] = current_app.config.get('OAI_TE3S_MIN_CHUNK_SIZE') model_variables['min_chunk_size'] = current_app.config.get('OAI_TE3S_MIN_CHUNK_SIZE')
model_variables['max_chunk_size'] = current_app.config.get('OAI_TE3S_MAX_CHUNK_SIZE') model_variables['max_chunk_size'] = current_app.config.get('OAI_TE3S_MAX_CHUNK_SIZE')
@@ -78,20 +98,34 @@ def select_model_variables(tenant):
model_variables['llm'] = ChatOpenAI(api_key=api_key, model_variables['llm'] = ChatOpenAI(api_key=api_key,
model=llm_model, model=llm_model,
temperature=model_variables['RAG_temperature']) temperature=model_variables['RAG_temperature'])
tool_calling_supported = False
match llm_model: match llm_model:
case 'gpt-4-turbo' | 'gpt-4o': case 'gpt-4-turbo' | 'gpt-4o':
summary_template = current_app.config.get('GPT4_SUMMARY_TEMPLATE') summary_template = current_app.config.get('GPT4_SUMMARY_TEMPLATE')
rag_template = current_app.config.get('GPT4_RAG_TEMPLATE') rag_template = current_app.config.get('GPT4_RAG_TEMPLATE')
tool_calling_supported = True
case 'gpt-3-5-turbo': case 'gpt-3-5-turbo':
summary_template = current_app.config.get('GPT3_5_SUMMARY_TEMPLATE') summary_template = current_app.config.get('GPT3_5_SUMMARY_TEMPLATE')
rag_template = current_app.config.get('GPT3_5_RAG_TEMPLATE') rag_template = current_app.config.get('GPT3_5_RAG_TEMPLATE')
case _: case _:
raise Exception(f'Error setting model variables for tenant {tenant.id} ' raise Exception(f'Error setting model variables for tenant {tenant.id} '
f'error: Invalid chat model') f'error: Invalid chat model')
model_variables['summary_prompt'] = ChatPromptTemplate.from_template(summary_template) model_variables['summary_template'] = summary_template
model_variables['rag_prompt'] = ChatPromptTemplate.from_template(rag_template) model_variables['rag_template'] = rag_template
if tool_calling_supported:
model_variables['cited_answer_cls'] = CitedAnswer
case _: case _:
raise Exception(f'Error setting model variables for tenant {tenant.id} ' raise Exception(f'Error setting model variables for tenant {tenant.id} '
f'error: Invalid chat provider') f'error: Invalid chat provider')
return model_variables return model_variables
def create_language_template(template, language):
try:
full_language = langcodes.Language.make(language=language)
language_template = template.replace('{language}', full_language.display_name())
except ValueError:
language_template = template.replace('{language}', language)
return language_template

View File

@@ -73,26 +73,24 @@ class Config(object):
OAI_TE3S_MAX_CHUNK_SIZE = 3000 OAI_TE3S_MAX_CHUNK_SIZE = 3000
# LLM TEMPLATES # LLM TEMPLATES
GPT4_SUMMARY_TEMPLATE = """Write a concise summary of the text in the same language as the provided text. GPT4_SUMMARY_TEMPLATE = """Write a concise summary of the text in {language}. The text is delimited between triple backquotes.
Text is delimited between triple backquotes.
```{text}```""" ```{text}```"""
GPT3_5_SUMMARY_TEMPLATE = """Write a concise summary of the text in the same language as the provided text. GPT3_5_SUMMARY_TEMPLATE = """Write a concise summary of the text in {language}. The text is delimited between triple backquotes.
Text is delimited between triple backquotes.
```{text}```""" ```{text}```"""
GPT4_RAG_TEMPLATE = """Answer the question based on the following context, both delimited between triple backquotes GPT4_RAG_TEMPLATE = """Answer the question based on the following context, delimited between triple backquotes.
in the same language as question. Use the following {language} in your communication, and cite the sources used.
If the question cannot be answered using the text, say "I don't know" in the same language as the question. If the question cannot be answered using the given context, say "I have insufficient information to answer this question."
Context: Context:
```{context}``` ```{context}```
Question: Question:
```{question}```""" {question}"""
GPT3_5_RAG_TEMPLATE = """Answer the question based on the following context, both delimited between triple backquotes GPT3_5_RAG_TEMPLATE = """Answer the question based on the following context, delimited between triple backquotes.
in the same language as question. Use the following {language} in your communication.
If the question cannot be answered using the text, say "I don't know" in the same language as the question. If the question cannot be answered using the given context, say "I have insufficient information to answer this question."
Context: Context:
```{context}``` ```{context}```
Question: Question:
```{question}```""" {question}"""
# SocketIO settings # SocketIO settings
# SOCKETIO_ASYNC_MODE = 'threading' # SOCKETIO_ASYNC_MODE = 'threading'
@@ -105,6 +103,14 @@ class Config(object):
PERMANENT_SESSION_LIFETIME = timedelta(minutes=60) PERMANENT_SESSION_LIFETIME = timedelta(minutes=60)
SESSION_REFRESH_EACH_REQUEST = True SESSION_REFRESH_EACH_REQUEST = True
# Interaction algorithms
INTERACTION_ALGORITHMS = {
"RAG_TENANT": {"name": "RAG_TENANT", "description": "Algorithm using only information provided by the tenant"},
"RAG_WIKIPEDIA": {"name": "RAG_WIKIPEDIA", "description": "Algorithm using information provided by Wikipedia"},
"RAG_GOOGLE": {"name": "RAG_GOOGLE", "description": "Algorithm using information provided by Google"},
"LLM": {"name": "LLM", "description": "Algorithm using information integrated in the used LLM"}
}
class DevConfig(Config): class DevConfig(Config):
DEVELOPMENT = True DEVELOPMENT = True

View File

@@ -69,7 +69,7 @@ LOGGING = {
'file_embed_tuning': { 'file_embed_tuning': {
'level': 'DEBUG', 'level': 'DEBUG',
'class': 'logging.handlers.RotatingFileHandler', 'class': 'logging.handlers.RotatingFileHandler',
'filename': 'logs/rag_tuning.log', 'filename': 'logs/embed_tuning.log',
'maxBytes': 1024*1024*5, # 5MB 'maxBytes': 1024*1024*5, # 5MB
'backupCount': 10, 'backupCount': 10,
'formatter': 'standard', 'formatter': 'standard',

View File

@@ -10,10 +10,10 @@
{% block content %} {% block content %}
<div class="container"> <div class="container">
<form method="POST" action="{{ url_for('document_bp.handle_document_version_selection') }}"> <form method="POST" action="{{ url_for('document_bp.handle_document_version_selection') }}">
{{ render_selectable_table(headers=["Document Version ID", "URL", "File Location", "File Name", "File Type", "Processing", "Processing Start", "Proceeing Finish"], rows=rows, selectable=True, id="versionsTable") }} {{ render_selectable_table(headers=["ID", "URL", "File Loc.", "File Name", "File Type", "Process.", "Proces. Start", "Proces. Finish", "Proces. Error"], rows=rows, selectable=True, id="versionsTable") }}
<div class="form-group mt-3"> <div class="form-group mt-3">
<button type="submit" name="action" value="edit_document_version" class="btn btn-primary">Edit Document Version</button> <button type="submit" name="action" value="edit_document_version" class="btn btn-primary">Edit Document Version</button>
<button type="submit" name="action" value="process_document_version" class="btn btn-secondary">Process Document Version</button> <button type="submit" name="action" value="process_document_version" class="btn btn-danger">Process Document Version</button>
</div> </div>
</form> </form>
</div> </div>

View File

@@ -5,7 +5,7 @@
{% block content_title %}Documents{% endblock %} {% block content_title %}Documents{% endblock %}
{% block content_description %}View Documents for Tenant{% endblock %} {% block content_description %}View Documents for Tenant{% endblock %}
{% block content_class %}<div class="col-xl-12 col-lg-5 col-md-7 mx-auto">{% endblock %} {% block content_class %}<div class="col-xl-12 col-lg-5 col-md-7 mx-auto"></div>{% endblock %}
{% block content %} {% block content %}
<div class="container"> <div class="container">
@@ -14,6 +14,7 @@
<div class="form-group mt-3"> <div class="form-group mt-3">
<button type="submit" name="action" value="edit_document" class="btn btn-primary">Edit Document</button> <button type="submit" name="action" value="edit_document" class="btn btn-primary">Edit Document</button>
<button type="submit" name="action" value="document_versions" class="btn btn-secondary">Show Document Versions</button> <button type="submit" name="action" value="document_versions" class="btn btn-secondary">Show Document Versions</button>
<button type="submit" name="action" value="refresh_document" class="btn btn-secondary">Refresh Document (new version)</button>
</div> </div>
</form> </form>
</div> </div>

View File

@@ -0,0 +1,31 @@
{% extends 'base.html' %}
{% block title %}Library Operations{% endblock %}
{% block content_title %}Library Operations{% endblock %}
{% block content_description %}Perform operations on the entire library of documents.{% endblock %}
{% block content_class %}<div class="col-xl-12 col-lg-5 col-md-7 mx-auto"></div>{% endblock %}
{% block content %}
<div class="container">
<form method="POST" action="{{ url_for('document_bp.handle_library_selection') }}">
<div class="form-group mt-3">
<h2>Re-Embed Latest Versions</h2>
<p>This functionality will re-apply embeddings on the latest versions of all documents in the library.
This is useful only while tuning the embedding parameters, or when changing embedding algorithms.
As it is an expensive operation and highly impacts the performance of the system in future use,
use it with caution!
</p>
<button type="submit" name="action" value="re_embed_latest_versions" class="btn btn-danger">Re-embed Latest Versions (expensive)</button>
<h2>Refresh all documents</h2>
<p>This operation will create new versions of all documents in the library with a URL. Documents that were uploaded directly,
cannot be automatically refreshed. This is an expensive operation, and impacts the performance of the system in future use.
Please use it with caution!
</p>
<button type="submit" name="action" value="refresh_all_documents" class="btn btn-danger">Refresh All Documents (expensive)</button>
</p>
</div>
</form>
</div>
{% endblock %}

View File

@@ -84,6 +84,7 @@
{'name': 'Add Document', 'url': '/document/add_document', 'roles': ['Super User', 'Tenant Admin']}, {'name': 'Add Document', 'url': '/document/add_document', 'roles': ['Super User', 'Tenant Admin']},
{'name': 'Add URL', 'url': '/document/add_url', 'roles': ['Super User', 'Tenant Admin']}, {'name': 'Add URL', 'url': '/document/add_url', 'roles': ['Super User', 'Tenant Admin']},
{'name': 'All Documents', 'url': '/document/documents', 'roles': ['Super User', 'Tenant Admin']}, {'name': 'All Documents', 'url': '/document/documents', 'roles': ['Super User', 'Tenant Admin']},
{'name': 'Library Operations', 'url': '/document/library_operations', 'roles': ['Super User', 'Tenant Admin']},
]) }} ]) }}
{% endif %} {% endif %}
{% if current_user.is_authenticated %} {% if current_user.is_authenticated %}

View File

@@ -1,6 +1,8 @@
import ast import ast
import os import os
from datetime import datetime as dt, timezone as tz from datetime import datetime as dt, timezone as tz
import chardet
from flask import request, redirect, flash, render_template, Blueprint, session, current_app from flask import request, redirect, flash, render_template, Blueprint, session, current_app
from flask_security import roles_accepted, current_user from flask_security import roles_accepted, current_user
from sqlalchemy import desc from sqlalchemy import desc
@@ -89,7 +91,7 @@ def add_url():
url = form.url.data url = form.url.data
html = fetch_html(url) html = fetch_html(url)
file = io.StringIO(html) file = io.BytesIO(html)
parsed_url = urlparse(url) parsed_url = urlparse(url)
path_parts = parsed_url.path.split('/') path_parts = parsed_url.path.split('/')
@@ -148,6 +150,11 @@ def handle_document_selection():
return redirect(prefixed_url_for('document_bp.edit_document', document_id=doc_id)) return redirect(prefixed_url_for('document_bp.edit_document', document_id=doc_id))
case 'document_versions': case 'document_versions':
return redirect(prefixed_url_for('document_bp.document_versions', document_id=doc_id)) return redirect(prefixed_url_for('document_bp.document_versions', document_id=doc_id))
case 'refresh_document':
refresh_document(doc_id)
return redirect(prefixed_url_for('document_bp.document_versions', document_id=doc_id))
case 're_embed_latest_versions':
re_embed_latest_versions()
# Add more conditions for other actions # Add more conditions for other actions
return redirect(prefixed_url_for('document_bp.documents')) return redirect(prefixed_url_for('document_bp.documents'))
@@ -210,7 +217,6 @@ def edit_document_version(document_version_id):
@document_bp.route('/document_versions/<int:document_id>', methods=['GET', 'POST']) @document_bp.route('/document_versions/<int:document_id>', methods=['GET', 'POST'])
@roles_accepted('Super User', 'Tenant Admin') @roles_accepted('Super User', 'Tenant Admin')
def document_versions(document_id): def document_versions(document_id):
flash(f'Processing documents is a long running process. Please be careful retriggering processing!', 'danger')
doc_vers = DocumentVersion.query.get_or_404(document_id) doc_vers = DocumentVersion.query.get_or_404(document_id)
doc_desc = f'Document {doc_vers.document.name}, Language {doc_vers.language}' doc_desc = f'Document {doc_vers.document.name}, Language {doc_vers.language}'
@@ -227,7 +233,7 @@ def document_versions(document_id):
rows = prepare_table_for_macro(doc_langs, [('id', ''), ('url', ''), ('file_location', ''), rows = prepare_table_for_macro(doc_langs, [('id', ''), ('url', ''), ('file_location', ''),
('file_name', ''), ('file_type', ''), ('file_name', ''), ('file_type', ''),
('processing', ''), ('processing_started_at', ''), ('processing', ''), ('processing_started_at', ''),
('processing_finished_at', '')]) ('processing_finished_at', ''), ('processing_error', '')])
return render_template('document/document_versions.html', rows=rows, pagination=pagination, document=doc_desc) return render_template('document/document_versions.html', rows=rows, pagination=pagination, document=doc_desc)
@@ -248,7 +254,91 @@ def handle_document_version_selection():
# Add more conditions for other actions # Add more conditions for other actions
doc_vers = DocumentVersion.query.get_or_404(doc_vers_id) doc_vers = DocumentVersion.query.get_or_404(doc_vers_id)
return redirect(prefixed_url_for('document_bp.document_versions', document_language_id=doc_vers.doc_lang_id)) return redirect(prefixed_url_for('document_bp.document_versions', document_id=doc_vers.doc_id))
@document_bp.route('/library_operations', methods=['GET', 'POST'])
@roles_accepted('Super User', 'Tenant Admin')
def library_operations():
return render_template('document/library_operations.html')
@document_bp.route('/handle_library_selection', methods=['GET', 'POST'])
@roles_accepted('Super User', 'Tenant Admin')
def handle_library_selection():
action = request.form['action']
match action:
case 're_embed_latest_versions':
re_embed_latest_versions()
case 'refresh_all_documents':
refresh_all_documents()
return redirect(prefixed_url_for('document_bp.library_operations'))
def refresh_all_documents():
for doc in Document.query.all():
refresh_document(doc.id)
def refresh_document(doc_id):
doc = Document.query.get_or_404(doc_id)
doc_vers = DocumentVersion.query.filter_by(doc_id=doc_id).order_by(desc(DocumentVersion.id)).first()
if not doc_vers.url:
current_app.logger.info(f'Document {doc_id} has no URL, skipping refresh')
flash(f'This document has no URL. I can only refresh documents with a URL. skipping refresh', 'alert')
return
new_doc_vers = create_version_for_document(doc, doc_vers.url, doc_vers.language, doc_vers.user_context)
try:
db.session.add(new_doc_vers)
db.session.commit()
except SQLAlchemyError as e:
current_app.logger.error(f'Error refreshing document {doc_id} for tenant {session["tenant"]["id"]}: {e}')
flash('Error refreshing document.', 'alert')
db.session.rollback()
error = e.args
raise
except Exception as e:
current_app.logger.error('Unknown error')
raise
html = fetch_html(new_doc_vers.url)
file = io.BytesIO(html)
parsed_url = urlparse(new_doc_vers.url)
path_parts = parsed_url.path.split('/')
filename = path_parts[-1]
if filename == '':
filename = 'index'
if not filename.endswith('.html'):
filename += '.html'
extension = 'html'
current_app.logger.info(f'Document added successfully for tenant {session["tenant"]["id"]}, '
f'Document Version {new_doc_vers.id}')
upload_file_for_version(new_doc_vers, file, extension)
task = current_celery.send_task('create_embeddings', queue='embeddings', args=[
session['tenant']['id'],
new_doc_vers.id,
])
current_app.logger.info(f'Embedding creation started for tenant {session["tenant"]["id"]}, '
f'Document Version {new_doc_vers.id}. '
f'Embedding creation task: {task.id}')
flash(f'Processing on document {doc.name}, version {new_doc_vers.id} started. Task ID: {task.id}.',
'success')
def re_embed_latest_versions():
docs = Document.query.all()
for doc in docs:
latest_doc_version = DocumentVersion.query.filter_by(doc_id=doc.id).order_by(desc(DocumentVersion.id)).first()
if latest_doc_version:
process_version(latest_doc_version.id)
def process_version(version_id): def process_version(version_id):
@@ -283,7 +373,7 @@ def create_document_stack(form, file, filename, extension):
new_doc = create_document(form, filename) new_doc = create_document(form, filename)
# Create the DocumentVersion # Create the DocumentVersion
new_doc_vers = create_version_for_document(new_doc, form.language.data, form.user_context.data) new_doc_vers = create_version_for_document(new_doc, form.url.data, form.language.data, form.user_context.data)
try: try:
db.session.add(new_doc) db.session.add(new_doc)
@@ -329,8 +419,11 @@ def create_document(form, filename):
return new_doc return new_doc
def create_version_for_document(document, language, user_context): def create_version_for_document(document, url, language, user_context):
new_doc_vers = DocumentVersion() new_doc_vers = DocumentVersion()
if url != '':
new_doc_vers.url = url
if language == '': if language == '':
new_doc_vers.language = session['default_language'] new_doc_vers.language = session['default_language']
else: else:
@@ -356,12 +449,11 @@ def upload_file_for_version(doc_vers, file, extension):
os.makedirs(upload_path, exist_ok=True) os.makedirs(upload_path, exist_ok=True)
if isinstance(file, FileStorage): if isinstance(file, FileStorage):
file.save(os.path.join(upload_path, doc_vers.file_name)) file.save(os.path.join(upload_path, doc_vers.file_name))
elif isinstance(file, io.StringIO): elif isinstance(file, io.BytesIO):
# It's a StringIO object, handle accordingly # It's a BytesIO object, handle accordingly
# Example: write content to a file manually # Example: write content to a file manually
content = file.getvalue() with open(os.path.join(upload_path, doc_vers.file_name), 'wb') as f:
with open(os.path.join(upload_path, doc_vers.file_name), 'w', encoding='utf-8') as file: f.write(file.getvalue())
file.write(content)
else: else:
raise TypeError('Unsupported file type.') raise TypeError('Unsupported file type.')
@@ -392,7 +484,7 @@ def fetch_html(url):
response = None response = None
response.raise_for_status() # Will raise an exception for bad requests response.raise_for_status() # Will raise an exception for bad requests
return response.text return response.content
def prepare_document_data(docs): def prepare_document_data(docs):

View File

@@ -267,7 +267,7 @@ def handle_tenant_selection():
case 'edit_tenant': case 'edit_tenant':
return redirect(prefixed_url_for('user_bp.edit_tenant', tenant_id=tenant_id)) return redirect(prefixed_url_for('user_bp.edit_tenant', tenant_id=tenant_id))
case 'select_tenant': case 'select_tenant':
return redirect(prefixed_url_for('basic_bp.session_defaults')) return redirect(prefixed_url_for('user_bp.tenant_overview'))
# Add more conditions for other actions # Add more conditions for other actions
return redirect(prefixed_url_for('select_tenant')) return redirect(prefixed_url_for('select_tenant'))

View File

@@ -1,9 +1,13 @@
import uuid
from flask_jwt_extended import create_access_token, get_jwt_identity, verify_jwt_in_request, decode_token from flask_jwt_extended import create_access_token, get_jwt_identity, verify_jwt_in_request, decode_token
from flask_socketio import emit, disconnect from flask_socketio import emit, disconnect
from flask import current_app, request from flask import current_app, request, session
from sqlalchemy.exc import SQLAlchemyError
from common.extensions import socketio, kms_client from common.extensions import socketio, kms_client, db
from common.models.user import Tenant from common.models.user import Tenant
from common.models.interaction import Interaction
from common.utils.celery_utils import current_celery from common.utils.celery_utils import current_celery
@@ -26,6 +30,10 @@ def handle_connect():
token = create_access_token(identity={"tenant_id": tenant_id, "api_key": api_key}) token = create_access_token(identity={"tenant_id": tenant_id, "api_key": api_key})
current_app.logger.debug(f'SocketIO: Connection handling created token: {token} for tenant {tenant_id}') current_app.logger.debug(f'SocketIO: Connection handling created token: {token} for tenant {tenant_id}')
# Create a unique session ID
if 'session_id' not in session:
session['session_id'] = str(uuid.uuid4())
# Communicate connection to client # Communicate connection to client
emit('connect', {'status': 'Connected', 'tenant_id': tenant_id}) emit('connect', {'status': 'Connected', 'tenant_id': tenant_id})
emit('authenticated', {'token': token}) # Emit custom event with the token emit('authenticated', {'token': token}) # Emit custom event with the token
@@ -71,14 +79,15 @@ def handle_message(data):
task = current_celery.send_task('ask_question', queue='llm_interactions', args=[ task = current_celery.send_task('ask_question', queue='llm_interactions', args=[
current_tenant_id, current_tenant_id,
data['message'], data['message'],
data['language'],
session['session_id'],
]) ])
current_app.logger.debug(f'SocketIO: Message offloading for tenant {current_tenant_id}, ' current_app.logger.debug(f'SocketIO: Message offloading for tenant {current_tenant_id}, '
f'Question: {task.id}') f'Question: {task.id}')
response = { response = {
'tenantId': data['tenantId'], 'tenantId': data['tenantId'],
'message': 'Processing question ...', 'message': f'Processing question ... Session ID = {session["session_id"]}',
'taskId': task.id, 'taskId': task.id,
'algorithm': 'alg1'
} }
current_app.logger.debug(f"SocketIO: Message handling sent bot response: {response}") current_app.logger.debug(f"SocketIO: Message handling sent bot response: {response}")
emit('bot_response', response, broadcast=True) emit('bot_response', response, broadcast=True)
@@ -99,16 +108,39 @@ def check_task_status(data):
if task_result.state == 'PENDING': if task_result.state == 'PENDING':
current_app.logger.debug(f'SocketIO: Task {task_id} is pending') current_app.logger.debug(f'SocketIO: Task {task_id} is pending')
emit('task_status', {'status': 'pending', 'taskId': task_id}) emit('task_status', {'status': 'pending', 'taskId': task_id})
elif task_result.state != 'FAILURE': elif task_result.state == 'SUCCESS':
current_app.logger.debug(f'SocketIO: Task {task_id} has finished. Status: {task_result.state}, ' current_app.logger.debug(f'SocketIO: Task {task_id} has finished. Status: {task_result.state}, '
f'Result: {task_result.result}') f'Result: {task_result.result}')
emit('task_status', { result = task_result.result
'status': task_result.state, response = {
'result': task_result.result 'status': 'success',
}) 'taskId': task_id,
'answer': result['answer'],
'citations': result['citations'],
'algorithm': result['algorithm'],
'interaction_id': result['interaction_id'],
}
emit('task_status', response)
else: else:
current_app.logger.error(f'SocketIO: Task {task_id} has failed. Error: {task_result.info}') current_app.logger.error(f'SocketIO: Task {task_id} has failed. Error: {task_result.info}')
emit('task_status', {'status': 'failure', 'message': str(task_result.info)}) emit('task_status', {'status': task_result.state, 'message': str(task_result.info)})
@socketio.on('feedback')
def handle_feedback(data):
interaction_id = data.get('interaction_id')
feedback = data.get('feedback') # 'up' or 'down'
# Store feedback in the database associated with the interaction_id
interaction = Interaction.query.get_or_404(interaction_id)
interaction.feedback = 0 if feedback == 'down' else 1
try:
db.session.commit()
emit('feedback_received', {'status': 'success', 'interaction_id': interaction_id})
except SQLAlchemyError as e:
current_app.logger.error(f'SocketIO: Feedback handling failed: {e}')
db.session.rollback()
emit('feedback_received', {'status': 'Could not register feedback', 'interaction_id': interaction_id})
raise e
def validate_api_key(tenant_id, api_key): def validate_api_key(tenant_id, api_key):

View File

@@ -1,7 +1,8 @@
from datetime import datetime as dt, timezone as tz from datetime import datetime as dt, timezone as tz
from flask import current_app from flask import current_app
from langchain_core.output_parsers import StrOutputParser from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableParallel from langchain_core.runnables import RunnableParallel, RunnablePassthrough
from langchain.globals import set_debug
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
from celery import states from celery import states
from celery.exceptions import Ignore from celery.exceptions import Ignore
@@ -15,17 +16,25 @@ from langchain.text_splitter import CharacterTextSplitter
from langchain_core.exceptions import LangChainException from langchain_core.exceptions import LangChainException
from common.utils.database import Database from common.utils.database import Database
from common.models.document import DocumentVersion, EmbeddingMistral, EmbeddingSmallOpenAI from common.models.document import DocumentVersion, EmbeddingMistral, EmbeddingSmallOpenAI, Embedding
from common.models.user import Tenant from common.models.user import Tenant
from common.models.interaction import ChatSession, Interaction, InteractionEmbedding
from common.extensions import db from common.extensions import db
from common.utils.celery_utils import current_celery from common.utils.celery_utils import current_celery
from common.utils.model_utils import select_model_variables from common.utils.model_utils import select_model_variables, create_language_template
from common.langchain.EveAIRetriever import EveAIRetriever from common.langchain.EveAIRetriever import EveAIRetriever
@current_celery.task(name='ask_question', queue='llm_interactions') @current_celery.task(name='ask_question', queue='llm_interactions')
def ask_question(tenant_id, question): def ask_question(tenant_id, question, language, session_id):
current_app.logger.debug('In ask_question') """returns result structured as follows:
result = {
'answer': 'Your answer here',
'citations': ['http://example.com/citation1', 'http://example.com/citation2'],
'algorithm': 'algorithm_name',
'interaction_id': 'interaction_id_value'
}
"""
current_app.logger.info(f'ask_question: Received question for tenant {tenant_id}: {question}. Processing...') current_app.logger.info(f'ask_question: Received question for tenant {tenant_id}: {question}. Processing...')
try: try:
@@ -37,17 +46,106 @@ def ask_question(tenant_id, question):
# Ensure we are working in the correct database schema # Ensure we are working in the correct database schema
Database(tenant_id).switch_schema() Database(tenant_id).switch_schema()
chat_session = ChatSession.query.filter_by(session_id=session_id).first()
if not chat_session:
# Initialize a chat_session on the database
try:
chat_session = ChatSession()
chat_session.session_id = session_id
chat_session.session_start = dt.now(tz.utc)
db.session.add(chat_session)
db.session.commit()
except SQLAlchemyError as e:
current_app.logger.error(f'ask_question: Error initializing chat session in database: {e}')
raise
new_interaction = Interaction()
new_interaction.question = question
new_interaction.language = language
new_interaction.chat_session_id = chat_session.id
new_interaction.question_at = dt.now(tz.utc)
new_interaction.algorithm_used = current_app.config['INTERACTION_ALGORITHMS']['RAG_TENANT']['name']
# try:
# db.session.add(new_interaction)
# db.session.commit()
# except SQLAlchemyError as e:
# current_app.logger.error(f'ask_question: Error saving interaction to database: {e}')
# raise
current_app.logger.debug(f'ask_question: new_interaction: {new_interaction}')
# Select variables to work with depending on tenant model # Select variables to work with depending on tenant model
model_variables = select_model_variables(tenant) model_variables = select_model_variables(tenant)
current_app.logger.debug(f'ask_question: model_variables: {model_variables}') current_app.logger.debug(f'ask_question: model_variables: {model_variables}')
set_debug(True)
retriever = EveAIRetriever(model_variables) retriever = EveAIRetriever(model_variables)
llm = model_variables['llm']
template = model_variables['rag_template']
language_template = create_language_template(template, language)
rag_prompt = ChatPromptTemplate.from_template(language_template)
setup_and_retrieval = RunnableParallel({"context": retriever, "question": RunnablePassthrough()})
# Search the database for relevant embeddings new_interaction_embeddings = []
relevant_embeddings = retriever.invoke(question) if not model_variables['cited_answer_cls']: # The model doesn't support structured feedback
output_parser = StrOutputParser()
return 'No response yet, check back later.' chain = setup_and_retrieval | rag_prompt | llm | output_parser
# Invoke the chain with the actual question
answer = chain.invoke(question)
new_interaction.answer = answer
result = {
'answer': answer,
'citations': []
}
else: # The model supports structured feedback
structured_llm = llm.with_structured_output(model_variables['cited_answer_cls'])
chain = setup_and_retrieval | rag_prompt | structured_llm
result = chain.invoke(question).dict()
current_app.logger.debug(f'ask_question: result answer: {result['answer']}')
current_app.logger.debug(f'ask_question: result citations: {result["citations"]}')
new_interaction.answer = result['answer']
# Filter out the existing Embedding IDs
given_embedding_ids = [int(emb_id) for emb_id in result['citations']]
embeddings = (
db.session.query(Embedding)
.filter(Embedding.id.in_(given_embedding_ids))
.all()
)
existing_embedding_ids = [emb.id for emb in embeddings]
urls = [emb.document_version.url for emb in embeddings]
for emb_id in existing_embedding_ids:
new_interaction_embedding = InteractionEmbedding(embedding_id=emb_id)
new_interaction_embedding.interaction = new_interaction
new_interaction_embeddings.append(new_interaction_embedding)
result['citations'] = urls
new_interaction.answer_at = dt.now(tz.utc)
chat_session.session_end = dt.now(tz.utc)
try:
db.session.add(chat_session)
db.session.add(new_interaction)
db.session.add_all(new_interaction_embeddings)
db.session.commit()
except SQLAlchemyError as e:
current_app.logger.error(f'ask_question: Error saving interaction to database: {e}')
raise
set_debug(False)
result['algorithm'] = current_app.config['INTERACTION_ALGORITHMS']['RAG_TENANT']['name']
result['interaction_id'] = new_interaction.id
return result
except Exception as e: except Exception as e:
current_app.logger.error(f'ask_question: Error processing question: {e}') current_app.logger.error(f'ask_question: Error processing question: {e}')
raise raise

View File

@@ -1,29 +1,26 @@
from datetime import datetime as dt, timezone as tz
from flask import current_app
from sqlalchemy.exc import SQLAlchemyError
from celery import states
from celery.exceptions import Ignore
import os import os
from datetime import datetime as dt, timezone as tz
from bs4 import BeautifulSoup
import html
from celery import states
from flask import current_app
# OpenAI imports
from langchain.chains.summarize import load_summarize_chain
from langchain.text_splitter import CharacterTextSplitter
from langchain_core.exceptions import LangChainException
from langchain_core.prompts import ChatPromptTemplate
from sqlalchemy.exc import SQLAlchemyError
# Unstructured commercial client imports # Unstructured commercial client imports
from unstructured_client import UnstructuredClient from unstructured_client import UnstructuredClient
from unstructured_client.models import shared from unstructured_client.models import shared
from unstructured_client.models.errors import SDKError from unstructured_client.models.errors import SDKError
# OpenAI imports
from langchain_core.prompts import ChatPromptTemplate
from langchain.chains.summarize import load_summarize_chain
from langchain.text_splitter import CharacterTextSplitter
from langchain_core.exceptions import LangChainException
from common.utils.database import Database
from common.models.document import DocumentVersion
from common.models.user import Tenant
from common.extensions import db from common.extensions import db
from common.models.document import DocumentVersion, Embedding
from common.models.user import Tenant
from common.utils.celery_utils import current_celery from common.utils.celery_utils import current_celery
from common.utils.model_utils import select_model_variables from common.utils.database import Database
from common.utils.model_utils import select_model_variables, create_language_template
from bs4 import BeautifulSoup
@current_celery.task(name='create_embeddings', queue='embeddings') @current_celery.task(name='create_embeddings', queue='embeddings')
@@ -65,6 +62,8 @@ def create_embeddings(tenant_id, document_version_id):
# start processing # start processing
document_version.processing = True document_version.processing = True
document_version.processing_started_at = dt.now(tz.utc) document_version.processing_started_at = dt.now(tz.utc)
document_version.processing_finished_at = None
document_version.processing_error = None
db.session.commit() db.session.commit()
except SQLAlchemyError as e: except SQLAlchemyError as e:
@@ -73,6 +72,8 @@ def create_embeddings(tenant_id, document_version_id):
f'for tenant {tenant_id}') f'for tenant {tenant_id}')
raise raise
delete_embeddings_for_document_version(document_version)
try: try:
match document_version.file_type: match document_version.file_type:
case 'pdf': case 'pdf':
@@ -152,6 +153,18 @@ def process_pdf(tenant, model_variables, document_version):
f'on document version {document_version.id} :-)') f'on document version {document_version.id} :-)')
def delete_embeddings_for_document_version(document_version):
embeddings_to_delete = db.session.query(Embedding).filter_by(doc_vers_id=document_version.id).all()
for embedding in embeddings_to_delete:
db.session.delete(embedding)
try:
db.session.commit()
current_app.logger.info(f'Deleted embeddings for document version {document_version.id}')
except SQLAlchemyError as e:
current_app.logger.error(f'Unable to delete embeddings for document version {document_version.id}')
raise
def process_html(tenant, model_variables, document_version): def process_html(tenant, model_variables, document_version):
# The tags to be considered can be dependent on the tenant # The tags to be considered can be dependent on the tenant
html_tags = model_variables['html_tags'] html_tags = model_variables['html_tags']
@@ -176,12 +189,15 @@ def process_html(tenant, model_variables, document_version):
extracted_data, title = parse_html(html_content, html_tags, included_elements=html_included_elements, extracted_data, title = parse_html(html_content, html_tags, included_elements=html_included_elements,
excluded_elements=html_excluded_elements) excluded_elements=html_excluded_elements)
potential_chunks = create_potential_chunks(extracted_data, html_end_tags) potential_chunks = create_potential_chunks(extracted_data, html_end_tags)
current_app.embed_tuning_logger.debug(f'Nr of potential chunks: {len(potential_chunks)}')
chunks = combine_chunks(potential_chunks, chunks = combine_chunks(potential_chunks,
model_variables['min_chunk_size'], model_variables['min_chunk_size'],
model_variables['max_chunk_size'] model_variables['max_chunk_size']
) )
current_app.logger.debug(f'Nr of chunks: {len(chunks)}')
if len(chunks) > 0: if len(chunks) > 1:
summary = summarize_chunk(tenant, model_variables, document_version, chunks[0]) summary = summarize_chunk(tenant, model_variables, document_version, chunks[0])
document_version.system_context = (f'Title: {title}\n' document_version.system_context = (f'Title: {title}\n'
f'Summary: {summary}\n') f'Summary: {summary}\n')
@@ -210,6 +226,7 @@ def process_html(tenant, model_variables, document_version):
def enrich_chunks(tenant, document_version, chunks): def enrich_chunks(tenant, document_version, chunks):
current_app.logger.debug(f'Enriching chunks for tenant {tenant.id} ' current_app.logger.debug(f'Enriching chunks for tenant {tenant.id} '
f'on document version {document_version.id}') f'on document version {document_version.id}')
current_app.logger.debug(f'Nr of chunks: {len(chunks)}')
chunk_total_context = (f'Filename: {document_version.file_name}\n' chunk_total_context = (f'Filename: {document_version.file_name}\n'
f'User Context:{document_version.user_context}\n' f'User Context:{document_version.user_context}\n'
f'{document_version.system_context}\n\n') f'{document_version.system_context}\n\n')
@@ -233,8 +250,10 @@ def summarize_chunk(tenant, model_variables, document_version, chunk):
current_app.logger.debug(f'Summarizing chunk for tenant {tenant.id} ' current_app.logger.debug(f'Summarizing chunk for tenant {tenant.id} '
f'on document version {document_version.id}') f'on document version {document_version.id}')
llm = model_variables['llm'] llm = model_variables['llm']
prompt = model_variables['summary_prompt'] template = model_variables['summary_template']
chain = load_summarize_chain(llm, chain_type='stuff', prompt=prompt) language_template = create_language_template(template, document_version.language)
current_app.logger.debug(f'Language prompt: {language_template}')
chain = load_summarize_chain(llm, chain_type='stuff', prompt=ChatPromptTemplate.from_template(language_template))
doc_creator = CharacterTextSplitter(chunk_size=model_variables['max_chunk_size'] * 2, chunk_overlap=0) doc_creator = CharacterTextSplitter(chunk_size=model_variables['max_chunk_size'] * 2, chunk_overlap=0)
text_to_summarize = doc_creator.create_documents(chunk) text_to_summarize = doc_creator.create_documents(chunk)
@@ -319,7 +338,6 @@ def parse_html(html_content, tags, included_elements=None, excluded_elements=Non
current_app.embed_tuning_logger.debug(f'Included Elements: {len(included_elements)}') current_app.embed_tuning_logger.debug(f'Included Elements: {len(included_elements)}')
current_app.embed_tuning_logger.debug(f'Excluded Elements: {excluded_elements}') current_app.embed_tuning_logger.debug(f'Excluded Elements: {excluded_elements}')
current_app.embed_tuning_logger.debug(f'Found {len(elements_to_parse)} elements to parse') current_app.embed_tuning_logger.debug(f'Found {len(elements_to_parse)} elements to parse')
current_app.embed_tuning_logger.debug(f'{elements_to_parse}')
# Iterate through the found included elements # Iterate through the found included elements
for element in elements_to_parse: for element in elements_to_parse:
@@ -327,7 +345,8 @@ def parse_html(html_content, tags, included_elements=None, excluded_elements=Non
for sub_element in element.find_all(tags): for sub_element in element.find_all(tags):
if excluded_elements and sub_element.find_parent(excluded_elements): if excluded_elements and sub_element.find_parent(excluded_elements):
continue # Skip this sub_element if it's within any of the excluded_elements continue # Skip this sub_element if it's within any of the excluded_elements
extracted_content.append((sub_element.name, sub_element.get_text(strip=True))) sub_content = html.unescape(sub_element.get_text(strip=False))
extracted_content.append((sub_element.name, sub_content))
title = soup.find('title').get_text(strip=True) title = soup.find('title').get_text(strip=True)
@@ -362,11 +381,14 @@ def combine_chunks(potential_chunks, min_chars, max_chars):
current_length = 0 current_length = 0
for chunk in potential_chunks: for chunk in potential_chunks:
current_app.embed_tuning_logger.debug(f'chunk: {chunk}')
chunk_content = ''.join(text for _, text in chunk) chunk_content = ''.join(text for _, text in chunk)
current_app.embed_tuning_logger.debug(f'chunk_content: {chunk_content}')
chunk_length = len(chunk_content) chunk_length = len(chunk_content)
if current_length + chunk_length > max_chars: if current_length + chunk_length > max_chars:
if current_length >= min_chars: if current_length >= min_chars:
current_app.embed_tuning_logger.debug(f'Adding chunk to actual_chunks: {current_chunk}')
actual_chunks.append(current_chunk) actual_chunks.append(current_chunk)
current_chunk = chunk_content current_chunk = chunk_content
current_length = chunk_length current_length = chunk_length
@@ -378,8 +400,11 @@ def combine_chunks(potential_chunks, min_chars, max_chars):
current_chunk += chunk_content current_chunk += chunk_content
current_length += chunk_length current_length += chunk_length
current_app.embed_tuning_logger.debug(f'Remaining Chunk: {current_chunk}')
current_app.embed_tuning_logger.debug(f'Remaining Length: {current_length}')
# Handle the last chunk # Handle the last chunk
if current_chunk and current_length >= min_chars: if current_chunk and current_length >= 0:
actual_chunks.append(current_chunk) actual_chunks.append(current_chunk)
return actual_chunks return actual_chunks

View File

@@ -0,0 +1,28 @@
"""Adding algorithm information in the Interaction model
Revision ID: 6fbceab656a8
Revises: f6ecc306055a
Create Date: 2024-06-11 14:24:20.837508
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '6fbceab656a8'
down_revision = 'f6ecc306055a'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('interaction', sa.Column('algorithm_used', sa.String(length=20), nullable=True))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('interaction', 'algorithm_used')
# ### end Alembic commands ###

View File

@@ -0,0 +1,28 @@
"""Adding session id to ChatSession
Revision ID: f6ecc306055a
Revises: 217938792642
Create Date: 2024-06-10 16:01:45.254969
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'f6ecc306055a'
down_revision = '217938792642'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('chat_session', sa.Column('session_id', sa.String(length=36), nullable=True))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('chat_session', 'session_id')
# ### end Alembic commands ###

View File

@@ -6,6 +6,7 @@
<title>Chat Client</title> <title>Chat Client</title>
<link rel="stylesheet" href="/static/css/eveai-chat-style.css"> <link rel="stylesheet" href="/static/css/eveai-chat-style.css">
<script src="https://cdn.socket.io/4.0.1/socket.io.min.js"></script> <script src="https://cdn.socket.io/4.0.1/socket.io.min.js"></script>
<script src="https://cdn.jsdelivr.net/npm/marked/marked.min.js"></script>
<script src="/static/js/eveai-sdk.js" defer></script> <script src="/static/js/eveai-sdk.js" defer></script>
<script src="/static/js/eveai-chat-widget.js" defer></script> <script src="/static/js/eveai-chat-widget.js" defer></script>
</head> </head>
@@ -16,7 +17,8 @@
const eveAI = new EveAI( const eveAI = new EveAI(
'1', '1',
'EveAI-CHAT-8553-7987-2800-9115-6454', 'EveAI-CHAT-8553-7987-2800-9115-6454',
'http://macstudio.ask-eve-ai-local.com' 'http://macstudio.ask-eve-ai-local.com',
'en'
); );
eveAI.initializeChat('chat-container'); eveAI.initializeChat('chat-container');
}); });

View File

@@ -6,6 +6,7 @@
<title>Chat Client AE</title> <title>Chat Client AE</title>
<link rel="stylesheet" href="/static/css/eveai-chat-style.css"> <link rel="stylesheet" href="/static/css/eveai-chat-style.css">
<script src="https://cdn.socket.io/4.0.1/socket.io.min.js"></script> <script src="https://cdn.socket.io/4.0.1/socket.io.min.js"></script>
<script src="https://cdn.jsdelivr.net/npm/marked/marked.min.js"></script>
<script src="/static/js/eveai-sdk.js" defer></script> <script src="/static/js/eveai-sdk.js" defer></script>
<script src="/static/js/eveai-chat-widget.js" defer></script> <script src="/static/js/eveai-chat-widget.js" defer></script>
</head> </head>
@@ -16,7 +17,8 @@
const eveAI = new EveAI( const eveAI = new EveAI(
'39', '39',
'EveAI-CHAT-6919-1265-9848-6655-9870', 'EveAI-CHAT-6919-1265-9848-6655-9870',
'http://macstudio.ask-eve-ai-local.com' 'http://macstudio.ask-eve-ai-local.com',
'en'
); );
eveAI.initializeChat('chat-container'); eveAI.initializeChat('chat-container');
}); });

View File

@@ -14,3 +14,5 @@ google~=3.0.0
redis~=5.0.4 redis~=5.0.4
itsdangerous~=2.2.0 itsdangerous~=2.2.0
pydantic~=2.7.1 pydantic~=2.7.1
chardet~=5.2.0
langcodes~=3.4.0

View File

@@ -1,6 +1,6 @@
class EveAIChatWidget extends HTMLElement { class EveAIChatWidget extends HTMLElement {
static get observedAttributes() { static get observedAttributes() {
return ['tenant-id', 'api-key', 'domain']; return ['tenant-id', 'api-key', 'domain', 'language'];
} }
constructor() { constructor() {
@@ -16,8 +16,14 @@ class EveAIChatWidget extends HTMLElement {
this.innerHTML = this.getTemplate(); this.innerHTML = this.getTemplate();
this.messagesArea = this.querySelector('.messages-area'); this.messagesArea = this.querySelector('.messages-area');
this.questionInput = this.querySelector('.question-area input'); this.questionInput = this.querySelector('.question-area input');
this.statusLine = this.querySelector('.status-line');
this.querySelector('.question-area button').addEventListener('click', () => this.handleSendMessage()); this.querySelector('.question-area button').addEventListener('click', () => this.handleSendMessage());
this.questionInput.addEventListener('keydown', (event) => {
if (event.key === 'Enter') {
this.handleSendMessage();
}
});
if (this.areAllAttributesSet() && !this.socket) { if (this.areAllAttributesSet() && !this.socket) {
console.log('Attributes already set in connectedCallback, initializing socket'); console.log('Attributes already set in connectedCallback, initializing socket');
@@ -31,7 +37,8 @@ class EveAIChatWidget extends HTMLElement {
console.log('Current attributes:', { console.log('Current attributes:', {
tenantId: this.getAttribute('tenant-id'), tenantId: this.getAttribute('tenant-id'),
apiKey: this.getAttribute('api-key'), apiKey: this.getAttribute('api-key'),
domain: this.getAttribute('domain') domain: this.getAttribute('domain'),
language: this.getAttribute('language')
}); });
if (this.areAllAttributesSet() && !this.socket) { if (this.areAllAttributesSet() && !this.socket) {
@@ -46,10 +53,12 @@ class EveAIChatWidget extends HTMLElement {
this.tenantId = this.getAttribute('tenant-id'); this.tenantId = this.getAttribute('tenant-id');
this.apiKey = this.getAttribute('api-key'); this.apiKey = this.getAttribute('api-key');
this.domain = this.getAttribute('domain'); this.domain = this.getAttribute('domain');
this.language = this.getAttribute('language');
console.log('Updated attributes:', { console.log('Updated attributes:', {
tenantId: this.tenantId, tenantId: this.tenantId,
apiKey: this.apiKey, apiKey: this.apiKey,
domain: this.domain domain: this.domain,
language: this.language
}); });
} }
@@ -57,10 +66,12 @@ class EveAIChatWidget extends HTMLElement {
const tenantId = this.getAttribute('tenant-id'); const tenantId = this.getAttribute('tenant-id');
const apiKey = this.getAttribute('api-key'); const apiKey = this.getAttribute('api-key');
const domain = this.getAttribute('domain'); const domain = this.getAttribute('domain');
const language = this.getAttribute('language');
console.log('Checking if all attributes are set:', { console.log('Checking if all attributes are set:', {
tenantId, tenantId,
apiKey, apiKey,
domain domain,
language
}); });
return tenantId && apiKey && domain; return tenantId && apiKey && domain;
} }
@@ -72,6 +83,7 @@ class EveAIChatWidget extends HTMLElement {
} }
if (!this.domain || this.domain === 'null') { if (!this.domain || this.domain === 'null') {
console.error('Domain attribute is missing or invalid'); console.error('Domain attribute is missing or invalid');
this.setStatusMessage('EveAI Chat Widget needs further configuration by site administrator.')
return; return;
} }
console.log(`Initializing socket connection to ${this.domain}`); console.log(`Initializing socket connection to ${this.domain}`);
@@ -91,10 +103,12 @@ class EveAIChatWidget extends HTMLElement {
this.socket.on('connect', (data) => { this.socket.on('connect', (data) => {
console.log('Socket connected'); console.log('Socket connected');
this.setStatusMessage('Connected to EveAI.')
}); });
this.socket.on('authenticated', (data) => { this.socket.on('authenticated', (data) => {
console.log('Authenticated event received: ', data); console.log('Authenticated event received: ', data);
this.setStatusMessage('Authenticated.')
if (data.token) { if (data.token) {
this.jwtToken = data.token; // Store the JWT token received from the server this.jwtToken = data.token; // Store the JWT token received from the server
} }
@@ -102,42 +116,64 @@ class EveAIChatWidget extends HTMLElement {
this.socket.on('connect_error', (err) => { this.socket.on('connect_error', (err) => {
console.error('Socket connection error:', err); console.error('Socket connection error:', err);
this.setStatusMessage('EveAI Chat Widget needs further configuration by site administrator.')
}); });
this.socket.on('connect_timeout', () => { this.socket.on('connect_timeout', () => {
console.error('Socket connection timeout'); console.error('Socket connection timeout');
this.setStatusMessage('EveAI Chat Widget needs further configuration by site administrator.')
}); });
this.socket.on('disconnect', () => { this.socket.on('disconnect', () => {
console.log('Socket disconnected'); console.log('Socket disconnected');
this.setStatusMessage('Disconnected from EveAI. Please refresh the page for further interaction.')
}); });
this.socket.on('bot_response', (data) => { this.socket.on('bot_response', (data) => {
if (data.tenantId === this.tenantId) { if (data.tenantId === this.tenantId) {
console.log('Bot response received:', data) console.log('Initial response received:', data)
console.log('Task ID received:', data.taskId) console.log('Task ID received:', data.taskId)
this.addMessage(data.message, 'bot', data.messageId, data.algorithm);
this.checkTaskStatus(data.taskId) this.checkTaskStatus(data.taskId)
this.setStatusMessage('Processing...')
} }
}); });
this.socket.on('task_status', (data) => { this.socket.on('task_status', (data) => {
console.log('Task status received:', data) console.log('Task status received:', data.status)
console.log('Task ID received:', data.taskId) console.log('Task ID received:', data.taskId)
if (data.status === 'SUCCESS') { console.log('Citations type:', typeof data.citations, 'Citations:', data.citations);
this.addMessage(data.result, 'bot');
} else if (data.status === 'FAILURE') { if (data.status === 'pending') {
this.addMessage('Failed to process message.', 'bot'); this.updateProgress();
} else if (data.status === 'pending') { setTimeout(() => this.checkTaskStatus(data.taskId), 1000); // Poll every second
console.log('Task is pending') } else if (data.status === 'success') {
setTimeout(() => this.checkTaskStatus(data.taskId), 1000); // Poll every second this.addBotMessage(data.answer, data.interaction_id, data.algorithm, data.citations);
console.log('New check sent') this.clearProgress(); // Clear progress indicator when done
} else {
this.setStatusMessage('Failed to process message.');
} }
}); });
} }
setStatusMessage(message) {
this.statusLine.textContent = message;
}
updateProgress() {
if (!this.statusLine.textContent) {
this.statusLine.textContent = 'Processing...';
} else {
this.statusLine.textContent += '.'; // Append a dot
}
}
clearProgress() {
this.statusLine.textContent = '';
}
checkTaskStatus(taskId) { checkTaskStatus(taskId) {
this.socket.emit('check_task_status', { task_id: taskId }); this.updateProgress();
this.socket.emit('check_task_status', { task_id: taskId });
} }
getTemplate() { getTemplate() {
@@ -148,21 +184,39 @@ class EveAIChatWidget extends HTMLElement {
<input type="text" placeholder="Type your message here..." /> <input type="text" placeholder="Type your message here..." />
<button>Send</button> <button>Send</button>
</div> </div>
<div class="status-line"></div>
</div> </div>
`; `;
} }
addMessage(text, type = 'user', id = null, algorithm = 'default') { addUserMessage(text) {
const message = document.createElement('div'); const message = document.createElement('div');
message.classList.add('message', type); message.classList.add('message', 'user');
message.innerHTML = `<p>${text}</p>`;
this.messagesArea.appendChild(message);
this.messagesArea.scrollTop = this.messagesArea.scrollHeight;
}
addBotMessage(text, interactionId, algorithm = 'default', citations = []) {
const message = document.createElement('div');
message.classList.add('message', 'bot');
let content = marked.parse(text); // Use marked to convert markdown to HTML
// Ensure citations is an array
if (!Array.isArray(citations)) {
console.error('Expected citations to be an array, but got:', citations);
citations = []; // Default to an empty array
}
let citationsHtml = citations.map(url => `<a href="${url}" target="_blank">${url}</a>`).join('<br>');
message.innerHTML = ` message.innerHTML = `
<p>${text}</p> <p>${content}</p>
${type === 'bot' ? ` ${citationsHtml ? `<p>Citations: ${citationsHtml}</p>` : ''}
<div class="message-icons"> <div class="message-icons">
<span class="algorithm-indicator" style="background-color: var(--algorithm-color-${algorithm});"></span> <span class="algorithm-indicator" style="background-color: var(--algorithm-color-${algorithm});">${algorithm}</span>
<i class="material-icons" onclick="handleFeedback('${id}', 'up')">thumb_up</i> <i class="material-icons" onclick="handleFeedback('up', '${interactionId}')">thumb_up</i>
<i class="material-icons" onclick="handleFeedback('${id}', 'down')">thumb_down</i> <i class="material-icons" onclick="handleFeedback('down', '${interactionId}')">thumb_down</i>
</div>` : ''} </div>
`; `;
this.messagesArea.appendChild(message); this.messagesArea.appendChild(message);
this.messagesArea.scrollTop = this.messagesArea.scrollHeight; this.messagesArea.scrollTop = this.messagesArea.scrollHeight;
@@ -172,7 +226,7 @@ class EveAIChatWidget extends HTMLElement {
console.log('handleSendMessage called'); console.log('handleSendMessage called');
const message = this.questionInput.value.trim(); const message = this.questionInput.value.trim();
if (message) { if (message) {
this.addMessage(message, 'user'); this.addUserMessage(message);
this.questionInput.value = ''; this.questionInput.value = '';
this.sendMessageToBackend(message); this.sendMessageToBackend(message);
} }
@@ -189,14 +243,15 @@ class EveAIChatWidget extends HTMLElement {
return; return;
} }
console.log('Sending message to backend'); console.log('Sending message to backend');
this.socket.emit('user_message', { tenantId: this.tenantId, token: this.jwtToken, message }); this.socket.emit('user_message', { tenantId: this.tenantId, token: this.jwtToken, message, language: this.language });
this.setStatusMessage('Processing started ...')
} }
} }
customElements.define('eveai-chat-widget', EveAIChatWidget); customElements.define('eveai-chat-widget', EveAIChatWidget);
function handleFeedback(messageId, feedback) { function handleFeedback(feedback, interactionId) {
// Send feedback to the backend // Send feedback to the backend
console.log(`Feedback for ${messageId}: ${feedback}`); console.log(`Feedback for ${interactionId}: ${feedback}`);
// Implement the actual feedback mechanism this.socket.emit('feedback', { feedback, interaction_id: interactionId });
} }

View File

@@ -1,9 +1,10 @@
// static/js/eveai-sdk.js // static/js/eveai-sdk.js
class EveAI { class EveAI {
constructor(tenantId, apiKey, domain) { constructor(tenantId, apiKey, domain, language) {
this.tenantId = tenantId; this.tenantId = tenantId;
this.apiKey = apiKey; this.apiKey = apiKey;
this.domain = domain; this.domain = domain;
this.language = language;
console.log('EveAI constructor:', { tenantId, apiKey, domain }); console.log('EveAI constructor:', { tenantId, apiKey, domain });
} }
@@ -17,10 +18,12 @@ class EveAI {
chatWidget.setAttribute('tenant-id', this.tenantId); chatWidget.setAttribute('tenant-id', this.tenantId);
chatWidget.setAttribute('api-key', this.apiKey); chatWidget.setAttribute('api-key', this.apiKey);
chatWidget.setAttribute('domain', this.domain); chatWidget.setAttribute('domain', this.domain);
chatWidget.setAttribute('language', this.language);
console.log('Attributes set in chat widget:', { console.log('Attributes set in chat widget:', {
tenantId: chatWidget.getAttribute('tenant-id'), tenantId: chatWidget.getAttribute('tenant-id'),
apiKey: chatWidget.getAttribute('api-key'), apiKey: chatWidget.getAttribute('api-key'),
domain: chatWidget.getAttribute('domain') domain: chatWidget.getAttribute('domain'),
language: chatWidget.getAttribute('language')
}); });
}); });
} else { } else {