From 89f1de825a9a7c83f54133293e37c72e32846ee0 Mon Sep 17 00:00:00 2001 From: Micha de Vries Date: Tue, 13 Aug 2024 11:50:40 +0200 Subject: [PATCH] Fix `record` casting from string (#4496) Co-authored-by: Tobie Morgan Hitchcock --- core/src/sql/thing.rs | 6 ++++ core/src/sql/value/value.rs | 58 +++++++++++++++++++++++-------------- lib/tests/cast.rs | 40 +++++++++++++++++++++++++ 3 files changed, 83 insertions(+), 21 deletions(-) diff --git a/core/src/sql/thing.rs b/core/src/sql/thing.rs index 3d3972a4..6bfadb2d 100644 --- a/core/src/sql/thing.rs +++ b/core/src/sql/thing.rs @@ -11,6 +11,8 @@ use serde::{Deserialize, Serialize}; use std::fmt; use std::str::FromStr; +use super::Table; + pub(crate) const TOKEN: &str = "$surrealdb::private::sql::Thing"; #[revisioned(revision = 1)] @@ -90,6 +92,10 @@ impl Thing { pub fn to_raw(&self) -> String { self.to_string() } + + pub fn is_record_type(&self, types: &[Table]) -> bool { + types.is_empty() || types.iter().any(|tb| tb.0 == self.tb) + } } impl fmt::Display for Thing { diff --git a/core/src/sql/value/value.rs b/core/src/sql/value/value.rs index 62bd6ab9..456be10f 100644 --- a/core/src/sql/value/value.rs +++ b/core/src/sql/value/value.rs @@ -1028,7 +1028,7 @@ impl Value { /// Check if this Value is a Thing of a specific type pub fn is_record_type(&self, types: &[Table]) -> bool { match self { - Value::Thing(v) => types.is_empty() || types.iter().any(|tb| tb.0 == v.tb), + Value::Thing(v) => v.is_record_type(types), _ => false, } } @@ -1870,7 +1870,7 @@ impl Value { Value::Bool(v) => Ok(v), // Attempt to convert a string value Value::Strand(ref v) => match v.parse::() { - // The string can be represented as a Float + // The string can be parsed as a Float Ok(v) => Ok(v), // This string is not a float _ => Err(Error::ConvertTo { @@ -1895,7 +1895,7 @@ impl Value { Value::Number(Number::Float(v)) if v.fract() == 0.0 => Ok(Number::Int(v as i64)), // Attempt to convert a decimal number Value::Number(Number::Decimal(v)) if v.is_integer() => match v.try_into() { - // The Decimal can be represented as an Int + // The Decimal can be parsed as an Int Ok(v) => Ok(Number::Int(v)), // The Decimal is out of bounds _ => Err(Error::ConvertTo { @@ -1905,7 +1905,7 @@ impl Value { }, // Attempt to convert a string value Value::Strand(ref v) => match v.parse::() { - // The string can be represented as a Float + // The string can be parsed as a Float Ok(v) => Ok(Number::Int(v)), // This string is not a float _ => Err(Error::ConvertTo { @@ -1930,7 +1930,7 @@ impl Value { Value::Number(Number::Int(v)) => Ok(Number::Float(v as f64)), // Attempt to convert a decimal number Value::Number(Number::Decimal(v)) => match v.try_into() { - // The Decimal can be represented as a Float + // The Decimal can be parsed as a Float Ok(v) => Ok(Number::Float(v)), // The Decimal loses precision _ => Err(Error::ConvertTo { @@ -1940,7 +1940,7 @@ impl Value { }, // Attempt to convert a string value Value::Strand(ref v) => match v.parse::() { - // The string can be represented as a Float + // The string can be parsed as a Float Ok(v) => Ok(Number::Float(v)), // This string is not a float _ => Err(Error::ConvertTo { @@ -1975,7 +1975,7 @@ impl Value { }, // Attempt to convert a string value Value::Strand(ref v) => match Decimal::from_str(v) { - // The string can be represented as a Decimal + // The string can be parsed as a Decimal Ok(v) => Ok(Number::Decimal(v)), // This string is not a Decimal _ => Err(Error::ConvertTo { @@ -1998,7 +1998,7 @@ impl Value { Value::Number(v) => Ok(v), // Attempt to convert a string value Value::Strand(ref v) => match Number::from_str(v) { - // The string can be represented as a Float + // The string can be parsed as a number Ok(v) => Ok(v), // This string is not a float _ => Err(Error::ConvertTo { @@ -2073,11 +2073,11 @@ impl Value { // Uuids are allowed Value::Uuid(v) => Ok(v), // Attempt to parse a string - Value::Strand(ref v) => match Uuid::try_from(v.as_str()) { - // The string can be represented as a uuid + Value::Strand(ref v) => match Uuid::from_str(v) { + // The string can be parsed as a uuid Ok(v) => Ok(v), // This string is not a uuid - Err(_) => Err(Error::ConvertTo { + _ => Err(Error::ConvertTo { from: self, into: "uuid".into(), }), @@ -2109,11 +2109,11 @@ impl Value { // Datetimes are allowed Value::Datetime(v) => Ok(v), // Attempt to parse a string - Value::Strand(ref v) => match Datetime::try_from(v.as_str()) { - // The string can be represented as a datetime + Value::Strand(ref v) => match Datetime::from_str(v) { + // The string can be parsed as a datetime Ok(v) => Ok(v), // This string is not a datetime - Err(_) => Err(Error::ConvertTo { + _ => Err(Error::ConvertTo { from: self, into: "datetime".into(), }), @@ -2132,11 +2132,11 @@ impl Value { // Durations are allowed Value::Duration(v) => Ok(v), // Attempt to parse a string - Value::Strand(ref v) => match Duration::try_from(v.as_str()) { - // The string can be represented as a duration + Value::Strand(ref v) => match Duration::from_str(v) { + // The string can be parsed as a duration Ok(v) => Ok(v), // This string is not a duration - Err(_) => Err(Error::ConvertTo { + _ => Err(Error::ConvertTo { from: self, into: "duration".into(), }), @@ -2218,10 +2218,16 @@ impl Value { match self { // Records are allowed Value::Thing(v) => Ok(v), - Value::Strand(v) => Thing::try_from(v.as_str()).map_err(move |_| Error::ConvertTo { - from: Value::Strand(v), - into: "record".into(), - }), + // Attempt to parse a string + Value::Strand(ref v) => match Thing::from_str(v) { + // The string can be parsed as a record + Ok(v) => Ok(v), + // This string is not a record + _ => Err(Error::ConvertTo { + from: self, + into: "record".into(), + }), + }, // Anything else raises an error _ => Err(Error::ConvertTo { from: self, @@ -2248,6 +2254,16 @@ impl Value { match self { // Records are allowed if correct type Value::Thing(v) if self.is_record_type(val) => Ok(v), + // Attempt to parse a string + Value::Strand(ref v) => match Thing::from_str(v) { + // The string can be parsed as a record of this type + Ok(v) if v.is_record_type(val) => Ok(v), + // This string is not a record of this type + _ => Err(Error::ConvertTo { + from: self, + into: "record".into(), + }), + }, // Anything else raises an error _ => Err(Error::ConvertTo { from: self, diff --git a/lib/tests/cast.rs b/lib/tests/cast.rs index ee787ff1..02fd8ccc 100644 --- a/lib/tests/cast.rs +++ b/lib/tests/cast.rs @@ -22,3 +22,43 @@ async fn cast_string_to_record() -> Result<(), Error> { // Ok(()) } + +#[tokio::test] +async fn cast_to_record_table() -> Result<(), Error> { + let sql = r#" + > a:1; + > "a:1"; + > a:1; + > "a:1"; + "#; + let dbs = new_ds().await?; + let ses = Session::owner().with_ns("test").with_db("test"); + let res = &mut dbs.execute(sql, &ses, None).await?; + assert_eq!(res.len(), 4); + // + let tmp = res.remove(0).result?; + let val = Value::parse("a:1"); + assert_eq!(tmp, val); + // + let tmp = res.remove(0).result?; + let val = Value::parse("a:1"); + assert_eq!(tmp, val); + // + match res.remove(0).result { + Err(Error::ConvertTo { + from, + into, + }) if into == "record" && from == Value::parse("a:1") => (), + _ => panic!("Casting should have failed with error: Expected a record but cannot convert a:1 into a record"), + } + // + match res.remove(0).result { + Err(Error::ConvertTo { + from, + into, + }) if into == "record" && from == Value::parse("'a:1'") => (), + _ => panic!("Casting should have failed with error: Expected a record but cannot convert 'a:1' into a record"), + } + // + Ok(()) +}