updating surrealml-core and adding checking for version and name (#3773)

This commit is contained in:
Maxwell Flitton 2024-03-28 13:57:26 +00:00 committed by GitHub
parent 1157d70b06
commit e201366602
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 207 additions and 26 deletions

49
Cargo.lock generated
View file

@ -1100,9 +1100,9 @@ dependencies = [
[[package]] [[package]]
name = "cedar-policy" name = "cedar-policy"
version = "2.4.3" version = "2.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "31ff2003d0aba0a4b2e5212660321d63dc7c36efe636d6ca1882d489cbc0bef8" checksum = "3d91e3b10a0f7f2911774d5e49713c4d25753466f9e11d1cd2ec627f8a2dc857"
dependencies = [ dependencies = [
"cedar-policy-core", "cedar-policy-core",
"cedar-policy-validator", "cedar-policy-validator",
@ -1117,9 +1117,9 @@ dependencies = [
[[package]] [[package]]
name = "cedar-policy-core" name = "cedar-policy-core"
version = "2.4.3" version = "2.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9c52f9666c7cb1b6f14a6e77d3ffcffa20fd3e1012ac8dcc393498c33ff632c3" checksum = "cd2315591c6b7e18f8038f0a0529f254235fd902b6c217aabc04f2459b0d9995"
dependencies = [ dependencies = [
"either", "either",
"ipnet", "ipnet",
@ -1140,9 +1140,9 @@ dependencies = [
[[package]] [[package]]
name = "cedar-policy-validator" name = "cedar-policy-validator"
version = "2.4.3" version = "2.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "76a63c1a72bcafda800830cbdde316162074b341b7d59bd4b1cea6156f22dfa7" checksum = "e756e1b2a5da742ed97e65199ad6d0893e9aa4bd6b34be1de9e70bd1e6adc7df"
dependencies = [ dependencies = [
"cedar-policy-core", "cedar-policy-core",
"itertools 0.10.5", "itertools 0.10.5",
@ -5272,9 +5272,9 @@ checksum = "cd0b0ec5f1c1ca621c432a25813d8d60c88abe6d3e08a3eb9cf37d97a0fe3d73"
[[package]] [[package]]
name = "serde" name = "serde"
version = "1.0.193" version = "1.0.197"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "25dd9975e68d0cb5aa1120c288333fc98731bd1dd12f561e468ea4728c042b89" checksum = "3fb1c873e1b9b056a4dc4c0c198b24c3ffa059243875552b2bd0933b1aee4ce2"
dependencies = [ dependencies = [
"serde_derive", "serde_derive",
] ]
@ -5290,9 +5290,9 @@ dependencies = [
[[package]] [[package]]
name = "serde_derive" name = "serde_derive"
version = "1.0.193" version = "1.0.197"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "43576ca501357b9b071ac53cdc7da8ef0cbd9493d8df094cd821777ea6e894d3" checksum = "7eb0b34b42edc17f6b7cac84a52a1c5f0e1bb2227e997ca9011ea3dd34e8610b"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
@ -5886,7 +5886,7 @@ dependencies = [
"surrealdb-jsonwebtoken", "surrealdb-jsonwebtoken",
"surrealdb-tikv-client", "surrealdb-tikv-client",
"surrealkv", "surrealkv",
"surrealml-core", "surrealml-core 0.0.8",
"thiserror", "thiserror",
"tokio", "tokio",
"tracing", "tracing",
@ -5972,7 +5972,7 @@ dependencies = [
"surrealdb-jsonwebtoken", "surrealdb-jsonwebtoken",
"surrealdb-tikv-client", "surrealdb-tikv-client",
"surrealkv", "surrealkv",
"surrealml-core", "surrealml-core 0.1.1",
"temp-dir", "temp-dir",
"test-log", "test-log",
"thiserror", "thiserror",
@ -6085,6 +6085,23 @@ dependencies = [
"regex", "regex",
] ]
[[package]]
name = "surrealml-core"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5af53365d57a1bd7473366d3a413a4c858d5ddb3823e54322977d9934f75e279"
dependencies = [
"bytes",
"futures-core",
"futures-util",
"ndarray",
"once_cell",
"ort",
"regex",
"serde",
"thiserror",
]
[[package]] [[package]]
name = "symbolic-common" name = "symbolic-common"
version = "12.8.0" version = "12.8.0"
@ -6274,18 +6291,18 @@ dependencies = [
[[package]] [[package]]
name = "thiserror" name = "thiserror"
version = "1.0.56" version = "1.0.58"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d54378c645627613241d077a3a79db965db602882668f9136ac42af9ecb730ad" checksum = "03468839009160513471e86a034bb2c5c0e4baae3b43f79ffc55c4a5427b3297"
dependencies = [ dependencies = [
"thiserror-impl", "thiserror-impl",
] ]
[[package]] [[package]]
name = "thiserror-impl" name = "thiserror-impl"
version = "1.0.56" version = "1.0.58"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fa0faa943b50f3db30a20aa7e265dbc66076993efed8463e8de414e5d06d3471" checksum = "c61f3ba182994efc43764a46c018c347bc492c79f024e705f46567b418f6d4f7"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",

View file

@ -135,7 +135,7 @@ sha2 = "0.10.8"
snap = "1.1.0" snap = "1.1.0"
speedb = { version = "0.0.4", features = ["lz4", "snappy"], optional = true } speedb = { version = "0.0.4", features = ["lz4", "snappy"], optional = true }
storekey = "0.5.0" storekey = "0.5.0"
surrealml-core1 = { version = "0.0.8", optional = true, package = "surrealml-core" } surrealml-core1 = { version = "0.1.1", optional = true, package = "surrealml-core" }
surrealkv = { version = "0.1.3", optional = true } surrealkv = { version = "0.1.3", optional = true }
thiserror = "1.0.50" thiserror = "1.0.50"
tikv = { version = "0.2.0-surreal.2", default-features = false, package = "surrealdb-tikv-client", optional = true } tikv = { version = "0.2.0-surreal.2", default-features = false, package = "surrealdb-tikv-client", optional = true }

View file

@ -11,6 +11,8 @@ use std::fmt;
#[cfg(any(feature = "ml", feature = "ml2"))] #[cfg(any(feature = "ml", feature = "ml2"))]
use crate::iam::Action; use crate::iam::Action;
#[cfg(any(feature = "ml", feature = "ml2"))] #[cfg(any(feature = "ml", feature = "ml2"))]
use crate::ml::errors::error::SurrealError;
#[cfg(any(feature = "ml", feature = "ml2"))]
use crate::ml::execution::compute::ModelComputation; use crate::ml::execution::compute::ModelComputation;
#[cfg(any(feature = "ml", feature = "ml2"))] #[cfg(any(feature = "ml", feature = "ml2"))]
use crate::ml::storage::surml_file::SurMlFile; use crate::ml::storage::surml_file::SurMlFile;
@ -125,11 +127,15 @@ impl Model {
let bytes = crate::obs::get(&path).await?; let bytes = crate::obs::get(&path).await?;
// Run the compute in a blocking task // Run the compute in a blocking task
let outcome = tokio::task::spawn_blocking(move || { let outcome = tokio::task::spawn_blocking(move || {
let mut file = SurMlFile::from_bytes(bytes).unwrap(); let mut file = SurMlFile::from_bytes(bytes).map_err(|err: SurrealError| {
Error::ModelComputation(err.message.to_string())
})?;
let compute_unit = ModelComputation { let compute_unit = ModelComputation {
surml_file: &mut file, surml_file: &mut file,
}; };
compute_unit.buffered_compute(&mut args).map_err(Error::ModelComputation) compute_unit.buffered_compute(&mut args).map_err(|err: SurrealError| {
Error::ModelComputation(err.message.to_string())
})
}) })
.await .await
.unwrap()?; .unwrap()?;
@ -149,11 +155,15 @@ impl Model {
let tensor = ndarray::arr1::<f32>(&[args]).into_dyn(); let tensor = ndarray::arr1::<f32>(&[args]).into_dyn();
// Run the compute in a blocking task // Run the compute in a blocking task
let outcome = tokio::task::spawn_blocking(move || { let outcome = tokio::task::spawn_blocking(move || {
let mut file = SurMlFile::from_bytes(bytes).unwrap(); let mut file = SurMlFile::from_bytes(bytes).map_err(|err: SurrealError| {
Error::ModelComputation(err.message.to_string())
})?;
let compute_unit = ModelComputation { let compute_unit = ModelComputation {
surml_file: &mut file, surml_file: &mut file,
}; };
compute_unit.raw_compute(tensor, None).map_err(Error::ModelComputation) compute_unit.raw_compute(tensor, None).map_err(|err: SurrealError| {
Error::ModelComputation(err.message.to_string())
})
}) })
.await .await
.unwrap()?; .unwrap()?;
@ -177,11 +187,15 @@ impl Model {
let tensor = ndarray::arr1::<f32>(&args).into_dyn(); let tensor = ndarray::arr1::<f32>(&args).into_dyn();
// Run the compute in a blocking task // Run the compute in a blocking task
let outcome = tokio::task::spawn_blocking(move || { let outcome = tokio::task::spawn_blocking(move || {
let mut file = SurMlFile::from_bytes(bytes).unwrap(); let mut file = SurMlFile::from_bytes(bytes).map_err(|err: SurrealError| {
Error::ModelComputation(err.message.to_string())
})?;
let compute_unit = ModelComputation { let compute_unit = ModelComputation {
surml_file: &mut file, surml_file: &mut file,
}; };
compute_unit.raw_compute(tensor, None).map_err(Error::ModelComputation) compute_unit.raw_compute(tensor, None).map_err(|err: SurrealError| {
Error::ModelComputation(err.message.to_string())
})
}) })
.await .await
.unwrap()?; .unwrap()?;

View file

@ -738,8 +738,11 @@ async fn router(
Ok(file) => file, Ok(file) => file,
Err(error) => { Err(error) => {
return Err(Error::FileRead { return Err(Error::FileRead {
path, path: PathBuf::from(path),
error, error: io::Error::new(
io::ErrorKind::InvalidData,
error.message.to_string(),
),
} }
.into()); .into());
} }

View file

@ -61,6 +61,12 @@ async fn import(
Ok(file) => file, Ok(file) => file,
Err(err) => return Err(Error::Other(err.to_string())), Err(err) => return Err(Error::Other(err.to_string())),
}; };
// reject the file if there is no model name or version
if file.header.name.to_string() == "" || file.header.version.to_string() == "" {
return Err(Error::Other("Model name and version must be set".to_string()));
}
// Convert the file back in to raw bytes // Convert the file back in to raw bytes
let data = file.to_bytes(); let data = file.to_bytes();
// Calculate the hash of the model file // Calculate the hash of the model file

View file

@ -14,6 +14,14 @@ mod ml_integration {
use test_log::test; use test_log::test;
use ulid::Ulid; use ulid::Ulid;
#[derive(Serialize, Deserialize, Debug)]
struct ErrorResponse {
code: u16,
details: String,
description: String,
information: String,
}
static LOCK: AtomicBool = AtomicBool::new(false); static LOCK: AtomicBool = AtomicBool::new(false);
#[derive(Serialize, Deserialize, Debug)] #[derive(Serialize, Deserialize, Debug)]
@ -43,7 +51,7 @@ mod ml_integration {
} }
async fn upload_file(addr: &str, ns: &str, db: &str) -> Result<(), Box<dyn std::error::Error>> { async fn upload_file(addr: &str, ns: &str, db: &str) -> Result<(), Box<dyn std::error::Error>> {
let generator = StreamAdapter::new(5, "./tests/linear_test.surml".to_string()); let generator = StreamAdapter::new(5, "./tests/linear_test.surml".to_string()).unwrap();
let body = Body::wrap_stream(generator); let body = Body::wrap_stream(generator);
// Prepare HTTP client // Prepare HTTP client
let mut headers = reqwest::header::HeaderMap::new(); let mut headers = reqwest::header::HeaderMap::new();
@ -75,6 +83,138 @@ mod ml_integration {
Ok(()) Ok(())
} }
#[test(tokio::test)]
async fn upload_bad_file() -> Result<(), Box<dyn std::error::Error>> {
let _lock = LockHandle::acquire_lock();
let (addr, _server) = common::start_server_with_defaults().await.unwrap();
let ns = Ulid::new().to_string();
let db = Ulid::new().to_string();
let generator = StreamAdapter::new(5, "./tests/should_crash.surml".to_string()).unwrap();
let body = Body::wrap_stream(generator);
// Prepare HTTP client
let mut headers = reqwest::header::HeaderMap::new();
headers.insert("NS", ns.parse()?);
headers.insert("DB", db.parse()?);
let client = reqwest::Client::builder()
.connect_timeout(Duration::from_secs(1))
.default_headers(headers)
.build()?;
// Send HTTP request
let res = client
.post(format!("http://{addr}/ml/import"))
.basic_auth(common::USER, Some(common::PASS))
.body(body)
.send()
.await?;
// Check response code
let raw_data = res.text().await?;
let response: ErrorResponse = serde_json::from_str(&raw_data)?;
assert_eq!(response.code, 400);
assert_eq!(
"Not enough bytes to read for header, maybe the file format is not correct".to_string(),
response.information
);
Ok(())
}
#[test(tokio::test)]
async fn upload_file_with_no_name() -> Result<(), Box<dyn std::error::Error>> {
let _lock = LockHandle::acquire_lock();
let (addr, _server) = common::start_server_with_defaults().await.unwrap();
let ns = Ulid::new().to_string();
let db = Ulid::new().to_string();
let generator = StreamAdapter::new(5, "./tests/no_name.surml".to_string()).unwrap();
let body = Body::wrap_stream(generator);
// Prepare HTTP client
let mut headers = reqwest::header::HeaderMap::new();
headers.insert("NS", ns.parse()?);
headers.insert("DB", db.parse()?);
let client = reqwest::Client::builder()
.connect_timeout(Duration::from_secs(1))
.default_headers(headers)
.build()?;
// Send HTTP request
let res = client
.post(format!("http://{addr}/ml/import"))
.basic_auth(common::USER, Some(common::PASS))
.body(body)
.send()
.await?;
// Check response code
let raw_data = res.text().await?;
let response: ErrorResponse = serde_json::from_str(&raw_data)?;
assert_eq!(response.code, 400);
assert_eq!("Model name and version must be set".to_string(), response.information);
Ok(())
}
#[test(tokio::test)]
async fn upload_file_with_no_version() -> Result<(), Box<dyn std::error::Error>> {
let _lock = LockHandle::acquire_lock();
let (addr, _server) = common::start_server_with_defaults().await.unwrap();
let ns = Ulid::new().to_string();
let db = Ulid::new().to_string();
let generator = StreamAdapter::new(5, "./tests/no_version.surml".to_string()).unwrap();
let body = Body::wrap_stream(generator);
// Prepare HTTP client
let mut headers = reqwest::header::HeaderMap::new();
headers.insert("NS", ns.parse()?);
headers.insert("DB", db.parse()?);
let client = reqwest::Client::builder()
.connect_timeout(Duration::from_secs(1))
.default_headers(headers)
.build()?;
// Send HTTP request
let res = client
.post(format!("http://{addr}/ml/import"))
.basic_auth(common::USER, Some(common::PASS))
.body(body)
.send()
.await?;
// Check response code
let raw_data = res.text().await?;
let response: ErrorResponse = serde_json::from_str(&raw_data)?;
assert_eq!(response.code, 400);
assert_eq!("Model name and version must be set".to_string(), response.information);
Ok(())
}
#[test(tokio::test)]
async fn upload_file_with_no_version_or_name() -> Result<(), Box<dyn std::error::Error>> {
let _lock = LockHandle::acquire_lock();
let (addr, _server) = common::start_server_with_defaults().await.unwrap();
let ns = Ulid::new().to_string();
let db = Ulid::new().to_string();
let generator =
StreamAdapter::new(5, "./tests/no_name_or_version.surml".to_string()).unwrap();
let body = Body::wrap_stream(generator);
// Prepare HTTP client
let mut headers = reqwest::header::HeaderMap::new();
headers.insert("NS", ns.parse()?);
headers.insert("DB", db.parse()?);
let client = reqwest::Client::builder()
.connect_timeout(Duration::from_secs(1))
.default_headers(headers)
.build()?;
// Send HTTP request
let res = client
.post(format!("http://{addr}/ml/import"))
.basic_auth(common::USER, Some(common::PASS))
.body(body)
.send()
.await?;
// Check response code
let raw_data = res.text().await?;
let response: ErrorResponse = serde_json::from_str(&raw_data)?;
assert_eq!(response.code, 400);
assert_eq!("Model name and version must be set".to_string(), response.information);
Ok(())
}
#[test(tokio::test)] #[test(tokio::test)]
async fn raw_compute() -> Result<(), Box<dyn std::error::Error>> { async fn raw_compute() -> Result<(), Box<dyn std::error::Error>> {
let _lock = LockHandle::acquire_lock(); let _lock = LockHandle::acquire_lock();

BIN
tests/no_name.surml Normal file

Binary file not shown.

Binary file not shown.

BIN
tests/no_version.surml Normal file

Binary file not shown.

1
tests/should_crash.surml Normal file
View file

@ -0,0 +1 @@
this should crash