From 27cc21876d32b1851d45b5b283b76e8f564cdb6a Mon Sep 17 00:00:00 2001 From: Finn Bear Date: Mon, 21 Aug 2023 15:05:11 -0700 Subject: [PATCH] Security - Limit parser depth. (#2369) --- lib/src/cnf/mod.rs | 11 +- lib/src/sql/error.rs | 1 + lib/src/sql/expression.rs | 2 + lib/src/sql/geometry.rs | 1 + lib/src/sql/parser.rs | 202 ++++++++++++++++++++++++++++++++++++- lib/src/sql/script.rs | 1 + lib/src/sql/value/value.rs | 6 ++ lib/tests/complex.rs | 19 ++-- 8 files changed, 232 insertions(+), 11 deletions(-) diff --git a/lib/src/cnf/mod.rs b/lib/src/cnf/mod.rs index 4ad9c412..d5071307 100644 --- a/lib/src/cnf/mod.rs +++ b/lib/src/cnf/mod.rs @@ -5,9 +5,18 @@ use once_cell::sync::Lazy; /// Specifies how many concurrent jobs can be buffered in the worker channel. pub const MAX_CONCURRENT_TASKS: usize = 64; -/// Specifies how deep various forms of computation will go before the query fails. +/// Specifies how deep various forms of computation will go before the query fails +/// with [`Error::ComputationDepthExceeded`]. /// /// For reference, use ~15 per MiB of stack in release mode. +/// +/// During query parsing, the total depth of calls to parse values (including arrays, expressions, +/// functions, objects, sub-queries), Javascript values, and geometry collections count against +/// this limit. +/// +/// During query execution, all potentially-recursive code paths count against this limit. Whereas +/// parsing assigns equal weight to each recursion, certain expensive code paths are allowed to +/// count for more than one unit of depth during execution. pub static MAX_COMPUTATION_DEPTH: Lazy = Lazy::new(|| { option_env!("SURREAL_MAX_COMPUTATION_DEPTH").and_then(|s| s.parse::().ok()).unwrap_or(120) }); diff --git a/lib/src/sql/error.rs b/lib/src/sql/error.rs index e6e70f8f..9d72360e 100644 --- a/lib/src/sql/error.rs +++ b/lib/src/sql/error.rs @@ -6,6 +6,7 @@ use thiserror::Error; #[derive(Error, Debug)] pub enum Error { Parser(I), + ExcessiveDepth, Field(I, String), Split(I, String), Order(I, String), diff --git a/lib/src/sql/expression.rs b/lib/src/sql/expression.rs index 75e400cd..878dc7c5 100644 --- a/lib/src/sql/expression.rs +++ b/lib/src/sql/expression.rs @@ -228,6 +228,8 @@ pub fn unary(i: &str) -> IResult<&str, Expression> { pub fn binary(i: &str) -> IResult<&str, Expression> { let (i, l) = single(i)?; let (i, o) = operator::binary(i)?; + // Make sure to dive if the query is a right-deep binary tree. + let _diving = crate::sql::parser::depth::dive()?; let (i, r) = value(i)?; let v = match r { Value::Expression(r) => r.augment(l, o), diff --git a/lib/src/sql/geometry.rs b/lib/src/sql/geometry.rs index 66d1ca8e..36360d13 100644 --- a/lib/src/sql/geometry.rs +++ b/lib/src/sql/geometry.rs @@ -627,6 +627,7 @@ impl hash::Hash for Geometry { } pub fn geometry(i: &str) -> IResult<&str, Geometry> { + let _diving = crate::sql::parser::depth::dive()?; alt((simple, normal))(i) } diff --git a/lib/src/sql/parser.rs b/lib/src/sql/parser.rs index 451f5188..9588452e 100644 --- a/lib/src/sql/parser.rs +++ b/lib/src/sql/parser.rs @@ -1,6 +1,6 @@ use crate::err::Error; use crate::iam::Error as IamError; -use crate::sql::error::Error::{Field, Group, Order, Parser, Role, Split}; +use crate::sql::error::Error::{ExcessiveDepth, Field, Group, Order, Parser, Role, Split}; use crate::sql::error::IResult; use crate::sql::query::{query, Query}; use crate::sql::subquery::Subquery; @@ -11,6 +11,15 @@ use std::str; use tracing::instrument; /// Parses a SurrealQL [`Query`] +/// +/// During query parsing, the total depth of calls to parse values (including arrays, expressions, +/// functions, objects, sub-queries), Javascript values, and geometry collections count against +/// a computation depth limit. If the limit is reached, parsing will return +/// [`Error::ComputationDepthExceeded`], as opposed to spending more time and potentially +/// overflowing the call stack. +/// +/// If you encounter this limit and believe that it should be increased, +/// please [open an issue](https://github.com/surrealdb/surrealdb/issues)! #[instrument(name = "parser", skip_all, fields(length = input.len()))] pub fn parse(input: &str) -> Result { parse_impl(input, query) @@ -41,6 +50,9 @@ pub fn json(input: &str) -> Result { } fn parse_impl(input: &str, parser: impl Fn(&str) -> IResult<&str, O>) -> Result { + // Reset the parse depth limiter + depth::reset(); + // Check the length of the input match input.trim().len() { // The input query was empty @@ -64,6 +76,8 @@ fn parse_impl(input: &str, parser: impl Fn(&str) -> IResult<&str, O>) -> Resu sql: s.to_string(), } } + // There was a parsing error + ExcessiveDepth => Error::ComputationDepthExceeded, // There was a SPLIT ON error Field(e, f) => Error::InvalidField { line: locate(input, e).1, @@ -117,12 +131,76 @@ fn locate<'a>(input: &str, tried: &'a str) -> (&'a str, usize, usize) { (tried, 0, 0) } +pub(crate) mod depth { + use crate::cnf::MAX_COMPUTATION_DEPTH; + use crate::sql::Error::ExcessiveDepth; + use nom::Err; + use std::cell::Cell; + use std::thread::panicking; + + thread_local! { + /// How many recursion levels deep parsing is currently. + static DEPTH: Cell = Cell::default(); + } + + /// Scale down `MAX_COMPUTATION_DEPTH` for parsing because: + /// - Only a few intermediate parsers, collectively sufficient to limit depth, call dive. + /// - Some of the depth budget during execution is for futures, graph traversal, and + /// other operations that don't exist during parsing. + /// - The parser currently runs in exponential time, so a lower limit guards against + /// CPU-intensive, time-consuming parsing. + const DEPTH_PER_DIVE: u8 = 4; + + /// Call when starting the parser to reset the recursion depth. + #[inline(never)] + pub(super) fn reset() { + DEPTH.with(|cell| { + debug_assert_eq!(cell.get(), 0, "previous parsing stopped abruptly"); + cell.set(0) + }); + } + + /// Call at least once in recursive parsing code paths to limit recursion depth. + #[inline(never)] + #[must_use = "must store and implicitly drop when returning"] + pub(crate) fn dive() -> Result>> { + DEPTH.with(|cell| { + let depth = cell.get().saturating_add(DEPTH_PER_DIVE); + if depth <= *MAX_COMPUTATION_DEPTH { + cell.replace(depth); + Ok(Diving) + } else { + Err(Err::Failure(ExcessiveDepth)) + } + }) + } + + #[must_use] + #[non_exhaustive] + pub(crate) struct Diving; + + impl Drop for Diving { + fn drop(&mut self) { + DEPTH.with(|cell| { + if let Some(depth) = cell.get().checked_sub(DEPTH_PER_DIVE) { + cell.replace(depth); + } else { + debug_assert!(panicking()); + } + }); + } + } +} + #[cfg(test)] mod tests { use super::*; use serde::Serialize; - use std::{collections::HashMap, time::Instant}; + use std::{ + collections::HashMap, + time::{Duration, Instant}, + }; #[test] fn no_ending() { @@ -159,6 +237,83 @@ mod tests { assert!(res.is_err()); } + #[test] + fn parse_ok_recursion() { + let sql = "SELECT * FROM ((SELECT * FROM (5))) * 5;"; + let res = parse(sql); + assert!(res.is_ok()); + } + + #[test] + fn parse_ok_recursion_deeper() { + let sql = "SELECT * FROM (((( SELECT * FROM ((5)) + ((5)) + ((5)) )))) * ((( function() {return 5;} )));"; + let start = Instant::now(); + let res = parse(sql); + let elapsed = start.elapsed(); + assert!(res.is_ok()); + assert!( + elapsed < Duration::from_millis(2000), + "took {}ms, previously took ~1000ms in debug", + elapsed.as_millis() + ) + } + + #[test] + fn parse_recursion_cast() { + for n in [10, 100, 500] { + recursive("SELECT * FROM ", "", "5", "", n, n > 50); + } + } + + #[test] + fn parse_recursion_geometry() { + for n in [1, 50, 100] { + recursive( + "SELECT * FROM ", + r#"{type: "GeometryCollection",geometries: ["#, + r#"{type: "MultiPoint",coordinates: [[10.0, 11.2],[10.5, 11.9]]}"#, + "]}", + n, + n > 25, + ); + } + } + + #[test] + fn parse_recursion_javascript() { + for n in [10, 1000] { + recursive("SELECT * FROM ", "function() {", "return 5;", "}", n, n > 500); + } + } + + #[test] + fn parse_recursion_mixed() { + for n in [3, 15, 75] { + recursive("", "SELECT * FROM ((((", "5 * 5", ")))) * 5", n, n > 5); + } + } + + #[test] + fn parse_recursion_select() { + for n in [5, 10, 100] { + recursive("SELECT * FROM ", "(SELECT * FROM ", "5", ")", n, n > 15); + } + } + + #[test] + fn parse_recursion_value_subquery() { + for p in 1..=4 { + recursive("SELECT * FROM ", "(", "5", ")", 10usize.pow(p), p > 1); + } + } + + #[test] + fn parse_recursion_if_subquery() { + for p in 1..=3 { + recursive("SELECT * FROM ", "IF true THEN ", "5", " ELSE 4 END", 6usize.pow(p), p > 1); + } + } + #[test] fn parser_try() { let sql = " @@ -246,4 +401,47 @@ mod tests { println!("sql::json took {:.10}s/iter", benchmark(|s| crate::sql::json(s).unwrap())); } + + /// Try parsing a query with O(n) recursion depth and expect to fail if and only if + /// `excessive` is true. + fn recursive( + prefix: &str, + recursive_start: &str, + base: &str, + recursive_end: &str, + n: usize, + excessive: bool, + ) { + let mut sql = String::from(prefix); + for _ in 0..n { + sql.push_str(recursive_start); + } + sql.push_str(base); + for _ in 0..n { + sql.push_str(recursive_end); + } + let start = Instant::now(); + let res = parse(&sql); + let elapsed = start.elapsed(); + if excessive { + assert!( + matches!(res, Err(Error::ComputationDepthExceeded)), + "expected computation depth exceeded, got {:?}", + res + ); + } else { + res.unwrap(); + } + // The parser can terminate faster in the excessive case. + let cutoff = if excessive { + 500 + } else { + 1000 + }; + assert!( + elapsed < Duration::from_millis(cutoff), + "took {}ms, previously much faster to parse {n} in debug mode", + elapsed.as_millis() + ) + } } diff --git a/lib/src/sql/script.rs b/lib/src/sql/script.rs index 61ed8362..96b3f6ff 100644 --- a/lib/src/sql/script.rs +++ b/lib/src/sql/script.rs @@ -60,6 +60,7 @@ pub fn script(i: &str) -> IResult<&str, Script> { } fn script_raw(i: &str) -> IResult<&str, &str> { + let _diving = crate::sql::parser::depth::dive()?; recognize(many0(alt(( script_comment, script_object, diff --git a/lib/src/sql/value/value.rs b/lib/src/sql/value/value.rs index b21ca8a3..20c29d29 100644 --- a/lib/src/sql/value/value.rs +++ b/lib/src/sql/value/value.rs @@ -2748,6 +2748,10 @@ pub fn value(i: &str) -> IResult<&str, Value> { /// Parse any `Value` excluding binary expressions pub fn single(i: &str) -> IResult<&str, Value> { + // Dive in `single` (as opposed to `value`) since it is directly + // called by `Cast` + let _diving = crate::sql::parser::depth::dive()?; + let (i, v) = alt(( alt(( terminated( @@ -2852,6 +2856,8 @@ pub fn what(i: &str) -> IResult<&str, Value> { /// Used to parse any simple JSON-like value pub fn json(i: &str) -> IResult<&str, Value> { + let _diving = crate::sql::parser::depth::dive()?; + // Use a specific parser for JSON objects pub fn object(i: &str) -> IResult<&str, Object> { let (i, _) = char('{')(i)?; diff --git a/lib/tests/complex.rs b/lib/tests/complex.rs index 0309f6cf..8d5ffcc2 100644 --- a/lib/tests/complex.rs +++ b/lib/tests/complex.rs @@ -166,7 +166,7 @@ fn ok_graph_traversal_depth() -> Result<(), Error> { fn ok_cast_chain_depth() -> Result<(), Error> { // Ensure a good stack size for tests with_enough_stack(async { - // Run a chasting query which succeeds + // Run a casting query which succeeds let mut res = run_queries(&cast_chain(10)).await?; // assert_eq!(res.len(), 1); @@ -183,13 +183,16 @@ fn ok_cast_chain_depth() -> Result<(), Error> { fn excessive_cast_chain_depth() -> Result<(), Error> { // Ensure a good stack size for tests with_enough_stack(async { - // Run a casting query which will fail - let mut res = run_queries(&cast_chain(125)).await?; - // - assert_eq!(res.len(), 1); - // - let tmp = res.next().unwrap(); - assert!(matches!(tmp, Err(Error::ComputationDepthExceeded))); + // Run a casting query which will fail (either while parsing or executing) + match run_queries(&cast_chain(125)).await { + Ok(mut res) => { + assert_eq!(res.len(), 1); + // + let tmp = res.next().unwrap(); + assert!(matches!(tmp, Err(Error::ComputationDepthExceeded))); + } + Err(e) => assert!(matches!(e, Error::ComputationDepthExceeded)), + } // Ok(()) })