Limit computation depth in functions, futures, and subqueries. (#241)
This commit is contained in:
parent
dfa42f1733
commit
88100854a8
7 changed files with 270 additions and 32 deletions
|
@ -2,8 +2,8 @@
|
|||
// Specifies how many concurrent jobs can be buffered in the worker channel.
|
||||
pub const MAX_CONCURRENT_TASKS: usize = 64;
|
||||
|
||||
// Specifies how many subqueries will be processed recursively before the query fails.
|
||||
pub const MAX_RECURSIVE_QUERIES: usize = 16;
|
||||
// Specifies how deep various forms of computation will go before the query fails.
|
||||
pub const MAX_COMPUTATION_DEPTH: u8 = 30;
|
||||
|
||||
// The characters which are supported in server record IDs.
|
||||
pub const ID_CHARS: [char; 36] = [
|
||||
|
|
|
@ -325,6 +325,8 @@ impl Iterator {
|
|||
txn: &Transaction,
|
||||
stm: &Statement<'_>,
|
||||
) -> Result<(), Error> {
|
||||
// Prevent deep recursion
|
||||
let opt = &opt.dive(4)?;
|
||||
// Process all prepared values
|
||||
for v in mem::take(&mut self.entries) {
|
||||
v.iterate(ctx, opt, txn, stm, self).await?;
|
||||
|
@ -343,6 +345,9 @@ impl Iterator {
|
|||
txn: &Transaction,
|
||||
stm: &Statement<'_>,
|
||||
) -> Result<(), Error> {
|
||||
// Prevent deep recursion
|
||||
let opt = &opt.dive(4)?;
|
||||
// Check if iterating in parallel
|
||||
match stm.parallel() {
|
||||
// Run statements sequentially
|
||||
false => {
|
||||
|
|
|
@ -19,8 +19,8 @@ pub struct Options {
|
|||
pub db: Option<Arc<str>>,
|
||||
// Connection authentication data
|
||||
pub auth: Arc<Auth>,
|
||||
// How many subqueries have we gone into?
|
||||
pub dive: usize,
|
||||
// Approximately how large is the current call stack?
|
||||
dive: u8,
|
||||
// Whether live queries are allowed?
|
||||
pub live: bool,
|
||||
// Should we debug query response SQL?
|
||||
|
@ -80,18 +80,22 @@ impl Options {
|
|||
self.db.as_ref().unwrap()
|
||||
}
|
||||
|
||||
/// Create a new Options object for a subquery
|
||||
pub fn dive(&self) -> Result<Options, Error> {
|
||||
if self.dive < cnf::MAX_RECURSIVE_QUERIES {
|
||||
/// Create a new Options object for a function/subquery/future/etc.
|
||||
///
|
||||
/// The parameter is the approximate cost of the operation (more concretely, the size of the
|
||||
/// stack frame it uses relative to a simple function call). When in doubt, use a value of 1.
|
||||
pub fn dive(&self, cost: u8) -> Result<Options, Error> {
|
||||
let dive = self.dive.saturating_add(cost);
|
||||
if dive <= cnf::MAX_COMPUTATION_DEPTH {
|
||||
Ok(Options {
|
||||
auth: self.auth.clone(),
|
||||
ns: self.ns.clone(),
|
||||
db: self.db.clone(),
|
||||
dive: self.dive + 1,
|
||||
dive,
|
||||
..*self
|
||||
})
|
||||
} else {
|
||||
Err(Error::TooManySubqueries)
|
||||
Err(Error::ComputationDepthExceeded)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -152,9 +152,9 @@ pub enum Error {
|
|||
#[error("Unable to perform the realtime query")]
|
||||
RealtimeDisabled,
|
||||
|
||||
/// Too many recursive subqueries have been processed
|
||||
#[error("Too many recursive subqueries have been processed")]
|
||||
TooManySubqueries,
|
||||
/// Reached excessive computation depth due to functions, subqueries, or futures
|
||||
#[error("Reached excessive computation depth due to functions, subqueries, or futures")]
|
||||
ComputationDepthExceeded,
|
||||
|
||||
/// Can not execute CREATE query using the specified value
|
||||
#[error("Can not execute CREATE query using value '{value}'")]
|
||||
|
|
|
@ -112,6 +112,9 @@ impl Function {
|
|||
txn: &Transaction,
|
||||
doc: Option<&Value>,
|
||||
) -> Result<Value, Error> {
|
||||
// Prevent long function chains
|
||||
let opt = &opt.dive(1)?;
|
||||
// Process the function type
|
||||
match self {
|
||||
Self::Future(v) => match opt.futures {
|
||||
true => {
|
||||
|
|
|
@ -60,12 +60,13 @@ impl Subquery {
|
|||
txn: &Transaction,
|
||||
doc: Option<&Value>,
|
||||
) -> Result<Value, Error> {
|
||||
// Prevent deep recursion
|
||||
let opt = &opt.dive(2)?;
|
||||
// Process the subquery
|
||||
match self {
|
||||
Self::Value(ref v) => v.compute(ctx, opt, txn, doc).await,
|
||||
Self::Ifelse(ref v) => v.compute(ctx, opt, txn, doc).await,
|
||||
Self::Select(ref v) => {
|
||||
// Duplicate options
|
||||
let opt = opt.dive()?;
|
||||
// Duplicate context
|
||||
let mut ctx = Context::new(ctx);
|
||||
// Add parent document
|
||||
|
@ -73,22 +74,20 @@ impl Subquery {
|
|||
ctx.add_value("parent".into(), doc);
|
||||
}
|
||||
// Process subquery
|
||||
let res = v.compute(&ctx, &opt, txn, doc).await?;
|
||||
let res = v.compute(&ctx, opt, txn, doc).await?;
|
||||
// Process result
|
||||
match v.limit() {
|
||||
1 => match v.expr.single() {
|
||||
Some(v) => res.first().get(&ctx, &opt, txn, &v).await,
|
||||
Some(v) => res.first().get(&ctx, opt, txn, &v).await,
|
||||
None => res.first().ok(),
|
||||
},
|
||||
_ => match v.expr.single() {
|
||||
Some(v) => res.get(&ctx, &opt, txn, &v).await,
|
||||
Some(v) => res.get(&ctx, opt, txn, &v).await,
|
||||
None => res.ok(),
|
||||
},
|
||||
}
|
||||
}
|
||||
Self::Create(ref v) => {
|
||||
// Duplicate options
|
||||
let opt = opt.dive()?;
|
||||
// Duplicate context
|
||||
let mut ctx = Context::new(ctx);
|
||||
// Add parent document
|
||||
|
@ -96,7 +95,7 @@ impl Subquery {
|
|||
ctx.add_value("parent".into(), doc);
|
||||
}
|
||||
// Process subquery
|
||||
match v.compute(&ctx, &opt, txn, doc).await? {
|
||||
match v.compute(&ctx, opt, txn, doc).await? {
|
||||
Value::Array(mut v) => match v.len() {
|
||||
1 => Ok(v.remove(0).pick(ID.as_ref())),
|
||||
_ => Ok(Value::from(v).pick(ID.as_ref())),
|
||||
|
@ -105,8 +104,6 @@ impl Subquery {
|
|||
}
|
||||
}
|
||||
Self::Update(ref v) => {
|
||||
// Duplicate options
|
||||
let opt = opt.dive()?;
|
||||
// Duplicate context
|
||||
let mut ctx = Context::new(ctx);
|
||||
// Add parent document
|
||||
|
@ -114,7 +111,7 @@ impl Subquery {
|
|||
ctx.add_value("parent".into(), doc);
|
||||
}
|
||||
// Process subquery
|
||||
match v.compute(&ctx, &opt, txn, doc).await? {
|
||||
match v.compute(&ctx, opt, txn, doc).await? {
|
||||
Value::Array(mut v) => match v.len() {
|
||||
1 => Ok(v.remove(0).pick(ID.as_ref())),
|
||||
_ => Ok(Value::from(v).pick(ID.as_ref())),
|
||||
|
@ -123,8 +120,6 @@ impl Subquery {
|
|||
}
|
||||
}
|
||||
Self::Delete(ref v) => {
|
||||
// Duplicate options
|
||||
let opt = opt.dive()?;
|
||||
// Duplicate context
|
||||
let mut ctx = Context::new(ctx);
|
||||
// Add parent document
|
||||
|
@ -132,7 +127,7 @@ impl Subquery {
|
|||
ctx.add_value("parent".into(), doc);
|
||||
}
|
||||
// Process subquery
|
||||
match v.compute(&ctx, &opt, txn, doc).await? {
|
||||
match v.compute(&ctx, opt, txn, doc).await? {
|
||||
Value::Array(mut v) => match v.len() {
|
||||
1 => Ok(v.remove(0).pick(ID.as_ref())),
|
||||
_ => Ok(Value::from(v).pick(ID.as_ref())),
|
||||
|
@ -141,8 +136,6 @@ impl Subquery {
|
|||
}
|
||||
}
|
||||
Self::Relate(ref v) => {
|
||||
// Duplicate options
|
||||
let opt = opt.dive()?;
|
||||
// Duplicate context
|
||||
let mut ctx = Context::new(ctx);
|
||||
// Add parent document
|
||||
|
@ -150,7 +143,7 @@ impl Subquery {
|
|||
ctx.add_value("parent".into(), doc);
|
||||
}
|
||||
// Process subquery
|
||||
match v.compute(&ctx, &opt, txn, doc).await? {
|
||||
match v.compute(&ctx, opt, txn, doc).await? {
|
||||
Value::Array(mut v) => match v.len() {
|
||||
1 => Ok(v.remove(0).pick(ID.as_ref())),
|
||||
_ => Ok(Value::from(v).pick(ID.as_ref())),
|
||||
|
@ -159,8 +152,6 @@ impl Subquery {
|
|||
}
|
||||
}
|
||||
Self::Insert(ref v) => {
|
||||
// Duplicate options
|
||||
let opt = opt.dive()?;
|
||||
// Duplicate context
|
||||
let mut ctx = Context::new(ctx);
|
||||
// Add parent document
|
||||
|
@ -168,7 +159,7 @@ impl Subquery {
|
|||
ctx.add_value("parent".into(), doc);
|
||||
}
|
||||
// Process subquery
|
||||
match v.compute(&ctx, &opt, txn, doc).await? {
|
||||
match v.compute(&ctx, opt, txn, doc).await? {
|
||||
Value::Array(mut v) => match v.len() {
|
||||
1 => Ok(v.remove(0).pick(ID.as_ref())),
|
||||
_ => Ok(Value::from(v).pick(ID.as_ref())),
|
||||
|
|
235
lib/tests/complex.rs
Normal file
235
lib/tests/complex.rs
Normal file
|
@ -0,0 +1,235 @@
|
|||
mod parse;
|
||||
use parse::Parse;
|
||||
use std::future::Future;
|
||||
use std::thread::Builder;
|
||||
use surrealdb::sql::Value;
|
||||
use surrealdb::Datastore;
|
||||
use surrealdb::Error;
|
||||
use surrealdb::Session;
|
||||
|
||||
#[test]
|
||||
fn self_referential_field() -> Result<(), Error> {
|
||||
// Ensure a good stack size for tests
|
||||
with_enough_stack(async {
|
||||
let mut res = run_queries(
|
||||
"
|
||||
CREATE pet:dog SET tail = <future> { tail };
|
||||
",
|
||||
)
|
||||
.await?;
|
||||
//
|
||||
assert_eq!(res.len(), 1);
|
||||
//
|
||||
let tmp = res.next().unwrap();
|
||||
assert!(matches!(tmp, Err(Error::ComputationDepthExceeded)));
|
||||
//
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cyclic_fields() -> Result<(), Error> {
|
||||
// Ensure a good stack size for tests
|
||||
with_enough_stack(async {
|
||||
let mut res = run_queries(
|
||||
"
|
||||
CREATE recycle SET consume = <future> { produce }, produce = <future> { consume };
|
||||
",
|
||||
)
|
||||
.await?;
|
||||
//
|
||||
assert_eq!(res.len(), 1);
|
||||
//
|
||||
let tmp = res.next().unwrap();
|
||||
assert!(matches!(tmp, Err(Error::ComputationDepthExceeded)));
|
||||
//
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cyclic_records() -> Result<(), Error> {
|
||||
// Ensure a good stack size for tests
|
||||
with_enough_stack(async {
|
||||
let mut res = run_queries(
|
||||
"
|
||||
CREATE thing:one SET friend = <future> { thing:two.friend };
|
||||
CREATE thing:two SET friend = <future> { thing:one.friend };
|
||||
",
|
||||
)
|
||||
.await?;
|
||||
//
|
||||
assert_eq!(res.len(), 2);
|
||||
//
|
||||
let tmp = res.next().unwrap();
|
||||
assert!(tmp.is_ok());
|
||||
//
|
||||
let tmp = res.next().unwrap();
|
||||
assert!(matches!(tmp, Err(Error::ComputationDepthExceeded)));
|
||||
//
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ok_future_graph_subquery_recursion_depth() -> Result<(), Error> {
|
||||
// Ensure a good stack size for tests
|
||||
with_enough_stack(async {
|
||||
let mut res = run_queries(
|
||||
r#"
|
||||
CREATE thing:three SET fut = <future> { friends[0].fut }, friends = [thing:four, thing:two];
|
||||
CREATE thing:four SET fut = <future> { (friend) }, friend = <future> { 42 };
|
||||
CREATE thing:two SET fut = <future> { friend }, friend = <future> { thing:three.fut };
|
||||
|
||||
CREATE thing:one SET foo = "bar";
|
||||
RELATE thing:one->friend->thing:two SET timestamp = time::now();
|
||||
|
||||
CREATE thing:zero SET foo = "baz";
|
||||
RELATE thing:zero->enemy->thing:one SET timestamp = time::now();
|
||||
|
||||
SELECT * FROM (SELECT * FROM (SELECT ->enemy->thing->friend->thing.fut as fut FROM thing:zero));
|
||||
"#,
|
||||
)
|
||||
.await?;
|
||||
//
|
||||
assert_eq!(res.len(), 8);
|
||||
//
|
||||
for i in 0..7 {
|
||||
let tmp = res.next().unwrap();
|
||||
assert!(tmp.is_ok(), "Statement {} resulted in {:?}", i, tmp);
|
||||
}
|
||||
//
|
||||
let tmp = res.next().unwrap()?;
|
||||
let val = Value::parse("[ [42] ]");
|
||||
assert_eq!(tmp, val);
|
||||
//
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ok_graph_traversal_depth() -> Result<(), Error> {
|
||||
// Build the SQL traversal query
|
||||
fn graph_traversal(n: usize) -> String {
|
||||
let mut ret = String::from("CREATE node:0;\n");
|
||||
for i in 1..=n {
|
||||
let prev = i - 1;
|
||||
ret.push_str(&format!("CREATE node:{i};\n"));
|
||||
ret.push_str(&format!("RELATE node:{prev}->edge{i}->node:{i};\n"));
|
||||
}
|
||||
ret.push_str("SELECT ");
|
||||
for i in 1..=n {
|
||||
ret.push_str(&format!("->edge{i}->node"));
|
||||
}
|
||||
ret.push_str(" AS res FROM node:0;\n");
|
||||
ret
|
||||
}
|
||||
// Test different traveral depths
|
||||
for n in 1..=40 {
|
||||
// Ensure a good stack size for tests
|
||||
with_enough_stack(async move {
|
||||
// Run the graph traversal queries
|
||||
let mut res = run_queries(&graph_traversal(n)).await?;
|
||||
// Remove the last result
|
||||
let tmp = res.next_back().unwrap();
|
||||
// Check all other queries
|
||||
assert!(res.all(|r| r.is_ok()));
|
||||
//
|
||||
match tmp {
|
||||
Ok(res) => {
|
||||
let val = Value::parse(&format!(
|
||||
"[
|
||||
{{
|
||||
res: [node:{n}],
|
||||
}}
|
||||
]"
|
||||
));
|
||||
assert_eq!(res, val);
|
||||
}
|
||||
Err(res) => {
|
||||
assert!(matches!(res, Error::ComputationDepthExceeded));
|
||||
assert!(n > 10, "Max traversals: {}", n - 1);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ok_cast_chain_depth() -> Result<(), Error> {
|
||||
// Ensure a good stack size for tests
|
||||
with_enough_stack(async {
|
||||
// Run a chasting query which succeeds
|
||||
let mut res = run_queries(&cast_chain(10)).await?;
|
||||
//
|
||||
assert_eq!(res.len(), 1);
|
||||
//
|
||||
let tmp = res.next().unwrap()?;
|
||||
let val = Value::from(vec![Value::from(5)]);
|
||||
assert_eq!(tmp, val);
|
||||
//
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn excessive_cast_chain_depth() -> Result<(), Error> {
|
||||
// Ensure a good stack size for tests
|
||||
with_enough_stack(async {
|
||||
// Run a casting query which will fail
|
||||
let mut res = run_queries(&cast_chain(35)).await?;
|
||||
//
|
||||
assert_eq!(res.len(), 1);
|
||||
//
|
||||
let tmp = res.next().unwrap();
|
||||
assert!(matches!(tmp, Err(Error::ComputationDepthExceeded)));
|
||||
//
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
async fn run_queries(
|
||||
sql: &str,
|
||||
) -> Result<
|
||||
impl Iterator<Item = Result<Value, Error>> + ExactSizeIterator + DoubleEndedIterator + 'static,
|
||||
Error,
|
||||
> {
|
||||
let dbs = Datastore::new("memory").await?;
|
||||
let ses = Session::for_kv().with_ns("test").with_db("test");
|
||||
dbs.execute(&sql, &ses, None, false).await.map(|v| v.into_iter().map(|res| res.result))
|
||||
}
|
||||
|
||||
fn with_enough_stack(
|
||||
fut: impl Future<Output = Result<(), Error>> + Send + 'static,
|
||||
) -> Result<(), Error> {
|
||||
#[allow(unused_mut)]
|
||||
let mut builder = Builder::new();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
builder = builder.stack_size(16_000_000);
|
||||
}
|
||||
|
||||
builder
|
||||
.spawn(|| {
|
||||
let runtime = tokio::runtime::Builder::new_current_thread().build().unwrap();
|
||||
runtime.block_on(fut)
|
||||
})
|
||||
.unwrap()
|
||||
.join()
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
fn cast_chain(n: usize) -> String {
|
||||
let mut sql = String::from("SELECT * FROM ");
|
||||
for _ in 0..n {
|
||||
sql.push_str("<int>");
|
||||
}
|
||||
sql.push_str("5;");
|
||||
sql
|
||||
}
|
Loading…
Reference in a new issue