Variables for rag_context and fallback algorithms added to Tenant and parts of the implementation.

This commit is contained in:
Josako
2024-06-13 15:23:35 +02:00
parent cbddaee810
commit b77e1ab321
6 changed files with 63 additions and 17 deletions

View File

@@ -20,6 +20,7 @@ class Tenant(db.Model):
name = db.Column(db.String(80), unique=True, nullable=False) name = db.Column(db.String(80), unique=True, nullable=False)
website = db.Column(db.String(255), nullable=True) website = db.Column(db.String(255), nullable=True)
timezone = db.Column(db.String(50), nullable=True, default='UTC') timezone = db.Column(db.String(50), nullable=True, default='UTC')
rag_context = db.Column(db.Text, nullable=True)
# language information # language information
default_language = db.Column(db.String(2), nullable=True) default_language = db.Column(db.String(2), nullable=True)
@@ -42,6 +43,7 @@ class Tenant(db.Model):
# Chat variables # Chat variables
chat_RAG_temperature = db.Column(db.Float, nullable=True, default=0.3) chat_RAG_temperature = db.Column(db.Float, nullable=True, default=0.3)
chat_no_RAG_temperature = db.Column(db.Float, nullable=True, default=0.5) chat_no_RAG_temperature = db.Column(db.Float, nullable=True, default=0.5)
fallback_algorithms = db.Column(ARRAY(sa.String(50)), nullable=True)
# Licensing Information # Licensing Information
license_start_date = db.Column(db.Date, nullable=True) license_start_date = db.Column(db.Date, nullable=True)
@@ -65,11 +67,21 @@ class Tenant(db.Model):
'id': self.id, 'id': self.id,
'name': self.name, 'name': self.name,
'website': self.website, 'website': self.website,
'timezone': self.timezone,
'rag_context': self.rag_context,
'default_language': self.default_language, 'default_language': self.default_language,
'allowed_languages': self.allowed_languages, 'allowed_languages': self.allowed_languages,
'timezone': self.timezone,
'embedding_model': self.embedding_model, 'embedding_model': self.embedding_model,
'llm_model': self.llm_model, 'llm_model': self.llm_model,
'html_tags': self.html_tags,
'html_end_tags': self.html_end_tags,
'html_included_elements': self.html_included_elements,
'html_excluded_elements': self.html_excluded_elements,
'es_k': self.es_k,
'es_similarity_threshold': self.es_similarity_threshold,
'chat_RAG_temperature': self.chat_RAG_temperature,
'chat_no_RAG_temperature': self.chat_no_RAG_temperature,
'fallback_algorithms': self.fallback_algorithms,
'license_start_date': self.license_start_date, 'license_start_date': self.license_start_date,
'license_end_date': self.license_end_date, 'license_end_date': self.license_end_date,
'allowed_monthly_interactions': self.allowed_monthly_interactions, 'allowed_monthly_interactions': self.allowed_monthly_interactions,

View File

@@ -67,6 +67,11 @@ def select_model_variables(tenant):
else: else:
model_variables['rag_tuning'] = False model_variables['rag_tuning'] = False
if tenant.rag_context:
model_variables['rag_context'] = tenant.rag_context
else:
model_variables['rag_context'] = " "
# Set HTML Chunking Variables # Set HTML Chunking Variables
model_variables['html_tags'] = tenant.html_tags model_variables['html_tags'] = tenant.html_tags
model_variables['html_end_tags'] = tenant.html_end_tags model_variables['html_end_tags'] = tenant.html_end_tags
@@ -98,6 +103,9 @@ def select_model_variables(tenant):
model_variables['llm'] = ChatOpenAI(api_key=api_key, model_variables['llm'] = ChatOpenAI(api_key=api_key,
model=llm_model, model=llm_model,
temperature=model_variables['RAG_temperature']) temperature=model_variables['RAG_temperature'])
model_variables['llm_no_rag'] = ChatOpenAI(api_key=api_key,
model=llm_model,
temperature=model_variables['no_RAG_temperature'])
tool_calling_supported = False tool_calling_supported = False
match llm_model: match llm_model:
case 'gpt-4-turbo' | 'gpt-4o': case 'gpt-4-turbo' | 'gpt-4o':
@@ -132,3 +140,7 @@ def create_language_template(template, language):
language_template = template.replace('{language}', language) language_template = template.replace('{language}', language)
return language_template return language_template
def replace_variable_in_template(template, variable, value):
return template.replace(variable, value)

View File

@@ -79,6 +79,7 @@ class Config(object):
```{text}```""" ```{text}```"""
GPT4_RAG_TEMPLATE = """Answer the question based on the following context, delimited between triple backquotes. GPT4_RAG_TEMPLATE = """Answer the question based on the following context, delimited between triple backquotes.
{tenant_context}
Use the following {language} in your communication, and cite the sources used. Use the following {language} in your communication, and cite the sources used.
If the question cannot be answered using the given context, say "I have insufficient information to answer this question." If the question cannot be answered using the given context, say "I have insufficient information to answer this question."
Context: Context:
@@ -86,6 +87,7 @@ class Config(object):
Question: Question:
{question}""" {question}"""
GPT3_5_RAG_TEMPLATE = """Answer the question based on the following context, delimited between triple backquotes. GPT3_5_RAG_TEMPLATE = """Answer the question based on the following context, delimited between triple backquotes.
{tenant_context}
Use the following {language} in your communication. Use the following {language} in your communication.
If the question cannot be answered using the given context, say "I have insufficient information to answer this question." If the question cannot be answered using the given context, say "I have insufficient information to answer this question."
Context: Context:
@@ -95,6 +97,7 @@ class Config(object):
GPT4_HISTORY_TEMPLATE = """You are a helpful assistant that details a question based on a previous context, GPT4_HISTORY_TEMPLATE = """You are a helpful assistant that details a question based on a previous context,
in such a way that the question is understandable without the previous context. in such a way that the question is understandable without the previous context.
{tenant_context}
The context is a conversation history, with the HUMAN asking questions, the AI answering questions. The context is a conversation history, with the HUMAN asking questions, the AI answering questions.
The history is delimited between triple backquotes. The history is delimited between triple backquotes.
Your answer by stating the question in {language}. Your answer by stating the question in {language}.
@@ -103,6 +106,13 @@ class Config(object):
Question to be detailed: Question to be detailed:
{question}""" {question}"""
# Fallback Algorithms
FALLBACK_ALGORITHMS = [
"RAG_TENANT",
"RAG_WIKIPEDIA",
"RAG_GOOGLE",
"LLM"
]
# SocketIO settings # SocketIO settings
# SOCKETIO_ASYNC_MODE = 'threading' # SOCKETIO_ASYNC_MODE = 'threading'

View File

@@ -1,7 +1,7 @@
from flask import current_app from flask import current_app
from flask_wtf import FlaskForm from flask_wtf import FlaskForm
from wtforms import (StringField, PasswordField, BooleanField, SubmitField, EmailField, IntegerField, DateField, from wtforms import (StringField, PasswordField, BooleanField, SubmitField, EmailField, IntegerField, DateField,
SelectField, SelectMultipleField, FieldList, FormField, FloatField) SelectField, SelectMultipleField, FieldList, FormField, FloatField, TextAreaField)
from wtforms.validators import DataRequired, Length, Email, NumberRange, Optional from wtforms.validators import DataRequired, Length, Email, NumberRange, Optional
import pytz import pytz
@@ -16,6 +16,8 @@ class TenantForm(FlaskForm):
allowed_languages = SelectMultipleField('Allowed Languages', choices=[], validators=[DataRequired()]) allowed_languages = SelectMultipleField('Allowed Languages', choices=[], validators=[DataRequired()])
# Timezone # Timezone
timezone = SelectField('Timezone', choices=[], validators=[DataRequired()]) timezone = SelectField('Timezone', choices=[], validators=[DataRequired()])
# RAG context
rag_context = TextAreaField('RAG Context', validators=[Optional()])
# LLM fields # LLM fields
embedding_model = SelectField('Embedding Model', choices=[], validators=[DataRequired()]) embedding_model = SelectField('Embedding Model', choices=[], validators=[DataRequired()])
llm_model = SelectField('Large Language Model', choices=[], validators=[DataRequired()]) llm_model = SelectField('Large Language Model', choices=[], validators=[DataRequired()])
@@ -37,6 +39,10 @@ class TenantForm(FlaskForm):
es_similarity_threshold = FloatField('Similarity Threshold for Searching Embeddings (0.5)', es_similarity_threshold = FloatField('Similarity Threshold for Searching Embeddings (0.5)',
default=0.5, default=0.5,
validators=[NumberRange(min=0, max=1)]) validators=[NumberRange(min=0, max=1)])
# Chat Variables
chat_RAG_temperature = FloatField('RAG Temperature', default=0.3, validators=[NumberRange(min=0, max=1)])
chat_no_RAG_temperature = FloatField('No RAG Temperature', default=0.5, validators=[NumberRange(min=0, max=1)])
fallback_algorithms = SelectMultipleField('Fallback Algorithms', choices=[], validators=[Optional()])
# Tuning variables # Tuning variables
embed_tuning = BooleanField('Enable Embedding Tuning', default=False) embed_tuning = BooleanField('Enable Embedding Tuning', default=False)
rag_tuning = BooleanField('Enable RAG Tuning', default=False) rag_tuning = BooleanField('Enable RAG Tuning', default=False)
@@ -53,6 +59,8 @@ class TenantForm(FlaskForm):
# initialise LLM fields # initialise LLM fields
self.embedding_model.choices = [(model, model) for model in current_app.config['SUPPORTED_EMBEDDINGS']] self.embedding_model.choices = [(model, model) for model in current_app.config['SUPPORTED_EMBEDDINGS']]
self.llm_model.choices = [(model, model) for model in current_app.config['SUPPORTED_LLMS']] self.llm_model.choices = [(model, model) for model in current_app.config['SUPPORTED_LLMS']]
# Initialize fallback algorithms
self.fallback_algorithms.choices = [(algorithm, algorithm.lower()) for algorithm in current_app.config['FALLBACK_ALGORITHMS']]
class BaseUserForm(FlaskForm): class BaseUserForm(FlaskForm):

View File

@@ -45,18 +45,20 @@ def tenant():
if form.validate_on_submit(): if form.validate_on_submit():
current_app.logger.debug('Creating new tenant') current_app.logger.debug('Creating new tenant')
# Handle the required attributes # Handle the required attributes
new_tenant = Tenant(name=form.name.data, new_tenant = Tenant()
website=form.website.data, form.populate_obj(new_tenant)
default_language=form.default_language.data, # new_tenant = Tenant(name=form.name.data,
allowed_languages=form.allowed_languages.data, # website=form.website.data,
timezone=form.timezone.data, # default_language=form.default_language.data,
embedding_model=form.embedding_model.data, # allowed_languages=form.allowed_languages.data,
llm_model=form.llm_model.data, # timezone=form.timezone.data,
license_start_date=form.license_start_date.data, # embedding_model=form.embedding_model.data,
license_end_date=form.license_end_date.data, # llm_model=form.llm_model.data,
allowed_monthly_interactions=form.allowed_monthly_interactions.data, # license_start_date=form.license_start_date.data,
embed_tuning=form.embed_tuning.data, # license_end_date=form.license_end_date.data,
rag_tuning=form.rag_tuning.data) # allowed_monthly_interactions=form.allowed_monthly_interactions.data,
# embed_tuning=form.embed_tuning.data,
# rag_tuning=form.rag_tuning.data)
# Handle Embedding Variables # Handle Embedding Variables
new_tenant.html_tags = form.html_tags.data.split(',') if form.html_tags.data else [] new_tenant.html_tags = form.html_tags.data.split(',') if form.html_tags.data else []

View File

@@ -21,7 +21,7 @@ from common.models.user import Tenant
from common.models.interaction import ChatSession, Interaction, InteractionEmbedding from common.models.interaction import ChatSession, Interaction, InteractionEmbedding
from common.extensions import db from common.extensions import db
from common.utils.celery_utils import current_celery from common.utils.celery_utils import current_celery
from common.utils.model_utils import select_model_variables, create_language_template from common.utils.model_utils import select_model_variables, create_language_template, replace_variable_in_template
from common.langchain.EveAIRetriever import EveAIRetriever from common.langchain.EveAIRetriever import EveAIRetriever
from common.langchain.EveAIHistoryRetriever import EveAIHistoryRetriever from common.langchain.EveAIHistoryRetriever import EveAIHistoryRetriever
@@ -31,7 +31,8 @@ def detail_question(question, language, model_variables, session_id):
llm = model_variables['llm'] llm = model_variables['llm']
template = model_variables['history_template'] template = model_variables['history_template']
language_template = create_language_template(template, language) language_template = create_language_template(template, language)
history_prompt = ChatPromptTemplate.from_template(language_template) full_template = replace_variable_in_template(language_template, "{tenant_context}", model_variables['rag_context'])
history_prompt = ChatPromptTemplate.from_template(full_template)
setup_and_retrieval = RunnableParallel({"history": retriever,"question": RunnablePassthrough()}) setup_and_retrieval = RunnableParallel({"history": retriever,"question": RunnablePassthrough()})
output_parser = StrOutputParser() output_parser = StrOutputParser()
@@ -103,7 +104,8 @@ def ask_question(tenant_id, question, language, session_id):
llm = model_variables['llm'] llm = model_variables['llm']
template = model_variables['rag_template'] template = model_variables['rag_template']
language_template = create_language_template(template, language) language_template = create_language_template(template, language)
rag_prompt = ChatPromptTemplate.from_template(language_template) full_template = replace_variable_in_template(language_template, "{tenant_context}", model_variables['rag_context'])
rag_prompt = ChatPromptTemplate.from_template(full_template)
setup_and_retrieval = RunnableParallel({"context": retriever, "question": RunnablePassthrough()}) setup_and_retrieval = RunnableParallel({"context": retriever, "question": RunnablePassthrough()})
new_interaction_embeddings = [] new_interaction_embeddings = []