Improve RETURN logic (#4298)

Co-authored-by: Mees Delzenne <DelSkayn@users.noreply.github.com>
This commit is contained in:
Micha de Vries 2024-07-04 17:05:47 +02:00 committed by GitHub
parent a926525a83
commit 12ddb94508
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 124 additions and 16 deletions

View file

@ -226,6 +226,9 @@ impl<'a> Executor<'a> {
// Initialise array of responses
let mut out: Vec<Response> = vec![];
let mut live_queries: Vec<TrackedResult> = vec![];
// Do we fast-forward a transaction?
// Set to true when we encounter a return statement in a transaction
let mut ff_txn = false;
// Process all statements in query
for stm in qry.into_iter() {
// Log the statement
@ -242,6 +245,13 @@ impl<'a> Executor<'a> {
let is_stm_kill = matches!(stm, Statement::Kill(_));
// Check if this is a RETURN statement
let is_stm_output = matches!(stm, Statement::Output(_));
// Has this statement returned a value
let mut has_returned = false;
// Do we skip this statement?
if ff_txn && !matches!(stm, Statement::Commit(_) | Statement::Cancel(_)) {
debug!("Skipping statement due to fast forwarded transaction");
continue;
}
// Process a single statement
let res = match stm {
// Specify runtime options
@ -289,6 +299,7 @@ impl<'a> Executor<'a> {
out.append(&mut buf);
debug_assert!(self.txn.is_none(), "commit(true) should have unset txn");
self.txn = None;
ff_txn = false;
continue;
}
// Switch to a different NS or DB
@ -406,10 +417,21 @@ impl<'a> Executor<'a> {
.await
}
};
// Check if this is a RETURN statement
let can_return =
matches!(stm, Statement::Output(_) | Statement::Value(_));
// Catch global timeout
let res = match ctx.is_timedout() {
true => Err(Error::QueryTimedout),
false => res,
false => match res {
Err(Error::Return {
value,
}) if can_return => {
has_returned = true;
Ok(value)
}
res => res,
},
};
// Finalise transaction and return the result.
if res.is_ok() && stm.writeable() {
@ -470,8 +492,9 @@ impl<'a> Executor<'a> {
};
// Output the response
if self.txn.is_some() {
if is_stm_output {
if is_stm_output || has_returned {
buf.clear();
ff_txn = true;
}
buf.push(res);
} else {

View file

@ -1010,6 +1010,13 @@ pub enum Error {
TbInvalid {
value: String,
},
/// This error is used for breaking execution when a value is returned
#[doc(hidden)]
#[error("Return statement has been reached")]
Return {
value: Value,
},
}
impl From<Error> for String {

View file

@ -167,7 +167,7 @@ pub async fn db_access(
sess.or.clone_from(&session.or);
// Compute the value with the params
match kvs.evaluate(au, &sess, None).await {
Ok(val) => match val.record() {
Ok(val) => match val.record() {
Some(id) => {
// Update rid with result from AUTHENTICATE clause
rid = id;

View file

@ -106,7 +106,7 @@ pub async fn db_access(
sess.or.clone_from(&session.or);
// Compute the value with the params
match kvs.evaluate(au, &sess, None).await {
Ok(val) => match val.record() {
Ok(val) => match val.record() {
Some(id) => {
// Update rid with result from AUTHENTICATE clause
rid = id;

View file

@ -112,16 +112,15 @@ impl Block {
v.compute(&ctx, opt, doc).await?;
}
Entry::Output(v) => {
// Return the RETURN value
return v.compute(stk, &ctx, opt, doc).await;
v.compute(stk, &ctx, opt, doc).await?;
}
Entry::Value(v) => {
if i == self.len() - 1 {
// If the last entry then return the value
return v.compute(stk, &ctx, opt, doc).await;
return v.compute_unbordered(stk, &ctx, opt, doc).await;
} else {
// Otherwise just process the value
v.compute(stk, &ctx, opt, doc).await?;
v.compute_unbordered(stk, &ctx, opt, doc).await?;
}
}
}

View file

@ -274,7 +274,12 @@ impl Function {
ctx.add_value(name.to_raw(), val.coerce_to(kind)?);
}
// Run the custom function
stk.run(|stk| val.block.compute(stk, &ctx, opt, doc)).await
match stk.run(|stk| val.block.compute(stk, &ctx, opt, doc)).await {
Err(Error::Return {
value,
}) => Ok(value),
res => res,
}
}
#[allow(unused_variables)]
Self::Script(s, x) => {

View file

@ -175,7 +175,7 @@ impl Statement {
// Ensure futures are processed
let opt = &opt.new_with_futures(true);
// Process the output value
v.compute(stk, ctx, opt, doc).await
v.compute_unbordered(stk, ctx, opt, doc).await
}
_ => unreachable!(),
}

View file

@ -48,11 +48,11 @@ impl IfelseStatement {
for (ref cond, ref then) in &self.exprs {
let v = cond.compute(stk, ctx, opt, doc).await?;
if v.is_truthy() {
return then.compute(stk, ctx, opt, doc).await;
return then.compute_unbordered(stk, ctx, opt, doc).await;
}
}
match self.close {
Some(ref v) => v.compute(stk, ctx, opt, doc).await,
Some(ref v) => v.compute_unbordered(stk, ctx, opt, doc).await,
None => Ok(Value::None),
}
}

View file

@ -35,15 +35,17 @@ impl OutputStatement {
// Ensure futures are processed
let opt = &opt.new_with_futures(true);
// Process the output value
let mut val = self.what.compute(stk, ctx, opt, doc).await?;
let mut value = self.what.compute(stk, ctx, opt, doc).await?;
// Fetch any
if let Some(fetchs) = &self.fetch {
for fetch in fetchs.iter() {
val.fetch(stk, ctx, opt, fetch).await?;
value.fetch(stk, ctx, opt, fetch).await?;
}
}
//
Ok(val)
Err(Error::Return {
value,
})
}
}

View file

@ -2642,7 +2642,7 @@ impl Value {
/// Process this type returning a computed simple Value
///
/// Is used recursively.
pub(crate) async fn compute(
pub(crate) async fn compute_unbordered(
&self,
stk: &mut Stk,
ctx: &Context<'_>,
@ -2670,6 +2670,21 @@ impl Value {
_ => Ok(self.to_owned()),
}
}
pub(crate) async fn compute(
&self,
stk: &mut Stk,
ctx: &Context<'_>,
opt: &Options,
doc: Option<&CursorDoc<'_>>,
) -> Result<Value, Error> {
match self.compute_unbordered(stk, ctx, opt, doc).await {
Err(Error::Return {
value,
}) => Ok(value),
res => res,
}
}
}
// ------------------------------

View file

@ -159,3 +159,60 @@ async fn return_subquery_only() -> Result<(), Error> {
//
Ok(())
}
#[tokio::test]
async fn return_breaks_nested_execution() -> Result<(), Error> {
let sql = "
DEFINE FUNCTION fn::test() {
{
RETURN 1;
};
RETURN 2;
};
RETURN fn::test();
BEGIN;
CREATE ONLY a:1;
RETURN 1;
CREATE ONLY a:2;
COMMIT;
{
RETURN 1;
};
SELECT VALUE {
IF $this % 2 == 0 {
RETURN $this;
} ELSE {
RETURN $this + 1;
}
} FROM [1, 2, 3, 4];
";
let dbs = new_ds().await?;
let ses = Session::owner().with_ns("test").with_db("test");
let res = &mut dbs.execute(sql, &ses, None).await?;
assert_eq!(res.len(), 5);
//
let tmp = res.remove(0).result;
assert!(tmp.is_ok());
//
let tmp = res.remove(0).result?;
let val = Value::parse("1");
assert_eq!(tmp, val);
//
let tmp = res.remove(0).result?;
let val = Value::parse("1");
assert_eq!(tmp, val);
//
let tmp = res.remove(0).result?;
let val = Value::parse("1");
assert_eq!(tmp, val);
//
let tmp = res.remove(0).result?;
let val = Value::parse("[2, 2, 4, 4]");
assert_eq!(tmp, val);
//
Ok(())
}