From c63fc47bc0028de96b8f805a223ba0c3fc091a46 Mon Sep 17 00:00:00 2001 From: Tobie Morgan Hitchcock Date: Mon, 17 Oct 2022 00:04:07 +0100 Subject: [PATCH] Ensure Record IDs with string-based integers, are output correctly Closes #1327 --- lib/src/sql/common.rs | 8 ++-- lib/src/sql/escape.rs | 64 +++++++++++++++++++++----------- lib/src/sql/id.rs | 18 ++++++++- lib/src/sql/thing.rs | 4 +- lib/tests/escape.rs | 86 +++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 151 insertions(+), 29 deletions(-) create mode 100644 lib/tests/escape.rs diff --git a/lib/src/sql/common.rs b/lib/src/sql/common.rs index fd233b3d..fb496216 100644 --- a/lib/src/sql/common.rs +++ b/lib/src/sql/common.rs @@ -40,13 +40,13 @@ pub fn is_digit(chr: char) -> bool { } #[inline] -pub fn val_char(chr: char) -> bool { - chr.is_ascii_alphanumeric() || chr == '_' +pub fn val_u8(chr: u8) -> bool { + is_alphanumeric(chr) || chr == b'_' } #[inline] -pub fn val_u8(chr: u8) -> bool { - is_alphanumeric(chr) || chr == b'_' +pub fn val_char(chr: char) -> bool { + chr.is_ascii_alphanumeric() || chr == '_' } pub fn take_u64(i: &str) -> IResult<&str, u64> { diff --git a/lib/src/sql/escape.rs b/lib/src/sql/escape.rs index 5b276a2d..87c8506e 100644 --- a/lib/src/sql/escape.rs +++ b/lib/src/sql/escape.rs @@ -1,8 +1,10 @@ use crate::sql::common::val_u8; +use nom::character::is_digit; use std::borrow::Cow; -const BRACKET_L: char = '⟨'; -const BRACKET_R: char = '⟩'; +const BRACKETL: char = '⟨'; +const BRACKETR: char = '⟩'; +const BRACKET_ESC: &str = r#"\⟩"#; const DOUBLE: char = '"'; const DOUBLE_ESC: &str = r#"\""#; @@ -16,36 +18,56 @@ pub fn escape_str(s: &str) -> String { } #[inline] -pub fn escape_id(s: &str) -> Cow<'_, str> { - for x in s.bytes() { - if !val_u8(x) { - return Cow::Owned(format!("{}{}{}", BRACKET_L, s, BRACKET_R)); - } - } - Cow::Borrowed(s) -} - -#[inline] +/// Escapes a key if necessary pub fn escape_key(s: &str) -> Cow<'_, str> { + escape_normal(s, DOUBLE, DOUBLE, DOUBLE_ESC) +} + +#[inline] +/// Escapes an id if necessary +pub fn escape_rid(s: &str) -> Cow<'_, str> { + escape_numeric(s, BRACKETL, BRACKETR, BRACKET_ESC) +} + +#[inline] +/// Escapes an ident if necessary +pub fn escape_ident(s: &str) -> Cow<'_, str> { + escape_numeric(s, BACKTICK, BACKTICK, BACKTICK_ESC) +} + +#[inline] +pub fn escape_normal<'a>(s: &'a str, l: char, r: char, e: &str) -> Cow<'a, str> { + // Loop over each character for x in s.bytes() { + // Check if character is allowed if !val_u8(x) { - return Cow::Owned(format!("{}{}{}", DOUBLE, s.replace(DOUBLE, DOUBLE_ESC), DOUBLE)); + return Cow::Owned(format!("{}{}{}", l, s.replace(r, e), r)); } } + // Output the value Cow::Borrowed(s) } #[inline] -pub fn escape_ident(s: &str) -> Cow<'_, str> { +pub fn escape_numeric<'a>(s: &'a str, l: char, r: char, e: &str) -> Cow<'a, str> { + // Presume this is numeric + let mut numeric = true; + // Loop over each character for x in s.bytes() { + // Check if character is allowed if !val_u8(x) { - return Cow::Owned(format!( - "{}{}{}", - BACKTICK, - s.replace(BACKTICK, BACKTICK_ESC), - BACKTICK - )); + return Cow::Owned(format!("{}{}{}", l, s.replace(r, e), r)); + } + // Check if character is non-numeric + if !is_digit(x) { + numeric = false; } } - Cow::Borrowed(s) + // Output the id value + match numeric { + // This is numeric so escape it + true => Cow::Owned(format!("{}{}{}", l, s.replace(r, e), r)), + // No need to escape the value + _ => Cow::Borrowed(s), + } } diff --git a/lib/src/sql/id.rs b/lib/src/sql/id.rs index 323750ff..852e71f1 100644 --- a/lib/src/sql/id.rs +++ b/lib/src/sql/id.rs @@ -1,7 +1,7 @@ use crate::cnf::ID_CHARS; use crate::sql::array::{array, Array}; use crate::sql::error::IResult; -use crate::sql::escape::escape_id; +use crate::sql::escape::escape_rid; use crate::sql::ident::ident_raw; use crate::sql::number::integer; use crate::sql::object::{object, Object}; @@ -100,7 +100,7 @@ impl Display for Id { fn fmt(&self, f: &mut Formatter) -> fmt::Result { match self { Self::Number(v) => Display::fmt(v, f), - Self::String(v) => Display::fmt(&escape_id(v), f), + Self::String(v) => Display::fmt(&escape_rid(v), f), Self::Object(v) => Display::fmt(v, f), Self::Array(v) => Display::fmt(v, f), } @@ -128,6 +128,7 @@ mod tests { assert!(res.is_ok()); let out = res.unwrap().1; assert_eq!(Id::from(1), out); + assert_eq!("1", format!("{}", out)); } #[test] @@ -137,6 +138,7 @@ mod tests { assert!(res.is_ok()); let out = res.unwrap().1; assert_eq!(Id::from(100), out); + assert_eq!("100", format!("{}", out)); } #[test] @@ -146,6 +148,17 @@ mod tests { assert!(res.is_ok()); let out = res.unwrap().1; assert_eq!(Id::from("test"), out); + assert_eq!("test", format!("{}", out)); + } + + #[test] + fn id_numeric() { + let sql = "⟨100⟩"; + let res = id(sql); + assert!(res.is_ok()); + let out = res.unwrap().1; + assert_eq!(Id::from("100"), out); + assert_eq!("⟨100⟩", format!("{}", out)); } #[test] @@ -155,5 +168,6 @@ mod tests { assert!(res.is_ok()); let out = res.unwrap().1; assert_eq!(Id::from("100test"), out); + assert_eq!("100test", format!("{}", out)); } } diff --git a/lib/src/sql/thing.rs b/lib/src/sql/thing.rs index 88690fd3..111cf5e4 100644 --- a/lib/src/sql/thing.rs +++ b/lib/src/sql/thing.rs @@ -3,7 +3,7 @@ use crate::dbs::Options; use crate::dbs::Transaction; use crate::err::Error; use crate::sql::error::IResult; -use crate::sql::escape::escape_id; +use crate::sql::escape::escape_rid; use crate::sql::id::{id, Id}; use crate::sql::ident::ident_raw; use crate::sql::serde::is_internal_serialization; @@ -51,7 +51,7 @@ impl Thing { impl fmt::Display for Thing { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}:{}", escape_id(&self.tb), self.id) + write!(f, "{}:{}", escape_rid(&self.tb), self.id) } } diff --git a/lib/tests/escape.rs b/lib/tests/escape.rs new file mode 100644 index 00000000..77762e04 --- /dev/null +++ b/lib/tests/escape.rs @@ -0,0 +1,86 @@ +mod parse; +use parse::Parse; +use surrealdb::sql::Value; +use surrealdb::Datastore; +use surrealdb::Error; +use surrealdb::Session; + +#[tokio::test] +async fn complex_string() -> Result<(), Error> { + let sql = r#" + CREATE person:100 SET test = 'One'; + CREATE person:00100; + CREATE 'person:100'; + CREATE "person:100"; + CREATE person:⟨100⟩ SET test = 'Two'; + CREATE person:`100`; + SELECT * FROM person; + "#; + let dbs = Datastore::new("memory").await?; + let ses = Session::for_kv().with_ns("test").with_db("test"); + let res = &mut dbs.execute(&sql, &ses, None, false).await?; + assert_eq!(res.len(), 7); + // + let tmp = res.remove(0).result?; + let val = Value::parse( + "[ + { + id: person:100, + test: 'One' + } + ]", + ); + assert_eq!(tmp, val); + // + let tmp = res.remove(0).result; + assert!(matches!( + tmp.err(), + Some(e) if e.to_string() == r#"Database record `person:100` already exists"# + )); + // + let tmp = res.remove(0).result; + assert!(matches!( + tmp.err(), + Some(e) if e.to_string() == r#"Database record `person:100` already exists"# + )); + // + let tmp = res.remove(0).result; + assert!(matches!( + tmp.err(), + Some(e) if e.to_string() == r#"Database record `person:100` already exists"# + )); + // + let tmp = res.remove(0).result?; + let val = Value::parse( + "[ + { + id: person:⟨100⟩, + test: 'Two' + } + ]", + ); + assert_eq!(tmp, val); + // + let tmp = res.remove(0).result; + assert!(matches!( + tmp.err(), + Some(e) if e.to_string() == r#"Database record `person:⟨100⟩` already exists"# + )); + // + let tmp = res.remove(0).result?; + let val = Value::parse( + "[ + { + id: person:100, + test: 'One' + }, + { + id: person:⟨100⟩, + test: 'Two' + } + ]", + ); + assert_eq!(tmp, val); + // + Ok(()) +}