diff --git a/Cargo.lock b/Cargo.lock index df8ac046..2e09ba1c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3953,6 +3953,7 @@ dependencies = [ "async-channel", "async-executor", "async-recursion", + "base64 0.21.0", "bcrypt", "bigdecimal", "bung", diff --git a/lib/Cargo.toml b/lib/Cargo.toml index 08ef0fc1..4789a498 100644 --- a/lib/Cargo.toml +++ b/lib/Cargo.toml @@ -54,6 +54,7 @@ addr = { version = "0.15.6", default-features = false, features = ["std"] } argon2 = "0.5.0" ascii = { version = "0.3.2", package = "any_ascii" } async-recursion = "1.0.4" +base64_lib = { version = "0.21.0", package = "base64" } bcrypt = "0.14.0" bigdecimal = { version = "0.3.0", features = ["serde", "string-only"] } bung = "0.1.0" diff --git a/lib/src/fnc/args.rs b/lib/src/fnc/args.rs index 8ce5f4c4..32157d73 100644 --- a/lib/src/fnc/args.rs +++ b/lib/src/fnc/args.rs @@ -1,6 +1,6 @@ use crate::err::Error; use crate::sql::value::Value; -use crate::sql::{Array, Datetime, Duration, Number, Strand, Thing}; +use crate::sql::{Array, Bytes, Datetime, Duration, Number, Strand, Thing}; /// Implemented by types that are commonly used, in a certain way, as arguments. pub trait FromArg: Sized { @@ -55,6 +55,12 @@ impl FromArg for Array { } } +impl FromArg for Bytes { + fn from_arg(arg: Value) -> Result { + arg.convert_to_bytes() + } +} + impl FromArg for i64 { fn from_arg(arg: Value) -> Result { arg.convert_to_i64() diff --git a/lib/src/fnc/bytes.rs b/lib/src/fnc/bytes.rs new file mode 100644 index 00000000..f041d659 --- /dev/null +++ b/lib/src/fnc/bytes.rs @@ -0,0 +1,6 @@ +use crate::err::Error; +use crate::sql::{Bytes, Value}; + +pub fn len((bytes,): (Bytes,)) -> Result { + Ok(bytes.len().into()) +} diff --git a/lib/src/fnc/encoding.rs b/lib/src/fnc/encoding.rs new file mode 100644 index 00000000..44cac549 --- /dev/null +++ b/lib/src/fnc/encoding.rs @@ -0,0 +1,18 @@ +pub mod base64 { + use crate::err::Error; + use crate::sql::{Bytes, Value}; + use base64_lib::{engine::general_purpose::STANDARD_NO_PAD, Engine}; + + pub fn encode((arg,): (Bytes,)) -> Result { + Ok(Value::from(STANDARD_NO_PAD.encode(&*arg))) + } + + pub fn decode((arg,): (String,)) -> Result { + Ok(Value::from(Bytes(STANDARD_NO_PAD.decode(arg).map_err(|_| { + Error::InvalidArguments { + name: "encoding::base64::decode".to_owned(), + message: "invalid base64".to_owned(), + } + })?))) + } +} diff --git a/lib/src/fnc/mod.rs b/lib/src/fnc/mod.rs index 20f59819..1c5fcd92 100644 --- a/lib/src/fnc/mod.rs +++ b/lib/src/fnc/mod.rs @@ -4,9 +4,11 @@ use crate::sql::value::Value; pub mod args; pub mod array; +pub mod bytes; pub mod count; pub mod crypto; pub mod duration; +pub mod encoding; pub mod geo; pub mod http; pub mod is; @@ -92,6 +94,8 @@ pub fn synchronous(ctx: &Context<'_>, name: &str, args: Vec) -> Result array::sort::asc, "array::sort::desc" => array::sort::desc, // + "bytes::len" => bytes::len, + // "count" => count::count, // "crypto::md5" => crypto::md5, @@ -117,6 +121,9 @@ pub fn synchronous(ctx: &Context<'_>, name: &str, args: Vec) -> Result duration::from::secs, "duration::from::weeks" => duration::from::weeks, // + "encoding::base64::decode" => encoding::base64::decode, + "encoding::base64::encode" => encoding::base64::encode, + // "geo::area" => geo::area, "geo::bearing" => geo::bearing, "geo::centroid" => geo::centroid, diff --git a/lib/src/fnc/script/modules/surrealdb/functions/bytes.rs b/lib/src/fnc/script/modules/surrealdb/functions/bytes.rs new file mode 100644 index 00000000..1599da83 --- /dev/null +++ b/lib/src/fnc/script/modules/surrealdb/functions/bytes.rs @@ -0,0 +1,10 @@ +use super::run; +use crate::fnc::script::modules::impl_module_def; + +pub struct Package; + +impl_module_def!( + Package, + "array", + "len" => run +); diff --git a/lib/src/fnc/script/modules/surrealdb/functions/encoding.rs b/lib/src/fnc/script/modules/surrealdb/functions/encoding.rs new file mode 100644 index 00000000..74ffd707 --- /dev/null +++ b/lib/src/fnc/script/modules/surrealdb/functions/encoding.rs @@ -0,0 +1,11 @@ +use crate::fnc::script::modules::impl_module_def; + +mod base64; + +pub struct Package; + +impl_module_def!( + Package, + "encoding", + "base64" => (base64::Package) +); diff --git a/lib/src/fnc/script/modules/surrealdb/functions/encoding/base64.rs b/lib/src/fnc/script/modules/surrealdb/functions/encoding/base64.rs new file mode 100644 index 00000000..99a403ff --- /dev/null +++ b/lib/src/fnc/script/modules/surrealdb/functions/encoding/base64.rs @@ -0,0 +1,11 @@ +use super::super::run; +use crate::fnc::script::modules::impl_module_def; + +pub struct Package; + +impl_module_def!( + Package, + "encoding::base64", + "decode" => run, + "encode" => run +); diff --git a/lib/src/fnc/script/modules/surrealdb/functions/mod.rs b/lib/src/fnc/script/modules/surrealdb/functions/mod.rs index 14a6089f..972f39c4 100644 --- a/lib/src/fnc/script/modules/surrealdb/functions/mod.rs +++ b/lib/src/fnc/script/modules/surrealdb/functions/mod.rs @@ -5,8 +5,10 @@ use crate::sql::Value; use js::{Async, Result}; mod array; +mod bytes; mod crypto; mod duration; +mod encoding; mod geo; mod http; mod is; @@ -25,9 +27,11 @@ impl_module_def!( Package, "", // root path "array" => (array::Package), + "bytes" => (bytes::Package), "count" => run, "crypto" => (crypto::Package), "duration" => (duration::Package), + "encoding" => (encoding::Package), "geo" => (geo::Package), "http" => (http::Package), "is" => (is::Package), diff --git a/lib/src/fnc/util/http/mod.rs b/lib/src/fnc/util/http/mod.rs index 5a1a8466..e586e7c7 100644 --- a/lib/src/fnc/util/http/mod.rs +++ b/lib/src/fnc/util/http/mod.rs @@ -1,16 +1,50 @@ use crate::ctx::Context; use crate::err::Error; -use crate::sql::json; use crate::sql::object::Object; use crate::sql::strand::Strand; use crate::sql::value::Value; +use crate::sql::{json, Bytes}; use reqwest::header::CONTENT_TYPE; -use reqwest::Client; +use reqwest::{Client, RequestBuilder, Response}; pub(crate) fn uri_is_valid(uri: &str) -> bool { reqwest::Url::parse(uri).is_ok() } +fn encode_body(req: RequestBuilder, body: Value) -> RequestBuilder { + match body { + Value::Bytes(bytes) => req.header(CONTENT_TYPE, "application/octet-stream").body(bytes.0), + _ if body.is_some() => req.json(&body), + _ => req, + } +} + +async fn decode_response(res: Response) -> Result { + match res.status() { + s if s.is_success() => match res.headers().get(CONTENT_TYPE) { + Some(mime) => match mime.to_str() { + Ok(v) if v.starts_with("application/json") => { + let txt = res.text().await?; + let val = json(&txt)?; + Ok(val) + } + Ok(v) if v.starts_with("application/octet-stream") => { + let bytes = res.bytes().await?; + Ok(Value::Bytes(Bytes(bytes.into()))) + } + Ok(v) if v.starts_with("text") => { + let txt = res.text().await?; + let val = txt.into(); + Ok(val) + } + _ => Ok(Value::None), + }, + _ => Ok(Value::None), + }, + s => Err(Error::Http(s.canonical_reason().unwrap_or_default().to_owned())), + } +} + pub async fn head(ctx: &Context<'_>, uri: Strand, opts: impl Into) -> Result { // Set a default client with no timeout let cli = Client::builder().build()?; @@ -56,26 +90,8 @@ pub async fn get(ctx: &Context<'_>, uri: Strand, opts: impl Into) -> Res Some(d) => req.timeout(d).send().await?, _ => req.send().await?, }; - // Check the response status - match res.status() { - s if s.is_success() => match res.headers().get(CONTENT_TYPE) { - Some(mime) => match mime.to_str() { - Ok(v) if v.starts_with("application/json") => { - let txt = res.text().await?; - let val = json(&txt)?; - Ok(val) - } - Ok(v) if v.starts_with("text") => { - let txt = res.text().await?; - let val = txt.into(); - Ok(val) - } - _ => Ok(Value::None), - }, - _ => Ok(Value::None), - }, - s => Err(Error::Http(s.canonical_reason().unwrap_or_default().to_owned())), - } + // Receive the response as a value + decode_response(res).await } pub async fn put( @@ -97,35 +113,15 @@ pub async fn put( req = req.header(k.as_str(), v.to_raw_string()); } // Submit the request body - if body.is_some() { - req = req.json(&body); - } + req = encode_body(req, body); // Send the request and wait let res = match ctx.timeout() { #[cfg(not(target_arch = "wasm32"))] Some(d) => req.timeout(d).send().await?, _ => req.send().await?, }; - // Check the response status - match res.status() { - s if s.is_success() => match res.headers().get(CONTENT_TYPE) { - Some(mime) => match mime.to_str() { - Ok(v) if v.starts_with("application/json") => { - let txt = res.text().await?; - let val = json(&txt)?; - Ok(val) - } - Ok(v) if v.starts_with("text") => { - let txt = res.text().await?; - let val = txt.into(); - Ok(val) - } - _ => Ok(Value::None), - }, - _ => Ok(Value::None), - }, - s => Err(Error::Http(s.canonical_reason().unwrap_or_default().to_owned())), - } + // Receive the response as a value + decode_response(res).await } pub async fn post( @@ -147,35 +143,15 @@ pub async fn post( req = req.header(k.as_str(), v.to_raw_string()); } // Submit the request body - if body.is_some() { - req = req.json(&body); - } + req = encode_body(req, body); // Send the request and wait let res = match ctx.timeout() { #[cfg(not(target_arch = "wasm32"))] Some(d) => req.timeout(d).send().await?, _ => req.send().await?, }; - // Check the response status - match res.status() { - s if s.is_success() => match res.headers().get(CONTENT_TYPE) { - Some(mime) => match mime.to_str() { - Ok(v) if v.starts_with("application/json") => { - let txt = res.text().await?; - let val = json(&txt)?; - Ok(val) - } - Ok(v) if v.starts_with("text") => { - let txt = res.text().await?; - let val = txt.into(); - Ok(val) - } - _ => Ok(Value::None), - }, - _ => Ok(Value::None), - }, - s => Err(Error::Http(s.canonical_reason().unwrap_or_default().to_owned())), - } + // Receive the response as a value + decode_response(res).await } pub async fn patch( @@ -197,35 +173,15 @@ pub async fn patch( req = req.header(k.as_str(), v.to_raw_string()); } // Submit the request body - if body.is_some() { - req = req.json(&body); - } + req = encode_body(req, body); // Send the request and wait let res = match ctx.timeout() { #[cfg(not(target_arch = "wasm32"))] Some(d) => req.timeout(d).send().await?, _ => req.send().await?, }; - // Check the response status - match res.status() { - s if s.is_success() => match res.headers().get(CONTENT_TYPE) { - Some(mime) => match mime.to_str() { - Ok(v) if v.starts_with("application/json") => { - let txt = res.text().await?; - let val = json(&txt)?; - Ok(val) - } - Ok(v) if v.starts_with("text") => { - let txt = res.text().await?; - let val = txt.into(); - Ok(val) - } - _ => Ok(Value::None), - }, - _ => Ok(Value::None), - }, - s => Err(Error::Http(s.canonical_reason().unwrap_or_default().to_owned())), - } + // Receive the response as a value + decode_response(res).await } pub async fn delete( @@ -251,24 +207,6 @@ pub async fn delete( Some(d) => req.timeout(d).send().await?, _ => req.send().await?, }; - // Check the response status - match res.status() { - s if s.is_success() => match res.headers().get(CONTENT_TYPE) { - Some(mime) => match mime.to_str() { - Ok(v) if v.starts_with("application/json") => { - let txt = res.text().await?; - let val = json(&txt)?; - Ok(val) - } - Ok(v) if v.starts_with("text") => { - let txt = res.text().await?; - let val = txt.into(); - Ok(val) - } - _ => Ok(Value::None), - }, - _ => Ok(Value::None), - }, - s => Err(Error::Http(s.canonical_reason().unwrap_or_default().to_owned())), - } + // Receive the response as a value + decode_response(res).await } diff --git a/lib/src/sql/bytes.rs b/lib/src/sql/bytes.rs index 67324b66..8f5ebc1d 100644 --- a/lib/src/sql/bytes.rs +++ b/lib/src/sql/bytes.rs @@ -1,8 +1,39 @@ -use serde::Deserialize; -use serde::Serialize; +use base64_lib::{engine::general_purpose::STANDARD_NO_PAD, Engine}; +use serde::{ + de::{self, Visitor}, + Deserialize, Serialize, +}; +use std::fmt::{self, Display, Formatter}; +use std::ops::Deref; -#[derive(Clone, Debug, Default, Eq, PartialEq, PartialOrd, Deserialize, Hash)] -pub struct Bytes(pub(super) Vec); +#[derive(Clone, Debug, Default, Eq, PartialEq, PartialOrd, Hash)] +pub struct Bytes(pub(crate) Vec); + +impl Bytes { + pub fn into_inner(self) -> Vec { + self.0 + } +} + +impl From> for Bytes { + fn from(v: Vec) -> Self { + Self(v) + } +} + +impl Deref for Bytes { + type Target = Vec; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl Display for Bytes { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "encoding::base64::decode(\"{}\")", STANDARD_NO_PAD.encode(&self.0)) + } +} impl Serialize for Bytes { fn serialize(&self, serializer: S) -> Result @@ -12,3 +43,50 @@ impl Serialize for Bytes { serializer.serialize_bytes(&self.0) } } + +impl<'de> Deserialize<'de> for Bytes { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + struct RawBytesVisitor; + + impl<'de> Visitor<'de> for RawBytesVisitor { + type Value = Bytes; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("bytes") + } + + fn visit_byte_buf(self, v: Vec) -> Result + where + E: de::Error, + { + Ok(Bytes(v)) + } + + fn visit_bytes(self, v: &[u8]) -> Result + where + E: de::Error, + { + Ok(Bytes(v.to_owned())) + } + } + + deserializer.deserialize_byte_buf(RawBytesVisitor) + } +} + +#[cfg(test)] +mod tests { + use crate::sql::{Bytes, Value}; + + #[test] + fn serialize() { + let val = Value::Bytes(Bytes(vec![1, 2, 3, 5])); + let serialized: Vec = val.clone().into(); + println!("{serialized:?}"); + let deserialized = Value::from(serialized); + assert_eq!(val, deserialized); + } +} diff --git a/lib/src/sql/function.rs b/lib/src/sql/function.rs index cfbfdfd3..a261bd46 100644 --- a/lib/src/sql/function.rs +++ b/lib/src/sql/function.rs @@ -271,8 +271,10 @@ fn cast(i: &str) -> IResult<&str, Function> { pub(crate) fn function_names(i: &str) -> IResult<&str, &str> { recognize(alt(( preceded(tag("array::"), function_array), + preceded(tag("bytes::"), function_bytes), preceded(tag("crypto::"), function_crypto), preceded(tag("duration::"), function_duration), + preceded(tag("encoding::"), function_encoding), preceded(tag("geo::"), function_geo), preceded(tag("http::"), function_http), preceded(tag("is::"), function_is), @@ -327,6 +329,10 @@ fn function_array(i: &str) -> IResult<&str, &str> { ))(i) } +fn function_bytes(i: &str) -> IResult<&str, &str> { + alt((tag("len"),))(i) +} + fn function_crypto(i: &str) -> IResult<&str, &str> { alt(( preceded(tag("argon2::"), alt((tag("compare"), tag("generate")))), @@ -367,6 +373,10 @@ fn function_duration(i: &str) -> IResult<&str, &str> { ))(i) } +fn function_encoding(i: &str) -> IResult<&str, &str> { + alt((preceded(tag("base64::"), alt((tag("decode"), tag("encode")))),))(i) +} + fn function_geo(i: &str) -> IResult<&str, &str> { alt(( tag("area"), diff --git a/lib/src/sql/value/value.rs b/lib/src/sql/value/value.rs index 66846574..8d8494ac 100644 --- a/lib/src/sql/value/value.rs +++ b/lib/src/sql/value/value.rs @@ -1493,6 +1493,8 @@ impl Value { match self { // Bytes are allowed Value::Bytes(v) => Ok(v), + // Strings can be converted to bytes + Value::Strand(s) => Ok(Bytes(s.0.into_bytes())), // Anything else raises an error _ => Err(Error::ConvertTo { from: self, @@ -1934,7 +1936,7 @@ impl fmt::Display for Value { Value::Function(v) => write!(f, "{v}"), Value::Subquery(v) => write!(f, "{v}"), Value::Expression(v) => write!(f, "{v}"), - Value::Bytes(_) => write!(f, ""), + Value::Bytes(v) => write!(f, "{v}"), } } } diff --git a/lib/tests/function.rs b/lib/tests/function.rs index 94dce8d3..63c88225 100644 --- a/lib/tests/function.rs +++ b/lib/tests/function.rs @@ -855,6 +855,44 @@ async fn function_array_union() -> Result<(), Error> { Ok(()) } +// -------------------------------------------------- +// bytes +// -------------------------------------------------- + +#[tokio::test] +async fn function_bytes_len() -> Result<(), Error> { + let sql = r#" + RETURN bytes::len(""); + RETURN bytes::len(true); + RETURN bytes::len("π"); + RETURN bytes::len("ππ"); + "#; + let dbs = Datastore::new("memory").await?; + let ses = Session::for_kv().with_ns("test").with_db("test"); + let res = &mut dbs.execute(&sql, &ses, None, false).await?; + assert_eq!(res.len(), 4); + // + let tmp = res.remove(0).result?; + let val = Value::parse("0"); + assert_eq!(tmp, val); + // + let tmp = res.remove(0).result; + assert!(matches!( + tmp.err(), + Some(e) if e.to_string() == "Incorrect arguments for function bytes::len(). Argument 1 was the wrong type. Expected a bytes but failed to convert true into a bytes" + )); + // + let tmp = res.remove(0).result?; + let val = Value::parse("2"); + assert_eq!(tmp, val); + // + let tmp = res.remove(0).result?; + let val = Value::parse("4"); + assert_eq!(tmp, val); + // + Ok(()) +} + // -------------------------------------------------- // count // -------------------------------------------------- @@ -1391,6 +1429,54 @@ async fn function_duration_from_weeks() -> Result<(), Error> { Ok(()) } +// -------------------------------------------------- +// encoding +// -------------------------------------------------- + +#[tokio::test] +async fn function_encoding_base64_decode() -> Result<(), Error> { + let sql = r#" + RETURN encoding::base64::decode(""); + RETURN encoding::base64::decode("aGVsbG8") = "hello"; + "#; + let dbs = Datastore::new("memory").await?; + let ses = Session::for_kv().with_ns("test").with_db("test"); + let res = &mut dbs.execute(&sql, &ses, None, false).await?; + assert_eq!(res.len(), 2); + // + let tmp = res.remove(0).result?; + let val = Value::Bytes(Vec::new().into()); + assert_eq!(tmp, val); + // + let tmp = res.remove(0).result?; + let val = Value::from(true); + assert_eq!(tmp, val); + // + Ok(()) +} + +#[tokio::test] +async fn function_encoding_base64_encode() -> Result<(), Error> { + let sql = r#" + RETURN encoding::base64::encode(""); + RETURN encoding::base64::encode("hello"); + "#; + let dbs = Datastore::new("memory").await?; + let ses = Session::for_kv().with_ns("test").with_db("test"); + let res = &mut dbs.execute(&sql, &ses, None, false).await?; + assert_eq!(res.len(), 2); + // + let tmp = res.remove(0).result?; + let val = Value::parse("''"); + assert_eq!(tmp, val); + // + let tmp = res.remove(0).result?; + let val = Value::parse("'aGVsbG8'"); + assert_eq!(tmp, val); + // + Ok(()) +} + // -------------------------------------------------- // geo // --------------------------------------------------