- Move to Mistral iso OpenAI as primary choice
This commit is contained in:
@@ -26,6 +26,10 @@ class StandardRAGRetriever(BaseRetriever):
|
||||
retriever = Retriever.query.get_or_404(retriever_id)
|
||||
self.catalog_id = retriever.catalog_id
|
||||
self.tenant_id = tenant_id
|
||||
catalog = Catalog.query.get_or_404(self.catalog_id)
|
||||
self.embedding_model, self.embedding_model_class = get_embedding_model_and_class(self.tenant_id,
|
||||
self.catalog_id,
|
||||
catalog.embedding_model)
|
||||
self.similarity_threshold = retriever.configuration.get('es_similarity_threshold', 0.3)
|
||||
self.k = retriever.configuration.get('es_k', 8)
|
||||
self.tuning = retriever.tuning
|
||||
@@ -77,10 +81,10 @@ class StandardRAGRetriever(BaseRetriever):
|
||||
query = arguments.query
|
||||
|
||||
# Get query embedding
|
||||
query_embedding = self._get_query_embedding(query)
|
||||
query_embedding = self.embedding_model.embed_query(query)
|
||||
|
||||
# Get the appropriate embedding database model
|
||||
db_class = self.model_variables.embedding_model_class
|
||||
db_class = self.embedding_model_class
|
||||
|
||||
# Get current date for validity checks
|
||||
current_date = dt.now(tz=tz.utc).date()
|
||||
@@ -159,12 +163,6 @@ class StandardRAGRetriever(BaseRetriever):
|
||||
current_app.logger.error(f'Unexpected error in RAG retrieval: {e}')
|
||||
raise
|
||||
|
||||
def _get_query_embedding(self, query: str):
|
||||
"""Get embedding for the query text"""
|
||||
catalog = Catalog.query.get_or_404(self.catalog_id)
|
||||
embedding_model, embedding_model_class = get_embedding_model_and_class(self.tenant_id, self.catalog_id,
|
||||
catalog.embedding_model)
|
||||
|
||||
|
||||
# Register the retriever type
|
||||
RetrieverRegistry.register("STANDARD_RAG", StandardRAGRetriever)
|
||||
|
||||
@@ -6,7 +6,7 @@ from flask import current_app
|
||||
|
||||
from common.models.interaction import Specialist
|
||||
from common.utils.business_event_context import current_event
|
||||
from common.utils.model_utils import get_model_variables
|
||||
from common.utils.model_utils import get_model_variables, get_crewai_llm
|
||||
from eveai_chat_workers.retrievers.retriever_typing import RetrieverArguments
|
||||
from eveai_chat_workers.specialists.crewai_base_classes import EveAICrewAIAgent, EveAICrewAITask
|
||||
from crewai.tools import BaseTool
|
||||
@@ -78,6 +78,7 @@ class CrewAIBaseSpecialistExecutor(BaseSpecialistExecutor):
|
||||
f"AI:\n{interaction.specialist_results.get('rag_output').get('answer', '')}"
|
||||
for interaction in self._cached_session.interactions
|
||||
])
|
||||
return formatted_history
|
||||
|
||||
def _add_task_agent(self, task_name: str, agent_name: str):
|
||||
self._task_agents[task_name.lower()] = agent_name
|
||||
@@ -119,15 +120,21 @@ class CrewAIBaseSpecialistExecutor(BaseSpecialistExecutor):
|
||||
agent_role = agent_config.get('role', '').replace('{custom_role}', agent.role or '')
|
||||
agent_goal = agent_config.get('goal', '').replace('{custom_goal}', agent.goal or '')
|
||||
agent_backstory = agent_config.get('backstory', '').replace('{custom_backstory}', agent.backstory or '')
|
||||
new_agent = EveAICrewAIAgent(
|
||||
self,
|
||||
agent.type.lower(),
|
||||
role=agent_role,
|
||||
goal=agent_goal,
|
||||
backstory=agent_backstory,
|
||||
verbose=agent.tuning,
|
||||
)
|
||||
agent_full_model_name = agent_config.get('full_model_name', 'mistral.mistral-large-latest')
|
||||
agent_temperature = agent_config.get('temperature', 0.3)
|
||||
llm = get_crewai_llm(agent_full_model_name, agent_temperature)
|
||||
if not llm:
|
||||
current_app.logger.error(f"No LLM found for {agent_full_model_name}")
|
||||
raise Exception(f"No LLM found for {agent_full_model_name}")
|
||||
agent_kwargs = {
|
||||
"role": agent_role,
|
||||
"goal": agent_goal,
|
||||
"backstory": agent_backstory,
|
||||
"verbose": agent.tuning,
|
||||
"llm": llm,
|
||||
}
|
||||
agent_name = agent.type.lower()
|
||||
new_agent = EveAICrewAIAgent(self, agent_name, **agent_kwargs)
|
||||
self.log_tuning(f"CrewAI Agent {agent_name} initialized", agent_config)
|
||||
self._agents[agent_name] = new_agent
|
||||
|
||||
@@ -180,7 +187,7 @@ class CrewAIBaseSpecialistExecutor(BaseSpecialistExecutor):
|
||||
if name.startswith('latest_'):
|
||||
element = name[len('latest_'):]
|
||||
if self._cached_session.interactions:
|
||||
return self._cached_session.interactions[-1].get(element, '')
|
||||
return self._cached_session.interactions[-1].specialist_results.get(element, '')
|
||||
else:
|
||||
return {}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user