diff --git a/lib/src/fnc/type.rs b/lib/src/fnc/type.rs index b29351d2..de6e844a 100644 --- a/lib/src/fnc/type.rs +++ b/lib/src/fnc/type.rs @@ -66,10 +66,10 @@ pub fn point((arg1, arg2): (Value, Option)) -> Result { } pub fn regex((arg,): (Value,)) -> Result { - match arg { - Value::Strand(v) => Ok(Value::Regex(v.as_str().into())), - _ => Ok(Value::None), - } + Ok(match arg { + Value::Strand(v) => v.parse().map(Value::Regex).unwrap_or(Value::None), + _ => Value::None, + }) } pub fn string((arg,): (Strand,)) -> Result { diff --git a/lib/src/sql/regex.rs b/lib/src/sql/regex.rs index 77f9bac8..ad228521 100644 --- a/lib/src/sql/regex.rs +++ b/lib/src/sql/regex.rs @@ -4,59 +4,120 @@ use nom::bytes::complete::escaped; use nom::bytes::complete::is_not; use nom::character::complete::anychar; use nom::character::complete::char; -use serde::{Deserialize, Serialize}; -use std::fmt; -use std::ops::Deref; +use serde::{ + de::{self, Visitor}, + Deserialize, Deserializer, Serialize, Serializer, +}; +use std::cmp::Ordering; +use std::fmt::Debug; +use std::fmt::{self, Display, Formatter}; +use std::hash::{Hash, Hasher}; use std::str; +use std::str::FromStr; pub(crate) const TOKEN: &str = "$surrealdb::private::sql::Regex"; -#[derive(Clone, Debug, Default, Eq, Ord, PartialEq, PartialOrd, Deserialize, Hash)] -pub struct Regex(pub(super) String); +#[derive(Clone)] +pub struct Regex(pub(super) regex::Regex); -impl From<&str> for Regex { - fn from(r: &str) -> Self { - Self(r.replace("\\/", "/")) - } -} - -impl Deref for Regex { - type Target = String; - fn deref(&self) -> &Self::Target { +impl Regex { + // Deref would expose `regex::Regex::as_str` which wouldn't have the '/' delimiters. + pub fn regex(&self) -> ®ex::Regex { &self.0 } } -impl fmt::Display for Regex { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "/{}/", &self.0) +impl FromStr for Regex { + type Err = ::Err; + + fn from_str(s: &str) -> Result { + regex::Regex::new(&s.replace("\\/", "/")).map(Self) } } -impl Regex { - pub fn regex(&self) -> Option { - regex::Regex::new(&self.0).ok() +impl PartialEq for Regex { + fn eq(&self, other: &Self) -> bool { + self.0.as_str().eq(other.0.as_str()) + } +} + +impl Eq for Regex {} + +impl Ord for Regex { + fn cmp(&self, other: &Self) -> Ordering { + self.0.as_str().cmp(other.0.as_str()) + } +} + +impl PartialOrd for Regex { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Hash for Regex { + fn hash(&self, state: &mut H) { + self.0.as_str().hash(state); + } +} + +impl Debug for Regex { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + Display::fmt(self, f) + } +} + +impl Display for Regex { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + write!(f, "/{}/", &self.0) } } impl Serialize for Regex { fn serialize(&self, serializer: S) -> Result where - S: serde::Serializer, + S: Serializer, { if is_internal_serialization() { - serializer.serialize_newtype_struct(TOKEN, &self.0) + serializer.serialize_newtype_struct(TOKEN, self.0.as_str()) } else { serializer.serialize_none() } } } +impl<'de> Deserialize<'de> for Regex { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + struct RegexVisitor; + + impl<'de> Visitor<'de> for RegexVisitor { + type Value = Regex; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a regex str") + } + + fn visit_str(self, value: &str) -> Result + where + E: de::Error, + { + Regex::from_str(value).map_err(|_| de::Error::custom("invalid regex")) + } + } + + deserializer.deserialize_str(RegexVisitor) + } +} + pub fn regex(i: &str) -> IResult<&str, Regex> { let (i, _) = char('/')(i)?; let (i, v) = escaped(is_not("\\/"), '\\', anychar)(i)?; let (i, _) = char('/')(i)?; - Ok((i, Regex::from(v))) + let regex = v.parse().map_err(|_| nom::Err::Error(crate::sql::Error::Parser(v)))?; + Ok((i, regex)) } #[cfg(test)] @@ -71,7 +132,7 @@ mod tests { assert!(res.is_ok()); let out = res.unwrap().1; assert_eq!("/test/", format!("{}", out)); - assert_eq!(out, Regex::from("test")); + assert_eq!(out, "test".parse().unwrap()); } #[test] @@ -81,6 +142,6 @@ mod tests { assert!(res.is_ok()); let out = res.unwrap().1; assert_eq!(r"/(?i)test/[a-z]+/\s\d\w{1}.*/", format!("{}", out)); - assert_eq!(out, Regex::from(r"(?i)test/[a-z]+/\s\d\w{1}.*")); + assert_eq!(out, r"(?i)test/[a-z]+/\s\d\w{1}.*".parse().unwrap()); } } diff --git a/lib/src/sql/value/serde/ser/value/mod.rs b/lib/src/sql/value/serde/ser/value/mod.rs index 5f30b605..e9aa1ac9 100644 --- a/lib/src/sql/value/serde/ser/value/mod.rs +++ b/lib/src/sql/value/serde/ser/value/mod.rs @@ -17,7 +17,6 @@ use crate::sql::Future; use crate::sql::Ident; use crate::sql::Idiom; use crate::sql::Param; -use crate::sql::Regex; use crate::sql::Strand; use crate::sql::Table; use crate::sql::Uuid; @@ -211,7 +210,7 @@ impl ser::Serializer for Serializer { value.serialize(ser::block::entry::vec::Serializer.wrap())?, ))))), sql::regex::TOKEN => { - Ok(Value::Regex(Regex(value.serialize(ser::string::Serializer.wrap())?))) + Ok(Value::Regex(value.serialize(ser::string::Serializer.wrap())?.parse().unwrap())) } sql::table::TOKEN => { Ok(Value::Table(Table(value.serialize(ser::string::Serializer.wrap())?))) @@ -743,7 +742,7 @@ mod tests { #[test] fn regex() { - let regex = Regex::default(); + let regex = "abc".parse().unwrap(); let value = to_value(®ex).unwrap(); let expected = Value::Regex(regex); assert_eq!(value, expected); diff --git a/lib/src/sql/value/value.rs b/lib/src/sql/value/value.rs index 38f287d5..72cc588f 100644 --- a/lib/src/sql/value/value.rs +++ b/lib/src/sql/value/value.rs @@ -1179,22 +1179,13 @@ impl Value { Value::False => other.is_false(), Value::Thing(v) => match other { Value::Thing(w) => v == w, - Value::Regex(w) => match w.regex() { - Some(ref r) => r.is_match(v.to_string().as_str()), - None => false, - }, + Value::Regex(w) => w.regex().is_match(v.to_string().as_str()), _ => false, }, Value::Regex(v) => match other { Value::Regex(w) => v == w, - Value::Number(w) => match v.regex() { - Some(ref r) => r.is_match(w.to_string().as_str()), - None => false, - }, - Value::Strand(w) => match v.regex() { - Some(ref r) => r.is_match(w.as_str()), - None => false, - }, + Value::Number(w) => v.regex().is_match(w.to_string().as_str()), + Value::Strand(w) => v.regex().is_match(w.as_str()), _ => false, }, Value::Uuid(v) => match other { @@ -1211,19 +1202,13 @@ impl Value { }, Value::Strand(v) => match other { Value::Strand(w) => v == w, - Value::Regex(w) => match w.regex() { - Some(ref r) => r.is_match(v.as_str()), - None => false, - }, + Value::Regex(w) => w.regex().is_match(v.as_str()), _ => v == &other.to_strand(), }, Value::Number(v) => match other { Value::Number(w) => v == w, Value::Strand(_) => v == &other.to_number(), - Value::Regex(w) => match w.regex() { - Some(ref r) => r.is_match(v.to_string().as_str()), - None => false, - }, + Value::Regex(w) => w.regex().is_match(v.to_string().as_str()), _ => false, }, Value::Geometry(v) => match other { @@ -1886,7 +1871,7 @@ mod tests { assert_eq!(24, std::mem::size_of::()); assert_eq!(56, std::mem::size_of::()); assert_eq!(48, std::mem::size_of::()); - assert_eq!(24, std::mem::size_of::()); + assert_eq!(16, std::mem::size_of::()); assert_eq!(8, std::mem::size_of::>()); assert_eq!(8, std::mem::size_of::>()); assert_eq!(8, std::mem::size_of::>());