Merge commit from fork
* fix * move check into iam * refactor * fix typo * add test * remove is_computed * impl validate_computed for array * range support * clippy * tidy * tidy * add comments --------- Co-authored-by: Raphael Darley <raphael.darley@surrealdb.com>
This commit is contained in:
parent
714bf9ebc5
commit
b7583a653a
8 changed files with 123 additions and 0 deletions
|
@ -1103,6 +1103,9 @@ pub enum Error {
|
||||||
/// There was an outdated storage version stored in the database
|
/// There was an outdated storage version stored in the database
|
||||||
#[error("The data stored on disk is out-of-date with this version. Please follow the upgrade guides in the documentation")]
|
#[error("The data stored on disk is out-of-date with this version. Please follow the upgrade guides in the documentation")]
|
||||||
OutdatedStorageVersion,
|
OutdatedStorageVersion,
|
||||||
|
|
||||||
|
#[error("Found a non-computed value where they are not allowed")]
|
||||||
|
NonComputed,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<Error> for String {
|
impl From<Error> for String {
|
||||||
|
|
|
@ -21,6 +21,8 @@ use subtle::ConstantTimeEq;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
pub async fn signin(kvs: &Datastore, session: &mut Session, vars: Object) -> Result<String, Error> {
|
pub async fn signin(kvs: &Datastore, session: &mut Session, vars: Object) -> Result<String, Error> {
|
||||||
|
// Check vars contains only computed values
|
||||||
|
vars.validate_computed()?;
|
||||||
// Parse the specified variables
|
// Parse the specified variables
|
||||||
let ns = vars.get("NS").or_else(|| vars.get("ns"));
|
let ns = vars.get("NS").or_else(|| vars.get("ns"));
|
||||||
let db = vars.get("DB").or_else(|| vars.get("db"));
|
let db = vars.get("DB").or_else(|| vars.get("db"));
|
||||||
|
|
|
@ -19,6 +19,8 @@ pub async fn signup(
|
||||||
session: &mut Session,
|
session: &mut Session,
|
||||||
vars: Object,
|
vars: Object,
|
||||||
) -> Result<Option<String>, Error> {
|
) -> Result<Option<String>, Error> {
|
||||||
|
// Check vars contains only computed values
|
||||||
|
vars.validate_computed()?;
|
||||||
// Parse the specified variables
|
// Parse the specified variables
|
||||||
let ns = vars.get("NS").or_else(|| vars.get("ns"));
|
let ns = vars.get("NS").or_else(|| vars.get("ns"));
|
||||||
let db = vars.get("DB").or_else(|| vars.get("db"));
|
let db = vars.get("DB").or_else(|| vars.get("db"));
|
||||||
|
|
|
@ -168,6 +168,11 @@ impl Array {
|
||||||
pub(crate) fn is_static(&self) -> bool {
|
pub(crate) fn is_static(&self) -> bool {
|
||||||
self.iter().all(Value::is_static)
|
self.iter().all(Value::is_static)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Validate that an Array contains only computed Values
|
||||||
|
pub fn validate_computed(&self) -> Result<(), Error> {
|
||||||
|
self.iter().try_for_each(|v| v.validate_computed())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Display for Array {
|
impl Display for Array {
|
||||||
|
|
|
@ -235,6 +235,11 @@ impl Object {
|
||||||
pub(crate) fn is_static(&self) -> bool {
|
pub(crate) fn is_static(&self) -> bool {
|
||||||
self.values().all(Value::is_static)
|
self.values().all(Value::is_static)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Validate that a Object contains only computed Values
|
||||||
|
pub(crate) fn validate_computed(&self) -> Result<(), Error> {
|
||||||
|
self.values().try_for_each(|v| v.validate_computed())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Display for Object {
|
impl Display for Object {
|
||||||
|
|
|
@ -115,6 +115,20 @@ impl Range {
|
||||||
},
|
},
|
||||||
})))
|
})))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Validate that a Range contains only computed Values
|
||||||
|
pub fn validate_computed(&self) -> Result<(), Error> {
|
||||||
|
match &self.beg {
|
||||||
|
Bound::Included(ref v) | Bound::Excluded(ref v) => v.validate_computed()?,
|
||||||
|
Bound::Unbounded => {}
|
||||||
|
}
|
||||||
|
match &self.end {
|
||||||
|
Bound::Included(ref v) | Bound::Excluded(ref v) => v.validate_computed()?,
|
||||||
|
Bound::Unbounded => {}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl PartialOrd for Range {
|
impl PartialOrd for Range {
|
||||||
|
|
|
@ -2890,6 +2890,21 @@ impl InfoStructure for Value {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Value {
|
||||||
|
/// Validate that a Value is computed or contains only computed Values
|
||||||
|
pub fn validate_computed(&self) -> Result<(), Error> {
|
||||||
|
use Value::*;
|
||||||
|
match self {
|
||||||
|
None | Null | Bool(_) | Number(_) | Strand(_) | Duration(_) | Datetime(_) | Uuid(_)
|
||||||
|
| Geometry(_) | Bytes(_) | Thing(_) => Ok(()),
|
||||||
|
Array(a) => a.validate_computed(),
|
||||||
|
Object(o) => o.validate_computed(),
|
||||||
|
Range(r) => r.validate_computed(),
|
||||||
|
_ => Err(Error::NonComputed),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl Value {
|
impl Value {
|
||||||
/// Check if we require a writeable transaction
|
/// Check if we require a writeable transaction
|
||||||
pub(crate) fn writeable(&self) -> bool {
|
pub(crate) fn writeable(&self) -> bool {
|
||||||
|
|
|
@ -9,6 +9,7 @@ mod http_integration {
|
||||||
use reqwest::Client;
|
use reqwest::Client;
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
use surrealdb::headers::{AUTH_DB, AUTH_NS};
|
use surrealdb::headers::{AUTH_DB, AUTH_NS};
|
||||||
|
use surrealdb::sql;
|
||||||
use test_log::test;
|
use test_log::test;
|
||||||
use ulid::Ulid;
|
use ulid::Ulid;
|
||||||
|
|
||||||
|
@ -1727,4 +1728,80 @@ mod http_integration {
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test(tokio::test)]
|
||||||
|
async fn signup_mal() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
let (addr, _server) = common::start_server_with_defaults().await.unwrap();
|
||||||
|
let rpc_url = &format!("http://{addr}/rpc");
|
||||||
|
|
||||||
|
let ns = Ulid::new().to_string();
|
||||||
|
let db = Ulid::new().to_string();
|
||||||
|
|
||||||
|
// Prepare HTTP client
|
||||||
|
let mut headers = reqwest::header::HeaderMap::new();
|
||||||
|
headers.insert("surreal-ns", ns.parse()?);
|
||||||
|
headers.insert("surreal-db", db.parse()?);
|
||||||
|
headers.insert(header::ACCEPT, "application/surrealdb".parse()?);
|
||||||
|
headers.insert(header::CONTENT_TYPE, "application/surrealdb".parse()?);
|
||||||
|
let client = reqwest::Client::builder()
|
||||||
|
.connect_timeout(Duration::from_millis(10))
|
||||||
|
.default_headers(headers)
|
||||||
|
.build()?;
|
||||||
|
|
||||||
|
// Define a record access method
|
||||||
|
{
|
||||||
|
let res = client
|
||||||
|
.post(format!("http://{addr}/sql"))
|
||||||
|
.basic_auth(USER, Some(PASS))
|
||||||
|
.body(
|
||||||
|
r#"
|
||||||
|
DEFINE ACCESS user ON DATABASE TYPE RECORD
|
||||||
|
SIGNUP ( CREATE user SET email = $email, pass = crypto::argon2::generate($pass) )
|
||||||
|
SIGNIN ( SELECT * FROM user WHERE email = $email AND crypto::argon2::compare(pass, $pass) )
|
||||||
|
DURATION FOR SESSION 12h
|
||||||
|
;
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
.send()
|
||||||
|
.await?;
|
||||||
|
assert!(res.status().is_success(), "body: {}", res.text().await?);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
let mut request = sql::Object::default();
|
||||||
|
request.insert("method".to_string(), "signup".into());
|
||||||
|
|
||||||
|
let stmt: sql::Statement = {
|
||||||
|
let mut tmp = sql::statements::CreateStatement::default();
|
||||||
|
let rid = sql::thing("foo:42").unwrap();
|
||||||
|
let mut tmp_values = sql::Values::default();
|
||||||
|
tmp_values.0 = vec![rid.into()];
|
||||||
|
tmp.what = tmp_values;
|
||||||
|
sql::Statement::Create(tmp)
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut obj = sql::Object::default();
|
||||||
|
obj.insert("email".to_string(), sql::Value::Query(stmt.into()));
|
||||||
|
obj.insert("pass".to_string(), "foo".into());
|
||||||
|
request.insert(
|
||||||
|
"params".to_string(),
|
||||||
|
sql::Value::Array(vec![sql::Value::Object(obj)].into()),
|
||||||
|
);
|
||||||
|
|
||||||
|
let req: sql::Value = sql::Value::Object(request);
|
||||||
|
|
||||||
|
let req = sql::serde::serialize(&req).unwrap();
|
||||||
|
|
||||||
|
let res = client.post(rpc_url).body(req).send().await?;
|
||||||
|
|
||||||
|
let body = res.text().await?;
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
body.contains("Found a non-computed value where they are not allowed"),
|
||||||
|
"{body:?}"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue