Variables for rag_context and fallback algorithms added to Tenant and parts of the implementation.

This commit is contained in:
Josako
2024-06-13 15:23:35 +02:00
parent cbddaee810
commit b77e1ab321
6 changed files with 63 additions and 17 deletions

View File

@@ -20,6 +20,7 @@ class Tenant(db.Model):
name = db.Column(db.String(80), unique=True, nullable=False)
website = db.Column(db.String(255), nullable=True)
timezone = db.Column(db.String(50), nullable=True, default='UTC')
rag_context = db.Column(db.Text, nullable=True)
# language information
default_language = db.Column(db.String(2), nullable=True)
@@ -42,6 +43,7 @@ class Tenant(db.Model):
# Chat variables
chat_RAG_temperature = db.Column(db.Float, nullable=True, default=0.3)
chat_no_RAG_temperature = db.Column(db.Float, nullable=True, default=0.5)
fallback_algorithms = db.Column(ARRAY(sa.String(50)), nullable=True)
# Licensing Information
license_start_date = db.Column(db.Date, nullable=True)
@@ -65,11 +67,21 @@ class Tenant(db.Model):
'id': self.id,
'name': self.name,
'website': self.website,
'timezone': self.timezone,
'rag_context': self.rag_context,
'default_language': self.default_language,
'allowed_languages': self.allowed_languages,
'timezone': self.timezone,
'embedding_model': self.embedding_model,
'llm_model': self.llm_model,
'html_tags': self.html_tags,
'html_end_tags': self.html_end_tags,
'html_included_elements': self.html_included_elements,
'html_excluded_elements': self.html_excluded_elements,
'es_k': self.es_k,
'es_similarity_threshold': self.es_similarity_threshold,
'chat_RAG_temperature': self.chat_RAG_temperature,
'chat_no_RAG_temperature': self.chat_no_RAG_temperature,
'fallback_algorithms': self.fallback_algorithms,
'license_start_date': self.license_start_date,
'license_end_date': self.license_end_date,
'allowed_monthly_interactions': self.allowed_monthly_interactions,

View File

@@ -67,6 +67,11 @@ def select_model_variables(tenant):
else:
model_variables['rag_tuning'] = False
if tenant.rag_context:
model_variables['rag_context'] = tenant.rag_context
else:
model_variables['rag_context'] = " "
# Set HTML Chunking Variables
model_variables['html_tags'] = tenant.html_tags
model_variables['html_end_tags'] = tenant.html_end_tags
@@ -98,6 +103,9 @@ def select_model_variables(tenant):
model_variables['llm'] = ChatOpenAI(api_key=api_key,
model=llm_model,
temperature=model_variables['RAG_temperature'])
model_variables['llm_no_rag'] = ChatOpenAI(api_key=api_key,
model=llm_model,
temperature=model_variables['no_RAG_temperature'])
tool_calling_supported = False
match llm_model:
case 'gpt-4-turbo' | 'gpt-4o':
@@ -132,3 +140,7 @@ def create_language_template(template, language):
language_template = template.replace('{language}', language)
return language_template
def replace_variable_in_template(template, variable, value):
return template.replace(variable, value)

View File

@@ -79,6 +79,7 @@ class Config(object):
```{text}```"""
GPT4_RAG_TEMPLATE = """Answer the question based on the following context, delimited between triple backquotes.
{tenant_context}
Use the following {language} in your communication, and cite the sources used.
If the question cannot be answered using the given context, say "I have insufficient information to answer this question."
Context:
@@ -86,6 +87,7 @@ class Config(object):
Question:
{question}"""
GPT3_5_RAG_TEMPLATE = """Answer the question based on the following context, delimited between triple backquotes.
{tenant_context}
Use the following {language} in your communication.
If the question cannot be answered using the given context, say "I have insufficient information to answer this question."
Context:
@@ -95,6 +97,7 @@ class Config(object):
GPT4_HISTORY_TEMPLATE = """You are a helpful assistant that details a question based on a previous context,
in such a way that the question is understandable without the previous context.
{tenant_context}
The context is a conversation history, with the HUMAN asking questions, the AI answering questions.
The history is delimited between triple backquotes.
Your answer by stating the question in {language}.
@@ -103,6 +106,13 @@ class Config(object):
Question to be detailed:
{question}"""
# Fallback Algorithms
FALLBACK_ALGORITHMS = [
"RAG_TENANT",
"RAG_WIKIPEDIA",
"RAG_GOOGLE",
"LLM"
]
# SocketIO settings
# SOCKETIO_ASYNC_MODE = 'threading'

View File

@@ -1,7 +1,7 @@
from flask import current_app
from flask_wtf import FlaskForm
from wtforms import (StringField, PasswordField, BooleanField, SubmitField, EmailField, IntegerField, DateField,
SelectField, SelectMultipleField, FieldList, FormField, FloatField)
SelectField, SelectMultipleField, FieldList, FormField, FloatField, TextAreaField)
from wtforms.validators import DataRequired, Length, Email, NumberRange, Optional
import pytz
@@ -16,6 +16,8 @@ class TenantForm(FlaskForm):
allowed_languages = SelectMultipleField('Allowed Languages', choices=[], validators=[DataRequired()])
# Timezone
timezone = SelectField('Timezone', choices=[], validators=[DataRequired()])
# RAG context
rag_context = TextAreaField('RAG Context', validators=[Optional()])
# LLM fields
embedding_model = SelectField('Embedding Model', choices=[], validators=[DataRequired()])
llm_model = SelectField('Large Language Model', choices=[], validators=[DataRequired()])
@@ -37,6 +39,10 @@ class TenantForm(FlaskForm):
es_similarity_threshold = FloatField('Similarity Threshold for Searching Embeddings (0.5)',
default=0.5,
validators=[NumberRange(min=0, max=1)])
# Chat Variables
chat_RAG_temperature = FloatField('RAG Temperature', default=0.3, validators=[NumberRange(min=0, max=1)])
chat_no_RAG_temperature = FloatField('No RAG Temperature', default=0.5, validators=[NumberRange(min=0, max=1)])
fallback_algorithms = SelectMultipleField('Fallback Algorithms', choices=[], validators=[Optional()])
# Tuning variables
embed_tuning = BooleanField('Enable Embedding Tuning', default=False)
rag_tuning = BooleanField('Enable RAG Tuning', default=False)
@@ -53,6 +59,8 @@ class TenantForm(FlaskForm):
# initialise LLM fields
self.embedding_model.choices = [(model, model) for model in current_app.config['SUPPORTED_EMBEDDINGS']]
self.llm_model.choices = [(model, model) for model in current_app.config['SUPPORTED_LLMS']]
# Initialize fallback algorithms
self.fallback_algorithms.choices = [(algorithm, algorithm.lower()) for algorithm in current_app.config['FALLBACK_ALGORITHMS']]
class BaseUserForm(FlaskForm):

View File

@@ -45,18 +45,20 @@ def tenant():
if form.validate_on_submit():
current_app.logger.debug('Creating new tenant')
# Handle the required attributes
new_tenant = Tenant(name=form.name.data,
website=form.website.data,
default_language=form.default_language.data,
allowed_languages=form.allowed_languages.data,
timezone=form.timezone.data,
embedding_model=form.embedding_model.data,
llm_model=form.llm_model.data,
license_start_date=form.license_start_date.data,
license_end_date=form.license_end_date.data,
allowed_monthly_interactions=form.allowed_monthly_interactions.data,
embed_tuning=form.embed_tuning.data,
rag_tuning=form.rag_tuning.data)
new_tenant = Tenant()
form.populate_obj(new_tenant)
# new_tenant = Tenant(name=form.name.data,
# website=form.website.data,
# default_language=form.default_language.data,
# allowed_languages=form.allowed_languages.data,
# timezone=form.timezone.data,
# embedding_model=form.embedding_model.data,
# llm_model=form.llm_model.data,
# license_start_date=form.license_start_date.data,
# license_end_date=form.license_end_date.data,
# allowed_monthly_interactions=form.allowed_monthly_interactions.data,
# embed_tuning=form.embed_tuning.data,
# rag_tuning=form.rag_tuning.data)
# Handle Embedding Variables
new_tenant.html_tags = form.html_tags.data.split(',') if form.html_tags.data else []

View File

@@ -21,7 +21,7 @@ from common.models.user import Tenant
from common.models.interaction import ChatSession, Interaction, InteractionEmbedding
from common.extensions import db
from common.utils.celery_utils import current_celery
from common.utils.model_utils import select_model_variables, create_language_template
from common.utils.model_utils import select_model_variables, create_language_template, replace_variable_in_template
from common.langchain.EveAIRetriever import EveAIRetriever
from common.langchain.EveAIHistoryRetriever import EveAIHistoryRetriever
@@ -31,7 +31,8 @@ def detail_question(question, language, model_variables, session_id):
llm = model_variables['llm']
template = model_variables['history_template']
language_template = create_language_template(template, language)
history_prompt = ChatPromptTemplate.from_template(language_template)
full_template = replace_variable_in_template(language_template, "{tenant_context}", model_variables['rag_context'])
history_prompt = ChatPromptTemplate.from_template(full_template)
setup_and_retrieval = RunnableParallel({"history": retriever,"question": RunnablePassthrough()})
output_parser = StrOutputParser()
@@ -103,7 +104,8 @@ def ask_question(tenant_id, question, language, session_id):
llm = model_variables['llm']
template = model_variables['rag_template']
language_template = create_language_template(template, language)
rag_prompt = ChatPromptTemplate.from_template(language_template)
full_template = replace_variable_in_template(language_template, "{tenant_context}", model_variables['rag_context'])
rag_prompt = ChatPromptTemplate.from_template(full_template)
setup_and_retrieval = RunnableParallel({"context": retriever, "question": RunnablePassthrough()})
new_interaction_embeddings = []