Security - Limit parser depth. (#2369)

This commit is contained in:
Finn Bear 2023-08-21 15:05:11 -07:00 committed by GitHub
parent 77c889f356
commit 27cc21876d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 232 additions and 11 deletions

View file

@ -5,9 +5,18 @@ use once_cell::sync::Lazy;
/// Specifies how many concurrent jobs can be buffered in the worker channel.
pub const MAX_CONCURRENT_TASKS: usize = 64;
/// Specifies how deep various forms of computation will go before the query fails.
/// Specifies how deep various forms of computation will go before the query fails
/// with [`Error::ComputationDepthExceeded`].
///
/// For reference, use ~15 per MiB of stack in release mode.
///
/// During query parsing, the total depth of calls to parse values (including arrays, expressions,
/// functions, objects, sub-queries), Javascript values, and geometry collections count against
/// this limit.
///
/// During query execution, all potentially-recursive code paths count against this limit. Whereas
/// parsing assigns equal weight to each recursion, certain expensive code paths are allowed to
/// count for more than one unit of depth during execution.
pub static MAX_COMPUTATION_DEPTH: Lazy<u8> = Lazy::new(|| {
option_env!("SURREAL_MAX_COMPUTATION_DEPTH").and_then(|s| s.parse::<u8>().ok()).unwrap_or(120)
});

View file

@ -6,6 +6,7 @@ use thiserror::Error;
#[derive(Error, Debug)]
pub enum Error<I> {
Parser(I),
ExcessiveDepth,
Field(I, String),
Split(I, String),
Order(I, String),

View file

@ -228,6 +228,8 @@ pub fn unary(i: &str) -> IResult<&str, Expression> {
pub fn binary(i: &str) -> IResult<&str, Expression> {
let (i, l) = single(i)?;
let (i, o) = operator::binary(i)?;
// Make sure to dive if the query is a right-deep binary tree.
let _diving = crate::sql::parser::depth::dive()?;
let (i, r) = value(i)?;
let v = match r {
Value::Expression(r) => r.augment(l, o),

View file

@ -627,6 +627,7 @@ impl hash::Hash for Geometry {
}
pub fn geometry(i: &str) -> IResult<&str, Geometry> {
let _diving = crate::sql::parser::depth::dive()?;
alt((simple, normal))(i)
}

View file

@ -1,6 +1,6 @@
use crate::err::Error;
use crate::iam::Error as IamError;
use crate::sql::error::Error::{Field, Group, Order, Parser, Role, Split};
use crate::sql::error::Error::{ExcessiveDepth, Field, Group, Order, Parser, Role, Split};
use crate::sql::error::IResult;
use crate::sql::query::{query, Query};
use crate::sql::subquery::Subquery;
@ -11,6 +11,15 @@ use std::str;
use tracing::instrument;
/// Parses a SurrealQL [`Query`]
///
/// During query parsing, the total depth of calls to parse values (including arrays, expressions,
/// functions, objects, sub-queries), Javascript values, and geometry collections count against
/// a computation depth limit. If the limit is reached, parsing will return
/// [`Error::ComputationDepthExceeded`], as opposed to spending more time and potentially
/// overflowing the call stack.
///
/// If you encounter this limit and believe that it should be increased,
/// please [open an issue](https://github.com/surrealdb/surrealdb/issues)!
#[instrument(name = "parser", skip_all, fields(length = input.len()))]
pub fn parse(input: &str) -> Result<Query, Error> {
parse_impl(input, query)
@ -41,6 +50,9 @@ pub fn json(input: &str) -> Result<Value, Error> {
}
fn parse_impl<O>(input: &str, parser: impl Fn(&str) -> IResult<&str, O>) -> Result<O, Error> {
// Reset the parse depth limiter
depth::reset();
// Check the length of the input
match input.trim().len() {
// The input query was empty
@ -64,6 +76,8 @@ fn parse_impl<O>(input: &str, parser: impl Fn(&str) -> IResult<&str, O>) -> Resu
sql: s.to_string(),
}
}
// There was a parsing error
ExcessiveDepth => Error::ComputationDepthExceeded,
// There was a SPLIT ON error
Field(e, f) => Error::InvalidField {
line: locate(input, e).1,
@ -117,12 +131,76 @@ fn locate<'a>(input: &str, tried: &'a str) -> (&'a str, usize, usize) {
(tried, 0, 0)
}
pub(crate) mod depth {
use crate::cnf::MAX_COMPUTATION_DEPTH;
use crate::sql::Error::ExcessiveDepth;
use nom::Err;
use std::cell::Cell;
use std::thread::panicking;
thread_local! {
/// How many recursion levels deep parsing is currently.
static DEPTH: Cell<u8> = Cell::default();
}
/// Scale down `MAX_COMPUTATION_DEPTH` for parsing because:
/// - Only a few intermediate parsers, collectively sufficient to limit depth, call dive.
/// - Some of the depth budget during execution is for futures, graph traversal, and
/// other operations that don't exist during parsing.
/// - The parser currently runs in exponential time, so a lower limit guards against
/// CPU-intensive, time-consuming parsing.
const DEPTH_PER_DIVE: u8 = 4;
/// Call when starting the parser to reset the recursion depth.
#[inline(never)]
pub(super) fn reset() {
DEPTH.with(|cell| {
debug_assert_eq!(cell.get(), 0, "previous parsing stopped abruptly");
cell.set(0)
});
}
/// Call at least once in recursive parsing code paths to limit recursion depth.
#[inline(never)]
#[must_use = "must store and implicitly drop when returning"]
pub(crate) fn dive() -> Result<Diving, Err<crate::sql::Error<&'static str>>> {
DEPTH.with(|cell| {
let depth = cell.get().saturating_add(DEPTH_PER_DIVE);
if depth <= *MAX_COMPUTATION_DEPTH {
cell.replace(depth);
Ok(Diving)
} else {
Err(Err::Failure(ExcessiveDepth))
}
})
}
#[must_use]
#[non_exhaustive]
pub(crate) struct Diving;
impl Drop for Diving {
fn drop(&mut self) {
DEPTH.with(|cell| {
if let Some(depth) = cell.get().checked_sub(DEPTH_PER_DIVE) {
cell.replace(depth);
} else {
debug_assert!(panicking());
}
});
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde::Serialize;
use std::{collections::HashMap, time::Instant};
use std::{
collections::HashMap,
time::{Duration, Instant},
};
#[test]
fn no_ending() {
@ -159,6 +237,83 @@ mod tests {
assert!(res.is_err());
}
#[test]
fn parse_ok_recursion() {
let sql = "SELECT * FROM ((SELECT * FROM (5))) * 5;";
let res = parse(sql);
assert!(res.is_ok());
}
#[test]
fn parse_ok_recursion_deeper() {
let sql = "SELECT * FROM (((( SELECT * FROM ((5)) + ((5)) + ((5)) )))) * ((( function() {return 5;} )));";
let start = Instant::now();
let res = parse(sql);
let elapsed = start.elapsed();
assert!(res.is_ok());
assert!(
elapsed < Duration::from_millis(2000),
"took {}ms, previously took ~1000ms in debug",
elapsed.as_millis()
)
}
#[test]
fn parse_recursion_cast() {
for n in [10, 100, 500] {
recursive("SELECT * FROM ", "<int>", "5", "", n, n > 50);
}
}
#[test]
fn parse_recursion_geometry() {
for n in [1, 50, 100] {
recursive(
"SELECT * FROM ",
r#"{type: "GeometryCollection",geometries: ["#,
r#"{type: "MultiPoint",coordinates: [[10.0, 11.2],[10.5, 11.9]]}"#,
"]}",
n,
n > 25,
);
}
}
#[test]
fn parse_recursion_javascript() {
for n in [10, 1000] {
recursive("SELECT * FROM ", "function() {", "return 5;", "}", n, n > 500);
}
}
#[test]
fn parse_recursion_mixed() {
for n in [3, 15, 75] {
recursive("", "SELECT * FROM ((((", "5 * 5", ")))) * 5", n, n > 5);
}
}
#[test]
fn parse_recursion_select() {
for n in [5, 10, 100] {
recursive("SELECT * FROM ", "(SELECT * FROM ", "5", ")", n, n > 15);
}
}
#[test]
fn parse_recursion_value_subquery() {
for p in 1..=4 {
recursive("SELECT * FROM ", "(", "5", ")", 10usize.pow(p), p > 1);
}
}
#[test]
fn parse_recursion_if_subquery() {
for p in 1..=3 {
recursive("SELECT * FROM ", "IF true THEN ", "5", " ELSE 4 END", 6usize.pow(p), p > 1);
}
}
#[test]
fn parser_try() {
let sql = "
@ -246,4 +401,47 @@ mod tests {
println!("sql::json took {:.10}s/iter", benchmark(|s| crate::sql::json(s).unwrap()));
}
/// Try parsing a query with O(n) recursion depth and expect to fail if and only if
/// `excessive` is true.
fn recursive(
prefix: &str,
recursive_start: &str,
base: &str,
recursive_end: &str,
n: usize,
excessive: bool,
) {
let mut sql = String::from(prefix);
for _ in 0..n {
sql.push_str(recursive_start);
}
sql.push_str(base);
for _ in 0..n {
sql.push_str(recursive_end);
}
let start = Instant::now();
let res = parse(&sql);
let elapsed = start.elapsed();
if excessive {
assert!(
matches!(res, Err(Error::ComputationDepthExceeded)),
"expected computation depth exceeded, got {:?}",
res
);
} else {
res.unwrap();
}
// The parser can terminate faster in the excessive case.
let cutoff = if excessive {
500
} else {
1000
};
assert!(
elapsed < Duration::from_millis(cutoff),
"took {}ms, previously much faster to parse {n} in debug mode",
elapsed.as_millis()
)
}
}

View file

@ -60,6 +60,7 @@ pub fn script(i: &str) -> IResult<&str, Script> {
}
fn script_raw(i: &str) -> IResult<&str, &str> {
let _diving = crate::sql::parser::depth::dive()?;
recognize(many0(alt((
script_comment,
script_object,

View file

@ -2748,6 +2748,10 @@ pub fn value(i: &str) -> IResult<&str, Value> {
/// Parse any `Value` excluding binary expressions
pub fn single(i: &str) -> IResult<&str, Value> {
// Dive in `single` (as opposed to `value`) since it is directly
// called by `Cast`
let _diving = crate::sql::parser::depth::dive()?;
let (i, v) = alt((
alt((
terminated(
@ -2852,6 +2856,8 @@ pub fn what(i: &str) -> IResult<&str, Value> {
/// Used to parse any simple JSON-like value
pub fn json(i: &str) -> IResult<&str, Value> {
let _diving = crate::sql::parser::depth::dive()?;
// Use a specific parser for JSON objects
pub fn object(i: &str) -> IResult<&str, Object> {
let (i, _) = char('{')(i)?;

View file

@ -166,7 +166,7 @@ fn ok_graph_traversal_depth() -> Result<(), Error> {
fn ok_cast_chain_depth() -> Result<(), Error> {
// Ensure a good stack size for tests
with_enough_stack(async {
// Run a chasting query which succeeds
// Run a casting query which succeeds
let mut res = run_queries(&cast_chain(10)).await?;
//
assert_eq!(res.len(), 1);
@ -183,13 +183,16 @@ fn ok_cast_chain_depth() -> Result<(), Error> {
fn excessive_cast_chain_depth() -> Result<(), Error> {
// Ensure a good stack size for tests
with_enough_stack(async {
// Run a casting query which will fail
let mut res = run_queries(&cast_chain(125)).await?;
//
assert_eq!(res.len(), 1);
//
let tmp = res.next().unwrap();
assert!(matches!(tmp, Err(Error::ComputationDepthExceeded)));
// Run a casting query which will fail (either while parsing or executing)
match run_queries(&cast_chain(125)).await {
Ok(mut res) => {
assert_eq!(res.len(), 1);
//
let tmp = res.next().unwrap();
assert!(matches!(tmp, Err(Error::ComputationDepthExceeded)));
}
Err(e) => assert!(matches!(e, Error::ComputationDepthExceeded)),
}
//
Ok(())
})