Feature - Expand support for Bytes (#1898)

This commit is contained in:
Finn Bear 2023-05-09 13:43:16 -07:00 committed by GitHub
parent 73374d4799
commit ccc16fa9a7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 306 additions and 117 deletions

1
Cargo.lock generated
View file

@ -3953,6 +3953,7 @@ dependencies = [
"async-channel", "async-channel",
"async-executor", "async-executor",
"async-recursion", "async-recursion",
"base64 0.21.0",
"bcrypt", "bcrypt",
"bigdecimal", "bigdecimal",
"bung", "bung",

View file

@ -54,6 +54,7 @@ addr = { version = "0.15.6", default-features = false, features = ["std"] }
argon2 = "0.5.0" argon2 = "0.5.0"
ascii = { version = "0.3.2", package = "any_ascii" } ascii = { version = "0.3.2", package = "any_ascii" }
async-recursion = "1.0.4" async-recursion = "1.0.4"
base64_lib = { version = "0.21.0", package = "base64" }
bcrypt = "0.14.0" bcrypt = "0.14.0"
bigdecimal = { version = "0.3.0", features = ["serde", "string-only"] } bigdecimal = { version = "0.3.0", features = ["serde", "string-only"] }
bung = "0.1.0" bung = "0.1.0"

View file

@ -1,6 +1,6 @@
use crate::err::Error; use crate::err::Error;
use crate::sql::value::Value; use crate::sql::value::Value;
use crate::sql::{Array, Datetime, Duration, Number, Strand, Thing}; use crate::sql::{Array, Bytes, Datetime, Duration, Number, Strand, Thing};
/// Implemented by types that are commonly used, in a certain way, as arguments. /// Implemented by types that are commonly used, in a certain way, as arguments.
pub trait FromArg: Sized { pub trait FromArg: Sized {
@ -55,6 +55,12 @@ impl FromArg for Array {
} }
} }
impl FromArg for Bytes {
fn from_arg(arg: Value) -> Result<Self, Error> {
arg.convert_to_bytes()
}
}
impl FromArg for i64 { impl FromArg for i64 {
fn from_arg(arg: Value) -> Result<Self, Error> { fn from_arg(arg: Value) -> Result<Self, Error> {
arg.convert_to_i64() arg.convert_to_i64()

6
lib/src/fnc/bytes.rs Normal file
View file

@ -0,0 +1,6 @@
use crate::err::Error;
use crate::sql::{Bytes, Value};
pub fn len((bytes,): (Bytes,)) -> Result<Value, Error> {
Ok(bytes.len().into())
}

18
lib/src/fnc/encoding.rs Normal file
View file

@ -0,0 +1,18 @@
pub mod base64 {
use crate::err::Error;
use crate::sql::{Bytes, Value};
use base64_lib::{engine::general_purpose::STANDARD_NO_PAD, Engine};
pub fn encode((arg,): (Bytes,)) -> Result<Value, Error> {
Ok(Value::from(STANDARD_NO_PAD.encode(&*arg)))
}
pub fn decode((arg,): (String,)) -> Result<Value, Error> {
Ok(Value::from(Bytes(STANDARD_NO_PAD.decode(arg).map_err(|_| {
Error::InvalidArguments {
name: "encoding::base64::decode".to_owned(),
message: "invalid base64".to_owned(),
}
})?)))
}
}

View file

@ -4,9 +4,11 @@ use crate::sql::value::Value;
pub mod args; pub mod args;
pub mod array; pub mod array;
pub mod bytes;
pub mod count; pub mod count;
pub mod crypto; pub mod crypto;
pub mod duration; pub mod duration;
pub mod encoding;
pub mod geo; pub mod geo;
pub mod http; pub mod http;
pub mod is; pub mod is;
@ -92,6 +94,8 @@ pub fn synchronous(ctx: &Context<'_>, name: &str, args: Vec<Value>) -> Result<Va
"array::sort::asc" => array::sort::asc, "array::sort::asc" => array::sort::asc,
"array::sort::desc" => array::sort::desc, "array::sort::desc" => array::sort::desc,
// //
"bytes::len" => bytes::len,
//
"count" => count::count, "count" => count::count,
// //
"crypto::md5" => crypto::md5, "crypto::md5" => crypto::md5,
@ -117,6 +121,9 @@ pub fn synchronous(ctx: &Context<'_>, name: &str, args: Vec<Value>) -> Result<Va
"duration::from::secs" => duration::from::secs, "duration::from::secs" => duration::from::secs,
"duration::from::weeks" => duration::from::weeks, "duration::from::weeks" => duration::from::weeks,
// //
"encoding::base64::decode" => encoding::base64::decode,
"encoding::base64::encode" => encoding::base64::encode,
//
"geo::area" => geo::area, "geo::area" => geo::area,
"geo::bearing" => geo::bearing, "geo::bearing" => geo::bearing,
"geo::centroid" => geo::centroid, "geo::centroid" => geo::centroid,

View file

@ -0,0 +1,10 @@
use super::run;
use crate::fnc::script::modules::impl_module_def;
pub struct Package;
impl_module_def!(
Package,
"array",
"len" => run
);

View file

@ -0,0 +1,11 @@
use crate::fnc::script::modules::impl_module_def;
mod base64;
pub struct Package;
impl_module_def!(
Package,
"encoding",
"base64" => (base64::Package)
);

View file

@ -0,0 +1,11 @@
use super::super::run;
use crate::fnc::script::modules::impl_module_def;
pub struct Package;
impl_module_def!(
Package,
"encoding::base64",
"decode" => run,
"encode" => run
);

View file

@ -5,8 +5,10 @@ use crate::sql::Value;
use js::{Async, Result}; use js::{Async, Result};
mod array; mod array;
mod bytes;
mod crypto; mod crypto;
mod duration; mod duration;
mod encoding;
mod geo; mod geo;
mod http; mod http;
mod is; mod is;
@ -25,9 +27,11 @@ impl_module_def!(
Package, Package,
"", // root path "", // root path
"array" => (array::Package), "array" => (array::Package),
"bytes" => (bytes::Package),
"count" => run, "count" => run,
"crypto" => (crypto::Package), "crypto" => (crypto::Package),
"duration" => (duration::Package), "duration" => (duration::Package),
"encoding" => (encoding::Package),
"geo" => (geo::Package), "geo" => (geo::Package),
"http" => (http::Package), "http" => (http::Package),
"is" => (is::Package), "is" => (is::Package),

View file

@ -1,16 +1,50 @@
use crate::ctx::Context; use crate::ctx::Context;
use crate::err::Error; use crate::err::Error;
use crate::sql::json;
use crate::sql::object::Object; use crate::sql::object::Object;
use crate::sql::strand::Strand; use crate::sql::strand::Strand;
use crate::sql::value::Value; use crate::sql::value::Value;
use crate::sql::{json, Bytes};
use reqwest::header::CONTENT_TYPE; use reqwest::header::CONTENT_TYPE;
use reqwest::Client; use reqwest::{Client, RequestBuilder, Response};
pub(crate) fn uri_is_valid(uri: &str) -> bool { pub(crate) fn uri_is_valid(uri: &str) -> bool {
reqwest::Url::parse(uri).is_ok() reqwest::Url::parse(uri).is_ok()
} }
fn encode_body(req: RequestBuilder, body: Value) -> RequestBuilder {
match body {
Value::Bytes(bytes) => req.header(CONTENT_TYPE, "application/octet-stream").body(bytes.0),
_ if body.is_some() => req.json(&body),
_ => req,
}
}
async fn decode_response(res: Response) -> Result<Value, Error> {
match res.status() {
s if s.is_success() => match res.headers().get(CONTENT_TYPE) {
Some(mime) => match mime.to_str() {
Ok(v) if v.starts_with("application/json") => {
let txt = res.text().await?;
let val = json(&txt)?;
Ok(val)
}
Ok(v) if v.starts_with("application/octet-stream") => {
let bytes = res.bytes().await?;
Ok(Value::Bytes(Bytes(bytes.into())))
}
Ok(v) if v.starts_with("text") => {
let txt = res.text().await?;
let val = txt.into();
Ok(val)
}
_ => Ok(Value::None),
},
_ => Ok(Value::None),
},
s => Err(Error::Http(s.canonical_reason().unwrap_or_default().to_owned())),
}
}
pub async fn head(ctx: &Context<'_>, uri: Strand, opts: impl Into<Object>) -> Result<Value, Error> { pub async fn head(ctx: &Context<'_>, uri: Strand, opts: impl Into<Object>) -> Result<Value, Error> {
// Set a default client with no timeout // Set a default client with no timeout
let cli = Client::builder().build()?; let cli = Client::builder().build()?;
@ -56,26 +90,8 @@ pub async fn get(ctx: &Context<'_>, uri: Strand, opts: impl Into<Object>) -> Res
Some(d) => req.timeout(d).send().await?, Some(d) => req.timeout(d).send().await?,
_ => req.send().await?, _ => req.send().await?,
}; };
// Check the response status // Receive the response as a value
match res.status() { decode_response(res).await
s if s.is_success() => match res.headers().get(CONTENT_TYPE) {
Some(mime) => match mime.to_str() {
Ok(v) if v.starts_with("application/json") => {
let txt = res.text().await?;
let val = json(&txt)?;
Ok(val)
}
Ok(v) if v.starts_with("text") => {
let txt = res.text().await?;
let val = txt.into();
Ok(val)
}
_ => Ok(Value::None),
},
_ => Ok(Value::None),
},
s => Err(Error::Http(s.canonical_reason().unwrap_or_default().to_owned())),
}
} }
pub async fn put( pub async fn put(
@ -97,35 +113,15 @@ pub async fn put(
req = req.header(k.as_str(), v.to_raw_string()); req = req.header(k.as_str(), v.to_raw_string());
} }
// Submit the request body // Submit the request body
if body.is_some() { req = encode_body(req, body);
req = req.json(&body);
}
// Send the request and wait // Send the request and wait
let res = match ctx.timeout() { let res = match ctx.timeout() {
#[cfg(not(target_arch = "wasm32"))] #[cfg(not(target_arch = "wasm32"))]
Some(d) => req.timeout(d).send().await?, Some(d) => req.timeout(d).send().await?,
_ => req.send().await?, _ => req.send().await?,
}; };
// Check the response status // Receive the response as a value
match res.status() { decode_response(res).await
s if s.is_success() => match res.headers().get(CONTENT_TYPE) {
Some(mime) => match mime.to_str() {
Ok(v) if v.starts_with("application/json") => {
let txt = res.text().await?;
let val = json(&txt)?;
Ok(val)
}
Ok(v) if v.starts_with("text") => {
let txt = res.text().await?;
let val = txt.into();
Ok(val)
}
_ => Ok(Value::None),
},
_ => Ok(Value::None),
},
s => Err(Error::Http(s.canonical_reason().unwrap_or_default().to_owned())),
}
} }
pub async fn post( pub async fn post(
@ -147,35 +143,15 @@ pub async fn post(
req = req.header(k.as_str(), v.to_raw_string()); req = req.header(k.as_str(), v.to_raw_string());
} }
// Submit the request body // Submit the request body
if body.is_some() { req = encode_body(req, body);
req = req.json(&body);
}
// Send the request and wait // Send the request and wait
let res = match ctx.timeout() { let res = match ctx.timeout() {
#[cfg(not(target_arch = "wasm32"))] #[cfg(not(target_arch = "wasm32"))]
Some(d) => req.timeout(d).send().await?, Some(d) => req.timeout(d).send().await?,
_ => req.send().await?, _ => req.send().await?,
}; };
// Check the response status // Receive the response as a value
match res.status() { decode_response(res).await
s if s.is_success() => match res.headers().get(CONTENT_TYPE) {
Some(mime) => match mime.to_str() {
Ok(v) if v.starts_with("application/json") => {
let txt = res.text().await?;
let val = json(&txt)?;
Ok(val)
}
Ok(v) if v.starts_with("text") => {
let txt = res.text().await?;
let val = txt.into();
Ok(val)
}
_ => Ok(Value::None),
},
_ => Ok(Value::None),
},
s => Err(Error::Http(s.canonical_reason().unwrap_or_default().to_owned())),
}
} }
pub async fn patch( pub async fn patch(
@ -197,35 +173,15 @@ pub async fn patch(
req = req.header(k.as_str(), v.to_raw_string()); req = req.header(k.as_str(), v.to_raw_string());
} }
// Submit the request body // Submit the request body
if body.is_some() { req = encode_body(req, body);
req = req.json(&body);
}
// Send the request and wait // Send the request and wait
let res = match ctx.timeout() { let res = match ctx.timeout() {
#[cfg(not(target_arch = "wasm32"))] #[cfg(not(target_arch = "wasm32"))]
Some(d) => req.timeout(d).send().await?, Some(d) => req.timeout(d).send().await?,
_ => req.send().await?, _ => req.send().await?,
}; };
// Check the response status // Receive the response as a value
match res.status() { decode_response(res).await
s if s.is_success() => match res.headers().get(CONTENT_TYPE) {
Some(mime) => match mime.to_str() {
Ok(v) if v.starts_with("application/json") => {
let txt = res.text().await?;
let val = json(&txt)?;
Ok(val)
}
Ok(v) if v.starts_with("text") => {
let txt = res.text().await?;
let val = txt.into();
Ok(val)
}
_ => Ok(Value::None),
},
_ => Ok(Value::None),
},
s => Err(Error::Http(s.canonical_reason().unwrap_or_default().to_owned())),
}
} }
pub async fn delete( pub async fn delete(
@ -251,24 +207,6 @@ pub async fn delete(
Some(d) => req.timeout(d).send().await?, Some(d) => req.timeout(d).send().await?,
_ => req.send().await?, _ => req.send().await?,
}; };
// Check the response status // Receive the response as a value
match res.status() { decode_response(res).await
s if s.is_success() => match res.headers().get(CONTENT_TYPE) {
Some(mime) => match mime.to_str() {
Ok(v) if v.starts_with("application/json") => {
let txt = res.text().await?;
let val = json(&txt)?;
Ok(val)
}
Ok(v) if v.starts_with("text") => {
let txt = res.text().await?;
let val = txt.into();
Ok(val)
}
_ => Ok(Value::None),
},
_ => Ok(Value::None),
},
s => Err(Error::Http(s.canonical_reason().unwrap_or_default().to_owned())),
}
} }

View file

@ -1,8 +1,39 @@
use serde::Deserialize; use base64_lib::{engine::general_purpose::STANDARD_NO_PAD, Engine};
use serde::Serialize; use serde::{
de::{self, Visitor},
Deserialize, Serialize,
};
use std::fmt::{self, Display, Formatter};
use std::ops::Deref;
#[derive(Clone, Debug, Default, Eq, PartialEq, PartialOrd, Deserialize, Hash)] #[derive(Clone, Debug, Default, Eq, PartialEq, PartialOrd, Hash)]
pub struct Bytes(pub(super) Vec<u8>); pub struct Bytes(pub(crate) Vec<u8>);
impl Bytes {
pub fn into_inner(self) -> Vec<u8> {
self.0
}
}
impl From<Vec<u8>> for Bytes {
fn from(v: Vec<u8>) -> Self {
Self(v)
}
}
impl Deref for Bytes {
type Target = Vec<u8>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl Display for Bytes {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "encoding::base64::decode(\"{}\")", STANDARD_NO_PAD.encode(&self.0))
}
}
impl Serialize for Bytes { impl Serialize for Bytes {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
@ -12,3 +43,50 @@ impl Serialize for Bytes {
serializer.serialize_bytes(&self.0) serializer.serialize_bytes(&self.0)
} }
} }
impl<'de> Deserialize<'de> for Bytes {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
struct RawBytesVisitor;
impl<'de> Visitor<'de> for RawBytesVisitor {
type Value = Bytes;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("bytes")
}
fn visit_byte_buf<E>(self, v: Vec<u8>) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(Bytes(v))
}
fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(Bytes(v.to_owned()))
}
}
deserializer.deserialize_byte_buf(RawBytesVisitor)
}
}
#[cfg(test)]
mod tests {
use crate::sql::{Bytes, Value};
#[test]
fn serialize() {
let val = Value::Bytes(Bytes(vec![1, 2, 3, 5]));
let serialized: Vec<u8> = val.clone().into();
println!("{serialized:?}");
let deserialized = Value::from(serialized);
assert_eq!(val, deserialized);
}
}

View file

@ -271,8 +271,10 @@ fn cast(i: &str) -> IResult<&str, Function> {
pub(crate) fn function_names(i: &str) -> IResult<&str, &str> { pub(crate) fn function_names(i: &str) -> IResult<&str, &str> {
recognize(alt(( recognize(alt((
preceded(tag("array::"), function_array), preceded(tag("array::"), function_array),
preceded(tag("bytes::"), function_bytes),
preceded(tag("crypto::"), function_crypto), preceded(tag("crypto::"), function_crypto),
preceded(tag("duration::"), function_duration), preceded(tag("duration::"), function_duration),
preceded(tag("encoding::"), function_encoding),
preceded(tag("geo::"), function_geo), preceded(tag("geo::"), function_geo),
preceded(tag("http::"), function_http), preceded(tag("http::"), function_http),
preceded(tag("is::"), function_is), preceded(tag("is::"), function_is),
@ -327,6 +329,10 @@ fn function_array(i: &str) -> IResult<&str, &str> {
))(i) ))(i)
} }
fn function_bytes(i: &str) -> IResult<&str, &str> {
alt((tag("len"),))(i)
}
fn function_crypto(i: &str) -> IResult<&str, &str> { fn function_crypto(i: &str) -> IResult<&str, &str> {
alt(( alt((
preceded(tag("argon2::"), alt((tag("compare"), tag("generate")))), preceded(tag("argon2::"), alt((tag("compare"), tag("generate")))),
@ -367,6 +373,10 @@ fn function_duration(i: &str) -> IResult<&str, &str> {
))(i) ))(i)
} }
fn function_encoding(i: &str) -> IResult<&str, &str> {
alt((preceded(tag("base64::"), alt((tag("decode"), tag("encode")))),))(i)
}
fn function_geo(i: &str) -> IResult<&str, &str> { fn function_geo(i: &str) -> IResult<&str, &str> {
alt(( alt((
tag("area"), tag("area"),

View file

@ -1493,6 +1493,8 @@ impl Value {
match self { match self {
// Bytes are allowed // Bytes are allowed
Value::Bytes(v) => Ok(v), Value::Bytes(v) => Ok(v),
// Strings can be converted to bytes
Value::Strand(s) => Ok(Bytes(s.0.into_bytes())),
// Anything else raises an error // Anything else raises an error
_ => Err(Error::ConvertTo { _ => Err(Error::ConvertTo {
from: self, from: self,
@ -1934,7 +1936,7 @@ impl fmt::Display for Value {
Value::Function(v) => write!(f, "{v}"), Value::Function(v) => write!(f, "{v}"),
Value::Subquery(v) => write!(f, "{v}"), Value::Subquery(v) => write!(f, "{v}"),
Value::Expression(v) => write!(f, "{v}"), Value::Expression(v) => write!(f, "{v}"),
Value::Bytes(_) => write!(f, "<bytes>"), Value::Bytes(v) => write!(f, "{v}"),
} }
} }
} }

View file

@ -855,6 +855,44 @@ async fn function_array_union() -> Result<(), Error> {
Ok(()) Ok(())
} }
// --------------------------------------------------
// bytes
// --------------------------------------------------
#[tokio::test]
async fn function_bytes_len() -> Result<(), Error> {
let sql = r#"
RETURN bytes::len(<bytes>"");
RETURN bytes::len(true);
RETURN bytes::len(<bytes>"π");
RETURN bytes::len("ππ");
"#;
let dbs = Datastore::new("memory").await?;
let ses = Session::for_kv().with_ns("test").with_db("test");
let res = &mut dbs.execute(&sql, &ses, None, false).await?;
assert_eq!(res.len(), 4);
//
let tmp = res.remove(0).result?;
let val = Value::parse("0");
assert_eq!(tmp, val);
//
let tmp = res.remove(0).result;
assert!(matches!(
tmp.err(),
Some(e) if e.to_string() == "Incorrect arguments for function bytes::len(). Argument 1 was the wrong type. Expected a bytes but failed to convert true into a bytes"
));
//
let tmp = res.remove(0).result?;
let val = Value::parse("2");
assert_eq!(tmp, val);
//
let tmp = res.remove(0).result?;
let val = Value::parse("4");
assert_eq!(tmp, val);
//
Ok(())
}
// -------------------------------------------------- // --------------------------------------------------
// count // count
// -------------------------------------------------- // --------------------------------------------------
@ -1391,6 +1429,54 @@ async fn function_duration_from_weeks() -> Result<(), Error> {
Ok(()) Ok(())
} }
// --------------------------------------------------
// encoding
// --------------------------------------------------
#[tokio::test]
async fn function_encoding_base64_decode() -> Result<(), Error> {
let sql = r#"
RETURN encoding::base64::decode("");
RETURN encoding::base64::decode("aGVsbG8") = <bytes>"hello";
"#;
let dbs = Datastore::new("memory").await?;
let ses = Session::for_kv().with_ns("test").with_db("test");
let res = &mut dbs.execute(&sql, &ses, None, false).await?;
assert_eq!(res.len(), 2);
//
let tmp = res.remove(0).result?;
let val = Value::Bytes(Vec::new().into());
assert_eq!(tmp, val);
//
let tmp = res.remove(0).result?;
let val = Value::from(true);
assert_eq!(tmp, val);
//
Ok(())
}
#[tokio::test]
async fn function_encoding_base64_encode() -> Result<(), Error> {
let sql = r#"
RETURN encoding::base64::encode("");
RETURN encoding::base64::encode("hello");
"#;
let dbs = Datastore::new("memory").await?;
let ses = Session::for_kv().with_ns("test").with_db("test");
let res = &mut dbs.execute(&sql, &ses, None, false).await?;
assert_eq!(res.len(), 2);
//
let tmp = res.remove(0).result?;
let val = Value::parse("''");
assert_eq!(tmp, val);
//
let tmp = res.remove(0).result?;
let val = Value::parse("'aGVsbG8'");
assert_eq!(tmp, val);
//
Ok(())
}
// -------------------------------------------------- // --------------------------------------------------
// geo // geo
// -------------------------------------------------- // --------------------------------------------------