diff --git a/lib/src/fnc/script/from.rs b/lib/src/fnc/script/from.rs index 2e57a4a7..cf168a36 100644 --- a/lib/src/fnc/script/from.rs +++ b/lib/src/fnc/script/from.rs @@ -3,12 +3,21 @@ use crate::sql::array::Array; use crate::sql::datetime::Datetime; use crate::sql::object::Object; use crate::sql::value::Value; +use crate::sql::Id; use chrono::{TimeZone, Utc}; use js::Ctx; use js::Error; use js::FromAtom; use js::FromJs; +fn check_nul(s: &str) -> Result<(), Error> { + if s.contains('\0') { + Err(Error::InvalidString(std::ffi::CString::new(s).unwrap_err())) + } else { + Ok(()) + } +} + impl<'js> FromJs<'js> for Value { fn from_js(ctx: Ctx<'js>, val: js::Value<'js>) -> Result { match val { @@ -16,7 +25,10 @@ impl<'js> FromJs<'js> for Value { val if val.type_name() == "undefined" => Ok(Value::None), val if val.is_bool() => Ok(val.as_bool().unwrap().into()), val if val.is_string() => match val.into_string().unwrap().to_string() { - Ok(v) => Ok(Value::from(v)), + Ok(v) => { + check_nul(&v)?; + Ok(Value::from(v)) + } Err(e) => Err(e), }, val if val.is_int() => Ok(val.as_int().unwrap().into()), @@ -48,6 +60,10 @@ impl<'js> FromJs<'js> for Value { if (v).instance_of::() { let v = v.into_instance::().unwrap(); let v: &classes::record::record::Record = v.as_ref(); + check_nul(&v.value.tb)?; + if let Id::String(s) = &v.value.id { + check_nul(&s)?; + } return Ok(v.value.clone().into()); } // Check to see if this object is a duration @@ -95,6 +111,7 @@ impl<'js> FromJs<'js> for Value { for i in v.props() { let (k, v) = i?; let k = String::from_atom(k)?; + check_nul(&k)?; let v = Value::from_js(ctx, v)?; x.insert(k, v); } diff --git a/lib/src/sql/id.rs b/lib/src/sql/id.rs index bd286e79..5195a610 100644 --- a/lib/src/sql/id.rs +++ b/lib/src/sql/id.rs @@ -24,6 +24,7 @@ use ulid::Ulid; #[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize, Hash)] pub enum Id { Number(i64), + /// Invariant: Doesn't contain NUL bytes. String(String), Array(Array), Object(Object), diff --git a/lib/src/sql/ident.rs b/lib/src/sql/ident.rs index 4f4c9c9c..ca5cd331 100644 --- a/lib/src/sql/ident.rs +++ b/lib/src/sql/ident.rs @@ -1,6 +1,7 @@ use crate::sql::common::val_char; use crate::sql::error::IResult; use crate::sql::escape::escape_ident; +use crate::sql::strand::no_nul_bytes; use nom::branch::alt; use nom::bytes::complete::escaped_transform; use nom::bytes::complete::is_not; @@ -18,13 +19,13 @@ use std::str; const BRACKET_L: char = '⟨'; const BRACKET_R: char = '⟩'; -const BRACKET_END: &str = r#"⟩"#; +const BRACKET_END_NUL: &str = "⟩\0"; const BACKTICK: char = '`'; -const BACKTICK_ESC: &str = r#"\`"#; +const BACKTICK_ESC_NUL: &str = "`\\\0"; #[derive(Clone, Debug, Default, Eq, Ord, PartialEq, PartialOrd, Serialize, Deserialize, Hash)] -pub struct Ident(pub String); +pub struct Ident(#[serde(with = "no_nul_bytes")] pub String); impl From for Ident { fn from(v: String) -> Self { @@ -90,7 +91,7 @@ fn ident_default(i: &str) -> IResult<&str, String> { fn ident_backtick(i: &str) -> IResult<&str, String> { let (i, _) = char(BACKTICK)(i)?; let (i, v) = escaped_transform( - is_not(BACKTICK_ESC), + is_not(BACKTICK_ESC_NUL), '\\', alt(( value('\u{5c}', char('\\')), @@ -108,7 +109,7 @@ fn ident_backtick(i: &str) -> IResult<&str, String> { } fn ident_brackets(i: &str) -> IResult<&str, String> { - let (i, v) = delimited(char(BRACKET_L), is_not(BRACKET_END), char(BRACKET_R))(i)?; + let (i, v) = delimited(char(BRACKET_L), is_not(BRACKET_END_NUL), char(BRACKET_R))(i)?; Ok((i, String::from(v))) } diff --git a/lib/src/sql/object.rs b/lib/src/sql/object.rs index 0e0ae9d2..cf56c4ff 100644 --- a/lib/src/sql/object.rs +++ b/lib/src/sql/object.rs @@ -26,9 +26,10 @@ use std::ops::DerefMut; pub(crate) const TOKEN: &str = "$surrealdb::private::sql::Object"; +/// Invariant: Keys never contain NUL bytes. #[derive(Clone, Debug, Default, Eq, Ord, PartialEq, PartialOrd, Serialize, Deserialize, Hash)] #[serde(rename = "$surrealdb::private::sql::Object")] -pub struct Object(pub BTreeMap); +pub struct Object(#[serde(with = "no_nul_bytes_in_keys")] pub BTreeMap); impl From> for Object { fn from(v: BTreeMap) -> Self { @@ -51,13 +52,13 @@ impl From> for Object { impl From for Object { fn from(v: Operation) -> Self { Self(map! { - String::from("op") => match v.op { - Op::None => Value::from("none"), - Op::Add => Value::from("add"), - Op::Remove => Value::from("remove"), - Op::Replace => Value::from("replace"), - Op::Change => Value::from("change"), - }, + String::from("op") => Value::from(match v.op { + Op::None => "none", + Op::Add => "add", + Op::Remove => "remove", + Op::Replace => "replace", + Op::Change => "change", + }), String::from("path") => v.path.to_path().into(), String::from("value") => v.value, }) @@ -167,6 +168,60 @@ impl Display for Object { } } +mod no_nul_bytes_in_keys { + use serde::{ + de::{self, Visitor}, + ser::SerializeMap, + Deserializer, Serializer, + }; + use std::{collections::BTreeMap, fmt}; + + use crate::sql::Value; + + pub(crate) fn serialize( + m: &BTreeMap, + serializer: S, + ) -> Result + where + S: Serializer, + { + let mut s = serializer.serialize_map(Some(m.len()))?; + for (k, v) in m { + debug_assert!(!k.contains('\0')); + s.serialize_entry(k, v)?; + } + s.end() + } + + pub(crate) fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> + where + D: Deserializer<'de>, + { + struct NoNulBytesInKeysVisitor; + + impl<'de> Visitor<'de> for NoNulBytesInKeysVisitor { + type Value = BTreeMap; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a map without any NUL bytes in its keys") + } + + fn visit_map(self, mut map: A) -> Result + where + A: de::MapAccess<'de>, + { + let mut ret = BTreeMap::new(); + while let Some((k, v)) = map.next_entry()? { + ret.insert(k, v); + } + Ok(ret) + } + } + + deserializer.deserialize_map(NoNulBytesInKeysVisitor) + } +} + pub fn object(i: &str) -> IResult<&str, Object> { let (i, _) = char('{')(i)?; let (i, _) = mightbespace(i)?; @@ -194,11 +249,11 @@ fn key_none(i: &str) -> IResult<&str, &str> { } fn key_single(i: &str) -> IResult<&str, &str> { - delimited(char('\''), is_not("\'"), char('\''))(i) + delimited(char('\''), is_not("\'\0"), char('\''))(i) } fn key_double(i: &str) -> IResult<&str, &str> { - delimited(char('\"'), is_not("\""), char('\"'))(i) + delimited(char('\"'), is_not("\"\0"), char('\"'))(i) } #[cfg(test)] diff --git a/lib/src/sql/part.rs b/lib/src/sql/part.rs index 463e9c8b..d6f17a99 100644 --- a/lib/src/sql/part.rs +++ b/lib/src/sql/part.rs @@ -6,6 +6,7 @@ use crate::sql::graph::{self, Graph}; use crate::sql::ident::{self, Ident}; use crate::sql::idiom::Idiom; use crate::sql::number::{number, Number}; +use crate::sql::strand::no_nul_bytes; use crate::sql::value::{self, Value}; use nom::branch::alt; use nom::bytes::complete::tag; @@ -27,7 +28,7 @@ pub enum Part { Where(Value), Graph(Graph), Value(Value), - Method(String, Vec), + Method(#[serde(with = "no_nul_bytes")] String, Vec), } impl From for Part { diff --git a/lib/src/sql/range.rs b/lib/src/sql/range.rs index e4b8841a..02aa3b91 100644 --- a/lib/src/sql/range.rs +++ b/lib/src/sql/range.rs @@ -5,6 +5,7 @@ use crate::err::Error; use crate::sql::error::IResult; use crate::sql::id::{id, Id}; use crate::sql::ident::ident_raw; +use crate::sql::strand::no_nul_bytes; use crate::sql::value::Value; use nom::branch::alt; use nom::character::complete::char; @@ -22,6 +23,7 @@ pub(crate) const TOKEN: &str = "$surrealdb::private::sql::Range"; #[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Hash)] #[serde(rename = "$surrealdb::private::sql::Range")] pub struct Range { + #[serde(with = "no_nul_bytes")] pub tb: String, pub beg: Bound, pub end: Bound, diff --git a/lib/src/sql/regex.rs b/lib/src/sql/regex.rs index 1afdaaf2..7b7df2db 100644 --- a/lib/src/sql/regex.rs +++ b/lib/src/sql/regex.rs @@ -30,7 +30,11 @@ impl FromStr for Regex { type Err = ::Err; fn from_str(s: &str) -> Result { - regex::Regex::new(&s.replace("\\/", "/")).map(Self) + if s.contains('\0') { + Err(regex::Error::Syntax("regex contained NUL byte".to_owned())) + } else { + regex::Regex::new(&s.replace("\\/", "/")).map(Self) + } } } diff --git a/lib/src/sql/script.rs b/lib/src/sql/script.rs index 0ecf68f5..48e0bff5 100644 --- a/lib/src/sql/script.rs +++ b/lib/src/sql/script.rs @@ -1,5 +1,6 @@ use crate::sql::comment::{block, slash}; use crate::sql::error::IResult; +use crate::sql::strand::no_nul_bytes; use nom::branch::alt; use nom::bytes::complete::escaped; use nom::bytes::complete::is_not; @@ -16,19 +17,19 @@ use std::ops::Deref; use std::str; const SINGLE: char = '\''; -const SINGLE_ESC: &str = r#"\'"#; +const SINGLE_ESC_NUL: &str = "'\\\0"; const DOUBLE: char = '"'; -const DOUBLE_ESC: &str = r#"\""#; +const DOUBLE_ESC_NUL: &str = "\"\\\0"; const BACKTICK: char = '`'; -const BACKTICK_ESC: &str = r#"\`"#; +const BACKTICK_ESC_NUL: &str = "`\\\0"; const OBJECT_BEG: char = '{'; const OBJECT_END: char = '}'; #[derive(Clone, Debug, Default, Eq, PartialEq, PartialOrd, Serialize, Deserialize, Hash)] -pub struct Script(pub String); +pub struct Script(#[serde(with = "no_nul_bytes")] pub String); impl From for Script { fn from(s: String) -> Self { @@ -100,19 +101,19 @@ fn script_string(i: &str) -> IResult<&str, &str> { }, |i| { let (i, _) = char(SINGLE)(i)?; - let (i, v) = escaped(is_not(SINGLE_ESC), '\\', char(SINGLE))(i)?; + let (i, v) = escaped(is_not(SINGLE_ESC_NUL), '\\', char(SINGLE))(i)?; let (i, _) = char(SINGLE)(i)?; Ok((i, v)) }, |i| { let (i, _) = char(DOUBLE)(i)?; - let (i, v) = escaped(is_not(DOUBLE_ESC), '\\', char(DOUBLE))(i)?; + let (i, v) = escaped(is_not(DOUBLE_ESC_NUL), '\\', char(DOUBLE))(i)?; let (i, _) = char(DOUBLE)(i)?; Ok((i, v)) }, |i| { let (i, _) = char(BACKTICK)(i)?; - let (i, v) = escaped(is_not(BACKTICK_ESC), '\\', char(BACKTICK))(i)?; + let (i, v) = escaped(is_not(BACKTICK_ESC_NUL), '\\', char(BACKTICK))(i)?; let (i, _) = char(BACKTICK)(i)?; Ok((i, v)) }, diff --git a/lib/src/sql/strand.rs b/lib/src/sql/strand.rs index 6dc557e5..edd8a1d8 100644 --- a/lib/src/sql/strand.rs +++ b/lib/src/sql/strand.rs @@ -16,26 +16,29 @@ use std::str; pub(crate) const TOKEN: &str = "$surrealdb::private::sql::Strand"; const SINGLE: char = '\''; -const SINGLE_ESC: &str = r#"\'"#; +const SINGLE_ESC_NUL: &str = "'\\\0"; const DOUBLE: char = '"'; -const DOUBLE_ESC: &str = r#"\""#; +const DOUBLE_ESC_NUL: &str = "\"\\\0"; const LEADING_SURROGATES: RangeInclusive = 0xD800..=0xDBFF; const TRAILING_SURROGATES: RangeInclusive = 0xDC00..=0xDFFF; -#[derive(Clone, Debug, Default, Eq, PartialEq, PartialOrd, Serialize, Deserialize, Hash)] +/// A string that doesn't contain NUL bytes. +#[derive(Clone, Debug, Default, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize, Hash)] #[serde(rename = "$surrealdb::private::sql::Strand")] -pub struct Strand(pub String); +pub struct Strand(#[serde(with = "no_nul_bytes")] pub String); impl From for Strand { fn from(s: String) -> Self { + debug_assert!(!s.contains('\0')); Strand(s) } } impl From<&str> for Strand { fn from(s: &str) -> Self { + debug_assert!(!s.contains('\0')); Self::from(String::from(s)) } } @@ -76,8 +79,9 @@ impl Display for Strand { impl ops::Add for Strand { type Output = Self; - fn add(self, other: Self) -> Self { - Strand::from(self.0 + &other.0) + fn add(mut self, other: Self) -> Self { + self.0.push_str(other.as_str()); + self } } @@ -108,7 +112,7 @@ fn strand_blank(i: &str) -> IResult<&str, String> { fn strand_single(i: &str) -> IResult<&str, String> { let (i, _) = char(SINGLE)(i)?; let (i, v) = escaped_transform( - is_not(SINGLE_ESC), + is_not(SINGLE_ESC_NUL), '\\', alt(( char_unicode, @@ -129,7 +133,7 @@ fn strand_single(i: &str) -> IResult<&str, String> { fn strand_double(i: &str) -> IResult<&str, String> { let (i, _) = char(DOUBLE)(i)?; let (i, v) = escaped_transform( - is_not(DOUBLE_ESC), + is_not(DOUBLE_ESC_NUL), '\\', alt(( char_unicode, @@ -182,7 +186,7 @@ fn char_unicode_bare(i: &str) -> IResult<&str, char> { Ok((i, v)) } else { // We can convert this to char or error in the case of invalid Unicode character - let v = char::from_u32(v as u32).ok_or(Failure(Parser(i)))?; + let v = char::from_u32(v as u32).filter(|c| *c != 0 as char).ok_or(Failure(Parser(i)))?; // Return the char Ok((i, v)) } @@ -197,13 +201,69 @@ fn char_unicode_bracketed(i: &str) -> IResult<&str, char> { // We can convert this to u32 as the max is 0xffffff let v = u32::from_str_radix(v, 16).unwrap(); // We can convert this to char or error in the case of invalid Unicode character - let v = char::from_u32(v).ok_or(Failure(Parser(i)))?; + let v = char::from_u32(v).filter(|c| *c != 0 as char).ok_or(Failure(Parser(i)))?; // Read the } character let (i, _) = char('}')(i)?; // Return the char Ok((i, v)) } +// serde(with = no_nul_bytes) will (de)serialize with no NUL bytes. +pub(crate) mod no_nul_bytes { + use serde::{ + de::{self, Visitor}, + Deserializer, Serializer, + }; + use std::fmt; + + pub(crate) fn serialize(s: &String, serializer: S) -> Result + where + S: Serializer, + { + debug_assert!(!s.contains('\0')); + serializer.serialize_str(s) + } + + pub(crate) fn deserialize<'de, D>(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + struct NoNulBytesVisitor; + + impl<'de> Visitor<'de> for NoNulBytesVisitor { + type Value = String; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a string without any NUL bytes") + } + + fn visit_str(self, value: &str) -> Result + where + E: de::Error, + { + if value.contains('\0') { + Err(de::Error::custom("contained NUL byte")) + } else { + Ok(value.to_owned()) + } + } + + fn visit_string(self, value: String) -> Result + where + E: de::Error, + { + if value.contains('\0') { + Err(de::Error::custom("contained NUL byte")) + } else { + Ok(value) + } + } + } + + deserializer.deserialize_string(NoNulBytesVisitor) + } +} + #[cfg(test)] mod tests { @@ -269,9 +329,16 @@ mod tests { assert_eq!(out, Strand::from("te\"st\n\tand\u{08}some\u{05d9}")); } + #[test] + fn strand_nul_byte() { + assert!(strand("'a\0b'").is_err()); + assert!(strand("'a\\u0000b'").is_err()); + assert!(strand("'a\\u{0}b'").is_err()); + } + #[test] fn strand_fuzz_escape() { - for n in (0..=char::MAX as u32).step_by(101) { + for n in (1..=char::MAX as u32).step_by(101) { if let Some(c) = char::from_u32(n) { let expected = format!("a{c}b"); diff --git a/lib/src/sql/table.rs b/lib/src/sql/table.rs index 57ac91a6..928ab8f0 100644 --- a/lib/src/sql/table.rs +++ b/lib/src/sql/table.rs @@ -4,6 +4,7 @@ use crate::sql::escape::escape_ident; use crate::sql::fmt::Fmt; use crate::sql::id::Id; use crate::sql::ident::{ident_raw, Ident}; +use crate::sql::strand::no_nul_bytes; use crate::sql::thing::Thing; use nom::multi::separated_list1; use serde::{Deserialize, Serialize}; @@ -42,7 +43,7 @@ pub fn tables(i: &str) -> IResult<&str, Tables> { #[derive(Clone, Debug, Default, Eq, PartialEq, PartialOrd, Serialize, Deserialize, Hash)] #[serde(rename = "$surrealdb::private::sql::Table")] -pub struct Table(pub String); +pub struct Table(#[serde(with = "no_nul_bytes")] pub String); impl From for Table { fn from(v: String) -> Self { diff --git a/lib/src/sql/value/compare.rs b/lib/src/sql/value/compare.rs index e85dd5bc..44b8ec32 100644 --- a/lib/src/sql/value/compare.rs +++ b/lib/src/sql/value/compare.rs @@ -16,7 +16,7 @@ impl Value { Some(p) => match (self, other) { // Current path part is an object (Value::Object(a), Value::Object(b)) => match p { - Part::Field(f) => match (a.get(f as &str), b.get(f as &str)) { + Part::Field(f) => match (a.get(f.as_str()), b.get(f.as_str())) { (Some(a), Some(b)) => a.compare(b, path.next(), collate, numeric), (Some(_), None) => Some(Ordering::Greater), (None, Some(_)) => Some(Ordering::Less), diff --git a/lib/src/sql/value/cut.rs b/lib/src/sql/value/cut.rs index dfc2a306..108861a8 100644 --- a/lib/src/sql/value/cut.rs +++ b/lib/src/sql/value/cut.rs @@ -13,10 +13,10 @@ impl Value { if let Part::Field(f) = p { match path.len() { 1 => { - v.remove(f as &str); + v.remove(f.as_str()); } _ => { - if let Some(v) = v.get_mut(f as &str) { + if let Some(v) = v.get_mut(f.as_str()) { v.cut(path.next()) } } diff --git a/lib/src/sql/value/del.rs b/lib/src/sql/value/del.rs index 4890f99e..47034033 100644 --- a/lib/src/sql/value/del.rs +++ b/lib/src/sql/value/del.rs @@ -28,10 +28,10 @@ impl Value { Value::Object(v) => match p { Part::Field(f) => match path.len() { 1 => { - v.remove(f as &str); + v.remove(f.as_str()); Ok(()) } - _ => match v.get_mut(f as &str) { + _ => match v.get_mut(f.as_str()) { Some(v) if v.is_some() => v.del(ctx, opt, txn, path.next()).await, _ => Ok(()), }, diff --git a/lib/src/sql/value/value.rs b/lib/src/sql/value/value.rs index 16382de3..66846574 100644 --- a/lib/src/sql/value/value.rs +++ b/lib/src/sql/value/value.rs @@ -402,13 +402,13 @@ impl From for Value { impl From for Value { fn from(v: String) -> Self { - Value::Strand(Strand::from(v)) + Self::Strand(Strand::from(v)) } } impl From<&str> for Value { fn from(v: &str) -> Self { - Value::Strand(Strand::from(v)) + Self::Strand(Strand::from(v)) } }