Ensure Record IDs with string-based integers, are output correctly

Closes #1327
This commit is contained in:
Tobie Morgan Hitchcock 2022-10-17 00:04:07 +01:00
parent d148ca9ab9
commit c63fc47bc0
5 changed files with 151 additions and 29 deletions

View file

@ -40,13 +40,13 @@ pub fn is_digit(chr: char) -> bool {
} }
#[inline] #[inline]
pub fn val_char(chr: char) -> bool { pub fn val_u8(chr: u8) -> bool {
chr.is_ascii_alphanumeric() || chr == '_' is_alphanumeric(chr) || chr == b'_'
} }
#[inline] #[inline]
pub fn val_u8(chr: u8) -> bool { pub fn val_char(chr: char) -> bool {
is_alphanumeric(chr) || chr == b'_' chr.is_ascii_alphanumeric() || chr == '_'
} }
pub fn take_u64(i: &str) -> IResult<&str, u64> { pub fn take_u64(i: &str) -> IResult<&str, u64> {

View file

@ -1,8 +1,10 @@
use crate::sql::common::val_u8; use crate::sql::common::val_u8;
use nom::character::is_digit;
use std::borrow::Cow; use std::borrow::Cow;
const BRACKET_L: char = '⟨'; const BRACKETL: char = '⟨';
const BRACKET_R: char = '⟩'; const BRACKETR: char = '⟩';
const BRACKET_ESC: &str = r#"\⟩"#;
const DOUBLE: char = '"'; const DOUBLE: char = '"';
const DOUBLE_ESC: &str = r#"\""#; const DOUBLE_ESC: &str = r#"\""#;
@ -16,36 +18,56 @@ pub fn escape_str(s: &str) -> String {
} }
#[inline] #[inline]
pub fn escape_id(s: &str) -> Cow<'_, str> { /// Escapes a key if necessary
for x in s.bytes() {
if !val_u8(x) {
return Cow::Owned(format!("{}{}{}", BRACKET_L, s, BRACKET_R));
}
}
Cow::Borrowed(s)
}
#[inline]
pub fn escape_key(s: &str) -> Cow<'_, str> { pub fn escape_key(s: &str) -> Cow<'_, str> {
escape_normal(s, DOUBLE, DOUBLE, DOUBLE_ESC)
}
#[inline]
/// Escapes an id if necessary
pub fn escape_rid(s: &str) -> Cow<'_, str> {
escape_numeric(s, BRACKETL, BRACKETR, BRACKET_ESC)
}
#[inline]
/// Escapes an ident if necessary
pub fn escape_ident(s: &str) -> Cow<'_, str> {
escape_numeric(s, BACKTICK, BACKTICK, BACKTICK_ESC)
}
#[inline]
pub fn escape_normal<'a>(s: &'a str, l: char, r: char, e: &str) -> Cow<'a, str> {
// Loop over each character
for x in s.bytes() { for x in s.bytes() {
// Check if character is allowed
if !val_u8(x) { if !val_u8(x) {
return Cow::Owned(format!("{}{}{}", DOUBLE, s.replace(DOUBLE, DOUBLE_ESC), DOUBLE)); return Cow::Owned(format!("{}{}{}", l, s.replace(r, e), r));
} }
} }
// Output the value
Cow::Borrowed(s) Cow::Borrowed(s)
} }
#[inline] #[inline]
pub fn escape_ident(s: &str) -> Cow<'_, str> { pub fn escape_numeric<'a>(s: &'a str, l: char, r: char, e: &str) -> Cow<'a, str> {
// Presume this is numeric
let mut numeric = true;
// Loop over each character
for x in s.bytes() { for x in s.bytes() {
// Check if character is allowed
if !val_u8(x) { if !val_u8(x) {
return Cow::Owned(format!( return Cow::Owned(format!("{}{}{}", l, s.replace(r, e), r));
"{}{}{}", }
BACKTICK, // Check if character is non-numeric
s.replace(BACKTICK, BACKTICK_ESC), if !is_digit(x) {
BACKTICK numeric = false;
));
} }
} }
Cow::Borrowed(s) // Output the id value
match numeric {
// This is numeric so escape it
true => Cow::Owned(format!("{}{}{}", l, s.replace(r, e), r)),
// No need to escape the value
_ => Cow::Borrowed(s),
}
} }

View file

@ -1,7 +1,7 @@
use crate::cnf::ID_CHARS; use crate::cnf::ID_CHARS;
use crate::sql::array::{array, Array}; use crate::sql::array::{array, Array};
use crate::sql::error::IResult; use crate::sql::error::IResult;
use crate::sql::escape::escape_id; use crate::sql::escape::escape_rid;
use crate::sql::ident::ident_raw; use crate::sql::ident::ident_raw;
use crate::sql::number::integer; use crate::sql::number::integer;
use crate::sql::object::{object, Object}; use crate::sql::object::{object, Object};
@ -100,7 +100,7 @@ impl Display for Id {
fn fmt(&self, f: &mut Formatter) -> fmt::Result { fn fmt(&self, f: &mut Formatter) -> fmt::Result {
match self { match self {
Self::Number(v) => Display::fmt(v, f), Self::Number(v) => Display::fmt(v, f),
Self::String(v) => Display::fmt(&escape_id(v), f), Self::String(v) => Display::fmt(&escape_rid(v), f),
Self::Object(v) => Display::fmt(v, f), Self::Object(v) => Display::fmt(v, f),
Self::Array(v) => Display::fmt(v, f), Self::Array(v) => Display::fmt(v, f),
} }
@ -128,6 +128,7 @@ mod tests {
assert!(res.is_ok()); assert!(res.is_ok());
let out = res.unwrap().1; let out = res.unwrap().1;
assert_eq!(Id::from(1), out); assert_eq!(Id::from(1), out);
assert_eq!("1", format!("{}", out));
} }
#[test] #[test]
@ -137,6 +138,7 @@ mod tests {
assert!(res.is_ok()); assert!(res.is_ok());
let out = res.unwrap().1; let out = res.unwrap().1;
assert_eq!(Id::from(100), out); assert_eq!(Id::from(100), out);
assert_eq!("100", format!("{}", out));
} }
#[test] #[test]
@ -146,6 +148,17 @@ mod tests {
assert!(res.is_ok()); assert!(res.is_ok());
let out = res.unwrap().1; let out = res.unwrap().1;
assert_eq!(Id::from("test"), out); assert_eq!(Id::from("test"), out);
assert_eq!("test", format!("{}", out));
}
#[test]
fn id_numeric() {
let sql = "⟨100⟩";
let res = id(sql);
assert!(res.is_ok());
let out = res.unwrap().1;
assert_eq!(Id::from("100"), out);
assert_eq!("⟨100⟩", format!("{}", out));
} }
#[test] #[test]
@ -155,5 +168,6 @@ mod tests {
assert!(res.is_ok()); assert!(res.is_ok());
let out = res.unwrap().1; let out = res.unwrap().1;
assert_eq!(Id::from("100test"), out); assert_eq!(Id::from("100test"), out);
assert_eq!("100test", format!("{}", out));
} }
} }

View file

@ -3,7 +3,7 @@ use crate::dbs::Options;
use crate::dbs::Transaction; use crate::dbs::Transaction;
use crate::err::Error; use crate::err::Error;
use crate::sql::error::IResult; use crate::sql::error::IResult;
use crate::sql::escape::escape_id; use crate::sql::escape::escape_rid;
use crate::sql::id::{id, Id}; use crate::sql::id::{id, Id};
use crate::sql::ident::ident_raw; use crate::sql::ident::ident_raw;
use crate::sql::serde::is_internal_serialization; use crate::sql::serde::is_internal_serialization;
@ -51,7 +51,7 @@ impl Thing {
impl fmt::Display for Thing { impl fmt::Display for Thing {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}:{}", escape_id(&self.tb), self.id) write!(f, "{}:{}", escape_rid(&self.tb), self.id)
} }
} }

86
lib/tests/escape.rs Normal file
View file

@ -0,0 +1,86 @@
mod parse;
use parse::Parse;
use surrealdb::sql::Value;
use surrealdb::Datastore;
use surrealdb::Error;
use surrealdb::Session;
#[tokio::test]
async fn complex_string() -> Result<(), Error> {
let sql = r#"
CREATE person:100 SET test = 'One';
CREATE person:00100;
CREATE 'person:100';
CREATE "person:100";
CREATE person:100 SET test = 'Two';
CREATE person:`100`;
SELECT * FROM person;
"#;
let dbs = Datastore::new("memory").await?;
let ses = Session::for_kv().with_ns("test").with_db("test");
let res = &mut dbs.execute(&sql, &ses, None, false).await?;
assert_eq!(res.len(), 7);
//
let tmp = res.remove(0).result?;
let val = Value::parse(
"[
{
id: person:100,
test: 'One'
}
]",
);
assert_eq!(tmp, val);
//
let tmp = res.remove(0).result;
assert!(matches!(
tmp.err(),
Some(e) if e.to_string() == r#"Database record `person:100` already exists"#
));
//
let tmp = res.remove(0).result;
assert!(matches!(
tmp.err(),
Some(e) if e.to_string() == r#"Database record `person:100` already exists"#
));
//
let tmp = res.remove(0).result;
assert!(matches!(
tmp.err(),
Some(e) if e.to_string() == r#"Database record `person:100` already exists"#
));
//
let tmp = res.remove(0).result?;
let val = Value::parse(
"[
{
id: person:100,
test: 'Two'
}
]",
);
assert_eq!(tmp, val);
//
let tmp = res.remove(0).result;
assert!(matches!(
tmp.err(),
Some(e) if e.to_string() == r#"Database record `person:⟨100⟩` already exists"#
));
//
let tmp = res.remove(0).result?;
let val = Value::parse(
"[
{
id: person:100,
test: 'One'
},
{
id: person:100,
test: 'Two'
}
]",
);
assert_eq!(tmp, val);
//
Ok(())
}