diff --git a/core/src/err/mod.rs b/core/src/err/mod.rs index 58db059f..526f3855 100644 --- a/core/src/err/mod.rs +++ b/core/src/err/mod.rs @@ -1103,6 +1103,9 @@ pub enum Error { /// There was an outdated storage version stored in the database #[error("The data stored on disk is out-of-date with this version. Please follow the upgrade guides in the documentation")] OutdatedStorageVersion, + + #[error("Found a non-computed value where they are not allowed")] + NonComputed, } impl From for String { diff --git a/core/src/iam/signin.rs b/core/src/iam/signin.rs index 87f7209d..b426066a 100644 --- a/core/src/iam/signin.rs +++ b/core/src/iam/signin.rs @@ -21,6 +21,8 @@ use subtle::ConstantTimeEq; use uuid::Uuid; pub async fn signin(kvs: &Datastore, session: &mut Session, vars: Object) -> Result { + // Check vars contains only computed values + vars.validate_computed()?; // Parse the specified variables let ns = vars.get("NS").or_else(|| vars.get("ns")); let db = vars.get("DB").or_else(|| vars.get("db")); diff --git a/core/src/iam/signup.rs b/core/src/iam/signup.rs index b83dfb87..15cd11ff 100644 --- a/core/src/iam/signup.rs +++ b/core/src/iam/signup.rs @@ -19,6 +19,8 @@ pub async fn signup( session: &mut Session, vars: Object, ) -> Result, Error> { + // Check vars contains only computed values + vars.validate_computed()?; // Parse the specified variables let ns = vars.get("NS").or_else(|| vars.get("ns")); let db = vars.get("DB").or_else(|| vars.get("db")); diff --git a/core/src/sql/array.rs b/core/src/sql/array.rs index ae13f50a..04fec8f6 100644 --- a/core/src/sql/array.rs +++ b/core/src/sql/array.rs @@ -168,6 +168,11 @@ impl Array { pub(crate) fn is_static(&self) -> bool { self.iter().all(Value::is_static) } + + /// Validate that an Array contains only computed Values + pub fn validate_computed(&self) -> Result<(), Error> { + self.iter().try_for_each(|v| v.validate_computed()) + } } impl Display for Array { diff --git a/core/src/sql/object.rs b/core/src/sql/object.rs index 7eb39c4d..dedf491a 100644 --- a/core/src/sql/object.rs +++ b/core/src/sql/object.rs @@ -235,6 +235,11 @@ impl Object { pub(crate) fn is_static(&self) -> bool { self.values().all(Value::is_static) } + + /// Validate that a Object contains only computed Values + pub(crate) fn validate_computed(&self) -> Result<(), Error> { + self.values().try_for_each(|v| v.validate_computed()) + } } impl Display for Object { diff --git a/core/src/sql/range.rs b/core/src/sql/range.rs index 1dee74d3..7c3159d9 100644 --- a/core/src/sql/range.rs +++ b/core/src/sql/range.rs @@ -115,6 +115,20 @@ impl Range { }, }))) } + + /// Validate that a Range contains only computed Values + pub fn validate_computed(&self) -> Result<(), Error> { + match &self.beg { + Bound::Included(ref v) | Bound::Excluded(ref v) => v.validate_computed()?, + Bound::Unbounded => {} + } + match &self.end { + Bound::Included(ref v) | Bound::Excluded(ref v) => v.validate_computed()?, + Bound::Unbounded => {} + } + + Ok(()) + } } impl PartialOrd for Range { diff --git a/core/src/sql/value/value.rs b/core/src/sql/value/value.rs index a9e32950..af05eeed 100644 --- a/core/src/sql/value/value.rs +++ b/core/src/sql/value/value.rs @@ -2890,6 +2890,21 @@ impl InfoStructure for Value { } } +impl Value { + /// Validate that a Value is computed or contains only computed Values + pub fn validate_computed(&self) -> Result<(), Error> { + use Value::*; + match self { + None | Null | Bool(_) | Number(_) | Strand(_) | Duration(_) | Datetime(_) | Uuid(_) + | Geometry(_) | Bytes(_) | Thing(_) => Ok(()), + Array(a) => a.validate_computed(), + Object(o) => o.validate_computed(), + Range(r) => r.validate_computed(), + _ => Err(Error::NonComputed), + } + } +} + impl Value { /// Check if we require a writeable transaction pub(crate) fn writeable(&self) -> bool { diff --git a/tests/http_integration.rs b/tests/http_integration.rs index 105baa2b..e824d49e 100644 --- a/tests/http_integration.rs +++ b/tests/http_integration.rs @@ -9,6 +9,7 @@ mod http_integration { use reqwest::Client; use serde_json::json; use surrealdb::headers::{AUTH_DB, AUTH_NS}; + use surrealdb::sql; use test_log::test; use ulid::Ulid; @@ -1727,4 +1728,80 @@ mod http_integration { Ok(()) } + + #[test(tokio::test)] + async fn signup_mal() -> Result<(), Box> { + let (addr, _server) = common::start_server_with_defaults().await.unwrap(); + let rpc_url = &format!("http://{addr}/rpc"); + + let ns = Ulid::new().to_string(); + let db = Ulid::new().to_string(); + + // Prepare HTTP client + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert("surreal-ns", ns.parse()?); + headers.insert("surreal-db", db.parse()?); + headers.insert(header::ACCEPT, "application/surrealdb".parse()?); + headers.insert(header::CONTENT_TYPE, "application/surrealdb".parse()?); + let client = reqwest::Client::builder() + .connect_timeout(Duration::from_millis(10)) + .default_headers(headers) + .build()?; + + // Define a record access method + { + let res = client + .post(format!("http://{addr}/sql")) + .basic_auth(USER, Some(PASS)) + .body( + r#" + DEFINE ACCESS user ON DATABASE TYPE RECORD + SIGNUP ( CREATE user SET email = $email, pass = crypto::argon2::generate($pass) ) + SIGNIN ( SELECT * FROM user WHERE email = $email AND crypto::argon2::compare(pass, $pass) ) + DURATION FOR SESSION 12h + ; + "#, + ) + .send() + .await?; + assert!(res.status().is_success(), "body: {}", res.text().await?); + } + + { + let mut request = sql::Object::default(); + request.insert("method".to_string(), "signup".into()); + + let stmt: sql::Statement = { + let mut tmp = sql::statements::CreateStatement::default(); + let rid = sql::thing("foo:42").unwrap(); + let mut tmp_values = sql::Values::default(); + tmp_values.0 = vec![rid.into()]; + tmp.what = tmp_values; + sql::Statement::Create(tmp) + }; + + let mut obj = sql::Object::default(); + obj.insert("email".to_string(), sql::Value::Query(stmt.into())); + obj.insert("pass".to_string(), "foo".into()); + request.insert( + "params".to_string(), + sql::Value::Array(vec![sql::Value::Object(obj)].into()), + ); + + let req: sql::Value = sql::Value::Object(request); + + let req = sql::serde::serialize(&req).unwrap(); + + let res = client.post(rpc_url).body(req).send().await?; + + let body = res.text().await?; + + assert!( + body.contains("Found a non-computed value where they are not allowed"), + "{body:?}" + ); + } + + Ok(()) + } }