diff --git a/lib/src/doc/table.rs b/lib/src/doc/table.rs index df437137..5ce2f687 100644 --- a/lib/src/doc/table.rs +++ b/lib/src/doc/table.rs @@ -321,7 +321,7 @@ impl<'a> Document<'a> { Operator::Equal, Value::Subquery(Box::new(Subquery::Ifelse(IfelseStatement { exprs: vec![( - Value::Expression(Box::new(Expression { + Value::Expression(Box::new(Expression::Binary { l: Value::Idiom(key.clone()), o: Operator::MoreThan, r: val.clone(), @@ -341,7 +341,7 @@ impl<'a> Document<'a> { Operator::Equal, Value::Subquery(Box::new(Subquery::Ifelse(IfelseStatement { exprs: vec![( - Value::Expression(Box::new(Expression { + Value::Expression(Box::new(Expression::Binary { l: Value::Idiom(key.clone()), o: Operator::LessThan, r: val.clone(), @@ -363,13 +363,13 @@ impl<'a> Document<'a> { ops.push(( key.clone(), Operator::Equal, - Value::Expression(Box::new(Expression { + Value::Expression(Box::new(Expression::Binary { l: Value::Subquery(Box::new(Subquery::Value(Value::Expression(Box::new( - Expression { + Expression::Binary { l: Value::Subquery(Box::new(Subquery::Value(Value::Expression(Box::new( - Expression { + Expression::Binary { l: Value::Subquery(Box::new(Subquery::Value(Value::Expression( - Box::new(Expression { + Box::new(Expression::Binary { l: Value::Idiom(key), o: Operator::Nco, r: Value::Number(Number::Int(0)), @@ -377,7 +377,7 @@ impl<'a> Document<'a> { )))), o: Operator::Mul, r: Value::Subquery(Box::new(Subquery::Value(Value::Expression( - Box::new(Expression { + Box::new(Expression::Binary { l: Value::Idiom(key_c.clone()), o: Operator::Nco, r: Value::Number(Number::Int(0)), @@ -395,9 +395,9 @@ impl<'a> Document<'a> { ))))), o: Operator::Div, r: Value::Subquery(Box::new(Subquery::Value(Value::Expression(Box::new( - Expression { + Expression::Binary { l: Value::Subquery(Box::new(Subquery::Value(Value::Expression(Box::new( - Expression { + Expression::Binary { l: Value::Idiom(key_c.clone()), o: Operator::Nco, r: Value::Number(Number::Int(0)), diff --git a/lib/src/err/mod.rs b/lib/src/err/mod.rs index 4de11613..7a1ed90a 100644 --- a/lib/src/err/mod.rs +++ b/lib/src/err/mod.rs @@ -430,6 +430,10 @@ pub enum Error { #[error("Cannot raise the value '{0}' with '{1}'")] TryPow(String, String), + /// Cannot perform negation + #[error("Cannot negate the value '{0}'")] + TryNeg(String), + /// It's is not possible to convert between the two types #[error("Cannot convert from '{0}' to '{1}'")] TryFrom(String, &'static str), diff --git a/lib/src/fnc/operate.rs b/lib/src/fnc/operate.rs index 1e159c29..f7ac7f71 100644 --- a/lib/src/fnc/operate.rs +++ b/lib/src/fnc/operate.rs @@ -3,11 +3,20 @@ use crate::err::Error; use crate::sql::value::TryAdd; use crate::sql::value::TryDiv; use crate::sql::value::TryMul; +use crate::sql::value::TryNeg; use crate::sql::value::TryPow; use crate::sql::value::TrySub; use crate::sql::value::Value; use crate::sql::Expression; +pub fn neg(a: Value) -> Result { + a.try_neg() +} + +pub fn not(a: Value) -> Result { + super::not::not((a,)) +} + pub fn or(a: Value, b: Value) -> Result { Ok(match a.is_truthy() { true => a, diff --git a/lib/src/idx/planner/tree.rs b/lib/src/idx/planner/tree.rs index 05ecd470..4e005092 100644 --- a/lib/src/idx/planner/tree.rs +++ b/lib/src/idx/planner/tree.rs @@ -87,27 +87,42 @@ impl<'a> TreeBuilder<'a> { } async fn eval_expression(&mut self, e: &Expression) -> Result { - let left = self.eval_value(&e.l).await?; - let right = self.eval_value(&e.r).await?; - let mut index_option = None; - if let Some(ix) = left.is_indexed_field() { - if let Some(io) = IndexOption::found(ix, &e.o, &right, e) { - index_option = Some(io.clone()); - self.add_index(e, io); + match e { + Expression::Unary { + .. + } => { + return Err(Error::FeatureNotYetImplemented { + feature: "unary expressions in index", + }); + } + Expression::Binary { + l, + o, + r, + } => { + let left = self.eval_value(l).await?; + let right = self.eval_value(r).await?; + let mut index_option = None; + if let Some(ix) = left.is_indexed_field() { + if let Some(io) = IndexOption::found(ix, o, &right, e) { + index_option = Some(io.clone()); + self.add_index(e, io); + } + } + if let Some(ix) = right.is_indexed_field() { + if let Some(io) = IndexOption::found(ix, o, &left, e) { + index_option = Some(io.clone()); + self.add_index(e, io); + } + } + Ok(Node::Expression { + index_option, + left: Box::new(left), + right: Box::new(right), + operator: o.to_owned(), + }) } } - if let Some(ix) = right.is_indexed_field() { - if let Some(io) = IndexOption::found(ix, &e.o, &left, e) { - index_option = Some(io.clone()); - self.add_index(e, io); - } - } - Ok(Node::Expression { - index_option, - left: Box::new(left), - right: Box::new(right), - operator: e.o.to_owned(), - }) } fn add_index(&mut self, e: &Expression, io: IndexOption) { diff --git a/lib/src/sql/cast.rs b/lib/src/sql/cast.rs index f16442b6..f19931a4 100644 --- a/lib/src/sql/cast.rs +++ b/lib/src/sql/cast.rs @@ -13,6 +13,8 @@ use serde::{Deserialize, Serialize}; use std::cmp::Ordering; use std::fmt; +pub(crate) const TOKEN: &str = "$surrealdb::private::sql::Cast"; + #[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Hash)] #[serde(rename = "$surrealdb::private::sql::Cast")] pub struct Cast(pub Kind, pub Value); diff --git a/lib/src/sql/ending.rs b/lib/src/sql/ending.rs index 79953d7c..b7ee1539 100644 --- a/lib/src/sql/ending.rs +++ b/lib/src/sql/ending.rs @@ -1,7 +1,7 @@ use crate::sql::comment::comment; use crate::sql::comment::{mightbespace, shouldbespace}; use crate::sql::error::IResult; -use crate::sql::operator::{assigner, operator}; +use crate::sql::operator::{assigner, binary}; use nom::branch::alt; use nom::bytes::complete::tag; use nom::bytes::complete::tag_no_case; @@ -15,7 +15,7 @@ use nom::sequence::preceded; pub fn number(i: &str) -> IResult<&str, ()> { peek(alt(( map(multispace1, |_| ()), - map(operator, |_| ()), + map(binary, |_| ()), map(assigner, |_| ()), map(comment, |_| ()), map(char(')'), |_| ()), @@ -33,7 +33,7 @@ pub fn number(i: &str) -> IResult<&str, ()> { pub fn ident(i: &str) -> IResult<&str, ()> { peek(alt(( map(multispace1, |_| ()), - map(operator, |_| ()), + map(binary, |_| ()), map(assigner, |_| ()), map(comment, |_| ()), map(char(')'), |_| ()), @@ -51,7 +51,7 @@ pub fn ident(i: &str) -> IResult<&str, ()> { pub fn duration(i: &str) -> IResult<&str, ()> { peek(alt(( map(multispace1, |_| ()), - map(operator, |_| ()), + map(binary, |_| ()), map(assigner, |_| ()), map(comment, |_| ()), map(char(')'), |_| ()), diff --git a/lib/src/sql/expression.rs b/lib/src/sql/expression.rs index cc3f5f45..6d8b3f69 100644 --- a/lib/src/sql/expression.rs +++ b/lib/src/sql/expression.rs @@ -2,8 +2,9 @@ use crate::ctx::Context; use crate::dbs::Options; use crate::err::Error; use crate::fnc; +use crate::sql::comment::mightbespace; use crate::sql::error::IResult; -use crate::sql::operator::{operator, Operator}; +use crate::sql::operator::{self, Operator}; use crate::sql::value::{single, value, Value}; use serde::{Deserialize, Serialize}; use std::fmt; @@ -11,17 +12,24 @@ use std::str; pub(crate) const TOKEN: &str = "$surrealdb::private::sql::Expression"; +/// Binary expressions. #[derive(Clone, Debug, Eq, PartialEq, PartialOrd, Serialize, Deserialize, Hash)] #[serde(rename = "$surrealdb::private::sql::Expression")] -pub struct Expression { - pub l: Value, - pub o: Operator, - pub r: Value, +pub enum Expression { + Unary { + o: Operator, + v: Value, + }, + Binary { + l: Value, + o: Operator, + r: Value, + }, } impl Default for Expression { fn default() -> Expression { - Expression { + Expression::Binary { l: Value::Null, o: Operator::default(), r: Value::Null, @@ -30,9 +38,9 @@ impl Default for Expression { } impl Expression { - /// Create a new expression + /// Create a new binary expression pub(crate) fn new(l: Value, o: Operator, r: Value) -> Self { - Self { + Self::Binary { l, o, r, @@ -40,29 +48,67 @@ impl Expression { } /// Augment an existing expression fn augment(mut self, l: Value, o: Operator) -> Self { - if o.precedence() >= self.o.precedence() { - match self.l { + match &mut self { + Self::Binary { + l: left, + o: op, + .. + } if o.precedence() >= op.precedence() => match left { Value::Expression(x) => { - self.l = x.augment(l, o).into(); + *x.as_mut() = std::mem::take(x).augment(l, o); self } _ => { - self.l = Self::new(l, o, self.l).into(); + *left = Self::new(l, o, std::mem::take(left)).into(); self } + }, + e => { + let r = Value::from(std::mem::take(e)); + Self::new(l, o, r) } - } else { - let r = Value::from(self); - Self::new(l, o, r) } } } impl Expression { + pub(crate) fn writeable(&self) -> bool { + match self { + Self::Unary { + v, + .. + } => v.writeable(), + Self::Binary { + l, + r, + .. + } => l.writeable() || r.writeable(), + } + } + /// Process this type returning a computed simple Value pub(crate) async fn compute(&self, ctx: &Context<'_>, opt: &Options) -> Result { - let l = self.l.compute(ctx, opt).await?; - match self.o { + let (l, o, r) = match self { + Self::Unary { + o, + v, + } => { + let operand = v.compute(ctx, opt).await?; + return match o { + Operator::Neg => fnc::operate::neg(operand), + Operator::Not => fnc::operate::not(operand), + op => unreachable!("{op:?} is not a unary op"), + }; + } + Self::Binary { + l, + o, + r, + } => (l, o, r), + }; + + let l = l.compute(ctx, opt).await?; + match o { Operator::Or => { if let true = l.is_truthy() { return Ok(l); @@ -85,8 +131,8 @@ impl Expression { } _ => {} // Continue } - let r = self.r.compute(ctx, opt).await?; - match self.o { + let r = r.compute(ctx, opt).await?; + match o { Operator::Or => fnc::operate::or(l, r), Operator::And => fnc::operate::and(l, r), Operator::Tco => fnc::operate::tco(l, r), @@ -129,13 +175,36 @@ impl Expression { impl fmt::Display for Expression { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{} {} {}", self.l, self.o, self.r) + match self { + Self::Unary { + o, + v, + } => write!(f, "{o}{v}"), + Self::Binary { + l, + o, + r, + } => write!(f, "{l} {o} {r}"), + } } } -pub fn expression(i: &str) -> IResult<&str, Expression> { +pub fn unary(i: &str) -> IResult<&str, Expression> { + let (i, o) = operator::unary(i)?; + let (i, _) = mightbespace(i)?; + let (i, v) = single(i)?; + Ok(( + i, + Expression::Unary { + o, + v, + }, + )) +} + +pub fn binary(i: &str) -> IResult<&str, Expression> { let (i, l) = single(i)?; - let (i, o) = operator(i)?; + let (i, o) = operator::binary(i)?; let (i, r) = value(i)?; let v = match r { Value::Expression(r) => r.augment(l, o), @@ -152,7 +221,7 @@ mod tests { #[test] fn expression_statement() { let sql = "true AND false"; - let res = expression(sql); + let res = binary(sql); assert!(res.is_ok()); let out = res.unwrap().1; assert_eq!("true AND false", format!("{}", out)); @@ -161,7 +230,7 @@ mod tests { #[test] fn expression_left_opened() { let sql = "3 * 3 * 3 = 27"; - let res = expression(sql); + let res = binary(sql); assert!(res.is_ok()); let out = res.unwrap().1; assert_eq!("3 * 3 * 3 = 27", format!("{}", out)); @@ -170,7 +239,7 @@ mod tests { #[test] fn expression_left_closed() { let sql = "(3 * 3 * 3) = 27"; - let res = expression(sql); + let res = binary(sql); assert!(res.is_ok()); let out = res.unwrap().1; assert_eq!("(3 * 3 * 3) = 27", format!("{}", out)); @@ -179,7 +248,7 @@ mod tests { #[test] fn expression_right_opened() { let sql = "27 = 3 * 3 * 3"; - let res = expression(sql); + let res = binary(sql); assert!(res.is_ok()); let out = res.unwrap().1; assert_eq!("27 = 3 * 3 * 3", format!("{}", out)); @@ -188,7 +257,7 @@ mod tests { #[test] fn expression_right_closed() { let sql = "27 = (3 * 3 * 3)"; - let res = expression(sql); + let res = binary(sql); assert!(res.is_ok()); let out = res.unwrap().1; assert_eq!("27 = (3 * 3 * 3)", format!("{}", out)); @@ -197,7 +266,7 @@ mod tests { #[test] fn expression_both_opened() { let sql = "3 * 3 * 3 = 3 * 3 * 3"; - let res = expression(sql); + let res = binary(sql); assert!(res.is_ok()); let out = res.unwrap().1; assert_eq!("3 * 3 * 3 = 3 * 3 * 3", format!("{}", out)); @@ -206,9 +275,27 @@ mod tests { #[test] fn expression_both_closed() { let sql = "(3 * 3 * 3) = (3 * 3 * 3)"; - let res = expression(sql); + let res = binary(sql); assert!(res.is_ok()); let out = res.unwrap().1; assert_eq!("(3 * 3 * 3) = (3 * 3 * 3)", format!("{}", out)); } + + #[test] + fn expression_unary() { + let sql = "-a"; + let res = unary(sql); + assert!(res.is_ok()); + let out = res.unwrap().1; + assert_eq!(sql, format!("{}", out)); + } + + #[test] + fn expression_with_unary() { + let sql = "-(5) + 5"; + let res = binary(sql); + assert!(res.is_ok()); + let out = res.unwrap().1; + assert_eq!(sql, format!("{}", out)); + } } diff --git a/lib/src/sql/number.rs b/lib/src/sql/number.rs index f48e8999..d9fd758a 100644 --- a/lib/src/sql/number.rs +++ b/lib/src/sql/number.rs @@ -16,7 +16,7 @@ use std::fmt::{self, Display, Formatter}; use std::hash; use std::iter::Product; use std::iter::Sum; -use std::ops; +use std::ops::{self, Neg}; use std::str::FromStr; pub(crate) const TOKEN: &str = "$surrealdb::private::sql::Number"; @@ -572,6 +572,18 @@ impl<'a, 'b> ops::Div<&'b Number> for &'a Number { } } +impl Neg for Number { + type Output = Self; + + fn neg(self) -> Self::Output { + match self { + Self::Int(n) => Number::Int(-n), + Self::Float(n) => Number::Float(-n), + Self::Decimal(n) => Number::Decimal(-n), + } + } +} + // ------------------------------ impl Sum for Number { diff --git a/lib/src/sql/operator.rs b/lib/src/sql/operator.rs index c8932567..0b6815ca 100644 --- a/lib/src/sql/operator.rs +++ b/lib/src/sql/operator.rs @@ -11,8 +11,12 @@ use serde::{Deserialize, Serialize}; use std::fmt; use std::fmt::Write; +/// Binary operators. #[derive(Clone, Debug, Eq, PartialEq, PartialOrd, Serialize, Deserialize, Hash)] pub enum Operator { + // + Neg, // - + Not, // ! // Or, // || And, // && @@ -70,14 +74,14 @@ impl Operator { #[inline] pub fn precedence(&self) -> u8 { match self { - Operator::Or => 1, - Operator::And => 2, - Operator::Tco => 3, - Operator::Nco => 4, - Operator::Sub => 6, - Operator::Add => 7, - Operator::Mul => 8, - Operator::Div => 9, + Self::Or => 1, + Self::And => 2, + Self::Tco => 3, + Self::Nco => 4, + Self::Sub => 6, + Self::Add => 7, + Self::Mul => 8, + Self::Div => 9, _ => 5, } } @@ -86,6 +90,8 @@ impl Operator { impl fmt::Display for Operator { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { + Self::Neg => f.write_str("-"), + Self::Not => f.write_str("!"), Self::Or => f.write_str("OR"), Self::And => f.write_str("AND"), Self::Tco => f.write_str("?:"), @@ -143,11 +149,23 @@ pub fn assigner(i: &str) -> IResult<&str, Operator> { ))(i) } -pub fn operator(i: &str) -> IResult<&str, Operator> { - alt((symbols, phrases))(i) +pub fn unary(i: &str) -> IResult<&str, Operator> { + unary_symbols(i) } -pub fn symbols(i: &str) -> IResult<&str, Operator> { +pub fn unary_symbols(i: &str) -> IResult<&str, Operator> { + let (i, _) = mightbespace(i)?; + let (i, v) = + alt((alt((map(tag("-"), |_| Operator::Neg), map(tag("!"), |_| Operator::Not))),))(i)?; + let (i, _) = mightbespace(i)?; + Ok((i, v)) +} + +pub fn binary(i: &str) -> IResult<&str, Operator> { + alt((binary_symbols, binary_phrases))(i) +} + +pub fn binary_symbols(i: &str) -> IResult<&str, Operator> { let (i, _) = mightbespace(i)?; let (i, v) = alt(( alt(( @@ -203,7 +221,7 @@ pub fn symbols(i: &str) -> IResult<&str, Operator> { Ok((i, v)) } -pub fn phrases(i: &str) -> IResult<&str, Operator> { +pub fn binary_phrases(i: &str) -> IResult<&str, Operator> { let (i, _) = shouldbespace(i)?; let (i, v) = alt(( alt(( diff --git a/lib/src/sql/test.rs b/lib/src/sql/test.rs index 646e77dd..a33b5afb 100644 --- a/lib/src/sql/test.rs +++ b/lib/src/sql/test.rs @@ -1,5 +1,5 @@ use crate::sql::array::{array, Array}; -use crate::sql::expression::{expression, Expression}; +use crate::sql::expression::{binary, Expression}; use crate::sql::idiom::{idiom, Idiom}; use crate::sql::param::{param, Param}; use crate::sql::script::{script, Script}; @@ -48,6 +48,6 @@ impl Parse for Thing { impl Parse for Expression { fn parse(val: &str) -> Self { - expression(val).unwrap().1 + binary(val).unwrap().1 } } diff --git a/lib/src/sql/value/serde/ser/expression/mod.rs b/lib/src/sql/value/serde/ser/expression/mod.rs index 429350c2..391913e6 100644 --- a/lib/src/sql/value/serde/ser/expression/mod.rs +++ b/lib/src/sql/value/serde/ser/expression/mod.rs @@ -19,29 +19,104 @@ impl ser::Serializer for Serializer { type SerializeTupleStruct = Impossible; type SerializeTupleVariant = Impossible; type SerializeMap = Impossible; - type SerializeStruct = SerializeExpression; - type SerializeStructVariant = Impossible; + type SerializeStruct = Impossible; + type SerializeStructVariant = SerializeExpression; - const EXPECTED: &'static str = "a struct `Expression`"; + const EXPECTED: &'static str = "an enum `Expression`"; #[inline] - fn serialize_struct( + fn serialize_struct_variant( self, - _name: &'static str, + name: &'static str, + _variant_index: u32, + variant: &'static str, _len: usize, - ) -> Result { - Ok(SerializeExpression::default()) + ) -> Result { + debug_assert_eq!(name, crate::sql::expression::TOKEN); + match variant { + "Unary" => Ok(SerializeExpression::Unary(Default::default())), + "Binary" => Ok(SerializeExpression::Binary(Default::default())), + _ => Err(Error::custom(format!("unexpected `Expression::{name}`"))), + } + } +} + +pub(super) enum SerializeExpression { + Unary(SerializeUnary), + Binary(SerializeBinary), +} + +impl serde::ser::SerializeStructVariant for SerializeExpression { + type Ok = Expression; + type Error = Error; + + fn serialize_field(&mut self, key: &'static str, value: &T) -> Result<(), Error> + where + T: ?Sized + Serialize, + { + match self { + Self::Unary(unary) => unary.serialize_field(key, value), + Self::Binary(binary) => binary.serialize_field(key, value), + } + } + + fn end(self) -> Result { + match self { + Self::Unary(unary) => unary.end(), + Self::Binary(binary) => binary.end(), + } } } #[derive(Default)] -pub(super) struct SerializeExpression { +pub(super) struct SerializeUnary { + o: Option, + v: Option, +} + +impl serde::ser::SerializeStructVariant for SerializeUnary { + type Ok = Expression; + type Error = Error; + + fn serialize_field(&mut self, key: &'static str, value: &T) -> Result<(), Error> + where + T: ?Sized + Serialize, + { + match key { + "o" => { + self.o = Some(value.serialize(ser::operator::Serializer.wrap())?); + } + "v" => { + self.v = Some(value.serialize(ser::value::Serializer.wrap())?); + } + key => { + return Err(Error::custom(format!( + "unexpected field `Expression::Unary{{{key}}}`" + ))); + } + } + Ok(()) + } + + fn end(self) -> Result { + match (self.o, self.v) { + (Some(o), Some(v)) => Ok(Expression::Unary { + o, + v, + }), + _ => Err(Error::custom("`Expression::Unary` missing required field(s)")), + } + } +} + +#[derive(Default)] +pub(super) struct SerializeBinary { l: Option, o: Option, r: Option, } -impl serde::ser::SerializeStruct for SerializeExpression { +impl serde::ser::SerializeStructVariant for SerializeBinary { type Ok = Expression; type Error = Error; @@ -60,7 +135,9 @@ impl serde::ser::SerializeStruct for SerializeExpression { self.r = Some(value.serialize(ser::value::Serializer.wrap())?); } key => { - return Err(Error::custom(format!("unexpected field `Expression::{key}`"))); + return Err(Error::custom(format!( + "unexpected field `Expression::Binary{{{key}}}`" + ))); } } Ok(()) @@ -68,12 +145,12 @@ impl serde::ser::SerializeStruct for SerializeExpression { fn end(self) -> Result { match (self.l, self.o, self.r) { - (Some(l), Some(o), Some(r)) => Ok(Expression { + (Some(l), Some(o), Some(r)) => Ok(Expression::Binary { l, o, r, }), - _ => Err(Error::custom("`Expression` missing required field(s)")), + _ => Err(Error::custom("`Expression::Binary` missing required field(s)")), } } } @@ -90,9 +167,19 @@ mod tests { assert_eq!(expression, serialized); } + #[test] + fn unary() { + let expression = Expression::Unary { + o: Operator::Not, + v: "Bar".into(), + }; + let serialized = expression.serialize(Serializer.wrap()).unwrap(); + assert_eq!(expression, serialized); + } + #[test] fn foo_equals_bar() { - let expression = Expression { + let expression = Expression::Binary { l: "foo".into(), o: Operator::Equal, r: "Bar".into(), diff --git a/lib/src/sql/value/serde/ser/operator/mod.rs b/lib/src/sql/value/serde/ser/operator/mod.rs index 167eb662..dc860030 100644 --- a/lib/src/sql/value/serde/ser/operator/mod.rs +++ b/lib/src/sql/value/serde/ser/operator/mod.rs @@ -28,6 +28,8 @@ impl ser::Serializer for Serializer { variant: &'static str, ) -> Result { match variant { + "Neg" => Ok(Operator::Neg), + "Not" => Ok(Operator::Not), "Or" => Ok(Operator::Or), "And" => Ok(Operator::And), "Tco" => Ok(Operator::Tco), diff --git a/lib/src/sql/value/serde/ser/value/mod.rs b/lib/src/sql/value/serde/ser/value/mod.rs index 4dbc1338..cf67b319 100644 --- a/lib/src/sql/value/serde/ser/value/mod.rs +++ b/lib/src/sql/value/serde/ser/value/mod.rs @@ -20,6 +20,7 @@ use crate::sql::Table; use crate::sql::Uuid; use map::SerializeValueMap; use rust_decimal::Decimal; +use ser::cast::SerializeCast; use ser::edges::SerializeEdges; use ser::expression::SerializeExpression; use ser::function::SerializeFunction; @@ -60,7 +61,7 @@ impl ser::Serializer for Serializer { type SerializeSeq = SerializeArray; type SerializeTuple = SerializeArray; - type SerializeTupleStruct = SerializeArray; + type SerializeTupleStruct = SerializeTupleStruct; type SerializeTupleVariant = SerializeTupleVariant; type SerializeMap = SerializeMap; type SerializeStruct = SerializeStruct; @@ -299,10 +300,13 @@ impl ser::Serializer for Serializer { fn serialize_tuple_struct( self, - _name: &'static str, - len: usize, + name: &'static str, + _len: usize, ) -> Result { - self.serialize_seq(Some(len)) + match name { + sql::cast::TOKEN => Ok(SerializeTupleStruct::Cast(Default::default())), + _ => Ok(SerializeTupleStruct::Array(Default::default())), + } } fn serialize_tuple_variant( @@ -347,7 +351,6 @@ impl ser::Serializer for Serializer { ) -> Result { Ok(match name { sql::thing::TOKEN => SerializeStruct::Thing(Default::default()), - sql::expression::TOKEN => SerializeStruct::Expression(Default::default()), sql::edges::TOKEN => SerializeStruct::Edges(Default::default()), sql::range::TOKEN => SerializeStruct::Range(Default::default()), _ => SerializeStruct::Unknown(Default::default()), @@ -356,18 +359,27 @@ impl ser::Serializer for Serializer { fn serialize_struct_variant( self, - _name: &'static str, + name: &'static str, _variant_index: u32, variant: &'static str, _len: usize, ) -> Result { - Ok(SerializeStructVariant { - name: String::from(variant), - map: Object::default(), + Ok(if name == sql::expression::TOKEN { + SerializeStructVariant::Expression(match variant { + "Unary" => SerializeExpression::Unary(Default::default()), + "Binary" => SerializeExpression::Binary(Default::default()), + _ => return Err(Error::custom(format!("unexpected `Expression::{name}`"))), + }) + } else { + SerializeStructVariant::Object { + name: String::from(variant), + map: Object::default(), + } }) } } +#[derive(Default)] pub(super) struct SerializeArray(vec::SerializeValueVec); impl serde::ser::SerializeSeq for SerializeArray { @@ -452,6 +464,33 @@ pub(super) enum SerializeTupleVariant { }, } +pub(super) enum SerializeTupleStruct { + Cast(SerializeCast), + Array(SerializeArray), +} + +impl serde::ser::SerializeTupleStruct for SerializeTupleStruct { + type Ok = Value; + type Error = Error; + + fn serialize_field(&mut self, value: &T) -> Result<(), Error> + where + T: ?Sized + Serialize, + { + match self { + Self::Cast(cast) => cast.serialize_field(value), + Self::Array(array) => array.serialize_field(value), + } + } + + fn end(self) -> Result { + match self { + Self::Cast(cast) => Ok(Value::Cast(Box::new(cast.end()?))), + Self::Array(array) => Ok(serde::ser::SerializeTupleStruct::end(array)?), + } + } +} + impl serde::ser::SerializeTupleVariant for SerializeTupleVariant { type Ok = Value; type Error = Error; @@ -461,9 +500,9 @@ impl serde::ser::SerializeTupleVariant for SerializeTupleVariant { T: ?Sized + Serialize, { match self { - SerializeTupleVariant::Model(model) => model.serialize_field(value), - SerializeTupleVariant::Function(function) => function.serialize_field(value), - SerializeTupleVariant::Unknown { + Self::Model(model) => model.serialize_field(value), + Self::Function(function) => function.serialize_field(value), + Self::Unknown { ref mut fields, .. } => fields.serialize_element(value), @@ -472,11 +511,9 @@ impl serde::ser::SerializeTupleVariant for SerializeTupleVariant { fn end(self) -> Result { match self { - SerializeTupleVariant::Model(model) => Ok(Value::Model(model.end()?)), - SerializeTupleVariant::Function(function) => { - Ok(Value::Function(Box::new(function.end()?))) - } - SerializeTupleVariant::Unknown { + Self::Model(model) => Ok(Value::Model(model.end()?)), + Self::Function(function) => Ok(Value::Function(Box::new(function.end()?))), + Self::Unknown { variant, fields, } => Ok(map! { @@ -489,7 +526,6 @@ impl serde::ser::SerializeTupleVariant for SerializeTupleVariant { pub(super) enum SerializeStruct { Thing(SerializeThing), - Expression(SerializeExpression), Edges(SerializeEdges), Range(SerializeRange), Unknown(SerializeValueMap), @@ -504,28 +540,29 @@ impl serde::ser::SerializeStruct for SerializeStruct { T: ?Sized + Serialize, { match self { - SerializeStruct::Thing(thing) => thing.serialize_field(key, value), - SerializeStruct::Expression(expr) => expr.serialize_field(key, value), - SerializeStruct::Edges(edges) => edges.serialize_field(key, value), - SerializeStruct::Range(range) => range.serialize_field(key, value), - SerializeStruct::Unknown(map) => map.serialize_entry(key, value), + Self::Thing(thing) => thing.serialize_field(key, value), + Self::Edges(edges) => edges.serialize_field(key, value), + Self::Range(range) => range.serialize_field(key, value), + Self::Unknown(map) => map.serialize_entry(key, value), } } fn end(self) -> Result { match self { - SerializeStruct::Thing(thing) => Ok(Value::Thing(thing.end()?)), - SerializeStruct::Expression(expr) => Ok(Value::Expression(Box::new(expr.end()?))), - SerializeStruct::Edges(edges) => Ok(Value::Edges(Box::new(edges.end()?))), - SerializeStruct::Range(range) => Ok(Value::Range(Box::new(range.end()?))), - SerializeStruct::Unknown(map) => Ok(Value::Object(Object(map.end()?))), + Self::Thing(thing) => Ok(Value::Thing(thing.end()?)), + Self::Edges(edges) => Ok(Value::Edges(Box::new(edges.end()?))), + Self::Range(range) => Ok(Value::Range(Box::new(range.end()?))), + Self::Unknown(map) => Ok(Value::Object(Object(map.end()?))), } } } -pub(super) struct SerializeStructVariant { - name: String, - map: Object, +pub(super) enum SerializeStructVariant { + Expression(SerializeExpression), + Object { + name: String, + map: Object, + }, } impl serde::ser::SerializeStructVariant for SerializeStructVariant { @@ -536,16 +573,32 @@ impl serde::ser::SerializeStructVariant for SerializeStructVariant { where T: ?Sized + Serialize, { - self.map.0.insert(String::from(key), value.serialize(Serializer.wrap())?); - Ok(()) + match self { + Self::Expression(expression) => expression.serialize_field(key, value), + Self::Object { + map, + .. + } => { + map.0.insert(String::from(key), value.serialize(Serializer.wrap())?); + Ok(()) + } + } } fn end(self) -> Result { - let mut object = Object::default(); + match self { + Self::Expression(expression) => Ok(Value::from(expression.end()?)), + Self::Object { + name, + map, + } => { + let mut object = Object::default(); - object.insert(self.name, Value::Object(self.map)); + object.insert(name, Value::Object(map)); - Ok(Value::Object(object)) + Ok(Value::Object(object)) + } + } } } @@ -825,7 +878,7 @@ mod tests { #[test] fn expression() { - let expression = Box::new(Expression { + let expression = Box::new(Expression::Binary { l: "foo".into(), o: Operator::Equal, r: "Bar".into(), diff --git a/lib/src/sql/value/serde/ser/value/vec.rs b/lib/src/sql/value/serde/ser/value/vec.rs index 30b53e7a..604c9050 100644 --- a/lib/src/sql/value/serde/ser/value/vec.rs +++ b/lib/src/sql/value/serde/ser/value/vec.rs @@ -38,6 +38,7 @@ impl ser::Serializer for Serializer { } } +#[derive(Default)] pub struct SerializeValueVec(pub Vec); impl serde::ser::SerializeSeq for SerializeValueVec { diff --git a/lib/src/sql/value/value.rs b/lib/src/sql/value/value.rs index d08bfbb3..a4f0acb8 100644 --- a/lib/src/sql/value/value.rs +++ b/lib/src/sql/value/value.rs @@ -15,7 +15,7 @@ use crate::sql::datetime::{datetime, Datetime}; use crate::sql::duration::{duration, Duration}; use crate::sql::edges::{edges, Edges}; use crate::sql::error::IResult; -use crate::sql::expression::{expression, Expression}; +use crate::sql::expression::{binary, unary, Expression}; use crate::sql::fmt::{Fmt, Pretty}; use crate::sql::function::{self, function, Function}; use crate::sql::future::{future, Future}; @@ -58,6 +58,7 @@ use std::collections::BTreeMap; use std::collections::HashMap; use std::fmt::{self, Display, Formatter, Write}; use std::ops::Deref; +use std::ops::Neg; use std::str::FromStr; static MATCHER: Lazy = Lazy::new(|| SkimMatcherV2::default().ignore_case()); @@ -2477,7 +2478,7 @@ impl Value { Value::Object(v) => v.iter().any(|(_, v)| v.writeable()), Value::Function(v) => v.is_custom() || v.args().iter().any(Value::writeable), Value::Subquery(v) => v.writeable(), - Value::Expression(v) => v.l.writeable() || v.r.writeable(), + Value::Expression(v) => v.writeable(), _ => false, } } @@ -2658,9 +2659,24 @@ impl TryPow for Value { // ------------------------------ -/// Parse any `Value` including binary expressions +pub(crate) trait TryNeg { + type Output; + fn try_neg(self) -> Result; +} + +impl TryNeg for Value { + type Output = Self; + fn try_neg(self) -> Result { + match self { + Self::Number(n) if !matches!(n, Number::Int(i64::MIN)) => Ok(Self::Number(n.neg())), + v => Err(Error::TryNeg(v.to_string())), + } + } +} + +/// Parse any `Value` including expressions pub fn value(i: &str) -> IResult<&str, Value> { - alt((map(expression, Value::from), single))(i) + alt((map(binary, Value::from), single))(i) } /// Parse any `Value` excluding binary expressions @@ -2684,8 +2700,11 @@ pub fn single(i: &str) -> IResult<&str, Value> { map(future, Value::from), map(unique, Value::from), map(number, Value::from), + map(unary, Value::from), map(object, Value::from), map(array, Value::from), + )), + alt(( map(block, Value::from), map(param, Value::from), map(regex, Value::from), @@ -2702,7 +2721,8 @@ pub fn single(i: &str) -> IResult<&str, Value> { pub fn select(i: &str) -> IResult<&str, Value> { alt(( alt(( - map(expression, Value::from), + map(unary, Value::from), + map(binary, Value::from), map(tag_no_case("NONE"), |_| Value::None), map(tag_no_case("NULL"), |_| Value::Null), map(tag_no_case("true"), |_| Value::Bool(true)), diff --git a/lib/tests/select.rs b/lib/tests/select.rs index 5a252a95..3a2cde64 100644 --- a/lib/tests/select.rs +++ b/lib/tests/select.rs @@ -65,6 +65,64 @@ async fn select_field_value() -> Result<(), Error> { Ok(()) } +#[tokio::test] +async fn select_expression_value() -> Result<(), Error> { + let sql = " + CREATE thing:a SET number = 5, boolean = true; + CREATE thing:b SET number = -5, boolean = false; + SELECT VALUE -number FROM thing; + SELECT VALUE !boolean FROM thing; + "; + let dbs = Datastore::new("memory").await?; + let ses = Session::for_kv().with_ns("test").with_db("test"); + let res = &mut dbs.execute(&sql, &ses, None, false).await?; + assert_eq!(res.len(), 4); + // + let tmp = res.remove(0).result?; + let val = Value::parse( + "[ + { + boolean: true, + id: thing:a, + number: 5 + } + ]", + ); + assert_eq!(tmp, val); + // + let tmp = res.remove(0).result?; + let val = Value::parse( + "[ + { + boolean: false, + id: thing:b, + number: -5 + } + ]", + ); + assert_eq!(tmp, val); + // + let tmp = res.remove(0).result?; + let val = Value::parse( + "[ + -5, + 5, + ]", + ); + assert_eq!(tmp, val); + // + let tmp = res.remove(0).result?; + let val = Value::parse( + "[ + false, + true + ]", + ); + assert_eq!(tmp, val); + // + Ok(()) +} + #[tokio::test] async fn select_writeable_subqueries() -> Result<(), Error> { let sql = "