diff --git a/Cargo.toml b/Cargo.toml index 9529d23e..7ac7d4d4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -83,7 +83,7 @@ surrealdb = { version = "1", path = "lib", features = [ "protocol-http", "protocol-ws", "rustls", - "sql2" + "sql2", ] } tempfile = "3.8.1" thiserror = "1.0.50" diff --git a/core/src/lib.rs b/core/src/lib.rs index ecb002c9..67d6b5bc 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -40,6 +40,8 @@ pub mod obs; #[doc(hidden)] pub mod options; #[doc(hidden)] +pub mod rpc; +#[doc(hidden)] pub mod syn; #[doc(hidden)] diff --git a/core/src/rpc/args.rs b/core/src/rpc/args.rs new file mode 100644 index 00000000..f446a5bf --- /dev/null +++ b/core/src/rpc/args.rs @@ -0,0 +1,62 @@ +use crate::sql::Array; +use crate::sql::Value; + +use super::rpc_error::RpcError; + +pub trait Take { + fn needs_one(self) -> Result; + fn needs_two(self) -> Result<(Value, Value), RpcError>; + fn needs_one_or_two(self) -> Result<(Value, Value), RpcError>; + fn needs_one_two_or_three(self) -> Result<(Value, Value, Value), RpcError>; +} + +impl Take for Array { + /// Convert the array to one argument + fn needs_one(self) -> Result { + if self.len() != 1 { + return Err(RpcError::InvalidParams); + } + let mut x = self.into_iter(); + match x.next() { + Some(a) => Ok(a), + None => Ok(Value::None), + } + } + /// Convert the array to two arguments + fn needs_two(self) -> Result<(Value, Value), RpcError> { + if self.len() != 2 { + return Err(RpcError::InvalidParams); + } + let mut x = self.into_iter(); + match (x.next(), x.next()) { + (Some(a), Some(b)) => Ok((a, b)), + (Some(a), None) => Ok((a, Value::None)), + (_, _) => Ok((Value::None, Value::None)), + } + } + /// Convert the array to two arguments + fn needs_one_or_two(self) -> Result<(Value, Value), RpcError> { + if self.is_empty() && self.len() > 2 { + return Err(RpcError::InvalidParams); + } + let mut x = self.into_iter(); + match (x.next(), x.next()) { + (Some(a), Some(b)) => Ok((a, b)), + (Some(a), None) => Ok((a, Value::None)), + (_, _) => Ok((Value::None, Value::None)), + } + } + /// Convert the array to three arguments + fn needs_one_two_or_three(self) -> Result<(Value, Value, Value), RpcError> { + if self.is_empty() && self.len() > 3 { + return Err(RpcError::InvalidParams); + } + let mut x = self.into_iter(); + match (x.next(), x.next(), x.next()) { + (Some(a), Some(b), Some(c)) => Ok((a, b, c)), + (Some(a), Some(b), None) => Ok((a, b, Value::None)), + (Some(a), None, None) => Ok((a, Value::None, Value::None)), + (_, _, _) => Ok((Value::None, Value::None, Value::None)), + } + } +} diff --git a/core/src/rpc/basic_context.rs b/core/src/rpc/basic_context.rs new file mode 100644 index 00000000..f0e50c4b --- /dev/null +++ b/core/src/rpc/basic_context.rs @@ -0,0 +1,94 @@ +use std::collections::BTreeMap; + +use crate::{ + dbs::Session, + kvs::Datastore, + rpc::RpcContext, + sql::{Array, Value}, +}; + +use super::{args::Take, Data, RpcError}; + +pub struct BasicRpcContext<'a> { + pub kvs: &'a Datastore, + pub session: Session, + pub vars: BTreeMap, + pub version_string: String, +} + +impl<'a> BasicRpcContext<'a> { + pub fn new( + kvs: &'a Datastore, + session: Session, + vars: BTreeMap, + version_string: String, + ) -> Self { + Self { + kvs, + session, + vars, + version_string, + } + } +} + +impl RpcContext for BasicRpcContext<'_> { + fn kvs(&self) -> &Datastore { + self.kvs + } + + fn session(&self) -> &Session { + &self.session + } + + fn session_mut(&mut self) -> &mut Session { + &mut self.session + } + + fn vars(&self) -> &BTreeMap { + &self.vars + } + + fn vars_mut(&mut self) -> &mut BTreeMap { + &mut self.vars + } + + fn version_data(&self) -> impl Into { + Value::Strand(self.version_string.clone().into()) + } + + // reimplimentaions: + + async fn signup(&mut self, params: Array) -> Result, RpcError> { + let Ok(Value::Object(v)) = params.needs_one() else { + return Err(RpcError::InvalidParams); + }; + let out: Result = + crate::iam::signup::signup(self.kvs, &mut self.session, v) + .await + .map(Into::into) + .map_err(Into::into); + + out + } + + async fn signin(&mut self, params: Array) -> Result, RpcError> { + let Ok(Value::Object(v)) = params.needs_one() else { + return Err(RpcError::InvalidParams); + }; + let out: Result = + crate::iam::signin::signin(self.kvs, &mut self.session, v) + .await + .map(Into::into) + .map_err(Into::into); + out + } + + async fn authenticate(&mut self, params: Array) -> Result, RpcError> { + let Ok(Value::Strand(token)) = params.needs_one() else { + return Err(RpcError::InvalidParams); + }; + crate::iam::verify::token(self.kvs, &mut self.session, &token.0).await?; + Ok(Value::None) + } +} diff --git a/core/src/rpc/method.rs b/core/src/rpc/method.rs new file mode 100644 index 00000000..a7e7f41c --- /dev/null +++ b/core/src/rpc/method.rs @@ -0,0 +1,111 @@ +#[non_exhaustive] +pub enum Method { + Unknown, + Ping, + Info, + Use, + Signup, + Signin, + Invalidate, + Authenticate, + Kill, + Live, + Set, + Unset, + Select, + Insert, + Create, + Update, + Merge, + Patch, + Delete, + Version, + Query, + Relate, +} + +impl Method { + pub fn parse(s: S) -> Self + where + S: AsRef, + { + match s.as_ref().to_lowercase().as_str() { + "ping" => Self::Ping, + "info" => Self::Info, + "use" => Self::Use, + "signup" => Self::Signup, + "signin" => Self::Signin, + "invalidate" => Self::Invalidate, + "authenticate" => Self::Authenticate, + "kill" => Self::Kill, + "live" => Self::Live, + "let" | "set" => Self::Set, + "unset" => Self::Unset, + "select" => Self::Select, + "insert" => Self::Insert, + "create" => Self::Create, + "update" => Self::Update, + "merge" => Self::Merge, + "patch" => Self::Patch, + "delete" => Self::Delete, + "version" => Self::Version, + "query" => Self::Query, + "relate" => Self::Relate, + _ => Self::Unknown, + } + } +} + +impl Method { + pub fn to_str(&self) -> &str { + match self { + Self::Unknown => "unknown", + Self::Ping => "ping", + Self::Info => "info", + Self::Use => "use", + Self::Signup => "signup", + Self::Signin => "signin", + Self::Invalidate => "invalidate", + Self::Authenticate => "authenticate", + Self::Kill => "kill", + Self::Live => "live", + Self::Set => "set", + Self::Unset => "unset", + Self::Select => "select", + Self::Insert => "insert", + Self::Create => "create", + Self::Update => "update", + Self::Merge => "merge", + Self::Patch => "patch", + Self::Delete => "delete", + Self::Version => "version", + Self::Query => "query", + Self::Relate => "relate", + } + } +} + +impl Method { + pub fn is_valid(&self) -> bool { + !matches!(self, Self::Unknown) + } + + pub fn needs_mut(&self) -> bool { + !self.can_be_immut() + } + + // should be the same as execute_immut + pub fn can_be_immut(&self) -> bool { + matches!( + self, + Method::Ping + | Method::Info | Method::Select + | Method::Insert | Method::Create + | Method::Update | Method::Merge + | Method::Patch | Method::Delete + | Method::Version + | Method::Query | Method::Relate + | Method::Unknown + ) + } +} diff --git a/core/src/rpc/mod.rs b/core/src/rpc/mod.rs new file mode 100644 index 00000000..2b6a8875 --- /dev/null +++ b/core/src/rpc/mod.rs @@ -0,0 +1,11 @@ +pub mod args; +pub mod basic_context; +pub mod method; +mod response; +pub mod rpc_context; +mod rpc_error; + +pub use basic_context::BasicRpcContext; +pub use response::Data; +pub use rpc_context::RpcContext; +pub use rpc_error::RpcError; diff --git a/core/src/rpc/response.rs b/core/src/rpc/response.rs new file mode 100644 index 00000000..c10dddd7 --- /dev/null +++ b/core/src/rpc/response.rs @@ -0,0 +1,55 @@ +use crate::dbs; +use crate::dbs::Notification; +use crate::sql; +use crate::sql::Value; +use revision::revisioned; +use serde::Serialize; + +/// The data returned by the database +// The variants here should be in exactly the same order as `crate::engine::remote::ws::Data` +// In future, they will possibly be merged to avoid having to keep them in sync. +#[derive(Debug, Serialize)] +#[revisioned(revision = 1)] +pub enum Data { + /// Generally methods return a `sql::Value` + Other(Value), + /// The query methods, `query` and `query_with` return a `Vec` of responses + Query(Vec), + /// Live queries return a notification + Live(Notification), + // Add new variants here +} + +impl From for Data { + fn from(v: Value) -> Self { + Data::Other(v) + } +} + +impl From for Data { + fn from(v: String) -> Self { + Data::Other(Value::from(v)) + } +} + +impl From for Data { + fn from(n: Notification) -> Self { + Data::Live(n) + } +} + +impl From> for Data { + fn from(v: Vec) -> Self { + Data::Query(v) + } +} + +impl From for Value { + fn from(val: Data) -> Self { + match val { + Data::Query(v) => sql::to_value(v).unwrap(), + Data::Live(v) => sql::to_value(v).unwrap(), + Data::Other(v) => v, + } + } +} diff --git a/core/src/rpc/rpc_context.rs b/core/src/rpc/rpc_context.rs new file mode 100644 index 00000000..86298577 --- /dev/null +++ b/core/src/rpc/rpc_context.rs @@ -0,0 +1,536 @@ +use std::collections::BTreeMap; + +use crate::{ + dbs::{QueryType, Response, Session}, + kvs::Datastore, + rpc::args::Take, + sql::{Array, Value}, +}; +use uuid::Uuid; + +use super::{method::Method, response::Data, rpc_error::RpcError}; + +macro_rules! mrg { + ($($m:expr, $x:expr)+) => {{ + $($m.extend($x.iter().map(|(k, v)| (k.clone(), v.clone())));)+ + $($m)+ + }}; +} + +#[allow(async_fn_in_trait)] +pub trait RpcContext { + fn kvs(&self) -> &Datastore; + fn session(&self) -> &Session; + fn session_mut(&mut self) -> &mut Session; + fn vars(&self) -> &BTreeMap; + fn vars_mut(&mut self) -> &mut BTreeMap; + fn version_data(&self) -> impl Into; + + const LQ_SUPPORT: bool = false; + fn handle_live(&self, _lqid: &Uuid) -> impl std::future::Future + Send { + async { unreachable!() } + } + fn handle_kill(&self, _lqid: &Uuid) -> impl std::future::Future + Send { + async { unreachable!() } + } + + async fn execute(&mut self, method: Method, params: Array) -> Result { + match method { + Method::Ping => Ok(Value::None.into()), + Method::Info => self.info().await.map(Into::into).map_err(Into::into), + Method::Use => self.yuse(params).await.map(Into::into).map_err(Into::into), + Method::Signup => self.signup(params).await.map(Into::into).map_err(Into::into), + Method::Signin => self.signin(params).await.map(Into::into).map_err(Into::into), + Method::Invalidate => self.invalidate().await.map(Into::into).map_err(Into::into), + Method::Authenticate => { + self.authenticate(params).await.map(Into::into).map_err(Into::into) + } + Method::Kill => self.kill(params).await.map(Into::into).map_err(Into::into), + Method::Live => self.live(params).await.map(Into::into).map_err(Into::into), + Method::Set => self.set(params).await.map(Into::into).map_err(Into::into), + Method::Unset => self.unset(params).await.map(Into::into).map_err(Into::into), + Method::Select => self.select(params).await.map(Into::into).map_err(Into::into), + Method::Insert => self.insert(params).await.map(Into::into).map_err(Into::into), + Method::Create => self.create(params).await.map(Into::into).map_err(Into::into), + Method::Update => self.update(params).await.map(Into::into).map_err(Into::into), + Method::Merge => self.merge(params).await.map(Into::into).map_err(Into::into), + Method::Patch => self.patch(params).await.map(Into::into).map_err(Into::into), + Method::Delete => self.delete(params).await.map(Into::into).map_err(Into::into), + Method::Version => self.version(params).await.map(Into::into).map_err(Into::into), + Method::Query => self.query(params).await.map(Into::into).map_err(Into::into), + Method::Relate => self.relate(params).await.map(Into::into).map_err(Into::into), + Method::Unknown => Err(RpcError::MethodNotFound), + } + } + + async fn execute_immut(&self, method: Method, params: Array) -> Result { + match method { + Method::Ping => Ok(Value::None.into()), + Method::Info => self.info().await.map(Into::into).map_err(Into::into), + Method::Select => self.select(params).await.map(Into::into).map_err(Into::into), + Method::Insert => self.insert(params).await.map(Into::into).map_err(Into::into), + Method::Create => self.create(params).await.map(Into::into).map_err(Into::into), + Method::Update => self.update(params).await.map(Into::into).map_err(Into::into), + Method::Merge => self.merge(params).await.map(Into::into).map_err(Into::into), + Method::Patch => self.patch(params).await.map(Into::into).map_err(Into::into), + Method::Delete => self.delete(params).await.map(Into::into).map_err(Into::into), + Method::Version => self.version(params).await.map(Into::into).map_err(Into::into), + Method::Query => self.query(params).await.map(Into::into).map_err(Into::into), + Method::Relate => self.relate(params).await.map(Into::into).map_err(Into::into), + Method::Unknown => Err(RpcError::MethodNotFound), + _ => Err(RpcError::MethodNotFound), + } + } + + // ------------------------------ + // Methods for authentication + // ------------------------------ + + async fn yuse(&mut self, params: Array) -> Result, RpcError> { + let (ns, db) = params.needs_two()?; + if let Value::Strand(ns) = ns { + self.session_mut().ns = Some(ns.0); + } + if let Value::Strand(db) = db { + self.session_mut().db = Some(db.0); + } + Ok(Value::None) + } + + async fn signup(&mut self, params: Array) -> Result, RpcError> { + let Ok(Value::Object(v)) = params.needs_one() else { + return Err(RpcError::InvalidParams); + }; + let mut tmp_session = self.session().clone(); + let out: Result = + crate::iam::signup::signup(self.kvs(), &mut tmp_session, v) + .await + .map(Into::into) + .map_err(Into::into); + *self.session_mut() = tmp_session; + + out + } + + async fn signin(&mut self, params: Array) -> Result, RpcError> { + let Ok(Value::Object(v)) = params.needs_one() else { + return Err(RpcError::InvalidParams); + }; + let mut tmp_session = self.session().clone(); + let out: Result = + crate::iam::signin::signin(self.kvs(), &mut tmp_session, v) + .await + .map(Into::into) + .map_err(Into::into); + *self.session_mut() = tmp_session; + out + } + + async fn invalidate(&mut self) -> Result, RpcError> { + crate::iam::clear::clear(self.session_mut())?; + Ok(Value::None) + } + + async fn authenticate(&mut self, params: Array) -> Result, RpcError> { + let Ok(Value::Strand(token)) = params.needs_one() else { + return Err(RpcError::InvalidParams); + }; + let mut tmp_session = self.session().clone(); + crate::iam::verify::token(self.kvs(), &mut tmp_session, &token.0).await?; + *self.session_mut() = tmp_session; + Ok(Value::None) + } + + // ------------------------------ + // Methods for identification + // ------------------------------ + + async fn info(&self) -> Result, RpcError> { + // Specify the SQL query string + let sql = "SELECT * FROM $auth"; + // Execute the query on the database + let mut res = self.kvs().execute(sql, self.session(), None).await?; + // Extract the first value from the result + let res = res.remove(0).result?.first(); + // Return the result to the client + Ok(res) + } + + // ------------------------------ + // Methods for setting variables + // ------------------------------ + + async fn set(&mut self, params: Array) -> Result, RpcError> { + let Ok((Value::Strand(key), val)) = params.needs_one_or_two() else { + return Err(RpcError::InvalidParams); + }; + // Specify the query parameters + let var = Some(map! { + key.0.clone() => Value::None, + => &self.vars() + }); + // Compute the specified parameter + match self.kvs().compute(val, self.session(), var).await? { + // Remove the variable if undefined + Value::None => self.vars_mut().remove(&key.0), + // Store the variable if defined + v => self.vars_mut().insert(key.0, v), + }; + Ok(Value::Null) + } + + async fn unset(&mut self, params: Array) -> Result, RpcError> { + let Ok(Value::Strand(key)) = params.needs_one() else { + return Err(RpcError::InvalidParams); + }; + self.vars_mut().remove(&key.0); + Ok(Value::Null) + } + + // ------------------------------ + // Methods for live queries + // ------------------------------ + + async fn kill(&mut self, params: Array) -> Result, RpcError> { + let id = params.needs_one()?; + // Specify the SQL query string + let sql = "KILL $id"; + // Specify the query parameters + let var = map! { + String::from("id") => id, + => &self.vars() + }; + // Execute the query on the database + // let mut res = self.query_with(Value::from(sql), Object::from(var)).await?; + let mut res = self.query_inner(Value::from(sql), Some(var)).await?; + // Extract the first query result + let response = res.remove(0); + response.result.map_err(Into::into) + } + + async fn live(&mut self, params: Array) -> Result, RpcError> { + let (tb, diff) = params.needs_one_or_two()?; + // Specify the SQL query string + let sql = match diff.is_true() { + true => "LIVE SELECT DIFF FROM $tb", + false => "LIVE SELECT * FROM $tb", + }; + // Specify the query parameters + let var = map! { + String::from("tb") => tb.could_be_table(), + => &self.vars() + }; + // Execute the query on the database + let mut res = self.query_inner(Value::from(sql), Some(var)).await?; + // Extract the first query result + let response = res.remove(0); + response.result.map_err(Into::into) + } + + // ------------------------------ + // Methods for selecting + // ------------------------------ + + async fn select(&self, params: Array) -> Result, RpcError> { + let Ok(what) = params.needs_one() else { + return Err(RpcError::InvalidParams); + }; + // Return a single result? + let one = what.is_thing(); + // Specify the SQL query string + let sql = "SELECT * FROM $what"; + // Specify the query parameters + let var = Some(map! { + String::from("what") => what.could_be_table(), + => &self.vars() + }); + // Execute the query on the database + let mut res = self.kvs().execute(sql, self.session(), var).await?; + // Extract the first query result + let res = match one { + true => res.remove(0).result?.first(), + false => res.remove(0).result?, + }; + // Return the result to the client + Ok(res) + } + + // ------------------------------ + // Methods for inserting + // ------------------------------ + + async fn insert(&self, params: Array) -> Result, RpcError> { + let Ok((what, data)) = params.needs_two() else { + return Err(RpcError::InvalidParams); + }; + // Return a single result? + let one = what.is_thing(); + // Specify the SQL query string + let sql = "INSERT INTO $what $data RETURN AFTER"; + // Specify the query parameters + let var = Some(map! { + String::from("what") => what.could_be_table(), + String::from("data") => data, + => &self.vars() + }); + // Execute the query on the database + let mut res = self.kvs().execute(sql, self.session(), var).await?; + // Extract the first query result + let res = match one { + true => res.remove(0).result?.first(), + false => res.remove(0).result?, + }; + // Return the result to the client + Ok(res) + } + + // ------------------------------ + // Methods for creating + // ------------------------------ + + async fn create(&self, params: Array) -> Result, RpcError> { + let Ok((what, data)) = params.needs_one_or_two() else { + return Err(RpcError::InvalidParams); + }; + // Return a single result? + let one = what.is_thing(); + // Specify the SQL query string + let sql = if data.is_none_or_null() { + "CREATE $what RETURN AFTER" + } else { + "CREATE $what CONTENT $data RETURN AFTER" + }; + // Specify the query parameters + let var = Some(map! { + String::from("what") => what.could_be_table(), + String::from("data") => data, + => &self.vars() + }); + // Execute the query on the database + let mut res = self.kvs().execute(sql, self.session(), var).await?; + // Extract the first query result + let res = match one { + true => res.remove(0).result?.first(), + false => res.remove(0).result?, + }; + // Return the result to the client + Ok(res) + } + + // ------------------------------ + // Methods for updating + // ------------------------------ + + async fn update(&self, params: Array) -> Result, RpcError> { + let Ok((what, data)) = params.needs_one_or_two() else { + return Err(RpcError::InvalidParams); + }; + // Return a single result? + let one = what.is_thing(); + // Specify the SQL query string + let sql = if data.is_none_or_null() { + "UPDATE $what RETURN AFTER" + } else { + "UPDATE $what CONTENT $data RETURN AFTER" + }; + // Specify the query parameters + let var = Some(map! { + String::from("what") => what.could_be_table(), + String::from("data") => data, + => &self.vars() + }); + // Execute the query on the database + let mut res = self.kvs().execute(sql, self.session(), var).await?; + // Extract the first query result + let res = match one { + true => res.remove(0).result?.first(), + false => res.remove(0).result?, + }; + // Return the result to the client + Ok(res) + } + + // ------------------------------ + // Methods for merging + // ------------------------------ + + async fn merge(&self, params: Array) -> Result, RpcError> { + let Ok((what, data)) = params.needs_one_or_two() else { + return Err(RpcError::InvalidParams); + }; + // Return a single result? + let one = what.is_thing(); + // Specify the SQL query string + let sql = if data.is_none_or_null() { + "UPDATE $what RETURN AFTER" + } else { + "UPDATE $what MERGE $data RETURN AFTER" + }; + // Specify the query parameters + let var = Some(map! { + String::from("what") => what.could_be_table(), + String::from("data") => data, + => &self.vars() + }); + // Execute the query on the database + let mut res = self.kvs().execute(sql, self.session(), var).await?; + // Extract the first query result + let res = match one { + true => res.remove(0).result?.first(), + false => res.remove(0).result?, + }; + // Return the result to the client + Ok(res) + } + + // ------------------------------ + // Methods for patching + // ------------------------------ + + async fn patch(&self, params: Array) -> Result, RpcError> { + let Ok((what, data, diff)) = params.needs_one_two_or_three() else { + return Err(RpcError::InvalidParams); + }; + // Return a single result? + let one = what.is_thing(); + // Specify the SQL query string + let sql = match diff.is_true() { + true => "UPDATE $what PATCH $data RETURN DIFF", + false => "UPDATE $what PATCH $data RETURN AFTER", + }; + // Specify the query parameters + let var = Some(map! { + String::from("what") => what.could_be_table(), + String::from("data") => data, + => &self.vars() + }); + // Execute the query on the database + let mut res = self.kvs().execute(sql, self.session(), var).await?; + // Extract the first query result + let res = match one { + true => res.remove(0).result?.first(), + false => res.remove(0).result?, + }; + // Return the result to the client + Ok(res) + } + + // ------------------------------ + // Methods for deleting + // ------------------------------ + + async fn delete(&self, params: Array) -> Result, RpcError> { + let Ok(what) = params.needs_one() else { + return Err(RpcError::InvalidParams); + }; + // Return a single result? + let one = what.is_thing(); + // Specify the SQL query string + let sql = "DELETE $what RETURN BEFORE"; + // Specify the query parameters + let var = Some(map! { + String::from("what") => what.could_be_table(), + => &self.vars() + }); + // Execute the query on the database + let mut res = self.kvs().execute(sql, self.session(), var).await?; + // Extract the first query result + let res = match one { + true => res.remove(0).result?.first(), + false => res.remove(0).result?, + }; + // Return the result to the client + Ok(res) + } + + // ------------------------------ + // Methods for getting info + // ------------------------------ + + async fn version(&self, params: Array) -> Result, RpcError> { + match params.len() { + 0 => Ok(self.version_data()), + _ => Err(RpcError::InvalidParams), + } + } + + // ------------------------------ + // Methods for querying + // ------------------------------ + + async fn query(&self, params: Array) -> Result, RpcError> { + let Ok((query, o)) = params.needs_one_or_two() else { + return Err(RpcError::InvalidParams); + }; + if !(query.is_query() || query.is_strand()) { + return Err(RpcError::InvalidParams); + } + + let o = match o { + Value::Object(v) => Some(v), + Value::None | Value::Null => None, + _ => return Err(RpcError::InvalidParams), + }; + + // Specify the query parameters + let vars = match o { + Some(mut v) => Some(mrg! {v.0, &self.vars()}), + None => Some(self.vars().clone()), + }; + self.query_inner(query, vars).await + } + + // ------------------------------ + // Methods for querying + // ------------------------------ + + async fn relate(&self, _params: Array) -> Result, RpcError> { + let out: Result = Err(RpcError::MethodNotFound); + out + } + + // ------------------------------ + // Private methods + // ------------------------------ + + async fn query_inner( + &self, + query: Value, + vars: Option>, + ) -> Result, RpcError> { + // If no live query handler force realtime off + if !Self::LQ_SUPPORT && self.session().rt { + return Err(RpcError::BadLQConfig); + } + // Execute the query on the database + let res = match query { + Value::Query(sql) => self.kvs().process(sql, self.session(), vars).await?, + Value::Strand(sql) => self.kvs().execute(&sql, self.session(), vars).await?, + _ => unreachable!(), + }; + + // Post-process hooks for web layer + for response in &res { + // This error should be unreachable because we shouldn't proceed if there's no handler + self.handle_live_query_results(response).await; + } + // Return the result to the client + Ok(res) + } + + async fn handle_live_query_results(&self, res: &Response) { + match &res.query_type { + QueryType::Live => { + if let Ok(Value::Uuid(lqid)) = &res.result { + self.handle_live(&lqid.0).await; + } + } + QueryType::Kill => { + if let Ok(Value::Uuid(lqid)) = &res.result { + self.handle_kill(&lqid.0).await; + } + } + _ => {} + } + } +} diff --git a/core/src/rpc/rpc_error.rs b/core/src/rpc/rpc_error.rs new file mode 100644 index 00000000..9c24e6a5 --- /dev/null +++ b/core/src/rpc/rpc_error.rs @@ -0,0 +1,52 @@ +use thiserror::Error; + +use crate::err; + +#[derive(Debug, Error)] +#[non_exhaustive] +pub enum RpcError { + #[error("Parse error")] + ParseError, + #[error("Invalid request")] + InvalidRequest, + #[error("Method not found")] + MethodNotFound, + #[error("Invalid params")] + InvalidParams, + #[error("There was a problem with the database: {0}")] + InternalError(err::Error), + #[error("Live Query was made, but is not supported")] + LqNotSuported, + #[error("RT is enabled for the session, but LQ is not supported with the context")] + BadLQConfig, + #[error("Error: {0}")] + Thrown(String), +} + +impl From for RpcError { + fn from(e: err::Error) -> Self { + use err::Error; + match e { + Error::RealtimeDisabled => RpcError::LqNotSuported, + _ => RpcError::InternalError(e), + } + } +} + +impl From<&str> for RpcError { + fn from(e: &str) -> Self { + RpcError::Thrown(e.to_string()) + } +} + +impl From for err::Error { + fn from(value: RpcError) -> Self { + use err::Error; + match value { + RpcError::InternalError(e) => e, + RpcError::Thrown(e) => Error::Thrown(e), + + _ => Error::Thrown(value.to_string()), + } + } +} diff --git a/src/err/mod.rs b/src/err/mod.rs index c834ee75..f65191b2 100644 --- a/src/err/mod.rs +++ b/src/err/mod.rs @@ -135,6 +135,17 @@ impl From for Error { } } +impl From for Error { + fn from(value: surrealdb::rpc::RpcError) -> Self { + use surrealdb::rpc::RpcError; + match value { + RpcError::InternalError(e) => Error::Db(surrealdb::Error::Db(e)), + RpcError::Thrown(e) => Error::Other(e), + _ => Error::Other(value.to_string()), + } + } +} + impl Serialize for Error { fn serialize(&self, serializer: S) -> Result where diff --git a/src/mac/mod.rs b/src/mac/mod.rs index 877f3b03..de1257c0 100644 --- a/src/mac/mod.rs +++ b/src/mac/mod.rs @@ -6,10 +6,3 @@ macro_rules! map { m }}; } - -macro_rules! mrg { - ($($m:expr, $x:expr)+) => {{ - $($m.extend($x.iter().map(|(k, v)| (k.clone(), v.clone())));)+ - $($m)+ - }}; -} diff --git a/src/net/headers/content_type.rs b/src/net/headers/content_type.rs new file mode 100644 index 00000000..9bff7b57 --- /dev/null +++ b/src/net/headers/content_type.rs @@ -0,0 +1,70 @@ +use axum::headers; +use axum::headers::Header; +use http::HeaderName; +use http::HeaderValue; + +/// Typed header implementation for the `ContentType` header. +pub enum ContentType { + TextPlain, + ApplicationJson, + ApplicationCbor, + ApplicationPack, + ApplicationOctetStream, + Surrealdb, +} + +impl std::fmt::Display for ContentType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ContentType::TextPlain => write!(f, "text/plain"), + ContentType::ApplicationJson => write!(f, "application/json"), + ContentType::ApplicationCbor => write!(f, "application/cbor"), + ContentType::ApplicationPack => write!(f, "application/pack"), + ContentType::ApplicationOctetStream => write!(f, "application/octet-stream"), + ContentType::Surrealdb => write!(f, "application/surrealdb"), + } + } +} + +impl Header for ContentType { + fn name() -> &'static HeaderName { + &http::header::CONTENT_TYPE + } + + fn decode<'i, I>(values: &mut I) -> Result + where + I: Iterator, + { + let value = values.next().ok_or_else(headers::Error::invalid)?; + + match value.to_str().map_err(|_| headers::Error::invalid())? { + "text/plain" => Ok(ContentType::TextPlain), + "application/json" => Ok(ContentType::ApplicationJson), + "application/cbor" => Ok(ContentType::ApplicationCbor), + "application/pack" => Ok(ContentType::ApplicationPack), + "application/octet-stream" => Ok(ContentType::ApplicationOctetStream), + "application/surrealdb" => Ok(ContentType::Surrealdb), + // TODO: Support more (all?) mime-types + _ => Err(headers::Error::invalid()), + } + } + + fn encode(&self, values: &mut E) + where + E: Extend, + { + values.extend(std::iter::once(self.into())); + } +} + +impl From for HeaderValue { + fn from(value: ContentType) -> Self { + HeaderValue::from(&value) + } +} + +impl From<&ContentType> for HeaderValue { + fn from(value: &ContentType) -> Self { + HeaderValue::from_str(value.to_string().as_str()).unwrap() + } +} diff --git a/src/net/headers/mod.rs b/src/net/headers/mod.rs index 5c545402..d7728d7e 100644 --- a/src/net/headers/mod.rs +++ b/src/net/headers/mod.rs @@ -14,6 +14,7 @@ use tower_http::set_header::SetResponseHeaderLayer; mod accept; mod auth_db; mod auth_ns; +mod content_type; mod db; mod id; mod ns; @@ -21,6 +22,7 @@ mod ns; pub use accept::Accept; pub use auth_db::SurrealAuthDatabase; pub use auth_ns::SurrealAuthNamespace; +pub use content_type::ContentType; pub use db::{SurrealDatabase, SurrealDatabaseLegacy}; pub use id::{SurrealId, SurrealIdLegacy}; pub use ns::{SurrealNamespace, SurrealNamespaceLegacy}; diff --git a/src/net/key.rs b/src/net/key.rs index 918416c0..0179c1fa 100644 --- a/src/net/key.rs +++ b/src/net/key.rs @@ -97,6 +97,7 @@ async fn select_all( Some(Accept::ApplicationCbor) => Ok(output::cbor(&output::simplify(res))), Some(Accept::ApplicationPack) => Ok(output::pack(&output::simplify(res))), // Internal serialization + // TODO: remove format in 2.0.0 Some(Accept::Surrealdb) => Ok(output::full(&res)), // An incorrect content-type was requested _ => Err(Error::InvalidType), diff --git a/src/net/mod.rs b/src/net/mod.rs index b9aba369..14834960 100644 --- a/src/net/mod.rs +++ b/src/net/mod.rs @@ -1,12 +1,12 @@ mod auth; pub mod client_ip; mod export; -mod headers; +pub(crate) mod headers; mod health; mod import; mod input; mod key; -mod output; +pub(crate) mod output; mod params; mod rpc; mod signals; diff --git a/src/net/rpc.rs b/src/net/rpc.rs index 06651586..dc74325f 100644 --- a/src/net/rpc.rs +++ b/src/net/rpc.rs @@ -1,30 +1,47 @@ +use std::collections::BTreeMap; +use std::ops::Deref; + use crate::cnf; +use crate::dbs::DB; use crate::err::Error; use crate::rpc::connection::Connection; use crate::rpc::format::Format; use crate::rpc::format::PROTOCOLS; +use crate::rpc::post_context::PostRpcContext; +use crate::rpc::response::IntoRpcResponse; use crate::rpc::WEBSOCKETS; use axum::routing::get; +use axum::routing::post; +use axum::TypedHeader; use axum::{ extract::ws::{WebSocket, WebSocketUpgrade}, response::IntoResponse, Extension, Router, }; +use bytes::Bytes; use http::HeaderValue; use http_body::Body as HttpBody; use surrealdb::dbs::Session; +use surrealdb::rpc::method::Method; use tower_http::request_id::RequestId; use uuid::Uuid; +use super::headers::Accept; +use super::headers::ContentType; + +use surrealdb::rpc::rpc_context::RpcContext; + pub(super) fn router() -> Router where B: HttpBody + Send + 'static, + B::Data: Send, + B::Error: std::error::Error + Send + Sync + 'static, S: Clone + Send + Sync + 'static, { - Router::new().route("/rpc", get(handler)) + Router::new().route("/rpc", get(get_handler)).route("/rpc", post(post_handler)) } -async fn handler( +async fn get_handler( ws: WebSocketUpgrade, Extension(id): Extension, Extension(sess): Extension, @@ -70,9 +87,37 @@ async fn handle_socket(ws: WebSocket, sess: Session, id: Uuid) { // No protocol format was specified _ => Format::None, }; - // + // Format::Unsupported is not in the PROTOCOLS list so cannot be the value of format here // Create a new connection instance let rpc = Connection::new(id, sess, format); // Serve the socket connection requests Connection::serve(rpc, ws).await; } + +async fn post_handler( + Extension(session): Extension, + output: Option>, + content_type: TypedHeader, + body: Bytes, +) -> Result { + let fmt: Format = content_type.deref().into(); + let out_fmt: Option = output.as_deref().map(Into::into); + if let Some(out_fmt) = out_fmt { + if fmt != out_fmt { + return Err(Error::InvalidType); + } + } + if fmt == Format::Unsupported || fmt == Format::None { + return Err(Error::InvalidType); + } + + let mut rpc_ctx = PostRpcContext::new(DB.get().unwrap(), session, BTreeMap::new()); + + match fmt.req_http(body) { + Ok(req) => { + let res = rpc_ctx.execute(Method::parse(req.method), req.params).await; + fmt.res_http(res.into_response(None)).map_err(Error::from) + } + Err(err) => Err(Error::from(err)), + } +} diff --git a/src/rpc/connection.rs b/src/rpc/connection.rs index 89577e6b..85cd4872 100644 --- a/src/rpc/connection.rs +++ b/src/rpc/connection.rs @@ -1,12 +1,10 @@ -use crate::cnf::PKG_NAME; -use crate::cnf::PKG_VERSION; -use crate::cnf::{WEBSOCKET_MAX_CONCURRENT_REQUESTS, WEBSOCKET_PING_FREQUENCY}; +use crate::cnf::{ + PKG_NAME, PKG_VERSION, WEBSOCKET_MAX_CONCURRENT_REQUESTS, WEBSOCKET_PING_FREQUENCY, +}; use crate::dbs::DB; -use crate::err::Error; -use crate::rpc::args::Take; use crate::rpc::failure::Failure; use crate::rpc::format::Format; -use crate::rpc::response::{failure, Data, IntoRpcResponse}; +use crate::rpc::response::{failure, IntoRpcResponse}; use crate::rpc::{CONN_CLOSED_ERR, LIVE_QUERIES, WEBSOCKETS}; use crate::telemetry; use crate::telemetry::metrics::ws::RequestContext; @@ -19,12 +17,13 @@ use opentelemetry::Context as TelemetryContext; use std::collections::BTreeMap; use std::sync::Arc; use surrealdb::channel::{self, Receiver, Sender}; -use surrealdb::dbs::QueryType; -use surrealdb::dbs::Response; use surrealdb::dbs::Session; +use surrealdb::kvs::Datastore; +use surrealdb::rpc::args::Take; +use surrealdb::rpc::method::Method; +use surrealdb::rpc::RpcContext; +use surrealdb::rpc::{Data, RpcError}; use surrealdb::sql::Array; -use surrealdb::sql::Object; -use surrealdb::sql::Strand; use surrealdb::sql::Value; use tokio::sync::{RwLock, Semaphore}; use tokio::task::JoinSet; @@ -295,7 +294,7 @@ impl Connection { let req_cx = RequestContext::default(); let otel_cx = Arc::new(TelemetryContext::new().with_value(req_cx.clone())); // Parse the RPC request structure - match fmt.req(msg) { + match fmt.req_ws(msg) { Ok(req) => { // Now that we know the method, we can update the span and create otel context span.record("rpc.method", &req.method); @@ -337,547 +336,90 @@ impl Connection { params: Array, ) -> Result { debug!("Process RPC request"); - // Match the method to a function - match method { - // Handle a surrealdb ping message - // - // This is used to keep the WebSocket connection alive in environments where the WebSocket protocol is not enough. - // For example, some browsers will wait for the TCP protocol to timeout before triggering an on_close event. This may take several seconds or even minutes in certain scenarios. - // By sending a ping message every few seconds from the client, we can force a connection check and trigger an on_close event if the ping can't be sent. - // - "ping" => Ok(Value::None.into()), - // Retrieve the current auth record - "info" => match params.len() { - 0 => rpc.read().await.info().await.map(Into::into).map_err(Into::into), - _ => Err(Failure::INVALID_PARAMS), - }, - // Switch to a specific namespace and database - "use" => match params.needs_two() { - Ok((ns, db)) => { - rpc.write().await.yuse(ns, db).await.map(Into::into).map_err(Into::into) - } - _ => Err(Failure::INVALID_PARAMS), - }, - // Signup to a specific authentication scope - "signup" => match params.needs_one() { - Ok(Value::Object(v)) => { - rpc.write().await.signup(v).await.map(Into::into).map_err(Into::into) - } - _ => Err(Failure::INVALID_PARAMS), - }, - // Signin as a root, namespace, database or scope user - "signin" => match params.needs_one() { - Ok(Value::Object(v)) => { - rpc.write().await.signin(v).await.map(Into::into).map_err(Into::into) - } - _ => Err(Failure::INVALID_PARAMS), - }, - // Invalidate the current authentication session - "invalidate" => match params.len() { - 0 => rpc.write().await.invalidate().await.map(Into::into).map_err(Into::into), - _ => Err(Failure::INVALID_PARAMS), - }, - // Authenticate using an authentication token - "authenticate" => match params.needs_one() { - Ok(Value::Strand(v)) => { - rpc.write().await.authenticate(v).await.map(Into::into).map_err(Into::into) - } - _ => Err(Failure::INVALID_PARAMS), - }, - // Kill a live query using a query id - "kill" => match params.needs_one() { - Ok(v) => rpc.read().await.kill(v).await.map(Into::into).map_err(Into::into), - _ => Err(Failure::INVALID_PARAMS), - }, - // Setup a live query on a specific table - "live" => match params.needs_one_or_two() { - Ok((v, d)) if v.is_table() => { - rpc.read().await.live(v, d).await.map(Into::into).map_err(Into::into) - } - Ok((v, d)) if v.is_strand() => { - rpc.read().await.live(v, d).await.map(Into::into).map_err(Into::into) - } - _ => Err(Failure::INVALID_PARAMS), - }, - // Specify a connection-wide parameter - "let" | "set" => match params.needs_one_or_two() { - Ok((Value::Strand(s), v)) => { - rpc.write().await.set(s, v).await.map(Into::into).map_err(Into::into) - } - _ => Err(Failure::INVALID_PARAMS), - }, - // Unset and clear a connection-wide parameter - "unset" => match params.needs_one() { - Ok(Value::Strand(s)) => { - rpc.write().await.unset(s).await.map(Into::into).map_err(Into::into) - } - _ => Err(Failure::INVALID_PARAMS), - }, - // Select a value or values from the database - "select" => match params.needs_one() { - Ok(v) => rpc.read().await.select(v).await.map(Into::into).map_err(Into::into), - _ => Err(Failure::INVALID_PARAMS), - }, - // Insert a value or values in the database - "insert" => match params.needs_one_or_two() { - Ok((v, o)) => { - rpc.read().await.insert(v, o).await.map(Into::into).map_err(Into::into) - } - _ => Err(Failure::INVALID_PARAMS), - }, - // Create a value or values in the database - "create" => match params.needs_one_or_two() { - Ok((v, o)) => { - rpc.read().await.create(v, o).await.map(Into::into).map_err(Into::into) - } - _ => Err(Failure::INVALID_PARAMS), - }, - // Update a value or values in the database using `CONTENT` - "update" => match params.needs_one_or_two() { - Ok((v, o)) => { - rpc.read().await.update(v, o).await.map(Into::into).map_err(Into::into) - } - _ => Err(Failure::INVALID_PARAMS), - }, - // Update a value or values in the database using `MERGE` - "merge" => match params.needs_one_or_two() { - Ok((v, o)) => { - rpc.read().await.merge(v, o).await.map(Into::into).map_err(Into::into) - } - _ => Err(Failure::INVALID_PARAMS), - }, - // Update a value or values in the database using `PATCH` - "patch" => match params.needs_one_two_or_three() { - Ok((v, o, d)) => { - rpc.read().await.patch(v, o, d).await.map(Into::into).map_err(Into::into) - } - _ => Err(Failure::INVALID_PARAMS), - }, - // Delete a value or values from the database - "delete" => match params.needs_one() { - Ok(v) => rpc.read().await.delete(v).await.map(Into::into).map_err(Into::into), - _ => Err(Failure::INVALID_PARAMS), - }, - // Get the current server version - "version" => match params.len() { - 0 => Ok(format!("{PKG_NAME}-{}", *PKG_VERSION).into()), - _ => Err(Failure::INVALID_PARAMS), - }, - // Run a full SurrealQL query against the database - "query" => match params.needs_one_or_two() { - Ok((v, o)) if (v.is_strand() || v.is_query()) && o.is_none_or_null() => { - rpc.read().await.query(v).await.map(Into::into).map_err(Into::into) - } - Ok((v, Value::Object(o))) if v.is_strand() || v.is_query() => { - rpc.read().await.query_with(v, o).await.map(Into::into).map_err(Into::into) - } - _ => Err(Failure::INVALID_PARAMS), - }, - _ => Err(Failure::METHOD_NOT_FOUND), + let method = Method::parse(method); + if !method.is_valid() { + return Err(Failure::METHOD_NOT_FOUND); } - } - // ------------------------------ - // Methods for authentication - // ------------------------------ - - async fn yuse(&mut self, ns: Value, db: Value) -> Result { - if let Value::Strand(ns) = ns { - self.session.ns = Some(ns.0); - } - if let Value::Strand(db) = db { - self.session.db = Some(db.0); - } - Ok(Value::None) - } - - async fn signup(&mut self, vars: Object) -> Result { - let kvs = DB.get().unwrap(); - surrealdb::iam::signup::signup(kvs, &mut self.session, vars) - .await - .map(Into::into) - .map_err(Into::into) - } - - async fn signin(&mut self, vars: Object) -> Result { - let kvs = DB.get().unwrap(); - surrealdb::iam::signin::signin(kvs, &mut self.session, vars) - .await - .map(Into::into) - .map_err(Into::into) - } - - async fn invalidate(&mut self) -> Result { - surrealdb::iam::clear::clear(&mut self.session)?; - Ok(Value::None) - } - - async fn authenticate(&mut self, token: Strand) -> Result { - let kvs = DB.get().unwrap(); - surrealdb::iam::verify::token(kvs, &mut self.session, &token.0).await?; - Ok(Value::None) - } - - // ------------------------------ - // Methods for identification - // ------------------------------ - - async fn info(&self) -> Result { - // Get a database reference - let kvs = DB.get().unwrap(); - // Specify the SQL query string - let sql = "SELECT * FROM $auth"; - // Execute the query on the database - let mut res = kvs.execute(sql, &self.session, None).await?; - // Extract the first value from the result - let res = res.remove(0).result?.first(); - // Return the result to the client - Ok(res) - } - - // ------------------------------ - // Methods for setting variables - // ------------------------------ - - async fn set(&mut self, key: Strand, val: Value) -> Result { - // Get a database reference - let kvs = DB.get().unwrap(); - // Specify the query parameters - let var = Some(map! { - key.0.clone() => Value::None, - => &self.vars - }); - // Compute the specified parameter - match kvs.compute(val, &self.session, var).await? { - // Remove the variable if undefined - Value::None => self.vars.remove(&key.0), - // Store the variable if defined - v => self.vars.insert(key.0, v), - }; - Ok(Value::Null) - } - - async fn unset(&mut self, key: Strand) -> Result { - self.vars.remove(&key.0); - Ok(Value::Null) - } - - // ------------------------------ - // Methods for live queries - // ------------------------------ - - async fn kill(&self, id: Value) -> Result { - // Specify the SQL query string - let sql = "KILL $id"; - // Specify the query parameters - let var = map! { - String::from("id") => id, - => &self.vars - }; - // Execute the query on the database - let mut res = self.query_with(Value::from(sql), Object::from(var)).await?; - // Extract the first query result - let response = res.remove(0); - match response.result { - Ok(v) => Ok(v), - Err(e) => Err(Error::from(e)), - } - } - - async fn live(&self, tb: Value, diff: Value) -> Result { - // Specify the SQL query string - let sql = match diff.is_true() { - true => "LIVE SELECT DIFF FROM $tb", - false => "LIVE SELECT * FROM $tb", - }; - // Specify the query parameters - let var = map! { - String::from("tb") => tb.could_be_table(), - => &self.vars - }; - // Execute the query on the database - let mut res = self.query_with(Value::from(sql), Object::from(var)).await?; - // Extract the first query result - let response = res.remove(0); - match response.result { - Ok(v) => Ok(v), - Err(e) => Err(Error::from(e)), - } - } - - // ------------------------------ - // Methods for selecting - // ------------------------------ - - async fn select(&self, what: Value) -> Result { - // Return a single result? - let one = what.is_thing(); - // Get a database reference - let kvs = DB.get().unwrap(); - // Specify the SQL query string - let sql = "SELECT * FROM $what"; - // Specify the query parameters - let var = Some(map! { - String::from("what") => what.could_be_table(), - => &self.vars - }); - // Execute the query on the database - let mut res = kvs.execute(sql, &self.session, var).await?; - // Extract the first query result - let res = match one { - true => res.remove(0).result?.first(), - false => res.remove(0).result?, - }; - // Return the result to the client - Ok(res) - } - - // ------------------------------ - // Methods for inserting - // ------------------------------ - - async fn insert(&self, what: Value, data: Value) -> Result { - // Return a single result? - let one = what.is_thing(); - // Get a database reference - let kvs = DB.get().unwrap(); - // Specify the SQL query string - let sql = "INSERT INTO $what $data RETURN AFTER"; - // Specify the query parameters - let var = Some(map! { - String::from("what") => what.could_be_table(), - String::from("data") => data, - => &self.vars - }); - // Execute the query on the database - let mut res = kvs.execute(sql, &self.session, var).await?; - // Extract the first query result - let res = match one { - true => res.remove(0).result?.first(), - false => res.remove(0).result?, - }; - // Return the result to the client - Ok(res) - } - - // ------------------------------ - // Methods for creating - // ------------------------------ - - async fn create(&self, what: Value, data: Value) -> Result { - // Return a single result? - let one = what.is_thing(); - // Get a database reference - let kvs = DB.get().unwrap(); - // Specify the SQL query string - let sql = if data.is_none_or_null() { - "CREATE $what RETURN AFTER" - } else { - "CREATE $what CONTENT $data RETURN AFTER" - }; - // Specify the query parameters - let var = Some(map! { - String::from("what") => what.could_be_table(), - String::from("data") => data, - => &self.vars - }); - // Execute the query on the database - let mut res = kvs.execute(sql, &self.session, var).await?; - // Extract the first query result - let res = match one { - true => res.remove(0).result?.first(), - false => res.remove(0).result?, - }; - // Return the result to the client - Ok(res) - } - - // ------------------------------ - // Methods for updating - // ------------------------------ - - async fn update(&self, what: Value, data: Value) -> Result { - // Return a single result? - let one = what.is_thing(); - // Get a database reference - let kvs = DB.get().unwrap(); - // Specify the SQL query string - let sql = if data.is_none_or_null() { - "UPDATE $what RETURN AFTER" - } else { - "UPDATE $what CONTENT $data RETURN AFTER" - }; - // Specify the query parameters - let var = Some(map! { - String::from("what") => what.could_be_table(), - String::from("data") => data, - => &self.vars - }); - // Execute the query on the database - let mut res = kvs.execute(sql, &self.session, var).await?; - // Extract the first query result - let res = match one { - true => res.remove(0).result?.first(), - false => res.remove(0).result?, - }; - // Return the result to the client - Ok(res) - } - - // ------------------------------ - // Methods for merging - // ------------------------------ - - async fn merge(&self, what: Value, data: Value) -> Result { - // Return a single result? - let one = what.is_thing(); - // Get a database reference - let kvs = DB.get().unwrap(); - // Specify the SQL query string - let sql = if data.is_none_or_null() { - "UPDATE $what RETURN AFTER" - } else { - "UPDATE $what MERGE $data RETURN AFTER" - }; - // Specify the query parameters - let var = Some(map! { - String::from("what") => what.could_be_table(), - String::from("data") => data, - => &self.vars - }); - // Execute the query on the database - let mut res = kvs.execute(sql, &self.session, var).await?; - // Extract the first query result - let res = match one { - true => res.remove(0).result?.first(), - false => res.remove(0).result?, - }; - // Return the result to the client - Ok(res) - } - - // ------------------------------ - // Methods for patching - // ------------------------------ - - async fn patch(&self, what: Value, data: Value, diff: Value) -> Result { - // Return a single result? - let one = what.is_thing(); - // Get a database reference - let kvs = DB.get().unwrap(); - // Specify the SQL query string - let sql = match diff.is_true() { - true => "UPDATE $what PATCH $data RETURN DIFF", - false => "UPDATE $what PATCH $data RETURN AFTER", - }; - // Specify the query parameters - let var = Some(map! { - String::from("what") => what.could_be_table(), - String::from("data") => data, - => &self.vars - }); - // Execute the query on the database - let mut res = kvs.execute(sql, &self.session, var).await?; - // Extract the first query result - let res = match one { - true => res.remove(0).result?.first(), - false => res.remove(0).result?, - }; - // Return the result to the client - Ok(res) - } - - // ------------------------------ - // Methods for deleting - // ------------------------------ - - async fn delete(&self, what: Value) -> Result { - // Return a single result? - let one = what.is_thing(); - // Get a database reference - let kvs = DB.get().unwrap(); - // Specify the SQL query string - let sql = "DELETE $what RETURN BEFORE"; - // Specify the query parameters - let var = Some(map! { - String::from("what") => what.could_be_table(), - => &self.vars - }); - // Execute the query on the database - let mut res = kvs.execute(sql, &self.session, var).await?; - // Extract the first query result - let res = match one { - true => res.remove(0).result?.first(), - false => res.remove(0).result?, - }; - // Return the result to the client - Ok(res) - } - - // ------------------------------ - // Methods for querying - // ------------------------------ - - async fn query(&self, sql: Value) -> Result, Error> { - // Get a database reference - let kvs = DB.get().unwrap(); - // Specify the query parameters - let var = Some(self.vars.clone()); - // Execute the query on the database - let res = match sql { - Value::Query(sql) => kvs.process(sql, &self.session, var).await?, - Value::Strand(sql) => kvs.execute(&sql, &self.session, var).await?, - _ => unreachable!(), - }; - - // Post-process hooks for web layer - for response in &res { - self.handle_live_query_results(response).await; - } - // Return the result to the client - Ok(res) - } - - async fn query_with(&self, sql: Value, mut vars: Object) -> Result, Error> { - // Get a database reference - let kvs = DB.get().unwrap(); - // Specify the query parameters - let var = Some(mrg! { vars.0, &self.vars }); - // Execute the query on the database - let res = match sql { - Value::Query(sql) => kvs.process(sql, &self.session, var).await?, - Value::Strand(sql) => kvs.execute(&sql, &self.session, var).await?, - _ => unreachable!(), - }; - // Post-process hooks for web layer - for response in &res { - self.handle_live_query_results(response).await; - } - // Return the result to the client - Ok(res) - } - - // ------------------------------ - // Private methods - // ------------------------------ - - async fn handle_live_query_results(&self, res: &Response) { - match &res.query_type { - QueryType::Live => { - if let Ok(Value::Uuid(lqid)) = &res.result { - // Match on Uuid type - LIVE_QUERIES.write().await.insert(lqid.0, self.id); - trace!("Registered live query {} on websocket {}", lqid, self.id); - } - } - QueryType::Kill => { - if let Ok(Value::Uuid(lqid)) = &res.result { - if let Some(id) = LIVE_QUERIES.write().await.remove(&lqid.0) { - trace!("Unregistered live query {} on websocket {}", lqid, id); - } - } - } - _ => {} + // if the write lock is a bottleneck then execute could be refactored into execute_mut and execute + // rpc.write().await.execute(method, params).await.map_err(Into::into) + match method.needs_mut() { + true => rpc.write().await.execute(method, params).await.map_err(Into::into), + false => rpc.read().await.execute_immut(method, params).await.map_err(Into::into), } } } + +impl RpcContext for Connection { + fn kvs(&self) -> &Datastore { + DB.get().unwrap() + } + + fn session(&self) -> &Session { + &self.session + } + + fn session_mut(&mut self) -> &mut Session { + &mut self.session + } + + fn vars(&self) -> &BTreeMap { + &self.vars + } + + fn vars_mut(&mut self) -> &mut BTreeMap { + &mut self.vars + } + + fn version_data(&self) -> impl Into { + format!("{PKG_NAME}-{}", *PKG_VERSION) + } + + const LQ_SUPPORT: bool = true; + + async fn handle_live(&self, lqid: &Uuid) { + LIVE_QUERIES.write().await.insert(*lqid, self.id); + trace!("Registered live query {} on websocket {}", lqid, self.id); + } + + async fn handle_kill(&self, lqid: &Uuid) { + if let Some(id) = LIVE_QUERIES.write().await.remove(lqid) { + trace!("Unregistered live query {} on websocket {}", lqid, id); + } + } + + // reimplimentaions + + async fn signup(&mut self, params: Array) -> Result, RpcError> { + let Ok(Value::Object(v)) = params.needs_one() else { + return Err(RpcError::InvalidParams); + }; + let out: Result = + surrealdb::iam::signup::signup(DB.get().unwrap(), &mut self.session, v) + .await + .map(Into::into) + .map_err(Into::into); + + out + } + + async fn signin(&mut self, params: Array) -> Result, RpcError> { + let Ok(Value::Object(v)) = params.needs_one() else { + return Err(RpcError::InvalidParams); + }; + let out: Result = + surrealdb::iam::signin::signin(DB.get().unwrap(), &mut self.session, v) + .await + .map(Into::into) + .map_err(Into::into); + out + } + + async fn authenticate(&mut self, params: Array) -> Result, RpcError> { + let Ok(Value::Strand(token)) = params.needs_one() else { + return Err(RpcError::InvalidParams); + }; + surrealdb::iam::verify::token(DB.get().unwrap(), &mut self.session, &token.0).await?; + Ok(Value::None) + } +} diff --git a/src/rpc/failure.rs b/src/rpc/failure.rs index 1ebd4088..3387b18d 100644 --- a/src/rpc/failure.rs +++ b/src/rpc/failure.rs @@ -3,6 +3,7 @@ use revision::revisioned; use revision::Revisioned; use serde::Serialize; use std::borrow::Cow; +use surrealdb::rpc::RpcError; use surrealdb::sql::Value; #[derive(Clone, Debug, Serialize)] @@ -51,6 +52,20 @@ impl From for Failure { } } +impl From for Failure { + fn from(err: RpcError) -> Self { + match err { + RpcError::ParseError => Failure::PARSE_ERROR, + RpcError::InvalidRequest => Failure::INVALID_REQUEST, + RpcError::MethodNotFound => Failure::METHOD_NOT_FOUND, + RpcError::InvalidParams => Failure::INVALID_PARAMS, + RpcError::InternalError(_) => Failure::custom(err.to_string()), + RpcError::Thrown(_) => Failure::custom(err.to_string()), + _ => Failure::custom(err.to_string()), + } + } +} + impl From for Value { fn from(err: Failure) -> Self { map! { diff --git a/src/rpc/format/bincode.rs b/src/rpc/format/bincode.rs index df4b6cbf..4760a89f 100644 --- a/src/rpc/format/bincode.rs +++ b/src/rpc/format/bincode.rs @@ -1,22 +1,40 @@ -use crate::rpc::failure::Failure; +use crate::net::headers::ContentType; use crate::rpc::request::Request; use crate::rpc::response::Response; use axum::extract::ws::Message; +use axum::response::IntoResponse; +use axum::response::Response as AxumResponse; +use bytes::Bytes; +use http::header::CONTENT_TYPE; +use http::HeaderValue; +use surrealdb::rpc::RpcError; use surrealdb::sql::serde::deserialize; use surrealdb::sql::Value; -pub fn req(msg: Message) -> Result { +pub fn req_ws(msg: Message) -> Result { match msg { Message::Binary(val) => { - deserialize::(&val).map_err(|_| Failure::PARSE_ERROR)?.try_into() + deserialize::(&val).map_err(|_| RpcError::ParseError)?.try_into() } - _ => Err(Failure::INVALID_REQUEST), + _ => Err(RpcError::InvalidRequest), } } -pub fn res(res: Response) -> Result<(usize, Message), Failure> { +pub fn res_ws(res: Response) -> Result<(usize, Message), RpcError> { // Serialize the response with full internal type information let res = surrealdb::sql::serde::serialize(&res).unwrap(); // Return the message length, and message as binary Ok((res.len(), Message::Binary(res))) } + +pub fn req_http(val: &Bytes) -> Result { + deserialize::(val).map_err(|_| RpcError::ParseError)?.try_into() +} + +pub fn res_http(res: Response) -> Result { + // Serialize the response with full internal type information + let res = surrealdb::sql::serde::serialize(&res).unwrap(); + // Return the message length, and message as binary + // TODO: Check what this header should be, I'm being consistent with /sql + Ok(([(CONTENT_TYPE, HeaderValue::from(ContentType::Surrealdb))], res).into_response()) +} diff --git a/src/rpc/format/cbor/mod.rs b/src/rpc/format/cbor/mod.rs index 6a2a9ea5..1bb0f950 100644 --- a/src/rpc/format/cbor/mod.rs +++ b/src/rpc/format/cbor/mod.rs @@ -1,27 +1,32 @@ mod convert; +use bytes::Bytes; pub use convert::Cbor; +use http::header::CONTENT_TYPE; +use http::HeaderValue; +use surrealdb::rpc::RpcError; -use crate::rpc::failure::Failure; +use crate::net::headers::ContentType; use crate::rpc::request::Request; use crate::rpc::response::Response; use axum::extract::ws::Message; +use axum::response::{IntoResponse, Response as AxumResponse}; use ciborium::Value as Data; -pub fn req(msg: Message) -> Result { +pub fn req_ws(msg: Message) -> Result { match msg { Message::Text(val) => { - surrealdb::sql::value(&val).map_err(|_| Failure::PARSE_ERROR)?.try_into() + surrealdb::sql::value(&val).map_err(|_| RpcError::ParseError)?.try_into() } Message::Binary(val) => ciborium::from_reader::(&mut val.as_slice()) - .map_err(|_| Failure::PARSE_ERROR) + .map_err(|_| RpcError::ParseError) .map(Cbor)? .try_into(), - _ => Err(Failure::INVALID_REQUEST), + _ => Err(RpcError::InvalidRequest), } } -pub fn res(res: Response) -> Result<(usize, Message), Failure> { +pub fn res_ws(res: Response) -> Result<(usize, Message), RpcError> { // Convert the response into a value let val: Cbor = res.into_value().try_into()?; // Create a new vector for encoding output @@ -31,3 +36,22 @@ pub fn res(res: Response) -> Result<(usize, Message), Failure> { // Return the message length, and message as binary Ok((res.len(), Message::Binary(res))) } + +pub fn req_http(body: Bytes) -> Result { + let val: Vec = body.into(); + ciborium::from_reader::(&mut val.as_slice()) + .map_err(|_| RpcError::ParseError) + .map(Cbor)? + .try_into() +} + +pub fn res_http(res: Response) -> Result { + // Convert the response into a value + let val: Cbor = res.into_value().try_into()?; + // Create a new vector for encoding output + let mut res = Vec::new(); + // Serialize the value into CBOR binary data + ciborium::into_writer(&val.0, &mut res).unwrap(); + // Return the message length, and message as binary + Ok(([(CONTENT_TYPE, HeaderValue::from(ContentType::Surrealdb))], res).into_response()) +} diff --git a/src/rpc/format/json.rs b/src/rpc/format/json.rs index 46f80fb2..89ed1115 100644 --- a/src/rpc/format/json.rs +++ b/src/rpc/format/json.rs @@ -1,18 +1,23 @@ -use crate::rpc::failure::Failure; use crate::rpc::request::Request; use crate::rpc::response::Response; use axum::extract::ws::Message; +use axum::response::IntoResponse; +use axum::response::Response as AxumResponse; +use bytes::Bytes; +use http::StatusCode; +use surrealdb::rpc::RpcError; +use surrealdb::sql; -pub fn req(msg: Message) -> Result { +pub fn req_ws(msg: Message) -> Result { match msg { Message::Text(val) => { - surrealdb::sql::value(&val).map_err(|_| Failure::PARSE_ERROR)?.try_into() + surrealdb::sql::value(&val).map_err(|_| RpcError::ParseError)?.try_into() } - _ => Err(Failure::INVALID_REQUEST), + _ => Err(RpcError::InvalidRequest), } } -pub fn res(res: Response) -> Result<(usize, Message), Failure> { +pub fn res_ws(res: Response) -> Result<(usize, Message), RpcError> { // Convert the response into simplified JSON let val = res.into_json(); // Serialize the response with simplified type information @@ -20,3 +25,18 @@ pub fn res(res: Response) -> Result<(usize, Message), Failure> { // Return the message length, and message as binary Ok((res.len(), Message::Text(res))) } + +pub fn req_http(val: &Bytes) -> Result { + sql::value(std::str::from_utf8(val).or(Err(RpcError::ParseError))?) + .or(Err(RpcError::ParseError))? + .try_into() +} + +pub fn res_http(res: Response) -> Result { + // Convert the response into simplified JSON + let val = res.into_json(); + // Serialize the response with simplified type information + let res = serde_json::to_string(&val).unwrap(); + // Return the message length, and message as binary + Ok((StatusCode::OK, res).into_response()) +} diff --git a/src/rpc/format/mod.rs b/src/rpc/format/mod.rs index 766abba3..a562afc9 100644 --- a/src/rpc/format/mod.rs +++ b/src/rpc/format/mod.rs @@ -4,10 +4,14 @@ mod json; pub mod msgpack; mod revision; +use crate::net::headers::{Accept, ContentType}; use crate::rpc::failure::Failure; use crate::rpc::request::Request; use crate::rpc::response::Response; use axum::extract::ws::Message; +use axum::response::Response as AxumResponse; +use bytes::Bytes; +use surrealdb::rpc::RpcError; pub const PROTOCOLS: [&str; 5] = [ "json", // For basic JSON serialisation @@ -19,12 +23,39 @@ pub const PROTOCOLS: [&str; 5] = [ #[derive(Debug, Clone, Copy, Eq, PartialEq)] pub enum Format { - None, // No format is specified yet - Json, // For basic JSON serialisation - Cbor, // For basic CBOR serialisation - Msgpack, // For basic Msgpack serialisation - Bincode, // For full internal serialisation - Revision, // For full versioned serialisation + None, // No format is specified yet + Json, // For basic JSON serialisation + Cbor, // For basic CBOR serialisation + Msgpack, // For basic Msgpack serialisation + Bincode, // For full internal serialisation + Revision, // For full versioned serialisation + Unsupported, // Unsupported format +} + +impl From<&Accept> for Format { + fn from(value: &Accept) -> Self { + match value { + Accept::TextPlain => Format::None, + Accept::ApplicationJson => Format::Json, + Accept::ApplicationCbor => Format::Cbor, + Accept::ApplicationPack => Format::Msgpack, + Accept::ApplicationOctetStream => Format::Unsupported, + Accept::Surrealdb => Format::Bincode, + } + } +} + +impl From<&ContentType> for Format { + fn from(value: &ContentType) -> Self { + match value { + ContentType::TextPlain => Format::None, + ContentType::ApplicationJson => Format::Json, + ContentType::ApplicationCbor => Format::Cbor, + ContentType::ApplicationPack => Format::Msgpack, + ContentType::ApplicationOctetStream => Format::Unsupported, + ContentType::Surrealdb => Format::Bincode, + } + } } impl From<&str> for Format { @@ -46,25 +77,53 @@ impl Format { matches!(self, Format::None) } /// Process a request using the specified format - pub fn req(&self, msg: Message) -> Result { + pub fn req_ws(&self, msg: Message) -> Result { match self { Self::None => unreachable!(), // We should never arrive at this code - Self::Json => json::req(msg), - Self::Cbor => cbor::req(msg), - Self::Msgpack => msgpack::req(msg), - Self::Bincode => bincode::req(msg), - Self::Revision => revision::req(msg), + Self::Unsupported => unreachable!(), // We should never arrive at this code + Self::Json => json::req_ws(msg), + Self::Cbor => cbor::req_ws(msg), + Self::Msgpack => msgpack::req_ws(msg), + Self::Bincode => bincode::req_ws(msg), + Self::Revision => revision::req_ws(msg), + } + .map_err(Into::into) + } + /// Process a response using the specified format + pub fn res_ws(&self, res: Response) -> Result<(usize, Message), Failure> { + match self { + Self::None => unreachable!(), // We should never arrive at this code + Self::Unsupported => unreachable!(), // We should never arrive at this code + Self::Json => json::res_ws(res), + Self::Cbor => cbor::res_ws(res), + Self::Msgpack => msgpack::res_ws(res), + Self::Bincode => bincode::res_ws(res), + Self::Revision => revision::res_ws(res), + } + .map_err(Into::into) + } + /// Process a request using the specified format + pub fn req_http(&self, body: Bytes) -> Result { + match self { + Self::None => unreachable!(), // We should never arrive at this code + Self::Unsupported => unreachable!(), // We should never arrive at this code + Self::Json => json::req_http(&body), + Self::Cbor => cbor::req_http(body), + Self::Msgpack => msgpack::req_http(body), + Self::Bincode => bincode::req_http(&body), + Self::Revision => revision::req_http(body), } } /// Process a response using the specified format - pub fn res(&self, res: Response) -> Result<(usize, Message), Failure> { + pub fn res_http(&self, res: Response) -> Result { match self { Self::None => unreachable!(), // We should never arrive at this code - Self::Json => json::res(res), - Self::Cbor => cbor::res(res), - Self::Msgpack => msgpack::res(res), - Self::Bincode => bincode::res(res), - Self::Revision => revision::res(res), + Self::Unsupported => unreachable!(), // We should never arrive at this code + Self::Json => json::res_http(res), + Self::Cbor => cbor::res_http(res), + Self::Msgpack => msgpack::res_http(res), + Self::Bincode => bincode::res_http(res), + Self::Revision => revision::res_http(res), } } } diff --git a/src/rpc/format/msgpack/mod.rs b/src/rpc/format/msgpack/mod.rs index 5a3775eb..66a54608 100644 --- a/src/rpc/format/msgpack/mod.rs +++ b/src/rpc/format/msgpack/mod.rs @@ -1,26 +1,31 @@ mod convert; +use bytes::Bytes; pub use convert::Pack; +use http::header::CONTENT_TYPE; +use http::HeaderValue; +use surrealdb::rpc::RpcError; -use crate::rpc::failure::Failure; +use crate::net::headers::ContentType; use crate::rpc::request::Request; use crate::rpc::response::Response; use axum::extract::ws::Message; +use axum::response::{IntoResponse, Response as AxumResponse}; -pub fn req(msg: Message) -> Result { +pub fn req_ws(msg: Message) -> Result { match msg { Message::Text(val) => { - surrealdb::sql::value(&val).map_err(|_| Failure::PARSE_ERROR)?.try_into() + surrealdb::sql::value(&val).map_err(|_| RpcError::ParseError)?.try_into() } Message::Binary(val) => rmpv::decode::read_value(&mut val.as_slice()) - .map_err(|_| Failure::PARSE_ERROR) + .map_err(|_| RpcError::ParseError) .map(Pack)? .try_into(), - _ => Err(Failure::INVALID_REQUEST), + _ => Err(RpcError::InvalidRequest), } } -pub fn res(res: Response) -> Result<(usize, Message), Failure> { +pub fn res_ws(res: Response) -> Result<(usize, Message), RpcError> { // Convert the response into a value let val: Pack = res.into_value().try_into()?; // Create a new vector for encoding output @@ -30,3 +35,21 @@ pub fn res(res: Response) -> Result<(usize, Message), Failure> { // Return the message length, and message as binary Ok((res.len(), Message::Binary(res))) } +pub fn req_http(body: Bytes) -> Result { + let val: Vec = body.into(); + rmpv::decode::read_value(&mut val.as_slice()) + .map_err(|_| RpcError::ParseError) + .map(Pack)? + .try_into() +} + +pub fn res_http(res: Response) -> Result { + // Convert the response into a value + let val: Pack = res.into_value().try_into()?; + // Create a new vector for encoding output + let mut res = Vec::new(); + // Serialize the value into MsgPack binary data + rmpv::encode::write_value(&mut res, &val.0).unwrap(); + // Return the message length, and message as binary + Ok(([(CONTENT_TYPE, HeaderValue::from(ContentType::ApplicationPack))], res).into_response()) +} diff --git a/src/rpc/format/revision.rs b/src/rpc/format/revision.rs index ac5cb984..bbc650cf 100644 --- a/src/rpc/format/revision.rs +++ b/src/rpc/format/revision.rs @@ -1,23 +1,42 @@ -use crate::rpc::failure::Failure; +use crate::net::headers::ContentType; use crate::rpc::request::Request; use crate::rpc::response::Response; use axum::extract::ws::Message; +use axum::response::{IntoResponse, Response as AxumResponse}; +use bytes::Bytes; +use http::header::CONTENT_TYPE; +use http::HeaderValue; use revision::Revisioned; +use surrealdb::rpc::RpcError; use surrealdb::sql::Value; -pub fn req(msg: Message) -> Result { +pub fn req_ws(msg: Message) -> Result { match msg { Message::Binary(val) => Value::deserialize_revisioned(&mut val.as_slice()) - .map_err(|_| Failure::PARSE_ERROR)? + .map_err(|_| RpcError::ParseError)? .try_into(), - _ => Err(Failure::INVALID_REQUEST), + _ => Err(RpcError::InvalidRequest), } } -pub fn res(res: Response) -> Result<(usize, Message), Failure> { +pub fn res_ws(res: Response) -> Result<(usize, Message), RpcError> { // Serialize the response with full internal type information let mut buf = Vec::new(); res.serialize_revisioned(&mut buf).unwrap(); // Return the message length, and message as binary Ok((buf.len(), Message::Binary(buf))) } + +pub fn req_http(body: Bytes) -> Result { + let val: Vec = body.into(); + Value::deserialize_revisioned(&mut val.as_slice()).map_err(|_| RpcError::ParseError)?.try_into() +} + +pub fn res_http(res: Response) -> Result { + // Serialize the response with full internal type information + let mut buf = Vec::new(); + res.serialize_revisioned(&mut buf).unwrap(); + // Return the message length, and message as binary + // TODO: Check what this header should be, new header needed for revisioned + Ok(([(CONTENT_TYPE, HeaderValue::from(ContentType::Surrealdb))], buf).into_response()) +} diff --git a/src/rpc/mod.rs b/src/rpc/mod.rs index 35b87b23..0e417411 100644 --- a/src/rpc/mod.rs +++ b/src/rpc/mod.rs @@ -2,6 +2,7 @@ pub mod args; pub mod connection; pub mod failure; pub mod format; +pub mod post_context; pub mod request; pub mod response; diff --git a/src/rpc/post_context.rs b/src/rpc/post_context.rs new file mode 100644 index 00000000..1cb62b3f --- /dev/null +++ b/src/rpc/post_context.rs @@ -0,0 +1,103 @@ +use std::collections::BTreeMap; + +use crate::cnf::{PKG_NAME, PKG_VERSION}; +use surrealdb::dbs::Session; +use surrealdb::kvs::Datastore; +use surrealdb::rpc::args::Take; +use surrealdb::rpc::Data; +use surrealdb::rpc::RpcContext; +use surrealdb::rpc::RpcError; +use surrealdb::sql::Array; +use surrealdb::sql::Value; + +pub struct PostRpcContext<'a> { + pub kvs: &'a Datastore, + pub session: Session, + pub vars: BTreeMap, +} + +impl<'a> PostRpcContext<'a> { + pub fn new(kvs: &'a Datastore, session: Session, vars: BTreeMap) -> Self { + Self { + kvs, + session, + vars, + } + } +} + +impl RpcContext for PostRpcContext<'_> { + fn kvs(&self) -> &Datastore { + self.kvs + } + + fn session(&self) -> &Session { + &self.session + } + + fn session_mut(&mut self) -> &mut Session { + &mut self.session + } + + fn vars(&self) -> &BTreeMap { + &self.vars + } + + fn vars_mut(&mut self) -> &mut BTreeMap { + &mut self.vars + } + + fn version_data(&self) -> impl Into { + let val: Value = format!("{PKG_NAME}-{}", *PKG_VERSION).into(); + val + } + + // disable: + + // doesn't do anything so shouldn't be supported + async fn set(&mut self, _params: Array) -> Result, RpcError> { + let out: Result = Err(RpcError::MethodNotFound); + out + } + + // doesn't do anything so shouldn't be supported + async fn unset(&mut self, _params: Array) -> Result, RpcError> { + let out: Result = Err(RpcError::MethodNotFound); + out + } + + // reimplimentaions: + + async fn signup(&mut self, params: Array) -> Result, RpcError> { + let Ok(Value::Object(v)) = params.needs_one() else { + return Err(RpcError::InvalidParams); + }; + let out: Result = + surrealdb::iam::signup::signup(self.kvs, &mut self.session, v) + .await + .map(Into::into) + .map_err(Into::into); + + out + } + + async fn signin(&mut self, params: Array) -> Result, RpcError> { + let Ok(Value::Object(v)) = params.needs_one() else { + return Err(RpcError::InvalidParams); + }; + let out: Result = + surrealdb::iam::signin::signin(self.kvs, &mut self.session, v) + .await + .map(Into::into) + .map_err(Into::into); + out + } + + async fn authenticate(&mut self, params: Array) -> Result, RpcError> { + let Ok(Value::Strand(token)) = params.needs_one() else { + return Err(RpcError::InvalidParams); + }; + surrealdb::iam::verify::token(self.kvs, &mut self.session, &token.0).await?; + Ok(Value::None) + } +} diff --git a/src/rpc/request.rs b/src/rpc/request.rs index 99f7c5eb..b621cb91 100644 --- a/src/rpc/request.rs +++ b/src/rpc/request.rs @@ -1,7 +1,7 @@ -use crate::rpc::failure::Failure; use crate::rpc::format::cbor::Cbor; use crate::rpc::format::msgpack::Pack; use once_cell::sync::Lazy; +use surrealdb::rpc::RpcError; use surrealdb::sql::Part; use surrealdb::sql::{Array, Value}; @@ -16,22 +16,22 @@ pub struct Request { } impl TryFrom for Request { - type Error = Failure; - fn try_from(val: Cbor) -> Result { - >::try_into(val).map_err(|_| Failure::INVALID_REQUEST)?.try_into() + type Error = RpcError; + fn try_from(val: Cbor) -> Result { + >::try_into(val).map_err(|_| RpcError::InvalidRequest)?.try_into() } } impl TryFrom for Request { - type Error = Failure; - fn try_from(val: Pack) -> Result { - >::try_into(val).map_err(|_| Failure::INVALID_REQUEST)?.try_into() + type Error = RpcError; + fn try_from(val: Pack) -> Result { + >::try_into(val).map_err(|_| RpcError::InvalidRequest)?.try_into() } } impl TryFrom for Request { - type Error = Failure; - fn try_from(val: Value) -> Result { + type Error = RpcError; + fn try_from(val: Value) -> Result { // Fetch the 'id' argument let id = match val.pick(&*ID) { v if v.is_none() => None, @@ -40,12 +40,12 @@ impl TryFrom for Request { v if v.is_number() => Some(v), v if v.is_strand() => Some(v), v if v.is_datetime() => Some(v), - _ => return Err(Failure::INVALID_REQUEST), + _ => return Err(RpcError::InvalidRequest), }; // Fetch the 'method' argument let method = match val.pick(&*METHOD) { Value::Strand(v) => v.to_raw(), - _ => return Err(Failure::INVALID_REQUEST), + _ => return Err(RpcError::InvalidRequest), }; // Fetch the 'params' argument let params = match val.pick(&*PARAMS) { diff --git a/src/rpc/response.rs b/src/rpc/response.rs index f012bc46..c7dbac2d 100644 --- a/src/rpc/response.rs +++ b/src/rpc/response.rs @@ -8,61 +8,10 @@ use serde::Serialize; use serde_json::Value as Json; use std::sync::Arc; use surrealdb::channel::Sender; -use surrealdb::dbs; -use surrealdb::dbs::Notification; -use surrealdb::sql; +use surrealdb::rpc::Data; use surrealdb::sql::Value; use tracing::Span; -/// The data returned by the database -// The variants here should be in exactly the same order as `surrealdb::engine::remote::ws::Data` -// In future, they will possibly be merged to avoid having to keep them in sync. -#[derive(Debug, Serialize)] -#[revisioned(revision = 1)] -pub enum Data { - /// Generally methods return a `sql::Value` - Other(Value), - /// The query methods, `query` and `query_with` return a `Vec` of responses - Query(Vec), - /// Live queries return a notification - Live(Notification), - // Add new variants here -} - -impl From for Data { - fn from(v: Value) -> Self { - Data::Other(v) - } -} - -impl From for Data { - fn from(v: String) -> Self { - Data::Other(Value::from(v)) - } -} - -impl From for Data { - fn from(n: Notification) -> Self { - Data::Live(n) - } -} - -impl From> for Data { - fn from(v: Vec) -> Self { - Data::Query(v) - } -} - -impl From for Value { - fn from(val: Data) -> Self { - match val { - Data::Query(v) => sql::to_value(v).unwrap(), - Data::Live(v) => sql::to_value(v).unwrap(), - Data::Other(v) => v, - } - } -} - #[derive(Debug, Serialize)] #[revisioned(revision = 1)] pub struct Response { @@ -111,7 +60,7 @@ impl Response { span.record("rpc.error_message", err.message.as_ref()); } // Process the response for the format - let (len, msg) = fmt.res(self).unwrap(); + let (len, msg) = fmt.res_ws(self).unwrap(); // Send the message to the write channel if chn.send(msg).await.is_ok() { record_rpc(cx.as_ref(), len, is_error);