Add support for ML model storage and execution (#3015)
This commit is contained in:
parent
fc66e2f4ea
commit
2ae8416791
64 changed files with 1815 additions and 320 deletions
28
.github/workflows/ci.yml
vendored
28
.github/workflows/ci.yml
vendored
|
@ -199,6 +199,34 @@ jobs:
|
||||||
- name: Run HTTP integration tests
|
- name: Run HTTP integration tests
|
||||||
run: cargo make ci-http-integration
|
run: cargo make ci-http-integration
|
||||||
|
|
||||||
|
ml-support:
|
||||||
|
name: ML integration tests
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
|
||||||
|
- name: Install stable toolchain
|
||||||
|
uses: dtolnay/rust-toolchain@stable
|
||||||
|
with:
|
||||||
|
toolchain: 1.71.1
|
||||||
|
|
||||||
|
- name: Checkout sources
|
||||||
|
uses: actions/checkout@v3
|
||||||
|
|
||||||
|
- name: Setup cache
|
||||||
|
uses: Swatinem/rust-cache@v2
|
||||||
|
with:
|
||||||
|
save-if: ${{ github.ref == 'refs/heads/main' }}
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
sudo apt-get -y update
|
||||||
|
|
||||||
|
- name: Install cargo-make
|
||||||
|
run: cargo install --debug --locked cargo-make
|
||||||
|
|
||||||
|
- name: Run ML integration tests
|
||||||
|
run: cargo make ci-ml-integration
|
||||||
|
|
||||||
ws-server:
|
ws-server:
|
||||||
name: WebSocket integration tests
|
name: WebSocket integration tests
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
|
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -45,5 +45,7 @@ Temporary Items
|
||||||
# Specific
|
# Specific
|
||||||
# -----------------------------------
|
# -----------------------------------
|
||||||
|
|
||||||
|
/cache/
|
||||||
|
/store/
|
||||||
surreal
|
surreal
|
||||||
history.txt
|
history.txt
|
||||||
|
|
187
Cargo.lock
generated
187
Cargo.lock
generated
|
@ -1813,6 +1813,18 @@ dependencies = [
|
||||||
"windows-sys 0.48.0",
|
"windows-sys 0.48.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "filetime"
|
||||||
|
version = "0.2.22"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "d4029edd3e734da6fe05b6cd7bd2960760a616bd2ddd0d59a0124746d6272af0"
|
||||||
|
dependencies = [
|
||||||
|
"cfg-if",
|
||||||
|
"libc",
|
||||||
|
"redox_syscall 0.3.5",
|
||||||
|
"windows-sys 0.48.0",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "findshlibs"
|
name = "findshlibs"
|
||||||
version = "0.10.2"
|
version = "0.10.2"
|
||||||
|
@ -2917,7 +2929,7 @@ checksum = "85c833ca1e66078851dba29046874e38f08b2c883700aa29a03ddd3b23814ee8"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"bitflags 2.4.1",
|
"bitflags 2.4.1",
|
||||||
"libc",
|
"libc",
|
||||||
"redox_syscall",
|
"redox_syscall 0.4.1",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -3036,6 +3048,16 @@ version = "0.7.3"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94"
|
checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "matrixmultiply"
|
||||||
|
version = "0.3.8"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "7574c1cf36da4798ab73da5b215bbf444f50718207754cb522201d78d1cd0ff2"
|
||||||
|
dependencies = [
|
||||||
|
"autocfg",
|
||||||
|
"rawpointer",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "md-5"
|
name = "md-5"
|
||||||
version = "0.10.6"
|
version = "0.10.6"
|
||||||
|
@ -3181,6 +3203,19 @@ dependencies = [
|
||||||
"tempfile",
|
"tempfile",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "ndarray"
|
||||||
|
version = "0.15.6"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "adb12d4e967ec485a5f71c6311fe28158e9d6f4bc4a447b474184d0f91a8fa32"
|
||||||
|
dependencies = [
|
||||||
|
"matrixmultiply",
|
||||||
|
"num-complex",
|
||||||
|
"num-integer",
|
||||||
|
"num-traits",
|
||||||
|
"rawpointer",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "new_debug_unreachable"
|
name = "new_debug_unreachable"
|
||||||
version = "1.0.4"
|
version = "1.0.4"
|
||||||
|
@ -3272,6 +3307,15 @@ dependencies = [
|
||||||
"zeroize",
|
"zeroize",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "num-complex"
|
||||||
|
version = "0.4.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "1ba157ca0885411de85d6ca030ba7e2a83a28636056c7c699b07c8b6f7383214"
|
||||||
|
dependencies = [
|
||||||
|
"num-traits",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "num-format"
|
name = "num-format"
|
||||||
version = "0.4.4"
|
version = "0.4.4"
|
||||||
|
@ -3332,6 +3376,27 @@ dependencies = [
|
||||||
"memchr",
|
"memchr",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "object_store"
|
||||||
|
version = "0.8.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "2524735495ea1268be33d200e1ee97455096a0846295a21548cd2f3541de7050"
|
||||||
|
dependencies = [
|
||||||
|
"async-trait",
|
||||||
|
"bytes",
|
||||||
|
"chrono",
|
||||||
|
"futures",
|
||||||
|
"humantime",
|
||||||
|
"itertools 0.11.0",
|
||||||
|
"parking_lot",
|
||||||
|
"percent-encoding",
|
||||||
|
"snafu",
|
||||||
|
"tokio",
|
||||||
|
"tracing",
|
||||||
|
"url",
|
||||||
|
"walkdir",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "once_cell"
|
name = "once_cell"
|
||||||
version = "1.18.0"
|
version = "1.18.0"
|
||||||
|
@ -3467,6 +3532,24 @@ dependencies = [
|
||||||
"tokio-stream",
|
"tokio-stream",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "ort"
|
||||||
|
version = "1.16.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "889dca4c98efa21b1ba54ddb2bde44fd4920d910f492b618351f839d8428d79d"
|
||||||
|
dependencies = [
|
||||||
|
"flate2",
|
||||||
|
"lazy_static",
|
||||||
|
"libc",
|
||||||
|
"libloading",
|
||||||
|
"ndarray",
|
||||||
|
"tar",
|
||||||
|
"thiserror",
|
||||||
|
"tracing",
|
||||||
|
"vswhom",
|
||||||
|
"winapi",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "overload"
|
name = "overload"
|
||||||
version = "0.1.1"
|
version = "0.1.1"
|
||||||
|
@ -3497,7 +3580,7 @@ checksum = "4c42a9226546d68acdd9c0a280d17ce19bfe27a46bf68784e4066115788d008e"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"cfg-if",
|
"cfg-if",
|
||||||
"libc",
|
"libc",
|
||||||
"redox_syscall",
|
"redox_syscall 0.4.1",
|
||||||
"smallvec",
|
"smallvec",
|
||||||
"windows-targets 0.48.5",
|
"windows-targets 0.48.5",
|
||||||
]
|
]
|
||||||
|
@ -4017,6 +4100,12 @@ dependencies = [
|
||||||
"rand_core 0.6.4",
|
"rand_core 0.6.4",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rawpointer"
|
||||||
|
version = "0.2.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "rayon"
|
name = "rayon"
|
||||||
version = "1.8.0"
|
version = "1.8.0"
|
||||||
|
@ -4049,6 +4138,15 @@ dependencies = [
|
||||||
"yasna",
|
"yasna",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "redox_syscall"
|
||||||
|
version = "0.3.5"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "567664f262709473930a4bf9e51bf2ebf3348f2e748ccc50dea20646858f8f29"
|
||||||
|
dependencies = [
|
||||||
|
"bitflags 1.3.2",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "redox_syscall"
|
name = "redox_syscall"
|
||||||
version = "0.4.1"
|
version = "0.4.1"
|
||||||
|
@ -4950,6 +5048,28 @@ dependencies = [
|
||||||
"serde",
|
"serde",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "snafu"
|
||||||
|
version = "0.7.5"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "e4de37ad025c587a29e8f3f5605c00f70b98715ef90b9061a815b9e59e9042d6"
|
||||||
|
dependencies = [
|
||||||
|
"doc-comment",
|
||||||
|
"snafu-derive",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "snafu-derive"
|
||||||
|
version = "0.7.5"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "990079665f075b699031e9c08fd3ab99be5029b96f3b78dc0709e8f77e4efebf"
|
||||||
|
dependencies = [
|
||||||
|
"heck",
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"syn 1.0.109",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "snap"
|
name = "snap"
|
||||||
version = "1.1.0"
|
version = "1.1.0"
|
||||||
|
@ -5120,6 +5240,7 @@ dependencies = [
|
||||||
"ipnet",
|
"ipnet",
|
||||||
"jemallocator",
|
"jemallocator",
|
||||||
"mimalloc",
|
"mimalloc",
|
||||||
|
"ndarray",
|
||||||
"nix 0.27.1",
|
"nix 0.27.1",
|
||||||
"once_cell",
|
"once_cell",
|
||||||
"opentelemetry",
|
"opentelemetry",
|
||||||
|
@ -5136,6 +5257,7 @@ dependencies = [
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"serial_test",
|
"serial_test",
|
||||||
"surrealdb",
|
"surrealdb",
|
||||||
|
"surrealml-core",
|
||||||
"temp-env",
|
"temp-env",
|
||||||
"tempfile",
|
"tempfile",
|
||||||
"test-log",
|
"test-log",
|
||||||
|
@ -5184,6 +5306,7 @@ dependencies = [
|
||||||
"futures-concurrency",
|
"futures-concurrency",
|
||||||
"fuzzy-matcher",
|
"fuzzy-matcher",
|
||||||
"geo 0.27.0",
|
"geo 0.27.0",
|
||||||
|
"hex",
|
||||||
"indexmap 2.1.0",
|
"indexmap 2.1.0",
|
||||||
"indxdb",
|
"indxdb",
|
||||||
"ipnet",
|
"ipnet",
|
||||||
|
@ -5192,8 +5315,10 @@ dependencies = [
|
||||||
"md-5",
|
"md-5",
|
||||||
"nanoid",
|
"nanoid",
|
||||||
"native-tls",
|
"native-tls",
|
||||||
|
"ndarray",
|
||||||
"nom",
|
"nom",
|
||||||
"num_cpus",
|
"num_cpus",
|
||||||
|
"object_store",
|
||||||
"once_cell",
|
"once_cell",
|
||||||
"path-clean",
|
"path-clean",
|
||||||
"pbkdf2",
|
"pbkdf2",
|
||||||
|
@ -5224,6 +5349,7 @@ dependencies = [
|
||||||
"surrealdb-derive",
|
"surrealdb-derive",
|
||||||
"surrealdb-jsonwebtoken",
|
"surrealdb-jsonwebtoken",
|
||||||
"surrealdb-tikv-client",
|
"surrealdb-tikv-client",
|
||||||
|
"surrealml-core",
|
||||||
"temp-dir",
|
"temp-dir",
|
||||||
"test-log",
|
"test-log",
|
||||||
"thiserror",
|
"thiserror",
|
||||||
|
@ -5299,6 +5425,21 @@ dependencies = [
|
||||||
"tonic 0.9.2",
|
"tonic 0.9.2",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "surrealml-core"
|
||||||
|
version = "0.0.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "2aabe3c44a73d6f7a3d069400a4966d2ffcf643fb1d60399b1746787abc9dd12"
|
||||||
|
dependencies = [
|
||||||
|
"bytes",
|
||||||
|
"futures-core",
|
||||||
|
"futures-util",
|
||||||
|
"ndarray",
|
||||||
|
"once_cell",
|
||||||
|
"ort",
|
||||||
|
"regex",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "symbolic-common"
|
name = "symbolic-common"
|
||||||
version = "12.7.0"
|
version = "12.7.0"
|
||||||
|
@ -5389,6 +5530,17 @@ version = "1.0.1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369"
|
checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tar"
|
||||||
|
version = "0.4.40"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "b16afcea1f22891c49a00c751c7b63b2233284064f11a200fc624137c51e2ddb"
|
||||||
|
dependencies = [
|
||||||
|
"filetime",
|
||||||
|
"libc",
|
||||||
|
"xattr",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "temp-dir"
|
name = "temp-dir"
|
||||||
version = "0.1.11"
|
version = "0.1.11"
|
||||||
|
@ -5413,7 +5565,7 @@ checksum = "7ef1adac450ad7f4b3c28589471ade84f25f731a7a0fe30d71dfa9f60fd808e5"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"cfg-if",
|
"cfg-if",
|
||||||
"fastrand 2.0.1",
|
"fastrand 2.0.1",
|
||||||
"redox_syscall",
|
"redox_syscall 0.4.1",
|
||||||
"rustix",
|
"rustix",
|
||||||
"windows-sys 0.48.0",
|
"windows-sys 0.48.0",
|
||||||
]
|
]
|
||||||
|
@ -6103,6 +6255,26 @@ version = "0.9.4"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f"
|
checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "vswhom"
|
||||||
|
version = "0.1.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "be979b7f07507105799e854203b470ff7c78a1639e330a58f183b5fea574608b"
|
||||||
|
dependencies = [
|
||||||
|
"libc",
|
||||||
|
"vswhom-sys",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "vswhom-sys"
|
||||||
|
version = "0.1.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "d3b17ae1f6c8a2b28506cd96d412eebf83b4a0ff2cbefeeb952f2f9dfa44ba18"
|
||||||
|
dependencies = [
|
||||||
|
"cc",
|
||||||
|
"libc",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "waker-fn"
|
name = "waker-fn"
|
||||||
version = "1.1.1"
|
version = "1.1.1"
|
||||||
|
@ -6501,6 +6673,15 @@ dependencies = [
|
||||||
"tap",
|
"tap",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "xattr"
|
||||||
|
version = "1.0.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "f4686009f71ff3e5c4dbcf1a282d0a44db3f021ba69350cd42086b3e5f1c6985"
|
||||||
|
dependencies = [
|
||||||
|
"libc",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "xml-rs"
|
name = "xml-rs"
|
||||||
version = "0.8.19"
|
version = "0.8.19"
|
||||||
|
|
|
@ -8,7 +8,7 @@ authors = ["Tobie Morgan Hitchcock <tobie@surrealdb.com>"]
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
# Public features
|
# Public features
|
||||||
default = ["storage-mem", "storage-rocksdb", "scripting", "http"]
|
default = ["storage-mem", "storage-rocksdb", "scripting", "http", "ml"]
|
||||||
storage-mem = ["surrealdb/kv-mem", "has-storage"]
|
storage-mem = ["surrealdb/kv-mem", "has-storage"]
|
||||||
storage-rocksdb = ["surrealdb/kv-rocksdb", "has-storage"]
|
storage-rocksdb = ["surrealdb/kv-rocksdb", "has-storage"]
|
||||||
storage-speedb = ["surrealdb/kv-speedb", "has-storage"]
|
storage-speedb = ["surrealdb/kv-speedb", "has-storage"]
|
||||||
|
@ -17,6 +17,7 @@ storage-fdb = ["surrealdb/kv-fdb-7_1", "has-storage"]
|
||||||
scripting = ["surrealdb/scripting"]
|
scripting = ["surrealdb/scripting"]
|
||||||
http = ["surrealdb/http"]
|
http = ["surrealdb/http"]
|
||||||
http-compression = []
|
http-compression = []
|
||||||
|
ml = ["surrealdb/ml", "surrealml-core"]
|
||||||
# Private features
|
# Private features
|
||||||
has-storage = []
|
has-storage = []
|
||||||
|
|
||||||
|
@ -49,6 +50,7 @@ http = "0.2.11"
|
||||||
http-body = "0.4.5"
|
http-body = "0.4.5"
|
||||||
hyper = "0.14.27"
|
hyper = "0.14.27"
|
||||||
ipnet = "2.9.0"
|
ipnet = "2.9.0"
|
||||||
|
ndarray = { version = "0.15.6", optional = true }
|
||||||
once_cell = "1.18.0"
|
once_cell = "1.18.0"
|
||||||
opentelemetry = { version = "0.19", features = ["rt-tokio"] }
|
opentelemetry = { version = "0.19", features = ["rt-tokio"] }
|
||||||
opentelemetry-otlp = { version = "0.12.0", features = ["metrics"] }
|
opentelemetry-otlp = { version = "0.12.0", features = ["metrics"] }
|
||||||
|
@ -61,6 +63,7 @@ serde_cbor = "0.11.2"
|
||||||
serde_json = "1.0.108"
|
serde_json = "1.0.108"
|
||||||
serde_pack = { version = "1.1.2", package = "rmp-serde" }
|
serde_pack = { version = "1.1.2", package = "rmp-serde" }
|
||||||
surrealdb = { path = "lib", features = ["protocol-http", "protocol-ws", "rustls"] }
|
surrealdb = { path = "lib", features = ["protocol-http", "protocol-ws", "rustls"] }
|
||||||
|
surrealml-core = { version = "0.0.2", optional = true}
|
||||||
tempfile = "3.8.1"
|
tempfile = "3.8.1"
|
||||||
thiserror = "1.0.50"
|
thiserror = "1.0.50"
|
||||||
tokio = { version = "1.34.0", features = ["macros", "signal"] }
|
tokio = { version = "1.34.0", features = ["macros", "signal"] }
|
||||||
|
|
|
@ -39,6 +39,12 @@ command = "cargo"
|
||||||
env = { RUST_LOG={ value = "ws_integration=debug", condition = { env_not_set = ["RUST_LOG"] } } }
|
env = { RUST_LOG={ value = "ws_integration=debug", condition = { env_not_set = ["RUST_LOG"] } } }
|
||||||
args = ["test", "--locked", "--no-default-features", "--features", "storage-mem", "--workspace", "--test", "ws_integration", "--", "ws_integration", "--nocapture"]
|
args = ["test", "--locked", "--no-default-features", "--features", "storage-mem", "--workspace", "--test", "ws_integration", "--", "ws_integration", "--nocapture"]
|
||||||
|
|
||||||
|
[tasks.ci-ml-integration]
|
||||||
|
category = "CI - INTEGRATION TESTS"
|
||||||
|
command = "cargo"
|
||||||
|
env = { RUST_LOG={ value = "cli_integration::common=debug", condition = { env_not_set = ["RUST_LOG"] } } }
|
||||||
|
args = ["test", "--locked", "--features", "storage-mem,ml", "--workspace", "--test", "ml_integration", "--", "ml_integration", "--nocapture"]
|
||||||
|
|
||||||
[tasks.ci-workspace-coverage]
|
[tasks.ci-workspace-coverage]
|
||||||
category = "CI - INTEGRATION TESTS"
|
category = "CI - INTEGRATION TESTS"
|
||||||
command = "cargo"
|
command = "cargo"
|
||||||
|
|
|
@ -10,7 +10,7 @@ reduce_output = true
|
||||||
default_to_workspace = false
|
default_to_workspace = false
|
||||||
|
|
||||||
[env]
|
[env]
|
||||||
DEV_FEATURES={ value = "storage-mem,http,scripting", condition = { env_not_set = ["DEV_FEATURES"] } }
|
DEV_FEATURES={ value = "storage-mem,scripting,http,ml", condition = { env_not_set = ["DEV_FEATURES"] } }
|
||||||
SURREAL_LOG={ value = "trace", condition = { env_not_set = ["SURREAL_LOG"] } }
|
SURREAL_LOG={ value = "trace", condition = { env_not_set = ["SURREAL_LOG"] } }
|
||||||
SURREAL_USER={ value = "root", condition = { env_not_set = ["SURREAL_USER"] } }
|
SURREAL_USER={ value = "root", condition = { env_not_set = ["SURREAL_USER"] } }
|
||||||
SURREAL_PASS={ value = "root", condition = { env_not_set = ["SURREAL_PASS"] } }
|
SURREAL_PASS={ value = "root", condition = { env_not_set = ["SURREAL_PASS"] } }
|
||||||
|
|
|
@ -37,6 +37,7 @@ scripting = ["dep:js"]
|
||||||
http = ["dep:reqwest"]
|
http = ["dep:reqwest"]
|
||||||
native-tls = ["dep:native-tls", "reqwest?/native-tls", "tokio-tungstenite?/native-tls"]
|
native-tls = ["dep:native-tls", "reqwest?/native-tls", "tokio-tungstenite?/native-tls"]
|
||||||
rustls = ["dep:rustls", "reqwest?/rustls-tls", "tokio-tungstenite?/rustls-tls-webpki-roots"]
|
rustls = ["dep:rustls", "reqwest?/rustls-tls", "tokio-tungstenite?/rustls-tls-webpki-roots"]
|
||||||
|
ml = ["surrealml-core", "ndarray"]
|
||||||
# Private features
|
# Private features
|
||||||
kv-fdb = ["foundationdb", "tokio/time"]
|
kv-fdb = ["foundationdb", "tokio/time"]
|
||||||
|
|
||||||
|
@ -74,6 +75,7 @@ futures = "0.3.29"
|
||||||
futures-concurrency = "7.4.3"
|
futures-concurrency = "7.4.3"
|
||||||
fuzzy-matcher = "0.3.7"
|
fuzzy-matcher = "0.3.7"
|
||||||
geo = { version = "0.27.0", features = ["use-serde"] }
|
geo = { version = "0.27.0", features = ["use-serde"] }
|
||||||
|
hex = { version = "0.4.3", optional = false }
|
||||||
indexmap = { version = "2.1.0", features = ["serde"] }
|
indexmap = { version = "2.1.0", features = ["serde"] }
|
||||||
indxdb = { version = "0.4.0", optional = true }
|
indxdb = { version = "0.4.0", optional = true }
|
||||||
ipnet = "2.9.0"
|
ipnet = "2.9.0"
|
||||||
|
@ -84,8 +86,10 @@ lru = "0.12.1"
|
||||||
md-5 = "0.10.6"
|
md-5 = "0.10.6"
|
||||||
nanoid = "0.4.0"
|
nanoid = "0.4.0"
|
||||||
native-tls = { version = "0.2.11", optional = true }
|
native-tls = { version = "0.2.11", optional = true }
|
||||||
|
ndarray = { version = "0.15.6", optional = true }
|
||||||
nom = { version = "7.1.3", features = ["alloc"] }
|
nom = { version = "7.1.3", features = ["alloc"] }
|
||||||
num_cpus = "1.16.0"
|
num_cpus = "1.16.0"
|
||||||
|
object_store = { version = "0.8.0", optional = false }
|
||||||
once_cell = "1.18.0"
|
once_cell = "1.18.0"
|
||||||
path-clean = "1.0.1"
|
path-clean = "1.0.1"
|
||||||
pbkdf2 = { version = "0.12.2", features = ["simple"] }
|
pbkdf2 = { version = "0.12.2", features = ["simple"] }
|
||||||
|
@ -109,6 +113,7 @@ sha2 = "0.10.8"
|
||||||
snap = "1.1.0"
|
snap = "1.1.0"
|
||||||
speedb = { version = "0.0.4", features = ["lz4", "snappy"], optional = true }
|
speedb = { version = "0.0.4", features = ["lz4", "snappy"], optional = true }
|
||||||
storekey = "0.5.0"
|
storekey = "0.5.0"
|
||||||
|
surrealml-core = { version = "0.0.2", optional = true }
|
||||||
thiserror = "1.0.50"
|
thiserror = "1.0.50"
|
||||||
tikv = { version = "0.2.0-surreal.2", default-features = false, package = "surrealdb-tikv-client", optional = true }
|
tikv = { version = "0.2.0-surreal.2", default-features = false, package = "surrealdb-tikv-client", optional = true }
|
||||||
tokio-util = { version = "0.7.10", optional = true, features = ["compat"] }
|
tokio-util = { version = "0.7.10", optional = true, features = ["compat"] }
|
||||||
|
|
|
@ -113,6 +113,16 @@ pub enum DbResponse {
|
||||||
Other(Value),
|
Other(Value),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
#[allow(dead_code)] // used by ML model import and export functions
|
||||||
|
pub(crate) enum MlConfig {
|
||||||
|
Import,
|
||||||
|
Export {
|
||||||
|
name: String,
|
||||||
|
version: String,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
/// Holds the parameters given to the caller
|
/// Holds the parameters given to the caller
|
||||||
#[derive(Debug, Default)]
|
#[derive(Debug, Default)]
|
||||||
#[allow(dead_code)] // used by the embedded and remote connections
|
#[allow(dead_code)] // used by the embedded and remote connections
|
||||||
|
@ -122,6 +132,7 @@ pub struct Param {
|
||||||
pub(crate) file: Option<PathBuf>,
|
pub(crate) file: Option<PathBuf>,
|
||||||
pub(crate) bytes_sender: Option<channel::Sender<Result<Vec<u8>>>>,
|
pub(crate) bytes_sender: Option<channel::Sender<Result<Vec<u8>>>>,
|
||||||
pub(crate) notification_sender: Option<channel::Sender<Notification>>,
|
pub(crate) notification_sender: Option<channel::Sender<Notification>>,
|
||||||
|
pub(crate) ml_config: Option<MlConfig>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Param {
|
impl Param {
|
||||||
|
|
|
@ -28,6 +28,8 @@ pub(crate) mod wasm;
|
||||||
|
|
||||||
use crate::api::conn::DbResponse;
|
use crate::api::conn::DbResponse;
|
||||||
use crate::api::conn::Method;
|
use crate::api::conn::Method;
|
||||||
|
#[cfg(not(target_arch = "wasm32"))]
|
||||||
|
use crate::api::conn::MlConfig;
|
||||||
use crate::api::conn::Param;
|
use crate::api::conn::Param;
|
||||||
use crate::api::engine::create_statement;
|
use crate::api::engine::create_statement;
|
||||||
use crate::api::engine::delete_statement;
|
use crate::api::engine::delete_statement;
|
||||||
|
@ -44,9 +46,27 @@ use crate::api::Surreal;
|
||||||
use crate::dbs::Notification;
|
use crate::dbs::Notification;
|
||||||
use crate::dbs::Response;
|
use crate::dbs::Response;
|
||||||
use crate::dbs::Session;
|
use crate::dbs::Session;
|
||||||
|
#[cfg(feature = "ml")]
|
||||||
|
#[cfg(not(target_arch = "wasm32"))]
|
||||||
|
use crate::iam::check::check_ns_db;
|
||||||
|
#[cfg(feature = "ml")]
|
||||||
|
#[cfg(not(target_arch = "wasm32"))]
|
||||||
|
use crate::iam::Action;
|
||||||
|
#[cfg(feature = "ml")]
|
||||||
|
#[cfg(not(target_arch = "wasm32"))]
|
||||||
|
use crate::iam::ResourceKind;
|
||||||
use crate::kvs::Datastore;
|
use crate::kvs::Datastore;
|
||||||
|
#[cfg(feature = "ml")]
|
||||||
|
#[cfg(not(target_arch = "wasm32"))]
|
||||||
|
use crate::kvs::{LockType, TransactionType};
|
||||||
use crate::method::Stats;
|
use crate::method::Stats;
|
||||||
use crate::opt::IntoEndpoint;
|
use crate::opt::IntoEndpoint;
|
||||||
|
#[cfg(feature = "ml")]
|
||||||
|
#[cfg(not(target_arch = "wasm32"))]
|
||||||
|
use crate::sql::statements::DefineModelStatement;
|
||||||
|
#[cfg(feature = "ml")]
|
||||||
|
#[cfg(not(target_arch = "wasm32"))]
|
||||||
|
use crate::sql::statements::DefineStatement;
|
||||||
use crate::sql::statements::KillStatement;
|
use crate::sql::statements::KillStatement;
|
||||||
use crate::sql::Array;
|
use crate::sql::Array;
|
||||||
use crate::sql::Query;
|
use crate::sql::Query;
|
||||||
|
@ -56,6 +76,9 @@ use crate::sql::Strand;
|
||||||
use crate::sql::Uuid;
|
use crate::sql::Uuid;
|
||||||
use crate::sql::Value;
|
use crate::sql::Value;
|
||||||
use channel::Sender;
|
use channel::Sender;
|
||||||
|
#[cfg(feature = "ml")]
|
||||||
|
#[cfg(not(target_arch = "wasm32"))]
|
||||||
|
use futures::StreamExt;
|
||||||
use indexmap::IndexMap;
|
use indexmap::IndexMap;
|
||||||
use std::collections::BTreeMap;
|
use std::collections::BTreeMap;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
@ -65,6 +88,9 @@ use std::mem;
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
#[cfg(feature = "ml")]
|
||||||
|
#[cfg(not(target_arch = "wasm32"))]
|
||||||
|
use surrealml_core::storage::surml_file::SurMlFile;
|
||||||
#[cfg(not(target_arch = "wasm32"))]
|
#[cfg(not(target_arch = "wasm32"))]
|
||||||
use tokio::fs::OpenOptions;
|
use tokio::fs::OpenOptions;
|
||||||
#[cfg(not(target_arch = "wasm32"))]
|
#[cfg(not(target_arch = "wasm32"))]
|
||||||
|
@ -405,11 +431,34 @@ async fn take(one: bool, responses: Vec<Response>) -> Result<Value> {
|
||||||
async fn export(
|
async fn export(
|
||||||
kvs: &Datastore,
|
kvs: &Datastore,
|
||||||
sess: &Session,
|
sess: &Session,
|
||||||
ns: String,
|
|
||||||
db: String,
|
|
||||||
chn: channel::Sender<Vec<u8>>,
|
chn: channel::Sender<Vec<u8>>,
|
||||||
|
ml_config: Option<MlConfig>,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
if let Err(error) = kvs.export(sess, ns, db, chn).await?.await {
|
match ml_config {
|
||||||
|
#[cfg(feature = "ml")]
|
||||||
|
Some(MlConfig::Export {
|
||||||
|
name,
|
||||||
|
version,
|
||||||
|
}) => {
|
||||||
|
// Ensure a NS and DB are set
|
||||||
|
let (nsv, dbv) = check_ns_db(sess)?;
|
||||||
|
// Check the permissions level
|
||||||
|
kvs.check(sess, Action::View, ResourceKind::Model.on_db(&nsv, &dbv))?;
|
||||||
|
// Start a new readonly transaction
|
||||||
|
let mut tx = kvs.transaction(TransactionType::Read, LockType::Optimistic).await?;
|
||||||
|
// Attempt to get the model definition
|
||||||
|
let info = tx.get_db_model(&nsv, &dbv, &name, &version).await?;
|
||||||
|
// Export the file data in to the store
|
||||||
|
let mut data = crate::obs::stream(info.hash.to_owned()).await?;
|
||||||
|
// Process all stream values
|
||||||
|
while let Some(Ok(bytes)) = data.next().await {
|
||||||
|
if chn.send(bytes.to_vec()).await.is_err() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
if let Err(error) = kvs.export(sess, chn).await?.await {
|
||||||
if let crate::error::Db::Channel(message) = error {
|
if let crate::error::Db::Channel(message) = error {
|
||||||
// This is not really an error. Just logging it for improved visibility.
|
// This is not really an error. Just logging it for improved visibility.
|
||||||
trace!("{message}");
|
trace!("{message}");
|
||||||
|
@ -417,6 +466,8 @@ async fn export(
|
||||||
}
|
}
|
||||||
return Err(error.into());
|
return Err(error.into());
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -563,8 +614,6 @@ async fn router(
|
||||||
Method::Export | Method::Import => unreachable!(),
|
Method::Export | Method::Import => unreachable!(),
|
||||||
#[cfg(not(target_arch = "wasm32"))]
|
#[cfg(not(target_arch = "wasm32"))]
|
||||||
Method::Export => {
|
Method::Export => {
|
||||||
let ns = session.ns.clone().unwrap_or_default();
|
|
||||||
let db = session.db.clone().unwrap_or_default();
|
|
||||||
let (tx, rx) = crate::channel::bounded(1);
|
let (tx, rx) = crate::channel::bounded(1);
|
||||||
|
|
||||||
match (param.file, param.bytes_sender) {
|
match (param.file, param.bytes_sender) {
|
||||||
|
@ -572,7 +621,7 @@ async fn router(
|
||||||
let (mut writer, mut reader) = io::duplex(10_240);
|
let (mut writer, mut reader) = io::duplex(10_240);
|
||||||
|
|
||||||
// Write to channel.
|
// Write to channel.
|
||||||
let export = export(kvs, session, ns, db, tx);
|
let export = export(kvs, session, tx, param.ml_config);
|
||||||
|
|
||||||
// Read from channel and write to pipe.
|
// Read from channel and write to pipe.
|
||||||
let bridge = async move {
|
let bridge = async move {
|
||||||
|
@ -613,7 +662,7 @@ async fn router(
|
||||||
let session = session.clone();
|
let session = session.clone();
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
let export = async {
|
let export = async {
|
||||||
if let Err(error) = export(&kvs, &session, ns, db, tx).await {
|
if let Err(error) = export(&kvs, &session, tx, param.ml_config).await {
|
||||||
let _ = backup.send(Err(error)).await;
|
let _ = backup.send(Err(error)).await;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -647,6 +696,52 @@ async fn router(
|
||||||
.into());
|
.into());
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
let responses = match param.ml_config {
|
||||||
|
#[cfg(feature = "ml")]
|
||||||
|
Some(MlConfig::Import) => {
|
||||||
|
// Ensure a NS and DB are set
|
||||||
|
let (nsv, dbv) = check_ns_db(session)?;
|
||||||
|
// Check the permissions level
|
||||||
|
kvs.check(session, Action::Edit, ResourceKind::Model.on_db(&nsv, &dbv))?;
|
||||||
|
// Create a new buffer
|
||||||
|
let mut buffer = Vec::new();
|
||||||
|
// Load all the uploaded file chunks
|
||||||
|
if let Err(error) = file.read_to_end(&mut buffer).await {
|
||||||
|
return Err(Error::FileRead {
|
||||||
|
path,
|
||||||
|
error,
|
||||||
|
}
|
||||||
|
.into());
|
||||||
|
}
|
||||||
|
// Check that the SurrealML file is valid
|
||||||
|
let file = match SurMlFile::from_bytes(buffer) {
|
||||||
|
Ok(file) => file,
|
||||||
|
Err(error) => {
|
||||||
|
return Err(Error::FileRead {
|
||||||
|
path,
|
||||||
|
error,
|
||||||
|
}
|
||||||
|
.into());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
// Convert the file back in to raw bytes
|
||||||
|
let data = file.to_bytes();
|
||||||
|
// Calculate the hash of the model file
|
||||||
|
let hash = crate::obs::hash(&data);
|
||||||
|
// Insert the file data in to the store
|
||||||
|
crate::obs::put(&hash, data).await?;
|
||||||
|
// Insert the model in to the database
|
||||||
|
let query = DefineStatement::Model(DefineModelStatement {
|
||||||
|
hash,
|
||||||
|
name: file.header.name.to_string().into(),
|
||||||
|
version: file.header.version.to_string(),
|
||||||
|
comment: Some(file.header.description.to_string().into()),
|
||||||
|
..Default::default()
|
||||||
|
})
|
||||||
|
.into();
|
||||||
|
kvs.process(query, session, Some(vars.clone())).await?
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
let mut statements = String::new();
|
let mut statements = String::new();
|
||||||
if let Err(error) = file.read_to_string(&mut statements).await {
|
if let Err(error) = file.read_to_string(&mut statements).await {
|
||||||
return Err(Error::FileRead {
|
return Err(Error::FileRead {
|
||||||
|
@ -655,7 +750,9 @@ async fn router(
|
||||||
}
|
}
|
||||||
.into());
|
.into());
|
||||||
}
|
}
|
||||||
let responses = kvs.execute(&statements, &*session, Some(vars.clone())).await?;
|
kvs.execute(&statements, &*session, Some(vars.clone())).await?
|
||||||
|
}
|
||||||
|
};
|
||||||
for response in responses {
|
for response in responses {
|
||||||
response.result?;
|
response.result?;
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,6 +7,9 @@ pub(crate) mod wasm;
|
||||||
|
|
||||||
use crate::api::conn::DbResponse;
|
use crate::api::conn::DbResponse;
|
||||||
use crate::api::conn::Method;
|
use crate::api::conn::Method;
|
||||||
|
#[cfg(feature = "ml")]
|
||||||
|
#[cfg(not(target_arch = "wasm32"))]
|
||||||
|
use crate::api::conn::MlConfig;
|
||||||
use crate::api::conn::Param;
|
use crate::api::conn::Param;
|
||||||
use crate::api::engine::create_statement;
|
use crate::api::engine::create_statement;
|
||||||
use crate::api::engine::delete_statement;
|
use crate::api::engine::delete_statement;
|
||||||
|
@ -516,7 +519,14 @@ async fn router(
|
||||||
Method::Export | Method::Import => unreachable!(),
|
Method::Export | Method::Import => unreachable!(),
|
||||||
#[cfg(not(target_arch = "wasm32"))]
|
#[cfg(not(target_arch = "wasm32"))]
|
||||||
Method::Export => {
|
Method::Export => {
|
||||||
let path = base_url.join(Method::Export.as_str())?;
|
let path = match param.ml_config {
|
||||||
|
#[cfg(feature = "ml")]
|
||||||
|
Some(MlConfig::Export {
|
||||||
|
name,
|
||||||
|
version,
|
||||||
|
}) => base_url.join(&format!("ml/export/{name}/{version}"))?,
|
||||||
|
_ => base_url.join(Method::Export.as_str())?,
|
||||||
|
};
|
||||||
let request = client
|
let request = client
|
||||||
.get(path)
|
.get(path)
|
||||||
.headers(headers.clone())
|
.headers(headers.clone())
|
||||||
|
@ -527,7 +537,11 @@ async fn router(
|
||||||
}
|
}
|
||||||
#[cfg(not(target_arch = "wasm32"))]
|
#[cfg(not(target_arch = "wasm32"))]
|
||||||
Method::Import => {
|
Method::Import => {
|
||||||
let path = base_url.join(Method::Import.as_str())?;
|
let path = match param.ml_config {
|
||||||
|
#[cfg(feature = "ml")]
|
||||||
|
Some(MlConfig::Import) => base_url.join("ml/import")?,
|
||||||
|
_ => base_url.join(Method::Import.as_str())?,
|
||||||
|
};
|
||||||
let file = param.file.expect("file to import from");
|
let file = param.file.expect("file to import from");
|
||||||
let request = client
|
let request = client
|
||||||
.post(path)
|
.post(path)
|
||||||
|
|
|
@ -1,15 +1,18 @@
|
||||||
use crate::api::conn::Method;
|
use crate::api::conn::Method;
|
||||||
|
use crate::api::conn::MlConfig;
|
||||||
use crate::api::conn::Param;
|
use crate::api::conn::Param;
|
||||||
use crate::api::Connection;
|
use crate::api::Connection;
|
||||||
use crate::api::Error;
|
use crate::api::Error;
|
||||||
use crate::api::ExtraFeatures;
|
use crate::api::ExtraFeatures;
|
||||||
use crate::api::Result;
|
use crate::api::Result;
|
||||||
|
use crate::method::Model;
|
||||||
use crate::method::OnceLockExt;
|
use crate::method::OnceLockExt;
|
||||||
use crate::opt::ExportDestination;
|
use crate::opt::ExportDestination;
|
||||||
use crate::Surreal;
|
use crate::Surreal;
|
||||||
use channel::Receiver;
|
use channel::Receiver;
|
||||||
use futures::Stream;
|
use futures::Stream;
|
||||||
use futures::StreamExt;
|
use futures::StreamExt;
|
||||||
|
use semver::Version;
|
||||||
use std::borrow::Cow;
|
use std::borrow::Cow;
|
||||||
use std::future::Future;
|
use std::future::Future;
|
||||||
use std::future::IntoFuture;
|
use std::future::IntoFuture;
|
||||||
|
@ -22,18 +25,39 @@ use std::task::Poll;
|
||||||
/// A database export future
|
/// A database export future
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
#[must_use = "futures do nothing unless you `.await` or poll them"]
|
#[must_use = "futures do nothing unless you `.await` or poll them"]
|
||||||
pub struct Export<'r, C: Connection, R> {
|
pub struct Export<'r, C: Connection, R, T = ()> {
|
||||||
pub(super) client: Cow<'r, Surreal<C>>,
|
pub(super) client: Cow<'r, Surreal<C>>,
|
||||||
pub(super) target: ExportDestination,
|
pub(super) target: ExportDestination,
|
||||||
|
pub(super) ml_config: Option<MlConfig>,
|
||||||
pub(super) response: PhantomData<R>,
|
pub(super) response: PhantomData<R>,
|
||||||
|
pub(super) export_type: PhantomData<T>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<C, R> Export<'_, C, R>
|
impl<'r, C, R> Export<'r, C, R>
|
||||||
|
where
|
||||||
|
C: Connection,
|
||||||
|
{
|
||||||
|
/// Export machine learning model
|
||||||
|
pub fn ml(self, name: &str, version: Version) -> Export<'r, C, R, Model> {
|
||||||
|
Export {
|
||||||
|
client: self.client,
|
||||||
|
target: self.target,
|
||||||
|
ml_config: Some(MlConfig::Export {
|
||||||
|
name: name.to_owned(),
|
||||||
|
version: version.to_string(),
|
||||||
|
}),
|
||||||
|
response: self.response,
|
||||||
|
export_type: PhantomData,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<C, R, T> Export<'_, C, R, T>
|
||||||
where
|
where
|
||||||
C: Connection,
|
C: Connection,
|
||||||
{
|
{
|
||||||
/// Converts to an owned type which can easily be moved to a different thread
|
/// Converts to an owned type which can easily be moved to a different thread
|
||||||
pub fn into_owned(self) -> Export<'static, C, R> {
|
pub fn into_owned(self) -> Export<'static, C, R, T> {
|
||||||
Export {
|
Export {
|
||||||
client: Cow::Owned(self.client.into_owned()),
|
client: Cow::Owned(self.client.into_owned()),
|
||||||
..self
|
..self
|
||||||
|
@ -41,7 +65,7 @@ where
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'r, Client> IntoFuture for Export<'r, Client, PathBuf>
|
impl<'r, Client, T> IntoFuture for Export<'r, Client, PathBuf, T>
|
||||||
where
|
where
|
||||||
Client: Connection,
|
Client: Connection,
|
||||||
{
|
{
|
||||||
|
@ -55,15 +79,17 @@ where
|
||||||
return Err(Error::BackupsNotSupported.into());
|
return Err(Error::BackupsNotSupported.into());
|
||||||
}
|
}
|
||||||
let mut conn = Client::new(Method::Export);
|
let mut conn = Client::new(Method::Export);
|
||||||
match self.target {
|
let mut param = match self.target {
|
||||||
ExportDestination::File(path) => conn.execute_unit(router, Param::file(path)).await,
|
ExportDestination::File(path) => Param::file(path),
|
||||||
ExportDestination::Memory => unreachable!(),
|
ExportDestination::Memory => unreachable!(),
|
||||||
}
|
};
|
||||||
|
param.ml_config = self.ml_config;
|
||||||
|
conn.execute_unit(router, param).await
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'r, Client> IntoFuture for Export<'r, Client, ()>
|
impl<'r, Client, T> IntoFuture for Export<'r, Client, (), T>
|
||||||
where
|
where
|
||||||
Client: Connection,
|
Client: Connection,
|
||||||
{
|
{
|
||||||
|
@ -81,7 +107,9 @@ where
|
||||||
let ExportDestination::Memory = self.target else {
|
let ExportDestination::Memory = self.target else {
|
||||||
unreachable!();
|
unreachable!();
|
||||||
};
|
};
|
||||||
conn.execute_unit(router, Param::bytes_sender(tx)).await?;
|
let mut param = Param::bytes_sender(tx);
|
||||||
|
param.ml_config = self.ml_config;
|
||||||
|
conn.execute_unit(router, param).await?;
|
||||||
Ok(Backup {
|
Ok(Backup {
|
||||||
rx,
|
rx,
|
||||||
})
|
})
|
||||||
|
|
|
@ -1,31 +1,51 @@
|
||||||
use crate::api::conn::Method;
|
use crate::api::conn::Method;
|
||||||
|
use crate::api::conn::MlConfig;
|
||||||
use crate::api::conn::Param;
|
use crate::api::conn::Param;
|
||||||
use crate::api::Connection;
|
use crate::api::Connection;
|
||||||
use crate::api::Error;
|
use crate::api::Error;
|
||||||
use crate::api::ExtraFeatures;
|
use crate::api::ExtraFeatures;
|
||||||
use crate::api::Result;
|
use crate::api::Result;
|
||||||
|
use crate::method::Model;
|
||||||
use crate::method::OnceLockExt;
|
use crate::method::OnceLockExt;
|
||||||
use crate::Surreal;
|
use crate::Surreal;
|
||||||
use std::borrow::Cow;
|
use std::borrow::Cow;
|
||||||
use std::future::Future;
|
use std::future::Future;
|
||||||
use std::future::IntoFuture;
|
use std::future::IntoFuture;
|
||||||
|
use std::marker::PhantomData;
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
use std::pin::Pin;
|
use std::pin::Pin;
|
||||||
|
|
||||||
/// An database import future
|
/// An database import future
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
#[must_use = "futures do nothing unless you `.await` or poll them"]
|
#[must_use = "futures do nothing unless you `.await` or poll them"]
|
||||||
pub struct Import<'r, C: Connection> {
|
pub struct Import<'r, C: Connection, T = ()> {
|
||||||
pub(super) client: Cow<'r, Surreal<C>>,
|
pub(super) client: Cow<'r, Surreal<C>>,
|
||||||
pub(super) file: PathBuf,
|
pub(super) file: PathBuf,
|
||||||
|
pub(super) ml_config: Option<MlConfig>,
|
||||||
|
pub(super) import_type: PhantomData<T>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<C> Import<'_, C>
|
impl<'r, C> Import<'r, C>
|
||||||
|
where
|
||||||
|
C: Connection,
|
||||||
|
{
|
||||||
|
/// Import machine learning model
|
||||||
|
pub fn ml(self) -> Import<'r, C, Model> {
|
||||||
|
Import {
|
||||||
|
client: self.client,
|
||||||
|
file: self.file,
|
||||||
|
ml_config: Some(MlConfig::Import),
|
||||||
|
import_type: PhantomData,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'r, C, T> Import<'r, C, T>
|
||||||
where
|
where
|
||||||
C: Connection,
|
C: Connection,
|
||||||
{
|
{
|
||||||
/// Converts to an owned type which can easily be moved to a different thread
|
/// Converts to an owned type which can easily be moved to a different thread
|
||||||
pub fn into_owned(self) -> Import<'static, C> {
|
pub fn into_owned(self) -> Import<'static, C, T> {
|
||||||
Import {
|
Import {
|
||||||
client: Cow::Owned(self.client.into_owned()),
|
client: Cow::Owned(self.client.into_owned()),
|
||||||
..self
|
..self
|
||||||
|
@ -33,7 +53,7 @@ where
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'r, Client> IntoFuture for Import<'r, Client>
|
impl<'r, Client, T> IntoFuture for Import<'r, Client, T>
|
||||||
where
|
where
|
||||||
Client: Connection,
|
Client: Connection,
|
||||||
{
|
{
|
||||||
|
@ -47,7 +67,9 @@ where
|
||||||
return Err(Error::BackupsNotSupported.into());
|
return Err(Error::BackupsNotSupported.into());
|
||||||
}
|
}
|
||||||
let mut conn = Client::new(Method::Import);
|
let mut conn = Client::new(Method::Import);
|
||||||
conn.execute_unit(router, Param::file(self.file)).await
|
let mut param = Param::file(self.file);
|
||||||
|
param.ml_config = self.ml_config;
|
||||||
|
conn.execute_unit(router, param).await
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -90,6 +90,9 @@ pub struct Stats {
|
||||||
pub execution_time: Duration,
|
pub execution_time: Duration,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Machine learning model marker type for import and export types
|
||||||
|
pub struct Model;
|
||||||
|
|
||||||
/// Responses returned with statistics
|
/// Responses returned with statistics
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct WithStats<T>(T);
|
pub struct WithStats<T>(T);
|
||||||
|
@ -1004,7 +1007,9 @@ where
|
||||||
Export {
|
Export {
|
||||||
client: Cow::Borrowed(self),
|
client: Cow::Borrowed(self),
|
||||||
target: target.into_export_destination(),
|
target: target.into_export_destination(),
|
||||||
|
ml_config: None,
|
||||||
response: PhantomData,
|
response: PhantomData,
|
||||||
|
export_type: PhantomData,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1034,6 +1039,8 @@ where
|
||||||
Import {
|
Import {
|
||||||
client: Cow::Borrowed(self),
|
client: Cow::Borrowed(self),
|
||||||
file: file.as_ref().to_owned(),
|
file: file.as_ref().to_owned(),
|
||||||
|
ml_config: None,
|
||||||
|
import_type: PhantomData,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -409,7 +409,7 @@ fn into_json(value: Value, simplify: bool) -> JsonValue {
|
||||||
},
|
},
|
||||||
Value::Cast(cast) => json!(cast),
|
Value::Cast(cast) => json!(cast),
|
||||||
Value::Function(function) => json!(function),
|
Value::Function(function) => json!(function),
|
||||||
Value::MlModel(model) => json!(model),
|
Value::Model(model) => json!(model),
|
||||||
Value::Query(query) => json!(query),
|
Value::Query(query) => json!(query),
|
||||||
Value::Subquery(subquery) => json!(subquery),
|
Value::Subquery(subquery) => json!(subquery),
|
||||||
Value::Expression(expression) => json!(expression),
|
Value::Expression(expression) => json!(expression),
|
||||||
|
|
|
@ -38,8 +38,7 @@ pub const PROCESSOR_BATCH_SIZE: u32 = 50;
|
||||||
|
|
||||||
/// Forward all signup/signin query errors to a client trying authenticate to a scope. Do not use in production.
|
/// Forward all signup/signin query errors to a client trying authenticate to a scope. Do not use in production.
|
||||||
pub static INSECURE_FORWARD_SCOPE_ERRORS: Lazy<bool> = Lazy::new(|| {
|
pub static INSECURE_FORWARD_SCOPE_ERRORS: Lazy<bool> = Lazy::new(|| {
|
||||||
let default = false;
|
option_env!("SURREAL_INSECURE_FORWARD_SCOPE_ERRORS")
|
||||||
std::env::var("SURREAL_INSECURE_FORWARD_SCOPE_ERRORS")
|
.and_then(|s| s.parse::<bool>().ok())
|
||||||
.map(|v| v.parse::<bool>().unwrap_or(default))
|
.unwrap_or(false)
|
||||||
.unwrap_or(default)
|
|
||||||
});
|
});
|
||||||
|
|
|
@ -12,6 +12,7 @@ use base64_lib::DecodeError as Base64Error;
|
||||||
use bincode::Error as BincodeError;
|
use bincode::Error as BincodeError;
|
||||||
use fst::Error as FstError;
|
use fst::Error as FstError;
|
||||||
use jsonwebtoken::errors::Error as JWTError;
|
use jsonwebtoken::errors::Error as JWTError;
|
||||||
|
use object_store::Error as ObjectStoreError;
|
||||||
use revision::Error as RevisionError;
|
use revision::Error as RevisionError;
|
||||||
use serde::Serialize;
|
use serde::Serialize;
|
||||||
use std::io::Error as IoError;
|
use std::io::Error as IoError;
|
||||||
|
@ -194,6 +195,12 @@ pub enum Error {
|
||||||
message: String,
|
message: String,
|
||||||
},
|
},
|
||||||
|
|
||||||
|
/// There was an error with the provided machine learning model
|
||||||
|
#[error("Problem with machine learning computation. {message}")]
|
||||||
|
InvalidModel {
|
||||||
|
message: String,
|
||||||
|
},
|
||||||
|
|
||||||
/// There was a problem running the specified function
|
/// There was a problem running the specified function
|
||||||
#[error("There was a problem running the {name}() function. {message}")]
|
#[error("There was a problem running the {name}() function. {message}")]
|
||||||
InvalidFunction {
|
InvalidFunction {
|
||||||
|
@ -316,6 +323,12 @@ pub enum Error {
|
||||||
value: String,
|
value: String,
|
||||||
},
|
},
|
||||||
|
|
||||||
|
/// The requested model does not exist
|
||||||
|
#[error("The model 'ml::{value}' does not exist")]
|
||||||
|
MlNotFound {
|
||||||
|
value: String,
|
||||||
|
},
|
||||||
|
|
||||||
/// The requested scope does not exist
|
/// The requested scope does not exist
|
||||||
#[error("The scope '{value}' does not exist")]
|
#[error("The scope '{value}' does not exist")]
|
||||||
ScNotFound {
|
ScNotFound {
|
||||||
|
@ -635,6 +648,14 @@ pub enum Error {
|
||||||
#[error("Utf8 error: {0}")]
|
#[error("Utf8 error: {0}")]
|
||||||
Utf8Error(#[from] FromUtf8Error),
|
Utf8Error(#[from] FromUtf8Error),
|
||||||
|
|
||||||
|
/// Represents an underlying error with the Object Store
|
||||||
|
#[error("Object Store error: {0}")]
|
||||||
|
ObsError(#[from] ObjectStoreError),
|
||||||
|
|
||||||
|
/// There was an error with model computation
|
||||||
|
#[error("There was an error with model computation: {0}")]
|
||||||
|
ModelComputation(String),
|
||||||
|
|
||||||
/// The feature has not yet being implemented
|
/// The feature has not yet being implemented
|
||||||
#[error("Feature not yet implemented: {feature}")]
|
#[error("Feature not yet implemented: {feature}")]
|
||||||
FeatureNotYetImplemented {
|
FeatureNotYetImplemented {
|
||||||
|
|
17
lib/src/iam/check.rs
Normal file
17
lib/src/iam/check.rs
Normal file
|
@ -0,0 +1,17 @@
|
||||||
|
use crate::dbs::Session;
|
||||||
|
use crate::err::Error;
|
||||||
|
|
||||||
|
pub fn check_ns_db(sess: &Session) -> Result<(String, String), Error> {
|
||||||
|
// Ensure that a namespace was specified
|
||||||
|
let ns = match sess.ns.clone() {
|
||||||
|
Some(ns) => ns,
|
||||||
|
None => return Err(Error::NsEmpty),
|
||||||
|
};
|
||||||
|
// Ensure that a database was specified
|
||||||
|
let db = match sess.db.clone() {
|
||||||
|
Some(db) => db,
|
||||||
|
None => return Err(Error::DbEmpty),
|
||||||
|
};
|
||||||
|
// All ok
|
||||||
|
Ok((ns, db))
|
||||||
|
}
|
|
@ -23,6 +23,7 @@ pub enum ResourceKind {
|
||||||
Function,
|
Function,
|
||||||
Analyzer,
|
Analyzer,
|
||||||
Parameter,
|
Parameter,
|
||||||
|
Model,
|
||||||
Event,
|
Event,
|
||||||
Field,
|
Field,
|
||||||
Index,
|
Index,
|
||||||
|
@ -44,6 +45,7 @@ impl std::fmt::Display for ResourceKind {
|
||||||
ResourceKind::Function => write!(f, "Function"),
|
ResourceKind::Function => write!(f, "Function"),
|
||||||
ResourceKind::Analyzer => write!(f, "Analyzer"),
|
ResourceKind::Analyzer => write!(f, "Analyzer"),
|
||||||
ResourceKind::Parameter => write!(f, "Parameter"),
|
ResourceKind::Parameter => write!(f, "Parameter"),
|
||||||
|
ResourceKind::Model => write!(f, "Model"),
|
||||||
ResourceKind::Event => write!(f, "Event"),
|
ResourceKind::Event => write!(f, "Event"),
|
||||||
ResourceKind::Field => write!(f, "Field"),
|
ResourceKind::Field => write!(f, "Field"),
|
||||||
ResourceKind::Index => write!(f, "Index"),
|
ResourceKind::Index => write!(f, "Index"),
|
||||||
|
|
|
@ -4,6 +4,7 @@ use thiserror::Error;
|
||||||
|
|
||||||
pub mod auth;
|
pub mod auth;
|
||||||
pub mod base;
|
pub mod base;
|
||||||
|
pub mod check;
|
||||||
pub mod clear;
|
pub mod clear;
|
||||||
pub mod entities;
|
pub mod entities;
|
||||||
pub mod policies;
|
pub mod policies;
|
||||||
|
|
89
lib/src/key/database/ml.rs
Normal file
89
lib/src/key/database/ml.rs
Normal file
|
@ -0,0 +1,89 @@
|
||||||
|
/// Stores a DEFINE MODEL config definition
|
||||||
|
use crate::key::error::KeyCategory;
|
||||||
|
use crate::key::key_req::KeyRequirements;
|
||||||
|
use derive::Key;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Eq, PartialEq, PartialOrd, Serialize, Deserialize, Key)]
|
||||||
|
pub struct Ml<'a> {
|
||||||
|
__: u8,
|
||||||
|
_a: u8,
|
||||||
|
pub ns: &'a str,
|
||||||
|
_b: u8,
|
||||||
|
pub db: &'a str,
|
||||||
|
_c: u8,
|
||||||
|
_d: u8,
|
||||||
|
_e: u8,
|
||||||
|
pub ml: &'a str,
|
||||||
|
pub vn: &'a str,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn new<'a>(ns: &'a str, db: &'a str, ml: &'a str, vn: &'a str) -> Ml<'a> {
|
||||||
|
Ml::new(ns, db, ml, vn)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn prefix(ns: &str, db: &str) -> Vec<u8> {
|
||||||
|
let mut k = super::all::new(ns, db).encode().unwrap();
|
||||||
|
k.extend_from_slice(&[b'!', b'm', b'l', 0x00]);
|
||||||
|
k
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn suffix(ns: &str, db: &str) -> Vec<u8> {
|
||||||
|
let mut k = super::all::new(ns, db).encode().unwrap();
|
||||||
|
k.extend_from_slice(&[b'!', b'm', b'l', 0xff]);
|
||||||
|
k
|
||||||
|
}
|
||||||
|
|
||||||
|
impl KeyRequirements for Ml<'_> {
|
||||||
|
fn key_category(&self) -> KeyCategory {
|
||||||
|
KeyCategory::DatabaseModel
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> Ml<'a> {
|
||||||
|
pub fn new(ns: &'a str, db: &'a str, ml: &'a str, vn: &'a str) -> Self {
|
||||||
|
Self {
|
||||||
|
__: b'/',
|
||||||
|
_a: b'*',
|
||||||
|
ns,
|
||||||
|
_b: b'*',
|
||||||
|
db,
|
||||||
|
_c: b'!',
|
||||||
|
_d: b'm',
|
||||||
|
_e: b'l',
|
||||||
|
ml,
|
||||||
|
vn,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
#[test]
|
||||||
|
fn key() {
|
||||||
|
use super::*;
|
||||||
|
#[rustfmt::skip]
|
||||||
|
let val = Ml::new(
|
||||||
|
"testns",
|
||||||
|
"testdb",
|
||||||
|
"testml",
|
||||||
|
"1.0.0",
|
||||||
|
);
|
||||||
|
let enc = Ml::encode(&val).unwrap();
|
||||||
|
assert_eq!(enc, b"/*testns\x00*testdb\x00!mltestml\x001.0.0\x00");
|
||||||
|
let dec = Ml::decode(&enc).unwrap();
|
||||||
|
assert_eq!(val, dec);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_prefix() {
|
||||||
|
let val = super::prefix("testns", "testdb");
|
||||||
|
assert_eq!(val, b"/*testns\0*testdb\0!ml\0");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_suffix() {
|
||||||
|
let val = super::suffix("testns", "testdb");
|
||||||
|
assert_eq!(val, b"/*testns\0*testdb\0!ml\xff");
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,6 +1,7 @@
|
||||||
pub mod all;
|
pub mod all;
|
||||||
pub mod az;
|
pub mod az;
|
||||||
pub mod fc;
|
pub mod fc;
|
||||||
|
pub mod ml;
|
||||||
pub mod pa;
|
pub mod pa;
|
||||||
pub mod sc;
|
pub mod sc;
|
||||||
pub mod tb;
|
pub mod tb;
|
||||||
|
|
|
@ -44,6 +44,8 @@ pub enum KeyCategory {
|
||||||
DatabaseFunction,
|
DatabaseFunction,
|
||||||
/// crate::key::database::lg /*{ns}*{db}!lg{lg}
|
/// crate::key::database::lg /*{ns}*{db}!lg{lg}
|
||||||
DatabaseLog,
|
DatabaseLog,
|
||||||
|
/// crate::key::database::ml /*{ns}*{db}!ml{ml}{vn}
|
||||||
|
DatabaseModel,
|
||||||
/// crate::key::database::pa /*{ns}*{db}!pa{pa}
|
/// crate::key::database::pa /*{ns}*{db}!pa{pa}
|
||||||
DatabaseParameter,
|
DatabaseParameter,
|
||||||
/// crate::key::database::sc /*{ns}*{db}!sc{sc}
|
/// crate::key::database::sc /*{ns}*{db}!sc{sc}
|
||||||
|
@ -138,6 +140,7 @@ impl Display for KeyCategory {
|
||||||
KeyCategory::DatabaseAnalyzer => "DatabaseAnalyzer",
|
KeyCategory::DatabaseAnalyzer => "DatabaseAnalyzer",
|
||||||
KeyCategory::DatabaseFunction => "DatabaseFunction",
|
KeyCategory::DatabaseFunction => "DatabaseFunction",
|
||||||
KeyCategory::DatabaseLog => "DatabaseLog",
|
KeyCategory::DatabaseLog => "DatabaseLog",
|
||||||
|
KeyCategory::DatabaseModel => "DatabaseModel",
|
||||||
KeyCategory::DatabaseParameter => "DatabaseParameter",
|
KeyCategory::DatabaseParameter => "DatabaseParameter",
|
||||||
KeyCategory::DatabaseScope => "DatabaseScope",
|
KeyCategory::DatabaseScope => "DatabaseScope",
|
||||||
KeyCategory::DatabaseTable => "DatabaseTable",
|
KeyCategory::DatabaseTable => "DatabaseTable",
|
||||||
|
|
|
@ -6,6 +6,7 @@ use crate::sql::statements::DefineEventStatement;
|
||||||
use crate::sql::statements::DefineFieldStatement;
|
use crate::sql::statements::DefineFieldStatement;
|
||||||
use crate::sql::statements::DefineFunctionStatement;
|
use crate::sql::statements::DefineFunctionStatement;
|
||||||
use crate::sql::statements::DefineIndexStatement;
|
use crate::sql::statements::DefineIndexStatement;
|
||||||
|
use crate::sql::statements::DefineModelStatement;
|
||||||
use crate::sql::statements::DefineNamespaceStatement;
|
use crate::sql::statements::DefineNamespaceStatement;
|
||||||
use crate::sql::statements::DefineParamStatement;
|
use crate::sql::statements::DefineParamStatement;
|
||||||
use crate::sql::statements::DefineScopeStatement;
|
use crate::sql::statements::DefineScopeStatement;
|
||||||
|
@ -22,6 +23,7 @@ pub enum Entry {
|
||||||
Db(Arc<DefineDatabaseStatement>),
|
Db(Arc<DefineDatabaseStatement>),
|
||||||
Fc(Arc<DefineFunctionStatement>),
|
Fc(Arc<DefineFunctionStatement>),
|
||||||
Ix(Arc<DefineIndexStatement>),
|
Ix(Arc<DefineIndexStatement>),
|
||||||
|
Ml(Arc<DefineModelStatement>),
|
||||||
Ns(Arc<DefineNamespaceStatement>),
|
Ns(Arc<DefineNamespaceStatement>),
|
||||||
Pa(Arc<DefineParamStatement>),
|
Pa(Arc<DefineParamStatement>),
|
||||||
Tb(Arc<DefineTableStatement>),
|
Tb(Arc<DefineTableStatement>),
|
||||||
|
@ -36,6 +38,7 @@ pub enum Entry {
|
||||||
Fts(Arc<[DefineTableStatement]>),
|
Fts(Arc<[DefineTableStatement]>),
|
||||||
Ixs(Arc<[DefineIndexStatement]>),
|
Ixs(Arc<[DefineIndexStatement]>),
|
||||||
Lvs(Arc<[LiveStatement]>),
|
Lvs(Arc<[LiveStatement]>),
|
||||||
|
Mls(Arc<[DefineModelStatement]>),
|
||||||
Nss(Arc<[DefineNamespaceStatement]>),
|
Nss(Arc<[DefineNamespaceStatement]>),
|
||||||
Nts(Arc<[DefineTokenStatement]>),
|
Nts(Arc<[DefineTokenStatement]>),
|
||||||
Nus(Arc<[DefineUserStatement]>),
|
Nus(Arc<[DefineUserStatement]>),
|
||||||
|
|
|
@ -6,7 +6,7 @@ use crate::dbs::{
|
||||||
Variables,
|
Variables,
|
||||||
};
|
};
|
||||||
use crate::err::Error;
|
use crate::err::Error;
|
||||||
use crate::iam::{Action, Auth, Error as IamError, ResourceKind, Role};
|
use crate::iam::{Action, Auth, Error as IamError, Resource, Role};
|
||||||
use crate::key::root::hb::Hb;
|
use crate::key::root::hb::Hb;
|
||||||
use crate::kvs::clock::SizedClock;
|
use crate::kvs::clock::SizedClock;
|
||||||
#[allow(unused_imports)]
|
#[allow(unused_imports)]
|
||||||
|
@ -1232,9 +1232,11 @@ impl Datastore {
|
||||||
self.notification_channel.as_ref().map(|v| v.1.clone())
|
self.notification_channel.as_ref().map(|v| v.1.clone())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(dead_code)]
|
/// Performs a database import from SQL
|
||||||
pub(crate) fn live_sender(&self) -> Option<Arc<RwLock<Sender<Notification>>>> {
|
#[instrument(level = "debug", skip(self, sess, sql))]
|
||||||
self.notification_channel.as_ref().map(|v| Arc::new(RwLock::new(v.0.clone())))
|
pub async fn import(&self, sql: &str, sess: &Session) -> Result<Vec<Response>, Error> {
|
||||||
|
// Execute the SQL import
|
||||||
|
self.execute(sql, sess, None).await
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Performs a full database export as SQL
|
/// Performs a full database export as SQL
|
||||||
|
@ -1242,15 +1244,10 @@ impl Datastore {
|
||||||
pub async fn export(
|
pub async fn export(
|
||||||
&self,
|
&self,
|
||||||
sess: &Session,
|
sess: &Session,
|
||||||
ns: String,
|
|
||||||
db: String,
|
|
||||||
chn: Sender<Vec<u8>>,
|
chn: Sender<Vec<u8>>,
|
||||||
) -> Result<impl Future<Output = Result<(), Error>>, Error> {
|
) -> Result<impl Future<Output = Result<(), Error>>, Error> {
|
||||||
// Skip auth for Anonymous users if auth is disabled
|
// Retrieve the provided NS and DB
|
||||||
let skip_auth = !self.is_auth_enabled() && sess.au.is_anon();
|
let (ns, db) = crate::iam::check::check_ns_db(sess)?;
|
||||||
if !skip_auth {
|
|
||||||
sess.au.is_allowed(Action::View, &ResourceKind::Any.on_db(&ns, &db))?;
|
|
||||||
}
|
|
||||||
// Create a new readonly transaction
|
// Create a new readonly transaction
|
||||||
let mut txn = self.transaction(Read, Optimistic).await?;
|
let mut txn = self.transaction(Read, Optimistic).await?;
|
||||||
// Return an async export job
|
// Return an async export job
|
||||||
|
@ -1262,18 +1259,15 @@ impl Datastore {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Performs a database import from SQL
|
/// Checks the required permissions level for this session
|
||||||
#[instrument(level = "debug", skip(self, sess, sql))]
|
#[instrument(level = "debug", skip(self, sess))]
|
||||||
pub async fn import(&self, sql: &str, sess: &Session) -> Result<Vec<Response>, Error> {
|
pub fn check(&self, sess: &Session, action: Action, resource: Resource) -> Result<(), Error> {
|
||||||
// Skip auth for Anonymous users if auth is disabled
|
// Skip auth for Anonymous users if auth is disabled
|
||||||
let skip_auth = !self.is_auth_enabled() && sess.au.is_anon();
|
let skip_auth = !self.is_auth_enabled() && sess.au.is_anon();
|
||||||
if !skip_auth {
|
if !skip_auth {
|
||||||
sess.au.is_allowed(
|
sess.au.is_allowed(action, &resource)?;
|
||||||
Action::Edit,
|
|
||||||
&ResourceKind::Any.on_level(sess.au.level().to_owned()),
|
|
||||||
)?;
|
|
||||||
}
|
}
|
||||||
// Execute the SQL import
|
// All ok
|
||||||
self.execute(sql, sess, None).await
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,64 +1,55 @@
|
||||||
use once_cell::sync::Lazy;
|
use once_cell::sync::Lazy;
|
||||||
|
|
||||||
pub static ROCKSDB_THREAD_COUNT: Lazy<i32> = Lazy::new(|| {
|
pub static ROCKSDB_THREAD_COUNT: Lazy<i32> = Lazy::new(|| {
|
||||||
let default = num_cpus::get() as i32;
|
option_env!("SURREAL_ROCKSDB_THREAD_COUNT")
|
||||||
std::env::var("SURREAL_ROCKSDB_THREAD_COUNT")
|
.and_then(|s| s.parse::<i32>().ok())
|
||||||
.map(|v| v.parse::<i32>().unwrap_or(default))
|
.unwrap_or(num_cpus::get() as i32)
|
||||||
.unwrap_or(default)
|
|
||||||
});
|
});
|
||||||
|
|
||||||
pub static ROCKSDB_WRITE_BUFFER_SIZE: Lazy<usize> = Lazy::new(|| {
|
pub static ROCKSDB_WRITE_BUFFER_SIZE: Lazy<usize> = Lazy::new(|| {
|
||||||
let default = 256 * 1024 * 1024;
|
option_env!("SURREAL_ROCKSDB_WRITE_BUFFER_SIZE")
|
||||||
std::env::var("SURREAL_ROCKSDB_WRITE_BUFFER_SIZE")
|
.and_then(|s| s.parse::<usize>().ok())
|
||||||
.map(|v| v.parse::<usize>().unwrap_or(default))
|
.unwrap_or(256 * 1024 * 1024)
|
||||||
.unwrap_or(default)
|
|
||||||
});
|
});
|
||||||
|
|
||||||
pub static ROCKSDB_TARGET_FILE_SIZE_BASE: Lazy<u64> = Lazy::new(|| {
|
pub static ROCKSDB_TARGET_FILE_SIZE_BASE: Lazy<u64> = Lazy::new(|| {
|
||||||
let default = 512 * 1024 * 1024;
|
option_env!("SURREAL_ROCKSDB_TARGET_FILE_SIZE_BASE")
|
||||||
std::env::var("SURREAL_ROCKSDB_TARGET_FILE_SIZE_BASE")
|
.and_then(|s| s.parse::<u64>().ok())
|
||||||
.map(|v| v.parse::<u64>().unwrap_or(default))
|
.unwrap_or(512 * 1024 * 1024)
|
||||||
.unwrap_or(default)
|
|
||||||
});
|
});
|
||||||
|
|
||||||
pub static ROCKSDB_MAX_WRITE_BUFFER_NUMBER: Lazy<i32> = Lazy::new(|| {
|
pub static ROCKSDB_MAX_WRITE_BUFFER_NUMBER: Lazy<i32> = Lazy::new(|| {
|
||||||
let default = 32;
|
option_env!("SURREAL_ROCKSDB_MAX_WRITE_BUFFER_NUMBER")
|
||||||
std::env::var("SURREAL_ROCKSDB_MAX_WRITE_BUFFER_NUMBER")
|
.and_then(|s| s.parse::<i32>().ok())
|
||||||
.map(|v| v.parse::<i32>().unwrap_or(default))
|
.unwrap_or(32)
|
||||||
.unwrap_or(default)
|
|
||||||
});
|
});
|
||||||
|
|
||||||
pub static ROCKSDB_MIN_WRITE_BUFFER_NUMBER_TO_MERGE: Lazy<i32> = Lazy::new(|| {
|
pub static ROCKSDB_MIN_WRITE_BUFFER_NUMBER_TO_MERGE: Lazy<i32> = Lazy::new(|| {
|
||||||
let default = 4;
|
option_env!("SURREAL_ROCKSDB_MIN_WRITE_BUFFER_NUMBER_TO_MERGE")
|
||||||
std::env::var("SURREAL_ROCKSDB_MIN_WRITE_BUFFER_NUMBER_TO_MERGE")
|
.and_then(|s| s.parse::<i32>().ok())
|
||||||
.map(|v| v.parse::<i32>().unwrap_or(default))
|
.unwrap_or(4)
|
||||||
.unwrap_or(default)
|
|
||||||
});
|
});
|
||||||
|
|
||||||
pub static ROCKSDB_ENABLE_PIPELINED_WRITES: Lazy<bool> = Lazy::new(|| {
|
pub static ROCKSDB_ENABLE_PIPELINED_WRITES: Lazy<bool> = Lazy::new(|| {
|
||||||
let default = true;
|
option_env!("SURREAL_ROCKSDB_ENABLE_PIPELINED_WRITES")
|
||||||
std::env::var("SURREAL_ROCKSDB_ENABLE_PIPELINED_WRITES")
|
.and_then(|s| s.parse::<bool>().ok())
|
||||||
.map(|v| v.parse::<bool>().unwrap_or(default))
|
.unwrap_or(true)
|
||||||
.unwrap_or(default)
|
|
||||||
});
|
});
|
||||||
|
|
||||||
pub static ROCKSDB_ENABLE_BLOB_FILES: Lazy<bool> = Lazy::new(|| {
|
pub static ROCKSDB_ENABLE_BLOB_FILES: Lazy<bool> = Lazy::new(|| {
|
||||||
let default = true;
|
option_env!("SURREAL_ROCKSDB_ENABLE_BLOB_FILES")
|
||||||
std::env::var("SURREAL_ROCKSDB_ENABLE_BLOB_FILES")
|
.and_then(|s| s.parse::<bool>().ok())
|
||||||
.map(|v| v.parse::<bool>().unwrap_or(default))
|
.unwrap_or(true)
|
||||||
.unwrap_or(default)
|
|
||||||
});
|
});
|
||||||
|
|
||||||
pub static ROCKSDB_MIN_BLOB_SIZE: Lazy<u64> = Lazy::new(|| {
|
pub static ROCKSDB_MIN_BLOB_SIZE: Lazy<u64> = Lazy::new(|| {
|
||||||
let default = 4 * 1024;
|
option_env!("SURREAL_ROCKSDB_MIN_BLOB_SIZE")
|
||||||
std::env::var("SURREAL_ROCKSDB_MIN_BLOB_SIZE")
|
.and_then(|s| s.parse::<u64>().ok())
|
||||||
.map(|v| v.parse::<u64>().unwrap_or(default))
|
.unwrap_or(4 * 1024)
|
||||||
.unwrap_or(default)
|
|
||||||
});
|
});
|
||||||
|
|
||||||
pub static ROCKSDB_KEEP_LOG_FILE_NUM: Lazy<usize> = Lazy::new(|| {
|
pub static ROCKSDB_KEEP_LOG_FILE_NUM: Lazy<usize> = Lazy::new(|| {
|
||||||
let default = 20;
|
option_env!("SURREAL_ROCKSDB_KEEP_LOG_FILE_NUM")
|
||||||
std::env::var("SURREAL_ROCKSDB_KEEP_LOG_FILE_NUM")
|
.and_then(|s| s.parse::<usize>().ok())
|
||||||
.map(|v| v.parse::<usize>().unwrap_or(default))
|
.unwrap_or(20)
|
||||||
.unwrap_or(default)
|
|
||||||
});
|
});
|
||||||
|
|
|
@ -1,64 +1,55 @@
|
||||||
use once_cell::sync::Lazy;
|
use once_cell::sync::Lazy;
|
||||||
|
|
||||||
pub static SPEEDB_THREAD_COUNT: Lazy<i32> = Lazy::new(|| {
|
pub static SPEEDB_THREAD_COUNT: Lazy<i32> = Lazy::new(|| {
|
||||||
let default = num_cpus::get() as i32;
|
option_env!("SURREAL_SPEEDB_THREAD_COUNT")
|
||||||
std::env::var("SURREAL_SPEEDB_THREAD_COUNT")
|
.and_then(|s| s.parse::<i32>().ok())
|
||||||
.map(|v| v.parse::<i32>().unwrap_or(default))
|
.unwrap_or(num_cpus::get() as i32)
|
||||||
.unwrap_or(default)
|
|
||||||
});
|
});
|
||||||
|
|
||||||
pub static SPEEDB_WRITE_BUFFER_SIZE: Lazy<usize> = Lazy::new(|| {
|
pub static SPEEDB_WRITE_BUFFER_SIZE: Lazy<usize> = Lazy::new(|| {
|
||||||
let default = 256 * 1024 * 1024;
|
option_env!("SURREAL_SPEEDB_WRITE_BUFFER_SIZE")
|
||||||
std::env::var("SURREAL_SPEEDB_WRITE_BUFFER_SIZE")
|
.and_then(|s| s.parse::<usize>().ok())
|
||||||
.map(|v| v.parse::<usize>().unwrap_or(default))
|
.unwrap_or(256 * 1024 * 1024)
|
||||||
.unwrap_or(default)
|
|
||||||
});
|
});
|
||||||
|
|
||||||
pub static SPEEDB_TARGET_FILE_SIZE_BASE: Lazy<u64> = Lazy::new(|| {
|
pub static SPEEDB_TARGET_FILE_SIZE_BASE: Lazy<u64> = Lazy::new(|| {
|
||||||
let default = 512 * 1024 * 1024;
|
option_env!("SURREAL_SPEEDB_TARGET_FILE_SIZE_BASE")
|
||||||
std::env::var("SURREAL_SPEEDB_TARGET_FILE_SIZE_BASE")
|
.and_then(|s| s.parse::<u64>().ok())
|
||||||
.map(|v| v.parse::<u64>().unwrap_or(default))
|
.unwrap_or(512 * 1024 * 1024)
|
||||||
.unwrap_or(default)
|
|
||||||
});
|
});
|
||||||
|
|
||||||
pub static SPEEDB_MAX_WRITE_BUFFER_NUMBER: Lazy<i32> = Lazy::new(|| {
|
pub static SPEEDB_MAX_WRITE_BUFFER_NUMBER: Lazy<i32> = Lazy::new(|| {
|
||||||
let default = 32;
|
option_env!("SURREAL_SPEEDB_MAX_WRITE_BUFFER_NUMBER")
|
||||||
std::env::var("SURREAL_SPEEDB_MAX_WRITE_BUFFER_NUMBER")
|
.and_then(|s| s.parse::<i32>().ok())
|
||||||
.map(|v| v.parse::<i32>().unwrap_or(default))
|
.unwrap_or(32)
|
||||||
.unwrap_or(default)
|
|
||||||
});
|
});
|
||||||
|
|
||||||
pub static SPEEDB_MIN_WRITE_BUFFER_NUMBER_TO_MERGE: Lazy<i32> = Lazy::new(|| {
|
pub static SPEEDB_MIN_WRITE_BUFFER_NUMBER_TO_MERGE: Lazy<i32> = Lazy::new(|| {
|
||||||
let default = 4;
|
option_env!("SURREAL_SPEEDB_MIN_WRITE_BUFFER_NUMBER_TO_MERGE")
|
||||||
std::env::var("SURREAL_SPEEDB_MIN_WRITE_BUFFER_NUMBER_TO_MERGE")
|
.and_then(|s| s.parse::<i32>().ok())
|
||||||
.map(|v| v.parse::<i32>().unwrap_or(default))
|
.unwrap_or(4)
|
||||||
.unwrap_or(default)
|
|
||||||
});
|
});
|
||||||
|
|
||||||
pub static SPEEDB_ENABLE_PIPELINED_WRITES: Lazy<bool> = Lazy::new(|| {
|
pub static SPEEDB_ENABLE_PIPELINED_WRITES: Lazy<bool> = Lazy::new(|| {
|
||||||
let default = true;
|
option_env!("SURREAL_SPEEDB_ENABLE_PIPELINED_WRITES")
|
||||||
std::env::var("SURREAL_SPEEDB_ENABLE_PIPELINED_WRITES")
|
.and_then(|s| s.parse::<bool>().ok())
|
||||||
.map(|v| v.parse::<bool>().unwrap_or(default))
|
.unwrap_or(true)
|
||||||
.unwrap_or(default)
|
|
||||||
});
|
});
|
||||||
|
|
||||||
pub static SPEEDB_ENABLE_BLOB_FILES: Lazy<bool> = Lazy::new(|| {
|
pub static SPEEDB_ENABLE_BLOB_FILES: Lazy<bool> = Lazy::new(|| {
|
||||||
let default = true;
|
option_env!("SURREAL_SPEEDB_ENABLE_BLOB_FILES")
|
||||||
std::env::var("SURREAL_SPEEDB_ENABLE_BLOB_FILES")
|
.and_then(|s| s.parse::<bool>().ok())
|
||||||
.map(|v| v.parse::<bool>().unwrap_or(default))
|
.unwrap_or(true)
|
||||||
.unwrap_or(default)
|
|
||||||
});
|
});
|
||||||
|
|
||||||
pub static SPEEDB_MIN_BLOB_SIZE: Lazy<u64> = Lazy::new(|| {
|
pub static SPEEDB_MIN_BLOB_SIZE: Lazy<u64> = Lazy::new(|| {
|
||||||
let default = 4 * 1024;
|
option_env!("SURREAL_SPEEDB_MIN_BLOB_SIZE")
|
||||||
std::env::var("SURREAL_SPEEDB_ENABLE_BLOB_FILES")
|
.and_then(|s| s.parse::<u64>().ok())
|
||||||
.map(|v| v.parse::<u64>().unwrap_or(default))
|
.unwrap_or(4 * 1024)
|
||||||
.unwrap_or(default)
|
|
||||||
});
|
});
|
||||||
|
|
||||||
pub static SPEEDB_KEEP_LOG_FILE_NUM: Lazy<usize> = Lazy::new(|| {
|
pub static SPEEDB_KEEP_LOG_FILE_NUM: Lazy<usize> = Lazy::new(|| {
|
||||||
let default = 20;
|
option_env!("SURREAL_SPEEDB_KEEP_LOG_FILE_NUM")
|
||||||
std::env::var("SURREAL_SPEEDB_KEEP_LOG_FILE_NUM")
|
.and_then(|s| s.parse::<usize>().ok())
|
||||||
.map(|v| v.parse::<usize>().unwrap_or(default))
|
.unwrap_or(20)
|
||||||
.unwrap_or(default)
|
|
||||||
});
|
});
|
||||||
|
|
|
@ -33,6 +33,7 @@ use sql::statements::DefineEventStatement;
|
||||||
use sql::statements::DefineFieldStatement;
|
use sql::statements::DefineFieldStatement;
|
||||||
use sql::statements::DefineFunctionStatement;
|
use sql::statements::DefineFunctionStatement;
|
||||||
use sql::statements::DefineIndexStatement;
|
use sql::statements::DefineIndexStatement;
|
||||||
|
use sql::statements::DefineModelStatement;
|
||||||
use sql::statements::DefineNamespaceStatement;
|
use sql::statements::DefineNamespaceStatement;
|
||||||
use sql::statements::DefineParamStatement;
|
use sql::statements::DefineParamStatement;
|
||||||
use sql::statements::DefineScopeStatement;
|
use sql::statements::DefineScopeStatement;
|
||||||
|
@ -1557,6 +1558,29 @@ impl Transaction {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Retrieve all model definitions for a specific database.
|
||||||
|
pub async fn all_db_models(
|
||||||
|
&mut self,
|
||||||
|
ns: &str,
|
||||||
|
db: &str,
|
||||||
|
) -> Result<Arc<[DefineModelStatement]>, Error> {
|
||||||
|
let key = crate::key::database::ml::prefix(ns, db);
|
||||||
|
Ok(if let Some(e) = self.cache.get(&key) {
|
||||||
|
if let Entry::Mls(v) = e {
|
||||||
|
v
|
||||||
|
} else {
|
||||||
|
unreachable!();
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
let beg = crate::key::database::ml::prefix(ns, db);
|
||||||
|
let end = crate::key::database::ml::suffix(ns, db);
|
||||||
|
let val = self.getr(beg..end, u32::MAX).await?;
|
||||||
|
let val = val.convert().into();
|
||||||
|
self.cache.set(key, Entry::Mls(Arc::clone(&val)));
|
||||||
|
val
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
/// Retrieve all scope definitions for a specific database.
|
/// Retrieve all scope definitions for a specific database.
|
||||||
pub async fn all_sc(
|
pub async fn all_sc(
|
||||||
&mut self,
|
&mut self,
|
||||||
|
@ -1840,6 +1864,21 @@ impl Transaction {
|
||||||
Ok(val.into())
|
Ok(val.into())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Retrieve a specific model definition from a database.
|
||||||
|
pub async fn get_db_model(
|
||||||
|
&mut self,
|
||||||
|
ns: &str,
|
||||||
|
db: &str,
|
||||||
|
ml: &str,
|
||||||
|
vn: &str,
|
||||||
|
) -> Result<DefineModelStatement, Error> {
|
||||||
|
let key = crate::key::database::ml::new(ns, db, ml, vn);
|
||||||
|
let val = self.get(key).await?.ok_or(Error::MlNotFound {
|
||||||
|
value: format!("{ml}<{vn}>"),
|
||||||
|
})?;
|
||||||
|
Ok(val.into())
|
||||||
|
}
|
||||||
|
|
||||||
/// Retrieve a specific database token definition.
|
/// Retrieve a specific database token definition.
|
||||||
pub async fn get_db_token(
|
pub async fn get_db_token(
|
||||||
&mut self,
|
&mut self,
|
||||||
|
@ -2178,6 +2217,31 @@ impl Transaction {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Retrieve a specific model definition.
|
||||||
|
pub async fn get_and_cache_db_model(
|
||||||
|
&mut self,
|
||||||
|
ns: &str,
|
||||||
|
db: &str,
|
||||||
|
ml: &str,
|
||||||
|
vn: &str,
|
||||||
|
) -> Result<Arc<DefineModelStatement>, Error> {
|
||||||
|
let key = crate::key::database::ml::new(ns, db, ml, vn).encode()?;
|
||||||
|
Ok(if let Some(e) = self.cache.get(&key) {
|
||||||
|
if let Entry::Ml(v) = e {
|
||||||
|
v
|
||||||
|
} else {
|
||||||
|
unreachable!();
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
let val = self.get(key.clone()).await?.ok_or(Error::MlNotFound {
|
||||||
|
value: format!("{ml}<{vn}>"),
|
||||||
|
})?;
|
||||||
|
let val: Arc<DefineModelStatement> = Arc::new(val.into());
|
||||||
|
self.cache.set(key, Entry::Ml(Arc::clone(&val)));
|
||||||
|
val
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
/// Retrieve a specific table index definition.
|
/// Retrieve a specific table index definition.
|
||||||
pub async fn get_and_cache_tb_index(
|
pub async fn get_and_cache_tb_index(
|
||||||
&mut self,
|
&mut self,
|
||||||
|
|
|
@ -134,6 +134,10 @@ pub mod idx;
|
||||||
pub mod key;
|
pub mod key;
|
||||||
#[doc(hidden)]
|
#[doc(hidden)]
|
||||||
pub mod kvs;
|
pub mod kvs;
|
||||||
|
|
||||||
|
#[cfg(feature = "ml")]
|
||||||
|
#[doc(hidden)]
|
||||||
|
pub mod obs;
|
||||||
#[doc(hidden)]
|
#[doc(hidden)]
|
||||||
pub mod syn;
|
pub mod syn;
|
||||||
|
|
||||||
|
|
96
lib/src/obs/mod.rs
Normal file
96
lib/src/obs/mod.rs
Normal file
|
@ -0,0 +1,96 @@
|
||||||
|
//! This module defines the operations for object storage using the [object_store](https://docs.rs/object_store/latest/object_store/)
|
||||||
|
//! crate. This will enable the user to store objects using local file storage, or cloud storage such as S3 or GCS.
|
||||||
|
use crate::err::Error;
|
||||||
|
use bytes::Bytes;
|
||||||
|
use futures::stream::BoxStream;
|
||||||
|
use object_store::local::LocalFileSystem;
|
||||||
|
use object_store::parse_url;
|
||||||
|
use object_store::path::Path;
|
||||||
|
use object_store::ObjectStore;
|
||||||
|
use once_cell::sync::Lazy;
|
||||||
|
use sha1::{Digest, Sha1};
|
||||||
|
use std::env;
|
||||||
|
use std::fs;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use url::Url;
|
||||||
|
|
||||||
|
static STORE: Lazy<Arc<dyn ObjectStore>> =
|
||||||
|
Lazy::new(|| match std::env::var("SURREAL_OBJECT_STORE") {
|
||||||
|
Ok(url) => {
|
||||||
|
let url = Url::parse(&url).expect("Expected a valid url for SURREAL_OBJECT_STORE");
|
||||||
|
let (store, _) =
|
||||||
|
parse_url(&url).expect("Expected a valid url for SURREAL_OBJECT_STORE");
|
||||||
|
Arc::new(store)
|
||||||
|
}
|
||||||
|
Err(_) => {
|
||||||
|
let path = env::current_dir().unwrap().join("store");
|
||||||
|
if !path.exists() || !path.is_dir() {
|
||||||
|
fs::create_dir_all(&path)
|
||||||
|
.expect("Unable to create directory structure for SURREAL_OBJECT_STORE");
|
||||||
|
}
|
||||||
|
// As long as the provided path is correct, the following should never panic
|
||||||
|
Arc::new(LocalFileSystem::new_with_prefix(path).unwrap())
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
static CACHE: Lazy<Arc<dyn ObjectStore>> =
|
||||||
|
Lazy::new(|| match std::env::var("SURREAL_OBJECT_CACHE") {
|
||||||
|
Ok(url) => {
|
||||||
|
let url = Url::parse(&url).expect("Expected a valid url for SURREAL_OBJECT_CACHE");
|
||||||
|
let (store, _) =
|
||||||
|
parse_url(&url).expect("Expected a valid url for SURREAL_OBJECT_CACHE");
|
||||||
|
Arc::new(store)
|
||||||
|
}
|
||||||
|
Err(_) => {
|
||||||
|
let path = env::current_dir().unwrap().join("cache");
|
||||||
|
if !path.exists() || !path.is_dir() {
|
||||||
|
fs::create_dir_all(&path)
|
||||||
|
.expect("Unable to create directory structure for SURREAL_OBJECT_CACHE");
|
||||||
|
}
|
||||||
|
// As long as the provided path is correct, the following should never panic
|
||||||
|
Arc::new(LocalFileSystem::new_with_prefix(path).unwrap())
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
/// Gets the file from the local file system object storage.
|
||||||
|
pub async fn stream(
|
||||||
|
file: String,
|
||||||
|
) -> Result<BoxStream<'static, Result<Bytes, object_store::Error>>, Error> {
|
||||||
|
match CACHE.get(&Path::from(file.as_str())).await {
|
||||||
|
Ok(data) => Ok(data.into_stream()),
|
||||||
|
_ => Ok(STORE.get(&Path::from(file.as_str())).await?.into_stream()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Gets the file from the local file system object storage.
|
||||||
|
pub async fn get(file: &str) -> Result<Vec<u8>, Error> {
|
||||||
|
match CACHE.get(&Path::from(file)).await {
|
||||||
|
Ok(data) => Ok(data.bytes().await?.to_vec()),
|
||||||
|
_ => {
|
||||||
|
let data = STORE.get(&Path::from(file)).await?;
|
||||||
|
CACHE.put(&Path::from(file), data.bytes().await?).await?;
|
||||||
|
Ok(CACHE.get(&Path::from(file)).await?.bytes().await?.to_vec())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Gets the file from the local file system object storage.
|
||||||
|
pub async fn put(file: &str, data: Vec<u8>) -> Result<(), Error> {
|
||||||
|
let _ = STORE.put(&Path::from(file), Bytes::from(data)).await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Gets the file from the local file system object storage.
|
||||||
|
pub async fn del(file: &str) -> Result<(), Error> {
|
||||||
|
Ok(STORE.delete(&Path::from(file)).await?)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Hashes the bytes of a file to a string for the storage of a file.
|
||||||
|
pub fn hash(data: &Vec<u8>) -> String {
|
||||||
|
let mut hasher = Sha1::new();
|
||||||
|
hasher.update(data);
|
||||||
|
let result = hasher.finalize();
|
||||||
|
let mut output = hex::encode(result);
|
||||||
|
output.truncate(6);
|
||||||
|
output
|
||||||
|
}
|
|
@ -159,8 +159,10 @@ impl Function {
|
||||||
fnc::run(ctx, opt, txn, doc, s, a).await
|
fnc::run(ctx, opt, txn, doc, s, a).await
|
||||||
}
|
}
|
||||||
Self::Custom(s, x) => {
|
Self::Custom(s, x) => {
|
||||||
|
// Get the full name of this function
|
||||||
|
let name = format!("fn::{s}");
|
||||||
// Check this function is allowed
|
// Check this function is allowed
|
||||||
ctx.check_allowed_function(format!("fn::{s}").as_str())?;
|
ctx.check_allowed_function(name.as_str())?;
|
||||||
// Get the function definition
|
// Get the function definition
|
||||||
let val = {
|
let val = {
|
||||||
// Claim transaction
|
// Claim transaction
|
||||||
|
@ -189,15 +191,16 @@ impl Function {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Return the value
|
// Get the number of function arguments
|
||||||
// Check the function arguments
|
|
||||||
let max_args_len = val.args.len();
|
let max_args_len = val.args.len();
|
||||||
|
// Track the number of required arguments
|
||||||
let mut min_args_len = 0;
|
let mut min_args_len = 0;
|
||||||
|
// Check for any final optional arguments
|
||||||
val.args.iter().rev().for_each(|(_, kind)| match kind {
|
val.args.iter().rev().for_each(|(_, kind)| match kind {
|
||||||
Kind::Option(_) if min_args_len == 0 => {}
|
Kind::Option(_) if min_args_len == 0 => {}
|
||||||
_ => min_args_len += 1,
|
_ => min_args_len += 1,
|
||||||
});
|
});
|
||||||
|
// Check the necessary arguments are passed
|
||||||
if x.len() < min_args_len || max_args_len < x.len() {
|
if x.len() < min_args_len || max_args_len < x.len() {
|
||||||
return Err(Error::InvalidArguments {
|
return Err(Error::InvalidArguments {
|
||||||
name: format!("fn::{}", val.name),
|
name: format!("fn::{}", val.name),
|
||||||
|
|
|
@ -1,16 +1,29 @@
|
||||||
use crate::{
|
use crate::ctx::Context;
|
||||||
ctx::Context,
|
use crate::dbs::{Options, Transaction};
|
||||||
dbs::{Options, Transaction},
|
use crate::doc::CursorDoc;
|
||||||
doc::CursorDoc,
|
use crate::err::Error;
|
||||||
err::Error,
|
use crate::sql::value::Value;
|
||||||
sql::value::Value,
|
|
||||||
};
|
|
||||||
use async_recursion::async_recursion;
|
|
||||||
use derive::Store;
|
use derive::Store;
|
||||||
use revision::revisioned;
|
use revision::revisioned;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::fmt;
|
use std::fmt;
|
||||||
|
|
||||||
|
#[cfg(feature = "ml")]
|
||||||
|
use crate::iam::Action;
|
||||||
|
#[cfg(feature = "ml")]
|
||||||
|
use crate::sql::Permission;
|
||||||
|
#[cfg(feature = "ml")]
|
||||||
|
use futures::future::try_join_all;
|
||||||
|
#[cfg(feature = "ml")]
|
||||||
|
use std::collections::HashMap;
|
||||||
|
#[cfg(feature = "ml")]
|
||||||
|
use surrealml_core::execution::compute::ModelComputation;
|
||||||
|
#[cfg(feature = "ml")]
|
||||||
|
use surrealml_core::storage::surml_file::SurMlFile;
|
||||||
|
|
||||||
|
#[cfg(feature = "ml")]
|
||||||
|
const ARGUMENTS: &str = "The model expects 1 argument. The argument can be either a number, an object, or an array of numbers.";
|
||||||
|
|
||||||
#[derive(Clone, Debug, Default, PartialEq, PartialOrd, Serialize, Deserialize, Store, Hash)]
|
#[derive(Clone, Debug, Default, PartialEq, PartialOrd, Serialize, Deserialize, Store, Hash)]
|
||||||
#[revisioned(revision = 1)]
|
#[revisioned(revision = 1)]
|
||||||
pub struct Model {
|
pub struct Model {
|
||||||
|
@ -33,15 +46,165 @@ impl fmt::Display for Model {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Model {
|
impl Model {
|
||||||
#[cfg_attr(not(target_arch = "wasm32"), async_recursion)]
|
#[cfg(feature = "ml")]
|
||||||
#[cfg_attr(target_arch = "wasm32", async_recursion(?Send))]
|
pub(crate) async fn compute(
|
||||||
|
&self,
|
||||||
|
ctx: &Context<'_>,
|
||||||
|
opt: &Options,
|
||||||
|
txn: &Transaction,
|
||||||
|
doc: Option<&CursorDoc<'_>>,
|
||||||
|
) -> Result<Value, Error> {
|
||||||
|
// Ensure futures are run
|
||||||
|
let opt = &opt.new_with_futures(true);
|
||||||
|
// Get the full name of this model
|
||||||
|
let name = format!("ml::{}", self.name);
|
||||||
|
// Check this function is allowed
|
||||||
|
ctx.check_allowed_function(name.as_str())?;
|
||||||
|
// Get the model definition
|
||||||
|
let val = {
|
||||||
|
// Claim transaction
|
||||||
|
let mut run = txn.lock().await;
|
||||||
|
// Get the function definition
|
||||||
|
run.get_and_cache_db_model(opt.ns(), opt.db(), &self.name, &self.version).await?
|
||||||
|
};
|
||||||
|
// Calculate the model path
|
||||||
|
let path = format!(
|
||||||
|
"ml/{}/{}/{}-{}-{}.surml",
|
||||||
|
opt.ns(),
|
||||||
|
opt.db(),
|
||||||
|
self.name,
|
||||||
|
self.version,
|
||||||
|
val.hash
|
||||||
|
);
|
||||||
|
// Check permissions
|
||||||
|
if opt.check_perms(Action::View) {
|
||||||
|
match &val.permissions {
|
||||||
|
Permission::Full => (),
|
||||||
|
Permission::None => {
|
||||||
|
return Err(Error::FunctionPermissions {
|
||||||
|
name: self.name.to_owned(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
Permission::Specific(e) => {
|
||||||
|
// Disable permissions
|
||||||
|
let opt = &opt.new_with_perms(false);
|
||||||
|
// Process the PERMISSION clause
|
||||||
|
if !e.compute(ctx, opt, txn, doc).await?.is_truthy() {
|
||||||
|
return Err(Error::FunctionPermissions {
|
||||||
|
name: self.name.to_owned(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Compute the function arguments
|
||||||
|
let mut args =
|
||||||
|
try_join_all(self.args.iter().map(|v| v.compute(ctx, opt, txn, doc))).await?;
|
||||||
|
// Check the minimum argument length
|
||||||
|
if args.len() != 1 {
|
||||||
|
return Err(Error::InvalidArguments {
|
||||||
|
name: format!("ml::{}<{}>", self.name, self.version),
|
||||||
|
message: ARGUMENTS.into(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
// Take the first and only specified argument
|
||||||
|
match args.swap_remove(0) {
|
||||||
|
// Perform bufferered compute
|
||||||
|
Value::Object(v) => {
|
||||||
|
// Compute the model function arguments
|
||||||
|
let mut args = v
|
||||||
|
.into_iter()
|
||||||
|
.map(|(k, v)| Ok((k, Value::try_into(v)?)))
|
||||||
|
.collect::<Result<HashMap<String, f32>, Error>>()
|
||||||
|
.map_err(|_| Error::InvalidArguments {
|
||||||
|
name: format!("ml::{}<{}>", self.name, self.version),
|
||||||
|
message: ARGUMENTS.into(),
|
||||||
|
})?;
|
||||||
|
// Get the model file as bytes
|
||||||
|
let bytes = crate::obs::get(&path).await?;
|
||||||
|
// Run the compute in a blocking task
|
||||||
|
let outcome = tokio::task::spawn_blocking(move || {
|
||||||
|
let mut file = SurMlFile::from_bytes(bytes).unwrap();
|
||||||
|
let compute_unit = ModelComputation {
|
||||||
|
surml_file: &mut file,
|
||||||
|
};
|
||||||
|
compute_unit.buffered_compute(&mut args).map_err(Error::ModelComputation)
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.unwrap()?;
|
||||||
|
// Convert the output to a value
|
||||||
|
Ok(outcome[0].into())
|
||||||
|
}
|
||||||
|
// Perform raw compute
|
||||||
|
Value::Number(v) => {
|
||||||
|
// Compute the model function arguments
|
||||||
|
let args: f32 = v.try_into().map_err(|_| Error::InvalidArguments {
|
||||||
|
name: format!("ml::{}<{}>", self.name, self.version),
|
||||||
|
message: ARGUMENTS.into(),
|
||||||
|
})?;
|
||||||
|
// Get the model file as bytes
|
||||||
|
let bytes = crate::obs::get(&path).await?;
|
||||||
|
// Convert the argument to a tensor
|
||||||
|
let tensor = ndarray::arr1::<f32>(&[args]).into_dyn();
|
||||||
|
// Run the compute in a blocking task
|
||||||
|
let outcome = tokio::task::spawn_blocking(move || {
|
||||||
|
let mut file = SurMlFile::from_bytes(bytes).unwrap();
|
||||||
|
let compute_unit = ModelComputation {
|
||||||
|
surml_file: &mut file,
|
||||||
|
};
|
||||||
|
compute_unit.raw_compute(tensor, None).map_err(Error::ModelComputation)
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.unwrap()?;
|
||||||
|
// Convert the output to a value
|
||||||
|
Ok(outcome[0].into())
|
||||||
|
}
|
||||||
|
// Perform raw compute
|
||||||
|
Value::Array(v) => {
|
||||||
|
// Compute the model function arguments
|
||||||
|
let args = v
|
||||||
|
.into_iter()
|
||||||
|
.map(Value::try_into)
|
||||||
|
.collect::<Result<Vec<f32>, Error>>()
|
||||||
|
.map_err(|_| Error::InvalidArguments {
|
||||||
|
name: format!("ml::{}<{}>", self.name, self.version),
|
||||||
|
message: ARGUMENTS.into(),
|
||||||
|
})?;
|
||||||
|
// Get the model file as bytes
|
||||||
|
let bytes = crate::obs::get(&path).await?;
|
||||||
|
// Convert the argument to a tensor
|
||||||
|
let tensor = ndarray::arr1::<f32>(&args).into_dyn();
|
||||||
|
// Run the compute in a blocking task
|
||||||
|
let outcome = tokio::task::spawn_blocking(move || {
|
||||||
|
let mut file = SurMlFile::from_bytes(bytes).unwrap();
|
||||||
|
let compute_unit = ModelComputation {
|
||||||
|
surml_file: &mut file,
|
||||||
|
};
|
||||||
|
compute_unit.raw_compute(tensor, None).map_err(Error::ModelComputation)
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.unwrap()?;
|
||||||
|
// Convert the output to a value
|
||||||
|
Ok(outcome[0].into())
|
||||||
|
}
|
||||||
|
//
|
||||||
|
_ => Err(Error::InvalidArguments {
|
||||||
|
name: format!("ml::{}<{}>", self.name, self.version),
|
||||||
|
message: ARGUMENTS.into(),
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "ml"))]
|
||||||
pub(crate) async fn compute(
|
pub(crate) async fn compute(
|
||||||
&self,
|
&self,
|
||||||
_ctx: &Context<'_>,
|
_ctx: &Context<'_>,
|
||||||
_opt: &Options,
|
_opt: &Options,
|
||||||
_txn: &Transaction,
|
_txn: &Transaction,
|
||||||
_doc: Option<&'async_recursion CursorDoc<'_>>,
|
_doc: Option<&CursorDoc<'_>>,
|
||||||
) -> Result<Value, Error> {
|
) -> Result<Value, Error> {
|
||||||
Err(Error::Unimplemented("ML model evaluation not yet implemented".to_string()))
|
Err(Error::InvalidModel {
|
||||||
|
message: String::from("Machine learning computation is not enabled."),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
use crate::sql::fmt::Pretty;
|
use crate::sql::fmt::Pretty;
|
||||||
use crate::sql::statement::{Statement, Statements};
|
use crate::sql::statement::{Statement, Statements};
|
||||||
use crate::sql::Value;
|
use crate::sql::statements::{DefineStatement, RemoveStatement};
|
||||||
use derive::Store;
|
use derive::Store;
|
||||||
use revision::revisioned;
|
use revision::revisioned;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
@ -16,6 +16,18 @@ pub(crate) const TOKEN: &str = "$surrealdb::private::sql::Query";
|
||||||
#[serde(rename = "$surrealdb::private::sql::Query")]
|
#[serde(rename = "$surrealdb::private::sql::Query")]
|
||||||
pub struct Query(pub Statements);
|
pub struct Query(pub Statements);
|
||||||
|
|
||||||
|
impl From<DefineStatement> for Query {
|
||||||
|
fn from(s: DefineStatement) -> Self {
|
||||||
|
Query(Statements(vec![Statement::Define(s)]))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<RemoveStatement> for Query {
|
||||||
|
fn from(s: RemoveStatement) -> Self {
|
||||||
|
Query(Statements(vec![Statement::Remove(s)]))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl Deref for Query {
|
impl Deref for Query {
|
||||||
type Target = Vec<Statement>;
|
type Target = Vec<Statement>;
|
||||||
fn deref(&self) -> &Self::Target {
|
fn deref(&self) -> &Self::Target {
|
||||||
|
@ -31,12 +43,6 @@ impl IntoIterator for Query {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<Query> for Value {
|
|
||||||
fn from(q: Query) -> Self {
|
|
||||||
Value::Query(q)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Display for Query {
|
impl Display for Query {
|
||||||
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
|
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
|
||||||
write!(Pretty::from(f), "{}", &self.0)
|
write!(Pretty::from(f), "{}", &self.0)
|
||||||
|
|
|
@ -52,7 +52,7 @@ pub enum DefineStatement {
|
||||||
Field(DefineFieldStatement),
|
Field(DefineFieldStatement),
|
||||||
Index(DefineIndexStatement),
|
Index(DefineIndexStatement),
|
||||||
User(DefineUserStatement),
|
User(DefineUserStatement),
|
||||||
MlModel(DefineModelStatement),
|
Model(DefineModelStatement),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl DefineStatement {
|
impl DefineStatement {
|
||||||
|
@ -81,7 +81,7 @@ impl DefineStatement {
|
||||||
Self::Index(ref v) => v.compute(ctx, opt, txn, doc).await,
|
Self::Index(ref v) => v.compute(ctx, opt, txn, doc).await,
|
||||||
Self::Analyzer(ref v) => v.compute(ctx, opt, txn, doc).await,
|
Self::Analyzer(ref v) => v.compute(ctx, opt, txn, doc).await,
|
||||||
Self::User(ref v) => v.compute(ctx, opt, txn, doc).await,
|
Self::User(ref v) => v.compute(ctx, opt, txn, doc).await,
|
||||||
Self::MlModel(ref v) => v.compute(ctx, opt, txn, doc).await,
|
Self::Model(ref v) => v.compute(ctx, opt, txn, doc).await,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -101,7 +101,7 @@ impl Display for DefineStatement {
|
||||||
Self::Field(v) => Display::fmt(v, f),
|
Self::Field(v) => Display::fmt(v, f),
|
||||||
Self::Index(v) => Display::fmt(v, f),
|
Self::Index(v) => Display::fmt(v, f),
|
||||||
Self::Analyzer(v) => Display::fmt(v, f),
|
Self::Analyzer(v) => Display::fmt(v, f),
|
||||||
Self::MlModel(v) => Display::fmt(v, f),
|
Self::Model(v) => Display::fmt(v, f),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,25 +1,21 @@
|
||||||
|
use crate::ctx::Context;
|
||||||
|
use crate::dbs::{Options, Transaction};
|
||||||
|
use crate::doc::CursorDoc;
|
||||||
|
use crate::err::Error;
|
||||||
|
use crate::iam::{Action, ResourceKind};
|
||||||
use crate::sql::{
|
use crate::sql::{
|
||||||
fmt::{is_pretty, pretty_indent},
|
fmt::{is_pretty, pretty_indent},
|
||||||
Permission,
|
Base, Ident, Permission, Strand, Value,
|
||||||
};
|
};
|
||||||
use async_recursion::async_recursion;
|
|
||||||
use derive::Store;
|
use derive::Store;
|
||||||
use revision::revisioned;
|
use revision::revisioned;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::fmt;
|
use std::fmt::{self, Write};
|
||||||
use std::fmt::Write;
|
|
||||||
|
|
||||||
use crate::{
|
|
||||||
ctx::Context,
|
|
||||||
dbs::{Options, Transaction},
|
|
||||||
doc::CursorDoc,
|
|
||||||
err::Error,
|
|
||||||
sql::{Ident, Strand, Value},
|
|
||||||
};
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, Default, Eq, PartialEq, PartialOrd, Serialize, Deserialize, Store, Hash)]
|
#[derive(Clone, Debug, Default, Eq, PartialEq, PartialOrd, Serialize, Deserialize, Store, Hash)]
|
||||||
#[revisioned(revision = 1)]
|
#[revisioned(revision = 1)]
|
||||||
pub struct DefineModelStatement {
|
pub struct DefineModelStatement {
|
||||||
|
pub hash: String,
|
||||||
pub name: Ident,
|
pub name: Ident,
|
||||||
pub version: String,
|
pub version: String,
|
||||||
pub comment: Option<Strand>,
|
pub comment: Option<Strand>,
|
||||||
|
@ -32,7 +28,6 @@ impl fmt::Display for DefineModelStatement {
|
||||||
if let Some(comment) = self.comment.as_ref() {
|
if let Some(comment) = self.comment.as_ref() {
|
||||||
write!(f, " COMMENT {}", comment)?;
|
write!(f, " COMMENT {}", comment)?;
|
||||||
}
|
}
|
||||||
if !self.permissions.is_full() {
|
|
||||||
let _indent = if is_pretty() {
|
let _indent = if is_pretty() {
|
||||||
Some(pretty_indent())
|
Some(pretty_indent())
|
||||||
} else {
|
} else {
|
||||||
|
@ -40,21 +35,33 @@ impl fmt::Display for DefineModelStatement {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
write!(f, "PERMISSIONS {}", self.permissions)?;
|
write!(f, "PERMISSIONS {}", self.permissions)?;
|
||||||
}
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl DefineModelStatement {
|
impl DefineModelStatement {
|
||||||
#[cfg_attr(not(target_arch = "wasm32"), async_recursion)]
|
/// Process this type returning a computed simple Value
|
||||||
#[cfg_attr(target_arch = "wasm32", async_recursion(?Send))]
|
|
||||||
pub(crate) async fn compute(
|
pub(crate) async fn compute(
|
||||||
&self,
|
&self,
|
||||||
_ctx: &Context<'_>,
|
_ctx: &Context<'_>,
|
||||||
_opt: &Options,
|
opt: &Options,
|
||||||
_txn: &Transaction,
|
txn: &Transaction,
|
||||||
_doc: Option<&'async_recursion CursorDoc<'_>>,
|
_doc: Option<&CursorDoc<'_>>,
|
||||||
) -> Result<Value, Error> {
|
) -> Result<Value, Error> {
|
||||||
Err(Error::Unimplemented("Ml model definition not yet implemented".to_string()))
|
// Allowed to run?
|
||||||
|
opt.is_allowed(Action::Edit, ResourceKind::Model, &Base::Db)?;
|
||||||
|
// Claim transaction
|
||||||
|
let mut run = txn.lock().await;
|
||||||
|
// Clear the cache
|
||||||
|
run.clear_cache();
|
||||||
|
// Process the statement
|
||||||
|
let key = crate::key::database::ml::new(opt.ns(), opt.db(), &self.name, &self.version);
|
||||||
|
run.add_ns(opt.ns(), opt.strict).await?;
|
||||||
|
run.add_db(opt.ns(), opt.db(), opt.strict).await?;
|
||||||
|
run.set(key, self).await?;
|
||||||
|
// Store the model file
|
||||||
|
// TODO
|
||||||
|
// Ok all good
|
||||||
|
Ok(Value::None)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -107,6 +107,12 @@ impl InfoStatement {
|
||||||
tmp.insert(v.name.to_string(), v.to_string().into());
|
tmp.insert(v.name.to_string(), v.to_string().into());
|
||||||
}
|
}
|
||||||
res.insert("functions".to_owned(), tmp.into());
|
res.insert("functions".to_owned(), tmp.into());
|
||||||
|
// Process the models
|
||||||
|
let mut tmp = Object::default();
|
||||||
|
for v in run.all_db_models(opt.ns(), opt.db()).await?.iter() {
|
||||||
|
tmp.insert(format!("{}<{}>", v.name, v.version), v.to_string().into());
|
||||||
|
}
|
||||||
|
res.insert("models".to_owned(), tmp.into());
|
||||||
// Process the params
|
// Process the params
|
||||||
let mut tmp = Object::default();
|
let mut tmp = Object::default();
|
||||||
for v in run.all_db_params(opt.ns(), opt.db()).await?.iter() {
|
for v in run.all_db_params(opt.ns(), opt.db()).await?.iter() {
|
||||||
|
|
|
@ -52,14 +52,14 @@ pub use self::update::UpdateStatement;
|
||||||
|
|
||||||
pub use self::define::{
|
pub use self::define::{
|
||||||
DefineAnalyzerStatement, DefineDatabaseStatement, DefineEventStatement, DefineFieldStatement,
|
DefineAnalyzerStatement, DefineDatabaseStatement, DefineEventStatement, DefineFieldStatement,
|
||||||
DefineFunctionStatement, DefineIndexStatement, DefineNamespaceStatement, DefineParamStatement,
|
DefineFunctionStatement, DefineIndexStatement, DefineModelStatement, DefineNamespaceStatement,
|
||||||
DefineScopeStatement, DefineStatement, DefineTableStatement, DefineTokenStatement,
|
DefineParamStatement, DefineScopeStatement, DefineStatement, DefineTableStatement,
|
||||||
DefineUserStatement,
|
DefineTokenStatement, DefineUserStatement,
|
||||||
};
|
};
|
||||||
|
|
||||||
pub use self::remove::{
|
pub use self::remove::{
|
||||||
RemoveAnalyzerStatement, RemoveDatabaseStatement, RemoveEventStatement, RemoveFieldStatement,
|
RemoveAnalyzerStatement, RemoveDatabaseStatement, RemoveEventStatement, RemoveFieldStatement,
|
||||||
RemoveFunctionStatement, RemoveIndexStatement, RemoveNamespaceStatement, RemoveParamStatement,
|
RemoveFunctionStatement, RemoveIndexStatement, RemoveModelStatement, RemoveNamespaceStatement,
|
||||||
RemoveScopeStatement, RemoveStatement, RemoveTableStatement, RemoveTokenStatement,
|
RemoveParamStatement, RemoveScopeStatement, RemoveStatement, RemoveTableStatement,
|
||||||
RemoveUserStatement,
|
RemoveTokenStatement, RemoveUserStatement,
|
||||||
};
|
};
|
||||||
|
|
|
@ -4,6 +4,7 @@ mod event;
|
||||||
mod field;
|
mod field;
|
||||||
mod function;
|
mod function;
|
||||||
mod index;
|
mod index;
|
||||||
|
mod model;
|
||||||
mod namespace;
|
mod namespace;
|
||||||
mod param;
|
mod param;
|
||||||
mod scope;
|
mod scope;
|
||||||
|
@ -17,6 +18,7 @@ pub use event::RemoveEventStatement;
|
||||||
pub use field::RemoveFieldStatement;
|
pub use field::RemoveFieldStatement;
|
||||||
pub use function::RemoveFunctionStatement;
|
pub use function::RemoveFunctionStatement;
|
||||||
pub use index::RemoveIndexStatement;
|
pub use index::RemoveIndexStatement;
|
||||||
|
pub use model::RemoveModelStatement;
|
||||||
pub use namespace::RemoveNamespaceStatement;
|
pub use namespace::RemoveNamespaceStatement;
|
||||||
pub use param::RemoveParamStatement;
|
pub use param::RemoveParamStatement;
|
||||||
pub use scope::RemoveScopeStatement;
|
pub use scope::RemoveScopeStatement;
|
||||||
|
@ -49,6 +51,7 @@ pub enum RemoveStatement {
|
||||||
Field(RemoveFieldStatement),
|
Field(RemoveFieldStatement),
|
||||||
Index(RemoveIndexStatement),
|
Index(RemoveIndexStatement),
|
||||||
User(RemoveUserStatement),
|
User(RemoveUserStatement),
|
||||||
|
Model(RemoveModelStatement),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl RemoveStatement {
|
impl RemoveStatement {
|
||||||
|
@ -77,6 +80,7 @@ impl RemoveStatement {
|
||||||
Self::Index(ref v) => v.compute(ctx, opt, txn).await,
|
Self::Index(ref v) => v.compute(ctx, opt, txn).await,
|
||||||
Self::Analyzer(ref v) => v.compute(ctx, opt, txn).await,
|
Self::Analyzer(ref v) => v.compute(ctx, opt, txn).await,
|
||||||
Self::User(ref v) => v.compute(ctx, opt, txn).await,
|
Self::User(ref v) => v.compute(ctx, opt, txn).await,
|
||||||
|
Self::Model(ref v) => v.compute(ctx, opt, txn).await,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -96,6 +100,7 @@ impl Display for RemoveStatement {
|
||||||
Self::Index(v) => Display::fmt(v, f),
|
Self::Index(v) => Display::fmt(v, f),
|
||||||
Self::Analyzer(v) => Display::fmt(v, f),
|
Self::Analyzer(v) => Display::fmt(v, f),
|
||||||
Self::User(v) => Display::fmt(v, f),
|
Self::User(v) => Display::fmt(v, f),
|
||||||
|
Self::Model(v) => Display::fmt(v, f),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
47
lib/src/sql/statements/remove/model.rs
Normal file
47
lib/src/sql/statements/remove/model.rs
Normal file
|
@ -0,0 +1,47 @@
|
||||||
|
use crate::ctx::Context;
|
||||||
|
use crate::dbs::{Options, Transaction};
|
||||||
|
use crate::err::Error;
|
||||||
|
use crate::iam::{Action, ResourceKind};
|
||||||
|
use crate::sql::{Base, Ident, Value};
|
||||||
|
use derive::Store;
|
||||||
|
use revision::revisioned;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::fmt::{self, Display};
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Default, Eq, PartialEq, PartialOrd, Serialize, Deserialize, Store, Hash)]
|
||||||
|
#[revisioned(revision = 1)]
|
||||||
|
pub struct RemoveModelStatement {
|
||||||
|
pub name: Ident,
|
||||||
|
pub version: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RemoveModelStatement {
|
||||||
|
/// Process this type returning a computed simple Value
|
||||||
|
pub(crate) async fn compute(
|
||||||
|
&self,
|
||||||
|
_ctx: &Context<'_>,
|
||||||
|
opt: &Options,
|
||||||
|
txn: &Transaction,
|
||||||
|
) -> Result<Value, Error> {
|
||||||
|
// Allowed to run?
|
||||||
|
opt.is_allowed(Action::Edit, ResourceKind::Model, &Base::Db)?;
|
||||||
|
// Claim transaction
|
||||||
|
let mut run = txn.lock().await;
|
||||||
|
// Clear the cache
|
||||||
|
run.clear_cache();
|
||||||
|
// Delete the definition
|
||||||
|
let key = crate::key::database::ml::new(opt.ns(), opt.db(), &self.name, &self.version);
|
||||||
|
run.del(key).await?;
|
||||||
|
// Remove the model file
|
||||||
|
// TODO
|
||||||
|
// Ok all good
|
||||||
|
Ok(Value::None)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Display for RemoveModelStatement {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||||
|
// Bypass ident display since we don't want backticks arround the ident.
|
||||||
|
write!(f, "REMOVE MODEL ml::{}<{}>", self.name.0, self.version)
|
||||||
|
}
|
||||||
|
}
|
|
@ -99,12 +99,11 @@ pub enum Value {
|
||||||
Edges(Box<Edges>),
|
Edges(Box<Edges>),
|
||||||
Future(Box<Future>),
|
Future(Box<Future>),
|
||||||
Constant(Constant),
|
Constant(Constant),
|
||||||
// Closure(Box<Closure>),
|
|
||||||
Function(Box<Function>),
|
Function(Box<Function>),
|
||||||
Subquery(Box<Subquery>),
|
Subquery(Box<Subquery>),
|
||||||
Expression(Box<Expression>),
|
Expression(Box<Expression>),
|
||||||
Query(Query),
|
Query(Query),
|
||||||
MlModel(Box<Model>),
|
Model(Box<Model>),
|
||||||
// Add new variants here
|
// Add new variants here
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -257,7 +256,7 @@ impl From<Function> for Value {
|
||||||
|
|
||||||
impl From<Model> for Value {
|
impl From<Model> for Value {
|
||||||
fn from(v: Model) -> Self {
|
fn from(v: Model) -> Self {
|
||||||
Value::MlModel(Box::new(v))
|
Value::Model(Box::new(v))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -505,6 +504,12 @@ impl From<Id> for Value {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl From<Query> for Value {
|
||||||
|
fn from(q: Query) -> Self {
|
||||||
|
Value::Query(q)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl TryFrom<Value> for i8 {
|
impl TryFrom<Value> for i8 {
|
||||||
type Error = Error;
|
type Error = Error;
|
||||||
fn try_from(value: Value) -> Result<Self, Self::Error> {
|
fn try_from(value: Value) -> Result<Self, Self::Error> {
|
||||||
|
@ -1035,7 +1040,7 @@ impl Value {
|
||||||
pub fn can_start_idiom(&self) -> bool {
|
pub fn can_start_idiom(&self) -> bool {
|
||||||
match self {
|
match self {
|
||||||
Value::Function(x) => !x.is_script(),
|
Value::Function(x) => !x.is_script(),
|
||||||
Value::MlModel(_)
|
Value::Model(_)
|
||||||
| Value::Subquery(_)
|
| Value::Subquery(_)
|
||||||
| Value::Constant(_)
|
| Value::Constant(_)
|
||||||
| Value::Datetime(_)
|
| Value::Datetime(_)
|
||||||
|
@ -2526,7 +2531,7 @@ impl fmt::Display for Value {
|
||||||
Value::Edges(v) => write!(f, "{v}"),
|
Value::Edges(v) => write!(f, "{v}"),
|
||||||
Value::Expression(v) => write!(f, "{v}"),
|
Value::Expression(v) => write!(f, "{v}"),
|
||||||
Value::Function(v) => write!(f, "{v}"),
|
Value::Function(v) => write!(f, "{v}"),
|
||||||
Value::MlModel(v) => write!(f, "{v}"),
|
Value::Model(v) => write!(f, "{v}"),
|
||||||
Value::Future(v) => write!(f, "{v}"),
|
Value::Future(v) => write!(f, "{v}"),
|
||||||
Value::Geometry(v) => write!(f, "{v}"),
|
Value::Geometry(v) => write!(f, "{v}"),
|
||||||
Value::Idiom(v) => write!(f, "{v}"),
|
Value::Idiom(v) => write!(f, "{v}"),
|
||||||
|
@ -2557,7 +2562,7 @@ impl Value {
|
||||||
Value::Function(v) => {
|
Value::Function(v) => {
|
||||||
v.is_custom() || v.is_script() || v.args().iter().any(Value::writeable)
|
v.is_custom() || v.is_script() || v.args().iter().any(Value::writeable)
|
||||||
}
|
}
|
||||||
Value::MlModel(m) => m.args.iter().any(Value::writeable),
|
Value::Model(m) => m.args.iter().any(Value::writeable),
|
||||||
Value::Subquery(v) => v.writeable(),
|
Value::Subquery(v) => v.writeable(),
|
||||||
Value::Expression(v) => v.writeable(),
|
Value::Expression(v) => v.writeable(),
|
||||||
_ => false,
|
_ => false,
|
||||||
|
@ -2588,7 +2593,7 @@ impl Value {
|
||||||
Value::Future(v) => v.compute(ctx, opt, txn, doc).await,
|
Value::Future(v) => v.compute(ctx, opt, txn, doc).await,
|
||||||
Value::Constant(v) => v.compute(ctx, opt, txn, doc).await,
|
Value::Constant(v) => v.compute(ctx, opt, txn, doc).await,
|
||||||
Value::Function(v) => v.compute(ctx, opt, txn, doc).await,
|
Value::Function(v) => v.compute(ctx, opt, txn, doc).await,
|
||||||
Value::MlModel(v) => v.compute(ctx, opt, txn, doc).await,
|
Value::Model(v) => v.compute(ctx, opt, txn, doc).await,
|
||||||
Value::Subquery(v) => v.compute(ctx, opt, txn, doc).await,
|
Value::Subquery(v) => v.compute(ctx, opt, txn, doc).await,
|
||||||
Value::Expression(v) => v.compute(ctx, opt, txn, doc).await,
|
Value::Expression(v) => v.compute(ctx, opt, txn, doc).await,
|
||||||
_ => Ok(self.to_owned()),
|
_ => Ok(self.to_owned()),
|
||||||
|
|
|
@ -2,6 +2,7 @@
|
||||||
mod api_integration {
|
mod api_integration {
|
||||||
use chrono::DateTime;
|
use chrono::DateTime;
|
||||||
use once_cell::sync::Lazy;
|
use once_cell::sync::Lazy;
|
||||||
|
use semver::Version;
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use serde::Serialize;
|
use serde::Serialize;
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
|
|
|
@ -23,3 +23,17 @@ async fn export_import() {
|
||||||
db.import(&file).await.unwrap();
|
db.import(&file).await.unwrap();
|
||||||
remove_file(file).await.unwrap();
|
remove_file(file).await.unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test_log::test(tokio::test)]
|
||||||
|
#[cfg(feature = "ml")]
|
||||||
|
async fn ml_export_import() {
|
||||||
|
let (permit, db) = new_db().await;
|
||||||
|
let db_name = Ulid::new().to_string();
|
||||||
|
db.use_ns(NS).use_db(&db_name).await.unwrap();
|
||||||
|
db.import("../tests/linear_test.surml").ml().await.unwrap();
|
||||||
|
drop(permit);
|
||||||
|
let file = format!("{db_name}.surml");
|
||||||
|
db.export(&file).ml("Prediction", Version::new(0, 0, 1)).await.unwrap();
|
||||||
|
db.import(&file).ml().await.unwrap();
|
||||||
|
remove_file(file).await.unwrap();
|
||||||
|
}
|
||||||
|
|
|
@ -87,8 +87,7 @@ async fn define_statement_function() -> Result<(), Error> {
|
||||||
analyzers: {},
|
analyzers: {},
|
||||||
tokens: {},
|
tokens: {},
|
||||||
functions: { test: 'DEFINE FUNCTION fn::test($first: string, $last: string) { RETURN $first + $last; } PERMISSIONS FULL' },
|
functions: { test: 'DEFINE FUNCTION fn::test($first: string, $last: string) { RETURN $first + $last; } PERMISSIONS FULL' },
|
||||||
params: {},
|
models: {},
|
||||||
scopes: {},
|
|
||||||
params: {},
|
params: {},
|
||||||
scopes: {},
|
scopes: {},
|
||||||
tables: {},
|
tables: {},
|
||||||
|
@ -120,6 +119,7 @@ async fn define_statement_table_drop() -> Result<(), Error> {
|
||||||
analyzers: {},
|
analyzers: {},
|
||||||
tokens: {},
|
tokens: {},
|
||||||
functions: {},
|
functions: {},
|
||||||
|
models: {},
|
||||||
params: {},
|
params: {},
|
||||||
scopes: {},
|
scopes: {},
|
||||||
tables: { test: 'DEFINE TABLE test DROP SCHEMALESS PERMISSIONS NONE' },
|
tables: { test: 'DEFINE TABLE test DROP SCHEMALESS PERMISSIONS NONE' },
|
||||||
|
@ -151,6 +151,7 @@ async fn define_statement_table_schemaless() -> Result<(), Error> {
|
||||||
analyzers: {},
|
analyzers: {},
|
||||||
tokens: {},
|
tokens: {},
|
||||||
functions: {},
|
functions: {},
|
||||||
|
models: {},
|
||||||
params: {},
|
params: {},
|
||||||
scopes: {},
|
scopes: {},
|
||||||
tables: { test: 'DEFINE TABLE test SCHEMALESS PERMISSIONS NONE' },
|
tables: { test: 'DEFINE TABLE test SCHEMALESS PERMISSIONS NONE' },
|
||||||
|
@ -186,6 +187,7 @@ async fn define_statement_table_schemafull() -> Result<(), Error> {
|
||||||
analyzers: {},
|
analyzers: {},
|
||||||
tokens: {},
|
tokens: {},
|
||||||
functions: {},
|
functions: {},
|
||||||
|
models: {},
|
||||||
params: {},
|
params: {},
|
||||||
scopes: {},
|
scopes: {},
|
||||||
tables: { test: 'DEFINE TABLE test SCHEMAFULL PERMISSIONS NONE' },
|
tables: { test: 'DEFINE TABLE test SCHEMAFULL PERMISSIONS NONE' },
|
||||||
|
@ -217,6 +219,7 @@ async fn define_statement_table_schemaful() -> Result<(), Error> {
|
||||||
analyzers: {},
|
analyzers: {},
|
||||||
tokens: {},
|
tokens: {},
|
||||||
functions: {},
|
functions: {},
|
||||||
|
models: {},
|
||||||
params: {},
|
params: {},
|
||||||
scopes: {},
|
scopes: {},
|
||||||
tables: { test: 'DEFINE TABLE test SCHEMAFULL PERMISSIONS NONE' },
|
tables: { test: 'DEFINE TABLE test SCHEMAFULL PERMISSIONS NONE' },
|
||||||
|
@ -256,6 +259,7 @@ async fn define_statement_table_foreigntable() -> Result<(), Error> {
|
||||||
analyzers: {},
|
analyzers: {},
|
||||||
tokens: {},
|
tokens: {},
|
||||||
functions: {},
|
functions: {},
|
||||||
|
models: {},
|
||||||
params: {},
|
params: {},
|
||||||
scopes: {},
|
scopes: {},
|
||||||
tables: {
|
tables: {
|
||||||
|
@ -288,6 +292,7 @@ async fn define_statement_table_foreigntable() -> Result<(), Error> {
|
||||||
analyzers: {},
|
analyzers: {},
|
||||||
tokens: {},
|
tokens: {},
|
||||||
functions: {},
|
functions: {},
|
||||||
|
models: {},
|
||||||
params: {},
|
params: {},
|
||||||
scopes: {},
|
scopes: {},
|
||||||
tables: {
|
tables: {
|
||||||
|
@ -1177,6 +1182,7 @@ async fn define_statement_analyzer() -> Result<(), Error> {
|
||||||
functions: {
|
functions: {
|
||||||
stripHtml: "DEFINE FUNCTION fn::stripHtml($html: string) { RETURN string::replace($html, /<[^>]*>/, ''); } PERMISSIONS FULL"
|
stripHtml: "DEFINE FUNCTION fn::stripHtml($html: string) { RETURN string::replace($html, /<[^>]*>/, ''); } PERMISSIONS FULL"
|
||||||
},
|
},
|
||||||
|
models: {},
|
||||||
params: {},
|
params: {},
|
||||||
scopes: {},
|
scopes: {},
|
||||||
tables: {},
|
tables: {},
|
||||||
|
@ -1496,8 +1502,8 @@ async fn permissions_checks_define_function() {
|
||||||
|
|
||||||
// Define the expected results for the check statement when the test statement succeeded and when it failed
|
// Define the expected results for the check statement when the test statement succeeded and when it failed
|
||||||
let check_results = [
|
let check_results = [
|
||||||
vec!["{ analyzers: { }, functions: { greet: \"DEFINE FUNCTION fn::greet() { RETURN 'Hello'; } PERMISSIONS FULL\" }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"],
|
vec!["{ analyzers: { }, functions: { greet: \"DEFINE FUNCTION fn::greet() { RETURN 'Hello'; } PERMISSIONS FULL\" }, models: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"],
|
||||||
vec!["{ analyzers: { }, functions: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"]
|
vec!["{ analyzers: { }, functions: { }, models: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"]
|
||||||
];
|
];
|
||||||
|
|
||||||
let test_cases = [
|
let test_cases = [
|
||||||
|
@ -1538,8 +1544,8 @@ async fn permissions_checks_define_analyzer() {
|
||||||
|
|
||||||
// Define the expected results for the check statement when the test statement succeeded and when it failed
|
// Define the expected results for the check statement when the test statement succeeded and when it failed
|
||||||
let check_results = [
|
let check_results = [
|
||||||
vec!["{ analyzers: { analyzer: 'DEFINE ANALYZER analyzer TOKENIZERS BLANK' }, functions: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"],
|
vec!["{ analyzers: { analyzer: 'DEFINE ANALYZER analyzer TOKENIZERS BLANK' }, functions: { }, models: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"],
|
||||||
vec!["{ analyzers: { }, functions: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"]
|
vec!["{ analyzers: { }, functions: { }, models: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"]
|
||||||
];
|
];
|
||||||
|
|
||||||
let test_cases = [
|
let test_cases = [
|
||||||
|
@ -1622,8 +1628,8 @@ async fn permissions_checks_define_token_db() {
|
||||||
|
|
||||||
// Define the expected results for the check statement when the test statement succeeded and when it failed
|
// Define the expected results for the check statement when the test statement succeeded and when it failed
|
||||||
let check_results = [
|
let check_results = [
|
||||||
vec!["{ analyzers: { }, functions: { }, params: { }, scopes: { }, tables: { }, tokens: { token: \"DEFINE TOKEN token ON DATABASE TYPE HS512 VALUE 'secret'\" }, users: { } }"],
|
vec!["{ analyzers: { }, functions: { }, models: { }, params: { }, scopes: { }, tables: { }, tokens: { token: \"DEFINE TOKEN token ON DATABASE TYPE HS512 VALUE 'secret'\" }, users: { } }"],
|
||||||
vec!["{ analyzers: { }, functions: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"]
|
vec!["{ analyzers: { }, functions: { }, models: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"]
|
||||||
];
|
];
|
||||||
|
|
||||||
let test_cases = [
|
let test_cases = [
|
||||||
|
@ -1748,8 +1754,8 @@ async fn permissions_checks_define_user_db() {
|
||||||
|
|
||||||
// Define the expected results for the check statement when the test statement succeeded and when it failed
|
// Define the expected results for the check statement when the test statement succeeded and when it failed
|
||||||
let check_results = [
|
let check_results = [
|
||||||
vec!["{ analyzers: { }, functions: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { user: \"DEFINE USER user ON DATABASE PASSHASH 'secret' ROLES VIEWER\" } }"],
|
vec!["{ analyzers: { }, functions: { }, models: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { user: \"DEFINE USER user ON DATABASE PASSHASH 'secret' ROLES VIEWER\" } }"],
|
||||||
vec!["{ analyzers: { }, functions: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"]
|
vec!["{ analyzers: { }, functions: { }, models: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"]
|
||||||
];
|
];
|
||||||
|
|
||||||
let test_cases = [
|
let test_cases = [
|
||||||
|
@ -1790,8 +1796,8 @@ async fn permissions_checks_define_scope() {
|
||||||
|
|
||||||
// Define the expected results for the check statement when the test statement succeeded and when it failed
|
// Define the expected results for the check statement when the test statement succeeded and when it failed
|
||||||
let check_results = [
|
let check_results = [
|
||||||
vec!["{ analyzers: { }, functions: { }, params: { }, scopes: { account: 'DEFINE SCOPE account SESSION 1h' }, tables: { }, tokens: { }, users: { } }"],
|
vec!["{ analyzers: { }, functions: { }, models: { }, params: { }, scopes: { account: 'DEFINE SCOPE account SESSION 1h' }, tables: { }, tokens: { }, users: { } }"],
|
||||||
vec!["{ analyzers: { }, functions: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"]
|
vec!["{ analyzers: { }, functions: { }, models: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"]
|
||||||
];
|
];
|
||||||
|
|
||||||
let test_cases = [
|
let test_cases = [
|
||||||
|
@ -1832,8 +1838,8 @@ async fn permissions_checks_define_param() {
|
||||||
|
|
||||||
// Define the expected results for the check statement when the test statement succeeded and when it failed
|
// Define the expected results for the check statement when the test statement succeeded and when it failed
|
||||||
let check_results = [
|
let check_results = [
|
||||||
vec!["{ analyzers: { }, functions: { }, params: { param: \"DEFINE PARAM $param VALUE 'foo' PERMISSIONS FULL\" }, scopes: { }, tables: { }, tokens: { }, users: { } }"],
|
vec!["{ analyzers: { }, functions: { }, models: { }, params: { param: \"DEFINE PARAM $param VALUE 'foo' PERMISSIONS FULL\" }, scopes: { }, tables: { }, tokens: { }, users: { } }"],
|
||||||
vec!["{ analyzers: { }, functions: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"]
|
vec!["{ analyzers: { }, functions: { }, models: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"]
|
||||||
];
|
];
|
||||||
|
|
||||||
let test_cases = [
|
let test_cases = [
|
||||||
|
@ -1871,8 +1877,8 @@ async fn permissions_checks_define_table() {
|
||||||
|
|
||||||
// Define the expected results for the check statement when the test statement succeeded and when it failed
|
// Define the expected results for the check statement when the test statement succeeded and when it failed
|
||||||
let check_results = [
|
let check_results = [
|
||||||
vec!["{ analyzers: { }, functions: { }, params: { }, scopes: { }, tables: { TB: 'DEFINE TABLE TB SCHEMALESS PERMISSIONS NONE' }, tokens: { }, users: { } }"],
|
vec!["{ analyzers: { }, functions: { }, models: { }, params: { }, scopes: { }, tables: { TB: 'DEFINE TABLE TB SCHEMALESS PERMISSIONS NONE' }, tokens: { }, users: { } }"],
|
||||||
vec!["{ analyzers: { }, functions: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"]
|
vec!["{ analyzers: { }, functions: { }, models: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"]
|
||||||
];
|
];
|
||||||
|
|
||||||
let test_cases = [
|
let test_cases = [
|
||||||
|
@ -2058,6 +2064,7 @@ async fn define_statement_table_permissions() -> Result<(), Error> {
|
||||||
"{
|
"{
|
||||||
analyzers: {},
|
analyzers: {},
|
||||||
functions: {},
|
functions: {},
|
||||||
|
models: {},
|
||||||
params: {},
|
params: {},
|
||||||
scopes: {},
|
scopes: {},
|
||||||
tables: {
|
tables: {
|
||||||
|
|
|
@ -312,8 +312,8 @@ async fn permissions_checks_info_db() {
|
||||||
|
|
||||||
// Define the expected results for the check statement when the test statement succeeded and when it failed
|
// Define the expected results for the check statement when the test statement succeeded and when it failed
|
||||||
let check_results = [
|
let check_results = [
|
||||||
vec!["{ analyzers: { }, functions: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"],
|
vec!["{ analyzers: { }, functions: { }, models: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"],
|
||||||
vec!["{ analyzers: { }, functions: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"],
|
vec!["{ analyzers: { }, functions: { }, models: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"],
|
||||||
];
|
];
|
||||||
|
|
||||||
let test_cases = [
|
let test_cases = [
|
||||||
|
|
|
@ -29,6 +29,7 @@ async fn define_global_param() -> Result<(), Error> {
|
||||||
analyzers: {},
|
analyzers: {},
|
||||||
tokens: {},
|
tokens: {},
|
||||||
functions: {},
|
functions: {},
|
||||||
|
models: {},
|
||||||
params: { test: 'DEFINE PARAM $test VALUE 12345 PERMISSIONS FULL' },
|
params: { test: 'DEFINE PARAM $test VALUE 12345 PERMISSIONS FULL' },
|
||||||
scopes: {},
|
scopes: {},
|
||||||
tables: {},
|
tables: {},
|
||||||
|
|
|
@ -39,6 +39,7 @@ async fn remove_statement_table() -> Result<(), Error> {
|
||||||
analyzers: {},
|
analyzers: {},
|
||||||
tokens: {},
|
tokens: {},
|
||||||
functions: {},
|
functions: {},
|
||||||
|
models: {},
|
||||||
params: {},
|
params: {},
|
||||||
scopes: {},
|
scopes: {},
|
||||||
tables: {},
|
tables: {},
|
||||||
|
@ -73,6 +74,7 @@ async fn remove_statement_analyzer() -> Result<(), Error> {
|
||||||
analyzers: {},
|
analyzers: {},
|
||||||
tokens: {},
|
tokens: {},
|
||||||
functions: {},
|
functions: {},
|
||||||
|
models: {},
|
||||||
params: {},
|
params: {},
|
||||||
scopes: {},
|
scopes: {},
|
||||||
tables: {},
|
tables: {},
|
||||||
|
@ -222,8 +224,8 @@ async fn permissions_checks_remove_function() {
|
||||||
|
|
||||||
// Define the expected results for the check statement when the test statement succeeded and when it failed
|
// Define the expected results for the check statement when the test statement succeeded and when it failed
|
||||||
let check_results = [
|
let check_results = [
|
||||||
vec!["{ analyzers: { }, functions: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"],
|
vec!["{ analyzers: { }, functions: { }, models: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"],
|
||||||
vec!["{ analyzers: { }, functions: { greet: \"DEFINE FUNCTION fn::greet() { RETURN 'Hello'; } PERMISSIONS FULL\" }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"],
|
vec!["{ analyzers: { }, functions: { greet: \"DEFINE FUNCTION fn::greet() { RETURN 'Hello'; } PERMISSIONS FULL\" }, models: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"],
|
||||||
];
|
];
|
||||||
|
|
||||||
let test_cases = [
|
let test_cases = [
|
||||||
|
@ -264,8 +266,8 @@ async fn permissions_checks_remove_analyzer() {
|
||||||
|
|
||||||
// Define the expected results for the check statement when the test statement succeeded and when it failed
|
// Define the expected results for the check statement when the test statement succeeded and when it failed
|
||||||
let check_results = [
|
let check_results = [
|
||||||
vec!["{ analyzers: { }, functions: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"],
|
vec!["{ analyzers: { }, functions: { }, models: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"],
|
||||||
vec!["{ analyzers: { analyzer: 'DEFINE ANALYZER analyzer TOKENIZERS BLANK' }, functions: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"],
|
vec!["{ analyzers: { analyzer: 'DEFINE ANALYZER analyzer TOKENIZERS BLANK' }, functions: { }, models: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"],
|
||||||
];
|
];
|
||||||
|
|
||||||
let test_cases = [
|
let test_cases = [
|
||||||
|
@ -348,8 +350,8 @@ async fn permissions_checks_remove_db_token() {
|
||||||
|
|
||||||
// Define the expected results for the check statement when the test statement succeeded and when it failed
|
// Define the expected results for the check statement when the test statement succeeded and when it failed
|
||||||
let check_results = [
|
let check_results = [
|
||||||
vec!["{ analyzers: { }, functions: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"],
|
vec!["{ analyzers: { }, functions: { }, models: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"],
|
||||||
vec!["{ analyzers: { }, functions: { }, params: { }, scopes: { }, tables: { }, tokens: { token: \"DEFINE TOKEN token ON DATABASE TYPE HS512 VALUE 'secret'\" }, users: { } }"],
|
vec!["{ analyzers: { }, functions: { }, models: { }, params: { }, scopes: { }, tables: { }, tokens: { token: \"DEFINE TOKEN token ON DATABASE TYPE HS512 VALUE 'secret'\" }, users: { } }"],
|
||||||
];
|
];
|
||||||
|
|
||||||
let test_cases = [
|
let test_cases = [
|
||||||
|
@ -474,8 +476,8 @@ async fn permissions_checks_remove_db_user() {
|
||||||
|
|
||||||
// Define the expected results for the check statement when the test statement succeeded and when it failed
|
// Define the expected results for the check statement when the test statement succeeded and when it failed
|
||||||
let check_results = [
|
let check_results = [
|
||||||
vec!["{ analyzers: { }, functions: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"],
|
vec!["{ analyzers: { }, functions: { }, models: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"],
|
||||||
vec!["{ analyzers: { }, functions: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { user: \"DEFINE USER user ON DATABASE PASSHASH 'secret' ROLES VIEWER\" } }"],
|
vec!["{ analyzers: { }, functions: { }, models: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { user: \"DEFINE USER user ON DATABASE PASSHASH 'secret' ROLES VIEWER\" } }"],
|
||||||
];
|
];
|
||||||
|
|
||||||
let test_cases = [
|
let test_cases = [
|
||||||
|
@ -516,8 +518,8 @@ async fn permissions_checks_remove_scope() {
|
||||||
|
|
||||||
// Define the expected results for the check statement when the test statement succeeded and when it failed
|
// Define the expected results for the check statement when the test statement succeeded and when it failed
|
||||||
let check_results = [
|
let check_results = [
|
||||||
vec!["{ analyzers: { }, functions: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"],
|
vec!["{ analyzers: { }, functions: { }, models: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"],
|
||||||
vec!["{ analyzers: { }, functions: { }, params: { }, scopes: { account: 'DEFINE SCOPE account SESSION 1h' }, tables: { }, tokens: { }, users: { } }"],
|
vec!["{ analyzers: { }, functions: { }, models: { }, params: { }, scopes: { account: 'DEFINE SCOPE account SESSION 1h' }, tables: { }, tokens: { }, users: { } }"],
|
||||||
];
|
];
|
||||||
|
|
||||||
let test_cases = [
|
let test_cases = [
|
||||||
|
@ -558,8 +560,8 @@ async fn permissions_checks_remove_param() {
|
||||||
|
|
||||||
// Define the expected results for the check statement when the test statement succeeded and when it failed
|
// Define the expected results for the check statement when the test statement succeeded and when it failed
|
||||||
let check_results = [
|
let check_results = [
|
||||||
vec!["{ analyzers: { }, functions: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"],
|
vec!["{ analyzers: { }, functions: { }, models: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"],
|
||||||
vec!["{ analyzers: { }, functions: { }, params: { param: \"DEFINE PARAM $param VALUE 'foo' PERMISSIONS FULL\" }, scopes: { }, tables: { }, tokens: { }, users: { } }"],
|
vec!["{ analyzers: { }, functions: { }, models: { }, params: { param: \"DEFINE PARAM $param VALUE 'foo' PERMISSIONS FULL\" }, scopes: { }, tables: { }, tokens: { }, users: { } }"],
|
||||||
];
|
];
|
||||||
|
|
||||||
let test_cases = [
|
let test_cases = [
|
||||||
|
@ -600,8 +602,8 @@ async fn permissions_checks_remove_table() {
|
||||||
|
|
||||||
// Define the expected results for the check statement when the test statement succeeded and when it failed
|
// Define the expected results for the check statement when the test statement succeeded and when it failed
|
||||||
let check_results = [
|
let check_results = [
|
||||||
vec!["{ analyzers: { }, functions: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"],
|
vec!["{ analyzers: { }, functions: { }, models: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"],
|
||||||
vec!["{ analyzers: { }, functions: { }, params: { }, scopes: { }, tables: { TB: 'DEFINE TABLE TB SCHEMALESS PERMISSIONS NONE' }, tokens: { }, users: { } }"],
|
vec!["{ analyzers: { }, functions: { }, models: { }, params: { }, scopes: { }, tables: { TB: 'DEFINE TABLE TB SCHEMALESS PERMISSIONS NONE' }, tokens: { }, users: { } }"],
|
||||||
];
|
];
|
||||||
|
|
||||||
let test_cases = [
|
let test_cases = [
|
||||||
|
|
|
@ -255,6 +255,7 @@ async fn loose_mode_all_ok() -> Result<(), Error> {
|
||||||
analyzers: {},
|
analyzers: {},
|
||||||
tokens: {},
|
tokens: {},
|
||||||
functions: {},
|
functions: {},
|
||||||
|
models: {},
|
||||||
params: {},
|
params: {},
|
||||||
scopes: {},
|
scopes: {},
|
||||||
tables: { test: 'DEFINE TABLE test SCHEMALESS PERMISSIONS NONE' },
|
tables: { test: 'DEFINE TABLE test SCHEMALESS PERMISSIONS NONE' },
|
||||||
|
|
|
@ -10,7 +10,7 @@ use tokio::io::{self, AsyncWriteExt};
|
||||||
|
|
||||||
#[derive(Args, Debug)]
|
#[derive(Args, Debug)]
|
||||||
pub struct ExportCommandArguments {
|
pub struct ExportCommandArguments {
|
||||||
#[arg(help = "Path to the sql file to export. Use dash - to write into stdout.")]
|
#[arg(help = "Path to the SurrealQL file to export. Use dash - to write into stdout.")]
|
||||||
#[arg(default_value = "-")]
|
#[arg(default_value = "-")]
|
||||||
#[arg(index = 1)]
|
#[arg(index = 1)]
|
||||||
file: String,
|
file: String,
|
||||||
|
@ -87,7 +87,7 @@ pub async fn init(
|
||||||
} else {
|
} else {
|
||||||
client.export(file).await?;
|
client.export(file).await?;
|
||||||
}
|
}
|
||||||
info!("The SQL file was exported successfully");
|
info!("The SurrealQL file was exported successfully");
|
||||||
// Everything OK
|
// Everything OK
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,7 +10,7 @@ use surrealdb::opt::Config;
|
||||||
|
|
||||||
#[derive(Args, Debug)]
|
#[derive(Args, Debug)]
|
||||||
pub struct ImportCommandArguments {
|
pub struct ImportCommandArguments {
|
||||||
#[arg(help = "Path to the sql file to import")]
|
#[arg(help = "Path to the SurrealQL file to import")]
|
||||||
#[arg(index = 1)]
|
#[arg(index = 1)]
|
||||||
file: String,
|
file: String,
|
||||||
#[command(flatten)]
|
#[command(flatten)]
|
||||||
|
@ -75,7 +75,7 @@ pub async fn init(
|
||||||
client.use_ns(namespace).use_db(database).await?;
|
client.use_ns(namespace).use_db(database).await?;
|
||||||
// Import the data into the database
|
// Import the data into the database
|
||||||
client.import(file).await?;
|
client.import(file).await?;
|
||||||
info!("The SQL file was imported successfully");
|
info!("The SurrealQL file was imported successfully");
|
||||||
// Everything OK
|
// Everything OK
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
117
src/cli/ml/export.rs
Normal file
117
src/cli/ml/export.rs
Normal file
|
@ -0,0 +1,117 @@
|
||||||
|
use crate::cli::abstraction::auth::{CredentialsBuilder, CredentialsLevel};
|
||||||
|
use crate::cli::abstraction::{
|
||||||
|
AuthArguments, DatabaseConnectionArguments, DatabaseSelectionArguments,
|
||||||
|
};
|
||||||
|
use crate::err::Error;
|
||||||
|
use clap::Args;
|
||||||
|
use futures_util::StreamExt;
|
||||||
|
use surrealdb::engine::any::{connect, IntoEndpoint};
|
||||||
|
use tokio::io::{self, AsyncWriteExt};
|
||||||
|
|
||||||
|
#[derive(Args, Debug)]
|
||||||
|
pub struct ModelArguments {
|
||||||
|
#[arg(help = "The name of the model")]
|
||||||
|
#[arg(env = "SURREAL_NAME", long = "name")]
|
||||||
|
pub(crate) name: String,
|
||||||
|
#[arg(help = "The version of the model")]
|
||||||
|
#[arg(env = "SURREAL_VERSION", long = "version")]
|
||||||
|
pub(crate) version: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Args, Debug)]
|
||||||
|
pub struct ExportCommandArguments {
|
||||||
|
#[arg(help = "Path to the SurrealML file to export. Use dash - to write into stdout.")]
|
||||||
|
#[arg(default_value = "-")]
|
||||||
|
#[arg(index = 1)]
|
||||||
|
file: String,
|
||||||
|
|
||||||
|
#[command(flatten)]
|
||||||
|
model: ModelArguments,
|
||||||
|
#[command(flatten)]
|
||||||
|
conn: DatabaseConnectionArguments,
|
||||||
|
#[command(flatten)]
|
||||||
|
auth: AuthArguments,
|
||||||
|
#[command(flatten)]
|
||||||
|
sel: DatabaseSelectionArguments,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn init(
|
||||||
|
ExportCommandArguments {
|
||||||
|
file,
|
||||||
|
model: ModelArguments {
|
||||||
|
name,
|
||||||
|
version,
|
||||||
|
},
|
||||||
|
conn: DatabaseConnectionArguments {
|
||||||
|
endpoint,
|
||||||
|
},
|
||||||
|
auth: AuthArguments {
|
||||||
|
username,
|
||||||
|
password,
|
||||||
|
auth_level,
|
||||||
|
},
|
||||||
|
sel: DatabaseSelectionArguments {
|
||||||
|
namespace,
|
||||||
|
database,
|
||||||
|
},
|
||||||
|
}: ExportCommandArguments,
|
||||||
|
) -> Result<(), Error> {
|
||||||
|
// Initialize opentelemetry and logging
|
||||||
|
crate::telemetry::builder().with_log_level("error").init();
|
||||||
|
|
||||||
|
// If username and password are specified, and we are connecting to a remote SurrealDB server, then we need to authenticate.
|
||||||
|
// If we are connecting directly to a datastore (i.e. file://local.db or tikv://...), then we don't need to authenticate because we use an embedded (local) SurrealDB instance with auth disabled.
|
||||||
|
let client = if username.is_some()
|
||||||
|
&& password.is_some()
|
||||||
|
&& !endpoint.clone().into_endpoint()?.parse_kind()?.is_local()
|
||||||
|
{
|
||||||
|
debug!("Connecting to the database engine with authentication");
|
||||||
|
let creds = CredentialsBuilder::default()
|
||||||
|
.with_username(username.as_deref())
|
||||||
|
.with_password(password.as_deref())
|
||||||
|
.with_namespace(namespace.as_str())
|
||||||
|
.with_database(database.as_str());
|
||||||
|
|
||||||
|
let client = connect(endpoint).await?;
|
||||||
|
|
||||||
|
debug!("Signing in to the database engine at '{:?}' level", auth_level);
|
||||||
|
match auth_level {
|
||||||
|
CredentialsLevel::Root => client.signin(creds.root()?).await?,
|
||||||
|
CredentialsLevel::Namespace => client.signin(creds.namespace()?).await?,
|
||||||
|
CredentialsLevel::Database => client.signin(creds.database()?).await?,
|
||||||
|
};
|
||||||
|
|
||||||
|
client
|
||||||
|
} else {
|
||||||
|
debug!("Connecting to the database engine without authentication");
|
||||||
|
connect(endpoint).await?
|
||||||
|
};
|
||||||
|
|
||||||
|
// Parse model version
|
||||||
|
let version = match version.parse() {
|
||||||
|
Ok(version) => version,
|
||||||
|
Err(_) => {
|
||||||
|
return Err(Error::Other(format!("`{version}` is not a valid semantic version")));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Use the specified namespace / database
|
||||||
|
client.use_ns(namespace).use_db(database).await?;
|
||||||
|
// Export the data from the database
|
||||||
|
debug!("Exporting data from the database");
|
||||||
|
if file == "-" {
|
||||||
|
// Prepare the backup
|
||||||
|
let mut backup = client.export(()).ml(&name, version).await?;
|
||||||
|
// Get a handle to standard output
|
||||||
|
let mut stdout = io::stdout();
|
||||||
|
// Write the backup to standard output
|
||||||
|
while let Some(bytes) = backup.next().await {
|
||||||
|
stdout.write_all(&bytes?).await?;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
client.export(file).ml(&name, version).await?;
|
||||||
|
}
|
||||||
|
info!("The SurrealML file was exported successfully");
|
||||||
|
// Everything OK
|
||||||
|
Ok(())
|
||||||
|
}
|
81
src/cli/ml/import.rs
Normal file
81
src/cli/ml/import.rs
Normal file
|
@ -0,0 +1,81 @@
|
||||||
|
use crate::cli::abstraction::auth::{CredentialsBuilder, CredentialsLevel};
|
||||||
|
use crate::cli::abstraction::{
|
||||||
|
AuthArguments, DatabaseConnectionArguments, DatabaseSelectionArguments,
|
||||||
|
};
|
||||||
|
use crate::err::Error;
|
||||||
|
use clap::Args;
|
||||||
|
use surrealdb::dbs::Capabilities;
|
||||||
|
use surrealdb::engine::any::{connect, IntoEndpoint};
|
||||||
|
use surrealdb::opt::Config;
|
||||||
|
|
||||||
|
#[derive(Args, Debug)]
|
||||||
|
pub struct ImportCommandArguments {
|
||||||
|
#[arg(help = "Path to the SurrealML file to import")]
|
||||||
|
#[arg(index = 1)]
|
||||||
|
file: String,
|
||||||
|
#[command(flatten)]
|
||||||
|
conn: DatabaseConnectionArguments,
|
||||||
|
#[command(flatten)]
|
||||||
|
auth: AuthArguments,
|
||||||
|
#[command(flatten)]
|
||||||
|
sel: DatabaseSelectionArguments,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn init(
|
||||||
|
ImportCommandArguments {
|
||||||
|
file,
|
||||||
|
conn: DatabaseConnectionArguments {
|
||||||
|
endpoint,
|
||||||
|
},
|
||||||
|
auth: AuthArguments {
|
||||||
|
username,
|
||||||
|
password,
|
||||||
|
auth_level,
|
||||||
|
},
|
||||||
|
sel: DatabaseSelectionArguments {
|
||||||
|
namespace,
|
||||||
|
database,
|
||||||
|
},
|
||||||
|
}: ImportCommandArguments,
|
||||||
|
) -> Result<(), Error> {
|
||||||
|
// Initialize opentelemetry and logging
|
||||||
|
crate::telemetry::builder().with_log_level("info").init();
|
||||||
|
// Default datastore configuration for local engines
|
||||||
|
let config = Config::new().capabilities(Capabilities::all());
|
||||||
|
|
||||||
|
// If username and password are specified, and we are connecting to a remote SurrealDB server, then we need to authenticate.
|
||||||
|
// If we are connecting directly to a datastore (i.e. file://local.db or tikv://...), then we don't need to authenticate because we use an embedded (local) SurrealDB instance with auth disabled.
|
||||||
|
let client = if username.is_some()
|
||||||
|
&& password.is_some()
|
||||||
|
&& !endpoint.clone().into_endpoint()?.parse_kind()?.is_local()
|
||||||
|
{
|
||||||
|
debug!("Connecting to the database engine with authentication");
|
||||||
|
let creds = CredentialsBuilder::default()
|
||||||
|
.with_username(username.as_deref())
|
||||||
|
.with_password(password.as_deref())
|
||||||
|
.with_namespace(namespace.as_str())
|
||||||
|
.with_database(database.as_str());
|
||||||
|
|
||||||
|
let client = connect(endpoint).await?;
|
||||||
|
|
||||||
|
debug!("Signing in to the database engine at '{:?}' level", auth_level);
|
||||||
|
match auth_level {
|
||||||
|
CredentialsLevel::Root => client.signin(creds.root()?).await?,
|
||||||
|
CredentialsLevel::Namespace => client.signin(creds.namespace()?).await?,
|
||||||
|
CredentialsLevel::Database => client.signin(creds.database()?).await?,
|
||||||
|
};
|
||||||
|
|
||||||
|
client
|
||||||
|
} else {
|
||||||
|
debug!("Connecting to the database engine without authentication");
|
||||||
|
connect((endpoint, config)).await?
|
||||||
|
};
|
||||||
|
|
||||||
|
// Use the specified namespace / database
|
||||||
|
client.use_ns(namespace).use_db(database).await?;
|
||||||
|
// Import the data into the database
|
||||||
|
client.import(file).ml().await?;
|
||||||
|
info!("The SurrealML file was imported successfully");
|
||||||
|
// Everything OK
|
||||||
|
Ok(())
|
||||||
|
}
|
22
src/cli/ml/mod.rs
Normal file
22
src/cli/ml/mod.rs
Normal file
|
@ -0,0 +1,22 @@
|
||||||
|
mod export;
|
||||||
|
mod import;
|
||||||
|
|
||||||
|
use self::export::ExportCommandArguments;
|
||||||
|
use self::import::ImportCommandArguments;
|
||||||
|
use crate::err::Error;
|
||||||
|
use clap::Subcommand;
|
||||||
|
|
||||||
|
#[derive(Debug, Subcommand)]
|
||||||
|
pub enum MlCommand {
|
||||||
|
#[command(about = "Import a SurrealML model into an existing database")]
|
||||||
|
Import(ImportCommandArguments),
|
||||||
|
#[command(about = "Export a SurrealML model from an existing database")]
|
||||||
|
Export(ExportCommandArguments),
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn init(command: MlCommand) -> Result<(), Error> {
|
||||||
|
match command {
|
||||||
|
MlCommand::Import(args) => import::init(args).await,
|
||||||
|
MlCommand::Export(args) => export::init(args).await,
|
||||||
|
}
|
||||||
|
}
|
|
@ -4,6 +4,7 @@ mod config;
|
||||||
mod export;
|
mod export;
|
||||||
mod import;
|
mod import;
|
||||||
mod isready;
|
mod isready;
|
||||||
|
mod ml;
|
||||||
mod sql;
|
mod sql;
|
||||||
#[cfg(feature = "has-storage")]
|
#[cfg(feature = "has-storage")]
|
||||||
mod start;
|
mod start;
|
||||||
|
@ -20,6 +21,7 @@ pub use config::CF;
|
||||||
use export::ExportCommandArguments;
|
use export::ExportCommandArguments;
|
||||||
use import::ImportCommandArguments;
|
use import::ImportCommandArguments;
|
||||||
use isready::IsReadyCommandArguments;
|
use isready::IsReadyCommandArguments;
|
||||||
|
use ml::MlCommand;
|
||||||
use sql::SqlCommandArguments;
|
use sql::SqlCommandArguments;
|
||||||
#[cfg(feature = "has-storage")]
|
#[cfg(feature = "has-storage")]
|
||||||
use start::StartCommandArguments;
|
use start::StartCommandArguments;
|
||||||
|
@ -68,6 +70,8 @@ enum Commands {
|
||||||
Upgrade(UpgradeCommandArguments),
|
Upgrade(UpgradeCommandArguments),
|
||||||
#[command(about = "Start an SQL REPL in your terminal with pipe support")]
|
#[command(about = "Start an SQL REPL in your terminal with pipe support")]
|
||||||
Sql(SqlCommandArguments),
|
Sql(SqlCommandArguments),
|
||||||
|
#[command(subcommand, about = "Manage SurrealML models within an existing database")]
|
||||||
|
Ml(MlCommand),
|
||||||
#[command(
|
#[command(
|
||||||
about = "Check if the SurrealDB server is ready to accept connections",
|
about = "Check if the SurrealDB server is ready to accept connections",
|
||||||
visible_alias = "isready"
|
visible_alias = "isready"
|
||||||
|
@ -88,6 +92,7 @@ pub async fn init() -> ExitCode {
|
||||||
Commands::Version(args) => version::init(args).await,
|
Commands::Version(args) => version::init(args).await,
|
||||||
Commands::Upgrade(args) => upgrade::init(args).await,
|
Commands::Upgrade(args) => upgrade::init(args).await,
|
||||||
Commands::Sql(args) => sql::init(args).await,
|
Commands::Sql(args) => sql::init(args).await,
|
||||||
|
Commands::Ml(args) => ml::init(args).await,
|
||||||
Commands::IsReady(args) => isready::init(args).await,
|
Commands::IsReady(args) => isready::init(args).await,
|
||||||
Commands::Validate(args) => validate::init(args).await,
|
Commands::Validate(args) => validate::init(args).await,
|
||||||
};
|
};
|
||||||
|
|
|
@ -29,31 +29,50 @@ pub const APP_ENDPOINT: &str = "https://surrealdb.com/app";
|
||||||
#[cfg(feature = "has-storage")]
|
#[cfg(feature = "has-storage")]
|
||||||
pub const WEBSOCKET_PING_FREQUENCY: Duration = Duration::from_secs(5);
|
pub const WEBSOCKET_PING_FREQUENCY: Duration = Duration::from_secs(5);
|
||||||
|
|
||||||
/// Set the maximum WebSocket frame size to 16mb
|
/// What is the maximum WebSocket frame size (defaults to 16 MiB)
|
||||||
#[cfg(feature = "has-storage")]
|
#[cfg(feature = "has-storage")]
|
||||||
pub static WEBSOCKET_MAX_FRAME_SIZE: Lazy<usize> = Lazy::new(|| {
|
pub static WEBSOCKET_MAX_FRAME_SIZE: Lazy<usize> = Lazy::new(|| {
|
||||||
let default = 16 << 20;
|
option_env!("SURREAL_WEBSOCKET_MAX_FRAME_SIZE")
|
||||||
std::env::var("SURREAL_WEBSOCKET_MAX_FRAME_SIZE")
|
.and_then(|s| s.parse::<usize>().ok())
|
||||||
.map(|v| v.parse::<usize>().unwrap_or(default))
|
.unwrap_or(16 << 20)
|
||||||
.unwrap_or(default)
|
|
||||||
});
|
});
|
||||||
|
|
||||||
/// Set the maximum WebSocket frame size to 128mb
|
/// What is the maximum WebSocket message size (defaults to 128 MiB)
|
||||||
#[cfg(feature = "has-storage")]
|
#[cfg(feature = "has-storage")]
|
||||||
pub static WEBSOCKET_MAX_MESSAGE_SIZE: Lazy<usize> = Lazy::new(|| {
|
pub static WEBSOCKET_MAX_MESSAGE_SIZE: Lazy<usize> = Lazy::new(|| {
|
||||||
let default = 128 << 20;
|
option_env!("SURREAL_WEBSOCKET_MAX_MESSAGE_SIZE")
|
||||||
std::env::var("SURREAL_WEBSOCKET_MAX_MESSAGE_SIZE")
|
.and_then(|s| s.parse::<usize>().ok())
|
||||||
.map(|v| v.parse::<usize>().unwrap_or(default))
|
.unwrap_or(128 << 20)
|
||||||
|
});
|
||||||
|
|
||||||
|
/// How many concurrent tasks can be handled on each WebSocket (defaults to 24)
|
||||||
|
#[cfg(feature = "has-storage")]
|
||||||
|
pub static WEBSOCKET_MAX_CONCURRENT_REQUESTS: Lazy<usize> = Lazy::new(|| {
|
||||||
|
option_env!("SURREAL_WEBSOCKET_MAX_CONCURRENT_REQUESTS")
|
||||||
|
.and_then(|s| s.parse::<usize>().ok())
|
||||||
|
.unwrap_or(24)
|
||||||
|
});
|
||||||
|
|
||||||
|
/// What is the runtime thread memory stack size (defaults to 10MiB)
|
||||||
|
#[cfg(feature = "has-storage")]
|
||||||
|
pub static RUNTIME_STACK_SIZE: Lazy<usize> = Lazy::new(|| {
|
||||||
|
// Stack frames are generally larger in debug mode.
|
||||||
|
let default = if cfg!(debug_assertions) {
|
||||||
|
20 * 1024 * 1024 // 20MiB in debug mode
|
||||||
|
} else {
|
||||||
|
10 * 1024 * 1024 // 10MiB in release mode
|
||||||
|
};
|
||||||
|
option_env!("SURREAL_RUNTIME_STACK_SIZE")
|
||||||
|
.and_then(|s| s.parse::<usize>().ok())
|
||||||
.unwrap_or(default)
|
.unwrap_or(default)
|
||||||
});
|
});
|
||||||
|
|
||||||
/// How many concurrent tasks can be handled in a WebSocket
|
/// How many threads which can be started for blocking operations (defaults to 512)
|
||||||
#[cfg(feature = "has-storage")]
|
#[cfg(feature = "has-storage")]
|
||||||
pub static WEBSOCKET_MAX_CONCURRENT_REQUESTS: Lazy<usize> = Lazy::new(|| {
|
pub static RUNTIME_MAX_BLOCKING_THREADS: Lazy<usize> = Lazy::new(|| {
|
||||||
let default = 24;
|
option_env!("SURREAL_RUNTIME_MAX_BLOCKING_THREADS")
|
||||||
std::env::var("SURREAL_WEBSOCKET_MAX_CONCURRENT_REQUESTS")
|
.and_then(|s| s.parse::<usize>().ok())
|
||||||
.map(|v| v.parse::<usize>().unwrap_or(default))
|
.unwrap_or(512)
|
||||||
.unwrap_or(default)
|
|
||||||
});
|
});
|
||||||
|
|
||||||
/// The version identifier of this build
|
/// The version identifier of this build
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
use crate::cli::abstraction::auth::Error as SurrealAuthError;
|
use crate::cli::abstraction::auth::Error as SurrealAuthError;
|
||||||
use axum::extract::rejection::TypedHeaderRejection;
|
use axum::extract::rejection::TypedHeaderRejection;
|
||||||
use axum::response::{IntoResponse, Response};
|
use axum::response::{IntoResponse, Response};
|
||||||
|
use axum::Error as AxumError;
|
||||||
use axum::Json;
|
use axum::Json;
|
||||||
use base64::DecodeError as Base64Error;
|
use base64::DecodeError as Base64Error;
|
||||||
use http::{HeaderName, StatusCode};
|
use http::{HeaderName, StatusCode};
|
||||||
|
@ -48,6 +49,9 @@ pub enum Error {
|
||||||
#[error("Couldn't open the specified file: {0}")]
|
#[error("Couldn't open the specified file: {0}")]
|
||||||
Io(#[from] IoError),
|
Io(#[from] IoError),
|
||||||
|
|
||||||
|
#[error("There was an error with the network: {0}")]
|
||||||
|
Axum(#[from] AxumError),
|
||||||
|
|
||||||
#[error("There was an error serializing to JSON: {0}")]
|
#[error("There was an error serializing to JSON: {0}")]
|
||||||
Json(#[from] JsonError),
|
Json(#[from] JsonError),
|
||||||
|
|
||||||
|
@ -60,11 +64,15 @@ pub enum Error {
|
||||||
#[error("There was an error with the remote request: {0}")]
|
#[error("There was an error with the remote request: {0}")]
|
||||||
Remote(#[from] ReqwestError),
|
Remote(#[from] ReqwestError),
|
||||||
|
|
||||||
|
#[error("There was an error with auth: {0}")]
|
||||||
|
Auth(#[from] SurrealAuthError),
|
||||||
|
|
||||||
#[error("There was an error with the node agent")]
|
#[error("There was an error with the node agent")]
|
||||||
NodeAgent,
|
NodeAgent,
|
||||||
|
|
||||||
#[error("There was an error with auth: {0}")]
|
/// Statement has been deprecated
|
||||||
Auth(#[from] SurrealAuthError),
|
#[error("{0}")]
|
||||||
|
Other(String),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<Error> for String {
|
impl From<Error> for String {
|
||||||
|
|
11
src/main.rs
11
src/main.rs
|
@ -42,15 +42,12 @@ fn main() -> ExitCode {
|
||||||
|
|
||||||
/// Rust's default thread stack size of 2MiB doesn't allow sufficient recursion depth.
|
/// Rust's default thread stack size of 2MiB doesn't allow sufficient recursion depth.
|
||||||
fn with_enough_stack<T>(fut: impl Future<Output = T> + Send) -> T {
|
fn with_enough_stack<T>(fut: impl Future<Output = T> + Send) -> T {
|
||||||
let stack_size = 10 * 1024 * 1024;
|
// Start a Tokio runtime with custom configuration
|
||||||
|
|
||||||
// Stack frames are generally larger in debug mode.
|
|
||||||
#[cfg(debug_assertions)]
|
|
||||||
let stack_size = stack_size * 2;
|
|
||||||
|
|
||||||
tokio::runtime::Builder::new_multi_thread()
|
tokio::runtime::Builder::new_multi_thread()
|
||||||
.enable_all()
|
.enable_all()
|
||||||
.thread_stack_size(stack_size)
|
.max_blocking_threads(*cnf::RUNTIME_MAX_BLOCKING_THREADS)
|
||||||
|
.thread_stack_size(*cnf::RUNTIME_STACK_SIZE)
|
||||||
|
.thread_name("surrealdb-worker")
|
||||||
.build()
|
.build()
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.block_on(fut)
|
.block_on(fut)
|
||||||
|
|
|
@ -9,6 +9,9 @@ use http::StatusCode;
|
||||||
use http_body::Body as HttpBody;
|
use http_body::Body as HttpBody;
|
||||||
use hyper::body::Body;
|
use hyper::body::Body;
|
||||||
use surrealdb::dbs::Session;
|
use surrealdb::dbs::Session;
|
||||||
|
use surrealdb::iam::check::check_ns_db;
|
||||||
|
use surrealdb::iam::Action::View;
|
||||||
|
use surrealdb::iam::ResourceKind::Any;
|
||||||
|
|
||||||
pub(super) fn router<S, B>() -> Router<S, B>
|
pub(super) fn router<S, B>() -> Router<S, B>
|
||||||
where
|
where
|
||||||
|
@ -18,35 +21,27 @@ where
|
||||||
Router::new().route("/export", get(handler))
|
Router::new().route("/export", get(handler))
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handler(
|
async fn handler(Extension(session): Extension<Session>) -> Result<impl IntoResponse, Error> {
|
||||||
Extension(session): Extension<Session>,
|
|
||||||
) -> Result<impl IntoResponse, impl IntoResponse> {
|
|
||||||
// Get the datastore reference
|
// Get the datastore reference
|
||||||
let db = DB.get().unwrap();
|
let db = DB.get().unwrap();
|
||||||
// Extract the NS header value
|
|
||||||
let nsv = match session.ns.clone() {
|
|
||||||
Some(ns) => ns,
|
|
||||||
None => return Err(Error::NoNamespace),
|
|
||||||
};
|
|
||||||
// Extract the DB header value
|
|
||||||
let dbv = match session.db.clone() {
|
|
||||||
Some(db) => db,
|
|
||||||
None => return Err(Error::NoDatabase),
|
|
||||||
};
|
|
||||||
// Create a chunked response
|
// Create a chunked response
|
||||||
let (mut chn, bdy) = Body::channel();
|
let (mut chn, body) = Body::channel();
|
||||||
|
// Ensure a NS and DB are set
|
||||||
|
let (nsv, dbv) = check_ns_db(&session)?;
|
||||||
|
// Check the permissions level
|
||||||
|
db.check(&session, View, Any.on_db(&nsv, &dbv))?;
|
||||||
// Create a new bounded channel
|
// Create a new bounded channel
|
||||||
let (snd, rcv) = surrealdb::channel::bounded(1);
|
let (snd, rcv) = surrealdb::channel::bounded(1);
|
||||||
|
// Start the export task
|
||||||
let export_job = db.export(&session, nsv, dbv, snd).await.map_err(Error::from)?;
|
let task = db.export(&session, snd).await?;
|
||||||
// Spawn a new database export job
|
// Spawn a new database export job
|
||||||
tokio::spawn(export_job);
|
tokio::spawn(task);
|
||||||
// Process all processed values
|
// Process all chunk values
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
while let Ok(v) = rcv.recv().await {
|
while let Ok(v) = rcv.recv().await {
|
||||||
let _ = chn.send_data(Bytes::from(v)).await;
|
let _ = chn.send_data(Bytes::from(v)).await;
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
// Return the chunked body
|
// Return the chunked body
|
||||||
Ok(Response::builder().status(StatusCode::OK).body(bdy).unwrap())
|
Ok(Response::builder().status(StatusCode::OK).body(body).unwrap())
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
use super::headers::Accept;
|
||||||
use crate::dbs::DB;
|
use crate::dbs::DB;
|
||||||
use crate::err::Error;
|
use crate::err::Error;
|
||||||
use crate::net::input::bytes_to_utf8;
|
use crate::net::input::bytes_to_utf8;
|
||||||
|
@ -11,10 +12,10 @@ use axum::TypedHeader;
|
||||||
use bytes::Bytes;
|
use bytes::Bytes;
|
||||||
use http_body::Body as HttpBody;
|
use http_body::Body as HttpBody;
|
||||||
use surrealdb::dbs::Session;
|
use surrealdb::dbs::Session;
|
||||||
|
use surrealdb::iam::Action::Edit;
|
||||||
|
use surrealdb::iam::ResourceKind::Any;
|
||||||
use tower_http::limit::RequestBodyLimitLayer;
|
use tower_http::limit::RequestBodyLimitLayer;
|
||||||
|
|
||||||
use super::headers::Accept;
|
|
||||||
|
|
||||||
const MAX: usize = 1024 * 1024 * 1024 * 4; // 4 GiB
|
const MAX: usize = 1024 * 1024 * 1024 * 4; // 4 GiB
|
||||||
|
|
||||||
pub(super) fn router<S, B>() -> Router<S, B>
|
pub(super) fn router<S, B>() -> Router<S, B>
|
||||||
|
@ -32,24 +33,26 @@ where
|
||||||
|
|
||||||
async fn handler(
|
async fn handler(
|
||||||
Extension(session): Extension<Session>,
|
Extension(session): Extension<Session>,
|
||||||
maybe_output: Option<TypedHeader<Accept>>,
|
accept: Option<TypedHeader<Accept>>,
|
||||||
sql: Bytes,
|
sql: Bytes,
|
||||||
) -> Result<impl IntoResponse, impl IntoResponse> {
|
) -> Result<impl IntoResponse, impl IntoResponse> {
|
||||||
// Get the datastore reference
|
// Get the datastore reference
|
||||||
let db = DB.get().unwrap();
|
let db = DB.get().unwrap();
|
||||||
// Convert the body to a byte slice
|
// Convert the body to a byte slice
|
||||||
let sql = bytes_to_utf8(&sql)?;
|
let sql = bytes_to_utf8(&sql)?;
|
||||||
|
// Check the permissions level
|
||||||
|
db.check(&session, Edit, Any.on_level(session.au.level().to_owned()))?;
|
||||||
// Execute the sql query in the database
|
// Execute the sql query in the database
|
||||||
match db.import(sql, &session).await {
|
match db.import(sql, &session).await {
|
||||||
Ok(res) => match maybe_output.as_deref() {
|
Ok(res) => match accept.as_deref() {
|
||||||
// Simple serialization
|
// Simple serialization
|
||||||
Some(Accept::ApplicationJson) => Ok(output::json(&output::simplify(res))),
|
Some(Accept::ApplicationJson) => Ok(output::json(&output::simplify(res))),
|
||||||
Some(Accept::ApplicationCbor) => Ok(output::cbor(&output::simplify(res))),
|
Some(Accept::ApplicationCbor) => Ok(output::cbor(&output::simplify(res))),
|
||||||
Some(Accept::ApplicationPack) => Ok(output::pack(&output::simplify(res))),
|
Some(Accept::ApplicationPack) => Ok(output::pack(&output::simplify(res))),
|
||||||
// Internal serialization
|
|
||||||
Some(Accept::Surrealdb) => Ok(output::full(&res)),
|
|
||||||
// Return nothing
|
// Return nothing
|
||||||
Some(Accept::ApplicationOctetStream) => Ok(output::none()),
|
Some(Accept::ApplicationOctetStream) => Ok(output::none()),
|
||||||
|
// Internal serialization
|
||||||
|
Some(Accept::Surrealdb) => Ok(output::full(&res)),
|
||||||
// An incorrect content-type was requested
|
// An incorrect content-type was requested
|
||||||
_ => Err(Error::InvalidType),
|
_ => Err(Error::InvalidType),
|
||||||
},
|
},
|
||||||
|
|
|
@ -13,6 +13,7 @@ use http_body::Body as HttpBody;
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use std::str;
|
use std::str;
|
||||||
use surrealdb::dbs::Session;
|
use surrealdb::dbs::Session;
|
||||||
|
use surrealdb::iam::check::check_ns_db;
|
||||||
use surrealdb::sql::Value;
|
use surrealdb::sql::Value;
|
||||||
use tower_http::limit::RequestBodyLimitLayer;
|
use tower_http::limit::RequestBodyLimitLayer;
|
||||||
|
|
||||||
|
@ -68,12 +69,14 @@ where
|
||||||
|
|
||||||
async fn select_all(
|
async fn select_all(
|
||||||
Extension(session): Extension<Session>,
|
Extension(session): Extension<Session>,
|
||||||
maybe_output: Option<TypedHeader<Accept>>,
|
accept: Option<TypedHeader<Accept>>,
|
||||||
Path(table): Path<String>,
|
Path(table): Path<String>,
|
||||||
Query(query): Query<QueryOptions>,
|
Query(query): Query<QueryOptions>,
|
||||||
) -> Result<impl IntoResponse, impl IntoResponse> {
|
) -> Result<impl IntoResponse, impl IntoResponse> {
|
||||||
// Get the datastore reference
|
// Get the datastore reference
|
||||||
let db = DB.get().unwrap();
|
let db = DB.get().unwrap();
|
||||||
|
// Ensure a NS and DB are set
|
||||||
|
let _ = check_ns_db(&session)?;
|
||||||
// Specify the request statement
|
// Specify the request statement
|
||||||
let sql = match query.fields {
|
let sql = match query.fields {
|
||||||
None => "SELECT * FROM type::table($table) LIMIT $limit START $start",
|
None => "SELECT * FROM type::table($table) LIMIT $limit START $start",
|
||||||
|
@ -88,7 +91,7 @@ async fn select_all(
|
||||||
};
|
};
|
||||||
// Execute the query and return the result
|
// Execute the query and return the result
|
||||||
match db.execute(sql, &session, Some(vars)).await {
|
match db.execute(sql, &session, Some(vars)).await {
|
||||||
Ok(res) => match maybe_output.as_deref() {
|
Ok(res) => match accept.as_deref() {
|
||||||
// Simple serialization
|
// Simple serialization
|
||||||
Some(Accept::ApplicationJson) => Ok(output::json(&output::simplify(res))),
|
Some(Accept::ApplicationJson) => Ok(output::json(&output::simplify(res))),
|
||||||
Some(Accept::ApplicationCbor) => Ok(output::cbor(&output::simplify(res))),
|
Some(Accept::ApplicationCbor) => Ok(output::cbor(&output::simplify(res))),
|
||||||
|
@ -105,13 +108,15 @@ async fn select_all(
|
||||||
|
|
||||||
async fn create_all(
|
async fn create_all(
|
||||||
Extension(session): Extension<Session>,
|
Extension(session): Extension<Session>,
|
||||||
maybe_output: Option<TypedHeader<Accept>>,
|
accept: Option<TypedHeader<Accept>>,
|
||||||
Path(table): Path<String>,
|
Path(table): Path<String>,
|
||||||
Query(params): Query<Params>,
|
Query(params): Query<Params>,
|
||||||
body: Bytes,
|
body: Bytes,
|
||||||
) -> Result<impl IntoResponse, impl IntoResponse> {
|
) -> Result<impl IntoResponse, impl IntoResponse> {
|
||||||
// Get the datastore reference
|
// Get the datastore reference
|
||||||
let db = DB.get().unwrap();
|
let db = DB.get().unwrap();
|
||||||
|
// Ensure a NS and DB are set
|
||||||
|
let _ = check_ns_db(&session)?;
|
||||||
// Convert the HTTP request body
|
// Convert the HTTP request body
|
||||||
let data = bytes_to_utf8(&body)?;
|
let data = bytes_to_utf8(&body)?;
|
||||||
// Parse the request body as JSON
|
// Parse the request body as JSON
|
||||||
|
@ -127,7 +132,7 @@ async fn create_all(
|
||||||
};
|
};
|
||||||
// Execute the query and return the result
|
// Execute the query and return the result
|
||||||
match db.execute(sql, &session, Some(vars)).await {
|
match db.execute(sql, &session, Some(vars)).await {
|
||||||
Ok(res) => match maybe_output.as_deref() {
|
Ok(res) => match accept.as_deref() {
|
||||||
// Simple serialization
|
// Simple serialization
|
||||||
Some(Accept::ApplicationJson) => Ok(output::json(&output::simplify(res))),
|
Some(Accept::ApplicationJson) => Ok(output::json(&output::simplify(res))),
|
||||||
Some(Accept::ApplicationCbor) => Ok(output::cbor(&output::simplify(res))),
|
Some(Accept::ApplicationCbor) => Ok(output::cbor(&output::simplify(res))),
|
||||||
|
@ -147,13 +152,15 @@ async fn create_all(
|
||||||
|
|
||||||
async fn update_all(
|
async fn update_all(
|
||||||
Extension(session): Extension<Session>,
|
Extension(session): Extension<Session>,
|
||||||
maybe_output: Option<TypedHeader<Accept>>,
|
accept: Option<TypedHeader<Accept>>,
|
||||||
Path(table): Path<String>,
|
Path(table): Path<String>,
|
||||||
Query(params): Query<Params>,
|
Query(params): Query<Params>,
|
||||||
body: Bytes,
|
body: Bytes,
|
||||||
) -> Result<impl IntoResponse, impl IntoResponse> {
|
) -> Result<impl IntoResponse, impl IntoResponse> {
|
||||||
// Get the datastore reference
|
// Get the datastore reference
|
||||||
let db = DB.get().unwrap();
|
let db = DB.get().unwrap();
|
||||||
|
// Ensure a NS and DB are set
|
||||||
|
let _ = check_ns_db(&session)?;
|
||||||
// Convert the HTTP request body
|
// Convert the HTTP request body
|
||||||
let data = bytes_to_utf8(&body)?;
|
let data = bytes_to_utf8(&body)?;
|
||||||
// Parse the request body as JSON
|
// Parse the request body as JSON
|
||||||
|
@ -169,7 +176,7 @@ async fn update_all(
|
||||||
};
|
};
|
||||||
// Execute the query and return the result
|
// Execute the query and return the result
|
||||||
match db.execute(sql, &session, Some(vars)).await {
|
match db.execute(sql, &session, Some(vars)).await {
|
||||||
Ok(res) => match maybe_output.as_deref() {
|
Ok(res) => match accept.as_deref() {
|
||||||
// Simple serialization
|
// Simple serialization
|
||||||
Some(Accept::ApplicationJson) => Ok(output::json(&output::simplify(res))),
|
Some(Accept::ApplicationJson) => Ok(output::json(&output::simplify(res))),
|
||||||
Some(Accept::ApplicationCbor) => Ok(output::cbor(&output::simplify(res))),
|
Some(Accept::ApplicationCbor) => Ok(output::cbor(&output::simplify(res))),
|
||||||
|
@ -189,13 +196,15 @@ async fn update_all(
|
||||||
|
|
||||||
async fn modify_all(
|
async fn modify_all(
|
||||||
Extension(session): Extension<Session>,
|
Extension(session): Extension<Session>,
|
||||||
maybe_output: Option<TypedHeader<Accept>>,
|
accept: Option<TypedHeader<Accept>>,
|
||||||
Path(table): Path<String>,
|
Path(table): Path<String>,
|
||||||
Query(params): Query<Params>,
|
Query(params): Query<Params>,
|
||||||
body: Bytes,
|
body: Bytes,
|
||||||
) -> Result<impl IntoResponse, impl IntoResponse> {
|
) -> Result<impl IntoResponse, impl IntoResponse> {
|
||||||
// Get the datastore reference
|
// Get the datastore reference
|
||||||
let db = DB.get().unwrap();
|
let db = DB.get().unwrap();
|
||||||
|
// Ensure a NS and DB are set
|
||||||
|
let _ = check_ns_db(&session)?;
|
||||||
// Convert the HTTP request body
|
// Convert the HTTP request body
|
||||||
let data = bytes_to_utf8(&body)?;
|
let data = bytes_to_utf8(&body)?;
|
||||||
// Parse the request body as JSON
|
// Parse the request body as JSON
|
||||||
|
@ -211,7 +220,7 @@ async fn modify_all(
|
||||||
};
|
};
|
||||||
// Execute the query and return the result
|
// Execute the query and return the result
|
||||||
match db.execute(sql, &session, Some(vars)).await {
|
match db.execute(sql, &session, Some(vars)).await {
|
||||||
Ok(res) => match maybe_output.as_deref() {
|
Ok(res) => match accept.as_deref() {
|
||||||
// Simple serialization
|
// Simple serialization
|
||||||
Some(Accept::ApplicationJson) => Ok(output::json(&output::simplify(res))),
|
Some(Accept::ApplicationJson) => Ok(output::json(&output::simplify(res))),
|
||||||
Some(Accept::ApplicationCbor) => Ok(output::cbor(&output::simplify(res))),
|
Some(Accept::ApplicationCbor) => Ok(output::cbor(&output::simplify(res))),
|
||||||
|
@ -231,12 +240,14 @@ async fn modify_all(
|
||||||
|
|
||||||
async fn delete_all(
|
async fn delete_all(
|
||||||
Extension(session): Extension<Session>,
|
Extension(session): Extension<Session>,
|
||||||
maybe_output: Option<TypedHeader<Accept>>,
|
accept: Option<TypedHeader<Accept>>,
|
||||||
Path(table): Path<String>,
|
Path(table): Path<String>,
|
||||||
Query(params): Query<Params>,
|
Query(params): Query<Params>,
|
||||||
) -> Result<impl IntoResponse, impl IntoResponse> {
|
) -> Result<impl IntoResponse, impl IntoResponse> {
|
||||||
// Get the datastore reference
|
// Get the datastore reference
|
||||||
let db = DB.get().unwrap();
|
let db = DB.get().unwrap();
|
||||||
|
// Ensure a NS and DB are set
|
||||||
|
let _ = check_ns_db(&session)?;
|
||||||
// Specify the request statement
|
// Specify the request statement
|
||||||
let sql = "DELETE type::table($table) RETURN BEFORE";
|
let sql = "DELETE type::table($table) RETURN BEFORE";
|
||||||
// Specify the request variables
|
// Specify the request variables
|
||||||
|
@ -246,7 +257,7 @@ async fn delete_all(
|
||||||
};
|
};
|
||||||
// Execute the query and return the result
|
// Execute the query and return the result
|
||||||
match db.execute(sql, &session, Some(vars)).await {
|
match db.execute(sql, &session, Some(vars)).await {
|
||||||
Ok(res) => match maybe_output.as_deref() {
|
Ok(res) => match accept.as_deref() {
|
||||||
// Simple serialization
|
// Simple serialization
|
||||||
Some(Accept::ApplicationJson) => Ok(output::json(&output::simplify(res))),
|
Some(Accept::ApplicationJson) => Ok(output::json(&output::simplify(res))),
|
||||||
Some(Accept::ApplicationCbor) => Ok(output::cbor(&output::simplify(res))),
|
Some(Accept::ApplicationCbor) => Ok(output::cbor(&output::simplify(res))),
|
||||||
|
@ -267,12 +278,14 @@ async fn delete_all(
|
||||||
|
|
||||||
async fn select_one(
|
async fn select_one(
|
||||||
Extension(session): Extension<Session>,
|
Extension(session): Extension<Session>,
|
||||||
maybe_output: Option<TypedHeader<Accept>>,
|
accept: Option<TypedHeader<Accept>>,
|
||||||
Path((table, id)): Path<(String, String)>,
|
Path((table, id)): Path<(String, String)>,
|
||||||
Query(query): Query<QueryOptions>,
|
Query(query): Query<QueryOptions>,
|
||||||
) -> Result<impl IntoResponse, impl IntoResponse> {
|
) -> Result<impl IntoResponse, impl IntoResponse> {
|
||||||
// Get the datastore reference
|
// Get the datastore reference
|
||||||
let db = DB.get().unwrap();
|
let db = DB.get().unwrap();
|
||||||
|
// Ensure a NS and DB are set
|
||||||
|
let _ = check_ns_db(&session)?;
|
||||||
// Specify the request statement
|
// Specify the request statement
|
||||||
let sql = match query.fields {
|
let sql = match query.fields {
|
||||||
None => "SELECT * FROM type::thing($table, $id)",
|
None => "SELECT * FROM type::thing($table, $id)",
|
||||||
|
@ -291,7 +304,7 @@ async fn select_one(
|
||||||
};
|
};
|
||||||
// Execute the query and return the result
|
// Execute the query and return the result
|
||||||
match db.execute(sql, &session, Some(vars)).await {
|
match db.execute(sql, &session, Some(vars)).await {
|
||||||
Ok(res) => match maybe_output.as_deref() {
|
Ok(res) => match accept.as_deref() {
|
||||||
// Simple serialization
|
// Simple serialization
|
||||||
Some(Accept::ApplicationJson) => Ok(output::json(&output::simplify(res))),
|
Some(Accept::ApplicationJson) => Ok(output::json(&output::simplify(res))),
|
||||||
Some(Accept::ApplicationCbor) => Ok(output::cbor(&output::simplify(res))),
|
Some(Accept::ApplicationCbor) => Ok(output::cbor(&output::simplify(res))),
|
||||||
|
@ -308,13 +321,15 @@ async fn select_one(
|
||||||
|
|
||||||
async fn create_one(
|
async fn create_one(
|
||||||
Extension(session): Extension<Session>,
|
Extension(session): Extension<Session>,
|
||||||
maybe_output: Option<TypedHeader<Accept>>,
|
accept: Option<TypedHeader<Accept>>,
|
||||||
Query(params): Query<Params>,
|
Query(params): Query<Params>,
|
||||||
Path((table, id)): Path<(String, String)>,
|
Path((table, id)): Path<(String, String)>,
|
||||||
body: Bytes,
|
body: Bytes,
|
||||||
) -> Result<impl IntoResponse, impl IntoResponse> {
|
) -> Result<impl IntoResponse, impl IntoResponse> {
|
||||||
// Get the datastore reference
|
// Get the datastore reference
|
||||||
let db = DB.get().unwrap();
|
let db = DB.get().unwrap();
|
||||||
|
// Ensure a NS and DB are set
|
||||||
|
let _ = check_ns_db(&session)?;
|
||||||
// Convert the HTTP request body
|
// Convert the HTTP request body
|
||||||
let data = bytes_to_utf8(&body)?;
|
let data = bytes_to_utf8(&body)?;
|
||||||
// Parse the Record ID as a SurrealQL value
|
// Parse the Record ID as a SurrealQL value
|
||||||
|
@ -336,7 +351,7 @@ async fn create_one(
|
||||||
};
|
};
|
||||||
// Execute the query and return the result
|
// Execute the query and return the result
|
||||||
match db.execute(sql, &session, Some(vars)).await {
|
match db.execute(sql, &session, Some(vars)).await {
|
||||||
Ok(res) => match maybe_output.as_deref() {
|
Ok(res) => match accept.as_deref() {
|
||||||
// Simple serialization
|
// Simple serialization
|
||||||
Some(Accept::ApplicationJson) => Ok(output::json(&output::simplify(res))),
|
Some(Accept::ApplicationJson) => Ok(output::json(&output::simplify(res))),
|
||||||
Some(Accept::ApplicationCbor) => Ok(output::cbor(&output::simplify(res))),
|
Some(Accept::ApplicationCbor) => Ok(output::cbor(&output::simplify(res))),
|
||||||
|
@ -356,13 +371,15 @@ async fn create_one(
|
||||||
|
|
||||||
async fn update_one(
|
async fn update_one(
|
||||||
Extension(session): Extension<Session>,
|
Extension(session): Extension<Session>,
|
||||||
maybe_output: Option<TypedHeader<Accept>>,
|
accept: Option<TypedHeader<Accept>>,
|
||||||
Query(params): Query<Params>,
|
Query(params): Query<Params>,
|
||||||
Path((table, id)): Path<(String, String)>,
|
Path((table, id)): Path<(String, String)>,
|
||||||
body: Bytes,
|
body: Bytes,
|
||||||
) -> Result<impl IntoResponse, impl IntoResponse> {
|
) -> Result<impl IntoResponse, impl IntoResponse> {
|
||||||
// Get the datastore reference
|
// Get the datastore reference
|
||||||
let db = DB.get().unwrap();
|
let db = DB.get().unwrap();
|
||||||
|
// Ensure a NS and DB are set
|
||||||
|
let _ = check_ns_db(&session)?;
|
||||||
// Convert the HTTP request body
|
// Convert the HTTP request body
|
||||||
let data = bytes_to_utf8(&body)?;
|
let data = bytes_to_utf8(&body)?;
|
||||||
// Parse the Record ID as a SurrealQL value
|
// Parse the Record ID as a SurrealQL value
|
||||||
|
@ -384,7 +401,7 @@ async fn update_one(
|
||||||
};
|
};
|
||||||
// Execute the query and return the result
|
// Execute the query and return the result
|
||||||
match db.execute(sql, &session, Some(vars)).await {
|
match db.execute(sql, &session, Some(vars)).await {
|
||||||
Ok(res) => match maybe_output.as_deref() {
|
Ok(res) => match accept.as_deref() {
|
||||||
// Simple serialization
|
// Simple serialization
|
||||||
Some(Accept::ApplicationJson) => Ok(output::json(&output::simplify(res))),
|
Some(Accept::ApplicationJson) => Ok(output::json(&output::simplify(res))),
|
||||||
Some(Accept::ApplicationCbor) => Ok(output::cbor(&output::simplify(res))),
|
Some(Accept::ApplicationCbor) => Ok(output::cbor(&output::simplify(res))),
|
||||||
|
@ -404,13 +421,15 @@ async fn update_one(
|
||||||
|
|
||||||
async fn modify_one(
|
async fn modify_one(
|
||||||
Extension(session): Extension<Session>,
|
Extension(session): Extension<Session>,
|
||||||
maybe_output: Option<TypedHeader<Accept>>,
|
accept: Option<TypedHeader<Accept>>,
|
||||||
Query(params): Query<Params>,
|
Query(params): Query<Params>,
|
||||||
Path((table, id)): Path<(String, String)>,
|
Path((table, id)): Path<(String, String)>,
|
||||||
body: Bytes,
|
body: Bytes,
|
||||||
) -> Result<impl IntoResponse, impl IntoResponse> {
|
) -> Result<impl IntoResponse, impl IntoResponse> {
|
||||||
// Get the datastore reference
|
// Get the datastore reference
|
||||||
let db = DB.get().unwrap();
|
let db = DB.get().unwrap();
|
||||||
|
// Ensure a NS and DB are set
|
||||||
|
let _ = check_ns_db(&session)?;
|
||||||
// Convert the HTTP request body
|
// Convert the HTTP request body
|
||||||
let data = bytes_to_utf8(&body)?;
|
let data = bytes_to_utf8(&body)?;
|
||||||
// Parse the Record ID as a SurrealQL value
|
// Parse the Record ID as a SurrealQL value
|
||||||
|
@ -432,7 +451,7 @@ async fn modify_one(
|
||||||
};
|
};
|
||||||
// Execute the query and return the result
|
// Execute the query and return the result
|
||||||
match db.execute(sql, &session, Some(vars)).await {
|
match db.execute(sql, &session, Some(vars)).await {
|
||||||
Ok(res) => match maybe_output.as_deref() {
|
Ok(res) => match accept.as_deref() {
|
||||||
// Simple serialization
|
// Simple serialization
|
||||||
Some(Accept::ApplicationJson) => Ok(output::json(&output::simplify(res))),
|
Some(Accept::ApplicationJson) => Ok(output::json(&output::simplify(res))),
|
||||||
Some(Accept::ApplicationCbor) => Ok(output::cbor(&output::simplify(res))),
|
Some(Accept::ApplicationCbor) => Ok(output::cbor(&output::simplify(res))),
|
||||||
|
@ -452,11 +471,13 @@ async fn modify_one(
|
||||||
|
|
||||||
async fn delete_one(
|
async fn delete_one(
|
||||||
Extension(session): Extension<Session>,
|
Extension(session): Extension<Session>,
|
||||||
maybe_output: Option<TypedHeader<Accept>>,
|
accept: Option<TypedHeader<Accept>>,
|
||||||
Path((table, id)): Path<(String, String)>,
|
Path((table, id)): Path<(String, String)>,
|
||||||
) -> Result<impl IntoResponse, impl IntoResponse> {
|
) -> Result<impl IntoResponse, impl IntoResponse> {
|
||||||
// Get the datastore reference
|
// Get the datastore reference
|
||||||
let db = DB.get().unwrap();
|
let db = DB.get().unwrap();
|
||||||
|
// Ensure a NS and DB are set
|
||||||
|
let _ = check_ns_db(&session)?;
|
||||||
// Specify the request statement
|
// Specify the request statement
|
||||||
let sql = "DELETE type::thing($table, $id) RETURN BEFORE";
|
let sql = "DELETE type::thing($table, $id) RETURN BEFORE";
|
||||||
// Parse the Record ID as a SurrealQL value
|
// Parse the Record ID as a SurrealQL value
|
||||||
|
@ -471,7 +492,7 @@ async fn delete_one(
|
||||||
};
|
};
|
||||||
// Execute the query and return the result
|
// Execute the query and return the result
|
||||||
match db.execute(sql, &session, Some(vars)).await {
|
match db.execute(sql, &session, Some(vars)).await {
|
||||||
Ok(res) => match maybe_output.as_deref() {
|
Ok(res) => match accept.as_deref() {
|
||||||
// Simple serialization
|
// Simple serialization
|
||||||
Some(Accept::ApplicationJson) => Ok(output::json(&output::simplify(res))),
|
Some(Accept::ApplicationJson) => Ok(output::json(&output::simplify(res))),
|
||||||
Some(Accept::ApplicationCbor) => Ok(output::cbor(&output::simplify(res))),
|
Some(Accept::ApplicationCbor) => Ok(output::cbor(&output::simplify(res))),
|
||||||
|
|
123
src/net/ml.rs
Normal file
123
src/net/ml.rs
Normal file
|
@ -0,0 +1,123 @@
|
||||||
|
//! This file defines the endpoints for the ML API for importing and exporting SurrealML models.
|
||||||
|
use crate::dbs::DB;
|
||||||
|
use crate::err::Error;
|
||||||
|
use crate::net::output;
|
||||||
|
use axum::extract::{BodyStream, DefaultBodyLimit, Path};
|
||||||
|
use axum::response::IntoResponse;
|
||||||
|
use axum::response::Response;
|
||||||
|
use axum::routing::{get, post};
|
||||||
|
use axum::Extension;
|
||||||
|
use axum::Router;
|
||||||
|
use bytes::Bytes;
|
||||||
|
use futures_util::StreamExt;
|
||||||
|
use http::StatusCode;
|
||||||
|
use http_body::Body as HttpBody;
|
||||||
|
use hyper::body::Body;
|
||||||
|
use surrealdb::dbs::Session;
|
||||||
|
use surrealdb::iam::check::check_ns_db;
|
||||||
|
use surrealdb::iam::Action::{Edit, View};
|
||||||
|
use surrealdb::iam::ResourceKind::Model;
|
||||||
|
use surrealdb::kvs::{LockType::Optimistic, TransactionType::Read};
|
||||||
|
use surrealdb::sql::statements::{DefineModelStatement, DefineStatement};
|
||||||
|
use surrealml_core::storage::surml_file::SurMlFile;
|
||||||
|
use tower_http::limit::RequestBodyLimitLayer;
|
||||||
|
|
||||||
|
const MAX: usize = 1024 * 1024 * 1024 * 4; // 4 GiB
|
||||||
|
|
||||||
|
/// The router definition for the ML API endpoints.
|
||||||
|
pub(super) fn router<S, B>() -> Router<S, B>
|
||||||
|
where
|
||||||
|
B: HttpBody + Send + 'static,
|
||||||
|
B::Data: Send + Into<Bytes>,
|
||||||
|
B::Error: std::error::Error + Send + Sync + 'static,
|
||||||
|
S: Clone + Send + Sync + 'static,
|
||||||
|
{
|
||||||
|
Router::new()
|
||||||
|
.route("/ml/import", post(import))
|
||||||
|
.route("/ml/export/:name/:version", get(export))
|
||||||
|
.route_layer(DefaultBodyLimit::disable())
|
||||||
|
.layer(RequestBodyLimitLayer::new(MAX))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// This endpoint allows the user to import a model into the database.
|
||||||
|
async fn import(
|
||||||
|
Extension(session): Extension<Session>,
|
||||||
|
mut stream: BodyStream,
|
||||||
|
) -> Result<impl IntoResponse, impl IntoResponse> {
|
||||||
|
// Get the datastore reference
|
||||||
|
let db = DB.get().unwrap();
|
||||||
|
// Ensure a NS and DB are set
|
||||||
|
let (nsv, dbv) = check_ns_db(&session)?;
|
||||||
|
// Check the permissions level
|
||||||
|
db.check(&session, Edit, Model.on_db(&nsv, &dbv))?;
|
||||||
|
// Create a new buffer
|
||||||
|
let mut buffer = Vec::new();
|
||||||
|
// Load all the uploaded file chunks
|
||||||
|
while let Some(chunk) = stream.next().await {
|
||||||
|
buffer.extend_from_slice(&chunk?);
|
||||||
|
}
|
||||||
|
// Check that the SurrealML file is valid
|
||||||
|
let file = match SurMlFile::from_bytes(buffer) {
|
||||||
|
Ok(file) => file,
|
||||||
|
Err(err) => return Err(Error::Other(err.to_string())),
|
||||||
|
};
|
||||||
|
// Convert the file back in to raw bytes
|
||||||
|
let data = file.to_bytes();
|
||||||
|
// Calculate the hash of the model file
|
||||||
|
let hash = surrealdb::obs::hash(&data);
|
||||||
|
// Calculate the path of the model file
|
||||||
|
let path = format!(
|
||||||
|
"ml/{nsv}/{dbv}/{}-{}-{hash}.surml",
|
||||||
|
file.header.name.to_string(),
|
||||||
|
file.header.version.to_string()
|
||||||
|
);
|
||||||
|
// Insert the file data in to the store
|
||||||
|
surrealdb::obs::put(&path, data).await?;
|
||||||
|
// Insert the model in to the database
|
||||||
|
db.process(
|
||||||
|
DefineStatement::Model(DefineModelStatement {
|
||||||
|
hash,
|
||||||
|
name: file.header.name.to_string().into(),
|
||||||
|
version: file.header.version.to_string(),
|
||||||
|
comment: Some(file.header.description.to_string().into()),
|
||||||
|
..Default::default()
|
||||||
|
})
|
||||||
|
.into(),
|
||||||
|
&session,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
//
|
||||||
|
Ok(output::none())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// This endpoint allows the user to export a model from the database.
|
||||||
|
async fn export(
|
||||||
|
Extension(session): Extension<Session>,
|
||||||
|
Path((name, version)): Path<(String, String)>,
|
||||||
|
) -> Result<impl IntoResponse, Error> {
|
||||||
|
// Get the datastore reference
|
||||||
|
let db = DB.get().unwrap();
|
||||||
|
// Ensure a NS and DB are set
|
||||||
|
let (nsv, dbv) = check_ns_db(&session)?;
|
||||||
|
// Check the permissions level
|
||||||
|
db.check(&session, View, Model.on_db(&nsv, &dbv))?;
|
||||||
|
// Start a new readonly transaction
|
||||||
|
let mut tx = db.transaction(Read, Optimistic).await?;
|
||||||
|
// Attempt to get the model definition
|
||||||
|
let info = tx.get_db_model(&nsv, &dbv, &name, &version).await?;
|
||||||
|
// Calculate the path of the model file
|
||||||
|
let path = format!("ml/{nsv}/{dbv}/{name}-{version}-{}.surml", info.hash);
|
||||||
|
// Export the file data in to the store
|
||||||
|
let mut data = surrealdb::obs::stream(path).await?;
|
||||||
|
// Create a chunked response
|
||||||
|
let (mut chn, body) = Body::channel();
|
||||||
|
// Process all stream values
|
||||||
|
tokio::spawn(async move {
|
||||||
|
while let Some(Ok(v)) = data.next().await {
|
||||||
|
let _ = chn.send_data(v).await;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
// Return the streamed body
|
||||||
|
Ok(Response::builder().status(StatusCode::OK).body(body).unwrap())
|
||||||
|
}
|
|
@ -17,6 +17,9 @@ mod sync;
|
||||||
mod tracer;
|
mod tracer;
|
||||||
mod version;
|
mod version;
|
||||||
|
|
||||||
|
#[cfg(feature = "ml")]
|
||||||
|
mod ml;
|
||||||
|
|
||||||
use axum::response::Redirect;
|
use axum::response::Redirect;
|
||||||
use axum::routing::get;
|
use axum::routing::get;
|
||||||
use axum::{middleware, Router};
|
use axum::{middleware, Router};
|
||||||
|
@ -150,8 +153,12 @@ pub async fn init(ct: CancellationToken) -> Result<(), Error> {
|
||||||
.merge(sql::router())
|
.merge(sql::router())
|
||||||
.merge(signin::router())
|
.merge(signin::router())
|
||||||
.merge(signup::router())
|
.merge(signup::router())
|
||||||
.merge(key::router())
|
.merge(key::router());
|
||||||
.layer(service);
|
|
||||||
|
#[cfg(feature = "ml")]
|
||||||
|
let axum_app = axum_app.merge(ml::router());
|
||||||
|
|
||||||
|
let axum_app = axum_app.layer(service);
|
||||||
|
|
||||||
// Setup the graceful shutdown
|
// Setup the graceful shutdown
|
||||||
let handle = Handle::new();
|
let handle = Handle::new();
|
||||||
|
|
|
@ -51,7 +51,7 @@ where
|
||||||
|
|
||||||
async fn handler(
|
async fn handler(
|
||||||
Extension(mut session): Extension<Session>,
|
Extension(mut session): Extension<Session>,
|
||||||
maybe_output: Option<TypedHeader<Accept>>,
|
accept: Option<TypedHeader<Accept>>,
|
||||||
body: Bytes,
|
body: Bytes,
|
||||||
) -> Result<impl IntoResponse, impl IntoResponse> {
|
) -> Result<impl IntoResponse, impl IntoResponse> {
|
||||||
// Get a database reference
|
// Get a database reference
|
||||||
|
@ -65,15 +65,15 @@ async fn handler(
|
||||||
match surrealdb::iam::signin::signin(kvs, &mut session, vars).await.map_err(Error::from)
|
match surrealdb::iam::signin::signin(kvs, &mut session, vars).await.map_err(Error::from)
|
||||||
{
|
{
|
||||||
// Authentication was successful
|
// Authentication was successful
|
||||||
Ok(v) => match maybe_output.as_deref() {
|
Ok(v) => match accept.as_deref() {
|
||||||
// Simple serialization
|
// Simple serialization
|
||||||
Some(Accept::ApplicationJson) => Ok(output::json(&Success::new(v))),
|
Some(Accept::ApplicationJson) => Ok(output::json(&Success::new(v))),
|
||||||
Some(Accept::ApplicationCbor) => Ok(output::cbor(&Success::new(v))),
|
Some(Accept::ApplicationCbor) => Ok(output::cbor(&Success::new(v))),
|
||||||
Some(Accept::ApplicationPack) => Ok(output::pack(&Success::new(v))),
|
Some(Accept::ApplicationPack) => Ok(output::pack(&Success::new(v))),
|
||||||
// Internal serialization
|
|
||||||
Some(Accept::Surrealdb) => Ok(output::full(&Success::new(v))),
|
|
||||||
// Text serialization
|
// Text serialization
|
||||||
Some(Accept::TextPlain) => Ok(output::text(v.unwrap_or_default())),
|
Some(Accept::TextPlain) => Ok(output::text(v.unwrap_or_default())),
|
||||||
|
// Internal serialization
|
||||||
|
Some(Accept::Surrealdb) => Ok(output::full(&Success::new(v))),
|
||||||
// Return nothing
|
// Return nothing
|
||||||
None => Ok(output::none()),
|
None => Ok(output::none()),
|
||||||
// An incorrect content-type was requested
|
// An incorrect content-type was requested
|
||||||
|
|
|
@ -49,7 +49,7 @@ where
|
||||||
|
|
||||||
async fn handler(
|
async fn handler(
|
||||||
Extension(mut session): Extension<Session>,
|
Extension(mut session): Extension<Session>,
|
||||||
maybe_output: Option<TypedHeader<Accept>>,
|
accept: Option<TypedHeader<Accept>>,
|
||||||
body: Bytes,
|
body: Bytes,
|
||||||
) -> Result<impl IntoResponse, impl IntoResponse> {
|
) -> Result<impl IntoResponse, impl IntoResponse> {
|
||||||
// Get a database reference
|
// Get a database reference
|
||||||
|
@ -63,15 +63,15 @@ async fn handler(
|
||||||
match surrealdb::iam::signup::signup(kvs, &mut session, vars).await.map_err(Error::from)
|
match surrealdb::iam::signup::signup(kvs, &mut session, vars).await.map_err(Error::from)
|
||||||
{
|
{
|
||||||
// Authentication was successful
|
// Authentication was successful
|
||||||
Ok(v) => match maybe_output.as_deref() {
|
Ok(v) => match accept.as_deref() {
|
||||||
// Simple serialization
|
// Simple serialization
|
||||||
Some(Accept::ApplicationJson) => Ok(output::json(&Success::new(v))),
|
Some(Accept::ApplicationJson) => Ok(output::json(&Success::new(v))),
|
||||||
Some(Accept::ApplicationCbor) => Ok(output::cbor(&Success::new(v))),
|
Some(Accept::ApplicationCbor) => Ok(output::cbor(&Success::new(v))),
|
||||||
Some(Accept::ApplicationPack) => Ok(output::pack(&Success::new(v))),
|
Some(Accept::ApplicationPack) => Ok(output::pack(&Success::new(v))),
|
||||||
// Internal serialization
|
|
||||||
Some(Accept::Surrealdb) => Ok(output::full(&Success::new(v))),
|
|
||||||
// Text serialization
|
// Text serialization
|
||||||
Some(Accept::TextPlain) => Ok(output::text(v.unwrap_or_default())),
|
Some(Accept::TextPlain) => Ok(output::text(v.unwrap_or_default())),
|
||||||
|
// Internal serialization
|
||||||
|
Some(Accept::Surrealdb) => Ok(output::full(&Success::new(v))),
|
||||||
// Return nothing
|
// Return nothing
|
||||||
None => Ok(output::none()),
|
None => Ok(output::none()),
|
||||||
// An incorrect content-type was requested
|
// An incorrect content-type was requested
|
||||||
|
|
BIN
tests/linear_test.surml
Normal file
BIN
tests/linear_test.surml
Normal file
Binary file not shown.
149
tests/ml_integration.rs
Normal file
149
tests/ml_integration.rs
Normal file
|
@ -0,0 +1,149 @@
|
||||||
|
// RUST_LOG=warn cargo make ci-ml-integration
|
||||||
|
mod common;
|
||||||
|
|
||||||
|
#[cfg(feature = "ml")]
|
||||||
|
mod ml_integration {
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
use http::{header, StatusCode};
|
||||||
|
use hyper::Body;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::sync::atomic::{AtomicBool, Ordering};
|
||||||
|
use std::time::Duration;
|
||||||
|
use surrealml_core::storage::stream_adapter::StreamAdapter;
|
||||||
|
use test_log::test;
|
||||||
|
use ulid::Ulid;
|
||||||
|
|
||||||
|
static LOCK: AtomicBool = AtomicBool::new(false);
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
|
struct Data {
|
||||||
|
result: f64,
|
||||||
|
status: String,
|
||||||
|
time: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct LockHandle;
|
||||||
|
|
||||||
|
impl LockHandle {
|
||||||
|
fn acquire_lock() -> Self {
|
||||||
|
while LOCK.compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
|
||||||
|
!= Ok(false)
|
||||||
|
{
|
||||||
|
std::thread::sleep(Duration::from_millis(100));
|
||||||
|
}
|
||||||
|
LockHandle
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for LockHandle {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
LOCK.store(false, Ordering::Release);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn upload_file(addr: &str, ns: &str, db: &str) -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
let generator = StreamAdapter::new(5, "./tests/linear_test.surml".to_string());
|
||||||
|
let body = Body::wrap_stream(generator);
|
||||||
|
// Prepare HTTP client
|
||||||
|
let mut headers = reqwest::header::HeaderMap::new();
|
||||||
|
headers.insert("NS", ns.parse()?);
|
||||||
|
headers.insert("DB", db.parse()?);
|
||||||
|
let client = reqwest::Client::builder()
|
||||||
|
.connect_timeout(Duration::from_secs(1))
|
||||||
|
.default_headers(headers)
|
||||||
|
.build()?;
|
||||||
|
// Send HTTP request
|
||||||
|
let res = client
|
||||||
|
.post(format!("http://{addr}/ml/import"))
|
||||||
|
.basic_auth(common::USER, Some(common::PASS))
|
||||||
|
.body(body)
|
||||||
|
.send()
|
||||||
|
.await?;
|
||||||
|
// Check response code
|
||||||
|
assert_eq!(res.status(), StatusCode::OK);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test(tokio::test)]
|
||||||
|
async fn upload_model() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
let _lock = LockHandle::acquire_lock();
|
||||||
|
let (addr, _server) = common::start_server_with_defaults().await.unwrap();
|
||||||
|
let ns = Ulid::new().to_string();
|
||||||
|
let db = Ulid::new().to_string();
|
||||||
|
upload_file(&addr, &ns, &db).await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test(tokio::test)]
|
||||||
|
async fn raw_compute() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
let _lock = LockHandle::acquire_lock();
|
||||||
|
let (addr, _server) = common::start_server_with_defaults().await.unwrap();
|
||||||
|
|
||||||
|
let ns = Ulid::new().to_string();
|
||||||
|
let db = Ulid::new().to_string();
|
||||||
|
|
||||||
|
upload_file(&addr, &ns, &db).await?;
|
||||||
|
|
||||||
|
// Prepare HTTP client
|
||||||
|
let mut headers = reqwest::header::HeaderMap::new();
|
||||||
|
headers.insert("NS", ns.parse()?);
|
||||||
|
headers.insert("DB", db.parse()?);
|
||||||
|
headers.insert(header::ACCEPT, "application/json".parse()?);
|
||||||
|
let client = reqwest::Client::builder()
|
||||||
|
.connect_timeout(Duration::from_millis(10))
|
||||||
|
.default_headers(headers)
|
||||||
|
.build()?;
|
||||||
|
|
||||||
|
// perform an SQL query to check if the model is available
|
||||||
|
{
|
||||||
|
let res = client
|
||||||
|
.post(format!("http://{addr}/sql"))
|
||||||
|
.basic_auth(common::USER, Some(common::PASS))
|
||||||
|
.body(r#"ml::Prediction<0.0.1>([1.0, 1.0]);"#)
|
||||||
|
.send()
|
||||||
|
.await?;
|
||||||
|
assert!(res.status().is_success(), "body: {}", res.text().await?);
|
||||||
|
let body = res.text().await?;
|
||||||
|
let deserialized_data: Vec<Data> = serde_json::from_str(&body)?;
|
||||||
|
assert_eq!(deserialized_data[0].result, 0.9998061656951904);
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test(tokio::test)]
|
||||||
|
async fn buffered_compute() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
let _lock = LockHandle::acquire_lock();
|
||||||
|
let (addr, _server) = common::start_server_with_defaults().await.unwrap();
|
||||||
|
|
||||||
|
let ns = Ulid::new().to_string();
|
||||||
|
let db = Ulid::new().to_string();
|
||||||
|
|
||||||
|
upload_file(&addr, &ns, &db).await?;
|
||||||
|
|
||||||
|
// Prepare HTTP client
|
||||||
|
let mut headers = reqwest::header::HeaderMap::new();
|
||||||
|
headers.insert("NS", ns.parse()?);
|
||||||
|
headers.insert("DB", db.parse()?);
|
||||||
|
headers.insert(header::ACCEPT, "application/json".parse()?);
|
||||||
|
let client = reqwest::Client::builder()
|
||||||
|
.connect_timeout(Duration::from_millis(10))
|
||||||
|
.default_headers(headers)
|
||||||
|
.build()?;
|
||||||
|
|
||||||
|
// perform an SQL query to check if the model is available
|
||||||
|
{
|
||||||
|
let res = client
|
||||||
|
.post(format!("http://{addr}/sql"))
|
||||||
|
.basic_auth(common::USER, Some(common::PASS))
|
||||||
|
.body(r#"ml::Prediction<0.0.1>({squarefoot: 500.0, num_floors: 1.0});"#)
|
||||||
|
.send()
|
||||||
|
.await?;
|
||||||
|
assert!(res.status().is_success(), "body: {}", res.text().await?);
|
||||||
|
let body = res.text().await?;
|
||||||
|
let deserialized_data: Vec<Data> = serde_json::from_str(&body)?;
|
||||||
|
assert_eq!(deserialized_data[0].result, 177206.21875);
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in a new issue