diff --git a/common/models/user.py b/common/models/user.py index 9d698b1..e1caa2e 100644 --- a/common/models/user.py +++ b/common/models/user.py @@ -20,6 +20,7 @@ class Tenant(db.Model): name = db.Column(db.String(80), unique=True, nullable=False) website = db.Column(db.String(255), nullable=True) timezone = db.Column(db.String(50), nullable=True, default='UTC') + rag_context = db.Column(db.Text, nullable=True) # language information default_language = db.Column(db.String(2), nullable=True) @@ -42,6 +43,7 @@ class Tenant(db.Model): # Chat variables chat_RAG_temperature = db.Column(db.Float, nullable=True, default=0.3) chat_no_RAG_temperature = db.Column(db.Float, nullable=True, default=0.5) + fallback_algorithms = db.Column(ARRAY(sa.String(50)), nullable=True) # Licensing Information license_start_date = db.Column(db.Date, nullable=True) @@ -65,11 +67,21 @@ class Tenant(db.Model): 'id': self.id, 'name': self.name, 'website': self.website, + 'timezone': self.timezone, + 'rag_context': self.rag_context, 'default_language': self.default_language, 'allowed_languages': self.allowed_languages, - 'timezone': self.timezone, 'embedding_model': self.embedding_model, 'llm_model': self.llm_model, + 'html_tags': self.html_tags, + 'html_end_tags': self.html_end_tags, + 'html_included_elements': self.html_included_elements, + 'html_excluded_elements': self.html_excluded_elements, + 'es_k': self.es_k, + 'es_similarity_threshold': self.es_similarity_threshold, + 'chat_RAG_temperature': self.chat_RAG_temperature, + 'chat_no_RAG_temperature': self.chat_no_RAG_temperature, + 'fallback_algorithms': self.fallback_algorithms, 'license_start_date': self.license_start_date, 'license_end_date': self.license_end_date, 'allowed_monthly_interactions': self.allowed_monthly_interactions, diff --git a/common/utils/model_utils.py b/common/utils/model_utils.py index afcd7b9..ec9203d 100644 --- a/common/utils/model_utils.py +++ b/common/utils/model_utils.py @@ -67,6 +67,11 @@ def select_model_variables(tenant): else: model_variables['rag_tuning'] = False + if tenant.rag_context: + model_variables['rag_context'] = tenant.rag_context + else: + model_variables['rag_context'] = " " + # Set HTML Chunking Variables model_variables['html_tags'] = tenant.html_tags model_variables['html_end_tags'] = tenant.html_end_tags @@ -98,6 +103,9 @@ def select_model_variables(tenant): model_variables['llm'] = ChatOpenAI(api_key=api_key, model=llm_model, temperature=model_variables['RAG_temperature']) + model_variables['llm_no_rag'] = ChatOpenAI(api_key=api_key, + model=llm_model, + temperature=model_variables['no_RAG_temperature']) tool_calling_supported = False match llm_model: case 'gpt-4-turbo' | 'gpt-4o': @@ -132,3 +140,7 @@ def create_language_template(template, language): language_template = template.replace('{language}', language) return language_template + + +def replace_variable_in_template(template, variable, value): + return template.replace(variable, value) \ No newline at end of file diff --git a/config/config.py b/config/config.py index f00fba3..632dac9 100644 --- a/config/config.py +++ b/config/config.py @@ -79,6 +79,7 @@ class Config(object): ```{text}```""" GPT4_RAG_TEMPLATE = """Answer the question based on the following context, delimited between triple backquotes. + {tenant_context} Use the following {language} in your communication, and cite the sources used. If the question cannot be answered using the given context, say "I have insufficient information to answer this question." Context: @@ -86,6 +87,7 @@ class Config(object): Question: {question}""" GPT3_5_RAG_TEMPLATE = """Answer the question based on the following context, delimited between triple backquotes. + {tenant_context} Use the following {language} in your communication. If the question cannot be answered using the given context, say "I have insufficient information to answer this question." Context: @@ -95,6 +97,7 @@ class Config(object): GPT4_HISTORY_TEMPLATE = """You are a helpful assistant that details a question based on a previous context, in such a way that the question is understandable without the previous context. + {tenant_context} The context is a conversation history, with the HUMAN asking questions, the AI answering questions. The history is delimited between triple backquotes. Your answer by stating the question in {language}. @@ -103,6 +106,13 @@ class Config(object): Question to be detailed: {question}""" + # Fallback Algorithms + FALLBACK_ALGORITHMS = [ + "RAG_TENANT", + "RAG_WIKIPEDIA", + "RAG_GOOGLE", + "LLM" + ] # SocketIO settings # SOCKETIO_ASYNC_MODE = 'threading' diff --git a/eveai_app/views/user_forms.py b/eveai_app/views/user_forms.py index 326b250..4a56cf8 100644 --- a/eveai_app/views/user_forms.py +++ b/eveai_app/views/user_forms.py @@ -1,7 +1,7 @@ from flask import current_app from flask_wtf import FlaskForm from wtforms import (StringField, PasswordField, BooleanField, SubmitField, EmailField, IntegerField, DateField, - SelectField, SelectMultipleField, FieldList, FormField, FloatField) + SelectField, SelectMultipleField, FieldList, FormField, FloatField, TextAreaField) from wtforms.validators import DataRequired, Length, Email, NumberRange, Optional import pytz @@ -16,6 +16,8 @@ class TenantForm(FlaskForm): allowed_languages = SelectMultipleField('Allowed Languages', choices=[], validators=[DataRequired()]) # Timezone timezone = SelectField('Timezone', choices=[], validators=[DataRequired()]) + # RAG context + rag_context = TextAreaField('RAG Context', validators=[Optional()]) # LLM fields embedding_model = SelectField('Embedding Model', choices=[], validators=[DataRequired()]) llm_model = SelectField('Large Language Model', choices=[], validators=[DataRequired()]) @@ -37,6 +39,10 @@ class TenantForm(FlaskForm): es_similarity_threshold = FloatField('Similarity Threshold for Searching Embeddings (0.5)', default=0.5, validators=[NumberRange(min=0, max=1)]) + # Chat Variables + chat_RAG_temperature = FloatField('RAG Temperature', default=0.3, validators=[NumberRange(min=0, max=1)]) + chat_no_RAG_temperature = FloatField('No RAG Temperature', default=0.5, validators=[NumberRange(min=0, max=1)]) + fallback_algorithms = SelectMultipleField('Fallback Algorithms', choices=[], validators=[Optional()]) # Tuning variables embed_tuning = BooleanField('Enable Embedding Tuning', default=False) rag_tuning = BooleanField('Enable RAG Tuning', default=False) @@ -53,6 +59,8 @@ class TenantForm(FlaskForm): # initialise LLM fields self.embedding_model.choices = [(model, model) for model in current_app.config['SUPPORTED_EMBEDDINGS']] self.llm_model.choices = [(model, model) for model in current_app.config['SUPPORTED_LLMS']] + # Initialize fallback algorithms + self.fallback_algorithms.choices = [(algorithm, algorithm.lower()) for algorithm in current_app.config['FALLBACK_ALGORITHMS']] class BaseUserForm(FlaskForm): diff --git a/eveai_app/views/user_views.py b/eveai_app/views/user_views.py index d5ae6f0..9bf5f91 100644 --- a/eveai_app/views/user_views.py +++ b/eveai_app/views/user_views.py @@ -45,18 +45,20 @@ def tenant(): if form.validate_on_submit(): current_app.logger.debug('Creating new tenant') # Handle the required attributes - new_tenant = Tenant(name=form.name.data, - website=form.website.data, - default_language=form.default_language.data, - allowed_languages=form.allowed_languages.data, - timezone=form.timezone.data, - embedding_model=form.embedding_model.data, - llm_model=form.llm_model.data, - license_start_date=form.license_start_date.data, - license_end_date=form.license_end_date.data, - allowed_monthly_interactions=form.allowed_monthly_interactions.data, - embed_tuning=form.embed_tuning.data, - rag_tuning=form.rag_tuning.data) + new_tenant = Tenant() + form.populate_obj(new_tenant) + # new_tenant = Tenant(name=form.name.data, + # website=form.website.data, + # default_language=form.default_language.data, + # allowed_languages=form.allowed_languages.data, + # timezone=form.timezone.data, + # embedding_model=form.embedding_model.data, + # llm_model=form.llm_model.data, + # license_start_date=form.license_start_date.data, + # license_end_date=form.license_end_date.data, + # allowed_monthly_interactions=form.allowed_monthly_interactions.data, + # embed_tuning=form.embed_tuning.data, + # rag_tuning=form.rag_tuning.data) # Handle Embedding Variables new_tenant.html_tags = form.html_tags.data.split(',') if form.html_tags.data else [] diff --git a/eveai_chat_workers/tasks.py b/eveai_chat_workers/tasks.py index c6bd0e1..68071cf 100644 --- a/eveai_chat_workers/tasks.py +++ b/eveai_chat_workers/tasks.py @@ -21,7 +21,7 @@ from common.models.user import Tenant from common.models.interaction import ChatSession, Interaction, InteractionEmbedding from common.extensions import db from common.utils.celery_utils import current_celery -from common.utils.model_utils import select_model_variables, create_language_template +from common.utils.model_utils import select_model_variables, create_language_template, replace_variable_in_template from common.langchain.EveAIRetriever import EveAIRetriever from common.langchain.EveAIHistoryRetriever import EveAIHistoryRetriever @@ -31,7 +31,8 @@ def detail_question(question, language, model_variables, session_id): llm = model_variables['llm'] template = model_variables['history_template'] language_template = create_language_template(template, language) - history_prompt = ChatPromptTemplate.from_template(language_template) + full_template = replace_variable_in_template(language_template, "{tenant_context}", model_variables['rag_context']) + history_prompt = ChatPromptTemplate.from_template(full_template) setup_and_retrieval = RunnableParallel({"history": retriever,"question": RunnablePassthrough()}) output_parser = StrOutputParser() @@ -103,7 +104,8 @@ def ask_question(tenant_id, question, language, session_id): llm = model_variables['llm'] template = model_variables['rag_template'] language_template = create_language_template(template, language) - rag_prompt = ChatPromptTemplate.from_template(language_template) + full_template = replace_variable_in_template(language_template, "{tenant_context}", model_variables['rag_context']) + rag_prompt = ChatPromptTemplate.from_template(full_template) setup_and_retrieval = RunnableParallel({"context": retriever, "question": RunnablePassthrough()}) new_interaction_embeddings = []