Merge commit from fork
* fix * move check into iam * refactor * fix typo * add test * remove is_computed * impl validate_computed for array * range support * clippy * tidy * tidy * add comments --------- Co-authored-by: Raphael Darley <raphael.darley@surrealdb.com>
This commit is contained in:
parent
714bf9ebc5
commit
b7583a653a
8 changed files with 123 additions and 0 deletions
|
@ -1103,6 +1103,9 @@ pub enum Error {
|
|||
/// There was an outdated storage version stored in the database
|
||||
#[error("The data stored on disk is out-of-date with this version. Please follow the upgrade guides in the documentation")]
|
||||
OutdatedStorageVersion,
|
||||
|
||||
#[error("Found a non-computed value where they are not allowed")]
|
||||
NonComputed,
|
||||
}
|
||||
|
||||
impl From<Error> for String {
|
||||
|
|
|
@ -21,6 +21,8 @@ use subtle::ConstantTimeEq;
|
|||
use uuid::Uuid;
|
||||
|
||||
pub async fn signin(kvs: &Datastore, session: &mut Session, vars: Object) -> Result<String, Error> {
|
||||
// Check vars contains only computed values
|
||||
vars.validate_computed()?;
|
||||
// Parse the specified variables
|
||||
let ns = vars.get("NS").or_else(|| vars.get("ns"));
|
||||
let db = vars.get("DB").or_else(|| vars.get("db"));
|
||||
|
|
|
@ -19,6 +19,8 @@ pub async fn signup(
|
|||
session: &mut Session,
|
||||
vars: Object,
|
||||
) -> Result<Option<String>, Error> {
|
||||
// Check vars contains only computed values
|
||||
vars.validate_computed()?;
|
||||
// Parse the specified variables
|
||||
let ns = vars.get("NS").or_else(|| vars.get("ns"));
|
||||
let db = vars.get("DB").or_else(|| vars.get("db"));
|
||||
|
|
|
@ -168,6 +168,11 @@ impl Array {
|
|||
pub(crate) fn is_static(&self) -> bool {
|
||||
self.iter().all(Value::is_static)
|
||||
}
|
||||
|
||||
/// Validate that an Array contains only computed Values
|
||||
pub fn validate_computed(&self) -> Result<(), Error> {
|
||||
self.iter().try_for_each(|v| v.validate_computed())
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for Array {
|
||||
|
|
|
@ -235,6 +235,11 @@ impl Object {
|
|||
pub(crate) fn is_static(&self) -> bool {
|
||||
self.values().all(Value::is_static)
|
||||
}
|
||||
|
||||
/// Validate that a Object contains only computed Values
|
||||
pub(crate) fn validate_computed(&self) -> Result<(), Error> {
|
||||
self.values().try_for_each(|v| v.validate_computed())
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for Object {
|
||||
|
|
|
@ -115,6 +115,20 @@ impl Range {
|
|||
},
|
||||
})))
|
||||
}
|
||||
|
||||
/// Validate that a Range contains only computed Values
|
||||
pub fn validate_computed(&self) -> Result<(), Error> {
|
||||
match &self.beg {
|
||||
Bound::Included(ref v) | Bound::Excluded(ref v) => v.validate_computed()?,
|
||||
Bound::Unbounded => {}
|
||||
}
|
||||
match &self.end {
|
||||
Bound::Included(ref v) | Bound::Excluded(ref v) => v.validate_computed()?,
|
||||
Bound::Unbounded => {}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialOrd for Range {
|
||||
|
|
|
@ -2890,6 +2890,21 @@ impl InfoStructure for Value {
|
|||
}
|
||||
}
|
||||
|
||||
impl Value {
|
||||
/// Validate that a Value is computed or contains only computed Values
|
||||
pub fn validate_computed(&self) -> Result<(), Error> {
|
||||
use Value::*;
|
||||
match self {
|
||||
None | Null | Bool(_) | Number(_) | Strand(_) | Duration(_) | Datetime(_) | Uuid(_)
|
||||
| Geometry(_) | Bytes(_) | Thing(_) => Ok(()),
|
||||
Array(a) => a.validate_computed(),
|
||||
Object(o) => o.validate_computed(),
|
||||
Range(r) => r.validate_computed(),
|
||||
_ => Err(Error::NonComputed),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Value {
|
||||
/// Check if we require a writeable transaction
|
||||
pub(crate) fn writeable(&self) -> bool {
|
||||
|
|
|
@ -9,6 +9,7 @@ mod http_integration {
|
|||
use reqwest::Client;
|
||||
use serde_json::json;
|
||||
use surrealdb::headers::{AUTH_DB, AUTH_NS};
|
||||
use surrealdb::sql;
|
||||
use test_log::test;
|
||||
use ulid::Ulid;
|
||||
|
||||
|
@ -1727,4 +1728,80 @@ mod http_integration {
|
|||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn signup_mal() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let (addr, _server) = common::start_server_with_defaults().await.unwrap();
|
||||
let rpc_url = &format!("http://{addr}/rpc");
|
||||
|
||||
let ns = Ulid::new().to_string();
|
||||
let db = Ulid::new().to_string();
|
||||
|
||||
// Prepare HTTP client
|
||||
let mut headers = reqwest::header::HeaderMap::new();
|
||||
headers.insert("surreal-ns", ns.parse()?);
|
||||
headers.insert("surreal-db", db.parse()?);
|
||||
headers.insert(header::ACCEPT, "application/surrealdb".parse()?);
|
||||
headers.insert(header::CONTENT_TYPE, "application/surrealdb".parse()?);
|
||||
let client = reqwest::Client::builder()
|
||||
.connect_timeout(Duration::from_millis(10))
|
||||
.default_headers(headers)
|
||||
.build()?;
|
||||
|
||||
// Define a record access method
|
||||
{
|
||||
let res = client
|
||||
.post(format!("http://{addr}/sql"))
|
||||
.basic_auth(USER, Some(PASS))
|
||||
.body(
|
||||
r#"
|
||||
DEFINE ACCESS user ON DATABASE TYPE RECORD
|
||||
SIGNUP ( CREATE user SET email = $email, pass = crypto::argon2::generate($pass) )
|
||||
SIGNIN ( SELECT * FROM user WHERE email = $email AND crypto::argon2::compare(pass, $pass) )
|
||||
DURATION FOR SESSION 12h
|
||||
;
|
||||
"#,
|
||||
)
|
||||
.send()
|
||||
.await?;
|
||||
assert!(res.status().is_success(), "body: {}", res.text().await?);
|
||||
}
|
||||
|
||||
{
|
||||
let mut request = sql::Object::default();
|
||||
request.insert("method".to_string(), "signup".into());
|
||||
|
||||
let stmt: sql::Statement = {
|
||||
let mut tmp = sql::statements::CreateStatement::default();
|
||||
let rid = sql::thing("foo:42").unwrap();
|
||||
let mut tmp_values = sql::Values::default();
|
||||
tmp_values.0 = vec![rid.into()];
|
||||
tmp.what = tmp_values;
|
||||
sql::Statement::Create(tmp)
|
||||
};
|
||||
|
||||
let mut obj = sql::Object::default();
|
||||
obj.insert("email".to_string(), sql::Value::Query(stmt.into()));
|
||||
obj.insert("pass".to_string(), "foo".into());
|
||||
request.insert(
|
||||
"params".to_string(),
|
||||
sql::Value::Array(vec![sql::Value::Object(obj)].into()),
|
||||
);
|
||||
|
||||
let req: sql::Value = sql::Value::Object(request);
|
||||
|
||||
let req = sql::serde::serialize(&req).unwrap();
|
||||
|
||||
let res = client.post(rpc_url).body(req).send().await?;
|
||||
|
||||
let body = res.text().await?;
|
||||
|
||||
assert!(
|
||||
body.contains("Found a non-computed value where they are not allowed"),
|
||||
"{body:?}"
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue