2023-12-12 13:51:43 +00:00
|
|
|
//! This file defines the endpoints for the ML API for importing and exporting SurrealML models.
|
|
|
|
use crate::dbs::DB;
|
|
|
|
use crate::err::Error;
|
|
|
|
use crate::net::output;
|
|
|
|
use axum::extract::{BodyStream, DefaultBodyLimit, Path};
|
|
|
|
use axum::response::IntoResponse;
|
|
|
|
use axum::response::Response;
|
|
|
|
use axum::routing::{get, post};
|
|
|
|
use axum::Extension;
|
|
|
|
use axum::Router;
|
|
|
|
use bytes::Bytes;
|
|
|
|
use futures_util::StreamExt;
|
|
|
|
use http::StatusCode;
|
|
|
|
use http_body::Body as HttpBody;
|
|
|
|
use hyper::body::Body;
|
|
|
|
use surrealdb::dbs::Session;
|
|
|
|
use surrealdb::iam::check::check_ns_db;
|
|
|
|
use surrealdb::iam::Action::{Edit, View};
|
|
|
|
use surrealdb::iam::ResourceKind::Model;
|
|
|
|
use surrealdb::kvs::{LockType::Optimistic, TransactionType::Read};
|
2024-02-02 22:10:47 +00:00
|
|
|
use surrealdb::ml::storage::surml_file::SurMlFile;
|
2023-12-12 13:51:43 +00:00
|
|
|
use surrealdb::sql::statements::{DefineModelStatement, DefineStatement};
|
|
|
|
use tower_http::limit::RequestBodyLimitLayer;
|
|
|
|
|
|
|
|
const MAX: usize = 1024 * 1024 * 1024 * 4; // 4 GiB
|
|
|
|
|
|
|
|
/// The router definition for the ML API endpoints.
|
|
|
|
pub(super) fn router<S, B>() -> Router<S, B>
|
|
|
|
where
|
|
|
|
B: HttpBody + Send + 'static,
|
|
|
|
B::Data: Send + Into<Bytes>,
|
|
|
|
B::Error: std::error::Error + Send + Sync + 'static,
|
|
|
|
S: Clone + Send + Sync + 'static,
|
|
|
|
{
|
|
|
|
Router::new()
|
|
|
|
.route("/ml/import", post(import))
|
|
|
|
.route("/ml/export/:name/:version", get(export))
|
|
|
|
.route_layer(DefaultBodyLimit::disable())
|
|
|
|
.layer(RequestBodyLimitLayer::new(MAX))
|
|
|
|
}
|
|
|
|
|
|
|
|
/// This endpoint allows the user to import a model into the database.
|
|
|
|
async fn import(
|
|
|
|
Extension(session): Extension<Session>,
|
|
|
|
mut stream: BodyStream,
|
|
|
|
) -> Result<impl IntoResponse, impl IntoResponse> {
|
|
|
|
// Get the datastore reference
|
|
|
|
let db = DB.get().unwrap();
|
|
|
|
// Ensure a NS and DB are set
|
|
|
|
let (nsv, dbv) = check_ns_db(&session)?;
|
|
|
|
// Check the permissions level
|
|
|
|
db.check(&session, Edit, Model.on_db(&nsv, &dbv))?;
|
|
|
|
// Create a new buffer
|
|
|
|
let mut buffer = Vec::new();
|
|
|
|
// Load all the uploaded file chunks
|
|
|
|
while let Some(chunk) = stream.next().await {
|
|
|
|
buffer.extend_from_slice(&chunk?);
|
|
|
|
}
|
|
|
|
// Check that the SurrealML file is valid
|
|
|
|
let file = match SurMlFile::from_bytes(buffer) {
|
|
|
|
Ok(file) => file,
|
|
|
|
Err(err) => return Err(Error::Other(err.to_string())),
|
|
|
|
};
|
|
|
|
// Convert the file back in to raw bytes
|
|
|
|
let data = file.to_bytes();
|
|
|
|
// Calculate the hash of the model file
|
|
|
|
let hash = surrealdb::obs::hash(&data);
|
|
|
|
// Calculate the path of the model file
|
|
|
|
let path = format!(
|
|
|
|
"ml/{nsv}/{dbv}/{}-{}-{hash}.surml",
|
|
|
|
file.header.name.to_string(),
|
|
|
|
file.header.version.to_string()
|
|
|
|
);
|
|
|
|
// Insert the file data in to the store
|
|
|
|
surrealdb::obs::put(&path, data).await?;
|
|
|
|
// Insert the model in to the database
|
|
|
|
db.process(
|
|
|
|
DefineStatement::Model(DefineModelStatement {
|
|
|
|
hash,
|
|
|
|
name: file.header.name.to_string().into(),
|
|
|
|
version: file.header.version.to_string(),
|
|
|
|
comment: Some(file.header.description.to_string().into()),
|
|
|
|
..Default::default()
|
|
|
|
})
|
|
|
|
.into(),
|
|
|
|
&session,
|
|
|
|
None,
|
|
|
|
)
|
|
|
|
.await?;
|
|
|
|
//
|
|
|
|
Ok(output::none())
|
|
|
|
}
|
|
|
|
|
|
|
|
/// This endpoint allows the user to export a model from the database.
|
|
|
|
async fn export(
|
|
|
|
Extension(session): Extension<Session>,
|
|
|
|
Path((name, version)): Path<(String, String)>,
|
|
|
|
) -> Result<impl IntoResponse, Error> {
|
|
|
|
// Get the datastore reference
|
|
|
|
let db = DB.get().unwrap();
|
|
|
|
// Ensure a NS and DB are set
|
|
|
|
let (nsv, dbv) = check_ns_db(&session)?;
|
|
|
|
// Check the permissions level
|
|
|
|
db.check(&session, View, Model.on_db(&nsv, &dbv))?;
|
|
|
|
// Start a new readonly transaction
|
|
|
|
let mut tx = db.transaction(Read, Optimistic).await?;
|
|
|
|
// Attempt to get the model definition
|
|
|
|
let info = tx.get_db_model(&nsv, &dbv, &name, &version).await?;
|
|
|
|
// Calculate the path of the model file
|
|
|
|
let path = format!("ml/{nsv}/{dbv}/{name}-{version}-{}.surml", info.hash);
|
|
|
|
// Export the file data in to the store
|
|
|
|
let mut data = surrealdb::obs::stream(path).await?;
|
|
|
|
// Create a chunked response
|
|
|
|
let (mut chn, body) = Body::channel();
|
|
|
|
// Process all stream values
|
|
|
|
tokio::spawn(async move {
|
|
|
|
while let Some(Ok(v)) = data.next().await {
|
|
|
|
let _ = chn.send_data(v).await;
|
|
|
|
}
|
|
|
|
});
|
|
|
|
// Return the streamed body
|
|
|
|
Ok(Response::builder().status(StatusCode::OK).body(body).unwrap())
|
|
|
|
}
|