diff --git a/core/src/iam/verify.rs b/core/src/iam/verify.rs index 8fde54ea..42096109 100644 --- a/core/src/iam/verify.rs +++ b/core/src/iam/verify.rs @@ -1467,7 +1467,7 @@ mod tests { // Test with generic user identifier // { - let resource_id = "user:⟨2k9qnabxuxh8k4d5gfto⟩".to_string(); + let resource_id = "user:2k9qnabxuxh8k4d5gfto".to_string(); // Prepare the claims object let mut claims = claims.clone(); claims.id = Some(resource_id.clone()); diff --git a/core/src/sql/escape.rs b/core/src/sql/escape.rs index f1df3b50..1badad4b 100644 --- a/core/src/sql/escape.rs +++ b/core/src/sql/escape.rs @@ -65,7 +65,7 @@ pub fn escape_key(s: &str) -> Cow<'_, str> { #[inline] /// Escapes an id if necessary pub fn escape_rid(s: &str) -> Cow<'_, str> { - escape_numeric(s, BRACKETL, BRACKETR, BRACKET_ESC) + escape_full_numeric(s, BRACKETL, BRACKETR, BRACKET_ESC) } #[inline] @@ -74,7 +74,7 @@ pub fn escape_ident(s: &str) -> Cow<'_, str> { if let Some(x) = escape_reserved_keyword(s) { return Cow::Owned(x); } - escape_numeric(s, BACKTICK, BACKTICK, BACKTICK_ESC) + escape_starts_numeric(s, BACKTICK, BACKTICK, BACKTICK_ESC) } #[inline] @@ -95,7 +95,29 @@ pub fn escape_reserved_keyword(s: &str) -> Option { } #[inline] -pub fn escape_numeric<'a>(s: &'a str, l: char, r: char, e: &str) -> Cow<'a, str> { +pub fn escape_full_numeric<'a>(s: &'a str, l: char, r: char, e: &str) -> Cow<'a, str> { + let mut numeric = true; + // Loop over each character + for x in s.bytes() { + // Check if character is allowed + if !(x.is_ascii_alphanumeric() || x == b'_') { + return Cow::Owned(format!("{l}{}{r}", s.replace(r, e))); + } + // For every character, we need to check if it is a digit until we encounter a non-digit + if numeric && !x.is_ascii_digit() { + numeric = false; + } + } + + // If all characters are digits, then we need to escape the string + if numeric { + return Cow::Owned(format!("{l}{}{r}", s.replace(r, e))); + } + Cow::Borrowed(s) +} + +#[inline] +pub fn escape_starts_numeric<'a>(s: &'a str, l: char, r: char, e: &str) -> Cow<'a, str> { // Loop over each character for (idx, x) in s.bytes().enumerate() { // the first character is not allowed to be a digit.