updating surrealml-core and adding checking for version and name (#3773)
This commit is contained in:
parent
1157d70b06
commit
e201366602
10 changed files with 207 additions and 26 deletions
49
Cargo.lock
generated
49
Cargo.lock
generated
|
@ -1100,9 +1100,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "cedar-policy"
|
||||
version = "2.4.3"
|
||||
version = "2.4.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "31ff2003d0aba0a4b2e5212660321d63dc7c36efe636d6ca1882d489cbc0bef8"
|
||||
checksum = "3d91e3b10a0f7f2911774d5e49713c4d25753466f9e11d1cd2ec627f8a2dc857"
|
||||
dependencies = [
|
||||
"cedar-policy-core",
|
||||
"cedar-policy-validator",
|
||||
|
@ -1117,9 +1117,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "cedar-policy-core"
|
||||
version = "2.4.3"
|
||||
version = "2.4.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9c52f9666c7cb1b6f14a6e77d3ffcffa20fd3e1012ac8dcc393498c33ff632c3"
|
||||
checksum = "cd2315591c6b7e18f8038f0a0529f254235fd902b6c217aabc04f2459b0d9995"
|
||||
dependencies = [
|
||||
"either",
|
||||
"ipnet",
|
||||
|
@ -1140,9 +1140,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "cedar-policy-validator"
|
||||
version = "2.4.3"
|
||||
version = "2.4.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "76a63c1a72bcafda800830cbdde316162074b341b7d59bd4b1cea6156f22dfa7"
|
||||
checksum = "e756e1b2a5da742ed97e65199ad6d0893e9aa4bd6b34be1de9e70bd1e6adc7df"
|
||||
dependencies = [
|
||||
"cedar-policy-core",
|
||||
"itertools 0.10.5",
|
||||
|
@ -5272,9 +5272,9 @@ checksum = "cd0b0ec5f1c1ca621c432a25813d8d60c88abe6d3e08a3eb9cf37d97a0fe3d73"
|
|||
|
||||
[[package]]
|
||||
name = "serde"
|
||||
version = "1.0.193"
|
||||
version = "1.0.197"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "25dd9975e68d0cb5aa1120c288333fc98731bd1dd12f561e468ea4728c042b89"
|
||||
checksum = "3fb1c873e1b9b056a4dc4c0c198b24c3ffa059243875552b2bd0933b1aee4ce2"
|
||||
dependencies = [
|
||||
"serde_derive",
|
||||
]
|
||||
|
@ -5290,9 +5290,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "serde_derive"
|
||||
version = "1.0.193"
|
||||
version = "1.0.197"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "43576ca501357b9b071ac53cdc7da8ef0cbd9493d8df094cd821777ea6e894d3"
|
||||
checksum = "7eb0b34b42edc17f6b7cac84a52a1c5f0e1bb2227e997ca9011ea3dd34e8610b"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
|
@ -5886,7 +5886,7 @@ dependencies = [
|
|||
"surrealdb-jsonwebtoken",
|
||||
"surrealdb-tikv-client",
|
||||
"surrealkv",
|
||||
"surrealml-core",
|
||||
"surrealml-core 0.0.8",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
"tracing",
|
||||
|
@ -5972,7 +5972,7 @@ dependencies = [
|
|||
"surrealdb-jsonwebtoken",
|
||||
"surrealdb-tikv-client",
|
||||
"surrealkv",
|
||||
"surrealml-core",
|
||||
"surrealml-core 0.1.1",
|
||||
"temp-dir",
|
||||
"test-log",
|
||||
"thiserror",
|
||||
|
@ -6085,6 +6085,23 @@ dependencies = [
|
|||
"regex",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "surrealml-core"
|
||||
version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5af53365d57a1bd7473366d3a413a4c858d5ddb3823e54322977d9934f75e279"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"futures-core",
|
||||
"futures-util",
|
||||
"ndarray",
|
||||
"once_cell",
|
||||
"ort",
|
||||
"regex",
|
||||
"serde",
|
||||
"thiserror",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "symbolic-common"
|
||||
version = "12.8.0"
|
||||
|
@ -6274,18 +6291,18 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "thiserror"
|
||||
version = "1.0.56"
|
||||
version = "1.0.58"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d54378c645627613241d077a3a79db965db602882668f9136ac42af9ecb730ad"
|
||||
checksum = "03468839009160513471e86a034bb2c5c0e4baae3b43f79ffc55c4a5427b3297"
|
||||
dependencies = [
|
||||
"thiserror-impl",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "thiserror-impl"
|
||||
version = "1.0.56"
|
||||
version = "1.0.58"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fa0faa943b50f3db30a20aa7e265dbc66076993efed8463e8de414e5d06d3471"
|
||||
checksum = "c61f3ba182994efc43764a46c018c347bc492c79f024e705f46567b418f6d4f7"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
|
|
|
@ -135,7 +135,7 @@ sha2 = "0.10.8"
|
|||
snap = "1.1.0"
|
||||
speedb = { version = "0.0.4", features = ["lz4", "snappy"], optional = true }
|
||||
storekey = "0.5.0"
|
||||
surrealml-core1 = { version = "0.0.8", optional = true, package = "surrealml-core" }
|
||||
surrealml-core1 = { version = "0.1.1", optional = true, package = "surrealml-core" }
|
||||
surrealkv = { version = "0.1.3", optional = true }
|
||||
thiserror = "1.0.50"
|
||||
tikv = { version = "0.2.0-surreal.2", default-features = false, package = "surrealdb-tikv-client", optional = true }
|
||||
|
|
|
@ -11,6 +11,8 @@ use std::fmt;
|
|||
#[cfg(any(feature = "ml", feature = "ml2"))]
|
||||
use crate::iam::Action;
|
||||
#[cfg(any(feature = "ml", feature = "ml2"))]
|
||||
use crate::ml::errors::error::SurrealError;
|
||||
#[cfg(any(feature = "ml", feature = "ml2"))]
|
||||
use crate::ml::execution::compute::ModelComputation;
|
||||
#[cfg(any(feature = "ml", feature = "ml2"))]
|
||||
use crate::ml::storage::surml_file::SurMlFile;
|
||||
|
@ -125,11 +127,15 @@ impl Model {
|
|||
let bytes = crate::obs::get(&path).await?;
|
||||
// Run the compute in a blocking task
|
||||
let outcome = tokio::task::spawn_blocking(move || {
|
||||
let mut file = SurMlFile::from_bytes(bytes).unwrap();
|
||||
let mut file = SurMlFile::from_bytes(bytes).map_err(|err: SurrealError| {
|
||||
Error::ModelComputation(err.message.to_string())
|
||||
})?;
|
||||
let compute_unit = ModelComputation {
|
||||
surml_file: &mut file,
|
||||
};
|
||||
compute_unit.buffered_compute(&mut args).map_err(Error::ModelComputation)
|
||||
compute_unit.buffered_compute(&mut args).map_err(|err: SurrealError| {
|
||||
Error::ModelComputation(err.message.to_string())
|
||||
})
|
||||
})
|
||||
.await
|
||||
.unwrap()?;
|
||||
|
@ -149,11 +155,15 @@ impl Model {
|
|||
let tensor = ndarray::arr1::<f32>(&[args]).into_dyn();
|
||||
// Run the compute in a blocking task
|
||||
let outcome = tokio::task::spawn_blocking(move || {
|
||||
let mut file = SurMlFile::from_bytes(bytes).unwrap();
|
||||
let mut file = SurMlFile::from_bytes(bytes).map_err(|err: SurrealError| {
|
||||
Error::ModelComputation(err.message.to_string())
|
||||
})?;
|
||||
let compute_unit = ModelComputation {
|
||||
surml_file: &mut file,
|
||||
};
|
||||
compute_unit.raw_compute(tensor, None).map_err(Error::ModelComputation)
|
||||
compute_unit.raw_compute(tensor, None).map_err(|err: SurrealError| {
|
||||
Error::ModelComputation(err.message.to_string())
|
||||
})
|
||||
})
|
||||
.await
|
||||
.unwrap()?;
|
||||
|
@ -177,11 +187,15 @@ impl Model {
|
|||
let tensor = ndarray::arr1::<f32>(&args).into_dyn();
|
||||
// Run the compute in a blocking task
|
||||
let outcome = tokio::task::spawn_blocking(move || {
|
||||
let mut file = SurMlFile::from_bytes(bytes).unwrap();
|
||||
let mut file = SurMlFile::from_bytes(bytes).map_err(|err: SurrealError| {
|
||||
Error::ModelComputation(err.message.to_string())
|
||||
})?;
|
||||
let compute_unit = ModelComputation {
|
||||
surml_file: &mut file,
|
||||
};
|
||||
compute_unit.raw_compute(tensor, None).map_err(Error::ModelComputation)
|
||||
compute_unit.raw_compute(tensor, None).map_err(|err: SurrealError| {
|
||||
Error::ModelComputation(err.message.to_string())
|
||||
})
|
||||
})
|
||||
.await
|
||||
.unwrap()?;
|
||||
|
|
|
@ -738,8 +738,11 @@ async fn router(
|
|||
Ok(file) => file,
|
||||
Err(error) => {
|
||||
return Err(Error::FileRead {
|
||||
path,
|
||||
error,
|
||||
path: PathBuf::from(path),
|
||||
error: io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
error.message.to_string(),
|
||||
),
|
||||
}
|
||||
.into());
|
||||
}
|
||||
|
|
|
@ -61,6 +61,12 @@ async fn import(
|
|||
Ok(file) => file,
|
||||
Err(err) => return Err(Error::Other(err.to_string())),
|
||||
};
|
||||
|
||||
// reject the file if there is no model name or version
|
||||
if file.header.name.to_string() == "" || file.header.version.to_string() == "" {
|
||||
return Err(Error::Other("Model name and version must be set".to_string()));
|
||||
}
|
||||
|
||||
// Convert the file back in to raw bytes
|
||||
let data = file.to_bytes();
|
||||
// Calculate the hash of the model file
|
||||
|
|
|
@ -14,6 +14,14 @@ mod ml_integration {
|
|||
use test_log::test;
|
||||
use ulid::Ulid;
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
struct ErrorResponse {
|
||||
code: u16,
|
||||
details: String,
|
||||
description: String,
|
||||
information: String,
|
||||
}
|
||||
|
||||
static LOCK: AtomicBool = AtomicBool::new(false);
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
|
@ -43,7 +51,7 @@ mod ml_integration {
|
|||
}
|
||||
|
||||
async fn upload_file(addr: &str, ns: &str, db: &str) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let generator = StreamAdapter::new(5, "./tests/linear_test.surml".to_string());
|
||||
let generator = StreamAdapter::new(5, "./tests/linear_test.surml".to_string()).unwrap();
|
||||
let body = Body::wrap_stream(generator);
|
||||
// Prepare HTTP client
|
||||
let mut headers = reqwest::header::HeaderMap::new();
|
||||
|
@ -75,6 +83,138 @@ mod ml_integration {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn upload_bad_file() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let _lock = LockHandle::acquire_lock();
|
||||
let (addr, _server) = common::start_server_with_defaults().await.unwrap();
|
||||
let ns = Ulid::new().to_string();
|
||||
let db = Ulid::new().to_string();
|
||||
let generator = StreamAdapter::new(5, "./tests/should_crash.surml".to_string()).unwrap();
|
||||
let body = Body::wrap_stream(generator);
|
||||
// Prepare HTTP client
|
||||
let mut headers = reqwest::header::HeaderMap::new();
|
||||
headers.insert("NS", ns.parse()?);
|
||||
headers.insert("DB", db.parse()?);
|
||||
let client = reqwest::Client::builder()
|
||||
.connect_timeout(Duration::from_secs(1))
|
||||
.default_headers(headers)
|
||||
.build()?;
|
||||
// Send HTTP request
|
||||
let res = client
|
||||
.post(format!("http://{addr}/ml/import"))
|
||||
.basic_auth(common::USER, Some(common::PASS))
|
||||
.body(body)
|
||||
.send()
|
||||
.await?;
|
||||
// Check response code
|
||||
let raw_data = res.text().await?;
|
||||
let response: ErrorResponse = serde_json::from_str(&raw_data)?;
|
||||
|
||||
assert_eq!(response.code, 400);
|
||||
assert_eq!(
|
||||
"Not enough bytes to read for header, maybe the file format is not correct".to_string(),
|
||||
response.information
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn upload_file_with_no_name() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let _lock = LockHandle::acquire_lock();
|
||||
let (addr, _server) = common::start_server_with_defaults().await.unwrap();
|
||||
let ns = Ulid::new().to_string();
|
||||
let db = Ulid::new().to_string();
|
||||
let generator = StreamAdapter::new(5, "./tests/no_name.surml".to_string()).unwrap();
|
||||
let body = Body::wrap_stream(generator);
|
||||
// Prepare HTTP client
|
||||
let mut headers = reqwest::header::HeaderMap::new();
|
||||
headers.insert("NS", ns.parse()?);
|
||||
headers.insert("DB", db.parse()?);
|
||||
let client = reqwest::Client::builder()
|
||||
.connect_timeout(Duration::from_secs(1))
|
||||
.default_headers(headers)
|
||||
.build()?;
|
||||
// Send HTTP request
|
||||
let res = client
|
||||
.post(format!("http://{addr}/ml/import"))
|
||||
.basic_auth(common::USER, Some(common::PASS))
|
||||
.body(body)
|
||||
.send()
|
||||
.await?;
|
||||
// Check response code
|
||||
let raw_data = res.text().await?;
|
||||
let response: ErrorResponse = serde_json::from_str(&raw_data)?;
|
||||
|
||||
assert_eq!(response.code, 400);
|
||||
assert_eq!("Model name and version must be set".to_string(), response.information);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn upload_file_with_no_version() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let _lock = LockHandle::acquire_lock();
|
||||
let (addr, _server) = common::start_server_with_defaults().await.unwrap();
|
||||
let ns = Ulid::new().to_string();
|
||||
let db = Ulid::new().to_string();
|
||||
let generator = StreamAdapter::new(5, "./tests/no_version.surml".to_string()).unwrap();
|
||||
let body = Body::wrap_stream(generator);
|
||||
// Prepare HTTP client
|
||||
let mut headers = reqwest::header::HeaderMap::new();
|
||||
headers.insert("NS", ns.parse()?);
|
||||
headers.insert("DB", db.parse()?);
|
||||
let client = reqwest::Client::builder()
|
||||
.connect_timeout(Duration::from_secs(1))
|
||||
.default_headers(headers)
|
||||
.build()?;
|
||||
// Send HTTP request
|
||||
let res = client
|
||||
.post(format!("http://{addr}/ml/import"))
|
||||
.basic_auth(common::USER, Some(common::PASS))
|
||||
.body(body)
|
||||
.send()
|
||||
.await?;
|
||||
// Check response code
|
||||
let raw_data = res.text().await?;
|
||||
let response: ErrorResponse = serde_json::from_str(&raw_data)?;
|
||||
|
||||
assert_eq!(response.code, 400);
|
||||
assert_eq!("Model name and version must be set".to_string(), response.information);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn upload_file_with_no_version_or_name() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let _lock = LockHandle::acquire_lock();
|
||||
let (addr, _server) = common::start_server_with_defaults().await.unwrap();
|
||||
let ns = Ulid::new().to_string();
|
||||
let db = Ulid::new().to_string();
|
||||
let generator =
|
||||
StreamAdapter::new(5, "./tests/no_name_or_version.surml".to_string()).unwrap();
|
||||
let body = Body::wrap_stream(generator);
|
||||
// Prepare HTTP client
|
||||
let mut headers = reqwest::header::HeaderMap::new();
|
||||
headers.insert("NS", ns.parse()?);
|
||||
headers.insert("DB", db.parse()?);
|
||||
let client = reqwest::Client::builder()
|
||||
.connect_timeout(Duration::from_secs(1))
|
||||
.default_headers(headers)
|
||||
.build()?;
|
||||
// Send HTTP request
|
||||
let res = client
|
||||
.post(format!("http://{addr}/ml/import"))
|
||||
.basic_auth(common::USER, Some(common::PASS))
|
||||
.body(body)
|
||||
.send()
|
||||
.await?;
|
||||
// Check response code
|
||||
let raw_data = res.text().await?;
|
||||
let response: ErrorResponse = serde_json::from_str(&raw_data)?;
|
||||
|
||||
assert_eq!(response.code, 400);
|
||||
assert_eq!("Model name and version must be set".to_string(), response.information);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn raw_compute() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let _lock = LockHandle::acquire_lock();
|
||||
|
|
BIN
tests/no_name.surml
Normal file
BIN
tests/no_name.surml
Normal file
Binary file not shown.
BIN
tests/no_name_or_version.surml
Normal file
BIN
tests/no_name_or_version.surml
Normal file
Binary file not shown.
BIN
tests/no_version.surml
Normal file
BIN
tests/no_version.surml
Normal file
Binary file not shown.
1
tests/should_crash.surml
Normal file
1
tests/should_crash.surml
Normal file
|
@ -0,0 +1 @@
|
|||
this should crash
|
Loading…
Reference in a new issue