Compare commits
340 Commits
1.0.2-alfa
...
develop
| Author | SHA1 | Date | |
|---|---|---|---|
| 92edbeacb2 | |||
|
|
30bfecc135 | ||
|
|
2c8347c91b | ||
|
|
fe9fc047ff | ||
|
|
0f8bda0aef | ||
|
|
bab9e89117 | ||
|
|
e25698d6cf | ||
|
|
e30fe7807c | ||
|
|
94b805e0eb | ||
|
|
9b86a220b1 | ||
|
|
5a5d6b03af | ||
|
|
b1d8c9a17d | ||
|
|
14273b8a70 | ||
|
|
5e25216b66 | ||
|
|
d68dfde52a | ||
|
|
4bc2292c4c | ||
|
|
f10bb6f395 | ||
|
|
0d3c3949de | ||
|
|
25adb4213b | ||
|
|
73125887a3 | ||
|
|
c29ed37c09 | ||
|
|
9b1f9e8a3b | ||
|
|
e167df3032 | ||
|
|
20fb2eee70 | ||
|
|
3815399a7e | ||
|
|
f2bd90e6ae | ||
|
|
95c8282eb8 | ||
|
|
04c9d8cf98 | ||
|
|
03f6ef4408 | ||
|
|
e8bb66c2c2 | ||
|
|
5dd711bcd2 | ||
|
|
ee13de7fde | ||
|
|
82ca6b537a | ||
|
|
af37aa7253 | ||
|
|
1748aebd38 | ||
|
|
d6041ebb27 | ||
|
|
b3ee2f7ce9 | ||
|
|
c523250ccb | ||
|
|
a43825f5f0 | ||
|
|
fb261ca0b9 | ||
|
|
3ca2e0a3a9 | ||
|
|
3aa2158a17 | ||
|
|
2bc5832db6 | ||
|
|
1720ddfa11 | ||
|
|
59febb7fbb | ||
|
|
4ec1099925 | ||
|
|
8d1a8d9645 | ||
|
|
1d79a19981 | ||
|
|
aab766fe5e | ||
|
|
05241ecdea | ||
|
|
451f95fbc1 | ||
|
|
842429a659 | ||
|
|
225d494e15 | ||
|
|
5501061dd1 | ||
|
|
eeb76d57b7 | ||
|
|
3ea3a06de6 | ||
|
|
37819cd7e5 | ||
|
|
a798217091 | ||
|
|
83272a4e2a | ||
|
|
b66e2e99ed | ||
|
|
aeee22b305 | ||
|
|
5f387dcef8 | ||
|
|
b499add891 | ||
|
|
2f815616b1 | ||
|
|
f23214bb6d | ||
|
|
6df9aa9c7e | ||
|
|
5465dae52f | ||
|
|
79a3f94ac2 | ||
|
|
06586a1312 | ||
|
|
7b0e3cee7f | ||
|
|
7bef4e69df | ||
|
|
a3e18cb4db | ||
|
|
471b8dd8c3 | ||
|
|
030d1b0e90 | ||
|
|
fa452e4934 | ||
|
|
e24e7265b9 | ||
|
|
a76f87ba75 | ||
|
|
c6fc8ca09a | ||
|
|
16ce59ae98 | ||
|
|
cc47ce2d32 | ||
|
|
b1e9fb71cb | ||
|
|
a57662db3f | ||
|
|
66433f19b3 | ||
|
|
e7397a6d0d | ||
|
|
d097451d42 | ||
|
|
44e5dd5d02 | ||
|
|
3b23be0ea4 | ||
|
|
61ae9c3174 | ||
|
|
b6512b2d8c | ||
|
|
0cd12a8491 | ||
|
|
ae36791ffe | ||
|
|
53bfc6bb23 | ||
|
|
2afee41c2a | ||
|
|
79b1fef5b6 | ||
|
|
2b04692fab | ||
|
|
541d3862e6 | ||
|
|
43fd4ce9c1 | ||
|
|
14ba53e26b | ||
|
|
4ab8b2a714 | ||
|
|
42cb1de0fd | ||
|
|
a325fa5084 | ||
|
|
7cb19ca21e | ||
|
|
6ccba7d1e3 | ||
|
|
6fbaff45a8 | ||
|
|
10ca344c84 | ||
|
|
a9bbd1f466 | ||
|
|
804486664b | ||
|
|
36575c17a8 | ||
|
|
575bfa259e | ||
|
|
362b2fe753 | ||
|
|
5c20e6c1f9 | ||
|
|
b812aedb81 | ||
|
|
d6ea3ba46c | ||
|
|
a6edd5c663 | ||
|
|
6115cc7e13 | ||
|
|
54a9641440 | ||
|
|
af8b5f54cd | ||
|
|
2a0c92b064 | ||
|
|
898bb32318 | ||
|
|
b0e1ad6e03 | ||
|
|
84afc0b2ee | ||
|
|
593dd438aa | ||
|
|
35f58f0c57 | ||
|
|
25ab9ccf23 | ||
|
|
2a4c9d7b00 | ||
|
|
e6c3c24bd8 | ||
|
|
481157fb31 | ||
|
|
376ad328ca | ||
|
|
2bb9d4b0be | ||
|
|
6eae0ab1a3 | ||
|
|
4395d2e407 | ||
|
|
da61f5f9ec | ||
|
|
53283b6687 | ||
|
|
5d715a958c | ||
|
|
0f969972d6 | ||
|
|
4c00d33bc3 | ||
|
|
9c63ecb17f | ||
|
|
d6a2635e50 | ||
|
|
84a9334c80 | ||
|
|
066f579294 | ||
|
|
ebf92b0474 | ||
|
|
7e35549262 | ||
|
|
866cc2a60d | ||
|
|
ed87d73c5a | ||
|
|
212ea28de8 | ||
|
|
cea38e02d2 | ||
|
|
248fae500a | ||
|
|
4d6466038f | ||
|
|
9a88582fff | ||
|
|
998ddf4c03 | ||
|
|
dabf97c96e | ||
|
|
5e81595622 | ||
|
|
ef138462d9 | ||
|
|
42ffe3795f | ||
|
|
ba523a95c5 | ||
|
|
8a85b4540f | ||
|
|
fc3cae1986 | ||
|
|
32df3d0589 | ||
|
|
ccc1a2afb8 | ||
|
|
f16ed85e82 | ||
|
|
e990fe65d8 | ||
|
|
32cf105d7b | ||
|
|
dc6cd9d940 | ||
|
|
a0f806ba4e | ||
|
|
98db88b00b | ||
|
|
4ad621428e | ||
|
|
0f33beddf4 | ||
|
|
f8f941d1e1 | ||
|
|
abc0a50dcc | ||
|
|
854d889413 | ||
|
|
7bbc32e381 | ||
|
|
e75c49d2fa | ||
|
|
ccb844c15c | ||
|
|
b60600e9f6 | ||
|
|
11b1d548bd | ||
|
|
f3a243698c | ||
|
|
000636a229 | ||
|
|
acad28b623 | ||
|
|
42635a583c | ||
|
|
7d7db296d3 | ||
|
|
51fd16bcc6 | ||
|
|
509ee95d81 | ||
|
|
33b5742d2f | ||
|
|
50773fe602 | ||
|
|
51d029d960 | ||
|
|
fbc9f44ac8 | ||
|
|
4338f09f5c | ||
|
|
53e32a67bd | ||
|
|
fda267b479 | ||
|
|
f5c9542a49 | ||
|
|
043cea45f2 | ||
|
|
7b87880045 | ||
|
|
5b2c04501c | ||
|
|
babcd6ec04 | ||
|
|
71adf64668 | ||
|
|
dbea41451a | ||
|
|
82e25b356c | ||
|
|
3c7460f741 | ||
|
|
2835486599 | ||
|
|
f1c60f9574 | ||
|
|
b326c0c6f2 | ||
|
|
5f1a5711f6 | ||
|
|
67ceb57b79 | ||
|
|
23b49516cb | ||
|
|
9cc266b97f | ||
|
|
3f77871c4f | ||
|
|
199cf94cf2 | ||
|
|
c4dcd6a0d3 | ||
|
|
43ee9139d6 | ||
|
|
8f45005713 | ||
|
|
bc1626c4ff | ||
|
|
57c0e7a1ba | ||
|
|
0d05499d2b | ||
|
|
b4e58659a8 | ||
|
|
67078ce925 | ||
|
|
ebdb836448 | ||
|
|
81e754317a | ||
|
|
578981c745 | ||
|
|
8fb2ad43c5 | ||
|
|
49f9077a7b | ||
|
|
d290b46a0c | ||
|
|
73647e4795 | ||
|
|
25e169dbea | ||
|
|
8a29eb0d8f | ||
|
|
0a5f0986e6 | ||
|
|
4d79c4fd5a | ||
|
|
5123de55cc | ||
|
|
1fdbd2ff45 | ||
|
|
d789e431ca | ||
|
|
70de4c0328 | ||
|
|
d2bb51a4a8 | ||
|
|
28aea85b10 | ||
|
|
d2a9092f46 | ||
|
|
5c982fcc2c | ||
|
|
b4f7b210e0 | ||
|
|
1b1eef0d2e | ||
|
|
17d32cd039 | ||
|
|
12a53ebc1c | ||
|
|
a421977918 | ||
|
|
4c480c9baa | ||
|
|
9ea04572c8 | ||
|
|
6ef025363d | ||
|
|
9652d0bff9 | ||
|
|
4bf12db142 | ||
|
|
5f58417d24 | ||
|
|
3eed546879 | ||
|
|
35f0adef1b | ||
|
|
f43e79376c | ||
|
|
be76dd5240 | ||
|
|
c2c3b01b28 | ||
|
|
8daa52d1e9 | ||
|
|
9ad7c1aee9 | ||
|
|
1762b930bc | ||
|
|
d57bc5cf03 | ||
|
|
6c8c33d296 | ||
|
|
4ea16521e2 | ||
|
|
b6ee7182de | ||
|
|
238bdb58f4 | ||
|
|
a35486b573 | ||
|
|
dc64bbc257 | ||
|
|
09555ae8b0 | ||
|
|
cf2201a1f7 | ||
|
|
a6402524ce | ||
|
|
56a00c2894 | ||
|
|
6465e4f358 | ||
|
|
4b43f96afe | ||
|
|
e088ef7e4e | ||
|
|
9e03af45e1 | ||
|
|
5bfd3445bb | ||
|
|
efff63043a | ||
|
|
c15cabc289 | ||
|
|
55a89c11bb | ||
|
|
c037d4135e | ||
|
|
25213f2004 | ||
|
|
d106520d22 | ||
|
|
7bddeb0ebd | ||
|
|
f7cd58ed2a | ||
|
|
53c625599a | ||
|
|
88ee4f482b | ||
|
|
3176b95323 | ||
|
|
46c60b36a0 | ||
|
|
d35ec9f5ae | ||
|
|
311927d5ea | ||
|
|
fb798501b9 | ||
|
|
99135c9b02 | ||
|
|
425b580f15 | ||
|
|
b658e68e65 | ||
|
|
b8e07bec77 | ||
|
|
344ea26ecc | ||
|
|
98cb4e4f2f | ||
|
|
07d89d204f | ||
|
|
7702a6dfcc | ||
|
|
4c009949b3 | ||
|
|
aa4ac3ec7c | ||
|
|
1807435339 | ||
|
|
55a8a95f79 | ||
|
|
503ea7965d | ||
|
|
88f4db1178 | ||
|
|
2df291ea91 | ||
|
|
5841525b4c | ||
|
|
532073d38e | ||
|
|
43547287b1 | ||
|
|
aa358df28e | ||
|
|
30fec27488 | ||
|
|
5e77b478dd | ||
|
|
6f71259822 | ||
|
|
74cc7ae95e | ||
|
|
7f12c8b355 | ||
|
|
6069f5f7e5 | ||
|
|
3e644f1652 | ||
|
|
3316a8bc47 | ||
|
|
270479c77d | ||
|
|
0f4558d775 | ||
|
|
9f5f090f0c | ||
|
|
5ffad160b1 | ||
|
|
d6a7743f26 | ||
|
|
9782e31ae5 | ||
|
|
f638860e90 | ||
|
|
b700cfac64 | ||
|
|
883175b8f5 | ||
|
|
ae697df4c9 | ||
|
|
d9cb00fcdc | ||
|
|
ee1b0f1cfa | ||
|
|
a740c96630 | ||
|
|
67bdeac434 | ||
|
|
1622591afd | ||
|
|
6cf660e622 | ||
|
|
9e14824249 | ||
|
|
76cb825660 | ||
|
|
341ba47d1c | ||
|
|
1fa33c029b | ||
|
|
bcf7d439f3 | ||
|
|
b9acf4d2ae | ||
|
|
ae7bf3dbae | ||
|
|
914c265afe | ||
|
|
a158655247 | ||
|
|
bc350af247 | ||
|
|
6062b7646c | ||
|
|
122d1a18df | ||
|
|
2ca006d82c |
19
.aiignore
Normal file
19
.aiignore
Normal file
@@ -0,0 +1,19 @@
|
||||
# An .aiignore file follows the same syntax as a .gitignore file.
|
||||
# .gitignore documentation: https://git-scm.com/docs/gitignore
|
||||
|
||||
# you can ignore files
|
||||
.DS_Store
|
||||
*.log
|
||||
*.tmp
|
||||
|
||||
# or folders
|
||||
dist/
|
||||
build/
|
||||
out/
|
||||
nginx/node_modules/
|
||||
nginx/static/
|
||||
db_backups/
|
||||
docker/eveai_logs/
|
||||
docker/logs/
|
||||
docker/minio/
|
||||
|
||||
48
.gitignore
vendored
48
.gitignore
vendored
@@ -12,3 +12,51 @@ docker/tenant_files/
|
||||
**/.DS_Store
|
||||
__pycache__
|
||||
**/__pycache__
|
||||
/.idea
|
||||
*.pyc
|
||||
common/.DS_Store
|
||||
common/__pycache__/__init__.cpython-312.pyc
|
||||
common/__pycache__/extensions.cpython-312.pyc
|
||||
common/models/__pycache__/__init__.cpython-312.pyc
|
||||
common/models/__pycache__/document.cpython-312.pyc
|
||||
common/models/__pycache__/interaction.cpython-312.pyc
|
||||
common/models/__pycache__/user.cpython-312.pyc
|
||||
common/utils/.DS_Store
|
||||
common/utils/__pycache__/__init__.cpython-312.pyc
|
||||
common/utils/__pycache__/celery_utils.cpython-312.pyc
|
||||
common/utils/__pycache__/nginx_utils.cpython-312.pyc
|
||||
common/utils/__pycache__/security.cpython-312.pyc
|
||||
common/utils/__pycache__/simple_encryption.cpython-312.pyc
|
||||
common/utils/__pycache__/template_filters.cpython-312.pyc
|
||||
config/.DS_Store
|
||||
config/__pycache__/__init__.cpython-312.pyc
|
||||
config/__pycache__/config.cpython-312.pyc
|
||||
config/__pycache__/logging_config.cpython-312.pyc
|
||||
eveai_app/.DS_Store
|
||||
eveai_app/__pycache__/__init__.cpython-312.pyc
|
||||
eveai_app/__pycache__/errors.cpython-312.pyc
|
||||
eveai_chat/.DS_Store
|
||||
migrations/.DS_Store
|
||||
migrations/public/.DS_Store
|
||||
scripts/.DS_Store
|
||||
scripts/__pycache__/run_eveai_app.cpython-312.pyc
|
||||
/eveai_repo.txt
|
||||
*repo.txt
|
||||
/docker/eveai_logs/
|
||||
/integrations/Wordpress/eveai_sync.zip
|
||||
/integrations/Wordpress/eveai-chat.zip
|
||||
/db_backups/
|
||||
/tests/interactive_client/specialist_client.log
|
||||
/.repopackignore
|
||||
/patched_packages/crewai/
|
||||
/docker/prometheus/data/
|
||||
/docker/grafana/data/
|
||||
/temp_requirements/
|
||||
/nginx/node_modules/
|
||||
/nginx/.parcel-cache/
|
||||
/nginx/static/
|
||||
/docker/build_logs/
|
||||
/content/.Ulysses-Group.plist
|
||||
/content/.Ulysses-Settings.plist
|
||||
/.python-version
|
||||
/q
|
||||
|
||||
8
.idea/.gitignore
generated
vendored
8
.idea/.gitignore
generated
vendored
@@ -1,8 +0,0 @@
|
||||
# Default ignored files
|
||||
/shelf/
|
||||
/workspace.xml
|
||||
# Editor-based HTTP Client requests
|
||||
/httpRequests/
|
||||
# Datasource local storage ignored files
|
||||
/dataSources/
|
||||
/dataSources.local.xml
|
||||
22
.idea/eveAI.iml
generated
22
.idea/eveAI.iml
generated
@@ -1,22 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<module type="PYTHON_MODULE" version="4">
|
||||
<component name="Flask">
|
||||
<option name="enabled" value="true" />
|
||||
</component>
|
||||
<component name="NewModuleRootManager">
|
||||
<content url="file://$MODULE_DIR$">
|
||||
<excludeFolder url="file://$MODULE_DIR$/.venv" />
|
||||
<excludeFolder url="file://$MODULE_DIR$/.venv2" />
|
||||
</content>
|
||||
<orderEntry type="jdk" jdkName="Python 3.12 (eveai_dev)" jdkType="Python SDK" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
</component>
|
||||
<component name="TemplatesService">
|
||||
<option name="TEMPLATE_CONFIGURATION" value="Jinja2" />
|
||||
<option name="TEMPLATE_FOLDERS">
|
||||
<list>
|
||||
<option value="$MODULE_DIR$/templates" />
|
||||
</list>
|
||||
</option>
|
||||
</component>
|
||||
</module>
|
||||
6
.idea/inspectionProfiles/profiles_settings.xml
generated
6
.idea/inspectionProfiles/profiles_settings.xml
generated
@@ -1,6 +0,0 @@
|
||||
<component name="InspectionProjectProfileManager">
|
||||
<settings>
|
||||
<option name="USE_PROJECT_PROFILE" value="false" />
|
||||
<version value="1.0" />
|
||||
</settings>
|
||||
</component>
|
||||
7
.idea/misc.xml
generated
7
.idea/misc.xml
generated
@@ -1,7 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="Black">
|
||||
<option name="sdkName" value="Python 3.12 (eveai_tbd)" />
|
||||
</component>
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.12 (eveai_tbd)" project-jdk-type="Python SDK" />
|
||||
</project>
|
||||
8
.idea/modules.xml
generated
8
.idea/modules.xml
generated
@@ -1,8 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ProjectModuleManager">
|
||||
<modules>
|
||||
<module fileurl="file://$PROJECT_DIR$/.idea/TBD.iml" filepath="$PROJECT_DIR$/.idea/TBD.iml" />
|
||||
</modules>
|
||||
</component>
|
||||
</project>
|
||||
6
.idea/sqldialects.xml
generated
6
.idea/sqldialects.xml
generated
@@ -1,6 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="SqlDialectMappings">
|
||||
<file url="PROJECT" dialect="PostgreSQL" />
|
||||
</component>
|
||||
</project>
|
||||
6
.idea/vcs.xml
generated
6
.idea/vcs.xml
generated
@@ -1,6 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="VcsDirectoryMappings">
|
||||
<mapping directory="$PROJECT_DIR$" vcs="Git" />
|
||||
</component>
|
||||
</project>
|
||||
@@ -1 +0,0 @@
|
||||
eveai_tbd
|
||||
342
CHANGELOG.md
Normal file
342
CHANGELOG.md
Normal file
@@ -0,0 +1,342 @@
|
||||
# Changelog
|
||||
|
||||
All notable changes to EveAI will be documented in this file.
|
||||
|
||||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## [2.3.0-alfa]
|
||||
|
||||
### Added
|
||||
- Introduction of Push Gateway for Prometheus
|
||||
- Introduction of Partner Models
|
||||
- Introduction of Tenant and Partner codes for more security
|
||||
- Introduction of 'Management Partner' type and additional 'Partner Admin'-role
|
||||
- Introduction of a technical services layer
|
||||
- Introduction of partner-specific configurations
|
||||
- Introduction of additional test environment
|
||||
- Introduction of strict no-overage usage
|
||||
- Introduction of LicensePeriod, Payments & Invoices
|
||||
- Introduction of Processed File Viewer
|
||||
- Introduction of Traicie Role Definition Specialist
|
||||
- Allow invocation of non-interactive specialists in administrative interface (eveai_app)
|
||||
- Introduction of advanced JSON editor
|
||||
- Introduction of ChatSession (Specialist Execution) follow-up in administrative interface
|
||||
- Introduce npm for javascript libraries usage and optimisations
|
||||
- Introduction of new top bar in administrative interface to show session defaults (removing old navbar buttons)
|
||||
-
|
||||
|
||||
### Changed
|
||||
- Add 'Register'-button to list views, replacing register menu-items
|
||||
- Add additional environment capabilities in docker
|
||||
- PDF Processor now uses Mistral OCR
|
||||
- Allow additional chunking mechanisms for very long chunks (in case of very large documents)
|
||||
- Allow for TrackedMistralAIEmbedding batching to allow for processing long documents
|
||||
- RAG & SPIN Specialist improvements
|
||||
- Move mail messaging from standard SMTP to Scaleway TEM mails
|
||||
- Improve mail layouts
|
||||
- Add functionality to add a default dictionary for dynamic forms
|
||||
- AI model choices defined by Ask Eve AI iso Tenant (replaces ModelVariables completely)
|
||||
- Improve HTML Processing
|
||||
- Pagination improvements
|
||||
- Update Material Kit Pro to latest version
|
||||
|
||||
### Removed
|
||||
- Repopack implementation ==> Using PyCharm's new AI capabilities instead
|
||||
|
||||
### Fixed
|
||||
- Synchronous vs Asynchronous behaviour in crewAI type specialists
|
||||
- Nasty dynamic boolean fields bug corrected
|
||||
- Several smaller bugfixes
|
||||
- Tasks & Tools editors finished
|
||||
|
||||
### Security
|
||||
- In case of vulnerabilities.
|
||||
|
||||
## [2.2.0-alfa]
|
||||
|
||||
### Added
|
||||
- Mistral AI as main provider for embeddings, chains and specialists
|
||||
- Usage measuring for specialists
|
||||
- RAG from chain to specialist technology
|
||||
- Dossier catalog management possibilities added to eveai_app
|
||||
- Asset definition (Paused - other priorities)
|
||||
- Prometheus and Grafana
|
||||
- Add prometheus monitoring to business events
|
||||
- Asynchronous execution of specialists
|
||||
|
||||
### Changed
|
||||
- Moved choice for AI providers / models to specialists and prompts
|
||||
- Improve RAG to not repeat historic answers
|
||||
- Fixed embedding model, no more choices allowed
|
||||
- clean url (of tracking parameters) before adding it to a catalog
|
||||
|
||||
### Deprecated
|
||||
- For soon-to-be removed features.
|
||||
|
||||
### Removed
|
||||
- Add Multiple URLs removed from menu
|
||||
- Old Specialist items removed from interaction menu
|
||||
-
|
||||
|
||||
### Fixed
|
||||
- Set default language when registering Documents or URLs.
|
||||
|
||||
### Security
|
||||
- In case of vulnerabilities.
|
||||
|
||||
## [2.1.0-alfa]
|
||||
|
||||
### Added
|
||||
- Zapier Refresh Document
|
||||
- SPIN Specialist definition - from start to finish
|
||||
- Introduction of startup scripts in eveai_app
|
||||
- Caching for all configurations added
|
||||
- Caching for processed specialist configurations
|
||||
- Caching for specialist history
|
||||
- Augmented Specialist Editor, including Specialist graphic presentation
|
||||
- Introduction of specialist_execution_api, introducting SSE
|
||||
- Introduction of crewai framework for specialist implementation
|
||||
- Test app for testing specialists - also serves as a sample client application for SSE
|
||||
-
|
||||
|
||||
### Changed
|
||||
- Improvement of startup of applications using gevent, and better handling and scaling of multiple connections
|
||||
- STANDARD_RAG Specialist improvement
|
||||
-
|
||||
|
||||
### Deprecated
|
||||
- eveai_chat - using sockets - will be replaced with new specialist_execution_api and SSE
|
||||
|
||||
## [2.0.1-alfa]
|
||||
|
||||
### Added
|
||||
- Zapîer Integration (partial - only adding files).
|
||||
- Addition of general chunking parameters (chunking_heading_level and chunking_patterns)
|
||||
- Addition of DocX and markdown Processor Types
|
||||
|
||||
### Changed
|
||||
- For changes in existing functionality.
|
||||
|
||||
### Deprecated
|
||||
- For soon-to-be removed features.
|
||||
|
||||
### Removed
|
||||
- For now removed features.
|
||||
|
||||
### Fixed
|
||||
- Ensure the RAG Specialist is using the detailed_question
|
||||
- Wordpress Chat Plugin: languages dropdown filled again
|
||||
- OpenAI update - proxies no longer supported
|
||||
- Build & Release script for Wordpress Plugins (including end user download folder)
|
||||
|
||||
### Security
|
||||
- In case of vulnerabilities.
|
||||
|
||||
## [2.0.0-alfa]
|
||||
|
||||
### Added
|
||||
- Introduction of dynamic Retrievers & Specialists
|
||||
- Introduction of dynamic Processors
|
||||
- Introduction of caching system
|
||||
- Introduction of a better template manager
|
||||
- Modernisation of external API/Socket authentication using projects
|
||||
- Creation of new eveai_chat WordPress plugin to support specialists
|
||||
|
||||
### Changed
|
||||
- Update of eveai_sync WordPress plugin
|
||||
|
||||
### Fixed
|
||||
- Set default language when registering Documents or URLs.
|
||||
|
||||
### Security
|
||||
- Security improvements to Docker images
|
||||
|
||||
## [1.0.14-alfa]
|
||||
|
||||
### Added
|
||||
- New release script added to tag images with release number
|
||||
- Allow the addition of multiple types of Catalogs
|
||||
- Generic functionality to enable dynamic fields
|
||||
- Addition of Retrievers to allow for smart collection of information in Catalogs
|
||||
- Add dynamic fields to Catalog / Retriever / DocumentVersion
|
||||
|
||||
### Changed
|
||||
- Processing parameters defined at Catalog level iso Tenant level
|
||||
- Reroute 'blank' paths to 'admin'
|
||||
|
||||
### Deprecated
|
||||
- For soon-to-be removed features.
|
||||
|
||||
### Removed
|
||||
- For now removed features.
|
||||
|
||||
### Fixed
|
||||
- Set default language when registering Documents or URLs.
|
||||
|
||||
### Security
|
||||
- In case of vulnerabilities.
|
||||
|
||||
## [1.0.13-alfa]
|
||||
|
||||
### Added
|
||||
- Finished Catalog introduction
|
||||
- Reinitialization of WordPress site for syncing
|
||||
|
||||
### Changed
|
||||
- Modification of WordPress Sync Component
|
||||
- Cleanup of attributes in Tenant
|
||||
|
||||
### Fixed
|
||||
- Overall bugfixes as result from the Catalog introduction
|
||||
|
||||
## [1.0.12-alfa]
|
||||
|
||||
### Added
|
||||
- Added Catalog functionality
|
||||
|
||||
### Changed
|
||||
- For changes in existing functionality.
|
||||
|
||||
### Deprecated
|
||||
- For soon-to-be removed features.
|
||||
|
||||
### Removed
|
||||
- For now removed features.
|
||||
|
||||
### Fixed
|
||||
- Set default language when registering Documents or URLs.
|
||||
|
||||
### Security
|
||||
- In case of vulnerabilities.
|
||||
|
||||
## [1.0.11-alfa]
|
||||
|
||||
### Added
|
||||
- License Usage Calculation realised
|
||||
- View License Usages
|
||||
- Celery Beat container added
|
||||
- First schedule in Celery Beat for calculating usage (hourly)
|
||||
|
||||
### Changed
|
||||
- repopack can now split for different components
|
||||
|
||||
### Fixed
|
||||
- Various fixes as consequence of changing file_location / file_name ==> bucket_name / object_name
|
||||
- Celery Routing / Queuing updated
|
||||
|
||||
## [1.0.10-alfa]
|
||||
|
||||
### Added
|
||||
- BusinessEventLog monitoring using Langchain native code
|
||||
|
||||
### Changed
|
||||
- Allow longer audio files (or video) to be uploaded and processed
|
||||
- Storage and Embedding usage now expressed in MiB iso tokens (more logical)
|
||||
- Views for License / LicenseTier
|
||||
|
||||
### Removed
|
||||
- Portkey removed for monitoring usage
|
||||
|
||||
## [1.0.9-alfa] - 2024/10/01
|
||||
|
||||
### Added
|
||||
- Business Event tracing (eveai_workers & eveai_chat_workers)
|
||||
- Flower Container added for monitoring
|
||||
|
||||
### Changed
|
||||
- Healthcheck improvements
|
||||
- model_utils turned into a class with lazy loading
|
||||
|
||||
### Deprecated
|
||||
- For soon-to-be removed features.
|
||||
|
||||
### Removed
|
||||
- For now removed features.
|
||||
|
||||
### Fixed
|
||||
- Set default language when registering Documents or URLs.
|
||||
|
||||
## [1.0.8-alfa] - 2024-09-12
|
||||
|
||||
### Added
|
||||
- Tenant type defined to allow for active, inactive, demo ... tenants
|
||||
- Search and filtering functionality on Tenants
|
||||
- Implementation of health checks (1st version)
|
||||
- Provision for Prometheus monitoring (no implementation yet)
|
||||
- Refine audio_processor and srt_processor to reduce duplicate code and support larger files
|
||||
- Introduction of repopack to reason in LLMs about the code
|
||||
|
||||
### Fixed
|
||||
- Refine audio_processor and srt_processor to reduce duplicate code and support larger files
|
||||
|
||||
## [1.0.7-alfa] - 2024-09-12
|
||||
|
||||
### Added
|
||||
- Full Document API allowing for creation, updating and invalidation of documents.
|
||||
- Metadata fields (JSON) added to DocumentVersion, allowing end-users to add structured information
|
||||
- Wordpress plugin eveai_sync to synchronize Wordpress content with EveAI
|
||||
|
||||
### Fixed
|
||||
- Maximal deduplication of code between views and api in document_utils.py
|
||||
|
||||
## [1.0.6-alfa] - 2024-09-03
|
||||
|
||||
### Fixed
|
||||
- Problems with tenant scheme migrations - may have to be revisited
|
||||
- Correction of default language settings when uploading docs or URLs
|
||||
- Addition of a CHANGELOG.md file
|
||||
|
||||
## [1.0.5-alfa] - 2024-09-02
|
||||
|
||||
### Added
|
||||
- Allow chatwidget to connect to multiple servers (e.g. development and production)
|
||||
- Start implementation of API
|
||||
- Add API-key functionality to tenants
|
||||
- Deduplication of API and Document view code
|
||||
- Allow URL addition to accept all types of files, not just HTML
|
||||
- Allow new file types upload: srt, mp3, ogg, mp4
|
||||
- Improve processing of different file types using Processor classes
|
||||
|
||||
### Removed
|
||||
- Removed direct upload of Youtube URLs, due to continuous changes in Youtube website
|
||||
|
||||
## [1.0.4-alfa] - 2024-08-27
|
||||
Skipped
|
||||
|
||||
## [1.0.3-alfa] - 2024-08-27
|
||||
|
||||
### Added
|
||||
- Refinement of HTML processing - allow for excluded classes and elements.
|
||||
- Allow for multiple instances of Evie on 1 website (pure + Wordpress plugin)
|
||||
|
||||
### Changed
|
||||
- PDF Processing extracted in new PDF Processor class.
|
||||
- Allow for longer and more complex PDFs to be uploaded.
|
||||
|
||||
## [1.0.2-alfa] - 2024-08-22
|
||||
|
||||
### Fixed
|
||||
- Bugfix for ResetPasswordForm in config.py
|
||||
|
||||
## [1.0.1-alfa] - 2024-08-21
|
||||
|
||||
### Added
|
||||
- Full Document Version Overview
|
||||
|
||||
### Changed
|
||||
- Improvements to user creation and registration, renewal of passwords, ...
|
||||
|
||||
## [1.0.0-alfa] - 2024-08-16
|
||||
|
||||
### Added
|
||||
- Initial release of the project.
|
||||
|
||||
### Changed
|
||||
- None
|
||||
|
||||
### Fixed
|
||||
- None
|
||||
|
||||
[Unreleased]: https://github.com/username/repo/compare/v1.0.0...HEAD
|
||||
[1.0.0]: https://github.com/username/repo/releases/tag/v1.0.0
|
||||
85
Evie Overview.md
Normal file
85
Evie Overview.md
Normal file
@@ -0,0 +1,85 @@
|
||||
# Evie Overview
|
||||
|
||||
Owner: pieter Laroy
|
||||
|
||||
# Introduction
|
||||
|
||||
The Evie project (developed by AskEveAI) is a SAAS product that enables SMEs to easily introduce AI optimisations for both internal and external use. There are two big concepts:
|
||||
|
||||
- Catalogs: these allow tenants to store information about their organisations or enterprises
|
||||
- Specialists: these allow tenants to perform logic into their processes, communications, …
|
||||
|
||||
As such, we could say we have an advanced RAG system tenants can use to optimise their workings.
|
||||
|
||||
## Multi-tenant
|
||||
|
||||
The application has a multi-tenant setup built in. This is reflected in:
|
||||
|
||||
- The Database:
|
||||
- We have 1 public schema, in which general information is defined such as tenants, their users, domains, licenses, …
|
||||
- We have a schema (named 1, 2, …) for each of the tenants defined in the system, containing all information on the tenant’s catalogs & documents, specialists & interactions, …
|
||||
- File Storage
|
||||
- We use S3-compatible storage
|
||||
- A bucket is defined for each tenant, storing their specific documents, assets, …
|
||||
|
||||
That way, general information required for the operation of Evie is stored in the public schema, and specific and potentially sensitive information is nicely stored behind a Chinese wall for each of the tenants.
|
||||
|
||||
## Partners
|
||||
|
||||
We started to define the concept of a partner. This allows us to have partners that introduce tenants to Evie, or offer them additional functionality (specialists) or knowledge (catalogs). This concept is in an early stage at this point.
|
||||
|
||||
## Domains
|
||||
|
||||
In order to ensure a structured approach, we have defined several domains in the project:
|
||||
|
||||
- **User**: the user domain is used to store all data on partners, tenants, actual users.
|
||||
- **Document**: the document domain is used to store all information on catalogs, documents, how to process documents, …
|
||||
- **Interaction**: This domain allows us to define specialists, agents, … and to interact with the specialists and agents.
|
||||
- **Entitlements**: This domain defines all license information, usage, …
|
||||
|
||||
# Project Structure
|
||||
|
||||
## Common
|
||||
|
||||
The common folder contains code that is used in different components of the system. It contains the following important pieces:
|
||||
|
||||
- **models**: in the models folder you can find the SQLAlchemy models used throughout the application. These models are organised in their relevant domains.
|
||||
- **eveai_model**: some classes to handle usage, wrappers around standard LLM clients
|
||||
- **langchain**: similar to eveai_model, but in the langchain library
|
||||
- **services**: I started to define services to define reusable functionality in the system. There again are defined in their respective domains
|
||||
- **utils**: a whole bunch of utility classes. Some should get converted to services classes in the future
|
||||
- **utils/cache**: contains code for caching different elements in the application
|
||||
|
||||
## config
|
||||
|
||||
The config folder contains quite some configuration data (as the name suggests):
|
||||
|
||||
- **config.py**: general configuration
|
||||
- **logging_config.py**: definition of logging files
|
||||
- **model_config.py**: obsolete
|
||||
- **type_defs**: contains the lists of definitions for several types used throughout the application. E.g. processor_types, specialist_types, …
|
||||
- **All other folders**: detailed configuration of all the types defined in type_defs.
|
||||
|
||||
## docker
|
||||
|
||||
The docker folder contains the configuration and scripts used for all operations on configuring and building containers, distributing containers, …
|
||||
|
||||
## eveai_… folders
|
||||
|
||||
These are different components (containerized) of our application:
|
||||
|
||||
- **eveai_api**: The API of our application.
|
||||
- **eveai_app**: The administrative interface of our application.
|
||||
- **eveai_beat**: a means to install batch processes for our application.
|
||||
- **eveai_chat**: obsolete at this moment
|
||||
- **eveai_chat_workers**: celery based invocation of our specialists
|
||||
- **eveai_client**: newly added. A desktop client to invoking specialists.
|
||||
- **eveai_entitlements**: celery based approach to handling business events, measuring and updating usage, …
|
||||
- **eveai_workers**: celery based approach to filling catalogs with documents (embedding)
|
||||
|
||||
## Remaining folders
|
||||
|
||||
- **integrations**: integrations to e.g. Wordpress and Zapier.
|
||||
- **migrations**: SQLAlchemy database migration files (for public and tenant schema)
|
||||
- **nginx**: configuration and static files for nginx
|
||||
- **scripts**: various scripts used to start up components, to perform database operations, …
|
||||
32
check_running_services.sh
Normal file
32
check_running_services.sh
Normal file
@@ -0,0 +1,32 @@
|
||||
#!/bin/bash
|
||||
# Diagnostic script to check what services are running
|
||||
|
||||
echo "=== KIND CLUSTER STATUS ==="
|
||||
echo "Namespaces:"
|
||||
kubectl get namespaces | grep eveai
|
||||
|
||||
echo -e "\nPods in eveai-dev:"
|
||||
kubectl get pods -n eveai-dev
|
||||
|
||||
echo -e "\nServices in eveai-dev:"
|
||||
kubectl get services -n eveai-dev
|
||||
|
||||
echo -e "\n=== TEST CONTAINERS STATUS ==="
|
||||
echo "Running test containers:"
|
||||
podman ps | grep eveai_test
|
||||
|
||||
echo -e "\n=== PORT ANALYSIS ==="
|
||||
echo "What's listening on port 3080:"
|
||||
lsof -i :3080 2>/dev/null || echo "Nothing found"
|
||||
|
||||
echo -e "\nWhat's listening on port 4080:"
|
||||
lsof -i :4080 2>/dev/null || echo "Nothing found"
|
||||
|
||||
echo -e "\n=== SOLUTION ==="
|
||||
echo "The application you see is from TEST CONTAINERS (6 days old),"
|
||||
echo "NOT from the Kind cluster (3 minutes old)."
|
||||
echo ""
|
||||
echo "To test Kind cluster:"
|
||||
echo "1. Stop test containers: podman stop eveai_test_nginx_1 eveai_test_eveai_app_1"
|
||||
echo "2. Deploy Kind services: kup-all-structured"
|
||||
echo "3. Restart test containers if needed"
|
||||
BIN
common/.DS_Store
vendored
BIN
common/.DS_Store
vendored
Binary file not shown.
Binary file not shown.
Binary file not shown.
11
common/eveai_model/eveai_embedding_base.py
Normal file
11
common/eveai_model/eveai_embedding_base.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from abc import abstractmethod
|
||||
from typing import List
|
||||
|
||||
|
||||
class EveAIEmbeddings:
|
||||
@abstractmethod
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
raise NotImplementedError
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
return self.embed_documents([text])[0]
|
||||
136
common/eveai_model/tracked_mistral_embeddings.py
Normal file
136
common/eveai_model/tracked_mistral_embeddings.py
Normal file
@@ -0,0 +1,136 @@
|
||||
from flask import current_app
|
||||
from langchain_mistralai import MistralAIEmbeddings
|
||||
from typing import List, Any
|
||||
import time
|
||||
|
||||
from common.eveai_model.eveai_embedding_base import EveAIEmbeddings
|
||||
from common.utils.business_event_context import current_event
|
||||
from mistralai import Mistral
|
||||
|
||||
|
||||
class TrackedMistralAIEmbeddings(EveAIEmbeddings):
|
||||
def __init__(self, model: str = "mistral_embed", batch_size: int = 10):
|
||||
"""
|
||||
Initialize the TrackedMistralAIEmbeddings class.
|
||||
|
||||
Args:
|
||||
model: The embedding model to use
|
||||
batch_size: Maximum number of texts to send in a single API call
|
||||
"""
|
||||
api_key = current_app.config['MISTRAL_API_KEY']
|
||||
self.client = Mistral(
|
||||
api_key=api_key
|
||||
)
|
||||
self.model = model
|
||||
self.batch_size = batch_size
|
||||
super().__init__()
|
||||
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
"""
|
||||
Embed a list of texts, processing in batches to avoid API limitations.
|
||||
|
||||
Args:
|
||||
texts: A list of texts to embed
|
||||
|
||||
Returns:
|
||||
A list of embeddings, one for each input text
|
||||
"""
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
all_embeddings = []
|
||||
|
||||
# Process texts in batches
|
||||
for i in range(0, len(texts), self.batch_size):
|
||||
batch = texts[i:i + self.batch_size]
|
||||
batch_num = i // self.batch_size + 1
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
result = self.client.embeddings.create(
|
||||
model=self.model,
|
||||
inputs=batch
|
||||
)
|
||||
end_time = time.time()
|
||||
batch_time = end_time - start_time
|
||||
|
||||
batch_embeddings = [embedding.embedding for embedding in result.data]
|
||||
all_embeddings.extend(batch_embeddings)
|
||||
|
||||
# Log metrics for this batch
|
||||
metrics = {
|
||||
'total_tokens': result.usage.total_tokens,
|
||||
'prompt_tokens': result.usage.prompt_tokens,
|
||||
'completion_tokens': result.usage.completion_tokens,
|
||||
'time_elapsed': batch_time,
|
||||
'interaction_type': 'Embedding',
|
||||
'batch': batch_num,
|
||||
'batch_size': len(batch)
|
||||
}
|
||||
current_event.log_llm_metrics(metrics)
|
||||
|
||||
# If processing multiple batches, add a small delay to avoid rate limits
|
||||
if len(texts) > self.batch_size and i + self.batch_size < len(texts):
|
||||
time.sleep(0.25) # 250ms pause between batches
|
||||
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error in embedding batch {batch_num}: {str(e)}")
|
||||
# If a batch fails, try to process each text individually
|
||||
for j, text in enumerate(batch):
|
||||
try:
|
||||
single_start_time = time.time()
|
||||
single_result = self.client.embeddings.create(
|
||||
model=self.model,
|
||||
inputs=[text]
|
||||
)
|
||||
single_end_time = time.time()
|
||||
|
||||
# Add the single embedding
|
||||
single_embedding = single_result.data[0].embedding
|
||||
all_embeddings.append(single_embedding)
|
||||
|
||||
# Log metrics for this individual embedding
|
||||
single_metrics = {
|
||||
'total_tokens': single_result.usage.total_tokens,
|
||||
'prompt_tokens': single_result.usage.prompt_tokens,
|
||||
'completion_tokens': single_result.usage.completion_tokens,
|
||||
'time_elapsed': single_end_time - single_start_time,
|
||||
'interaction_type': 'Embedding',
|
||||
'batch': f"{batch_num}-recovery-{j}",
|
||||
'batch_size': 1
|
||||
}
|
||||
current_event.log_llm_metrics(single_metrics)
|
||||
|
||||
except Exception as inner_e:
|
||||
current_app.logger.error(f"Failed to embed individual text at index {i + j}: {str(inner_e)}")
|
||||
# Add a zero vector as a placeholder for failed embeddings
|
||||
# Use the correct dimensionality for the model (1024 for mistral_embed)
|
||||
embedding_dim = 1024
|
||||
all_embeddings.append([0.0] * embedding_dim)
|
||||
|
||||
total_batches = (len(texts) + self.batch_size - 1) // self.batch_size
|
||||
current_app.logger.info(f"Embedded {len(texts)} texts in {total_batches} batches")
|
||||
|
||||
return all_embeddings
|
||||
|
||||
# def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
# start_time = time.time()
|
||||
# result = self.client.embeddings.create(
|
||||
# model=self.model,
|
||||
# inputs=texts
|
||||
# )
|
||||
# end_time = time.time()
|
||||
#
|
||||
# metrics = {
|
||||
# 'total_tokens': result.usage.total_tokens,
|
||||
# 'prompt_tokens': result.usage.prompt_tokens, # For embeddings, all tokens are prompt tokens
|
||||
# 'completion_tokens': result.usage.completion_tokens,
|
||||
# 'time_elapsed': end_time - start_time,
|
||||
# 'interaction_type': 'Embedding',
|
||||
# }
|
||||
# current_event.log_llm_metrics(metrics)
|
||||
#
|
||||
# embeddings = [embedding.embedding for embedding in result.data]
|
||||
#
|
||||
# return embeddings
|
||||
|
||||
53
common/eveai_model/tracked_mistral_ocr_client.py
Normal file
53
common/eveai_model/tracked_mistral_ocr_client.py
Normal file
@@ -0,0 +1,53 @@
|
||||
import re
|
||||
import time
|
||||
|
||||
from flask import current_app
|
||||
from mistralai import Mistral
|
||||
|
||||
from common.utils.business_event_context import current_event
|
||||
|
||||
|
||||
class TrackedMistralOcrClient:
|
||||
def __init__(self):
|
||||
api_key = current_app.config['MISTRAL_API_KEY']
|
||||
self.client = Mistral(
|
||||
api_key=api_key,
|
||||
)
|
||||
self.model = "mistral-ocr-latest"
|
||||
|
||||
def _get_title(self, markdown):
|
||||
# Look for the first level-1 heading
|
||||
match = re.search(r'^# (.+)', markdown, re.MULTILINE)
|
||||
return match.group(1).strip() if match else None
|
||||
|
||||
def process_pdf(self, file_name, file_content):
|
||||
start_time = time.time()
|
||||
uploaded_pdf = self.client.files.upload(
|
||||
file={
|
||||
"file_name": file_name,
|
||||
"content": file_content
|
||||
},
|
||||
purpose="ocr"
|
||||
)
|
||||
signed_url = self.client.files.get_signed_url(file_id=uploaded_pdf.id)
|
||||
ocr_response = self.client.ocr.process(
|
||||
model=self.model,
|
||||
document={
|
||||
"type": "document_url",
|
||||
"document_url": signed_url.url
|
||||
},
|
||||
include_image_base64=False
|
||||
)
|
||||
nr_of_pages = len(ocr_response.pages)
|
||||
all_markdown = " ".join(page.markdown for page in ocr_response.pages)
|
||||
title = self._get_title(all_markdown)
|
||||
end_time = time.time()
|
||||
|
||||
metrics = {
|
||||
'nr_of_pages': nr_of_pages,
|
||||
'time_elapsed': end_time - start_time,
|
||||
'interaction_type': 'OCR',
|
||||
}
|
||||
current_event.log_llm_metrics(metrics)
|
||||
|
||||
return all_markdown, title
|
||||
@@ -2,15 +2,16 @@ from flask_sqlalchemy import SQLAlchemy
|
||||
from flask_migrate import Migrate
|
||||
from flask_bootstrap import Bootstrap
|
||||
from flask_security import Security
|
||||
from flask_mailman import Mail
|
||||
from flask_login import LoginManager
|
||||
from flask_cors import CORS
|
||||
from flask_socketio import SocketIO
|
||||
from flask_jwt_extended import JWTManager
|
||||
from flask_session import Session
|
||||
from flask_wtf import CSRFProtect
|
||||
from flask_restx import Api
|
||||
from prometheus_flask_exporter import PrometheusMetrics
|
||||
|
||||
from .utils.nginx_utils import prefixed_url_for
|
||||
from .utils.cache.eveai_cache_manager import EveAICacheManager
|
||||
from .utils.content_utils import ContentManager
|
||||
from .utils.simple_encryption import SimpleEncryption
|
||||
from .utils.minio_utils import MinioClient
|
||||
|
||||
@@ -21,14 +22,14 @@ migrate = Migrate()
|
||||
bootstrap = Bootstrap()
|
||||
csrf = CSRFProtect()
|
||||
security = Security()
|
||||
mail = Mail()
|
||||
login_manager = LoginManager()
|
||||
cors = CORS()
|
||||
socketio = SocketIO()
|
||||
jwt = JWTManager()
|
||||
session = Session()
|
||||
|
||||
# kms_client = JosKMSClient.from_service_account_json('config/gc_sa_eveai.json')
|
||||
|
||||
api_rest = Api()
|
||||
simple_encryption = SimpleEncryption()
|
||||
minio_client = MinioClient()
|
||||
metrics = PrometheusMetrics.for_app_factory()
|
||||
cache_manager = EveAICacheManager()
|
||||
content_manager = ContentManager()
|
||||
|
||||
|
||||
@@ -1,44 +0,0 @@
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from sqlalchemy import asc
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Any, Dict
|
||||
from flask import current_app
|
||||
|
||||
from common.extensions import db
|
||||
from common.models.interaction import ChatSession, Interaction
|
||||
from common.utils.datetime_utils import get_date_in_timezone
|
||||
|
||||
|
||||
class EveAIHistoryRetriever(BaseRetriever):
|
||||
model_variables: Dict[str, Any] = Field(...)
|
||||
session_id: str = Field(...)
|
||||
|
||||
def __init__(self, model_variables: Dict[str, Any], session_id: str):
|
||||
super().__init__()
|
||||
self.model_variables = model_variables
|
||||
self.session_id = session_id
|
||||
|
||||
def _get_relevant_documents(self, query: str):
|
||||
current_app.logger.debug(f'Retrieving history of interactions for query: {query}')
|
||||
|
||||
try:
|
||||
query_obj = (
|
||||
db.session.query(Interaction)
|
||||
.join(ChatSession, Interaction.chat_session_id == ChatSession.id)
|
||||
.filter(ChatSession.session_id == self.session_id)
|
||||
.order_by(asc(Interaction.id))
|
||||
)
|
||||
|
||||
interactions = query_obj.all()
|
||||
|
||||
result = []
|
||||
for interaction in interactions:
|
||||
result.append(f'HUMAN:\n{interaction.detailed_question}\n\nAI: \n{interaction.answer}\n\n')
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
current_app.logger.error(f'Error retrieving history of interactions: {e}')
|
||||
db.session.rollback()
|
||||
return []
|
||||
|
||||
return result
|
||||
@@ -1,129 +0,0 @@
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from sqlalchemy import func, and_, or_, desc
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Any, Dict
|
||||
from flask import current_app
|
||||
|
||||
from common.extensions import db
|
||||
from common.models.document import Document, DocumentVersion
|
||||
from common.utils.datetime_utils import get_date_in_timezone
|
||||
|
||||
|
||||
class EveAIRetriever(BaseRetriever):
|
||||
model_variables: Dict[str, Any] = Field(...)
|
||||
tenant_info: Dict[str, Any] = Field(...)
|
||||
|
||||
def __init__(self, model_variables: Dict[str, Any], tenant_info: Dict[str, Any]):
|
||||
super().__init__()
|
||||
self.model_variables = model_variables
|
||||
self.tenant_info = tenant_info
|
||||
|
||||
def _get_relevant_documents(self, query: str):
|
||||
|
||||
|
||||
|
||||
current_app.logger.debug(f'Retrieving relevant documents for query: {query}')
|
||||
query_embedding = self._get_query_embedding(query)
|
||||
db_class = self.model_variables['embedding_db_model']
|
||||
similarity_threshold = self.model_variables['similarity_threshold']
|
||||
k = self.model_variables['k']
|
||||
|
||||
if self.tenant_info['rag_tuning']:
|
||||
try:
|
||||
current_date = get_date_in_timezone(self.tenant_info['timezone'])
|
||||
current_app.rag_tuning_logger.debug(f'Current date: {current_date}\n')
|
||||
|
||||
# Debug query to show similarity for all valid documents (without chunk text)
|
||||
debug_query = (
|
||||
db.session.query(
|
||||
Document.id.label('document_id'),
|
||||
DocumentVersion.id.label('version_id'),
|
||||
db_class.id.label('embedding_id'),
|
||||
(1 - db_class.embedding.cosine_distance(query_embedding)).label('similarity')
|
||||
)
|
||||
.join(DocumentVersion, db_class.doc_vers_id == DocumentVersion.id)
|
||||
.join(Document, DocumentVersion.doc_id == Document.id)
|
||||
.filter(
|
||||
or_(Document.valid_from.is_(None), func.date(Document.valid_from) <= current_date),
|
||||
or_(Document.valid_to.is_(None), func.date(Document.valid_to) >= current_date)
|
||||
)
|
||||
.order_by(desc('similarity'))
|
||||
)
|
||||
|
||||
debug_results = debug_query.all()
|
||||
|
||||
current_app.logger.debug("Debug: Similarity for all valid documents:")
|
||||
for row in debug_results:
|
||||
current_app.rag_tuning_logger.debug(f"Doc ID: {row.document_id}, "
|
||||
f"Version ID: {row.version_id}, "
|
||||
f"Embedding ID: {row.embedding_id}, "
|
||||
f"Similarity: {row.similarity}")
|
||||
current_app.rag_tuning_logger.debug(f'---------------------------------------\n')
|
||||
except SQLAlchemyError as e:
|
||||
current_app.logger.error(f'Error generating overview: {e}')
|
||||
db.session.rollback()
|
||||
|
||||
if self.tenant_info['rag_tuning']:
|
||||
current_app.rag_tuning_logger.debug(f'Parameters for Retrieval of documents: \n')
|
||||
current_app.rag_tuning_logger.debug(f'Similarity Threshold: {similarity_threshold}\n')
|
||||
current_app.rag_tuning_logger.debug(f'K: {k}\n')
|
||||
current_app.rag_tuning_logger.debug(f'---------------------------------------\n')
|
||||
|
||||
try:
|
||||
current_date = get_date_in_timezone(self.tenant_info['timezone'])
|
||||
# Subquery to find the latest version of each document
|
||||
subquery = (
|
||||
db.session.query(
|
||||
DocumentVersion.doc_id,
|
||||
func.max(DocumentVersion.id).label('latest_version_id')
|
||||
)
|
||||
.group_by(DocumentVersion.doc_id)
|
||||
.subquery()
|
||||
)
|
||||
# Main query to filter embeddings
|
||||
query_obj = (
|
||||
db.session.query(db_class,
|
||||
(1 - db_class.embedding.cosine_distance(query_embedding)).label('similarity'))
|
||||
.join(DocumentVersion, db_class.doc_vers_id == DocumentVersion.id)
|
||||
.join(Document, DocumentVersion.doc_id == Document.id)
|
||||
.join(subquery, DocumentVersion.id == subquery.c.latest_version_id)
|
||||
.filter(
|
||||
or_(Document.valid_from.is_(None), func.date(Document.valid_from) <= current_date),
|
||||
or_(Document.valid_to.is_(None), func.date(Document.valid_to) >= current_date),
|
||||
(1 - db_class.embedding.cosine_distance(query_embedding)) > similarity_threshold
|
||||
)
|
||||
.order_by(desc('similarity'))
|
||||
.limit(k)
|
||||
)
|
||||
|
||||
if self.tenant_info['rag_tuning']:
|
||||
current_app.rag_tuning_logger.debug(f'Query executed for Retrieval of documents: \n')
|
||||
current_app.rag_tuning_logger.debug(f'{query_obj.statement}\n')
|
||||
current_app.rag_tuning_logger.debug(f'---------------------------------------\n')
|
||||
|
||||
res = query_obj.all()
|
||||
|
||||
if self.tenant_info['rag_tuning']:
|
||||
current_app.rag_tuning_logger.debug(f'Retrieved {len(res)} relevant documents \n')
|
||||
current_app.rag_tuning_logger.debug(f'Data retrieved: \n')
|
||||
current_app.rag_tuning_logger.debug(f'{res}\n')
|
||||
current_app.rag_tuning_logger.debug(f'---------------------------------------\n')
|
||||
|
||||
result = []
|
||||
for doc in res:
|
||||
if self.tenant_info['rag_tuning']:
|
||||
current_app.rag_tuning_logger.debug(f'Document ID: {doc[0].id} - Distance: {doc[1]}\n')
|
||||
current_app.rag_tuning_logger.debug(f'Chunk: \n {doc[0].chunk}\n\n')
|
||||
result.append(f'SOURCE: {doc[0].id}\n\n{doc[0].chunk}\n\n')
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
current_app.logger.error(f'Error retrieving relevant documents: {e}')
|
||||
db.session.rollback()
|
||||
return []
|
||||
return result
|
||||
|
||||
def _get_query_embedding(self, query: str):
|
||||
embedding_model = self.model_variables['embedding_model']
|
||||
query_embedding = embedding_model.embed_query(query)
|
||||
return query_embedding
|
||||
Binary file not shown.
Binary file not shown.
48
common/langchain/llm_metrics_handler.py
Normal file
48
common/langchain/llm_metrics_handler.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import time
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from typing import Dict, Any, List
|
||||
from langchain.schema import LLMResult
|
||||
from common.utils.business_event_context import current_event
|
||||
|
||||
|
||||
class LLMMetricsHandler(BaseCallbackHandler):
|
||||
def __init__(self):
|
||||
self.total_tokens: int = 0
|
||||
self.prompt_tokens: int = 0
|
||||
self.completion_tokens: int = 0
|
||||
self.start_time: float = 0
|
||||
self.end_time: float = 0
|
||||
self.total_time: float = 0
|
||||
|
||||
def reset(self):
|
||||
self.total_tokens = 0
|
||||
self.prompt_tokens = 0
|
||||
self.completion_tokens = 0
|
||||
self.start_time = 0
|
||||
self.end_time = 0
|
||||
self.total_time = 0
|
||||
|
||||
def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> None:
|
||||
self.start_time = time.time()
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
self.end_time = time.time()
|
||||
self.total_time = self.end_time - self.start_time
|
||||
|
||||
usage = response.llm_output.get('token_usage', {})
|
||||
self.prompt_tokens += usage.get('prompt_tokens', 0)
|
||||
self.completion_tokens += usage.get('completion_tokens', 0)
|
||||
self.total_tokens = self.prompt_tokens + self.completion_tokens
|
||||
|
||||
metrics = self.get_metrics()
|
||||
current_event.log_llm_metrics(metrics)
|
||||
self.reset() # Reset for the next call
|
||||
|
||||
def get_metrics(self) -> Dict[str, int | float]:
|
||||
return {
|
||||
'total_tokens': self.total_tokens,
|
||||
'prompt_tokens': self.prompt_tokens,
|
||||
'completion_tokens': self.completion_tokens,
|
||||
'time_elapsed': self.total_time,
|
||||
'interaction_type': 'LLM',
|
||||
}
|
||||
23
common/langchain/outputs/base.py
Normal file
23
common/langchain/outputs/base.py
Normal file
@@ -0,0 +1,23 @@
|
||||
# Output Schema Management - common/langchain/outputs/base.py
|
||||
from typing import Dict, Type, Any
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class BaseSpecialistOutput(BaseModel):
|
||||
"""Base class for all specialist outputs"""
|
||||
pass
|
||||
|
||||
|
||||
class OutputRegistry:
|
||||
"""Registry for specialist output schemas"""
|
||||
_schemas: Dict[str, Type[BaseSpecialistOutput]] = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, specialist_type: str, schema_class: Type[BaseSpecialistOutput]):
|
||||
cls._schemas[specialist_type] = schema_class
|
||||
|
||||
@classmethod
|
||||
def get_schema(cls, specialist_type: str) -> Type[BaseSpecialistOutput]:
|
||||
if specialist_type not in cls._schemas:
|
||||
raise ValueError(f"No output schema registered for {specialist_type}")
|
||||
return cls._schemas[specialist_type]
|
||||
22
common/langchain/outputs/rag.py
Normal file
22
common/langchain/outputs/rag.py
Normal file
@@ -0,0 +1,22 @@
|
||||
# RAG Specialist Output - common/langchain/outputs/rag.py
|
||||
from typing import List
|
||||
from pydantic import Field
|
||||
from .base import BaseSpecialistOutput
|
||||
|
||||
|
||||
class RAGOutput(BaseSpecialistOutput):
|
||||
"""Output schema for RAG specialist"""
|
||||
"""Default docstring - to be replaced with actual prompt"""
|
||||
|
||||
answer: str = Field(
|
||||
...,
|
||||
description="The answer to the user question, based on the given sources",
|
||||
)
|
||||
citations: List[int] = Field(
|
||||
...,
|
||||
description="The integer IDs of the SPECIFIC sources that were used to generate the answer"
|
||||
)
|
||||
insufficient_info: bool = Field(
|
||||
False, # Default value is set to False
|
||||
description="A boolean indicating whether given sources were sufficient or not to generate the answer"
|
||||
)
|
||||
47
common/langchain/persistent_llm_metrics_handler.py
Normal file
47
common/langchain/persistent_llm_metrics_handler.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import time
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from typing import Dict, Any, List
|
||||
from langchain.schema import LLMResult
|
||||
from common.utils.business_event_context import current_event
|
||||
|
||||
|
||||
class PersistentLLMMetricsHandler(BaseCallbackHandler):
|
||||
"""Metrics handler that allows metrics to be retrieved from within any call. In case metrics are required for other
|
||||
purposes than business event logging."""
|
||||
|
||||
def __init__(self):
|
||||
self.total_tokens: int = 0
|
||||
self.prompt_tokens: int = 0
|
||||
self.completion_tokens: int = 0
|
||||
self.start_time: float = 0
|
||||
self.end_time: float = 0
|
||||
self.total_time: float = 0
|
||||
|
||||
def reset(self):
|
||||
self.total_tokens = 0
|
||||
self.prompt_tokens = 0
|
||||
self.completion_tokens = 0
|
||||
self.start_time = 0
|
||||
self.end_time = 0
|
||||
self.total_time = 0
|
||||
|
||||
def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> None:
|
||||
self.start_time = time.time()
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
self.end_time = time.time()
|
||||
self.total_time = self.end_time - self.start_time
|
||||
|
||||
usage = response.llm_output.get('token_usage', {})
|
||||
self.prompt_tokens += usage.get('prompt_tokens', 0)
|
||||
self.completion_tokens += usage.get('completion_tokens', 0)
|
||||
self.total_tokens = self.prompt_tokens + self.completion_tokens
|
||||
|
||||
def get_metrics(self) -> Dict[str, int | float]:
|
||||
return {
|
||||
'total_tokens': self.total_tokens,
|
||||
'prompt_tokens': self.prompt_tokens,
|
||||
'completion_tokens': self.completion_tokens,
|
||||
'time_elapsed': self.total_time,
|
||||
'interaction_type': 'LLM',
|
||||
}
|
||||
51
common/langchain/tracked_openai_embeddings.py
Normal file
51
common/langchain/tracked_openai_embeddings.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
from typing import List, Any
|
||||
import time
|
||||
from common.utils.business_event_context import current_event
|
||||
|
||||
|
||||
class TrackedOpenAIEmbeddings(OpenAIEmbeddings):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
start_time = time.time()
|
||||
result = super().embed_documents(texts)
|
||||
end_time = time.time()
|
||||
|
||||
# Estimate token usage (OpenAI uses tiktoken for this)
|
||||
import tiktoken
|
||||
enc = tiktoken.encoding_for_model(self.model)
|
||||
total_tokens = sum(len(enc.encode(text)) for text in texts)
|
||||
|
||||
metrics = {
|
||||
'total_tokens': total_tokens,
|
||||
'prompt_tokens': total_tokens, # For embeddings, all tokens are prompt tokens
|
||||
'completion_tokens': 0,
|
||||
'time_elapsed': end_time - start_time,
|
||||
'interaction_type': 'Embedding',
|
||||
}
|
||||
current_event.log_llm_metrics(metrics)
|
||||
|
||||
return result
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
start_time = time.time()
|
||||
result = super().embed_query(text)
|
||||
end_time = time.time()
|
||||
|
||||
# Estimate token usage
|
||||
import tiktoken
|
||||
enc = tiktoken.encoding_for_model(self.model)
|
||||
total_tokens = len(enc.encode(text))
|
||||
|
||||
metrics = {
|
||||
'total_tokens': total_tokens,
|
||||
'prompt_tokens': total_tokens,
|
||||
'completion_tokens': 0,
|
||||
'time_elapsed': end_time - start_time,
|
||||
'interaction_type': 'Embedding',
|
||||
}
|
||||
current_event.log_llm_metrics(metrics)
|
||||
|
||||
return result
|
||||
77
common/langchain/tracked_transcription.py
Normal file
77
common/langchain/tracked_transcription.py
Normal file
@@ -0,0 +1,77 @@
|
||||
# common/langchain/tracked_transcription.py
|
||||
from typing import Any, Optional, Dict
|
||||
import time
|
||||
from openai import OpenAI
|
||||
from common.utils.business_event_context import current_event
|
||||
|
||||
|
||||
class TrackedOpenAITranscription:
|
||||
"""Wrapper for OpenAI transcription with metric tracking"""
|
||||
|
||||
def __init__(self, api_key: str, **kwargs: Any):
|
||||
"""Initialize with OpenAI client settings"""
|
||||
self.client = OpenAI(api_key=api_key)
|
||||
self.model = kwargs.get('model', 'whisper-1')
|
||||
|
||||
def transcribe(self,
|
||||
file: Any,
|
||||
model: Optional[str] = None,
|
||||
language: Optional[str] = None,
|
||||
prompt: Optional[str] = None,
|
||||
response_format: Optional[str] = None,
|
||||
temperature: Optional[float] = None,
|
||||
duration: Optional[int] = None) -> str:
|
||||
"""
|
||||
Transcribe audio with metrics tracking
|
||||
|
||||
Args:
|
||||
file: Audio file to transcribe
|
||||
model: Model to use (defaults to whisper-1)
|
||||
language: Optional language of the audio
|
||||
prompt: Optional prompt to guide transcription
|
||||
response_format: Response format (json, text, etc)
|
||||
temperature: Sampling temperature
|
||||
duration: Duration of audio in seconds for metrics
|
||||
|
||||
Returns:
|
||||
Transcription text
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# Create transcription options
|
||||
options = {
|
||||
"file": file,
|
||||
"model": model or self.model,
|
||||
}
|
||||
if language:
|
||||
options["language"] = language
|
||||
if prompt:
|
||||
options["prompt"] = prompt
|
||||
if response_format:
|
||||
options["response_format"] = response_format
|
||||
if temperature:
|
||||
options["temperature"] = temperature
|
||||
|
||||
response = self.client.audio.transcriptions.create(**options)
|
||||
|
||||
# Calculate metrics
|
||||
end_time = time.time()
|
||||
|
||||
# Token usage for transcriptions is based on audio duration
|
||||
metrics = {
|
||||
'total_tokens': duration or 600, # Default to 10 minutes if duration not provided
|
||||
'prompt_tokens': 0, # For transcriptions, all tokens are completion
|
||||
'completion_tokens': duration or 600,
|
||||
'time_elapsed': end_time - start_time,
|
||||
'interaction_type': 'ASR',
|
||||
}
|
||||
current_event.log_llm_metrics(metrics)
|
||||
|
||||
# Return text from response
|
||||
if isinstance(response, str):
|
||||
return response
|
||||
return response.text
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"Transcription failed: {str(e)}")
|
||||
BIN
common/models/.DS_Store
vendored
BIN
common/models/.DS_Store
vendored
Binary file not shown.
2
common/models/README.txt
Normal file
2
common/models/README.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
If models are added to the public schema (i.e. in the user domain), ensure to add their corresponding tables to the
|
||||
env.py, get_public_table_names, for tenant migrations!
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,24 +1,116 @@
|
||||
from common.extensions import db
|
||||
from .user import User, Tenant
|
||||
from pgvector.sqlalchemy import Vector
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.dialects.postgresql import ARRAY
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
class Catalog(db.Model):
|
||||
id = db.Column(db.Integer, primary_key=True)
|
||||
name = db.Column(db.String(50), nullable=False, unique=True)
|
||||
description = db.Column(db.Text, nullable=True)
|
||||
type = db.Column(db.String(50), nullable=False, default="STANDARD_CATALOG")
|
||||
type_version = db.Column(db.String(20), nullable=True, default="1.0.0")
|
||||
|
||||
min_chunk_size = db.Column(db.Integer, nullable=True, default=1500)
|
||||
max_chunk_size = db.Column(db.Integer, nullable=True, default=2500)
|
||||
|
||||
# Meta Data
|
||||
user_metadata = db.Column(JSONB, nullable=True)
|
||||
system_metadata = db.Column(JSONB, nullable=True)
|
||||
configuration = db.Column(JSONB, nullable=True)
|
||||
|
||||
# Versioning Information
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now())
|
||||
created_by = db.Column(db.Integer, db.ForeignKey(User.id), nullable=True)
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now(), onupdate=db.func.now())
|
||||
updated_by = db.Column(db.Integer, db.ForeignKey(User.id))
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
'id': self.id,
|
||||
'name': self.name,
|
||||
'description': self.description,
|
||||
'type': self.type,
|
||||
'type_version': self.type_version,
|
||||
'min_chunk_size': self.min_chunk_size,
|
||||
'max_chunk_size': self.max_chunk_size,
|
||||
'user_metadata': self.user_metadata,
|
||||
'system_metadata': self.system_metadata,
|
||||
'configuration': self.configuration,
|
||||
}
|
||||
|
||||
|
||||
class Processor(db.Model):
|
||||
id = db.Column(db.Integer, primary_key=True)
|
||||
name = db.Column(db.String(50), nullable=False)
|
||||
description = db.Column(db.Text, nullable=True)
|
||||
catalog_id = db.Column(db.Integer, db.ForeignKey('catalog.id'), nullable=True)
|
||||
type = db.Column(db.String(50), nullable=False)
|
||||
sub_file_type = db.Column(db.String(50), nullable=True)
|
||||
active = db.Column(db.Boolean, nullable=True, default=True)
|
||||
|
||||
# Tuning enablers
|
||||
tuning = db.Column(db.Boolean, nullable=True, default=False)
|
||||
|
||||
# Meta Data
|
||||
user_metadata = db.Column(JSONB, nullable=True)
|
||||
system_metadata = db.Column(JSONB, nullable=True)
|
||||
configuration = db.Column(JSONB, nullable=True)
|
||||
|
||||
# Versioning Information
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now())
|
||||
created_by = db.Column(db.Integer, db.ForeignKey(User.id), nullable=True)
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now(), onupdate=db.func.now())
|
||||
updated_by = db.Column(db.Integer, db.ForeignKey(User.id))
|
||||
|
||||
|
||||
class Retriever(db.Model):
|
||||
id = db.Column(db.Integer, primary_key=True)
|
||||
name = db.Column(db.String(50), nullable=False)
|
||||
description = db.Column(db.Text, nullable=True)
|
||||
catalog_id = db.Column(db.Integer, db.ForeignKey('catalog.id'), nullable=True)
|
||||
type = db.Column(db.String(50), nullable=False, default="STANDARD_RAG")
|
||||
type_version = db.Column(db.String(20), nullable=True, default="STANDARD_RAG")
|
||||
tuning = db.Column(db.Boolean, nullable=True, default=False)
|
||||
|
||||
# Meta Data
|
||||
user_metadata = db.Column(JSONB, nullable=True)
|
||||
system_metadata = db.Column(JSONB, nullable=True)
|
||||
configuration = db.Column(JSONB, nullable=True)
|
||||
arguments = db.Column(JSONB, nullable=True)
|
||||
|
||||
# Versioning Information
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now())
|
||||
created_by = db.Column(db.Integer, db.ForeignKey(User.id), nullable=True)
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now(), onupdate=db.func.now())
|
||||
updated_by = db.Column(db.Integer, db.ForeignKey(User.id))
|
||||
|
||||
|
||||
class Document(db.Model):
|
||||
id = db.Column(db.Integer, primary_key=True)
|
||||
# tenant_id = db.Column(db.Integer, db.ForeignKey(Tenant.id), nullable=False)
|
||||
catalog_id = db.Column(db.Integer, db.ForeignKey(Catalog.id), nullable=True)
|
||||
name = db.Column(db.String(100), nullable=False)
|
||||
tenant_id = db.Column(db.Integer, db.ForeignKey(Tenant.id), nullable=False)
|
||||
valid_from = db.Column(db.DateTime, nullable=True)
|
||||
valid_to = db.Column(db.DateTime, nullable=True)
|
||||
|
||||
# Versioning Information
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now())
|
||||
created_by = db.Column(db.Integer, db.ForeignKey(User.id), nullable=False)
|
||||
created_by = db.Column(db.Integer, db.ForeignKey(User.id), nullable=True)
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now(), onupdate=db.func.now())
|
||||
updated_by = db.Column(db.Integer, db.ForeignKey(User.id))
|
||||
|
||||
# Relations
|
||||
versions = db.relationship('DocumentVersion', backref='document', lazy=True)
|
||||
|
||||
@property
|
||||
def latest_version(self):
|
||||
"""Returns the latest document version (the one with highest id)"""
|
||||
from sqlalchemy import desc
|
||||
return DocumentVersion.query.filter_by(doc_id=self.id).order_by(desc(DocumentVersion.id)).first()
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Document {self.id}: {self.name}>"
|
||||
|
||||
@@ -27,12 +119,17 @@ class DocumentVersion(db.Model):
|
||||
id = db.Column(db.Integer, primary_key=True)
|
||||
doc_id = db.Column(db.Integer, db.ForeignKey(Document.id), nullable=False)
|
||||
url = db.Column(db.String(200), nullable=True)
|
||||
file_location = db.Column(db.String(255), nullable=True)
|
||||
file_name = db.Column(db.String(200), nullable=True)
|
||||
bucket_name = db.Column(db.String(255), nullable=True)
|
||||
object_name = db.Column(db.String(200), nullable=True)
|
||||
file_type = db.Column(db.String(20), nullable=True)
|
||||
sub_file_type = db.Column(db.String(50), nullable=True)
|
||||
file_size = db.Column(db.Float, nullable=True)
|
||||
language = db.Column(db.String(2), nullable=False)
|
||||
user_context = db.Column(db.Text, nullable=True)
|
||||
system_context = db.Column(db.Text, nullable=True)
|
||||
user_metadata = db.Column(JSONB, nullable=True)
|
||||
system_metadata = db.Column(JSONB, nullable=True)
|
||||
catalog_properties = db.Column(JSONB, nullable=True)
|
||||
|
||||
# Versioning Information
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now())
|
||||
@@ -52,12 +149,6 @@ class DocumentVersion(db.Model):
|
||||
def __repr__(self):
|
||||
return f"<DocumentVersion {self.document_language.document_id}.{self.document_language.language}>.{self.id}>"
|
||||
|
||||
def calc_file_location(self):
|
||||
return f"{self.document.tenant_id}/{self.document.id}/{self.language}"
|
||||
|
||||
def calc_file_name(self):
|
||||
return f"{self.id}.{self.file_type}"
|
||||
|
||||
|
||||
class Embedding(db.Model):
|
||||
__tablename__ = 'embeddings'
|
||||
|
||||
528
common/models/entitlements.py
Normal file
528
common/models/entitlements.py
Normal file
@@ -0,0 +1,528 @@
|
||||
from sqlalchemy.sql.expression import text
|
||||
|
||||
from common.extensions import db
|
||||
from datetime import datetime as dt, timezone as tz
|
||||
from enum import Enum
|
||||
from sqlalchemy import event
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.ext.hybrid import hybrid_property
|
||||
from dateutil.relativedelta import relativedelta
|
||||
|
||||
from common.utils.database import Database
|
||||
|
||||
|
||||
class BusinessEventLog(db.Model):
|
||||
__bind_key__ = 'public'
|
||||
__table_args__ = {'schema': 'public'}
|
||||
|
||||
id = db.Column(db.Integer, primary_key=True)
|
||||
timestamp = db.Column(db.DateTime, nullable=False)
|
||||
event_type = db.Column(db.String(50), nullable=False)
|
||||
tenant_id = db.Column(db.Integer, nullable=False)
|
||||
trace_id = db.Column(db.String(50), nullable=False)
|
||||
span_id = db.Column(db.String(50))
|
||||
span_name = db.Column(db.String(255))
|
||||
parent_span_id = db.Column(db.String(50))
|
||||
document_version_id = db.Column(db.Integer)
|
||||
document_version_file_size = db.Column(db.Float)
|
||||
specialist_id = db.Column(db.Integer)
|
||||
specialist_type = db.Column(db.String(50))
|
||||
specialist_type_version = db.Column(db.String(20))
|
||||
chat_session_id = db.Column(db.String(50))
|
||||
interaction_id = db.Column(db.Integer)
|
||||
environment = db.Column(db.String(20))
|
||||
llm_metrics_total_tokens = db.Column(db.Integer)
|
||||
llm_metrics_prompt_tokens = db.Column(db.Integer)
|
||||
llm_metrics_completion_tokens = db.Column(db.Integer)
|
||||
llm_metrics_total_time = db.Column(db.Float)
|
||||
llm_metrics_nr_of_pages = db.Column(db.Integer)
|
||||
llm_metrics_call_count = db.Column(db.Integer)
|
||||
llm_interaction_type = db.Column(db.String(20))
|
||||
message = db.Column(db.Text)
|
||||
license_usage_id = db.Column(db.Integer, db.ForeignKey('public.license_usage.id'), nullable=True)
|
||||
license_usage = db.relationship('LicenseUsage', backref='events')
|
||||
|
||||
|
||||
class License(db.Model):
|
||||
__bind_key__ = 'public'
|
||||
__table_args__ = {'schema': 'public'}
|
||||
|
||||
id = db.Column(db.Integer, primary_key=True)
|
||||
tenant_id = db.Column(db.Integer, db.ForeignKey('public.tenant.id'), nullable=False)
|
||||
tier_id = db.Column(db.Integer, db.ForeignKey('public.license_tier.id'),nullable=False) # 'small', 'medium', 'custom'
|
||||
start_date = db.Column(db.Date, nullable=False)
|
||||
end_date = db.Column(db.Date, nullable=True)
|
||||
nr_of_periods = db.Column(db.Integer, nullable=False)
|
||||
currency = db.Column(db.String(20), nullable=False)
|
||||
yearly_payment = db.Column(db.Boolean, nullable=False, default=False)
|
||||
basic_fee = db.Column(db.Float, nullable=False)
|
||||
max_storage_mb = db.Column(db.Integer, nullable=False)
|
||||
additional_storage_price = db.Column(db.Float, nullable=False)
|
||||
additional_storage_bucket = db.Column(db.Integer, nullable=False)
|
||||
included_embedding_mb = db.Column(db.Integer, nullable=False)
|
||||
additional_embedding_price = db.Column(db.Numeric(10, 4), nullable=False)
|
||||
additional_embedding_bucket = db.Column(db.Integer, nullable=False)
|
||||
included_interaction_tokens = db.Column(db.Integer, nullable=False)
|
||||
additional_interaction_token_price = db.Column(db.Numeric(10, 4), nullable=False)
|
||||
additional_interaction_bucket = db.Column(db.Integer, nullable=False)
|
||||
overage_embedding = db.Column(db.Float, nullable=False, default=0)
|
||||
overage_interaction = db.Column(db.Float, nullable=False, default=0)
|
||||
additional_storage_allowed = db.Column(db.Boolean, nullable=False, default=False)
|
||||
additional_embedding_allowed = db.Column(db.Boolean, nullable=False, default=False)
|
||||
additional_interaction_allowed = db.Column(db.Boolean, nullable=False, default=False)
|
||||
|
||||
# Versioning Information
|
||||
created_at = db.Column(db.DateTime, nullable=True, server_default=db.func.now())
|
||||
created_by = db.Column(db.Integer, db.ForeignKey('public.user.id'), nullable=True)
|
||||
updated_at = db.Column(db.DateTime, nullable=True, server_default=db.func.now(), onupdate=db.func.now())
|
||||
updated_by = db.Column(db.Integer, db.ForeignKey('public.user.id'), nullable=True)
|
||||
|
||||
tenant = db.relationship('Tenant', back_populates='licenses')
|
||||
license_tier = db.relationship('LicenseTier', back_populates='licenses')
|
||||
periods = db.relationship('LicensePeriod', back_populates='license',
|
||||
order_by='LicensePeriod.period_number',
|
||||
cascade='all, delete-orphan')
|
||||
|
||||
def calculate_end_date(start_date, nr_of_periods):
|
||||
"""Utility functie om einddatum te berekenen"""
|
||||
if start_date and nr_of_periods:
|
||||
return start_date + relativedelta(months=nr_of_periods) - relativedelta(days=1)
|
||||
return None
|
||||
|
||||
# Luister naar start_date wijzigingen
|
||||
@event.listens_for(License.start_date, 'set')
|
||||
def set_start_date(target, value, oldvalue, initiator):
|
||||
"""Bijwerken van end_date wanneer start_date wordt aangepast"""
|
||||
if value and target.nr_of_periods:
|
||||
target.end_date = calculate_end_date(value, target.nr_of_periods)
|
||||
|
||||
# Luister naar nr_of_periods wijzigingen
|
||||
@event.listens_for(License.nr_of_periods, 'set')
|
||||
def set_nr_of_periods(target, value, oldvalue, initiator):
|
||||
"""Bijwerken van end_date wanneer nr_of_periods wordt aangepast"""
|
||||
if value and target.start_date:
|
||||
target.end_date = calculate_end_date(target.start_date, value)
|
||||
|
||||
|
||||
class LicenseTier(db.Model):
|
||||
__bind_key__ = 'public'
|
||||
__table_args__ = {'schema': 'public'}
|
||||
|
||||
id = db.Column(db.Integer, primary_key=True)
|
||||
name = db.Column(db.String(50), nullable=False)
|
||||
version = db.Column(db.String(50), nullable=False)
|
||||
start_date = db.Column(db.Date, nullable=False)
|
||||
end_date = db.Column(db.Date, nullable=True)
|
||||
basic_fee_d = db.Column(db.Float, nullable=True)
|
||||
basic_fee_e = db.Column(db.Float, nullable=True)
|
||||
max_storage_mb = db.Column(db.Integer, nullable=False)
|
||||
additional_storage_price_d = db.Column(db.Numeric(10, 4), nullable=False)
|
||||
additional_storage_price_e = db.Column(db.Numeric(10, 4), nullable=False)
|
||||
additional_storage_bucket = db.Column(db.Integer, nullable=False)
|
||||
included_embedding_mb = db.Column(db.Integer, nullable=False)
|
||||
additional_embedding_price_d = db.Column(db.Numeric(10, 4), nullable=False)
|
||||
additional_embedding_price_e = db.Column(db.Numeric(10, 4), nullable=False)
|
||||
additional_embedding_bucket = db.Column(db.Integer, nullable=False)
|
||||
included_interaction_tokens = db.Column(db.Integer, nullable=False)
|
||||
additional_interaction_token_price_d = db.Column(db.Numeric(10, 4), nullable=False)
|
||||
additional_interaction_token_price_e = db.Column(db.Numeric(10, 4), nullable=False)
|
||||
additional_interaction_bucket = db.Column(db.Integer, nullable=False)
|
||||
standard_overage_embedding = db.Column(db.Float, nullable=False, default=0)
|
||||
standard_overage_interaction = db.Column(db.Float, nullable=False, default=0)
|
||||
|
||||
# Versioning Information
|
||||
created_at = db.Column(db.DateTime, nullable=True, server_default=db.func.now())
|
||||
created_by = db.Column(db.Integer, db.ForeignKey('public.user.id'), nullable=True)
|
||||
updated_at = db.Column(db.DateTime, nullable=True, server_default=db.func.now(), onupdate=db.func.now())
|
||||
updated_by = db.Column(db.Integer, db.ForeignKey('public.user.id'), nullable=True)
|
||||
|
||||
licenses = db.relationship('License', back_populates='license_tier')
|
||||
partner_services = db.relationship('PartnerServiceLicenseTier', back_populates='license_tier')
|
||||
|
||||
|
||||
class PartnerServiceLicenseTier(db.Model):
|
||||
__bind_key__ = 'public'
|
||||
__table_args__ = {'schema': 'public'}
|
||||
|
||||
partner_service_id = db.Column(db.Integer, db.ForeignKey('public.partner_service.id'), primary_key=True,
|
||||
nullable=False)
|
||||
license_tier_id = db.Column(db.Integer, db.ForeignKey('public.license_tier.id'), primary_key=True,
|
||||
nullable=False)
|
||||
|
||||
# Versioning Information
|
||||
created_at = db.Column(db.DateTime, nullable=True, server_default=db.func.now())
|
||||
created_by = db.Column(db.Integer, db.ForeignKey('public.user.id'), nullable=True)
|
||||
updated_at = db.Column(db.DateTime, nullable=True, server_default=db.func.now(), onupdate=db.func.now())
|
||||
updated_by = db.Column(db.Integer, db.ForeignKey('public.user.id'), nullable=True)
|
||||
|
||||
license_tier = db.relationship('LicenseTier', back_populates='partner_services')
|
||||
partner_service = db.relationship('PartnerService', back_populates='license_tiers')
|
||||
|
||||
|
||||
class PeriodStatus(Enum):
|
||||
UPCOMING = "UPCOMING" # The period is still in the future
|
||||
PENDING = "PENDING" # The period is active, but prepaid is not yet received
|
||||
ACTIVE = "ACTIVE" # The period is active and prepaid has been received
|
||||
COMPLETED = "COMPLETED" # The period has been completed, but not yet invoiced
|
||||
INVOICED = "INVOICED" # The period has been completed and invoiced, but overage payment still pending
|
||||
CLOSED = "CLOSED" # The period has been closed, invoiced and fully paid
|
||||
|
||||
|
||||
class LicensePeriod(db.Model):
|
||||
__bind_key__ = 'public'
|
||||
__table_args__ = {'schema': 'public'}
|
||||
|
||||
id = db.Column(db.Integer, primary_key=True)
|
||||
license_id = db.Column(db.Integer, db.ForeignKey('public.license.id'), nullable=False)
|
||||
tenant_id = db.Column(db.Integer, db.ForeignKey('public.tenant.id'), nullable=False)
|
||||
|
||||
# Period identification
|
||||
period_number = db.Column(db.Integer, nullable=False)
|
||||
period_start = db.Column(db.Date, nullable=False)
|
||||
period_end = db.Column(db.Date, nullable=False)
|
||||
|
||||
# License configuration snapshot - copied from license when period is created
|
||||
currency = db.Column(db.String(20), nullable=True)
|
||||
basic_fee = db.Column(db.Float, nullable=True)
|
||||
max_storage_mb = db.Column(db.Integer, nullable=True)
|
||||
additional_storage_price = db.Column(db.Float, nullable=True)
|
||||
additional_storage_bucket = db.Column(db.Integer, nullable=True)
|
||||
included_embedding_mb = db.Column(db.Integer, nullable=True)
|
||||
additional_embedding_price = db.Column(db.Numeric(10, 4), nullable=True)
|
||||
additional_embedding_bucket = db.Column(db.Integer, nullable=True)
|
||||
included_interaction_tokens = db.Column(db.Integer, nullable=True)
|
||||
additional_interaction_token_price = db.Column(db.Numeric(10, 4), nullable=True)
|
||||
additional_interaction_bucket = db.Column(db.Integer, nullable=True)
|
||||
|
||||
# Allowance flags - can be changed from False to True within a period
|
||||
additional_storage_allowed = db.Column(db.Boolean, nullable=True, default=False)
|
||||
additional_embedding_allowed = db.Column(db.Boolean, nullable=True, default=False)
|
||||
additional_interaction_allowed = db.Column(db.Boolean, nullable=True, default=False)
|
||||
|
||||
# Status tracking
|
||||
status = db.Column(db.Enum(PeriodStatus), nullable=False, default=PeriodStatus.UPCOMING)
|
||||
|
||||
# State transition timestamps
|
||||
upcoming_at = db.Column(db.DateTime, nullable=True)
|
||||
pending_at = db.Column(db.DateTime, nullable=True)
|
||||
active_at = db.Column(db.DateTime, nullable=True)
|
||||
completed_at = db.Column(db.DateTime, nullable=True)
|
||||
invoiced_at = db.Column(db.DateTime, nullable=True)
|
||||
closed_at = db.Column(db.DateTime, nullable=True)
|
||||
|
||||
# Standard audit fields
|
||||
created_at = db.Column(db.DateTime, server_default=db.func.now())
|
||||
updated_at = db.Column(db.DateTime, server_default=db.func.now(), onupdate=db.func.now())
|
||||
created_by = db.Column(db.Integer, db.ForeignKey('public.user.id'))
|
||||
updated_by = db.Column(db.Integer, db.ForeignKey('public.user.id'))
|
||||
|
||||
# Relationships
|
||||
license = db.relationship('License', back_populates='periods')
|
||||
license_usage = db.relationship('LicenseUsage',
|
||||
uselist=False, # This makes it one-to-one
|
||||
back_populates='license_period',
|
||||
cascade='all, delete-orphan')
|
||||
payments = db.relationship('Payment', back_populates='license_period')
|
||||
invoices = db.relationship('Invoice', back_populates='license_period',
|
||||
cascade='all, delete-orphan')
|
||||
|
||||
def update_allowance(self, allowance_type, allow_value, user_id=None):
|
||||
"""
|
||||
Update an allowance flag within a period
|
||||
Only allows transitioning from False to True
|
||||
|
||||
Args:
|
||||
allowance_type: One of 'storage', 'embedding', or 'interaction'
|
||||
allow_value: The new value (must be True)
|
||||
user_id: User ID performing the update
|
||||
|
||||
Raises:
|
||||
ValueError: If trying to change from True to False, or invalid allowance type
|
||||
"""
|
||||
field_name = f"additional_{allowance_type}_allowed"
|
||||
|
||||
# Verify valid field
|
||||
if not hasattr(self, field_name):
|
||||
raise ValueError(f"Invalid allowance type: {allowance_type}")
|
||||
|
||||
# Get current value
|
||||
current_value = getattr(self, field_name)
|
||||
|
||||
# Only allow False -> True transition
|
||||
if current_value is True and allow_value is True:
|
||||
# Already True, no change needed
|
||||
return
|
||||
elif allow_value is False:
|
||||
raise ValueError(f"Cannot change {field_name} from {current_value} to False")
|
||||
|
||||
# Update the field
|
||||
setattr(self, field_name, True)
|
||||
self.updated_at = dt.now(tz.utc)
|
||||
if user_id:
|
||||
self.updated_by = user_id
|
||||
|
||||
@property
|
||||
def prepaid_invoice(self):
|
||||
"""Get the prepaid invoice for this period"""
|
||||
return Invoice.query.filter_by(
|
||||
license_period_id=self.id,
|
||||
invoice_type=PaymentType.PREPAID
|
||||
).first()
|
||||
|
||||
@property
|
||||
def overage_invoice(self):
|
||||
"""Get the overage invoice for this period"""
|
||||
return Invoice.query.filter_by(
|
||||
license_period_id=self.id,
|
||||
invoice_type=PaymentType.POSTPAID
|
||||
).first()
|
||||
|
||||
@property
|
||||
def prepaid_payment(self):
|
||||
"""Get the prepaid payment for this period"""
|
||||
return Payment.query.filter_by(
|
||||
license_period_id=self.id,
|
||||
payment_type=PaymentType.PREPAID
|
||||
).first()
|
||||
|
||||
@property
|
||||
def overage_payment(self):
|
||||
"""Get the overage payment for this period"""
|
||||
return Payment.query.filter_by(
|
||||
license_period_id=self.id,
|
||||
payment_type=PaymentType.POSTPAID
|
||||
).first()
|
||||
|
||||
@property
|
||||
def all_invoices(self):
|
||||
"""Get all invoices for this period"""
|
||||
return self.invoices
|
||||
|
||||
@property
|
||||
def all_payments(self):
|
||||
"""Get all payments for this period"""
|
||||
return self.payments
|
||||
|
||||
def transition_status(self, new_status: PeriodStatus, user_id: int = None):
|
||||
"""Transition to a new status with proper validation and logging"""
|
||||
if not self.can_transition_to(new_status):
|
||||
raise ValueError(f"Invalid status transition from {self.status} to {new_status}")
|
||||
|
||||
self.status = new_status
|
||||
self.updated_at = dt.now(tz.utc)
|
||||
if user_id:
|
||||
self.updated_by = user_id
|
||||
|
||||
# Set appropriate timestamps
|
||||
if new_status == PeriodStatus.ACTIVE and not self.prepaid_received_at:
|
||||
self.prepaid_received_at = dt.now(tz.utc)
|
||||
elif new_status == PeriodStatus.COMPLETED:
|
||||
self.completed_at = dt.now(tz.utc)
|
||||
elif new_status == PeriodStatus.INVOICED:
|
||||
self.invoiced_at = dt.now(tz.utc)
|
||||
elif new_status == PeriodStatus.CLOSED:
|
||||
self.closed_at = dt.now(tz.utc)
|
||||
|
||||
@property
|
||||
def is_overdue(self):
|
||||
"""Check if a prepaid payment is overdue"""
|
||||
return (self.status == PeriodStatus.PENDING and
|
||||
self.period_start <= dt.now(tz.utc).date())
|
||||
|
||||
def can_transition_to(self, new_status: PeriodStatus) -> bool:
|
||||
"""Check if a status transition is valid"""
|
||||
valid_transitions = {
|
||||
PeriodStatus.UPCOMING: [PeriodStatus.ACTIVE, PeriodStatus.PENDING],
|
||||
PeriodStatus.PENDING: [PeriodStatus.ACTIVE],
|
||||
PeriodStatus.ACTIVE: [PeriodStatus.COMPLETED],
|
||||
PeriodStatus.COMPLETED: [PeriodStatus.INVOICED, PeriodStatus.CLOSED],
|
||||
PeriodStatus.INVOICED: [PeriodStatus.CLOSED],
|
||||
PeriodStatus.CLOSED: []
|
||||
}
|
||||
return new_status in valid_transitions.get(self.status, [])
|
||||
|
||||
def __repr__(self):
|
||||
return f'<LicensePeriod {self.id}: License {self.license_id}, Period {self.period_number}>'
|
||||
|
||||
|
||||
class LicenseUsage(db.Model):
|
||||
__bind_key__ = 'public'
|
||||
__table_args__ = {'schema': 'public'}
|
||||
|
||||
id = db.Column(db.Integer, primary_key=True)
|
||||
tenant_id = db.Column(db.Integer, db.ForeignKey('public.tenant.id'), nullable=False)
|
||||
storage_mb_used = db.Column(db.Float, default=0)
|
||||
embedding_mb_used = db.Column(db.Float, default=0)
|
||||
embedding_prompt_tokens_used = db.Column(db.Integer, default=0)
|
||||
embedding_completion_tokens_used = db.Column(db.Integer, default=0)
|
||||
embedding_total_tokens_used = db.Column(db.Integer, default=0)
|
||||
interaction_prompt_tokens_used = db.Column(db.Integer, default=0)
|
||||
interaction_completion_tokens_used = db.Column(db.Integer, default=0)
|
||||
interaction_total_tokens_used = db.Column(db.Integer, default=0)
|
||||
license_period_id = db.Column(db.Integer, db.ForeignKey('public.license_period.id'), nullable=False)
|
||||
|
||||
# Standard audit fields
|
||||
created_at = db.Column(db.DateTime, server_default=db.func.now())
|
||||
updated_at = db.Column(db.DateTime, server_default=db.func.now(), onupdate=db.func.now())
|
||||
created_by = db.Column(db.Integer, db.ForeignKey('public.user.id'))
|
||||
updated_by = db.Column(db.Integer, db.ForeignKey('public.user.id'))
|
||||
|
||||
license_period = db.relationship('LicensePeriod', back_populates='license_usage')
|
||||
|
||||
def recalculate_storage(self):
|
||||
Database(self.tenant_id).switch_schema()
|
||||
# Perform a SUM operation to get the total file size from document_versions
|
||||
total_storage = db.session.execute(text(f"""
|
||||
SELECT SUM(file_size)
|
||||
FROM document_version
|
||||
""")).scalar()
|
||||
|
||||
self.storage_mb_used = total_storage
|
||||
|
||||
|
||||
class PaymentType(Enum):
|
||||
PREPAID = "PREPAID"
|
||||
POSTPAID = "POSTPAID"
|
||||
|
||||
|
||||
class PaymentStatus(Enum):
|
||||
PENDING = "PENDING"
|
||||
PAID = "PAID"
|
||||
FAILED = "FAILED"
|
||||
CANCELLED = "CANCELLED"
|
||||
|
||||
|
||||
class Payment(db.Model):
|
||||
__bind_key__ = 'public'
|
||||
__table_args__ = {'schema': 'public'}
|
||||
|
||||
id = db.Column(db.Integer, primary_key=True)
|
||||
license_period_id = db.Column(db.Integer, db.ForeignKey('public.license_period.id'), nullable=True)
|
||||
tenant_id = db.Column(db.Integer, db.ForeignKey('public.tenant.id'), nullable=False)
|
||||
|
||||
# Payment details
|
||||
payment_type = db.Column(db.Enum(PaymentType), nullable=False)
|
||||
amount = db.Column(db.Numeric(10, 2), nullable=False)
|
||||
currency = db.Column(db.String(3), nullable=False)
|
||||
description = db.Column(db.Text, nullable=True)
|
||||
|
||||
# Status tracking
|
||||
status = db.Column(db.Enum(PaymentStatus), nullable=False, default=PaymentStatus.PENDING)
|
||||
|
||||
# External provider information
|
||||
external_payment_id = db.Column(db.String(255), nullable=True)
|
||||
payment_method = db.Column(db.String(50), nullable=True) # credit_card, bank_transfer, etc.
|
||||
provider_data = db.Column(JSONB, nullable=True) # Provider-specific data
|
||||
|
||||
# Payment information
|
||||
paid_at = db.Column(db.DateTime, nullable=True)
|
||||
|
||||
# Standard audit fields
|
||||
created_at = db.Column(db.DateTime, server_default=db.func.now())
|
||||
created_by = db.Column(db.Integer, db.ForeignKey('public.user.id'))
|
||||
updated_at = db.Column(db.DateTime, server_default=db.func.now(), onupdate=db.func.now())
|
||||
updated_by = db.Column(db.Integer, db.ForeignKey('public.user.id'))
|
||||
|
||||
# Relationships
|
||||
license_period = db.relationship('LicensePeriod', back_populates='payments')
|
||||
invoice = db.relationship('Invoice', back_populates='payment', uselist=False)
|
||||
|
||||
@property
|
||||
def is_overdue(self):
|
||||
"""Check if payment is overdue"""
|
||||
if self.status != PaymentStatus.PENDING:
|
||||
return False
|
||||
|
||||
# For prepaid payments, check if period start has passed
|
||||
if (self.payment_type == PaymentType.PREPAID and
|
||||
self.license_period_id):
|
||||
return self.license_period.period_start <= dt.now(tz.utc).date()
|
||||
|
||||
# For postpaid, check against due date (would be on invoice)
|
||||
return False
|
||||
|
||||
def __repr__(self):
|
||||
return f'<Payment {self.id}: {self.payment_type} {self.amount} {self.currency}>'
|
||||
|
||||
|
||||
class InvoiceStatus(Enum):
|
||||
DRAFT = "DRAFT"
|
||||
SENT = "SENT"
|
||||
PAID = "PAID"
|
||||
OVERDUE = "OVERDUE"
|
||||
CANCELLED = "CANCELLED"
|
||||
|
||||
|
||||
class Invoice(db.Model):
|
||||
__bind_key__ = 'public'
|
||||
__table_args__ = {'schema': 'public'}
|
||||
|
||||
id = db.Column(db.Integer, primary_key=True)
|
||||
license_period_id = db.Column(db.Integer, db.ForeignKey('public.license_period.id'), nullable=False)
|
||||
payment_id = db.Column(db.Integer, db.ForeignKey('public.payment.id'), nullable=True)
|
||||
tenant_id = db.Column(db.Integer, db.ForeignKey('public.tenant.id'), nullable=False)
|
||||
|
||||
# Invoice details
|
||||
invoice_type = db.Column(db.Enum(PaymentType), nullable=False)
|
||||
invoice_number = db.Column(db.String(50), unique=True, nullable=False)
|
||||
invoice_date = db.Column(db.Date, nullable=False)
|
||||
due_date = db.Column(db.Date, nullable=False)
|
||||
|
||||
# Financial details
|
||||
amount = db.Column(db.Numeric(10, 2), nullable=False)
|
||||
currency = db.Column(db.String(3), nullable=False)
|
||||
tax_amount = db.Column(db.Numeric(10, 2), default=0)
|
||||
|
||||
# Descriptive fields
|
||||
description = db.Column(db.Text, nullable=True)
|
||||
status = db.Column(db.Enum(InvoiceStatus), nullable=False, default=InvoiceStatus.DRAFT)
|
||||
|
||||
# Timestamps
|
||||
sent_at = db.Column(db.DateTime, nullable=True)
|
||||
paid_at = db.Column(db.DateTime, nullable=True)
|
||||
|
||||
# Standard audit fields
|
||||
created_at = db.Column(db.DateTime, server_default=db.func.now())
|
||||
created_by = db.Column(db.Integer, db.ForeignKey('public.user.id'))
|
||||
updated_at = db.Column(db.DateTime, server_default=db.func.now(), onupdate=db.func.now())
|
||||
updated_by = db.Column(db.Integer, db.ForeignKey('public.user.id'))
|
||||
|
||||
# Relationships
|
||||
license_period = db.relationship('LicensePeriod', back_populates='invoices')
|
||||
payment = db.relationship('Payment', back_populates='invoice')
|
||||
|
||||
def __repr__(self):
|
||||
return f'<Invoice {self.invoice_number}: {self.amount} {self.currency}>'
|
||||
|
||||
|
||||
class LicenseChangeLog(db.Model):
|
||||
"""
|
||||
Log of changes to license configurations
|
||||
Used for auditing and tracking when/why license details changed
|
||||
"""
|
||||
__bind_key__ = 'public'
|
||||
__table_args__ = {'schema': 'public'}
|
||||
|
||||
id = db.Column(db.Integer, primary_key=True)
|
||||
license_id = db.Column(db.Integer, db.ForeignKey('public.license.id'), nullable=False)
|
||||
changed_at = db.Column(db.DateTime, nullable=False, default=lambda: dt.now(tz.utc))
|
||||
|
||||
# What changed
|
||||
field_name = db.Column(db.String(100), nullable=False)
|
||||
old_value = db.Column(db.String(255), nullable=True)
|
||||
new_value = db.Column(db.String(255), nullable=False)
|
||||
|
||||
# Why it changed
|
||||
reason = db.Column(db.Text, nullable=True)
|
||||
|
||||
# Standard audit fields
|
||||
created_by = db.Column(db.Integer, db.ForeignKey('public.user.id'), nullable=True)
|
||||
|
||||
# Relationships
|
||||
license = db.relationship('License', backref=db.backref('change_logs', order_by='LicenseChangeLog.changed_at'))
|
||||
|
||||
def __repr__(self):
|
||||
return f'<LicenseChangeLog: {self.license_id} {self.field_name} {self.old_value} -> {self.new_value}>'
|
||||
|
||||
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
|
||||
from ..extensions import db
|
||||
from .user import User, Tenant
|
||||
from .document import Embedding
|
||||
from .user import User, Tenant, TenantMake
|
||||
from .document import Embedding, Retriever
|
||||
|
||||
|
||||
class ChatSession(db.Model):
|
||||
id = db.Column(db.Integer, primary_key=True)
|
||||
user_id = db.Column(db.Integer, db.ForeignKey(User.id), nullable=True)
|
||||
session_id = db.Column(db.String(36), nullable=True)
|
||||
session_id = db.Column(db.String(49), nullable=True)
|
||||
session_start = db.Column(db.DateTime, nullable=False)
|
||||
session_end = db.Column(db.DateTime, nullable=True)
|
||||
timezone = db.Column(db.String(30), nullable=True)
|
||||
@@ -18,14 +20,184 @@ class ChatSession(db.Model):
|
||||
return f"<ChatSession {self.id} by {self.user_id}>"
|
||||
|
||||
|
||||
class Specialist(db.Model):
|
||||
id = db.Column(db.Integer, primary_key=True)
|
||||
name = db.Column(db.String(50), nullable=False)
|
||||
description = db.Column(db.Text, nullable=True)
|
||||
type = db.Column(db.String(50), nullable=False, default="STANDARD_RAG")
|
||||
type_version = db.Column(db.String(20), nullable=True, default="1.0.0")
|
||||
tuning = db.Column(db.Boolean, nullable=True, default=False)
|
||||
configuration = db.Column(JSONB, nullable=True)
|
||||
arguments = db.Column(JSONB, nullable=True)
|
||||
active = db.Column(db.Boolean, nullable=True, default=True)
|
||||
|
||||
# Relationship to retrievers through the association table
|
||||
retrievers = db.relationship('SpecialistRetriever', backref='specialist', lazy=True,
|
||||
cascade="all, delete-orphan")
|
||||
agents = db.relationship('EveAIAgent', backref='specialist', lazy=True)
|
||||
tasks = db.relationship('EveAITask', backref='specialist', lazy=True)
|
||||
tools = db.relationship('EveAITool', backref='specialist', lazy=True)
|
||||
dispatchers = db.relationship('SpecialistDispatcher', backref='specialist', lazy=True)
|
||||
|
||||
# Versioning Information
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now())
|
||||
created_by = db.Column(db.Integer, db.ForeignKey(User.id), nullable=True)
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now(), onupdate=db.func.now())
|
||||
updated_by = db.Column(db.Integer, db.ForeignKey(User.id))
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Specialist {self.id}: {self.name}>"
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
'id': self.id,
|
||||
'name': self.name,
|
||||
'description': self.description,
|
||||
'type': self.type,
|
||||
'type_version': self.type_version,
|
||||
'configuration': self.configuration,
|
||||
'arguments': self.arguments,
|
||||
'active': self.active,
|
||||
}
|
||||
|
||||
|
||||
class EveAIAsset(db.Model):
|
||||
id = db.Column(db.Integer, primary_key=True)
|
||||
name = db.Column(db.String(50), nullable=False)
|
||||
description = db.Column(db.Text, nullable=True)
|
||||
type = db.Column(db.String(50), nullable=False, default="DOCUMENT_TEMPLATE")
|
||||
type_version = db.Column(db.String(20), nullable=True, default="1.0.0")
|
||||
|
||||
# Storage information
|
||||
bucket_name = db.Column(db.String(255), nullable=True)
|
||||
object_name = db.Column(db.String(200), nullable=True)
|
||||
file_type = db.Column(db.String(20), nullable=True)
|
||||
file_size = db.Column(db.Float, nullable=True)
|
||||
|
||||
# Metadata information
|
||||
user_metadata = db.Column(JSONB, nullable=True)
|
||||
system_metadata = db.Column(JSONB, nullable=True)
|
||||
|
||||
# Configuration information
|
||||
configuration = db.Column(JSONB, nullable=True)
|
||||
|
||||
# Cost information
|
||||
prompt_tokens = db.Column(db.Integer, nullable=True)
|
||||
completion_tokens = db.Column(db.Integer, nullable=True)
|
||||
|
||||
# Versioning Information
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now())
|
||||
created_by = db.Column(db.Integer, db.ForeignKey(User.id), nullable=True)
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now(), onupdate=db.func.now())
|
||||
updated_by = db.Column(db.Integer, db.ForeignKey(User.id))
|
||||
|
||||
last_used_at = db.Column(db.DateTime, nullable=True)
|
||||
|
||||
|
||||
class EveAIDataCapsule(db.Model):
|
||||
id = db.Column(db.Integer, primary_key=True)
|
||||
chat_session_id = db.Column(db.Integer, db.ForeignKey(ChatSession.id), nullable=False)
|
||||
type = db.Column(db.String(50), nullable=False, default="STANDARD_RAG")
|
||||
type_version = db.Column(db.String(20), nullable=True, default="1.0.0")
|
||||
configuration = db.Column(JSONB, nullable=True)
|
||||
data = db.Column(JSONB, nullable=True)
|
||||
|
||||
# Versioning Information
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now())
|
||||
created_by = db.Column(db.Integer, db.ForeignKey(User.id), nullable=True)
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now(), onupdate=db.func.now())
|
||||
updated_by = db.Column(db.Integer, db.ForeignKey(User.id))
|
||||
|
||||
# Unieke constraint voor chat_session_id, type en type_version
|
||||
__table_args__ = (db.UniqueConstraint('chat_session_id', 'type', 'type_version', name='uix_data_capsule_session_type_version'),)
|
||||
|
||||
|
||||
class EveAIAgent(db.Model):
|
||||
id = db.Column(db.Integer, primary_key=True)
|
||||
specialist_id = db.Column(db.Integer, db.ForeignKey(Specialist.id), nullable=False)
|
||||
name = db.Column(db.String(50), nullable=False)
|
||||
description = db.Column(db.Text, nullable=True)
|
||||
type = db.Column(db.String(50), nullable=False, default="STANDARD_RAG")
|
||||
type_version = db.Column(db.String(20), nullable=True, default="1.0.0")
|
||||
role = db.Column(db.Text, nullable=True)
|
||||
goal = db.Column(db.Text, nullable=True)
|
||||
backstory = db.Column(db.Text, nullable=True)
|
||||
temperature = db.Column(db.Float, nullable=True)
|
||||
llm_model = db.Column(db.String(50), nullable=True)
|
||||
tuning = db.Column(db.Boolean, nullable=True, default=False)
|
||||
configuration = db.Column(JSONB, nullable=True)
|
||||
arguments = db.Column(JSONB, nullable=True)
|
||||
|
||||
# Versioning Information
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now())
|
||||
created_by = db.Column(db.Integer, db.ForeignKey(User.id), nullable=True)
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now(), onupdate=db.func.now())
|
||||
updated_by = db.Column(db.Integer, db.ForeignKey(User.id))
|
||||
|
||||
|
||||
class EveAITask(db.Model):
|
||||
id = db.Column(db.Integer, primary_key=True)
|
||||
specialist_id = db.Column(db.Integer, db.ForeignKey(Specialist.id), nullable=False)
|
||||
name = db.Column(db.String(50), nullable=False)
|
||||
description = db.Column(db.Text, nullable=True)
|
||||
type = db.Column(db.String(50), nullable=False, default="STANDARD_RAG")
|
||||
type_version = db.Column(db.String(20), nullable=True, default="1.0.0")
|
||||
task_description = db.Column(db.Text, nullable=True)
|
||||
expected_output = db.Column(db.Text, nullable=True)
|
||||
tuning = db.Column(db.Boolean, nullable=True, default=False)
|
||||
configuration = db.Column(JSONB, nullable=True)
|
||||
arguments = db.Column(JSONB, nullable=True)
|
||||
context = db.Column(JSONB, nullable=True)
|
||||
asynchronous = db.Column(db.Boolean, nullable=True, default=False)
|
||||
|
||||
# Versioning Information
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now())
|
||||
created_by = db.Column(db.Integer, db.ForeignKey(User.id), nullable=True)
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now(), onupdate=db.func.now())
|
||||
updated_by = db.Column(db.Integer, db.ForeignKey(User.id))
|
||||
|
||||
|
||||
class EveAITool(db.Model):
|
||||
id = db.Column(db.Integer, primary_key=True)
|
||||
specialist_id = db.Column(db.Integer, db.ForeignKey(Specialist.id), nullable=False)
|
||||
name = db.Column(db.String(50), nullable=False)
|
||||
description = db.Column(db.Text, nullable=True)
|
||||
type = db.Column(db.String(50), nullable=False, default="STANDARD_RAG")
|
||||
type_version = db.Column(db.String(20), nullable=True, default="1.0.0")
|
||||
tuning = db.Column(db.Boolean, nullable=True, default=False)
|
||||
configuration = db.Column(JSONB, nullable=True)
|
||||
arguments = db.Column(JSONB, nullable=True)
|
||||
|
||||
# Versioning Information
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now())
|
||||
created_by = db.Column(db.Integer, db.ForeignKey(User.id), nullable=True)
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now(), onupdate=db.func.now())
|
||||
updated_by = db.Column(db.Integer, db.ForeignKey(User.id))
|
||||
|
||||
|
||||
class Dispatcher(db.Model):
|
||||
id = db.Column(db.Integer, primary_key=True)
|
||||
name = db.Column(db.String(50), nullable=False)
|
||||
description = db.Column(db.Text, nullable=True)
|
||||
type = db.Column(db.String(50), nullable=False, default="STANDARD_RAG")
|
||||
type_version = db.Column(db.String(20), nullable=True, default="1.0.0")
|
||||
tuning = db.Column(db.Boolean, nullable=True, default=False)
|
||||
configuration = db.Column(JSONB, nullable=True)
|
||||
arguments = db.Column(JSONB, nullable=True)
|
||||
|
||||
# Versioning Information
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now())
|
||||
created_by = db.Column(db.Integer, db.ForeignKey(User.id), nullable=True)
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now(), onupdate=db.func.now())
|
||||
updated_by = db.Column(db.Integer, db.ForeignKey(User.id))
|
||||
|
||||
|
||||
class Interaction(db.Model):
|
||||
id = db.Column(db.Integer, primary_key=True)
|
||||
chat_session_id = db.Column(db.Integer, db.ForeignKey(ChatSession.id), nullable=False)
|
||||
question = db.Column(db.Text, nullable=False)
|
||||
detailed_question = db.Column(db.Text, nullable=True)
|
||||
answer = db.Column(db.Text, nullable=True)
|
||||
algorithm_used = db.Column(db.String(20), nullable=True)
|
||||
language = db.Column(db.String(2), nullable=False)
|
||||
specialist_id = db.Column(db.Integer, db.ForeignKey(Specialist.id), nullable=True)
|
||||
specialist_arguments = db.Column(JSONB, nullable=True)
|
||||
specialist_results = db.Column(JSONB, nullable=True)
|
||||
timezone = db.Column(db.String(30), nullable=True)
|
||||
appreciation = db.Column(db.Integer, nullable=True)
|
||||
|
||||
@@ -33,6 +205,7 @@ class Interaction(db.Model):
|
||||
question_at = db.Column(db.DateTime, nullable=False)
|
||||
detailed_question_at = db.Column(db.DateTime, nullable=True)
|
||||
answer_at = db.Column(db.DateTime, nullable=True)
|
||||
processing_error = db.Column(db.String(255), nullable=True)
|
||||
|
||||
# Relations
|
||||
embeddings = db.relationship('InteractionEmbedding', backref='interaction', lazy=True)
|
||||
@@ -44,3 +217,50 @@ class Interaction(db.Model):
|
||||
class InteractionEmbedding(db.Model):
|
||||
interaction_id = db.Column(db.Integer, db.ForeignKey(Interaction.id, ondelete='CASCADE'), primary_key=True)
|
||||
embedding_id = db.Column(db.Integer, db.ForeignKey(Embedding.id, ondelete='CASCADE'), primary_key=True)
|
||||
|
||||
|
||||
class SpecialistRetriever(db.Model):
|
||||
specialist_id = db.Column(db.Integer, db.ForeignKey(Specialist.id, ondelete='CASCADE'), primary_key=True)
|
||||
retriever_id = db.Column(db.Integer, db.ForeignKey(Retriever.id, ondelete='CASCADE'), primary_key=True)
|
||||
|
||||
retriever = db.relationship("Retriever", backref="specialist_retrievers")
|
||||
|
||||
|
||||
class SpecialistDispatcher(db.Model):
|
||||
specialist_id = db.Column(db.Integer, db.ForeignKey(Specialist.id, ondelete='CASCADE'), primary_key=True)
|
||||
dispatcher_id = db.Column(db.Integer, db.ForeignKey(Dispatcher.id, ondelete='CASCADE'), primary_key=True)
|
||||
|
||||
dispatcher = db.relationship("Dispatcher", backref="specialist_dispatchers")
|
||||
|
||||
|
||||
class SpecialistMagicLink(db.Model):
|
||||
id = db.Column(db.Integer, primary_key=True)
|
||||
name = db.Column(db.String(50), nullable=False)
|
||||
description = db.Column(db.Text, nullable=True)
|
||||
specialist_id = db.Column(db.Integer, db.ForeignKey(Specialist.id, ondelete='CASCADE'), nullable=False)
|
||||
tenant_make_id = db.Column(db.Integer, db.ForeignKey(TenantMake.id, ondelete='CASCADE'), nullable=True)
|
||||
magic_link_code = db.Column(db.String(55), nullable=False, unique=True)
|
||||
|
||||
valid_from = db.Column(db.DateTime, nullable=True)
|
||||
valid_to = db.Column(db.DateTime, nullable=True)
|
||||
|
||||
specialist_args = db.Column(JSONB, nullable=True)
|
||||
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now())
|
||||
created_by = db.Column(db.Integer, db.ForeignKey(User.id), nullable=True)
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now(), onupdate=db.func.now())
|
||||
updated_by = db.Column(db.Integer, db.ForeignKey(User.id))
|
||||
|
||||
def __repr__(self):
|
||||
return f"<SpecialistMagicLink {self.specialist_id} {self.magic_link_code}>"
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
'id': self.id,
|
||||
'name': self.name,
|
||||
'description': self.description,
|
||||
'magic_link_code': self.magic_link_code,
|
||||
'valid_from': self.valid_from,
|
||||
'valid_to': self.valid_to,
|
||||
'specialist_args': self.specialist_args,
|
||||
}
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
from datetime import date
|
||||
from enum import Enum
|
||||
|
||||
from common.extensions import db
|
||||
from flask_security import UserMixin, RoleMixin
|
||||
from sqlalchemy.dialects.postgresql import ARRAY
|
||||
from sqlalchemy.dialects.postgresql import ARRAY, JSONB
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import CheckConstraint
|
||||
|
||||
from common.models.entitlements import License
|
||||
|
||||
|
||||
class Tenant(db.Model):
|
||||
@@ -17,50 +21,33 @@ class Tenant(db.Model):
|
||||
|
||||
# company Information
|
||||
id = db.Column(db.Integer, primary_key=True)
|
||||
code = db.Column(db.String(50), unique=True, nullable=True)
|
||||
name = db.Column(db.String(80), unique=True, nullable=False)
|
||||
website = db.Column(db.String(255), nullable=True)
|
||||
timezone = db.Column(db.String(50), nullable=True, default='UTC')
|
||||
rag_context = db.Column(db.Text, nullable=True)
|
||||
type = db.Column(db.String(20), nullable=True, server_default='Active')
|
||||
|
||||
# language information
|
||||
default_language = db.Column(db.String(2), nullable=True)
|
||||
allowed_languages = db.Column(ARRAY(sa.String(2)), nullable=True)
|
||||
|
||||
# LLM specific choices
|
||||
embedding_model = db.Column(db.String(50), nullable=True)
|
||||
llm_model = db.Column(db.String(50), nullable=True)
|
||||
|
||||
# Embedding variables
|
||||
html_tags = db.Column(ARRAY(sa.String(10)), nullable=True, default=['p', 'h1', 'h2', 'h3', 'h4', 'h5', 'h6', 'li'])
|
||||
html_end_tags = db.Column(ARRAY(sa.String(10)), nullable=True, default=['p', 'li'])
|
||||
html_included_elements = db.Column(ARRAY(sa.String(50)), nullable=True)
|
||||
html_excluded_elements = db.Column(ARRAY(sa.String(50)), nullable=True)
|
||||
min_chunk_size = db.Column(db.Integer, nullable=True, default=2000)
|
||||
max_chunk_size = db.Column(db.Integer, nullable=True, default=3000)
|
||||
|
||||
|
||||
# Embedding search variables
|
||||
es_k = db.Column(db.Integer, nullable=True, default=5)
|
||||
es_similarity_threshold = db.Column(db.Float, nullable=True, default=0.7)
|
||||
|
||||
# Chat variables
|
||||
chat_RAG_temperature = db.Column(db.Float, nullable=True, default=0.3)
|
||||
chat_no_RAG_temperature = db.Column(db.Float, nullable=True, default=0.5)
|
||||
fallback_algorithms = db.Column(ARRAY(sa.String(50)), nullable=True)
|
||||
|
||||
# Licensing Information
|
||||
license_start_date = db.Column(db.Date, nullable=True)
|
||||
license_end_date = db.Column(db.Date, nullable=True)
|
||||
allowed_monthly_interactions = db.Column(db.Integer, nullable=True)
|
||||
encrypted_chat_api_key = db.Column(db.String(500), nullable=True)
|
||||
|
||||
# Tuning enablers
|
||||
embed_tuning = db.Column(db.Boolean, nullable=True, default=False)
|
||||
rag_tuning = db.Column(db.Boolean, nullable=True, default=False)
|
||||
# Entitlements
|
||||
currency = db.Column(db.String(20), nullable=True)
|
||||
storage_dirty = db.Column(db.Boolean, nullable=True, default=False)
|
||||
default_tenant_make_id = db.Column(db.Integer, db.ForeignKey('public.tenant_make.id'), nullable=True)
|
||||
|
||||
# Relations
|
||||
users = db.relationship('User', backref='tenant')
|
||||
domains = db.relationship('TenantDomain', backref='tenant')
|
||||
licenses = db.relationship('License', back_populates='tenant')
|
||||
license_usages = db.relationship('LicenseUsage', backref='tenant')
|
||||
tenant_makes = db.relationship('TenantMake', backref='tenant', foreign_keys='TenantMake.tenant_id')
|
||||
default_tenant_make = db.relationship('TenantMake', foreign_keys=[default_tenant_make_id], uselist=False)
|
||||
|
||||
@property
|
||||
def current_license(self):
|
||||
today = date.today()
|
||||
return License.query.filter(
|
||||
License.tenant_id == self.id,
|
||||
License.start_date <= today,
|
||||
(License.end_date.is_(None) | (License.end_date >= today))
|
||||
).order_by(License.start_date.desc()).first()
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Tenant {self.id}: {self.name}>"
|
||||
@@ -71,27 +58,9 @@ class Tenant(db.Model):
|
||||
'name': self.name,
|
||||
'website': self.website,
|
||||
'timezone': self.timezone,
|
||||
'rag_context': self.rag_context,
|
||||
'default_language': self.default_language,
|
||||
'allowed_languages': self.allowed_languages,
|
||||
'embedding_model': self.embedding_model,
|
||||
'llm_model': self.llm_model,
|
||||
'html_tags': self.html_tags,
|
||||
'html_end_tags': self.html_end_tags,
|
||||
'html_included_elements': self.html_included_elements,
|
||||
'html_excluded_elements': self.html_excluded_elements,
|
||||
'min_chunk_size': self.min_chunk_size,
|
||||
'max_chunk_size': self.max_chunk_size,
|
||||
'es_k': self.es_k,
|
||||
'es_similarity_threshold': self.es_similarity_threshold,
|
||||
'chat_RAG_temperature': self.chat_RAG_temperature,
|
||||
'chat_no_RAG_temperature': self.chat_no_RAG_temperature,
|
||||
'fallback_algorithms': self.fallback_algorithms,
|
||||
'license_start_date': self.license_start_date,
|
||||
'license_end_date': self.license_end_date,
|
||||
'allowed_monthly_interactions': self.allowed_monthly_interactions,
|
||||
'embed_tuning': self.embed_tuning,
|
||||
'rag_tuning': self.rag_tuning,
|
||||
'type': self.type,
|
||||
'currency': self.currency,
|
||||
'default_tenant_make_id': self.default_tenant_make_id,
|
||||
}
|
||||
|
||||
|
||||
@@ -124,6 +93,7 @@ class User(db.Model, UserMixin):
|
||||
|
||||
# User Information
|
||||
id = db.Column(db.Integer, primary_key=True)
|
||||
tenant_id = db.Column(db.Integer, db.ForeignKey('public.tenant.id'), nullable=False)
|
||||
user_name = db.Column(db.String(80), unique=True, nullable=False)
|
||||
email = db.Column(db.String(255), unique=True, nullable=False)
|
||||
password = db.Column(db.String(255), nullable=True)
|
||||
@@ -133,6 +103,8 @@ class User(db.Model, UserMixin):
|
||||
fs_uniquifier = db.Column(db.String(255), unique=True, nullable=False)
|
||||
confirmed_at = db.Column(db.DateTime, nullable=True)
|
||||
valid_to = db.Column(db.Date, nullable=True)
|
||||
is_primary_contact = db.Column(db.Boolean, nullable=True, default=False)
|
||||
is_financial_contact = db.Column(db.Boolean, nullable=True, default=False)
|
||||
|
||||
# Security Trackable Information
|
||||
last_login_at = db.Column(db.DateTime, nullable=True)
|
||||
@@ -143,7 +115,6 @@ class User(db.Model, UserMixin):
|
||||
|
||||
# Relations
|
||||
roles = db.relationship('Role', secondary=RolesUsers.__table__, backref=db.backref('users', lazy='dynamic'))
|
||||
tenant_id = db.Column(db.Integer, db.ForeignKey('public.tenant.id'), nullable=False)
|
||||
|
||||
def __repr__(self):
|
||||
return '<User %r>' % self.user_name
|
||||
@@ -151,7 +122,6 @@ class User(db.Model, UserMixin):
|
||||
def has_roles(self, *args):
|
||||
return any(role.name in args for role in self.roles)
|
||||
|
||||
|
||||
class TenantDomain(db.Model):
|
||||
__bind_key__ = 'public'
|
||||
__table_args__ = {'schema': 'public'}
|
||||
@@ -166,10 +136,264 @@ class TenantDomain(db.Model):
|
||||
|
||||
# Versioning Information
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now())
|
||||
created_by = db.Column(db.Integer, db.ForeignKey(User.id), nullable=False)
|
||||
created_by = db.Column(db.Integer, db.ForeignKey('public.user.id'), nullable=False)
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now(), onupdate=db.func.now())
|
||||
updated_by = db.Column(db.Integer, db.ForeignKey(User.id))
|
||||
updated_by = db.Column(db.Integer, db.ForeignKey('public.user.id'))
|
||||
|
||||
def __repr__(self):
|
||||
return f"<TenantDomain {self.id}: {self.domain}>"
|
||||
|
||||
|
||||
class TenantProject(db.Model):
|
||||
__bind_key__ = 'public'
|
||||
__table_args__ = {'schema': 'public'}
|
||||
|
||||
id = db.Column(db.Integer, primary_key=True)
|
||||
tenant_id = db.Column(db.Integer, db.ForeignKey('public.tenant.id'), nullable=False)
|
||||
name = db.Column(db.String(50), nullable=False)
|
||||
description = db.Column(db.Text, nullable=True)
|
||||
services = db.Column(ARRAY(sa.String(50)), nullable=False)
|
||||
encrypted_api_key = db.Column(db.String(500), nullable=True)
|
||||
visual_api_key = db.Column(db.String(20), nullable=True)
|
||||
active = db.Column(db.Boolean, nullable=False, default=True)
|
||||
responsible_email = db.Column(db.String(255), nullable=True)
|
||||
|
||||
# Versioning Information
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now())
|
||||
created_by = db.Column(db.Integer, db.ForeignKey('public.user.id'), nullable=True)
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now(), onupdate=db.func.now())
|
||||
updated_by = db.Column(db.Integer, db.ForeignKey('public.user.id'))
|
||||
|
||||
# Relations
|
||||
tenant = db.relationship('Tenant', backref='projects')
|
||||
|
||||
def __repr__(self):
|
||||
return f"<TenantProject {self.id}: {self.name}>"
|
||||
|
||||
|
||||
class TenantMake(db.Model):
|
||||
__bind_key__ = 'public'
|
||||
__table_args__ = {'schema': 'public'}
|
||||
|
||||
id = db.Column(db.Integer, primary_key=True)
|
||||
tenant_id = db.Column(db.Integer, db.ForeignKey('public.tenant.id'), nullable=False)
|
||||
name = db.Column(db.String(50), nullable=False, unique=True)
|
||||
description = db.Column(db.Text, nullable=True)
|
||||
active = db.Column(db.Boolean, nullable=False, default=True)
|
||||
website = db.Column(db.String(255), nullable=True)
|
||||
logo_url = db.Column(db.String(255), nullable=True)
|
||||
default_language = db.Column(db.String(2), nullable=True)
|
||||
allowed_languages = db.Column(ARRAY(sa.String(2)), nullable=True)
|
||||
|
||||
# Chat customisation options
|
||||
chat_customisation_options = db.Column(JSONB, nullable=True)
|
||||
|
||||
# Versioning Information
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now())
|
||||
created_by = db.Column(db.Integer, db.ForeignKey('public.user.id'), nullable=True)
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now(), onupdate=db.func.now())
|
||||
updated_by = db.Column(db.Integer, db.ForeignKey('public.user.id'))
|
||||
|
||||
def __repr__(self):
|
||||
return f"<TenantMake {self.id} for tenant {self.tenant_id}: {self.name}>"
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
'id': self.id,
|
||||
'name': self.name,
|
||||
'description': self.description,
|
||||
'active': self.active,
|
||||
'website': self.website,
|
||||
'logo_url': self.logo_url,
|
||||
'chat_customisation_options': self.chat_customisation_options,
|
||||
'allowed_languages': self.allowed_languages,
|
||||
'default_language': self.default_language,
|
||||
}
|
||||
|
||||
|
||||
class Partner(db.Model):
|
||||
__bind_key__ = 'public'
|
||||
__table_args__ = {'schema': 'public'}
|
||||
|
||||
id = db.Column(db.Integer, primary_key=True)
|
||||
tenant_id = db.Column(db.Integer, db.ForeignKey('public.tenant.id'), nullable=False, unique=True)
|
||||
code = db.Column(db.String(50), unique=True, nullable=False)
|
||||
|
||||
# Basic information
|
||||
logo_url = db.Column(db.String(255), nullable=True)
|
||||
active = db.Column(db.Boolean, default=True)
|
||||
|
||||
# Versioning Information
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now())
|
||||
created_by = db.Column(db.Integer, db.ForeignKey('public.user.id'), nullable=True)
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now(), onupdate=db.func.now())
|
||||
updated_by = db.Column(db.Integer, db.ForeignKey('public.user.id'), nullable=True)
|
||||
|
||||
# Relationships
|
||||
services = db.relationship('PartnerService', back_populates='partner')
|
||||
tenant = db.relationship('Tenant', backref=db.backref('partner', uselist=False))
|
||||
|
||||
def to_dict(self):
|
||||
services_info = []
|
||||
for service in self.services:
|
||||
services_info.append({
|
||||
'id': service.id,
|
||||
'name': service.name,
|
||||
'description': service.description,
|
||||
'type': service.type,
|
||||
'type_version': service.type_version,
|
||||
'active': service.active,
|
||||
'configuration': service.configuration,
|
||||
'permissions': service.permissions,
|
||||
})
|
||||
return {
|
||||
'id': self.id,
|
||||
'tenant_id': self.tenant_id,
|
||||
'code': self.code,
|
||||
'logo_url': self.logo_url,
|
||||
'active': self.active,
|
||||
'name': self.tenant.name,
|
||||
'services': services_info
|
||||
}
|
||||
|
||||
|
||||
class PartnerService(db.Model):
|
||||
__bind_key__ = 'public'
|
||||
__table_args__ = {'schema': 'public'}
|
||||
|
||||
id = db.Column(db.Integer, primary_key=True)
|
||||
partner_id = db.Column(db.Integer, db.ForeignKey('public.partner.id'), nullable=False)
|
||||
|
||||
# Basic info
|
||||
name = db.Column(db.String(50), nullable=False)
|
||||
description = db.Column(db.Text, nullable=True)
|
||||
|
||||
# Service type with versioning (similar to your specialist/retriever pattern)
|
||||
type = db.Column(db.String(50), nullable=False) # REFERRAL, KNOWLEDGE, SPECIALIST, IMPLEMENTATION, WHITE_LABEL
|
||||
type_version = db.Column(db.String(20), nullable=False, default="1.0.0")
|
||||
|
||||
# Status
|
||||
active = db.Column(db.Boolean, default=True)
|
||||
|
||||
# Dynamic configuration specific to this service - using JSONB like your other models
|
||||
configuration = db.Column(db.JSON, nullable=True)
|
||||
permissions = db.Column(db.JSON, nullable=True)
|
||||
|
||||
# For services that need to track shared resources
|
||||
system_metadata = db.Column(db.JSON, nullable=True)
|
||||
user_metadata = db.Column(db.JSON, nullable=True)
|
||||
|
||||
# Versioning Information
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now())
|
||||
created_by = db.Column(db.Integer, db.ForeignKey('public.user.id'), nullable=True)
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now(), onupdate=db.func.now())
|
||||
updated_by = db.Column(db.Integer, db.ForeignKey('public.user.id'), nullable=True)
|
||||
|
||||
# Relationships
|
||||
partner = db.relationship('Partner', back_populates='services')
|
||||
license_tiers = db.relationship('PartnerServiceLicenseTier', back_populates='partner_service')
|
||||
|
||||
|
||||
class PartnerTenant(db.Model):
|
||||
__bind_key__ = 'public'
|
||||
__table_args__ = {'schema': 'public'}
|
||||
|
||||
partner_service_id = db.Column(db.Integer, db.ForeignKey('public.partner_service.id'), primary_key=True)
|
||||
tenant_id = db.Column(db.Integer, db.ForeignKey('public.tenant.id'), primary_key=True)
|
||||
|
||||
# JSONB for flexible configuration specific to this relationship
|
||||
configuration = db.Column(db.JSON, nullable=True)
|
||||
|
||||
# Tracking
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now())
|
||||
created_by = db.Column(db.Integer, db.ForeignKey('public.user.id'), nullable=True)
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now(), onupdate=db.func.now())
|
||||
updated_by = db.Column(db.Integer, db.ForeignKey('public.user.id'), nullable=True)
|
||||
|
||||
|
||||
class TenantConsent(db.Model):
|
||||
__bind_key__ = 'public'
|
||||
__table_args__ = {'schema': 'public'}
|
||||
id = db.Column(db.Integer, primary_key=True)
|
||||
tenant_id = db.Column(db.Integer, db.ForeignKey('public.tenant.id'), nullable=False)
|
||||
partner_id = db.Column(db.Integer, db.ForeignKey('public.partner.id'), nullable=True)
|
||||
partner_service_id = db.Column(db.Integer, db.ForeignKey('public.partner_service.id'), nullable=True)
|
||||
user_id = db.Column(db.Integer, db.ForeignKey('public.user.id'), nullable=False)
|
||||
consent_type = db.Column(db.String(50), nullable=False)
|
||||
consent_date = db.Column(db.DateTime, nullable=False, server_default=db.func.now())
|
||||
consent_version = db.Column(db.String(20), nullable=False, default="1.0.0")
|
||||
consent_data = db.Column(db.JSON, nullable=False)
|
||||
|
||||
# Tracking
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now())
|
||||
created_by = db.Column(db.Integer, db.ForeignKey('public.user.id'), nullable=True)
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now(), onupdate=db.func.now())
|
||||
updated_by = db.Column(db.Integer, db.ForeignKey('public.user.id'), nullable=True)
|
||||
|
||||
|
||||
class ConsentVersion(db.Model):
|
||||
__bind_key__ = 'public'
|
||||
__table_args__ = {'schema': 'public'}
|
||||
id = db.Column(db.Integer, primary_key=True)
|
||||
consent_type = db.Column(db.String(50), nullable=False)
|
||||
consent_version = db.Column(db.String(20), nullable=False)
|
||||
consent_valid_from = db.Column(db.DateTime, nullable=False, server_default=db.func.now())
|
||||
consent_valid_to = db.Column(db.DateTime, nullable=True)
|
||||
|
||||
# Tracking
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now())
|
||||
created_by = db.Column(db.Integer, db.ForeignKey('public.user.id'), nullable=True)
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now(), onupdate=db.func.now())
|
||||
updated_by = db.Column(db.Integer, db.ForeignKey('public.user.id'), nullable=True)
|
||||
|
||||
|
||||
class ConsentStatus(str, Enum):
|
||||
CONSENTED = 'CONSENTED'
|
||||
NOT_CONSENTED = 'NOT_CONSENTED'
|
||||
RENEWAL_REQUIRED = 'RENEWAL_REQUIRED'
|
||||
CONSENT_EXPIRED = 'CONSENT_EXPIRED'
|
||||
UNKNOWN_CONSENT_VERSION = 'UNKNOWN_CONSENT_VERSION'
|
||||
|
||||
class SpecialistMagicLinkTenant(db.Model):
|
||||
__bind_key__ = 'public'
|
||||
__table_args__ = {'schema': 'public'}
|
||||
|
||||
magic_link_code = db.Column(db.String(55), primary_key=True)
|
||||
tenant_id = db.Column(db.Integer, db.ForeignKey('public.tenant.id'), nullable=False)
|
||||
|
||||
|
||||
class TranslationCache(db.Model):
|
||||
__bind_key__ = 'public'
|
||||
__table_args__ = {'schema': 'public'}
|
||||
|
||||
cache_key = db.Column(db.String(16), primary_key=True)
|
||||
source_text = db.Column(db.Text, nullable=False)
|
||||
translated_text = db.Column(db.Text, nullable=False)
|
||||
source_language = db.Column(db.String(2), nullable=True)
|
||||
target_language = db.Column(db.String(2), nullable=False)
|
||||
context = db.Column(db.Text, nullable=True)
|
||||
|
||||
# Translation cost
|
||||
prompt_tokens = db.Column(db.Integer, nullable=False)
|
||||
completion_tokens = db.Column(db.Integer, nullable=False)
|
||||
|
||||
# Tracking
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now())
|
||||
created_by = db.Column(db.Integer, db.ForeignKey('public.user.id'), nullable=True)
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now(), onupdate=db.func.now())
|
||||
updated_by = db.Column(db.Integer, db.ForeignKey('public.user.id'), nullable=True)
|
||||
|
||||
last_used_at = db.Column(db.DateTime, nullable=True)
|
||||
|
||||
|
||||
# class PartnerRAGRetriever(db.Model):
|
||||
# __bind_key__ = 'public'
|
||||
# __table_args__ = (
|
||||
# db.PrimaryKeyConstraint('tenant_id', 'retriever_id'),
|
||||
# db.UniqueConstraint('partner_id', 'tenant_id', 'retriever_id'),
|
||||
# {'schema': 'public'},
|
||||
# )
|
||||
#
|
||||
# partner_id = db.Column(db.Integer, db.ForeignKey('public.partner.id'), nullable=False)
|
||||
# tenant_id = db.Column(db.Integer, db.ForeignKey('public.tenant.id'), nullable=False)
|
||||
# retriever_id = db.Column(db.Integer, nullable=False)
|
||||
|
||||
9
common/services/entitlements/__init__.py
Normal file
9
common/services/entitlements/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from common.services.entitlements.license_period_services import LicensePeriodServices
|
||||
from common.services.entitlements.license_usage_services import LicenseUsageServices
|
||||
from common.services.entitlements.license_tier_services import LicenseTierServices
|
||||
|
||||
__all__ = [
|
||||
'LicensePeriodServices',
|
||||
'LicenseUsageServices',
|
||||
'LicenseTierServices'
|
||||
]
|
||||
247
common/services/entitlements/license_period_services.py
Normal file
247
common/services/entitlements/license_period_services.py
Normal file
@@ -0,0 +1,247 @@
|
||||
from dateutil.relativedelta import relativedelta
|
||||
from datetime import datetime as dt, timezone as tz, timedelta
|
||||
from flask import current_app
|
||||
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.sql.expression import and_
|
||||
|
||||
from common.extensions import db
|
||||
from common.models.entitlements import LicensePeriod, License, PeriodStatus, LicenseUsage
|
||||
from common.utils.eveai_exceptions import EveAILicensePeriodsExceeded, EveAIPendingLicensePeriod, EveAINoActiveLicense
|
||||
from common.utils.model_logging_utils import set_logging_information, update_logging_information
|
||||
|
||||
|
||||
class LicensePeriodServices:
|
||||
@staticmethod
|
||||
def find_current_license_period_for_usage(tenant_id: int) -> LicensePeriod:
|
||||
"""
|
||||
Find the current license period for a tenant. It ensures the status of the different license periods are adapted
|
||||
when required, and a LicenseUsage object is created if required.
|
||||
|
||||
Args:
|
||||
tenant_id: The ID of the tenant to find the license period for
|
||||
|
||||
Raises:
|
||||
EveAIException: and derived classes
|
||||
"""
|
||||
try:
|
||||
current_app.logger.debug(f"Finding current license period for tenant {tenant_id}")
|
||||
current_date = dt.now(tz.utc).date()
|
||||
license_period = (db.session.query(LicensePeriod)
|
||||
.filter_by(tenant_id=tenant_id)
|
||||
.filter(and_(LicensePeriod.period_start <= current_date,
|
||||
LicensePeriod.period_end >= current_date))
|
||||
.first())
|
||||
current_app.logger.debug(f"End searching for license period for tenant {tenant_id} ")
|
||||
if not license_period:
|
||||
current_app.logger.debug(f"No license period found for tenant {tenant_id} on date {current_date}")
|
||||
license_period = LicensePeriodServices._create_next_license_period_for_usage(tenant_id)
|
||||
current_app.logger.debug(f"Created license period {license_period.id} for tenant {tenant_id}")
|
||||
if license_period:
|
||||
current_app.logger.debug(f"Found license period {license_period.id} for tenant {tenant_id} "
|
||||
f"with status {license_period.status}")
|
||||
match license_period.status:
|
||||
case PeriodStatus.UPCOMING | PeriodStatus.PENDING:
|
||||
current_app.logger.debug(f"In upcoming state")
|
||||
LicensePeriodServices._complete_last_license_period(tenant_id=tenant_id)
|
||||
current_app.logger.debug(f"Completed last license period for tenant {tenant_id}")
|
||||
LicensePeriodServices._activate_license_period(license_period=license_period)
|
||||
current_app.logger.debug(f"Activated license period {license_period.id} for tenant {tenant_id}")
|
||||
if not license_period.license_usage:
|
||||
new_license_usage = LicenseUsage(
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
new_license_usage.license_period = license_period
|
||||
try:
|
||||
db.session.add(new_license_usage)
|
||||
db.session.commit()
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
db.session.rollback()
|
||||
current_app.logger.error(
|
||||
f"Error creating new license usage for license period "
|
||||
f"{license_period.id}: {str(e)}")
|
||||
raise e
|
||||
if license_period.status == PeriodStatus.ACTIVE:
|
||||
return license_period
|
||||
else:
|
||||
# Status is PENDING, so no prepaid payment received. There is no license period we can use.
|
||||
# We allow for a delay of 5 days before raising an exception.
|
||||
current_date = dt.now(tz.utc).date()
|
||||
delta = abs(current_date - license_period.period_start)
|
||||
if delta > timedelta(days=current_app.config.get('ENTITLEMENTS_MAX_PENDING_DAYS', 5)):
|
||||
raise EveAIPendingLicensePeriod()
|
||||
else:
|
||||
return license_period
|
||||
case PeriodStatus.ACTIVE:
|
||||
return license_period
|
||||
else:
|
||||
raise EveAILicensePeriodsExceeded(license_id=None)
|
||||
except SQLAlchemyError as e:
|
||||
db.session.rollback()
|
||||
current_app.logger.error(f"Error finding current license period for tenant {tenant_id}: {str(e)}")
|
||||
raise e
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
@staticmethod
|
||||
def _create_next_license_period_for_usage(tenant_id) -> LicensePeriod:
|
||||
"""
|
||||
Create a new period for this license using the current license configuration
|
||||
|
||||
Args:
|
||||
tenant_id: The ID of the tenant to create the period for
|
||||
|
||||
Returns:
|
||||
LicensePeriod: The newly created license period
|
||||
"""
|
||||
current_date = dt.now(tz.utc).date()
|
||||
|
||||
# Zoek de actieve licentie voor deze tenant op de huidige datum
|
||||
the_license = (db.session.query(License)
|
||||
.filter_by(tenant_id=tenant_id)
|
||||
.filter(License.start_date <= current_date)
|
||||
.filter(License.end_date >= current_date)
|
||||
.first())
|
||||
|
||||
if not the_license:
|
||||
current_app.logger.error(f"No active license found for tenant {tenant_id} on date {current_date}")
|
||||
raise EveAINoActiveLicense(tenant_id=tenant_id)
|
||||
else:
|
||||
current_app.logger.debug(f"Found active license {the_license.id} for tenant {tenant_id} "
|
||||
f"on date {current_date}")
|
||||
|
||||
next_period_number = 1
|
||||
if the_license.periods:
|
||||
# If there are existing periods, get the next sequential number
|
||||
next_period_number = max(p.period_number for p in the_license.periods) + 1
|
||||
current_app.logger.debug(f"Next period number for tenant {tenant_id} is {next_period_number}")
|
||||
|
||||
if next_period_number > the_license.nr_of_periods:
|
||||
raise EveAILicensePeriodsExceeded(license_id=the_license.id)
|
||||
|
||||
new_license_period = LicensePeriod(
|
||||
license_id=the_license.id,
|
||||
tenant_id=tenant_id,
|
||||
period_number=next_period_number,
|
||||
period_start=the_license.start_date + relativedelta(months=next_period_number-1),
|
||||
period_end=the_license.start_date + relativedelta(months=next_period_number, days=-1),
|
||||
status=PeriodStatus.UPCOMING,
|
||||
upcoming_at=dt.now(tz.utc),
|
||||
)
|
||||
set_logging_information(new_license_period, dt.now(tz.utc))
|
||||
|
||||
try:
|
||||
current_app.logger.debug(f"Creating next license period for tenant {tenant_id} ")
|
||||
db.session.add(new_license_period)
|
||||
db.session.commit()
|
||||
current_app.logger.info(f"Created next license period for tenant {tenant_id} "
|
||||
f"with id {new_license_period.id}")
|
||||
return new_license_period
|
||||
except SQLAlchemyError as e:
|
||||
db.session.rollback()
|
||||
current_app.logger.error(f"Error creating next license period for tenant {tenant_id}: {str(e)}")
|
||||
raise e
|
||||
|
||||
@staticmethod
|
||||
def _activate_license_period(license_period_id: int = None, license_period: LicensePeriod = None) -> LicensePeriod:
|
||||
"""
|
||||
Activate a license period
|
||||
|
||||
Args:
|
||||
license_period_id: The ID of the license period to activate (optional if license_period is provided)
|
||||
license_period: The LicensePeriod object to activate (optional if license_period_id is provided)
|
||||
|
||||
Returns:
|
||||
LicensePeriod: The activated license period object
|
||||
|
||||
Raises:
|
||||
ValueError: If neither license_period_id nor license_period is provided
|
||||
"""
|
||||
current_app.logger.debug(f"Activating license period")
|
||||
if license_period is None and license_period_id is None:
|
||||
raise ValueError("Either license_period_id or license_period must be provided")
|
||||
|
||||
# Get a license period object if only ID was provided
|
||||
if license_period is None:
|
||||
current_app.logger.debug(f"Getting license period {license_period_id} to activate")
|
||||
license_period = LicensePeriod.query.get_or_404(license_period_id)
|
||||
|
||||
if license_period.pending_at is not None:
|
||||
license_period.pending_at = dt.now(tz.utc)
|
||||
license_period.status = PeriodStatus.PENDING
|
||||
if license_period.prepaid_payment:
|
||||
# There is a payment received for the given period
|
||||
license_period.active_at = dt.now(tz.utc)
|
||||
license_period.status = PeriodStatus.ACTIVE
|
||||
|
||||
# Copy snapshot fields from the license to the period
|
||||
the_license = License.query.get_or_404(license_period.license_id)
|
||||
license_period.currency = the_license.currency
|
||||
license_period.basic_fee = the_license.basic_fee
|
||||
license_period.max_storage_mb = the_license.max_storage_mb
|
||||
license_period.additional_storage_price = the_license.additional_storage_price
|
||||
license_period.additional_storage_bucket = the_license.additional_storage_bucket
|
||||
license_period.included_embedding_mb = the_license.included_embedding_mb
|
||||
license_period.additional_embedding_price = the_license.additional_embedding_price
|
||||
license_period.additional_embedding_bucket = the_license.additional_embedding_bucket
|
||||
license_period.included_interaction_tokens = the_license.included_interaction_tokens
|
||||
license_period.additional_interaction_token_price = the_license.additional_interaction_token_price
|
||||
license_period.additional_interaction_bucket = the_license.additional_interaction_bucket
|
||||
license_period.additional_storage_allowed = the_license.additional_storage_allowed
|
||||
license_period.additional_embedding_allowed = the_license.additional_embedding_allowed
|
||||
license_period.additional_interaction_allowed = the_license.additional_interaction_allowed
|
||||
|
||||
update_logging_information(license_period, dt.now(tz.utc))
|
||||
|
||||
if not license_period.license_usage:
|
||||
license_period.license_usage = LicenseUsage(
|
||||
tenant_id=license_period.tenant_id,
|
||||
license_period_id=license_period.id,
|
||||
)
|
||||
|
||||
license_period.license_usage.recalculate_storage()
|
||||
|
||||
try:
|
||||
db.session.add(license_period)
|
||||
db.session.add(license_period.license_usage)
|
||||
db.session.commit()
|
||||
except SQLAlchemyError as e:
|
||||
db.session.rollback()
|
||||
current_app.logger.error(f"Error activating license period {license_period_id}: {str(e)}")
|
||||
raise e
|
||||
|
||||
return license_period
|
||||
|
||||
@staticmethod
|
||||
def _complete_last_license_period(tenant_id) -> None:
|
||||
"""
|
||||
Complete the active or pending license period for a tenant. This is done by setting the status to COMPLETED.
|
||||
|
||||
Args:
|
||||
tenant_id: De ID van de tenant
|
||||
"""
|
||||
# Zoek de licenseperiode voor deze tenant met status ACTIVE of PENDING
|
||||
active_period = (
|
||||
db.session.query(LicensePeriod)
|
||||
.filter_by(tenant_id=tenant_id)
|
||||
.filter(LicensePeriod.status.in_([PeriodStatus.ACTIVE, PeriodStatus.PENDING]))
|
||||
.first()
|
||||
)
|
||||
|
||||
# Als er geen actieve periode gevonden is, hoeven we niets te doen
|
||||
if not active_period:
|
||||
return
|
||||
|
||||
# Zet de gevonden periode op COMPLETED
|
||||
active_period.status = PeriodStatus.COMPLETED
|
||||
active_period.completed_at = dt.now(tz.utc)
|
||||
update_logging_information(active_period, dt.now(tz.utc))
|
||||
|
||||
try:
|
||||
db.session.add(active_period)
|
||||
db.session.commit()
|
||||
except SQLAlchemyError as e:
|
||||
db.session.rollback()
|
||||
current_app.logger.error(f"Error completing period {active_period.id} for {tenant_id}: {str(e)}")
|
||||
raise e
|
||||
67
common/services/entitlements/license_tier_services.py
Normal file
67
common/services/entitlements/license_tier_services.py
Normal file
@@ -0,0 +1,67 @@
|
||||
from flask import session, flash, current_app
|
||||
from datetime import datetime as dt, timezone as tz
|
||||
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from common.extensions import db
|
||||
from common.models.entitlements import PartnerServiceLicenseTier
|
||||
from common.models.user import Partner
|
||||
from common.utils.eveai_exceptions import EveAINoManagementPartnerService, EveAINoSessionPartner
|
||||
from common.utils.model_logging_utils import set_logging_information
|
||||
|
||||
|
||||
class LicenseTierServices:
|
||||
@staticmethod
|
||||
def associate_license_tier_with_partner(license_tier_id):
|
||||
"""Associate a license tier with a partner"""
|
||||
try:
|
||||
partner_id = session['partner']['id']
|
||||
# Get partner service (MANAGEMENT_SERVICE type)
|
||||
partner = Partner.query.get(partner_id)
|
||||
if not partner:
|
||||
raise EveAINoSessionPartner()
|
||||
|
||||
# Find a management service for this partner
|
||||
management_service = next((service for service in session['partner']['services']
|
||||
if service.get('type') == 'MANAGEMENT_SERVICE'), None)
|
||||
|
||||
if not management_service:
|
||||
flash("Cannot associate license tier with partner. No management service defined for partner", "danger")
|
||||
current_app.logger.error(f"No Management Service defined for partner {partner_id}"
|
||||
f"trying to associate license tier {license_tier_id}.")
|
||||
raise EveAINoManagementPartnerService()
|
||||
# Check if the association already exists
|
||||
existing_association = PartnerServiceLicenseTier.query.filter_by(
|
||||
partner_service_id=management_service['id'],
|
||||
license_tier_id=license_tier_id
|
||||
).first()
|
||||
|
||||
if existing_association:
|
||||
# Association already exists, nothing to do
|
||||
flash("License tier was already associated with partner", "info")
|
||||
current_app.logger.info(f"Association between partner service {management_service['id']} and "
|
||||
f"license tier {license_tier_id} already exists.")
|
||||
return
|
||||
|
||||
# Create the association
|
||||
association = PartnerServiceLicenseTier(
|
||||
partner_service_id=management_service['id'],
|
||||
license_tier_id=license_tier_id
|
||||
)
|
||||
set_logging_information(association, dt.now(tz.utc))
|
||||
|
||||
db.session.add(association)
|
||||
db.session.commit()
|
||||
|
||||
flash("Successfully associated license tier to partner", "success")
|
||||
current_app.logger.info(f"Successfully associated license tier {license_tier_id} with "
|
||||
f"partner service {management_service['id']}")
|
||||
|
||||
return True
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
db.session.rollback()
|
||||
flash("Failed to associated license tier with partner service due to an internal error. "
|
||||
"Please contact the System Administrator", "danger")
|
||||
current_app.logger.error(f"Error associating license tier {license_tier_id} with partner: {str(e)}")
|
||||
raise e
|
||||
143
common/services/entitlements/license_usage_services.py
Normal file
143
common/services/entitlements/license_usage_services.py
Normal file
@@ -0,0 +1,143 @@
|
||||
from dateutil.relativedelta import relativedelta
|
||||
from flask import session, current_app, flash
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.sql.expression import text
|
||||
|
||||
from common.extensions import db, cache_manager
|
||||
from common.models.entitlements import PartnerServiceLicenseTier, License, LicenseUsage, LicensePeriod, PeriodStatus
|
||||
from common.models.user import Partner, PartnerTenant
|
||||
from common.services.entitlements import LicensePeriodServices
|
||||
from common.utils.database import Database
|
||||
from common.utils.eveai_exceptions import EveAINoManagementPartnerService, EveAINoActiveLicense, \
|
||||
EveAIStorageQuotaExceeded, EveAIEmbeddingQuotaExceeded, EveAIInteractionQuotaExceeded, EveAILicensePeriodsExceeded, \
|
||||
EveAIException
|
||||
from common.utils.model_logging_utils import set_logging_information, update_logging_information
|
||||
from datetime import datetime as dt, timezone as tz
|
||||
|
||||
from common.utils.security_utils import current_user_has_role
|
||||
|
||||
|
||||
class LicenseUsageServices:
|
||||
@staticmethod
|
||||
def check_storage_and_embedding_quota(tenant_id: int, file_size_mb: float) -> None:
|
||||
"""
|
||||
Check if a tenant can add a new document without exceeding storage and embedding quotas
|
||||
|
||||
Args:
|
||||
tenant_id: ID of the tenant
|
||||
file_size_mb: Size of the file in MB
|
||||
|
||||
Raises:
|
||||
EveAIStorageQuotaExceeded: If storage quota would be exceeded
|
||||
EveAIEmbeddingQuotaExceeded: If embedding quota would be exceeded
|
||||
EveAINoActiveLicense: If no active license is found
|
||||
EveAIException: For other errors
|
||||
"""
|
||||
# Get active license period
|
||||
license_period = LicensePeriodServices.find_current_license_period_for_usage(tenant_id)
|
||||
# Early return if both overruns are allowed - no need to check usage at all
|
||||
if license_period.additional_storage_allowed and license_period.additional_embedding_allowed:
|
||||
return
|
||||
|
||||
# Check storage quota only if overruns are not allowed
|
||||
if not license_period.additional_storage_allowed:
|
||||
LicenseUsageServices._validate_storage_quota(license_period, file_size_mb)
|
||||
|
||||
# Check embedding quota only if overruns are not allowed
|
||||
if not license_period.additional_embedding_allowed:
|
||||
LicenseUsageServices._validate_embedding_quota(license_period, file_size_mb)
|
||||
|
||||
@staticmethod
|
||||
def check_embedding_quota(tenant_id: int, file_size_mb: float) -> None:
|
||||
"""
|
||||
Check if a tenant can re-embed a document without exceeding embedding quota
|
||||
|
||||
Args:
|
||||
tenant_id: ID of the tenant
|
||||
file_size_mb: Size of the file in MB
|
||||
|
||||
Raises:
|
||||
EveAIEmbeddingQuotaExceeded: If embedding quota would be exceeded
|
||||
EveAINoActiveLicense: If no active license is found
|
||||
EveAIException: For other errors
|
||||
"""
|
||||
# Get active license period
|
||||
license_period = LicensePeriodServices.find_current_license_period_for_usage(tenant_id)
|
||||
# Early return if both overruns are allowed - no need to check usage at all
|
||||
if license_period.additional_embedding_allowed:
|
||||
return
|
||||
|
||||
# Check embedding quota
|
||||
LicenseUsageServices._validate_embedding_quota(license_period, file_size_mb)
|
||||
|
||||
@staticmethod
|
||||
def check_interaction_quota(tenant_id: int) -> None:
|
||||
"""
|
||||
Check if a tenant can execute a specialist without exceeding interaction quota. As it is impossible to estimate
|
||||
the number of interaction tokens, we only check if the interaction quota are exceeded. So we might have a
|
||||
limited overrun.
|
||||
|
||||
Args:
|
||||
tenant_id: ID of the tenant
|
||||
|
||||
Raises:
|
||||
EveAIInteractionQuotaExceeded: If interaction quota would be exceeded
|
||||
EveAINoActiveLicense: If no active license is found
|
||||
EveAIException: For other errors
|
||||
"""
|
||||
# Get active license period
|
||||
license_period = LicensePeriodServices.find_current_license_period_for_usage(tenant_id)
|
||||
# Early return if both overruns are allowed - no need to check usage at all
|
||||
if license_period.additional_interaction_allowed:
|
||||
return
|
||||
|
||||
# Convert tokens to M tokens and check interaction quota
|
||||
LicenseUsageServices._validate_interaction_quota(license_period)
|
||||
|
||||
@staticmethod
|
||||
def _validate_storage_quota(license_period: LicensePeriod, additional_mb: float) -> None:
|
||||
"""Check storage quota and raise exception if exceeded"""
|
||||
current_storage = license_period.license_usage.storage_mb_used or 0
|
||||
projected_storage = current_storage + additional_mb
|
||||
max_storage = license_period.max_storage_mb
|
||||
|
||||
# Hard limit check (we only get here if overruns are NOT allowed)
|
||||
if projected_storage > max_storage:
|
||||
raise EveAIStorageQuotaExceeded(
|
||||
current_usage=current_storage,
|
||||
limit=max_storage,
|
||||
additional=additional_mb
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _validate_embedding_quota(license_period: LicensePeriod, additional_mb: float) -> None:
|
||||
"""Check embedding quota and raise exception if exceeded"""
|
||||
current_embedding = license_period.license_usage.embedding_mb_used or 0
|
||||
projected_embedding = current_embedding + additional_mb
|
||||
max_embedding = license_period.included_embedding_mb
|
||||
|
||||
# Hard limit check (we only get here if overruns are NOT allowed)
|
||||
if projected_embedding > max_embedding:
|
||||
raise EveAIEmbeddingQuotaExceeded(
|
||||
current_usage=current_embedding,
|
||||
limit=max_embedding,
|
||||
additional=additional_mb
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _validate_interaction_quota(license_period) -> None:
|
||||
"""Check interaction quota and raise exception if exceeded (tokens in millions). We might have an overrun!"""
|
||||
current_tokens = license_period.license_usage.interaction_total_tokens_used / 1_000_000 or 0
|
||||
max_tokens = license_period.included_interaction_tokens
|
||||
|
||||
# Hard limit check (we only get here if overruns are NOT allowed)
|
||||
if current_tokens > max_tokens:
|
||||
raise EveAIInteractionQuotaExceeded(
|
||||
current_usage=current_tokens,
|
||||
limit=max_tokens
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
9
common/services/interaction/asset_services.py
Normal file
9
common/services/interaction/asset_services.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from common.models.interaction import EveAIAsset
|
||||
from common.extensions import minio_client
|
||||
|
||||
|
||||
class AssetServices:
|
||||
|
||||
@staticmethod
|
||||
def add_or_replace_asset_file(asset_id, file_data):
|
||||
asset = EveAIAsset.query.get_or_404(asset_id)
|
||||
25
common/services/interaction/capsule_services.py
Normal file
25
common/services/interaction/capsule_services.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from datetime import datetime as dt, timezone as tz
|
||||
|
||||
from common.models.interaction import EveAIDataCapsule
|
||||
from common.extensions import db
|
||||
from common.utils.model_logging_utils import set_logging_information, update_logging_information
|
||||
|
||||
|
||||
class CapsuleServices:
|
||||
@staticmethod
|
||||
def push_capsule_data(chat_session_id: str, type: str, type_version: str, configuration: dict, data: dict):
|
||||
capsule = EveAIDataCapsule.query.filter_by(chat_session_id=chat_session_id, type=type, type_version=type_version).first()
|
||||
if capsule:
|
||||
# Update bestaande capsule als deze al bestaat
|
||||
capsule.configuration = configuration
|
||||
capsule.data = data
|
||||
update_logging_information(capsule, dt.now(tz.utc))
|
||||
else:
|
||||
# Maak nieuwe capsule aan als deze nog niet bestaat
|
||||
capsule = EveAIDataCapsule(chat_session_id=chat_session_id, type=type, type_version=type_version,
|
||||
configuration=configuration, data=data)
|
||||
set_logging_information(capsule, dt.now(tz.utc))
|
||||
db.session.add(capsule)
|
||||
|
||||
db.session.commit()
|
||||
return capsule
|
||||
239
common/services/interaction/specialist_services.py
Normal file
239
common/services/interaction/specialist_services.py
Normal file
@@ -0,0 +1,239 @@
|
||||
import uuid
|
||||
from datetime import datetime as dt, timezone as tz
|
||||
from typing import Dict, Any, Tuple, Optional
|
||||
from flask import current_app
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from common.extensions import db, cache_manager
|
||||
from common.models.interaction import (
|
||||
Specialist, EveAIAgent, EveAITask, EveAITool
|
||||
)
|
||||
from common.utils.celery_utils import current_celery
|
||||
from common.utils.model_logging_utils import set_logging_information, update_logging_information
|
||||
|
||||
|
||||
class SpecialistServices:
|
||||
@staticmethod
|
||||
def start_session() -> str:
|
||||
return f"CHAT_SESSION_{uuid.uuid4()}"
|
||||
|
||||
@staticmethod
|
||||
def execute_specialist(tenant_id, specialist_id, specialist_arguments, session_id, user_timezone) -> Dict[str, Any]:
|
||||
current_app.logger.debug(f"Before sending task for {specialist_id} with arguments {specialist_arguments}")
|
||||
task = current_celery.send_task(
|
||||
'execute_specialist',
|
||||
args=[tenant_id,
|
||||
specialist_id,
|
||||
specialist_arguments,
|
||||
session_id,
|
||||
user_timezone,
|
||||
],
|
||||
queue='llm_interactions'
|
||||
)
|
||||
current_app.logger.debug(f"Task sent for {specialist_id}, task ID: {task.id}")
|
||||
|
||||
return {
|
||||
'task_id': task.id,
|
||||
'status': 'queued',
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def initialize_specialist(specialist_id: int, specialist_type: str, specialist_version: str):
|
||||
"""
|
||||
Initialize an agentic specialist by creating all its components based on configuration.
|
||||
|
||||
Args:
|
||||
specialist_id: ID of the specialist to initialize
|
||||
specialist_type: Type of the specialist
|
||||
specialist_version: Version of the specialist type to use
|
||||
|
||||
Raises:
|
||||
ValueError: If specialist not found or invalid configuration
|
||||
SQLAlchemyError: If database operations fail
|
||||
"""
|
||||
config = cache_manager.specialists_config_cache.get_config(specialist_type, specialist_version)
|
||||
if not config:
|
||||
raise ValueError(f"No configuration found for {specialist_type} version {specialist_version}")
|
||||
if config['framework'] == 'langchain':
|
||||
pass # Langchain does not require additional items to be initialized. All configuration is in the specialist.
|
||||
|
||||
specialist = Specialist.query.get(specialist_id)
|
||||
if not specialist:
|
||||
raise ValueError(f"Specialist with ID {specialist_id} not found")
|
||||
|
||||
if config['framework'] == 'crewai':
|
||||
SpecialistServices.initialize_crewai_specialist(specialist, config)
|
||||
|
||||
@staticmethod
|
||||
def initialize_crewai_specialist(specialist: Specialist, config: Dict[str, Any]):
|
||||
timestamp = dt.now(tz=tz.utc)
|
||||
|
||||
try:
|
||||
# Initialize agents
|
||||
if 'agents' in config:
|
||||
for agent_config in config['agents']:
|
||||
SpecialistServices._create_agent(
|
||||
specialist_id=specialist.id,
|
||||
agent_type=agent_config['type'],
|
||||
agent_version=agent_config['version'],
|
||||
name=agent_config.get('name'),
|
||||
description=agent_config.get('description'),
|
||||
timestamp=timestamp
|
||||
)
|
||||
|
||||
# Initialize tasks
|
||||
if 'tasks' in config:
|
||||
for task_config in config['tasks']:
|
||||
SpecialistServices._create_task(
|
||||
specialist_id=specialist.id,
|
||||
task_type=task_config['type'],
|
||||
task_version=task_config['version'],
|
||||
name=task_config.get('name'),
|
||||
description=task_config.get('description'),
|
||||
timestamp=timestamp
|
||||
)
|
||||
|
||||
# Initialize tools
|
||||
if 'tools' in config:
|
||||
for tool_config in config['tools']:
|
||||
SpecialistServices._create_tool(
|
||||
specialist_id=specialist.id,
|
||||
tool_type=tool_config['type'],
|
||||
tool_version=tool_config['version'],
|
||||
name=tool_config.get('name'),
|
||||
description=tool_config.get('description'),
|
||||
timestamp=timestamp
|
||||
)
|
||||
|
||||
db.session.commit()
|
||||
current_app.logger.info(f"Successfully initialized crewai specialist {specialist.id}")
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
db.session.rollback()
|
||||
current_app.logger.error(f"Database error initializing crewai specialist {specialist.id}: {str(e)}")
|
||||
raise
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
current_app.logger.error(f"Error initializing crewai specialist {specialist.id}: {str(e)}")
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def _create_agent(
|
||||
specialist_id: int,
|
||||
agent_type: str,
|
||||
agent_version: str,
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
timestamp: Optional[dt] = None
|
||||
) -> EveAIAgent:
|
||||
"""Create an agent with the given configuration."""
|
||||
if timestamp is None:
|
||||
timestamp = dt.now(tz=tz.utc)
|
||||
|
||||
# Get agent configuration from cache
|
||||
agent_config = cache_manager.agents_config_cache.get_config(agent_type, agent_version)
|
||||
|
||||
agent = EveAIAgent(
|
||||
specialist_id=specialist_id,
|
||||
name=name or agent_config.get('name', agent_type),
|
||||
description=description or agent_config.get('metadata').get('description', ''),
|
||||
type=agent_type,
|
||||
type_version=agent_version,
|
||||
role=None,
|
||||
goal=None,
|
||||
backstory=None,
|
||||
tuning=False,
|
||||
configuration=None,
|
||||
arguments=None
|
||||
)
|
||||
|
||||
set_logging_information(agent, timestamp)
|
||||
|
||||
db.session.add(agent)
|
||||
current_app.logger.info(f"Created agent {agent.id} of type {agent_type}")
|
||||
return agent
|
||||
|
||||
@staticmethod
|
||||
def _create_task(
|
||||
specialist_id: int,
|
||||
task_type: str,
|
||||
task_version: str,
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
timestamp: Optional[dt] = None
|
||||
) -> EveAITask:
|
||||
"""Create a task with the given configuration."""
|
||||
if timestamp is None:
|
||||
timestamp = dt.now(tz=tz.utc)
|
||||
|
||||
# Get task configuration from cache
|
||||
task_config = cache_manager.tasks_config_cache.get_config(task_type, task_version)
|
||||
|
||||
task = EveAITask(
|
||||
specialist_id=specialist_id,
|
||||
name=name or task_config.get('name', task_type),
|
||||
description=description or task_config.get('metadata').get('description', ''),
|
||||
type=task_type,
|
||||
type_version=task_version,
|
||||
task_description=None,
|
||||
expected_output=None,
|
||||
tuning=False,
|
||||
configuration=None,
|
||||
arguments=None,
|
||||
context=None,
|
||||
asynchronous=False,
|
||||
)
|
||||
|
||||
set_logging_information(task, timestamp)
|
||||
|
||||
db.session.add(task)
|
||||
current_app.logger.info(f"Created task {task.id} of type {task_type}")
|
||||
return task
|
||||
|
||||
@staticmethod
|
||||
def _create_tool(
|
||||
specialist_id: int,
|
||||
tool_type: str,
|
||||
tool_version: str,
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
timestamp: Optional[dt] = None
|
||||
) -> EveAITool:
|
||||
"""Create a tool with the given configuration."""
|
||||
if timestamp is None:
|
||||
timestamp = dt.now(tz=tz.utc)
|
||||
|
||||
# Get tool configuration from cache
|
||||
tool_config = cache_manager.tools_config_cache.get_config(tool_type, tool_version)
|
||||
|
||||
tool = EveAITool(
|
||||
specialist_id=specialist_id,
|
||||
name=name or tool_config.get('name', tool_type),
|
||||
description=description or tool_config.get('metadata').get('description', ''),
|
||||
type=tool_type,
|
||||
type_version=tool_version,
|
||||
tuning=False,
|
||||
configuration=None,
|
||||
arguments=None,
|
||||
)
|
||||
|
||||
set_logging_information(tool, timestamp)
|
||||
|
||||
db.session.add(tool)
|
||||
current_app.logger.info(f"Created tool {tool.id} of type {tool_type}")
|
||||
return tool
|
||||
|
||||
@staticmethod
|
||||
def get_specialist_system_field(specialist_id, config_name, system_name):
|
||||
"""Get the value of a system field in a specialist's configuration. Returns the actual value, or None."""
|
||||
specialist = Specialist.query.get(specialist_id)
|
||||
if not specialist:
|
||||
raise ValueError(f"Specialist with ID {specialist_id} not found")
|
||||
config = cache_manager.specialists_config_cache.get_config(specialist.type, specialist.type_version)
|
||||
if not config:
|
||||
raise ValueError(f"No configuration found for {specialist.type} version {specialist.version}")
|
||||
potential_field = config.get(config_name, None)
|
||||
if potential_field:
|
||||
if potential_field.type == 'system' and potential_field.system_name == system_name:
|
||||
return specialist.configuration.get(config_name, None)
|
||||
return None
|
||||
6
common/services/user/__init__.py
Normal file
6
common/services/user/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from common.services.user.user_services import UserServices
|
||||
from common.services.user.partner_services import PartnerServices
|
||||
from common.services.user.tenant_services import TenantServices
|
||||
from common.services.user.consent_services import ConsentServices
|
||||
|
||||
__all__ = ['UserServices', 'PartnerServices', 'TenantServices', 'ConsentServices']
|
||||
254
common/services/user/consent_services.py
Normal file
254
common/services/user/consent_services.py
Normal file
@@ -0,0 +1,254 @@
|
||||
from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime as dt, timezone as tz
|
||||
from typing import List, Optional, Tuple, Dict
|
||||
|
||||
from flask import current_app, request, session
|
||||
from flask_security import current_user
|
||||
from sqlalchemy import desc
|
||||
from sqlalchemy.exc import SQLAlchemyError, IntegrityError
|
||||
|
||||
from common.extensions import db
|
||||
from common.models.user import TenantConsent, ConsentVersion, ConsentStatus, PartnerService, PartnerTenant, Tenant
|
||||
|
||||
|
||||
@dataclass
|
||||
class TypeStatus:
|
||||
consent_type: str
|
||||
status: ConsentStatus
|
||||
active_version: Optional[str]
|
||||
last_version: Optional[str]
|
||||
|
||||
|
||||
class ConsentServices:
|
||||
@staticmethod
|
||||
def get_required_consent_types() -> List[str]:
|
||||
return list(current_app.config.get("CONSENT_TYPES", []))
|
||||
|
||||
@staticmethod
|
||||
def get_active_consent_version(consent_type: str) -> Optional[ConsentVersion]:
|
||||
try:
|
||||
# Active version: the one with consent_valid_to IS NULL, latest for this type
|
||||
return (ConsentVersion.query
|
||||
.filter_by(consent_type=consent_type, consent_valid_to=None)
|
||||
.order_by(desc(ConsentVersion.consent_valid_from))
|
||||
.first())
|
||||
except SQLAlchemyError as e:
|
||||
current_app.logger.error(f"DB error in get_active_consent_version({consent_type}): {e}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_tenant_last_consent(tenant_id: int, consent_type: str) -> Optional[TenantConsent]:
|
||||
try:
|
||||
return (TenantConsent.query
|
||||
.filter_by(tenant_id=tenant_id, consent_type=consent_type)
|
||||
.order_by(desc(TenantConsent.id))
|
||||
.first())
|
||||
except SQLAlchemyError as e:
|
||||
current_app.logger.error(f"DB error in get_tenant_last_consent({tenant_id}, {consent_type}): {e}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def evaluate_type_status(tenant_id: int, consent_type: str) -> TypeStatus:
|
||||
active = ConsentServices.get_active_consent_version(consent_type)
|
||||
if not active:
|
||||
current_app.logger.error(f"No active ConsentVersion found for type {consent_type}")
|
||||
return TypeStatus(consent_type, ConsentStatus.UNKNOWN_CONSENT_VERSION, None, None)
|
||||
|
||||
last = ConsentServices.get_tenant_last_consent(tenant_id, consent_type)
|
||||
if not last:
|
||||
return TypeStatus(consent_type, ConsentStatus.NOT_CONSENTED, active.consent_version, None)
|
||||
|
||||
# If last consent equals active → CONSENTED
|
||||
if last.consent_version == active.consent_version:
|
||||
return TypeStatus(consent_type, ConsentStatus.CONSENTED, active.consent_version, last.consent_version)
|
||||
|
||||
# Else: last refers to an older version; check its ConsentVersion to see grace period
|
||||
prev_cv = ConsentVersion.query.filter_by(consent_type=consent_type,
|
||||
consent_version=last.consent_version).first()
|
||||
if not prev_cv:
|
||||
current_app.logger.error(f"Tenant {tenant_id} references unknown ConsentVersion {last.consent_version} for {consent_type}")
|
||||
return TypeStatus(consent_type, ConsentStatus.UNKNOWN_CONSENT_VERSION, active.consent_version, last.consent_version)
|
||||
|
||||
if prev_cv.consent_valid_to:
|
||||
now = dt.now(tz.utc)
|
||||
if prev_cv.consent_valid_to >= now:
|
||||
# Within transition window
|
||||
return TypeStatus(consent_type, ConsentStatus.RENEWAL_REQUIRED, active.consent_version, last.consent_version)
|
||||
else:
|
||||
return TypeStatus(consent_type, ConsentStatus.NOT_CONSENTED, active.consent_version, last.consent_version)
|
||||
else:
|
||||
# Should not happen if a newer active exists; treat as unknown config
|
||||
current_app.logger.error(f"Previous ConsentVersion without valid_to while a newer active exists for {consent_type}")
|
||||
return TypeStatus(consent_type, ConsentStatus.UNKNOWN_CONSENT_VERSION, active.consent_version, last.consent_version)
|
||||
|
||||
@staticmethod
|
||||
def aggregate_status(type_statuses: List[TypeStatus]) -> ConsentStatus:
|
||||
# Priority: UNKNOWN > NOT_CONSENTED > RENEWAL_REQUIRED > CONSENTED
|
||||
priorities = {
|
||||
ConsentStatus.UNKNOWN_CONSENT_VERSION: 4,
|
||||
ConsentStatus.NOT_CONSENTED: 3,
|
||||
ConsentStatus.RENEWAL_REQUIRED: 2,
|
||||
ConsentStatus.CONSENTED: 1,
|
||||
}
|
||||
if not type_statuses:
|
||||
return ConsentStatus.CONSENTED
|
||||
worst = max(type_statuses, key=lambda ts: priorities.get(ts.status, 0))
|
||||
return worst.status
|
||||
|
||||
@staticmethod
|
||||
def get_consent_status(tenant_id: int) -> ConsentStatus:
|
||||
statuses = [ConsentServices.evaluate_type_status(tenant_id, ct) for ct in ConsentServices.get_required_consent_types()]
|
||||
return ConsentServices.aggregate_status(statuses)
|
||||
|
||||
@staticmethod
|
||||
def _is_tenant_admin_for(tenant_id: int) -> bool:
|
||||
try:
|
||||
return current_user.is_authenticated and current_user.has_roles('Tenant Admin') and getattr(current_user, 'tenant_id', None) == tenant_id
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _is_management_partner_for(tenant_id: int) -> Tuple[bool, Optional[int], Optional[int]]:
|
||||
"""Return (allowed, partner_id, partner_service_id) for management partner context."""
|
||||
try:
|
||||
if not (current_user.is_authenticated and current_user.has_roles('Partner Admin')):
|
||||
return False, None, None
|
||||
# Check PartnerTenant relationship via MANAGEMENT_SERVICE
|
||||
ps = PartnerService.query.filter_by(type='MANAGEMENT_SERVICE').all()
|
||||
if not ps:
|
||||
return False, None, None
|
||||
ps_ids = [p.id for p in ps]
|
||||
pt = PartnerTenant.query.filter_by(tenant_id=tenant_id).filter(PartnerTenant.partner_service_id.in_(ps_ids)).first()
|
||||
if not pt:
|
||||
return False, None, None
|
||||
the_ps = PartnerService.query.get(pt.partner_service_id)
|
||||
return True, the_ps.partner_id if the_ps else None, the_ps.id if the_ps else None
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error in _is_management_partner_for: {e}")
|
||||
return False, None, None
|
||||
|
||||
@staticmethod
|
||||
def can_consent_on_behalf(tenant_id: int) -> Tuple[bool, str, Optional[int], Optional[int]]:
|
||||
# Returns: allowed, mode('tenant_admin'|'management_partner'), partner_id, partner_service_id
|
||||
if ConsentServices._is_tenant_admin_for(tenant_id):
|
||||
return True, 'tenant_admin', None, None
|
||||
allowed, partner_id, partner_service_id = ConsentServices._is_management_partner_for(tenant_id)
|
||||
if allowed:
|
||||
return True, 'management_partner', partner_id, partner_service_id
|
||||
return False, 'none', None, None
|
||||
|
||||
@staticmethod
|
||||
def _resolve_consent_content(consent_type: str, version: str) -> Dict:
|
||||
"""Resolve canonical file ref and hash for a consent document.
|
||||
Uses configurable base dir, type subpaths, and patch-dir strategy.
|
||||
Defaults:
|
||||
- base: 'content'
|
||||
- map: {'Data Privacy Agreement':'dpa','Terms & Conditions':'terms'}
|
||||
- strategy: 'major_minor' -> a.b.c => a.b/a.b.c.md
|
||||
- ext: '.md'
|
||||
"""
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
|
||||
cfg = current_app.config if current_app else {}
|
||||
base_dir = cfg.get('CONSENT_CONTENT_BASE_DIR', 'content')
|
||||
type_paths = cfg.get('CONSENT_TYPE_PATHS', {
|
||||
'Data Privacy Agreement': 'dpa',
|
||||
'Terms & Conditions': 'terms',
|
||||
})
|
||||
strategy = cfg.get('CONSENT_PATCH_DIR_STRATEGY', 'major_minor')
|
||||
ext = cfg.get('CONSENT_MARKDOWN_EXT', '.md')
|
||||
|
||||
type_dir = type_paths.get(consent_type, consent_type.lower().replace(' ', '_'))
|
||||
subpath = ''
|
||||
filename = f"{version}{ext}"
|
||||
try:
|
||||
parts = version.split('.')
|
||||
if strategy == 'major_minor' and len(parts) >= 2:
|
||||
subpath = f"{parts[0]}.{parts[1]}"
|
||||
filename = f"{parts[0]}.{parts[1]}.{parts[2] if len(parts)>2 else '0'}{ext}"
|
||||
# Build canonical path
|
||||
if subpath:
|
||||
canonical_ref = f"{base_dir}/{type_dir}/{subpath}/{filename}"
|
||||
else:
|
||||
canonical_ref = f"{base_dir}/{type_dir}/{filename}"
|
||||
except Exception:
|
||||
canonical_ref = f"{base_dir}/{type_dir}/{version}{ext}"
|
||||
|
||||
# Read file and hash
|
||||
content_hash = ''
|
||||
try:
|
||||
# project root = parent of app package
|
||||
root = Path(current_app.root_path).parent if current_app else Path('.')
|
||||
fpath = root / canonical_ref
|
||||
content_bytes = fpath.read_bytes() if fpath.exists() else b''
|
||||
content_hash = hashlib.sha256(content_bytes).hexdigest() if content_bytes else ''
|
||||
except Exception:
|
||||
content_hash = ''
|
||||
|
||||
return {
|
||||
'canonical_document_ref': canonical_ref,
|
||||
'content_hash': content_hash,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def record_consent(tenant_id: int, consent_type: str) -> TenantConsent:
|
||||
# Validate type
|
||||
if consent_type not in ConsentServices.get_required_consent_types():
|
||||
raise ValueError(f"Unknown consent type: {consent_type}")
|
||||
active = ConsentServices.get_active_consent_version(consent_type)
|
||||
if not active:
|
||||
raise RuntimeError(f"No active ConsentVersion for type {consent_type}")
|
||||
|
||||
allowed, mode, partner_id, partner_service_id = ConsentServices.can_consent_on_behalf(tenant_id)
|
||||
if not allowed:
|
||||
raise PermissionError("Not authorized to record consent for this tenant")
|
||||
|
||||
# Idempotency: if already consented for active version, return existing
|
||||
existing = (TenantConsent.query
|
||||
.filter_by(tenant_id=tenant_id, consent_type=consent_type, consent_version=active.consent_version)
|
||||
.first())
|
||||
if existing:
|
||||
return existing
|
||||
|
||||
# Build consent_data with audit info
|
||||
ip = request.headers.get('X-Forwarded-For', '').split(',')[0].strip() or request.remote_addr or ''
|
||||
ua = request.headers.get('User-Agent', '')
|
||||
locale = session.get('locale') or request.accept_languages.best or ''
|
||||
content_meta = ConsentServices._resolve_consent_content(consent_type, active.consent_version)
|
||||
consent_data = {
|
||||
'source_ip': ip,
|
||||
'user_agent': ua,
|
||||
'locale': locale,
|
||||
**content_meta,
|
||||
}
|
||||
|
||||
tc = TenantConsent(
|
||||
tenant_id=tenant_id,
|
||||
partner_id=partner_id,
|
||||
partner_service_id=partner_service_id,
|
||||
user_id=getattr(current_user, 'id', None) or 0,
|
||||
consent_type=consent_type,
|
||||
consent_version=active.consent_version,
|
||||
consent_data=consent_data,
|
||||
)
|
||||
try:
|
||||
db.session.add(tc)
|
||||
db.session.commit()
|
||||
current_app.logger.info(f"Consent recorded: tenant={tenant_id}, type={consent_type}, version={active.consent_version}, mode={mode}, user={getattr(current_user, 'id', None)}")
|
||||
return tc
|
||||
except IntegrityError as e:
|
||||
db.session.rollback()
|
||||
# In case of race, fetch existing
|
||||
current_app.logger.warning(f"IntegrityError on consent insert, falling back: {e}")
|
||||
existing = (TenantConsent.query
|
||||
.filter_by(tenant_id=tenant_id, consent_type=consent_type, consent_version=active.consent_version)
|
||||
.first())
|
||||
if existing:
|
||||
return existing
|
||||
raise
|
||||
except SQLAlchemyError as e:
|
||||
db.session.rollback()
|
||||
current_app.logger.error(f"DB error in record_consent: {e}")
|
||||
raise
|
||||
52
common/services/user/partner_services.py
Normal file
52
common/services/user/partner_services.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from typing import List, Dict, Any
|
||||
|
||||
from flask import session
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from common.models.entitlements import PartnerServiceLicenseTier
|
||||
from common.utils.eveai_exceptions import EveAINoManagementPartnerService, EveAINoSessionPartner
|
||||
|
||||
|
||||
|
||||
class PartnerServices:
|
||||
@staticmethod
|
||||
def get_allowed_license_tier_ids() -> List[int]:
|
||||
"""
|
||||
Retrieve IDs of all License Tiers associated with the partner's management service
|
||||
|
||||
Returns:
|
||||
List of license tier IDs
|
||||
|
||||
Raises:
|
||||
EveAINoSessionPartner: If no partner is in the session
|
||||
EveAINoManagementPartnerService: If partner has no management service
|
||||
"""
|
||||
partner = session.get("partner", None)
|
||||
if not partner:
|
||||
raise EveAINoSessionPartner()
|
||||
|
||||
# Find a management service for this partner
|
||||
management_service = next((service for service in session['partner']['services']
|
||||
if service.get('type') == 'MANAGEMENT_SERVICE'), None)
|
||||
if not management_service:
|
||||
raise EveAINoManagementPartnerService()
|
||||
management_service_id = management_service['id']
|
||||
|
||||
# Query for all license tiers associated with this management service
|
||||
associations = PartnerServiceLicenseTier.query.filter_by(
|
||||
partner_service_id=management_service_id
|
||||
).all()
|
||||
|
||||
# Extract the license tier IDs
|
||||
license_tier_ids = [assoc.license_tier_id for assoc in associations]
|
||||
|
||||
return license_tier_ids
|
||||
|
||||
@staticmethod
|
||||
def get_management_service() -> Dict[str, Any]:
|
||||
management_service = next((service for service in session['partner']['services']
|
||||
if service.get('type') == 'MANAGEMENT_SERVICE'), None)
|
||||
return management_service
|
||||
|
||||
|
||||
|
||||
182
common/services/user/tenant_services.py
Normal file
182
common/services/user/tenant_services.py
Normal file
@@ -0,0 +1,182 @@
|
||||
from typing import Dict, List
|
||||
|
||||
from flask import session, current_app
|
||||
from sqlalchemy import desc
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from common.extensions import db, cache_manager
|
||||
from common.models.user import Partner, PartnerTenant, PartnerService, Tenant, TenantConsent, ConsentStatus, \
|
||||
ConsentVersion
|
||||
from common.utils.eveai_exceptions import EveAINoManagementPartnerService
|
||||
from common.utils.model_logging_utils import set_logging_information
|
||||
from datetime import datetime as dt, timezone as tz
|
||||
|
||||
|
||||
|
||||
class TenantServices:
|
||||
@staticmethod
|
||||
def associate_tenant_with_partner(tenant_id):
|
||||
"""Associate a tenant with a partner"""
|
||||
try:
|
||||
partner_id = session['partner']['id']
|
||||
# Get partner service (MANAGEMENT_SERVICE type)
|
||||
partner = Partner.query.get(partner_id)
|
||||
if not partner:
|
||||
return
|
||||
|
||||
# Find a management service for this partner
|
||||
management_service = next((service for service in session['partner']['services']
|
||||
if service.get('type') == 'MANAGEMENT_SERVICE'), None)
|
||||
|
||||
if not management_service:
|
||||
current_app.logger.error(f"No Management Service defined for partner {partner_id} "
|
||||
f"while associating tenant {tenant_id} with partner.")
|
||||
raise EveAINoManagementPartnerService()
|
||||
|
||||
# Create the association
|
||||
tenant_partner = PartnerTenant(
|
||||
partner_service_id=management_service['id'],
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
set_logging_information(tenant_partner, dt.now(tz.utc))
|
||||
|
||||
db.session.add(tenant_partner)
|
||||
db.session.commit()
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
db.session.rollback()
|
||||
current_app.logger.error(f"Error associating tenant {tenant_id} with partner: {str(e)}")
|
||||
raise e
|
||||
|
||||
@staticmethod
|
||||
def get_available_types_for_tenant(tenant_id: int, config_type: str) -> Dict[str, Dict[str, str]]:
|
||||
"""
|
||||
Get available configuration types for a tenant based on partner relationships
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant ID
|
||||
config_type: The configuration type ('specialists', 'agents', 'tasks', etc.)
|
||||
|
||||
Returns:
|
||||
Dictionary of available types for the tenant
|
||||
"""
|
||||
# Get the appropriate cache handler based on config_type
|
||||
cache_handler = None
|
||||
if config_type == 'specialists':
|
||||
cache_handler = cache_manager.specialists_types_cache
|
||||
elif config_type == 'agents':
|
||||
cache_handler = cache_manager.agents_types_cache
|
||||
elif config_type == 'tasks':
|
||||
cache_handler = cache_manager.tasks_types_cache
|
||||
elif config_type == 'tools':
|
||||
cache_handler = cache_manager.tools_types_cache
|
||||
elif config_type == 'catalogs':
|
||||
cache_handler = cache_manager.catalogs_types_cache
|
||||
elif config_type == 'retrievers':
|
||||
cache_handler = cache_manager.retrievers_types_cache
|
||||
else:
|
||||
raise ValueError(f"Unsupported config type: {config_type}")
|
||||
|
||||
# Get all types with their metadata (including partner info)
|
||||
all_types = cache_handler.get_types()
|
||||
|
||||
# Filter to include:
|
||||
# 1. Types with no partner (global)
|
||||
# 2. Types with partners that have a SPECIALIST_SERVICE relationship with this tenant
|
||||
available_partners = TenantServices.get_tenant_partner_specialist_denominators(tenant_id)
|
||||
|
||||
available_types = {
|
||||
type_id: info for type_id, info in all_types.items()
|
||||
if info.get('partner') is None or info.get('partner') in available_partners
|
||||
}
|
||||
|
||||
return available_types
|
||||
|
||||
@staticmethod
|
||||
def get_tenant_partner_specialist_denominators(tenant_id: int) -> List[str]:
|
||||
"""
|
||||
Get names of partners that have a SPECIALIST_SERVICE relationship with this tenant, that can be used for
|
||||
filtering configurations.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant ID
|
||||
|
||||
Returns:
|
||||
List of partner names (tenant names)
|
||||
"""
|
||||
# Find all PartnerTenant relationships for this tenant
|
||||
partner_service_denominators = []
|
||||
try:
|
||||
# Get all partner services of type SPECIALIST_SERVICE
|
||||
specialist_services = (
|
||||
PartnerService.query
|
||||
.filter_by(type='SPECIALIST_SERVICE')
|
||||
.all()
|
||||
)
|
||||
|
||||
if not specialist_services:
|
||||
return []
|
||||
|
||||
# Find tenant relationships with these services
|
||||
partner_tenants = (
|
||||
PartnerTenant.query
|
||||
.filter_by(tenant_id=tenant_id)
|
||||
.filter(PartnerTenant.partner_service_id.in_([svc.id for svc in specialist_services]))
|
||||
.all()
|
||||
)
|
||||
|
||||
# Get the partner names (their tenant names)
|
||||
for pt in partner_tenants:
|
||||
partner_service = (
|
||||
PartnerService.query
|
||||
.filter_by(id=pt.partner_service_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if partner_service:
|
||||
partner_service_denominators.append(partner_service.configuration.get("specialist_denominator", ""))
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
current_app.logger.error(f"Database error retrieving partner names: {str(e)}")
|
||||
|
||||
return partner_service_denominators
|
||||
|
||||
@staticmethod
|
||||
def can_use_specialist_type(tenant_id: int, specialist_type: str) -> bool:
|
||||
"""
|
||||
Check if a tenant can use a specific specialist type
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant ID
|
||||
specialist_type: The specialist type ID
|
||||
|
||||
Returns:
|
||||
True if the tenant can use the specialist type, False otherwise
|
||||
"""
|
||||
# Get the specialist type definition
|
||||
try:
|
||||
specialist_types = cache_manager.specialists_types_cache.get_types()
|
||||
specialist_def = specialist_types.get(specialist_type)
|
||||
|
||||
if not specialist_def:
|
||||
return False
|
||||
|
||||
# If it's a global specialist, anyone can use it
|
||||
if specialist_def.get('partner') is None:
|
||||
return True
|
||||
|
||||
# If it's a partner-specific specialist, check if tenant has access
|
||||
partner_name = specialist_def.get('partner')
|
||||
available_partners = TenantServices.get_tenant_partner_specialist_denominators(tenant_id)
|
||||
|
||||
return partner_name in available_partners
|
||||
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error checking specialist type access: {str(e)}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def get_consent_status(tenant_id: int) -> ConsentStatus:
|
||||
# Delegate to centralized ConsentService to ensure consistent logic
|
||||
from common.services.user.consent_services import ConsentServices
|
||||
return ConsentServices.get_consent_status(tenant_id)
|
||||
95
common/services/user/user_services.py
Normal file
95
common/services/user/user_services.py
Normal file
@@ -0,0 +1,95 @@
|
||||
from flask import session
|
||||
|
||||
from common.models.user import Partner, Role, PartnerTenant
|
||||
|
||||
from common.utils.eveai_exceptions import EveAIRoleAssignmentException
|
||||
from common.utils.security_utils import current_user_has_role
|
||||
|
||||
|
||||
class UserServices:
|
||||
@staticmethod
|
||||
def get_assignable_roles():
|
||||
"""Retrieves roles that can be assigned to a user depending on the current user logged in,
|
||||
and the active tenant for the session"""
|
||||
current_tenant_id = session.get('tenant').get('id', None)
|
||||
effective_role_names = []
|
||||
if current_tenant_id == 1:
|
||||
if current_user_has_role("Super User"):
|
||||
effective_role_names.append("Super User")
|
||||
elif current_tenant_id:
|
||||
if current_user_has_role("Tenant Admin"):
|
||||
effective_role_names.append("Tenant Admin")
|
||||
if current_user_has_role("Partner Admin") or current_user_has_role("Super User"):
|
||||
effective_role_names.append("Tenant Admin")
|
||||
if session.get('partner'):
|
||||
if session.get('partner').get('tenant_id') == current_tenant_id:
|
||||
effective_role_names.append("Partner Admin")
|
||||
effective_role_names = list(set(effective_role_names))
|
||||
effective_roles = [(role.id, role.name) for role in
|
||||
Role.query.filter(Role.name.in_(effective_role_names)).all()]
|
||||
return effective_roles
|
||||
|
||||
@staticmethod
|
||||
def validate_role_assignments(role_ids):
|
||||
"""Validate a set of role assignments, raising exception for first invalid role"""
|
||||
assignable_roles = UserServices.get_assignable_roles()
|
||||
assignable_role_ids = {role[0] for role in assignable_roles}
|
||||
role_id_set = set(role_ids)
|
||||
return role_id_set.issubset(assignable_role_ids)
|
||||
|
||||
@staticmethod
|
||||
def can_user_edit_tenant(tenant_id) -> bool:
|
||||
if current_user_has_role('Super User'):
|
||||
return True
|
||||
elif current_user_has_role('Partner Admin'):
|
||||
partner = session.get('partner', None)
|
||||
if partner and partner["tenant_id"] == tenant_id:
|
||||
return True
|
||||
partner_service = next((service for service in session['partner']['services']
|
||||
if service.get('type') == 'MANAGEMENT_SERVICE'), None)
|
||||
if not partner_service:
|
||||
return False
|
||||
else:
|
||||
partner_tenant = PartnerTenant.query.filter(
|
||||
PartnerTenant.tenant_id == tenant_id,
|
||||
PartnerTenant.partner_service_id == partner_service['id'],
|
||||
).first()
|
||||
if partner_tenant:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def can_user_create_tenant() -> bool:
|
||||
if current_user_has_role('Super User'):
|
||||
return True
|
||||
elif current_user_has_role('Partner Admin'):
|
||||
partner_id = session['partner']['id']
|
||||
partner_service = next((service for service in session['partner']['services']
|
||||
if service.get('type') == 'MANAGEMENT_SERVICE'), None)
|
||||
if not partner_service:
|
||||
return False
|
||||
else:
|
||||
partner_permissions = partner_service.get('permissions', None)
|
||||
return partner_permissions.get('can_create_tenant', False)
|
||||
else:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def can_user_assign_license() -> bool:
|
||||
if current_user_has_role('Super User'):
|
||||
return True
|
||||
elif current_user_has_role('Partner Admin'):
|
||||
partner_id = session['partner']['id']
|
||||
partner_service = next((service for service in session['partner']['services']
|
||||
if service.get('type') == 'MANAGEMENT_SERVICE'), None)
|
||||
if not partner_service:
|
||||
return False
|
||||
else:
|
||||
partner_permissions = partner_service.get('permissions', None)
|
||||
return partner_permissions.get('can_assign_license', False)
|
||||
else:
|
||||
return False
|
||||
|
||||
108
common/services/utils/human_answer_services.py
Normal file
108
common/services/utils/human_answer_services.py
Normal file
@@ -0,0 +1,108 @@
|
||||
from flask import current_app, session
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_core.runnables import RunnablePassthrough
|
||||
|
||||
from common.utils.business_event import BusinessEvent
|
||||
from common.utils.business_event_context import current_event
|
||||
from common.utils.model_utils import get_template
|
||||
from eveai_chat_workers.outputs.globals.a2q_output.q_a_output_v1_0 import A2QOutput
|
||||
from eveai_chat_workers.outputs.globals.q_a_output.q_a_output_v1_0 import QAOutput
|
||||
|
||||
|
||||
class HumanAnswerServices:
|
||||
@staticmethod
|
||||
def check_affirmative_answer(tenant_id: int, question: str, answer: str, language_iso: str) -> bool:
|
||||
return HumanAnswerServices._check_answer(tenant_id, question, answer, language_iso, "check_affirmative_answer",
|
||||
"Check Affirmative Answer")
|
||||
|
||||
@staticmethod
|
||||
def check_additional_information(tenant_id: int, question: str, answer: str, language_iso: str) -> bool:
|
||||
result = HumanAnswerServices._check_answer(tenant_id, question, answer, language_iso,
|
||||
"check_additional_information", "Check Additional Information")
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def get_answer_to_question(tenant_id: int, question: str, answer: str, language_iso: str) -> str:
|
||||
|
||||
language = HumanAnswerServices._process_arguments(question, answer, language_iso)
|
||||
span_name = "Get Answer To Question"
|
||||
template_name = "get_answer_to_question"
|
||||
|
||||
if not current_event:
|
||||
with BusinessEvent('Answer Check Service', tenant_id):
|
||||
with current_event.create_span(span_name):
|
||||
return HumanAnswerServices._get_answer_to_question_logic(question, answer, language, template_name)
|
||||
else:
|
||||
with current_event.create_span('Check Affirmative Answer'):
|
||||
return HumanAnswerServices._get_answer_to_question_logic(question, answer, language, template_name)
|
||||
|
||||
@staticmethod
|
||||
def _check_answer(tenant_id: int, question: str, answer: str, language_iso: str, template_name: str,
|
||||
span_name: str) -> bool:
|
||||
language = HumanAnswerServices._process_arguments(question, answer, language_iso)
|
||||
if not current_event:
|
||||
with BusinessEvent('Answer Check Service', tenant_id):
|
||||
with current_event.create_span(span_name):
|
||||
return HumanAnswerServices._check_answer_logic(question, answer, language, template_name)
|
||||
else:
|
||||
with current_event.create_span(span_name):
|
||||
return HumanAnswerServices._check_answer_logic(question, answer, language, template_name)
|
||||
|
||||
@staticmethod
|
||||
def _check_answer_logic(question: str, answer: str, language: str, template_name: str) -> bool:
|
||||
prompt_params = {
|
||||
'question': question,
|
||||
'answer': answer,
|
||||
'language': language,
|
||||
}
|
||||
|
||||
template, llm = get_template(template_name)
|
||||
check_answer_prompt = ChatPromptTemplate.from_template(template)
|
||||
setup = RunnablePassthrough()
|
||||
|
||||
output_schema = QAOutput
|
||||
structured_llm = llm.with_structured_output(output_schema)
|
||||
|
||||
chain = (setup | check_answer_prompt | structured_llm )
|
||||
|
||||
raw_answer = chain.invoke(prompt_params)
|
||||
|
||||
return raw_answer.answer
|
||||
|
||||
@staticmethod
|
||||
def _get_answer_to_question_logic(question: str, answer: str, language: str, template_name: str) \
|
||||
-> str:
|
||||
prompt_params = {
|
||||
'question': question,
|
||||
'answer': answer,
|
||||
'language': language,
|
||||
}
|
||||
|
||||
template, llm = get_template(template_name)
|
||||
check_answer_prompt = ChatPromptTemplate.from_template(template)
|
||||
setup = RunnablePassthrough()
|
||||
|
||||
output_schema = A2QOutput
|
||||
structured_llm = llm.with_structured_output(output_schema)
|
||||
|
||||
chain = (setup | check_answer_prompt | structured_llm)
|
||||
|
||||
raw_answer = chain.invoke(prompt_params)
|
||||
|
||||
return raw_answer.answer
|
||||
|
||||
@staticmethod
|
||||
def _process_arguments(question, answer, language_iso: str) -> str:
|
||||
if language_iso.strip() == '':
|
||||
raise ValueError("Language cannot be empty")
|
||||
language = current_app.config.get('SUPPORTED_LANGUAGE_ISO639_1_LOOKUP').get(language_iso)
|
||||
if language is None:
|
||||
raise ValueError(f"Unsupported language: {language_iso}")
|
||||
if question.strip() == '':
|
||||
raise ValueError("Question cannot be empty")
|
||||
if answer.strip() == '':
|
||||
raise ValueError("Answer cannot be empty")
|
||||
|
||||
return language
|
||||
203
common/services/utils/translation_services.py
Normal file
203
common/services/utils/translation_services.py
Normal file
@@ -0,0 +1,203 @@
|
||||
import json
|
||||
import copy
|
||||
import re
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
from flask import session
|
||||
|
||||
from common.extensions import cache_manager
|
||||
from common.utils.business_event import BusinessEvent
|
||||
from common.utils.business_event_context import current_event
|
||||
|
||||
class TranslationServices:
|
||||
|
||||
@staticmethod
|
||||
def translate_config(tenant_id: int, config_data: Dict[str, Any], field_config: str, target_language: str,
|
||||
source_language: Optional[str] = None, context: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Vertaalt een configuratie op basis van een veld-configuratie.
|
||||
|
||||
Args:
|
||||
tenant_id: Identificatie van de tenant waarvoor we de vertaling doen.
|
||||
config_data: Een dictionary of JSON (die dan wordt geconverteerd naar een dictionary) met configuratiegegevens
|
||||
field_config: De naam van een veld-configuratie (bijv. 'fields')
|
||||
target_language: De taal waarnaar vertaald moet worden
|
||||
source_language: Optioneel, de brontaal van de configuratie
|
||||
context: Optioneel, een specifieke context voor de vertaling
|
||||
|
||||
Returns:
|
||||
Een dictionary met de vertaalde configuratie
|
||||
"""
|
||||
config_type = config_data.get('type', 'Unknown')
|
||||
config_version = config_data.get('version', 'Unknown')
|
||||
span_name = f"{config_type}-{config_version}-{field_config}"
|
||||
|
||||
if current_event:
|
||||
with current_event.create_span(span_name):
|
||||
translated_config = TranslationServices._translate_config(tenant_id, config_data, field_config,
|
||||
target_language, source_language, context)
|
||||
return translated_config
|
||||
else:
|
||||
with BusinessEvent('Config Translation Service', tenant_id):
|
||||
with current_event.create_span(span_name):
|
||||
translated_config = TranslationServices._translate_config(tenant_id, config_data, field_config,
|
||||
target_language, source_language, context)
|
||||
return translated_config
|
||||
|
||||
@staticmethod
|
||||
def _translate_config(tenant_id: int, config_data: Dict[str, Any], field_config: str, target_language: str,
|
||||
source_language: Optional[str] = None, context: Optional[str] = None) -> Dict[str, Any]:
|
||||
|
||||
# Zorg ervoor dat we een dictionary hebben
|
||||
if isinstance(config_data, str):
|
||||
config_data = json.loads(config_data)
|
||||
|
||||
# Maak een deep copy van de originele data om te wijzigen en input-mutatie te vermijden
|
||||
translated_config = copy.deepcopy(config_data)
|
||||
|
||||
# Haal type en versie op voor de Business Event span
|
||||
config_type = config_data.get('type', 'Unknown')
|
||||
config_version = config_data.get('version', 'Unknown')
|
||||
|
||||
if field_config in config_data:
|
||||
fields = config_data[field_config]
|
||||
|
||||
# Haal description uit metadata voor context als geen context is opgegeven
|
||||
description_context = ""
|
||||
if not context and 'metadata' in config_data and 'description' in config_data['metadata']:
|
||||
description_context = config_data['metadata']['description']
|
||||
|
||||
# Hulpfuncties
|
||||
def is_nonempty_str(val: Any) -> bool:
|
||||
return isinstance(val, str) and val.strip() != ''
|
||||
|
||||
def safe_translate(text: str, ctx: Optional[str]):
|
||||
try:
|
||||
res = cache_manager.translation_cache.get_translation(
|
||||
text=text,
|
||||
target_lang=target_language,
|
||||
source_lang=source_language,
|
||||
context=ctx
|
||||
)
|
||||
return res.translated_text if res else None
|
||||
except Exception as e:
|
||||
if current_event:
|
||||
current_event.log_error('translation_error', {
|
||||
'tenant_id': tenant_id,
|
||||
'config_type': config_type,
|
||||
'config_version': config_version,
|
||||
'field_config': field_config,
|
||||
'error': str(e)
|
||||
})
|
||||
return None
|
||||
|
||||
tag_pair_pattern = re.compile(r'<([a-zA-Z][\w-]*)>[\s\S]*?<\/\1>')
|
||||
|
||||
def extract_tag_counts(text: str) -> Dict[str, int]:
|
||||
counts: Dict[str, int] = {}
|
||||
for m in tag_pair_pattern.finditer(text or ''):
|
||||
tag = m.group(1)
|
||||
counts[tag] = counts.get(tag, 0) + 1
|
||||
return counts
|
||||
|
||||
def tags_valid(source: str, translated: str) -> bool:
|
||||
return extract_tag_counts(source) == extract_tag_counts(translated)
|
||||
|
||||
# Counters
|
||||
meta_consentRich_translated_count = 0
|
||||
meta_aria_translated_count = 0
|
||||
meta_inline_tags_invalid_after_translation_count = 0
|
||||
|
||||
# Loop door elk veld in de configuratie
|
||||
for field_name, field_data in fields.items():
|
||||
# Vertaal name als het bestaat en niet leeg is (alleen strings)
|
||||
if 'name' in field_data and is_nonempty_str(field_data['name']):
|
||||
field_context = context if context else description_context
|
||||
t = safe_translate(field_data['name'], field_context)
|
||||
if t:
|
||||
translated_config[field_config][field_name]['name'] = t
|
||||
|
||||
if 'title' in field_data and is_nonempty_str(field_data.get('title')):
|
||||
field_context = context if context else description_context
|
||||
t = safe_translate(field_data['title'], field_context)
|
||||
if t:
|
||||
translated_config[field_config][field_name]['title'] = t
|
||||
|
||||
# Vertaal description als het bestaat en niet leeg is
|
||||
if 'description' in field_data and is_nonempty_str(field_data.get('description')):
|
||||
field_context = context if context else description_context
|
||||
t = safe_translate(field_data['description'], field_context)
|
||||
if t:
|
||||
translated_config[field_config][field_name]['description'] = t
|
||||
|
||||
# Vertaal context als het bestaat en niet leeg is
|
||||
if 'context' in field_data and is_nonempty_str(field_data.get('context')):
|
||||
t = safe_translate(field_data['context'], context)
|
||||
if t:
|
||||
translated_config[field_config][field_name]['context'] = t
|
||||
|
||||
# vertaal allowed_values als het veld bestaat en waarden niet leeg zijn (alleen string-items)
|
||||
if 'allowed_values' in field_data and isinstance(field_data['allowed_values'], list) and field_data['allowed_values']:
|
||||
translated_allowed_values = []
|
||||
for allowed_value in field_data['allowed_values']:
|
||||
if is_nonempty_str(allowed_value):
|
||||
t = safe_translate(allowed_value, context)
|
||||
translated_allowed_values.append(t if t else allowed_value)
|
||||
else:
|
||||
translated_allowed_values.append(allowed_value)
|
||||
if translated_allowed_values:
|
||||
translated_config[field_config][field_name]['allowed_values'] = translated_allowed_values
|
||||
|
||||
# Vertaal meta.consentRich en meta.aria*
|
||||
meta = field_data.get('meta')
|
||||
if isinstance(meta, dict):
|
||||
# consentRich
|
||||
if is_nonempty_str(meta.get('consentRich')):
|
||||
consent_ctx = (context if context else description_context) or ''
|
||||
consent_ctx = f"Consent rich text with inline tags. Keep tag names intact and translate only inner text. {consent_ctx}".strip()
|
||||
t = safe_translate(meta['consentRich'], consent_ctx)
|
||||
if t and tags_valid(meta['consentRich'], t):
|
||||
translated_config[field_config][field_name].setdefault('meta', {})['consentRich'] = t
|
||||
meta_consentRich_translated_count += 1
|
||||
else:
|
||||
if t and not tags_valid(meta['consentRich'], t) and current_event:
|
||||
src_counts = extract_tag_counts(meta['consentRich'])
|
||||
dst_counts = extract_tag_counts(t)
|
||||
current_event.log_error('inline_tags_validation_failed', {
|
||||
'tenant_id': tenant_id,
|
||||
'config_type': config_type,
|
||||
'config_version': config_version,
|
||||
'field_config': field_config,
|
||||
'field_name': field_name,
|
||||
'target_language': target_language,
|
||||
'source_tag_counts': src_counts,
|
||||
'translated_tag_counts': dst_counts
|
||||
})
|
||||
meta_inline_tags_invalid_after_translation_count += 1
|
||||
# fallback: keep original (already in deep copy)
|
||||
# aria*
|
||||
for k, v in list(meta.items()):
|
||||
if isinstance(k, str) and k.startswith('aria') and is_nonempty_str(v):
|
||||
aria_ctx = (context if context else description_context) or ''
|
||||
aria_ctx = f"ARIA label for accessibility. Short, imperative, descriptive. Form '{config_type} {config_version}', field '{field_name}'. {aria_ctx}".strip()
|
||||
t2 = safe_translate(v, aria_ctx)
|
||||
if t2:
|
||||
translated_config[field_config][field_name].setdefault('meta', {})[k] = t2
|
||||
meta_aria_translated_count += 1
|
||||
|
||||
return translated_config
|
||||
|
||||
@staticmethod
|
||||
def translate(tenant_id: int, text: str, target_language: str, source_language: Optional[str] = None,
|
||||
context: Optional[str] = None)-> str:
|
||||
if current_event:
|
||||
with current_event.create_span('Translation'):
|
||||
translation_cache = cache_manager.translation_cache.get_translation(text, target_language,
|
||||
source_language, context)
|
||||
return translation_cache.translated_text
|
||||
else:
|
||||
with BusinessEvent('Translation Service', tenant_id):
|
||||
with current_event.create_span('Translation'):
|
||||
translation_cache = cache_manager.translation_cache.get_translation(text, target_language,
|
||||
source_language, context)
|
||||
return translation_cache.translated_text
|
||||
14
common/services/utils/version_services.py
Normal file
14
common/services/utils/version_services.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from flask import current_app
|
||||
|
||||
class VersionServices:
|
||||
@staticmethod
|
||||
def split_version(full_version: str) -> tuple[str, str]:
|
||||
parts = full_version.split(".")
|
||||
if len(parts) < 3:
|
||||
major_minor = '.'.join(parts[:2]) if len(parts) >= 2 else full_version
|
||||
patch = ''
|
||||
else:
|
||||
major_minor = '.'.join(parts[:2])
|
||||
patch = parts[2]
|
||||
|
||||
return major_minor, patch
|
||||
22
common/templates/error/401.html
Normal file
22
common/templates/error/401.html
Normal file
@@ -0,0 +1,22 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>Unauthorized</title>
|
||||
<style>
|
||||
body { font-family: system-ui, -apple-system, Segoe UI, Roboto, Helvetica, Arial, sans-serif; background:#f7f7f9; color:#222; }
|
||||
.wrap { max-width: 720px; margin: 10vh auto; background:#fff; border:1px solid #e5e7eb; border-radius:12px; padding:32px; box-shadow: 0 8px 24px rgba(0,0,0,0.06); }
|
||||
h1 { margin: 0 0 8px; font-size: 28px; }
|
||||
p { margin: 0 0 16px; line-height:1.6; }
|
||||
a.btn { display:inline-block; padding:10px 16px; background:#2c3e50; color:#fff; text-decoration:none; border-radius:8px; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<main class="wrap">
|
||||
<h1>Not authorized</h1>
|
||||
<p>Your session may have expired or this action is not permitted.</p>
|
||||
<p><a class="btn" href="/">Go to home</a></p>
|
||||
</main>
|
||||
</body>
|
||||
</html>
|
||||
22
common/templates/error/403.html
Normal file
22
common/templates/error/403.html
Normal file
@@ -0,0 +1,22 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>Forbidden</title>
|
||||
<style>
|
||||
body { font-family: system-ui, -apple-system, Segoe UI, Roboto, Helvetica, Arial, sans-serif; background:#f7f7f9; color:#222; }
|
||||
.wrap { max-width: 720px; margin: 10vh auto; background:#fff; border:1px solid #e5e7eb; border-radius:12px; padding:32px; box-shadow: 0 8px 24px rgba(0,0,0,0.06); }
|
||||
h1 { margin: 0 0 8px; font-size: 28px; }
|
||||
p { margin: 0 0 16px; line-height:1.6; }
|
||||
a.btn { display:inline-block; padding:10px 16px; background:#2c3e50; color:#fff; text-decoration:none; border-radius:8px; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<main class="wrap">
|
||||
<h1>Access forbidden</h1>
|
||||
<p>You don't have permission to access this resource.</p>
|
||||
<p><a class="btn" href="/">Go to home</a></p>
|
||||
</main>
|
||||
</body>
|
||||
</html>
|
||||
22
common/templates/error/404.html
Normal file
22
common/templates/error/404.html
Normal file
@@ -0,0 +1,22 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>Page not found</title>
|
||||
<style>
|
||||
body { font-family: system-ui, -apple-system, Segoe UI, Roboto, Helvetica, Arial, sans-serif; background:#f7f7f9; color:#222; }
|
||||
.wrap { max-width: 720px; margin: 10vh auto; background:#fff; border:1px solid #e5e7eb; border-radius:12px; padding:32px; box-shadow: 0 8px 24px rgba(0,0,0,0.06); }
|
||||
h1 { margin: 0 0 8px; font-size: 28px; }
|
||||
p { margin: 0 0 16px; line-height:1.6; }
|
||||
a.btn { display:inline-block; padding:10px 16px; background:#2c3e50; color:#fff; text-decoration:none; border-radius:8px; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<main class="wrap">
|
||||
<h1>Page not found</h1>
|
||||
<p>The page you are looking for doesn’t exist or has been moved.</p>
|
||||
<p><a class="btn" href="/">Go to home</a></p>
|
||||
</main>
|
||||
</body>
|
||||
</html>
|
||||
22
common/templates/error/500.html
Normal file
22
common/templates/error/500.html
Normal file
@@ -0,0 +1,22 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>Something went wrong</title>
|
||||
<style>
|
||||
body { font-family: system-ui, -apple-system, Segoe UI, Roboto, Helvetica, Arial, sans-serif; background:#f7f7f9; color:#222; }
|
||||
.wrap { max-width: 720px; margin: 10vh auto; background:#fff; border:1px solid #e5e7eb; border-radius:12px; padding:32px; box-shadow: 0 8px 24px rgba(0,0,0,0.06); }
|
||||
h1 { margin: 0 0 8px; font-size: 28px; }
|
||||
p { margin: 0 0 16px; line-height:1.6; }
|
||||
a.btn { display:inline-block; padding:10px 16px; background:#2c3e50; color:#fff; text-decoration:none; border-radius:8px; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<main class="wrap">
|
||||
<h1>We’re sorry — something went wrong</h1>
|
||||
<p>Please try again later. If the issue persists, contact support.</p>
|
||||
<p><a class="btn" href="/">Go to home</a></p>
|
||||
</main>
|
||||
</body>
|
||||
</html>
|
||||
22
common/templates/error/generic.html
Normal file
22
common/templates/error/generic.html
Normal file
@@ -0,0 +1,22 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>Error</title>
|
||||
<style>
|
||||
body { font-family: system-ui, -apple-system, Segoe UI, Roboto, Helvetica, Arial, sans-serif; background:#f7f7f9; color:#222; }
|
||||
.wrap { max-width: 720px; margin: 10vh auto; background:#fff; border:1px solid #e5e7eb; border-radius:12px; padding:32px; box-shadow: 0 8px 24px rgba(0,0,0,0.06); }
|
||||
h1 { margin: 0 0 8px; font-size: 28px; }
|
||||
p { margin: 0 0 16px; line-height:1.6; }
|
||||
a.btn { display:inline-block; padding:10px 16px; background:#2c3e50; color:#fff; text-decoration:none; border-radius:8px; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<main class="wrap">
|
||||
<h1>Oops! Something went wrong</h1>
|
||||
<p>Please try again. If the issue persists, contact support.</p>
|
||||
<p><a class="btn" href="/">Go to home</a></p>
|
||||
</main>
|
||||
</body>
|
||||
</html>
|
||||
BIN
common/utils/.DS_Store
vendored
BIN
common/utils/.DS_Store
vendored
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
45
common/utils/asset_manifest.py
Normal file
45
common/utils/asset_manifest.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import json
|
||||
import os
|
||||
from functools import lru_cache
|
||||
from typing import Dict
|
||||
|
||||
# Default manifest path inside app images; override with env
|
||||
DEFAULT_MANIFEST_PATH = os.environ.get(
|
||||
'EVEAI_STATIC_MANIFEST_PATH',
|
||||
'/app/config/static-manifest/manifest.json'
|
||||
)
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _load_manifest(manifest_path: str = DEFAULT_MANIFEST_PATH) -> Dict[str, str]:
|
||||
try:
|
||||
with open(manifest_path, 'r', encoding='utf-8') as f:
|
||||
return json.load(f)
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
|
||||
def resolve_asset(logical_path: str, manifest_path: str = DEFAULT_MANIFEST_PATH) -> str:
|
||||
"""
|
||||
Map a logical asset path (e.g. 'dist/chat-client.js') to the hashed path
|
||||
found in the Parcel manifest. If not found or manifest missing, return the
|
||||
original logical path for graceful fallback.
|
||||
"""
|
||||
if not logical_path:
|
||||
return logical_path
|
||||
|
||||
manifest = _load_manifest(manifest_path)
|
||||
|
||||
# Try several key variants as Parcel manifests may use different keys
|
||||
candidates = [
|
||||
logical_path,
|
||||
logical_path.lstrip('/'),
|
||||
logical_path.replace('static/', ''),
|
||||
logical_path.replace('dist/', ''),
|
||||
]
|
||||
|
||||
for key in candidates:
|
||||
if key in manifest:
|
||||
return manifest[key]
|
||||
|
||||
return logical_path
|
||||
12
common/utils/asset_utils.py
Normal file
12
common/utils/asset_utils.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from datetime import datetime as dt, timezone as tz
|
||||
|
||||
from flask import current_app
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from common.extensions import cache_manager, minio_client, db
|
||||
from common.models.interaction import EveAIAsset
|
||||
from common.utils.model_logging_utils import set_logging_information
|
||||
|
||||
|
||||
|
||||
|
||||
656
common/utils/business_event.py
Normal file
656
common/utils/business_event.py
Normal file
@@ -0,0 +1,656 @@
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from contextlib import contextmanager, asynccontextmanager
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, Optional, List
|
||||
from datetime import datetime as dt, timezone as tz
|
||||
import logging
|
||||
|
||||
from flask import current_app
|
||||
from prometheus_client import Counter, Histogram, Gauge, Summary, push_to_gateway, REGISTRY
|
||||
|
||||
from .business_event_context import BusinessEventContext
|
||||
from common.models.entitlements import BusinessEventLog
|
||||
from common.extensions import db
|
||||
from .celery_utils import current_celery
|
||||
from common.utils.prometheus_utils import sanitize_label
|
||||
|
||||
# Standard duration buckets for all histograms
|
||||
DURATION_BUCKETS = [0.1, 0.5, 1, 2.5, 5, 10, 15, 30, 60, 120, 240, 360, float('inf')]
|
||||
|
||||
# Prometheus metrics for business events
|
||||
TRACE_COUNTER = Counter(
|
||||
'eveai_business_events_total',
|
||||
'Total number of business events triggered',
|
||||
['tenant_id', 'event_type', 'specialist_id', 'specialist_type', 'specialist_type_version']
|
||||
)
|
||||
|
||||
TRACE_DURATION = Histogram(
|
||||
'eveai_business_events_duration_seconds',
|
||||
'Duration of business events in seconds',
|
||||
['tenant_id', 'event_type', 'specialist_id', 'specialist_type', 'specialist_type_version'],
|
||||
buckets=DURATION_BUCKETS
|
||||
)
|
||||
|
||||
CONCURRENT_TRACES = Gauge(
|
||||
'eveai_business_events_concurrent',
|
||||
'Number of concurrent business events',
|
||||
['tenant_id', 'event_type', 'specialist_id', 'specialist_type', 'specialist_type_version']
|
||||
)
|
||||
|
||||
SPAN_COUNTER = Counter(
|
||||
'eveai_business_spans_total',
|
||||
'Total number of spans within business events',
|
||||
['tenant_id', 'event_type', 'activity_name', 'specialist_id', 'specialist_type', 'specialist_type_version']
|
||||
)
|
||||
|
||||
SPAN_DURATION = Histogram(
|
||||
'eveai_business_spans_duration_seconds',
|
||||
'Duration of spans within business events in seconds',
|
||||
['tenant_id', 'event_type', 'activity_name', 'specialist_id', 'specialist_type', 'specialist_type_version'],
|
||||
buckets=DURATION_BUCKETS
|
||||
)
|
||||
|
||||
CONCURRENT_SPANS = Gauge(
|
||||
'eveai_business_spans_concurrent',
|
||||
'Number of concurrent spans within business events',
|
||||
['tenant_id', 'event_type', 'activity_name', 'specialist_id', 'specialist_type', 'specialist_type_version']
|
||||
)
|
||||
|
||||
# LLM Usage metrics
|
||||
LLM_TOKENS_COUNTER = Counter(
|
||||
'eveai_llm_tokens_total',
|
||||
'Total number of tokens used in LLM calls',
|
||||
['tenant_id', 'event_type', 'interaction_type', 'token_type', 'specialist_id', 'specialist_type',
|
||||
'specialist_type_version']
|
||||
)
|
||||
|
||||
LLM_DURATION = Histogram(
|
||||
'eveai_llm_duration_seconds',
|
||||
'Duration of LLM API calls in seconds',
|
||||
['tenant_id', 'event_type', 'interaction_type', 'specialist_id', 'specialist_type', 'specialist_type_version'],
|
||||
buckets=DURATION_BUCKETS
|
||||
)
|
||||
|
||||
LLM_CALLS_COUNTER = Counter(
|
||||
'eveai_llm_calls_total',
|
||||
'Total number of LLM API calls',
|
||||
['tenant_id', 'event_type', 'interaction_type', 'specialist_id', 'specialist_type', 'specialist_type_version']
|
||||
)
|
||||
|
||||
|
||||
class BusinessEvent:
|
||||
# The BusinessEvent class itself is a context manager, but it doesn't use the @contextmanager decorator.
|
||||
# Instead, it defines __enter__ and __exit__ methods explicitly. This is because we're doing something a bit more
|
||||
# complex - we're interacting with the BusinessEventContext and the _business_event_stack.
|
||||
|
||||
def __init__(self, event_type: str, tenant_id: int, **kwargs):
|
||||
self.event_type = event_type
|
||||
self.tenant_id = tenant_id
|
||||
self.trace_id = str(uuid.uuid4())
|
||||
self.span_id = None
|
||||
self.span_name = None
|
||||
self.parent_span_id = None
|
||||
self.document_version_id = kwargs.get('document_version_id')
|
||||
self.document_version_file_size = kwargs.get('document_version_file_size')
|
||||
self.chat_session_id = kwargs.get('chat_session_id')
|
||||
self.interaction_id = kwargs.get('interaction_id')
|
||||
self.specialist_id = kwargs.get('specialist_id')
|
||||
self.specialist_type = kwargs.get('specialist_type')
|
||||
self.specialist_type_version = kwargs.get('specialist_type_version')
|
||||
self.environment = os.environ.get("FLASK_ENV", "development")
|
||||
self.span_counter = 0
|
||||
self.spans = []
|
||||
self.llm_metrics = {
|
||||
'total_tokens': 0,
|
||||
'prompt_tokens': 0,
|
||||
'completion_tokens': 0,
|
||||
'nr_of_pages': 0,
|
||||
'total_time': 0,
|
||||
'call_count': 0,
|
||||
'interaction_type': None
|
||||
}
|
||||
self._log_buffer = []
|
||||
|
||||
# Prometheus label values must be strings
|
||||
self.tenant_id_str = str(self.tenant_id)
|
||||
self.event_type_str = sanitize_label(self.event_type)
|
||||
self.specialist_id_str = str(self.specialist_id) if self.specialist_id else ""
|
||||
self.specialist_type_str = str(self.specialist_type) if self.specialist_type else ""
|
||||
self.specialist_type_version_str = sanitize_label(str(self.specialist_type_version)) \
|
||||
if self.specialist_type_version else ""
|
||||
self.span_name_str = ""
|
||||
|
||||
# Increment concurrent events gauge when initialized
|
||||
CONCURRENT_TRACES.labels(
|
||||
tenant_id=self.tenant_id_str,
|
||||
event_type=self.event_type_str,
|
||||
specialist_id=self.specialist_id_str,
|
||||
specialist_type=self.specialist_type_str,
|
||||
specialist_type_version=self.specialist_type_version_str
|
||||
).inc()
|
||||
|
||||
# Increment trace counter
|
||||
TRACE_COUNTER.labels(
|
||||
tenant_id=self.tenant_id_str,
|
||||
event_type=self.event_type_str,
|
||||
specialist_id=self.specialist_id_str,
|
||||
specialist_type=self.specialist_type_str,
|
||||
specialist_type_version=self.specialist_type_version_str
|
||||
).inc()
|
||||
|
||||
self._push_to_gateway()
|
||||
|
||||
def update_attribute(self, attribute: str, value: any):
|
||||
if hasattr(self, attribute):
|
||||
setattr(self, attribute, value)
|
||||
# Update string versions for Prometheus labels if needed
|
||||
if attribute == 'specialist_id':
|
||||
self.specialist_id_str = str(value) if value else ""
|
||||
elif attribute == 'specialist_type':
|
||||
self.specialist_type_str = str(value) if value else ""
|
||||
elif attribute == 'specialist_type_version':
|
||||
self.specialist_type_version_str = sanitize_label(str(value)) if value else ""
|
||||
elif attribute == 'tenant_id':
|
||||
self.tenant_id_str = str(value)
|
||||
elif attribute == 'event_type':
|
||||
self.event_type_str = sanitize_label(value)
|
||||
elif attribute == 'span_name':
|
||||
self.span_name_str = sanitize_label(value)
|
||||
else:
|
||||
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{attribute}'")
|
||||
|
||||
def update_llm_metrics(self, metrics: dict):
|
||||
self.llm_metrics['total_tokens'] += metrics.get('total_tokens', 0)
|
||||
self.llm_metrics['prompt_tokens'] += metrics.get('prompt_tokens', 0)
|
||||
self.llm_metrics['completion_tokens'] += metrics.get('completion_tokens', 0)
|
||||
self.llm_metrics['nr_of_pages'] += metrics.get('nr_of_pages', 0)
|
||||
self.llm_metrics['total_time'] += metrics.get('time_elapsed', 0)
|
||||
self.llm_metrics['call_count'] += 1
|
||||
self.llm_metrics['interaction_type'] = metrics['interaction_type']
|
||||
|
||||
# Track in Prometheus metrics
|
||||
interaction_type_str = sanitize_label(metrics['interaction_type']) if metrics['interaction_type'] else ""
|
||||
|
||||
# Track token usage
|
||||
LLM_TOKENS_COUNTER.labels(
|
||||
tenant_id=self.tenant_id_str,
|
||||
event_type=self.event_type_str,
|
||||
interaction_type=interaction_type_str,
|
||||
token_type='total',
|
||||
specialist_id=self.specialist_id_str,
|
||||
specialist_type=self.specialist_type_str,
|
||||
specialist_type_version=self.specialist_type_version_str
|
||||
).inc(metrics.get('total_tokens', 0))
|
||||
|
||||
LLM_TOKENS_COUNTER.labels(
|
||||
tenant_id=self.tenant_id_str,
|
||||
event_type=self.event_type_str,
|
||||
interaction_type=interaction_type_str,
|
||||
token_type='prompt',
|
||||
specialist_id=self.specialist_id_str,
|
||||
specialist_type=self.specialist_type_str,
|
||||
specialist_type_version=self.specialist_type_version_str
|
||||
).inc(metrics.get('prompt_tokens', 0))
|
||||
|
||||
LLM_TOKENS_COUNTER.labels(
|
||||
tenant_id=self.tenant_id_str,
|
||||
event_type=self.event_type_str,
|
||||
interaction_type=interaction_type_str,
|
||||
token_type='completion',
|
||||
specialist_id=self.specialist_id_str,
|
||||
specialist_type=self.specialist_type_str,
|
||||
specialist_type_version=self.specialist_type_version_str
|
||||
).inc(metrics.get('completion_tokens', 0))
|
||||
|
||||
# Track duration
|
||||
LLM_DURATION.labels(
|
||||
tenant_id=self.tenant_id_str,
|
||||
event_type=self.event_type_str,
|
||||
interaction_type=interaction_type_str,
|
||||
specialist_id=self.specialist_id_str,
|
||||
specialist_type=self.specialist_type_str,
|
||||
specialist_type_version=self.specialist_type_version_str
|
||||
).observe(metrics.get('time_elapsed', 0))
|
||||
|
||||
# Track call count
|
||||
LLM_CALLS_COUNTER.labels(
|
||||
tenant_id=self.tenant_id_str,
|
||||
event_type=self.event_type_str,
|
||||
interaction_type=interaction_type_str,
|
||||
specialist_id=self.specialist_id_str,
|
||||
specialist_type=self.specialist_type_str,
|
||||
specialist_type_version=self.specialist_type_version_str
|
||||
).inc()
|
||||
|
||||
self._push_to_gateway()
|
||||
|
||||
def reset_llm_metrics(self):
|
||||
self.llm_metrics['total_tokens'] = 0
|
||||
self.llm_metrics['prompt_tokens'] = 0
|
||||
self.llm_metrics['completion_tokens'] = 0
|
||||
self.llm_metrics['nr_of_pages'] = 0
|
||||
self.llm_metrics['total_time'] = 0
|
||||
self.llm_metrics['call_count'] = 0
|
||||
self.llm_metrics['interaction_type'] = None
|
||||
|
||||
@contextmanager
|
||||
def create_span(self, span_name: str):
|
||||
# The create_span method is designed to be used as a context manager. We want to perform some actions when
|
||||
# entering the span (like setting the span ID and name) and some actions when exiting the span (like removing
|
||||
# these temporary attributes). The @contextmanager decorator allows us to write this method in a way that
|
||||
# clearly separates the "entry" and "exit" logic, with the yield statement in between.
|
||||
|
||||
parent_span_id = self.span_id
|
||||
self.span_counter += 1
|
||||
new_span_id = str(uuid.uuid4())
|
||||
|
||||
# Save the current span info
|
||||
self.spans.append((self.span_id, self.span_name, self.parent_span_id))
|
||||
|
||||
# Set the new span info
|
||||
self.span_id = new_span_id
|
||||
self.span_name = span_name
|
||||
self.span_name_str = sanitize_label(span_name) if span_name else ""
|
||||
self.parent_span_id = parent_span_id
|
||||
|
||||
# Track start time for the span
|
||||
span_start_time = time.time()
|
||||
|
||||
# Increment span metrics - using span_name as activity_name for metrics
|
||||
SPAN_COUNTER.labels(
|
||||
tenant_id=self.tenant_id_str,
|
||||
event_type=self.event_type_str,
|
||||
activity_name=self.span_name_str,
|
||||
specialist_id=self.specialist_id_str,
|
||||
specialist_type=self.specialist_type_str,
|
||||
specialist_type_version=self.specialist_type_version_str
|
||||
).inc()
|
||||
|
||||
# Increment concurrent spans gauge
|
||||
CONCURRENT_SPANS.labels(
|
||||
tenant_id=self.tenant_id_str,
|
||||
event_type=self.event_type_str,
|
||||
activity_name=self.span_name_str,
|
||||
specialist_id=self.specialist_id_str,
|
||||
specialist_type=self.specialist_type_str,
|
||||
specialist_type_version=self.specialist_type_version_str
|
||||
).inc()
|
||||
|
||||
self._push_to_gateway()
|
||||
|
||||
self.log(f"Start")
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
# Calculate total time for this span
|
||||
span_total_time = time.time() - span_start_time
|
||||
|
||||
# Observe span duration
|
||||
SPAN_DURATION.labels(
|
||||
tenant_id=self.tenant_id_str,
|
||||
event_type=self.event_type_str,
|
||||
activity_name=self.span_name_str,
|
||||
specialist_id=self.specialist_id_str,
|
||||
specialist_type=self.specialist_type_str,
|
||||
specialist_type_version=self.specialist_type_version_str
|
||||
).observe(span_total_time)
|
||||
|
||||
# Decrement concurrent spans gauge
|
||||
CONCURRENT_SPANS.labels(
|
||||
tenant_id=self.tenant_id_str,
|
||||
event_type=self.event_type_str,
|
||||
activity_name=self.span_name_str,
|
||||
specialist_id=self.specialist_id_str,
|
||||
specialist_type=self.specialist_type_str,
|
||||
specialist_type_version=self.specialist_type_version_str
|
||||
).dec()
|
||||
|
||||
self._push_to_gateway()
|
||||
|
||||
if self.llm_metrics['call_count'] > 0:
|
||||
self.log_final_metrics()
|
||||
self.reset_llm_metrics()
|
||||
self.log(f"End", extra_fields={'span_duration': span_total_time})
|
||||
# Restore the previous span info
|
||||
if self.spans:
|
||||
self.span_id, self.span_name, self.parent_span_id = self.spans.pop()
|
||||
self.span_name_str = sanitize_label(span_name) if span_name else ""
|
||||
else:
|
||||
self.span_id = None
|
||||
self.span_name = None
|
||||
self.parent_span_id = None
|
||||
self.span_name_str = ""
|
||||
|
||||
@asynccontextmanager
|
||||
async def create_span_async(self, span_name: str):
|
||||
"""Async version of create_span using async context manager"""
|
||||
parent_span_id = self.span_id
|
||||
self.span_counter += 1
|
||||
new_span_id = str(uuid.uuid4())
|
||||
|
||||
# Save the current span info
|
||||
self.spans.append((self.span_id, self.span_name, self.parent_span_id))
|
||||
|
||||
# Set the new span info
|
||||
self.span_id = new_span_id
|
||||
self.span_name = span_name
|
||||
self.span_name_str = sanitize_label(span_name) if span_name else ""
|
||||
self.parent_span_id = parent_span_id
|
||||
|
||||
# Track start time for the span
|
||||
span_start_time = time.time()
|
||||
|
||||
# Increment span metrics - using span_name as activity_name for metrics
|
||||
SPAN_COUNTER.labels(
|
||||
tenant_id=self.tenant_id_str,
|
||||
event_type=self.event_type_str,
|
||||
activity_name=self.span_name_str,
|
||||
specialist_id=self.specialist_id_str,
|
||||
specialist_type=self.specialist_type_str,
|
||||
specialist_type_version=self.specialist_type_version_str
|
||||
).inc()
|
||||
|
||||
# Increment concurrent spans gauge
|
||||
CONCURRENT_SPANS.labels(
|
||||
tenant_id=self.tenant_id_str,
|
||||
event_type=self.event_type_str,
|
||||
activity_name=self.span_name_str,
|
||||
specialist_id=self.specialist_id_str,
|
||||
specialist_type=self.specialist_type_str,
|
||||
specialist_type_version=self.specialist_type_version_str
|
||||
).inc()
|
||||
|
||||
self._push_to_gateway()
|
||||
|
||||
self.log(f"Start")
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
# Calculate total time for this span
|
||||
span_total_time = time.time() - span_start_time
|
||||
|
||||
# Observe span duration
|
||||
SPAN_DURATION.labels(
|
||||
tenant_id=self.tenant_id_str,
|
||||
event_type=self.event_type_str,
|
||||
activity_name=self.span_name_str,
|
||||
specialist_id=self.specialist_id_str,
|
||||
specialist_type=self.specialist_type_str,
|
||||
specialist_type_version=self.specialist_type_version_str
|
||||
).observe(span_total_time)
|
||||
|
||||
# Decrement concurrent spans gauge
|
||||
CONCURRENT_SPANS.labels(
|
||||
tenant_id=self.tenant_id_str,
|
||||
event_type=self.event_type_str,
|
||||
activity_name=self.span_name_str,
|
||||
specialist_id=self.specialist_id_str,
|
||||
specialist_type=self.specialist_type_str,
|
||||
specialist_type_version=self.specialist_type_version_str
|
||||
).dec()
|
||||
|
||||
self._push_to_gateway()
|
||||
|
||||
if self.llm_metrics['call_count'] > 0:
|
||||
self.log_final_metrics()
|
||||
self.reset_llm_metrics()
|
||||
self.log(f"End", extra_fields={'span_duration': span_total_time})
|
||||
# Restore the previous span info
|
||||
if self.spans:
|
||||
self.span_id, self.span_name, self.parent_span_id = self.spans.pop()
|
||||
self.span_name_str = sanitize_label(span_name) if span_name else ""
|
||||
else:
|
||||
self.span_id = None
|
||||
self.span_name = None
|
||||
self.parent_span_id = None
|
||||
self.span_name_str = ""
|
||||
|
||||
def log(self, message: str, level: str = 'info', extra_fields: Dict[str, Any] = None):
|
||||
log_data = {
|
||||
'timestamp': dt.now(tz=tz.utc),
|
||||
'event_type': self.event_type,
|
||||
'tenant_id': self.tenant_id,
|
||||
'trace_id': self.trace_id,
|
||||
'span_id': self.span_id,
|
||||
'span_name': self.span_name,
|
||||
'parent_span_id': self.parent_span_id,
|
||||
'document_version_id': self.document_version_id,
|
||||
'document_version_file_size': self.document_version_file_size,
|
||||
'chat_session_id': self.chat_session_id,
|
||||
'interaction_id': self.interaction_id,
|
||||
'specialist_id': self.specialist_id,
|
||||
'specialist_type': self.specialist_type,
|
||||
'specialist_type_version': self.specialist_type_version,
|
||||
'environment': self.environment,
|
||||
'message': message,
|
||||
}
|
||||
# Add any extra fields
|
||||
if extra_fields:
|
||||
for key, value in extra_fields.items():
|
||||
# For span/trace duration, use the llm_metrics_total_time field
|
||||
if key == 'span_duration' or key == 'trace_duration':
|
||||
log_data['llm_metrics_total_time'] = value
|
||||
else:
|
||||
log_data[key] = value
|
||||
|
||||
self._log_buffer.append(log_data)
|
||||
|
||||
def log_llm_metrics(self, metrics: dict, level: str = 'info'):
|
||||
self.update_llm_metrics(metrics)
|
||||
message = "LLM Metrics"
|
||||
logger = logging.getLogger('business_events')
|
||||
log_data = {
|
||||
'timestamp': dt.now(tz=tz.utc),
|
||||
'event_type': self.event_type,
|
||||
'tenant_id': self.tenant_id,
|
||||
'trace_id': self.trace_id,
|
||||
'span_id': self.span_id,
|
||||
'span_name': self.span_name,
|
||||
'parent_span_id': self.parent_span_id,
|
||||
'document_version_id': self.document_version_id,
|
||||
'document_version_file_size': self.document_version_file_size,
|
||||
'chat_session_id': self.chat_session_id,
|
||||
'interaction_id': self.interaction_id,
|
||||
'specialist_id': self.specialist_id,
|
||||
'specialist_type': self.specialist_type,
|
||||
'specialist_type_version': self.specialist_type_version,
|
||||
'environment': self.environment,
|
||||
'llm_metrics_total_tokens': metrics.get('total_tokens', 0),
|
||||
'llm_metrics_prompt_tokens': metrics.get('prompt_tokens', 0),
|
||||
'llm_metrics_completion_tokens': metrics.get('completion_tokens', 0),
|
||||
'llm_metrics_nr_of_pages': metrics.get('nr_of_pages', 0),
|
||||
'llm_metrics_total_time': metrics.get('time_elapsed', 0),
|
||||
'llm_interaction_type': metrics['interaction_type'],
|
||||
'message': message,
|
||||
}
|
||||
self._log_buffer.append(log_data)
|
||||
|
||||
def log_final_metrics(self, level: str = 'info'):
|
||||
logger = logging.getLogger('business_events')
|
||||
message = "Final LLM Metrics"
|
||||
log_data = {
|
||||
'timestamp': dt.now(tz=tz.utc),
|
||||
'event_type': self.event_type,
|
||||
'tenant_id': self.tenant_id,
|
||||
'trace_id': self.trace_id,
|
||||
'span_id': self.span_id,
|
||||
'span_name': self.span_name,
|
||||
'parent_span_id': self.parent_span_id,
|
||||
'document_version_id': self.document_version_id,
|
||||
'document_version_file_size': self.document_version_file_size,
|
||||
'chat_session_id': self.chat_session_id,
|
||||
'interaction_id': self.interaction_id,
|
||||
'specialist_id': self.specialist_id,
|
||||
'specialist_type': self.specialist_type,
|
||||
'specialist_type_version': self.specialist_type_version,
|
||||
'environment': self.environment,
|
||||
'llm_metrics_total_tokens': self.llm_metrics['total_tokens'],
|
||||
'llm_metrics_prompt_tokens': self.llm_metrics['prompt_tokens'],
|
||||
'llm_metrics_completion_tokens': self.llm_metrics['completion_tokens'],
|
||||
'llm_metrics_nr_of_pages': self.llm_metrics['nr_of_pages'],
|
||||
'llm_metrics_total_time': self.llm_metrics['total_time'],
|
||||
'llm_metrics_call_count': self.llm_metrics['call_count'],
|
||||
'llm_interaction_type': self.llm_metrics['interaction_type'],
|
||||
'message': message,
|
||||
}
|
||||
self._log_buffer.append(log_data)
|
||||
|
||||
@staticmethod
|
||||
def _direct_db_persist(log_entries: List[Dict[str, Any]]):
|
||||
"""Fallback method to directly persist logs to DB if async fails"""
|
||||
try:
|
||||
db_entries = []
|
||||
for entry in log_entries:
|
||||
event_log = BusinessEventLog(
|
||||
timestamp=entry.pop('timestamp'),
|
||||
event_type=entry.pop('event_type'),
|
||||
tenant_id=entry.pop('tenant_id'),
|
||||
trace_id=entry.pop('trace_id'),
|
||||
span_id=entry.pop('span_id', None),
|
||||
span_name=entry.pop('span_name', None),
|
||||
parent_span_id=entry.pop('parent_span_id', None),
|
||||
document_version_id=entry.pop('document_version_id', None),
|
||||
document_version_file_size=entry.pop('document_version_file_size', None),
|
||||
chat_session_id=entry.pop('chat_session_id', None),
|
||||
interaction_id=entry.pop('interaction_id', None),
|
||||
specialist_id=entry.pop('specialist_id', None),
|
||||
specialist_type=entry.pop('specialist_type', None),
|
||||
specialist_type_version=entry.pop('specialist_type_version', None),
|
||||
environment=entry.pop('environment', None),
|
||||
llm_metrics_total_tokens=entry.pop('llm_metrics_total_tokens', None),
|
||||
llm_metrics_prompt_tokens=entry.pop('llm_metrics_prompt_tokens', None),
|
||||
llm_metrics_completion_tokens=entry.pop('llm_metrics_completion_tokens', None),
|
||||
llm_metrics_total_time=entry.pop('llm_metrics_total_time', None),
|
||||
llm_metrics_call_count=entry.pop('llm_metrics_call_count', None),
|
||||
llm_interaction_type=entry.pop('llm_interaction_type', None),
|
||||
message=entry.pop('message', None)
|
||||
)
|
||||
db_entries.append(event_log)
|
||||
|
||||
# Bulk insert
|
||||
db.session.bulk_save_objects(db_entries)
|
||||
db.session.commit()
|
||||
except Exception as e:
|
||||
logger = logging.getLogger('business_events')
|
||||
logger.error(f"Failed to persist logs directly to DB: {e}")
|
||||
db.session.rollback()
|
||||
|
||||
def _flush_log_buffer(self):
|
||||
"""Flush the log buffer to the database via a Celery task"""
|
||||
if self._log_buffer:
|
||||
try:
|
||||
# Send to Celery task
|
||||
current_celery.send_task(
|
||||
'persist_business_events',
|
||||
args=[self._log_buffer],
|
||||
queue='entitlements' # Or dedicated log queue
|
||||
)
|
||||
except Exception as e:
|
||||
# Fallback to direct DB write in case of issues with Celery
|
||||
logger = logging.getLogger('business_events')
|
||||
logger.error(f"Failed to send logs to Celery. Falling back to direct DB: {e}")
|
||||
self._direct_db_persist(self._log_buffer)
|
||||
|
||||
# Clear the buffer after sending
|
||||
self._log_buffer = []
|
||||
|
||||
def _push_to_gateway(self):
|
||||
# Push metrics to the gateway with grouping key to avoid overwrites across pods/processes
|
||||
try:
|
||||
# Determine grouping labels
|
||||
pod_name = current_app.config.get('POD_NAME', current_app.config.get('COMPONENT_NAME', 'dev'))
|
||||
pod_namespace = current_app.config.get('POD_NAMESPACE', current_app.config.get('FLASK_ENV', 'dev'))
|
||||
worker_id = str(os.getpid())
|
||||
|
||||
grouping_key = {
|
||||
'instance': pod_name,
|
||||
'namespace': pod_namespace,
|
||||
'process': worker_id,
|
||||
}
|
||||
|
||||
push_to_gateway(
|
||||
current_app.config['PUSH_GATEWAY_URL'],
|
||||
job=current_app.config['COMPONENT_NAME'],
|
||||
registry=REGISTRY,
|
||||
grouping_key=grouping_key,
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Failed to push metrics to Prometheus Push Gateway: {e}")
|
||||
|
||||
def __enter__(self):
|
||||
self.trace_start_time = time.time()
|
||||
self.log(f'Starting Trace for {self.event_type}')
|
||||
return BusinessEventContext(self).__enter__()
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
trace_total_time = time.time() - self.trace_start_time
|
||||
|
||||
# Record trace duration
|
||||
TRACE_DURATION.labels(
|
||||
tenant_id=self.tenant_id_str,
|
||||
event_type=self.event_type_str,
|
||||
specialist_id=self.specialist_id_str,
|
||||
specialist_type=self.specialist_type_str,
|
||||
specialist_type_version=self.specialist_type_version_str
|
||||
).observe(trace_total_time)
|
||||
|
||||
# Decrement concurrent traces gauge
|
||||
CONCURRENT_TRACES.labels(
|
||||
tenant_id=self.tenant_id_str,
|
||||
event_type=self.event_type_str,
|
||||
specialist_id=self.specialist_id_str,
|
||||
specialist_type=self.specialist_type_str,
|
||||
specialist_type_version=self.specialist_type_version_str
|
||||
).dec()
|
||||
|
||||
self._push_to_gateway()
|
||||
|
||||
if self.llm_metrics['call_count'] > 0:
|
||||
self.log_final_metrics()
|
||||
self.reset_llm_metrics()
|
||||
|
||||
self.log(f'Ending Trace for {self.event_type}', extra_fields={'trace_duration': trace_total_time})
|
||||
self._flush_log_buffer()
|
||||
|
||||
|
||||
return BusinessEventContext(self).__exit__(exc_type, exc_val, exc_tb)
|
||||
|
||||
async def __aenter__(self):
|
||||
self.trace_start_time = time.time()
|
||||
self.log(f'Starting Trace for {self.event_type}')
|
||||
return await BusinessEventContext(self).__aenter__()
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
trace_total_time = time.time() - self.trace_start_time
|
||||
|
||||
# Record trace duration
|
||||
TRACE_DURATION.labels(
|
||||
tenant_id=self.tenant_id_str,
|
||||
event_type=self.event_type_str,
|
||||
specialist_id=self.specialist_id_str,
|
||||
specialist_type=self.specialist_type_str,
|
||||
specialist_type_version=self.specialist_type_version_str
|
||||
).observe(trace_total_time)
|
||||
|
||||
# Decrement concurrent traces gauge
|
||||
CONCURRENT_TRACES.labels(
|
||||
tenant_id=self.tenant_id_str,
|
||||
event_type=self.event_type_str,
|
||||
specialist_id=self.specialist_id_str,
|
||||
specialist_type=self.specialist_type_str,
|
||||
specialist_type_version=self.specialist_type_version_str
|
||||
).dec()
|
||||
|
||||
self._push_to_gateway()
|
||||
|
||||
if self.llm_metrics['call_count'] > 0:
|
||||
self.log_final_metrics()
|
||||
self.reset_llm_metrics()
|
||||
|
||||
self.log(f'Ending Trace for {self.event_type}', extra_fields={'trace_duration': trace_total_time})
|
||||
self._flush_log_buffer()
|
||||
return await BusinessEventContext(self).__aexit__(exc_type, exc_val, exc_tb)
|
||||
52
common/utils/business_event_context.py
Normal file
52
common/utils/business_event_context.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from werkzeug.local import LocalProxy, LocalStack
|
||||
import asyncio
|
||||
from contextvars import ContextVar
|
||||
import contextvars
|
||||
|
||||
# Keep existing stack for backward compatibility
|
||||
_business_event_stack = LocalStack()
|
||||
|
||||
# Add contextvar for async support
|
||||
_business_event_contextvar = ContextVar('business_event', default=None)
|
||||
|
||||
|
||||
def _get_current_event():
|
||||
# Try contextvar first (for async)
|
||||
event = _business_event_contextvar.get()
|
||||
if event is not None:
|
||||
return event
|
||||
|
||||
# Fall back to the stack-based approach (for sync)
|
||||
top = _business_event_stack.top
|
||||
if top is None:
|
||||
raise RuntimeError("No business event context found. Are you sure you're in a business event?")
|
||||
return top
|
||||
|
||||
|
||||
current_event = LocalProxy(_get_current_event)
|
||||
|
||||
|
||||
class BusinessEventContext:
|
||||
def __init__(self, event):
|
||||
self.event = event
|
||||
self._token = None # For storing contextvar token
|
||||
|
||||
def __enter__(self):
|
||||
_business_event_stack.push(self.event)
|
||||
self._token = _business_event_contextvar.set(self.event)
|
||||
return self.event
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
_business_event_stack.pop()
|
||||
if self._token is not None:
|
||||
_business_event_contextvar.reset(self._token)
|
||||
|
||||
async def __aenter__(self):
|
||||
_business_event_stack.push(self.event)
|
||||
self._token = _business_event_contextvar.set(self.event)
|
||||
return self.event
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
_business_event_stack.pop()
|
||||
if self._token is not None:
|
||||
_business_event_contextvar.reset(self._token)
|
||||
192
common/utils/cache/base.py
vendored
Normal file
192
common/utils/cache/base.py
vendored
Normal file
@@ -0,0 +1,192 @@
|
||||
from typing import Any, Dict, List, Optional, TypeVar, Generic, Type
|
||||
from dataclasses import dataclass
|
||||
from flask import Flask, current_app
|
||||
from dogpile.cache import CacheRegion
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
T = TypeVar('T') # Generic type parameter for cached data
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheKey:
|
||||
"""
|
||||
Represents a composite cache key made up of multiple components.
|
||||
Enables structured and consistent key generation for cache entries.
|
||||
|
||||
Attributes:
|
||||
components (Dict[str, Any]): Dictionary of key components and their values
|
||||
|
||||
Example:
|
||||
key = CacheKey({'tenant_id': 123, 'user_id': 456})
|
||||
str(key) -> "tenant_id=123:user_id=456"
|
||||
"""
|
||||
components: Dict[str, Any]
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""
|
||||
Converts components into a deterministic string representation.
|
||||
Components are sorted alphabetically to ensure consistent key generation.
|
||||
"""
|
||||
return ":".join(f"{k}={v}" for k, v in sorted(self.components.items()))
|
||||
|
||||
|
||||
class CacheHandler(Generic[T]):
|
||||
"""
|
||||
Base cache handler implementation providing structured caching functionality.
|
||||
Uses generics to ensure type safety of cached data.
|
||||
|
||||
Type Parameters:
|
||||
T: Type of data being cached
|
||||
|
||||
Attributes:
|
||||
region (CacheRegion): Dogpile cache region for storage
|
||||
prefix (str): Prefix for all cache keys managed by this handler
|
||||
"""
|
||||
|
||||
def __init__(self, region: CacheRegion, prefix: str):
|
||||
self.region = region
|
||||
self.prefix = prefix
|
||||
self._key_components = [] # List of required key components
|
||||
|
||||
@abstractmethod
|
||||
def _to_cache_data(self, instance: T) -> Any:
|
||||
"""
|
||||
Convert the data to a cacheable format for internal use.
|
||||
|
||||
Args:
|
||||
instance: The data to be cached.
|
||||
|
||||
Returns:
|
||||
A serializable format of the instance.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def _from_cache_data(self, data: Any, **kwargs) -> T:
|
||||
"""
|
||||
Convert cached data back to usable format for internal use.
|
||||
|
||||
Args:
|
||||
data: The cached data.
|
||||
**kwargs: Additional context.
|
||||
|
||||
Returns:
|
||||
The data in its usable format.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def _should_cache(self, value: T) -> bool:
|
||||
"""
|
||||
Validate if the value should be cached for internal use.
|
||||
|
||||
Args:
|
||||
value: The value to be cached.
|
||||
|
||||
Returns:
|
||||
True if the value should be cached, False otherwise.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def configure_keys(self, *components: str):
|
||||
"""
|
||||
Configure required components for cache key generation.
|
||||
|
||||
Args:
|
||||
*components: Required key component names
|
||||
|
||||
Returns:
|
||||
self for method chaining
|
||||
"""
|
||||
self._key_components = components
|
||||
return self
|
||||
|
||||
def generate_key(self, **identifiers) -> str:
|
||||
"""
|
||||
Generate a cache key from provided identifiers.
|
||||
|
||||
Args:
|
||||
**identifiers: Key-value pairs for key components
|
||||
|
||||
Returns:
|
||||
Formatted cache key string
|
||||
|
||||
Raises:
|
||||
ValueError: If required components are missing
|
||||
"""
|
||||
missing = set(self._key_components) - set(identifiers.keys())
|
||||
if missing:
|
||||
raise ValueError(f"Missing key components: {missing}")
|
||||
|
||||
region_name = getattr(self.region, 'name', 'default_region')
|
||||
|
||||
key = CacheKey({k: identifiers[k] for k in self._key_components})
|
||||
return f"{region_name}:{self.prefix}:{str(key)}"
|
||||
|
||||
def get(self, creator_func, **identifiers) -> T:
|
||||
"""
|
||||
Get or create a cached value.
|
||||
|
||||
Args:
|
||||
creator_func: Function to create value if not cached
|
||||
**identifiers: Key components for cache key
|
||||
|
||||
Returns:
|
||||
Cached or newly created value
|
||||
"""
|
||||
cache_key = self.generate_key(**identifiers)
|
||||
|
||||
def creator():
|
||||
instance = creator_func(**identifiers)
|
||||
serialized_instance = self._to_cache_data(instance)
|
||||
return serialized_instance
|
||||
|
||||
cached_data = self.region.get_or_create(
|
||||
cache_key,
|
||||
creator,
|
||||
should_cache_fn=self._should_cache
|
||||
)
|
||||
|
||||
return self._from_cache_data(cached_data, **identifiers)
|
||||
|
||||
def invalidate(self, **identifiers):
|
||||
"""
|
||||
Invalidate a specific cache entry.
|
||||
|
||||
Args:
|
||||
**identifiers: Key components for the cache entry
|
||||
"""
|
||||
cache_key = self.generate_key(**identifiers)
|
||||
self.region.delete(cache_key)
|
||||
|
||||
def invalidate_by_model(self, model: str, **identifiers):
|
||||
"""
|
||||
Invalidate cache entry based on model changes.
|
||||
|
||||
Args:
|
||||
model: Changed model name
|
||||
**identifiers: Model instance identifiers
|
||||
"""
|
||||
try:
|
||||
self.invalidate(**identifiers)
|
||||
except ValueError:
|
||||
pass # Skip if cache key can't be generated from provided identifiers
|
||||
|
||||
def invalidate_region(self):
|
||||
"""
|
||||
Invalidate all cache entries within this region.
|
||||
|
||||
Deletes all keys that start with the region prefix.
|
||||
"""
|
||||
# Construct the pattern for all keys in this region
|
||||
pattern = f"{self.region}:{self.prefix}:*"
|
||||
|
||||
# Assuming Redis backend with dogpile, use `delete_multi` or direct Redis access
|
||||
if hasattr(self.region.backend, 'client'):
|
||||
redis_client = self.region.backend.client
|
||||
keys_to_delete = redis_client.keys(pattern)
|
||||
if keys_to_delete:
|
||||
redis_client.delete(*keys_to_delete)
|
||||
else:
|
||||
# Fallback for other backends
|
||||
raise NotImplementedError("Region invalidation is only supported for Redis backend.")
|
||||
545
common/utils/cache/config_cache.py
vendored
Normal file
545
common/utils/cache/config_cache.py
vendored
Normal file
@@ -0,0 +1,545 @@
|
||||
from typing import Dict, Any, Optional
|
||||
from pathlib import Path
|
||||
import yaml
|
||||
from packaging import version
|
||||
import os
|
||||
from flask import current_app
|
||||
|
||||
from common.utils.cache.base import CacheHandler, CacheKey
|
||||
from config.type_defs import agent_types, task_types, tool_types, specialist_types, retriever_types, prompt_types, \
|
||||
catalog_types, partner_service_types, processor_types, customisation_types, specialist_form_types, capsule_types
|
||||
|
||||
|
||||
def is_major_minor(version: str) -> bool:
|
||||
parts = version.strip('.').split('.')
|
||||
return len(parts) == 2 and all(part.isdigit() for part in parts)
|
||||
|
||||
|
||||
class BaseConfigCacheHandler(CacheHandler[Dict[str, Any]]):
|
||||
"""Base handler for configuration caching"""
|
||||
|
||||
def __init__(self, region, config_type: str):
|
||||
"""
|
||||
Args:
|
||||
region: Cache region
|
||||
config_type: Type of configuration (agents, tasks, etc.)
|
||||
"""
|
||||
super().__init__(region, f'config_{config_type}')
|
||||
self.config_type = config_type
|
||||
self._types_module = None # Set by subclasses
|
||||
self._config_dir = None # Set by subclasses
|
||||
self.version_tree_cache = None
|
||||
self.configure_keys('type_name', 'version')
|
||||
|
||||
def _to_cache_data(self, instance: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Convert the data to a cacheable format"""
|
||||
# For configuration data, we can just return the dictionary as is
|
||||
# since it's already in a serializable format
|
||||
return instance
|
||||
|
||||
def _from_cache_data(self, data: Dict[str, Any], **kwargs) -> Dict[str, Any]:
|
||||
"""Convert cached data back to usable format"""
|
||||
# Similarly, we can return the data directly since it's already
|
||||
# in the format we need
|
||||
return data
|
||||
|
||||
def _should_cache(self, value: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
Validate if the value should be cached
|
||||
|
||||
Args:
|
||||
value: The value to be cached
|
||||
|
||||
Returns:
|
||||
bool: True if the value should be cached
|
||||
"""
|
||||
return isinstance(value, dict) # Cache all dictionaries
|
||||
|
||||
def set_version_tree_cache(self, cache):
|
||||
"""Set the version tree cache dependency."""
|
||||
self.version_tree_cache = cache
|
||||
|
||||
def _load_specific_config(self, type_name: str, version_str: str = 'latest') -> Dict[str, Any]:
|
||||
"""
|
||||
Load a specific configuration version
|
||||
Automatically handles global vs partner-specific configs
|
||||
"""
|
||||
version_tree = self.version_tree_cache.get_versions(type_name)
|
||||
versions = version_tree['versions']
|
||||
|
||||
if version_str == 'latest':
|
||||
version_str = version_tree['latest_version']
|
||||
|
||||
if version_str not in versions:
|
||||
raise ValueError(f"Version {version_str} not found for {type_name}")
|
||||
|
||||
version_info = versions[version_str]
|
||||
file_path = version_info['file_path']
|
||||
partner = version_info.get('partner')
|
||||
|
||||
try:
|
||||
with open(file_path) as f:
|
||||
config = yaml.safe_load(f)
|
||||
# Add partner information to the config
|
||||
if partner:
|
||||
config['partner'] = partner
|
||||
return config
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error loading config from {file_path}: {e}")
|
||||
|
||||
def get_config(self, type_name: str, version: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Get configuration for a specific type and version
|
||||
If version not specified, returns latest
|
||||
|
||||
Args:
|
||||
type_name: Configuration type name
|
||||
version: Optional specific version to retrieve
|
||||
|
||||
Returns:
|
||||
Configuration data
|
||||
"""
|
||||
if version is None:
|
||||
version_str = self.version_tree_cache.get_latest_version(type_name)
|
||||
elif is_major_minor(version):
|
||||
version_str = self.version_tree_cache.get_latest_patch_version(type_name, version)
|
||||
else:
|
||||
version_str = version
|
||||
|
||||
result = self.get(
|
||||
lambda type_name, version: self._load_specific_config(type_name, version),
|
||||
type_name=type_name,
|
||||
version=version_str
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
class BaseConfigVersionTreeCacheHandler(CacheHandler[Dict[str, Any]]):
|
||||
"""Base handler for configuration version tree caching"""
|
||||
|
||||
def __init__(self, region, config_type: str):
|
||||
"""
|
||||
Args:
|
||||
region: Cache region
|
||||
config_type: Type of configuration (agents, tasks, etc.)
|
||||
"""
|
||||
super().__init__(region, f'config_{config_type}_version_tree')
|
||||
self.config_type = config_type
|
||||
self._types_module = None # Set by subclasses
|
||||
self._config_dir = None # Set by subclasses
|
||||
self.configure_keys('type_name')
|
||||
|
||||
def _load_version_tree(self, type_name: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Load version tree for a specific type without loading full configurations
|
||||
Checks both global and partner-specific directories
|
||||
"""
|
||||
# First check the global path
|
||||
global_path = Path(self._config_dir) / "globals" / type_name
|
||||
|
||||
# If global path doesn't exist, check if the type exists directly in the root
|
||||
# (for backward compatibility)
|
||||
if not global_path.exists():
|
||||
global_path = Path(self._config_dir) / type_name
|
||||
|
||||
if not global_path.exists():
|
||||
# Check if it exists in any partner subdirectories
|
||||
partner_dirs = [d for d in Path(self._config_dir).iterdir()
|
||||
if d.is_dir() and d.name != "globals"]
|
||||
|
||||
for partner_dir in partner_dirs:
|
||||
partner_type_path = partner_dir / type_name
|
||||
if partner_type_path.exists():
|
||||
# Found in partner directory
|
||||
return self._load_versions_from_path(partner_type_path)
|
||||
|
||||
# If we get here, the type wasn't found anywhere
|
||||
raise ValueError(f"No configuration found for type {type_name}")
|
||||
|
||||
return self._load_versions_from_path(global_path)
|
||||
|
||||
def _load_versions_from_path(self, path: Path) -> Dict[str, Any]:
|
||||
"""Load all versions from a specific path"""
|
||||
version_files = list(path.glob('*.yaml'))
|
||||
if not version_files:
|
||||
raise ValueError(f"No versions found in {path}")
|
||||
|
||||
versions = {}
|
||||
latest_version = None
|
||||
latest_version_obj = None
|
||||
|
||||
for file_path in version_files:
|
||||
ver = file_path.stem # Get version from filename
|
||||
try:
|
||||
ver_obj = version.parse(ver)
|
||||
# Only load minimal metadata for version tree
|
||||
with open(file_path) as f:
|
||||
yaml_data = yaml.safe_load(f)
|
||||
metadata = yaml_data.get('metadata', {})
|
||||
# Add partner information if available
|
||||
partner = None
|
||||
if "globals" not in str(file_path):
|
||||
# Extract partner name from path
|
||||
# Path format: config_dir/partner_name/type_name/version.yaml
|
||||
partner = file_path.parent.parent.name
|
||||
|
||||
versions[ver] = {
|
||||
'metadata': metadata,
|
||||
'file_path': str(file_path),
|
||||
'partner': partner
|
||||
}
|
||||
|
||||
# Track latest version
|
||||
if latest_version_obj is None or ver_obj > latest_version_obj:
|
||||
latest_version = ver
|
||||
latest_version_obj = ver_obj
|
||||
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error loading version {ver}: {e}")
|
||||
continue
|
||||
|
||||
return {
|
||||
'versions': versions,
|
||||
'latest_version': latest_version
|
||||
}
|
||||
|
||||
def _to_cache_data(self, instance: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Convert the data to a cacheable format"""
|
||||
# For configuration data, we can just return the dictionary as is
|
||||
# since it's already in a serializable format
|
||||
return instance
|
||||
|
||||
def _from_cache_data(self, data: Dict[str, Any], **kwargs) -> Dict[str, Any]:
|
||||
"""Convert cached data back to usable format"""
|
||||
# Similarly, we can return the data directly since it's already
|
||||
# in the format we need
|
||||
return data
|
||||
|
||||
def _should_cache(self, value: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
Validate if the value should be cached
|
||||
|
||||
Args:
|
||||
value: The value to be cached
|
||||
|
||||
Returns:
|
||||
bool: True if the value should be cached
|
||||
"""
|
||||
return isinstance(value, dict) # Cache all dictionaries
|
||||
|
||||
def get_versions(self, type_name: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Get version tree for a type
|
||||
|
||||
Args:
|
||||
type_name: Type to get versions for
|
||||
|
||||
Returns:
|
||||
Dict with version information
|
||||
"""
|
||||
return self.get(
|
||||
lambda type_name: self._load_version_tree(type_name),
|
||||
type_name=type_name
|
||||
)
|
||||
|
||||
def get_latest_version(self, type_name: str) -> str:
|
||||
"""
|
||||
Get the latest version for a given type name.
|
||||
|
||||
Args:
|
||||
type_name: Name of the configuration type
|
||||
|
||||
Returns:
|
||||
Latest version string
|
||||
|
||||
Raises:
|
||||
ValueError: If type not found or no versions available
|
||||
"""
|
||||
version_tree = self.get_versions(type_name)
|
||||
if not version_tree or 'latest_version' not in version_tree:
|
||||
raise ValueError(f"No versions found for {type_name}")
|
||||
|
||||
return version_tree['latest_version']
|
||||
|
||||
def get_latest_patch_version(self, type_name: str, major_minor: str) -> str:
|
||||
"""
|
||||
Get the latest patch version for a given major.minor version.
|
||||
|
||||
Args:
|
||||
type_name: Name of the configuration type
|
||||
major_minor: Major.minor version (e.g. "1.0")
|
||||
|
||||
Returns:
|
||||
Latest patch version string (e.g. "1.0.3")
|
||||
|
||||
Raises:
|
||||
ValueError: If type not found or no matching versions
|
||||
"""
|
||||
version_tree = self.get_versions(type_name)
|
||||
if not version_tree or 'versions' not in version_tree:
|
||||
raise ValueError(f"No versions found for {type_name}")
|
||||
|
||||
# Filter versions that match the major.minor prefix
|
||||
matching_versions = [
|
||||
ver for ver in version_tree['versions'].keys()
|
||||
if ver.startswith(major_minor + '.')
|
||||
]
|
||||
|
||||
if not matching_versions:
|
||||
raise ValueError(f"No versions found for {type_name} with prefix {major_minor}")
|
||||
|
||||
# Return highest matching version
|
||||
latest_patch = max(matching_versions, key=version.parse)
|
||||
return latest_patch
|
||||
|
||||
|
||||
class BaseConfigTypesCacheHandler(CacheHandler[Dict[str, Any]]):
|
||||
"""Base handler for configuration types caching"""
|
||||
|
||||
def __init__(self, region, config_type: str):
|
||||
"""
|
||||
Args:
|
||||
region: Cache region
|
||||
config_type: Type of configuration (agents, tasks, etc.)
|
||||
"""
|
||||
super().__init__(region, f'config_{config_type}_types')
|
||||
self.config_type = config_type
|
||||
self._types_module = None # Set by subclasses
|
||||
self._config_dir = None # Set by subclasses
|
||||
self.configure_keys()
|
||||
|
||||
def _to_cache_data(self, instance: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Convert the data to a cacheable format"""
|
||||
# For configuration data, we can just return the dictionary as is
|
||||
# since it's already in a serializable format
|
||||
return instance
|
||||
|
||||
def _from_cache_data(self, data: Dict[str, Any], **kwargs) -> Dict[str, Any]:
|
||||
"""Convert cached data back to usable format"""
|
||||
# Similarly, we can return the data directly since it's already
|
||||
# in the format we need
|
||||
return data
|
||||
|
||||
def _should_cache(self, value: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
Validate if the value should be cached
|
||||
|
||||
Args:
|
||||
value: The value to be cached
|
||||
|
||||
Returns:
|
||||
bool: True if the value should be cached
|
||||
"""
|
||||
return isinstance(value, dict) # Cache all dictionaries
|
||||
|
||||
def _load_type_definitions(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""Load type definitions from the corresponding type_defs module"""
|
||||
if not self._types_module:
|
||||
raise ValueError("_types_module must be set by subclass")
|
||||
|
||||
type_definitions = {}
|
||||
for type_id, info in self._types_module.items():
|
||||
# Kopieer alle velden uit de type definitie
|
||||
type_definitions[type_id] = {}
|
||||
for key, value in info.items():
|
||||
type_definitions[type_id][key] = value
|
||||
|
||||
return type_definitions
|
||||
|
||||
def get_types(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""Get dictionary of available types with all defined properties"""
|
||||
result = self.get(
|
||||
lambda type_name: self._load_type_definitions(),
|
||||
type_name=f'{self.config_type}_types',
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def create_config_cache_handlers(config_type: str, config_dir: str, types_module: dict) -> tuple:
|
||||
"""
|
||||
Factory function to dynamically create the 3 cache handler classes for a given configuration type.
|
||||
The following cache names are created:
|
||||
- <config_type>_config_cache
|
||||
- <config_type>_version_tree_cache
|
||||
- <config_type>_types_cache
|
||||
|
||||
|
||||
Args:
|
||||
config_type: The configuration type (e.g., 'agents', 'tasks').
|
||||
config_dir: The directory where configuration files are stored.
|
||||
types_module: The types module defining the available types for this config.
|
||||
|
||||
Returns:
|
||||
A tuple of dynamically created classes for config, version tree, and types handlers.
|
||||
"""
|
||||
|
||||
class ConfigCacheHandler(BaseConfigCacheHandler):
|
||||
handler_name = f"{config_type}_config_cache"
|
||||
|
||||
def __init__(self, region):
|
||||
super().__init__(region, config_type)
|
||||
self._types_module = types_module
|
||||
self._config_dir = config_dir
|
||||
|
||||
class VersionTreeCacheHandler(BaseConfigVersionTreeCacheHandler):
|
||||
handler_name = f"{config_type}_version_tree_cache"
|
||||
|
||||
def __init__(self, region):
|
||||
super().__init__(region, config_type)
|
||||
self._types_module = types_module
|
||||
self._config_dir = config_dir
|
||||
|
||||
class TypesCacheHandler(BaseConfigTypesCacheHandler):
|
||||
handler_name = f"{config_type}_types_cache"
|
||||
|
||||
def __init__(self, region):
|
||||
super().__init__(region, config_type)
|
||||
self._types_module = types_module
|
||||
self._config_dir = config_dir
|
||||
|
||||
return ConfigCacheHandler, VersionTreeCacheHandler, TypesCacheHandler
|
||||
|
||||
|
||||
AgentConfigCacheHandler, AgentConfigVersionTreeCacheHandler, AgentConfigTypesCacheHandler = (
|
||||
create_config_cache_handlers(
|
||||
config_type='agents',
|
||||
config_dir='config/agents',
|
||||
types_module=agent_types.AGENT_TYPES
|
||||
))
|
||||
|
||||
|
||||
TaskConfigCacheHandler, TaskConfigVersionTreeCacheHandler, TaskConfigTypesCacheHandler = (
|
||||
create_config_cache_handlers(
|
||||
config_type='tasks',
|
||||
config_dir='config/tasks',
|
||||
types_module=task_types.TASK_TYPES
|
||||
))
|
||||
|
||||
|
||||
ToolConfigCacheHandler, ToolConfigVersionTreeCacheHandler, ToolConfigTypesCacheHandler = (
|
||||
create_config_cache_handlers(
|
||||
config_type='tools',
|
||||
config_dir='config/tools',
|
||||
types_module=tool_types.TOOL_TYPES
|
||||
))
|
||||
|
||||
|
||||
SpecialistConfigCacheHandler, SpecialistConfigVersionTreeCacheHandler, SpecialistConfigTypesCacheHandler = (
|
||||
create_config_cache_handlers(
|
||||
config_type='specialists',
|
||||
config_dir='config/specialists',
|
||||
types_module=specialist_types.SPECIALIST_TYPES
|
||||
))
|
||||
|
||||
|
||||
RetrieverConfigCacheHandler, RetrieverConfigVersionTreeCacheHandler, RetrieverConfigTypesCacheHandler = (
|
||||
create_config_cache_handlers(
|
||||
config_type='retrievers',
|
||||
config_dir='config/retrievers',
|
||||
types_module=retriever_types.RETRIEVER_TYPES
|
||||
))
|
||||
|
||||
|
||||
PromptConfigCacheHandler, PromptConfigVersionTreeCacheHandler, PromptConfigTypesCacheHandler = (
|
||||
create_config_cache_handlers(
|
||||
config_type='prompts',
|
||||
config_dir='config/prompts',
|
||||
types_module=prompt_types.PROMPT_TYPES
|
||||
))
|
||||
|
||||
CatalogConfigCacheHandler, CatalogConfigVersionTreeCacheHandler, CatalogConfigTypesCacheHandler = (
|
||||
create_config_cache_handlers(
|
||||
config_type='catalogs',
|
||||
config_dir='config/catalogs',
|
||||
types_module=catalog_types.CATALOG_TYPES
|
||||
))
|
||||
|
||||
ProcessorConfigCacheHandler, ProcessorConfigVersionTreeCacheHandler, ProcessorConfigTypesCacheHandler = (
|
||||
create_config_cache_handlers(
|
||||
config_type='processors',
|
||||
config_dir='config/processors',
|
||||
types_module=processor_types.PROCESSOR_TYPES
|
||||
))
|
||||
|
||||
PartnerServiceConfigCacheHandler, PartnerServiceConfigVersionTreeCacheHandler, PartnerServiceConfigTypesCacheHandler = (
|
||||
create_config_cache_handlers(
|
||||
config_type='partner_services',
|
||||
config_dir='config/partner_services',
|
||||
types_module=partner_service_types.PARTNER_SERVICE_TYPES
|
||||
))
|
||||
|
||||
CustomisationConfigCacheHandler, CustomisationConfigVersionTreeCacheHandler, CustomisationConfigTypesCacheHandler = (
|
||||
create_config_cache_handlers(
|
||||
config_type='customisations',
|
||||
config_dir='config/customisations',
|
||||
types_module=customisation_types.CUSTOMISATION_TYPES
|
||||
)
|
||||
)
|
||||
|
||||
SpecialistFormConfigCacheHandler, SpecialistFormConfigVersionTreeCacheHandler, SpecialistFormConfigTypesCacheHandler = (
|
||||
create_config_cache_handlers(
|
||||
config_type='specialist_forms',
|
||||
config_dir='config/specialist_forms',
|
||||
types_module=specialist_form_types.SPECIALIST_FORM_TYPES
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
CapsuleConfigCacheHandler, CapsuleConfigVersionTreeCacheHandler, CapsuleConfigTypesCacheHandler = (
|
||||
create_config_cache_handlers(
|
||||
config_type='data_capsules',
|
||||
config_dir='config/data_capsules',
|
||||
types_module=capsule_types.CAPSULE_TYPES
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def register_config_cache_handlers(cache_manager) -> None:
|
||||
cache_manager.register_handler(AgentConfigCacheHandler, 'eveai_config')
|
||||
cache_manager.register_handler(AgentConfigTypesCacheHandler, 'eveai_config')
|
||||
cache_manager.register_handler(AgentConfigVersionTreeCacheHandler, 'eveai_config')
|
||||
cache_manager.register_handler(TaskConfigCacheHandler, 'eveai_config')
|
||||
cache_manager.register_handler(TaskConfigTypesCacheHandler, 'eveai_config')
|
||||
cache_manager.register_handler(TaskConfigVersionTreeCacheHandler, 'eveai_config')
|
||||
cache_manager.register_handler(ToolConfigCacheHandler, 'eveai_config')
|
||||
cache_manager.register_handler(ToolConfigTypesCacheHandler, 'eveai_config')
|
||||
cache_manager.register_handler(ToolConfigVersionTreeCacheHandler, 'eveai_config')
|
||||
cache_manager.register_handler(SpecialistConfigCacheHandler, 'eveai_config')
|
||||
cache_manager.register_handler(SpecialistConfigTypesCacheHandler, 'eveai_config')
|
||||
cache_manager.register_handler(SpecialistConfigVersionTreeCacheHandler, 'eveai_config')
|
||||
cache_manager.register_handler(RetrieverConfigCacheHandler, 'eveai_config')
|
||||
cache_manager.register_handler(RetrieverConfigTypesCacheHandler, 'eveai_config')
|
||||
cache_manager.register_handler(RetrieverConfigVersionTreeCacheHandler, 'eveai_config')
|
||||
cache_manager.register_handler(PromptConfigCacheHandler, 'eveai_config')
|
||||
cache_manager.register_handler(PromptConfigVersionTreeCacheHandler, 'eveai_config')
|
||||
cache_manager.register_handler(PromptConfigTypesCacheHandler, 'eveai_config')
|
||||
cache_manager.register_handler(CatalogConfigCacheHandler, 'eveai_config')
|
||||
cache_manager.register_handler(CatalogConfigTypesCacheHandler, 'eveai_config')
|
||||
cache_manager.register_handler(CatalogConfigVersionTreeCacheHandler, 'eveai_config')
|
||||
cache_manager.register_handler(ProcessorConfigCacheHandler, 'eveai_config')
|
||||
cache_manager.register_handler(ProcessorConfigTypesCacheHandler, 'eveai_config')
|
||||
cache_manager.register_handler(ProcessorConfigVersionTreeCacheHandler, 'eveai_config')
|
||||
cache_manager.register_handler(AgentConfigCacheHandler, 'eveai_config')
|
||||
cache_manager.register_handler(AgentConfigTypesCacheHandler, 'eveai_config')
|
||||
cache_manager.register_handler(AgentConfigVersionTreeCacheHandler, 'eveai_config')
|
||||
cache_manager.register_handler(PartnerServiceConfigCacheHandler, 'eveai_config')
|
||||
cache_manager.register_handler(PartnerServiceConfigTypesCacheHandler, 'eveai_config')
|
||||
cache_manager.register_handler(PartnerServiceConfigVersionTreeCacheHandler, 'eveai_config')
|
||||
cache_manager.register_handler(CustomisationConfigCacheHandler, 'eveai_config')
|
||||
cache_manager.register_handler(CustomisationConfigTypesCacheHandler, 'eveai_config')
|
||||
cache_manager.register_handler(CustomisationConfigVersionTreeCacheHandler, 'eveai_config')
|
||||
cache_manager.register_handler(SpecialistFormConfigCacheHandler, 'eveai_config')
|
||||
cache_manager.register_handler(SpecialistFormConfigTypesCacheHandler, 'eveai_config')
|
||||
cache_manager.register_handler(SpecialistFormConfigVersionTreeCacheHandler, 'eveai_config')
|
||||
|
||||
cache_manager.agents_config_cache.set_version_tree_cache(cache_manager.agents_version_tree_cache)
|
||||
cache_manager.tasks_config_cache.set_version_tree_cache(cache_manager.tasks_version_tree_cache)
|
||||
cache_manager.tools_config_cache.set_version_tree_cache(cache_manager.tools_version_tree_cache)
|
||||
cache_manager.specialists_config_cache.set_version_tree_cache(cache_manager.specialists_version_tree_cache)
|
||||
cache_manager.retrievers_config_cache.set_version_tree_cache(cache_manager.retrievers_version_tree_cache)
|
||||
cache_manager.prompts_config_cache.set_version_tree_cache(cache_manager.prompts_version_tree_cache)
|
||||
cache_manager.catalogs_config_cache.set_version_tree_cache(cache_manager.catalogs_version_tree_cache)
|
||||
cache_manager.processors_config_cache.set_version_tree_cache(cache_manager.processors_version_tree_cache)
|
||||
cache_manager.partner_services_config_cache.set_version_tree_cache(cache_manager.partner_services_version_tree_cache)
|
||||
cache_manager.customisations_config_cache.set_version_tree_cache(cache_manager.customisations_version_tree_cache)
|
||||
cache_manager.specialist_forms_config_cache.set_version_tree_cache(cache_manager.specialist_forms_version_tree_cache)
|
||||
218
common/utils/cache/crewai_config_processor.py
vendored
Normal file
218
common/utils/cache/crewai_config_processor.py
vendored
Normal file
@@ -0,0 +1,218 @@
|
||||
from typing import Dict, Any, Type, TypeVar, List
|
||||
from abc import ABC, abstractmethod
|
||||
from flask import current_app
|
||||
|
||||
from common.extensions import cache_manager, db
|
||||
from common.models.interaction import EveAIAgent, EveAITask, EveAITool, Specialist
|
||||
from common.utils.cache.crewai_configuration import (
|
||||
ProcessedAgentConfig, ProcessedTaskConfig, ProcessedToolConfig,
|
||||
SpecialistProcessedConfig
|
||||
)
|
||||
|
||||
T = TypeVar('T') # For generic model types
|
||||
|
||||
|
||||
class BaseCrewAIConfigProcessor:
|
||||
"""Base processor for specialist configurations"""
|
||||
|
||||
# Standard mapping between model fields and template placeholders
|
||||
AGENT_FIELD_MAPPING = {
|
||||
'role': 'custom_role',
|
||||
'goal': 'custom_goal',
|
||||
'backstory': 'custom_backstory'
|
||||
}
|
||||
|
||||
TASK_FIELD_MAPPING = {
|
||||
'task_description': 'custom_description',
|
||||
'expected_output': 'custom_expected_output'
|
||||
}
|
||||
|
||||
def __init__(self, tenant_id: int, specialist_id: int):
|
||||
self.tenant_id = tenant_id
|
||||
self.specialist_id = specialist_id
|
||||
self.specialist = self._get_specialist()
|
||||
self.verbose = self._get_verbose_setting()
|
||||
|
||||
def _get_specialist(self) -> Specialist:
|
||||
"""Get specialist and verify existence"""
|
||||
specialist = Specialist.query.get(self.specialist_id)
|
||||
if not specialist:
|
||||
raise ValueError(f"Specialist {self.specialist_id} not found")
|
||||
return specialist
|
||||
|
||||
def _get_verbose_setting(self) -> bool:
|
||||
"""Get verbose setting from specialist"""
|
||||
return bool(self.specialist.tuning)
|
||||
|
||||
def _get_db_items(self, model_class: Type[T], type_list: List[str]) -> Dict[str, T]:
|
||||
"""Get database items of specified type"""
|
||||
items = (model_class.query
|
||||
.filter_by(specialist_id=self.specialist_id)
|
||||
.filter(model_class.type.in_(type_list))
|
||||
.all())
|
||||
return {item.type: item for item in items}
|
||||
|
||||
def _apply_replacements(self, text: str, replacements: Dict[str, str]) -> str:
|
||||
"""Apply text replacements to a string"""
|
||||
result = text
|
||||
for key, value in replacements.items():
|
||||
if value is not None: # Only replace if value exists
|
||||
placeholder = "{" + key + "}"
|
||||
result = result.replace(placeholder, str(value))
|
||||
return result
|
||||
|
||||
def _process_agent_configs(self, specialist_config: Dict[str, Any]) -> Dict[str, ProcessedAgentConfig]:
|
||||
"""Process all agent configurations"""
|
||||
agent_configs = {}
|
||||
|
||||
if 'agents' not in specialist_config:
|
||||
return agent_configs
|
||||
|
||||
# Get all DB agents at once
|
||||
agent_types = [agent_def['type'] for agent_def in specialist_config['agents']]
|
||||
db_agents = self._get_db_items(EveAIAgent, agent_types)
|
||||
|
||||
for agent_def in specialist_config['agents']:
|
||||
agent_type = agent_def['type']
|
||||
agent_type_lower = agent_type.lower()
|
||||
db_agent = db_agents.get(agent_type)
|
||||
|
||||
# Get full configuration
|
||||
config = cache_manager.agents_config_cache.get_config(
|
||||
agent_type,
|
||||
agent_def.get('version', '1.0')
|
||||
)
|
||||
|
||||
# Start with YAML values
|
||||
role = config['role']
|
||||
goal = config['goal']
|
||||
backstory = config['backstory']
|
||||
|
||||
# Apply DB values if they exist
|
||||
if db_agent:
|
||||
for model_field, placeholder in self.AGENT_FIELD_MAPPING.items():
|
||||
value = getattr(db_agent, model_field)
|
||||
if value:
|
||||
placeholder_text = "{" + placeholder + "}"
|
||||
role = role.replace(placeholder_text, value)
|
||||
goal = goal.replace(placeholder_text, value)
|
||||
backstory = backstory.replace(placeholder_text, value)
|
||||
|
||||
agent_configs[agent_type_lower] = ProcessedAgentConfig(
|
||||
role=role,
|
||||
goal=goal,
|
||||
backstory=backstory,
|
||||
name=agent_def.get('name') or config.get('name', agent_type_lower),
|
||||
type=agent_type,
|
||||
description=agent_def.get('description') or config.get('description'),
|
||||
verbose=self.verbose
|
||||
)
|
||||
|
||||
return agent_configs
|
||||
|
||||
def _process_task_configs(self, specialist_config: Dict[str, Any]) -> Dict[str, ProcessedTaskConfig]:
|
||||
"""Process all task configurations"""
|
||||
task_configs = {}
|
||||
|
||||
if 'tasks' not in specialist_config:
|
||||
return task_configs
|
||||
|
||||
# Get all DB tasks at once
|
||||
task_types = [task_def['type'] for task_def in specialist_config['tasks']]
|
||||
db_tasks = self._get_db_items(EveAITask, task_types)
|
||||
|
||||
for task_def in specialist_config['tasks']:
|
||||
task_type = task_def['type']
|
||||
task_type_lower = task_type.lower()
|
||||
db_task = db_tasks.get(task_type)
|
||||
|
||||
# Get full configuration
|
||||
config = cache_manager.tasks_config_cache.get_config(
|
||||
task_type,
|
||||
task_def.get('version', '1.0')
|
||||
)
|
||||
|
||||
# Start with YAML values
|
||||
task_description = config['task_description']
|
||||
expected_output = config['expected_output']
|
||||
|
||||
# Apply DB values if they exist
|
||||
if db_task:
|
||||
for model_field, placeholder in self.TASK_FIELD_MAPPING.items():
|
||||
value = getattr(db_task, model_field)
|
||||
if value:
|
||||
placeholder_text = "{" + placeholder + "}"
|
||||
task_description = task_description.replace(placeholder_text, value)
|
||||
expected_output = expected_output.replace(placeholder_text, value)
|
||||
|
||||
task_configs[task_type_lower] = ProcessedTaskConfig(
|
||||
task_description=task_description,
|
||||
expected_output=expected_output,
|
||||
name=task_def.get('name') or config.get('name', task_type_lower),
|
||||
type=task_type,
|
||||
description=task_def.get('description') or config.get('description'),
|
||||
verbose=self.verbose
|
||||
)
|
||||
|
||||
return task_configs
|
||||
|
||||
def _process_tool_configs(self, specialist_config: Dict[str, Any]) -> Dict[str, ProcessedToolConfig]:
|
||||
"""Process all tool configurations"""
|
||||
tool_configs = {}
|
||||
|
||||
if 'tools' not in specialist_config:
|
||||
return tool_configs
|
||||
|
||||
# Get all DB tools at once
|
||||
tool_types = [tool_def['type'] for tool_def in specialist_config['tools']]
|
||||
db_tools = self._get_db_items(EveAITool, tool_types)
|
||||
|
||||
for tool_def in specialist_config['tools']:
|
||||
tool_type = tool_def['type']
|
||||
tool_type_lower = tool_type.lower()
|
||||
db_tool = db_tools.get(tool_type)
|
||||
|
||||
# Get full configuration
|
||||
config = cache_manager.tools_config_cache.get_config(
|
||||
tool_type,
|
||||
tool_def.get('version', '1.0')
|
||||
)
|
||||
|
||||
# Combine configuration
|
||||
tool_config = config.get('configuration', {})
|
||||
if db_tool and db_tool.configuration:
|
||||
tool_config.update(db_tool.configuration)
|
||||
|
||||
tool_configs[tool_type_lower] = ProcessedToolConfig(
|
||||
name=tool_def.get('name') or config.get('name', tool_type_lower),
|
||||
type=tool_type,
|
||||
description=tool_def.get('description') or config.get('description'),
|
||||
configuration=tool_config,
|
||||
verbose=self.verbose
|
||||
)
|
||||
|
||||
return tool_configs
|
||||
|
||||
def process_config(self) -> SpecialistProcessedConfig:
|
||||
"""Process complete specialist configuration"""
|
||||
try:
|
||||
# Get full specialist configuration
|
||||
specialist_config = cache_manager.specialists_config_cache.get_config(
|
||||
self.specialist.type,
|
||||
self.specialist.type_version
|
||||
)
|
||||
|
||||
if not specialist_config:
|
||||
raise ValueError(f"No configuration found for {self.specialist.type}")
|
||||
|
||||
# Process all configurations
|
||||
processed_config = SpecialistProcessedConfig(
|
||||
agents=self._process_agent_configs(specialist_config),
|
||||
tasks=self._process_task_configs(specialist_config),
|
||||
tools=self._process_tool_configs(specialist_config)
|
||||
)
|
||||
return processed_config
|
||||
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error processing specialist configuration: {e}")
|
||||
raise
|
||||
126
common/utils/cache/crewai_configuration.py
vendored
Normal file
126
common/utils/cache/crewai_configuration.py
vendored
Normal file
@@ -0,0 +1,126 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProcessedAgentConfig:
|
||||
"""Processed and ready-to-use agent configuration"""
|
||||
role: str
|
||||
goal: str
|
||||
backstory: str
|
||||
name: str
|
||||
type: str
|
||||
description: Optional[str] = None
|
||||
verbose: bool = False
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for serialization"""
|
||||
return {
|
||||
'role': self.role,
|
||||
'goal': self.goal,
|
||||
'backstory': self.backstory,
|
||||
'name': self.name,
|
||||
'type': self.type,
|
||||
'description': self.description,
|
||||
'verbose': self.verbose
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'ProcessedAgentConfig':
|
||||
"""Create from dictionary"""
|
||||
return cls(**data)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProcessedTaskConfig:
|
||||
"""Processed and ready-to-use task configuration"""
|
||||
task_description: str
|
||||
expected_output: str
|
||||
name: str
|
||||
type: str
|
||||
description: Optional[str] = None
|
||||
verbose: bool = False
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for serialization"""
|
||||
return {
|
||||
'task_description': self.task_description,
|
||||
'expected_output': self.expected_output,
|
||||
'name': self.name,
|
||||
'type': self.type,
|
||||
'description': self.description,
|
||||
'verbose': self.verbose
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'ProcessedTaskConfig':
|
||||
"""Create from dictionary"""
|
||||
return cls(**data)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProcessedToolConfig:
|
||||
"""Processed and ready-to-use tool configuration"""
|
||||
name: str
|
||||
type: str
|
||||
description: Optional[str] = None
|
||||
configuration: Optional[Dict[str, Any]] = None
|
||||
verbose: bool = False
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for serialization"""
|
||||
return {
|
||||
'name': self.name,
|
||||
'type': self.type,
|
||||
'description': self.description,
|
||||
'configuration': self.configuration,
|
||||
'verbose': self.verbose
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'ProcessedToolConfig':
|
||||
"""Create from dictionary"""
|
||||
return cls(**data)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpecialistProcessedConfig:
|
||||
"""Complete processed configuration for a specialist"""
|
||||
agents: Dict[str, ProcessedAgentConfig]
|
||||
tasks: Dict[str, ProcessedTaskConfig]
|
||||
tools: Dict[str, ProcessedToolConfig]
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert entire configuration to dictionary"""
|
||||
return {
|
||||
'agents': {
|
||||
agent_type: config.to_dict()
|
||||
for agent_type, config in self.agents.items()
|
||||
},
|
||||
'tasks': {
|
||||
task_type: config.to_dict()
|
||||
for task_type, config in self.tasks.items()
|
||||
},
|
||||
'tools': {
|
||||
tool_type: config.to_dict()
|
||||
for tool_type, config in self.tools.items()
|
||||
}
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'SpecialistProcessedConfig':
|
||||
"""Create from dictionary"""
|
||||
return cls(
|
||||
agents={
|
||||
agent_type: ProcessedAgentConfig.from_dict(config)
|
||||
for agent_type, config in data['agents'].items()
|
||||
},
|
||||
tasks={
|
||||
task_type: ProcessedTaskConfig.from_dict(config)
|
||||
for task_type, config in data['tasks'].items()
|
||||
},
|
||||
tools={
|
||||
tool_type: ProcessedToolConfig.from_dict(config)
|
||||
for tool_type, config in data['tools'].items()
|
||||
}
|
||||
)
|
||||
75
common/utils/cache/crewai_processed_config_cache.py
vendored
Normal file
75
common/utils/cache/crewai_processed_config_cache.py
vendored
Normal file
@@ -0,0 +1,75 @@
|
||||
from typing import Dict, Any, Type
|
||||
from flask import current_app
|
||||
|
||||
from common.utils.cache.base import CacheHandler
|
||||
from common.utils.cache.crewai_configuration import SpecialistProcessedConfig
|
||||
from common.utils.cache.crewai_config_processor import BaseCrewAIConfigProcessor
|
||||
|
||||
|
||||
class CrewAIProcessedConfigCacheHandler(CacheHandler[SpecialistProcessedConfig]):
|
||||
"""Handles caching of processed specialist configurations"""
|
||||
handler_name = 'crewai_processed_config_cache'
|
||||
|
||||
def __init__(self, region):
|
||||
super().__init__(region, 'crewai_processed_config')
|
||||
self.configure_keys('tenant_id', 'specialist_id')
|
||||
|
||||
def _to_cache_data(self, instance: SpecialistProcessedConfig) -> Dict[str, Any]:
|
||||
"""Convert SpecialistProcessedConfig to cache data"""
|
||||
return instance.to_dict()
|
||||
|
||||
def _from_cache_data(self, data: Dict[str, Any], **kwargs) -> SpecialistProcessedConfig:
|
||||
"""Create SpecialistProcessedConfig from cache data"""
|
||||
return SpecialistProcessedConfig.from_dict(data)
|
||||
|
||||
def _should_cache(self, value: Dict[str, Any]) -> bool:
|
||||
"""Validate cache data"""
|
||||
required_keys = {'agents', 'tasks', 'tools'}
|
||||
if not all(key in value for key in required_keys):
|
||||
current_app.logger.warning(f'CrewAI Processed Config Cache missing required keys: {required_keys}')
|
||||
return False
|
||||
return bool(value['agents'] or value['tasks'])
|
||||
|
||||
def get_specialist_config(self, tenant_id: int, specialist_id: int) -> SpecialistProcessedConfig:
|
||||
"""
|
||||
Get or create processed configuration for a specialist
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID
|
||||
specialist_id: Specialist ID
|
||||
|
||||
Returns:
|
||||
Processed specialist configuration
|
||||
|
||||
Raises:
|
||||
ValueError: If specialist not found or processor not configured
|
||||
"""
|
||||
|
||||
def creator_func(tenant_id: int, specialist_id: int) -> SpecialistProcessedConfig:
|
||||
# Create processor instance and process config
|
||||
processor = BaseCrewAIConfigProcessor(tenant_id, specialist_id)
|
||||
return processor.process_config()
|
||||
|
||||
return self.get(
|
||||
creator_func,
|
||||
tenant_id=tenant_id,
|
||||
specialist_id=specialist_id
|
||||
)
|
||||
|
||||
def invalidate_tenant_specialist(self, tenant_id: int, specialist_id: int):
|
||||
"""Invalidate cache for a specific tenant's specialist"""
|
||||
self.invalidate(
|
||||
tenant_id=tenant_id,
|
||||
specialist_id=specialist_id
|
||||
)
|
||||
current_app.logger.info(
|
||||
f"Invalidated cache for tenant {tenant_id} specialist {specialist_id}"
|
||||
)
|
||||
|
||||
|
||||
def register_specialist_cache_handlers(cache_manager) -> None:
|
||||
"""Register specialist cache handlers with cache manager"""
|
||||
cache_manager.register_handler(
|
||||
CrewAIProcessedConfigCacheHandler,
|
||||
'eveai_chat_workers'
|
||||
)
|
||||
51
common/utils/cache/eveai_cache_manager.py
vendored
Normal file
51
common/utils/cache/eveai_cache_manager.py
vendored
Normal file
@@ -0,0 +1,51 @@
|
||||
from typing import Type
|
||||
|
||||
from flask import Flask
|
||||
|
||||
from common.utils.cache.base import CacheHandler
|
||||
from common.utils.cache.regions import create_cache_regions
|
||||
from common.utils.cache.config_cache import AgentConfigCacheHandler
|
||||
|
||||
|
||||
class EveAICacheManager:
|
||||
"""Cache manager with registration capabilities"""
|
||||
|
||||
def __init__(self):
|
||||
self._regions = {}
|
||||
self._handlers = {}
|
||||
self._handler_instances = {}
|
||||
|
||||
def init_app(self, app: Flask):
|
||||
"""Initialize cache regions"""
|
||||
self._regions = create_cache_regions(app)
|
||||
|
||||
# Store regions in instance
|
||||
for region_name, region in self._regions.items():
|
||||
setattr(self, f"{region_name}_region", region)
|
||||
|
||||
app.logger.info(f'Cache regions initialized: {self._regions.keys()}')
|
||||
|
||||
def register_handler(self, handler_class: Type[CacheHandler], region: str):
|
||||
"""Register a cache handler class with its region"""
|
||||
if not hasattr(handler_class, 'handler_name'):
|
||||
raise ValueError("Cache handler must define handler_name class attribute")
|
||||
self._handlers[handler_class] = region
|
||||
|
||||
# Create handler instance
|
||||
region_instance = self._regions[region]
|
||||
handler_instance = handler_class(region_instance)
|
||||
self._handler_instances[handler_class.handler_name] = handler_instance
|
||||
|
||||
def invalidate_region(self, region_name: str):
|
||||
"""Invalidate an entire cache region"""
|
||||
if region_name in self._regions:
|
||||
self._regions[region_name].invalidate()
|
||||
else:
|
||||
raise ValueError(f"Unknown cache region: {region_name}")
|
||||
|
||||
def __getattr__(self, name):
|
||||
"""Handle dynamic access to registered handlers"""
|
||||
instances = object.__getattribute__(self, '_handler_instances')
|
||||
if name in instances:
|
||||
return instances[name]
|
||||
raise AttributeError(f"'EveAICacheManager' object has no attribute '{name}'")
|
||||
102
common/utils/cache/license_cache.py
vendored
Normal file
102
common/utils/cache/license_cache.py
vendored
Normal file
@@ -0,0 +1,102 @@
|
||||
# common/utils/cache/license_cache.py
|
||||
from typing import Dict, Any, Optional
|
||||
from datetime import datetime as dt, timezone as tz
|
||||
|
||||
from flask import current_app
|
||||
from sqlalchemy import and_
|
||||
from sqlalchemy.inspection import inspect
|
||||
|
||||
from common.utils.cache.base import CacheHandler
|
||||
from common.models.entitlements import License
|
||||
|
||||
|
||||
class LicenseCacheHandler(CacheHandler[License]):
|
||||
"""Handles caching of active licenses for tenants"""
|
||||
handler_name = 'license_cache'
|
||||
|
||||
def __init__(self, region):
|
||||
super().__init__(region, 'active_license')
|
||||
self.configure_keys('tenant_id')
|
||||
|
||||
def _to_cache_data(self, instance: License) -> Dict[str, Any]:
|
||||
"""Convert License instance to cache data using SQLAlchemy inspection"""
|
||||
if not instance:
|
||||
return {}
|
||||
|
||||
# Get all column attributes from the SQLAlchemy model
|
||||
mapper = inspect(License)
|
||||
data = {}
|
||||
|
||||
for column in mapper.columns:
|
||||
value = getattr(instance, column.name)
|
||||
|
||||
# Handle date serialization
|
||||
if isinstance(value, dt):
|
||||
data[column.name] = value.isoformat()
|
||||
else:
|
||||
data[column.name] = value
|
||||
|
||||
return data
|
||||
|
||||
def _from_cache_data(self, data: Dict[str, Any], **kwargs) -> License:
|
||||
"""Create License instance from cache data using SQLAlchemy inspection"""
|
||||
if not data:
|
||||
return None
|
||||
|
||||
# Create a new License instance
|
||||
license = License()
|
||||
mapper = inspect(License)
|
||||
|
||||
# Set all attributes dynamically
|
||||
for column in mapper.columns:
|
||||
if column.name in data:
|
||||
value = data[column.name]
|
||||
|
||||
# Handle date deserialization
|
||||
if column.name.endswith('_date') and value:
|
||||
if isinstance(value, str):
|
||||
value = dt.fromisoformat(value).date()
|
||||
|
||||
setattr(license, column.name, value)
|
||||
|
||||
return license
|
||||
|
||||
def _should_cache(self, value: License) -> bool:
|
||||
"""Validate if the license should be cached"""
|
||||
return value is not None and value.id is not None
|
||||
|
||||
def get_active_license(self, tenant_id: int) -> Optional[License]:
|
||||
"""
|
||||
Get the currently active license for a tenant
|
||||
|
||||
Args:
|
||||
tenant_id: ID of the tenant
|
||||
|
||||
Returns:
|
||||
License instance if found, None otherwise
|
||||
"""
|
||||
|
||||
def creator_func(tenant_id: int) -> Optional[License]:
|
||||
from common.extensions import db
|
||||
current_date = dt.now(tz=tz.utc).date()
|
||||
|
||||
# TODO --> Active License via active Period?
|
||||
|
||||
return (db.session.query(License)
|
||||
.filter_by(tenant_id=tenant_id)
|
||||
.filter(License.start_date <= current_date)
|
||||
.last())
|
||||
|
||||
return self.get(creator_func, tenant_id=tenant_id)
|
||||
|
||||
def invalidate_tenant_license(self, tenant_id: int):
|
||||
"""Invalidate cached license for specific tenant"""
|
||||
self.invalidate(tenant_id=tenant_id)
|
||||
|
||||
|
||||
def register_license_cache_handlers(cache_manager) -> None:
|
||||
"""Register license cache handlers with cache manager"""
|
||||
cache_manager.register_handler(
|
||||
LicenseCacheHandler,
|
||||
'eveai_model' # Use existing eveai_model region
|
||||
)
|
||||
90
common/utils/cache/regions.py
vendored
Normal file
90
common/utils/cache/regions.py
vendored
Normal file
@@ -0,0 +1,90 @@
|
||||
# common/utils/cache/regions.py
|
||||
import time
|
||||
|
||||
import redis
|
||||
from dogpile.cache import make_region
|
||||
import ssl
|
||||
|
||||
def get_redis_config(app):
|
||||
"""
|
||||
Create Redis configuration dict based on app config.
|
||||
Handles both authenticated and non-authenticated setups.
|
||||
"""
|
||||
app.logger.debug(f"Creating Redis config")
|
||||
# Parse the REDIS_BASE_URI to get all components
|
||||
# redis_uri = urlparse(app.config['REDIS_BASE_URI'])
|
||||
|
||||
config = {
|
||||
'host': app.config['REDIS_URL'],
|
||||
'port': app.config['REDIS_PORT'],
|
||||
'max_connections': 20,
|
||||
'retry_on_timeout': True,
|
||||
'socket_keepalive': True,
|
||||
'socket_keepalive_options': {},
|
||||
}
|
||||
|
||||
# Add authentication if provided
|
||||
un = app.config.get('REDIS_USER')
|
||||
pw = app.config.get('REDIS_PASS')
|
||||
if un and pw:
|
||||
config.update({
|
||||
'username': un,
|
||||
'password': pw
|
||||
})
|
||||
|
||||
# SSL support using centralised config
|
||||
cert_path = app.config.get('REDIS_CA_CERT_PATH')
|
||||
redis_scheme = app.config.get('REDIS_SCHEME')
|
||||
if cert_path and redis_scheme == 'rediss':
|
||||
config.update({
|
||||
'connection_class': redis.SSLConnection,
|
||||
'ssl_cert_reqs': ssl.CERT_REQUIRED,
|
||||
'ssl_check_hostname': app.config.get('REDIS_SSL_CHECK_HOSTNAME', True),
|
||||
'ssl_ca_certs': cert_path,
|
||||
})
|
||||
|
||||
app.logger.debug(f"config for Redis connection: {config}")
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def create_cache_regions(app):
|
||||
"""Initialise all cache regions with app config"""
|
||||
redis_config = get_redis_config(app)
|
||||
redis_pool = redis.ConnectionPool(**redis_config)
|
||||
regions = {}
|
||||
startup_time = int(time.time())
|
||||
|
||||
# Region for model-related caching (ModelVariables etc)
|
||||
model_region = make_region(name='eveai_model').configure(
|
||||
'dogpile.cache.redis',
|
||||
arguments={'connection_pool': redis_pool},
|
||||
replace_existing_backend=True
|
||||
)
|
||||
regions['eveai_model'] = model_region
|
||||
|
||||
# Region for eveai_chat_workers components (Specialists, Retrievers, ...)
|
||||
eveai_chat_workers_region = make_region(name='eveai_chat_workers').configure(
|
||||
'dogpile.cache.redis',
|
||||
arguments={'connection_pool': redis_pool},
|
||||
replace_existing_backend=True
|
||||
)
|
||||
regions['eveai_chat_workers'] = eveai_chat_workers_region
|
||||
|
||||
# Region for eveai_workers components (Processors, ...)
|
||||
eveai_workers_region = make_region(name='eveai_workers').configure(
|
||||
'dogpile.cache.redis',
|
||||
arguments={'connection_pool': redis_pool}, # Same config for now
|
||||
replace_existing_backend=True
|
||||
)
|
||||
regions['eveai_workers'] = eveai_workers_region
|
||||
|
||||
eveai_config_region = make_region(name='eveai_config').configure(
|
||||
'dogpile.cache.redis',
|
||||
arguments={'connection_pool': redis_pool},
|
||||
replace_existing_backend=True
|
||||
)
|
||||
regions['eveai_config'] = eveai_config_region
|
||||
|
||||
return regions
|
||||
|
||||
223
common/utils/cache/translation_cache.py
vendored
Normal file
223
common/utils/cache/translation_cache.py
vendored
Normal file
@@ -0,0 +1,223 @@
|
||||
import json
|
||||
import re
|
||||
from typing import Dict, Any, Optional
|
||||
from datetime import datetime as dt, timezone as tz
|
||||
|
||||
import xxhash
|
||||
from flask import current_app
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_core.runnables import RunnablePassthrough
|
||||
from sqlalchemy.inspection import inspect
|
||||
|
||||
from common.langchain.persistent_llm_metrics_handler import PersistentLLMMetricsHandler
|
||||
from common.utils.business_event_context import current_event
|
||||
from common.utils.cache.base import CacheHandler, T
|
||||
from common.extensions import db
|
||||
|
||||
from common.models.user import TranslationCache
|
||||
from flask_security import current_user
|
||||
|
||||
from common.utils.model_utils import get_template
|
||||
|
||||
|
||||
class TranslationCacheHandler(CacheHandler[TranslationCache]):
|
||||
"""Handles caching of translations with fallback to database and external translation service"""
|
||||
handler_name = 'translation_cache'
|
||||
|
||||
def __init__(self, region):
|
||||
super().__init__(region, 'translation')
|
||||
self.configure_keys('hash_key')
|
||||
|
||||
def _to_cache_data(self, instance: TranslationCache) -> Dict[str, Any]:
|
||||
"""Convert TranslationCache instance to cache data using SQLAlchemy inspection"""
|
||||
if not instance:
|
||||
return {}
|
||||
|
||||
mapper = inspect(TranslationCache)
|
||||
data = {}
|
||||
|
||||
for column in mapper.columns:
|
||||
value = getattr(instance, column.name)
|
||||
|
||||
# Handle date serialization
|
||||
if isinstance(value, dt):
|
||||
data[column.name] = value.isoformat()
|
||||
else:
|
||||
data[column.name] = value
|
||||
|
||||
return data
|
||||
|
||||
def _from_cache_data(self, data: Dict[str, Any], **kwargs) -> TranslationCache:
|
||||
if not data:
|
||||
return None
|
||||
|
||||
# Create a new TranslationCache instance
|
||||
translation = TranslationCache()
|
||||
mapper = inspect(TranslationCache)
|
||||
|
||||
# Set all attributes dynamically
|
||||
for column in mapper.columns:
|
||||
if column.name in data:
|
||||
value = data[column.name]
|
||||
|
||||
# Handle date deserialization
|
||||
if column.name.endswith('_date') and value:
|
||||
if isinstance(value, str):
|
||||
value = dt.fromisoformat(value).date()
|
||||
|
||||
setattr(translation, column.name, value)
|
||||
|
||||
metrics = {
|
||||
'total_tokens': translation.prompt_tokens + translation.completion_tokens,
|
||||
'prompt_tokens': translation.prompt_tokens,
|
||||
'completion_tokens': translation.completion_tokens,
|
||||
'time_elapsed': 0,
|
||||
'interaction_type': 'TRANSLATION-CACHE'
|
||||
}
|
||||
current_event.log_llm_metrics(metrics)
|
||||
|
||||
return translation
|
||||
|
||||
def _should_cache(self, value) -> bool:
|
||||
"""Validate if the translation should be cached"""
|
||||
if value is None:
|
||||
return False
|
||||
|
||||
# Handle both TranslationCache objects and serialized data (dict)
|
||||
if isinstance(value, TranslationCache):
|
||||
return value.cache_key is not None
|
||||
elif isinstance(value, dict):
|
||||
return value.get('cache_key') is not None
|
||||
|
||||
return False
|
||||
|
||||
def get_translation(self, text: str, target_lang: str, source_lang: str = None, context: str = None) -> Optional[
|
||||
TranslationCache]:
|
||||
"""
|
||||
Get the translation for a text in a specific language
|
||||
|
||||
Args:
|
||||
text: The text to be translated
|
||||
target_lang: The target language for the translation
|
||||
source_lang: The source language of the text to be translated
|
||||
context: Optional context for the translation
|
||||
|
||||
Returns:
|
||||
TranslationCache instance if found, None otherwise
|
||||
"""
|
||||
if not context:
|
||||
context = 'No context provided.'
|
||||
|
||||
def creator_func(hash_key: str) -> Optional[TranslationCache]:
|
||||
# Check if translation already exists in database
|
||||
existing_translation = db.session.query(TranslationCache).filter_by(cache_key=hash_key).first()
|
||||
|
||||
if existing_translation:
|
||||
# Update last used timestamp
|
||||
existing_translation.last_used_at = dt.now(tz=tz.utc)
|
||||
metrics = {
|
||||
'total_tokens': existing_translation.prompt_tokens + existing_translation.completion_tokens,
|
||||
'prompt_tokens': existing_translation.prompt_tokens,
|
||||
'completion_tokens': existing_translation.completion_tokens,
|
||||
'time_elapsed': 0,
|
||||
'interaction_type': 'TRANSLATION-DB'
|
||||
}
|
||||
current_event.log_llm_metrics(metrics)
|
||||
db.session.commit()
|
||||
return existing_translation
|
||||
|
||||
# Translation not found in DB, need to create it
|
||||
# Get the translation and metrics
|
||||
translated_text, metrics = self.translate_text(
|
||||
text_to_translate=text,
|
||||
target_lang=target_lang,
|
||||
source_lang=source_lang,
|
||||
context=context
|
||||
)
|
||||
|
||||
# Create new translation cache record
|
||||
new_translation = TranslationCache(
|
||||
cache_key=hash_key,
|
||||
source_text=text,
|
||||
translated_text=translated_text,
|
||||
source_language=source_lang,
|
||||
target_language=target_lang,
|
||||
context=context,
|
||||
prompt_tokens=metrics.get('prompt_tokens', 0),
|
||||
completion_tokens=metrics.get('completion_tokens', 0),
|
||||
created_at=dt.now(tz=tz.utc),
|
||||
created_by=getattr(current_user, 'id', None) if 'current_user' in globals() else None,
|
||||
updated_at=dt.now(tz=tz.utc),
|
||||
updated_by=getattr(current_user, 'id', None) if 'current_user' in globals() else None,
|
||||
last_used_at=dt.now(tz=tz.utc)
|
||||
)
|
||||
|
||||
# Save to database
|
||||
db.session.add(new_translation)
|
||||
db.session.commit()
|
||||
|
||||
return new_translation
|
||||
|
||||
# Generate the hash key using your existing method
|
||||
hash_key = self._generate_cache_key(text, target_lang, source_lang, context)
|
||||
|
||||
# Pass the hash_key to the get method
|
||||
return self.get(creator_func, hash_key=hash_key)
|
||||
|
||||
def invalidate_tenant_translations(self, tenant_id: int):
|
||||
"""Invalidate cached translations for specific tenant"""
|
||||
self.invalidate(tenant_id=tenant_id)
|
||||
|
||||
def _generate_cache_key(self, text: str, target_lang: str, source_lang: str = None, context: str = None) -> str:
|
||||
"""Generate cache key for a translation"""
|
||||
cache_data = {
|
||||
"text": text.strip(),
|
||||
"target_lang": target_lang.lower(),
|
||||
"source_lang": source_lang.lower() if source_lang else None,
|
||||
"context": context.strip() if context else None,
|
||||
}
|
||||
|
||||
cache_string = json.dumps(cache_data, sort_keys=True, ensure_ascii=False)
|
||||
return xxhash.xxh64(cache_string.encode('utf-8')).hexdigest()
|
||||
|
||||
def translate_text(self, text_to_translate: str, target_lang: str, source_lang: str = None, context: str = None) \
|
||||
-> tuple[str, dict[str, int | float]]:
|
||||
target_language = current_app.config['SUPPORTED_LANGUAGE_ISO639_1_LOOKUP'][target_lang]
|
||||
prompt_params = {
|
||||
"text_to_translate": text_to_translate,
|
||||
"target_language": target_language,
|
||||
}
|
||||
if context:
|
||||
template, llm = get_template("translation_with_context")
|
||||
prompt_params["context"] = context
|
||||
else:
|
||||
template, llm = get_template("translation_without_context")
|
||||
|
||||
# Add a metrics handler to capture usage
|
||||
|
||||
metrics_handler = PersistentLLMMetricsHandler()
|
||||
existing_callbacks = llm.callbacks
|
||||
llm.callbacks = existing_callbacks + [metrics_handler]
|
||||
|
||||
translation_prompt = ChatPromptTemplate.from_template(template)
|
||||
|
||||
setup = RunnablePassthrough()
|
||||
|
||||
chain = (setup | translation_prompt | llm | StrOutputParser())
|
||||
|
||||
translation = chain.invoke(prompt_params)
|
||||
|
||||
# Remove double square brackets from translation
|
||||
translation = re.sub(r'\[\[(.*?)\]\]', r'\1', translation)
|
||||
|
||||
metrics = metrics_handler.get_metrics()
|
||||
|
||||
return translation, metrics
|
||||
|
||||
def register_translation_cache_handlers(cache_manager) -> None:
|
||||
"""Register translation cache handlers with cache manager"""
|
||||
cache_manager.register_handler(
|
||||
TranslationCacheHandler,
|
||||
'eveai_model' # Use existing eveai_model region
|
||||
)
|
||||
@@ -1,35 +1,97 @@
|
||||
import ssl
|
||||
|
||||
from celery import Celery
|
||||
from kombu import Queue
|
||||
from werkzeug.local import LocalProxy
|
||||
from redbeat import RedBeatScheduler
|
||||
|
||||
celery_app = Celery()
|
||||
|
||||
|
||||
def init_celery(celery, app):
|
||||
def init_celery(celery, app, is_beat=False):
|
||||
celery_app.main = app.name
|
||||
app.logger.debug(f'CELERY_BROKER_URL: {app.config["CELERY_BROKER_URL"]}')
|
||||
app.logger.debug(f'CELERY_RESULT_BACKEND: {app.config["CELERY_RESULT_BACKEND"]}')
|
||||
|
||||
celery_config = {
|
||||
'broker_url': app.config.get('CELERY_BROKER_URL', 'redis://localhost:6379/0'),
|
||||
'broker_url': app.config.get('CELERY_BROKER_URL', 'redis://localhost:6379/0'),
|
||||
'result_backend': app.config.get('CELERY_RESULT_BACKEND', 'redis://localhost:6379/0'),
|
||||
'task_serializer': app.config.get('CELERY_TASK_SERIALIZER', 'json'),
|
||||
'result_serializer': app.config.get('CELERY_RESULT_SERIALIZER', 'json'),
|
||||
'accept_content': app.config.get('CELERY_ACCEPT_CONTENT', ['json']),
|
||||
'timezone': app.config.get('CELERY_TIMEZONE', 'UTC'),
|
||||
'enable_utc': app.config.get('CELERY_ENABLE_UTC', True),
|
||||
'task_routes': {'eveai_worker.tasks.create_embeddings': {'queue': 'embeddings',
|
||||
'routing_key': 'embeddings.create_embeddings'}},
|
||||
# connection pools
|
||||
# 'broker_pool_limit': app.config.get('CELERY_BROKER_POOL_LIMIT', 10),
|
||||
}
|
||||
|
||||
# Transport options (timeouts, max_connections for Redis transport)
|
||||
# broker_transport_options = {
|
||||
# 'master_name': None, # only relevant for Sentinel; otherwise harmless
|
||||
# 'max_connections': 20,
|
||||
# 'retry_on_timeout': True,
|
||||
# 'socket_connect_timeout': 5,
|
||||
# 'socket_timeout': 5,
|
||||
# }
|
||||
# celery_config['broker_transport_options'] = broker_transport_options
|
||||
#
|
||||
# # Backend transport options (Redis backend accepts similar timeouts)
|
||||
# result_backend_transport_options = {
|
||||
# 'retry_on_timeout': True,
|
||||
# 'socket_connect_timeout': 5,
|
||||
# 'socket_timeout': 5,
|
||||
# # max_connections may be supported on newer Celery/redis backends; harmless if ignored
|
||||
# 'max_connections': 20,
|
||||
# }
|
||||
# celery_config['result_backend_transport_options'] = result_backend_transport_options
|
||||
|
||||
# TLS (only when cert is provided or your URLs are rediss://)
|
||||
ssl_opts = None
|
||||
cert_path = app.config.get('REDIS_CA_CERT_PATH')
|
||||
if cert_path:
|
||||
ssl_opts = {
|
||||
'ssl_cert_reqs': ssl.CERT_REQUIRED,
|
||||
'ssl_ca_certs': cert_path,
|
||||
'ssl_check_hostname': app.config.get('REDIS_SSL_CHECK_HOSTNAME', True),
|
||||
}
|
||||
app.logger.info(
|
||||
"SSL configured for Celery Redis connection (CA: %s, hostname-check: %s)",
|
||||
cert_path,
|
||||
'enabled' if app.config.get('REDIS_SSL_CHECK_HOSTNAME', True) else 'disabled (IP)'
|
||||
)
|
||||
celery_config['broker_use_ssl'] = ssl_opts
|
||||
celery_config['redis_backend_use_ssl'] = ssl_opts
|
||||
|
||||
# Beat/RedBeat
|
||||
if is_beat:
|
||||
celery_config['beat_scheduler'] = 'redbeat.RedBeatScheduler'
|
||||
celery_config['redbeat_lock_key'] = 'redbeat::lock'
|
||||
celery_config['beat_max_loop_interval'] = 10
|
||||
|
||||
celery_app.conf.update(**celery_config)
|
||||
|
||||
# Setting up Celery task queues
|
||||
celery_app.conf.task_queues = (
|
||||
Queue('default', routing_key='task.#'),
|
||||
Queue('embeddings', routing_key='embeddings.#', queue_arguments={'x-max-priority': 10}),
|
||||
Queue('llm_interactions', routing_key='llm_interactions.#', queue_arguments={'x-max-priority': 5}),
|
||||
)
|
||||
# Queues for workers (note: Redis ignores routing_key and priority features like RabbitMQ)
|
||||
if not is_beat:
|
||||
celery_app.conf.task_queues = (
|
||||
Queue('default', routing_key='task.#'),
|
||||
Queue('embeddings', routing_key='embeddings.#', queue_arguments={'x-max-priority': 10}),
|
||||
Queue('llm_interactions', routing_key='llm_interactions.#', queue_arguments={'x-max-priority': 5}),
|
||||
Queue('entitlements', routing_key='entitlements.#', queue_arguments={'x-max-priority': 10}),
|
||||
)
|
||||
celery_app.conf.task_routes = {
|
||||
'eveai_workers.*': { # All tasks from eveai_workers module
|
||||
'queue': 'embeddings',
|
||||
'routing_key': 'embeddings.#',
|
||||
},
|
||||
'eveai_chat_workers.*': { # All tasks from eveai_chat_workers module
|
||||
'queue': 'llm_interactions',
|
||||
'routing_key': 'llm_interactions.#',
|
||||
},
|
||||
'eveai_entitlements.*': { # All tasks from eveai_entitlements module
|
||||
'queue': 'entitlements',
|
||||
'routing_key': 'entitlements.#',
|
||||
}
|
||||
}
|
||||
|
||||
# Ensuring tasks execute with Flask application context
|
||||
# Ensure tasks execute with Flask context
|
||||
class ContextTask(celery.Task):
|
||||
def __call__(self, *args, **kwargs):
|
||||
with app.app_context():
|
||||
@@ -39,6 +101,7 @@ def init_celery(celery, app):
|
||||
|
||||
|
||||
def make_celery(app_name, config):
|
||||
# keep API but return the single instance
|
||||
return celery_app
|
||||
|
||||
|
||||
@@ -46,4 +109,4 @@ def _get_current_celery():
|
||||
return celery_app
|
||||
|
||||
|
||||
current_celery = LocalProxy(_get_current_celery)
|
||||
current_celery = LocalProxy(_get_current_celery)
|
||||
175
common/utils/chat_utils.py
Normal file
175
common/utils/chat_utils.py
Normal file
@@ -0,0 +1,175 @@
|
||||
import json
|
||||
import re
|
||||
|
||||
"""
|
||||
Utility functions for chat customization.
|
||||
"""
|
||||
from flask import current_app
|
||||
|
||||
|
||||
def get_default_chat_customisation(tenant_customisation=None):
|
||||
"""
|
||||
Get chat customization options with default values for missing options.
|
||||
|
||||
Args:
|
||||
tenant_customisation (dict or str, optional): The tenant's customization options.
|
||||
Defaults to None. Can be a dict or a JSON string.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing all customization options with default values
|
||||
for any missing options.
|
||||
"""
|
||||
# Default customization options
|
||||
default_customisation = {
|
||||
'sidebar_markdown': '',
|
||||
'sidebar_color': '#f8f9fa',
|
||||
'sidebar_background': '#2c3e50',
|
||||
'markdown_background_color': 'transparent',
|
||||
'markdown_text_color': '#ffffff',
|
||||
'gradient_start_color': '#f5f7fa',
|
||||
'gradient_end_color': '#c3cfe2',
|
||||
'progress_tracker_insights': 'No Information',
|
||||
'form_title_display': 'Full Title',
|
||||
'active_background_color': '#ffffff',
|
||||
'history_background': 10,
|
||||
'ai_message_background': '#ffffff',
|
||||
'ai_message_text_color': '#212529',
|
||||
'human_message_background': '#212529',
|
||||
'human_message_text_color': '#ffffff',
|
||||
'human_message_inactive_text_color': '#808080',
|
||||
'tab_background': '#0a0a0a',
|
||||
'tab_icon_active_color': '#ffffff',
|
||||
'tab_icon_inactive_color': '#f0f0f0',
|
||||
}
|
||||
|
||||
# If no tenant customization is provided, return the defaults
|
||||
if tenant_customisation is None:
|
||||
return default_customisation
|
||||
|
||||
# Start with the default customization
|
||||
customisation = default_customisation.copy()
|
||||
|
||||
# Convert JSON string to dict if needed
|
||||
if isinstance(tenant_customisation, str):
|
||||
try:
|
||||
tenant_customisation = json.loads(tenant_customisation)
|
||||
except json.JSONDecodeError as e:
|
||||
current_app.logger.error(f"Error parsing JSON customisation: {e}")
|
||||
return default_customisation
|
||||
|
||||
# Update with tenant customization
|
||||
if tenant_customisation:
|
||||
for key, value in tenant_customisation.items():
|
||||
if key in customisation:
|
||||
customisation[key] = value
|
||||
|
||||
return customisation
|
||||
|
||||
|
||||
def hex_to_rgb(hex_color):
|
||||
"""
|
||||
Convert hex color to RGB tuple.
|
||||
|
||||
Args:
|
||||
hex_color (str): Hex color string (e.g., '#ffffff' or 'ffffff')
|
||||
|
||||
Returns:
|
||||
tuple: RGB values as (r, g, b)
|
||||
"""
|
||||
# Remove # if present
|
||||
hex_color = hex_color.lstrip('#')
|
||||
|
||||
# Handle 3-character hex codes
|
||||
if len(hex_color) == 3:
|
||||
hex_color = ''.join([c*2 for c in hex_color])
|
||||
|
||||
# Convert to RGB
|
||||
try:
|
||||
return tuple(int(hex_color[i:i+2], 16) for i in (0, 2, 4))
|
||||
except ValueError:
|
||||
# Return white as fallback
|
||||
return (255, 255, 255)
|
||||
|
||||
|
||||
def adjust_color_alpha(percentage):
|
||||
"""
|
||||
Convert percentage to RGBA color with appropriate base color and alpha.
|
||||
|
||||
Args:
|
||||
percentage (int): Percentage (-50 to 50)
|
||||
Positive = white base (lighten)
|
||||
Negative = black base (darken)
|
||||
Zero = transparent
|
||||
|
||||
Returns:
|
||||
str: RGBA color string for CSS
|
||||
"""
|
||||
if percentage == 0:
|
||||
return 'rgba(255, 255, 255, 0)' # Volledig transparant
|
||||
|
||||
# Bepaal basis kleur
|
||||
if percentage > 0:
|
||||
# Positief = wit voor verheldering
|
||||
base_color = (255, 255, 255)
|
||||
else:
|
||||
# Negatief = zwart voor verdonkering
|
||||
base_color = (0, 0, 0)
|
||||
|
||||
# Bereken alpha op basis van percentage (max 50 = alpha 1.0)
|
||||
alpha = abs(percentage) / 50.0
|
||||
alpha = max(0.0, min(1.0, alpha)) # Zorg voor 0.0-1.0 range
|
||||
|
||||
return f'rgba({base_color[0]}, {base_color[1]}, {base_color[2]}, {alpha})'
|
||||
|
||||
|
||||
def adjust_color_brightness(hex_color, percentage):
|
||||
"""
|
||||
Adjust the brightness of a hex color by a percentage.
|
||||
|
||||
Args:
|
||||
hex_color (str): Hex color string (e.g., '#ffffff')
|
||||
percentage (int): Percentage to adjust (-100 to 100)
|
||||
Positive = lighter, Negative = darker
|
||||
|
||||
Returns:
|
||||
str: RGBA color string for CSS (e.g., 'rgba(255, 255, 255, 0.9)')
|
||||
"""
|
||||
if not hex_color or not isinstance(hex_color, str):
|
||||
return 'rgba(255, 255, 255, 0.1)'
|
||||
|
||||
# Get RGB values
|
||||
r, g, b = hex_to_rgb(hex_color)
|
||||
|
||||
# Calculate adjustment factor
|
||||
if percentage > 0:
|
||||
# Lighten: move towards white
|
||||
factor = percentage / 100.0
|
||||
r = int(r + (255 - r) * factor)
|
||||
g = int(g + (255 - g) * factor)
|
||||
b = int(b + (255 - b) * factor)
|
||||
else:
|
||||
# Darken: move towards black
|
||||
factor = abs(percentage) / 100.0
|
||||
r = int(r * (1 - factor))
|
||||
g = int(g * (1 - factor))
|
||||
b = int(b * (1 - factor))
|
||||
|
||||
# Ensure values are within 0-255 range
|
||||
r = max(0, min(255, r))
|
||||
g = max(0, min(255, g))
|
||||
b = max(0, min(255, b))
|
||||
|
||||
# Return as rgba with slight transparency for better blending
|
||||
return f'rgba({r}, {g}, {b}, 0.9)'
|
||||
|
||||
|
||||
def get_base_background_color():
|
||||
"""
|
||||
Get the base background color for history adjustments.
|
||||
This should be the main chat background color.
|
||||
|
||||
Returns:
|
||||
str: Hex color string
|
||||
"""
|
||||
# Use a neutral base color that works well with adjustments
|
||||
return '#f8f9fa'
|
||||
711
common/utils/config_field_types.py
Normal file
711
common/utils/config_field_types.py
Normal file
@@ -0,0 +1,711 @@
|
||||
from typing import Optional, List, Union, Dict, Any, Pattern
|
||||
from pydantic import BaseModel, field_validator, model_validator
|
||||
from typing_extensions import Annotated
|
||||
import re
|
||||
from datetime import datetime
|
||||
import json
|
||||
from textwrap import dedent
|
||||
import yaml
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
class TaggingField(BaseModel):
|
||||
"""Represents a single tagging field configuration"""
|
||||
type: str
|
||||
required: bool = False
|
||||
description: Optional[str] = None
|
||||
allowed_values: Optional[List[Any]] = None # for enum type
|
||||
min_value: Optional[Union[int, float]] = None # for numeric types
|
||||
max_value: Optional[Union[int, float]] = None # for numeric types
|
||||
|
||||
@field_validator('type', mode='before')
|
||||
@classmethod
|
||||
def validate_type(cls, v: str) -> str:
|
||||
valid_types = ['string', 'integer', 'float', 'date', 'enum', 'color']
|
||||
if v not in valid_types:
|
||||
raise ValueError(f'type must be one of {valid_types}')
|
||||
return v
|
||||
|
||||
@model_validator(mode='after')
|
||||
def validate_field_constraints(self) -> 'TaggingField':
|
||||
# Validate enum constraints
|
||||
if self.type == 'enum':
|
||||
if not self.allowed_values:
|
||||
raise ValueError('allowed_values must be provided for enum type')
|
||||
elif self.allowed_values is not None:
|
||||
raise ValueError('allowed_values only valid for enum type')
|
||||
|
||||
# Validate numeric constraints
|
||||
if self.type not in ('integer', 'float'):
|
||||
if self.min_value is not None or self.max_value is not None:
|
||||
raise ValueError('min_value/max_value only valid for numeric types')
|
||||
else:
|
||||
if self.min_value is not None and self.max_value is not None and self.min_value >= self.max_value:
|
||||
raise ValueError('min_value must be less than max_value')
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class TaggingFields(BaseModel):
|
||||
"""Represents a collection of tagging fields, mapped by their names"""
|
||||
fields: Dict[str, TaggingField]
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Dict[str, Any]]) -> 'TaggingFields':
|
||||
return cls(fields={
|
||||
field_name: TaggingField(**field_config)
|
||||
for field_name, field_config in data.items()
|
||||
})
|
||||
|
||||
def to_dict(self) -> Dict[str, Dict[str, Any]]:
|
||||
return {
|
||||
field_name: field.model_dump(exclude_none=True)
|
||||
for field_name, field in self.fields.items()
|
||||
}
|
||||
|
||||
|
||||
class ChunkingPatternsField(BaseModel):
|
||||
"""Represents a set of chunking patterns"""
|
||||
patterns: List[str]
|
||||
|
||||
@field_validator('patterns')
|
||||
def validate_patterns(cls, patterns):
|
||||
for pattern in patterns:
|
||||
try:
|
||||
re.compile(pattern)
|
||||
except re.error as e:
|
||||
raise ValueError(f"Invalid regex pattern '{pattern}': {str(e)}")
|
||||
return patterns
|
||||
|
||||
|
||||
class ArgumentConstraint(BaseModel):
|
||||
"""Base class for all argument constraints"""
|
||||
description: Optional[str] = None
|
||||
error_message: Optional[str] = None
|
||||
|
||||
|
||||
class NumericConstraint(ArgumentConstraint):
|
||||
"""Constraints for numeric values (int/float)"""
|
||||
min_value: Optional[float] = None
|
||||
max_value: Optional[float] = None
|
||||
include_min: bool = True # True for >= min_value, False for > min_value
|
||||
include_max: bool = True # True for <= max_value, False for < max_value
|
||||
|
||||
@model_validator(mode='after')
|
||||
def validate_ranges(self) -> 'NumericConstraint':
|
||||
if self.min_value is not None and self.max_value is not None:
|
||||
if self.min_value > self.max_value:
|
||||
raise ValueError("min_value must be less than or equal to max_value")
|
||||
return self
|
||||
|
||||
def validate(self, value: Union[int, float]) -> bool:
|
||||
if self.min_value is not None:
|
||||
if self.include_min and value < self.min_value:
|
||||
return False
|
||||
if not self.include_min and value <= self.min_value:
|
||||
return False
|
||||
if self.max_value is not None:
|
||||
if self.include_max and value > self.max_value:
|
||||
return False
|
||||
if not self.include_max and value >= self.max_value:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class StringConstraint(ArgumentConstraint):
|
||||
"""Constraints for string values"""
|
||||
min_length: Optional[int] = None
|
||||
max_length: Optional[int] = None
|
||||
patterns: Optional[List[str]] = None # List of regex patterns to match
|
||||
pattern_match_all: bool = False # If True, string must match all patterns
|
||||
forbidden_patterns: Optional[List[str]] = None # List of regex patterns that must not match
|
||||
allow_empty: bool = False
|
||||
|
||||
@field_validator('patterns', 'forbidden_patterns')
|
||||
@classmethod
|
||||
def validate_patterns(cls, v: Optional[List[str]]) -> Optional[List[str]]:
|
||||
if v is not None:
|
||||
# Validate each pattern compiles
|
||||
for pattern in v:
|
||||
try:
|
||||
re.compile(pattern)
|
||||
except re.error as e:
|
||||
raise ValueError(f"Invalid regex pattern '{pattern}': {str(e)}")
|
||||
return v
|
||||
|
||||
def validate(self, value: str) -> bool:
|
||||
if not self.allow_empty and not value:
|
||||
return False
|
||||
|
||||
if self.min_length is not None and len(value) < self.min_length:
|
||||
return False
|
||||
|
||||
if self.max_length is not None and len(value) > self.max_length:
|
||||
return False
|
||||
|
||||
if self.patterns:
|
||||
matches = [bool(re.search(pattern, value)) for pattern in self.patterns]
|
||||
if self.pattern_match_all and not all(matches):
|
||||
return False
|
||||
if not self.pattern_match_all and not any(matches):
|
||||
return False
|
||||
|
||||
if self.forbidden_patterns:
|
||||
for pattern in self.forbidden_patterns:
|
||||
if re.search(pattern, value):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
class DateConstraint(ArgumentConstraint):
|
||||
"""Constraints for date values"""
|
||||
min_date: Optional[datetime] = None
|
||||
max_date: Optional[datetime] = None
|
||||
include_min: bool = True
|
||||
include_max: bool = True
|
||||
allowed_formats: Optional[List[str]] = None # List of allowed date formats
|
||||
|
||||
@model_validator(mode='after')
|
||||
def validate_ranges(self) -> 'DateConstraint':
|
||||
if self.min_date and self.max_date and self.min_date > self.max_date:
|
||||
raise ValueError("min_date must be less than or equal to max_date")
|
||||
return self
|
||||
|
||||
def validate(self, value: datetime) -> bool:
|
||||
if self.min_date is not None:
|
||||
if self.include_min and value < self.min_date:
|
||||
return False
|
||||
if not self.include_min and value <= self.min_date:
|
||||
return False
|
||||
|
||||
if self.max_date is not None:
|
||||
if self.include_max and value > self.max_date:
|
||||
return False
|
||||
if not self.include_max and value >= self.max_date:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
class EnumConstraint(ArgumentConstraint):
|
||||
"""Constraints for enum values"""
|
||||
allowed_values: List[Any]
|
||||
case_sensitive: bool = True # For string enums
|
||||
allow_multiple: bool = False # If True, value can be a list of allowed values
|
||||
min_selections: Optional[int] = None # When allow_multiple is True
|
||||
max_selections: Optional[int] = None # When allow_multiple is True
|
||||
|
||||
@model_validator(mode='after')
|
||||
def validate_selections(self) -> 'EnumConstraint':
|
||||
if self.allow_multiple:
|
||||
if self.min_selections is not None and self.max_selections is not None:
|
||||
if self.min_selections > self.max_selections:
|
||||
raise ValueError("min_selections must be less than or equal to max_selections")
|
||||
if self.max_selections > len(self.allowed_values):
|
||||
raise ValueError("max_selections cannot be greater than number of allowed values")
|
||||
return self
|
||||
|
||||
def validate(self, value: Union[Any, List[Any]]) -> bool:
|
||||
if self.allow_multiple:
|
||||
if not isinstance(value, list):
|
||||
return False
|
||||
|
||||
if self.min_selections is not None and len(value) < self.min_selections:
|
||||
return False
|
||||
|
||||
if self.max_selections is not None and len(value) > self.max_selections:
|
||||
return False
|
||||
|
||||
for v in value:
|
||||
if not self._validate_single_value(v):
|
||||
return False
|
||||
else:
|
||||
return self._validate_single_value(value)
|
||||
|
||||
return True
|
||||
|
||||
def _validate_single_value(self, value: Any) -> bool:
|
||||
if isinstance(value, str) and not self.case_sensitive:
|
||||
return any(str(value).lower() == str(v).lower() for v in self.allowed_values)
|
||||
return value in self.allowed_values
|
||||
|
||||
|
||||
class ArgumentDefinition(BaseModel):
|
||||
"""Defines an argument with its type and constraints"""
|
||||
name: str
|
||||
type: str
|
||||
description: Optional[str] = None
|
||||
required: bool = False
|
||||
default: Optional[Any] = None
|
||||
constraints: Optional[Union[NumericConstraint, StringConstraint, DateConstraint, EnumConstraint]] = None
|
||||
|
||||
@field_validator('type')
|
||||
@classmethod
|
||||
def validate_type(cls, v: str) -> str:
|
||||
valid_types = ['string', 'integer', 'float', 'date', 'enum', 'color']
|
||||
if v not in valid_types:
|
||||
raise ValueError(f'type must be one of {valid_types}')
|
||||
return v
|
||||
|
||||
@model_validator(mode='after')
|
||||
def validate_constraints(self) -> 'ArgumentDefinition':
|
||||
if self.constraints:
|
||||
expected_constraint_types = {
|
||||
'string': StringConstraint,
|
||||
'integer': NumericConstraint,
|
||||
'float': NumericConstraint,
|
||||
'date': DateConstraint,
|
||||
'enum': EnumConstraint,
|
||||
'color': StringConstraint
|
||||
}
|
||||
|
||||
expected_type = expected_constraint_types.get(self.type)
|
||||
if not isinstance(self.constraints, expected_type):
|
||||
raise ValueError(f'Constraints for type {self.type} must be of type {expected_type.__name__}')
|
||||
|
||||
if self.default is not None:
|
||||
if not self.constraints.validate(self.default):
|
||||
raise ValueError(f'Default value does not satisfy constraints for {self.name}')
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class ArgumentDefinitions(BaseModel):
|
||||
"""Collection of argument definitions"""
|
||||
arguments: Dict[str, ArgumentDefinition]
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Dict[str, Any]]) -> 'ArgumentDefinitions':
|
||||
return cls(arguments={
|
||||
arg_name: ArgumentDefinition(**arg_config)
|
||||
for arg_name, arg_config in data.items()
|
||||
})
|
||||
|
||||
def to_dict(self) -> Dict[str, Dict[str, Any]]:
|
||||
return {
|
||||
arg_name: arg.model_dump(exclude_none=True)
|
||||
for arg_name, arg in self.arguments.items()
|
||||
}
|
||||
|
||||
def validate_argument_values(self, values: Dict[str, Any]) -> Dict[str, str]:
|
||||
"""
|
||||
Validate a set of argument values against their definitions
|
||||
Returns a dictionary of error messages for invalid arguments
|
||||
"""
|
||||
errors = {}
|
||||
|
||||
# Check for required arguments
|
||||
for name, arg_def in self.arguments.items():
|
||||
if arg_def.required and name not in values:
|
||||
errors[name] = "Required argument missing"
|
||||
continue
|
||||
|
||||
if name in values:
|
||||
value = values[name]
|
||||
|
||||
# Validate type
|
||||
try:
|
||||
if arg_def.type == 'integer':
|
||||
value = int(value)
|
||||
elif arg_def.type == 'float':
|
||||
value = float(value)
|
||||
elif arg_def.type == 'date' and isinstance(value, str):
|
||||
if arg_def.constraints and arg_def.constraints.allowed_formats:
|
||||
for fmt in arg_def.constraints.allowed_formats:
|
||||
try:
|
||||
value = datetime.strptime(value, fmt)
|
||||
break
|
||||
except ValueError:
|
||||
continue
|
||||
else:
|
||||
errors[
|
||||
name] = f"Invalid date format. Allowed formats: {arg_def.constraints.allowed_formats}"
|
||||
continue
|
||||
except (ValueError, TypeError):
|
||||
errors[name] = f"Invalid type. Expected {arg_def.type}"
|
||||
continue
|
||||
|
||||
# Validate constraints
|
||||
if arg_def.constraints and not arg_def.constraints.validate(value):
|
||||
errors[name] = arg_def.constraints.error_message or "Value does not satisfy constraints"
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
@dataclass
|
||||
class DocumentationFormat:
|
||||
"""Constants for documentation formats"""
|
||||
MARKDOWN = "markdown"
|
||||
JSON = "json"
|
||||
YAML = "yaml"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DocumentationVersion:
|
||||
"""Constants for documentation versions"""
|
||||
BASIC = "basic" # Original documentation without retriever info
|
||||
EXTENDED = "extended" # Including retriever documentation
|
||||
|
||||
|
||||
def _generate_argument_constraints(field_config: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
"""Generate possible argument constraints based on field type"""
|
||||
constraints = []
|
||||
|
||||
base_constraint = {
|
||||
"description": f"Constraint for {field_config.get('description', 'field')}",
|
||||
"error_message": "Optional custom error message"
|
||||
}
|
||||
|
||||
if field_config["type"] == "integer" or field_config["type"] == "float":
|
||||
constraints.append({
|
||||
**base_constraint,
|
||||
"type": "NumericConstraint",
|
||||
"possible_constraints": {
|
||||
"min_value": "number",
|
||||
"max_value": "number",
|
||||
"include_min": "boolean",
|
||||
"include_max": "boolean"
|
||||
},
|
||||
"example": {
|
||||
"min_value": field_config.get("min_value", 0),
|
||||
"max_value": field_config.get("max_value", 100),
|
||||
"include_min": True,
|
||||
"include_max": True
|
||||
}
|
||||
})
|
||||
|
||||
elif field_config["type"] == "string":
|
||||
constraints.append({
|
||||
**base_constraint,
|
||||
"type": "StringConstraint",
|
||||
"possible_constraints": {
|
||||
"min_length": "integer",
|
||||
"max_length": "integer",
|
||||
"patterns": "list[str]",
|
||||
"pattern_match_all": "boolean",
|
||||
"forbidden_patterns": "list[str]",
|
||||
"allow_empty": "boolean"
|
||||
},
|
||||
"example": {
|
||||
"min_length": 1,
|
||||
"max_length": 100,
|
||||
"patterns": ["^[A-Za-z0-9]+$"],
|
||||
"pattern_match_all": False,
|
||||
"forbidden_patterns": ["^test_", "_temp$"],
|
||||
"allow_empty": False
|
||||
}
|
||||
})
|
||||
|
||||
elif field_config["type"] == "enum":
|
||||
constraints.append({
|
||||
**base_constraint,
|
||||
"type": "EnumConstraint",
|
||||
"possible_constraints": {
|
||||
"allowed_values": f"list[{field_config.get('allowed_values', ['value1', 'value2'])}]",
|
||||
"case_sensitive": "boolean",
|
||||
"allow_multiple": "boolean",
|
||||
"min_selections": "integer",
|
||||
"max_selections": "integer"
|
||||
},
|
||||
"example": {
|
||||
"allowed_values": field_config.get("allowed_values", ["value1", "value2"]),
|
||||
"case_sensitive": True,
|
||||
"allow_multiple": True,
|
||||
"min_selections": 1,
|
||||
"max_selections": 2
|
||||
}
|
||||
})
|
||||
|
||||
elif field_config["type"] == "date":
|
||||
constraints.append({
|
||||
**base_constraint,
|
||||
"type": "DateConstraint",
|
||||
"possible_constraints": {
|
||||
"min_date": "datetime",
|
||||
"max_date": "datetime",
|
||||
"include_min": "boolean",
|
||||
"include_max": "boolean",
|
||||
"allowed_formats": "list[str]"
|
||||
},
|
||||
"example": {
|
||||
"min_date": "2024-01-01T00:00:00",
|
||||
"max_date": "2024-12-31T23:59:59",
|
||||
"include_min": True,
|
||||
"include_max": True,
|
||||
"allowed_formats": ["%Y-%m-%d", "%Y/%m/%d"]
|
||||
}
|
||||
})
|
||||
|
||||
return constraints
|
||||
|
||||
|
||||
def generate_field_documentation(
|
||||
tagging_fields: Dict[str, Any],
|
||||
format: str = "markdown",
|
||||
version: str = "basic"
|
||||
) -> str:
|
||||
"""
|
||||
Generate documentation for tagging fields configuration.
|
||||
|
||||
Args:
|
||||
tagging_fields: Dictionary containing tagging fields configuration
|
||||
format: Output format ("markdown", "json", or "yaml")
|
||||
version: Documentation version ("basic" or "extended")
|
||||
|
||||
Returns:
|
||||
str: Formatted documentation
|
||||
"""
|
||||
if version not in [DocumentationVersion.BASIC, DocumentationVersion.EXTENDED]:
|
||||
raise ValueError(f"Unsupported documentation version: {version}")
|
||||
|
||||
# Normalize fields configuration
|
||||
normalized_fields = {}
|
||||
|
||||
for field_name, field_config in tagging_fields.items():
|
||||
field_doc = {
|
||||
"name": field_name,
|
||||
"type": field_config["type"],
|
||||
"required": field_config.get("required", False),
|
||||
"description": field_config.get("description", "No description provided"),
|
||||
"constraints": []
|
||||
}
|
||||
|
||||
# Only include possible arguments in extended version
|
||||
if version == DocumentationVersion.EXTENDED:
|
||||
field_doc["possible_arguments"] = _generate_argument_constraints(field_config)
|
||||
|
||||
# Add type-specific constraints
|
||||
if field_config["type"] == "integer" or field_config["type"] == "float":
|
||||
if "min_value" in field_config:
|
||||
field_doc["constraints"].append(
|
||||
f"Minimum value: {field_config['min_value']}")
|
||||
if "max_value" in field_config:
|
||||
field_doc["constraints"].append(
|
||||
f"Maximum value: {field_config['max_value']}")
|
||||
|
||||
elif field_config["type"] == "string":
|
||||
if "min_length" in field_config:
|
||||
field_doc["constraints"].append(
|
||||
f"Minimum length: {field_config['min_length']}")
|
||||
if "max_length" in field_config:
|
||||
field_doc["constraints"].append(
|
||||
f"Maximum length: {field_config['max_length']}")
|
||||
if "patterns" in field_config:
|
||||
field_doc["constraints"].append(
|
||||
f"Must match patterns: {', '.join(field_config['patterns'])}")
|
||||
|
||||
elif field_config["type"] == "enum":
|
||||
if "allowed_values" in field_config:
|
||||
field_doc["constraints"].append(
|
||||
f"Allowed values: {', '.join(str(v) for v in field_config['allowed_values'])}")
|
||||
|
||||
elif field_config["type"] == "date":
|
||||
if "min_date" in field_config:
|
||||
field_doc["constraints"].append(
|
||||
f"Minimum date: {field_config['min_date']}")
|
||||
if "max_date" in field_config:
|
||||
field_doc["constraints"].append(
|
||||
f"Maximum date: {field_config['max_date']}")
|
||||
if "allowed_formats" in field_config:
|
||||
field_doc["constraints"].append(
|
||||
f"Allowed formats: {', '.join(field_config['allowed_formats'])}")
|
||||
|
||||
normalized_fields[field_name] = field_doc
|
||||
|
||||
# Generate documentation in requested format
|
||||
if format == DocumentationFormat.MARKDOWN:
|
||||
return _generate_markdown_docs(normalized_fields, version)
|
||||
elif format == DocumentationFormat.JSON:
|
||||
return _generate_json_docs(normalized_fields, version)
|
||||
elif format == DocumentationFormat.YAML:
|
||||
return _generate_yaml_docs(normalized_fields, version)
|
||||
else:
|
||||
raise ValueError(f"Unsupported documentation format: {format}")
|
||||
|
||||
|
||||
def _generate_markdown_docs(fields: Dict[str, Any], version: str) -> str:
|
||||
"""Generate markdown documentation"""
|
||||
docs = ["# Tagging Fields Documentation\n"]
|
||||
|
||||
# Add overview table
|
||||
docs.append("## Fields Overview\n")
|
||||
docs.append("| Field Name | Type | Required | Description |")
|
||||
docs.append("|------------|------|----------|-------------|")
|
||||
|
||||
for field_name, field in fields.items():
|
||||
docs.append(
|
||||
f"| {field_name} | {field['type']} | "
|
||||
f"{'Yes' if field['required'] else 'No'} | {field['description']} |"
|
||||
)
|
||||
|
||||
# Add detailed field specifications
|
||||
docs.append("\n## Detailed Field Specifications\n")
|
||||
|
||||
for field_name, field in fields.items():
|
||||
docs.append(f"### {field_name}\n")
|
||||
docs.append(f"**Type:** {field['type']}")
|
||||
docs.append(f"**Required:** {'Yes' if field['required'] else 'No'}")
|
||||
docs.append(f"**Description:** {field['description']}\n")
|
||||
|
||||
if field["constraints"]:
|
||||
docs.append("**Field Constraints:**")
|
||||
for constraint in field["constraints"]:
|
||||
docs.append(f"- {constraint}")
|
||||
docs.append("")
|
||||
|
||||
# Add retriever argument documentation only in extended version
|
||||
if version == DocumentationVersion.EXTENDED and "possible_arguments" in field:
|
||||
docs.append("**Possible Retriever Arguments:**")
|
||||
for arg_constraint in field["possible_arguments"]:
|
||||
docs.append(f"\n*{arg_constraint['type']}*")
|
||||
docs.append(f"Description: {arg_constraint['description']}")
|
||||
docs.append("\nPossible constraints:")
|
||||
for const_name, const_type in arg_constraint["possible_constraints"].items():
|
||||
docs.append(f"- `{const_name}`: {const_type}")
|
||||
|
||||
docs.append("\nExample:")
|
||||
docs.append("```python")
|
||||
docs.append(json.dumps(arg_constraint["example"], indent=2))
|
||||
docs.append("```\n")
|
||||
|
||||
# Add example retriever configuration only in extended version
|
||||
if version == DocumentationVersion.EXTENDED:
|
||||
docs.append("\n## Example Retriever Configuration\n")
|
||||
docs.append("```python")
|
||||
example_config = {
|
||||
"metadata_filters": {
|
||||
field_name: field["possible_arguments"][0]["example"]
|
||||
for field_name, field in fields.items()
|
||||
if "possible_arguments" in field
|
||||
}
|
||||
}
|
||||
docs.append(json.dumps(example_config, indent=2))
|
||||
docs.append("```")
|
||||
|
||||
return "\n".join(docs)
|
||||
|
||||
|
||||
def _generate_json_docs(fields: Dict[str, Any], version: str) -> str:
|
||||
"""Generate JSON documentation"""
|
||||
doc = {
|
||||
"tagging_fields_documentation": {
|
||||
"version": version,
|
||||
"fields": fields
|
||||
}
|
||||
}
|
||||
|
||||
if version == DocumentationVersion.EXTENDED:
|
||||
doc["tagging_fields_documentation"]["example_retriever_config"] = {
|
||||
"metadata_filters": {
|
||||
field_name: field["possible_arguments"][0]["example"]
|
||||
for field_name, field in fields.items()
|
||||
if "possible_arguments" in field
|
||||
}
|
||||
}
|
||||
|
||||
return json.dumps(doc, indent=2)
|
||||
|
||||
|
||||
def _generate_yaml_docs(fields: Dict[str, Any], version: str) -> str:
|
||||
"""Generate YAML documentation"""
|
||||
doc = {
|
||||
"tagging_fields_documentation": {
|
||||
"version": version,
|
||||
"fields": fields
|
||||
}
|
||||
}
|
||||
|
||||
if version == DocumentationVersion.EXTENDED:
|
||||
doc["tagging_fields_documentation"]["example_retriever_config"] = {
|
||||
"metadata_filters": {
|
||||
field_name: field["possible_arguments"][0]["example"]
|
||||
for field_name, field in fields.items()
|
||||
if "possible_arguments" in field
|
||||
}
|
||||
}
|
||||
|
||||
return yaml.dump(doc, sort_keys=False, default_flow_style=False)
|
||||
|
||||
|
||||
def patterns_to_json(text_area_content: str) -> str:
|
||||
"""Convert line-based patterns to JSON"""
|
||||
text_area_content = text_area_content.strip()
|
||||
if len(text_area_content) == 0:
|
||||
return json.dumps([])
|
||||
# Split on newlines and remove empty lines
|
||||
patterns = [line.strip() for line in text_area_content.split('\n') if line.strip()]
|
||||
return json.dumps(patterns)
|
||||
|
||||
|
||||
def json_to_patterns(json_content: str) -> str:
|
||||
"""Convert JSON patterns list to text area content"""
|
||||
try:
|
||||
patterns = json.loads(json_content)
|
||||
if not isinstance(patterns, list):
|
||||
raise ValueError("JSON must contain a list of patterns")
|
||||
# Join with newlines
|
||||
return '\n'.join(patterns)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Invalid JSON format: {e}")
|
||||
|
||||
|
||||
def json_to_pattern_list(json_content: str) -> list:
|
||||
"""Convert JSON patterns list to text area content"""
|
||||
try:
|
||||
if json_content:
|
||||
patterns = json.loads(json_content)
|
||||
if not isinstance(patterns, list):
|
||||
raise ValueError("JSON must contain a list of patterns")
|
||||
# Unescape if needed
|
||||
patterns = [pattern.replace('\\\\', '\\') for pattern in patterns]
|
||||
return patterns
|
||||
else:
|
||||
return []
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Invalid JSON format: {e}")
|
||||
|
||||
|
||||
def normalize_json_field(value: str | dict | None, field_name: str = "JSON field") -> dict:
|
||||
"""
|
||||
Normalize a JSON field value to ensure it's a valid dictionary.
|
||||
|
||||
Args:
|
||||
value: The input value which can be:
|
||||
- None (will return empty dict)
|
||||
- String (will be parsed as JSON)
|
||||
- Dict (will be validated and returned)
|
||||
field_name: Name of the field for error messages
|
||||
|
||||
Returns:
|
||||
dict: The normalized JSON data as a Python dictionary
|
||||
|
||||
Raises:
|
||||
ValueError: If the input string is not valid JSON or the input dict contains invalid types
|
||||
"""
|
||||
# Handle None case
|
||||
if value is None:
|
||||
return {}
|
||||
|
||||
# Handle dictionary case
|
||||
if isinstance(value, dict):
|
||||
try:
|
||||
# Validate all values are JSON serializable
|
||||
import json
|
||||
json.dumps(value)
|
||||
return value
|
||||
except TypeError as e:
|
||||
raise ValueError(f"{field_name} contains invalid types: {str(e)}")
|
||||
|
||||
# Handle string case
|
||||
if isinstance(value, str):
|
||||
if not value.strip():
|
||||
return {}
|
||||
|
||||
try:
|
||||
import json
|
||||
return json.loads(value)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"{field_name} contains invalid JSON: {str(e)}")
|
||||
|
||||
raise ValueError(f"{field_name} must be a string, dictionary, or None (got {type(value)})")
|
||||
222
common/utils/content_utils.py
Normal file
222
common/utils/content_utils.py
Normal file
@@ -0,0 +1,222 @@
|
||||
import os
|
||||
import re
|
||||
import logging
|
||||
from packaging import version
|
||||
from flask import current_app
|
||||
|
||||
class ContentManager:
|
||||
def __init__(self, app=None):
|
||||
self.app = app
|
||||
if app:
|
||||
self.init_app(app)
|
||||
|
||||
def init_app(self, app):
|
||||
self.app = app
|
||||
|
||||
# Controleer of het pad bestaat
|
||||
# if not os.path.exists(app.config['CONTENT_DIR']):
|
||||
# logger.warning(f"Content directory not found at: {app.config['CONTENT_DIR']}")
|
||||
# else:
|
||||
# logger.info(f"Content directory configured at: {app.config['CONTENT_DIR']}")
|
||||
|
||||
def get_content_path(self, content_type, major_minor=None, patch=None):
|
||||
"""
|
||||
Geef het volledige pad naar een contentbestand
|
||||
|
||||
Args:
|
||||
content_type (str): Type content (bv. 'changelog', 'terms')
|
||||
major_minor (str, optional): Major.Minor versie (bv. '1.0')
|
||||
patch (str, optional): Patchnummer (bv. '5')
|
||||
|
||||
Returns:
|
||||
str: Volledige pad naar de content map of bestand
|
||||
"""
|
||||
content_path = os.path.join(self.app.config['CONTENT_DIR'], content_type)
|
||||
|
||||
if major_minor:
|
||||
content_path = os.path.join(content_path, major_minor)
|
||||
|
||||
if patch:
|
||||
content_path = os.path.join(content_path, f"{major_minor}.{patch}.md")
|
||||
|
||||
return content_path
|
||||
|
||||
def _parse_version(self, filename):
|
||||
"""Parse een versienummer uit een bestandsnaam"""
|
||||
match = re.match(r'(\d+\.\d+)\.(\d+)\.md', filename)
|
||||
if match:
|
||||
return match.group(1), match.group(2)
|
||||
return None, None
|
||||
|
||||
def get_latest_version(self, content_type, major_minor=None):
|
||||
"""
|
||||
Verkrijg de laatste versie van een bepaald contenttype
|
||||
|
||||
Args:
|
||||
content_type (str): Type content (bv. 'changelog', 'terms')
|
||||
major_minor (str, optional): Specifieke major.minor versie, anders de hoogste
|
||||
|
||||
Returns:
|
||||
tuple: (major_minor, patch, full_version) of None als niet gevonden
|
||||
"""
|
||||
try:
|
||||
# Basispad voor dit contenttype
|
||||
content_path = os.path.join(self.app.config['CONTENT_DIR'], content_type)
|
||||
|
||||
if not os.path.exists(content_path):
|
||||
current_app.logger.error(f"Content path does not exist: {content_path}")
|
||||
return None
|
||||
|
||||
# Als geen major_minor opgegeven, vind de hoogste
|
||||
if not major_minor:
|
||||
available_versions = [f for f in os.listdir(content_path) if not f.startswith('.')]
|
||||
if not available_versions:
|
||||
return None
|
||||
|
||||
# Sorteer op versienummer (major.minor)
|
||||
available_versions.sort(key=lambda v: version.parse(v))
|
||||
major_minor = available_versions[-1]
|
||||
|
||||
# Nu we major_minor hebben, zoek de hoogste patch
|
||||
major_minor_path = os.path.join(content_path, major_minor)
|
||||
current_app.logger.debug(f"Major/Minor path: {major_minor_path}")
|
||||
|
||||
if not os.path.exists(major_minor_path):
|
||||
current_app.logger.error(f"Version path does not exist: {major_minor_path}")
|
||||
return None
|
||||
|
||||
files = [f for f in os.listdir(major_minor_path) if not f.startswith('.')]
|
||||
current_app.logger.debug(f"Files in version path: {files}")
|
||||
version_files = []
|
||||
|
||||
for file in files:
|
||||
mm, p = self._parse_version(file)
|
||||
current_app.logger.debug(f"File: {file}, mm: {mm}, p: {p}")
|
||||
if mm == major_minor and p:
|
||||
version_files.append((mm, p, f"{mm}.{p}"))
|
||||
|
||||
if not version_files:
|
||||
return None
|
||||
|
||||
# Sorteer op patch nummer
|
||||
version_files.sort(key=lambda v: int(v[1]))
|
||||
|
||||
current_app.logger.debug(f"Latest version: {version_files[-1]}")
|
||||
return version_files[-1]
|
||||
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error finding latest version for {content_type}: {str(e)}")
|
||||
return None
|
||||
|
||||
def read_content(self, content_type, major_minor=None, patch=None):
|
||||
"""
|
||||
Lees content met versieondersteuning
|
||||
|
||||
Als major_minor en patch niet zijn opgegeven, wordt de laatste versie gebruikt.
|
||||
Als alleen major_minor is opgegeven, wordt de laatste patch van die versie gebruikt.
|
||||
|
||||
Args:
|
||||
content_type (str): Type content (bv. 'changelog', 'terms')
|
||||
major_minor (str, optional): Major.Minor versie (bv. '1.0')
|
||||
patch (str, optional): Patchnummer (bv. '5')
|
||||
|
||||
Returns:
|
||||
dict: {
|
||||
'content': str,
|
||||
'version': str,
|
||||
'content_type': str
|
||||
} of None bij fout
|
||||
"""
|
||||
try:
|
||||
current_app.logger.debug(f"Reading content {content_type}")
|
||||
# Als geen versie opgegeven, vind de laatste
|
||||
if not major_minor:
|
||||
version_info = self.get_latest_version(content_type)
|
||||
if not version_info:
|
||||
current_app.logger.error(f"No versions found for {content_type}")
|
||||
return None
|
||||
|
||||
major_minor, patch, full_version = version_info
|
||||
|
||||
# Als geen patch opgegeven, vind de laatste patch voor deze major_minor
|
||||
elif not patch:
|
||||
version_info = self.get_latest_version(content_type, major_minor)
|
||||
if not version_info:
|
||||
current_app.logger.error(f"No versions found for {content_type} {major_minor}")
|
||||
return None
|
||||
|
||||
major_minor, patch, full_version = version_info
|
||||
else:
|
||||
full_version = f"{major_minor}.{patch}"
|
||||
|
||||
# Nu hebben we major_minor en patch, lees het bestand
|
||||
file_path = self.get_content_path(content_type, major_minor, patch)
|
||||
current_app.logger.debug(f"Content File path: {file_path}")
|
||||
|
||||
if not os.path.exists(file_path):
|
||||
current_app.logger.error(f"Content file does not exist: {file_path}")
|
||||
return None
|
||||
|
||||
with open(file_path, 'r', encoding='utf-8') as file:
|
||||
content = file.read()
|
||||
|
||||
current_app.logger.debug(f"Content read: {content}")
|
||||
|
||||
return {
|
||||
'content': content,
|
||||
'version': full_version,
|
||||
'content_type': content_type
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error reading content {content_type} {major_minor}.{patch}: {str(e)}")
|
||||
return None
|
||||
|
||||
def list_content_types(self):
|
||||
"""Lijst alle beschikbare contenttypes op"""
|
||||
try:
|
||||
return [d for d in os.listdir(self.app.config['CONTENT_DIR'])
|
||||
if os.path.isdir(os.path.join(self.app.config['CONTENT_DIR'], d))]
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error listing content types: {str(e)}")
|
||||
return []
|
||||
|
||||
def list_versions(self, content_type):
|
||||
"""
|
||||
Lijst alle beschikbare versies voor een contenttype
|
||||
|
||||
Returns:
|
||||
list: Lijst van dicts met versie-informatie
|
||||
[{'version': '1.0.0', 'path': '/path/to/file', 'date_modified': datetime}]
|
||||
"""
|
||||
versions = []
|
||||
try:
|
||||
content_path = os.path.join(self.app.config['CONTENT_DIR'], content_type)
|
||||
|
||||
if not os.path.exists(content_path):
|
||||
return []
|
||||
|
||||
for major_minor in os.listdir(content_path):
|
||||
major_minor_path = os.path.join(content_path, major_minor)
|
||||
|
||||
if not os.path.isdir(major_minor_path):
|
||||
continue
|
||||
|
||||
for file in os.listdir(major_minor_path):
|
||||
mm, p = self._parse_version(file)
|
||||
if mm and p:
|
||||
file_path = os.path.join(major_minor_path, file)
|
||||
mod_time = os.path.getmtime(file_path)
|
||||
versions.append({
|
||||
'version': f"{mm}.{p}",
|
||||
'path': file_path,
|
||||
'date_modified': mod_time
|
||||
})
|
||||
|
||||
# Sorteer op versienummer
|
||||
versions.sort(key=lambda v: version.parse(v['version']))
|
||||
return versions
|
||||
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error listing versions for {content_type}: {str(e)}")
|
||||
return []
|
||||
@@ -1,14 +1,14 @@
|
||||
from flask import request, current_app, session
|
||||
from flask_jwt_extended import decode_token, verify_jwt_in_request, get_jwt_identity
|
||||
|
||||
from common.models.user import Tenant, TenantDomain
|
||||
|
||||
|
||||
def get_allowed_origins(tenant_id):
|
||||
session_key = f"allowed_origins_{tenant_id}"
|
||||
if session_key in session:
|
||||
current_app.logger.debug(f"Fetching allowed origins for tenant {tenant_id} from session")
|
||||
return session[session_key]
|
||||
|
||||
current_app.logger.debug(f"Fetching allowed origins for tenant {tenant_id} from database")
|
||||
tenant_domains = TenantDomain.query.filter_by(tenant_id=int(tenant_id)).all()
|
||||
allowed_origins = [domain.domain for domain in tenant_domains]
|
||||
|
||||
@@ -18,43 +18,52 @@ def get_allowed_origins(tenant_id):
|
||||
|
||||
|
||||
def cors_after_request(response, prefix):
|
||||
current_app.logger.debug(f'CORS after request: {request.path}, prefix: {prefix}')
|
||||
current_app.logger.debug(f'request.headers: {request.headers}')
|
||||
current_app.logger.debug(f'request.args: {request.args}')
|
||||
current_app.logger.debug(f'request is json?: {request.is_json}')
|
||||
# Exclude health checks from checks
|
||||
if request.path.startswith('/healthz') or request.path.startswith('/_healthz'):
|
||||
response.headers.add('Access-Control-Allow-Origin', '*')
|
||||
response.headers.add('Access-Control-Allow-Headers', '*')
|
||||
response.headers.add('Access-Control-Allow-Methods', '*')
|
||||
return response
|
||||
|
||||
# Handle OPTIONS preflight requests
|
||||
if request.method == 'OPTIONS':
|
||||
response.headers.add('Access-Control-Allow-Origin', '*')
|
||||
response.headers.add('Access-Control-Allow-Headers', 'Content-Type,Authorization,X-Tenant-ID')
|
||||
response.headers.add('Access-Control-Allow-Methods', 'GET,POST,PUT,DELETE,OPTIONS')
|
||||
response.headers.add('Access-Control-Allow-Credentials', 'true')
|
||||
return response
|
||||
|
||||
tenant_id = None
|
||||
allowed_origins = []
|
||||
|
||||
# Try to get tenant_id from JSON payload
|
||||
json_data = request.get_json(silent=True)
|
||||
current_app.logger.debug(f'request.get_json(silent=True): {json_data}')
|
||||
|
||||
if json_data and 'tenant_id' in json_data:
|
||||
tenant_id = json_data['tenant_id']
|
||||
# Check Socket.IO connection
|
||||
if 'socket.io' in request.path:
|
||||
token = request.args.get('token')
|
||||
if token:
|
||||
try:
|
||||
decoded = decode_token(token)
|
||||
tenant_id = decoded['sub']
|
||||
except Exception as e:
|
||||
current_app.logger.error(f'Error decoding token: {e}')
|
||||
return response
|
||||
else:
|
||||
# Fallback to get tenant_id from query parameters or headers if JSON is not available
|
||||
tenant_id = request.args.get('tenant_id') or request.args.get('tenantId') or request.headers.get('X-Tenant-ID')
|
||||
|
||||
current_app.logger.debug(f'Identified tenant_id: {tenant_id}')
|
||||
# Regular API requests
|
||||
try:
|
||||
if verify_jwt_in_request(optional=True):
|
||||
tenant_id = get_jwt_identity()
|
||||
except Exception as e:
|
||||
current_app.logger.error(f'Error verifying JWT: {e}')
|
||||
return response
|
||||
|
||||
if tenant_id:
|
||||
origin = request.headers.get('Origin')
|
||||
allowed_origins = get_allowed_origins(tenant_id)
|
||||
current_app.logger.debug(f'Allowed origins for tenant {tenant_id}: {allowed_origins}')
|
||||
else:
|
||||
current_app.logger.warning('tenant_id not found in request')
|
||||
|
||||
origin = request.headers.get('Origin')
|
||||
current_app.logger.debug(f'Origin: {origin}')
|
||||
|
||||
if origin in allowed_origins:
|
||||
response.headers.add('Access-Control-Allow-Origin', origin)
|
||||
response.headers.add('Access-Control-Allow-Headers', 'Content-Type,Authorization')
|
||||
response.headers.add('Access-Control-Allow-Methods', 'GET,POST,PUT,DELETE,OPTIONS')
|
||||
response.headers.add('Access-Control-Allow-Credentials', 'true')
|
||||
current_app.logger.debug(f'CORS headers set for origin: {origin}')
|
||||
else:
|
||||
current_app.logger.warning(f'Origin {origin} not allowed')
|
||||
if origin in allowed_origins:
|
||||
response.headers.add('Access-Control-Allow-Origin', origin)
|
||||
response.headers.add('Access-Control-Allow-Headers', 'Content-Type,Authorization')
|
||||
response.headers.add('Access-Control-Allow-Methods', 'GET,POST,PUT,DELETE,OPTIONS')
|
||||
response.headers.add('Access-Control-Allow-Credentials', 'true')
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
"""Database related functions"""
|
||||
from os import popen
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy import text, event
|
||||
from sqlalchemy.schema import CreateSchema
|
||||
from sqlalchemy.exc import InternalError
|
||||
from sqlalchemy.orm import sessionmaker, scoped_session
|
||||
from sqlalchemy.orm import sessionmaker, scoped_session, Session as SASession
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from flask import current_app
|
||||
|
||||
@@ -16,6 +16,66 @@ class Database:
|
||||
def __init__(self, tenant: str) -> None:
|
||||
self.schema = str(tenant)
|
||||
|
||||
# --- Session / Transaction events to ensure correct search_path per transaction ---
|
||||
@event.listens_for(SASession, "after_begin")
|
||||
def _set_search_path_per_tx(session, transaction, connection):
|
||||
"""Ensure each transaction sees the right tenant schema, regardless of
|
||||
which pooled connection is used. Uses SET LOCAL so it is scoped to the tx.
|
||||
"""
|
||||
schema = session.info.get("tenant_schema")
|
||||
if schema:
|
||||
try:
|
||||
connection.exec_driver_sql(f'SET LOCAL search_path TO "{schema}", public')
|
||||
# Optional visibility/logging for debugging
|
||||
sp = connection.exec_driver_sql("SHOW search_path").scalar()
|
||||
try:
|
||||
current_app.logger.info(f"DBCTX tx_begin conn_id={id(connection.connection)} search_path={sp}")
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
try:
|
||||
current_app.logger.error(f"Failed to SET LOCAL search_path for schema {schema}: {e!r}")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _log_db_context(self, origin: str = "") -> None:
|
||||
"""Log key DB context info to diagnose schema/search_path issues.
|
||||
|
||||
Collects and logs in a single structured line:
|
||||
- current_database()
|
||||
- inet_server_addr(), inet_server_port()
|
||||
- SHOW search_path
|
||||
- current_schema()
|
||||
- to_regclass('interaction')
|
||||
- to_regclass('<tenant>.interaction')
|
||||
"""
|
||||
try:
|
||||
db_name = db.session.execute(text("SELECT current_database()"))\
|
||||
.scalar()
|
||||
host = db.session.execute(text("SELECT inet_server_addr()"))\
|
||||
.scalar()
|
||||
port = db.session.execute(text("SELECT inet_server_port()"))\
|
||||
.scalar()
|
||||
search_path = db.session.execute(text("SHOW search_path"))\
|
||||
.scalar()
|
||||
current_schema = db.session.execute(text("SELECT current_schema()"))\
|
||||
.scalar()
|
||||
reg_unqualified = db.session.execute(text("SELECT to_regclass('interaction')"))\
|
||||
.scalar()
|
||||
qualified = f"{self.schema}.interaction"
|
||||
reg_qualified = db.session.execute(
|
||||
text("SELECT to_regclass(:qn)"),
|
||||
{"qn": qualified}
|
||||
).scalar()
|
||||
current_app.logger.info(
|
||||
"DBCTX origin=%s db=%s host=%s port=%s search_path=%s current_schema=%s to_regclass(interaction)=%s to_regclass(%s)=%s",
|
||||
origin, db_name, host, port, search_path, current_schema, reg_unqualified, qualified, reg_qualified
|
||||
)
|
||||
except SQLAlchemyError as e:
|
||||
current_app.logger.error(
|
||||
f"DBCTX logging failed at {origin} for schema {self.schema}: {e!r}"
|
||||
)
|
||||
|
||||
def get_engine(self):
|
||||
"""create new schema engine"""
|
||||
return db.engine.execution_options(
|
||||
@@ -46,12 +106,38 @@ class Database:
|
||||
|
||||
def create_tables(self):
|
||||
"""create tables in for schema"""
|
||||
db.metadata.create_all(self.get_engine())
|
||||
try:
|
||||
db.metadata.create_all(self.get_engine())
|
||||
except SQLAlchemyError as e:
|
||||
current_app.logger.error(f"💔 Error creating tables for schema {self.schema}: {e.args}")
|
||||
|
||||
def switch_schema(self):
|
||||
"""switch between tenant/public database schema"""
|
||||
db.session.execute(text(f'set search_path to "{self.schema}", public'))
|
||||
db.session.commit()
|
||||
"""switch between tenant/public database schema with diagnostics logging"""
|
||||
# Record the desired tenant schema on the active Session so events can use it
|
||||
try:
|
||||
db.session.info["tenant_schema"] = self.schema
|
||||
except Exception:
|
||||
pass
|
||||
# Log the context before switching
|
||||
self._log_db_context("before_switch")
|
||||
try:
|
||||
db.session.execute(text(f'set search_path to "{self.schema}", public'))
|
||||
db.session.commit()
|
||||
except SQLAlchemyError as e:
|
||||
# Rollback on error to avoid InFailedSqlTransaction and log details
|
||||
try:
|
||||
db.session.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
current_app.logger.error(
|
||||
f"Error switching search_path to {self.schema}: {e!r}"
|
||||
)
|
||||
# Also log context after failure
|
||||
self._log_db_context("after_switch_failed")
|
||||
# Re-raise to let caller decide handling if needed
|
||||
raise
|
||||
# Log the context after successful switch
|
||||
self._log_db_context("after_switch")
|
||||
|
||||
def migrate_tenant_schema(self):
|
||||
"""migrate tenant database schema for new tenant"""
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user