From ec3bb1f659545c8a077184d854b06f089669b298 Mon Sep 17 00:00:00 2001 From: Rushmore Mushambi Date: Wed, 20 Mar 2024 13:24:24 +0200 Subject: [PATCH] Add `insert` method to the Rust SDK (#3720) --- lib/src/api/conn.rs | 2 + lib/src/api/engine/local/mod.rs | 8 ++ lib/src/api/engine/mod.rs | 19 +++ lib/src/api/engine/remote/http/mod.rs | 9 ++ lib/src/api/engine/remote/ws/native.rs | 26 +++- lib/src/api/engine/remote/ws/wasm.rs | 20 +++- lib/src/api/err/mod.rs | 12 ++ lib/src/api/method/insert.rs | 160 +++++++++++++++++++++++++ lib/src/api/method/mod.rs | 104 ++++++++++++++++ lib/src/api/method/tests/mod.rs | 6 + lib/src/api/method/tests/server.rs | 9 ++ lib/tests/api/mod.rs | 38 ++++++ 12 files changed, 407 insertions(+), 6 deletions(-) create mode 100644 lib/src/api/method/insert.rs diff --git a/lib/src/api/conn.rs b/lib/src/api/conn.rs index e98109da..9a8e621e 100644 --- a/lib/src/api/conn.rs +++ b/lib/src/api/conn.rs @@ -68,6 +68,8 @@ pub enum Method { Import, /// Invalidates a session Invalidate, + /// Inserts a record or records into a table + Insert, /// Kills a live query #[doc(hidden)] // Not supported yet Kill, diff --git a/lib/src/api/engine/local/mod.rs b/lib/src/api/engine/local/mod.rs index d9d93a00..59614a57 100644 --- a/lib/src/api/engine/local/mod.rs +++ b/lib/src/api/engine/local/mod.rs @@ -33,6 +33,7 @@ use crate::api::conn::MlConfig; use crate::api::conn::Param; use crate::api::engine::create_statement; use crate::api::engine::delete_statement; +use crate::api::engine::insert_statement; use crate::api::engine::merge_statement; use crate::api::engine::patch_statement; use crate::api::engine::select_statement; @@ -582,6 +583,13 @@ async fn router( let value = take(one, response).await?; Ok(DbResponse::Other(value)) } + Method::Insert => { + let (one, statement) = insert_statement(&mut params); + let query = Query(Statements(vec![Statement::Insert(statement)])); + let response = kvs.process(query, &*session, Some(vars.clone())).await?; + let value = take(one, response).await?; + Ok(DbResponse::Other(value)) + } Method::Patch => { let (one, statement) = patch_statement(&mut params); let query = Query(Statements(vec![Statement::Update(statement)])); diff --git a/lib/src/api/engine/mod.rs b/lib/src/api/engine/mod.rs index fd18de1e..1b47de38 100644 --- a/lib/src/api/engine/mod.rs +++ b/lib/src/api/engine/mod.rs @@ -18,6 +18,7 @@ pub mod tasks; use crate::sql::statements::CreateStatement; use crate::sql::statements::DeleteStatement; +use crate::sql::statements::InsertStatement; use crate::sql::statements::SelectStatement; use crate::sql::statements::UpdateStatement; use crate::sql::Array; @@ -89,6 +90,24 @@ fn update_statement(params: &mut [Value]) -> (bool, UpdateStatement) { ) } +#[allow(dead_code)] // used by the the embedded database and `http` +fn insert_statement(params: &mut [Value]) -> (bool, InsertStatement) { + let (what, data) = match params { + [what, data] => (mem::take(what), mem::take(data)), + _ => unreachable!(), + }; + let one = !data.is_array(); + ( + one, + InsertStatement { + into: what, + data: Data::SingleExpression(data), + output: Some(Output::After), + ..Default::default() + }, + ) +} + #[allow(dead_code)] // used by the the embedded database and `http` fn patch_statement(params: &mut [Value]) -> (bool, UpdateStatement) { let (one, what, data) = split_params(params); diff --git a/lib/src/api/engine/remote/http/mod.rs b/lib/src/api/engine/remote/http/mod.rs index 5511e206..baf05515 100644 --- a/lib/src/api/engine/remote/http/mod.rs +++ b/lib/src/api/engine/remote/http/mod.rs @@ -13,6 +13,7 @@ use crate::api::conn::MlConfig; use crate::api::conn::Param; use crate::api::engine::create_statement; use crate::api::engine::delete_statement; +use crate::api::engine::insert_statement; use crate::api::engine::merge_statement; use crate::api::engine::patch_statement; use crate::api::engine::remote::duration_from_str; @@ -474,6 +475,14 @@ async fn router( let value = take(one, request).await?; Ok(DbResponse::Other(value)) } + Method::Insert => { + let path = base_url.join(SQL_PATH)?; + let (one, statement) = insert_statement(&mut params); + let request = + client.post(path).headers(headers.clone()).auth(auth).body(statement.to_string()); + let value = take(one, request).await?; + Ok(DbResponse::Other(value)) + } Method::Patch => { let path = base_url.join(SQL_PATH)?; let (one, statement) = patch_statement(&mut params); diff --git a/lib/src/api/engine/remote/ws/native.rs b/lib/src/api/engine/remote/ws/native.rs index b0a85867..f337f51b 100644 --- a/lib/src/api/engine/remote/ws/native.rs +++ b/lib/src/api/engine/remote/ws/native.rs @@ -21,6 +21,7 @@ use crate::api::Surreal; use crate::engine::remote::ws::Data; use crate::engine::IntervalStream; use crate::opt::WaitFor; +use crate::sql::Array; use crate::sql::Strand; use crate::sql::Value; use flume::Receiver; @@ -37,6 +38,7 @@ use std::collections::HashMap; use std::collections::HashSet; use std::future::Future; use std::marker::PhantomData; +use std::mem; use std::pin::Pin; use std::sync::atomic::AtomicI64; use std::sync::Arc; @@ -354,11 +356,29 @@ pub(crate) fn router( } } // Send the response back to the caller + let mut response = response.result; + if matches!(method, Method::Insert) + { + // For insert, we need to flatten single responses in an array + if let Ok(Data::Other( + Value::Array(Array(value)), + )) = &mut response + { + if let [value] = + &mut value[..] + { + response = + Ok(Data::Other( + mem::take( + value, + ), + )); + } + } + } let _res = sender .into_send_async( - DbResponse::from( - response.result, - ), + DbResponse::from(response), ) .await; } diff --git a/lib/src/api/engine/remote/ws/wasm.rs b/lib/src/api/engine/remote/ws/wasm.rs index 5094a1fb..7469da7b 100644 --- a/lib/src/api/engine/remote/ws/wasm.rs +++ b/lib/src/api/engine/remote/ws/wasm.rs @@ -19,6 +19,7 @@ use crate::api::Surreal; use crate::engine::remote::ws::Data; use crate::engine::IntervalStream; use crate::opt::WaitFor; +use crate::sql::Array; use crate::sql::Strand; use crate::sql::Value; use flume::Receiver; @@ -38,6 +39,7 @@ use std::collections::HashMap; use std::collections::HashSet; use std::future::Future; use std::marker::PhantomData; +use std::mem; use std::pin::Pin; use std::sync::atomic::AtomicI64; use std::sync::Arc; @@ -312,10 +314,22 @@ pub(crate) fn router( } } // Send the response back to the caller + let mut response = response.result; + if matches!(method, Method::Insert) { + // For insert, we need to flatten single responses in an array + if let Ok(Data::Other(Value::Array( + Array(value), + ))) = &mut response + { + if let [value] = &mut value[..] { + response = Ok(Data::Other( + mem::take(value), + )); + } + } + } let _res = sender - .into_send_async(DbResponse::from( - response.result, - )) + .into_send_async(DbResponse::from(response)) .await; } } diff --git a/lib/src/api/err/mod.rs b/lib/src/api/err/mod.rs index 9769d2f5..d6ee95e7 100644 --- a/lib/src/api/err/mod.rs +++ b/lib/src/api/err/mod.rs @@ -194,6 +194,18 @@ pub enum Error { /// Called `Response::take` or `Response::stream` on a query response more than once #[error("Tried to take a query response that has already been taken")] ResponseAlreadyTaken, + + /// Tried to insert on an object + #[error("Insert queries on objects not supported: {0}")] + InsertOnObject(Object), + + /// Tried to insert on an array + #[error("Insert queries on arrays not supported: {0}")] + InsertOnArray(Array), + + /// Tried to insert on an edge or edges + #[error("Insert queries on edges not supported: {0}")] + InsertOnEdges(Edges), } #[cfg(feature = "protocol-http")] diff --git a/lib/src/api/method/insert.rs b/lib/src/api/method/insert.rs new file mode 100644 index 00000000..c3accf38 --- /dev/null +++ b/lib/src/api/method/insert.rs @@ -0,0 +1,160 @@ +use crate::api::conn::Method; +use crate::api::conn::Param; +use crate::api::err::Error; +use crate::api::method::Content; +use crate::api::opt::Resource; +use crate::api::Connection; +use crate::api::Result; +use crate::method::OnceLockExt; +use crate::sql::Ident; +use crate::sql::Part; +use crate::sql::Table; +use crate::sql::Value; +use crate::Surreal; +use serde::de::DeserializeOwned; +use serde::Serialize; +use std::borrow::Cow; +use std::future::Future; +use std::future::IntoFuture; +use std::marker::PhantomData; +use std::pin::Pin; + +/// An insert future +#[derive(Debug)] +#[must_use = "futures do nothing unless you `.await` or poll them"] +pub struct Insert<'r, C: Connection, R> { + pub(super) client: Cow<'r, Surreal>, + pub(super) resource: Result, + pub(super) response_type: PhantomData, +} + +impl Insert<'_, C, R> +where + C: Connection, +{ + /// Converts to an owned type which can easily be moved to a different thread + pub fn into_owned(self) -> Insert<'static, C, R> { + Insert { + client: Cow::Owned(self.client.into_owned()), + ..self + } + } +} + +macro_rules! into_future { + ($method:ident) => { + fn into_future(self) -> Self::IntoFuture { + let Insert { + client, + resource, + .. + } = self; + Box::pin(async move { + let (table, data) = match resource? { + Resource::Table(table) => (table.into(), Value::Object(Default::default())), + Resource::RecordId(record_id) => ( + Table(record_id.tb.clone()).into(), + crate::map! { String::from("id") => record_id.into() }.into(), + ), + Resource::Object(obj) => return Err(Error::InsertOnObject(obj).into()), + Resource::Array(arr) => return Err(Error::InsertOnArray(arr).into()), + Resource::Edges(edges) => return Err(Error::InsertOnEdges(edges).into()), + }; + let mut conn = Client::new(Method::Insert); + let param = vec![table, data]; + conn.$method(client.router.extract()?, Param::new(param)).await + }) + } + }; +} + +impl<'r, Client> IntoFuture for Insert<'r, Client, Value> +where + Client: Connection, +{ + type Output = Result; + type IntoFuture = Pin + Send + Sync + 'r>>; + + into_future! {execute_value} +} + +impl<'r, Client, R> IntoFuture for Insert<'r, Client, Option> +where + Client: Connection, + R: DeserializeOwned, +{ + type Output = Result>; + type IntoFuture = Pin + Send + Sync + 'r>>; + + into_future! {execute_opt} +} + +impl<'r, Client, R> IntoFuture for Insert<'r, Client, Vec> +where + Client: Connection, + R: DeserializeOwned, +{ + type Output = Result>; + type IntoFuture = Pin + Send + Sync + 'r>>; + + into_future! {execute_vec} +} + +impl<'r, C, R> Insert<'r, C, R> +where + C: Connection, + R: DeserializeOwned, +{ + /// Specifies the data to insert into the table + pub fn content(self, data: D) -> Content<'r, C, Value, R> + where + D: Serialize, + { + let mut content = Content { + client: self.client, + method: Method::Insert, + resource: self.resource, + range: None, + content: Value::None, + response_type: PhantomData, + }; + match crate::sql::to_value(data) { + Ok(mut data) => match content.resource { + Ok(Resource::Table(table)) => { + content.resource = Ok(table.into()); + content.content = data; + } + Ok(Resource::RecordId(record_id)) => match data.is_array() { + true => { + content.resource = Err(Error::InvalidParams( + "Tried to insert multiple records on a record ID".to_owned(), + ) + .into()); + } + false => { + content.resource = Ok(Table(record_id.tb.clone()).into()); + let id = Part::Field(Ident("id".to_owned())); + data.put(&[id], record_id.into()); + content.content = data; + } + }, + Ok(Resource::Object(obj)) => { + content.resource = Err(Error::InsertOnObject(obj).into()); + } + Ok(Resource::Array(arr)) => { + content.resource = Err(Error::InsertOnArray(arr).into()); + } + Ok(Resource::Edges(edges)) => { + content.resource = Err(Error::InsertOnEdges(edges).into()); + } + Err(error) => { + content.resource = Err(error); + } + }, + Err(error) => { + content.resource = Err(error.into()); + } + }; + content + } +} diff --git a/lib/src/api/method/mod.rs b/lib/src/api/method/mod.rs index e78f8b97..94fde2be 100644 --- a/lib/src/api/method/mod.rs +++ b/lib/src/api/method/mod.rs @@ -13,6 +13,7 @@ mod delete; mod export; mod health; mod import; +mod insert; mod invalidate; mod merge; mod patch; @@ -45,6 +46,7 @@ pub use export::Backup; pub use export::Export; pub use health::Health; pub use import::Import; +pub use insert::Insert; pub use invalidate::Invalidate; pub use live::Stream; pub use merge::Merge; @@ -113,6 +115,7 @@ impl Method { Method::Health => "health", Method::Import => "import", Method::Invalidate => "invalidate", + Method::Insert => "insert", Method::Kill => "kill", Method::Live => "live", Method::Merge => "merge", @@ -771,6 +774,107 @@ where } } + /// Insert a record or records into a table + /// + /// # Examples + /// + /// ```no_run + /// use serde::Serialize; + /// use surrealdb::sql; + /// + /// # #[derive(serde::Deserialize)] + /// # struct Person; + /// # + /// #[derive(Serialize)] + /// struct Settings { + /// active: bool, + /// marketing: bool, + /// } + /// + /// #[derive(Serialize)] + /// struct User<'a> { + /// name: &'a str, + /// settings: Settings, + /// } + /// + /// # #[tokio::main] + /// # async fn main() -> surrealdb::Result<()> { + /// # let db = surrealdb::engine::any::connect("mem://").await?; + /// # + /// // Select the namespace/database to use + /// db.use_ns("namespace").use_db("database").await?; + /// + /// // Insert a record with a specific ID + /// let person: Option = db.insert(("person", "tobie")) + /// .content(User { + /// name: "Tobie", + /// settings: Settings { + /// active: true, + /// marketing: true, + /// }, + /// }) + /// .await?; + /// + /// // Insert multiple records into the table + /// let people: Vec = db.insert("person") + /// .content(vec![ + /// User { + /// name: "Tobie", + /// settings: Settings { + /// active: true, + /// marketing: false, + /// }, + /// }, + /// User { + /// name: "Jaime", + /// settings: Settings { + /// active: true, + /// marketing: true, + /// }, + /// }, + /// ]) + /// .await?; + /// + /// // Insert multiple records with pre-defined IDs + /// #[derive(Serialize)] + /// struct UserWithId<'a> { + /// id: sql::Thing, + /// name: &'a str, + /// settings: Settings, + /// } + /// + /// let people: Vec = db.insert("person") + /// .content(vec![ + /// UserWithId { + /// id: sql::thing("person:tobie")?, + /// name: "Tobie", + /// settings: Settings { + /// active: true, + /// marketing: false, + /// }, + /// }, + /// UserWithId { + /// id: sql::thing("person:jaime")?, + /// name: "Jaime", + /// settings: Settings { + /// active: true, + /// marketing: true, + /// }, + /// }, + /// ]) + /// .await?; + /// # + /// # Ok(()) + /// # } + /// ``` + pub fn insert(&self, resource: impl opt::IntoResource) -> Insert { + Insert { + client: Cow::Borrowed(self), + resource: resource.into_resource(), + response_type: PhantomData, + } + } + /// Updates all records in a table, or a specific record /// /// # Examples diff --git a/lib/src/api/method/tests/mod.rs b/lib/src/api/method/tests/mod.rs index aee2557b..0ee4edfc 100644 --- a/lib/src/api/method/tests/mod.rs +++ b/lib/src/api/method/tests/mod.rs @@ -138,6 +138,12 @@ async fn api() { DB.update(USER).range("jane".."john").content(User::default()).await.unwrap(); let _: Option = DB.update((USER, "john")).content(User::default()).await.unwrap(); + // insert + let _: Vec = DB.insert(USER).await.unwrap(); + let _: Option = DB.insert((USER, "john")).await.unwrap(); + let _: Vec = DB.insert(USER).content(User::default()).await.unwrap(); + let _: Option = DB.insert((USER, "john")).content(User::default()).await.unwrap(); + // merge let _: Vec = DB.update(USER).merge(User::default()).await.unwrap(); let _: Vec = DB.update(USER).range("jane".."john").merge(User::default()).await.unwrap(); diff --git a/lib/src/api/method/tests/server.rs b/lib/src/api/method/tests/server.rs index 9fb419c1..4242d9c9 100644 --- a/lib/src/api/method/tests/server.rs +++ b/lib/src/api/method/tests/server.rs @@ -78,6 +78,15 @@ pub(super) fn mock(route_rx: Receiver>) { } _ => unreachable!(), }, + Method::Insert => match ¶ms[..] { + [Value::Table(..), Value::Array(..)] => { + Ok(DbResponse::Other(Value::Array(Array(Vec::new())))) + } + [Value::Table(..), _] => { + Ok(DbResponse::Other(to_value(User::default()).unwrap())) + } + _ => unreachable!(), + }, Method::Export | Method::Import => match param.file { Some(_) => Ok(DbResponse::Other(Value::None)), _ => unreachable!(), diff --git a/lib/tests/api/mod.rs b/lib/tests/api/mod.rs index 320f3a52..6a698873 100644 --- a/lib/tests/api/mod.rs +++ b/lib/tests/api/mod.rs @@ -486,6 +486,44 @@ async fn create_record_with_id_with_content() { assert_eq!(value.record(), thing("user:jane").ok()); } +#[test_log::test(tokio::test)] +async fn insert_table() { + let (permit, db) = new_db().await; + db.use_ns(NS).use_db(Ulid::new().to_string()).await.unwrap(); + drop(permit); + let table = "user"; + let _: Vec = db.insert(table).await.unwrap(); + let _: Vec = db.insert(table).content(json!({ "foo": "bar" })).await.unwrap(); + let _: Vec = db.insert(table).content(json!([{ "foo": "bar" }])).await.unwrap(); + let _: Value = db.insert(Resource::from(table)).await.unwrap(); + let _: Value = db.insert(Resource::from(table)).content(json!({ "foo": "bar" })).await.unwrap(); + let _: Value = + db.insert(Resource::from(table)).content(json!([{ "foo": "bar" }])).await.unwrap(); + let users: Vec = db.insert(table).await.unwrap(); + assert!(!users.is_empty()); +} + +#[test_log::test(tokio::test)] +async fn insert_thing() { + let (permit, db) = new_db().await; + db.use_ns(NS).use_db(Ulid::new().to_string()).await.unwrap(); + drop(permit); + let table = "user"; + let _: Option = db.insert((table, "user1")).await.unwrap(); + let _: Option = + db.insert((table, "user1")).content(json!({ "foo": "bar" })).await.unwrap(); + let _: Value = db.insert(Resource::from((table, "user2"))).await.unwrap(); + let _: Value = + db.insert(Resource::from((table, "user2"))).content(json!({ "foo": "bar" })).await.unwrap(); + let user: Option = db.insert((table, "user3")).await.unwrap(); + assert_eq!( + user, + Some(RecordId { + id: thing("user:user3").unwrap(), + }) + ); +} + #[test_log::test(tokio::test)] async fn select_table() { let (permit, db) = new_db().await;