Add insert method to the Rust SDK (#3720)

This commit is contained in:
Rushmore Mushambi 2024-03-20 13:24:24 +02:00 committed by GitHub
parent 8b13546327
commit ec3bb1f659
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 407 additions and 6 deletions

View file

@ -68,6 +68,8 @@ pub enum Method {
Import, Import,
/// Invalidates a session /// Invalidates a session
Invalidate, Invalidate,
/// Inserts a record or records into a table
Insert,
/// Kills a live query /// Kills a live query
#[doc(hidden)] // Not supported yet #[doc(hidden)] // Not supported yet
Kill, Kill,

View file

@ -33,6 +33,7 @@ use crate::api::conn::MlConfig;
use crate::api::conn::Param; use crate::api::conn::Param;
use crate::api::engine::create_statement; use crate::api::engine::create_statement;
use crate::api::engine::delete_statement; use crate::api::engine::delete_statement;
use crate::api::engine::insert_statement;
use crate::api::engine::merge_statement; use crate::api::engine::merge_statement;
use crate::api::engine::patch_statement; use crate::api::engine::patch_statement;
use crate::api::engine::select_statement; use crate::api::engine::select_statement;
@ -582,6 +583,13 @@ async fn router(
let value = take(one, response).await?; let value = take(one, response).await?;
Ok(DbResponse::Other(value)) Ok(DbResponse::Other(value))
} }
Method::Insert => {
let (one, statement) = insert_statement(&mut params);
let query = Query(Statements(vec![Statement::Insert(statement)]));
let response = kvs.process(query, &*session, Some(vars.clone())).await?;
let value = take(one, response).await?;
Ok(DbResponse::Other(value))
}
Method::Patch => { Method::Patch => {
let (one, statement) = patch_statement(&mut params); let (one, statement) = patch_statement(&mut params);
let query = Query(Statements(vec![Statement::Update(statement)])); let query = Query(Statements(vec![Statement::Update(statement)]));

View file

@ -18,6 +18,7 @@ pub mod tasks;
use crate::sql::statements::CreateStatement; use crate::sql::statements::CreateStatement;
use crate::sql::statements::DeleteStatement; use crate::sql::statements::DeleteStatement;
use crate::sql::statements::InsertStatement;
use crate::sql::statements::SelectStatement; use crate::sql::statements::SelectStatement;
use crate::sql::statements::UpdateStatement; use crate::sql::statements::UpdateStatement;
use crate::sql::Array; use crate::sql::Array;
@ -89,6 +90,24 @@ fn update_statement(params: &mut [Value]) -> (bool, UpdateStatement) {
) )
} }
#[allow(dead_code)] // used by the the embedded database and `http`
fn insert_statement(params: &mut [Value]) -> (bool, InsertStatement) {
let (what, data) = match params {
[what, data] => (mem::take(what), mem::take(data)),
_ => unreachable!(),
};
let one = !data.is_array();
(
one,
InsertStatement {
into: what,
data: Data::SingleExpression(data),
output: Some(Output::After),
..Default::default()
},
)
}
#[allow(dead_code)] // used by the the embedded database and `http` #[allow(dead_code)] // used by the the embedded database and `http`
fn patch_statement(params: &mut [Value]) -> (bool, UpdateStatement) { fn patch_statement(params: &mut [Value]) -> (bool, UpdateStatement) {
let (one, what, data) = split_params(params); let (one, what, data) = split_params(params);

View file

@ -13,6 +13,7 @@ use crate::api::conn::MlConfig;
use crate::api::conn::Param; use crate::api::conn::Param;
use crate::api::engine::create_statement; use crate::api::engine::create_statement;
use crate::api::engine::delete_statement; use crate::api::engine::delete_statement;
use crate::api::engine::insert_statement;
use crate::api::engine::merge_statement; use crate::api::engine::merge_statement;
use crate::api::engine::patch_statement; use crate::api::engine::patch_statement;
use crate::api::engine::remote::duration_from_str; use crate::api::engine::remote::duration_from_str;
@ -474,6 +475,14 @@ async fn router(
let value = take(one, request).await?; let value = take(one, request).await?;
Ok(DbResponse::Other(value)) Ok(DbResponse::Other(value))
} }
Method::Insert => {
let path = base_url.join(SQL_PATH)?;
let (one, statement) = insert_statement(&mut params);
let request =
client.post(path).headers(headers.clone()).auth(auth).body(statement.to_string());
let value = take(one, request).await?;
Ok(DbResponse::Other(value))
}
Method::Patch => { Method::Patch => {
let path = base_url.join(SQL_PATH)?; let path = base_url.join(SQL_PATH)?;
let (one, statement) = patch_statement(&mut params); let (one, statement) = patch_statement(&mut params);

View file

@ -21,6 +21,7 @@ use crate::api::Surreal;
use crate::engine::remote::ws::Data; use crate::engine::remote::ws::Data;
use crate::engine::IntervalStream; use crate::engine::IntervalStream;
use crate::opt::WaitFor; use crate::opt::WaitFor;
use crate::sql::Array;
use crate::sql::Strand; use crate::sql::Strand;
use crate::sql::Value; use crate::sql::Value;
use flume::Receiver; use flume::Receiver;
@ -37,6 +38,7 @@ use std::collections::HashMap;
use std::collections::HashSet; use std::collections::HashSet;
use std::future::Future; use std::future::Future;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::mem;
use std::pin::Pin; use std::pin::Pin;
use std::sync::atomic::AtomicI64; use std::sync::atomic::AtomicI64;
use std::sync::Arc; use std::sync::Arc;
@ -354,11 +356,29 @@ pub(crate) fn router(
} }
} }
// Send the response back to the caller // Send the response back to the caller
let mut response = response.result;
if matches!(method, Method::Insert)
{
// For insert, we need to flatten single responses in an array
if let Ok(Data::Other(
Value::Array(Array(value)),
)) = &mut response
{
if let [value] =
&mut value[..]
{
response =
Ok(Data::Other(
mem::take(
value,
),
));
}
}
}
let _res = sender let _res = sender
.into_send_async( .into_send_async(
DbResponse::from( DbResponse::from(response),
response.result,
),
) )
.await; .await;
} }

View file

@ -19,6 +19,7 @@ use crate::api::Surreal;
use crate::engine::remote::ws::Data; use crate::engine::remote::ws::Data;
use crate::engine::IntervalStream; use crate::engine::IntervalStream;
use crate::opt::WaitFor; use crate::opt::WaitFor;
use crate::sql::Array;
use crate::sql::Strand; use crate::sql::Strand;
use crate::sql::Value; use crate::sql::Value;
use flume::Receiver; use flume::Receiver;
@ -38,6 +39,7 @@ use std::collections::HashMap;
use std::collections::HashSet; use std::collections::HashSet;
use std::future::Future; use std::future::Future;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::mem;
use std::pin::Pin; use std::pin::Pin;
use std::sync::atomic::AtomicI64; use std::sync::atomic::AtomicI64;
use std::sync::Arc; use std::sync::Arc;
@ -312,10 +314,22 @@ pub(crate) fn router(
} }
} }
// Send the response back to the caller // Send the response back to the caller
let mut response = response.result;
if matches!(method, Method::Insert) {
// For insert, we need to flatten single responses in an array
if let Ok(Data::Other(Value::Array(
Array(value),
))) = &mut response
{
if let [value] = &mut value[..] {
response = Ok(Data::Other(
mem::take(value),
));
}
}
}
let _res = sender let _res = sender
.into_send_async(DbResponse::from( .into_send_async(DbResponse::from(response))
response.result,
))
.await; .await;
} }
} }

View file

@ -194,6 +194,18 @@ pub enum Error {
/// Called `Response::take` or `Response::stream` on a query response more than once /// Called `Response::take` or `Response::stream` on a query response more than once
#[error("Tried to take a query response that has already been taken")] #[error("Tried to take a query response that has already been taken")]
ResponseAlreadyTaken, ResponseAlreadyTaken,
/// Tried to insert on an object
#[error("Insert queries on objects not supported: {0}")]
InsertOnObject(Object),
/// Tried to insert on an array
#[error("Insert queries on arrays not supported: {0}")]
InsertOnArray(Array),
/// Tried to insert on an edge or edges
#[error("Insert queries on edges not supported: {0}")]
InsertOnEdges(Edges),
} }
#[cfg(feature = "protocol-http")] #[cfg(feature = "protocol-http")]

View file

@ -0,0 +1,160 @@
use crate::api::conn::Method;
use crate::api::conn::Param;
use crate::api::err::Error;
use crate::api::method::Content;
use crate::api::opt::Resource;
use crate::api::Connection;
use crate::api::Result;
use crate::method::OnceLockExt;
use crate::sql::Ident;
use crate::sql::Part;
use crate::sql::Table;
use crate::sql::Value;
use crate::Surreal;
use serde::de::DeserializeOwned;
use serde::Serialize;
use std::borrow::Cow;
use std::future::Future;
use std::future::IntoFuture;
use std::marker::PhantomData;
use std::pin::Pin;
/// An insert future
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct Insert<'r, C: Connection, R> {
pub(super) client: Cow<'r, Surreal<C>>,
pub(super) resource: Result<Resource>,
pub(super) response_type: PhantomData<R>,
}
impl<C, R> Insert<'_, C, R>
where
C: Connection,
{
/// Converts to an owned type which can easily be moved to a different thread
pub fn into_owned(self) -> Insert<'static, C, R> {
Insert {
client: Cow::Owned(self.client.into_owned()),
..self
}
}
}
macro_rules! into_future {
($method:ident) => {
fn into_future(self) -> Self::IntoFuture {
let Insert {
client,
resource,
..
} = self;
Box::pin(async move {
let (table, data) = match resource? {
Resource::Table(table) => (table.into(), Value::Object(Default::default())),
Resource::RecordId(record_id) => (
Table(record_id.tb.clone()).into(),
crate::map! { String::from("id") => record_id.into() }.into(),
),
Resource::Object(obj) => return Err(Error::InsertOnObject(obj).into()),
Resource::Array(arr) => return Err(Error::InsertOnArray(arr).into()),
Resource::Edges(edges) => return Err(Error::InsertOnEdges(edges).into()),
};
let mut conn = Client::new(Method::Insert);
let param = vec![table, data];
conn.$method(client.router.extract()?, Param::new(param)).await
})
}
};
}
impl<'r, Client> IntoFuture for Insert<'r, Client, Value>
where
Client: Connection,
{
type Output = Result<Value>;
type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send + Sync + 'r>>;
into_future! {execute_value}
}
impl<'r, Client, R> IntoFuture for Insert<'r, Client, Option<R>>
where
Client: Connection,
R: DeserializeOwned,
{
type Output = Result<Option<R>>;
type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send + Sync + 'r>>;
into_future! {execute_opt}
}
impl<'r, Client, R> IntoFuture for Insert<'r, Client, Vec<R>>
where
Client: Connection,
R: DeserializeOwned,
{
type Output = Result<Vec<R>>;
type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send + Sync + 'r>>;
into_future! {execute_vec}
}
impl<'r, C, R> Insert<'r, C, R>
where
C: Connection,
R: DeserializeOwned,
{
/// Specifies the data to insert into the table
pub fn content<D>(self, data: D) -> Content<'r, C, Value, R>
where
D: Serialize,
{
let mut content = Content {
client: self.client,
method: Method::Insert,
resource: self.resource,
range: None,
content: Value::None,
response_type: PhantomData,
};
match crate::sql::to_value(data) {
Ok(mut data) => match content.resource {
Ok(Resource::Table(table)) => {
content.resource = Ok(table.into());
content.content = data;
}
Ok(Resource::RecordId(record_id)) => match data.is_array() {
true => {
content.resource = Err(Error::InvalidParams(
"Tried to insert multiple records on a record ID".to_owned(),
)
.into());
}
false => {
content.resource = Ok(Table(record_id.tb.clone()).into());
let id = Part::Field(Ident("id".to_owned()));
data.put(&[id], record_id.into());
content.content = data;
}
},
Ok(Resource::Object(obj)) => {
content.resource = Err(Error::InsertOnObject(obj).into());
}
Ok(Resource::Array(arr)) => {
content.resource = Err(Error::InsertOnArray(arr).into());
}
Ok(Resource::Edges(edges)) => {
content.resource = Err(Error::InsertOnEdges(edges).into());
}
Err(error) => {
content.resource = Err(error);
}
},
Err(error) => {
content.resource = Err(error.into());
}
};
content
}
}

View file

@ -13,6 +13,7 @@ mod delete;
mod export; mod export;
mod health; mod health;
mod import; mod import;
mod insert;
mod invalidate; mod invalidate;
mod merge; mod merge;
mod patch; mod patch;
@ -45,6 +46,7 @@ pub use export::Backup;
pub use export::Export; pub use export::Export;
pub use health::Health; pub use health::Health;
pub use import::Import; pub use import::Import;
pub use insert::Insert;
pub use invalidate::Invalidate; pub use invalidate::Invalidate;
pub use live::Stream; pub use live::Stream;
pub use merge::Merge; pub use merge::Merge;
@ -113,6 +115,7 @@ impl Method {
Method::Health => "health", Method::Health => "health",
Method::Import => "import", Method::Import => "import",
Method::Invalidate => "invalidate", Method::Invalidate => "invalidate",
Method::Insert => "insert",
Method::Kill => "kill", Method::Kill => "kill",
Method::Live => "live", Method::Live => "live",
Method::Merge => "merge", Method::Merge => "merge",
@ -771,6 +774,107 @@ where
} }
} }
/// Insert a record or records into a table
///
/// # Examples
///
/// ```no_run
/// use serde::Serialize;
/// use surrealdb::sql;
///
/// # #[derive(serde::Deserialize)]
/// # struct Person;
/// #
/// #[derive(Serialize)]
/// struct Settings {
/// active: bool,
/// marketing: bool,
/// }
///
/// #[derive(Serialize)]
/// struct User<'a> {
/// name: &'a str,
/// settings: Settings,
/// }
///
/// # #[tokio::main]
/// # async fn main() -> surrealdb::Result<()> {
/// # let db = surrealdb::engine::any::connect("mem://").await?;
/// #
/// // Select the namespace/database to use
/// db.use_ns("namespace").use_db("database").await?;
///
/// // Insert a record with a specific ID
/// let person: Option<Person> = db.insert(("person", "tobie"))
/// .content(User {
/// name: "Tobie",
/// settings: Settings {
/// active: true,
/// marketing: true,
/// },
/// })
/// .await?;
///
/// // Insert multiple records into the table
/// let people: Vec<Person> = db.insert("person")
/// .content(vec![
/// User {
/// name: "Tobie",
/// settings: Settings {
/// active: true,
/// marketing: false,
/// },
/// },
/// User {
/// name: "Jaime",
/// settings: Settings {
/// active: true,
/// marketing: true,
/// },
/// },
/// ])
/// .await?;
///
/// // Insert multiple records with pre-defined IDs
/// #[derive(Serialize)]
/// struct UserWithId<'a> {
/// id: sql::Thing,
/// name: &'a str,
/// settings: Settings,
/// }
///
/// let people: Vec<Person> = db.insert("person")
/// .content(vec![
/// UserWithId {
/// id: sql::thing("person:tobie")?,
/// name: "Tobie",
/// settings: Settings {
/// active: true,
/// marketing: false,
/// },
/// },
/// UserWithId {
/// id: sql::thing("person:jaime")?,
/// name: "Jaime",
/// settings: Settings {
/// active: true,
/// marketing: true,
/// },
/// },
/// ])
/// .await?;
/// #
/// # Ok(())
/// # }
/// ```
pub fn insert<R>(&self, resource: impl opt::IntoResource<R>) -> Insert<C, R> {
Insert {
client: Cow::Borrowed(self),
resource: resource.into_resource(),
response_type: PhantomData,
}
}
/// Updates all records in a table, or a specific record /// Updates all records in a table, or a specific record
/// ///
/// # Examples /// # Examples

View file

@ -138,6 +138,12 @@ async fn api() {
DB.update(USER).range("jane".."john").content(User::default()).await.unwrap(); DB.update(USER).range("jane".."john").content(User::default()).await.unwrap();
let _: Option<User> = DB.update((USER, "john")).content(User::default()).await.unwrap(); let _: Option<User> = DB.update((USER, "john")).content(User::default()).await.unwrap();
// insert
let _: Vec<User> = DB.insert(USER).await.unwrap();
let _: Option<User> = DB.insert((USER, "john")).await.unwrap();
let _: Vec<User> = DB.insert(USER).content(User::default()).await.unwrap();
let _: Option<User> = DB.insert((USER, "john")).content(User::default()).await.unwrap();
// merge // merge
let _: Vec<User> = DB.update(USER).merge(User::default()).await.unwrap(); let _: Vec<User> = DB.update(USER).merge(User::default()).await.unwrap();
let _: Vec<User> = DB.update(USER).range("jane".."john").merge(User::default()).await.unwrap(); let _: Vec<User> = DB.update(USER).range("jane".."john").merge(User::default()).await.unwrap();

View file

@ -78,6 +78,15 @@ pub(super) fn mock(route_rx: Receiver<Option<Route>>) {
} }
_ => unreachable!(), _ => unreachable!(),
}, },
Method::Insert => match &params[..] {
[Value::Table(..), Value::Array(..)] => {
Ok(DbResponse::Other(Value::Array(Array(Vec::new()))))
}
[Value::Table(..), _] => {
Ok(DbResponse::Other(to_value(User::default()).unwrap()))
}
_ => unreachable!(),
},
Method::Export | Method::Import => match param.file { Method::Export | Method::Import => match param.file {
Some(_) => Ok(DbResponse::Other(Value::None)), Some(_) => Ok(DbResponse::Other(Value::None)),
_ => unreachable!(), _ => unreachable!(),

View file

@ -486,6 +486,44 @@ async fn create_record_with_id_with_content() {
assert_eq!(value.record(), thing("user:jane").ok()); assert_eq!(value.record(), thing("user:jane").ok());
} }
#[test_log::test(tokio::test)]
async fn insert_table() {
let (permit, db) = new_db().await;
db.use_ns(NS).use_db(Ulid::new().to_string()).await.unwrap();
drop(permit);
let table = "user";
let _: Vec<RecordId> = db.insert(table).await.unwrap();
let _: Vec<RecordId> = db.insert(table).content(json!({ "foo": "bar" })).await.unwrap();
let _: Vec<RecordId> = db.insert(table).content(json!([{ "foo": "bar" }])).await.unwrap();
let _: Value = db.insert(Resource::from(table)).await.unwrap();
let _: Value = db.insert(Resource::from(table)).content(json!({ "foo": "bar" })).await.unwrap();
let _: Value =
db.insert(Resource::from(table)).content(json!([{ "foo": "bar" }])).await.unwrap();
let users: Vec<RecordId> = db.insert(table).await.unwrap();
assert!(!users.is_empty());
}
#[test_log::test(tokio::test)]
async fn insert_thing() {
let (permit, db) = new_db().await;
db.use_ns(NS).use_db(Ulid::new().to_string()).await.unwrap();
drop(permit);
let table = "user";
let _: Option<RecordId> = db.insert((table, "user1")).await.unwrap();
let _: Option<RecordId> =
db.insert((table, "user1")).content(json!({ "foo": "bar" })).await.unwrap();
let _: Value = db.insert(Resource::from((table, "user2"))).await.unwrap();
let _: Value =
db.insert(Resource::from((table, "user2"))).content(json!({ "foo": "bar" })).await.unwrap();
let user: Option<RecordId> = db.insert((table, "user3")).await.unwrap();
assert_eq!(
user,
Some(RecordId {
id: thing("user:user3").unwrap(),
})
);
}
#[test_log::test(tokio::test)] #[test_log::test(tokio::test)]
async fn select_table() { async fn select_table() {
let (permit, db) = new_db().await; let (permit, db) = new_db().await;