Add relation rpc method (#3775)

Co-authored-by: Tobie Morgan Hitchcock <tobie@surrealdb.com>
Co-authored-by: Salvador Girones Gil <salvadorgirones@gmail.com>
Co-authored-by: Mees Delzenne <DelSkayn@users.noreply.github.com>
Co-authored-by: Micha de Vries <micha@devrie.sh>
This commit is contained in:
Raphael Darley 2024-04-18 12:48:12 +02:00 committed by GitHub
parent 52dc064005
commit 31cc0e37e0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 213 additions and 17 deletions

View file

@ -9,6 +9,7 @@ pub trait Take {
fn needs_three(self) -> Result<(Value, Value, Value), RpcError>; fn needs_three(self) -> Result<(Value, Value, Value), RpcError>;
fn needs_one_or_two(self) -> Result<(Value, Value), RpcError>; fn needs_one_or_two(self) -> Result<(Value, Value), RpcError>;
fn needs_one_two_or_three(self) -> Result<(Value, Value, Value), RpcError>; fn needs_one_two_or_three(self) -> Result<(Value, Value, Value), RpcError>;
fn needs_three_or_four(self) -> Result<(Value, Value, Value, Value), RpcError>;
} }
impl Take for Array { impl Take for Array {
@ -71,4 +72,16 @@ impl Take for Array {
(_, _, _) => Ok((Value::None, Value::None, Value::None)), (_, _, _) => Ok((Value::None, Value::None, Value::None)),
} }
} }
/// Convert the array to four arguments
fn needs_three_or_four(self) -> Result<(Value, Value, Value, Value), RpcError> {
if self.len() < 3 || self.len() > 4 {
return Err(RpcError::InvalidParams);
}
let mut x = self.into_iter();
match (x.next(), x.next(), x.next(), x.next()) {
(Some(a), Some(b), Some(c), Some(d)) => Ok((a, b, c, d)),
(Some(a), Some(b), Some(c), None) => Ok((a, b, c, Value::None)),
(_, _, _, _) => Ok((Value::None, Value::None, Value::None, Value::None)),
}
}
} }

View file

@ -1,12 +1,13 @@
use std::collections::BTreeMap; use std::collections::BTreeMap;
use uuid::Uuid;
use crate::{ use crate::{
dbs::{QueryType, Response, Session}, dbs::{QueryType, Response, Session},
kvs::Datastore, kvs::Datastore,
rpc::args::Take, rpc::args::Take,
sql::{Array, Function, Model, Statement, Strand, Value}, sql::{Array, Function, Model, Statement, Strand, Value},
}; };
use uuid::Uuid;
use super::{method::Method, response::Data, rpc_error::RpcError}; use super::{method::Method, response::Data, rpc_error::RpcError};
@ -417,6 +418,41 @@ pub trait RpcContext {
Ok(res) Ok(res)
} }
// ------------------------------
// Methods for relating
// ------------------------------
async fn relate(&self, params: Array) -> Result<impl Into<Data>, RpcError> {
let Ok((from, kind, to, data)) = params.needs_three_or_four() else {
return Err(RpcError::InvalidParams);
};
// Return a single result?
let one = kind.is_thing();
// Specify the SQL query string
let sql = if data.is_none_or_null() {
"RELATE $from->$kind->$to"
} else {
"RELATE $from->$kind->$to CONTENT $data"
};
// Specify the query parameters
let var = Some(map! {
String::from("from") => from,
String::from("kind") => kind.could_be_table(),
String::from("to") => to,
String::from("data") => data,
=> &self.vars()
});
// Execute the query on the database
let mut res = self.kvs().execute(sql, self.session(), var).await?;
// Extract the first query result
let res = match one {
true => res.remove(0).result?.first(),
false => res.remove(0).result?,
};
// Return the result to the client
Ok(res)
}
// ------------------------------ // ------------------------------
// Methods for deleting // Methods for deleting
// ------------------------------ // ------------------------------
@ -482,15 +518,6 @@ pub trait RpcContext {
self.query_inner(query, vars).await self.query_inner(query, vars).await
} }
// ------------------------------
// Methods for relating
// ------------------------------
async fn relate(&self, _params: Array) -> Result<impl Into<Data>, RpcError> {
let out: Result<Value, RpcError> = Err(RpcError::MethodNotFound);
out
}
// ------------------------------ // ------------------------------
// Methods for running functions // Methods for running functions
// ------------------------------ // ------------------------------
@ -528,9 +555,7 @@ pub trait RpcContext {
.kvs() .kvs()
.process(Statement::Value(func).into(), self.session(), Some(self.vars().clone())) .process(Statement::Value(func).into(), self.session(), Some(self.vars().clone()))
.await?; .await?;
let out = res.remove(0).result?; res.remove(0).result.map_err(Into::into)
Ok(out)
} }
// ------------------------------ // ------------------------------

View file

@ -132,7 +132,7 @@ impl RelateStatement {
for w in with.iter() { for w in with.iter() {
let f = f.clone(); let f = f.clone();
let w = w.clone(); let w = w.clone();
match &self.kind { match &self.kind.compute(ctx, opt, txn, doc).await? {
// The relation has a specific record id // The relation has a specific record id
Value::Thing(id) => i.ingest(Iterable::Relatable(f, id.to_owned(), w)), Value::Thing(id) => i.ingest(Iterable::Relatable(f, id.to_owned(), w)),
// The relation does not have a specific record id // The relation does not have a specific record id
@ -149,7 +149,11 @@ impl RelateStatement {
None => i.ingest(Iterable::Relatable(f, tb.generate(), w)), None => i.ingest(Iterable::Relatable(f, tb.generate(), w)),
}, },
// The relation can not be any other type // The relation can not be any other type
_ => unreachable!(), v => {
return Err(Error::RelateStatement {
value: v.to_string(),
})
}
}; };
} }
} }

View file

@ -41,7 +41,7 @@ impl Parser<'_> {
t!("<-") => false, t!("<-") => false,
x => unexpected!(self, x, "a relation arrow"), x => unexpected!(self, x, "a relation arrow"),
}; };
let kind = self.parse_thing_or_table(stk).await?; let kind = self.parse_relate_kind(stk).await?;
if is_o { if is_o {
expected!(self, t!("->")) expected!(self, t!("->"))
} else { } else {
@ -55,6 +55,20 @@ impl Parser<'_> {
} }
} }
pub async fn parse_relate_kind(&mut self, ctx: &mut Stk) -> ParseResult<Value> {
match self.peek_kind() {
t!("$param") => self.next_token_value().map(Value::Param),
t!("(") => {
let span = self.pop_peek().span;
let res = self
.parse_inner_subquery(ctx, Some(span))
.await
.map(|x| Value::Subquery(Box::new(x)))?;
Ok(res)
}
_ => self.parse_thing_or_table(ctx).await,
}
}
pub async fn parse_relate_value(&mut self, ctx: &mut Stk) -> ParseResult<Value> { pub async fn parse_relate_value(&mut self, ctx: &mut Stk) -> ParseResult<Value> {
match self.peek_kind() { match self.peek_kind() {
t!("[") => { t!("[") => {

View file

@ -1,6 +1,7 @@
mod parse; mod parse;
use parse::Parse; use parse::Parse;
mod helpers; mod helpers;
use helpers::new_ds; use helpers::new_ds;
use surrealdb::dbs::Session; use surrealdb::dbs::Session;
use surrealdb::err::Error; use surrealdb::err::Error;
@ -105,3 +106,77 @@ async fn relate_and_overwrite() -> Result<(), Error> {
// //
Ok(()) Ok(())
} }
#[tokio::test]
async fn relate_with_param_or_subquery() -> Result<(), Error> {
let sql = r#"
LET $tobie = person:tobie;
LET $jaime = person:jaime;
LET $relation = type::table("knows");
RELATE $tobie->$relation->$jaime;
RELATE $tobie->(type::table("knows"))->$jaime;
LET $relation = type::thing("knows:foo");
RELATE $tobie->$relation->$jaime;
RELATE $tobie->(type::thing("knows:bar"))->$jaime;
"#;
let dbs = new_ds().await?;
let ses = Session::owner().with_ns("test").with_db("test");
let res = &mut dbs.execute(sql, &ses, None).await?;
assert_eq!(res.len(), 8);
//
for _ in 0..3 {
let tmp = res.remove(0).result?;
let val = Value::None;
assert_eq!(tmp, val);
}
//
for _ in 0..2 {
let tmp = res.remove(0).result?;
let Value::Array(v) = tmp else {
panic!("response should be array:{tmp:?}")
};
assert_eq!(v.len(), 1);
let tmp = v.into_iter().next().unwrap();
let Value::Object(o) = tmp else {
panic!("should be object {tmp:?}")
};
assert_eq!(o.get("in").unwrap(), &Value::parse("person:tobie"));
assert_eq!(o.get("out").unwrap(), &Value::parse("person:jaime"));
let id = o.get("id").unwrap();
let Value::Thing(t) = id else {
panic!("should be thing {id:?}")
};
assert_eq!(t.tb, "knows");
}
//
let tmp = res.remove(0).result?;
let val = Value::None;
assert_eq!(tmp, val);
//
let tmp = res.remove(0).result?;
let val = Value::parse(
"[
{
id: knows:foo,
in: person:tobie,
out: person:jaime,
}
]",
);
assert_eq!(tmp, val);
//
let tmp = res.remove(0).result?;
let val = Value::parse(
"[
{
id: knows:bar,
in: person:tobie,
out: person:jaime,
}
]",
);
//
assert_eq!(tmp, val);
Ok(())
}

View file

@ -421,6 +421,7 @@ impl Socket {
} }
} }
} }
pub async fn send_message_run( pub async fn send_message_run(
&mut self, &mut self,
fn_name: &str, fn_name: &str,
@ -449,4 +450,38 @@ impl Socket {
} }
} }
} }
pub async fn send_message_relate(
&mut self,
from: serde_json::Value,
kind: serde_json::Value,
with: serde_json::Value,
content: Option<serde_json::Value>,
) -> Result<serde_json::Value> {
// Send message and receive response
let msg = if let Some(content) = content {
self.send_request("relate", json!([from, kind, with, content])).await?
} else {
self.send_request("relate", json!([from, kind, with])).await?
};
// Check response message structure
match msg.as_object() {
Some(obj) if obj.keys().all(|k| ["id", "error"].contains(&k.as_str())) => {
Err(format!("unexpected error from query request: {:?}", obj.get("error")).into())
}
Some(obj) if obj.keys().all(|k| ["id", "result"].contains(&k.as_str())) => Ok(obj
.get("result")
.ok_or(TestError::AssertionError {
message: format!(
"expected a result from the received object, got this instead: {:?}",
obj
),
})?
.to_owned()),
_ => {
error!("{:?}", msg.as_object().unwrap().keys().collect::<Vec<_>>());
Err(format!("unexpected response: {:?}", msg).into())
}
}
}
} }

View file

@ -1498,7 +1498,7 @@ async fn run_functions() {
assert!(matches!(res, serde_json::Value::String(s) if &s == "fn::bar called with: string_val")); assert!(matches!(res, serde_json::Value::String(s) if &s == "fn::bar called with: string_val"));
// normal functions // normal functions
let res = socket.send_message_run("math::abs", None, vec![(-42).into()]).await.unwrap(); let res = socket.send_message_run("math::abs", None, vec![42.into()]).await.unwrap();
assert!(matches!(res, serde_json::Value::Number(n) if n.as_u64() == Some(42))); assert!(matches!(res, serde_json::Value::Number(n) if n.as_u64() == Some(42)));
let res = socket let res = socket
.send_message_run("math::max", None, vec![vec![1, 2, 3, 4, 5, 6].into()]) .send_message_run("math::max", None, vec![vec![1, 2, 3, 4, 5, 6].into()])
@ -1509,3 +1509,33 @@ async fn run_functions() {
// Test passed // Test passed
server.finish().unwrap(); server.finish().unwrap();
} }
#[test(tokio::test)]
async fn relate_rpc() {
// Setup database server
let (addr, mut server) = common::start_server_with_defaults().await.unwrap();
// Connect to WebSocket
let mut socket = Socket::connect(&addr, SERVER, FORMAT).await.unwrap();
// Authenticate the connection
socket.send_message_signin(USER, PASS, None, None, None).await.unwrap();
// Specify a namespace and database
socket.send_message_use(Some(NS), Some(DB)).await.unwrap();
// create records and relate
socket.send_message_query("CREATE foo:a, foo:b").await.unwrap();
socket
.send_message_relate("foo:a".into(), "bar".into(), "foo:b".into(), Some(json!({"val": 42})))
.await
.unwrap();
// test
let mut res = socket.send_message_query("RETURN foo:a->bar.val").await.unwrap();
let expected = json!(42);
assert_eq!(res.remove(0)["result"], expected);
let mut res = socket.send_message_query("RETURN foo:a->bar->foo").await.unwrap();
let expected = json!(["foo:b"]);
assert_eq!(res.remove(0)["result"], expected);
// Test passed
server.finish().unwrap();
}