From 2ae841679111686eb4a0113861c9bd74ea74636d Mon Sep 17 00:00:00 2001 From: Maxwell Flitton Date: Tue, 12 Dec 2023 13:51:43 +0000 Subject: [PATCH] Add support for ML model storage and execution (#3015) --- .github/workflows/ci.yml | 28 +++ .gitignore | 2 + Cargo.lock | 187 ++++++++++++++++++++- Cargo.toml | 5 +- Makefile.ci.toml | 6 + Makefile.toml | 2 +- lib/Cargo.toml | 5 + lib/src/api/conn.rs | 11 ++ lib/src/api/engine/local/mod.rs | 137 ++++++++++++--- lib/src/api/engine/remote/http/mod.rs | 18 +- lib/src/api/method/export.rs | 46 ++++- lib/src/api/method/import.rs | 32 +++- lib/src/api/method/mod.rs | 7 + lib/src/api/opt/mod.rs | 2 +- lib/src/cnf/mod.rs | 7 +- lib/src/err/mod.rs | 21 +++ lib/src/iam/check.rs | 17 ++ lib/src/iam/entities/resources/resource.rs | 2 + lib/src/iam/mod.rs | 1 + lib/src/key/database/ml.rs | 89 ++++++++++ lib/src/key/database/mod.rs | 1 + lib/src/key/error.rs | 3 + lib/src/kvs/cache.rs | 3 + lib/src/kvs/ds.rs | 34 ++-- lib/src/kvs/rocksdb/cnf.rs | 63 +++---- lib/src/kvs/speedb/cnf.rs | 63 +++---- lib/src/kvs/tx.rs | 64 +++++++ lib/src/lib.rs | 4 + lib/src/obs/mod.rs | 96 +++++++++++ lib/src/sql/function.rs | 11 +- lib/src/sql/model.rs | 187 +++++++++++++++++++-- lib/src/sql/query.rs | 20 ++- lib/src/sql/statements/define/mod.rs | 6 +- lib/src/sql/statements/define/model.rs | 63 ++++--- lib/src/sql/statements/info.rs | 6 + lib/src/sql/statements/mod.rs | 12 +- lib/src/sql/statements/remove/mod.rs | 5 + lib/src/sql/statements/remove/model.rs | 47 ++++++ lib/src/sql/value/value.rs | 19 ++- lib/tests/api.rs | 1 + lib/tests/api/backup.rs | 14 ++ lib/tests/define.rs | 39 +++-- lib/tests/info.rs | 4 +- lib/tests/param.rs | 1 + lib/tests/remove.rs | 30 ++-- lib/tests/strict.rs | 1 + src/cli/export.rs | 4 +- src/cli/import.rs | 4 +- src/cli/ml/export.rs | 117 +++++++++++++ src/cli/ml/import.rs | 81 +++++++++ src/cli/ml/mod.rs | 22 +++ src/cli/mod.rs | 5 + src/cnf/mod.rs | 49 ++++-- src/err/mod.rs | 12 +- src/main.rs | 11 +- src/net/export.rs | 33 ++-- src/net/import.rs | 15 +- src/net/key.rs | 61 ++++--- src/net/ml.rs | 123 ++++++++++++++ src/net/mod.rs | 11 +- src/net/signin.rs | 8 +- src/net/signup.rs | 8 +- tests/linear_test.surml | Bin 0 -> 552 bytes tests/ml_integration.rs | 149 ++++++++++++++++ 64 files changed, 1815 insertions(+), 320 deletions(-) create mode 100644 lib/src/iam/check.rs create mode 100644 lib/src/key/database/ml.rs create mode 100644 lib/src/obs/mod.rs create mode 100644 lib/src/sql/statements/remove/model.rs create mode 100644 src/cli/ml/export.rs create mode 100644 src/cli/ml/import.rs create mode 100644 src/cli/ml/mod.rs create mode 100644 src/net/ml.rs create mode 100644 tests/linear_test.surml create mode 100644 tests/ml_integration.rs diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 4dbad463..f1e99b60 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -198,6 +198,34 @@ jobs: - name: Run HTTP integration tests run: cargo make ci-http-integration + + ml-support: + name: ML integration tests + runs-on: ubuntu-latest + steps: + + - name: Install stable toolchain + uses: dtolnay/rust-toolchain@stable + with: + toolchain: 1.71.1 + + - name: Checkout sources + uses: actions/checkout@v3 + + - name: Setup cache + uses: Swatinem/rust-cache@v2 + with: + save-if: ${{ github.ref == 'refs/heads/main' }} + + - name: Install dependencies + run: | + sudo apt-get -y update + + - name: Install cargo-make + run: cargo install --debug --locked cargo-make + + - name: Run ML integration tests + run: cargo make ci-ml-integration ws-server: name: WebSocket integration tests diff --git a/.gitignore b/.gitignore index 804d8dcd..96c2b960 100644 --- a/.gitignore +++ b/.gitignore @@ -45,5 +45,7 @@ Temporary Items # Specific # ----------------------------------- +/cache/ +/store/ surreal history.txt diff --git a/Cargo.lock b/Cargo.lock index 20338b21..1b852d09 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1813,6 +1813,18 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "filetime" +version = "0.2.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d4029edd3e734da6fe05b6cd7bd2960760a616bd2ddd0d59a0124746d6272af0" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall 0.3.5", + "windows-sys 0.48.0", +] + [[package]] name = "findshlibs" version = "0.10.2" @@ -2917,7 +2929,7 @@ checksum = "85c833ca1e66078851dba29046874e38f08b2c883700aa29a03ddd3b23814ee8" dependencies = [ "bitflags 2.4.1", "libc", - "redox_syscall", + "redox_syscall 0.4.1", ] [[package]] @@ -3036,6 +3048,16 @@ version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94" +[[package]] +name = "matrixmultiply" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7574c1cf36da4798ab73da5b215bbf444f50718207754cb522201d78d1cd0ff2" +dependencies = [ + "autocfg", + "rawpointer", +] + [[package]] name = "md-5" version = "0.10.6" @@ -3181,6 +3203,19 @@ dependencies = [ "tempfile", ] +[[package]] +name = "ndarray" +version = "0.15.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adb12d4e967ec485a5f71c6311fe28158e9d6f4bc4a447b474184d0f91a8fa32" +dependencies = [ + "matrixmultiply", + "num-complex", + "num-integer", + "num-traits", + "rawpointer", +] + [[package]] name = "new_debug_unreachable" version = "1.0.4" @@ -3272,6 +3307,15 @@ dependencies = [ "zeroize", ] +[[package]] +name = "num-complex" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ba157ca0885411de85d6ca030ba7e2a83a28636056c7c699b07c8b6f7383214" +dependencies = [ + "num-traits", +] + [[package]] name = "num-format" version = "0.4.4" @@ -3332,6 +3376,27 @@ dependencies = [ "memchr", ] +[[package]] +name = "object_store" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2524735495ea1268be33d200e1ee97455096a0846295a21548cd2f3541de7050" +dependencies = [ + "async-trait", + "bytes", + "chrono", + "futures", + "humantime", + "itertools 0.11.0", + "parking_lot", + "percent-encoding", + "snafu", + "tokio", + "tracing", + "url", + "walkdir", +] + [[package]] name = "once_cell" version = "1.18.0" @@ -3467,6 +3532,24 @@ dependencies = [ "tokio-stream", ] +[[package]] +name = "ort" +version = "1.16.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "889dca4c98efa21b1ba54ddb2bde44fd4920d910f492b618351f839d8428d79d" +dependencies = [ + "flate2", + "lazy_static", + "libc", + "libloading", + "ndarray", + "tar", + "thiserror", + "tracing", + "vswhom", + "winapi", +] + [[package]] name = "overload" version = "0.1.1" @@ -3497,7 +3580,7 @@ checksum = "4c42a9226546d68acdd9c0a280d17ce19bfe27a46bf68784e4066115788d008e" dependencies = [ "cfg-if", "libc", - "redox_syscall", + "redox_syscall 0.4.1", "smallvec", "windows-targets 0.48.5", ] @@ -4017,6 +4100,12 @@ dependencies = [ "rand_core 0.6.4", ] +[[package]] +name = "rawpointer" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" + [[package]] name = "rayon" version = "1.8.0" @@ -4049,6 +4138,15 @@ dependencies = [ "yasna", ] +[[package]] +name = "redox_syscall" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "567664f262709473930a4bf9e51bf2ebf3348f2e748ccc50dea20646858f8f29" +dependencies = [ + "bitflags 1.3.2", +] + [[package]] name = "redox_syscall" version = "0.4.1" @@ -4950,6 +5048,28 @@ dependencies = [ "serde", ] +[[package]] +name = "snafu" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e4de37ad025c587a29e8f3f5605c00f70b98715ef90b9061a815b9e59e9042d6" +dependencies = [ + "doc-comment", + "snafu-derive", +] + +[[package]] +name = "snafu-derive" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "990079665f075b699031e9c08fd3ab99be5029b96f3b78dc0709e8f77e4efebf" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "snap" version = "1.1.0" @@ -5120,6 +5240,7 @@ dependencies = [ "ipnet", "jemallocator", "mimalloc", + "ndarray", "nix 0.27.1", "once_cell", "opentelemetry", @@ -5136,6 +5257,7 @@ dependencies = [ "serde_json", "serial_test", "surrealdb", + "surrealml-core", "temp-env", "tempfile", "test-log", @@ -5184,6 +5306,7 @@ dependencies = [ "futures-concurrency", "fuzzy-matcher", "geo 0.27.0", + "hex", "indexmap 2.1.0", "indxdb", "ipnet", @@ -5192,8 +5315,10 @@ dependencies = [ "md-5", "nanoid", "native-tls", + "ndarray", "nom", "num_cpus", + "object_store", "once_cell", "path-clean", "pbkdf2", @@ -5224,6 +5349,7 @@ dependencies = [ "surrealdb-derive", "surrealdb-jsonwebtoken", "surrealdb-tikv-client", + "surrealml-core", "temp-dir", "test-log", "thiserror", @@ -5299,6 +5425,21 @@ dependencies = [ "tonic 0.9.2", ] +[[package]] +name = "surrealml-core" +version = "0.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2aabe3c44a73d6f7a3d069400a4966d2ffcf643fb1d60399b1746787abc9dd12" +dependencies = [ + "bytes", + "futures-core", + "futures-util", + "ndarray", + "once_cell", + "ort", + "regex", +] + [[package]] name = "symbolic-common" version = "12.7.0" @@ -5389,6 +5530,17 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" +[[package]] +name = "tar" +version = "0.4.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b16afcea1f22891c49a00c751c7b63b2233284064f11a200fc624137c51e2ddb" +dependencies = [ + "filetime", + "libc", + "xattr", +] + [[package]] name = "temp-dir" version = "0.1.11" @@ -5413,7 +5565,7 @@ checksum = "7ef1adac450ad7f4b3c28589471ade84f25f731a7a0fe30d71dfa9f60fd808e5" dependencies = [ "cfg-if", "fastrand 2.0.1", - "redox_syscall", + "redox_syscall 0.4.1", "rustix", "windows-sys 0.48.0", ] @@ -6103,6 +6255,26 @@ version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" +[[package]] +name = "vswhom" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be979b7f07507105799e854203b470ff7c78a1639e330a58f183b5fea574608b" +dependencies = [ + "libc", + "vswhom-sys", +] + +[[package]] +name = "vswhom-sys" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3b17ae1f6c8a2b28506cd96d412eebf83b4a0ff2cbefeeb952f2f9dfa44ba18" +dependencies = [ + "cc", + "libc", +] + [[package]] name = "waker-fn" version = "1.1.1" @@ -6501,6 +6673,15 @@ dependencies = [ "tap", ] +[[package]] +name = "xattr" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f4686009f71ff3e5c4dbcf1a282d0a44db3f021ba69350cd42086b3e5f1c6985" +dependencies = [ + "libc", +] + [[package]] name = "xml-rs" version = "0.8.19" diff --git a/Cargo.toml b/Cargo.toml index cfa62147..3cda9b3b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,7 +8,7 @@ authors = ["Tobie Morgan Hitchcock "] [features] # Public features -default = ["storage-mem", "storage-rocksdb", "scripting", "http"] +default = ["storage-mem", "storage-rocksdb", "scripting", "http", "ml"] storage-mem = ["surrealdb/kv-mem", "has-storage"] storage-rocksdb = ["surrealdb/kv-rocksdb", "has-storage"] storage-speedb = ["surrealdb/kv-speedb", "has-storage"] @@ -17,6 +17,7 @@ storage-fdb = ["surrealdb/kv-fdb-7_1", "has-storage"] scripting = ["surrealdb/scripting"] http = ["surrealdb/http"] http-compression = [] +ml = ["surrealdb/ml", "surrealml-core"] # Private features has-storage = [] @@ -49,6 +50,7 @@ http = "0.2.11" http-body = "0.4.5" hyper = "0.14.27" ipnet = "2.9.0" +ndarray = { version = "0.15.6", optional = true } once_cell = "1.18.0" opentelemetry = { version = "0.19", features = ["rt-tokio"] } opentelemetry-otlp = { version = "0.12.0", features = ["metrics"] } @@ -61,6 +63,7 @@ serde_cbor = "0.11.2" serde_json = "1.0.108" serde_pack = { version = "1.1.2", package = "rmp-serde" } surrealdb = { path = "lib", features = ["protocol-http", "protocol-ws", "rustls"] } +surrealml-core = { version = "0.0.2", optional = true} tempfile = "3.8.1" thiserror = "1.0.50" tokio = { version = "1.34.0", features = ["macros", "signal"] } diff --git a/Makefile.ci.toml b/Makefile.ci.toml index 905a3789..bfbfa208 100644 --- a/Makefile.ci.toml +++ b/Makefile.ci.toml @@ -39,6 +39,12 @@ command = "cargo" env = { RUST_LOG={ value = "ws_integration=debug", condition = { env_not_set = ["RUST_LOG"] } } } args = ["test", "--locked", "--no-default-features", "--features", "storage-mem", "--workspace", "--test", "ws_integration", "--", "ws_integration", "--nocapture"] +[tasks.ci-ml-integration] +category = "CI - INTEGRATION TESTS" +command = "cargo" +env = { RUST_LOG={ value = "cli_integration::common=debug", condition = { env_not_set = ["RUST_LOG"] } } } +args = ["test", "--locked", "--features", "storage-mem,ml", "--workspace", "--test", "ml_integration", "--", "ml_integration", "--nocapture"] + [tasks.ci-workspace-coverage] category = "CI - INTEGRATION TESTS" command = "cargo" diff --git a/Makefile.toml b/Makefile.toml index b10ce6ad..05952d12 100644 --- a/Makefile.toml +++ b/Makefile.toml @@ -10,7 +10,7 @@ reduce_output = true default_to_workspace = false [env] -DEV_FEATURES={ value = "storage-mem,http,scripting", condition = { env_not_set = ["DEV_FEATURES"] } } +DEV_FEATURES={ value = "storage-mem,scripting,http,ml", condition = { env_not_set = ["DEV_FEATURES"] } } SURREAL_LOG={ value = "trace", condition = { env_not_set = ["SURREAL_LOG"] } } SURREAL_USER={ value = "root", condition = { env_not_set = ["SURREAL_USER"] } } SURREAL_PASS={ value = "root", condition = { env_not_set = ["SURREAL_PASS"] } } diff --git a/lib/Cargo.toml b/lib/Cargo.toml index 75cad02e..fdf1b3f4 100644 --- a/lib/Cargo.toml +++ b/lib/Cargo.toml @@ -37,6 +37,7 @@ scripting = ["dep:js"] http = ["dep:reqwest"] native-tls = ["dep:native-tls", "reqwest?/native-tls", "tokio-tungstenite?/native-tls"] rustls = ["dep:rustls", "reqwest?/rustls-tls", "tokio-tungstenite?/rustls-tls-webpki-roots"] +ml = ["surrealml-core", "ndarray"] # Private features kv-fdb = ["foundationdb", "tokio/time"] @@ -74,6 +75,7 @@ futures = "0.3.29" futures-concurrency = "7.4.3" fuzzy-matcher = "0.3.7" geo = { version = "0.27.0", features = ["use-serde"] } +hex = { version = "0.4.3", optional = false } indexmap = { version = "2.1.0", features = ["serde"] } indxdb = { version = "0.4.0", optional = true } ipnet = "2.9.0" @@ -84,8 +86,10 @@ lru = "0.12.1" md-5 = "0.10.6" nanoid = "0.4.0" native-tls = { version = "0.2.11", optional = true } +ndarray = { version = "0.15.6", optional = true } nom = { version = "7.1.3", features = ["alloc"] } num_cpus = "1.16.0" +object_store = { version = "0.8.0", optional = false } once_cell = "1.18.0" path-clean = "1.0.1" pbkdf2 = { version = "0.12.2", features = ["simple"] } @@ -109,6 +113,7 @@ sha2 = "0.10.8" snap = "1.1.0" speedb = { version = "0.0.4", features = ["lz4", "snappy"], optional = true } storekey = "0.5.0" +surrealml-core = { version = "0.0.2", optional = true } thiserror = "1.0.50" tikv = { version = "0.2.0-surreal.2", default-features = false, package = "surrealdb-tikv-client", optional = true } tokio-util = { version = "0.7.10", optional = true, features = ["compat"] } diff --git a/lib/src/api/conn.rs b/lib/src/api/conn.rs index 23bb7f4f..76b98086 100644 --- a/lib/src/api/conn.rs +++ b/lib/src/api/conn.rs @@ -113,6 +113,16 @@ pub enum DbResponse { Other(Value), } +#[derive(Debug)] +#[allow(dead_code)] // used by ML model import and export functions +pub(crate) enum MlConfig { + Import, + Export { + name: String, + version: String, + }, +} + /// Holds the parameters given to the caller #[derive(Debug, Default)] #[allow(dead_code)] // used by the embedded and remote connections @@ -122,6 +132,7 @@ pub struct Param { pub(crate) file: Option, pub(crate) bytes_sender: Option>>>, pub(crate) notification_sender: Option>, + pub(crate) ml_config: Option, } impl Param { diff --git a/lib/src/api/engine/local/mod.rs b/lib/src/api/engine/local/mod.rs index 09d02e4d..d7dd4959 100644 --- a/lib/src/api/engine/local/mod.rs +++ b/lib/src/api/engine/local/mod.rs @@ -28,6 +28,8 @@ pub(crate) mod wasm; use crate::api::conn::DbResponse; use crate::api::conn::Method; +#[cfg(not(target_arch = "wasm32"))] +use crate::api::conn::MlConfig; use crate::api::conn::Param; use crate::api::engine::create_statement; use crate::api::engine::delete_statement; @@ -44,9 +46,27 @@ use crate::api::Surreal; use crate::dbs::Notification; use crate::dbs::Response; use crate::dbs::Session; +#[cfg(feature = "ml")] +#[cfg(not(target_arch = "wasm32"))] +use crate::iam::check::check_ns_db; +#[cfg(feature = "ml")] +#[cfg(not(target_arch = "wasm32"))] +use crate::iam::Action; +#[cfg(feature = "ml")] +#[cfg(not(target_arch = "wasm32"))] +use crate::iam::ResourceKind; use crate::kvs::Datastore; +#[cfg(feature = "ml")] +#[cfg(not(target_arch = "wasm32"))] +use crate::kvs::{LockType, TransactionType}; use crate::method::Stats; use crate::opt::IntoEndpoint; +#[cfg(feature = "ml")] +#[cfg(not(target_arch = "wasm32"))] +use crate::sql::statements::DefineModelStatement; +#[cfg(feature = "ml")] +#[cfg(not(target_arch = "wasm32"))] +use crate::sql::statements::DefineStatement; use crate::sql::statements::KillStatement; use crate::sql::Array; use crate::sql::Query; @@ -56,6 +76,9 @@ use crate::sql::Strand; use crate::sql::Uuid; use crate::sql::Value; use channel::Sender; +#[cfg(feature = "ml")] +#[cfg(not(target_arch = "wasm32"))] +use futures::StreamExt; use indexmap::IndexMap; use std::collections::BTreeMap; use std::collections::HashMap; @@ -65,6 +88,9 @@ use std::mem; use std::path::PathBuf; use std::sync::Arc; use std::time::Duration; +#[cfg(feature = "ml")] +#[cfg(not(target_arch = "wasm32"))] +use surrealml_core::storage::surml_file::SurMlFile; #[cfg(not(target_arch = "wasm32"))] use tokio::fs::OpenOptions; #[cfg(not(target_arch = "wasm32"))] @@ -405,17 +431,42 @@ async fn take(one: bool, responses: Vec) -> Result { async fn export( kvs: &Datastore, sess: &Session, - ns: String, - db: String, chn: channel::Sender>, + ml_config: Option, ) -> Result<()> { - if let Err(error) = kvs.export(sess, ns, db, chn).await?.await { - if let crate::error::Db::Channel(message) = error { - // This is not really an error. Just logging it for improved visibility. - trace!("{message}"); - return Ok(()); + match ml_config { + #[cfg(feature = "ml")] + Some(MlConfig::Export { + name, + version, + }) => { + // Ensure a NS and DB are set + let (nsv, dbv) = check_ns_db(sess)?; + // Check the permissions level + kvs.check(sess, Action::View, ResourceKind::Model.on_db(&nsv, &dbv))?; + // Start a new readonly transaction + let mut tx = kvs.transaction(TransactionType::Read, LockType::Optimistic).await?; + // Attempt to get the model definition + let info = tx.get_db_model(&nsv, &dbv, &name, &version).await?; + // Export the file data in to the store + let mut data = crate::obs::stream(info.hash.to_owned()).await?; + // Process all stream values + while let Some(Ok(bytes)) = data.next().await { + if chn.send(bytes.to_vec()).await.is_err() { + break; + } + } + } + _ => { + if let Err(error) = kvs.export(sess, chn).await?.await { + if let crate::error::Db::Channel(message) = error { + // This is not really an error. Just logging it for improved visibility. + trace!("{message}"); + return Ok(()); + } + return Err(error.into()); + } } - return Err(error.into()); } Ok(()) } @@ -563,8 +614,6 @@ async fn router( Method::Export | Method::Import => unreachable!(), #[cfg(not(target_arch = "wasm32"))] Method::Export => { - let ns = session.ns.clone().unwrap_or_default(); - let db = session.db.clone().unwrap_or_default(); let (tx, rx) = crate::channel::bounded(1); match (param.file, param.bytes_sender) { @@ -572,7 +621,7 @@ async fn router( let (mut writer, mut reader) = io::duplex(10_240); // Write to channel. - let export = export(kvs, session, ns, db, tx); + let export = export(kvs, session, tx, param.ml_config); // Read from channel and write to pipe. let bridge = async move { @@ -613,7 +662,7 @@ async fn router( let session = session.clone(); tokio::spawn(async move { let export = async { - if let Err(error) = export(&kvs, &session, ns, db, tx).await { + if let Err(error) = export(&kvs, &session, tx, param.ml_config).await { let _ = backup.send(Err(error)).await; } }; @@ -647,15 +696,63 @@ async fn router( .into()); } }; - let mut statements = String::new(); - if let Err(error) = file.read_to_string(&mut statements).await { - return Err(Error::FileRead { - path, - error, + let responses = match param.ml_config { + #[cfg(feature = "ml")] + Some(MlConfig::Import) => { + // Ensure a NS and DB are set + let (nsv, dbv) = check_ns_db(session)?; + // Check the permissions level + kvs.check(session, Action::Edit, ResourceKind::Model.on_db(&nsv, &dbv))?; + // Create a new buffer + let mut buffer = Vec::new(); + // Load all the uploaded file chunks + if let Err(error) = file.read_to_end(&mut buffer).await { + return Err(Error::FileRead { + path, + error, + } + .into()); + } + // Check that the SurrealML file is valid + let file = match SurMlFile::from_bytes(buffer) { + Ok(file) => file, + Err(error) => { + return Err(Error::FileRead { + path, + error, + } + .into()); + } + }; + // Convert the file back in to raw bytes + let data = file.to_bytes(); + // Calculate the hash of the model file + let hash = crate::obs::hash(&data); + // Insert the file data in to the store + crate::obs::put(&hash, data).await?; + // Insert the model in to the database + let query = DefineStatement::Model(DefineModelStatement { + hash, + name: file.header.name.to_string().into(), + version: file.header.version.to_string(), + comment: Some(file.header.description.to_string().into()), + ..Default::default() + }) + .into(); + kvs.process(query, session, Some(vars.clone())).await? } - .into()); - } - let responses = kvs.execute(&statements, &*session, Some(vars.clone())).await?; + _ => { + let mut statements = String::new(); + if let Err(error) = file.read_to_string(&mut statements).await { + return Err(Error::FileRead { + path, + error, + } + .into()); + } + kvs.execute(&statements, &*session, Some(vars.clone())).await? + } + }; for response in responses { response.result?; } diff --git a/lib/src/api/engine/remote/http/mod.rs b/lib/src/api/engine/remote/http/mod.rs index ce4d4431..c41fd621 100644 --- a/lib/src/api/engine/remote/http/mod.rs +++ b/lib/src/api/engine/remote/http/mod.rs @@ -7,6 +7,9 @@ pub(crate) mod wasm; use crate::api::conn::DbResponse; use crate::api::conn::Method; +#[cfg(feature = "ml")] +#[cfg(not(target_arch = "wasm32"))] +use crate::api::conn::MlConfig; use crate::api::conn::Param; use crate::api::engine::create_statement; use crate::api::engine::delete_statement; @@ -516,7 +519,14 @@ async fn router( Method::Export | Method::Import => unreachable!(), #[cfg(not(target_arch = "wasm32"))] Method::Export => { - let path = base_url.join(Method::Export.as_str())?; + let path = match param.ml_config { + #[cfg(feature = "ml")] + Some(MlConfig::Export { + name, + version, + }) => base_url.join(&format!("ml/export/{name}/{version}"))?, + _ => base_url.join(Method::Export.as_str())?, + }; let request = client .get(path) .headers(headers.clone()) @@ -527,7 +537,11 @@ async fn router( } #[cfg(not(target_arch = "wasm32"))] Method::Import => { - let path = base_url.join(Method::Import.as_str())?; + let path = match param.ml_config { + #[cfg(feature = "ml")] + Some(MlConfig::Import) => base_url.join("ml/import")?, + _ => base_url.join(Method::Import.as_str())?, + }; let file = param.file.expect("file to import from"); let request = client .post(path) diff --git a/lib/src/api/method/export.rs b/lib/src/api/method/export.rs index 69b263fc..357a3612 100644 --- a/lib/src/api/method/export.rs +++ b/lib/src/api/method/export.rs @@ -1,15 +1,18 @@ use crate::api::conn::Method; +use crate::api::conn::MlConfig; use crate::api::conn::Param; use crate::api::Connection; use crate::api::Error; use crate::api::ExtraFeatures; use crate::api::Result; +use crate::method::Model; use crate::method::OnceLockExt; use crate::opt::ExportDestination; use crate::Surreal; use channel::Receiver; use futures::Stream; use futures::StreamExt; +use semver::Version; use std::borrow::Cow; use std::future::Future; use std::future::IntoFuture; @@ -22,18 +25,39 @@ use std::task::Poll; /// A database export future #[derive(Debug)] #[must_use = "futures do nothing unless you `.await` or poll them"] -pub struct Export<'r, C: Connection, R> { +pub struct Export<'r, C: Connection, R, T = ()> { pub(super) client: Cow<'r, Surreal>, pub(super) target: ExportDestination, + pub(super) ml_config: Option, pub(super) response: PhantomData, + pub(super) export_type: PhantomData, } -impl Export<'_, C, R> +impl<'r, C, R> Export<'r, C, R> +where + C: Connection, +{ + /// Export machine learning model + pub fn ml(self, name: &str, version: Version) -> Export<'r, C, R, Model> { + Export { + client: self.client, + target: self.target, + ml_config: Some(MlConfig::Export { + name: name.to_owned(), + version: version.to_string(), + }), + response: self.response, + export_type: PhantomData, + } + } +} + +impl Export<'_, C, R, T> where C: Connection, { /// Converts to an owned type which can easily be moved to a different thread - pub fn into_owned(self) -> Export<'static, C, R> { + pub fn into_owned(self) -> Export<'static, C, R, T> { Export { client: Cow::Owned(self.client.into_owned()), ..self @@ -41,7 +65,7 @@ where } } -impl<'r, Client> IntoFuture for Export<'r, Client, PathBuf> +impl<'r, Client, T> IntoFuture for Export<'r, Client, PathBuf, T> where Client: Connection, { @@ -55,15 +79,17 @@ where return Err(Error::BackupsNotSupported.into()); } let mut conn = Client::new(Method::Export); - match self.target { - ExportDestination::File(path) => conn.execute_unit(router, Param::file(path)).await, + let mut param = match self.target { + ExportDestination::File(path) => Param::file(path), ExportDestination::Memory => unreachable!(), - } + }; + param.ml_config = self.ml_config; + conn.execute_unit(router, param).await }) } } -impl<'r, Client> IntoFuture for Export<'r, Client, ()> +impl<'r, Client, T> IntoFuture for Export<'r, Client, (), T> where Client: Connection, { @@ -81,7 +107,9 @@ where let ExportDestination::Memory = self.target else { unreachable!(); }; - conn.execute_unit(router, Param::bytes_sender(tx)).await?; + let mut param = Param::bytes_sender(tx); + param.ml_config = self.ml_config; + conn.execute_unit(router, param).await?; Ok(Backup { rx, }) diff --git a/lib/src/api/method/import.rs b/lib/src/api/method/import.rs index aeb77214..c4f3ee64 100644 --- a/lib/src/api/method/import.rs +++ b/lib/src/api/method/import.rs @@ -1,31 +1,51 @@ use crate::api::conn::Method; +use crate::api::conn::MlConfig; use crate::api::conn::Param; use crate::api::Connection; use crate::api::Error; use crate::api::ExtraFeatures; use crate::api::Result; +use crate::method::Model; use crate::method::OnceLockExt; use crate::Surreal; use std::borrow::Cow; use std::future::Future; use std::future::IntoFuture; +use std::marker::PhantomData; use std::path::PathBuf; use std::pin::Pin; /// An database import future #[derive(Debug)] #[must_use = "futures do nothing unless you `.await` or poll them"] -pub struct Import<'r, C: Connection> { +pub struct Import<'r, C: Connection, T = ()> { pub(super) client: Cow<'r, Surreal>, pub(super) file: PathBuf, + pub(super) ml_config: Option, + pub(super) import_type: PhantomData, } -impl Import<'_, C> +impl<'r, C> Import<'r, C> +where + C: Connection, +{ + /// Import machine learning model + pub fn ml(self) -> Import<'r, C, Model> { + Import { + client: self.client, + file: self.file, + ml_config: Some(MlConfig::Import), + import_type: PhantomData, + } + } +} + +impl<'r, C, T> Import<'r, C, T> where C: Connection, { /// Converts to an owned type which can easily be moved to a different thread - pub fn into_owned(self) -> Import<'static, C> { + pub fn into_owned(self) -> Import<'static, C, T> { Import { client: Cow::Owned(self.client.into_owned()), ..self @@ -33,7 +53,7 @@ where } } -impl<'r, Client> IntoFuture for Import<'r, Client> +impl<'r, Client, T> IntoFuture for Import<'r, Client, T> where Client: Connection, { @@ -47,7 +67,9 @@ where return Err(Error::BackupsNotSupported.into()); } let mut conn = Client::new(Method::Import); - conn.execute_unit(router, Param::file(self.file)).await + let mut param = Param::file(self.file); + param.ml_config = self.ml_config; + conn.execute_unit(router, param).await }) } } diff --git a/lib/src/api/method/mod.rs b/lib/src/api/method/mod.rs index c52eda3b..45aa9504 100644 --- a/lib/src/api/method/mod.rs +++ b/lib/src/api/method/mod.rs @@ -90,6 +90,9 @@ pub struct Stats { pub execution_time: Duration, } +/// Machine learning model marker type for import and export types +pub struct Model; + /// Responses returned with statistics #[derive(Debug)] pub struct WithStats(T); @@ -1004,7 +1007,9 @@ where Export { client: Cow::Borrowed(self), target: target.into_export_destination(), + ml_config: None, response: PhantomData, + export_type: PhantomData, } } @@ -1034,6 +1039,8 @@ where Import { client: Cow::Borrowed(self), file: file.as_ref().to_owned(), + ml_config: None, + import_type: PhantomData, } } } diff --git a/lib/src/api/opt/mod.rs b/lib/src/api/opt/mod.rs index 8b1f38a4..3de16983 100644 --- a/lib/src/api/opt/mod.rs +++ b/lib/src/api/opt/mod.rs @@ -409,7 +409,7 @@ fn into_json(value: Value, simplify: bool) -> JsonValue { }, Value::Cast(cast) => json!(cast), Value::Function(function) => json!(function), - Value::MlModel(model) => json!(model), + Value::Model(model) => json!(model), Value::Query(query) => json!(query), Value::Subquery(subquery) => json!(subquery), Value::Expression(expression) => json!(expression), diff --git a/lib/src/cnf/mod.rs b/lib/src/cnf/mod.rs index 0cface7c..8f6f619e 100644 --- a/lib/src/cnf/mod.rs +++ b/lib/src/cnf/mod.rs @@ -38,8 +38,7 @@ pub const PROCESSOR_BATCH_SIZE: u32 = 50; /// Forward all signup/signin query errors to a client trying authenticate to a scope. Do not use in production. pub static INSECURE_FORWARD_SCOPE_ERRORS: Lazy = Lazy::new(|| { - let default = false; - std::env::var("SURREAL_INSECURE_FORWARD_SCOPE_ERRORS") - .map(|v| v.parse::().unwrap_or(default)) - .unwrap_or(default) + option_env!("SURREAL_INSECURE_FORWARD_SCOPE_ERRORS") + .and_then(|s| s.parse::().ok()) + .unwrap_or(false) }); diff --git a/lib/src/err/mod.rs b/lib/src/err/mod.rs index cf55a4d8..56725aa3 100644 --- a/lib/src/err/mod.rs +++ b/lib/src/err/mod.rs @@ -12,6 +12,7 @@ use base64_lib::DecodeError as Base64Error; use bincode::Error as BincodeError; use fst::Error as FstError; use jsonwebtoken::errors::Error as JWTError; +use object_store::Error as ObjectStoreError; use revision::Error as RevisionError; use serde::Serialize; use std::io::Error as IoError; @@ -194,6 +195,12 @@ pub enum Error { message: String, }, + /// There was an error with the provided machine learning model + #[error("Problem with machine learning computation. {message}")] + InvalidModel { + message: String, + }, + /// There was a problem running the specified function #[error("There was a problem running the {name}() function. {message}")] InvalidFunction { @@ -316,6 +323,12 @@ pub enum Error { value: String, }, + /// The requested model does not exist + #[error("The model 'ml::{value}' does not exist")] + MlNotFound { + value: String, + }, + /// The requested scope does not exist #[error("The scope '{value}' does not exist")] ScNotFound { @@ -635,6 +648,14 @@ pub enum Error { #[error("Utf8 error: {0}")] Utf8Error(#[from] FromUtf8Error), + /// Represents an underlying error with the Object Store + #[error("Object Store error: {0}")] + ObsError(#[from] ObjectStoreError), + + /// There was an error with model computation + #[error("There was an error with model computation: {0}")] + ModelComputation(String), + /// The feature has not yet being implemented #[error("Feature not yet implemented: {feature}")] FeatureNotYetImplemented { diff --git a/lib/src/iam/check.rs b/lib/src/iam/check.rs new file mode 100644 index 00000000..12b92594 --- /dev/null +++ b/lib/src/iam/check.rs @@ -0,0 +1,17 @@ +use crate::dbs::Session; +use crate::err::Error; + +pub fn check_ns_db(sess: &Session) -> Result<(String, String), Error> { + // Ensure that a namespace was specified + let ns = match sess.ns.clone() { + Some(ns) => ns, + None => return Err(Error::NsEmpty), + }; + // Ensure that a database was specified + let db = match sess.db.clone() { + Some(db) => db, + None => return Err(Error::DbEmpty), + }; + // All ok + Ok((ns, db)) +} diff --git a/lib/src/iam/entities/resources/resource.rs b/lib/src/iam/entities/resources/resource.rs index c0203cf7..c71a33d5 100644 --- a/lib/src/iam/entities/resources/resource.rs +++ b/lib/src/iam/entities/resources/resource.rs @@ -23,6 +23,7 @@ pub enum ResourceKind { Function, Analyzer, Parameter, + Model, Event, Field, Index, @@ -44,6 +45,7 @@ impl std::fmt::Display for ResourceKind { ResourceKind::Function => write!(f, "Function"), ResourceKind::Analyzer => write!(f, "Analyzer"), ResourceKind::Parameter => write!(f, "Parameter"), + ResourceKind::Model => write!(f, "Model"), ResourceKind::Event => write!(f, "Event"), ResourceKind::Field => write!(f, "Field"), ResourceKind::Index => write!(f, "Index"), diff --git a/lib/src/iam/mod.rs b/lib/src/iam/mod.rs index cc7f8038..d161ddfe 100644 --- a/lib/src/iam/mod.rs +++ b/lib/src/iam/mod.rs @@ -4,6 +4,7 @@ use thiserror::Error; pub mod auth; pub mod base; +pub mod check; pub mod clear; pub mod entities; pub mod policies; diff --git a/lib/src/key/database/ml.rs b/lib/src/key/database/ml.rs new file mode 100644 index 00000000..05f8525d --- /dev/null +++ b/lib/src/key/database/ml.rs @@ -0,0 +1,89 @@ +/// Stores a DEFINE MODEL config definition +use crate::key::error::KeyCategory; +use crate::key::key_req::KeyRequirements; +use derive::Key; +use serde::{Deserialize, Serialize}; + +#[derive(Clone, Debug, Eq, PartialEq, PartialOrd, Serialize, Deserialize, Key)] +pub struct Ml<'a> { + __: u8, + _a: u8, + pub ns: &'a str, + _b: u8, + pub db: &'a str, + _c: u8, + _d: u8, + _e: u8, + pub ml: &'a str, + pub vn: &'a str, +} + +pub fn new<'a>(ns: &'a str, db: &'a str, ml: &'a str, vn: &'a str) -> Ml<'a> { + Ml::new(ns, db, ml, vn) +} + +pub fn prefix(ns: &str, db: &str) -> Vec { + let mut k = super::all::new(ns, db).encode().unwrap(); + k.extend_from_slice(&[b'!', b'm', b'l', 0x00]); + k +} + +pub fn suffix(ns: &str, db: &str) -> Vec { + let mut k = super::all::new(ns, db).encode().unwrap(); + k.extend_from_slice(&[b'!', b'm', b'l', 0xff]); + k +} + +impl KeyRequirements for Ml<'_> { + fn key_category(&self) -> KeyCategory { + KeyCategory::DatabaseModel + } +} + +impl<'a> Ml<'a> { + pub fn new(ns: &'a str, db: &'a str, ml: &'a str, vn: &'a str) -> Self { + Self { + __: b'/', + _a: b'*', + ns, + _b: b'*', + db, + _c: b'!', + _d: b'm', + _e: b'l', + ml, + vn, + } + } +} + +#[cfg(test)] +mod tests { + #[test] + fn key() { + use super::*; + #[rustfmt::skip] + let val = Ml::new( + "testns", + "testdb", + "testml", + "1.0.0", + ); + let enc = Ml::encode(&val).unwrap(); + assert_eq!(enc, b"/*testns\x00*testdb\x00!mltestml\x001.0.0\x00"); + let dec = Ml::decode(&enc).unwrap(); + assert_eq!(val, dec); + } + + #[test] + fn test_prefix() { + let val = super::prefix("testns", "testdb"); + assert_eq!(val, b"/*testns\0*testdb\0!ml\0"); + } + + #[test] + fn test_suffix() { + let val = super::suffix("testns", "testdb"); + assert_eq!(val, b"/*testns\0*testdb\0!ml\xff"); + } +} diff --git a/lib/src/key/database/mod.rs b/lib/src/key/database/mod.rs index 2c13da86..ec74c794 100644 --- a/lib/src/key/database/mod.rs +++ b/lib/src/key/database/mod.rs @@ -1,6 +1,7 @@ pub mod all; pub mod az; pub mod fc; +pub mod ml; pub mod pa; pub mod sc; pub mod tb; diff --git a/lib/src/key/error.rs b/lib/src/key/error.rs index b5e61187..a4dd4b09 100644 --- a/lib/src/key/error.rs +++ b/lib/src/key/error.rs @@ -44,6 +44,8 @@ pub enum KeyCategory { DatabaseFunction, /// crate::key::database::lg /*{ns}*{db}!lg{lg} DatabaseLog, + /// crate::key::database::ml /*{ns}*{db}!ml{ml}{vn} + DatabaseModel, /// crate::key::database::pa /*{ns}*{db}!pa{pa} DatabaseParameter, /// crate::key::database::sc /*{ns}*{db}!sc{sc} @@ -138,6 +140,7 @@ impl Display for KeyCategory { KeyCategory::DatabaseAnalyzer => "DatabaseAnalyzer", KeyCategory::DatabaseFunction => "DatabaseFunction", KeyCategory::DatabaseLog => "DatabaseLog", + KeyCategory::DatabaseModel => "DatabaseModel", KeyCategory::DatabaseParameter => "DatabaseParameter", KeyCategory::DatabaseScope => "DatabaseScope", KeyCategory::DatabaseTable => "DatabaseTable", diff --git a/lib/src/kvs/cache.rs b/lib/src/kvs/cache.rs index 0e9f5e73..42c16914 100644 --- a/lib/src/kvs/cache.rs +++ b/lib/src/kvs/cache.rs @@ -6,6 +6,7 @@ use crate::sql::statements::DefineEventStatement; use crate::sql::statements::DefineFieldStatement; use crate::sql::statements::DefineFunctionStatement; use crate::sql::statements::DefineIndexStatement; +use crate::sql::statements::DefineModelStatement; use crate::sql::statements::DefineNamespaceStatement; use crate::sql::statements::DefineParamStatement; use crate::sql::statements::DefineScopeStatement; @@ -22,6 +23,7 @@ pub enum Entry { Db(Arc), Fc(Arc), Ix(Arc), + Ml(Arc), Ns(Arc), Pa(Arc), Tb(Arc), @@ -36,6 +38,7 @@ pub enum Entry { Fts(Arc<[DefineTableStatement]>), Ixs(Arc<[DefineIndexStatement]>), Lvs(Arc<[LiveStatement]>), + Mls(Arc<[DefineModelStatement]>), Nss(Arc<[DefineNamespaceStatement]>), Nts(Arc<[DefineTokenStatement]>), Nus(Arc<[DefineUserStatement]>), diff --git a/lib/src/kvs/ds.rs b/lib/src/kvs/ds.rs index 043784ce..59c057e2 100644 --- a/lib/src/kvs/ds.rs +++ b/lib/src/kvs/ds.rs @@ -6,7 +6,7 @@ use crate::dbs::{ Variables, }; use crate::err::Error; -use crate::iam::{Action, Auth, Error as IamError, ResourceKind, Role}; +use crate::iam::{Action, Auth, Error as IamError, Resource, Role}; use crate::key::root::hb::Hb; use crate::kvs::clock::SizedClock; #[allow(unused_imports)] @@ -1232,9 +1232,11 @@ impl Datastore { self.notification_channel.as_ref().map(|v| v.1.clone()) } - #[allow(dead_code)] - pub(crate) fn live_sender(&self) -> Option>>> { - self.notification_channel.as_ref().map(|v| Arc::new(RwLock::new(v.0.clone()))) + /// Performs a database import from SQL + #[instrument(level = "debug", skip(self, sess, sql))] + pub async fn import(&self, sql: &str, sess: &Session) -> Result, Error> { + // Execute the SQL import + self.execute(sql, sess, None).await } /// Performs a full database export as SQL @@ -1242,15 +1244,10 @@ impl Datastore { pub async fn export( &self, sess: &Session, - ns: String, - db: String, chn: Sender>, ) -> Result>, Error> { - // Skip auth for Anonymous users if auth is disabled - let skip_auth = !self.is_auth_enabled() && sess.au.is_anon(); - if !skip_auth { - sess.au.is_allowed(Action::View, &ResourceKind::Any.on_db(&ns, &db))?; - } + // Retrieve the provided NS and DB + let (ns, db) = crate::iam::check::check_ns_db(sess)?; // Create a new readonly transaction let mut txn = self.transaction(Read, Optimistic).await?; // Return an async export job @@ -1262,18 +1259,15 @@ impl Datastore { }) } - /// Performs a database import from SQL - #[instrument(level = "debug", skip(self, sess, sql))] - pub async fn import(&self, sql: &str, sess: &Session) -> Result, Error> { + /// Checks the required permissions level for this session + #[instrument(level = "debug", skip(self, sess))] + pub fn check(&self, sess: &Session, action: Action, resource: Resource) -> Result<(), Error> { // Skip auth for Anonymous users if auth is disabled let skip_auth = !self.is_auth_enabled() && sess.au.is_anon(); if !skip_auth { - sess.au.is_allowed( - Action::Edit, - &ResourceKind::Any.on_level(sess.au.level().to_owned()), - )?; + sess.au.is_allowed(action, &resource)?; } - // Execute the SQL import - self.execute(sql, sess, None).await + // All ok + Ok(()) } } diff --git a/lib/src/kvs/rocksdb/cnf.rs b/lib/src/kvs/rocksdb/cnf.rs index cff8488a..78411eda 100644 --- a/lib/src/kvs/rocksdb/cnf.rs +++ b/lib/src/kvs/rocksdb/cnf.rs @@ -1,64 +1,55 @@ use once_cell::sync::Lazy; pub static ROCKSDB_THREAD_COUNT: Lazy = Lazy::new(|| { - let default = num_cpus::get() as i32; - std::env::var("SURREAL_ROCKSDB_THREAD_COUNT") - .map(|v| v.parse::().unwrap_or(default)) - .unwrap_or(default) + option_env!("SURREAL_ROCKSDB_THREAD_COUNT") + .and_then(|s| s.parse::().ok()) + .unwrap_or(num_cpus::get() as i32) }); pub static ROCKSDB_WRITE_BUFFER_SIZE: Lazy = Lazy::new(|| { - let default = 256 * 1024 * 1024; - std::env::var("SURREAL_ROCKSDB_WRITE_BUFFER_SIZE") - .map(|v| v.parse::().unwrap_or(default)) - .unwrap_or(default) + option_env!("SURREAL_ROCKSDB_WRITE_BUFFER_SIZE") + .and_then(|s| s.parse::().ok()) + .unwrap_or(256 * 1024 * 1024) }); pub static ROCKSDB_TARGET_FILE_SIZE_BASE: Lazy = Lazy::new(|| { - let default = 512 * 1024 * 1024; - std::env::var("SURREAL_ROCKSDB_TARGET_FILE_SIZE_BASE") - .map(|v| v.parse::().unwrap_or(default)) - .unwrap_or(default) + option_env!("SURREAL_ROCKSDB_TARGET_FILE_SIZE_BASE") + .and_then(|s| s.parse::().ok()) + .unwrap_or(512 * 1024 * 1024) }); pub static ROCKSDB_MAX_WRITE_BUFFER_NUMBER: Lazy = Lazy::new(|| { - let default = 32; - std::env::var("SURREAL_ROCKSDB_MAX_WRITE_BUFFER_NUMBER") - .map(|v| v.parse::().unwrap_or(default)) - .unwrap_or(default) + option_env!("SURREAL_ROCKSDB_MAX_WRITE_BUFFER_NUMBER") + .and_then(|s| s.parse::().ok()) + .unwrap_or(32) }); pub static ROCKSDB_MIN_WRITE_BUFFER_NUMBER_TO_MERGE: Lazy = Lazy::new(|| { - let default = 4; - std::env::var("SURREAL_ROCKSDB_MIN_WRITE_BUFFER_NUMBER_TO_MERGE") - .map(|v| v.parse::().unwrap_or(default)) - .unwrap_or(default) + option_env!("SURREAL_ROCKSDB_MIN_WRITE_BUFFER_NUMBER_TO_MERGE") + .and_then(|s| s.parse::().ok()) + .unwrap_or(4) }); pub static ROCKSDB_ENABLE_PIPELINED_WRITES: Lazy = Lazy::new(|| { - let default = true; - std::env::var("SURREAL_ROCKSDB_ENABLE_PIPELINED_WRITES") - .map(|v| v.parse::().unwrap_or(default)) - .unwrap_or(default) + option_env!("SURREAL_ROCKSDB_ENABLE_PIPELINED_WRITES") + .and_then(|s| s.parse::().ok()) + .unwrap_or(true) }); pub static ROCKSDB_ENABLE_BLOB_FILES: Lazy = Lazy::new(|| { - let default = true; - std::env::var("SURREAL_ROCKSDB_ENABLE_BLOB_FILES") - .map(|v| v.parse::().unwrap_or(default)) - .unwrap_or(default) + option_env!("SURREAL_ROCKSDB_ENABLE_BLOB_FILES") + .and_then(|s| s.parse::().ok()) + .unwrap_or(true) }); pub static ROCKSDB_MIN_BLOB_SIZE: Lazy = Lazy::new(|| { - let default = 4 * 1024; - std::env::var("SURREAL_ROCKSDB_MIN_BLOB_SIZE") - .map(|v| v.parse::().unwrap_or(default)) - .unwrap_or(default) + option_env!("SURREAL_ROCKSDB_MIN_BLOB_SIZE") + .and_then(|s| s.parse::().ok()) + .unwrap_or(4 * 1024) }); pub static ROCKSDB_KEEP_LOG_FILE_NUM: Lazy = Lazy::new(|| { - let default = 20; - std::env::var("SURREAL_ROCKSDB_KEEP_LOG_FILE_NUM") - .map(|v| v.parse::().unwrap_or(default)) - .unwrap_or(default) + option_env!("SURREAL_ROCKSDB_KEEP_LOG_FILE_NUM") + .and_then(|s| s.parse::().ok()) + .unwrap_or(20) }); diff --git a/lib/src/kvs/speedb/cnf.rs b/lib/src/kvs/speedb/cnf.rs index b7b690c6..4fc4007e 100644 --- a/lib/src/kvs/speedb/cnf.rs +++ b/lib/src/kvs/speedb/cnf.rs @@ -1,64 +1,55 @@ use once_cell::sync::Lazy; pub static SPEEDB_THREAD_COUNT: Lazy = Lazy::new(|| { - let default = num_cpus::get() as i32; - std::env::var("SURREAL_SPEEDB_THREAD_COUNT") - .map(|v| v.parse::().unwrap_or(default)) - .unwrap_or(default) + option_env!("SURREAL_SPEEDB_THREAD_COUNT") + .and_then(|s| s.parse::().ok()) + .unwrap_or(num_cpus::get() as i32) }); pub static SPEEDB_WRITE_BUFFER_SIZE: Lazy = Lazy::new(|| { - let default = 256 * 1024 * 1024; - std::env::var("SURREAL_SPEEDB_WRITE_BUFFER_SIZE") - .map(|v| v.parse::().unwrap_or(default)) - .unwrap_or(default) + option_env!("SURREAL_SPEEDB_WRITE_BUFFER_SIZE") + .and_then(|s| s.parse::().ok()) + .unwrap_or(256 * 1024 * 1024) }); pub static SPEEDB_TARGET_FILE_SIZE_BASE: Lazy = Lazy::new(|| { - let default = 512 * 1024 * 1024; - std::env::var("SURREAL_SPEEDB_TARGET_FILE_SIZE_BASE") - .map(|v| v.parse::().unwrap_or(default)) - .unwrap_or(default) + option_env!("SURREAL_SPEEDB_TARGET_FILE_SIZE_BASE") + .and_then(|s| s.parse::().ok()) + .unwrap_or(512 * 1024 * 1024) }); pub static SPEEDB_MAX_WRITE_BUFFER_NUMBER: Lazy = Lazy::new(|| { - let default = 32; - std::env::var("SURREAL_SPEEDB_MAX_WRITE_BUFFER_NUMBER") - .map(|v| v.parse::().unwrap_or(default)) - .unwrap_or(default) + option_env!("SURREAL_SPEEDB_MAX_WRITE_BUFFER_NUMBER") + .and_then(|s| s.parse::().ok()) + .unwrap_or(32) }); pub static SPEEDB_MIN_WRITE_BUFFER_NUMBER_TO_MERGE: Lazy = Lazy::new(|| { - let default = 4; - std::env::var("SURREAL_SPEEDB_MIN_WRITE_BUFFER_NUMBER_TO_MERGE") - .map(|v| v.parse::().unwrap_or(default)) - .unwrap_or(default) + option_env!("SURREAL_SPEEDB_MIN_WRITE_BUFFER_NUMBER_TO_MERGE") + .and_then(|s| s.parse::().ok()) + .unwrap_or(4) }); pub static SPEEDB_ENABLE_PIPELINED_WRITES: Lazy = Lazy::new(|| { - let default = true; - std::env::var("SURREAL_SPEEDB_ENABLE_PIPELINED_WRITES") - .map(|v| v.parse::().unwrap_or(default)) - .unwrap_or(default) + option_env!("SURREAL_SPEEDB_ENABLE_PIPELINED_WRITES") + .and_then(|s| s.parse::().ok()) + .unwrap_or(true) }); pub static SPEEDB_ENABLE_BLOB_FILES: Lazy = Lazy::new(|| { - let default = true; - std::env::var("SURREAL_SPEEDB_ENABLE_BLOB_FILES") - .map(|v| v.parse::().unwrap_or(default)) - .unwrap_or(default) + option_env!("SURREAL_SPEEDB_ENABLE_BLOB_FILES") + .and_then(|s| s.parse::().ok()) + .unwrap_or(true) }); pub static SPEEDB_MIN_BLOB_SIZE: Lazy = Lazy::new(|| { - let default = 4 * 1024; - std::env::var("SURREAL_SPEEDB_ENABLE_BLOB_FILES") - .map(|v| v.parse::().unwrap_or(default)) - .unwrap_or(default) + option_env!("SURREAL_SPEEDB_MIN_BLOB_SIZE") + .and_then(|s| s.parse::().ok()) + .unwrap_or(4 * 1024) }); pub static SPEEDB_KEEP_LOG_FILE_NUM: Lazy = Lazy::new(|| { - let default = 20; - std::env::var("SURREAL_SPEEDB_KEEP_LOG_FILE_NUM") - .map(|v| v.parse::().unwrap_or(default)) - .unwrap_or(default) + option_env!("SURREAL_SPEEDB_KEEP_LOG_FILE_NUM") + .and_then(|s| s.parse::().ok()) + .unwrap_or(20) }); diff --git a/lib/src/kvs/tx.rs b/lib/src/kvs/tx.rs index 96bf5d61..b2f99cf8 100644 --- a/lib/src/kvs/tx.rs +++ b/lib/src/kvs/tx.rs @@ -33,6 +33,7 @@ use sql::statements::DefineEventStatement; use sql::statements::DefineFieldStatement; use sql::statements::DefineFunctionStatement; use sql::statements::DefineIndexStatement; +use sql::statements::DefineModelStatement; use sql::statements::DefineNamespaceStatement; use sql::statements::DefineParamStatement; use sql::statements::DefineScopeStatement; @@ -1557,6 +1558,29 @@ impl Transaction { }) } + /// Retrieve all model definitions for a specific database. + pub async fn all_db_models( + &mut self, + ns: &str, + db: &str, + ) -> Result, Error> { + let key = crate::key::database::ml::prefix(ns, db); + Ok(if let Some(e) = self.cache.get(&key) { + if let Entry::Mls(v) = e { + v + } else { + unreachable!(); + } + } else { + let beg = crate::key::database::ml::prefix(ns, db); + let end = crate::key::database::ml::suffix(ns, db); + let val = self.getr(beg..end, u32::MAX).await?; + let val = val.convert().into(); + self.cache.set(key, Entry::Mls(Arc::clone(&val))); + val + }) + } + /// Retrieve all scope definitions for a specific database. pub async fn all_sc( &mut self, @@ -1840,6 +1864,21 @@ impl Transaction { Ok(val.into()) } + /// Retrieve a specific model definition from a database. + pub async fn get_db_model( + &mut self, + ns: &str, + db: &str, + ml: &str, + vn: &str, + ) -> Result { + let key = crate::key::database::ml::new(ns, db, ml, vn); + let val = self.get(key).await?.ok_or(Error::MlNotFound { + value: format!("{ml}<{vn}>"), + })?; + Ok(val.into()) + } + /// Retrieve a specific database token definition. pub async fn get_db_token( &mut self, @@ -2178,6 +2217,31 @@ impl Transaction { }) } + /// Retrieve a specific model definition. + pub async fn get_and_cache_db_model( + &mut self, + ns: &str, + db: &str, + ml: &str, + vn: &str, + ) -> Result, Error> { + let key = crate::key::database::ml::new(ns, db, ml, vn).encode()?; + Ok(if let Some(e) = self.cache.get(&key) { + if let Entry::Ml(v) = e { + v + } else { + unreachable!(); + } + } else { + let val = self.get(key.clone()).await?.ok_or(Error::MlNotFound { + value: format!("{ml}<{vn}>"), + })?; + let val: Arc = Arc::new(val.into()); + self.cache.set(key, Entry::Ml(Arc::clone(&val))); + val + }) + } + /// Retrieve a specific table index definition. pub async fn get_and_cache_tb_index( &mut self, diff --git a/lib/src/lib.rs b/lib/src/lib.rs index f0e5e949..b61ac576 100644 --- a/lib/src/lib.rs +++ b/lib/src/lib.rs @@ -134,6 +134,10 @@ pub mod idx; pub mod key; #[doc(hidden)] pub mod kvs; + +#[cfg(feature = "ml")] +#[doc(hidden)] +pub mod obs; #[doc(hidden)] pub mod syn; diff --git a/lib/src/obs/mod.rs b/lib/src/obs/mod.rs new file mode 100644 index 00000000..69dccd3d --- /dev/null +++ b/lib/src/obs/mod.rs @@ -0,0 +1,96 @@ +//! This module defines the operations for object storage using the [object_store](https://docs.rs/object_store/latest/object_store/) +//! crate. This will enable the user to store objects using local file storage, or cloud storage such as S3 or GCS. +use crate::err::Error; +use bytes::Bytes; +use futures::stream::BoxStream; +use object_store::local::LocalFileSystem; +use object_store::parse_url; +use object_store::path::Path; +use object_store::ObjectStore; +use once_cell::sync::Lazy; +use sha1::{Digest, Sha1}; +use std::env; +use std::fs; +use std::sync::Arc; +use url::Url; + +static STORE: Lazy> = + Lazy::new(|| match std::env::var("SURREAL_OBJECT_STORE") { + Ok(url) => { + let url = Url::parse(&url).expect("Expected a valid url for SURREAL_OBJECT_STORE"); + let (store, _) = + parse_url(&url).expect("Expected a valid url for SURREAL_OBJECT_STORE"); + Arc::new(store) + } + Err(_) => { + let path = env::current_dir().unwrap().join("store"); + if !path.exists() || !path.is_dir() { + fs::create_dir_all(&path) + .expect("Unable to create directory structure for SURREAL_OBJECT_STORE"); + } + // As long as the provided path is correct, the following should never panic + Arc::new(LocalFileSystem::new_with_prefix(path).unwrap()) + } + }); + +static CACHE: Lazy> = + Lazy::new(|| match std::env::var("SURREAL_OBJECT_CACHE") { + Ok(url) => { + let url = Url::parse(&url).expect("Expected a valid url for SURREAL_OBJECT_CACHE"); + let (store, _) = + parse_url(&url).expect("Expected a valid url for SURREAL_OBJECT_CACHE"); + Arc::new(store) + } + Err(_) => { + let path = env::current_dir().unwrap().join("cache"); + if !path.exists() || !path.is_dir() { + fs::create_dir_all(&path) + .expect("Unable to create directory structure for SURREAL_OBJECT_CACHE"); + } + // As long as the provided path is correct, the following should never panic + Arc::new(LocalFileSystem::new_with_prefix(path).unwrap()) + } + }); + +/// Gets the file from the local file system object storage. +pub async fn stream( + file: String, +) -> Result>, Error> { + match CACHE.get(&Path::from(file.as_str())).await { + Ok(data) => Ok(data.into_stream()), + _ => Ok(STORE.get(&Path::from(file.as_str())).await?.into_stream()), + } +} + +/// Gets the file from the local file system object storage. +pub async fn get(file: &str) -> Result, Error> { + match CACHE.get(&Path::from(file)).await { + Ok(data) => Ok(data.bytes().await?.to_vec()), + _ => { + let data = STORE.get(&Path::from(file)).await?; + CACHE.put(&Path::from(file), data.bytes().await?).await?; + Ok(CACHE.get(&Path::from(file)).await?.bytes().await?.to_vec()) + } + } +} + +/// Gets the file from the local file system object storage. +pub async fn put(file: &str, data: Vec) -> Result<(), Error> { + let _ = STORE.put(&Path::from(file), Bytes::from(data)).await?; + Ok(()) +} + +/// Gets the file from the local file system object storage. +pub async fn del(file: &str) -> Result<(), Error> { + Ok(STORE.delete(&Path::from(file)).await?) +} + +/// Hashes the bytes of a file to a string for the storage of a file. +pub fn hash(data: &Vec) -> String { + let mut hasher = Sha1::new(); + hasher.update(data); + let result = hasher.finalize(); + let mut output = hex::encode(result); + output.truncate(6); + output +} diff --git a/lib/src/sql/function.rs b/lib/src/sql/function.rs index 43c9db4e..f351c5d9 100644 --- a/lib/src/sql/function.rs +++ b/lib/src/sql/function.rs @@ -159,8 +159,10 @@ impl Function { fnc::run(ctx, opt, txn, doc, s, a).await } Self::Custom(s, x) => { + // Get the full name of this function + let name = format!("fn::{s}"); // Check this function is allowed - ctx.check_allowed_function(format!("fn::{s}").as_str())?; + ctx.check_allowed_function(name.as_str())?; // Get the function definition let val = { // Claim transaction @@ -189,15 +191,16 @@ impl Function { } } } - // Return the value - // Check the function arguments + // Get the number of function arguments let max_args_len = val.args.len(); + // Track the number of required arguments let mut min_args_len = 0; + // Check for any final optional arguments val.args.iter().rev().for_each(|(_, kind)| match kind { Kind::Option(_) if min_args_len == 0 => {} _ => min_args_len += 1, }); - + // Check the necessary arguments are passed if x.len() < min_args_len || max_args_len < x.len() { return Err(Error::InvalidArguments { name: format!("fn::{}", val.name), diff --git a/lib/src/sql/model.rs b/lib/src/sql/model.rs index 65310632..0ffc4e97 100644 --- a/lib/src/sql/model.rs +++ b/lib/src/sql/model.rs @@ -1,16 +1,29 @@ -use crate::{ - ctx::Context, - dbs::{Options, Transaction}, - doc::CursorDoc, - err::Error, - sql::value::Value, -}; -use async_recursion::async_recursion; +use crate::ctx::Context; +use crate::dbs::{Options, Transaction}; +use crate::doc::CursorDoc; +use crate::err::Error; +use crate::sql::value::Value; use derive::Store; use revision::revisioned; use serde::{Deserialize, Serialize}; use std::fmt; +#[cfg(feature = "ml")] +use crate::iam::Action; +#[cfg(feature = "ml")] +use crate::sql::Permission; +#[cfg(feature = "ml")] +use futures::future::try_join_all; +#[cfg(feature = "ml")] +use std::collections::HashMap; +#[cfg(feature = "ml")] +use surrealml_core::execution::compute::ModelComputation; +#[cfg(feature = "ml")] +use surrealml_core::storage::surml_file::SurMlFile; + +#[cfg(feature = "ml")] +const ARGUMENTS: &str = "The model expects 1 argument. The argument can be either a number, an object, or an array of numbers."; + #[derive(Clone, Debug, Default, PartialEq, PartialOrd, Serialize, Deserialize, Store, Hash)] #[revisioned(revision = 1)] pub struct Model { @@ -33,15 +46,165 @@ impl fmt::Display for Model { } impl Model { - #[cfg_attr(not(target_arch = "wasm32"), async_recursion)] - #[cfg_attr(target_arch = "wasm32", async_recursion(?Send))] + #[cfg(feature = "ml")] + pub(crate) async fn compute( + &self, + ctx: &Context<'_>, + opt: &Options, + txn: &Transaction, + doc: Option<&CursorDoc<'_>>, + ) -> Result { + // Ensure futures are run + let opt = &opt.new_with_futures(true); + // Get the full name of this model + let name = format!("ml::{}", self.name); + // Check this function is allowed + ctx.check_allowed_function(name.as_str())?; + // Get the model definition + let val = { + // Claim transaction + let mut run = txn.lock().await; + // Get the function definition + run.get_and_cache_db_model(opt.ns(), opt.db(), &self.name, &self.version).await? + }; + // Calculate the model path + let path = format!( + "ml/{}/{}/{}-{}-{}.surml", + opt.ns(), + opt.db(), + self.name, + self.version, + val.hash + ); + // Check permissions + if opt.check_perms(Action::View) { + match &val.permissions { + Permission::Full => (), + Permission::None => { + return Err(Error::FunctionPermissions { + name: self.name.to_owned(), + }) + } + Permission::Specific(e) => { + // Disable permissions + let opt = &opt.new_with_perms(false); + // Process the PERMISSION clause + if !e.compute(ctx, opt, txn, doc).await?.is_truthy() { + return Err(Error::FunctionPermissions { + name: self.name.to_owned(), + }); + } + } + } + } + // Compute the function arguments + let mut args = + try_join_all(self.args.iter().map(|v| v.compute(ctx, opt, txn, doc))).await?; + // Check the minimum argument length + if args.len() != 1 { + return Err(Error::InvalidArguments { + name: format!("ml::{}<{}>", self.name, self.version), + message: ARGUMENTS.into(), + }); + } + // Take the first and only specified argument + match args.swap_remove(0) { + // Perform bufferered compute + Value::Object(v) => { + // Compute the model function arguments + let mut args = v + .into_iter() + .map(|(k, v)| Ok((k, Value::try_into(v)?))) + .collect::, Error>>() + .map_err(|_| Error::InvalidArguments { + name: format!("ml::{}<{}>", self.name, self.version), + message: ARGUMENTS.into(), + })?; + // Get the model file as bytes + let bytes = crate::obs::get(&path).await?; + // Run the compute in a blocking task + let outcome = tokio::task::spawn_blocking(move || { + let mut file = SurMlFile::from_bytes(bytes).unwrap(); + let compute_unit = ModelComputation { + surml_file: &mut file, + }; + compute_unit.buffered_compute(&mut args).map_err(Error::ModelComputation) + }) + .await + .unwrap()?; + // Convert the output to a value + Ok(outcome[0].into()) + } + // Perform raw compute + Value::Number(v) => { + // Compute the model function arguments + let args: f32 = v.try_into().map_err(|_| Error::InvalidArguments { + name: format!("ml::{}<{}>", self.name, self.version), + message: ARGUMENTS.into(), + })?; + // Get the model file as bytes + let bytes = crate::obs::get(&path).await?; + // Convert the argument to a tensor + let tensor = ndarray::arr1::(&[args]).into_dyn(); + // Run the compute in a blocking task + let outcome = tokio::task::spawn_blocking(move || { + let mut file = SurMlFile::from_bytes(bytes).unwrap(); + let compute_unit = ModelComputation { + surml_file: &mut file, + }; + compute_unit.raw_compute(tensor, None).map_err(Error::ModelComputation) + }) + .await + .unwrap()?; + // Convert the output to a value + Ok(outcome[0].into()) + } + // Perform raw compute + Value::Array(v) => { + // Compute the model function arguments + let args = v + .into_iter() + .map(Value::try_into) + .collect::, Error>>() + .map_err(|_| Error::InvalidArguments { + name: format!("ml::{}<{}>", self.name, self.version), + message: ARGUMENTS.into(), + })?; + // Get the model file as bytes + let bytes = crate::obs::get(&path).await?; + // Convert the argument to a tensor + let tensor = ndarray::arr1::(&args).into_dyn(); + // Run the compute in a blocking task + let outcome = tokio::task::spawn_blocking(move || { + let mut file = SurMlFile::from_bytes(bytes).unwrap(); + let compute_unit = ModelComputation { + surml_file: &mut file, + }; + compute_unit.raw_compute(tensor, None).map_err(Error::ModelComputation) + }) + .await + .unwrap()?; + // Convert the output to a value + Ok(outcome[0].into()) + } + // + _ => Err(Error::InvalidArguments { + name: format!("ml::{}<{}>", self.name, self.version), + message: ARGUMENTS.into(), + }), + } + } + + #[cfg(not(feature = "ml"))] pub(crate) async fn compute( &self, _ctx: &Context<'_>, _opt: &Options, _txn: &Transaction, - _doc: Option<&'async_recursion CursorDoc<'_>>, + _doc: Option<&CursorDoc<'_>>, ) -> Result { - Err(Error::Unimplemented("ML model evaluation not yet implemented".to_string())) + Err(Error::InvalidModel { + message: String::from("Machine learning computation is not enabled."), + }) } } diff --git a/lib/src/sql/query.rs b/lib/src/sql/query.rs index 1958becb..dc09bd89 100644 --- a/lib/src/sql/query.rs +++ b/lib/src/sql/query.rs @@ -1,6 +1,6 @@ use crate::sql::fmt::Pretty; use crate::sql::statement::{Statement, Statements}; -use crate::sql::Value; +use crate::sql::statements::{DefineStatement, RemoveStatement}; use derive::Store; use revision::revisioned; use serde::{Deserialize, Serialize}; @@ -16,6 +16,18 @@ pub(crate) const TOKEN: &str = "$surrealdb::private::sql::Query"; #[serde(rename = "$surrealdb::private::sql::Query")] pub struct Query(pub Statements); +impl From for Query { + fn from(s: DefineStatement) -> Self { + Query(Statements(vec![Statement::Define(s)])) + } +} + +impl From for Query { + fn from(s: RemoveStatement) -> Self { + Query(Statements(vec![Statement::Remove(s)])) + } +} + impl Deref for Query { type Target = Vec; fn deref(&self) -> &Self::Target { @@ -31,12 +43,6 @@ impl IntoIterator for Query { } } -impl From for Value { - fn from(q: Query) -> Self { - Value::Query(q) - } -} - impl Display for Query { fn fmt(&self, f: &mut Formatter) -> fmt::Result { write!(Pretty::from(f), "{}", &self.0) diff --git a/lib/src/sql/statements/define/mod.rs b/lib/src/sql/statements/define/mod.rs index 343fbab6..874755d2 100644 --- a/lib/src/sql/statements/define/mod.rs +++ b/lib/src/sql/statements/define/mod.rs @@ -52,7 +52,7 @@ pub enum DefineStatement { Field(DefineFieldStatement), Index(DefineIndexStatement), User(DefineUserStatement), - MlModel(DefineModelStatement), + Model(DefineModelStatement), } impl DefineStatement { @@ -81,7 +81,7 @@ impl DefineStatement { Self::Index(ref v) => v.compute(ctx, opt, txn, doc).await, Self::Analyzer(ref v) => v.compute(ctx, opt, txn, doc).await, Self::User(ref v) => v.compute(ctx, opt, txn, doc).await, - Self::MlModel(ref v) => v.compute(ctx, opt, txn, doc).await, + Self::Model(ref v) => v.compute(ctx, opt, txn, doc).await, } } } @@ -101,7 +101,7 @@ impl Display for DefineStatement { Self::Field(v) => Display::fmt(v, f), Self::Index(v) => Display::fmt(v, f), Self::Analyzer(v) => Display::fmt(v, f), - Self::MlModel(v) => Display::fmt(v, f), + Self::Model(v) => Display::fmt(v, f), } } } diff --git a/lib/src/sql/statements/define/model.rs b/lib/src/sql/statements/define/model.rs index 2b21528c..fbe5967f 100644 --- a/lib/src/sql/statements/define/model.rs +++ b/lib/src/sql/statements/define/model.rs @@ -1,25 +1,21 @@ +use crate::ctx::Context; +use crate::dbs::{Options, Transaction}; +use crate::doc::CursorDoc; +use crate::err::Error; +use crate::iam::{Action, ResourceKind}; use crate::sql::{ fmt::{is_pretty, pretty_indent}, - Permission, + Base, Ident, Permission, Strand, Value, }; -use async_recursion::async_recursion; use derive::Store; use revision::revisioned; use serde::{Deserialize, Serialize}; -use std::fmt; -use std::fmt::Write; - -use crate::{ - ctx::Context, - dbs::{Options, Transaction}, - doc::CursorDoc, - err::Error, - sql::{Ident, Strand, Value}, -}; +use std::fmt::{self, Write}; #[derive(Clone, Debug, Default, Eq, PartialEq, PartialOrd, Serialize, Deserialize, Store, Hash)] #[revisioned(revision = 1)] pub struct DefineModelStatement { + pub hash: String, pub name: Ident, pub version: String, pub comment: Option, @@ -30,31 +26,42 @@ impl fmt::Display for DefineModelStatement { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "DEFINE MODEL ml::{}<{}>", self.name, self.version)?; if let Some(comment) = self.comment.as_ref() { - write!(f, "COMMENT {}", comment)?; - } - if !self.permissions.is_full() { - let _indent = if is_pretty() { - Some(pretty_indent()) - } else { - f.write_char(' ')?; - None - }; - write!(f, "PERMISSIONS {}", self.permissions)?; + write!(f, " COMMENT {}", comment)?; } + let _indent = if is_pretty() { + Some(pretty_indent()) + } else { + f.write_char(' ')?; + None + }; + write!(f, "PERMISSIONS {}", self.permissions)?; Ok(()) } } impl DefineModelStatement { - #[cfg_attr(not(target_arch = "wasm32"), async_recursion)] - #[cfg_attr(target_arch = "wasm32", async_recursion(?Send))] + /// Process this type returning a computed simple Value pub(crate) async fn compute( &self, _ctx: &Context<'_>, - _opt: &Options, - _txn: &Transaction, - _doc: Option<&'async_recursion CursorDoc<'_>>, + opt: &Options, + txn: &Transaction, + _doc: Option<&CursorDoc<'_>>, ) -> Result { - Err(Error::Unimplemented("Ml model definition not yet implemented".to_string())) + // Allowed to run? + opt.is_allowed(Action::Edit, ResourceKind::Model, &Base::Db)?; + // Claim transaction + let mut run = txn.lock().await; + // Clear the cache + run.clear_cache(); + // Process the statement + let key = crate::key::database::ml::new(opt.ns(), opt.db(), &self.name, &self.version); + run.add_ns(opt.ns(), opt.strict).await?; + run.add_db(opt.ns(), opt.db(), opt.strict).await?; + run.set(key, self).await?; + // Store the model file + // TODO + // Ok all good + Ok(Value::None) } } diff --git a/lib/src/sql/statements/info.rs b/lib/src/sql/statements/info.rs index 5ed390bd..fae606bc 100644 --- a/lib/src/sql/statements/info.rs +++ b/lib/src/sql/statements/info.rs @@ -107,6 +107,12 @@ impl InfoStatement { tmp.insert(v.name.to_string(), v.to_string().into()); } res.insert("functions".to_owned(), tmp.into()); + // Process the models + let mut tmp = Object::default(); + for v in run.all_db_models(opt.ns(), opt.db()).await?.iter() { + tmp.insert(format!("{}<{}>", v.name, v.version), v.to_string().into()); + } + res.insert("models".to_owned(), tmp.into()); // Process the params let mut tmp = Object::default(); for v in run.all_db_params(opt.ns(), opt.db()).await?.iter() { diff --git a/lib/src/sql/statements/mod.rs b/lib/src/sql/statements/mod.rs index 5522402f..e777c783 100644 --- a/lib/src/sql/statements/mod.rs +++ b/lib/src/sql/statements/mod.rs @@ -52,14 +52,14 @@ pub use self::update::UpdateStatement; pub use self::define::{ DefineAnalyzerStatement, DefineDatabaseStatement, DefineEventStatement, DefineFieldStatement, - DefineFunctionStatement, DefineIndexStatement, DefineNamespaceStatement, DefineParamStatement, - DefineScopeStatement, DefineStatement, DefineTableStatement, DefineTokenStatement, - DefineUserStatement, + DefineFunctionStatement, DefineIndexStatement, DefineModelStatement, DefineNamespaceStatement, + DefineParamStatement, DefineScopeStatement, DefineStatement, DefineTableStatement, + DefineTokenStatement, DefineUserStatement, }; pub use self::remove::{ RemoveAnalyzerStatement, RemoveDatabaseStatement, RemoveEventStatement, RemoveFieldStatement, - RemoveFunctionStatement, RemoveIndexStatement, RemoveNamespaceStatement, RemoveParamStatement, - RemoveScopeStatement, RemoveStatement, RemoveTableStatement, RemoveTokenStatement, - RemoveUserStatement, + RemoveFunctionStatement, RemoveIndexStatement, RemoveModelStatement, RemoveNamespaceStatement, + RemoveParamStatement, RemoveScopeStatement, RemoveStatement, RemoveTableStatement, + RemoveTokenStatement, RemoveUserStatement, }; diff --git a/lib/src/sql/statements/remove/mod.rs b/lib/src/sql/statements/remove/mod.rs index 19086f93..fc6cdbac 100644 --- a/lib/src/sql/statements/remove/mod.rs +++ b/lib/src/sql/statements/remove/mod.rs @@ -4,6 +4,7 @@ mod event; mod field; mod function; mod index; +mod model; mod namespace; mod param; mod scope; @@ -17,6 +18,7 @@ pub use event::RemoveEventStatement; pub use field::RemoveFieldStatement; pub use function::RemoveFunctionStatement; pub use index::RemoveIndexStatement; +pub use model::RemoveModelStatement; pub use namespace::RemoveNamespaceStatement; pub use param::RemoveParamStatement; pub use scope::RemoveScopeStatement; @@ -49,6 +51,7 @@ pub enum RemoveStatement { Field(RemoveFieldStatement), Index(RemoveIndexStatement), User(RemoveUserStatement), + Model(RemoveModelStatement), } impl RemoveStatement { @@ -77,6 +80,7 @@ impl RemoveStatement { Self::Index(ref v) => v.compute(ctx, opt, txn).await, Self::Analyzer(ref v) => v.compute(ctx, opt, txn).await, Self::User(ref v) => v.compute(ctx, opt, txn).await, + Self::Model(ref v) => v.compute(ctx, opt, txn).await, } } } @@ -96,6 +100,7 @@ impl Display for RemoveStatement { Self::Index(v) => Display::fmt(v, f), Self::Analyzer(v) => Display::fmt(v, f), Self::User(v) => Display::fmt(v, f), + Self::Model(v) => Display::fmt(v, f), } } } diff --git a/lib/src/sql/statements/remove/model.rs b/lib/src/sql/statements/remove/model.rs new file mode 100644 index 00000000..8b56921a --- /dev/null +++ b/lib/src/sql/statements/remove/model.rs @@ -0,0 +1,47 @@ +use crate::ctx::Context; +use crate::dbs::{Options, Transaction}; +use crate::err::Error; +use crate::iam::{Action, ResourceKind}; +use crate::sql::{Base, Ident, Value}; +use derive::Store; +use revision::revisioned; +use serde::{Deserialize, Serialize}; +use std::fmt::{self, Display}; + +#[derive(Clone, Debug, Default, Eq, PartialEq, PartialOrd, Serialize, Deserialize, Store, Hash)] +#[revisioned(revision = 1)] +pub struct RemoveModelStatement { + pub name: Ident, + pub version: String, +} + +impl RemoveModelStatement { + /// Process this type returning a computed simple Value + pub(crate) async fn compute( + &self, + _ctx: &Context<'_>, + opt: &Options, + txn: &Transaction, + ) -> Result { + // Allowed to run? + opt.is_allowed(Action::Edit, ResourceKind::Model, &Base::Db)?; + // Claim transaction + let mut run = txn.lock().await; + // Clear the cache + run.clear_cache(); + // Delete the definition + let key = crate::key::database::ml::new(opt.ns(), opt.db(), &self.name, &self.version); + run.del(key).await?; + // Remove the model file + // TODO + // Ok all good + Ok(Value::None) + } +} + +impl Display for RemoveModelStatement { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + // Bypass ident display since we don't want backticks arround the ident. + write!(f, "REMOVE MODEL ml::{}<{}>", self.name.0, self.version) + } +} diff --git a/lib/src/sql/value/value.rs b/lib/src/sql/value/value.rs index d939453e..14f21387 100644 --- a/lib/src/sql/value/value.rs +++ b/lib/src/sql/value/value.rs @@ -99,12 +99,11 @@ pub enum Value { Edges(Box), Future(Box), Constant(Constant), - // Closure(Box), Function(Box), Subquery(Box), Expression(Box), Query(Query), - MlModel(Box), + Model(Box), // Add new variants here } @@ -257,7 +256,7 @@ impl From for Value { impl From for Value { fn from(v: Model) -> Self { - Value::MlModel(Box::new(v)) + Value::Model(Box::new(v)) } } @@ -505,6 +504,12 @@ impl From for Value { } } +impl From for Value { + fn from(q: Query) -> Self { + Value::Query(q) + } +} + impl TryFrom for i8 { type Error = Error; fn try_from(value: Value) -> Result { @@ -1035,7 +1040,7 @@ impl Value { pub fn can_start_idiom(&self) -> bool { match self { Value::Function(x) => !x.is_script(), - Value::MlModel(_) + Value::Model(_) | Value::Subquery(_) | Value::Constant(_) | Value::Datetime(_) @@ -2526,7 +2531,7 @@ impl fmt::Display for Value { Value::Edges(v) => write!(f, "{v}"), Value::Expression(v) => write!(f, "{v}"), Value::Function(v) => write!(f, "{v}"), - Value::MlModel(v) => write!(f, "{v}"), + Value::Model(v) => write!(f, "{v}"), Value::Future(v) => write!(f, "{v}"), Value::Geometry(v) => write!(f, "{v}"), Value::Idiom(v) => write!(f, "{v}"), @@ -2557,7 +2562,7 @@ impl Value { Value::Function(v) => { v.is_custom() || v.is_script() || v.args().iter().any(Value::writeable) } - Value::MlModel(m) => m.args.iter().any(Value::writeable), + Value::Model(m) => m.args.iter().any(Value::writeable), Value::Subquery(v) => v.writeable(), Value::Expression(v) => v.writeable(), _ => false, @@ -2588,7 +2593,7 @@ impl Value { Value::Future(v) => v.compute(ctx, opt, txn, doc).await, Value::Constant(v) => v.compute(ctx, opt, txn, doc).await, Value::Function(v) => v.compute(ctx, opt, txn, doc).await, - Value::MlModel(v) => v.compute(ctx, opt, txn, doc).await, + Value::Model(v) => v.compute(ctx, opt, txn, doc).await, Value::Subquery(v) => v.compute(ctx, opt, txn, doc).await, Value::Expression(v) => v.compute(ctx, opt, txn, doc).await, _ => Ok(self.to_owned()), diff --git a/lib/tests/api.rs b/lib/tests/api.rs index 8345b5df..c430d013 100644 --- a/lib/tests/api.rs +++ b/lib/tests/api.rs @@ -2,6 +2,7 @@ mod api_integration { use chrono::DateTime; use once_cell::sync::Lazy; + use semver::Version; use serde::Deserialize; use serde::Serialize; use serde_json::json; diff --git a/lib/tests/api/backup.rs b/lib/tests/api/backup.rs index 499e1cb0..355f633e 100644 --- a/lib/tests/api/backup.rs +++ b/lib/tests/api/backup.rs @@ -23,3 +23,17 @@ async fn export_import() { db.import(&file).await.unwrap(); remove_file(file).await.unwrap(); } + +#[test_log::test(tokio::test)] +#[cfg(feature = "ml")] +async fn ml_export_import() { + let (permit, db) = new_db().await; + let db_name = Ulid::new().to_string(); + db.use_ns(NS).use_db(&db_name).await.unwrap(); + db.import("../tests/linear_test.surml").ml().await.unwrap(); + drop(permit); + let file = format!("{db_name}.surml"); + db.export(&file).ml("Prediction", Version::new(0, 0, 1)).await.unwrap(); + db.import(&file).ml().await.unwrap(); + remove_file(file).await.unwrap(); +} diff --git a/lib/tests/define.rs b/lib/tests/define.rs index 76e06459..8dca0050 100644 --- a/lib/tests/define.rs +++ b/lib/tests/define.rs @@ -87,8 +87,7 @@ async fn define_statement_function() -> Result<(), Error> { analyzers: {}, tokens: {}, functions: { test: 'DEFINE FUNCTION fn::test($first: string, $last: string) { RETURN $first + $last; } PERMISSIONS FULL' }, - params: {}, - scopes: {}, + models: {}, params: {}, scopes: {}, tables: {}, @@ -120,6 +119,7 @@ async fn define_statement_table_drop() -> Result<(), Error> { analyzers: {}, tokens: {}, functions: {}, + models: {}, params: {}, scopes: {}, tables: { test: 'DEFINE TABLE test DROP SCHEMALESS PERMISSIONS NONE' }, @@ -151,6 +151,7 @@ async fn define_statement_table_schemaless() -> Result<(), Error> { analyzers: {}, tokens: {}, functions: {}, + models: {}, params: {}, scopes: {}, tables: { test: 'DEFINE TABLE test SCHEMALESS PERMISSIONS NONE' }, @@ -186,6 +187,7 @@ async fn define_statement_table_schemafull() -> Result<(), Error> { analyzers: {}, tokens: {}, functions: {}, + models: {}, params: {}, scopes: {}, tables: { test: 'DEFINE TABLE test SCHEMAFULL PERMISSIONS NONE' }, @@ -217,6 +219,7 @@ async fn define_statement_table_schemaful() -> Result<(), Error> { analyzers: {}, tokens: {}, functions: {}, + models: {}, params: {}, scopes: {}, tables: { test: 'DEFINE TABLE test SCHEMAFULL PERMISSIONS NONE' }, @@ -256,6 +259,7 @@ async fn define_statement_table_foreigntable() -> Result<(), Error> { analyzers: {}, tokens: {}, functions: {}, + models: {}, params: {}, scopes: {}, tables: { @@ -288,6 +292,7 @@ async fn define_statement_table_foreigntable() -> Result<(), Error> { analyzers: {}, tokens: {}, functions: {}, + models: {}, params: {}, scopes: {}, tables: { @@ -1177,6 +1182,7 @@ async fn define_statement_analyzer() -> Result<(), Error> { functions: { stripHtml: "DEFINE FUNCTION fn::stripHtml($html: string) { RETURN string::replace($html, /<[^>]*>/, ''); } PERMISSIONS FULL" }, + models: {}, params: {}, scopes: {}, tables: {}, @@ -1496,8 +1502,8 @@ async fn permissions_checks_define_function() { // Define the expected results for the check statement when the test statement succeeded and when it failed let check_results = [ - vec!["{ analyzers: { }, functions: { greet: \"DEFINE FUNCTION fn::greet() { RETURN 'Hello'; } PERMISSIONS FULL\" }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"], - vec!["{ analyzers: { }, functions: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"] + vec!["{ analyzers: { }, functions: { greet: \"DEFINE FUNCTION fn::greet() { RETURN 'Hello'; } PERMISSIONS FULL\" }, models: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"], + vec!["{ analyzers: { }, functions: { }, models: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"] ]; let test_cases = [ @@ -1538,8 +1544,8 @@ async fn permissions_checks_define_analyzer() { // Define the expected results for the check statement when the test statement succeeded and when it failed let check_results = [ - vec!["{ analyzers: { analyzer: 'DEFINE ANALYZER analyzer TOKENIZERS BLANK' }, functions: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"], - vec!["{ analyzers: { }, functions: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"] + vec!["{ analyzers: { analyzer: 'DEFINE ANALYZER analyzer TOKENIZERS BLANK' }, functions: { }, models: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"], + vec!["{ analyzers: { }, functions: { }, models: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"] ]; let test_cases = [ @@ -1622,8 +1628,8 @@ async fn permissions_checks_define_token_db() { // Define the expected results for the check statement when the test statement succeeded and when it failed let check_results = [ - vec!["{ analyzers: { }, functions: { }, params: { }, scopes: { }, tables: { }, tokens: { token: \"DEFINE TOKEN token ON DATABASE TYPE HS512 VALUE 'secret'\" }, users: { } }"], - vec!["{ analyzers: { }, functions: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"] + vec!["{ analyzers: { }, functions: { }, models: { }, params: { }, scopes: { }, tables: { }, tokens: { token: \"DEFINE TOKEN token ON DATABASE TYPE HS512 VALUE 'secret'\" }, users: { } }"], + vec!["{ analyzers: { }, functions: { }, models: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"] ]; let test_cases = [ @@ -1748,8 +1754,8 @@ async fn permissions_checks_define_user_db() { // Define the expected results for the check statement when the test statement succeeded and when it failed let check_results = [ - vec!["{ analyzers: { }, functions: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { user: \"DEFINE USER user ON DATABASE PASSHASH 'secret' ROLES VIEWER\" } }"], - vec!["{ analyzers: { }, functions: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"] + vec!["{ analyzers: { }, functions: { }, models: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { user: \"DEFINE USER user ON DATABASE PASSHASH 'secret' ROLES VIEWER\" } }"], + vec!["{ analyzers: { }, functions: { }, models: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"] ]; let test_cases = [ @@ -1790,8 +1796,8 @@ async fn permissions_checks_define_scope() { // Define the expected results for the check statement when the test statement succeeded and when it failed let check_results = [ - vec!["{ analyzers: { }, functions: { }, params: { }, scopes: { account: 'DEFINE SCOPE account SESSION 1h' }, tables: { }, tokens: { }, users: { } }"], - vec!["{ analyzers: { }, functions: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"] + vec!["{ analyzers: { }, functions: { }, models: { }, params: { }, scopes: { account: 'DEFINE SCOPE account SESSION 1h' }, tables: { }, tokens: { }, users: { } }"], + vec!["{ analyzers: { }, functions: { }, models: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"] ]; let test_cases = [ @@ -1832,8 +1838,8 @@ async fn permissions_checks_define_param() { // Define the expected results for the check statement when the test statement succeeded and when it failed let check_results = [ - vec!["{ analyzers: { }, functions: { }, params: { param: \"DEFINE PARAM $param VALUE 'foo' PERMISSIONS FULL\" }, scopes: { }, tables: { }, tokens: { }, users: { } }"], - vec!["{ analyzers: { }, functions: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"] + vec!["{ analyzers: { }, functions: { }, models: { }, params: { param: \"DEFINE PARAM $param VALUE 'foo' PERMISSIONS FULL\" }, scopes: { }, tables: { }, tokens: { }, users: { } }"], + vec!["{ analyzers: { }, functions: { }, models: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"] ]; let test_cases = [ @@ -1871,8 +1877,8 @@ async fn permissions_checks_define_table() { // Define the expected results for the check statement when the test statement succeeded and when it failed let check_results = [ - vec!["{ analyzers: { }, functions: { }, params: { }, scopes: { }, tables: { TB: 'DEFINE TABLE TB SCHEMALESS PERMISSIONS NONE' }, tokens: { }, users: { } }"], - vec!["{ analyzers: { }, functions: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"] + vec!["{ analyzers: { }, functions: { }, models: { }, params: { }, scopes: { }, tables: { TB: 'DEFINE TABLE TB SCHEMALESS PERMISSIONS NONE' }, tokens: { }, users: { } }"], + vec!["{ analyzers: { }, functions: { }, models: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"] ]; let test_cases = [ @@ -2058,6 +2064,7 @@ async fn define_statement_table_permissions() -> Result<(), Error> { "{ analyzers: {}, functions: {}, + models: {}, params: {}, scopes: {}, tables: { diff --git a/lib/tests/info.rs b/lib/tests/info.rs index d7dd08a1..97747bad 100644 --- a/lib/tests/info.rs +++ b/lib/tests/info.rs @@ -312,8 +312,8 @@ async fn permissions_checks_info_db() { // Define the expected results for the check statement when the test statement succeeded and when it failed let check_results = [ - vec!["{ analyzers: { }, functions: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"], - vec!["{ analyzers: { }, functions: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"], + vec!["{ analyzers: { }, functions: { }, models: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"], + vec!["{ analyzers: { }, functions: { }, models: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"], ]; let test_cases = [ diff --git a/lib/tests/param.rs b/lib/tests/param.rs index aa9d1bf7..ceb53a73 100644 --- a/lib/tests/param.rs +++ b/lib/tests/param.rs @@ -29,6 +29,7 @@ async fn define_global_param() -> Result<(), Error> { analyzers: {}, tokens: {}, functions: {}, + models: {}, params: { test: 'DEFINE PARAM $test VALUE 12345 PERMISSIONS FULL' }, scopes: {}, tables: {}, diff --git a/lib/tests/remove.rs b/lib/tests/remove.rs index 1b4de2e4..11924782 100644 --- a/lib/tests/remove.rs +++ b/lib/tests/remove.rs @@ -39,6 +39,7 @@ async fn remove_statement_table() -> Result<(), Error> { analyzers: {}, tokens: {}, functions: {}, + models: {}, params: {}, scopes: {}, tables: {}, @@ -73,6 +74,7 @@ async fn remove_statement_analyzer() -> Result<(), Error> { analyzers: {}, tokens: {}, functions: {}, + models: {}, params: {}, scopes: {}, tables: {}, @@ -222,8 +224,8 @@ async fn permissions_checks_remove_function() { // Define the expected results for the check statement when the test statement succeeded and when it failed let check_results = [ - vec!["{ analyzers: { }, functions: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"], - vec!["{ analyzers: { }, functions: { greet: \"DEFINE FUNCTION fn::greet() { RETURN 'Hello'; } PERMISSIONS FULL\" }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"], + vec!["{ analyzers: { }, functions: { }, models: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"], + vec!["{ analyzers: { }, functions: { greet: \"DEFINE FUNCTION fn::greet() { RETURN 'Hello'; } PERMISSIONS FULL\" }, models: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"], ]; let test_cases = [ @@ -264,8 +266,8 @@ async fn permissions_checks_remove_analyzer() { // Define the expected results for the check statement when the test statement succeeded and when it failed let check_results = [ - vec!["{ analyzers: { }, functions: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"], - vec!["{ analyzers: { analyzer: 'DEFINE ANALYZER analyzer TOKENIZERS BLANK' }, functions: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"], + vec!["{ analyzers: { }, functions: { }, models: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"], + vec!["{ analyzers: { analyzer: 'DEFINE ANALYZER analyzer TOKENIZERS BLANK' }, functions: { }, models: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"], ]; let test_cases = [ @@ -348,8 +350,8 @@ async fn permissions_checks_remove_db_token() { // Define the expected results for the check statement when the test statement succeeded and when it failed let check_results = [ - vec!["{ analyzers: { }, functions: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"], - vec!["{ analyzers: { }, functions: { }, params: { }, scopes: { }, tables: { }, tokens: { token: \"DEFINE TOKEN token ON DATABASE TYPE HS512 VALUE 'secret'\" }, users: { } }"], + vec!["{ analyzers: { }, functions: { }, models: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"], + vec!["{ analyzers: { }, functions: { }, models: { }, params: { }, scopes: { }, tables: { }, tokens: { token: \"DEFINE TOKEN token ON DATABASE TYPE HS512 VALUE 'secret'\" }, users: { } }"], ]; let test_cases = [ @@ -474,8 +476,8 @@ async fn permissions_checks_remove_db_user() { // Define the expected results for the check statement when the test statement succeeded and when it failed let check_results = [ - vec!["{ analyzers: { }, functions: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"], - vec!["{ analyzers: { }, functions: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { user: \"DEFINE USER user ON DATABASE PASSHASH 'secret' ROLES VIEWER\" } }"], + vec!["{ analyzers: { }, functions: { }, models: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"], + vec!["{ analyzers: { }, functions: { }, models: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { user: \"DEFINE USER user ON DATABASE PASSHASH 'secret' ROLES VIEWER\" } }"], ]; let test_cases = [ @@ -516,8 +518,8 @@ async fn permissions_checks_remove_scope() { // Define the expected results for the check statement when the test statement succeeded and when it failed let check_results = [ - vec!["{ analyzers: { }, functions: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"], - vec!["{ analyzers: { }, functions: { }, params: { }, scopes: { account: 'DEFINE SCOPE account SESSION 1h' }, tables: { }, tokens: { }, users: { } }"], + vec!["{ analyzers: { }, functions: { }, models: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"], + vec!["{ analyzers: { }, functions: { }, models: { }, params: { }, scopes: { account: 'DEFINE SCOPE account SESSION 1h' }, tables: { }, tokens: { }, users: { } }"], ]; let test_cases = [ @@ -558,8 +560,8 @@ async fn permissions_checks_remove_param() { // Define the expected results for the check statement when the test statement succeeded and when it failed let check_results = [ - vec!["{ analyzers: { }, functions: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"], - vec!["{ analyzers: { }, functions: { }, params: { param: \"DEFINE PARAM $param VALUE 'foo' PERMISSIONS FULL\" }, scopes: { }, tables: { }, tokens: { }, users: { } }"], + vec!["{ analyzers: { }, functions: { }, models: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"], + vec!["{ analyzers: { }, functions: { }, models: { }, params: { param: \"DEFINE PARAM $param VALUE 'foo' PERMISSIONS FULL\" }, scopes: { }, tables: { }, tokens: { }, users: { } }"], ]; let test_cases = [ @@ -600,8 +602,8 @@ async fn permissions_checks_remove_table() { // Define the expected results for the check statement when the test statement succeeded and when it failed let check_results = [ - vec!["{ analyzers: { }, functions: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"], - vec!["{ analyzers: { }, functions: { }, params: { }, scopes: { }, tables: { TB: 'DEFINE TABLE TB SCHEMALESS PERMISSIONS NONE' }, tokens: { }, users: { } }"], + vec!["{ analyzers: { }, functions: { }, models: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"], + vec!["{ analyzers: { }, functions: { }, models: { }, params: { }, scopes: { }, tables: { TB: 'DEFINE TABLE TB SCHEMALESS PERMISSIONS NONE' }, tokens: { }, users: { } }"], ]; let test_cases = [ diff --git a/lib/tests/strict.rs b/lib/tests/strict.rs index daefd9b2..ece0e8a3 100644 --- a/lib/tests/strict.rs +++ b/lib/tests/strict.rs @@ -255,6 +255,7 @@ async fn loose_mode_all_ok() -> Result<(), Error> { analyzers: {}, tokens: {}, functions: {}, + models: {}, params: {}, scopes: {}, tables: { test: 'DEFINE TABLE test SCHEMALESS PERMISSIONS NONE' }, diff --git a/src/cli/export.rs b/src/cli/export.rs index 84017f77..ba180d9a 100644 --- a/src/cli/export.rs +++ b/src/cli/export.rs @@ -10,7 +10,7 @@ use tokio::io::{self, AsyncWriteExt}; #[derive(Args, Debug)] pub struct ExportCommandArguments { - #[arg(help = "Path to the sql file to export. Use dash - to write into stdout.")] + #[arg(help = "Path to the SurrealQL file to export. Use dash - to write into stdout.")] #[arg(default_value = "-")] #[arg(index = 1)] file: String, @@ -87,7 +87,7 @@ pub async fn init( } else { client.export(file).await?; } - info!("The SQL file was exported successfully"); + info!("The SurrealQL file was exported successfully"); // Everything OK Ok(()) } diff --git a/src/cli/import.rs b/src/cli/import.rs index 826855eb..a835bfa3 100644 --- a/src/cli/import.rs +++ b/src/cli/import.rs @@ -10,7 +10,7 @@ use surrealdb::opt::Config; #[derive(Args, Debug)] pub struct ImportCommandArguments { - #[arg(help = "Path to the sql file to import")] + #[arg(help = "Path to the SurrealQL file to import")] #[arg(index = 1)] file: String, #[command(flatten)] @@ -75,7 +75,7 @@ pub async fn init( client.use_ns(namespace).use_db(database).await?; // Import the data into the database client.import(file).await?; - info!("The SQL file was imported successfully"); + info!("The SurrealQL file was imported successfully"); // Everything OK Ok(()) } diff --git a/src/cli/ml/export.rs b/src/cli/ml/export.rs new file mode 100644 index 00000000..6dd9f249 --- /dev/null +++ b/src/cli/ml/export.rs @@ -0,0 +1,117 @@ +use crate::cli::abstraction::auth::{CredentialsBuilder, CredentialsLevel}; +use crate::cli::abstraction::{ + AuthArguments, DatabaseConnectionArguments, DatabaseSelectionArguments, +}; +use crate::err::Error; +use clap::Args; +use futures_util::StreamExt; +use surrealdb::engine::any::{connect, IntoEndpoint}; +use tokio::io::{self, AsyncWriteExt}; + +#[derive(Args, Debug)] +pub struct ModelArguments { + #[arg(help = "The name of the model")] + #[arg(env = "SURREAL_NAME", long = "name")] + pub(crate) name: String, + #[arg(help = "The version of the model")] + #[arg(env = "SURREAL_VERSION", long = "version")] + pub(crate) version: String, +} + +#[derive(Args, Debug)] +pub struct ExportCommandArguments { + #[arg(help = "Path to the SurrealML file to export. Use dash - to write into stdout.")] + #[arg(default_value = "-")] + #[arg(index = 1)] + file: String, + + #[command(flatten)] + model: ModelArguments, + #[command(flatten)] + conn: DatabaseConnectionArguments, + #[command(flatten)] + auth: AuthArguments, + #[command(flatten)] + sel: DatabaseSelectionArguments, +} + +pub async fn init( + ExportCommandArguments { + file, + model: ModelArguments { + name, + version, + }, + conn: DatabaseConnectionArguments { + endpoint, + }, + auth: AuthArguments { + username, + password, + auth_level, + }, + sel: DatabaseSelectionArguments { + namespace, + database, + }, + }: ExportCommandArguments, +) -> Result<(), Error> { + // Initialize opentelemetry and logging + crate::telemetry::builder().with_log_level("error").init(); + + // If username and password are specified, and we are connecting to a remote SurrealDB server, then we need to authenticate. + // If we are connecting directly to a datastore (i.e. file://local.db or tikv://...), then we don't need to authenticate because we use an embedded (local) SurrealDB instance with auth disabled. + let client = if username.is_some() + && password.is_some() + && !endpoint.clone().into_endpoint()?.parse_kind()?.is_local() + { + debug!("Connecting to the database engine with authentication"); + let creds = CredentialsBuilder::default() + .with_username(username.as_deref()) + .with_password(password.as_deref()) + .with_namespace(namespace.as_str()) + .with_database(database.as_str()); + + let client = connect(endpoint).await?; + + debug!("Signing in to the database engine at '{:?}' level", auth_level); + match auth_level { + CredentialsLevel::Root => client.signin(creds.root()?).await?, + CredentialsLevel::Namespace => client.signin(creds.namespace()?).await?, + CredentialsLevel::Database => client.signin(creds.database()?).await?, + }; + + client + } else { + debug!("Connecting to the database engine without authentication"); + connect(endpoint).await? + }; + + // Parse model version + let version = match version.parse() { + Ok(version) => version, + Err(_) => { + return Err(Error::Other(format!("`{version}` is not a valid semantic version"))); + } + }; + + // Use the specified namespace / database + client.use_ns(namespace).use_db(database).await?; + // Export the data from the database + debug!("Exporting data from the database"); + if file == "-" { + // Prepare the backup + let mut backup = client.export(()).ml(&name, version).await?; + // Get a handle to standard output + let mut stdout = io::stdout(); + // Write the backup to standard output + while let Some(bytes) = backup.next().await { + stdout.write_all(&bytes?).await?; + } + } else { + client.export(file).ml(&name, version).await?; + } + info!("The SurrealML file was exported successfully"); + // Everything OK + Ok(()) +} diff --git a/src/cli/ml/import.rs b/src/cli/ml/import.rs new file mode 100644 index 00000000..f574e4fa --- /dev/null +++ b/src/cli/ml/import.rs @@ -0,0 +1,81 @@ +use crate::cli::abstraction::auth::{CredentialsBuilder, CredentialsLevel}; +use crate::cli::abstraction::{ + AuthArguments, DatabaseConnectionArguments, DatabaseSelectionArguments, +}; +use crate::err::Error; +use clap::Args; +use surrealdb::dbs::Capabilities; +use surrealdb::engine::any::{connect, IntoEndpoint}; +use surrealdb::opt::Config; + +#[derive(Args, Debug)] +pub struct ImportCommandArguments { + #[arg(help = "Path to the SurrealML file to import")] + #[arg(index = 1)] + file: String, + #[command(flatten)] + conn: DatabaseConnectionArguments, + #[command(flatten)] + auth: AuthArguments, + #[command(flatten)] + sel: DatabaseSelectionArguments, +} + +pub async fn init( + ImportCommandArguments { + file, + conn: DatabaseConnectionArguments { + endpoint, + }, + auth: AuthArguments { + username, + password, + auth_level, + }, + sel: DatabaseSelectionArguments { + namespace, + database, + }, + }: ImportCommandArguments, +) -> Result<(), Error> { + // Initialize opentelemetry and logging + crate::telemetry::builder().with_log_level("info").init(); + // Default datastore configuration for local engines + let config = Config::new().capabilities(Capabilities::all()); + + // If username and password are specified, and we are connecting to a remote SurrealDB server, then we need to authenticate. + // If we are connecting directly to a datastore (i.e. file://local.db or tikv://...), then we don't need to authenticate because we use an embedded (local) SurrealDB instance with auth disabled. + let client = if username.is_some() + && password.is_some() + && !endpoint.clone().into_endpoint()?.parse_kind()?.is_local() + { + debug!("Connecting to the database engine with authentication"); + let creds = CredentialsBuilder::default() + .with_username(username.as_deref()) + .with_password(password.as_deref()) + .with_namespace(namespace.as_str()) + .with_database(database.as_str()); + + let client = connect(endpoint).await?; + + debug!("Signing in to the database engine at '{:?}' level", auth_level); + match auth_level { + CredentialsLevel::Root => client.signin(creds.root()?).await?, + CredentialsLevel::Namespace => client.signin(creds.namespace()?).await?, + CredentialsLevel::Database => client.signin(creds.database()?).await?, + }; + + client + } else { + debug!("Connecting to the database engine without authentication"); + connect((endpoint, config)).await? + }; + + // Use the specified namespace / database + client.use_ns(namespace).use_db(database).await?; + // Import the data into the database + client.import(file).ml().await?; + info!("The SurrealML file was imported successfully"); + // Everything OK + Ok(()) +} diff --git a/src/cli/ml/mod.rs b/src/cli/ml/mod.rs new file mode 100644 index 00000000..57ae2092 --- /dev/null +++ b/src/cli/ml/mod.rs @@ -0,0 +1,22 @@ +mod export; +mod import; + +use self::export::ExportCommandArguments; +use self::import::ImportCommandArguments; +use crate::err::Error; +use clap::Subcommand; + +#[derive(Debug, Subcommand)] +pub enum MlCommand { + #[command(about = "Import a SurrealML model into an existing database")] + Import(ImportCommandArguments), + #[command(about = "Export a SurrealML model from an existing database")] + Export(ExportCommandArguments), +} + +pub async fn init(command: MlCommand) -> Result<(), Error> { + match command { + MlCommand::Import(args) => import::init(args).await, + MlCommand::Export(args) => export::init(args).await, + } +} diff --git a/src/cli/mod.rs b/src/cli/mod.rs index 3672bf1a..06885208 100644 --- a/src/cli/mod.rs +++ b/src/cli/mod.rs @@ -4,6 +4,7 @@ mod config; mod export; mod import; mod isready; +mod ml; mod sql; #[cfg(feature = "has-storage")] mod start; @@ -20,6 +21,7 @@ pub use config::CF; use export::ExportCommandArguments; use import::ImportCommandArguments; use isready::IsReadyCommandArguments; +use ml::MlCommand; use sql::SqlCommandArguments; #[cfg(feature = "has-storage")] use start::StartCommandArguments; @@ -68,6 +70,8 @@ enum Commands { Upgrade(UpgradeCommandArguments), #[command(about = "Start an SQL REPL in your terminal with pipe support")] Sql(SqlCommandArguments), + #[command(subcommand, about = "Manage SurrealML models within an existing database")] + Ml(MlCommand), #[command( about = "Check if the SurrealDB server is ready to accept connections", visible_alias = "isready" @@ -88,6 +92,7 @@ pub async fn init() -> ExitCode { Commands::Version(args) => version::init(args).await, Commands::Upgrade(args) => upgrade::init(args).await, Commands::Sql(args) => sql::init(args).await, + Commands::Ml(args) => ml::init(args).await, Commands::IsReady(args) => isready::init(args).await, Commands::Validate(args) => validate::init(args).await, }; diff --git a/src/cnf/mod.rs b/src/cnf/mod.rs index d539e4e2..e8fe5459 100644 --- a/src/cnf/mod.rs +++ b/src/cnf/mod.rs @@ -29,31 +29,50 @@ pub const APP_ENDPOINT: &str = "https://surrealdb.com/app"; #[cfg(feature = "has-storage")] pub const WEBSOCKET_PING_FREQUENCY: Duration = Duration::from_secs(5); -/// Set the maximum WebSocket frame size to 16mb +/// What is the maximum WebSocket frame size (defaults to 16 MiB) #[cfg(feature = "has-storage")] pub static WEBSOCKET_MAX_FRAME_SIZE: Lazy = Lazy::new(|| { - let default = 16 << 20; - std::env::var("SURREAL_WEBSOCKET_MAX_FRAME_SIZE") - .map(|v| v.parse::().unwrap_or(default)) - .unwrap_or(default) + option_env!("SURREAL_WEBSOCKET_MAX_FRAME_SIZE") + .and_then(|s| s.parse::().ok()) + .unwrap_or(16 << 20) }); -/// Set the maximum WebSocket frame size to 128mb +/// What is the maximum WebSocket message size (defaults to 128 MiB) #[cfg(feature = "has-storage")] pub static WEBSOCKET_MAX_MESSAGE_SIZE: Lazy = Lazy::new(|| { - let default = 128 << 20; - std::env::var("SURREAL_WEBSOCKET_MAX_MESSAGE_SIZE") - .map(|v| v.parse::().unwrap_or(default)) + option_env!("SURREAL_WEBSOCKET_MAX_MESSAGE_SIZE") + .and_then(|s| s.parse::().ok()) + .unwrap_or(128 << 20) +}); + +/// How many concurrent tasks can be handled on each WebSocket (defaults to 24) +#[cfg(feature = "has-storage")] +pub static WEBSOCKET_MAX_CONCURRENT_REQUESTS: Lazy = Lazy::new(|| { + option_env!("SURREAL_WEBSOCKET_MAX_CONCURRENT_REQUESTS") + .and_then(|s| s.parse::().ok()) + .unwrap_or(24) +}); + +/// What is the runtime thread memory stack size (defaults to 10MiB) +#[cfg(feature = "has-storage")] +pub static RUNTIME_STACK_SIZE: Lazy = Lazy::new(|| { + // Stack frames are generally larger in debug mode. + let default = if cfg!(debug_assertions) { + 20 * 1024 * 1024 // 20MiB in debug mode + } else { + 10 * 1024 * 1024 // 10MiB in release mode + }; + option_env!("SURREAL_RUNTIME_STACK_SIZE") + .and_then(|s| s.parse::().ok()) .unwrap_or(default) }); -/// How many concurrent tasks can be handled in a WebSocket +/// How many threads which can be started for blocking operations (defaults to 512) #[cfg(feature = "has-storage")] -pub static WEBSOCKET_MAX_CONCURRENT_REQUESTS: Lazy = Lazy::new(|| { - let default = 24; - std::env::var("SURREAL_WEBSOCKET_MAX_CONCURRENT_REQUESTS") - .map(|v| v.parse::().unwrap_or(default)) - .unwrap_or(default) +pub static RUNTIME_MAX_BLOCKING_THREADS: Lazy = Lazy::new(|| { + option_env!("SURREAL_RUNTIME_MAX_BLOCKING_THREADS") + .and_then(|s| s.parse::().ok()) + .unwrap_or(512) }); /// The version identifier of this build diff --git a/src/err/mod.rs b/src/err/mod.rs index 5bae179e..f8fb94bb 100644 --- a/src/err/mod.rs +++ b/src/err/mod.rs @@ -1,6 +1,7 @@ use crate::cli::abstraction::auth::Error as SurrealAuthError; use axum::extract::rejection::TypedHeaderRejection; use axum::response::{IntoResponse, Response}; +use axum::Error as AxumError; use axum::Json; use base64::DecodeError as Base64Error; use http::{HeaderName, StatusCode}; @@ -48,6 +49,9 @@ pub enum Error { #[error("Couldn't open the specified file: {0}")] Io(#[from] IoError), + #[error("There was an error with the network: {0}")] + Axum(#[from] AxumError), + #[error("There was an error serializing to JSON: {0}")] Json(#[from] JsonError), @@ -60,11 +64,15 @@ pub enum Error { #[error("There was an error with the remote request: {0}")] Remote(#[from] ReqwestError), + #[error("There was an error with auth: {0}")] + Auth(#[from] SurrealAuthError), + #[error("There was an error with the node agent")] NodeAgent, - #[error("There was an error with auth: {0}")] - Auth(#[from] SurrealAuthError), + /// Statement has been deprecated + #[error("{0}")] + Other(String), } impl From for String { diff --git a/src/main.rs b/src/main.rs index 24364720..de8679da 100644 --- a/src/main.rs +++ b/src/main.rs @@ -42,15 +42,12 @@ fn main() -> ExitCode { /// Rust's default thread stack size of 2MiB doesn't allow sufficient recursion depth. fn with_enough_stack(fut: impl Future + Send) -> T { - let stack_size = 10 * 1024 * 1024; - - // Stack frames are generally larger in debug mode. - #[cfg(debug_assertions)] - let stack_size = stack_size * 2; - + // Start a Tokio runtime with custom configuration tokio::runtime::Builder::new_multi_thread() .enable_all() - .thread_stack_size(stack_size) + .max_blocking_threads(*cnf::RUNTIME_MAX_BLOCKING_THREADS) + .thread_stack_size(*cnf::RUNTIME_STACK_SIZE) + .thread_name("surrealdb-worker") .build() .unwrap() .block_on(fut) diff --git a/src/net/export.rs b/src/net/export.rs index be3a0849..ed12d6bb 100644 --- a/src/net/export.rs +++ b/src/net/export.rs @@ -9,6 +9,9 @@ use http::StatusCode; use http_body::Body as HttpBody; use hyper::body::Body; use surrealdb::dbs::Session; +use surrealdb::iam::check::check_ns_db; +use surrealdb::iam::Action::View; +use surrealdb::iam::ResourceKind::Any; pub(super) fn router() -> Router where @@ -18,35 +21,27 @@ where Router::new().route("/export", get(handler)) } -async fn handler( - Extension(session): Extension, -) -> Result { +async fn handler(Extension(session): Extension) -> Result { // Get the datastore reference let db = DB.get().unwrap(); - // Extract the NS header value - let nsv = match session.ns.clone() { - Some(ns) => ns, - None => return Err(Error::NoNamespace), - }; - // Extract the DB header value - let dbv = match session.db.clone() { - Some(db) => db, - None => return Err(Error::NoDatabase), - }; // Create a chunked response - let (mut chn, bdy) = Body::channel(); + let (mut chn, body) = Body::channel(); + // Ensure a NS and DB are set + let (nsv, dbv) = check_ns_db(&session)?; + // Check the permissions level + db.check(&session, View, Any.on_db(&nsv, &dbv))?; // Create a new bounded channel let (snd, rcv) = surrealdb::channel::bounded(1); - - let export_job = db.export(&session, nsv, dbv, snd).await.map_err(Error::from)?; + // Start the export task + let task = db.export(&session, snd).await?; // Spawn a new database export job - tokio::spawn(export_job); - // Process all processed values + tokio::spawn(task); + // Process all chunk values tokio::spawn(async move { while let Ok(v) = rcv.recv().await { let _ = chn.send_data(Bytes::from(v)).await; } }); // Return the chunked body - Ok(Response::builder().status(StatusCode::OK).body(bdy).unwrap()) + Ok(Response::builder().status(StatusCode::OK).body(body).unwrap()) } diff --git a/src/net/import.rs b/src/net/import.rs index f6d0c406..1902b534 100644 --- a/src/net/import.rs +++ b/src/net/import.rs @@ -1,3 +1,4 @@ +use super::headers::Accept; use crate::dbs::DB; use crate::err::Error; use crate::net::input::bytes_to_utf8; @@ -11,10 +12,10 @@ use axum::TypedHeader; use bytes::Bytes; use http_body::Body as HttpBody; use surrealdb::dbs::Session; +use surrealdb::iam::Action::Edit; +use surrealdb::iam::ResourceKind::Any; use tower_http::limit::RequestBodyLimitLayer; -use super::headers::Accept; - const MAX: usize = 1024 * 1024 * 1024 * 4; // 4 GiB pub(super) fn router() -> Router @@ -32,24 +33,26 @@ where async fn handler( Extension(session): Extension, - maybe_output: Option>, + accept: Option>, sql: Bytes, ) -> Result { // Get the datastore reference let db = DB.get().unwrap(); // Convert the body to a byte slice let sql = bytes_to_utf8(&sql)?; + // Check the permissions level + db.check(&session, Edit, Any.on_level(session.au.level().to_owned()))?; // Execute the sql query in the database match db.import(sql, &session).await { - Ok(res) => match maybe_output.as_deref() { + Ok(res) => match accept.as_deref() { // Simple serialization Some(Accept::ApplicationJson) => Ok(output::json(&output::simplify(res))), Some(Accept::ApplicationCbor) => Ok(output::cbor(&output::simplify(res))), Some(Accept::ApplicationPack) => Ok(output::pack(&output::simplify(res))), - // Internal serialization - Some(Accept::Surrealdb) => Ok(output::full(&res)), // Return nothing Some(Accept::ApplicationOctetStream) => Ok(output::none()), + // Internal serialization + Some(Accept::Surrealdb) => Ok(output::full(&res)), // An incorrect content-type was requested _ => Err(Error::InvalidType), }, diff --git a/src/net/key.rs b/src/net/key.rs index b74aee43..918416c0 100644 --- a/src/net/key.rs +++ b/src/net/key.rs @@ -13,6 +13,7 @@ use http_body::Body as HttpBody; use serde::Deserialize; use std::str; use surrealdb::dbs::Session; +use surrealdb::iam::check::check_ns_db; use surrealdb::sql::Value; use tower_http::limit::RequestBodyLimitLayer; @@ -68,12 +69,14 @@ where async fn select_all( Extension(session): Extension, - maybe_output: Option>, + accept: Option>, Path(table): Path, Query(query): Query, ) -> Result { // Get the datastore reference let db = DB.get().unwrap(); + // Ensure a NS and DB are set + let _ = check_ns_db(&session)?; // Specify the request statement let sql = match query.fields { None => "SELECT * FROM type::table($table) LIMIT $limit START $start", @@ -88,7 +91,7 @@ async fn select_all( }; // Execute the query and return the result match db.execute(sql, &session, Some(vars)).await { - Ok(res) => match maybe_output.as_deref() { + Ok(res) => match accept.as_deref() { // Simple serialization Some(Accept::ApplicationJson) => Ok(output::json(&output::simplify(res))), Some(Accept::ApplicationCbor) => Ok(output::cbor(&output::simplify(res))), @@ -105,13 +108,15 @@ async fn select_all( async fn create_all( Extension(session): Extension, - maybe_output: Option>, + accept: Option>, Path(table): Path, Query(params): Query, body: Bytes, ) -> Result { // Get the datastore reference let db = DB.get().unwrap(); + // Ensure a NS and DB are set + let _ = check_ns_db(&session)?; // Convert the HTTP request body let data = bytes_to_utf8(&body)?; // Parse the request body as JSON @@ -127,7 +132,7 @@ async fn create_all( }; // Execute the query and return the result match db.execute(sql, &session, Some(vars)).await { - Ok(res) => match maybe_output.as_deref() { + Ok(res) => match accept.as_deref() { // Simple serialization Some(Accept::ApplicationJson) => Ok(output::json(&output::simplify(res))), Some(Accept::ApplicationCbor) => Ok(output::cbor(&output::simplify(res))), @@ -147,13 +152,15 @@ async fn create_all( async fn update_all( Extension(session): Extension, - maybe_output: Option>, + accept: Option>, Path(table): Path, Query(params): Query, body: Bytes, ) -> Result { // Get the datastore reference let db = DB.get().unwrap(); + // Ensure a NS and DB are set + let _ = check_ns_db(&session)?; // Convert the HTTP request body let data = bytes_to_utf8(&body)?; // Parse the request body as JSON @@ -169,7 +176,7 @@ async fn update_all( }; // Execute the query and return the result match db.execute(sql, &session, Some(vars)).await { - Ok(res) => match maybe_output.as_deref() { + Ok(res) => match accept.as_deref() { // Simple serialization Some(Accept::ApplicationJson) => Ok(output::json(&output::simplify(res))), Some(Accept::ApplicationCbor) => Ok(output::cbor(&output::simplify(res))), @@ -189,13 +196,15 @@ async fn update_all( async fn modify_all( Extension(session): Extension, - maybe_output: Option>, + accept: Option>, Path(table): Path, Query(params): Query, body: Bytes, ) -> Result { // Get the datastore reference let db = DB.get().unwrap(); + // Ensure a NS and DB are set + let _ = check_ns_db(&session)?; // Convert the HTTP request body let data = bytes_to_utf8(&body)?; // Parse the request body as JSON @@ -211,7 +220,7 @@ async fn modify_all( }; // Execute the query and return the result match db.execute(sql, &session, Some(vars)).await { - Ok(res) => match maybe_output.as_deref() { + Ok(res) => match accept.as_deref() { // Simple serialization Some(Accept::ApplicationJson) => Ok(output::json(&output::simplify(res))), Some(Accept::ApplicationCbor) => Ok(output::cbor(&output::simplify(res))), @@ -231,12 +240,14 @@ async fn modify_all( async fn delete_all( Extension(session): Extension, - maybe_output: Option>, + accept: Option>, Path(table): Path, Query(params): Query, ) -> Result { // Get the datastore reference let db = DB.get().unwrap(); + // Ensure a NS and DB are set + let _ = check_ns_db(&session)?; // Specify the request statement let sql = "DELETE type::table($table) RETURN BEFORE"; // Specify the request variables @@ -246,7 +257,7 @@ async fn delete_all( }; // Execute the query and return the result match db.execute(sql, &session, Some(vars)).await { - Ok(res) => match maybe_output.as_deref() { + Ok(res) => match accept.as_deref() { // Simple serialization Some(Accept::ApplicationJson) => Ok(output::json(&output::simplify(res))), Some(Accept::ApplicationCbor) => Ok(output::cbor(&output::simplify(res))), @@ -267,12 +278,14 @@ async fn delete_all( async fn select_one( Extension(session): Extension, - maybe_output: Option>, + accept: Option>, Path((table, id)): Path<(String, String)>, Query(query): Query, ) -> Result { // Get the datastore reference let db = DB.get().unwrap(); + // Ensure a NS and DB are set + let _ = check_ns_db(&session)?; // Specify the request statement let sql = match query.fields { None => "SELECT * FROM type::thing($table, $id)", @@ -291,7 +304,7 @@ async fn select_one( }; // Execute the query and return the result match db.execute(sql, &session, Some(vars)).await { - Ok(res) => match maybe_output.as_deref() { + Ok(res) => match accept.as_deref() { // Simple serialization Some(Accept::ApplicationJson) => Ok(output::json(&output::simplify(res))), Some(Accept::ApplicationCbor) => Ok(output::cbor(&output::simplify(res))), @@ -308,13 +321,15 @@ async fn select_one( async fn create_one( Extension(session): Extension, - maybe_output: Option>, + accept: Option>, Query(params): Query, Path((table, id)): Path<(String, String)>, body: Bytes, ) -> Result { // Get the datastore reference let db = DB.get().unwrap(); + // Ensure a NS and DB are set + let _ = check_ns_db(&session)?; // Convert the HTTP request body let data = bytes_to_utf8(&body)?; // Parse the Record ID as a SurrealQL value @@ -336,7 +351,7 @@ async fn create_one( }; // Execute the query and return the result match db.execute(sql, &session, Some(vars)).await { - Ok(res) => match maybe_output.as_deref() { + Ok(res) => match accept.as_deref() { // Simple serialization Some(Accept::ApplicationJson) => Ok(output::json(&output::simplify(res))), Some(Accept::ApplicationCbor) => Ok(output::cbor(&output::simplify(res))), @@ -356,13 +371,15 @@ async fn create_one( async fn update_one( Extension(session): Extension, - maybe_output: Option>, + accept: Option>, Query(params): Query, Path((table, id)): Path<(String, String)>, body: Bytes, ) -> Result { // Get the datastore reference let db = DB.get().unwrap(); + // Ensure a NS and DB are set + let _ = check_ns_db(&session)?; // Convert the HTTP request body let data = bytes_to_utf8(&body)?; // Parse the Record ID as a SurrealQL value @@ -384,7 +401,7 @@ async fn update_one( }; // Execute the query and return the result match db.execute(sql, &session, Some(vars)).await { - Ok(res) => match maybe_output.as_deref() { + Ok(res) => match accept.as_deref() { // Simple serialization Some(Accept::ApplicationJson) => Ok(output::json(&output::simplify(res))), Some(Accept::ApplicationCbor) => Ok(output::cbor(&output::simplify(res))), @@ -404,13 +421,15 @@ async fn update_one( async fn modify_one( Extension(session): Extension, - maybe_output: Option>, + accept: Option>, Query(params): Query, Path((table, id)): Path<(String, String)>, body: Bytes, ) -> Result { // Get the datastore reference let db = DB.get().unwrap(); + // Ensure a NS and DB are set + let _ = check_ns_db(&session)?; // Convert the HTTP request body let data = bytes_to_utf8(&body)?; // Parse the Record ID as a SurrealQL value @@ -432,7 +451,7 @@ async fn modify_one( }; // Execute the query and return the result match db.execute(sql, &session, Some(vars)).await { - Ok(res) => match maybe_output.as_deref() { + Ok(res) => match accept.as_deref() { // Simple serialization Some(Accept::ApplicationJson) => Ok(output::json(&output::simplify(res))), Some(Accept::ApplicationCbor) => Ok(output::cbor(&output::simplify(res))), @@ -452,11 +471,13 @@ async fn modify_one( async fn delete_one( Extension(session): Extension, - maybe_output: Option>, + accept: Option>, Path((table, id)): Path<(String, String)>, ) -> Result { // Get the datastore reference let db = DB.get().unwrap(); + // Ensure a NS and DB are set + let _ = check_ns_db(&session)?; // Specify the request statement let sql = "DELETE type::thing($table, $id) RETURN BEFORE"; // Parse the Record ID as a SurrealQL value @@ -471,7 +492,7 @@ async fn delete_one( }; // Execute the query and return the result match db.execute(sql, &session, Some(vars)).await { - Ok(res) => match maybe_output.as_deref() { + Ok(res) => match accept.as_deref() { // Simple serialization Some(Accept::ApplicationJson) => Ok(output::json(&output::simplify(res))), Some(Accept::ApplicationCbor) => Ok(output::cbor(&output::simplify(res))), diff --git a/src/net/ml.rs b/src/net/ml.rs new file mode 100644 index 00000000..1a294212 --- /dev/null +++ b/src/net/ml.rs @@ -0,0 +1,123 @@ +//! This file defines the endpoints for the ML API for importing and exporting SurrealML models. +use crate::dbs::DB; +use crate::err::Error; +use crate::net::output; +use axum::extract::{BodyStream, DefaultBodyLimit, Path}; +use axum::response::IntoResponse; +use axum::response::Response; +use axum::routing::{get, post}; +use axum::Extension; +use axum::Router; +use bytes::Bytes; +use futures_util::StreamExt; +use http::StatusCode; +use http_body::Body as HttpBody; +use hyper::body::Body; +use surrealdb::dbs::Session; +use surrealdb::iam::check::check_ns_db; +use surrealdb::iam::Action::{Edit, View}; +use surrealdb::iam::ResourceKind::Model; +use surrealdb::kvs::{LockType::Optimistic, TransactionType::Read}; +use surrealdb::sql::statements::{DefineModelStatement, DefineStatement}; +use surrealml_core::storage::surml_file::SurMlFile; +use tower_http::limit::RequestBodyLimitLayer; + +const MAX: usize = 1024 * 1024 * 1024 * 4; // 4 GiB + +/// The router definition for the ML API endpoints. +pub(super) fn router() -> Router +where + B: HttpBody + Send + 'static, + B::Data: Send + Into, + B::Error: std::error::Error + Send + Sync + 'static, + S: Clone + Send + Sync + 'static, +{ + Router::new() + .route("/ml/import", post(import)) + .route("/ml/export/:name/:version", get(export)) + .route_layer(DefaultBodyLimit::disable()) + .layer(RequestBodyLimitLayer::new(MAX)) +} + +/// This endpoint allows the user to import a model into the database. +async fn import( + Extension(session): Extension, + mut stream: BodyStream, +) -> Result { + // Get the datastore reference + let db = DB.get().unwrap(); + // Ensure a NS and DB are set + let (nsv, dbv) = check_ns_db(&session)?; + // Check the permissions level + db.check(&session, Edit, Model.on_db(&nsv, &dbv))?; + // Create a new buffer + let mut buffer = Vec::new(); + // Load all the uploaded file chunks + while let Some(chunk) = stream.next().await { + buffer.extend_from_slice(&chunk?); + } + // Check that the SurrealML file is valid + let file = match SurMlFile::from_bytes(buffer) { + Ok(file) => file, + Err(err) => return Err(Error::Other(err.to_string())), + }; + // Convert the file back in to raw bytes + let data = file.to_bytes(); + // Calculate the hash of the model file + let hash = surrealdb::obs::hash(&data); + // Calculate the path of the model file + let path = format!( + "ml/{nsv}/{dbv}/{}-{}-{hash}.surml", + file.header.name.to_string(), + file.header.version.to_string() + ); + // Insert the file data in to the store + surrealdb::obs::put(&path, data).await?; + // Insert the model in to the database + db.process( + DefineStatement::Model(DefineModelStatement { + hash, + name: file.header.name.to_string().into(), + version: file.header.version.to_string(), + comment: Some(file.header.description.to_string().into()), + ..Default::default() + }) + .into(), + &session, + None, + ) + .await?; + // + Ok(output::none()) +} + +/// This endpoint allows the user to export a model from the database. +async fn export( + Extension(session): Extension, + Path((name, version)): Path<(String, String)>, +) -> Result { + // Get the datastore reference + let db = DB.get().unwrap(); + // Ensure a NS and DB are set + let (nsv, dbv) = check_ns_db(&session)?; + // Check the permissions level + db.check(&session, View, Model.on_db(&nsv, &dbv))?; + // Start a new readonly transaction + let mut tx = db.transaction(Read, Optimistic).await?; + // Attempt to get the model definition + let info = tx.get_db_model(&nsv, &dbv, &name, &version).await?; + // Calculate the path of the model file + let path = format!("ml/{nsv}/{dbv}/{name}-{version}-{}.surml", info.hash); + // Export the file data in to the store + let mut data = surrealdb::obs::stream(path).await?; + // Create a chunked response + let (mut chn, body) = Body::channel(); + // Process all stream values + tokio::spawn(async move { + while let Some(Ok(v)) = data.next().await { + let _ = chn.send_data(v).await; + } + }); + // Return the streamed body + Ok(Response::builder().status(StatusCode::OK).body(body).unwrap()) +} diff --git a/src/net/mod.rs b/src/net/mod.rs index e62f325d..50583c57 100644 --- a/src/net/mod.rs +++ b/src/net/mod.rs @@ -17,6 +17,9 @@ mod sync; mod tracer; mod version; +#[cfg(feature = "ml")] +mod ml; + use axum::response::Redirect; use axum::routing::get; use axum::{middleware, Router}; @@ -150,8 +153,12 @@ pub async fn init(ct: CancellationToken) -> Result<(), Error> { .merge(sql::router()) .merge(signin::router()) .merge(signup::router()) - .merge(key::router()) - .layer(service); + .merge(key::router()); + + #[cfg(feature = "ml")] + let axum_app = axum_app.merge(ml::router()); + + let axum_app = axum_app.layer(service); // Setup the graceful shutdown let handle = Handle::new(); diff --git a/src/net/signin.rs b/src/net/signin.rs index 28ce4649..3129b1e5 100644 --- a/src/net/signin.rs +++ b/src/net/signin.rs @@ -51,7 +51,7 @@ where async fn handler( Extension(mut session): Extension, - maybe_output: Option>, + accept: Option>, body: Bytes, ) -> Result { // Get a database reference @@ -65,15 +65,15 @@ async fn handler( match surrealdb::iam::signin::signin(kvs, &mut session, vars).await.map_err(Error::from) { // Authentication was successful - Ok(v) => match maybe_output.as_deref() { + Ok(v) => match accept.as_deref() { // Simple serialization Some(Accept::ApplicationJson) => Ok(output::json(&Success::new(v))), Some(Accept::ApplicationCbor) => Ok(output::cbor(&Success::new(v))), Some(Accept::ApplicationPack) => Ok(output::pack(&Success::new(v))), - // Internal serialization - Some(Accept::Surrealdb) => Ok(output::full(&Success::new(v))), // Text serialization Some(Accept::TextPlain) => Ok(output::text(v.unwrap_or_default())), + // Internal serialization + Some(Accept::Surrealdb) => Ok(output::full(&Success::new(v))), // Return nothing None => Ok(output::none()), // An incorrect content-type was requested diff --git a/src/net/signup.rs b/src/net/signup.rs index ae992263..a1a07410 100644 --- a/src/net/signup.rs +++ b/src/net/signup.rs @@ -49,7 +49,7 @@ where async fn handler( Extension(mut session): Extension, - maybe_output: Option>, + accept: Option>, body: Bytes, ) -> Result { // Get a database reference @@ -63,15 +63,15 @@ async fn handler( match surrealdb::iam::signup::signup(kvs, &mut session, vars).await.map_err(Error::from) { // Authentication was successful - Ok(v) => match maybe_output.as_deref() { + Ok(v) => match accept.as_deref() { // Simple serialization Some(Accept::ApplicationJson) => Ok(output::json(&Success::new(v))), Some(Accept::ApplicationCbor) => Ok(output::cbor(&Success::new(v))), Some(Accept::ApplicationPack) => Ok(output::pack(&Success::new(v))), - // Internal serialization - Some(Accept::Surrealdb) => Ok(output::full(&Success::new(v))), // Text serialization Some(Accept::TextPlain) => Ok(output::text(v.unwrap_or_default())), + // Internal serialization + Some(Accept::Surrealdb) => Ok(output::full(&Success::new(v))), // Return nothing None => Ok(output::none()), // An incorrect content-type was requested diff --git a/tests/linear_test.surml b/tests/linear_test.surml new file mode 100644 index 0000000000000000000000000000000000000000..e9deb2851e2f2516ad2967e166ad75776e398762 GIT binary patch literal 552 zcmZ`!O-{l<7^M}Z4iQN=B&;T^5J@|&mI|gd5?3T7E?l{3rtM&BN~dm8kuHJvFw0mj5tXedSt^o8<_~9p@s~@K+5>MWZtI44?P~z zxX>X5DOxey1tEh=>XQg26@x(@ko=KbLb`w$mVZDOGvY%U!;mvrHG@Sp{6HX%eVE5} zNm?kXno8lGJad|QK`~l(L_Hi8V{xR0+*qUCeLYWuCSL-$0m2HjXdJI5lR4(|EOJJ2 zM~%Xm;8b02XvUL_JE+_G3-k^u0=NXtN};$RX5c@)EEwJ9kHxj`_xAEjJ}eVEk0Bp) zB%v)#x4mZd{?W(V07_yB?IfK^r>~#(J8WLvALm7603=C}>p+x5H(wZA3t=j@Ti>XJ BqqP74 literal 0 HcmV?d00001 diff --git a/tests/ml_integration.rs b/tests/ml_integration.rs new file mode 100644 index 00000000..631b43a5 --- /dev/null +++ b/tests/ml_integration.rs @@ -0,0 +1,149 @@ +// RUST_LOG=warn cargo make ci-ml-integration +mod common; + +#[cfg(feature = "ml")] +mod ml_integration { + + use super::*; + use http::{header, StatusCode}; + use hyper::Body; + use serde::{Deserialize, Serialize}; + use std::sync::atomic::{AtomicBool, Ordering}; + use std::time::Duration; + use surrealml_core::storage::stream_adapter::StreamAdapter; + use test_log::test; + use ulid::Ulid; + + static LOCK: AtomicBool = AtomicBool::new(false); + + #[derive(Serialize, Deserialize, Debug)] + struct Data { + result: f64, + status: String, + time: String, + } + + struct LockHandle; + + impl LockHandle { + fn acquire_lock() -> Self { + while LOCK.compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed) + != Ok(false) + { + std::thread::sleep(Duration::from_millis(100)); + } + LockHandle + } + } + + impl Drop for LockHandle { + fn drop(&mut self) { + LOCK.store(false, Ordering::Release); + } + } + + async fn upload_file(addr: &str, ns: &str, db: &str) -> Result<(), Box> { + let generator = StreamAdapter::new(5, "./tests/linear_test.surml".to_string()); + let body = Body::wrap_stream(generator); + // Prepare HTTP client + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert("NS", ns.parse()?); + headers.insert("DB", db.parse()?); + let client = reqwest::Client::builder() + .connect_timeout(Duration::from_secs(1)) + .default_headers(headers) + .build()?; + // Send HTTP request + let res = client + .post(format!("http://{addr}/ml/import")) + .basic_auth(common::USER, Some(common::PASS)) + .body(body) + .send() + .await?; + // Check response code + assert_eq!(res.status(), StatusCode::OK); + Ok(()) + } + + #[test(tokio::test)] + async fn upload_model() -> Result<(), Box> { + let _lock = LockHandle::acquire_lock(); + let (addr, _server) = common::start_server_with_defaults().await.unwrap(); + let ns = Ulid::new().to_string(); + let db = Ulid::new().to_string(); + upload_file(&addr, &ns, &db).await?; + Ok(()) + } + + #[test(tokio::test)] + async fn raw_compute() -> Result<(), Box> { + let _lock = LockHandle::acquire_lock(); + let (addr, _server) = common::start_server_with_defaults().await.unwrap(); + + let ns = Ulid::new().to_string(); + let db = Ulid::new().to_string(); + + upload_file(&addr, &ns, &db).await?; + + // Prepare HTTP client + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert("NS", ns.parse()?); + headers.insert("DB", db.parse()?); + headers.insert(header::ACCEPT, "application/json".parse()?); + let client = reqwest::Client::builder() + .connect_timeout(Duration::from_millis(10)) + .default_headers(headers) + .build()?; + + // perform an SQL query to check if the model is available + { + let res = client + .post(format!("http://{addr}/sql")) + .basic_auth(common::USER, Some(common::PASS)) + .body(r#"ml::Prediction<0.0.1>([1.0, 1.0]);"#) + .send() + .await?; + assert!(res.status().is_success(), "body: {}", res.text().await?); + let body = res.text().await?; + let deserialized_data: Vec = serde_json::from_str(&body)?; + assert_eq!(deserialized_data[0].result, 0.9998061656951904); + } + Ok(()) + } + + #[test(tokio::test)] + async fn buffered_compute() -> Result<(), Box> { + let _lock = LockHandle::acquire_lock(); + let (addr, _server) = common::start_server_with_defaults().await.unwrap(); + + let ns = Ulid::new().to_string(); + let db = Ulid::new().to_string(); + + upload_file(&addr, &ns, &db).await?; + + // Prepare HTTP client + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert("NS", ns.parse()?); + headers.insert("DB", db.parse()?); + headers.insert(header::ACCEPT, "application/json".parse()?); + let client = reqwest::Client::builder() + .connect_timeout(Duration::from_millis(10)) + .default_headers(headers) + .build()?; + + // perform an SQL query to check if the model is available + { + let res = client + .post(format!("http://{addr}/sql")) + .basic_auth(common::USER, Some(common::PASS)) + .body(r#"ml::Prediction<0.0.1>({squarefoot: 500.0, num_floors: 1.0});"#) + .send() + .await?; + assert!(res.status().is_success(), "body: {}", res.text().await?); + let body = res.text().await?; + let deserialized_data: Vec = serde_json::from_str(&body)?; + assert_eq!(deserialized_data[0].result, 177206.21875); + } + Ok(()) + } +}