Variables for rag_context and fallback algorithms added to Tenant and parts of the implementation.
This commit is contained in:
@@ -20,6 +20,7 @@ class Tenant(db.Model):
|
||||
name = db.Column(db.String(80), unique=True, nullable=False)
|
||||
website = db.Column(db.String(255), nullable=True)
|
||||
timezone = db.Column(db.String(50), nullable=True, default='UTC')
|
||||
rag_context = db.Column(db.Text, nullable=True)
|
||||
|
||||
# language information
|
||||
default_language = db.Column(db.String(2), nullable=True)
|
||||
@@ -42,6 +43,7 @@ class Tenant(db.Model):
|
||||
# Chat variables
|
||||
chat_RAG_temperature = db.Column(db.Float, nullable=True, default=0.3)
|
||||
chat_no_RAG_temperature = db.Column(db.Float, nullable=True, default=0.5)
|
||||
fallback_algorithms = db.Column(ARRAY(sa.String(50)), nullable=True)
|
||||
|
||||
# Licensing Information
|
||||
license_start_date = db.Column(db.Date, nullable=True)
|
||||
@@ -65,11 +67,21 @@ class Tenant(db.Model):
|
||||
'id': self.id,
|
||||
'name': self.name,
|
||||
'website': self.website,
|
||||
'timezone': self.timezone,
|
||||
'rag_context': self.rag_context,
|
||||
'default_language': self.default_language,
|
||||
'allowed_languages': self.allowed_languages,
|
||||
'timezone': self.timezone,
|
||||
'embedding_model': self.embedding_model,
|
||||
'llm_model': self.llm_model,
|
||||
'html_tags': self.html_tags,
|
||||
'html_end_tags': self.html_end_tags,
|
||||
'html_included_elements': self.html_included_elements,
|
||||
'html_excluded_elements': self.html_excluded_elements,
|
||||
'es_k': self.es_k,
|
||||
'es_similarity_threshold': self.es_similarity_threshold,
|
||||
'chat_RAG_temperature': self.chat_RAG_temperature,
|
||||
'chat_no_RAG_temperature': self.chat_no_RAG_temperature,
|
||||
'fallback_algorithms': self.fallback_algorithms,
|
||||
'license_start_date': self.license_start_date,
|
||||
'license_end_date': self.license_end_date,
|
||||
'allowed_monthly_interactions': self.allowed_monthly_interactions,
|
||||
|
||||
@@ -67,6 +67,11 @@ def select_model_variables(tenant):
|
||||
else:
|
||||
model_variables['rag_tuning'] = False
|
||||
|
||||
if tenant.rag_context:
|
||||
model_variables['rag_context'] = tenant.rag_context
|
||||
else:
|
||||
model_variables['rag_context'] = " "
|
||||
|
||||
# Set HTML Chunking Variables
|
||||
model_variables['html_tags'] = tenant.html_tags
|
||||
model_variables['html_end_tags'] = tenant.html_end_tags
|
||||
@@ -98,6 +103,9 @@ def select_model_variables(tenant):
|
||||
model_variables['llm'] = ChatOpenAI(api_key=api_key,
|
||||
model=llm_model,
|
||||
temperature=model_variables['RAG_temperature'])
|
||||
model_variables['llm_no_rag'] = ChatOpenAI(api_key=api_key,
|
||||
model=llm_model,
|
||||
temperature=model_variables['no_RAG_temperature'])
|
||||
tool_calling_supported = False
|
||||
match llm_model:
|
||||
case 'gpt-4-turbo' | 'gpt-4o':
|
||||
@@ -132,3 +140,7 @@ def create_language_template(template, language):
|
||||
language_template = template.replace('{language}', language)
|
||||
|
||||
return language_template
|
||||
|
||||
|
||||
def replace_variable_in_template(template, variable, value):
|
||||
return template.replace(variable, value)
|
||||
@@ -79,6 +79,7 @@ class Config(object):
|
||||
```{text}```"""
|
||||
|
||||
GPT4_RAG_TEMPLATE = """Answer the question based on the following context, delimited between triple backquotes.
|
||||
{tenant_context}
|
||||
Use the following {language} in your communication, and cite the sources used.
|
||||
If the question cannot be answered using the given context, say "I have insufficient information to answer this question."
|
||||
Context:
|
||||
@@ -86,6 +87,7 @@ class Config(object):
|
||||
Question:
|
||||
{question}"""
|
||||
GPT3_5_RAG_TEMPLATE = """Answer the question based on the following context, delimited between triple backquotes.
|
||||
{tenant_context}
|
||||
Use the following {language} in your communication.
|
||||
If the question cannot be answered using the given context, say "I have insufficient information to answer this question."
|
||||
Context:
|
||||
@@ -95,6 +97,7 @@ class Config(object):
|
||||
|
||||
GPT4_HISTORY_TEMPLATE = """You are a helpful assistant that details a question based on a previous context,
|
||||
in such a way that the question is understandable without the previous context.
|
||||
{tenant_context}
|
||||
The context is a conversation history, with the HUMAN asking questions, the AI answering questions.
|
||||
The history is delimited between triple backquotes.
|
||||
Your answer by stating the question in {language}.
|
||||
@@ -103,6 +106,13 @@ class Config(object):
|
||||
Question to be detailed:
|
||||
{question}"""
|
||||
|
||||
# Fallback Algorithms
|
||||
FALLBACK_ALGORITHMS = [
|
||||
"RAG_TENANT",
|
||||
"RAG_WIKIPEDIA",
|
||||
"RAG_GOOGLE",
|
||||
"LLM"
|
||||
]
|
||||
|
||||
# SocketIO settings
|
||||
# SOCKETIO_ASYNC_MODE = 'threading'
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from flask import current_app
|
||||
from flask_wtf import FlaskForm
|
||||
from wtforms import (StringField, PasswordField, BooleanField, SubmitField, EmailField, IntegerField, DateField,
|
||||
SelectField, SelectMultipleField, FieldList, FormField, FloatField)
|
||||
SelectField, SelectMultipleField, FieldList, FormField, FloatField, TextAreaField)
|
||||
from wtforms.validators import DataRequired, Length, Email, NumberRange, Optional
|
||||
import pytz
|
||||
|
||||
@@ -16,6 +16,8 @@ class TenantForm(FlaskForm):
|
||||
allowed_languages = SelectMultipleField('Allowed Languages', choices=[], validators=[DataRequired()])
|
||||
# Timezone
|
||||
timezone = SelectField('Timezone', choices=[], validators=[DataRequired()])
|
||||
# RAG context
|
||||
rag_context = TextAreaField('RAG Context', validators=[Optional()])
|
||||
# LLM fields
|
||||
embedding_model = SelectField('Embedding Model', choices=[], validators=[DataRequired()])
|
||||
llm_model = SelectField('Large Language Model', choices=[], validators=[DataRequired()])
|
||||
@@ -37,6 +39,10 @@ class TenantForm(FlaskForm):
|
||||
es_similarity_threshold = FloatField('Similarity Threshold for Searching Embeddings (0.5)',
|
||||
default=0.5,
|
||||
validators=[NumberRange(min=0, max=1)])
|
||||
# Chat Variables
|
||||
chat_RAG_temperature = FloatField('RAG Temperature', default=0.3, validators=[NumberRange(min=0, max=1)])
|
||||
chat_no_RAG_temperature = FloatField('No RAG Temperature', default=0.5, validators=[NumberRange(min=0, max=1)])
|
||||
fallback_algorithms = SelectMultipleField('Fallback Algorithms', choices=[], validators=[Optional()])
|
||||
# Tuning variables
|
||||
embed_tuning = BooleanField('Enable Embedding Tuning', default=False)
|
||||
rag_tuning = BooleanField('Enable RAG Tuning', default=False)
|
||||
@@ -53,6 +59,8 @@ class TenantForm(FlaskForm):
|
||||
# initialise LLM fields
|
||||
self.embedding_model.choices = [(model, model) for model in current_app.config['SUPPORTED_EMBEDDINGS']]
|
||||
self.llm_model.choices = [(model, model) for model in current_app.config['SUPPORTED_LLMS']]
|
||||
# Initialize fallback algorithms
|
||||
self.fallback_algorithms.choices = [(algorithm, algorithm.lower()) for algorithm in current_app.config['FALLBACK_ALGORITHMS']]
|
||||
|
||||
|
||||
class BaseUserForm(FlaskForm):
|
||||
|
||||
@@ -45,18 +45,20 @@ def tenant():
|
||||
if form.validate_on_submit():
|
||||
current_app.logger.debug('Creating new tenant')
|
||||
# Handle the required attributes
|
||||
new_tenant = Tenant(name=form.name.data,
|
||||
website=form.website.data,
|
||||
default_language=form.default_language.data,
|
||||
allowed_languages=form.allowed_languages.data,
|
||||
timezone=form.timezone.data,
|
||||
embedding_model=form.embedding_model.data,
|
||||
llm_model=form.llm_model.data,
|
||||
license_start_date=form.license_start_date.data,
|
||||
license_end_date=form.license_end_date.data,
|
||||
allowed_monthly_interactions=form.allowed_monthly_interactions.data,
|
||||
embed_tuning=form.embed_tuning.data,
|
||||
rag_tuning=form.rag_tuning.data)
|
||||
new_tenant = Tenant()
|
||||
form.populate_obj(new_tenant)
|
||||
# new_tenant = Tenant(name=form.name.data,
|
||||
# website=form.website.data,
|
||||
# default_language=form.default_language.data,
|
||||
# allowed_languages=form.allowed_languages.data,
|
||||
# timezone=form.timezone.data,
|
||||
# embedding_model=form.embedding_model.data,
|
||||
# llm_model=form.llm_model.data,
|
||||
# license_start_date=form.license_start_date.data,
|
||||
# license_end_date=form.license_end_date.data,
|
||||
# allowed_monthly_interactions=form.allowed_monthly_interactions.data,
|
||||
# embed_tuning=form.embed_tuning.data,
|
||||
# rag_tuning=form.rag_tuning.data)
|
||||
|
||||
# Handle Embedding Variables
|
||||
new_tenant.html_tags = form.html_tags.data.split(',') if form.html_tags.data else []
|
||||
|
||||
@@ -21,7 +21,7 @@ from common.models.user import Tenant
|
||||
from common.models.interaction import ChatSession, Interaction, InteractionEmbedding
|
||||
from common.extensions import db
|
||||
from common.utils.celery_utils import current_celery
|
||||
from common.utils.model_utils import select_model_variables, create_language_template
|
||||
from common.utils.model_utils import select_model_variables, create_language_template, replace_variable_in_template
|
||||
from common.langchain.EveAIRetriever import EveAIRetriever
|
||||
from common.langchain.EveAIHistoryRetriever import EveAIHistoryRetriever
|
||||
|
||||
@@ -31,7 +31,8 @@ def detail_question(question, language, model_variables, session_id):
|
||||
llm = model_variables['llm']
|
||||
template = model_variables['history_template']
|
||||
language_template = create_language_template(template, language)
|
||||
history_prompt = ChatPromptTemplate.from_template(language_template)
|
||||
full_template = replace_variable_in_template(language_template, "{tenant_context}", model_variables['rag_context'])
|
||||
history_prompt = ChatPromptTemplate.from_template(full_template)
|
||||
setup_and_retrieval = RunnableParallel({"history": retriever,"question": RunnablePassthrough()})
|
||||
output_parser = StrOutputParser()
|
||||
|
||||
@@ -103,7 +104,8 @@ def ask_question(tenant_id, question, language, session_id):
|
||||
llm = model_variables['llm']
|
||||
template = model_variables['rag_template']
|
||||
language_template = create_language_template(template, language)
|
||||
rag_prompt = ChatPromptTemplate.from_template(language_template)
|
||||
full_template = replace_variable_in_template(language_template, "{tenant_context}", model_variables['rag_context'])
|
||||
rag_prompt = ChatPromptTemplate.from_template(full_template)
|
||||
setup_and_retrieval = RunnableParallel({"context": retriever, "question": RunnablePassthrough()})
|
||||
|
||||
new_interaction_embeddings = []
|
||||
|
||||
Reference in New Issue
Block a user