diff --git a/core/src/rpc/method.rs b/core/src/rpc/method.rs index a7263d47..054e584f 100644 --- a/core/src/rpc/method.rs +++ b/core/src/rpc/method.rs @@ -25,6 +25,7 @@ pub enum Method { Relate, Run, GraphQL, + InsertRelation, } impl Method { @@ -57,6 +58,7 @@ impl Method { "relate" => Self::Relate, "run" => Self::Run, "graphql" => Self::GraphQL, + "insert_relation" => Self::InsertRelation, _ => Self::Unknown, } } @@ -90,6 +92,7 @@ impl Method { Self::Relate => "relate", Self::Run => "run", Self::GraphQL => "graphql", + Self::InsertRelation => "insert_relation", } } } @@ -115,6 +118,7 @@ impl Method { | Method::Delete | Method::Version | Method::Query | Method::Relate | Method::Run | Method::GraphQL + | Method::InsertRelation | Method::Unknown ) } diff --git a/core/src/rpc/rpc_context.rs b/core/src/rpc/rpc_context.rs index 61c8271a..bdc95a80 100644 --- a/core/src/rpc/rpc_context.rs +++ b/core/src/rpc/rpc_context.rs @@ -68,6 +68,9 @@ pub trait RpcContext { Method::Relate => self.relate(params).await.map(Into::into).map_err(Into::into), Method::Run => self.run(params).await.map(Into::into).map_err(Into::into), Method::GraphQL => self.graphql(params).await.map(Into::into).map_err(Into::into), + Method::InsertRelation => { + self.insert_relation(params).await.map(Into::into).map_err(Into::into) + } Method::Unknown => Err(RpcError::MethodNotFound), } } @@ -89,6 +92,9 @@ pub trait RpcContext { Method::Relate => self.relate(params).await.map(Into::into).map_err(Into::into), Method::Run => self.run(params).await.map(Into::into).map_err(Into::into), Method::GraphQL => self.graphql(params).await.map(Into::into).map_err(Into::into), + Method::InsertRelation => { + self.insert_relation(params).await.map(Into::into).map_err(Into::into) + } Method::Unknown => Err(RpcError::MethodNotFound), _ => Err(RpcError::MethodNotFound), } @@ -326,6 +332,41 @@ pub trait RpcContext { Ok(res.into()) } + async fn insert_relation(&self, params: Array) -> Result, RpcError> { + let Ok((what, data)) = params.needs_two() else { + return Err(RpcError::InvalidParams); + }; + + let one = data.is_single(); + + let mut res = match what { + Value::None | Value::Null => { + let sql = "INSERT RELATION $data RETURN AFTER"; + let vars = Some(map! { + String::from("data") => data, + => &self.vars() + }); + self.kvs().execute(sql, self.session(), vars).await? + } + Value::Table(_) | Value::Strand(_) => { + let sql = "INSERT RELATION INTO $what $data RETURN AFTER"; + let vars = Some(map! { + String::from("data") => data, + String::from("what") => what.could_be_table(), + => &self.vars() + }); + self.kvs().execute(sql, self.session(), vars).await? + } + _ => return Err(RpcError::InvalidParams), + }; + + let res = match one { + true => res.remove(0).result?.first(), + false => res.remove(0).result?, + }; + Ok(res) + } + // ------------------------------ // Methods for creating // ------------------------------ diff --git a/core/src/sql/value/value.rs b/core/src/sql/value/value.rs index fc2c07dc..e0b900f8 100644 --- a/core/src/sql/value/value.rs +++ b/core/src/sql/value/value.rs @@ -1154,6 +1154,14 @@ impl Value { } } + pub fn is_single(&self) -> bool { + match self { + Value::Object(_) => true, + Value::Array(a) if a.len() == 1 => true, + _ => false, + } + } + // ----------------------------------- // Simple conversion of value // ----------------------------------- diff --git a/sdk/src/api/conn/cmd.rs b/sdk/src/api/conn/cmd.rs index ec55ae63..f5bc6142 100644 --- a/sdk/src/api/conn/cmd.rs +++ b/sdk/src/api/conn/cmd.rs @@ -46,6 +46,10 @@ pub(crate) enum Command { what: Option, data: CoreValue, }, + InsertRelation { + what: Option, + data: CoreValue, + }, Patch { what: Resource, data: Option, @@ -214,6 +218,26 @@ impl Command { params: Some(params.into()), } } + Command::InsertRelation { + what, + data, + } => { + let table = match what { + Some(w) => { + let mut tmp = CoreTable::default(); + tmp.0 = w.clone(); + CoreValue::from(tmp) + } + None => CoreValue::None, + }; + let params = vec![table, data]; + + RouterRequest { + id, + method: "insert_relation", + params: Some(params.into()), + } + } Command::Patch { what, data, diff --git a/sdk/src/api/engine/local/mod.rs b/sdk/src/api/engine/local/mod.rs index c2924d89..78adca62 100644 --- a/sdk/src/api/engine/local/mod.rs +++ b/sdk/src/api/engine/local/mod.rs @@ -611,6 +611,25 @@ async fn router( let value = take(one, response).await?; Ok(DbResponse::Other(value)) } + Command::InsertRelation { + what, + data, + } => { + let mut query = Query::default(); + let one = !data.is_array(); + let statement = { + let mut stmt = InsertStatement::default(); + stmt.into = what.map(|w| Table(w).into_core().into()); + stmt.data = Data::SingleExpression(data); + stmt.output = Some(Output::After); + stmt.relation = true; + stmt + }; + query.0 .0 = vec![Statement::Insert(statement)]; + let response = kvs.process(query, &*session, Some(vars.clone())).await?; + let value = take(one, response).await?; + Ok(DbResponse::Other(value)) + } Command::Patch { what, data, diff --git a/sdk/src/api/method/insert.rs b/sdk/src/api/method/insert.rs index 4f91bcf9..f5e1e434 100644 --- a/sdk/src/api/method/insert.rs +++ b/sdk/src/api/method/insert.rs @@ -15,6 +15,8 @@ use std::future::IntoFuture; use std::marker::PhantomData; use surrealdb_core::sql::{to_value as to_core_value, Object as CoreObject, Value as CoreValue}; +use super::insert_relation::InsertRelation; + /// An insert future #[derive(Debug)] #[must_use = "futures do nothing unless you `.await` or poll them"] @@ -155,3 +157,51 @@ where }) } } + +impl<'r, C, R> Insert<'r, C, R> +where + C: Connection, + R: DeserializeOwned, +{ + /// Specifies the data to insert into the table + pub fn relation(self, data: D) -> InsertRelation<'r, C, R> + where + D: Serialize + 'static, + { + InsertRelation::from_closure(self.client, || { + let mut data = to_core_value(data)?; + match self.resource? { + Resource::Table(table) => Ok(Command::InsertRelation { + what: Some(table), + data, + }), + Resource::RecordId(thing) => { + if data.is_array() { + Err(Error::InvalidParams( + "Tried to insert multiple records on a record ID".to_owned(), + ) + .into()) + } else { + let thing = thing.into_inner(); + if let CoreValue::Object(ref mut x) = data { + x.insert("id".to_string(), thing.id.into()); + } + + Ok(Command::InsertRelation { + what: Some(thing.tb), + data, + }) + } + } + Resource::Unspecified => Ok(Command::InsertRelation { + what: None, + data, + }), + Resource::Object(_) => Err(Error::InsertOnObject.into()), + Resource::Array(_) => Err(Error::InsertOnArray.into()), + Resource::Edge(_) => Err(Error::InsertOnEdges.into()), + Resource::Range(_) => Err(Error::InsertOnRange.into()), + } + }) + } +} diff --git a/sdk/src/api/method/insert_relation.rs b/sdk/src/api/method/insert_relation.rs new file mode 100644 index 00000000..2458b1ce --- /dev/null +++ b/sdk/src/api/method/insert_relation.rs @@ -0,0 +1,95 @@ +use crate::api::conn::Command; +use crate::api::Connection; +use crate::api::Result; +use crate::method::OnceLockExt; +use crate::Surreal; +use crate::Value; +use serde::de::DeserializeOwned; +use std::borrow::Cow; +use std::future::IntoFuture; +use std::marker::PhantomData; + +use super::BoxFuture; + +/// An Insert Relation future +/// +/// +#[derive(Debug)] +#[must_use = "futures do nothing unless you `.await` or poll them"] +pub struct InsertRelation<'r, C: Connection, R> { + pub(super) client: Cow<'r, Surreal>, + pub(super) command: Result, + pub(super) response_type: PhantomData, +} + +impl<'r, C, R> InsertRelation<'r, C, R> +where + C: Connection, +{ + pub(crate) fn from_closure(client: Cow<'r, Surreal>, f: F) -> Self + where + F: FnOnce() -> Result, + { + InsertRelation { + client, + command: f(), + response_type: PhantomData, + } + } + + /// Converts to an owned type which can easily be moved to a different thread + pub fn into_owned(self) -> InsertRelation<'static, C, R> { + InsertRelation { + client: Cow::Owned(self.client.into_owned()), + ..self + } + } +} + +macro_rules! into_future { + ($method:ident) => { + fn into_future(self) -> Self::IntoFuture { + let InsertRelation { + client, + command, + .. + } = self; + Box::pin(async move { + let router = client.router.extract()?; + router.$method(command?).await + }) + } + }; +} + +impl<'r, Client> IntoFuture for InsertRelation<'r, Client, Value> +where + Client: Connection, +{ + type Output = Result; + type IntoFuture = BoxFuture<'r, Self::Output>; + + into_future! {execute_value} +} + +impl<'r, Client, R> IntoFuture for InsertRelation<'r, Client, Option> +where + Client: Connection, + R: DeserializeOwned, +{ + type Output = Result>; + type IntoFuture = BoxFuture<'r, Self::Output>; + + into_future! {execute_opt} +} + +impl<'r, Client, R> IntoFuture for InsertRelation<'r, Client, Vec> +where + Client: Connection, + R: DeserializeOwned, +{ + type Output = Result>; + type IntoFuture = BoxFuture<'r, Self::Output>; + + into_future! {execute_vec} +} diff --git a/sdk/src/api/method/mod.rs b/sdk/src/api/method/mod.rs index eda5b20a..a36c3ff4 100644 --- a/sdk/src/api/method/mod.rs +++ b/sdk/src/api/method/mod.rs @@ -35,6 +35,7 @@ mod export; mod health; mod import; mod insert; +mod insert_relation; mod invalidate; mod merge; mod patch; @@ -761,10 +762,10 @@ where /// # Examples /// /// ```no_run - /// use serde::Serialize; + /// use serde::{Serialize, Deserialize}; /// use surrealdb::sql; /// - /// # #[derive(serde::Deserialize)] + /// # #[derive(Deserialize)] /// # struct Person; /// # /// #[derive(Serialize)] @@ -866,6 +867,29 @@ where /// ]) /// .await?; /// + /// + /// // Insert relations + /// #[derive(Serialize, Deserialize)] + /// struct Founded { + /// #[serde(rename = "in")] + /// founder: sql::Thing, + /// #[serde(rename = "out")] + /// company: sql::Thing, + /// } + /// + /// let founded: Vec = db.insert("founded") + /// .relation(vec![ + /// Founded { + /// founder: sql::thing("person:tobie")?, + /// company: sql::thing("company:surrealdb")?, + /// }, + /// Founded { + /// founder: sql::thing("person:jaime")?, + /// company: sql::thing("company:surrealdb")?, + /// }, + /// ]) + /// .await?; + /// /// # /// # Ok(()) /// # } diff --git a/sdk/src/api/method/tests/server.rs b/sdk/src/api/method/tests/server.rs index fd6b6fbf..59481b7d 100644 --- a/sdk/src/api/method/tests/server.rs +++ b/sdk/src/api/method/tests/server.rs @@ -100,6 +100,15 @@ pub(super) fn mock(route_rx: Receiver) { } _ => Ok(DbResponse::Other(to_core_value(User::default()).unwrap())), }, + Command::InsertRelation { + data, + .. + } => match data { + CoreValue::Array(..) => { + Ok(DbResponse::Other(CoreValue::Array(Default::default()))) + } + _ => Ok(DbResponse::Other(to_core_value(User::default()).unwrap())), + }, Command::Run { .. } => Ok(DbResponse::Other(CoreValue::None)), diff --git a/sdk/tests/api/mod.rs b/sdk/tests/api/mod.rs index a987d78f..cb1afe96 100644 --- a/sdk/tests/api/mod.rs +++ b/sdk/tests/api/mod.rs @@ -575,6 +575,23 @@ async fn insert_unspecified() { assert_eq!(tmp, val); } +#[test_log::test(tokio::test)] +async fn insert_relation_table() { + let (permit, db) = new_db().await; + db.use_ns(NS).use_db(Ulid::new().to_string()).await.unwrap(); + drop(permit); + let tmp: Result, _> = db.insert("likes").relation("{}".parse::().unwrap()).await; + tmp.unwrap_err(); + let val = "{in: person:a, out: thing:a}".parse::().unwrap(); + let _: Vec = db.insert("likes").relation(val).await.unwrap(); + + let vals = + "[{in: person:b, out: thing:a}, {id: likes:2, in: person:a, out: thing:a}, {id: hates:3, in: person:a, out: thing:a}]" + .parse::() + .unwrap(); + let _: Vec = db.insert("likes").relation(vals).await.unwrap(); +} + #[test_log::test(tokio::test)] async fn select_table() { let (permit, db) = new_db().await;