From be311c440b868dfbd7e328dfbf4e8486a260c252 Mon Sep 17 00:00:00 2001 From: Josako Date: Wed, 12 Jun 2024 11:07:18 +0200 Subject: [PATCH] Improving chat functionality significantly throughout the application. --- common/langchain/EveAIRetriever.py | 53 ++++++-- common/models/interaction.py | 5 + common/models/user.py | 5 +- common/utils/model_utils.py | 44 ++++++- config/config.py | 36 +++--- config/logging_config.py | 2 +- .../templates/document/document_versions.html | 4 +- eveai_app/templates/document/documents.html | 3 +- .../document/library_operations.html | 31 +++++ eveai_app/templates/navbar.html | 1 + eveai_app/views/document_views.py | 116 ++++++++++++++++-- eveai_app/views/user_views.py | 2 +- eveai_chat/socket_handlers/chat_handler.py | 52 ++++++-- eveai_chat_workers/tasks.py | 114 +++++++++++++++-- eveai_workers/tasks.py | 73 +++++++---- ...a8_adding_algorithm_information_in_the_.py | 28 +++++ ...06055a_adding_session_id_to_chatsession.py | 28 +++++ public/chat.html | 4 +- public/chat_ae.html | 4 +- requirements.txt | 4 +- static/js/eveai-chat-widget.js | 115 ++++++++++++----- static/js/eveai-sdk.js | 7 +- 22 files changed, 604 insertions(+), 127 deletions(-) create mode 100644 eveai_app/templates/document/library_operations.html create mode 100644 migrations/tenant/versions/6fbceab656a8_adding_algorithm_information_in_the_.py create mode 100644 migrations/tenant/versions/f6ecc306055a_adding_session_id_to_chatsession.py diff --git a/common/langchain/EveAIRetriever.py b/common/langchain/EveAIRetriever.py index 4232454..efbacf3 100644 --- a/common/langchain/EveAIRetriever.py +++ b/common/langchain/EveAIRetriever.py @@ -1,11 +1,13 @@ from langchain_core.retrievers import BaseRetriever +from sqlalchemy import func, and_, or_ from sqlalchemy.exc import SQLAlchemyError from pydantic import BaseModel, Field from typing import Any, Dict +from flask import current_app +from datetime import date from common.extensions import db -from flask import current_app -from config.logging_config import LOGGING +from common.models.document import Document, DocumentVersion, Embedding class EveAIRetriever(BaseRetriever): @@ -23,26 +25,53 @@ class EveAIRetriever(BaseRetriever): db_class = self.model_variables['embedding_db_model'] similarity_threshold = self.model_variables['similarity_threshold'] k = self.model_variables['k'] + try: - res = ( + current_date = date.today() + # Subquery to find the latest version of each document + subquery = ( + db.session.query( + DocumentVersion.doc_id, + func.max(DocumentVersion.id).label('latest_version_id') + ) + .group_by(DocumentVersion.doc_id) + .subquery() + ) + # Main query to filter embeddings + query_obj = ( db.session.query(db_class, - db_class.embedding.cosine_distance(query_embedding) - .label('distance')) - .filter(db_class.embedding.cosine_distance(query_embedding) < similarity_threshold) + db_class.embedding.cosine_distance(query_embedding).label('distance')) + .join(DocumentVersion, db_class.doc_vers_id == DocumentVersion.id) + .join(Document, DocumentVersion.doc_id == Document.id) + .join(subquery, DocumentVersion.id == subquery.c.latest_version_id) + .filter( + or_(Document.valid_from.is_(None), Document.valid_from <= current_date), + or_(Document.valid_to.is_(None), Document.valid_to >= current_date), + db_class.embedding.cosine_distance(query_embedding) < similarity_threshold + ) .order_by('distance') .limit(k) - .all() ) - current_app.rag_tuning_logger.debug(f'Retrieved {len(res)} relevant documents') - current_app.rag_tuning_logger.debug(f'---------------------------------------') + + # Print the generated SQL statement for debugging + current_app.logger.debug("SQL Statement:\n") + current_app.logger.debug(query_obj.statement.compile(compile_kwargs={"literal_binds": True})) + + res = query_obj.all() + + # current_app.rag_tuning_logger.debug(f'Retrieved {len(res)} relevant documents') + # current_app.rag_tuning_logger.debug(f'---------------------------------------') + result = [] for doc in res: - current_app.rag_tuning_logger.debug(f'Document ID: {doc[0].id} - Distance: {doc[1]}\n') - current_app.rag_tuning_logger.debug(f'Chunk: \n {doc[0].chunk}\n\n') + # current_app.rag_tuning_logger.debug(f'Document ID: {doc[0].id} - Distance: {doc[1]}\n') + # current_app.rag_tuning_logger.debug(f'Chunk: \n {doc[0].chunk}\n\n') + result.append(f'SOURCE: {doc[0].id}\n\n{doc[0].chunk}\n\n') except SQLAlchemyError as e: current_app.logger.error(f'Error retrieving relevant documents: {e}') + db.session.rollback() return [] - return res + return result def _get_query_embedding(self, query: str): embedding_model = self.model_variables['embedding_model'] diff --git a/common/models/interaction.py b/common/models/interaction.py index 156eb03..1c1a144 100644 --- a/common/models/interaction.py +++ b/common/models/interaction.py @@ -6,6 +6,7 @@ from .document import Embedding class ChatSession(db.Model): id = db.Column(db.Integer, primary_key=True) user_id = db.Column(db.Integer, db.ForeignKey(User.id), nullable=True) + session_id = db.Column(db.String(36), nullable=True) session_start = db.Column(db.DateTime, nullable=False) session_end = db.Column(db.DateTime, nullable=True) @@ -21,6 +22,7 @@ class Interaction(db.Model): chat_session_id = db.Column(db.Integer, db.ForeignKey(ChatSession.id), nullable=False) question = db.Column(db.Text, nullable=False) answer = db.Column(db.Text, nullable=True) + algorithm_used = db.Column(db.String(20), nullable=True) language = db.Column(db.String(2), nullable=False) appreciation = db.Column(db.Integer, nullable=True, default=100) @@ -28,6 +30,9 @@ class Interaction(db.Model): question_at = db.Column(db.DateTime, nullable=False) answer_at = db.Column(db.DateTime, nullable=True) + # Relations + embeddings = db.relationship('InteractionEmbedding', backref='interaction', lazy=True) + def __repr__(self): return f"" diff --git a/common/models/user.py b/common/models/user.py index 53ec4b3..0d8400f 100644 --- a/common/models/user.py +++ b/common/models/user.py @@ -19,6 +19,7 @@ class Tenant(db.Model): id = db.Column(db.Integer, primary_key=True) name = db.Column(db.String(80), unique=True, nullable=False) website = db.Column(db.String(255), nullable=True) + timezone = db.Column(db.String(50), nullable=True, default='UTC') # language information default_language = db.Column(db.String(2), nullable=True) @@ -70,7 +71,9 @@ class Tenant(db.Model): 'llm_model': self.llm_model, 'license_start_date': self.license_start_date, 'license_end_date': self.license_end_date, - 'allowed_monthly_interactions': self.allowed_monthly_interactions + 'allowed_monthly_interactions': self.allowed_monthly_interactions, + 'embed_tuning': self.embed_tuning, + 'rag_tuning': self.rag_tuning, } diff --git a/common/utils/model_utils.py b/common/utils/model_utils.py index 339e076..fcb344e 100644 --- a/common/utils/model_utils.py +++ b/common/utils/model_utils.py @@ -1,12 +1,32 @@ +import langcodes from flask import current_app -from langchain.embeddings import OpenAIEmbeddings -from langchain.chat_models import ChatOpenAI +from langchain_community.embeddings import OpenAIEmbeddings +from langchain_openai import ChatOpenAI +from langchain_core.pydantic_v1 import BaseModel, Field from langchain.prompts import ChatPromptTemplate import ast +from typing import List from common.models.document import EmbeddingSmallOpenAI +class CitedAnswer(BaseModel): + """Default docstring - to be replaced with actual prompt""" + + answer: str = Field( + ..., + description="The answer to the user question, based on the given sources", + ) + citations: List[int] = Field( + ..., + description="The integer IDs of the SPECIFIC sources that were used to generate the answer" + ) + + +def set_language_prompt_template(cls, language_prompt): + cls.__doc__ = language_prompt + + def select_model_variables(tenant): embedding_provider = tenant.embedding_model.rsplit('.', 1)[0] embedding_model = tenant.embedding_model.rsplit('.', 1)[1] @@ -60,7 +80,7 @@ def select_model_variables(tenant): case 'text-embedding-3-small': api_key = current_app.config.get('OPENAI_API_KEY') model_variables['embedding_model'] = OpenAIEmbeddings(api_key=api_key, - model='text-embedding-3-small') + model='text-embedding-3-small') model_variables['embedding_db_model'] = EmbeddingSmallOpenAI model_variables['min_chunk_size'] = current_app.config.get('OAI_TE3S_MIN_CHUNK_SIZE') model_variables['max_chunk_size'] = current_app.config.get('OAI_TE3S_MAX_CHUNK_SIZE') @@ -78,20 +98,34 @@ def select_model_variables(tenant): model_variables['llm'] = ChatOpenAI(api_key=api_key, model=llm_model, temperature=model_variables['RAG_temperature']) + tool_calling_supported = False match llm_model: case 'gpt-4-turbo' | 'gpt-4o': summary_template = current_app.config.get('GPT4_SUMMARY_TEMPLATE') rag_template = current_app.config.get('GPT4_RAG_TEMPLATE') + tool_calling_supported = True case 'gpt-3-5-turbo': summary_template = current_app.config.get('GPT3_5_SUMMARY_TEMPLATE') rag_template = current_app.config.get('GPT3_5_RAG_TEMPLATE') case _: raise Exception(f'Error setting model variables for tenant {tenant.id} ' f'error: Invalid chat model') - model_variables['summary_prompt'] = ChatPromptTemplate.from_template(summary_template) - model_variables['rag_prompt'] = ChatPromptTemplate.from_template(rag_template) + model_variables['summary_template'] = summary_template + model_variables['rag_template'] = rag_template + if tool_calling_supported: + model_variables['cited_answer_cls'] = CitedAnswer case _: raise Exception(f'Error setting model variables for tenant {tenant.id} ' f'error: Invalid chat provider') return model_variables + + +def create_language_template(template, language): + try: + full_language = langcodes.Language.make(language=language) + language_template = template.replace('{language}', full_language.display_name()) + except ValueError: + language_template = template.replace('{language}', language) + + return language_template diff --git a/config/config.py b/config/config.py index 6908a30..a53a26f 100644 --- a/config/config.py +++ b/config/config.py @@ -73,26 +73,24 @@ class Config(object): OAI_TE3S_MAX_CHUNK_SIZE = 3000 # LLM TEMPLATES - GPT4_SUMMARY_TEMPLATE = """Write a concise summary of the text in the same language as the provided text. - Text is delimited between triple backquotes. + GPT4_SUMMARY_TEMPLATE = """Write a concise summary of the text in {language}. The text is delimited between triple backquotes. ```{text}```""" - GPT3_5_SUMMARY_TEMPLATE = """Write a concise summary of the text in the same language as the provided text. - Text is delimited between triple backquotes. + GPT3_5_SUMMARY_TEMPLATE = """Write a concise summary of the text in {language}. The text is delimited between triple backquotes. ```{text}```""" - GPT4_RAG_TEMPLATE = """Answer the question based on the following context, both delimited between triple backquotes - in the same language as question. - If the question cannot be answered using the text, say "I don't know" in the same language as the question. + GPT4_RAG_TEMPLATE = """Answer the question based on the following context, delimited between triple backquotes. + Use the following {language} in your communication, and cite the sources used. + If the question cannot be answered using the given context, say "I have insufficient information to answer this question." Context: ```{context}``` Question: - ```{question}```""" - GPT3_5_RAG_TEMPLATE = """Answer the question based on the following context, both delimited between triple backquotes - in the same language as question. - If the question cannot be answered using the text, say "I don't know" in the same language as the question. - Context: - ```{context}``` - Question: - ```{question}```""" + {question}""" + GPT3_5_RAG_TEMPLATE = """Answer the question based on the following context, delimited between triple backquotes. + Use the following {language} in your communication. + If the question cannot be answered using the given context, say "I have insufficient information to answer this question." + Context: + ```{context}``` + Question: + {question}""" # SocketIO settings # SOCKETIO_ASYNC_MODE = 'threading' @@ -105,6 +103,14 @@ class Config(object): PERMANENT_SESSION_LIFETIME = timedelta(minutes=60) SESSION_REFRESH_EACH_REQUEST = True + # Interaction algorithms + INTERACTION_ALGORITHMS = { + "RAG_TENANT": {"name": "RAG_TENANT", "description": "Algorithm using only information provided by the tenant"}, + "RAG_WIKIPEDIA": {"name": "RAG_WIKIPEDIA", "description": "Algorithm using information provided by Wikipedia"}, + "RAG_GOOGLE": {"name": "RAG_GOOGLE", "description": "Algorithm using information provided by Google"}, + "LLM": {"name": "LLM", "description": "Algorithm using information integrated in the used LLM"} + } + class DevConfig(Config): DEVELOPMENT = True diff --git a/config/logging_config.py b/config/logging_config.py index b5c420d..60fb4c2 100644 --- a/config/logging_config.py +++ b/config/logging_config.py @@ -69,7 +69,7 @@ LOGGING = { 'file_embed_tuning': { 'level': 'DEBUG', 'class': 'logging.handlers.RotatingFileHandler', - 'filename': 'logs/rag_tuning.log', + 'filename': 'logs/embed_tuning.log', 'maxBytes': 1024*1024*5, # 5MB 'backupCount': 10, 'formatter': 'standard', diff --git a/eveai_app/templates/document/document_versions.html b/eveai_app/templates/document/document_versions.html index 7bd07f2..15adbf7 100644 --- a/eveai_app/templates/document/document_versions.html +++ b/eveai_app/templates/document/document_versions.html @@ -10,10 +10,10 @@ {% block content %}
- {{ render_selectable_table(headers=["Document Version ID", "URL", "File Location", "File Name", "File Type", "Processing", "Processing Start", "Proceeing Finish"], rows=rows, selectable=True, id="versionsTable") }} + {{ render_selectable_table(headers=["ID", "URL", "File Loc.", "File Name", "File Type", "Process.", "Proces. Start", "Proces. Finish", "Proces. Error"], rows=rows, selectable=True, id="versionsTable") }}
- +
diff --git a/eveai_app/templates/document/documents.html b/eveai_app/templates/document/documents.html index cebc150..89da53a 100644 --- a/eveai_app/templates/document/documents.html +++ b/eveai_app/templates/document/documents.html @@ -5,7 +5,7 @@ {% block content_title %}Documents{% endblock %} {% block content_description %}View Documents for Tenant{% endblock %} -{% block content_class %}
{% endblock %} +{% block content_class %}
{% endblock %} {% block content %}
@@ -14,6 +14,7 @@
+
diff --git a/eveai_app/templates/document/library_operations.html b/eveai_app/templates/document/library_operations.html new file mode 100644 index 0000000..e1f732c --- /dev/null +++ b/eveai_app/templates/document/library_operations.html @@ -0,0 +1,31 @@ +{% extends 'base.html' %} + +{% block title %}Library Operations{% endblock %} + +{% block content_title %}Library Operations{% endblock %} +{% block content_description %}Perform operations on the entire library of documents.{% endblock %} +{% block content_class %}
{% endblock %} + +{% block content %} +
+
+
+

Re-Embed Latest Versions

+

This functionality will re-apply embeddings on the latest versions of all documents in the library. + This is useful only while tuning the embedding parameters, or when changing embedding algorithms. + As it is an expensive operation and highly impacts the performance of the system in future use, + use it with caution! +

+ +

Refresh all documents

+

This operation will create new versions of all documents in the library with a URL. Documents that were uploaded directly, + cannot be automatically refreshed. This is an expensive operation, and impacts the performance of the system in future use. + Please use it with caution! +

+ +

+
+
+
+{% endblock %} + diff --git a/eveai_app/templates/navbar.html b/eveai_app/templates/navbar.html index 54287ee..fa39732 100644 --- a/eveai_app/templates/navbar.html +++ b/eveai_app/templates/navbar.html @@ -84,6 +84,7 @@ {'name': 'Add Document', 'url': '/document/add_document', 'roles': ['Super User', 'Tenant Admin']}, {'name': 'Add URL', 'url': '/document/add_url', 'roles': ['Super User', 'Tenant Admin']}, {'name': 'All Documents', 'url': '/document/documents', 'roles': ['Super User', 'Tenant Admin']}, + {'name': 'Library Operations', 'url': '/document/library_operations', 'roles': ['Super User', 'Tenant Admin']}, ]) }} {% endif %} {% if current_user.is_authenticated %} diff --git a/eveai_app/views/document_views.py b/eveai_app/views/document_views.py index b60defc..79f6935 100644 --- a/eveai_app/views/document_views.py +++ b/eveai_app/views/document_views.py @@ -1,6 +1,8 @@ import ast import os from datetime import datetime as dt, timezone as tz + +import chardet from flask import request, redirect, flash, render_template, Blueprint, session, current_app from flask_security import roles_accepted, current_user from sqlalchemy import desc @@ -89,7 +91,7 @@ def add_url(): url = form.url.data html = fetch_html(url) - file = io.StringIO(html) + file = io.BytesIO(html) parsed_url = urlparse(url) path_parts = parsed_url.path.split('/') @@ -148,6 +150,11 @@ def handle_document_selection(): return redirect(prefixed_url_for('document_bp.edit_document', document_id=doc_id)) case 'document_versions': return redirect(prefixed_url_for('document_bp.document_versions', document_id=doc_id)) + case 'refresh_document': + refresh_document(doc_id) + return redirect(prefixed_url_for('document_bp.document_versions', document_id=doc_id)) + case 're_embed_latest_versions': + re_embed_latest_versions() # Add more conditions for other actions return redirect(prefixed_url_for('document_bp.documents')) @@ -210,7 +217,6 @@ def edit_document_version(document_version_id): @document_bp.route('/document_versions/', methods=['GET', 'POST']) @roles_accepted('Super User', 'Tenant Admin') def document_versions(document_id): - flash(f'Processing documents is a long running process. Please be careful retriggering processing!', 'danger') doc_vers = DocumentVersion.query.get_or_404(document_id) doc_desc = f'Document {doc_vers.document.name}, Language {doc_vers.language}' @@ -227,7 +233,7 @@ def document_versions(document_id): rows = prepare_table_for_macro(doc_langs, [('id', ''), ('url', ''), ('file_location', ''), ('file_name', ''), ('file_type', ''), ('processing', ''), ('processing_started_at', ''), - ('processing_finished_at', '')]) + ('processing_finished_at', ''), ('processing_error', '')]) return render_template('document/document_versions.html', rows=rows, pagination=pagination, document=doc_desc) @@ -248,7 +254,91 @@ def handle_document_version_selection(): # Add more conditions for other actions doc_vers = DocumentVersion.query.get_or_404(doc_vers_id) - return redirect(prefixed_url_for('document_bp.document_versions', document_language_id=doc_vers.doc_lang_id)) + return redirect(prefixed_url_for('document_bp.document_versions', document_id=doc_vers.doc_id)) + + +@document_bp.route('/library_operations', methods=['GET', 'POST']) +@roles_accepted('Super User', 'Tenant Admin') +def library_operations(): + return render_template('document/library_operations.html') + + +@document_bp.route('/handle_library_selection', methods=['GET', 'POST']) +@roles_accepted('Super User', 'Tenant Admin') +def handle_library_selection(): + action = request.form['action'] + + match action: + case 're_embed_latest_versions': + re_embed_latest_versions() + case 'refresh_all_documents': + refresh_all_documents() + + return redirect(prefixed_url_for('document_bp.library_operations')) + + +def refresh_all_documents(): + for doc in Document.query.all(): + refresh_document(doc.id) + + +def refresh_document(doc_id): + doc = Document.query.get_or_404(doc_id) + doc_vers = DocumentVersion.query.filter_by(doc_id=doc_id).order_by(desc(DocumentVersion.id)).first() + if not doc_vers.url: + current_app.logger.info(f'Document {doc_id} has no URL, skipping refresh') + flash(f'This document has no URL. I can only refresh documents with a URL. skipping refresh', 'alert') + return + + new_doc_vers = create_version_for_document(doc, doc_vers.url, doc_vers.language, doc_vers.user_context) + + try: + db.session.add(new_doc_vers) + db.session.commit() + except SQLAlchemyError as e: + current_app.logger.error(f'Error refreshing document {doc_id} for tenant {session["tenant"]["id"]}: {e}') + flash('Error refreshing document.', 'alert') + db.session.rollback() + error = e.args + raise + except Exception as e: + current_app.logger.error('Unknown error') + raise + + html = fetch_html(new_doc_vers.url) + file = io.BytesIO(html) + + parsed_url = urlparse(new_doc_vers.url) + path_parts = parsed_url.path.split('/') + filename = path_parts[-1] + if filename == '': + filename = 'index' + if not filename.endswith('.html'): + filename += '.html' + extension = 'html' + + current_app.logger.info(f'Document added successfully for tenant {session["tenant"]["id"]}, ' + f'Document Version {new_doc_vers.id}') + + upload_file_for_version(new_doc_vers, file, extension) + + task = current_celery.send_task('create_embeddings', queue='embeddings', args=[ + session['tenant']['id'], + new_doc_vers.id, + ]) + current_app.logger.info(f'Embedding creation started for tenant {session["tenant"]["id"]}, ' + f'Document Version {new_doc_vers.id}. ' + f'Embedding creation task: {task.id}') + flash(f'Processing on document {doc.name}, version {new_doc_vers.id} started. Task ID: {task.id}.', + 'success') + + +def re_embed_latest_versions(): + docs = Document.query.all() + for doc in docs: + latest_doc_version = DocumentVersion.query.filter_by(doc_id=doc.id).order_by(desc(DocumentVersion.id)).first() + if latest_doc_version: + process_version(latest_doc_version.id) def process_version(version_id): @@ -283,7 +373,7 @@ def create_document_stack(form, file, filename, extension): new_doc = create_document(form, filename) # Create the DocumentVersion - new_doc_vers = create_version_for_document(new_doc, form.language.data, form.user_context.data) + new_doc_vers = create_version_for_document(new_doc, form.url.data, form.language.data, form.user_context.data) try: db.session.add(new_doc) @@ -329,8 +419,11 @@ def create_document(form, filename): return new_doc -def create_version_for_document(document, language, user_context): +def create_version_for_document(document, url, language, user_context): new_doc_vers = DocumentVersion() + if url != '': + new_doc_vers.url = url + if language == '': new_doc_vers.language = session['default_language'] else: @@ -356,12 +449,11 @@ def upload_file_for_version(doc_vers, file, extension): os.makedirs(upload_path, exist_ok=True) if isinstance(file, FileStorage): file.save(os.path.join(upload_path, doc_vers.file_name)) - elif isinstance(file, io.StringIO): - # It's a StringIO object, handle accordingly + elif isinstance(file, io.BytesIO): + # It's a BytesIO object, handle accordingly # Example: write content to a file manually - content = file.getvalue() - with open(os.path.join(upload_path, doc_vers.file_name), 'w', encoding='utf-8') as file: - file.write(content) + with open(os.path.join(upload_path, doc_vers.file_name), 'wb') as f: + f.write(file.getvalue()) else: raise TypeError('Unsupported file type.') @@ -392,7 +484,7 @@ def fetch_html(url): response = None response.raise_for_status() # Will raise an exception for bad requests - return response.text + return response.content def prepare_document_data(docs): diff --git a/eveai_app/views/user_views.py b/eveai_app/views/user_views.py index 0da0679..5d43d9b 100644 --- a/eveai_app/views/user_views.py +++ b/eveai_app/views/user_views.py @@ -267,7 +267,7 @@ def handle_tenant_selection(): case 'edit_tenant': return redirect(prefixed_url_for('user_bp.edit_tenant', tenant_id=tenant_id)) case 'select_tenant': - return redirect(prefixed_url_for('basic_bp.session_defaults')) + return redirect(prefixed_url_for('user_bp.tenant_overview')) # Add more conditions for other actions return redirect(prefixed_url_for('select_tenant')) diff --git a/eveai_chat/socket_handlers/chat_handler.py b/eveai_chat/socket_handlers/chat_handler.py index b4ac376..a5aaf38 100644 --- a/eveai_chat/socket_handlers/chat_handler.py +++ b/eveai_chat/socket_handlers/chat_handler.py @@ -1,9 +1,13 @@ +import uuid + from flask_jwt_extended import create_access_token, get_jwt_identity, verify_jwt_in_request, decode_token from flask_socketio import emit, disconnect -from flask import current_app, request +from flask import current_app, request, session +from sqlalchemy.exc import SQLAlchemyError -from common.extensions import socketio, kms_client +from common.extensions import socketio, kms_client, db from common.models.user import Tenant +from common.models.interaction import Interaction from common.utils.celery_utils import current_celery @@ -26,6 +30,10 @@ def handle_connect(): token = create_access_token(identity={"tenant_id": tenant_id, "api_key": api_key}) current_app.logger.debug(f'SocketIO: Connection handling created token: {token} for tenant {tenant_id}') + # Create a unique session ID + if 'session_id' not in session: + session['session_id'] = str(uuid.uuid4()) + # Communicate connection to client emit('connect', {'status': 'Connected', 'tenant_id': tenant_id}) emit('authenticated', {'token': token}) # Emit custom event with the token @@ -71,14 +79,15 @@ def handle_message(data): task = current_celery.send_task('ask_question', queue='llm_interactions', args=[ current_tenant_id, data['message'], + data['language'], + session['session_id'], ]) current_app.logger.debug(f'SocketIO: Message offloading for tenant {current_tenant_id}, ' f'Question: {task.id}') response = { 'tenantId': data['tenantId'], - 'message': 'Processing question ...', + 'message': f'Processing question ... Session ID = {session["session_id"]}', 'taskId': task.id, - 'algorithm': 'alg1' } current_app.logger.debug(f"SocketIO: Message handling sent bot response: {response}") emit('bot_response', response, broadcast=True) @@ -99,16 +108,39 @@ def check_task_status(data): if task_result.state == 'PENDING': current_app.logger.debug(f'SocketIO: Task {task_id} is pending') emit('task_status', {'status': 'pending', 'taskId': task_id}) - elif task_result.state != 'FAILURE': + elif task_result.state == 'SUCCESS': current_app.logger.debug(f'SocketIO: Task {task_id} has finished. Status: {task_result.state}, ' f'Result: {task_result.result}') - emit('task_status', { - 'status': task_result.state, - 'result': task_result.result - }) + result = task_result.result + response = { + 'status': 'success', + 'taskId': task_id, + 'answer': result['answer'], + 'citations': result['citations'], + 'algorithm': result['algorithm'], + 'interaction_id': result['interaction_id'], + } + emit('task_status', response) else: current_app.logger.error(f'SocketIO: Task {task_id} has failed. Error: {task_result.info}') - emit('task_status', {'status': 'failure', 'message': str(task_result.info)}) + emit('task_status', {'status': task_result.state, 'message': str(task_result.info)}) + + +@socketio.on('feedback') +def handle_feedback(data): + interaction_id = data.get('interaction_id') + feedback = data.get('feedback') # 'up' or 'down' + # Store feedback in the database associated with the interaction_id + interaction = Interaction.query.get_or_404(interaction_id) + interaction.feedback = 0 if feedback == 'down' else 1 + try: + db.session.commit() + emit('feedback_received', {'status': 'success', 'interaction_id': interaction_id}) + except SQLAlchemyError as e: + current_app.logger.error(f'SocketIO: Feedback handling failed: {e}') + db.session.rollback() + emit('feedback_received', {'status': 'Could not register feedback', 'interaction_id': interaction_id}) + raise e def validate_api_key(tenant_id, api_key): diff --git a/eveai_chat_workers/tasks.py b/eveai_chat_workers/tasks.py index 379f969..48adbd3 100644 --- a/eveai_chat_workers/tasks.py +++ b/eveai_chat_workers/tasks.py @@ -1,7 +1,8 @@ from datetime import datetime as dt, timezone as tz from flask import current_app from langchain_core.output_parsers import StrOutputParser -from langchain_core.runnables import RunnableParallel +from langchain_core.runnables import RunnableParallel, RunnablePassthrough +from langchain.globals import set_debug from sqlalchemy.exc import SQLAlchemyError from celery import states from celery.exceptions import Ignore @@ -15,17 +16,25 @@ from langchain.text_splitter import CharacterTextSplitter from langchain_core.exceptions import LangChainException from common.utils.database import Database -from common.models.document import DocumentVersion, EmbeddingMistral, EmbeddingSmallOpenAI +from common.models.document import DocumentVersion, EmbeddingMistral, EmbeddingSmallOpenAI, Embedding from common.models.user import Tenant +from common.models.interaction import ChatSession, Interaction, InteractionEmbedding from common.extensions import db from common.utils.celery_utils import current_celery -from common.utils.model_utils import select_model_variables +from common.utils.model_utils import select_model_variables, create_language_template from common.langchain.EveAIRetriever import EveAIRetriever @current_celery.task(name='ask_question', queue='llm_interactions') -def ask_question(tenant_id, question): - current_app.logger.debug('In ask_question') +def ask_question(tenant_id, question, language, session_id): + """returns result structured as follows: + result = { + 'answer': 'Your answer here', + 'citations': ['http://example.com/citation1', 'http://example.com/citation2'], + 'algorithm': 'algorithm_name', + 'interaction_id': 'interaction_id_value' + } + """ current_app.logger.info(f'ask_question: Received question for tenant {tenant_id}: {question}. Processing...') try: @@ -37,17 +46,106 @@ def ask_question(tenant_id, question): # Ensure we are working in the correct database schema Database(tenant_id).switch_schema() + chat_session = ChatSession.query.filter_by(session_id=session_id).first() + if not chat_session: + # Initialize a chat_session on the database + try: + chat_session = ChatSession() + chat_session.session_id = session_id + chat_session.session_start = dt.now(tz.utc) + db.session.add(chat_session) + db.session.commit() + except SQLAlchemyError as e: + current_app.logger.error(f'ask_question: Error initializing chat session in database: {e}') + raise + + new_interaction = Interaction() + new_interaction.question = question + new_interaction.language = language + new_interaction.chat_session_id = chat_session.id + new_interaction.question_at = dt.now(tz.utc) + new_interaction.algorithm_used = current_app.config['INTERACTION_ALGORITHMS']['RAG_TENANT']['name'] + + # try: + # db.session.add(new_interaction) + # db.session.commit() + # except SQLAlchemyError as e: + # current_app.logger.error(f'ask_question: Error saving interaction to database: {e}') + # raise + + current_app.logger.debug(f'ask_question: new_interaction: {new_interaction}') + # Select variables to work with depending on tenant model model_variables = select_model_variables(tenant) current_app.logger.debug(f'ask_question: model_variables: {model_variables}') + set_debug(True) retriever = EveAIRetriever(model_variables) + llm = model_variables['llm'] + template = model_variables['rag_template'] + language_template = create_language_template(template, language) + rag_prompt = ChatPromptTemplate.from_template(language_template) + setup_and_retrieval = RunnableParallel({"context": retriever, "question": RunnablePassthrough()}) - # Search the database for relevant embeddings - relevant_embeddings = retriever.invoke(question) + new_interaction_embeddings = [] + if not model_variables['cited_answer_cls']: # The model doesn't support structured feedback + output_parser = StrOutputParser() - return 'No response yet, check back later.' + chain = setup_and_retrieval | rag_prompt | llm | output_parser + + # Invoke the chain with the actual question + answer = chain.invoke(question) + new_interaction.answer = answer + result = { + 'answer': answer, + 'citations': [] + } + + else: # The model supports structured feedback + structured_llm = llm.with_structured_output(model_variables['cited_answer_cls']) + + chain = setup_and_retrieval | rag_prompt | structured_llm + + result = chain.invoke(question).dict() + current_app.logger.debug(f'ask_question: result answer: {result['answer']}') + current_app.logger.debug(f'ask_question: result citations: {result["citations"]}') + new_interaction.answer = result['answer'] + + # Filter out the existing Embedding IDs + given_embedding_ids = [int(emb_id) for emb_id in result['citations']] + embeddings = ( + db.session.query(Embedding) + .filter(Embedding.id.in_(given_embedding_ids)) + .all() + ) + existing_embedding_ids = [emb.id for emb in embeddings] + urls = [emb.document_version.url for emb in embeddings] + + for emb_id in existing_embedding_ids: + new_interaction_embedding = InteractionEmbedding(embedding_id=emb_id) + new_interaction_embedding.interaction = new_interaction + new_interaction_embeddings.append(new_interaction_embedding) + + result['citations'] = urls + + new_interaction.answer_at = dt.now(tz.utc) + chat_session.session_end = dt.now(tz.utc) + + try: + db.session.add(chat_session) + db.session.add(new_interaction) + db.session.add_all(new_interaction_embeddings) + db.session.commit() + except SQLAlchemyError as e: + current_app.logger.error(f'ask_question: Error saving interaction to database: {e}') + raise + + set_debug(False) + + result['algorithm'] = current_app.config['INTERACTION_ALGORITHMS']['RAG_TENANT']['name'] + result['interaction_id'] = new_interaction.id + return result except Exception as e: current_app.logger.error(f'ask_question: Error processing question: {e}') raise diff --git a/eveai_workers/tasks.py b/eveai_workers/tasks.py index a20f202..8905a34 100644 --- a/eveai_workers/tasks.py +++ b/eveai_workers/tasks.py @@ -1,29 +1,26 @@ -from datetime import datetime as dt, timezone as tz -from flask import current_app -from sqlalchemy.exc import SQLAlchemyError -from celery import states -from celery.exceptions import Ignore import os - +from datetime import datetime as dt, timezone as tz +from bs4 import BeautifulSoup +import html +from celery import states +from flask import current_app +# OpenAI imports +from langchain.chains.summarize import load_summarize_chain +from langchain.text_splitter import CharacterTextSplitter +from langchain_core.exceptions import LangChainException +from langchain_core.prompts import ChatPromptTemplate +from sqlalchemy.exc import SQLAlchemyError # Unstructured commercial client imports from unstructured_client import UnstructuredClient from unstructured_client.models import shared from unstructured_client.models.errors import SDKError -# OpenAI imports -from langchain_core.prompts import ChatPromptTemplate -from langchain.chains.summarize import load_summarize_chain -from langchain.text_splitter import CharacterTextSplitter -from langchain_core.exceptions import LangChainException - -from common.utils.database import Database -from common.models.document import DocumentVersion -from common.models.user import Tenant from common.extensions import db +from common.models.document import DocumentVersion, Embedding +from common.models.user import Tenant from common.utils.celery_utils import current_celery -from common.utils.model_utils import select_model_variables - -from bs4 import BeautifulSoup +from common.utils.database import Database +from common.utils.model_utils import select_model_variables, create_language_template @current_celery.task(name='create_embeddings', queue='embeddings') @@ -65,6 +62,8 @@ def create_embeddings(tenant_id, document_version_id): # start processing document_version.processing = True document_version.processing_started_at = dt.now(tz.utc) + document_version.processing_finished_at = None + document_version.processing_error = None db.session.commit() except SQLAlchemyError as e: @@ -73,6 +72,8 @@ def create_embeddings(tenant_id, document_version_id): f'for tenant {tenant_id}') raise + delete_embeddings_for_document_version(document_version) + try: match document_version.file_type: case 'pdf': @@ -152,6 +153,18 @@ def process_pdf(tenant, model_variables, document_version): f'on document version {document_version.id} :-)') +def delete_embeddings_for_document_version(document_version): + embeddings_to_delete = db.session.query(Embedding).filter_by(doc_vers_id=document_version.id).all() + for embedding in embeddings_to_delete: + db.session.delete(embedding) + try: + db.session.commit() + current_app.logger.info(f'Deleted embeddings for document version {document_version.id}') + except SQLAlchemyError as e: + current_app.logger.error(f'Unable to delete embeddings for document version {document_version.id}') + raise + + def process_html(tenant, model_variables, document_version): # The tags to be considered can be dependent on the tenant html_tags = model_variables['html_tags'] @@ -176,12 +189,15 @@ def process_html(tenant, model_variables, document_version): extracted_data, title = parse_html(html_content, html_tags, included_elements=html_included_elements, excluded_elements=html_excluded_elements) potential_chunks = create_potential_chunks(extracted_data, html_end_tags) + current_app.embed_tuning_logger.debug(f'Nr of potential chunks: {len(potential_chunks)}') + chunks = combine_chunks(potential_chunks, model_variables['min_chunk_size'], model_variables['max_chunk_size'] ) + current_app.logger.debug(f'Nr of chunks: {len(chunks)}') - if len(chunks) > 0: + if len(chunks) > 1: summary = summarize_chunk(tenant, model_variables, document_version, chunks[0]) document_version.system_context = (f'Title: {title}\n' f'Summary: {summary}\n') @@ -210,6 +226,7 @@ def process_html(tenant, model_variables, document_version): def enrich_chunks(tenant, document_version, chunks): current_app.logger.debug(f'Enriching chunks for tenant {tenant.id} ' f'on document version {document_version.id}') + current_app.logger.debug(f'Nr of chunks: {len(chunks)}') chunk_total_context = (f'Filename: {document_version.file_name}\n' f'User Context:{document_version.user_context}\n' f'{document_version.system_context}\n\n') @@ -233,8 +250,10 @@ def summarize_chunk(tenant, model_variables, document_version, chunk): current_app.logger.debug(f'Summarizing chunk for tenant {tenant.id} ' f'on document version {document_version.id}') llm = model_variables['llm'] - prompt = model_variables['summary_prompt'] - chain = load_summarize_chain(llm, chain_type='stuff', prompt=prompt) + template = model_variables['summary_template'] + language_template = create_language_template(template, document_version.language) + current_app.logger.debug(f'Language prompt: {language_template}') + chain = load_summarize_chain(llm, chain_type='stuff', prompt=ChatPromptTemplate.from_template(language_template)) doc_creator = CharacterTextSplitter(chunk_size=model_variables['max_chunk_size'] * 2, chunk_overlap=0) text_to_summarize = doc_creator.create_documents(chunk) @@ -319,7 +338,6 @@ def parse_html(html_content, tags, included_elements=None, excluded_elements=Non current_app.embed_tuning_logger.debug(f'Included Elements: {len(included_elements)}') current_app.embed_tuning_logger.debug(f'Excluded Elements: {excluded_elements}') current_app.embed_tuning_logger.debug(f'Found {len(elements_to_parse)} elements to parse') - current_app.embed_tuning_logger.debug(f'{elements_to_parse}') # Iterate through the found included elements for element in elements_to_parse: @@ -327,7 +345,8 @@ def parse_html(html_content, tags, included_elements=None, excluded_elements=Non for sub_element in element.find_all(tags): if excluded_elements and sub_element.find_parent(excluded_elements): continue # Skip this sub_element if it's within any of the excluded_elements - extracted_content.append((sub_element.name, sub_element.get_text(strip=True))) + sub_content = html.unescape(sub_element.get_text(strip=False)) + extracted_content.append((sub_element.name, sub_content)) title = soup.find('title').get_text(strip=True) @@ -362,11 +381,14 @@ def combine_chunks(potential_chunks, min_chars, max_chars): current_length = 0 for chunk in potential_chunks: + current_app.embed_tuning_logger.debug(f'chunk: {chunk}') chunk_content = ''.join(text for _, text in chunk) + current_app.embed_tuning_logger.debug(f'chunk_content: {chunk_content}') chunk_length = len(chunk_content) if current_length + chunk_length > max_chars: if current_length >= min_chars: + current_app.embed_tuning_logger.debug(f'Adding chunk to actual_chunks: {current_chunk}') actual_chunks.append(current_chunk) current_chunk = chunk_content current_length = chunk_length @@ -378,8 +400,11 @@ def combine_chunks(potential_chunks, min_chars, max_chars): current_chunk += chunk_content current_length += chunk_length + current_app.embed_tuning_logger.debug(f'Remaining Chunk: {current_chunk}') + current_app.embed_tuning_logger.debug(f'Remaining Length: {current_length}') + # Handle the last chunk - if current_chunk and current_length >= min_chars: + if current_chunk and current_length >= 0: actual_chunks.append(current_chunk) return actual_chunks diff --git a/migrations/tenant/versions/6fbceab656a8_adding_algorithm_information_in_the_.py b/migrations/tenant/versions/6fbceab656a8_adding_algorithm_information_in_the_.py new file mode 100644 index 0000000..730f730 --- /dev/null +++ b/migrations/tenant/versions/6fbceab656a8_adding_algorithm_information_in_the_.py @@ -0,0 +1,28 @@ +"""Adding algorithm information in the Interaction model + +Revision ID: 6fbceab656a8 +Revises: f6ecc306055a +Create Date: 2024-06-11 14:24:20.837508 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '6fbceab656a8' +down_revision = 'f6ecc306055a' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('interaction', sa.Column('algorithm_used', sa.String(length=20), nullable=True)) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('interaction', 'algorithm_used') + # ### end Alembic commands ### diff --git a/migrations/tenant/versions/f6ecc306055a_adding_session_id_to_chatsession.py b/migrations/tenant/versions/f6ecc306055a_adding_session_id_to_chatsession.py new file mode 100644 index 0000000..fb48ccf --- /dev/null +++ b/migrations/tenant/versions/f6ecc306055a_adding_session_id_to_chatsession.py @@ -0,0 +1,28 @@ +"""Adding session id to ChatSession + +Revision ID: f6ecc306055a +Revises: 217938792642 +Create Date: 2024-06-10 16:01:45.254969 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'f6ecc306055a' +down_revision = '217938792642' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('chat_session', sa.Column('session_id', sa.String(length=36), nullable=True)) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('chat_session', 'session_id') + # ### end Alembic commands ### diff --git a/public/chat.html b/public/chat.html index 0fb59f1..099174e 100644 --- a/public/chat.html +++ b/public/chat.html @@ -6,6 +6,7 @@ Chat Client + @@ -16,7 +17,8 @@ const eveAI = new EveAI( '1', 'EveAI-CHAT-8553-7987-2800-9115-6454', - 'http://macstudio.ask-eve-ai-local.com' + 'http://macstudio.ask-eve-ai-local.com', + 'en' ); eveAI.initializeChat('chat-container'); }); diff --git a/public/chat_ae.html b/public/chat_ae.html index 8603b44..181395f 100644 --- a/public/chat_ae.html +++ b/public/chat_ae.html @@ -6,6 +6,7 @@ Chat Client AE + @@ -16,7 +17,8 @@ const eveAI = new EveAI( '39', 'EveAI-CHAT-6919-1265-9848-6655-9870', - 'http://macstudio.ask-eve-ai-local.com' + 'http://macstudio.ask-eve-ai-local.com', + 'en' ); eveAI.initializeChat('chat-container'); }); diff --git a/requirements.txt b/requirements.txt index a5a0063..d5cf086 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,4 +13,6 @@ beautifulsoup4~=4.12.3 google~=3.0.0 redis~=5.0.4 itsdangerous~=2.2.0 -pydantic~=2.7.1 \ No newline at end of file +pydantic~=2.7.1 +chardet~=5.2.0 +langcodes~=3.4.0 \ No newline at end of file diff --git a/static/js/eveai-chat-widget.js b/static/js/eveai-chat-widget.js index 5789a82..4ccc81d 100644 --- a/static/js/eveai-chat-widget.js +++ b/static/js/eveai-chat-widget.js @@ -1,6 +1,6 @@ class EveAIChatWidget extends HTMLElement { static get observedAttributes() { - return ['tenant-id', 'api-key', 'domain']; + return ['tenant-id', 'api-key', 'domain', 'language']; } constructor() { @@ -16,8 +16,14 @@ class EveAIChatWidget extends HTMLElement { this.innerHTML = this.getTemplate(); this.messagesArea = this.querySelector('.messages-area'); this.questionInput = this.querySelector('.question-area input'); + this.statusLine = this.querySelector('.status-line'); this.querySelector('.question-area button').addEventListener('click', () => this.handleSendMessage()); + this.questionInput.addEventListener('keydown', (event) => { + if (event.key === 'Enter') { + this.handleSendMessage(); + } + }); if (this.areAllAttributesSet() && !this.socket) { console.log('Attributes already set in connectedCallback, initializing socket'); @@ -31,7 +37,8 @@ class EveAIChatWidget extends HTMLElement { console.log('Current attributes:', { tenantId: this.getAttribute('tenant-id'), apiKey: this.getAttribute('api-key'), - domain: this.getAttribute('domain') + domain: this.getAttribute('domain'), + language: this.getAttribute('language') }); if (this.areAllAttributesSet() && !this.socket) { @@ -46,10 +53,12 @@ class EveAIChatWidget extends HTMLElement { this.tenantId = this.getAttribute('tenant-id'); this.apiKey = this.getAttribute('api-key'); this.domain = this.getAttribute('domain'); + this.language = this.getAttribute('language'); console.log('Updated attributes:', { tenantId: this.tenantId, apiKey: this.apiKey, - domain: this.domain + domain: this.domain, + language: this.language }); } @@ -57,10 +66,12 @@ class EveAIChatWidget extends HTMLElement { const tenantId = this.getAttribute('tenant-id'); const apiKey = this.getAttribute('api-key'); const domain = this.getAttribute('domain'); + const language = this.getAttribute('language'); console.log('Checking if all attributes are set:', { tenantId, apiKey, - domain + domain, + language }); return tenantId && apiKey && domain; } @@ -72,6 +83,7 @@ class EveAIChatWidget extends HTMLElement { } if (!this.domain || this.domain === 'null') { console.error('Domain attribute is missing or invalid'); + this.setStatusMessage('EveAI Chat Widget needs further configuration by site administrator.') return; } console.log(`Initializing socket connection to ${this.domain}`); @@ -91,10 +103,12 @@ class EveAIChatWidget extends HTMLElement { this.socket.on('connect', (data) => { console.log('Socket connected'); + this.setStatusMessage('Connected to EveAI.') }); this.socket.on('authenticated', (data) => { console.log('Authenticated event received: ', data); + this.setStatusMessage('Authenticated.') if (data.token) { this.jwtToken = data.token; // Store the JWT token received from the server } @@ -102,42 +116,64 @@ class EveAIChatWidget extends HTMLElement { this.socket.on('connect_error', (err) => { console.error('Socket connection error:', err); + this.setStatusMessage('EveAI Chat Widget needs further configuration by site administrator.') }); this.socket.on('connect_timeout', () => { console.error('Socket connection timeout'); + this.setStatusMessage('EveAI Chat Widget needs further configuration by site administrator.') }); this.socket.on('disconnect', () => { console.log('Socket disconnected'); + this.setStatusMessage('Disconnected from EveAI. Please refresh the page for further interaction.') }); this.socket.on('bot_response', (data) => { if (data.tenantId === this.tenantId) { - console.log('Bot response received:', data) + console.log('Initial response received:', data) console.log('Task ID received:', data.taskId) - this.addMessage(data.message, 'bot', data.messageId, data.algorithm); this.checkTaskStatus(data.taskId) + this.setStatusMessage('Processing...') } }); this.socket.on('task_status', (data) => { - console.log('Task status received:', data) + console.log('Task status received:', data.status) console.log('Task ID received:', data.taskId) - if (data.status === 'SUCCESS') { - this.addMessage(data.result, 'bot'); - } else if (data.status === 'FAILURE') { - this.addMessage('Failed to process message.', 'bot'); - } else if (data.status === 'pending') { - console.log('Task is pending') - setTimeout(() => this.checkTaskStatus(data.taskId), 1000); // Poll every second - console.log('New check sent') + console.log('Citations type:', typeof data.citations, 'Citations:', data.citations); + + if (data.status === 'pending') { + this.updateProgress(); + setTimeout(() => this.checkTaskStatus(data.taskId), 1000); // Poll every second + } else if (data.status === 'success') { + this.addBotMessage(data.answer, data.interaction_id, data.algorithm, data.citations); + this.clearProgress(); // Clear progress indicator when done + } else { + this.setStatusMessage('Failed to process message.'); } }); } + setStatusMessage(message) { + this.statusLine.textContent = message; + } + + updateProgress() { + if (!this.statusLine.textContent) { + this.statusLine.textContent = 'Processing...'; + } else { + this.statusLine.textContent += '.'; // Append a dot + } + } + + clearProgress() { + this.statusLine.textContent = ''; + } + checkTaskStatus(taskId) { - this.socket.emit('check_task_status', { task_id: taskId }); + this.updateProgress(); + this.socket.emit('check_task_status', { task_id: taskId }); } getTemplate() { @@ -148,21 +184,39 @@ class EveAIChatWidget extends HTMLElement {
+
`; } - addMessage(text, type = 'user', id = null, algorithm = 'default') { + addUserMessage(text) { const message = document.createElement('div'); - message.classList.add('message', type); + message.classList.add('message', 'user'); + message.innerHTML = `

${text}

`; + this.messagesArea.appendChild(message); + this.messagesArea.scrollTop = this.messagesArea.scrollHeight; + } + + addBotMessage(text, interactionId, algorithm = 'default', citations = []) { + const message = document.createElement('div'); + message.classList.add('message', 'bot'); + + let content = marked.parse(text); // Use marked to convert markdown to HTML + // Ensure citations is an array + if (!Array.isArray(citations)) { + console.error('Expected citations to be an array, but got:', citations); + citations = []; // Default to an empty array + } + let citationsHtml = citations.map(url => `${url}`).join('
'); + message.innerHTML = ` -

${text}

- ${type === 'bot' ? ` -
- - thumb_up - thumb_down -
` : ''} +

${content}

+ ${citationsHtml ? `

Citations: ${citationsHtml}

` : ''} +
+ ${algorithm} + thumb_up + thumb_down +
`; this.messagesArea.appendChild(message); this.messagesArea.scrollTop = this.messagesArea.scrollHeight; @@ -172,7 +226,7 @@ class EveAIChatWidget extends HTMLElement { console.log('handleSendMessage called'); const message = this.questionInput.value.trim(); if (message) { - this.addMessage(message, 'user'); + this.addUserMessage(message); this.questionInput.value = ''; this.sendMessageToBackend(message); } @@ -189,14 +243,15 @@ class EveAIChatWidget extends HTMLElement { return; } console.log('Sending message to backend'); - this.socket.emit('user_message', { tenantId: this.tenantId, token: this.jwtToken, message }); + this.socket.emit('user_message', { tenantId: this.tenantId, token: this.jwtToken, message, language: this.language }); + this.setStatusMessage('Processing started ...') } } customElements.define('eveai-chat-widget', EveAIChatWidget); -function handleFeedback(messageId, feedback) { +function handleFeedback(feedback, interactionId) { // Send feedback to the backend - console.log(`Feedback for ${messageId}: ${feedback}`); - // Implement the actual feedback mechanism + console.log(`Feedback for ${interactionId}: ${feedback}`); + this.socket.emit('feedback', { feedback, interaction_id: interactionId }); } diff --git a/static/js/eveai-sdk.js b/static/js/eveai-sdk.js index aceb522..ac08330 100644 --- a/static/js/eveai-sdk.js +++ b/static/js/eveai-sdk.js @@ -1,9 +1,10 @@ // static/js/eveai-sdk.js class EveAI { - constructor(tenantId, apiKey, domain) { + constructor(tenantId, apiKey, domain, language) { this.tenantId = tenantId; this.apiKey = apiKey; this.domain = domain; + this.language = language; console.log('EveAI constructor:', { tenantId, apiKey, domain }); } @@ -17,10 +18,12 @@ class EveAI { chatWidget.setAttribute('tenant-id', this.tenantId); chatWidget.setAttribute('api-key', this.apiKey); chatWidget.setAttribute('domain', this.domain); + chatWidget.setAttribute('language', this.language); console.log('Attributes set in chat widget:', { tenantId: chatWidget.getAttribute('tenant-id'), apiKey: chatWidget.getAttribute('api-key'), - domain: chatWidget.getAttribute('domain') + domain: chatWidget.getAttribute('domain'), + language: chatWidget.getAttribute('language') }); }); } else {