diff --git a/lib/src/sql/strand.rs b/lib/src/sql/strand.rs index 7659b4cc..6dc557e5 100644 --- a/lib/src/sql/strand.rs +++ b/lib/src/sql/strand.rs @@ -2,16 +2,15 @@ use crate::sql::error::Error::Parser; use crate::sql::error::IResult; use crate::sql::escape::escape_str; use nom::branch::alt; -use nom::bytes::complete::escaped_transform; -use nom::bytes::complete::is_not; -use nom::bytes::complete::take_while_m_n; +use nom::bytes::complete::{escaped_transform, is_not, tag, take, take_while_m_n}; use nom::character::complete::char; use nom::combinator::value; +use nom::sequence::preceded; use nom::Err::Failure; use serde::{Deserialize, Serialize}; use std::fmt::{self, Display, Formatter}; -use std::ops; use std::ops::Deref; +use std::ops::{self, RangeInclusive}; use std::str; pub(crate) const TOKEN: &str = "$surrealdb::private::sql::Strand"; @@ -22,7 +21,8 @@ const SINGLE_ESC: &str = r#"\'"#; const DOUBLE: char = '"'; const DOUBLE_ESC: &str = r#"\""#; -const SURROGATES: [u32; 2] = [55296, 57343]; +const LEADING_SURROGATES: RangeInclusive = 0xD800..=0xDBFF; +const TRAILING_SURROGATES: RangeInclusive = 0xDC00..=0xDFFF; #[derive(Clone, Debug, Default, Eq, PartialEq, PartialOrd, Serialize, Deserialize, Hash)] #[serde(rename = "$surrealdb::private::sql::Strand")] @@ -111,7 +111,7 @@ fn strand_single(i: &str) -> IResult<&str, String> { is_not(SINGLE_ESC), '\\', alt(( - strand_unicode, + char_unicode, value('\u{5c}', char('\\')), value('\u{27}', char('\'')), value('\u{2f}', char('/')), @@ -132,7 +132,7 @@ fn strand_double(i: &str) -> IResult<&str, String> { is_not(DOUBLE_ESC), '\\', alt(( - strand_unicode, + char_unicode, value('\u{5c}', char('\\')), value('\u{22}', char('\"')), value('\u{2f}', char('/')), @@ -147,30 +147,59 @@ fn strand_double(i: &str) -> IResult<&str, String> { Ok((i, v)) } -fn strand_unicode(i: &str) -> IResult<&str, char> { - // Read the \u character - let (i, _) = char('u')(i)?; - // Let's read the next 6 ascii hexadecimal characters +fn char_unicode(i: &str) -> IResult<&str, char> { + preceded(char('u'), alt((char_unicode_bracketed, char_unicode_bare)))(i) +} + +// \uABCD or \uDBFF\uDFFF (surrogate pair) +fn char_unicode_bare(i: &str) -> IResult<&str, char> { + // Take exactly 4 bytes + let (i, v) = take(4usize)(i)?; + // Parse them as hex, where an error indicates invalid hex digits + let v: u16 = u16::from_str_radix(v, 16).map_err(|_| Failure(Parser(i)))?; + + if LEADING_SURROGATES.contains(&v) { + let leading = v; + + // Read the next \u. + let (i, _) = tag("\\u")(i)?; + // Take exactly 4 more bytes + let (i, v) = take(4usize)(i)?; + // Parse them as hex, where an error indicates invalid hex digits + let trailing = u16::from_str_radix(v, 16).map_err(|_| Failure(Parser(i)))?; + if !TRAILING_SURROGATES.contains(&trailing) { + return Err(Failure(Parser(i))); + } + // Compute the codepoint. + // https://datacadamia.com/data/type/text/surrogate#from_surrogate_to_character_code + let codepoint = 0x10000 + + ((leading as u32 - *LEADING_SURROGATES.start() as u32) << 10) + + trailing as u32 + - *TRAILING_SURROGATES.start() as u32; + // Convert to char + let v = char::from_u32(codepoint).ok_or(Failure(Parser(i)))?; + // Return the char + Ok((i, v)) + } else { + // We can convert this to char or error in the case of invalid Unicode character + let v = char::from_u32(v as u32).ok_or(Failure(Parser(i)))?; + // Return the char + Ok((i, v)) + } +} + +// \u{10ffff} +fn char_unicode_bracketed(i: &str) -> IResult<&str, char> { + // Read the { character + let (i, _) = char('{')(i)?; + // Let's up to 6 ascii hexadecimal characters let (i, v) = take_while_m_n(1, 6, |c: char| c.is_ascii_hexdigit())(i)?; - // We can convert this to u32 as we only have 6 chars - let v = match u32::from_str_radix(v, 16) { - // We found an invalid unicode sequence - Err(_) => return Err(Failure(Parser(i))), - // The unicode sequence was valid - Ok(v) => match v { - // This is a surrogate, so convert to a space - v if v >= SURROGATES[0] && v <= SURROGATES[1] => 32, - // This is a valid UTF-8 / UTF-16 character - _ => v, - }, - }; - // We can convert this to char as we know it is valid - let v = match std::char::from_u32(v) { - // We found an invalid unicode sequence - None => return Err(Failure(Parser(i))), - // The unicode sequence was valid - Some(v) => v, - }; + // We can convert this to u32 as the max is 0xffffff + let v = u32::from_str_radix(v, 16).unwrap(); + // We can convert this to char or error in the case of invalid Unicode character + let v = char::from_u32(v).ok_or(Failure(Parser(i)))?; + // Read the } character + let (i, _) = char('}')(i)?; // Return the char Ok((i, v)) } @@ -239,4 +268,32 @@ mod tests { assert_eq!("'te\"st\n\tand\u{08}some\u{05d9}'", format!("{}", out)); assert_eq!(out, Strand::from("te\"st\n\tand\u{08}some\u{05d9}")); } + + #[test] + fn strand_fuzz_escape() { + for n in (0..=char::MAX as u32).step_by(101) { + if let Some(c) = char::from_u32(n) { + let expected = format!("a{c}b"); + + let utf32 = format!("\"a\\u{{{n:x}}}b\""); + let (rest, s) = strand(&utf32).unwrap(); + assert_eq!(rest, ""); + assert_eq!(s.as_str(), &expected); + + let mut utf16 = String::with_capacity(16); + utf16 += "\"a"; + let mut buf = [0; 2]; + for &mut n in c.encode_utf16(&mut buf) { + utf16 += &format!("\\u{n:04x}"); + } + utf16 += "b\""; + let (rest, s) = strand(&utf16).unwrap(); + assert_eq!(rest, ""); + assert_eq!(s.as_str(), &expected); + } + } + + // Unpaired surrogate. + assert!(strand("\"\\u{DBFF}\"").is_err()); + } }