// RUST_LOG=warn cargo make ci-ml-integration mod common; #[cfg(any(feature = "ml", feature = "ml2"))] mod ml_integration { use super::*; use http::{header, StatusCode}; use hyper::Body; use serde::{Deserialize, Serialize}; use std::sync::atomic::{AtomicBool, Ordering}; use std::time::Duration; use surrealdb::ml::storage::stream_adapter::StreamAdapter; use test_log::test; use ulid::Ulid; static LOCK: AtomicBool = AtomicBool::new(false); #[derive(Serialize, Deserialize, Debug)] struct Data { result: f64, status: String, time: String, } struct LockHandle; impl LockHandle { fn acquire_lock() -> Self { while LOCK.compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed) != Ok(false) { std::thread::sleep(Duration::from_millis(100)); } LockHandle } } impl Drop for LockHandle { fn drop(&mut self) { LOCK.store(false, Ordering::Release); } } async fn upload_file(addr: &str, ns: &str, db: &str) -> Result<(), Box<dyn std::error::Error>> { let generator = StreamAdapter::new(5, "./tests/linear_test.surml".to_string()); let body = Body::wrap_stream(generator); // Prepare HTTP client let mut headers = reqwest::header::HeaderMap::new(); headers.insert("NS", ns.parse()?); headers.insert("DB", db.parse()?); let client = reqwest::Client::builder() .connect_timeout(Duration::from_secs(1)) .default_headers(headers) .build()?; // Send HTTP request let res = client .post(format!("http://{addr}/ml/import")) .basic_auth(common::USER, Some(common::PASS)) .body(body) .send() .await?; // Check response code assert_eq!(res.status(), StatusCode::OK); Ok(()) } #[test(tokio::test)] async fn upload_model() -> Result<(), Box<dyn std::error::Error>> { let _lock = LockHandle::acquire_lock(); let (addr, _server) = common::start_server_with_defaults().await.unwrap(); let ns = Ulid::new().to_string(); let db = Ulid::new().to_string(); upload_file(&addr, &ns, &db).await?; Ok(()) } #[test(tokio::test)] async fn raw_compute() -> Result<(), Box<dyn std::error::Error>> { let _lock = LockHandle::acquire_lock(); let (addr, _server) = common::start_server_with_defaults().await.unwrap(); let ns = Ulid::new().to_string(); let db = Ulid::new().to_string(); upload_file(&addr, &ns, &db).await?; // Prepare HTTP client let mut headers = reqwest::header::HeaderMap::new(); headers.insert("NS", ns.parse()?); headers.insert("DB", db.parse()?); headers.insert(header::ACCEPT, "application/json".parse()?); let client = reqwest::Client::builder() .connect_timeout(Duration::from_millis(10)) .default_headers(headers) .build()?; // perform an SQL query to check if the model is available { let res = client .post(format!("http://{addr}/sql")) .basic_auth(common::USER, Some(common::PASS)) .body(r#"ml::Prediction<0.0.1>([1.0, 1.0]);"#) .send() .await?; assert!(res.status().is_success(), "body: {}", res.text().await?); let body = res.text().await?; let deserialized_data: Vec<Data> = serde_json::from_str(&body)?; assert_eq!(deserialized_data[0].result, 0.9998061656951904); } Ok(()) } #[test(tokio::test)] async fn buffered_compute() -> Result<(), Box<dyn std::error::Error>> { let _lock = LockHandle::acquire_lock(); let (addr, _server) = common::start_server_with_defaults().await.unwrap(); let ns = Ulid::new().to_string(); let db = Ulid::new().to_string(); upload_file(&addr, &ns, &db).await?; // Prepare HTTP client let mut headers = reqwest::header::HeaderMap::new(); headers.insert("NS", ns.parse()?); headers.insert("DB", db.parse()?); headers.insert(header::ACCEPT, "application/json".parse()?); let client = reqwest::Client::builder() .connect_timeout(Duration::from_millis(10)) .default_headers(headers) .build()?; // perform an SQL query to check if the model is available { let res = client .post(format!("http://{addr}/sql")) .basic_auth(common::USER, Some(common::PASS)) .body(r#"ml::Prediction<0.0.1>({squarefoot: 500.0, num_floors: 1.0});"#) .send() .await?; assert!(res.status().is_success(), "body: {}", res.text().await?); let body = res.text().await?; let deserialized_data: Vec<Data> = serde_json::from_str(&body)?; assert_eq!(deserialized_data[0].result, 177206.21875); } Ok(()) } }