diff --git a/lib/src/dbs/executor.rs b/lib/src/dbs/executor.rs index 09890587..5a08b0eb 100644 --- a/lib/src/dbs/executor.rs +++ b/lib/src/dbs/executor.rs @@ -7,6 +7,8 @@ use crate::dbs::Transaction; use crate::dbs::LOG; use crate::err::Error; use crate::kvs::Datastore; +use crate::sql::paths::DB; +use crate::sql::paths::NS; use crate::sql::query::Query; use crate::sql::statement::Statement; use crate::sql::value::Value; @@ -112,6 +114,20 @@ impl<'a> Executor<'a> { } } + async fn set_ns(&self, ctx: &mut Context<'_>, opt: &mut Options, ns: &str) { + let mut session = ctx.value("session").unwrap_or(&Value::None).clone(); + session.put(NS.as_ref(), ns.to_owned().into()); + ctx.add_value(String::from("session"), session); + opt.ns = Some(Arc::new(ns.to_owned())); + } + + async fn set_db(&self, ctx: &mut Context<'_>, opt: &mut Options, db: &str) { + let mut session = ctx.value("session").unwrap_or(&Value::None).clone(); + session.put(DB.as_ref(), db.to_owned().into()); + ctx.add_value(String::from("session"), session); + opt.db = Some(Arc::new(db.to_owned())); + } + pub async fn execute( &mut self, mut ctx: Context<'_>, @@ -178,9 +194,9 @@ impl<'a> Executor<'a> { Statement::Use(stm) => { if let Some(ref ns) = stm.ns { match &*opt.auth { - Auth::No => opt.ns = Some(Arc::new(ns.to_owned())), - Auth::Kv => opt.ns = Some(Arc::new(ns.to_owned())), - Auth::Ns(v) if v == ns => opt.ns = Some(Arc::new(ns.to_owned())), + Auth::No => self.set_ns(&mut ctx, &mut opt, ns).await, + Auth::Kv => self.set_ns(&mut ctx, &mut opt, ns).await, + Auth::Ns(v) if v == ns => self.set_ns(&mut ctx, &mut opt, ns).await, _ => { opt.ns = None; return Err(Error::NsNotAllowed { @@ -191,10 +207,10 @@ impl<'a> Executor<'a> { } if let Some(ref db) = stm.db { match &*opt.auth { - Auth::No => opt.db = Some(Arc::new(db.to_owned())), - Auth::Kv => opt.db = Some(Arc::new(db.to_owned())), - Auth::Ns(_) => opt.db = Some(Arc::new(db.to_owned())), - Auth::Db(_, v) if v == db => opt.db = Some(Arc::new(db.to_owned())), + Auth::No => self.set_db(&mut ctx, &mut opt, db).await, + Auth::Kv => self.set_db(&mut ctx, &mut opt, db).await, + Auth::Ns(_) => self.set_db(&mut ctx, &mut opt, db).await, + Auth::Db(_, v) if v == db => self.set_db(&mut ctx, &mut opt, db).await, _ => { opt.db = None; return Err(Error::DbNotAllowed { diff --git a/lib/src/sql/value/mod.rs b/lib/src/sql/value/mod.rs index a29c1b91..a0fa56cb 100644 --- a/lib/src/sql/value/mod.rs +++ b/lib/src/sql/value/mod.rs @@ -22,6 +22,7 @@ mod merge; mod object; mod patch; mod pick; +mod put; mod replace; mod retable; mod set; diff --git a/lib/src/sql/value/put.rs b/lib/src/sql/value/put.rs new file mode 100644 index 00000000..52dd400a --- /dev/null +++ b/lib/src/sql/value/put.rs @@ -0,0 +1,195 @@ +use crate::sql::part::Next; +use crate::sql::part::Part; +use crate::sql::value::Value; + +impl Value { + pub fn put(&mut self, path: &[Part], val: Value) { + match path.first() { + // Get the current path part + Some(p) => match self { + // Current path part is an object + Value::Object(v) => match p { + Part::Thing(t) => match v.get_mut(t.to_raw().as_str()) { + Some(v) if v.is_some() => v.put(path.next(), val), + _ => { + let mut obj = Value::base(); + obj.put(path.next(), val); + v.insert(t.to_raw(), obj); + } + }, + Part::Graph(g) => match v.get_mut(g.to_raw().as_str()) { + Some(v) if v.is_some() => v.put(path.next(), val), + _ => { + let mut obj = Value::base(); + obj.put(path.next(), val); + v.insert(g.to_raw(), obj); + } + }, + Part::Field(f) => match v.get_mut(f.to_raw().as_str()) { + Some(v) if v.is_some() => v.put(path.next(), val), + _ => { + let mut obj = Value::base(); + obj.put(path.next(), val); + v.insert(f.to_raw(), obj); + } + }, + _ => (), + }, + // Current path part is an array + Value::Array(v) => match p { + Part::All => { + let path = path.next(); + v.iter_mut().for_each(|v| v.put(path, val.clone())); + } + Part::First => match v.first_mut() { + Some(v) => v.put(path.next(), val), + None => (), + }, + Part::Last => match v.last_mut() { + Some(v) => v.put(path.next(), val), + None => (), + }, + Part::Index(i) => match v.get_mut(i.to_usize()) { + Some(v) => v.put(path.next(), val), + None => (), + }, + _ => { + v.iter_mut().for_each(|v| v.put(path, val.clone())); + } + }, + // Current path part is empty + Value::Null => { + *self = Value::base(); + self.put(path, val) + } + // Current path part is empty + Value::None => { + *self = Value::base(); + self.put(path, val) + } + // Ignore everything else + _ => (), + }, + // No more parts so put the value + None => { + *self = val; + } + } + } +} + +#[cfg(test)] +mod tests { + + use super::*; + use crate::sql::idiom::Idiom; + use crate::sql::test::Parse; + + #[tokio::test] + async fn put_none() { + let idi = Idiom::default(); + let mut val = Value::parse("{ test: { other: null, something: 123 } }"); + let res = Value::parse("999"); + val.put(&idi, Value::from(999)); + assert_eq!(res, val); + } + + #[tokio::test] + async fn put_empty() { + let idi = Idiom::parse("test"); + let mut val = Value::None; + let res = Value::parse("{ test: 999 }"); + val.put(&idi, Value::from(999)); + assert_eq!(res, val); + } + + #[tokio::test] + async fn put_blank() { + let idi = Idiom::parse("test.something"); + let mut val = Value::None; + let res = Value::parse("{ test: { something: 999 } }"); + val.put(&idi, Value::from(999)); + assert_eq!(res, val); + } + + #[tokio::test] + async fn put_reput() { + let idi = Idiom::parse("test"); + let mut val = Value::parse("{ test: { other: null, something: 123 } }"); + let res = Value::parse("{ test: 999 }"); + val.put(&idi, Value::from(999)); + assert_eq!(res, val); + } + + #[tokio::test] + async fn put_basic() { + let idi = Idiom::parse("test.something"); + let mut val = Value::parse("{ test: { other: null, something: 123 } }"); + let res = Value::parse("{ test: { other: null, something: 999 } }"); + val.put(&idi, Value::from(999)); + assert_eq!(res, val); + } + + #[tokio::test] + async fn put_allow() { + let idi = Idiom::parse("test.something.allow"); + let mut val = Value::parse("{ test: { other: null } }"); + let res = Value::parse("{ test: { other: null, something: { allow: 999 } } }"); + val.put(&idi, Value::from(999)); + assert_eq!(res, val); + } + + #[tokio::test] + async fn put_wrong() { + let idi = Idiom::parse("test.something.wrong"); + let mut val = Value::parse("{ test: { other: null, something: 123 } }"); + let res = Value::parse("{ test: { other: null, something: 123 } }"); + val.put(&idi, Value::from(999)); + assert_eq!(res, val); + } + + #[tokio::test] + async fn put_other() { + let idi = Idiom::parse("test.other.something"); + let mut val = Value::parse("{ test: { other: null, something: 123 } }"); + let res = Value::parse("{ test: { other: { something: 999 }, something: 123 } }"); + val.put(&idi, Value::from(999)); + assert_eq!(res, val); + } + + #[tokio::test] + async fn put_array() { + let idi = Idiom::parse("test.something[1]"); + let mut val = Value::parse("{ test: { something: [123, 456, 789] } }"); + let res = Value::parse("{ test: { something: [123, 999, 789] } }"); + val.put(&idi, Value::from(999)); + assert_eq!(res, val); + } + + #[tokio::test] + async fn put_array_field() { + let idi = Idiom::parse("test.something[1].age"); + let mut val = Value::parse("{ test: { something: [{ age: 34 }, { age: 36 }] } }"); + let res = Value::parse("{ test: { something: [{ age: 34 }, { age: 21 }] } }"); + val.put(&idi, Value::from(21)); + assert_eq!(res, val); + } + + #[tokio::test] + async fn put_array_fields() { + let idi = Idiom::parse("test.something[*].age"); + let mut val = Value::parse("{ test: { something: [{ age: 34 }, { age: 36 }] } }"); + let res = Value::parse("{ test: { something: [{ age: 21 }, { age: 21 }] } }"); + val.put(&idi, Value::from(21)); + assert_eq!(res, val); + } + + #[tokio::test] + async fn put_array_fields_flat() { + let idi = Idiom::parse("test.something.age"); + let mut val = Value::parse("{ test: { something: [{ age: 34 }, { age: 36 }] } }"); + let res = Value::parse("{ test: { something: [{ age: 21 }, { age: 21 }] } }"); + val.put(&idi, Value::from(21)); + assert_eq!(res, val); + } +} diff --git a/lib/tests/yuse.rs b/lib/tests/yuse.rs new file mode 100644 index 00000000..24422354 --- /dev/null +++ b/lib/tests/yuse.rs @@ -0,0 +1,84 @@ +mod parse; +use parse::Parse; +use surrealdb::sql::Value; +use surrealdb::Datastore; +use surrealdb::Error; +use surrealdb::Session; + +#[tokio::test] +async fn use_statement_set_ns() -> Result<(), Error> { + let sql = " + SELECT * FROM $session.ns, session::ns(), $session.db, session::db(); + USE NS my_ns; + SELECT * FROM $session.ns, session::ns(), $session.db, session::db(); + "; + let dbs = Datastore::new("memory").await?; + let ses = Session::for_kv().with_ns("test").with_db("test"); + let res = &mut dbs.execute(&sql, &ses, None, false).await?; + assert_eq!(res.len(), 3); + // + let tmp = res.remove(0).result?; + let val = Value::parse("['test', 'test', 'test', 'test']"); + assert_eq!(tmp, val); + // + let tmp = res.remove(0).result; + assert!(tmp.is_ok()); + // + let tmp = res.remove(0).result?; + let val = Value::parse("['my_ns', 'my_ns', 'test', 'test']"); + assert_eq!(tmp, val); + // + Ok(()) +} + +#[tokio::test] +async fn use_statement_set_db() -> Result<(), Error> { + let sql = " + SELECT * FROM $session.ns, session::ns(), $session.db, session::db(); + USE DB my_db; + SELECT * FROM $session.ns, session::ns(), $session.db, session::db(); + "; + let dbs = Datastore::new("memory").await?; + let ses = Session::for_kv().with_ns("test").with_db("test"); + let res = &mut dbs.execute(&sql, &ses, None, false).await?; + assert_eq!(res.len(), 3); + // + let tmp = res.remove(0).result?; + let val = Value::parse("['test', 'test', 'test', 'test']"); + assert_eq!(tmp, val); + // + let tmp = res.remove(0).result; + assert!(tmp.is_ok()); + // + let tmp = res.remove(0).result?; + let val = Value::parse("['test', 'test', 'my_db', 'my_db']"); + assert_eq!(tmp, val); + // + Ok(()) +} + +#[tokio::test] +async fn use_statement_set_both() -> Result<(), Error> { + let sql = " + SELECT * FROM $session.ns, session::ns(), $session.db, session::db(); + USE NS my_ns DB my_db; + SELECT * FROM $session.ns, session::ns(), $session.db, session::db(); + "; + let dbs = Datastore::new("memory").await?; + let ses = Session::for_kv().with_ns("test").with_db("test"); + let res = &mut dbs.execute(&sql, &ses, None, false).await?; + assert_eq!(res.len(), 3); + // + let tmp = res.remove(0).result?; + let val = Value::parse("['test', 'test', 'test', 'test']"); + assert_eq!(tmp, val); + // + let tmp = res.remove(0).result; + assert!(tmp.is_ok()); + // + let tmp = res.remove(0).result?; + let val = Value::parse("['my_ns', 'my_ns', 'my_db', 'my_db']"); + assert_eq!(tmp, val); + // + Ok(()) +}