From e201366602eb547047354cd1c1d3da6c956d870a Mon Sep 17 00:00:00 2001 From: Maxwell Flitton Date: Thu, 28 Mar 2024 13:57:26 +0000 Subject: [PATCH] updating surrealml-core and adding checking for version and name (#3773) --- Cargo.lock | 49 +++++++---- core/Cargo.toml | 2 +- core/src/sql/model.rs | 26 ++++-- lib/src/api/engine/local/mod.rs | 7 +- src/net/ml.rs | 6 ++ tests/ml_integration.rs | 142 +++++++++++++++++++++++++++++++- tests/no_name.surml | Bin 0 -> 550 bytes tests/no_name_or_version.surml | Bin 0 -> 545 bytes tests/no_version.surml | Bin 0 -> 555 bytes tests/should_crash.surml | 1 + 10 files changed, 207 insertions(+), 26 deletions(-) create mode 100644 tests/no_name.surml create mode 100644 tests/no_name_or_version.surml create mode 100644 tests/no_version.surml create mode 100644 tests/should_crash.surml diff --git a/Cargo.lock b/Cargo.lock index e41c9974..51291cef 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1100,9 +1100,9 @@ dependencies = [ [[package]] name = "cedar-policy" -version = "2.4.3" +version = "2.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "31ff2003d0aba0a4b2e5212660321d63dc7c36efe636d6ca1882d489cbc0bef8" +checksum = "3d91e3b10a0f7f2911774d5e49713c4d25753466f9e11d1cd2ec627f8a2dc857" dependencies = [ "cedar-policy-core", "cedar-policy-validator", @@ -1117,9 +1117,9 @@ dependencies = [ [[package]] name = "cedar-policy-core" -version = "2.4.3" +version = "2.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c52f9666c7cb1b6f14a6e77d3ffcffa20fd3e1012ac8dcc393498c33ff632c3" +checksum = "cd2315591c6b7e18f8038f0a0529f254235fd902b6c217aabc04f2459b0d9995" dependencies = [ "either", "ipnet", @@ -1140,9 +1140,9 @@ dependencies = [ [[package]] name = "cedar-policy-validator" -version = "2.4.3" +version = "2.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76a63c1a72bcafda800830cbdde316162074b341b7d59bd4b1cea6156f22dfa7" +checksum = "e756e1b2a5da742ed97e65199ad6d0893e9aa4bd6b34be1de9e70bd1e6adc7df" dependencies = [ "cedar-policy-core", "itertools 0.10.5", @@ -5272,9 +5272,9 @@ checksum = "cd0b0ec5f1c1ca621c432a25813d8d60c88abe6d3e08a3eb9cf37d97a0fe3d73" [[package]] name = "serde" -version = "1.0.193" +version = "1.0.197" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "25dd9975e68d0cb5aa1120c288333fc98731bd1dd12f561e468ea4728c042b89" +checksum = "3fb1c873e1b9b056a4dc4c0c198b24c3ffa059243875552b2bd0933b1aee4ce2" dependencies = [ "serde_derive", ] @@ -5290,9 +5290,9 @@ dependencies = [ [[package]] name = "serde_derive" -version = "1.0.193" +version = "1.0.197" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43576ca501357b9b071ac53cdc7da8ef0cbd9493d8df094cd821777ea6e894d3" +checksum = "7eb0b34b42edc17f6b7cac84a52a1c5f0e1bb2227e997ca9011ea3dd34e8610b" dependencies = [ "proc-macro2", "quote", @@ -5886,7 +5886,7 @@ dependencies = [ "surrealdb-jsonwebtoken", "surrealdb-tikv-client", "surrealkv", - "surrealml-core", + "surrealml-core 0.0.8", "thiserror", "tokio", "tracing", @@ -5972,7 +5972,7 @@ dependencies = [ "surrealdb-jsonwebtoken", "surrealdb-tikv-client", "surrealkv", - "surrealml-core", + "surrealml-core 0.1.1", "temp-dir", "test-log", "thiserror", @@ -6085,6 +6085,23 @@ dependencies = [ "regex", ] +[[package]] +name = "surrealml-core" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5af53365d57a1bd7473366d3a413a4c858d5ddb3823e54322977d9934f75e279" +dependencies = [ + "bytes", + "futures-core", + "futures-util", + "ndarray", + "once_cell", + "ort", + "regex", + "serde", + "thiserror", +] + [[package]] name = "symbolic-common" version = "12.8.0" @@ -6274,18 +6291,18 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.56" +version = "1.0.58" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d54378c645627613241d077a3a79db965db602882668f9136ac42af9ecb730ad" +checksum = "03468839009160513471e86a034bb2c5c0e4baae3b43f79ffc55c4a5427b3297" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.56" +version = "1.0.58" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa0faa943b50f3db30a20aa7e265dbc66076993efed8463e8de414e5d06d3471" +checksum = "c61f3ba182994efc43764a46c018c347bc492c79f024e705f46567b418f6d4f7" dependencies = [ "proc-macro2", "quote", diff --git a/core/Cargo.toml b/core/Cargo.toml index 638cea1a..9ecd5fbb 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -135,7 +135,7 @@ sha2 = "0.10.8" snap = "1.1.0" speedb = { version = "0.0.4", features = ["lz4", "snappy"], optional = true } storekey = "0.5.0" -surrealml-core1 = { version = "0.0.8", optional = true, package = "surrealml-core" } +surrealml-core1 = { version = "0.1.1", optional = true, package = "surrealml-core" } surrealkv = { version = "0.1.3", optional = true } thiserror = "1.0.50" tikv = { version = "0.2.0-surreal.2", default-features = false, package = "surrealdb-tikv-client", optional = true } diff --git a/core/src/sql/model.rs b/core/src/sql/model.rs index 960e0617..3ecc3e1c 100644 --- a/core/src/sql/model.rs +++ b/core/src/sql/model.rs @@ -11,6 +11,8 @@ use std::fmt; #[cfg(any(feature = "ml", feature = "ml2"))] use crate::iam::Action; #[cfg(any(feature = "ml", feature = "ml2"))] +use crate::ml::errors::error::SurrealError; +#[cfg(any(feature = "ml", feature = "ml2"))] use crate::ml::execution::compute::ModelComputation; #[cfg(any(feature = "ml", feature = "ml2"))] use crate::ml::storage::surml_file::SurMlFile; @@ -125,11 +127,15 @@ impl Model { let bytes = crate::obs::get(&path).await?; // Run the compute in a blocking task let outcome = tokio::task::spawn_blocking(move || { - let mut file = SurMlFile::from_bytes(bytes).unwrap(); + let mut file = SurMlFile::from_bytes(bytes).map_err(|err: SurrealError| { + Error::ModelComputation(err.message.to_string()) + })?; let compute_unit = ModelComputation { surml_file: &mut file, }; - compute_unit.buffered_compute(&mut args).map_err(Error::ModelComputation) + compute_unit.buffered_compute(&mut args).map_err(|err: SurrealError| { + Error::ModelComputation(err.message.to_string()) + }) }) .await .unwrap()?; @@ -149,11 +155,15 @@ impl Model { let tensor = ndarray::arr1::(&[args]).into_dyn(); // Run the compute in a blocking task let outcome = tokio::task::spawn_blocking(move || { - let mut file = SurMlFile::from_bytes(bytes).unwrap(); + let mut file = SurMlFile::from_bytes(bytes).map_err(|err: SurrealError| { + Error::ModelComputation(err.message.to_string()) + })?; let compute_unit = ModelComputation { surml_file: &mut file, }; - compute_unit.raw_compute(tensor, None).map_err(Error::ModelComputation) + compute_unit.raw_compute(tensor, None).map_err(|err: SurrealError| { + Error::ModelComputation(err.message.to_string()) + }) }) .await .unwrap()?; @@ -177,11 +187,15 @@ impl Model { let tensor = ndarray::arr1::(&args).into_dyn(); // Run the compute in a blocking task let outcome = tokio::task::spawn_blocking(move || { - let mut file = SurMlFile::from_bytes(bytes).unwrap(); + let mut file = SurMlFile::from_bytes(bytes).map_err(|err: SurrealError| { + Error::ModelComputation(err.message.to_string()) + })?; let compute_unit = ModelComputation { surml_file: &mut file, }; - compute_unit.raw_compute(tensor, None).map_err(Error::ModelComputation) + compute_unit.raw_compute(tensor, None).map_err(|err: SurrealError| { + Error::ModelComputation(err.message.to_string()) + }) }) .await .unwrap()?; diff --git a/lib/src/api/engine/local/mod.rs b/lib/src/api/engine/local/mod.rs index 59614a57..399830bc 100644 --- a/lib/src/api/engine/local/mod.rs +++ b/lib/src/api/engine/local/mod.rs @@ -738,8 +738,11 @@ async fn router( Ok(file) => file, Err(error) => { return Err(Error::FileRead { - path, - error, + path: PathBuf::from(path), + error: io::Error::new( + io::ErrorKind::InvalidData, + error.message.to_string(), + ), } .into()); } diff --git a/src/net/ml.rs b/src/net/ml.rs index 95e1a70c..e3442901 100644 --- a/src/net/ml.rs +++ b/src/net/ml.rs @@ -61,6 +61,12 @@ async fn import( Ok(file) => file, Err(err) => return Err(Error::Other(err.to_string())), }; + + // reject the file if there is no model name or version + if file.header.name.to_string() == "" || file.header.version.to_string() == "" { + return Err(Error::Other("Model name and version must be set".to_string())); + } + // Convert the file back in to raw bytes let data = file.to_bytes(); // Calculate the hash of the model file diff --git a/tests/ml_integration.rs b/tests/ml_integration.rs index f5894b56..41dc78c6 100644 --- a/tests/ml_integration.rs +++ b/tests/ml_integration.rs @@ -14,6 +14,14 @@ mod ml_integration { use test_log::test; use ulid::Ulid; + #[derive(Serialize, Deserialize, Debug)] + struct ErrorResponse { + code: u16, + details: String, + description: String, + information: String, + } + static LOCK: AtomicBool = AtomicBool::new(false); #[derive(Serialize, Deserialize, Debug)] @@ -43,7 +51,7 @@ mod ml_integration { } async fn upload_file(addr: &str, ns: &str, db: &str) -> Result<(), Box> { - let generator = StreamAdapter::new(5, "./tests/linear_test.surml".to_string()); + let generator = StreamAdapter::new(5, "./tests/linear_test.surml".to_string()).unwrap(); let body = Body::wrap_stream(generator); // Prepare HTTP client let mut headers = reqwest::header::HeaderMap::new(); @@ -75,6 +83,138 @@ mod ml_integration { Ok(()) } + #[test(tokio::test)] + async fn upload_bad_file() -> Result<(), Box> { + let _lock = LockHandle::acquire_lock(); + let (addr, _server) = common::start_server_with_defaults().await.unwrap(); + let ns = Ulid::new().to_string(); + let db = Ulid::new().to_string(); + let generator = StreamAdapter::new(5, "./tests/should_crash.surml".to_string()).unwrap(); + let body = Body::wrap_stream(generator); + // Prepare HTTP client + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert("NS", ns.parse()?); + headers.insert("DB", db.parse()?); + let client = reqwest::Client::builder() + .connect_timeout(Duration::from_secs(1)) + .default_headers(headers) + .build()?; + // Send HTTP request + let res = client + .post(format!("http://{addr}/ml/import")) + .basic_auth(common::USER, Some(common::PASS)) + .body(body) + .send() + .await?; + // Check response code + let raw_data = res.text().await?; + let response: ErrorResponse = serde_json::from_str(&raw_data)?; + + assert_eq!(response.code, 400); + assert_eq!( + "Not enough bytes to read for header, maybe the file format is not correct".to_string(), + response.information + ); + Ok(()) + } + + #[test(tokio::test)] + async fn upload_file_with_no_name() -> Result<(), Box> { + let _lock = LockHandle::acquire_lock(); + let (addr, _server) = common::start_server_with_defaults().await.unwrap(); + let ns = Ulid::new().to_string(); + let db = Ulid::new().to_string(); + let generator = StreamAdapter::new(5, "./tests/no_name.surml".to_string()).unwrap(); + let body = Body::wrap_stream(generator); + // Prepare HTTP client + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert("NS", ns.parse()?); + headers.insert("DB", db.parse()?); + let client = reqwest::Client::builder() + .connect_timeout(Duration::from_secs(1)) + .default_headers(headers) + .build()?; + // Send HTTP request + let res = client + .post(format!("http://{addr}/ml/import")) + .basic_auth(common::USER, Some(common::PASS)) + .body(body) + .send() + .await?; + // Check response code + let raw_data = res.text().await?; + let response: ErrorResponse = serde_json::from_str(&raw_data)?; + + assert_eq!(response.code, 400); + assert_eq!("Model name and version must be set".to_string(), response.information); + Ok(()) + } + + #[test(tokio::test)] + async fn upload_file_with_no_version() -> Result<(), Box> { + let _lock = LockHandle::acquire_lock(); + let (addr, _server) = common::start_server_with_defaults().await.unwrap(); + let ns = Ulid::new().to_string(); + let db = Ulid::new().to_string(); + let generator = StreamAdapter::new(5, "./tests/no_version.surml".to_string()).unwrap(); + let body = Body::wrap_stream(generator); + // Prepare HTTP client + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert("NS", ns.parse()?); + headers.insert("DB", db.parse()?); + let client = reqwest::Client::builder() + .connect_timeout(Duration::from_secs(1)) + .default_headers(headers) + .build()?; + // Send HTTP request + let res = client + .post(format!("http://{addr}/ml/import")) + .basic_auth(common::USER, Some(common::PASS)) + .body(body) + .send() + .await?; + // Check response code + let raw_data = res.text().await?; + let response: ErrorResponse = serde_json::from_str(&raw_data)?; + + assert_eq!(response.code, 400); + assert_eq!("Model name and version must be set".to_string(), response.information); + Ok(()) + } + + #[test(tokio::test)] + async fn upload_file_with_no_version_or_name() -> Result<(), Box> { + let _lock = LockHandle::acquire_lock(); + let (addr, _server) = common::start_server_with_defaults().await.unwrap(); + let ns = Ulid::new().to_string(); + let db = Ulid::new().to_string(); + let generator = + StreamAdapter::new(5, "./tests/no_name_or_version.surml".to_string()).unwrap(); + let body = Body::wrap_stream(generator); + // Prepare HTTP client + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert("NS", ns.parse()?); + headers.insert("DB", db.parse()?); + let client = reqwest::Client::builder() + .connect_timeout(Duration::from_secs(1)) + .default_headers(headers) + .build()?; + // Send HTTP request + let res = client + .post(format!("http://{addr}/ml/import")) + .basic_auth(common::USER, Some(common::PASS)) + .body(body) + .send() + .await?; + // Check response code + let raw_data = res.text().await?; + let response: ErrorResponse = serde_json::from_str(&raw_data)?; + + assert_eq!(response.code, 400); + assert_eq!("Model name and version must be set".to_string(), response.information); + Ok(()) + } + #[test(tokio::test)] async fn raw_compute() -> Result<(), Box> { let _lock = LockHandle::acquire_lock(); diff --git a/tests/no_name.surml b/tests/no_name.surml new file mode 100644 index 0000000000000000000000000000000000000000..406f3b356aa34e35ed89fa006b41c53fc15ac77d GIT binary patch literal 550 zcmZ`!O;5rw7^V|tmITQzB%B&fh-B-gV}jX?#1jcOPhOU8D{f}2a~}!%v%DLR{v7{; z?S_ITCT;WfeLnI$jRruPW;^*)Niw=(tgxM=T)V57u`K_Kytui~GCDML1CK5JQpKh* zJ=L_G!ghW*W2{=(((%NyIK_5CR_4@Av&g4A*7?N37>{*qnkH7K7Zp>mv5Hk)40mCa zqcsa?jM9t-kzeGf2r2(?Bg9sSpvoC}B&Pvl35qzQ?t?@j@<0(u0>ndlG!u%toGRSY z=YnN^C`Q|fnNMOdiYvF2))ws_8axf2eZIc|cniS>c36^ZX0s(JmSyZtl%5tx2_>1f z-LSGq%fiL|&Tn9FR1@G8Y}X3ai%1Ut(<{L2x4$iJf?#mKALY2p>^?@ts3(IinD2Wn vnnU>UIzqBkL4T4i-5>#M6z|$F~Mv`;)#TtCofC46*n{1xpoQqvAi3PejLAo z-3EdrCVkE8`~B(r)oK#dG}~b>InKzEQf@m*zH*l_r5QU!-dyI>j0_CjKqE`PR*`8; z&NXePu$^Dd2&opbbTqasL9v~X=8U*$7Wrhyx*S^wp^=VE(?sgzs$>@Iq8`GC!4(Zi z4AYDRk?3Bj4Rnb4f4 zId@U7{TJw;R0MDXT9rceBFw;ldPOpOt*ymf5cH44rJR;KyU!6Hc4es}&Gx5>#M6z|$F~Mv`;)#UBgC{Rbw-q-t*12{G`mwwlkA57# zg53s+Bqn{$>-+ub`_*a^)HK^+&pFP>l2UFvNxpKIF{K&XM_yg#(~JxZ-9RHtzf_TF zOinfJS7AFpoDotjWa(&ZS%P9aA*2{A0~p@$g>AWdMz8LUFVA{Bll5GMf?n|dTG6jegC za8F-2&HS)vZ6~Haj<^5HO{KX(dxwUKR5bnic$dH}kk+70lVm-eE-+u@u{&0}S{x+= zXWC{$bDrkhMZNZ4pnp^mz!hj!3e}4+1OMq2$?UbZ7B@lAKM