Add support for ML model storage and execution (#3015)

This commit is contained in:
Maxwell Flitton 2023-12-12 13:51:43 +00:00 committed by GitHub
parent fc66e2f4ea
commit 2ae8416791
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
64 changed files with 1815 additions and 320 deletions

View file

@ -199,6 +199,34 @@ jobs:
- name: Run HTTP integration tests
run: cargo make ci-http-integration
ml-support:
name: ML integration tests
runs-on: ubuntu-latest
steps:
- name: Install stable toolchain
uses: dtolnay/rust-toolchain@stable
with:
toolchain: 1.71.1
- name: Checkout sources
uses: actions/checkout@v3
- name: Setup cache
uses: Swatinem/rust-cache@v2
with:
save-if: ${{ github.ref == 'refs/heads/main' }}
- name: Install dependencies
run: |
sudo apt-get -y update
- name: Install cargo-make
run: cargo install --debug --locked cargo-make
- name: Run ML integration tests
run: cargo make ci-ml-integration
ws-server:
name: WebSocket integration tests
runs-on: ubuntu-latest

2
.gitignore vendored
View file

@ -45,5 +45,7 @@ Temporary Items
# Specific
# -----------------------------------
/cache/
/store/
surreal
history.txt

187
Cargo.lock generated
View file

@ -1813,6 +1813,18 @@ dependencies = [
"windows-sys 0.48.0",
]
[[package]]
name = "filetime"
version = "0.2.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d4029edd3e734da6fe05b6cd7bd2960760a616bd2ddd0d59a0124746d6272af0"
dependencies = [
"cfg-if",
"libc",
"redox_syscall 0.3.5",
"windows-sys 0.48.0",
]
[[package]]
name = "findshlibs"
version = "0.10.2"
@ -2917,7 +2929,7 @@ checksum = "85c833ca1e66078851dba29046874e38f08b2c883700aa29a03ddd3b23814ee8"
dependencies = [
"bitflags 2.4.1",
"libc",
"redox_syscall",
"redox_syscall 0.4.1",
]
[[package]]
@ -3036,6 +3048,16 @@ version = "0.7.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94"
[[package]]
name = "matrixmultiply"
version = "0.3.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7574c1cf36da4798ab73da5b215bbf444f50718207754cb522201d78d1cd0ff2"
dependencies = [
"autocfg",
"rawpointer",
]
[[package]]
name = "md-5"
version = "0.10.6"
@ -3181,6 +3203,19 @@ dependencies = [
"tempfile",
]
[[package]]
name = "ndarray"
version = "0.15.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "adb12d4e967ec485a5f71c6311fe28158e9d6f4bc4a447b474184d0f91a8fa32"
dependencies = [
"matrixmultiply",
"num-complex",
"num-integer",
"num-traits",
"rawpointer",
]
[[package]]
name = "new_debug_unreachable"
version = "1.0.4"
@ -3272,6 +3307,15 @@ dependencies = [
"zeroize",
]
[[package]]
name = "num-complex"
version = "0.4.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1ba157ca0885411de85d6ca030ba7e2a83a28636056c7c699b07c8b6f7383214"
dependencies = [
"num-traits",
]
[[package]]
name = "num-format"
version = "0.4.4"
@ -3332,6 +3376,27 @@ dependencies = [
"memchr",
]
[[package]]
name = "object_store"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2524735495ea1268be33d200e1ee97455096a0846295a21548cd2f3541de7050"
dependencies = [
"async-trait",
"bytes",
"chrono",
"futures",
"humantime",
"itertools 0.11.0",
"parking_lot",
"percent-encoding",
"snafu",
"tokio",
"tracing",
"url",
"walkdir",
]
[[package]]
name = "once_cell"
version = "1.18.0"
@ -3467,6 +3532,24 @@ dependencies = [
"tokio-stream",
]
[[package]]
name = "ort"
version = "1.16.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "889dca4c98efa21b1ba54ddb2bde44fd4920d910f492b618351f839d8428d79d"
dependencies = [
"flate2",
"lazy_static",
"libc",
"libloading",
"ndarray",
"tar",
"thiserror",
"tracing",
"vswhom",
"winapi",
]
[[package]]
name = "overload"
version = "0.1.1"
@ -3497,7 +3580,7 @@ checksum = "4c42a9226546d68acdd9c0a280d17ce19bfe27a46bf68784e4066115788d008e"
dependencies = [
"cfg-if",
"libc",
"redox_syscall",
"redox_syscall 0.4.1",
"smallvec",
"windows-targets 0.48.5",
]
@ -4017,6 +4100,12 @@ dependencies = [
"rand_core 0.6.4",
]
[[package]]
name = "rawpointer"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3"
[[package]]
name = "rayon"
version = "1.8.0"
@ -4049,6 +4138,15 @@ dependencies = [
"yasna",
]
[[package]]
name = "redox_syscall"
version = "0.3.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "567664f262709473930a4bf9e51bf2ebf3348f2e748ccc50dea20646858f8f29"
dependencies = [
"bitflags 1.3.2",
]
[[package]]
name = "redox_syscall"
version = "0.4.1"
@ -4950,6 +5048,28 @@ dependencies = [
"serde",
]
[[package]]
name = "snafu"
version = "0.7.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e4de37ad025c587a29e8f3f5605c00f70b98715ef90b9061a815b9e59e9042d6"
dependencies = [
"doc-comment",
"snafu-derive",
]
[[package]]
name = "snafu-derive"
version = "0.7.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "990079665f075b699031e9c08fd3ab99be5029b96f3b78dc0709e8f77e4efebf"
dependencies = [
"heck",
"proc-macro2",
"quote",
"syn 1.0.109",
]
[[package]]
name = "snap"
version = "1.1.0"
@ -5120,6 +5240,7 @@ dependencies = [
"ipnet",
"jemallocator",
"mimalloc",
"ndarray",
"nix 0.27.1",
"once_cell",
"opentelemetry",
@ -5136,6 +5257,7 @@ dependencies = [
"serde_json",
"serial_test",
"surrealdb",
"surrealml-core",
"temp-env",
"tempfile",
"test-log",
@ -5184,6 +5306,7 @@ dependencies = [
"futures-concurrency",
"fuzzy-matcher",
"geo 0.27.0",
"hex",
"indexmap 2.1.0",
"indxdb",
"ipnet",
@ -5192,8 +5315,10 @@ dependencies = [
"md-5",
"nanoid",
"native-tls",
"ndarray",
"nom",
"num_cpus",
"object_store",
"once_cell",
"path-clean",
"pbkdf2",
@ -5224,6 +5349,7 @@ dependencies = [
"surrealdb-derive",
"surrealdb-jsonwebtoken",
"surrealdb-tikv-client",
"surrealml-core",
"temp-dir",
"test-log",
"thiserror",
@ -5299,6 +5425,21 @@ dependencies = [
"tonic 0.9.2",
]
[[package]]
name = "surrealml-core"
version = "0.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2aabe3c44a73d6f7a3d069400a4966d2ffcf643fb1d60399b1746787abc9dd12"
dependencies = [
"bytes",
"futures-core",
"futures-util",
"ndarray",
"once_cell",
"ort",
"regex",
]
[[package]]
name = "symbolic-common"
version = "12.7.0"
@ -5389,6 +5530,17 @@ version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369"
[[package]]
name = "tar"
version = "0.4.40"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b16afcea1f22891c49a00c751c7b63b2233284064f11a200fc624137c51e2ddb"
dependencies = [
"filetime",
"libc",
"xattr",
]
[[package]]
name = "temp-dir"
version = "0.1.11"
@ -5413,7 +5565,7 @@ checksum = "7ef1adac450ad7f4b3c28589471ade84f25f731a7a0fe30d71dfa9f60fd808e5"
dependencies = [
"cfg-if",
"fastrand 2.0.1",
"redox_syscall",
"redox_syscall 0.4.1",
"rustix",
"windows-sys 0.48.0",
]
@ -6103,6 +6255,26 @@ version = "0.9.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f"
[[package]]
name = "vswhom"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "be979b7f07507105799e854203b470ff7c78a1639e330a58f183b5fea574608b"
dependencies = [
"libc",
"vswhom-sys",
]
[[package]]
name = "vswhom-sys"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d3b17ae1f6c8a2b28506cd96d412eebf83b4a0ff2cbefeeb952f2f9dfa44ba18"
dependencies = [
"cc",
"libc",
]
[[package]]
name = "waker-fn"
version = "1.1.1"
@ -6501,6 +6673,15 @@ dependencies = [
"tap",
]
[[package]]
name = "xattr"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f4686009f71ff3e5c4dbcf1a282d0a44db3f021ba69350cd42086b3e5f1c6985"
dependencies = [
"libc",
]
[[package]]
name = "xml-rs"
version = "0.8.19"

View file

@ -8,7 +8,7 @@ authors = ["Tobie Morgan Hitchcock <tobie@surrealdb.com>"]
[features]
# Public features
default = ["storage-mem", "storage-rocksdb", "scripting", "http"]
default = ["storage-mem", "storage-rocksdb", "scripting", "http", "ml"]
storage-mem = ["surrealdb/kv-mem", "has-storage"]
storage-rocksdb = ["surrealdb/kv-rocksdb", "has-storage"]
storage-speedb = ["surrealdb/kv-speedb", "has-storage"]
@ -17,6 +17,7 @@ storage-fdb = ["surrealdb/kv-fdb-7_1", "has-storage"]
scripting = ["surrealdb/scripting"]
http = ["surrealdb/http"]
http-compression = []
ml = ["surrealdb/ml", "surrealml-core"]
# Private features
has-storage = []
@ -49,6 +50,7 @@ http = "0.2.11"
http-body = "0.4.5"
hyper = "0.14.27"
ipnet = "2.9.0"
ndarray = { version = "0.15.6", optional = true }
once_cell = "1.18.0"
opentelemetry = { version = "0.19", features = ["rt-tokio"] }
opentelemetry-otlp = { version = "0.12.0", features = ["metrics"] }
@ -61,6 +63,7 @@ serde_cbor = "0.11.2"
serde_json = "1.0.108"
serde_pack = { version = "1.1.2", package = "rmp-serde" }
surrealdb = { path = "lib", features = ["protocol-http", "protocol-ws", "rustls"] }
surrealml-core = { version = "0.0.2", optional = true}
tempfile = "3.8.1"
thiserror = "1.0.50"
tokio = { version = "1.34.0", features = ["macros", "signal"] }

View file

@ -39,6 +39,12 @@ command = "cargo"
env = { RUST_LOG={ value = "ws_integration=debug", condition = { env_not_set = ["RUST_LOG"] } } }
args = ["test", "--locked", "--no-default-features", "--features", "storage-mem", "--workspace", "--test", "ws_integration", "--", "ws_integration", "--nocapture"]
[tasks.ci-ml-integration]
category = "CI - INTEGRATION TESTS"
command = "cargo"
env = { RUST_LOG={ value = "cli_integration::common=debug", condition = { env_not_set = ["RUST_LOG"] } } }
args = ["test", "--locked", "--features", "storage-mem,ml", "--workspace", "--test", "ml_integration", "--", "ml_integration", "--nocapture"]
[tasks.ci-workspace-coverage]
category = "CI - INTEGRATION TESTS"
command = "cargo"

View file

@ -10,7 +10,7 @@ reduce_output = true
default_to_workspace = false
[env]
DEV_FEATURES={ value = "storage-mem,http,scripting", condition = { env_not_set = ["DEV_FEATURES"] } }
DEV_FEATURES={ value = "storage-mem,scripting,http,ml", condition = { env_not_set = ["DEV_FEATURES"] } }
SURREAL_LOG={ value = "trace", condition = { env_not_set = ["SURREAL_LOG"] } }
SURREAL_USER={ value = "root", condition = { env_not_set = ["SURREAL_USER"] } }
SURREAL_PASS={ value = "root", condition = { env_not_set = ["SURREAL_PASS"] } }

View file

@ -37,6 +37,7 @@ scripting = ["dep:js"]
http = ["dep:reqwest"]
native-tls = ["dep:native-tls", "reqwest?/native-tls", "tokio-tungstenite?/native-tls"]
rustls = ["dep:rustls", "reqwest?/rustls-tls", "tokio-tungstenite?/rustls-tls-webpki-roots"]
ml = ["surrealml-core", "ndarray"]
# Private features
kv-fdb = ["foundationdb", "tokio/time"]
@ -74,6 +75,7 @@ futures = "0.3.29"
futures-concurrency = "7.4.3"
fuzzy-matcher = "0.3.7"
geo = { version = "0.27.0", features = ["use-serde"] }
hex = { version = "0.4.3", optional = false }
indexmap = { version = "2.1.0", features = ["serde"] }
indxdb = { version = "0.4.0", optional = true }
ipnet = "2.9.0"
@ -84,8 +86,10 @@ lru = "0.12.1"
md-5 = "0.10.6"
nanoid = "0.4.0"
native-tls = { version = "0.2.11", optional = true }
ndarray = { version = "0.15.6", optional = true }
nom = { version = "7.1.3", features = ["alloc"] }
num_cpus = "1.16.0"
object_store = { version = "0.8.0", optional = false }
once_cell = "1.18.0"
path-clean = "1.0.1"
pbkdf2 = { version = "0.12.2", features = ["simple"] }
@ -109,6 +113,7 @@ sha2 = "0.10.8"
snap = "1.1.0"
speedb = { version = "0.0.4", features = ["lz4", "snappy"], optional = true }
storekey = "0.5.0"
surrealml-core = { version = "0.0.2", optional = true }
thiserror = "1.0.50"
tikv = { version = "0.2.0-surreal.2", default-features = false, package = "surrealdb-tikv-client", optional = true }
tokio-util = { version = "0.7.10", optional = true, features = ["compat"] }

View file

@ -113,6 +113,16 @@ pub enum DbResponse {
Other(Value),
}
#[derive(Debug)]
#[allow(dead_code)] // used by ML model import and export functions
pub(crate) enum MlConfig {
Import,
Export {
name: String,
version: String,
},
}
/// Holds the parameters given to the caller
#[derive(Debug, Default)]
#[allow(dead_code)] // used by the embedded and remote connections
@ -122,6 +132,7 @@ pub struct Param {
pub(crate) file: Option<PathBuf>,
pub(crate) bytes_sender: Option<channel::Sender<Result<Vec<u8>>>>,
pub(crate) notification_sender: Option<channel::Sender<Notification>>,
pub(crate) ml_config: Option<MlConfig>,
}
impl Param {

View file

@ -28,6 +28,8 @@ pub(crate) mod wasm;
use crate::api::conn::DbResponse;
use crate::api::conn::Method;
#[cfg(not(target_arch = "wasm32"))]
use crate::api::conn::MlConfig;
use crate::api::conn::Param;
use crate::api::engine::create_statement;
use crate::api::engine::delete_statement;
@ -44,9 +46,27 @@ use crate::api::Surreal;
use crate::dbs::Notification;
use crate::dbs::Response;
use crate::dbs::Session;
#[cfg(feature = "ml")]
#[cfg(not(target_arch = "wasm32"))]
use crate::iam::check::check_ns_db;
#[cfg(feature = "ml")]
#[cfg(not(target_arch = "wasm32"))]
use crate::iam::Action;
#[cfg(feature = "ml")]
#[cfg(not(target_arch = "wasm32"))]
use crate::iam::ResourceKind;
use crate::kvs::Datastore;
#[cfg(feature = "ml")]
#[cfg(not(target_arch = "wasm32"))]
use crate::kvs::{LockType, TransactionType};
use crate::method::Stats;
use crate::opt::IntoEndpoint;
#[cfg(feature = "ml")]
#[cfg(not(target_arch = "wasm32"))]
use crate::sql::statements::DefineModelStatement;
#[cfg(feature = "ml")]
#[cfg(not(target_arch = "wasm32"))]
use crate::sql::statements::DefineStatement;
use crate::sql::statements::KillStatement;
use crate::sql::Array;
use crate::sql::Query;
@ -56,6 +76,9 @@ use crate::sql::Strand;
use crate::sql::Uuid;
use crate::sql::Value;
use channel::Sender;
#[cfg(feature = "ml")]
#[cfg(not(target_arch = "wasm32"))]
use futures::StreamExt;
use indexmap::IndexMap;
use std::collections::BTreeMap;
use std::collections::HashMap;
@ -65,6 +88,9 @@ use std::mem;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;
#[cfg(feature = "ml")]
#[cfg(not(target_arch = "wasm32"))]
use surrealml_core::storage::surml_file::SurMlFile;
#[cfg(not(target_arch = "wasm32"))]
use tokio::fs::OpenOptions;
#[cfg(not(target_arch = "wasm32"))]
@ -405,11 +431,34 @@ async fn take(one: bool, responses: Vec<Response>) -> Result<Value> {
async fn export(
kvs: &Datastore,
sess: &Session,
ns: String,
db: String,
chn: channel::Sender<Vec<u8>>,
ml_config: Option<MlConfig>,
) -> Result<()> {
if let Err(error) = kvs.export(sess, ns, db, chn).await?.await {
match ml_config {
#[cfg(feature = "ml")]
Some(MlConfig::Export {
name,
version,
}) => {
// Ensure a NS and DB are set
let (nsv, dbv) = check_ns_db(sess)?;
// Check the permissions level
kvs.check(sess, Action::View, ResourceKind::Model.on_db(&nsv, &dbv))?;
// Start a new readonly transaction
let mut tx = kvs.transaction(TransactionType::Read, LockType::Optimistic).await?;
// Attempt to get the model definition
let info = tx.get_db_model(&nsv, &dbv, &name, &version).await?;
// Export the file data in to the store
let mut data = crate::obs::stream(info.hash.to_owned()).await?;
// Process all stream values
while let Some(Ok(bytes)) = data.next().await {
if chn.send(bytes.to_vec()).await.is_err() {
break;
}
}
}
_ => {
if let Err(error) = kvs.export(sess, chn).await?.await {
if let crate::error::Db::Channel(message) = error {
// This is not really an error. Just logging it for improved visibility.
trace!("{message}");
@ -417,6 +466,8 @@ async fn export(
}
return Err(error.into());
}
}
}
Ok(())
}
@ -563,8 +614,6 @@ async fn router(
Method::Export | Method::Import => unreachable!(),
#[cfg(not(target_arch = "wasm32"))]
Method::Export => {
let ns = session.ns.clone().unwrap_or_default();
let db = session.db.clone().unwrap_or_default();
let (tx, rx) = crate::channel::bounded(1);
match (param.file, param.bytes_sender) {
@ -572,7 +621,7 @@ async fn router(
let (mut writer, mut reader) = io::duplex(10_240);
// Write to channel.
let export = export(kvs, session, ns, db, tx);
let export = export(kvs, session, tx, param.ml_config);
// Read from channel and write to pipe.
let bridge = async move {
@ -613,7 +662,7 @@ async fn router(
let session = session.clone();
tokio::spawn(async move {
let export = async {
if let Err(error) = export(&kvs, &session, ns, db, tx).await {
if let Err(error) = export(&kvs, &session, tx, param.ml_config).await {
let _ = backup.send(Err(error)).await;
}
};
@ -647,6 +696,52 @@ async fn router(
.into());
}
};
let responses = match param.ml_config {
#[cfg(feature = "ml")]
Some(MlConfig::Import) => {
// Ensure a NS and DB are set
let (nsv, dbv) = check_ns_db(session)?;
// Check the permissions level
kvs.check(session, Action::Edit, ResourceKind::Model.on_db(&nsv, &dbv))?;
// Create a new buffer
let mut buffer = Vec::new();
// Load all the uploaded file chunks
if let Err(error) = file.read_to_end(&mut buffer).await {
return Err(Error::FileRead {
path,
error,
}
.into());
}
// Check that the SurrealML file is valid
let file = match SurMlFile::from_bytes(buffer) {
Ok(file) => file,
Err(error) => {
return Err(Error::FileRead {
path,
error,
}
.into());
}
};
// Convert the file back in to raw bytes
let data = file.to_bytes();
// Calculate the hash of the model file
let hash = crate::obs::hash(&data);
// Insert the file data in to the store
crate::obs::put(&hash, data).await?;
// Insert the model in to the database
let query = DefineStatement::Model(DefineModelStatement {
hash,
name: file.header.name.to_string().into(),
version: file.header.version.to_string(),
comment: Some(file.header.description.to_string().into()),
..Default::default()
})
.into();
kvs.process(query, session, Some(vars.clone())).await?
}
_ => {
let mut statements = String::new();
if let Err(error) = file.read_to_string(&mut statements).await {
return Err(Error::FileRead {
@ -655,7 +750,9 @@ async fn router(
}
.into());
}
let responses = kvs.execute(&statements, &*session, Some(vars.clone())).await?;
kvs.execute(&statements, &*session, Some(vars.clone())).await?
}
};
for response in responses {
response.result?;
}

View file

@ -7,6 +7,9 @@ pub(crate) mod wasm;
use crate::api::conn::DbResponse;
use crate::api::conn::Method;
#[cfg(feature = "ml")]
#[cfg(not(target_arch = "wasm32"))]
use crate::api::conn::MlConfig;
use crate::api::conn::Param;
use crate::api::engine::create_statement;
use crate::api::engine::delete_statement;
@ -516,7 +519,14 @@ async fn router(
Method::Export | Method::Import => unreachable!(),
#[cfg(not(target_arch = "wasm32"))]
Method::Export => {
let path = base_url.join(Method::Export.as_str())?;
let path = match param.ml_config {
#[cfg(feature = "ml")]
Some(MlConfig::Export {
name,
version,
}) => base_url.join(&format!("ml/export/{name}/{version}"))?,
_ => base_url.join(Method::Export.as_str())?,
};
let request = client
.get(path)
.headers(headers.clone())
@ -527,7 +537,11 @@ async fn router(
}
#[cfg(not(target_arch = "wasm32"))]
Method::Import => {
let path = base_url.join(Method::Import.as_str())?;
let path = match param.ml_config {
#[cfg(feature = "ml")]
Some(MlConfig::Import) => base_url.join("ml/import")?,
_ => base_url.join(Method::Import.as_str())?,
};
let file = param.file.expect("file to import from");
let request = client
.post(path)

View file

@ -1,15 +1,18 @@
use crate::api::conn::Method;
use crate::api::conn::MlConfig;
use crate::api::conn::Param;
use crate::api::Connection;
use crate::api::Error;
use crate::api::ExtraFeatures;
use crate::api::Result;
use crate::method::Model;
use crate::method::OnceLockExt;
use crate::opt::ExportDestination;
use crate::Surreal;
use channel::Receiver;
use futures::Stream;
use futures::StreamExt;
use semver::Version;
use std::borrow::Cow;
use std::future::Future;
use std::future::IntoFuture;
@ -22,18 +25,39 @@ use std::task::Poll;
/// A database export future
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct Export<'r, C: Connection, R> {
pub struct Export<'r, C: Connection, R, T = ()> {
pub(super) client: Cow<'r, Surreal<C>>,
pub(super) target: ExportDestination,
pub(super) ml_config: Option<MlConfig>,
pub(super) response: PhantomData<R>,
pub(super) export_type: PhantomData<T>,
}
impl<C, R> Export<'_, C, R>
impl<'r, C, R> Export<'r, C, R>
where
C: Connection,
{
/// Export machine learning model
pub fn ml(self, name: &str, version: Version) -> Export<'r, C, R, Model> {
Export {
client: self.client,
target: self.target,
ml_config: Some(MlConfig::Export {
name: name.to_owned(),
version: version.to_string(),
}),
response: self.response,
export_type: PhantomData,
}
}
}
impl<C, R, T> Export<'_, C, R, T>
where
C: Connection,
{
/// Converts to an owned type which can easily be moved to a different thread
pub fn into_owned(self) -> Export<'static, C, R> {
pub fn into_owned(self) -> Export<'static, C, R, T> {
Export {
client: Cow::Owned(self.client.into_owned()),
..self
@ -41,7 +65,7 @@ where
}
}
impl<'r, Client> IntoFuture for Export<'r, Client, PathBuf>
impl<'r, Client, T> IntoFuture for Export<'r, Client, PathBuf, T>
where
Client: Connection,
{
@ -55,15 +79,17 @@ where
return Err(Error::BackupsNotSupported.into());
}
let mut conn = Client::new(Method::Export);
match self.target {
ExportDestination::File(path) => conn.execute_unit(router, Param::file(path)).await,
let mut param = match self.target {
ExportDestination::File(path) => Param::file(path),
ExportDestination::Memory => unreachable!(),
}
};
param.ml_config = self.ml_config;
conn.execute_unit(router, param).await
})
}
}
impl<'r, Client> IntoFuture for Export<'r, Client, ()>
impl<'r, Client, T> IntoFuture for Export<'r, Client, (), T>
where
Client: Connection,
{
@ -81,7 +107,9 @@ where
let ExportDestination::Memory = self.target else {
unreachable!();
};
conn.execute_unit(router, Param::bytes_sender(tx)).await?;
let mut param = Param::bytes_sender(tx);
param.ml_config = self.ml_config;
conn.execute_unit(router, param).await?;
Ok(Backup {
rx,
})

View file

@ -1,31 +1,51 @@
use crate::api::conn::Method;
use crate::api::conn::MlConfig;
use crate::api::conn::Param;
use crate::api::Connection;
use crate::api::Error;
use crate::api::ExtraFeatures;
use crate::api::Result;
use crate::method::Model;
use crate::method::OnceLockExt;
use crate::Surreal;
use std::borrow::Cow;
use std::future::Future;
use std::future::IntoFuture;
use std::marker::PhantomData;
use std::path::PathBuf;
use std::pin::Pin;
/// An database import future
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct Import<'r, C: Connection> {
pub struct Import<'r, C: Connection, T = ()> {
pub(super) client: Cow<'r, Surreal<C>>,
pub(super) file: PathBuf,
pub(super) ml_config: Option<MlConfig>,
pub(super) import_type: PhantomData<T>,
}
impl<C> Import<'_, C>
impl<'r, C> Import<'r, C>
where
C: Connection,
{
/// Import machine learning model
pub fn ml(self) -> Import<'r, C, Model> {
Import {
client: self.client,
file: self.file,
ml_config: Some(MlConfig::Import),
import_type: PhantomData,
}
}
}
impl<'r, C, T> Import<'r, C, T>
where
C: Connection,
{
/// Converts to an owned type which can easily be moved to a different thread
pub fn into_owned(self) -> Import<'static, C> {
pub fn into_owned(self) -> Import<'static, C, T> {
Import {
client: Cow::Owned(self.client.into_owned()),
..self
@ -33,7 +53,7 @@ where
}
}
impl<'r, Client> IntoFuture for Import<'r, Client>
impl<'r, Client, T> IntoFuture for Import<'r, Client, T>
where
Client: Connection,
{
@ -47,7 +67,9 @@ where
return Err(Error::BackupsNotSupported.into());
}
let mut conn = Client::new(Method::Import);
conn.execute_unit(router, Param::file(self.file)).await
let mut param = Param::file(self.file);
param.ml_config = self.ml_config;
conn.execute_unit(router, param).await
})
}
}

View file

@ -90,6 +90,9 @@ pub struct Stats {
pub execution_time: Duration,
}
/// Machine learning model marker type for import and export types
pub struct Model;
/// Responses returned with statistics
#[derive(Debug)]
pub struct WithStats<T>(T);
@ -1004,7 +1007,9 @@ where
Export {
client: Cow::Borrowed(self),
target: target.into_export_destination(),
ml_config: None,
response: PhantomData,
export_type: PhantomData,
}
}
@ -1034,6 +1039,8 @@ where
Import {
client: Cow::Borrowed(self),
file: file.as_ref().to_owned(),
ml_config: None,
import_type: PhantomData,
}
}
}

View file

@ -409,7 +409,7 @@ fn into_json(value: Value, simplify: bool) -> JsonValue {
},
Value::Cast(cast) => json!(cast),
Value::Function(function) => json!(function),
Value::MlModel(model) => json!(model),
Value::Model(model) => json!(model),
Value::Query(query) => json!(query),
Value::Subquery(subquery) => json!(subquery),
Value::Expression(expression) => json!(expression),

View file

@ -38,8 +38,7 @@ pub const PROCESSOR_BATCH_SIZE: u32 = 50;
/// Forward all signup/signin query errors to a client trying authenticate to a scope. Do not use in production.
pub static INSECURE_FORWARD_SCOPE_ERRORS: Lazy<bool> = Lazy::new(|| {
let default = false;
std::env::var("SURREAL_INSECURE_FORWARD_SCOPE_ERRORS")
.map(|v| v.parse::<bool>().unwrap_or(default))
.unwrap_or(default)
option_env!("SURREAL_INSECURE_FORWARD_SCOPE_ERRORS")
.and_then(|s| s.parse::<bool>().ok())
.unwrap_or(false)
});

View file

@ -12,6 +12,7 @@ use base64_lib::DecodeError as Base64Error;
use bincode::Error as BincodeError;
use fst::Error as FstError;
use jsonwebtoken::errors::Error as JWTError;
use object_store::Error as ObjectStoreError;
use revision::Error as RevisionError;
use serde::Serialize;
use std::io::Error as IoError;
@ -194,6 +195,12 @@ pub enum Error {
message: String,
},
/// There was an error with the provided machine learning model
#[error("Problem with machine learning computation. {message}")]
InvalidModel {
message: String,
},
/// There was a problem running the specified function
#[error("There was a problem running the {name}() function. {message}")]
InvalidFunction {
@ -316,6 +323,12 @@ pub enum Error {
value: String,
},
/// The requested model does not exist
#[error("The model 'ml::{value}' does not exist")]
MlNotFound {
value: String,
},
/// The requested scope does not exist
#[error("The scope '{value}' does not exist")]
ScNotFound {
@ -635,6 +648,14 @@ pub enum Error {
#[error("Utf8 error: {0}")]
Utf8Error(#[from] FromUtf8Error),
/// Represents an underlying error with the Object Store
#[error("Object Store error: {0}")]
ObsError(#[from] ObjectStoreError),
/// There was an error with model computation
#[error("There was an error with model computation: {0}")]
ModelComputation(String),
/// The feature has not yet being implemented
#[error("Feature not yet implemented: {feature}")]
FeatureNotYetImplemented {

17
lib/src/iam/check.rs Normal file
View file

@ -0,0 +1,17 @@
use crate::dbs::Session;
use crate::err::Error;
pub fn check_ns_db(sess: &Session) -> Result<(String, String), Error> {
// Ensure that a namespace was specified
let ns = match sess.ns.clone() {
Some(ns) => ns,
None => return Err(Error::NsEmpty),
};
// Ensure that a database was specified
let db = match sess.db.clone() {
Some(db) => db,
None => return Err(Error::DbEmpty),
};
// All ok
Ok((ns, db))
}

View file

@ -23,6 +23,7 @@ pub enum ResourceKind {
Function,
Analyzer,
Parameter,
Model,
Event,
Field,
Index,
@ -44,6 +45,7 @@ impl std::fmt::Display for ResourceKind {
ResourceKind::Function => write!(f, "Function"),
ResourceKind::Analyzer => write!(f, "Analyzer"),
ResourceKind::Parameter => write!(f, "Parameter"),
ResourceKind::Model => write!(f, "Model"),
ResourceKind::Event => write!(f, "Event"),
ResourceKind::Field => write!(f, "Field"),
ResourceKind::Index => write!(f, "Index"),

View file

@ -4,6 +4,7 @@ use thiserror::Error;
pub mod auth;
pub mod base;
pub mod check;
pub mod clear;
pub mod entities;
pub mod policies;

View file

@ -0,0 +1,89 @@
/// Stores a DEFINE MODEL config definition
use crate::key::error::KeyCategory;
use crate::key::key_req::KeyRequirements;
use derive::Key;
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, Eq, PartialEq, PartialOrd, Serialize, Deserialize, Key)]
pub struct Ml<'a> {
__: u8,
_a: u8,
pub ns: &'a str,
_b: u8,
pub db: &'a str,
_c: u8,
_d: u8,
_e: u8,
pub ml: &'a str,
pub vn: &'a str,
}
pub fn new<'a>(ns: &'a str, db: &'a str, ml: &'a str, vn: &'a str) -> Ml<'a> {
Ml::new(ns, db, ml, vn)
}
pub fn prefix(ns: &str, db: &str) -> Vec<u8> {
let mut k = super::all::new(ns, db).encode().unwrap();
k.extend_from_slice(&[b'!', b'm', b'l', 0x00]);
k
}
pub fn suffix(ns: &str, db: &str) -> Vec<u8> {
let mut k = super::all::new(ns, db).encode().unwrap();
k.extend_from_slice(&[b'!', b'm', b'l', 0xff]);
k
}
impl KeyRequirements for Ml<'_> {
fn key_category(&self) -> KeyCategory {
KeyCategory::DatabaseModel
}
}
impl<'a> Ml<'a> {
pub fn new(ns: &'a str, db: &'a str, ml: &'a str, vn: &'a str) -> Self {
Self {
__: b'/',
_a: b'*',
ns,
_b: b'*',
db,
_c: b'!',
_d: b'm',
_e: b'l',
ml,
vn,
}
}
}
#[cfg(test)]
mod tests {
#[test]
fn key() {
use super::*;
#[rustfmt::skip]
let val = Ml::new(
"testns",
"testdb",
"testml",
"1.0.0",
);
let enc = Ml::encode(&val).unwrap();
assert_eq!(enc, b"/*testns\x00*testdb\x00!mltestml\x001.0.0\x00");
let dec = Ml::decode(&enc).unwrap();
assert_eq!(val, dec);
}
#[test]
fn test_prefix() {
let val = super::prefix("testns", "testdb");
assert_eq!(val, b"/*testns\0*testdb\0!ml\0");
}
#[test]
fn test_suffix() {
let val = super::suffix("testns", "testdb");
assert_eq!(val, b"/*testns\0*testdb\0!ml\xff");
}
}

View file

@ -1,6 +1,7 @@
pub mod all;
pub mod az;
pub mod fc;
pub mod ml;
pub mod pa;
pub mod sc;
pub mod tb;

View file

@ -44,6 +44,8 @@ pub enum KeyCategory {
DatabaseFunction,
/// crate::key::database::lg /*{ns}*{db}!lg{lg}
DatabaseLog,
/// crate::key::database::ml /*{ns}*{db}!ml{ml}{vn}
DatabaseModel,
/// crate::key::database::pa /*{ns}*{db}!pa{pa}
DatabaseParameter,
/// crate::key::database::sc /*{ns}*{db}!sc{sc}
@ -138,6 +140,7 @@ impl Display for KeyCategory {
KeyCategory::DatabaseAnalyzer => "DatabaseAnalyzer",
KeyCategory::DatabaseFunction => "DatabaseFunction",
KeyCategory::DatabaseLog => "DatabaseLog",
KeyCategory::DatabaseModel => "DatabaseModel",
KeyCategory::DatabaseParameter => "DatabaseParameter",
KeyCategory::DatabaseScope => "DatabaseScope",
KeyCategory::DatabaseTable => "DatabaseTable",

View file

@ -6,6 +6,7 @@ use crate::sql::statements::DefineEventStatement;
use crate::sql::statements::DefineFieldStatement;
use crate::sql::statements::DefineFunctionStatement;
use crate::sql::statements::DefineIndexStatement;
use crate::sql::statements::DefineModelStatement;
use crate::sql::statements::DefineNamespaceStatement;
use crate::sql::statements::DefineParamStatement;
use crate::sql::statements::DefineScopeStatement;
@ -22,6 +23,7 @@ pub enum Entry {
Db(Arc<DefineDatabaseStatement>),
Fc(Arc<DefineFunctionStatement>),
Ix(Arc<DefineIndexStatement>),
Ml(Arc<DefineModelStatement>),
Ns(Arc<DefineNamespaceStatement>),
Pa(Arc<DefineParamStatement>),
Tb(Arc<DefineTableStatement>),
@ -36,6 +38,7 @@ pub enum Entry {
Fts(Arc<[DefineTableStatement]>),
Ixs(Arc<[DefineIndexStatement]>),
Lvs(Arc<[LiveStatement]>),
Mls(Arc<[DefineModelStatement]>),
Nss(Arc<[DefineNamespaceStatement]>),
Nts(Arc<[DefineTokenStatement]>),
Nus(Arc<[DefineUserStatement]>),

View file

@ -6,7 +6,7 @@ use crate::dbs::{
Variables,
};
use crate::err::Error;
use crate::iam::{Action, Auth, Error as IamError, ResourceKind, Role};
use crate::iam::{Action, Auth, Error as IamError, Resource, Role};
use crate::key::root::hb::Hb;
use crate::kvs::clock::SizedClock;
#[allow(unused_imports)]
@ -1232,9 +1232,11 @@ impl Datastore {
self.notification_channel.as_ref().map(|v| v.1.clone())
}
#[allow(dead_code)]
pub(crate) fn live_sender(&self) -> Option<Arc<RwLock<Sender<Notification>>>> {
self.notification_channel.as_ref().map(|v| Arc::new(RwLock::new(v.0.clone())))
/// Performs a database import from SQL
#[instrument(level = "debug", skip(self, sess, sql))]
pub async fn import(&self, sql: &str, sess: &Session) -> Result<Vec<Response>, Error> {
// Execute the SQL import
self.execute(sql, sess, None).await
}
/// Performs a full database export as SQL
@ -1242,15 +1244,10 @@ impl Datastore {
pub async fn export(
&self,
sess: &Session,
ns: String,
db: String,
chn: Sender<Vec<u8>>,
) -> Result<impl Future<Output = Result<(), Error>>, Error> {
// Skip auth for Anonymous users if auth is disabled
let skip_auth = !self.is_auth_enabled() && sess.au.is_anon();
if !skip_auth {
sess.au.is_allowed(Action::View, &ResourceKind::Any.on_db(&ns, &db))?;
}
// Retrieve the provided NS and DB
let (ns, db) = crate::iam::check::check_ns_db(sess)?;
// Create a new readonly transaction
let mut txn = self.transaction(Read, Optimistic).await?;
// Return an async export job
@ -1262,18 +1259,15 @@ impl Datastore {
})
}
/// Performs a database import from SQL
#[instrument(level = "debug", skip(self, sess, sql))]
pub async fn import(&self, sql: &str, sess: &Session) -> Result<Vec<Response>, Error> {
/// Checks the required permissions level for this session
#[instrument(level = "debug", skip(self, sess))]
pub fn check(&self, sess: &Session, action: Action, resource: Resource) -> Result<(), Error> {
// Skip auth for Anonymous users if auth is disabled
let skip_auth = !self.is_auth_enabled() && sess.au.is_anon();
if !skip_auth {
sess.au.is_allowed(
Action::Edit,
&ResourceKind::Any.on_level(sess.au.level().to_owned()),
)?;
sess.au.is_allowed(action, &resource)?;
}
// Execute the SQL import
self.execute(sql, sess, None).await
// All ok
Ok(())
}
}

View file

@ -1,64 +1,55 @@
use once_cell::sync::Lazy;
pub static ROCKSDB_THREAD_COUNT: Lazy<i32> = Lazy::new(|| {
let default = num_cpus::get() as i32;
std::env::var("SURREAL_ROCKSDB_THREAD_COUNT")
.map(|v| v.parse::<i32>().unwrap_or(default))
.unwrap_or(default)
option_env!("SURREAL_ROCKSDB_THREAD_COUNT")
.and_then(|s| s.parse::<i32>().ok())
.unwrap_or(num_cpus::get() as i32)
});
pub static ROCKSDB_WRITE_BUFFER_SIZE: Lazy<usize> = Lazy::new(|| {
let default = 256 * 1024 * 1024;
std::env::var("SURREAL_ROCKSDB_WRITE_BUFFER_SIZE")
.map(|v| v.parse::<usize>().unwrap_or(default))
.unwrap_or(default)
option_env!("SURREAL_ROCKSDB_WRITE_BUFFER_SIZE")
.and_then(|s| s.parse::<usize>().ok())
.unwrap_or(256 * 1024 * 1024)
});
pub static ROCKSDB_TARGET_FILE_SIZE_BASE: Lazy<u64> = Lazy::new(|| {
let default = 512 * 1024 * 1024;
std::env::var("SURREAL_ROCKSDB_TARGET_FILE_SIZE_BASE")
.map(|v| v.parse::<u64>().unwrap_or(default))
.unwrap_or(default)
option_env!("SURREAL_ROCKSDB_TARGET_FILE_SIZE_BASE")
.and_then(|s| s.parse::<u64>().ok())
.unwrap_or(512 * 1024 * 1024)
});
pub static ROCKSDB_MAX_WRITE_BUFFER_NUMBER: Lazy<i32> = Lazy::new(|| {
let default = 32;
std::env::var("SURREAL_ROCKSDB_MAX_WRITE_BUFFER_NUMBER")
.map(|v| v.parse::<i32>().unwrap_or(default))
.unwrap_or(default)
option_env!("SURREAL_ROCKSDB_MAX_WRITE_BUFFER_NUMBER")
.and_then(|s| s.parse::<i32>().ok())
.unwrap_or(32)
});
pub static ROCKSDB_MIN_WRITE_BUFFER_NUMBER_TO_MERGE: Lazy<i32> = Lazy::new(|| {
let default = 4;
std::env::var("SURREAL_ROCKSDB_MIN_WRITE_BUFFER_NUMBER_TO_MERGE")
.map(|v| v.parse::<i32>().unwrap_or(default))
.unwrap_or(default)
option_env!("SURREAL_ROCKSDB_MIN_WRITE_BUFFER_NUMBER_TO_MERGE")
.and_then(|s| s.parse::<i32>().ok())
.unwrap_or(4)
});
pub static ROCKSDB_ENABLE_PIPELINED_WRITES: Lazy<bool> = Lazy::new(|| {
let default = true;
std::env::var("SURREAL_ROCKSDB_ENABLE_PIPELINED_WRITES")
.map(|v| v.parse::<bool>().unwrap_or(default))
.unwrap_or(default)
option_env!("SURREAL_ROCKSDB_ENABLE_PIPELINED_WRITES")
.and_then(|s| s.parse::<bool>().ok())
.unwrap_or(true)
});
pub static ROCKSDB_ENABLE_BLOB_FILES: Lazy<bool> = Lazy::new(|| {
let default = true;
std::env::var("SURREAL_ROCKSDB_ENABLE_BLOB_FILES")
.map(|v| v.parse::<bool>().unwrap_or(default))
.unwrap_or(default)
option_env!("SURREAL_ROCKSDB_ENABLE_BLOB_FILES")
.and_then(|s| s.parse::<bool>().ok())
.unwrap_or(true)
});
pub static ROCKSDB_MIN_BLOB_SIZE: Lazy<u64> = Lazy::new(|| {
let default = 4 * 1024;
std::env::var("SURREAL_ROCKSDB_MIN_BLOB_SIZE")
.map(|v| v.parse::<u64>().unwrap_or(default))
.unwrap_or(default)
option_env!("SURREAL_ROCKSDB_MIN_BLOB_SIZE")
.and_then(|s| s.parse::<u64>().ok())
.unwrap_or(4 * 1024)
});
pub static ROCKSDB_KEEP_LOG_FILE_NUM: Lazy<usize> = Lazy::new(|| {
let default = 20;
std::env::var("SURREAL_ROCKSDB_KEEP_LOG_FILE_NUM")
.map(|v| v.parse::<usize>().unwrap_or(default))
.unwrap_or(default)
option_env!("SURREAL_ROCKSDB_KEEP_LOG_FILE_NUM")
.and_then(|s| s.parse::<usize>().ok())
.unwrap_or(20)
});

View file

@ -1,64 +1,55 @@
use once_cell::sync::Lazy;
pub static SPEEDB_THREAD_COUNT: Lazy<i32> = Lazy::new(|| {
let default = num_cpus::get() as i32;
std::env::var("SURREAL_SPEEDB_THREAD_COUNT")
.map(|v| v.parse::<i32>().unwrap_or(default))
.unwrap_or(default)
option_env!("SURREAL_SPEEDB_THREAD_COUNT")
.and_then(|s| s.parse::<i32>().ok())
.unwrap_or(num_cpus::get() as i32)
});
pub static SPEEDB_WRITE_BUFFER_SIZE: Lazy<usize> = Lazy::new(|| {
let default = 256 * 1024 * 1024;
std::env::var("SURREAL_SPEEDB_WRITE_BUFFER_SIZE")
.map(|v| v.parse::<usize>().unwrap_or(default))
.unwrap_or(default)
option_env!("SURREAL_SPEEDB_WRITE_BUFFER_SIZE")
.and_then(|s| s.parse::<usize>().ok())
.unwrap_or(256 * 1024 * 1024)
});
pub static SPEEDB_TARGET_FILE_SIZE_BASE: Lazy<u64> = Lazy::new(|| {
let default = 512 * 1024 * 1024;
std::env::var("SURREAL_SPEEDB_TARGET_FILE_SIZE_BASE")
.map(|v| v.parse::<u64>().unwrap_or(default))
.unwrap_or(default)
option_env!("SURREAL_SPEEDB_TARGET_FILE_SIZE_BASE")
.and_then(|s| s.parse::<u64>().ok())
.unwrap_or(512 * 1024 * 1024)
});
pub static SPEEDB_MAX_WRITE_BUFFER_NUMBER: Lazy<i32> = Lazy::new(|| {
let default = 32;
std::env::var("SURREAL_SPEEDB_MAX_WRITE_BUFFER_NUMBER")
.map(|v| v.parse::<i32>().unwrap_or(default))
.unwrap_or(default)
option_env!("SURREAL_SPEEDB_MAX_WRITE_BUFFER_NUMBER")
.and_then(|s| s.parse::<i32>().ok())
.unwrap_or(32)
});
pub static SPEEDB_MIN_WRITE_BUFFER_NUMBER_TO_MERGE: Lazy<i32> = Lazy::new(|| {
let default = 4;
std::env::var("SURREAL_SPEEDB_MIN_WRITE_BUFFER_NUMBER_TO_MERGE")
.map(|v| v.parse::<i32>().unwrap_or(default))
.unwrap_or(default)
option_env!("SURREAL_SPEEDB_MIN_WRITE_BUFFER_NUMBER_TO_MERGE")
.and_then(|s| s.parse::<i32>().ok())
.unwrap_or(4)
});
pub static SPEEDB_ENABLE_PIPELINED_WRITES: Lazy<bool> = Lazy::new(|| {
let default = true;
std::env::var("SURREAL_SPEEDB_ENABLE_PIPELINED_WRITES")
.map(|v| v.parse::<bool>().unwrap_or(default))
.unwrap_or(default)
option_env!("SURREAL_SPEEDB_ENABLE_PIPELINED_WRITES")
.and_then(|s| s.parse::<bool>().ok())
.unwrap_or(true)
});
pub static SPEEDB_ENABLE_BLOB_FILES: Lazy<bool> = Lazy::new(|| {
let default = true;
std::env::var("SURREAL_SPEEDB_ENABLE_BLOB_FILES")
.map(|v| v.parse::<bool>().unwrap_or(default))
.unwrap_or(default)
option_env!("SURREAL_SPEEDB_ENABLE_BLOB_FILES")
.and_then(|s| s.parse::<bool>().ok())
.unwrap_or(true)
});
pub static SPEEDB_MIN_BLOB_SIZE: Lazy<u64> = Lazy::new(|| {
let default = 4 * 1024;
std::env::var("SURREAL_SPEEDB_ENABLE_BLOB_FILES")
.map(|v| v.parse::<u64>().unwrap_or(default))
.unwrap_or(default)
option_env!("SURREAL_SPEEDB_MIN_BLOB_SIZE")
.and_then(|s| s.parse::<u64>().ok())
.unwrap_or(4 * 1024)
});
pub static SPEEDB_KEEP_LOG_FILE_NUM: Lazy<usize> = Lazy::new(|| {
let default = 20;
std::env::var("SURREAL_SPEEDB_KEEP_LOG_FILE_NUM")
.map(|v| v.parse::<usize>().unwrap_or(default))
.unwrap_or(default)
option_env!("SURREAL_SPEEDB_KEEP_LOG_FILE_NUM")
.and_then(|s| s.parse::<usize>().ok())
.unwrap_or(20)
});

View file

@ -33,6 +33,7 @@ use sql::statements::DefineEventStatement;
use sql::statements::DefineFieldStatement;
use sql::statements::DefineFunctionStatement;
use sql::statements::DefineIndexStatement;
use sql::statements::DefineModelStatement;
use sql::statements::DefineNamespaceStatement;
use sql::statements::DefineParamStatement;
use sql::statements::DefineScopeStatement;
@ -1557,6 +1558,29 @@ impl Transaction {
})
}
/// Retrieve all model definitions for a specific database.
pub async fn all_db_models(
&mut self,
ns: &str,
db: &str,
) -> Result<Arc<[DefineModelStatement]>, Error> {
let key = crate::key::database::ml::prefix(ns, db);
Ok(if let Some(e) = self.cache.get(&key) {
if let Entry::Mls(v) = e {
v
} else {
unreachable!();
}
} else {
let beg = crate::key::database::ml::prefix(ns, db);
let end = crate::key::database::ml::suffix(ns, db);
let val = self.getr(beg..end, u32::MAX).await?;
let val = val.convert().into();
self.cache.set(key, Entry::Mls(Arc::clone(&val)));
val
})
}
/// Retrieve all scope definitions for a specific database.
pub async fn all_sc(
&mut self,
@ -1840,6 +1864,21 @@ impl Transaction {
Ok(val.into())
}
/// Retrieve a specific model definition from a database.
pub async fn get_db_model(
&mut self,
ns: &str,
db: &str,
ml: &str,
vn: &str,
) -> Result<DefineModelStatement, Error> {
let key = crate::key::database::ml::new(ns, db, ml, vn);
let val = self.get(key).await?.ok_or(Error::MlNotFound {
value: format!("{ml}<{vn}>"),
})?;
Ok(val.into())
}
/// Retrieve a specific database token definition.
pub async fn get_db_token(
&mut self,
@ -2178,6 +2217,31 @@ impl Transaction {
})
}
/// Retrieve a specific model definition.
pub async fn get_and_cache_db_model(
&mut self,
ns: &str,
db: &str,
ml: &str,
vn: &str,
) -> Result<Arc<DefineModelStatement>, Error> {
let key = crate::key::database::ml::new(ns, db, ml, vn).encode()?;
Ok(if let Some(e) = self.cache.get(&key) {
if let Entry::Ml(v) = e {
v
} else {
unreachable!();
}
} else {
let val = self.get(key.clone()).await?.ok_or(Error::MlNotFound {
value: format!("{ml}<{vn}>"),
})?;
let val: Arc<DefineModelStatement> = Arc::new(val.into());
self.cache.set(key, Entry::Ml(Arc::clone(&val)));
val
})
}
/// Retrieve a specific table index definition.
pub async fn get_and_cache_tb_index(
&mut self,

View file

@ -134,6 +134,10 @@ pub mod idx;
pub mod key;
#[doc(hidden)]
pub mod kvs;
#[cfg(feature = "ml")]
#[doc(hidden)]
pub mod obs;
#[doc(hidden)]
pub mod syn;

96
lib/src/obs/mod.rs Normal file
View file

@ -0,0 +1,96 @@
//! This module defines the operations for object storage using the [object_store](https://docs.rs/object_store/latest/object_store/)
//! crate. This will enable the user to store objects using local file storage, or cloud storage such as S3 or GCS.
use crate::err::Error;
use bytes::Bytes;
use futures::stream::BoxStream;
use object_store::local::LocalFileSystem;
use object_store::parse_url;
use object_store::path::Path;
use object_store::ObjectStore;
use once_cell::sync::Lazy;
use sha1::{Digest, Sha1};
use std::env;
use std::fs;
use std::sync::Arc;
use url::Url;
static STORE: Lazy<Arc<dyn ObjectStore>> =
Lazy::new(|| match std::env::var("SURREAL_OBJECT_STORE") {
Ok(url) => {
let url = Url::parse(&url).expect("Expected a valid url for SURREAL_OBJECT_STORE");
let (store, _) =
parse_url(&url).expect("Expected a valid url for SURREAL_OBJECT_STORE");
Arc::new(store)
}
Err(_) => {
let path = env::current_dir().unwrap().join("store");
if !path.exists() || !path.is_dir() {
fs::create_dir_all(&path)
.expect("Unable to create directory structure for SURREAL_OBJECT_STORE");
}
// As long as the provided path is correct, the following should never panic
Arc::new(LocalFileSystem::new_with_prefix(path).unwrap())
}
});
static CACHE: Lazy<Arc<dyn ObjectStore>> =
Lazy::new(|| match std::env::var("SURREAL_OBJECT_CACHE") {
Ok(url) => {
let url = Url::parse(&url).expect("Expected a valid url for SURREAL_OBJECT_CACHE");
let (store, _) =
parse_url(&url).expect("Expected a valid url for SURREAL_OBJECT_CACHE");
Arc::new(store)
}
Err(_) => {
let path = env::current_dir().unwrap().join("cache");
if !path.exists() || !path.is_dir() {
fs::create_dir_all(&path)
.expect("Unable to create directory structure for SURREAL_OBJECT_CACHE");
}
// As long as the provided path is correct, the following should never panic
Arc::new(LocalFileSystem::new_with_prefix(path).unwrap())
}
});
/// Gets the file from the local file system object storage.
pub async fn stream(
file: String,
) -> Result<BoxStream<'static, Result<Bytes, object_store::Error>>, Error> {
match CACHE.get(&Path::from(file.as_str())).await {
Ok(data) => Ok(data.into_stream()),
_ => Ok(STORE.get(&Path::from(file.as_str())).await?.into_stream()),
}
}
/// Gets the file from the local file system object storage.
pub async fn get(file: &str) -> Result<Vec<u8>, Error> {
match CACHE.get(&Path::from(file)).await {
Ok(data) => Ok(data.bytes().await?.to_vec()),
_ => {
let data = STORE.get(&Path::from(file)).await?;
CACHE.put(&Path::from(file), data.bytes().await?).await?;
Ok(CACHE.get(&Path::from(file)).await?.bytes().await?.to_vec())
}
}
}
/// Gets the file from the local file system object storage.
pub async fn put(file: &str, data: Vec<u8>) -> Result<(), Error> {
let _ = STORE.put(&Path::from(file), Bytes::from(data)).await?;
Ok(())
}
/// Gets the file from the local file system object storage.
pub async fn del(file: &str) -> Result<(), Error> {
Ok(STORE.delete(&Path::from(file)).await?)
}
/// Hashes the bytes of a file to a string for the storage of a file.
pub fn hash(data: &Vec<u8>) -> String {
let mut hasher = Sha1::new();
hasher.update(data);
let result = hasher.finalize();
let mut output = hex::encode(result);
output.truncate(6);
output
}

View file

@ -159,8 +159,10 @@ impl Function {
fnc::run(ctx, opt, txn, doc, s, a).await
}
Self::Custom(s, x) => {
// Get the full name of this function
let name = format!("fn::{s}");
// Check this function is allowed
ctx.check_allowed_function(format!("fn::{s}").as_str())?;
ctx.check_allowed_function(name.as_str())?;
// Get the function definition
let val = {
// Claim transaction
@ -189,15 +191,16 @@ impl Function {
}
}
}
// Return the value
// Check the function arguments
// Get the number of function arguments
let max_args_len = val.args.len();
// Track the number of required arguments
let mut min_args_len = 0;
// Check for any final optional arguments
val.args.iter().rev().for_each(|(_, kind)| match kind {
Kind::Option(_) if min_args_len == 0 => {}
_ => min_args_len += 1,
});
// Check the necessary arguments are passed
if x.len() < min_args_len || max_args_len < x.len() {
return Err(Error::InvalidArguments {
name: format!("fn::{}", val.name),

View file

@ -1,16 +1,29 @@
use crate::{
ctx::Context,
dbs::{Options, Transaction},
doc::CursorDoc,
err::Error,
sql::value::Value,
};
use async_recursion::async_recursion;
use crate::ctx::Context;
use crate::dbs::{Options, Transaction};
use crate::doc::CursorDoc;
use crate::err::Error;
use crate::sql::value::Value;
use derive::Store;
use revision::revisioned;
use serde::{Deserialize, Serialize};
use std::fmt;
#[cfg(feature = "ml")]
use crate::iam::Action;
#[cfg(feature = "ml")]
use crate::sql::Permission;
#[cfg(feature = "ml")]
use futures::future::try_join_all;
#[cfg(feature = "ml")]
use std::collections::HashMap;
#[cfg(feature = "ml")]
use surrealml_core::execution::compute::ModelComputation;
#[cfg(feature = "ml")]
use surrealml_core::storage::surml_file::SurMlFile;
#[cfg(feature = "ml")]
const ARGUMENTS: &str = "The model expects 1 argument. The argument can be either a number, an object, or an array of numbers.";
#[derive(Clone, Debug, Default, PartialEq, PartialOrd, Serialize, Deserialize, Store, Hash)]
#[revisioned(revision = 1)]
pub struct Model {
@ -33,15 +46,165 @@ impl fmt::Display for Model {
}
impl Model {
#[cfg_attr(not(target_arch = "wasm32"), async_recursion)]
#[cfg_attr(target_arch = "wasm32", async_recursion(?Send))]
#[cfg(feature = "ml")]
pub(crate) async fn compute(
&self,
ctx: &Context<'_>,
opt: &Options,
txn: &Transaction,
doc: Option<&CursorDoc<'_>>,
) -> Result<Value, Error> {
// Ensure futures are run
let opt = &opt.new_with_futures(true);
// Get the full name of this model
let name = format!("ml::{}", self.name);
// Check this function is allowed
ctx.check_allowed_function(name.as_str())?;
// Get the model definition
let val = {
// Claim transaction
let mut run = txn.lock().await;
// Get the function definition
run.get_and_cache_db_model(opt.ns(), opt.db(), &self.name, &self.version).await?
};
// Calculate the model path
let path = format!(
"ml/{}/{}/{}-{}-{}.surml",
opt.ns(),
opt.db(),
self.name,
self.version,
val.hash
);
// Check permissions
if opt.check_perms(Action::View) {
match &val.permissions {
Permission::Full => (),
Permission::None => {
return Err(Error::FunctionPermissions {
name: self.name.to_owned(),
})
}
Permission::Specific(e) => {
// Disable permissions
let opt = &opt.new_with_perms(false);
// Process the PERMISSION clause
if !e.compute(ctx, opt, txn, doc).await?.is_truthy() {
return Err(Error::FunctionPermissions {
name: self.name.to_owned(),
});
}
}
}
}
// Compute the function arguments
let mut args =
try_join_all(self.args.iter().map(|v| v.compute(ctx, opt, txn, doc))).await?;
// Check the minimum argument length
if args.len() != 1 {
return Err(Error::InvalidArguments {
name: format!("ml::{}<{}>", self.name, self.version),
message: ARGUMENTS.into(),
});
}
// Take the first and only specified argument
match args.swap_remove(0) {
// Perform bufferered compute
Value::Object(v) => {
// Compute the model function arguments
let mut args = v
.into_iter()
.map(|(k, v)| Ok((k, Value::try_into(v)?)))
.collect::<Result<HashMap<String, f32>, Error>>()
.map_err(|_| Error::InvalidArguments {
name: format!("ml::{}<{}>", self.name, self.version),
message: ARGUMENTS.into(),
})?;
// Get the model file as bytes
let bytes = crate::obs::get(&path).await?;
// Run the compute in a blocking task
let outcome = tokio::task::spawn_blocking(move || {
let mut file = SurMlFile::from_bytes(bytes).unwrap();
let compute_unit = ModelComputation {
surml_file: &mut file,
};
compute_unit.buffered_compute(&mut args).map_err(Error::ModelComputation)
})
.await
.unwrap()?;
// Convert the output to a value
Ok(outcome[0].into())
}
// Perform raw compute
Value::Number(v) => {
// Compute the model function arguments
let args: f32 = v.try_into().map_err(|_| Error::InvalidArguments {
name: format!("ml::{}<{}>", self.name, self.version),
message: ARGUMENTS.into(),
})?;
// Get the model file as bytes
let bytes = crate::obs::get(&path).await?;
// Convert the argument to a tensor
let tensor = ndarray::arr1::<f32>(&[args]).into_dyn();
// Run the compute in a blocking task
let outcome = tokio::task::spawn_blocking(move || {
let mut file = SurMlFile::from_bytes(bytes).unwrap();
let compute_unit = ModelComputation {
surml_file: &mut file,
};
compute_unit.raw_compute(tensor, None).map_err(Error::ModelComputation)
})
.await
.unwrap()?;
// Convert the output to a value
Ok(outcome[0].into())
}
// Perform raw compute
Value::Array(v) => {
// Compute the model function arguments
let args = v
.into_iter()
.map(Value::try_into)
.collect::<Result<Vec<f32>, Error>>()
.map_err(|_| Error::InvalidArguments {
name: format!("ml::{}<{}>", self.name, self.version),
message: ARGUMENTS.into(),
})?;
// Get the model file as bytes
let bytes = crate::obs::get(&path).await?;
// Convert the argument to a tensor
let tensor = ndarray::arr1::<f32>(&args).into_dyn();
// Run the compute in a blocking task
let outcome = tokio::task::spawn_blocking(move || {
let mut file = SurMlFile::from_bytes(bytes).unwrap();
let compute_unit = ModelComputation {
surml_file: &mut file,
};
compute_unit.raw_compute(tensor, None).map_err(Error::ModelComputation)
})
.await
.unwrap()?;
// Convert the output to a value
Ok(outcome[0].into())
}
//
_ => Err(Error::InvalidArguments {
name: format!("ml::{}<{}>", self.name, self.version),
message: ARGUMENTS.into(),
}),
}
}
#[cfg(not(feature = "ml"))]
pub(crate) async fn compute(
&self,
_ctx: &Context<'_>,
_opt: &Options,
_txn: &Transaction,
_doc: Option<&'async_recursion CursorDoc<'_>>,
_doc: Option<&CursorDoc<'_>>,
) -> Result<Value, Error> {
Err(Error::Unimplemented("ML model evaluation not yet implemented".to_string()))
Err(Error::InvalidModel {
message: String::from("Machine learning computation is not enabled."),
})
}
}

View file

@ -1,6 +1,6 @@
use crate::sql::fmt::Pretty;
use crate::sql::statement::{Statement, Statements};
use crate::sql::Value;
use crate::sql::statements::{DefineStatement, RemoveStatement};
use derive::Store;
use revision::revisioned;
use serde::{Deserialize, Serialize};
@ -16,6 +16,18 @@ pub(crate) const TOKEN: &str = "$surrealdb::private::sql::Query";
#[serde(rename = "$surrealdb::private::sql::Query")]
pub struct Query(pub Statements);
impl From<DefineStatement> for Query {
fn from(s: DefineStatement) -> Self {
Query(Statements(vec![Statement::Define(s)]))
}
}
impl From<RemoveStatement> for Query {
fn from(s: RemoveStatement) -> Self {
Query(Statements(vec![Statement::Remove(s)]))
}
}
impl Deref for Query {
type Target = Vec<Statement>;
fn deref(&self) -> &Self::Target {
@ -31,12 +43,6 @@ impl IntoIterator for Query {
}
}
impl From<Query> for Value {
fn from(q: Query) -> Self {
Value::Query(q)
}
}
impl Display for Query {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
write!(Pretty::from(f), "{}", &self.0)

View file

@ -52,7 +52,7 @@ pub enum DefineStatement {
Field(DefineFieldStatement),
Index(DefineIndexStatement),
User(DefineUserStatement),
MlModel(DefineModelStatement),
Model(DefineModelStatement),
}
impl DefineStatement {
@ -81,7 +81,7 @@ impl DefineStatement {
Self::Index(ref v) => v.compute(ctx, opt, txn, doc).await,
Self::Analyzer(ref v) => v.compute(ctx, opt, txn, doc).await,
Self::User(ref v) => v.compute(ctx, opt, txn, doc).await,
Self::MlModel(ref v) => v.compute(ctx, opt, txn, doc).await,
Self::Model(ref v) => v.compute(ctx, opt, txn, doc).await,
}
}
}
@ -101,7 +101,7 @@ impl Display for DefineStatement {
Self::Field(v) => Display::fmt(v, f),
Self::Index(v) => Display::fmt(v, f),
Self::Analyzer(v) => Display::fmt(v, f),
Self::MlModel(v) => Display::fmt(v, f),
Self::Model(v) => Display::fmt(v, f),
}
}
}

View file

@ -1,25 +1,21 @@
use crate::ctx::Context;
use crate::dbs::{Options, Transaction};
use crate::doc::CursorDoc;
use crate::err::Error;
use crate::iam::{Action, ResourceKind};
use crate::sql::{
fmt::{is_pretty, pretty_indent},
Permission,
Base, Ident, Permission, Strand, Value,
};
use async_recursion::async_recursion;
use derive::Store;
use revision::revisioned;
use serde::{Deserialize, Serialize};
use std::fmt;
use std::fmt::Write;
use crate::{
ctx::Context,
dbs::{Options, Transaction},
doc::CursorDoc,
err::Error,
sql::{Ident, Strand, Value},
};
use std::fmt::{self, Write};
#[derive(Clone, Debug, Default, Eq, PartialEq, PartialOrd, Serialize, Deserialize, Store, Hash)]
#[revisioned(revision = 1)]
pub struct DefineModelStatement {
pub hash: String,
pub name: Ident,
pub version: String,
pub comment: Option<Strand>,
@ -32,7 +28,6 @@ impl fmt::Display for DefineModelStatement {
if let Some(comment) = self.comment.as_ref() {
write!(f, " COMMENT {}", comment)?;
}
if !self.permissions.is_full() {
let _indent = if is_pretty() {
Some(pretty_indent())
} else {
@ -40,21 +35,33 @@ impl fmt::Display for DefineModelStatement {
None
};
write!(f, "PERMISSIONS {}", self.permissions)?;
}
Ok(())
}
}
impl DefineModelStatement {
#[cfg_attr(not(target_arch = "wasm32"), async_recursion)]
#[cfg_attr(target_arch = "wasm32", async_recursion(?Send))]
/// Process this type returning a computed simple Value
pub(crate) async fn compute(
&self,
_ctx: &Context<'_>,
_opt: &Options,
_txn: &Transaction,
_doc: Option<&'async_recursion CursorDoc<'_>>,
opt: &Options,
txn: &Transaction,
_doc: Option<&CursorDoc<'_>>,
) -> Result<Value, Error> {
Err(Error::Unimplemented("Ml model definition not yet implemented".to_string()))
// Allowed to run?
opt.is_allowed(Action::Edit, ResourceKind::Model, &Base::Db)?;
// Claim transaction
let mut run = txn.lock().await;
// Clear the cache
run.clear_cache();
// Process the statement
let key = crate::key::database::ml::new(opt.ns(), opt.db(), &self.name, &self.version);
run.add_ns(opt.ns(), opt.strict).await?;
run.add_db(opt.ns(), opt.db(), opt.strict).await?;
run.set(key, self).await?;
// Store the model file
// TODO
// Ok all good
Ok(Value::None)
}
}

View file

@ -107,6 +107,12 @@ impl InfoStatement {
tmp.insert(v.name.to_string(), v.to_string().into());
}
res.insert("functions".to_owned(), tmp.into());
// Process the models
let mut tmp = Object::default();
for v in run.all_db_models(opt.ns(), opt.db()).await?.iter() {
tmp.insert(format!("{}<{}>", v.name, v.version), v.to_string().into());
}
res.insert("models".to_owned(), tmp.into());
// Process the params
let mut tmp = Object::default();
for v in run.all_db_params(opt.ns(), opt.db()).await?.iter() {

View file

@ -52,14 +52,14 @@ pub use self::update::UpdateStatement;
pub use self::define::{
DefineAnalyzerStatement, DefineDatabaseStatement, DefineEventStatement, DefineFieldStatement,
DefineFunctionStatement, DefineIndexStatement, DefineNamespaceStatement, DefineParamStatement,
DefineScopeStatement, DefineStatement, DefineTableStatement, DefineTokenStatement,
DefineUserStatement,
DefineFunctionStatement, DefineIndexStatement, DefineModelStatement, DefineNamespaceStatement,
DefineParamStatement, DefineScopeStatement, DefineStatement, DefineTableStatement,
DefineTokenStatement, DefineUserStatement,
};
pub use self::remove::{
RemoveAnalyzerStatement, RemoveDatabaseStatement, RemoveEventStatement, RemoveFieldStatement,
RemoveFunctionStatement, RemoveIndexStatement, RemoveNamespaceStatement, RemoveParamStatement,
RemoveScopeStatement, RemoveStatement, RemoveTableStatement, RemoveTokenStatement,
RemoveUserStatement,
RemoveFunctionStatement, RemoveIndexStatement, RemoveModelStatement, RemoveNamespaceStatement,
RemoveParamStatement, RemoveScopeStatement, RemoveStatement, RemoveTableStatement,
RemoveTokenStatement, RemoveUserStatement,
};

View file

@ -4,6 +4,7 @@ mod event;
mod field;
mod function;
mod index;
mod model;
mod namespace;
mod param;
mod scope;
@ -17,6 +18,7 @@ pub use event::RemoveEventStatement;
pub use field::RemoveFieldStatement;
pub use function::RemoveFunctionStatement;
pub use index::RemoveIndexStatement;
pub use model::RemoveModelStatement;
pub use namespace::RemoveNamespaceStatement;
pub use param::RemoveParamStatement;
pub use scope::RemoveScopeStatement;
@ -49,6 +51,7 @@ pub enum RemoveStatement {
Field(RemoveFieldStatement),
Index(RemoveIndexStatement),
User(RemoveUserStatement),
Model(RemoveModelStatement),
}
impl RemoveStatement {
@ -77,6 +80,7 @@ impl RemoveStatement {
Self::Index(ref v) => v.compute(ctx, opt, txn).await,
Self::Analyzer(ref v) => v.compute(ctx, opt, txn).await,
Self::User(ref v) => v.compute(ctx, opt, txn).await,
Self::Model(ref v) => v.compute(ctx, opt, txn).await,
}
}
}
@ -96,6 +100,7 @@ impl Display for RemoveStatement {
Self::Index(v) => Display::fmt(v, f),
Self::Analyzer(v) => Display::fmt(v, f),
Self::User(v) => Display::fmt(v, f),
Self::Model(v) => Display::fmt(v, f),
}
}
}

View file

@ -0,0 +1,47 @@
use crate::ctx::Context;
use crate::dbs::{Options, Transaction};
use crate::err::Error;
use crate::iam::{Action, ResourceKind};
use crate::sql::{Base, Ident, Value};
use derive::Store;
use revision::revisioned;
use serde::{Deserialize, Serialize};
use std::fmt::{self, Display};
#[derive(Clone, Debug, Default, Eq, PartialEq, PartialOrd, Serialize, Deserialize, Store, Hash)]
#[revisioned(revision = 1)]
pub struct RemoveModelStatement {
pub name: Ident,
pub version: String,
}
impl RemoveModelStatement {
/// Process this type returning a computed simple Value
pub(crate) async fn compute(
&self,
_ctx: &Context<'_>,
opt: &Options,
txn: &Transaction,
) -> Result<Value, Error> {
// Allowed to run?
opt.is_allowed(Action::Edit, ResourceKind::Model, &Base::Db)?;
// Claim transaction
let mut run = txn.lock().await;
// Clear the cache
run.clear_cache();
// Delete the definition
let key = crate::key::database::ml::new(opt.ns(), opt.db(), &self.name, &self.version);
run.del(key).await?;
// Remove the model file
// TODO
// Ok all good
Ok(Value::None)
}
}
impl Display for RemoveModelStatement {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
// Bypass ident display since we don't want backticks arround the ident.
write!(f, "REMOVE MODEL ml::{}<{}>", self.name.0, self.version)
}
}

View file

@ -99,12 +99,11 @@ pub enum Value {
Edges(Box<Edges>),
Future(Box<Future>),
Constant(Constant),
// Closure(Box<Closure>),
Function(Box<Function>),
Subquery(Box<Subquery>),
Expression(Box<Expression>),
Query(Query),
MlModel(Box<Model>),
Model(Box<Model>),
// Add new variants here
}
@ -257,7 +256,7 @@ impl From<Function> for Value {
impl From<Model> for Value {
fn from(v: Model) -> Self {
Value::MlModel(Box::new(v))
Value::Model(Box::new(v))
}
}
@ -505,6 +504,12 @@ impl From<Id> for Value {
}
}
impl From<Query> for Value {
fn from(q: Query) -> Self {
Value::Query(q)
}
}
impl TryFrom<Value> for i8 {
type Error = Error;
fn try_from(value: Value) -> Result<Self, Self::Error> {
@ -1035,7 +1040,7 @@ impl Value {
pub fn can_start_idiom(&self) -> bool {
match self {
Value::Function(x) => !x.is_script(),
Value::MlModel(_)
Value::Model(_)
| Value::Subquery(_)
| Value::Constant(_)
| Value::Datetime(_)
@ -2526,7 +2531,7 @@ impl fmt::Display for Value {
Value::Edges(v) => write!(f, "{v}"),
Value::Expression(v) => write!(f, "{v}"),
Value::Function(v) => write!(f, "{v}"),
Value::MlModel(v) => write!(f, "{v}"),
Value::Model(v) => write!(f, "{v}"),
Value::Future(v) => write!(f, "{v}"),
Value::Geometry(v) => write!(f, "{v}"),
Value::Idiom(v) => write!(f, "{v}"),
@ -2557,7 +2562,7 @@ impl Value {
Value::Function(v) => {
v.is_custom() || v.is_script() || v.args().iter().any(Value::writeable)
}
Value::MlModel(m) => m.args.iter().any(Value::writeable),
Value::Model(m) => m.args.iter().any(Value::writeable),
Value::Subquery(v) => v.writeable(),
Value::Expression(v) => v.writeable(),
_ => false,
@ -2588,7 +2593,7 @@ impl Value {
Value::Future(v) => v.compute(ctx, opt, txn, doc).await,
Value::Constant(v) => v.compute(ctx, opt, txn, doc).await,
Value::Function(v) => v.compute(ctx, opt, txn, doc).await,
Value::MlModel(v) => v.compute(ctx, opt, txn, doc).await,
Value::Model(v) => v.compute(ctx, opt, txn, doc).await,
Value::Subquery(v) => v.compute(ctx, opt, txn, doc).await,
Value::Expression(v) => v.compute(ctx, opt, txn, doc).await,
_ => Ok(self.to_owned()),

View file

@ -2,6 +2,7 @@
mod api_integration {
use chrono::DateTime;
use once_cell::sync::Lazy;
use semver::Version;
use serde::Deserialize;
use serde::Serialize;
use serde_json::json;

View file

@ -23,3 +23,17 @@ async fn export_import() {
db.import(&file).await.unwrap();
remove_file(file).await.unwrap();
}
#[test_log::test(tokio::test)]
#[cfg(feature = "ml")]
async fn ml_export_import() {
let (permit, db) = new_db().await;
let db_name = Ulid::new().to_string();
db.use_ns(NS).use_db(&db_name).await.unwrap();
db.import("../tests/linear_test.surml").ml().await.unwrap();
drop(permit);
let file = format!("{db_name}.surml");
db.export(&file).ml("Prediction", Version::new(0, 0, 1)).await.unwrap();
db.import(&file).ml().await.unwrap();
remove_file(file).await.unwrap();
}

View file

@ -87,8 +87,7 @@ async fn define_statement_function() -> Result<(), Error> {
analyzers: {},
tokens: {},
functions: { test: 'DEFINE FUNCTION fn::test($first: string, $last: string) { RETURN $first + $last; } PERMISSIONS FULL' },
params: {},
scopes: {},
models: {},
params: {},
scopes: {},
tables: {},
@ -120,6 +119,7 @@ async fn define_statement_table_drop() -> Result<(), Error> {
analyzers: {},
tokens: {},
functions: {},
models: {},
params: {},
scopes: {},
tables: { test: 'DEFINE TABLE test DROP SCHEMALESS PERMISSIONS NONE' },
@ -151,6 +151,7 @@ async fn define_statement_table_schemaless() -> Result<(), Error> {
analyzers: {},
tokens: {},
functions: {},
models: {},
params: {},
scopes: {},
tables: { test: 'DEFINE TABLE test SCHEMALESS PERMISSIONS NONE' },
@ -186,6 +187,7 @@ async fn define_statement_table_schemafull() -> Result<(), Error> {
analyzers: {},
tokens: {},
functions: {},
models: {},
params: {},
scopes: {},
tables: { test: 'DEFINE TABLE test SCHEMAFULL PERMISSIONS NONE' },
@ -217,6 +219,7 @@ async fn define_statement_table_schemaful() -> Result<(), Error> {
analyzers: {},
tokens: {},
functions: {},
models: {},
params: {},
scopes: {},
tables: { test: 'DEFINE TABLE test SCHEMAFULL PERMISSIONS NONE' },
@ -256,6 +259,7 @@ async fn define_statement_table_foreigntable() -> Result<(), Error> {
analyzers: {},
tokens: {},
functions: {},
models: {},
params: {},
scopes: {},
tables: {
@ -288,6 +292,7 @@ async fn define_statement_table_foreigntable() -> Result<(), Error> {
analyzers: {},
tokens: {},
functions: {},
models: {},
params: {},
scopes: {},
tables: {
@ -1177,6 +1182,7 @@ async fn define_statement_analyzer() -> Result<(), Error> {
functions: {
stripHtml: "DEFINE FUNCTION fn::stripHtml($html: string) { RETURN string::replace($html, /<[^>]*>/, ''); } PERMISSIONS FULL"
},
models: {},
params: {},
scopes: {},
tables: {},
@ -1496,8 +1502,8 @@ async fn permissions_checks_define_function() {
// Define the expected results for the check statement when the test statement succeeded and when it failed
let check_results = [
vec!["{ analyzers: { }, functions: { greet: \"DEFINE FUNCTION fn::greet() { RETURN 'Hello'; } PERMISSIONS FULL\" }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"],
vec!["{ analyzers: { }, functions: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"]
vec!["{ analyzers: { }, functions: { greet: \"DEFINE FUNCTION fn::greet() { RETURN 'Hello'; } PERMISSIONS FULL\" }, models: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"],
vec!["{ analyzers: { }, functions: { }, models: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"]
];
let test_cases = [
@ -1538,8 +1544,8 @@ async fn permissions_checks_define_analyzer() {
// Define the expected results for the check statement when the test statement succeeded and when it failed
let check_results = [
vec!["{ analyzers: { analyzer: 'DEFINE ANALYZER analyzer TOKENIZERS BLANK' }, functions: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"],
vec!["{ analyzers: { }, functions: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"]
vec!["{ analyzers: { analyzer: 'DEFINE ANALYZER analyzer TOKENIZERS BLANK' }, functions: { }, models: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"],
vec!["{ analyzers: { }, functions: { }, models: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"]
];
let test_cases = [
@ -1622,8 +1628,8 @@ async fn permissions_checks_define_token_db() {
// Define the expected results for the check statement when the test statement succeeded and when it failed
let check_results = [
vec!["{ analyzers: { }, functions: { }, params: { }, scopes: { }, tables: { }, tokens: { token: \"DEFINE TOKEN token ON DATABASE TYPE HS512 VALUE 'secret'\" }, users: { } }"],
vec!["{ analyzers: { }, functions: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"]
vec!["{ analyzers: { }, functions: { }, models: { }, params: { }, scopes: { }, tables: { }, tokens: { token: \"DEFINE TOKEN token ON DATABASE TYPE HS512 VALUE 'secret'\" }, users: { } }"],
vec!["{ analyzers: { }, functions: { }, models: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"]
];
let test_cases = [
@ -1748,8 +1754,8 @@ async fn permissions_checks_define_user_db() {
// Define the expected results for the check statement when the test statement succeeded and when it failed
let check_results = [
vec!["{ analyzers: { }, functions: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { user: \"DEFINE USER user ON DATABASE PASSHASH 'secret' ROLES VIEWER\" } }"],
vec!["{ analyzers: { }, functions: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"]
vec!["{ analyzers: { }, functions: { }, models: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { user: \"DEFINE USER user ON DATABASE PASSHASH 'secret' ROLES VIEWER\" } }"],
vec!["{ analyzers: { }, functions: { }, models: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"]
];
let test_cases = [
@ -1790,8 +1796,8 @@ async fn permissions_checks_define_scope() {
// Define the expected results for the check statement when the test statement succeeded and when it failed
let check_results = [
vec!["{ analyzers: { }, functions: { }, params: { }, scopes: { account: 'DEFINE SCOPE account SESSION 1h' }, tables: { }, tokens: { }, users: { } }"],
vec!["{ analyzers: { }, functions: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"]
vec!["{ analyzers: { }, functions: { }, models: { }, params: { }, scopes: { account: 'DEFINE SCOPE account SESSION 1h' }, tables: { }, tokens: { }, users: { } }"],
vec!["{ analyzers: { }, functions: { }, models: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"]
];
let test_cases = [
@ -1832,8 +1838,8 @@ async fn permissions_checks_define_param() {
// Define the expected results for the check statement when the test statement succeeded and when it failed
let check_results = [
vec!["{ analyzers: { }, functions: { }, params: { param: \"DEFINE PARAM $param VALUE 'foo' PERMISSIONS FULL\" }, scopes: { }, tables: { }, tokens: { }, users: { } }"],
vec!["{ analyzers: { }, functions: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"]
vec!["{ analyzers: { }, functions: { }, models: { }, params: { param: \"DEFINE PARAM $param VALUE 'foo' PERMISSIONS FULL\" }, scopes: { }, tables: { }, tokens: { }, users: { } }"],
vec!["{ analyzers: { }, functions: { }, models: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"]
];
let test_cases = [
@ -1871,8 +1877,8 @@ async fn permissions_checks_define_table() {
// Define the expected results for the check statement when the test statement succeeded and when it failed
let check_results = [
vec!["{ analyzers: { }, functions: { }, params: { }, scopes: { }, tables: { TB: 'DEFINE TABLE TB SCHEMALESS PERMISSIONS NONE' }, tokens: { }, users: { } }"],
vec!["{ analyzers: { }, functions: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"]
vec!["{ analyzers: { }, functions: { }, models: { }, params: { }, scopes: { }, tables: { TB: 'DEFINE TABLE TB SCHEMALESS PERMISSIONS NONE' }, tokens: { }, users: { } }"],
vec!["{ analyzers: { }, functions: { }, models: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"]
];
let test_cases = [
@ -2058,6 +2064,7 @@ async fn define_statement_table_permissions() -> Result<(), Error> {
"{
analyzers: {},
functions: {},
models: {},
params: {},
scopes: {},
tables: {

View file

@ -312,8 +312,8 @@ async fn permissions_checks_info_db() {
// Define the expected results for the check statement when the test statement succeeded and when it failed
let check_results = [
vec!["{ analyzers: { }, functions: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"],
vec!["{ analyzers: { }, functions: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"],
vec!["{ analyzers: { }, functions: { }, models: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"],
vec!["{ analyzers: { }, functions: { }, models: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"],
];
let test_cases = [

View file

@ -29,6 +29,7 @@ async fn define_global_param() -> Result<(), Error> {
analyzers: {},
tokens: {},
functions: {},
models: {},
params: { test: 'DEFINE PARAM $test VALUE 12345 PERMISSIONS FULL' },
scopes: {},
tables: {},

View file

@ -39,6 +39,7 @@ async fn remove_statement_table() -> Result<(), Error> {
analyzers: {},
tokens: {},
functions: {},
models: {},
params: {},
scopes: {},
tables: {},
@ -73,6 +74,7 @@ async fn remove_statement_analyzer() -> Result<(), Error> {
analyzers: {},
tokens: {},
functions: {},
models: {},
params: {},
scopes: {},
tables: {},
@ -222,8 +224,8 @@ async fn permissions_checks_remove_function() {
// Define the expected results for the check statement when the test statement succeeded and when it failed
let check_results = [
vec!["{ analyzers: { }, functions: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"],
vec!["{ analyzers: { }, functions: { greet: \"DEFINE FUNCTION fn::greet() { RETURN 'Hello'; } PERMISSIONS FULL\" }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"],
vec!["{ analyzers: { }, functions: { }, models: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"],
vec!["{ analyzers: { }, functions: { greet: \"DEFINE FUNCTION fn::greet() { RETURN 'Hello'; } PERMISSIONS FULL\" }, models: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"],
];
let test_cases = [
@ -264,8 +266,8 @@ async fn permissions_checks_remove_analyzer() {
// Define the expected results for the check statement when the test statement succeeded and when it failed
let check_results = [
vec!["{ analyzers: { }, functions: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"],
vec!["{ analyzers: { analyzer: 'DEFINE ANALYZER analyzer TOKENIZERS BLANK' }, functions: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"],
vec!["{ analyzers: { }, functions: { }, models: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"],
vec!["{ analyzers: { analyzer: 'DEFINE ANALYZER analyzer TOKENIZERS BLANK' }, functions: { }, models: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"],
];
let test_cases = [
@ -348,8 +350,8 @@ async fn permissions_checks_remove_db_token() {
// Define the expected results for the check statement when the test statement succeeded and when it failed
let check_results = [
vec!["{ analyzers: { }, functions: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"],
vec!["{ analyzers: { }, functions: { }, params: { }, scopes: { }, tables: { }, tokens: { token: \"DEFINE TOKEN token ON DATABASE TYPE HS512 VALUE 'secret'\" }, users: { } }"],
vec!["{ analyzers: { }, functions: { }, models: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"],
vec!["{ analyzers: { }, functions: { }, models: { }, params: { }, scopes: { }, tables: { }, tokens: { token: \"DEFINE TOKEN token ON DATABASE TYPE HS512 VALUE 'secret'\" }, users: { } }"],
];
let test_cases = [
@ -474,8 +476,8 @@ async fn permissions_checks_remove_db_user() {
// Define the expected results for the check statement when the test statement succeeded and when it failed
let check_results = [
vec!["{ analyzers: { }, functions: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"],
vec!["{ analyzers: { }, functions: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { user: \"DEFINE USER user ON DATABASE PASSHASH 'secret' ROLES VIEWER\" } }"],
vec!["{ analyzers: { }, functions: { }, models: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"],
vec!["{ analyzers: { }, functions: { }, models: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { user: \"DEFINE USER user ON DATABASE PASSHASH 'secret' ROLES VIEWER\" } }"],
];
let test_cases = [
@ -516,8 +518,8 @@ async fn permissions_checks_remove_scope() {
// Define the expected results for the check statement when the test statement succeeded and when it failed
let check_results = [
vec!["{ analyzers: { }, functions: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"],
vec!["{ analyzers: { }, functions: { }, params: { }, scopes: { account: 'DEFINE SCOPE account SESSION 1h' }, tables: { }, tokens: { }, users: { } }"],
vec!["{ analyzers: { }, functions: { }, models: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"],
vec!["{ analyzers: { }, functions: { }, models: { }, params: { }, scopes: { account: 'DEFINE SCOPE account SESSION 1h' }, tables: { }, tokens: { }, users: { } }"],
];
let test_cases = [
@ -558,8 +560,8 @@ async fn permissions_checks_remove_param() {
// Define the expected results for the check statement when the test statement succeeded and when it failed
let check_results = [
vec!["{ analyzers: { }, functions: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"],
vec!["{ analyzers: { }, functions: { }, params: { param: \"DEFINE PARAM $param VALUE 'foo' PERMISSIONS FULL\" }, scopes: { }, tables: { }, tokens: { }, users: { } }"],
vec!["{ analyzers: { }, functions: { }, models: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"],
vec!["{ analyzers: { }, functions: { }, models: { }, params: { param: \"DEFINE PARAM $param VALUE 'foo' PERMISSIONS FULL\" }, scopes: { }, tables: { }, tokens: { }, users: { } }"],
];
let test_cases = [
@ -600,8 +602,8 @@ async fn permissions_checks_remove_table() {
// Define the expected results for the check statement when the test statement succeeded and when it failed
let check_results = [
vec!["{ analyzers: { }, functions: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"],
vec!["{ analyzers: { }, functions: { }, params: { }, scopes: { }, tables: { TB: 'DEFINE TABLE TB SCHEMALESS PERMISSIONS NONE' }, tokens: { }, users: { } }"],
vec!["{ analyzers: { }, functions: { }, models: { }, params: { }, scopes: { }, tables: { }, tokens: { }, users: { } }"],
vec!["{ analyzers: { }, functions: { }, models: { }, params: { }, scopes: { }, tables: { TB: 'DEFINE TABLE TB SCHEMALESS PERMISSIONS NONE' }, tokens: { }, users: { } }"],
];
let test_cases = [

View file

@ -255,6 +255,7 @@ async fn loose_mode_all_ok() -> Result<(), Error> {
analyzers: {},
tokens: {},
functions: {},
models: {},
params: {},
scopes: {},
tables: { test: 'DEFINE TABLE test SCHEMALESS PERMISSIONS NONE' },

View file

@ -10,7 +10,7 @@ use tokio::io::{self, AsyncWriteExt};
#[derive(Args, Debug)]
pub struct ExportCommandArguments {
#[arg(help = "Path to the sql file to export. Use dash - to write into stdout.")]
#[arg(help = "Path to the SurrealQL file to export. Use dash - to write into stdout.")]
#[arg(default_value = "-")]
#[arg(index = 1)]
file: String,
@ -87,7 +87,7 @@ pub async fn init(
} else {
client.export(file).await?;
}
info!("The SQL file was exported successfully");
info!("The SurrealQL file was exported successfully");
// Everything OK
Ok(())
}

View file

@ -10,7 +10,7 @@ use surrealdb::opt::Config;
#[derive(Args, Debug)]
pub struct ImportCommandArguments {
#[arg(help = "Path to the sql file to import")]
#[arg(help = "Path to the SurrealQL file to import")]
#[arg(index = 1)]
file: String,
#[command(flatten)]
@ -75,7 +75,7 @@ pub async fn init(
client.use_ns(namespace).use_db(database).await?;
// Import the data into the database
client.import(file).await?;
info!("The SQL file was imported successfully");
info!("The SurrealQL file was imported successfully");
// Everything OK
Ok(())
}

117
src/cli/ml/export.rs Normal file
View file

@ -0,0 +1,117 @@
use crate::cli::abstraction::auth::{CredentialsBuilder, CredentialsLevel};
use crate::cli::abstraction::{
AuthArguments, DatabaseConnectionArguments, DatabaseSelectionArguments,
};
use crate::err::Error;
use clap::Args;
use futures_util::StreamExt;
use surrealdb::engine::any::{connect, IntoEndpoint};
use tokio::io::{self, AsyncWriteExt};
#[derive(Args, Debug)]
pub struct ModelArguments {
#[arg(help = "The name of the model")]
#[arg(env = "SURREAL_NAME", long = "name")]
pub(crate) name: String,
#[arg(help = "The version of the model")]
#[arg(env = "SURREAL_VERSION", long = "version")]
pub(crate) version: String,
}
#[derive(Args, Debug)]
pub struct ExportCommandArguments {
#[arg(help = "Path to the SurrealML file to export. Use dash - to write into stdout.")]
#[arg(default_value = "-")]
#[arg(index = 1)]
file: String,
#[command(flatten)]
model: ModelArguments,
#[command(flatten)]
conn: DatabaseConnectionArguments,
#[command(flatten)]
auth: AuthArguments,
#[command(flatten)]
sel: DatabaseSelectionArguments,
}
pub async fn init(
ExportCommandArguments {
file,
model: ModelArguments {
name,
version,
},
conn: DatabaseConnectionArguments {
endpoint,
},
auth: AuthArguments {
username,
password,
auth_level,
},
sel: DatabaseSelectionArguments {
namespace,
database,
},
}: ExportCommandArguments,
) -> Result<(), Error> {
// Initialize opentelemetry and logging
crate::telemetry::builder().with_log_level("error").init();
// If username and password are specified, and we are connecting to a remote SurrealDB server, then we need to authenticate.
// If we are connecting directly to a datastore (i.e. file://local.db or tikv://...), then we don't need to authenticate because we use an embedded (local) SurrealDB instance with auth disabled.
let client = if username.is_some()
&& password.is_some()
&& !endpoint.clone().into_endpoint()?.parse_kind()?.is_local()
{
debug!("Connecting to the database engine with authentication");
let creds = CredentialsBuilder::default()
.with_username(username.as_deref())
.with_password(password.as_deref())
.with_namespace(namespace.as_str())
.with_database(database.as_str());
let client = connect(endpoint).await?;
debug!("Signing in to the database engine at '{:?}' level", auth_level);
match auth_level {
CredentialsLevel::Root => client.signin(creds.root()?).await?,
CredentialsLevel::Namespace => client.signin(creds.namespace()?).await?,
CredentialsLevel::Database => client.signin(creds.database()?).await?,
};
client
} else {
debug!("Connecting to the database engine without authentication");
connect(endpoint).await?
};
// Parse model version
let version = match version.parse() {
Ok(version) => version,
Err(_) => {
return Err(Error::Other(format!("`{version}` is not a valid semantic version")));
}
};
// Use the specified namespace / database
client.use_ns(namespace).use_db(database).await?;
// Export the data from the database
debug!("Exporting data from the database");
if file == "-" {
// Prepare the backup
let mut backup = client.export(()).ml(&name, version).await?;
// Get a handle to standard output
let mut stdout = io::stdout();
// Write the backup to standard output
while let Some(bytes) = backup.next().await {
stdout.write_all(&bytes?).await?;
}
} else {
client.export(file).ml(&name, version).await?;
}
info!("The SurrealML file was exported successfully");
// Everything OK
Ok(())
}

81
src/cli/ml/import.rs Normal file
View file

@ -0,0 +1,81 @@
use crate::cli::abstraction::auth::{CredentialsBuilder, CredentialsLevel};
use crate::cli::abstraction::{
AuthArguments, DatabaseConnectionArguments, DatabaseSelectionArguments,
};
use crate::err::Error;
use clap::Args;
use surrealdb::dbs::Capabilities;
use surrealdb::engine::any::{connect, IntoEndpoint};
use surrealdb::opt::Config;
#[derive(Args, Debug)]
pub struct ImportCommandArguments {
#[arg(help = "Path to the SurrealML file to import")]
#[arg(index = 1)]
file: String,
#[command(flatten)]
conn: DatabaseConnectionArguments,
#[command(flatten)]
auth: AuthArguments,
#[command(flatten)]
sel: DatabaseSelectionArguments,
}
pub async fn init(
ImportCommandArguments {
file,
conn: DatabaseConnectionArguments {
endpoint,
},
auth: AuthArguments {
username,
password,
auth_level,
},
sel: DatabaseSelectionArguments {
namespace,
database,
},
}: ImportCommandArguments,
) -> Result<(), Error> {
// Initialize opentelemetry and logging
crate::telemetry::builder().with_log_level("info").init();
// Default datastore configuration for local engines
let config = Config::new().capabilities(Capabilities::all());
// If username and password are specified, and we are connecting to a remote SurrealDB server, then we need to authenticate.
// If we are connecting directly to a datastore (i.e. file://local.db or tikv://...), then we don't need to authenticate because we use an embedded (local) SurrealDB instance with auth disabled.
let client = if username.is_some()
&& password.is_some()
&& !endpoint.clone().into_endpoint()?.parse_kind()?.is_local()
{
debug!("Connecting to the database engine with authentication");
let creds = CredentialsBuilder::default()
.with_username(username.as_deref())
.with_password(password.as_deref())
.with_namespace(namespace.as_str())
.with_database(database.as_str());
let client = connect(endpoint).await?;
debug!("Signing in to the database engine at '{:?}' level", auth_level);
match auth_level {
CredentialsLevel::Root => client.signin(creds.root()?).await?,
CredentialsLevel::Namespace => client.signin(creds.namespace()?).await?,
CredentialsLevel::Database => client.signin(creds.database()?).await?,
};
client
} else {
debug!("Connecting to the database engine without authentication");
connect((endpoint, config)).await?
};
// Use the specified namespace / database
client.use_ns(namespace).use_db(database).await?;
// Import the data into the database
client.import(file).ml().await?;
info!("The SurrealML file was imported successfully");
// Everything OK
Ok(())
}

22
src/cli/ml/mod.rs Normal file
View file

@ -0,0 +1,22 @@
mod export;
mod import;
use self::export::ExportCommandArguments;
use self::import::ImportCommandArguments;
use crate::err::Error;
use clap::Subcommand;
#[derive(Debug, Subcommand)]
pub enum MlCommand {
#[command(about = "Import a SurrealML model into an existing database")]
Import(ImportCommandArguments),
#[command(about = "Export a SurrealML model from an existing database")]
Export(ExportCommandArguments),
}
pub async fn init(command: MlCommand) -> Result<(), Error> {
match command {
MlCommand::Import(args) => import::init(args).await,
MlCommand::Export(args) => export::init(args).await,
}
}

View file

@ -4,6 +4,7 @@ mod config;
mod export;
mod import;
mod isready;
mod ml;
mod sql;
#[cfg(feature = "has-storage")]
mod start;
@ -20,6 +21,7 @@ pub use config::CF;
use export::ExportCommandArguments;
use import::ImportCommandArguments;
use isready::IsReadyCommandArguments;
use ml::MlCommand;
use sql::SqlCommandArguments;
#[cfg(feature = "has-storage")]
use start::StartCommandArguments;
@ -68,6 +70,8 @@ enum Commands {
Upgrade(UpgradeCommandArguments),
#[command(about = "Start an SQL REPL in your terminal with pipe support")]
Sql(SqlCommandArguments),
#[command(subcommand, about = "Manage SurrealML models within an existing database")]
Ml(MlCommand),
#[command(
about = "Check if the SurrealDB server is ready to accept connections",
visible_alias = "isready"
@ -88,6 +92,7 @@ pub async fn init() -> ExitCode {
Commands::Version(args) => version::init(args).await,
Commands::Upgrade(args) => upgrade::init(args).await,
Commands::Sql(args) => sql::init(args).await,
Commands::Ml(args) => ml::init(args).await,
Commands::IsReady(args) => isready::init(args).await,
Commands::Validate(args) => validate::init(args).await,
};

View file

@ -29,31 +29,50 @@ pub const APP_ENDPOINT: &str = "https://surrealdb.com/app";
#[cfg(feature = "has-storage")]
pub const WEBSOCKET_PING_FREQUENCY: Duration = Duration::from_secs(5);
/// Set the maximum WebSocket frame size to 16mb
/// What is the maximum WebSocket frame size (defaults to 16 MiB)
#[cfg(feature = "has-storage")]
pub static WEBSOCKET_MAX_FRAME_SIZE: Lazy<usize> = Lazy::new(|| {
let default = 16 << 20;
std::env::var("SURREAL_WEBSOCKET_MAX_FRAME_SIZE")
.map(|v| v.parse::<usize>().unwrap_or(default))
.unwrap_or(default)
option_env!("SURREAL_WEBSOCKET_MAX_FRAME_SIZE")
.and_then(|s| s.parse::<usize>().ok())
.unwrap_or(16 << 20)
});
/// Set the maximum WebSocket frame size to 128mb
/// What is the maximum WebSocket message size (defaults to 128 MiB)
#[cfg(feature = "has-storage")]
pub static WEBSOCKET_MAX_MESSAGE_SIZE: Lazy<usize> = Lazy::new(|| {
let default = 128 << 20;
std::env::var("SURREAL_WEBSOCKET_MAX_MESSAGE_SIZE")
.map(|v| v.parse::<usize>().unwrap_or(default))
option_env!("SURREAL_WEBSOCKET_MAX_MESSAGE_SIZE")
.and_then(|s| s.parse::<usize>().ok())
.unwrap_or(128 << 20)
});
/// How many concurrent tasks can be handled on each WebSocket (defaults to 24)
#[cfg(feature = "has-storage")]
pub static WEBSOCKET_MAX_CONCURRENT_REQUESTS: Lazy<usize> = Lazy::new(|| {
option_env!("SURREAL_WEBSOCKET_MAX_CONCURRENT_REQUESTS")
.and_then(|s| s.parse::<usize>().ok())
.unwrap_or(24)
});
/// What is the runtime thread memory stack size (defaults to 10MiB)
#[cfg(feature = "has-storage")]
pub static RUNTIME_STACK_SIZE: Lazy<usize> = Lazy::new(|| {
// Stack frames are generally larger in debug mode.
let default = if cfg!(debug_assertions) {
20 * 1024 * 1024 // 20MiB in debug mode
} else {
10 * 1024 * 1024 // 10MiB in release mode
};
option_env!("SURREAL_RUNTIME_STACK_SIZE")
.and_then(|s| s.parse::<usize>().ok())
.unwrap_or(default)
});
/// How many concurrent tasks can be handled in a WebSocket
/// How many threads which can be started for blocking operations (defaults to 512)
#[cfg(feature = "has-storage")]
pub static WEBSOCKET_MAX_CONCURRENT_REQUESTS: Lazy<usize> = Lazy::new(|| {
let default = 24;
std::env::var("SURREAL_WEBSOCKET_MAX_CONCURRENT_REQUESTS")
.map(|v| v.parse::<usize>().unwrap_or(default))
.unwrap_or(default)
pub static RUNTIME_MAX_BLOCKING_THREADS: Lazy<usize> = Lazy::new(|| {
option_env!("SURREAL_RUNTIME_MAX_BLOCKING_THREADS")
.and_then(|s| s.parse::<usize>().ok())
.unwrap_or(512)
});
/// The version identifier of this build

View file

@ -1,6 +1,7 @@
use crate::cli::abstraction::auth::Error as SurrealAuthError;
use axum::extract::rejection::TypedHeaderRejection;
use axum::response::{IntoResponse, Response};
use axum::Error as AxumError;
use axum::Json;
use base64::DecodeError as Base64Error;
use http::{HeaderName, StatusCode};
@ -48,6 +49,9 @@ pub enum Error {
#[error("Couldn't open the specified file: {0}")]
Io(#[from] IoError),
#[error("There was an error with the network: {0}")]
Axum(#[from] AxumError),
#[error("There was an error serializing to JSON: {0}")]
Json(#[from] JsonError),
@ -60,11 +64,15 @@ pub enum Error {
#[error("There was an error with the remote request: {0}")]
Remote(#[from] ReqwestError),
#[error("There was an error with auth: {0}")]
Auth(#[from] SurrealAuthError),
#[error("There was an error with the node agent")]
NodeAgent,
#[error("There was an error with auth: {0}")]
Auth(#[from] SurrealAuthError),
/// Statement has been deprecated
#[error("{0}")]
Other(String),
}
impl From<Error> for String {

View file

@ -42,15 +42,12 @@ fn main() -> ExitCode {
/// Rust's default thread stack size of 2MiB doesn't allow sufficient recursion depth.
fn with_enough_stack<T>(fut: impl Future<Output = T> + Send) -> T {
let stack_size = 10 * 1024 * 1024;
// Stack frames are generally larger in debug mode.
#[cfg(debug_assertions)]
let stack_size = stack_size * 2;
// Start a Tokio runtime with custom configuration
tokio::runtime::Builder::new_multi_thread()
.enable_all()
.thread_stack_size(stack_size)
.max_blocking_threads(*cnf::RUNTIME_MAX_BLOCKING_THREADS)
.thread_stack_size(*cnf::RUNTIME_STACK_SIZE)
.thread_name("surrealdb-worker")
.build()
.unwrap()
.block_on(fut)

View file

@ -9,6 +9,9 @@ use http::StatusCode;
use http_body::Body as HttpBody;
use hyper::body::Body;
use surrealdb::dbs::Session;
use surrealdb::iam::check::check_ns_db;
use surrealdb::iam::Action::View;
use surrealdb::iam::ResourceKind::Any;
pub(super) fn router<S, B>() -> Router<S, B>
where
@ -18,35 +21,27 @@ where
Router::new().route("/export", get(handler))
}
async fn handler(
Extension(session): Extension<Session>,
) -> Result<impl IntoResponse, impl IntoResponse> {
async fn handler(Extension(session): Extension<Session>) -> Result<impl IntoResponse, Error> {
// Get the datastore reference
let db = DB.get().unwrap();
// Extract the NS header value
let nsv = match session.ns.clone() {
Some(ns) => ns,
None => return Err(Error::NoNamespace),
};
// Extract the DB header value
let dbv = match session.db.clone() {
Some(db) => db,
None => return Err(Error::NoDatabase),
};
// Create a chunked response
let (mut chn, bdy) = Body::channel();
let (mut chn, body) = Body::channel();
// Ensure a NS and DB are set
let (nsv, dbv) = check_ns_db(&session)?;
// Check the permissions level
db.check(&session, View, Any.on_db(&nsv, &dbv))?;
// Create a new bounded channel
let (snd, rcv) = surrealdb::channel::bounded(1);
let export_job = db.export(&session, nsv, dbv, snd).await.map_err(Error::from)?;
// Start the export task
let task = db.export(&session, snd).await?;
// Spawn a new database export job
tokio::spawn(export_job);
// Process all processed values
tokio::spawn(task);
// Process all chunk values
tokio::spawn(async move {
while let Ok(v) = rcv.recv().await {
let _ = chn.send_data(Bytes::from(v)).await;
}
});
// Return the chunked body
Ok(Response::builder().status(StatusCode::OK).body(bdy).unwrap())
Ok(Response::builder().status(StatusCode::OK).body(body).unwrap())
}

View file

@ -1,3 +1,4 @@
use super::headers::Accept;
use crate::dbs::DB;
use crate::err::Error;
use crate::net::input::bytes_to_utf8;
@ -11,10 +12,10 @@ use axum::TypedHeader;
use bytes::Bytes;
use http_body::Body as HttpBody;
use surrealdb::dbs::Session;
use surrealdb::iam::Action::Edit;
use surrealdb::iam::ResourceKind::Any;
use tower_http::limit::RequestBodyLimitLayer;
use super::headers::Accept;
const MAX: usize = 1024 * 1024 * 1024 * 4; // 4 GiB
pub(super) fn router<S, B>() -> Router<S, B>
@ -32,24 +33,26 @@ where
async fn handler(
Extension(session): Extension<Session>,
maybe_output: Option<TypedHeader<Accept>>,
accept: Option<TypedHeader<Accept>>,
sql: Bytes,
) -> Result<impl IntoResponse, impl IntoResponse> {
// Get the datastore reference
let db = DB.get().unwrap();
// Convert the body to a byte slice
let sql = bytes_to_utf8(&sql)?;
// Check the permissions level
db.check(&session, Edit, Any.on_level(session.au.level().to_owned()))?;
// Execute the sql query in the database
match db.import(sql, &session).await {
Ok(res) => match maybe_output.as_deref() {
Ok(res) => match accept.as_deref() {
// Simple serialization
Some(Accept::ApplicationJson) => Ok(output::json(&output::simplify(res))),
Some(Accept::ApplicationCbor) => Ok(output::cbor(&output::simplify(res))),
Some(Accept::ApplicationPack) => Ok(output::pack(&output::simplify(res))),
// Internal serialization
Some(Accept::Surrealdb) => Ok(output::full(&res)),
// Return nothing
Some(Accept::ApplicationOctetStream) => Ok(output::none()),
// Internal serialization
Some(Accept::Surrealdb) => Ok(output::full(&res)),
// An incorrect content-type was requested
_ => Err(Error::InvalidType),
},

View file

@ -13,6 +13,7 @@ use http_body::Body as HttpBody;
use serde::Deserialize;
use std::str;
use surrealdb::dbs::Session;
use surrealdb::iam::check::check_ns_db;
use surrealdb::sql::Value;
use tower_http::limit::RequestBodyLimitLayer;
@ -68,12 +69,14 @@ where
async fn select_all(
Extension(session): Extension<Session>,
maybe_output: Option<TypedHeader<Accept>>,
accept: Option<TypedHeader<Accept>>,
Path(table): Path<String>,
Query(query): Query<QueryOptions>,
) -> Result<impl IntoResponse, impl IntoResponse> {
// Get the datastore reference
let db = DB.get().unwrap();
// Ensure a NS and DB are set
let _ = check_ns_db(&session)?;
// Specify the request statement
let sql = match query.fields {
None => "SELECT * FROM type::table($table) LIMIT $limit START $start",
@ -88,7 +91,7 @@ async fn select_all(
};
// Execute the query and return the result
match db.execute(sql, &session, Some(vars)).await {
Ok(res) => match maybe_output.as_deref() {
Ok(res) => match accept.as_deref() {
// Simple serialization
Some(Accept::ApplicationJson) => Ok(output::json(&output::simplify(res))),
Some(Accept::ApplicationCbor) => Ok(output::cbor(&output::simplify(res))),
@ -105,13 +108,15 @@ async fn select_all(
async fn create_all(
Extension(session): Extension<Session>,
maybe_output: Option<TypedHeader<Accept>>,
accept: Option<TypedHeader<Accept>>,
Path(table): Path<String>,
Query(params): Query<Params>,
body: Bytes,
) -> Result<impl IntoResponse, impl IntoResponse> {
// Get the datastore reference
let db = DB.get().unwrap();
// Ensure a NS and DB are set
let _ = check_ns_db(&session)?;
// Convert the HTTP request body
let data = bytes_to_utf8(&body)?;
// Parse the request body as JSON
@ -127,7 +132,7 @@ async fn create_all(
};
// Execute the query and return the result
match db.execute(sql, &session, Some(vars)).await {
Ok(res) => match maybe_output.as_deref() {
Ok(res) => match accept.as_deref() {
// Simple serialization
Some(Accept::ApplicationJson) => Ok(output::json(&output::simplify(res))),
Some(Accept::ApplicationCbor) => Ok(output::cbor(&output::simplify(res))),
@ -147,13 +152,15 @@ async fn create_all(
async fn update_all(
Extension(session): Extension<Session>,
maybe_output: Option<TypedHeader<Accept>>,
accept: Option<TypedHeader<Accept>>,
Path(table): Path<String>,
Query(params): Query<Params>,
body: Bytes,
) -> Result<impl IntoResponse, impl IntoResponse> {
// Get the datastore reference
let db = DB.get().unwrap();
// Ensure a NS and DB are set
let _ = check_ns_db(&session)?;
// Convert the HTTP request body
let data = bytes_to_utf8(&body)?;
// Parse the request body as JSON
@ -169,7 +176,7 @@ async fn update_all(
};
// Execute the query and return the result
match db.execute(sql, &session, Some(vars)).await {
Ok(res) => match maybe_output.as_deref() {
Ok(res) => match accept.as_deref() {
// Simple serialization
Some(Accept::ApplicationJson) => Ok(output::json(&output::simplify(res))),
Some(Accept::ApplicationCbor) => Ok(output::cbor(&output::simplify(res))),
@ -189,13 +196,15 @@ async fn update_all(
async fn modify_all(
Extension(session): Extension<Session>,
maybe_output: Option<TypedHeader<Accept>>,
accept: Option<TypedHeader<Accept>>,
Path(table): Path<String>,
Query(params): Query<Params>,
body: Bytes,
) -> Result<impl IntoResponse, impl IntoResponse> {
// Get the datastore reference
let db = DB.get().unwrap();
// Ensure a NS and DB are set
let _ = check_ns_db(&session)?;
// Convert the HTTP request body
let data = bytes_to_utf8(&body)?;
// Parse the request body as JSON
@ -211,7 +220,7 @@ async fn modify_all(
};
// Execute the query and return the result
match db.execute(sql, &session, Some(vars)).await {
Ok(res) => match maybe_output.as_deref() {
Ok(res) => match accept.as_deref() {
// Simple serialization
Some(Accept::ApplicationJson) => Ok(output::json(&output::simplify(res))),
Some(Accept::ApplicationCbor) => Ok(output::cbor(&output::simplify(res))),
@ -231,12 +240,14 @@ async fn modify_all(
async fn delete_all(
Extension(session): Extension<Session>,
maybe_output: Option<TypedHeader<Accept>>,
accept: Option<TypedHeader<Accept>>,
Path(table): Path<String>,
Query(params): Query<Params>,
) -> Result<impl IntoResponse, impl IntoResponse> {
// Get the datastore reference
let db = DB.get().unwrap();
// Ensure a NS and DB are set
let _ = check_ns_db(&session)?;
// Specify the request statement
let sql = "DELETE type::table($table) RETURN BEFORE";
// Specify the request variables
@ -246,7 +257,7 @@ async fn delete_all(
};
// Execute the query and return the result
match db.execute(sql, &session, Some(vars)).await {
Ok(res) => match maybe_output.as_deref() {
Ok(res) => match accept.as_deref() {
// Simple serialization
Some(Accept::ApplicationJson) => Ok(output::json(&output::simplify(res))),
Some(Accept::ApplicationCbor) => Ok(output::cbor(&output::simplify(res))),
@ -267,12 +278,14 @@ async fn delete_all(
async fn select_one(
Extension(session): Extension<Session>,
maybe_output: Option<TypedHeader<Accept>>,
accept: Option<TypedHeader<Accept>>,
Path((table, id)): Path<(String, String)>,
Query(query): Query<QueryOptions>,
) -> Result<impl IntoResponse, impl IntoResponse> {
// Get the datastore reference
let db = DB.get().unwrap();
// Ensure a NS and DB are set
let _ = check_ns_db(&session)?;
// Specify the request statement
let sql = match query.fields {
None => "SELECT * FROM type::thing($table, $id)",
@ -291,7 +304,7 @@ async fn select_one(
};
// Execute the query and return the result
match db.execute(sql, &session, Some(vars)).await {
Ok(res) => match maybe_output.as_deref() {
Ok(res) => match accept.as_deref() {
// Simple serialization
Some(Accept::ApplicationJson) => Ok(output::json(&output::simplify(res))),
Some(Accept::ApplicationCbor) => Ok(output::cbor(&output::simplify(res))),
@ -308,13 +321,15 @@ async fn select_one(
async fn create_one(
Extension(session): Extension<Session>,
maybe_output: Option<TypedHeader<Accept>>,
accept: Option<TypedHeader<Accept>>,
Query(params): Query<Params>,
Path((table, id)): Path<(String, String)>,
body: Bytes,
) -> Result<impl IntoResponse, impl IntoResponse> {
// Get the datastore reference
let db = DB.get().unwrap();
// Ensure a NS and DB are set
let _ = check_ns_db(&session)?;
// Convert the HTTP request body
let data = bytes_to_utf8(&body)?;
// Parse the Record ID as a SurrealQL value
@ -336,7 +351,7 @@ async fn create_one(
};
// Execute the query and return the result
match db.execute(sql, &session, Some(vars)).await {
Ok(res) => match maybe_output.as_deref() {
Ok(res) => match accept.as_deref() {
// Simple serialization
Some(Accept::ApplicationJson) => Ok(output::json(&output::simplify(res))),
Some(Accept::ApplicationCbor) => Ok(output::cbor(&output::simplify(res))),
@ -356,13 +371,15 @@ async fn create_one(
async fn update_one(
Extension(session): Extension<Session>,
maybe_output: Option<TypedHeader<Accept>>,
accept: Option<TypedHeader<Accept>>,
Query(params): Query<Params>,
Path((table, id)): Path<(String, String)>,
body: Bytes,
) -> Result<impl IntoResponse, impl IntoResponse> {
// Get the datastore reference
let db = DB.get().unwrap();
// Ensure a NS and DB are set
let _ = check_ns_db(&session)?;
// Convert the HTTP request body
let data = bytes_to_utf8(&body)?;
// Parse the Record ID as a SurrealQL value
@ -384,7 +401,7 @@ async fn update_one(
};
// Execute the query and return the result
match db.execute(sql, &session, Some(vars)).await {
Ok(res) => match maybe_output.as_deref() {
Ok(res) => match accept.as_deref() {
// Simple serialization
Some(Accept::ApplicationJson) => Ok(output::json(&output::simplify(res))),
Some(Accept::ApplicationCbor) => Ok(output::cbor(&output::simplify(res))),
@ -404,13 +421,15 @@ async fn update_one(
async fn modify_one(
Extension(session): Extension<Session>,
maybe_output: Option<TypedHeader<Accept>>,
accept: Option<TypedHeader<Accept>>,
Query(params): Query<Params>,
Path((table, id)): Path<(String, String)>,
body: Bytes,
) -> Result<impl IntoResponse, impl IntoResponse> {
// Get the datastore reference
let db = DB.get().unwrap();
// Ensure a NS and DB are set
let _ = check_ns_db(&session)?;
// Convert the HTTP request body
let data = bytes_to_utf8(&body)?;
// Parse the Record ID as a SurrealQL value
@ -432,7 +451,7 @@ async fn modify_one(
};
// Execute the query and return the result
match db.execute(sql, &session, Some(vars)).await {
Ok(res) => match maybe_output.as_deref() {
Ok(res) => match accept.as_deref() {
// Simple serialization
Some(Accept::ApplicationJson) => Ok(output::json(&output::simplify(res))),
Some(Accept::ApplicationCbor) => Ok(output::cbor(&output::simplify(res))),
@ -452,11 +471,13 @@ async fn modify_one(
async fn delete_one(
Extension(session): Extension<Session>,
maybe_output: Option<TypedHeader<Accept>>,
accept: Option<TypedHeader<Accept>>,
Path((table, id)): Path<(String, String)>,
) -> Result<impl IntoResponse, impl IntoResponse> {
// Get the datastore reference
let db = DB.get().unwrap();
// Ensure a NS and DB are set
let _ = check_ns_db(&session)?;
// Specify the request statement
let sql = "DELETE type::thing($table, $id) RETURN BEFORE";
// Parse the Record ID as a SurrealQL value
@ -471,7 +492,7 @@ async fn delete_one(
};
// Execute the query and return the result
match db.execute(sql, &session, Some(vars)).await {
Ok(res) => match maybe_output.as_deref() {
Ok(res) => match accept.as_deref() {
// Simple serialization
Some(Accept::ApplicationJson) => Ok(output::json(&output::simplify(res))),
Some(Accept::ApplicationCbor) => Ok(output::cbor(&output::simplify(res))),

123
src/net/ml.rs Normal file
View file

@ -0,0 +1,123 @@
//! This file defines the endpoints for the ML API for importing and exporting SurrealML models.
use crate::dbs::DB;
use crate::err::Error;
use crate::net::output;
use axum::extract::{BodyStream, DefaultBodyLimit, Path};
use axum::response::IntoResponse;
use axum::response::Response;
use axum::routing::{get, post};
use axum::Extension;
use axum::Router;
use bytes::Bytes;
use futures_util::StreamExt;
use http::StatusCode;
use http_body::Body as HttpBody;
use hyper::body::Body;
use surrealdb::dbs::Session;
use surrealdb::iam::check::check_ns_db;
use surrealdb::iam::Action::{Edit, View};
use surrealdb::iam::ResourceKind::Model;
use surrealdb::kvs::{LockType::Optimistic, TransactionType::Read};
use surrealdb::sql::statements::{DefineModelStatement, DefineStatement};
use surrealml_core::storage::surml_file::SurMlFile;
use tower_http::limit::RequestBodyLimitLayer;
const MAX: usize = 1024 * 1024 * 1024 * 4; // 4 GiB
/// The router definition for the ML API endpoints.
pub(super) fn router<S, B>() -> Router<S, B>
where
B: HttpBody + Send + 'static,
B::Data: Send + Into<Bytes>,
B::Error: std::error::Error + Send + Sync + 'static,
S: Clone + Send + Sync + 'static,
{
Router::new()
.route("/ml/import", post(import))
.route("/ml/export/:name/:version", get(export))
.route_layer(DefaultBodyLimit::disable())
.layer(RequestBodyLimitLayer::new(MAX))
}
/// This endpoint allows the user to import a model into the database.
async fn import(
Extension(session): Extension<Session>,
mut stream: BodyStream,
) -> Result<impl IntoResponse, impl IntoResponse> {
// Get the datastore reference
let db = DB.get().unwrap();
// Ensure a NS and DB are set
let (nsv, dbv) = check_ns_db(&session)?;
// Check the permissions level
db.check(&session, Edit, Model.on_db(&nsv, &dbv))?;
// Create a new buffer
let mut buffer = Vec::new();
// Load all the uploaded file chunks
while let Some(chunk) = stream.next().await {
buffer.extend_from_slice(&chunk?);
}
// Check that the SurrealML file is valid
let file = match SurMlFile::from_bytes(buffer) {
Ok(file) => file,
Err(err) => return Err(Error::Other(err.to_string())),
};
// Convert the file back in to raw bytes
let data = file.to_bytes();
// Calculate the hash of the model file
let hash = surrealdb::obs::hash(&data);
// Calculate the path of the model file
let path = format!(
"ml/{nsv}/{dbv}/{}-{}-{hash}.surml",
file.header.name.to_string(),
file.header.version.to_string()
);
// Insert the file data in to the store
surrealdb::obs::put(&path, data).await?;
// Insert the model in to the database
db.process(
DefineStatement::Model(DefineModelStatement {
hash,
name: file.header.name.to_string().into(),
version: file.header.version.to_string(),
comment: Some(file.header.description.to_string().into()),
..Default::default()
})
.into(),
&session,
None,
)
.await?;
//
Ok(output::none())
}
/// This endpoint allows the user to export a model from the database.
async fn export(
Extension(session): Extension<Session>,
Path((name, version)): Path<(String, String)>,
) -> Result<impl IntoResponse, Error> {
// Get the datastore reference
let db = DB.get().unwrap();
// Ensure a NS and DB are set
let (nsv, dbv) = check_ns_db(&session)?;
// Check the permissions level
db.check(&session, View, Model.on_db(&nsv, &dbv))?;
// Start a new readonly transaction
let mut tx = db.transaction(Read, Optimistic).await?;
// Attempt to get the model definition
let info = tx.get_db_model(&nsv, &dbv, &name, &version).await?;
// Calculate the path of the model file
let path = format!("ml/{nsv}/{dbv}/{name}-{version}-{}.surml", info.hash);
// Export the file data in to the store
let mut data = surrealdb::obs::stream(path).await?;
// Create a chunked response
let (mut chn, body) = Body::channel();
// Process all stream values
tokio::spawn(async move {
while let Some(Ok(v)) = data.next().await {
let _ = chn.send_data(v).await;
}
});
// Return the streamed body
Ok(Response::builder().status(StatusCode::OK).body(body).unwrap())
}

View file

@ -17,6 +17,9 @@ mod sync;
mod tracer;
mod version;
#[cfg(feature = "ml")]
mod ml;
use axum::response::Redirect;
use axum::routing::get;
use axum::{middleware, Router};
@ -150,8 +153,12 @@ pub async fn init(ct: CancellationToken) -> Result<(), Error> {
.merge(sql::router())
.merge(signin::router())
.merge(signup::router())
.merge(key::router())
.layer(service);
.merge(key::router());
#[cfg(feature = "ml")]
let axum_app = axum_app.merge(ml::router());
let axum_app = axum_app.layer(service);
// Setup the graceful shutdown
let handle = Handle::new();

View file

@ -51,7 +51,7 @@ where
async fn handler(
Extension(mut session): Extension<Session>,
maybe_output: Option<TypedHeader<Accept>>,
accept: Option<TypedHeader<Accept>>,
body: Bytes,
) -> Result<impl IntoResponse, impl IntoResponse> {
// Get a database reference
@ -65,15 +65,15 @@ async fn handler(
match surrealdb::iam::signin::signin(kvs, &mut session, vars).await.map_err(Error::from)
{
// Authentication was successful
Ok(v) => match maybe_output.as_deref() {
Ok(v) => match accept.as_deref() {
// Simple serialization
Some(Accept::ApplicationJson) => Ok(output::json(&Success::new(v))),
Some(Accept::ApplicationCbor) => Ok(output::cbor(&Success::new(v))),
Some(Accept::ApplicationPack) => Ok(output::pack(&Success::new(v))),
// Internal serialization
Some(Accept::Surrealdb) => Ok(output::full(&Success::new(v))),
// Text serialization
Some(Accept::TextPlain) => Ok(output::text(v.unwrap_or_default())),
// Internal serialization
Some(Accept::Surrealdb) => Ok(output::full(&Success::new(v))),
// Return nothing
None => Ok(output::none()),
// An incorrect content-type was requested

View file

@ -49,7 +49,7 @@ where
async fn handler(
Extension(mut session): Extension<Session>,
maybe_output: Option<TypedHeader<Accept>>,
accept: Option<TypedHeader<Accept>>,
body: Bytes,
) -> Result<impl IntoResponse, impl IntoResponse> {
// Get a database reference
@ -63,15 +63,15 @@ async fn handler(
match surrealdb::iam::signup::signup(kvs, &mut session, vars).await.map_err(Error::from)
{
// Authentication was successful
Ok(v) => match maybe_output.as_deref() {
Ok(v) => match accept.as_deref() {
// Simple serialization
Some(Accept::ApplicationJson) => Ok(output::json(&Success::new(v))),
Some(Accept::ApplicationCbor) => Ok(output::cbor(&Success::new(v))),
Some(Accept::ApplicationPack) => Ok(output::pack(&Success::new(v))),
// Internal serialization
Some(Accept::Surrealdb) => Ok(output::full(&Success::new(v))),
// Text serialization
Some(Accept::TextPlain) => Ok(output::text(v.unwrap_or_default())),
// Internal serialization
Some(Accept::Surrealdb) => Ok(output::full(&Success::new(v))),
// Return nothing
None => Ok(output::none()),
// An incorrect content-type was requested

BIN
tests/linear_test.surml Normal file

Binary file not shown.

149
tests/ml_integration.rs Normal file
View file

@ -0,0 +1,149 @@
// RUST_LOG=warn cargo make ci-ml-integration
mod common;
#[cfg(feature = "ml")]
mod ml_integration {
use super::*;
use http::{header, StatusCode};
use hyper::Body;
use serde::{Deserialize, Serialize};
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::Duration;
use surrealml_core::storage::stream_adapter::StreamAdapter;
use test_log::test;
use ulid::Ulid;
static LOCK: AtomicBool = AtomicBool::new(false);
#[derive(Serialize, Deserialize, Debug)]
struct Data {
result: f64,
status: String,
time: String,
}
struct LockHandle;
impl LockHandle {
fn acquire_lock() -> Self {
while LOCK.compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
!= Ok(false)
{
std::thread::sleep(Duration::from_millis(100));
}
LockHandle
}
}
impl Drop for LockHandle {
fn drop(&mut self) {
LOCK.store(false, Ordering::Release);
}
}
async fn upload_file(addr: &str, ns: &str, db: &str) -> Result<(), Box<dyn std::error::Error>> {
let generator = StreamAdapter::new(5, "./tests/linear_test.surml".to_string());
let body = Body::wrap_stream(generator);
// Prepare HTTP client
let mut headers = reqwest::header::HeaderMap::new();
headers.insert("NS", ns.parse()?);
headers.insert("DB", db.parse()?);
let client = reqwest::Client::builder()
.connect_timeout(Duration::from_secs(1))
.default_headers(headers)
.build()?;
// Send HTTP request
let res = client
.post(format!("http://{addr}/ml/import"))
.basic_auth(common::USER, Some(common::PASS))
.body(body)
.send()
.await?;
// Check response code
assert_eq!(res.status(), StatusCode::OK);
Ok(())
}
#[test(tokio::test)]
async fn upload_model() -> Result<(), Box<dyn std::error::Error>> {
let _lock = LockHandle::acquire_lock();
let (addr, _server) = common::start_server_with_defaults().await.unwrap();
let ns = Ulid::new().to_string();
let db = Ulid::new().to_string();
upload_file(&addr, &ns, &db).await?;
Ok(())
}
#[test(tokio::test)]
async fn raw_compute() -> Result<(), Box<dyn std::error::Error>> {
let _lock = LockHandle::acquire_lock();
let (addr, _server) = common::start_server_with_defaults().await.unwrap();
let ns = Ulid::new().to_string();
let db = Ulid::new().to_string();
upload_file(&addr, &ns, &db).await?;
// Prepare HTTP client
let mut headers = reqwest::header::HeaderMap::new();
headers.insert("NS", ns.parse()?);
headers.insert("DB", db.parse()?);
headers.insert(header::ACCEPT, "application/json".parse()?);
let client = reqwest::Client::builder()
.connect_timeout(Duration::from_millis(10))
.default_headers(headers)
.build()?;
// perform an SQL query to check if the model is available
{
let res = client
.post(format!("http://{addr}/sql"))
.basic_auth(common::USER, Some(common::PASS))
.body(r#"ml::Prediction<0.0.1>([1.0, 1.0]);"#)
.send()
.await?;
assert!(res.status().is_success(), "body: {}", res.text().await?);
let body = res.text().await?;
let deserialized_data: Vec<Data> = serde_json::from_str(&body)?;
assert_eq!(deserialized_data[0].result, 0.9998061656951904);
}
Ok(())
}
#[test(tokio::test)]
async fn buffered_compute() -> Result<(), Box<dyn std::error::Error>> {
let _lock = LockHandle::acquire_lock();
let (addr, _server) = common::start_server_with_defaults().await.unwrap();
let ns = Ulid::new().to_string();
let db = Ulid::new().to_string();
upload_file(&addr, &ns, &db).await?;
// Prepare HTTP client
let mut headers = reqwest::header::HeaderMap::new();
headers.insert("NS", ns.parse()?);
headers.insert("DB", db.parse()?);
headers.insert(header::ACCEPT, "application/json".parse()?);
let client = reqwest::Client::builder()
.connect_timeout(Duration::from_millis(10))
.default_headers(headers)
.build()?;
// perform an SQL query to check if the model is available
{
let res = client
.post(format!("http://{addr}/sql"))
.basic_auth(common::USER, Some(common::PASS))
.body(r#"ml::Prediction<0.0.1>({squarefoot: 500.0, num_floors: 1.0});"#)
.send()
.await?;
assert!(res.status().is_success(), "body: {}", res.text().await?);
let body = res.text().await?;
let deserialized_data: Vec<Data> = serde_json::from_str(&body)?;
assert_eq!(deserialized_data[0].result, 177206.21875);
}
Ok(())
}
}