From 74d8a36056a5e7a06b48f5542a1e4d73cd310527 Mon Sep 17 00:00:00 2001 From: Tobie Morgan Hitchcock Date: Mon, 24 May 2021 09:18:58 +0100 Subject: [PATCH] Improve string parsing and allow escaped characters --- src/sql/ident.rs | 28 +++++++++---- src/sql/model.rs | 4 +- src/sql/param.rs | 2 - src/sql/statements/define.rs | 28 ++++++------- src/sql/statements/info.rs | 4 +- src/sql/statements/remove.rs | 24 +++++------ src/sql/statements/yuse.rs | 8 ++-- src/sql/strand.rs | 81 +++++++++++++++++++++++++++++++++--- src/sql/table.rs | 8 ++++ src/sql/thing.rs | 18 ++++---- 10 files changed, 146 insertions(+), 59 deletions(-) diff --git a/src/sql/ident.rs b/src/sql/ident.rs index 28bbf191..b1f181d2 100644 --- a/src/sql/ident.rs +++ b/src/sql/ident.rs @@ -15,6 +15,14 @@ pub struct Ident { pub name: String, } +impl From for Ident { + fn from(s: String) -> Self { + Ident { + name: s, + } + } +} + impl<'a> From<&'a str> for Ident { fn from(i: &str) -> Ident { Ident { @@ -34,20 +42,24 @@ pub fn ident(i: &str) -> IResult<&str, Ident> { Ok((i, Ident::from(v))) } -pub fn ident_raw(i: &str) -> IResult<&str, &str> { - alt((ident_default, ident_backtick, ident_brackets))(i) +pub fn ident_raw(i: &str) -> IResult<&str, String> { + let (i, v) = alt((ident_default, ident_backtick, ident_brackets))(i)?; + Ok((i, String::from(v))) } -fn ident_default(i: &str) -> IResult<&str, &str> { - take_while1(val_char)(i) +fn ident_default(i: &str) -> IResult<&str, String> { + let (i, v) = take_while1(val_char)(i)?; + Ok((i, String::from(v))) } -fn ident_backtick(i: &str) -> IResult<&str, &str> { - delimited(tag("`"), is_not("`"), tag("`"))(i) +fn ident_backtick(i: &str) -> IResult<&str, String> { + let (i, v) = delimited(tag("`"), is_not("`"), tag("`"))(i)?; + Ok((i, String::from(v))) } -fn ident_brackets(i: &str) -> IResult<&str, &str> { - delimited(tag("⟨"), is_not("⟩"), tag("⟩"))(i) +fn ident_brackets(i: &str) -> IResult<&str, String> { + let (i, v) = delimited(tag("⟨"), is_not("⟩"), tag("⟩"))(i)?; + Ok((i, String::from(v))) } #[cfg(test)] diff --git a/src/sql/model.rs b/src/sql/model.rs index e4db1a32..38f1b3ff 100644 --- a/src/sql/model.rs +++ b/src/sql/model.rs @@ -50,7 +50,7 @@ fn model_count(i: &str) -> IResult<&str, Model> { Ok(( i, Model { - table: String::from(t), + table: t, count: Some(c), range: None, }, @@ -68,7 +68,7 @@ fn model_range(i: &str) -> IResult<&str, Model> { Ok(( i, Model { - table: String::from(t), + table: t, count: None, range: Some((b, e)), }, diff --git a/src/sql/param.rs b/src/sql/param.rs index 901afa68..a21841db 100644 --- a/src/sql/param.rs +++ b/src/sql/param.rs @@ -3,11 +3,9 @@ use crate::dbs::Executor; use crate::dbs::Runtime; use crate::doc::Document; use crate::err::Error; -use crate::sql::common::val_char; use crate::sql::idiom::{idiom, Idiom}; use crate::sql::literal::Literal; use nom::bytes::complete::tag; -use nom::bytes::complete::take_while1; use nom::IResult; use serde::{Deserialize, Serialize}; use std::fmt; diff --git a/src/sql/statements/define.rs b/src/sql/statements/define.rs index 5ac0a512..5e5bec94 100644 --- a/src/sql/statements/define.rs +++ b/src/sql/statements/define.rs @@ -102,7 +102,7 @@ fn namespace(i: &str) -> IResult<&str, DefineNamespaceStatement> { Ok(( i, DefineNamespaceStatement { - name: String::from(name), + name, }, )) } @@ -131,7 +131,7 @@ fn database(i: &str) -> IResult<&str, DefineDatabaseStatement> { Ok(( i, DefineDatabaseStatement { - name: String::from(name), + name, }, )) } @@ -177,7 +177,7 @@ fn login(i: &str) -> IResult<&str, DefineLoginStatement> { Ok(( i, DefineLoginStatement { - name: String::from(name), + name, base, pass: match opts { DefineLoginOption::Password(ref v) => Some(v.to_owned()), @@ -206,7 +206,7 @@ fn login_pass(i: &str) -> IResult<&str, DefineLoginOption> { let (i, _) = tag_no_case("PASSWORD")(i)?; let (i, _) = shouldbespace(i)?; let (i, v) = strand_raw(i)?; - Ok((i, DefineLoginOption::Password(String::from(v)))) + Ok((i, DefineLoginOption::Password(v))) } fn login_hash(i: &str) -> IResult<&str, DefineLoginOption> { @@ -214,7 +214,7 @@ fn login_hash(i: &str) -> IResult<&str, DefineLoginOption> { let (i, _) = tag_no_case("PASSHASH")(i)?; let (i, _) = shouldbespace(i)?; let (i, v) = strand_raw(i)?; - Ok((i, DefineLoginOption::Passhash(String::from(v)))) + Ok((i, DefineLoginOption::Passhash(v))) } // -------------------------------------------------- @@ -260,10 +260,10 @@ fn token(i: &str) -> IResult<&str, DefineTokenStatement> { Ok(( i, DefineTokenStatement { - name: String::from(name), + name, base, kind, - code: String::from(code), + code, }, )) } @@ -314,7 +314,7 @@ fn scope(i: &str) -> IResult<&str, DefineScopeStatement> { Ok(( i, DefineScopeStatement { - name: String::from(name), + name, session: opts.iter().find_map(|x| match x { DefineScopeOption::Session(ref v) => Some(v.to_owned()), _ => None, @@ -423,7 +423,7 @@ fn table(i: &str) -> IResult<&str, DefineTableStatement> { Ok(( i, DefineTableStatement { - name: String::from(name), + name, drop: opts .iter() .find_map(|x| match x { @@ -540,8 +540,8 @@ fn event(i: &str) -> IResult<&str, DefineEventStatement> { Ok(( i, DefineEventStatement { - name: String::from(name), - what: String::from(what), + name, + what, when, then, }, @@ -599,7 +599,7 @@ fn field(i: &str) -> IResult<&str, DefineFieldStatement> { i, DefineFieldStatement { name, - what: String::from(what), + what, kind: opts.iter().find_map(|x| match x { DefineFieldOption::Kind(ref v) => Some(v.to_owned()), _ => None, @@ -722,8 +722,8 @@ fn index(i: &str) -> IResult<&str, DefineIndexStatement> { Ok(( i, DefineIndexStatement { - name: String::from(name), - what: String::from(what), + name, + what, cols, uniq: uniq.is_some(), }, diff --git a/src/sql/statements/info.rs b/src/sql/statements/info.rs index acd47126..deef6a5d 100644 --- a/src/sql/statements/info.rs +++ b/src/sql/statements/info.rs @@ -64,14 +64,14 @@ fn scope(i: &str) -> IResult<&str, InfoStatement> { let (i, _) = alt((tag_no_case("SCOPE"), tag_no_case("SC")))(i)?; let (i, _) = shouldbespace(i)?; let (i, scope) = ident_raw(i)?; - Ok((i, InfoStatement::Scope(String::from(scope)))) + Ok((i, InfoStatement::Scope(scope))) } fn table(i: &str) -> IResult<&str, InfoStatement> { let (i, _) = alt((tag_no_case("TABLE"), tag_no_case("TB")))(i)?; let (i, _) = shouldbespace(i)?; let (i, table) = ident_raw(i)?; - Ok((i, InfoStatement::Table(String::from(table)))) + Ok((i, InfoStatement::Table(table))) } #[cfg(test)] diff --git a/src/sql/statements/remove.rs b/src/sql/statements/remove.rs index 4edf4779..aa4c9197 100644 --- a/src/sql/statements/remove.rs +++ b/src/sql/statements/remove.rs @@ -92,7 +92,7 @@ fn namespace(i: &str) -> IResult<&str, RemoveNamespaceStatement> { Ok(( i, RemoveNamespaceStatement { - name: String::from(name), + name, }, )) } @@ -121,7 +121,7 @@ fn database(i: &str) -> IResult<&str, RemoveDatabaseStatement> { Ok(( i, RemoveDatabaseStatement { - name: String::from(name), + name, }, )) } @@ -155,7 +155,7 @@ fn login(i: &str) -> IResult<&str, RemoveLoginStatement> { Ok(( i, RemoveLoginStatement { - name: String::from(name), + name, base, }, )) @@ -190,7 +190,7 @@ fn token(i: &str) -> IResult<&str, RemoveTokenStatement> { Ok(( i, RemoveTokenStatement { - name: String::from(name), + name, base, }, )) @@ -220,7 +220,7 @@ fn scope(i: &str) -> IResult<&str, RemoveScopeStatement> { Ok(( i, RemoveScopeStatement { - name: String::from(name), + name, }, )) } @@ -249,7 +249,7 @@ fn table(i: &str) -> IResult<&str, RemoveTableStatement> { Ok(( i, RemoveTableStatement { - name: String::from(name), + name, }, )) } @@ -283,8 +283,8 @@ fn event(i: &str) -> IResult<&str, RemoveEventStatement> { Ok(( i, RemoveEventStatement { - name: String::from(name), - what: String::from(what), + name, + what, }, )) } @@ -318,8 +318,8 @@ fn field(i: &str) -> IResult<&str, RemoveFieldStatement> { Ok(( i, RemoveFieldStatement { - name: String::from(name), - what: String::from(what), + name, + what, }, )) } @@ -353,8 +353,8 @@ fn index(i: &str) -> IResult<&str, RemoveIndexStatement> { Ok(( i, RemoveIndexStatement { - name: String::from(name), - what: String::from(what), + name, + what, }, )) } diff --git a/src/sql/statements/yuse.rs b/src/sql/statements/yuse.rs index f4be4d1f..c324d508 100644 --- a/src/sql/statements/yuse.rs +++ b/src/sql/statements/yuse.rs @@ -61,8 +61,8 @@ fn both(i: &str) -> IResult<&str, UseStatement> { Ok(( i, UseStatement { - ns: Some(String::from(ns)), - db: Some(String::from(db)), + ns: Some(ns), + db: Some(db), }, )) } @@ -76,7 +76,7 @@ fn ns(i: &str) -> IResult<&str, UseStatement> { Ok(( i, UseStatement { - ns: Some(String::from(ns)), + ns: Some(ns), db: None, }, )) @@ -92,7 +92,7 @@ fn db(i: &str) -> IResult<&str, UseStatement> { i, UseStatement { ns: None, - db: Some(String::from(db)), + db: Some(db), }, )) } diff --git a/src/sql/strand.rs b/src/sql/strand.rs index 5f48f3e7..e97301f2 100644 --- a/src/sql/strand.rs +++ b/src/sql/strand.rs @@ -1,13 +1,20 @@ use nom::branch::alt; +use nom::bytes::complete::escaped; use nom::bytes::complete::is_not; use nom::bytes::complete::tag; -use nom::sequence::delimited; +use nom::character::complete::one_of; use nom::IResult; use serde::ser::SerializeStruct; use serde::{Deserialize, Serialize}; use std::fmt; use std::str; +const SINGLE: &str = r#"'"#; +const SINGLE_ESC: &str = r#"\'"#; + +const DOUBLE: &str = r#"""#; +const DOUBLE_ESC: &str = r#"\""#; + #[derive(Clone, Debug, Default, Eq, PartialEq, PartialOrd, Deserialize)] pub struct Strand { pub value: String, @@ -55,14 +62,76 @@ pub fn strand(i: &str) -> IResult<&str, Strand> { Ok((i, Strand::from(v))) } -pub fn strand_raw(i: &str) -> IResult<&str, &str> { +pub fn strand_raw(i: &str) -> IResult<&str, String> { alt((strand_single, strand_double))(i) } -fn strand_single(i: &str) -> IResult<&str, &str> { - delimited(tag("\'"), is_not("\'"), tag("\'"))(i) +fn strand_single(i: &str) -> IResult<&str, String> { + let (i, _) = tag(SINGLE)(i)?; + let (i, v) = alt((escaped(is_not(SINGLE_ESC), '\\', one_of(SINGLE)), tag("")))(i)?; + let (i, _) = tag(SINGLE)(i)?; + Ok((i, String::from(v).replace(SINGLE_ESC, SINGLE))) } -fn strand_double(i: &str) -> IResult<&str, &str> { - delimited(tag("\""), is_not("\""), tag("\""))(i) +fn strand_double(i: &str) -> IResult<&str, String> { + let (i, _) = tag(DOUBLE)(i)?; + let (i, v) = alt((escaped(is_not(DOUBLE_ESC), '\\', one_of(DOUBLE)), tag("")))(i)?; + let (i, _) = tag(DOUBLE)(i)?; + Ok((i, String::from(v).replace(DOUBLE_ESC, DOUBLE))) +} + +#[cfg(test)] +mod tests { + + use super::*; + + #[test] + fn strand_empty() { + let sql = r#""""#; + let res = strand(sql); + assert!(res.is_ok()); + let out = res.unwrap().1; + assert_eq!(r#""""#, format!("{}", out)); + assert_eq!(out, Strand::from("")); + } + + #[test] + fn strand_single() { + let sql = r#"'test'"#; + let res = strand(sql); + assert!(res.is_ok()); + let out = res.unwrap().1; + assert_eq!(r#""test""#, format!("{}", out)); + assert_eq!(out, Strand::from("test")); + } + + #[test] + fn strand_double() { + let sql = r#""test""#; + let res = strand(sql); + assert!(res.is_ok()); + let out = res.unwrap().1; + assert_eq!(r#""test""#, format!("{}", out)); + assert_eq!(out, Strand::from("test")); + } + + #[test] + fn strand_quoted_single() { + let sql = r#"'te\'st'"#; + let res = strand(sql); + assert!(res.is_ok()); + let out = res.unwrap().1; + assert_eq!(r#""te'st""#, format!("{}", out)); + assert_eq!(out, Strand::from(r#"te'st"#)); + } + + #[test] + fn strand_quoted_double() { + let sql = r#""te\"st""#; + let res = strand(sql); + assert!(res.is_ok()); + let out = res.unwrap().1; + assert_eq!(r#""te"st""#, format!("{}", out)); + assert_eq!(out, Strand::from(r#"te"st"#)); + } } diff --git a/src/sql/table.rs b/src/sql/table.rs index 28b7ba49..8654b333 100644 --- a/src/sql/table.rs +++ b/src/sql/table.rs @@ -27,6 +27,14 @@ pub struct Table { pub name: String, } +impl From for Table { + fn from(s: String) -> Self { + Table { + name: s, + } + } +} + impl<'a> From<&'a str> for Table { fn from(t: &str) -> Table { Table { diff --git a/src/sql/thing.rs b/src/sql/thing.rs index c6a4fb8b..2e0074c7 100644 --- a/src/sql/thing.rs +++ b/src/sql/thing.rs @@ -9,13 +9,13 @@ use std::fmt; #[derive(Clone, Debug, Default, Eq, PartialEq, PartialOrd, Deserialize)] pub struct Thing { - pub table: String, + pub tb: String, pub id: String, } impl fmt::Display for Thing { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - let t = escape(&self.table, &val_char, "`"); + let t = escape(&self.tb, &val_char, "`"); let i = escape(&self.id, &val_char, "`"); write!(f, "{}:{}", t, i) } @@ -27,11 +27,11 @@ impl Serialize for Thing { S: serde::Serializer, { if serializer.is_human_readable() { - let output = format!("{}:{}", self.table, self.id); + let output = format!("{}:{}", self.tb, self.id); serializer.serialize_some(&output) } else { let mut val = serializer.serialize_struct("Thing", 2)?; - val.serialize_field("table", &self.table)?; + val.serialize_field("tb", &self.tb)?; val.serialize_field("id", &self.id)?; val.end() } @@ -45,8 +45,8 @@ pub fn thing(i: &str) -> IResult<&str, Thing> { Ok(( i, Thing { - table: String::from(t), - id: String::from(v), + tb: t, + id: v, }, )) } @@ -66,7 +66,7 @@ mod tests { assert_eq!( out, Thing { - table: String::from("test"), + tb: String::from("test"), id: String::from("id"), } ); @@ -82,7 +82,7 @@ mod tests { assert_eq!( out, Thing { - table: String::from("test"), + tb: String::from("test"), id: String::from("id"), } ); @@ -98,7 +98,7 @@ mod tests { assert_eq!( out, Thing { - table: String::from("test"), + tb: String::from("test"), id: String::from("id"), } );