From 55a89c11bb6a9ddf0509eff36c029ea4c4deacf8 Mon Sep 17 00:00:00 2001 From: Josako Date: Tue, 25 Feb 2025 11:17:19 +0100 Subject: [PATCH] - Move from OpenAI to Mistral Embeddings - Move embedding model settings from tenant to catalog - BUG: error processing configuration for chunking patterns in HTML_PROCESSOR - Removed eveai_chat from docker-files and nginx configuration, as it is now obsolete - BUG: error in Library Operations when creating a new default RAG library - BUG: Added public type in migration scripts - Removed SocketIO from all code and requirements.txt --- CHANGELOG.md | 23 ++ common/eveai_model/__init__.py | 0 common/eveai_model/eveai_embedding_base.py | 11 + .../eveai_model/tracked_mistral_embeddings.py | 40 +++ common/extensions.py | 2 - common/models/document.py | 6 +- common/models/user.py | 2 - common/utils/config_field_types.py | 15 +- common/utils/eveai_exceptions.py | 13 +- common/utils/model_utils.py | 106 +++++--- common/utils/security.py | 2 - config/config.py | 43 ++- config/prompts/html_parse/1.0.0.yaml | 1 + docker/compose_dev.yaml | 69 ++--- docker/compose_stackhero.yaml | 37 +-- eveai_api/__init__.py | 13 +- eveai_api/api/specialist_execution_api.py | 46 ++++ eveai_app/__init__.py | 5 +- .../templates/document/edit_catalog.html | 2 +- eveai_app/templates/user/edit_tenant.html | 2 +- eveai_app/templates/user/tenant_overview.html | 2 +- eveai_app/views/document_forms.py | 20 +- eveai_app/views/document_views.py | 7 +- eveai_app/views/user_forms.py | 2 - eveai_app/views/user_views.py | 1 - eveai_chat_workers/retrievers/standard_rag.py | 9 +- eveai_workers/tasks.py | 21 +- ...ove_embedding_model_settings_to_catalog.py | 32 +++ migrations/tenant/env.py | 2 +- ...61eee4_link_embedding_models_to_catalog.py | 29 ++ nginx/nginx.conf | 36 +-- requirements.txt | 19 +- tests/interactive_client/specialist_client.py | 247 ------------------ .../test_specialist_client.py | 36 ++- 34 files changed, 457 insertions(+), 444 deletions(-) create mode 100644 common/eveai_model/__init__.py create mode 100644 common/eveai_model/eveai_embedding_base.py create mode 100644 common/eveai_model/tracked_mistral_embeddings.py create mode 100644 migrations/public/versions/b02d9ad000f4_move_embedding_model_settings_to_catalog.py create mode 100644 migrations/tenant/versions/2b04e961eee4_link_embedding_models_to_catalog.py delete mode 100644 tests/interactive_client/specialist_client.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 27851d4..984e8b1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,6 +25,29 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Security - In case of vulnerabilities. +## [2.1.0-alfa] + +### Added +- Zapier Refresh Document +- SPIN Specialist definition - from start to finish +- Introduction of startup scripts in eveai_app +- Caching for all configurations added +- Caching for processed specialist configurations +- Caching for specialist history +- Augmented Specialist Editor, including Specialist graphic presentation +- Introduction of specialist_execution_api, introducting SSE +- Introduction of crewai framework for specialist implementation +- Test app for testing specialists - also serves as a sample client application for SSE +- + +### Changed +- Improvement of startup of applications using gevent, and better handling and scaling of multiple connections +- STANDARD_RAG Specialist improvement +- + +### Deprecated +- eveai_chat - using sockets - will be replaced with new specialist_execution_api and SSE + ## [2.0.1-alfa] ### Added diff --git a/common/eveai_model/__init__.py b/common/eveai_model/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/common/eveai_model/eveai_embedding_base.py b/common/eveai_model/eveai_embedding_base.py new file mode 100644 index 0000000..e79b350 --- /dev/null +++ b/common/eveai_model/eveai_embedding_base.py @@ -0,0 +1,11 @@ +from abc import abstractmethod +from typing import List + + +class EveAIEmbeddings: + @abstractmethod + def embed_documents(self, texts: List[str]) -> List[List[float]]: + pass + + def embed_query(self, text: str) -> List[float]: + return self.embed_documents([text])[0] \ No newline at end of file diff --git a/common/eveai_model/tracked_mistral_embeddings.py b/common/eveai_model/tracked_mistral_embeddings.py new file mode 100644 index 0000000..73d018d --- /dev/null +++ b/common/eveai_model/tracked_mistral_embeddings.py @@ -0,0 +1,40 @@ +from flask import current_app +from langchain_mistralai import MistralAIEmbeddings +from typing import List, Any +import time + +from common.eveai_model.eveai_embedding_base import EveAIEmbeddings +from common.utils.business_event_context import current_event +from mistralai import Mistral + + +class TrackedMistralAIEmbeddings(EveAIEmbeddings): + def __init__(self, model: str = "mistral_embed"): + api_key = current_app.config['MISTRAL_API_KEY'] + self.client = Mistral( + api_key=api_key + ) + self.model = model + super().__init__() + + def embed_documents(self, texts: list[str]) -> list[list[float]]: + start_time = time.time() + result = self.client.embeddings.create( + model=self.model, + inputs=texts + ) + end_time = time.time() + + metrics = { + 'total_tokens': result.usage.total_tokens, + 'prompt_tokens': result.usage.prompt_tokens, # For embeddings, all tokens are prompt tokens + 'completion_tokens': result.usage.completion_tokens, + 'time_elapsed': end_time - start_time, + 'interaction_type': 'Embedding', + } + current_event.log_llm_metrics(metrics) + + embeddings = [embedding.embedding for embedding in result.data] + + return embeddings + diff --git a/common/extensions.py b/common/extensions.py index dc4a80f..d2ae2d6 100644 --- a/common/extensions.py +++ b/common/extensions.py @@ -5,7 +5,6 @@ from flask_security import Security from flask_mailman import Mail from flask_login import LoginManager from flask_cors import CORS -from flask_socketio import SocketIO from flask_jwt_extended import JWTManager from flask_session import Session from flask_wtf import CSRFProtect @@ -27,7 +26,6 @@ security = Security() mail = Mail() login_manager = LoginManager() cors = CORS() -socketio = SocketIO() jwt = JWTManager() session = Session() api_rest = Api() diff --git a/common/models/document.py b/common/models/document.py index 49187e5..4fac7aa 100644 --- a/common/models/document.py +++ b/common/models/document.py @@ -12,8 +12,10 @@ class Catalog(db.Model): description = db.Column(db.Text, nullable=True) type = db.Column(db.String(50), nullable=False, default="STANDARD_CATALOG") - min_chunk_size = db.Column(db.Integer, nullable=True, default=2000) - max_chunk_size = db.Column(db.Integer, nullable=True, default=3000) + embedding_model = db.Column(db.String(50), nullable=True) + + min_chunk_size = db.Column(db.Integer, nullable=True, default=1500) + max_chunk_size = db.Column(db.Integer, nullable=True, default=2500) # Meta Data user_metadata = db.Column(JSONB, nullable=True) diff --git a/common/models/user.py b/common/models/user.py index 6602d31..561cec9 100644 --- a/common/models/user.py +++ b/common/models/user.py @@ -31,7 +31,6 @@ class Tenant(db.Model): allowed_languages = db.Column(ARRAY(sa.String(2)), nullable=True) # LLM specific choices - embedding_model = db.Column(db.String(50), nullable=True) llm_model = db.Column(db.String(50), nullable=True) # Entitlements @@ -66,7 +65,6 @@ class Tenant(db.Model): 'type': self.type, 'default_language': self.default_language, 'allowed_languages': self.allowed_languages, - 'embedding_model': self.embedding_model, 'llm_model': self.llm_model, 'currency': self.currency, } diff --git a/common/utils/config_field_types.py b/common/utils/config_field_types.py index 7fd8143..5bf8fc4 100644 --- a/common/utils/config_field_types.py +++ b/common/utils/config_field_types.py @@ -652,12 +652,15 @@ def json_to_patterns(json_content: str) -> str: def json_to_pattern_list(json_content: str) -> list: """Convert JSON patterns list to text area content""" try: - patterns = json.loads(json_content) - if not isinstance(patterns, list): - raise ValueError("JSON must contain a list of patterns") - # Unescape if needed - patterns = [pattern.replace('\\\\', '\\') for pattern in patterns] - return patterns + if json_content: + patterns = json.loads(json_content) + if not isinstance(patterns, list): + raise ValueError("JSON must contain a list of patterns") + # Unescape if needed + patterns = [pattern.replace('\\\\', '\\') for pattern in patterns] + return patterns + else: + return [] except json.JSONDecodeError as e: raise ValueError(f"Invalid JSON format: {e}") diff --git a/common/utils/eveai_exceptions.py b/common/utils/eveai_exceptions.py index 782d565..66db901 100644 --- a/common/utils/eveai_exceptions.py +++ b/common/utils/eveai_exceptions.py @@ -124,4 +124,15 @@ class EveAISocketInputException(EveAIException): """Raised when a socket call receives an invalid payload""" def __init__(self, message, status_code=400, payload=None): - super.__init__(message, status_code, payload) \ No newline at end of file + super.__init__(message, status_code, payload) + + +class EveAIInvalidEmbeddingModel(EveAIException): + """Raised when no or an invalid embedding model is provided in the catalog""" + + def __init__(self, tenant_id, catalog_id, status_code=400, payload=None): + self.tenant_id = tenant_id + self.catalog_id = catalog_id + # Construct the message dynamically + message = f"Tenant with ID '{tenant_id}' has no or an invalid embedding model in Catalog {catalog_id}." + super().__init__(message, status_code, payload) diff --git a/common/utils/model_utils.py b/common/utils/model_utils.py index e307c66..9277819 100644 --- a/common/utils/model_utils.py +++ b/common/utils/model_utils.py @@ -1,23 +1,25 @@ import os -from typing import Dict, Any, Optional +from typing import Dict, Any, Optional, Tuple import langcodes +from langchain_core.language_models import BaseChatModel from common.langchain.llm_metrics_handler import LLMMetricsHandler -from common.langchain.templates.template_manager import TemplateManager -from langchain_openai import OpenAIEmbeddings, ChatOpenAI, OpenAI +from langchain_openai import ChatOpenAI from langchain_anthropic import ChatAnthropic +from langchain_mistralai import ChatMistralAI from flask import current_app -from datetime import datetime as dt, timezone as tz -from common.langchain.tracked_openai_embeddings import TrackedOpenAIEmbeddings +from common.eveai_model.tracked_mistral_embeddings import TrackedMistralAIEmbeddings from common.langchain.tracked_transcription import TrackedOpenAITranscription from common.models.user import Tenant -from common.utils.cache.base import CacheHandler from config.model_config import MODEL_CONFIG -from common.extensions import template_manager, cache_manager -from common.models.document import EmbeddingLargeOpenAI, EmbeddingSmallOpenAI -from common.utils.eveai_exceptions import EveAITenantNotFound +from common.extensions import template_manager +from common.models.document import EmbeddingMistral +from common.utils.eveai_exceptions import EveAITenantNotFound, EveAIInvalidEmbeddingModel + +llm_model_cache: Dict[Tuple[str, float], BaseChatModel] = {} +llm_metrics_handler = LLMMetricsHandler() def create_language_template(template: str, language: str) -> str: @@ -55,6 +57,63 @@ def replace_variable_in_template(template: str, variable: str, value: str) -> st return template.replace(variable, value or "") +def get_embedding_model_and_class(tenant_id, catalog_id, full_embedding_name): + """ + Retrieve the embedding model and embedding model class to store Embeddings + + Args: + tenant_id: ID of the tenant + catalog_id: ID of the catalog + full_embedding_name: The full name of the embedding model: . + + Returns: + embedding_model, embedding_model_class + """ + embedding_provider, embedding_model_name = full_embedding_name.split('.') + + # Calculate the embedding model to be used + if embedding_provider == "mistral": + api_key = current_app.config['MISTRAL_API_KEY'] + embedding_model = TrackedMistralAIEmbeddings( + model=embedding_model_name + ) + else: + raise EveAIInvalidEmbeddingModel(tenant_id, catalog_id) + + # Calculate the Embedding Model Class to be used to store embeddings + if embedding_model_name == "mistral-embed": + embedding_model_class = EmbeddingMistral + else: + raise EveAIInvalidEmbeddingModel(tenant_id, catalog_id) + + return embedding_model, embedding_model_class + + +def get_llm(full_model_name, temperature): + if not full_model_name: + full_model_name = 'openai.gpt-4o' # Default to gpt-4o for now, as this is the original model developed against + + llm = llm_model_cache.get((full_model_name, temperature)) + if not llm: + llm_provider, llm_model_name = full_model_name.split('.') + if llm_provider == "openai": + llm = ChatOpenAI( + api_key=current_app.config['OPENAI_API_KEY'], + model=llm_model_name, + temperature=temperature, + callbacks=[llm_metrics_handler] + ) + elif llm_provider == "mistral": + llm = ChatMistralAI( + api_key=current_app.config['MISTRAL_API_KEY'], + model=llm_model_name, + temperature=temperature, + callbacks=[llm_metrics_handler] + ) + + llm_model_cache[(full_model_name, temperature)] = llm + + class ModelVariables: """Manages model-related variables and configurations""" @@ -63,15 +122,13 @@ class ModelVariables: Initialize ModelVariables with tenant and optional template manager Args: - tenant: Tenant instance - template_manager: Optional TemplateManager instance + tenant_id: Tenant instance + variables: Optional variables """ current_app.logger.info(f'Model variables initialized with tenant {tenant_id} and variables \n{variables}') self.tenant_id = tenant_id self._variables = variables if variables is not None else self._initialize_variables() current_app.logger.info(f'Model _variables initialized to {self._variables}') - self._embedding_model = None - self._embedding_model_class = None self._llm_instances = {} self.llm_metrics_handler = LLMMetricsHandler() self._transcription_model = None @@ -85,7 +142,6 @@ class ModelVariables: raise EveAITenantNotFound(self.tenant_id) # Set model providers - variables['embedding_provider'], variables['embedding_model'] = tenant.embedding_model.split('.') variables['llm_provider'], variables['llm_model'] = tenant.llm_model.split('.') variables['llm_full_model'] = tenant.llm_model @@ -102,28 +158,6 @@ class ModelVariables: return variables - @property - def embedding_model(self): - """Get the embedding model instance""" - if self._embedding_model is None: - api_key = os.getenv('OPENAI_API_KEY') - self._embedding_model = TrackedOpenAIEmbeddings( - api_key=api_key, - model=self._variables['embedding_model'] - ) - return self._embedding_model - - @property - def embedding_model_class(self): - """Get the embedding model class""" - if self._embedding_model_class is None: - if self._variables['embedding_model'] == 'text-embedding-3-large': - self._embedding_model_class = EmbeddingLargeOpenAI - else: # text-embedding-3-small - self._embedding_model_class = EmbeddingSmallOpenAI - - return self._embedding_model_class - @property def annotation_chunk_length(self): return self._variables['annotation_chunk_length'] diff --git a/common/utils/security.py b/common/utils/security.py index 51b86bb..8e47912 100644 --- a/common/utils/security.py +++ b/common/utils/security.py @@ -13,14 +13,12 @@ def set_tenant_session_data(sender, user, **kwargs): tenant = Tenant.query.filter_by(id=user.tenant_id).first() session['tenant'] = tenant.to_dict() session['default_language'] = tenant.default_language - session['default_embedding_model'] = tenant.embedding_model session['default_llm_model'] = tenant.llm_model def clear_tenant_session_data(sender, user, **kwargs): session.pop('tenant', None) session.pop('default_language', None) - session.pop('default_embedding_model', None) session.pop('default_llm_model', None) diff --git a/config/config.py b/config/config.py index 245c2dd..9c8b3d6 100644 --- a/config/config.py +++ b/config/config.py @@ -63,8 +63,10 @@ class Config(object): SUPPORTED_CURRENCIES = ['€', '$'] # supported LLMs - SUPPORTED_EMBEDDINGS = ['openai.text-embedding-3-small', 'openai.text-embedding-3-large', 'mistral.mistral-embed'] - SUPPORTED_LLMS = ['openai.gpt-4o', 'anthropic.claude-3-5-sonnet', 'openai.gpt-4o-mini'] + # SUPPORTED_EMBEDDINGS = ['openai.text-embedding-3-small', 'openai.text-embedding-3-large', 'mistral.mistral-embed'] + SUPPORTED_EMBEDDINGS = ['mistral.mistral-embed'] + SUPPORTED_LLMS = ['openai.gpt-4o', 'anthropic.claude-3-5-sonnet', 'openai.gpt-4o-mini', + 'mistral.mistral-large-latest', 'mistral.mistral-small-latest'] ANTHROPIC_LLM_VERSIONS = {'claude-3-5-sonnet': 'claude-3-5-sonnet-20240620', } @@ -75,13 +77,10 @@ class Config(object): 'anthropic.claude-3-5-sonnet': 8000 } - # OpenAI API Keys + # Environemnt Loaders OPENAI_API_KEY = environ.get('OPENAI_API_KEY') - - # Groq API Keys + MISTRAL_API_KEY = environ.get('MISTRAL_API_KEY') GROQ_API_KEY = environ.get('GROQ_API_KEY') - - # Anthropic API Keys ANTHROPIC_API_KEY = environ.get('ANTHROPIC_API_KEY') # Celery settings @@ -93,7 +92,7 @@ class Config(object): # SocketIO settings # SOCKETIO_ASYNC_MODE = 'threading' - SOCKETIO_ASYNC_MODE = 'gevent' + # SOCKETIO_ASYNC_MODE = 'gevent' # Session Settings SESSION_TYPE = 'redis' @@ -207,13 +206,13 @@ class DevConfig(Config): # UNSTRUCTURED_FULL_URL = 'https://flowitbv-16c4us0m.api.unstructuredapp.io/general/v0/general' # SocketIO settings - SOCKETIO_MESSAGE_QUEUE = f'{REDIS_BASE_URI}/1' - SOCKETIO_CORS_ALLOWED_ORIGINS = '*' - SOCKETIO_LOGGER = True - SOCKETIO_ENGINEIO_LOGGER = True - SOCKETIO_PING_TIMEOUT = 20000 - SOCKETIO_PING_INTERVAL = 25000 - SOCKETIO_MAX_IDLE_TIME = timedelta(minutes=60) # Changing this value ==> change maxConnectionDuration value in + # SOCKETIO_MESSAGE_QUEUE = f'{REDIS_BASE_URI}/1' + # SOCKETIO_CORS_ALLOWED_ORIGINS = '*' + # SOCKETIO_LOGGER = True + # SOCKETIO_ENGINEIO_LOGGER = True + # SOCKETIO_PING_TIMEOUT = 20000 + # SOCKETIO_PING_INTERVAL = 25000 + # SOCKETIO_MAX_IDLE_TIME = timedelta(minutes=60) # Changing this value ==> change maxConnectionDuration value in # eveai-chat-widget.js # Google Cloud settings @@ -299,13 +298,13 @@ class ProdConfig(Config): SESSION_REDIS = redis.from_url(f'{REDIS_BASE_URI}/2') # SocketIO settings - SOCKETIO_MESSAGE_QUEUE = f'{REDIS_BASE_URI}/1' - SOCKETIO_CORS_ALLOWED_ORIGINS = '*' - SOCKETIO_LOGGER = True - SOCKETIO_ENGINEIO_LOGGER = True - SOCKETIO_PING_TIMEOUT = 20000 - SOCKETIO_PING_INTERVAL = 25000 - SOCKETIO_MAX_IDLE_TIME = timedelta(minutes=60) # Changing this value ==> change maxConnectionDuration value in + # SOCKETIO_MESSAGE_QUEUE = f'{REDIS_BASE_URI}/1' + # SOCKETIO_CORS_ALLOWED_ORIGINS = '*' + # SOCKETIO_LOGGER = True + # SOCKETIO_ENGINEIO_LOGGER = True + # SOCKETIO_PING_TIMEOUT = 20000 + # SOCKETIO_PING_INTERVAL = 25000 + # SOCKETIO_MAX_IDLE_TIME = timedelta(minutes=60) # Changing this value ==> change maxConnectionDuration value in # eveai-chat-widget.js # Google Cloud settings diff --git a/config/prompts/html_parse/1.0.0.yaml b/config/prompts/html_parse/1.0.0.yaml index f1bb745..4cacd7f 100644 --- a/config/prompts/html_parse/1.0.0.yaml +++ b/config/prompts/html_parse/1.0.0.yaml @@ -13,6 +13,7 @@ content: | HTML is between triple backquotes. ```{html}``` +model: "mistral.mistral-small-latest" metadata: author: "Josako" date_added: "2024-11-10" diff --git a/docker/compose_dev.yaml b/docker/compose_dev.yaml index 0d52c29..edf392b 100644 --- a/docker/compose_dev.yaml +++ b/docker/compose_dev.yaml @@ -28,6 +28,7 @@ x-common-variables: &common-variables FLOWER_PASSWORD: 'Jungles' OPENAI_API_KEY: 'sk-proj-8R0jWzwjL7PeoPyMhJTZT3BlbkFJLb6HfRB2Hr9cEVFWEhU7' GROQ_API_KEY: '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' + MISTRAL_API_KEY: 'jGDc6fkCbt0iOC0jQsbuZhcjLWBPGc2b' ANTHROPIC_API_KEY: 'sk-ant-api03-c2TmkzbReeGhXBO5JxNH6BJNylRDonc9GmZd0eRbrvyekec2' JWT_SECRET_KEY: 'bsdMkmQ8ObfMD52yAFg4trrvjgjMhuIqg2fjDpD/JqvgY0ccCcmlsEnVFmR79WPiLKEA3i8a5zmejwLZKl4v9Q==' API_ENCRYPTION_KEY: 'xfF5369IsredSrlrYZqkM9ZNrfUASYYS6TCcAR9UKj4=' @@ -65,7 +66,7 @@ services: - ./logs/nginx:/var/log/nginx depends_on: - eveai_app - - eveai_chat + - eveai_api networks: - eveai-network @@ -134,39 +135,39 @@ services: networks: - eveai-network - eveai_chat: - image: josakola/eveai_chat:latest - build: - context: .. - dockerfile: ./docker/eveai_chat/Dockerfile - platforms: - - linux/amd64 - - linux/arm64 - ports: - - 5002:5002 - environment: - <<: *common-variables - COMPONENT_NAME: eveai_chat - volumes: - - ../eveai_chat:/app/eveai_chat - - ../common:/app/common - - ../config:/app/config - - ../scripts:/app/scripts - - ../patched_packages:/app/patched_packages - - ./eveai_logs:/app/logs - depends_on: - db: - condition: service_healthy - redis: - condition: service_healthy - healthcheck: - test: [ "CMD", "curl", "-f", "http://localhost:5002/healthz/ready" ] # Adjust based on your health endpoint - interval: 30s - timeout: 1s - retries: 3 - start_period: 30s - networks: - - eveai-network +# eveai_chat: +# image: josakola/eveai_chat:latest +# build: +# context: .. +# dockerfile: ./docker/eveai_chat/Dockerfile +# platforms: +# - linux/amd64 +# - linux/arm64 +# ports: +# - 5002:5002 +# environment: +# <<: *common-variables +# COMPONENT_NAME: eveai_chat +# volumes: +# - ../eveai_chat:/app/eveai_chat +# - ../common:/app/common +# - ../config:/app/config +# - ../scripts:/app/scripts +# - ../patched_packages:/app/patched_packages +# - ./eveai_logs:/app/logs +# depends_on: +# db: +# condition: service_healthy +# redis: +# condition: service_healthy +# healthcheck: +# test: [ "CMD", "curl", "-f", "http://localhost:5002/healthz/ready" ] # Adjust based on your health endpoint +# interval: 30s +# timeout: 1s +# retries: 3 +# start_period: 30s +# networks: +# - eveai-network eveai_chat_workers: image: josakola/eveai_chat_workers:latest diff --git a/docker/compose_stackhero.yaml b/docker/compose_stackhero.yaml index e9bee08..71c7883 100644 --- a/docker/compose_stackhero.yaml +++ b/docker/compose_stackhero.yaml @@ -31,6 +31,7 @@ x-common-variables: &common-variables OPENAI_API_KEY: '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' GROQ_API_KEY: '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' ANTHROPIC_API_KEY: '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' + MISTRAL_API_KEY: 'PjnUeDRPD7B144wdHlH0CzR7m0z8RHXi' JWT_SECRET_KEY: '0d99e810e686ea567ef305d8e9b06195c4db482952e19276590a726cde60a408' API_ENCRYPTION_KEY: 'Ly5XYWwEKiasfAwEqdEMdwR-k0vhrq6QPYd4whEROB0=' GRAYLOG_HOST: de4zvu.stackhero-network.com @@ -66,7 +67,7 @@ services: - "traefik.http.services.nginx.loadbalancer.server.port=80" depends_on: - eveai_app - - eveai_chat + - eveai_api networks: - eveai-network @@ -99,23 +100,23 @@ services: networks: - eveai-network - eveai_chat: - platform: linux/amd64 - image: josakola/eveai_chat:latest - ports: - - 5002:5002 - environment: - <<: *common-variables - COMPONENT_NAME: eveai_chat - volumes: - - eveai_logs:/app/logs - healthcheck: - test: [ "CMD", "curl", "-f", "http://localhost:5002/healthz/ready" ] # Adjust based on your health endpoint - interval: 10s - timeout: 5s - retries: 5 - networks: - - eveai-network +# eveai_chat: +# platform: linux/amd64 +# image: josakola/eveai_chat:latest +# ports: +# - 5002:5002 +# environment: +# <<: *common-variables +# COMPONENT_NAME: eveai_chat +# volumes: +# - eveai_logs:/app/logs +# healthcheck: +# test: [ "CMD", "curl", "-f", "http://localhost:5002/healthz/ready" ] # Adjust based on your health endpoint +# interval: 10s +# timeout: 5s +# retries: 5 +# networks: +# - eveai-network eveai_chat_workers: platform: linux/amd64 diff --git a/eveai_api/__init__.py b/eveai_api/__init__.py index bb5d33f..ad2f983 100644 --- a/eveai_api/__init__.py +++ b/eveai_api/__init__.py @@ -5,7 +5,7 @@ from flask_jwt_extended import get_jwt_identity, verify_jwt_in_request from sqlalchemy.exc import SQLAlchemyError from werkzeug.exceptions import HTTPException -from common.extensions import db, api_rest, jwt, minio_client, simple_encryption, cors +from common.extensions import db, api_rest, jwt, minio_client, simple_encryption, cors, cache_manager import os import logging.config @@ -31,7 +31,7 @@ def create_app(config_file=None): case 'development': app.config.from_object(get_config('dev')) case 'production': - app.config.from_object(get_config('prod')) + app.config.from_object(get_config('prod')) case _: app.config.from_object(get_config('dev')) @@ -60,6 +60,9 @@ def create_app(config_file=None): # Register Request Debugger register_request_debugger(app) + # Register Cache Handlers + register_cache_handlers(app) + @app.before_request def check_cors(): if request.method == 'OPTIONS': @@ -120,6 +123,7 @@ def register_extensions(app): jwt.init_app(app) minio_client.init_app(app) simple_encryption.init_app(app) + cache_manager.init_app(app) cors.init_app(app, resources={ r"/api/v1/*": { "origins": "*", @@ -201,3 +205,8 @@ def register_error_handlers(app): "message": str(e), "type": "BadRequestError" }), 400 + + +def register_cache_handlers(app): + from common.utils.cache.config_cache import register_config_cache_handlers + register_config_cache_handlers(cache_manager) diff --git a/eveai_api/api/specialist_execution_api.py b/eveai_api/api/specialist_execution_api.py index 32837ad..0f302ca 100644 --- a/eveai_api/api/specialist_execution_api.py +++ b/eveai_api/api/specialist_execution_api.py @@ -5,9 +5,11 @@ from flask import Response, stream_with_context, current_app from flask_restx import Namespace, Resource, fields from flask_jwt_extended import jwt_required, get_jwt_identity +from common.extensions import cache_manager from common.utils.celery_utils import current_celery from common.utils.execution_progress import ExecutionProgressTracker from eveai_api.api.auth import requires_service +from common.models.interaction import Specialist specialist_execution_ns = Namespace('specialist-execution', description='Specialist execution operations') @@ -87,3 +89,47 @@ class ExecutionStream(Resource): 'Connection': 'keep-alive' } ) + + +specialist_arguments_input = specialist_execution_ns.model('SpecialistArgumentsInput', { + 'specialist_id': fields.Integer(required=True, description='ID of the specialist to use'), +}) + +specialist_arguments_response = specialist_execution_ns.model('SpecialistArgumentsResponse', { + 'arguments': fields.Raw(description='Dynamic list of attributes for the specialist.'), +}) + + +@specialist_execution_ns.route('/specialist_arguments', methods=['GET']) +class SpecialistArgument(Resource): + @jwt_required() + @requires_service('SPECIALIST_API') + @specialist_execution_ns.expect(specialist_arguments_input) + @specialist_execution_ns.response(200, 'Specialist configuration fetched.', specialist_arguments_response) + @specialist_execution_ns.response(404, 'Specialist configuration not found.') + @specialist_execution_ns.response(500, 'Internal Server Error') + def get(self): + """Start execution of a specialist""" + tenant_id = get_jwt_identity() + data = specialist_execution_ns.payload + specialist_id = data['specialist_id'] + try: + specialist = Specialist.query.get(specialist_id) + if specialist: + configuration = cache_manager.specialists_config_cache.get_config(specialist.type, + specialist.type_version) + current_app.logger.debug(f"Configuration returned: {configuration}") + if configuration: + if 'arguments' in configuration: + return { + 'arguments': configuration['arguments'], + }, 200 + else: + specialist_execution_ns.abort(404, 'No arguments found in specialist configuration.') + else: + specialist_execution_ns.abort(404, 'Error fetching Specialist configuration.') + else: + specialist_execution_ns.abort(404, 'Error fetching Specialist') + except Exception as e: + current_app.logger.error(f"Error while retrieving Specialist configuration: {str(e)}") + specialist_execution_ns.abort(500, 'Unexpected Error while fetching Specialist configuration.') diff --git a/eveai_app/__init__.py b/eveai_app/__init__.py index 74fc580..c7b8d76 100644 --- a/eveai_app/__init__.py +++ b/eveai_app/__init__.py @@ -1,7 +1,7 @@ import logging import os -from flask import Flask, render_template, jsonify, flash, redirect, request -from flask_security import SQLAlchemyUserDatastore, LoginForm +from flask import Flask, jsonify +from flask_security import SQLAlchemyUserDatastore from flask_security.signals import user_authenticated from werkzeug.middleware.proxy_fix import ProxyFix import logging.config @@ -12,7 +12,6 @@ from common.models.user import User, Role, Tenant, TenantDomain import common.models.interaction import common.models.entitlements import common.models.document -from common.utils.nginx_utils import prefixed_url_for from common.utils.startup_eveai import perform_startup_actions from config.logging_config import LOGGING from common.utils.security import set_tenant_session_data diff --git a/eveai_app/templates/document/edit_catalog.html b/eveai_app/templates/document/edit_catalog.html index 77e3c22..5867d72 100644 --- a/eveai_app/templates/document/edit_catalog.html +++ b/eveai_app/templates/document/edit_catalog.html @@ -11,7 +11,7 @@ When you change chunking of embedding information, you'll need to manually refre {% block content %}
{{ form.hidden_tag() }} - {% set disabled_fields = ['type'] %} + {% set disabled_fields = ['type', 'embedding_model'] %} {% set exclude_fields = [] %} {% for field in form.get_static_fields() %} diff --git a/eveai_app/templates/user/edit_tenant.html b/eveai_app/templates/user/edit_tenant.html index 40e31c5..095c4fd 100644 --- a/eveai_app/templates/user/edit_tenant.html +++ b/eveai_app/templates/user/edit_tenant.html @@ -9,7 +9,7 @@ {% block content %} {{ form.hidden_tag() }} - {% set disabled_fields = ['name', 'embedding_model', 'llm_model'] %} + {% set disabled_fields = ['name', 'llm_model'] %} {% set exclude_fields = [] %} {% for field in form %} {{ render_field(field, disabled_fields, exclude_fields) }} diff --git a/eveai_app/templates/user/tenant_overview.html b/eveai_app/templates/user/tenant_overview.html index a734a09..0a80b82 100644 --- a/eveai_app/templates/user/tenant_overview.html +++ b/eveai_app/templates/user/tenant_overview.html @@ -35,7 +35,7 @@
- {% set model_fields = ['embedding_model', 'llm_model'] %} + {% set model_fields = ['llm_model'] %} {% for field in form %} {{ render_included_field(field, disabled_fields=model_fields, include_fields=model_fields) }} {% endfor %} diff --git a/eveai_app/views/document_forms.py b/eveai_app/views/document_forms.py index 71346ad..8fe89f7 100644 --- a/eveai_app/views/document_forms.py +++ b/eveai_app/views/document_forms.py @@ -1,6 +1,7 @@ from flask import session, current_app from flask_wtf import FlaskForm -from wtforms import (StringField, BooleanField, SubmitField, DateField, IntegerField, SelectField, TextAreaField, URLField) +from wtforms import (StringField, BooleanField, SubmitField, DateField, IntegerField, SelectField, TextAreaField, + URLField) from wtforms.validators import DataRequired, Length, Optional, URL, ValidationError, NumberRange from flask_wtf.file import FileField, FileRequired import json @@ -30,10 +31,13 @@ class CatalogForm(FlaskForm): # Select Field for Catalog Type (Uses the CATALOG_TYPES defined in config) type = SelectField('Catalog Type', validators=[DataRequired()]) - min_chunk_size = IntegerField('Minimum Chunk Size (2000)', validators=[NumberRange(min=0), Optional()], - default=2000) - max_chunk_size = IntegerField('Maximum Chunk Size (3000)', validators=[NumberRange(min=0), Optional()], - default=3000) + # Selection fields for processing & creating embeddings + embedding_model = SelectField('Embedding Model', choices=[], validators=[DataRequired()]) + + min_chunk_size = IntegerField('Minimum Chunk Size (1500)', validators=[NumberRange(min=0), Optional()], + default=1500) + max_chunk_size = IntegerField('Maximum Chunk Size (2500)', validators=[NumberRange(min=0), Optional()], + default=2500) # Metadata fields user_metadata = TextAreaField('User Metadata', validators=[Optional(), validate_json]) @@ -43,6 +47,7 @@ class CatalogForm(FlaskForm): super().__init__(*args, **kwargs) # Dynamically populate the 'type' field using the constructor self.type.choices = [(key, value['name']) for key, value in CATALOG_TYPES.items()] + self.embedding_model.choices = [(model, model) for model in current_app.config['SUPPORTED_EMBEDDINGS']] class EditCatalogForm(DynamicFormBase): @@ -52,6 +57,9 @@ class EditCatalogForm(DynamicFormBase): # Select Field for Catalog Type (Uses the CATALOG_TYPES defined in config) type = StringField('Catalog Type', validators=[DataRequired()], render_kw={'readonly': True}) + # Selection fields for processing & creating embeddings + embedding_model = StringField('Embedding Model', validators=[DataRequired()], render_kw={'readonly': True}) + min_chunk_size = IntegerField('Minimum Chunk Size (2000)', validators=[NumberRange(min=0), Optional()], default=2000) max_chunk_size = IntegerField('Maximum Chunk Size (3000)', validators=[NumberRange(min=0), Optional()], @@ -59,7 +67,7 @@ class EditCatalogForm(DynamicFormBase): # Metadata fields user_metadata = TextAreaField('User Metadata', validators=[Optional(), validate_json]) - system_metadata = TextAreaField('System Metadata', validators=[Optional(), validate_json],) + system_metadata = TextAreaField('System Metadata', validators=[Optional(), validate_json], ) class ProcessorForm(FlaskForm): diff --git a/eveai_app/views/document_views.py b/eveai_app/views/document_views.py index 9d8ed42..88edfdd 100644 --- a/eveai_app/views/document_views.py +++ b/eveai_app/views/document_views.py @@ -684,8 +684,9 @@ def create_default_rag_library(): name='Default RAG Catalog', description='Default RAG Catalog', type="STANDARD_CATALOG", - min_chunk_size=2000, - max_chunk_size=3000, + min_chunk_size=1500, + max_chunk_size=2500, + embedding_model="mistral.mistral-embed" ) set_logging_information(cat, timestamp) @@ -696,7 +697,7 @@ def create_default_rag_library(): name='Default HTML Processor', description='Default HTML Processor', catalog_id=cat.id, - type="HTML Processor", + type="HTML_PROCESSOR", configuration={ "html_tags": "p, h1, h2, h3, h4, h5, h6, li, table, thead, tbody, tr, td", "html_end_tags": "p, li, table", diff --git a/eveai_app/views/user_forms.py b/eveai_app/views/user_forms.py index 3a91e1b..8396302 100644 --- a/eveai_app/views/user_forms.py +++ b/eveai_app/views/user_forms.py @@ -21,7 +21,6 @@ class TenantForm(FlaskForm): # Timezone timezone = SelectField('Timezone', choices=[], validators=[DataRequired()]) # LLM fields - embedding_model = SelectField('Embedding Model', choices=[], validators=[DataRequired()]) llm_model = SelectField('Large Language Model', choices=[], validators=[DataRequired()]) # Embedding variables submit = SubmitField('Submit') @@ -36,7 +35,6 @@ class TenantForm(FlaskForm): # initialise timezone self.timezone.choices = [(tz, tz) for tz in pytz.all_timezones] # initialise LLM fields - self.embedding_model.choices = [(model, model) for model in current_app.config['SUPPORTED_EMBEDDINGS']] self.llm_model.choices = [(model, model) for model in current_app.config['SUPPORTED_LLMS']] # Initialize fallback algorithms self.type.choices = [(t, t) for t in current_app.config['TENANT_TYPES']] diff --git a/eveai_app/views/user_views.py b/eveai_app/views/user_views.py index 4d742d0..903eb25 100644 --- a/eveai_app/views/user_views.py +++ b/eveai_app/views/user_views.py @@ -228,7 +228,6 @@ def handle_tenant_selection(): # set tenant information in the session session['tenant'] = the_tenant.to_dict() session['default_language'] = the_tenant.default_language - session['embedding_model'] = the_tenant.embedding_model session['llm_model'] = the_tenant.llm_model # remove catalog-related items from the session session.pop('catalog_id', None) diff --git a/eveai_chat_workers/retrievers/standard_rag.py b/eveai_chat_workers/retrievers/standard_rag.py index 88cb523..6fc4319 100644 --- a/eveai_chat_workers/retrievers/standard_rag.py +++ b/eveai_chat_workers/retrievers/standard_rag.py @@ -10,7 +10,7 @@ from common.extensions import db from common.models.document import Document, DocumentVersion, Catalog, Retriever from common.models.user import Tenant from common.utils.datetime_utils import get_date_in_timezone -from common.utils.model_utils import get_model_variables +from common.utils.model_utils import get_embedding_model_and_class from .base import BaseRetriever from .registry import RetrieverRegistry @@ -25,10 +25,10 @@ class StandardRAGRetriever(BaseRetriever): retriever = Retriever.query.get_or_404(retriever_id) self.catalog_id = retriever.catalog_id + self.tenant_id = tenant_id self.similarity_threshold = retriever.configuration.get('es_similarity_threshold', 0.3) self.k = retriever.configuration.get('es_k', 8) self.tuning = retriever.tuning - self.model_variables = get_model_variables(self.tenant_id) self.log_tuning("Standard RAG retriever initialized") @@ -161,8 +161,9 @@ class StandardRAGRetriever(BaseRetriever): def _get_query_embedding(self, query: str): """Get embedding for the query text""" - embedding_model = self.model_variables.embedding_model - return embedding_model.embed_query(query) + catalog = Catalog.query.get_or_404(self.catalog_id) + embedding_model, embedding_model_class = get_embedding_model_and_class(self.tenant_id, self.catalog_id, + catalog.embedding_model) # Register the retriever type diff --git a/eveai_workers/tasks.py b/eveai_workers/tasks.py index 05de7f5..5f3c043 100644 --- a/eveai_workers/tasks.py +++ b/eveai_workers/tasks.py @@ -12,18 +12,20 @@ from langchain_core.runnables import RunnablePassthrough from sqlalchemy import or_ from sqlalchemy.exc import SQLAlchemyError -from common.extensions import db, minio_client +from common.extensions import db from common.models.document import DocumentVersion, Embedding, Document, Processor, Catalog from common.models.user import Tenant from common.utils.celery_utils import current_celery from common.utils.database import Database -from common.utils.model_utils import create_language_template, get_model_variables +from common.utils.model_utils import create_language_template, get_model_variables, get_embedding_model_and_class from common.utils.business_event import BusinessEvent from common.utils.business_event_context import current_event from config.type_defs.processor_types import PROCESSOR_TYPES from eveai_workers.processors.processor_registry import ProcessorRegistry +from common.utils.eveai_exceptions import EveAIInvalidEmbeddingModel + from common.utils.config_field_types import json_to_pattern_list @@ -155,7 +157,7 @@ def embed_markdown(tenant, model_variables, document_version, catalog, processor # Create embeddings with current_event.create_span("Create Embeddings"): - embeddings = embed_chunks(tenant, model_variables, document_version, enriched_chunks) + embeddings = embed_chunks(tenant, catalog, document_version, enriched_chunks) # Update document version and save embeddings try: @@ -227,9 +229,14 @@ def summarize_chunk(tenant, model_variables, document_version, chunk): raise -def embed_chunks(tenant, model_variables, document_version, chunks): - embedding_model = model_variables.embedding_model +def embed_chunks(tenant, catalog, document_version, chunks): + if catalog.embedding_model: + embedding_model, embedding_model_class = get_embedding_model_and_class(tenant.id, catalog.id, + catalog.embedding_model) + else: + raise EveAIInvalidEmbeddingModel(tenant.id, catalog.id) + # Actually embed try: embeddings = embedding_model.embed_documents(chunks) except LangChainException as e: @@ -241,7 +248,7 @@ def embed_chunks(tenant, model_variables, document_version, chunks): # Add embeddings to the database new_embeddings = [] for chunk, embedding in zip(chunks, embeddings): - new_embedding = model_variables.embedding_model_class() + new_embedding = embedding_model_class() new_embedding.document_version = document_version new_embedding.active = True new_embedding.chunk = chunk @@ -309,7 +316,7 @@ def combine_chunks_for_markdown(potential_chunks, min_chars, max_chars, processo return False - chunking_patterns = json_to_pattern_list(processor.configuration.get('chunking_patterns', [])) + chunking_patterns = json_to_pattern_list(processor.configuration.get('chunking_patterns', "")) processor.log_tuning(f'Chunking Patterns Extraction: ', { 'Full Configuration': processor.configuration, diff --git a/migrations/public/versions/b02d9ad000f4_move_embedding_model_settings_to_catalog.py b/migrations/public/versions/b02d9ad000f4_move_embedding_model_settings_to_catalog.py new file mode 100644 index 0000000..29b9941 --- /dev/null +++ b/migrations/public/versions/b02d9ad000f4_move_embedding_model_settings_to_catalog.py @@ -0,0 +1,32 @@ +"""Move embedding model settings to Catalog + +Revision ID: b02d9ad000f4 +Revises: f0ab991a6411 +Create Date: 2025-02-21 22:11:10.313148 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'b02d9ad000f4' +down_revision = 'f0ab991a6411' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tenant', schema=None) as batch_op: + batch_op.drop_column('embedding_model') + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tenant', schema=None) as batch_op: + batch_op.add_column(sa.Column('embedding_model', sa.VARCHAR(length=50), autoincrement=False, nullable=True)) + + # ### end Alembic commands ### diff --git a/migrations/tenant/env.py b/migrations/tenant/env.py index d223524..5ee44b0 100644 --- a/migrations/tenant/env.py +++ b/migrations/tenant/env.py @@ -71,7 +71,7 @@ target_db = current_app.extensions['migrate'].db def get_public_table_names(): # TODO: This function should include the necessary functionality to automatically retrieve table names return ['role', 'roles_users', 'tenant', 'user', 'tenant_domain','license_tier', 'license', 'license_usage', - 'business_event_log'] + 'business_event_log', 'tenant_project'] PUBLIC_TABLES = get_public_table_names() diff --git a/migrations/tenant/versions/2b04e961eee4_link_embedding_models_to_catalog.py b/migrations/tenant/versions/2b04e961eee4_link_embedding_models_to_catalog.py new file mode 100644 index 0000000..c3a35a2 --- /dev/null +++ b/migrations/tenant/versions/2b04e961eee4_link_embedding_models_to_catalog.py @@ -0,0 +1,29 @@ +"""Link embedding models to Catalog + +Revision ID: 2b04e961eee4 +Revises: e58835fadd96 +Create Date: 2025-02-21 22:06:43.527013 + +""" +from alembic import op +import sqlalchemy as sa +import pgvector + + +# revision identifiers, used by Alembic. +revision = '2b04e961eee4' +down_revision = 'e58835fadd96' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('catalog', sa.Column('embedding_model', sa.String(length=50), nullable=True)) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('catalog', 'embedding_model') + # ### end Alembic commands ### diff --git a/nginx/nginx.conf b/nginx/nginx.conf index d2c2e0a..1766f96 100644 --- a/nginx/nginx.conf +++ b/nginx/nginx.conf @@ -74,24 +74,24 @@ http { root /etc/nginx/public; } - location /chat/ { - proxy_pass http://eveai_chat:5002/; - - proxy_set_header Host $host; - proxy_set_header X-Real-IP $remote_addr; - proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; - proxy_set_header X-Forwarded-Proto $scheme; - proxy_http_version 1.1; - proxy_set_header Upgrade $http_upgrade; - proxy_set_header Connection "upgrade"; - proxy_buffering off; - - # Add CORS headers - add_header 'Access-Control-Allow-Origin' '*' always; - add_header 'Access-Control-Allow-Methods' 'GET, POST, OPTIONS' always; - add_header 'Access-Control-Allow-Headers' 'DNT,X-CustomHeader,Keep-Alive,User-Agent,X-Requested-With,If-Modified-Since,Cache-Control,Content-Type,Authorization' always; - add_header 'Access-Control-Allow-Credentials' 'true' always; - } +# location /chat/ { +# proxy_pass http://eveai_chat:5002/; +# +# proxy_set_header Host $host; +# proxy_set_header X-Real-IP $remote_addr; +# proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; +# proxy_set_header X-Forwarded-Proto $scheme; +# proxy_http_version 1.1; +# proxy_set_header Upgrade $http_upgrade; +# proxy_set_header Connection "upgrade"; +# proxy_buffering off; +# +# # Add CORS headers +# add_header 'Access-Control-Allow-Origin' '*' always; +# add_header 'Access-Control-Allow-Methods' 'GET, POST, OPTIONS' always; +# add_header 'Access-Control-Allow-Headers' 'DNT,X-CustomHeader,Keep-Alive,User-Agent,X-Requested-With,If-Modified-Since,Cache-Control,Content-Type,Authorization' always; +# add_header 'Access-Control-Allow-Credentials' 'true' always; +# } location /admin/ { # include uwsgi_params; diff --git a/requirements.txt b/requirements.txt index a95355b..2b9739c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -alembic~=1.13.2 +alembic~=1.14.1 annotated-types~=0.7.0 bcrypt~=4.1.3 beautifulsoup4~=4.12.3 @@ -6,18 +6,17 @@ celery~=5.4.0 certifi~=2024.7.4 chardet~=5.2.0 cors~=1.0.1 -Flask~=3.0.3 +Flask~=3.1.0 Flask-BabelEx~=0.9.4 Flask-Bootstrap~=3.3.7.1 Flask-Cors~=5.0.0 -Flask-JWT-Extended~=4.6.0 +Flask-JWT-Extended~=4.7.1 Flask-Login~=0.6.3 flask-mailman~=1.1.1 -Flask-Migrate~=4.0.7 +Flask-Migrate~=4.1.0 Flask-Principal~=0.4.0 -Flask-Security-Too~=5.5.2 +Flask-Security-Too~=5.6.0 Flask-Session~=0.8.0 -Flask-SocketIO~=5.3.6 Flask-SQLAlchemy~=3.1.1 Flask-WTF~=1.2.1 gevent~=24.2.1 @@ -43,12 +42,10 @@ pgvector~=0.2.5 pycryptodome~=3.20.0 pydantic~=2.9.1 PyJWT~=2.8.0 -PySocks~=1.7.1 python-dateutil~=2.9.0.post0 python-engineio~=4.9.1 python-iso639~=2024.4.27 python-magic~=0.4.27 -python-socketio~=5.11.3 pytz~=2024.1 PyYAML~=6.0.2 redis~=5.0.4 @@ -64,7 +61,7 @@ groq~=0.9.0 pydub~=0.25.1 argparse~=1.4.0 minio~=7.2.7 -Werkzeug~=3.0.3 +Werkzeug~=3.1.3 itsdangerous~=2.2.0 cryptography~=43.0.0 graypy~=2.1.0 @@ -91,4 +88,6 @@ dogpile.cache~=1.3.3 python-docx~=1.1.2 crewai~=0.102.0 sseclient~=0.0.27 -termcolor~=2.5.0 \ No newline at end of file +termcolor~=2.5.0 +mistral-common~=1.5.3 +mistralai~=1.5.0 \ No newline at end of file diff --git a/tests/interactive_client/specialist_client.py b/tests/interactive_client/specialist_client.py deleted file mode 100644 index f898ae1..0000000 --- a/tests/interactive_client/specialist_client.py +++ /dev/null @@ -1,247 +0,0 @@ -#!/usr/bin/env python3 -import json -import logging -import sys -import time -import requests # Used for calling the auth API -from datetime import datetime -import yaml # For loading the YAML configuration -from urllib.parse import urlparse - -import socketio # Official python-socketio client - -# ---------------------------- -# Constants for authentication and specialist selection -# ---------------------------- -API_KEY = "EveAI-8342-2966-4731-6578-1010-8903-4230-4378" -TENANT_ID = 2 -SPECIALIST_ID = 2 -BASE_API_URL = "http://macstudio.ask-eve-ai-local.com:8080/api/api/v1" -BASE_SOCKET_URL = "http://macstudio.ask-eve-ai-local.com:8080" -CONFIG_FILE = "config/specialists/SPIN_SPECIALIST/1.0.0.yaml" # Path to specialist configuration - -# ---------------------------- -# Logging Configuration -# ---------------------------- -LOG_FILENAME = "specialist_client.log" -logging.basicConfig( - filename=LOG_FILENAME, - level=logging.DEBUG, - format="%(asctime)s %(levelname)s: %(message)s" -) -console_handler = logging.StreamHandler(sys.stdout) -console_handler.setLevel(logging.INFO) -logging.getLogger('').addHandler(console_handler) - -# ---------------------------- -# Create the Socket.IO client using the official python-socketio client -# ---------------------------- -sio = socketio.Client(logger=True, engineio_logger=True) -room = None # Global variable to store the assigned room - -# ---------------------------- -# Event Handlers -# ---------------------------- -@sio.event -def connect(): - logging.info("Connected to Socket.IO server.") - print("Connected to server.") - -@sio.event -def disconnect(): - logging.info("Disconnected from Socket.IO server.") - print("Disconnected from server.") - -@sio.on("connect_error") -def on_connect_error(data): - logging.error("Connect error: %s", data) - print("Connect error:", data) - -@sio.on("authenticated") -def on_authenticated(data): - global room - room = data.get("room") - logging.info("Authenticated. Room: %s", room) - print("Authenticated. Room:", room) - -@sio.on("room_join") -def on_room_join(data): - global room - room = data.get("room") - logging.info("Room join event received. Room: %s", room) - print("Joined room:", room) - -@sio.on("token_expired") -def on_token_expired(data): - logging.warning("Token expired.") - print("Token expired. Please refresh your session.") - -@sio.on("reconnect_attempt") -def on_reconnect_attempt(attempt): - logging.info("Reconnect attempt #%s", attempt) - print(f"Reconnect attempt #{attempt}") - -@sio.on("reconnect") -def on_reconnect(): - logging.info("Reconnected successfully.") - print("Reconnected to server.") - -@sio.on("reconnect_failed") -def on_reconnect_failed(): - logging.error("Reconnection failed.") - print("Reconnection failed. Please refresh.") - -@sio.on("room_rejoin_result") -def on_room_rejoin_result(data): - if data.get("success"): - global room - room = data.get("room") - logging.info("Successfully rejoined room: %s", room) - print("Rejoined room:", room) - else: - logging.error("Failed to rejoin room.") - print("Failed to rejoin room.") - -@sio.on("bot_response") -def on_bot_response(data): - logging.info("Received bot response: %s", data) - print("Bot response received:") - print(json.dumps(data, indent=2)) - -@sio.on("task_status") -def on_task_status(data): - logging.info("Received task status: %s", data) - print("Task status:") - print(json.dumps(data, indent=2)) - -# ---------------------------- -# Helper: Retrieve token from REST API -# ---------------------------- -def retrieve_token(api_url: str) -> str: - payload = { - "tenant_id": TENANT_ID, - "api_key": API_KEY - } - try: - logging.info("Requesting token from %s with payload: %s", api_url, payload) - response = requests.post(api_url, json=payload) - response.raise_for_status() - token = response.json()["access_token"] - logging.info("Token retrieved successfully.") - return token - except Exception as e: - logging.error("Failed to retrieve token: %s", e) - raise e - -# ---------------------------- -# Main Interactive UI Function -# ---------------------------- -def main(): - global room - - # Retrieve the token - auth_url = f"{BASE_API_URL}/auth/token" - try: - token = retrieve_token(auth_url) - print("Token retrieved successfully.") - except Exception as e: - print("Error retrieving token. Check logs for details.") - sys.exit(1) - - # Parse the BASE_SOCKET_URL - parsed_url = urlparse(BASE_SOCKET_URL) - host_url = f"{parsed_url.scheme}://{parsed_url.netloc}" - - # Connect to the Socket.IO server. - # Note: Use `auth` instead of `query_string` (the official client uses the `auth` parameter) - try: - sio.connect( - host_url, - socketio_path='/chat/socket.io', - auth={"token": token}, - ) - except Exception as e: - logging.error("Failed to connect to Socket.IO server: %s", e) - print("Failed to connect to Socket.IO server:", e) - sys.exit(1) - - # Allow time for authentication and room assignment. - time.sleep(2) - if not room: - logging.warning("No room assigned. Exiting.") - print("No room assigned by the server. Exiting.") - sio.disconnect() - sys.exit(1) - - # Load specialist configuration from YAML. - try: - with open(CONFIG_FILE, "r") as f: - specialist_config = yaml.safe_load(f) - arg_config = specialist_config.get("arguments", {}) - logging.info("Loaded specialist argument configuration: %s", arg_config) - except Exception as e: - logging.error("Failed to load specialist configuration: %s", e) - print("Failed to load specialist configuration. Exiting.") - sys.exit(1) - - # Dictionary to store default values for static arguments (except "query") - static_defaults = {} - - print("\nInteractive Specialist Client") - print("For each iteration, you will be prompted for the following arguments:") - for key, details in arg_config.items(): - print(f" - {details.get('name', key)}: {details.get('description', '')}") - print("Type 'quit' or 'exit' as the query to end the session.\n") - - # Interactive loop: prompt for arguments and send user message. - while True: - current_arguments = {} - for arg_key, arg_details in arg_config.items(): - prompt_msg = f"Enter {arg_details.get('name', arg_key)}" - desc = arg_details.get("description", "") - if desc: - prompt_msg += f" ({desc})" - if arg_key != "query": - default_value = static_defaults.get(arg_key, "") - if default_value: - prompt_msg += f" [default: {default_value}]" - prompt_msg += ": " - value = input(prompt_msg).strip() - if not value: - value = default_value - static_defaults[arg_key] = value - else: - prompt_msg += " (required): " - value = input(prompt_msg).strip() - while not value: - print("Query is required. Please enter a value.") - value = input(prompt_msg).strip() - current_arguments[arg_key] = value - - if current_arguments.get("query", "").lower() in ["quit", "exit"]: - break - - try: - timezone = datetime.now().astimezone().tzname() - except Exception: - timezone = "UTC" - - payload = { - "token": token, - "tenant_id": TENANT_ID, - "specialist_id": SPECIALIST_ID, - "arguments": current_arguments, - "timezone": timezone, - "room": room - } - - logging.info("Sending user_message with payload: %s", payload) - print("Sending message to specialist...") - sio.emit("user_message", payload) - time.sleep(1) - - print("Exiting interactive session.") - sio.disconnect() - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/tests/specialist_execution/test_specialist_client.py b/tests/specialist_execution/test_specialist_client.py index 986ec82..3f2a556 100644 --- a/tests/specialist_execution/test_specialist_client.py +++ b/tests/specialist_execution/test_specialist_client.py @@ -18,9 +18,7 @@ sys.path.append(project_root) API_BASE_URL = "http://macstudio.ask-eve-ai-local.com:8080/api/api/v1" TENANT_ID = 2 # Replace with your tenant ID API_KEY = "EveAI-5096-5466-6143-1487-8085-4174-2080-7208" # Replace with your API key -SPECIALIST_TYPE = "SPIN_SPECIALIST" # Replace with your specialist type SPECIALIST_ID = 5 # Replace with your specialist ID -ROOT_FOLDER = "../.." def get_auth_token() -> str: @@ -52,15 +50,27 @@ def get_session_id(auth_token: str) -> str: return response.json()["session_id"] -def load_specialist_config() -> Dict[str, Any]: - """Load specialist configuration from YAML file""" - config_path = f"{ROOT_FOLDER}/config/specialists/{SPECIALIST_TYPE}/1.0.0.yaml" - if not os.path.exists(config_path): - print(colored(f"Error: Configuration file not found: {config_path}", "red")) - sys.exit(1) +def get_specialist_config(auth_token: str, specialist_id: int) -> Dict[str, Any]: + """Get specialist configuration from API""" + headers = { + 'Authorization': f'Bearer {auth_token}', + 'Content-Type': 'application/json' + } - with open(config_path, 'r') as f: - return yaml.safe_load(f) + response = requests.get( + f"{API_BASE_URL}/specialist-execution/specialist_arguments", + headers=headers, + json={ + 'specialist_id': specialist_id + } + ) + + print(colored(f"Status Code: {response.status_code}", "cyan")) + if response.status_code == 200: + config_data = response.json() + return config_data.get('arguments', {}) + else: + raise Exception(f"Failed to get specialist configuration: {response.text}") def get_argument_value(arg_name: str, arg_config: Dict[str, Any], previous_value: Any = None) -> Any: @@ -163,8 +173,10 @@ def main(): auth_token = get_auth_token() # Load specialist configuration - print(colored(f"Loading specialist configuration {SPECIALIST_TYPE}", "cyan")) - config = load_specialist_config() + print(colored(f"Loading specialist configuration", "cyan")) + config = { + 'arguments': get_specialist_config(auth_token, SPECIALIST_ID) + } previous_args = None while True: