Merge commit from fork

* fix

* move check into iam

* refactor

* fix typo

* add test

* remove is_computed

* impl validate_computed for array

* range support

* clippy

* tidy

* tidy

* add comments

---------

Co-authored-by: Raphael Darley <raphael.darley@surrealdb.com>
This commit is contained in:
Gerard Guillemas Martos 2024-09-10 17:19:45 +02:00 committed by GitHub
parent 714bf9ebc5
commit b7583a653a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 123 additions and 0 deletions

View file

@ -1103,6 +1103,9 @@ pub enum Error {
/// There was an outdated storage version stored in the database /// There was an outdated storage version stored in the database
#[error("The data stored on disk is out-of-date with this version. Please follow the upgrade guides in the documentation")] #[error("The data stored on disk is out-of-date with this version. Please follow the upgrade guides in the documentation")]
OutdatedStorageVersion, OutdatedStorageVersion,
#[error("Found a non-computed value where they are not allowed")]
NonComputed,
} }
impl From<Error> for String { impl From<Error> for String {

View file

@ -21,6 +21,8 @@ use subtle::ConstantTimeEq;
use uuid::Uuid; use uuid::Uuid;
pub async fn signin(kvs: &Datastore, session: &mut Session, vars: Object) -> Result<String, Error> { pub async fn signin(kvs: &Datastore, session: &mut Session, vars: Object) -> Result<String, Error> {
// Check vars contains only computed values
vars.validate_computed()?;
// Parse the specified variables // Parse the specified variables
let ns = vars.get("NS").or_else(|| vars.get("ns")); let ns = vars.get("NS").or_else(|| vars.get("ns"));
let db = vars.get("DB").or_else(|| vars.get("db")); let db = vars.get("DB").or_else(|| vars.get("db"));

View file

@ -19,6 +19,8 @@ pub async fn signup(
session: &mut Session, session: &mut Session,
vars: Object, vars: Object,
) -> Result<Option<String>, Error> { ) -> Result<Option<String>, Error> {
// Check vars contains only computed values
vars.validate_computed()?;
// Parse the specified variables // Parse the specified variables
let ns = vars.get("NS").or_else(|| vars.get("ns")); let ns = vars.get("NS").or_else(|| vars.get("ns"));
let db = vars.get("DB").or_else(|| vars.get("db")); let db = vars.get("DB").or_else(|| vars.get("db"));

View file

@ -168,6 +168,11 @@ impl Array {
pub(crate) fn is_static(&self) -> bool { pub(crate) fn is_static(&self) -> bool {
self.iter().all(Value::is_static) self.iter().all(Value::is_static)
} }
/// Validate that an Array contains only computed Values
pub fn validate_computed(&self) -> Result<(), Error> {
self.iter().try_for_each(|v| v.validate_computed())
}
} }
impl Display for Array { impl Display for Array {

View file

@ -235,6 +235,11 @@ impl Object {
pub(crate) fn is_static(&self) -> bool { pub(crate) fn is_static(&self) -> bool {
self.values().all(Value::is_static) self.values().all(Value::is_static)
} }
/// Validate that a Object contains only computed Values
pub(crate) fn validate_computed(&self) -> Result<(), Error> {
self.values().try_for_each(|v| v.validate_computed())
}
} }
impl Display for Object { impl Display for Object {

View file

@ -115,6 +115,20 @@ impl Range {
}, },
}))) })))
} }
/// Validate that a Range contains only computed Values
pub fn validate_computed(&self) -> Result<(), Error> {
match &self.beg {
Bound::Included(ref v) | Bound::Excluded(ref v) => v.validate_computed()?,
Bound::Unbounded => {}
}
match &self.end {
Bound::Included(ref v) | Bound::Excluded(ref v) => v.validate_computed()?,
Bound::Unbounded => {}
}
Ok(())
}
} }
impl PartialOrd for Range { impl PartialOrd for Range {

View file

@ -2890,6 +2890,21 @@ impl InfoStructure for Value {
} }
} }
impl Value {
/// Validate that a Value is computed or contains only computed Values
pub fn validate_computed(&self) -> Result<(), Error> {
use Value::*;
match self {
None | Null | Bool(_) | Number(_) | Strand(_) | Duration(_) | Datetime(_) | Uuid(_)
| Geometry(_) | Bytes(_) | Thing(_) => Ok(()),
Array(a) => a.validate_computed(),
Object(o) => o.validate_computed(),
Range(r) => r.validate_computed(),
_ => Err(Error::NonComputed),
}
}
}
impl Value { impl Value {
/// Check if we require a writeable transaction /// Check if we require a writeable transaction
pub(crate) fn writeable(&self) -> bool { pub(crate) fn writeable(&self) -> bool {

View file

@ -9,6 +9,7 @@ mod http_integration {
use reqwest::Client; use reqwest::Client;
use serde_json::json; use serde_json::json;
use surrealdb::headers::{AUTH_DB, AUTH_NS}; use surrealdb::headers::{AUTH_DB, AUTH_NS};
use surrealdb::sql;
use test_log::test; use test_log::test;
use ulid::Ulid; use ulid::Ulid;
@ -1727,4 +1728,80 @@ mod http_integration {
Ok(()) Ok(())
} }
#[test(tokio::test)]
async fn signup_mal() -> Result<(), Box<dyn std::error::Error>> {
let (addr, _server) = common::start_server_with_defaults().await.unwrap();
let rpc_url = &format!("http://{addr}/rpc");
let ns = Ulid::new().to_string();
let db = Ulid::new().to_string();
// Prepare HTTP client
let mut headers = reqwest::header::HeaderMap::new();
headers.insert("surreal-ns", ns.parse()?);
headers.insert("surreal-db", db.parse()?);
headers.insert(header::ACCEPT, "application/surrealdb".parse()?);
headers.insert(header::CONTENT_TYPE, "application/surrealdb".parse()?);
let client = reqwest::Client::builder()
.connect_timeout(Duration::from_millis(10))
.default_headers(headers)
.build()?;
// Define a record access method
{
let res = client
.post(format!("http://{addr}/sql"))
.basic_auth(USER, Some(PASS))
.body(
r#"
DEFINE ACCESS user ON DATABASE TYPE RECORD
SIGNUP ( CREATE user SET email = $email, pass = crypto::argon2::generate($pass) )
SIGNIN ( SELECT * FROM user WHERE email = $email AND crypto::argon2::compare(pass, $pass) )
DURATION FOR SESSION 12h
;
"#,
)
.send()
.await?;
assert!(res.status().is_success(), "body: {}", res.text().await?);
}
{
let mut request = sql::Object::default();
request.insert("method".to_string(), "signup".into());
let stmt: sql::Statement = {
let mut tmp = sql::statements::CreateStatement::default();
let rid = sql::thing("foo:42").unwrap();
let mut tmp_values = sql::Values::default();
tmp_values.0 = vec![rid.into()];
tmp.what = tmp_values;
sql::Statement::Create(tmp)
};
let mut obj = sql::Object::default();
obj.insert("email".to_string(), sql::Value::Query(stmt.into()));
obj.insert("pass".to_string(), "foo".into());
request.insert(
"params".to_string(),
sql::Value::Array(vec![sql::Value::Object(obj)].into()),
);
let req: sql::Value = sql::Value::Object(request);
let req = sql::serde::serialize(&req).unwrap();
let res = client.post(rpc_url).body(req).send().await?;
let body = res.text().await?;
assert!(
body.contains("Found a non-computed value where they are not allowed"),
"{body:?}"
);
}
Ok(())
}
} }