Variables for rag_context and fallback algorithms added to Tenant and parts of the implementation.
This commit is contained in:
@@ -20,6 +20,7 @@ class Tenant(db.Model):
|
|||||||
name = db.Column(db.String(80), unique=True, nullable=False)
|
name = db.Column(db.String(80), unique=True, nullable=False)
|
||||||
website = db.Column(db.String(255), nullable=True)
|
website = db.Column(db.String(255), nullable=True)
|
||||||
timezone = db.Column(db.String(50), nullable=True, default='UTC')
|
timezone = db.Column(db.String(50), nullable=True, default='UTC')
|
||||||
|
rag_context = db.Column(db.Text, nullable=True)
|
||||||
|
|
||||||
# language information
|
# language information
|
||||||
default_language = db.Column(db.String(2), nullable=True)
|
default_language = db.Column(db.String(2), nullable=True)
|
||||||
@@ -42,6 +43,7 @@ class Tenant(db.Model):
|
|||||||
# Chat variables
|
# Chat variables
|
||||||
chat_RAG_temperature = db.Column(db.Float, nullable=True, default=0.3)
|
chat_RAG_temperature = db.Column(db.Float, nullable=True, default=0.3)
|
||||||
chat_no_RAG_temperature = db.Column(db.Float, nullable=True, default=0.5)
|
chat_no_RAG_temperature = db.Column(db.Float, nullable=True, default=0.5)
|
||||||
|
fallback_algorithms = db.Column(ARRAY(sa.String(50)), nullable=True)
|
||||||
|
|
||||||
# Licensing Information
|
# Licensing Information
|
||||||
license_start_date = db.Column(db.Date, nullable=True)
|
license_start_date = db.Column(db.Date, nullable=True)
|
||||||
@@ -65,11 +67,21 @@ class Tenant(db.Model):
|
|||||||
'id': self.id,
|
'id': self.id,
|
||||||
'name': self.name,
|
'name': self.name,
|
||||||
'website': self.website,
|
'website': self.website,
|
||||||
|
'timezone': self.timezone,
|
||||||
|
'rag_context': self.rag_context,
|
||||||
'default_language': self.default_language,
|
'default_language': self.default_language,
|
||||||
'allowed_languages': self.allowed_languages,
|
'allowed_languages': self.allowed_languages,
|
||||||
'timezone': self.timezone,
|
|
||||||
'embedding_model': self.embedding_model,
|
'embedding_model': self.embedding_model,
|
||||||
'llm_model': self.llm_model,
|
'llm_model': self.llm_model,
|
||||||
|
'html_tags': self.html_tags,
|
||||||
|
'html_end_tags': self.html_end_tags,
|
||||||
|
'html_included_elements': self.html_included_elements,
|
||||||
|
'html_excluded_elements': self.html_excluded_elements,
|
||||||
|
'es_k': self.es_k,
|
||||||
|
'es_similarity_threshold': self.es_similarity_threshold,
|
||||||
|
'chat_RAG_temperature': self.chat_RAG_temperature,
|
||||||
|
'chat_no_RAG_temperature': self.chat_no_RAG_temperature,
|
||||||
|
'fallback_algorithms': self.fallback_algorithms,
|
||||||
'license_start_date': self.license_start_date,
|
'license_start_date': self.license_start_date,
|
||||||
'license_end_date': self.license_end_date,
|
'license_end_date': self.license_end_date,
|
||||||
'allowed_monthly_interactions': self.allowed_monthly_interactions,
|
'allowed_monthly_interactions': self.allowed_monthly_interactions,
|
||||||
|
|||||||
@@ -67,6 +67,11 @@ def select_model_variables(tenant):
|
|||||||
else:
|
else:
|
||||||
model_variables['rag_tuning'] = False
|
model_variables['rag_tuning'] = False
|
||||||
|
|
||||||
|
if tenant.rag_context:
|
||||||
|
model_variables['rag_context'] = tenant.rag_context
|
||||||
|
else:
|
||||||
|
model_variables['rag_context'] = " "
|
||||||
|
|
||||||
# Set HTML Chunking Variables
|
# Set HTML Chunking Variables
|
||||||
model_variables['html_tags'] = tenant.html_tags
|
model_variables['html_tags'] = tenant.html_tags
|
||||||
model_variables['html_end_tags'] = tenant.html_end_tags
|
model_variables['html_end_tags'] = tenant.html_end_tags
|
||||||
@@ -98,6 +103,9 @@ def select_model_variables(tenant):
|
|||||||
model_variables['llm'] = ChatOpenAI(api_key=api_key,
|
model_variables['llm'] = ChatOpenAI(api_key=api_key,
|
||||||
model=llm_model,
|
model=llm_model,
|
||||||
temperature=model_variables['RAG_temperature'])
|
temperature=model_variables['RAG_temperature'])
|
||||||
|
model_variables['llm_no_rag'] = ChatOpenAI(api_key=api_key,
|
||||||
|
model=llm_model,
|
||||||
|
temperature=model_variables['no_RAG_temperature'])
|
||||||
tool_calling_supported = False
|
tool_calling_supported = False
|
||||||
match llm_model:
|
match llm_model:
|
||||||
case 'gpt-4-turbo' | 'gpt-4o':
|
case 'gpt-4-turbo' | 'gpt-4o':
|
||||||
@@ -132,3 +140,7 @@ def create_language_template(template, language):
|
|||||||
language_template = template.replace('{language}', language)
|
language_template = template.replace('{language}', language)
|
||||||
|
|
||||||
return language_template
|
return language_template
|
||||||
|
|
||||||
|
|
||||||
|
def replace_variable_in_template(template, variable, value):
|
||||||
|
return template.replace(variable, value)
|
||||||
@@ -79,6 +79,7 @@ class Config(object):
|
|||||||
```{text}```"""
|
```{text}```"""
|
||||||
|
|
||||||
GPT4_RAG_TEMPLATE = """Answer the question based on the following context, delimited between triple backquotes.
|
GPT4_RAG_TEMPLATE = """Answer the question based on the following context, delimited between triple backquotes.
|
||||||
|
{tenant_context}
|
||||||
Use the following {language} in your communication, and cite the sources used.
|
Use the following {language} in your communication, and cite the sources used.
|
||||||
If the question cannot be answered using the given context, say "I have insufficient information to answer this question."
|
If the question cannot be answered using the given context, say "I have insufficient information to answer this question."
|
||||||
Context:
|
Context:
|
||||||
@@ -86,6 +87,7 @@ class Config(object):
|
|||||||
Question:
|
Question:
|
||||||
{question}"""
|
{question}"""
|
||||||
GPT3_5_RAG_TEMPLATE = """Answer the question based on the following context, delimited between triple backquotes.
|
GPT3_5_RAG_TEMPLATE = """Answer the question based on the following context, delimited between triple backquotes.
|
||||||
|
{tenant_context}
|
||||||
Use the following {language} in your communication.
|
Use the following {language} in your communication.
|
||||||
If the question cannot be answered using the given context, say "I have insufficient information to answer this question."
|
If the question cannot be answered using the given context, say "I have insufficient information to answer this question."
|
||||||
Context:
|
Context:
|
||||||
@@ -95,6 +97,7 @@ class Config(object):
|
|||||||
|
|
||||||
GPT4_HISTORY_TEMPLATE = """You are a helpful assistant that details a question based on a previous context,
|
GPT4_HISTORY_TEMPLATE = """You are a helpful assistant that details a question based on a previous context,
|
||||||
in such a way that the question is understandable without the previous context.
|
in such a way that the question is understandable without the previous context.
|
||||||
|
{tenant_context}
|
||||||
The context is a conversation history, with the HUMAN asking questions, the AI answering questions.
|
The context is a conversation history, with the HUMAN asking questions, the AI answering questions.
|
||||||
The history is delimited between triple backquotes.
|
The history is delimited between triple backquotes.
|
||||||
Your answer by stating the question in {language}.
|
Your answer by stating the question in {language}.
|
||||||
@@ -103,6 +106,13 @@ class Config(object):
|
|||||||
Question to be detailed:
|
Question to be detailed:
|
||||||
{question}"""
|
{question}"""
|
||||||
|
|
||||||
|
# Fallback Algorithms
|
||||||
|
FALLBACK_ALGORITHMS = [
|
||||||
|
"RAG_TENANT",
|
||||||
|
"RAG_WIKIPEDIA",
|
||||||
|
"RAG_GOOGLE",
|
||||||
|
"LLM"
|
||||||
|
]
|
||||||
|
|
||||||
# SocketIO settings
|
# SocketIO settings
|
||||||
# SOCKETIO_ASYNC_MODE = 'threading'
|
# SOCKETIO_ASYNC_MODE = 'threading'
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from flask import current_app
|
from flask import current_app
|
||||||
from flask_wtf import FlaskForm
|
from flask_wtf import FlaskForm
|
||||||
from wtforms import (StringField, PasswordField, BooleanField, SubmitField, EmailField, IntegerField, DateField,
|
from wtforms import (StringField, PasswordField, BooleanField, SubmitField, EmailField, IntegerField, DateField,
|
||||||
SelectField, SelectMultipleField, FieldList, FormField, FloatField)
|
SelectField, SelectMultipleField, FieldList, FormField, FloatField, TextAreaField)
|
||||||
from wtforms.validators import DataRequired, Length, Email, NumberRange, Optional
|
from wtforms.validators import DataRequired, Length, Email, NumberRange, Optional
|
||||||
import pytz
|
import pytz
|
||||||
|
|
||||||
@@ -16,6 +16,8 @@ class TenantForm(FlaskForm):
|
|||||||
allowed_languages = SelectMultipleField('Allowed Languages', choices=[], validators=[DataRequired()])
|
allowed_languages = SelectMultipleField('Allowed Languages', choices=[], validators=[DataRequired()])
|
||||||
# Timezone
|
# Timezone
|
||||||
timezone = SelectField('Timezone', choices=[], validators=[DataRequired()])
|
timezone = SelectField('Timezone', choices=[], validators=[DataRequired()])
|
||||||
|
# RAG context
|
||||||
|
rag_context = TextAreaField('RAG Context', validators=[Optional()])
|
||||||
# LLM fields
|
# LLM fields
|
||||||
embedding_model = SelectField('Embedding Model', choices=[], validators=[DataRequired()])
|
embedding_model = SelectField('Embedding Model', choices=[], validators=[DataRequired()])
|
||||||
llm_model = SelectField('Large Language Model', choices=[], validators=[DataRequired()])
|
llm_model = SelectField('Large Language Model', choices=[], validators=[DataRequired()])
|
||||||
@@ -37,6 +39,10 @@ class TenantForm(FlaskForm):
|
|||||||
es_similarity_threshold = FloatField('Similarity Threshold for Searching Embeddings (0.5)',
|
es_similarity_threshold = FloatField('Similarity Threshold for Searching Embeddings (0.5)',
|
||||||
default=0.5,
|
default=0.5,
|
||||||
validators=[NumberRange(min=0, max=1)])
|
validators=[NumberRange(min=0, max=1)])
|
||||||
|
# Chat Variables
|
||||||
|
chat_RAG_temperature = FloatField('RAG Temperature', default=0.3, validators=[NumberRange(min=0, max=1)])
|
||||||
|
chat_no_RAG_temperature = FloatField('No RAG Temperature', default=0.5, validators=[NumberRange(min=0, max=1)])
|
||||||
|
fallback_algorithms = SelectMultipleField('Fallback Algorithms', choices=[], validators=[Optional()])
|
||||||
# Tuning variables
|
# Tuning variables
|
||||||
embed_tuning = BooleanField('Enable Embedding Tuning', default=False)
|
embed_tuning = BooleanField('Enable Embedding Tuning', default=False)
|
||||||
rag_tuning = BooleanField('Enable RAG Tuning', default=False)
|
rag_tuning = BooleanField('Enable RAG Tuning', default=False)
|
||||||
@@ -53,6 +59,8 @@ class TenantForm(FlaskForm):
|
|||||||
# initialise LLM fields
|
# initialise LLM fields
|
||||||
self.embedding_model.choices = [(model, model) for model in current_app.config['SUPPORTED_EMBEDDINGS']]
|
self.embedding_model.choices = [(model, model) for model in current_app.config['SUPPORTED_EMBEDDINGS']]
|
||||||
self.llm_model.choices = [(model, model) for model in current_app.config['SUPPORTED_LLMS']]
|
self.llm_model.choices = [(model, model) for model in current_app.config['SUPPORTED_LLMS']]
|
||||||
|
# Initialize fallback algorithms
|
||||||
|
self.fallback_algorithms.choices = [(algorithm, algorithm.lower()) for algorithm in current_app.config['FALLBACK_ALGORITHMS']]
|
||||||
|
|
||||||
|
|
||||||
class BaseUserForm(FlaskForm):
|
class BaseUserForm(FlaskForm):
|
||||||
|
|||||||
@@ -45,18 +45,20 @@ def tenant():
|
|||||||
if form.validate_on_submit():
|
if form.validate_on_submit():
|
||||||
current_app.logger.debug('Creating new tenant')
|
current_app.logger.debug('Creating new tenant')
|
||||||
# Handle the required attributes
|
# Handle the required attributes
|
||||||
new_tenant = Tenant(name=form.name.data,
|
new_tenant = Tenant()
|
||||||
website=form.website.data,
|
form.populate_obj(new_tenant)
|
||||||
default_language=form.default_language.data,
|
# new_tenant = Tenant(name=form.name.data,
|
||||||
allowed_languages=form.allowed_languages.data,
|
# website=form.website.data,
|
||||||
timezone=form.timezone.data,
|
# default_language=form.default_language.data,
|
||||||
embedding_model=form.embedding_model.data,
|
# allowed_languages=form.allowed_languages.data,
|
||||||
llm_model=form.llm_model.data,
|
# timezone=form.timezone.data,
|
||||||
license_start_date=form.license_start_date.data,
|
# embedding_model=form.embedding_model.data,
|
||||||
license_end_date=form.license_end_date.data,
|
# llm_model=form.llm_model.data,
|
||||||
allowed_monthly_interactions=form.allowed_monthly_interactions.data,
|
# license_start_date=form.license_start_date.data,
|
||||||
embed_tuning=form.embed_tuning.data,
|
# license_end_date=form.license_end_date.data,
|
||||||
rag_tuning=form.rag_tuning.data)
|
# allowed_monthly_interactions=form.allowed_monthly_interactions.data,
|
||||||
|
# embed_tuning=form.embed_tuning.data,
|
||||||
|
# rag_tuning=form.rag_tuning.data)
|
||||||
|
|
||||||
# Handle Embedding Variables
|
# Handle Embedding Variables
|
||||||
new_tenant.html_tags = form.html_tags.data.split(',') if form.html_tags.data else []
|
new_tenant.html_tags = form.html_tags.data.split(',') if form.html_tags.data else []
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ from common.models.user import Tenant
|
|||||||
from common.models.interaction import ChatSession, Interaction, InteractionEmbedding
|
from common.models.interaction import ChatSession, Interaction, InteractionEmbedding
|
||||||
from common.extensions import db
|
from common.extensions import db
|
||||||
from common.utils.celery_utils import current_celery
|
from common.utils.celery_utils import current_celery
|
||||||
from common.utils.model_utils import select_model_variables, create_language_template
|
from common.utils.model_utils import select_model_variables, create_language_template, replace_variable_in_template
|
||||||
from common.langchain.EveAIRetriever import EveAIRetriever
|
from common.langchain.EveAIRetriever import EveAIRetriever
|
||||||
from common.langchain.EveAIHistoryRetriever import EveAIHistoryRetriever
|
from common.langchain.EveAIHistoryRetriever import EveAIHistoryRetriever
|
||||||
|
|
||||||
@@ -31,7 +31,8 @@ def detail_question(question, language, model_variables, session_id):
|
|||||||
llm = model_variables['llm']
|
llm = model_variables['llm']
|
||||||
template = model_variables['history_template']
|
template = model_variables['history_template']
|
||||||
language_template = create_language_template(template, language)
|
language_template = create_language_template(template, language)
|
||||||
history_prompt = ChatPromptTemplate.from_template(language_template)
|
full_template = replace_variable_in_template(language_template, "{tenant_context}", model_variables['rag_context'])
|
||||||
|
history_prompt = ChatPromptTemplate.from_template(full_template)
|
||||||
setup_and_retrieval = RunnableParallel({"history": retriever,"question": RunnablePassthrough()})
|
setup_and_retrieval = RunnableParallel({"history": retriever,"question": RunnablePassthrough()})
|
||||||
output_parser = StrOutputParser()
|
output_parser = StrOutputParser()
|
||||||
|
|
||||||
@@ -103,7 +104,8 @@ def ask_question(tenant_id, question, language, session_id):
|
|||||||
llm = model_variables['llm']
|
llm = model_variables['llm']
|
||||||
template = model_variables['rag_template']
|
template = model_variables['rag_template']
|
||||||
language_template = create_language_template(template, language)
|
language_template = create_language_template(template, language)
|
||||||
rag_prompt = ChatPromptTemplate.from_template(language_template)
|
full_template = replace_variable_in_template(language_template, "{tenant_context}", model_variables['rag_context'])
|
||||||
|
rag_prompt = ChatPromptTemplate.from_template(full_template)
|
||||||
setup_and_retrieval = RunnableParallel({"context": retriever, "question": RunnablePassthrough()})
|
setup_and_retrieval = RunnableParallel({"context": retriever, "question": RunnablePassthrough()})
|
||||||
|
|
||||||
new_interaction_embeddings = []
|
new_interaction_embeddings = []
|
||||||
|
|||||||
Reference in New Issue
Block a user