Bugfix - Prevent NUL bytes from existing in UTF-8 strings (#1941)

This commit is contained in:
Finn Bear 2023-05-09 10:48:14 -07:00 committed by GitHub
parent 3d76645908
commit 73374d4799
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 194 additions and 44 deletions

View file

@ -3,12 +3,21 @@ use crate::sql::array::Array;
use crate::sql::datetime::Datetime;
use crate::sql::object::Object;
use crate::sql::value::Value;
use crate::sql::Id;
use chrono::{TimeZone, Utc};
use js::Ctx;
use js::Error;
use js::FromAtom;
use js::FromJs;
fn check_nul(s: &str) -> Result<(), Error> {
if s.contains('\0') {
Err(Error::InvalidString(std::ffi::CString::new(s).unwrap_err()))
} else {
Ok(())
}
}
impl<'js> FromJs<'js> for Value {
fn from_js(ctx: Ctx<'js>, val: js::Value<'js>) -> Result<Self, Error> {
match val {
@ -16,7 +25,10 @@ impl<'js> FromJs<'js> for Value {
val if val.type_name() == "undefined" => Ok(Value::None),
val if val.is_bool() => Ok(val.as_bool().unwrap().into()),
val if val.is_string() => match val.into_string().unwrap().to_string() {
Ok(v) => Ok(Value::from(v)),
Ok(v) => {
check_nul(&v)?;
Ok(Value::from(v))
}
Err(e) => Err(e),
},
val if val.is_int() => Ok(val.as_int().unwrap().into()),
@ -48,6 +60,10 @@ impl<'js> FromJs<'js> for Value {
if (v).instance_of::<classes::record::record::Record>() {
let v = v.into_instance::<classes::record::record::Record>().unwrap();
let v: &classes::record::record::Record = v.as_ref();
check_nul(&v.value.tb)?;
if let Id::String(s) = &v.value.id {
check_nul(&s)?;
}
return Ok(v.value.clone().into());
}
// Check to see if this object is a duration
@ -95,6 +111,7 @@ impl<'js> FromJs<'js> for Value {
for i in v.props() {
let (k, v) = i?;
let k = String::from_atom(k)?;
check_nul(&k)?;
let v = Value::from_js(ctx, v)?;
x.insert(k, v);
}

View file

@ -24,6 +24,7 @@ use ulid::Ulid;
#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize, Hash)]
pub enum Id {
Number(i64),
/// Invariant: Doesn't contain NUL bytes.
String(String),
Array(Array),
Object(Object),

View file

@ -1,6 +1,7 @@
use crate::sql::common::val_char;
use crate::sql::error::IResult;
use crate::sql::escape::escape_ident;
use crate::sql::strand::no_nul_bytes;
use nom::branch::alt;
use nom::bytes::complete::escaped_transform;
use nom::bytes::complete::is_not;
@ -18,13 +19,13 @@ use std::str;
const BRACKET_L: char = '⟨';
const BRACKET_R: char = '⟩';
const BRACKET_END: &str = r#"⟩"#;
const BRACKET_END_NUL: &str = "\0";
const BACKTICK: char = '`';
const BACKTICK_ESC: &str = r#"\`"#;
const BACKTICK_ESC_NUL: &str = "`\\\0";
#[derive(Clone, Debug, Default, Eq, Ord, PartialEq, PartialOrd, Serialize, Deserialize, Hash)]
pub struct Ident(pub String);
pub struct Ident(#[serde(with = "no_nul_bytes")] pub String);
impl From<String> for Ident {
fn from(v: String) -> Self {
@ -90,7 +91,7 @@ fn ident_default(i: &str) -> IResult<&str, String> {
fn ident_backtick(i: &str) -> IResult<&str, String> {
let (i, _) = char(BACKTICK)(i)?;
let (i, v) = escaped_transform(
is_not(BACKTICK_ESC),
is_not(BACKTICK_ESC_NUL),
'\\',
alt((
value('\u{5c}', char('\\')),
@ -108,7 +109,7 @@ fn ident_backtick(i: &str) -> IResult<&str, String> {
}
fn ident_brackets(i: &str) -> IResult<&str, String> {
let (i, v) = delimited(char(BRACKET_L), is_not(BRACKET_END), char(BRACKET_R))(i)?;
let (i, v) = delimited(char(BRACKET_L), is_not(BRACKET_END_NUL), char(BRACKET_R))(i)?;
Ok((i, String::from(v)))
}

View file

@ -26,9 +26,10 @@ use std::ops::DerefMut;
pub(crate) const TOKEN: &str = "$surrealdb::private::sql::Object";
/// Invariant: Keys never contain NUL bytes.
#[derive(Clone, Debug, Default, Eq, Ord, PartialEq, PartialOrd, Serialize, Deserialize, Hash)]
#[serde(rename = "$surrealdb::private::sql::Object")]
pub struct Object(pub BTreeMap<String, Value>);
pub struct Object(#[serde(with = "no_nul_bytes_in_keys")] pub BTreeMap<String, Value>);
impl From<BTreeMap<String, Value>> for Object {
fn from(v: BTreeMap<String, Value>) -> Self {
@ -51,13 +52,13 @@ impl From<Option<Self>> for Object {
impl From<Operation> for Object {
fn from(v: Operation) -> Self {
Self(map! {
String::from("op") => match v.op {
Op::None => Value::from("none"),
Op::Add => Value::from("add"),
Op::Remove => Value::from("remove"),
Op::Replace => Value::from("replace"),
Op::Change => Value::from("change"),
},
String::from("op") => Value::from(match v.op {
Op::None => "none",
Op::Add => "add",
Op::Remove => "remove",
Op::Replace => "replace",
Op::Change => "change",
}),
String::from("path") => v.path.to_path().into(),
String::from("value") => v.value,
})
@ -167,6 +168,60 @@ impl Display for Object {
}
}
mod no_nul_bytes_in_keys {
use serde::{
de::{self, Visitor},
ser::SerializeMap,
Deserializer, Serializer,
};
use std::{collections::BTreeMap, fmt};
use crate::sql::Value;
pub(crate) fn serialize<S>(
m: &BTreeMap<String, Value>,
serializer: S,
) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut s = serializer.serialize_map(Some(m.len()))?;
for (k, v) in m {
debug_assert!(!k.contains('\0'));
s.serialize_entry(k, v)?;
}
s.end()
}
pub(crate) fn deserialize<'de, D>(deserializer: D) -> Result<BTreeMap<String, Value>, D::Error>
where
D: Deserializer<'de>,
{
struct NoNulBytesInKeysVisitor;
impl<'de> Visitor<'de> for NoNulBytesInKeysVisitor {
type Value = BTreeMap<String, Value>;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a map without any NUL bytes in its keys")
}
fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
where
A: de::MapAccess<'de>,
{
let mut ret = BTreeMap::new();
while let Some((k, v)) = map.next_entry()? {
ret.insert(k, v);
}
Ok(ret)
}
}
deserializer.deserialize_map(NoNulBytesInKeysVisitor)
}
}
pub fn object(i: &str) -> IResult<&str, Object> {
let (i, _) = char('{')(i)?;
let (i, _) = mightbespace(i)?;
@ -194,11 +249,11 @@ fn key_none(i: &str) -> IResult<&str, &str> {
}
fn key_single(i: &str) -> IResult<&str, &str> {
delimited(char('\''), is_not("\'"), char('\''))(i)
delimited(char('\''), is_not("\'\0"), char('\''))(i)
}
fn key_double(i: &str) -> IResult<&str, &str> {
delimited(char('\"'), is_not("\""), char('\"'))(i)
delimited(char('\"'), is_not("\"\0"), char('\"'))(i)
}
#[cfg(test)]

View file

@ -6,6 +6,7 @@ use crate::sql::graph::{self, Graph};
use crate::sql::ident::{self, Ident};
use crate::sql::idiom::Idiom;
use crate::sql::number::{number, Number};
use crate::sql::strand::no_nul_bytes;
use crate::sql::value::{self, Value};
use nom::branch::alt;
use nom::bytes::complete::tag;
@ -27,7 +28,7 @@ pub enum Part {
Where(Value),
Graph(Graph),
Value(Value),
Method(String, Vec<Value>),
Method(#[serde(with = "no_nul_bytes")] String, Vec<Value>),
}
impl From<i32> for Part {

View file

@ -5,6 +5,7 @@ use crate::err::Error;
use crate::sql::error::IResult;
use crate::sql::id::{id, Id};
use crate::sql::ident::ident_raw;
use crate::sql::strand::no_nul_bytes;
use crate::sql::value::Value;
use nom::branch::alt;
use nom::character::complete::char;
@ -22,6 +23,7 @@ pub(crate) const TOKEN: &str = "$surrealdb::private::sql::Range";
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Hash)]
#[serde(rename = "$surrealdb::private::sql::Range")]
pub struct Range {
#[serde(with = "no_nul_bytes")]
pub tb: String,
pub beg: Bound<Id>,
pub end: Bound<Id>,

View file

@ -30,9 +30,13 @@ impl FromStr for Regex {
type Err = <regex::Regex as FromStr>::Err;
fn from_str(s: &str) -> Result<Self, Self::Err> {
if s.contains('\0') {
Err(regex::Error::Syntax("regex contained NUL byte".to_owned()))
} else {
regex::Regex::new(&s.replace("\\/", "/")).map(Self)
}
}
}
impl PartialEq for Regex {
fn eq(&self, other: &Self) -> bool {

View file

@ -1,5 +1,6 @@
use crate::sql::comment::{block, slash};
use crate::sql::error::IResult;
use crate::sql::strand::no_nul_bytes;
use nom::branch::alt;
use nom::bytes::complete::escaped;
use nom::bytes::complete::is_not;
@ -16,19 +17,19 @@ use std::ops::Deref;
use std::str;
const SINGLE: char = '\'';
const SINGLE_ESC: &str = r#"\'"#;
const SINGLE_ESC_NUL: &str = "'\\\0";
const DOUBLE: char = '"';
const DOUBLE_ESC: &str = r#"\""#;
const DOUBLE_ESC_NUL: &str = "\"\\\0";
const BACKTICK: char = '`';
const BACKTICK_ESC: &str = r#"\`"#;
const BACKTICK_ESC_NUL: &str = "`\\\0";
const OBJECT_BEG: char = '{';
const OBJECT_END: char = '}';
#[derive(Clone, Debug, Default, Eq, PartialEq, PartialOrd, Serialize, Deserialize, Hash)]
pub struct Script(pub String);
pub struct Script(#[serde(with = "no_nul_bytes")] pub String);
impl From<String> for Script {
fn from(s: String) -> Self {
@ -100,19 +101,19 @@ fn script_string(i: &str) -> IResult<&str, &str> {
},
|i| {
let (i, _) = char(SINGLE)(i)?;
let (i, v) = escaped(is_not(SINGLE_ESC), '\\', char(SINGLE))(i)?;
let (i, v) = escaped(is_not(SINGLE_ESC_NUL), '\\', char(SINGLE))(i)?;
let (i, _) = char(SINGLE)(i)?;
Ok((i, v))
},
|i| {
let (i, _) = char(DOUBLE)(i)?;
let (i, v) = escaped(is_not(DOUBLE_ESC), '\\', char(DOUBLE))(i)?;
let (i, v) = escaped(is_not(DOUBLE_ESC_NUL), '\\', char(DOUBLE))(i)?;
let (i, _) = char(DOUBLE)(i)?;
Ok((i, v))
},
|i| {
let (i, _) = char(BACKTICK)(i)?;
let (i, v) = escaped(is_not(BACKTICK_ESC), '\\', char(BACKTICK))(i)?;
let (i, v) = escaped(is_not(BACKTICK_ESC_NUL), '\\', char(BACKTICK))(i)?;
let (i, _) = char(BACKTICK)(i)?;
Ok((i, v))
},

View file

@ -16,26 +16,29 @@ use std::str;
pub(crate) const TOKEN: &str = "$surrealdb::private::sql::Strand";
const SINGLE: char = '\'';
const SINGLE_ESC: &str = r#"\'"#;
const SINGLE_ESC_NUL: &str = "'\\\0";
const DOUBLE: char = '"';
const DOUBLE_ESC: &str = r#"\""#;
const DOUBLE_ESC_NUL: &str = "\"\\\0";
const LEADING_SURROGATES: RangeInclusive<u16> = 0xD800..=0xDBFF;
const TRAILING_SURROGATES: RangeInclusive<u16> = 0xDC00..=0xDFFF;
#[derive(Clone, Debug, Default, Eq, PartialEq, PartialOrd, Serialize, Deserialize, Hash)]
/// A string that doesn't contain NUL bytes.
#[derive(Clone, Debug, Default, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize, Hash)]
#[serde(rename = "$surrealdb::private::sql::Strand")]
pub struct Strand(pub String);
pub struct Strand(#[serde(with = "no_nul_bytes")] pub String);
impl From<String> for Strand {
fn from(s: String) -> Self {
debug_assert!(!s.contains('\0'));
Strand(s)
}
}
impl From<&str> for Strand {
fn from(s: &str) -> Self {
debug_assert!(!s.contains('\0'));
Self::from(String::from(s))
}
}
@ -76,8 +79,9 @@ impl Display for Strand {
impl ops::Add for Strand {
type Output = Self;
fn add(self, other: Self) -> Self {
Strand::from(self.0 + &other.0)
fn add(mut self, other: Self) -> Self {
self.0.push_str(other.as_str());
self
}
}
@ -108,7 +112,7 @@ fn strand_blank(i: &str) -> IResult<&str, String> {
fn strand_single(i: &str) -> IResult<&str, String> {
let (i, _) = char(SINGLE)(i)?;
let (i, v) = escaped_transform(
is_not(SINGLE_ESC),
is_not(SINGLE_ESC_NUL),
'\\',
alt((
char_unicode,
@ -129,7 +133,7 @@ fn strand_single(i: &str) -> IResult<&str, String> {
fn strand_double(i: &str) -> IResult<&str, String> {
let (i, _) = char(DOUBLE)(i)?;
let (i, v) = escaped_transform(
is_not(DOUBLE_ESC),
is_not(DOUBLE_ESC_NUL),
'\\',
alt((
char_unicode,
@ -182,7 +186,7 @@ fn char_unicode_bare(i: &str) -> IResult<&str, char> {
Ok((i, v))
} else {
// We can convert this to char or error in the case of invalid Unicode character
let v = char::from_u32(v as u32).ok_or(Failure(Parser(i)))?;
let v = char::from_u32(v as u32).filter(|c| *c != 0 as char).ok_or(Failure(Parser(i)))?;
// Return the char
Ok((i, v))
}
@ -197,13 +201,69 @@ fn char_unicode_bracketed(i: &str) -> IResult<&str, char> {
// We can convert this to u32 as the max is 0xffffff
let v = u32::from_str_radix(v, 16).unwrap();
// We can convert this to char or error in the case of invalid Unicode character
let v = char::from_u32(v).ok_or(Failure(Parser(i)))?;
let v = char::from_u32(v).filter(|c| *c != 0 as char).ok_or(Failure(Parser(i)))?;
// Read the } character
let (i, _) = char('}')(i)?;
// Return the char
Ok((i, v))
}
// serde(with = no_nul_bytes) will (de)serialize with no NUL bytes.
pub(crate) mod no_nul_bytes {
use serde::{
de::{self, Visitor},
Deserializer, Serializer,
};
use std::fmt;
pub(crate) fn serialize<S>(s: &String, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
debug_assert!(!s.contains('\0'));
serializer.serialize_str(s)
}
pub(crate) fn deserialize<'de, D>(deserializer: D) -> Result<String, D::Error>
where
D: Deserializer<'de>,
{
struct NoNulBytesVisitor;
impl<'de> Visitor<'de> for NoNulBytesVisitor {
type Value = String;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a string without any NUL bytes")
}
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
where
E: de::Error,
{
if value.contains('\0') {
Err(de::Error::custom("contained NUL byte"))
} else {
Ok(value.to_owned())
}
}
fn visit_string<E>(self, value: String) -> Result<Self::Value, E>
where
E: de::Error,
{
if value.contains('\0') {
Err(de::Error::custom("contained NUL byte"))
} else {
Ok(value)
}
}
}
deserializer.deserialize_string(NoNulBytesVisitor)
}
}
#[cfg(test)]
mod tests {
@ -269,9 +329,16 @@ mod tests {
assert_eq!(out, Strand::from("te\"st\n\tand\u{08}some\u{05d9}"));
}
#[test]
fn strand_nul_byte() {
assert!(strand("'a\0b'").is_err());
assert!(strand("'a\\u0000b'").is_err());
assert!(strand("'a\\u{0}b'").is_err());
}
#[test]
fn strand_fuzz_escape() {
for n in (0..=char::MAX as u32).step_by(101) {
for n in (1..=char::MAX as u32).step_by(101) {
if let Some(c) = char::from_u32(n) {
let expected = format!("a{c}b");

View file

@ -4,6 +4,7 @@ use crate::sql::escape::escape_ident;
use crate::sql::fmt::Fmt;
use crate::sql::id::Id;
use crate::sql::ident::{ident_raw, Ident};
use crate::sql::strand::no_nul_bytes;
use crate::sql::thing::Thing;
use nom::multi::separated_list1;
use serde::{Deserialize, Serialize};
@ -42,7 +43,7 @@ pub fn tables(i: &str) -> IResult<&str, Tables> {
#[derive(Clone, Debug, Default, Eq, PartialEq, PartialOrd, Serialize, Deserialize, Hash)]
#[serde(rename = "$surrealdb::private::sql::Table")]
pub struct Table(pub String);
pub struct Table(#[serde(with = "no_nul_bytes")] pub String);
impl From<String> for Table {
fn from(v: String) -> Self {

View file

@ -16,7 +16,7 @@ impl Value {
Some(p) => match (self, other) {
// Current path part is an object
(Value::Object(a), Value::Object(b)) => match p {
Part::Field(f) => match (a.get(f as &str), b.get(f as &str)) {
Part::Field(f) => match (a.get(f.as_str()), b.get(f.as_str())) {
(Some(a), Some(b)) => a.compare(b, path.next(), collate, numeric),
(Some(_), None) => Some(Ordering::Greater),
(None, Some(_)) => Some(Ordering::Less),

View file

@ -13,10 +13,10 @@ impl Value {
if let Part::Field(f) = p {
match path.len() {
1 => {
v.remove(f as &str);
v.remove(f.as_str());
}
_ => {
if let Some(v) = v.get_mut(f as &str) {
if let Some(v) = v.get_mut(f.as_str()) {
v.cut(path.next())
}
}

View file

@ -28,10 +28,10 @@ impl Value {
Value::Object(v) => match p {
Part::Field(f) => match path.len() {
1 => {
v.remove(f as &str);
v.remove(f.as_str());
Ok(())
}
_ => match v.get_mut(f as &str) {
_ => match v.get_mut(f.as_str()) {
Some(v) if v.is_some() => v.del(ctx, opt, txn, path.next()).await,
_ => Ok(()),
},

View file

@ -402,13 +402,13 @@ impl From<BigDecimal> for Value {
impl From<String> for Value {
fn from(v: String) -> Self {
Value::Strand(Strand::from(v))
Self::Strand(Strand::from(v))
}
}
impl From<&str> for Value {
fn from(v: &str) -> Self {
Value::Strand(Strand::from(v))
Self::Strand(Strand::from(v))
}
}