diff --git a/lib/src/cnf/mod.rs b/lib/src/cnf/mod.rs index 7176ce75..f380e5a0 100644 --- a/lib/src/cnf/mod.rs +++ b/lib/src/cnf/mod.rs @@ -2,8 +2,8 @@ // Specifies how many concurrent jobs can be buffered in the worker channel. pub const MAX_CONCURRENT_TASKS: usize = 64; -// Specifies how many subqueries will be processed recursively before the query fails. -pub const MAX_RECURSIVE_QUERIES: usize = 16; +// Specifies how deep various forms of computation will go before the query fails. +pub const MAX_COMPUTATION_DEPTH: u8 = 30; // The characters which are supported in server record IDs. pub const ID_CHARS: [char; 36] = [ diff --git a/lib/src/dbs/iterator.rs b/lib/src/dbs/iterator.rs index 21be6baa..c884f971 100644 --- a/lib/src/dbs/iterator.rs +++ b/lib/src/dbs/iterator.rs @@ -325,6 +325,8 @@ impl Iterator { txn: &Transaction, stm: &Statement<'_>, ) -> Result<(), Error> { + // Prevent deep recursion + let opt = &opt.dive(4)?; // Process all prepared values for v in mem::take(&mut self.entries) { v.iterate(ctx, opt, txn, stm, self).await?; @@ -343,6 +345,9 @@ impl Iterator { txn: &Transaction, stm: &Statement<'_>, ) -> Result<(), Error> { + // Prevent deep recursion + let opt = &opt.dive(4)?; + // Check if iterating in parallel match stm.parallel() { // Run statements sequentially false => { diff --git a/lib/src/dbs/options.rs b/lib/src/dbs/options.rs index 86c677b8..39d57274 100644 --- a/lib/src/dbs/options.rs +++ b/lib/src/dbs/options.rs @@ -19,8 +19,8 @@ pub struct Options { pub db: Option>, // Connection authentication data pub auth: Arc, - // How many subqueries have we gone into? - pub dive: usize, + // Approximately how large is the current call stack? + dive: u8, // Whether live queries are allowed? pub live: bool, // Should we debug query response SQL? @@ -80,18 +80,22 @@ impl Options { self.db.as_ref().unwrap() } - /// Create a new Options object for a subquery - pub fn dive(&self) -> Result { - if self.dive < cnf::MAX_RECURSIVE_QUERIES { + /// Create a new Options object for a function/subquery/future/etc. + /// + /// The parameter is the approximate cost of the operation (more concretely, the size of the + /// stack frame it uses relative to a simple function call). When in doubt, use a value of 1. + pub fn dive(&self, cost: u8) -> Result { + let dive = self.dive.saturating_add(cost); + if dive <= cnf::MAX_COMPUTATION_DEPTH { Ok(Options { auth: self.auth.clone(), ns: self.ns.clone(), db: self.db.clone(), - dive: self.dive + 1, + dive, ..*self }) } else { - Err(Error::TooManySubqueries) + Err(Error::ComputationDepthExceeded) } } diff --git a/lib/src/err/mod.rs b/lib/src/err/mod.rs index dbd880ce..a71a2098 100644 --- a/lib/src/err/mod.rs +++ b/lib/src/err/mod.rs @@ -152,9 +152,9 @@ pub enum Error { #[error("Unable to perform the realtime query")] RealtimeDisabled, - /// Too many recursive subqueries have been processed - #[error("Too many recursive subqueries have been processed")] - TooManySubqueries, + /// Reached excessive computation depth due to functions, subqueries, or futures + #[error("Reached excessive computation depth due to functions, subqueries, or futures")] + ComputationDepthExceeded, /// Can not execute CREATE query using the specified value #[error("Can not execute CREATE query using value '{value}'")] diff --git a/lib/src/sql/function.rs b/lib/src/sql/function.rs index e4a6a9dd..99879ee1 100644 --- a/lib/src/sql/function.rs +++ b/lib/src/sql/function.rs @@ -112,6 +112,9 @@ impl Function { txn: &Transaction, doc: Option<&Value>, ) -> Result { + // Prevent long function chains + let opt = &opt.dive(1)?; + // Process the function type match self { Self::Future(v) => match opt.futures { true => { diff --git a/lib/src/sql/subquery.rs b/lib/src/sql/subquery.rs index ed615e38..f5fa2ec6 100644 --- a/lib/src/sql/subquery.rs +++ b/lib/src/sql/subquery.rs @@ -60,12 +60,13 @@ impl Subquery { txn: &Transaction, doc: Option<&Value>, ) -> Result { + // Prevent deep recursion + let opt = &opt.dive(2)?; + // Process the subquery match self { Self::Value(ref v) => v.compute(ctx, opt, txn, doc).await, Self::Ifelse(ref v) => v.compute(ctx, opt, txn, doc).await, Self::Select(ref v) => { - // Duplicate options - let opt = opt.dive()?; // Duplicate context let mut ctx = Context::new(ctx); // Add parent document @@ -73,22 +74,20 @@ impl Subquery { ctx.add_value("parent".into(), doc); } // Process subquery - let res = v.compute(&ctx, &opt, txn, doc).await?; + let res = v.compute(&ctx, opt, txn, doc).await?; // Process result match v.limit() { 1 => match v.expr.single() { - Some(v) => res.first().get(&ctx, &opt, txn, &v).await, + Some(v) => res.first().get(&ctx, opt, txn, &v).await, None => res.first().ok(), }, _ => match v.expr.single() { - Some(v) => res.get(&ctx, &opt, txn, &v).await, + Some(v) => res.get(&ctx, opt, txn, &v).await, None => res.ok(), }, } } Self::Create(ref v) => { - // Duplicate options - let opt = opt.dive()?; // Duplicate context let mut ctx = Context::new(ctx); // Add parent document @@ -96,7 +95,7 @@ impl Subquery { ctx.add_value("parent".into(), doc); } // Process subquery - match v.compute(&ctx, &opt, txn, doc).await? { + match v.compute(&ctx, opt, txn, doc).await? { Value::Array(mut v) => match v.len() { 1 => Ok(v.remove(0).pick(ID.as_ref())), _ => Ok(Value::from(v).pick(ID.as_ref())), @@ -105,8 +104,6 @@ impl Subquery { } } Self::Update(ref v) => { - // Duplicate options - let opt = opt.dive()?; // Duplicate context let mut ctx = Context::new(ctx); // Add parent document @@ -114,7 +111,7 @@ impl Subquery { ctx.add_value("parent".into(), doc); } // Process subquery - match v.compute(&ctx, &opt, txn, doc).await? { + match v.compute(&ctx, opt, txn, doc).await? { Value::Array(mut v) => match v.len() { 1 => Ok(v.remove(0).pick(ID.as_ref())), _ => Ok(Value::from(v).pick(ID.as_ref())), @@ -123,8 +120,6 @@ impl Subquery { } } Self::Delete(ref v) => { - // Duplicate options - let opt = opt.dive()?; // Duplicate context let mut ctx = Context::new(ctx); // Add parent document @@ -132,7 +127,7 @@ impl Subquery { ctx.add_value("parent".into(), doc); } // Process subquery - match v.compute(&ctx, &opt, txn, doc).await? { + match v.compute(&ctx, opt, txn, doc).await? { Value::Array(mut v) => match v.len() { 1 => Ok(v.remove(0).pick(ID.as_ref())), _ => Ok(Value::from(v).pick(ID.as_ref())), @@ -141,8 +136,6 @@ impl Subquery { } } Self::Relate(ref v) => { - // Duplicate options - let opt = opt.dive()?; // Duplicate context let mut ctx = Context::new(ctx); // Add parent document @@ -150,7 +143,7 @@ impl Subquery { ctx.add_value("parent".into(), doc); } // Process subquery - match v.compute(&ctx, &opt, txn, doc).await? { + match v.compute(&ctx, opt, txn, doc).await? { Value::Array(mut v) => match v.len() { 1 => Ok(v.remove(0).pick(ID.as_ref())), _ => Ok(Value::from(v).pick(ID.as_ref())), @@ -159,8 +152,6 @@ impl Subquery { } } Self::Insert(ref v) => { - // Duplicate options - let opt = opt.dive()?; // Duplicate context let mut ctx = Context::new(ctx); // Add parent document @@ -168,7 +159,7 @@ impl Subquery { ctx.add_value("parent".into(), doc); } // Process subquery - match v.compute(&ctx, &opt, txn, doc).await? { + match v.compute(&ctx, opt, txn, doc).await? { Value::Array(mut v) => match v.len() { 1 => Ok(v.remove(0).pick(ID.as_ref())), _ => Ok(Value::from(v).pick(ID.as_ref())), diff --git a/lib/tests/complex.rs b/lib/tests/complex.rs new file mode 100644 index 00000000..94ca8232 --- /dev/null +++ b/lib/tests/complex.rs @@ -0,0 +1,235 @@ +mod parse; +use parse::Parse; +use std::future::Future; +use std::thread::Builder; +use surrealdb::sql::Value; +use surrealdb::Datastore; +use surrealdb::Error; +use surrealdb::Session; + +#[test] +fn self_referential_field() -> Result<(), Error> { + // Ensure a good stack size for tests + with_enough_stack(async { + let mut res = run_queries( + " + CREATE pet:dog SET tail = { tail }; + ", + ) + .await?; + // + assert_eq!(res.len(), 1); + // + let tmp = res.next().unwrap(); + assert!(matches!(tmp, Err(Error::ComputationDepthExceeded))); + // + Ok(()) + }) +} + +#[test] +fn cyclic_fields() -> Result<(), Error> { + // Ensure a good stack size for tests + with_enough_stack(async { + let mut res = run_queries( + " + CREATE recycle SET consume = { produce }, produce = { consume }; + ", + ) + .await?; + // + assert_eq!(res.len(), 1); + // + let tmp = res.next().unwrap(); + assert!(matches!(tmp, Err(Error::ComputationDepthExceeded))); + // + Ok(()) + }) +} + +#[test] +fn cyclic_records() -> Result<(), Error> { + // Ensure a good stack size for tests + with_enough_stack(async { + let mut res = run_queries( + " + CREATE thing:one SET friend = { thing:two.friend }; + CREATE thing:two SET friend = { thing:one.friend }; + ", + ) + .await?; + // + assert_eq!(res.len(), 2); + // + let tmp = res.next().unwrap(); + assert!(tmp.is_ok()); + // + let tmp = res.next().unwrap(); + assert!(matches!(tmp, Err(Error::ComputationDepthExceeded))); + // + Ok(()) + }) +} + +#[test] +fn ok_future_graph_subquery_recursion_depth() -> Result<(), Error> { + // Ensure a good stack size for tests + with_enough_stack(async { + let mut res = run_queries( + r#" + CREATE thing:three SET fut = { friends[0].fut }, friends = [thing:four, thing:two]; + CREATE thing:four SET fut = { (friend) }, friend = { 42 }; + CREATE thing:two SET fut = { friend }, friend = { thing:three.fut }; + + CREATE thing:one SET foo = "bar"; + RELATE thing:one->friend->thing:two SET timestamp = time::now(); + + CREATE thing:zero SET foo = "baz"; + RELATE thing:zero->enemy->thing:one SET timestamp = time::now(); + + SELECT * FROM (SELECT * FROM (SELECT ->enemy->thing->friend->thing.fut as fut FROM thing:zero)); + "#, + ) + .await?; + // + assert_eq!(res.len(), 8); + // + for i in 0..7 { + let tmp = res.next().unwrap(); + assert!(tmp.is_ok(), "Statement {} resulted in {:?}", i, tmp); + } + // + let tmp = res.next().unwrap()?; + let val = Value::parse("[ [42] ]"); + assert_eq!(tmp, val); + // + Ok(()) + }) +} + +#[test] +fn ok_graph_traversal_depth() -> Result<(), Error> { + // Build the SQL traversal query + fn graph_traversal(n: usize) -> String { + let mut ret = String::from("CREATE node:0;\n"); + for i in 1..=n { + let prev = i - 1; + ret.push_str(&format!("CREATE node:{i};\n")); + ret.push_str(&format!("RELATE node:{prev}->edge{i}->node:{i};\n")); + } + ret.push_str("SELECT "); + for i in 1..=n { + ret.push_str(&format!("->edge{i}->node")); + } + ret.push_str(" AS res FROM node:0;\n"); + ret + } + // Test different traveral depths + for n in 1..=40 { + // Ensure a good stack size for tests + with_enough_stack(async move { + // Run the graph traversal queries + let mut res = run_queries(&graph_traversal(n)).await?; + // Remove the last result + let tmp = res.next_back().unwrap(); + // Check all other queries + assert!(res.all(|r| r.is_ok())); + // + match tmp { + Ok(res) => { + let val = Value::parse(&format!( + "[ + {{ + res: [node:{n}], + }} + ]" + )); + assert_eq!(res, val); + } + Err(res) => { + assert!(matches!(res, Error::ComputationDepthExceeded)); + assert!(n > 10, "Max traversals: {}", n - 1); + } + } + + Ok(()) + }) + .unwrap(); + } + + Ok(()) +} + +#[test] +fn ok_cast_chain_depth() -> Result<(), Error> { + // Ensure a good stack size for tests + with_enough_stack(async { + // Run a chasting query which succeeds + let mut res = run_queries(&cast_chain(10)).await?; + // + assert_eq!(res.len(), 1); + // + let tmp = res.next().unwrap()?; + let val = Value::from(vec![Value::from(5)]); + assert_eq!(tmp, val); + // + Ok(()) + }) +} + +#[test] +fn excessive_cast_chain_depth() -> Result<(), Error> { + // Ensure a good stack size for tests + with_enough_stack(async { + // Run a casting query which will fail + let mut res = run_queries(&cast_chain(35)).await?; + // + assert_eq!(res.len(), 1); + // + let tmp = res.next().unwrap(); + assert!(matches!(tmp, Err(Error::ComputationDepthExceeded))); + // + Ok(()) + }) +} + +async fn run_queries( + sql: &str, +) -> Result< + impl Iterator> + ExactSizeIterator + DoubleEndedIterator + 'static, + Error, +> { + let dbs = Datastore::new("memory").await?; + let ses = Session::for_kv().with_ns("test").with_db("test"); + dbs.execute(&sql, &ses, None, false).await.map(|v| v.into_iter().map(|res| res.result)) +} + +fn with_enough_stack( + fut: impl Future> + Send + 'static, +) -> Result<(), Error> { + #[allow(unused_mut)] + let mut builder = Builder::new(); + + #[cfg(debug_assertions)] + { + builder = builder.stack_size(16_000_000); + } + + builder + .spawn(|| { + let runtime = tokio::runtime::Builder::new_current_thread().build().unwrap(); + runtime.block_on(fut) + }) + .unwrap() + .join() + .unwrap() +} + +fn cast_chain(n: usize) -> String { + let mut sql = String::from("SELECT * FROM "); + for _ in 0..n { + sql.push_str(""); + } + sql.push_str("5;"); + sql +}