Add post rpc (#3697)

This commit is contained in:
Raphael Darley 2024-03-19 15:17:38 +00:00 committed by GitHub
parent 0fc410dec2
commit da483716c5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
28 changed files with 1490 additions and 672 deletions

View file

@ -83,7 +83,7 @@ surrealdb = { version = "1", path = "lib", features = [
"protocol-http", "protocol-http",
"protocol-ws", "protocol-ws",
"rustls", "rustls",
"sql2" "sql2",
] } ] }
tempfile = "3.8.1" tempfile = "3.8.1"
thiserror = "1.0.50" thiserror = "1.0.50"

View file

@ -40,6 +40,8 @@ pub mod obs;
#[doc(hidden)] #[doc(hidden)]
pub mod options; pub mod options;
#[doc(hidden)] #[doc(hidden)]
pub mod rpc;
#[doc(hidden)]
pub mod syn; pub mod syn;
#[doc(hidden)] #[doc(hidden)]

62
core/src/rpc/args.rs Normal file
View file

@ -0,0 +1,62 @@
use crate::sql::Array;
use crate::sql::Value;
use super::rpc_error::RpcError;
pub trait Take {
fn needs_one(self) -> Result<Value, RpcError>;
fn needs_two(self) -> Result<(Value, Value), RpcError>;
fn needs_one_or_two(self) -> Result<(Value, Value), RpcError>;
fn needs_one_two_or_three(self) -> Result<(Value, Value, Value), RpcError>;
}
impl Take for Array {
/// Convert the array to one argument
fn needs_one(self) -> Result<Value, RpcError> {
if self.len() != 1 {
return Err(RpcError::InvalidParams);
}
let mut x = self.into_iter();
match x.next() {
Some(a) => Ok(a),
None => Ok(Value::None),
}
}
/// Convert the array to two arguments
fn needs_two(self) -> Result<(Value, Value), RpcError> {
if self.len() != 2 {
return Err(RpcError::InvalidParams);
}
let mut x = self.into_iter();
match (x.next(), x.next()) {
(Some(a), Some(b)) => Ok((a, b)),
(Some(a), None) => Ok((a, Value::None)),
(_, _) => Ok((Value::None, Value::None)),
}
}
/// Convert the array to two arguments
fn needs_one_or_two(self) -> Result<(Value, Value), RpcError> {
if self.is_empty() && self.len() > 2 {
return Err(RpcError::InvalidParams);
}
let mut x = self.into_iter();
match (x.next(), x.next()) {
(Some(a), Some(b)) => Ok((a, b)),
(Some(a), None) => Ok((a, Value::None)),
(_, _) => Ok((Value::None, Value::None)),
}
}
/// Convert the array to three arguments
fn needs_one_two_or_three(self) -> Result<(Value, Value, Value), RpcError> {
if self.is_empty() && self.len() > 3 {
return Err(RpcError::InvalidParams);
}
let mut x = self.into_iter();
match (x.next(), x.next(), x.next()) {
(Some(a), Some(b), Some(c)) => Ok((a, b, c)),
(Some(a), Some(b), None) => Ok((a, b, Value::None)),
(Some(a), None, None) => Ok((a, Value::None, Value::None)),
(_, _, _) => Ok((Value::None, Value::None, Value::None)),
}
}
}

View file

@ -0,0 +1,94 @@
use std::collections::BTreeMap;
use crate::{
dbs::Session,
kvs::Datastore,
rpc::RpcContext,
sql::{Array, Value},
};
use super::{args::Take, Data, RpcError};
pub struct BasicRpcContext<'a> {
pub kvs: &'a Datastore,
pub session: Session,
pub vars: BTreeMap<String, Value>,
pub version_string: String,
}
impl<'a> BasicRpcContext<'a> {
pub fn new(
kvs: &'a Datastore,
session: Session,
vars: BTreeMap<String, Value>,
version_string: String,
) -> Self {
Self {
kvs,
session,
vars,
version_string,
}
}
}
impl RpcContext for BasicRpcContext<'_> {
fn kvs(&self) -> &Datastore {
self.kvs
}
fn session(&self) -> &Session {
&self.session
}
fn session_mut(&mut self) -> &mut Session {
&mut self.session
}
fn vars(&self) -> &BTreeMap<String, Value> {
&self.vars
}
fn vars_mut(&mut self) -> &mut BTreeMap<String, Value> {
&mut self.vars
}
fn version_data(&self) -> impl Into<super::Data> {
Value::Strand(self.version_string.clone().into())
}
// reimplimentaions:
async fn signup(&mut self, params: Array) -> Result<impl Into<Data>, RpcError> {
let Ok(Value::Object(v)) = params.needs_one() else {
return Err(RpcError::InvalidParams);
};
let out: Result<Value, RpcError> =
crate::iam::signup::signup(self.kvs, &mut self.session, v)
.await
.map(Into::into)
.map_err(Into::into);
out
}
async fn signin(&mut self, params: Array) -> Result<impl Into<Data>, RpcError> {
let Ok(Value::Object(v)) = params.needs_one() else {
return Err(RpcError::InvalidParams);
};
let out: Result<Value, RpcError> =
crate::iam::signin::signin(self.kvs, &mut self.session, v)
.await
.map(Into::into)
.map_err(Into::into);
out
}
async fn authenticate(&mut self, params: Array) -> Result<impl Into<Data>, RpcError> {
let Ok(Value::Strand(token)) = params.needs_one() else {
return Err(RpcError::InvalidParams);
};
crate::iam::verify::token(self.kvs, &mut self.session, &token.0).await?;
Ok(Value::None)
}
}

111
core/src/rpc/method.rs Normal file
View file

@ -0,0 +1,111 @@
#[non_exhaustive]
pub enum Method {
Unknown,
Ping,
Info,
Use,
Signup,
Signin,
Invalidate,
Authenticate,
Kill,
Live,
Set,
Unset,
Select,
Insert,
Create,
Update,
Merge,
Patch,
Delete,
Version,
Query,
Relate,
}
impl Method {
pub fn parse<S>(s: S) -> Self
where
S: AsRef<str>,
{
match s.as_ref().to_lowercase().as_str() {
"ping" => Self::Ping,
"info" => Self::Info,
"use" => Self::Use,
"signup" => Self::Signup,
"signin" => Self::Signin,
"invalidate" => Self::Invalidate,
"authenticate" => Self::Authenticate,
"kill" => Self::Kill,
"live" => Self::Live,
"let" | "set" => Self::Set,
"unset" => Self::Unset,
"select" => Self::Select,
"insert" => Self::Insert,
"create" => Self::Create,
"update" => Self::Update,
"merge" => Self::Merge,
"patch" => Self::Patch,
"delete" => Self::Delete,
"version" => Self::Version,
"query" => Self::Query,
"relate" => Self::Relate,
_ => Self::Unknown,
}
}
}
impl Method {
pub fn to_str(&self) -> &str {
match self {
Self::Unknown => "unknown",
Self::Ping => "ping",
Self::Info => "info",
Self::Use => "use",
Self::Signup => "signup",
Self::Signin => "signin",
Self::Invalidate => "invalidate",
Self::Authenticate => "authenticate",
Self::Kill => "kill",
Self::Live => "live",
Self::Set => "set",
Self::Unset => "unset",
Self::Select => "select",
Self::Insert => "insert",
Self::Create => "create",
Self::Update => "update",
Self::Merge => "merge",
Self::Patch => "patch",
Self::Delete => "delete",
Self::Version => "version",
Self::Query => "query",
Self::Relate => "relate",
}
}
}
impl Method {
pub fn is_valid(&self) -> bool {
!matches!(self, Self::Unknown)
}
pub fn needs_mut(&self) -> bool {
!self.can_be_immut()
}
// should be the same as execute_immut
pub fn can_be_immut(&self) -> bool {
matches!(
self,
Method::Ping
| Method::Info | Method::Select
| Method::Insert | Method::Create
| Method::Update | Method::Merge
| Method::Patch | Method::Delete
| Method::Version
| Method::Query | Method::Relate
| Method::Unknown
)
}
}

11
core/src/rpc/mod.rs Normal file
View file

@ -0,0 +1,11 @@
pub mod args;
pub mod basic_context;
pub mod method;
mod response;
pub mod rpc_context;
mod rpc_error;
pub use basic_context::BasicRpcContext;
pub use response::Data;
pub use rpc_context::RpcContext;
pub use rpc_error::RpcError;

55
core/src/rpc/response.rs Normal file
View file

@ -0,0 +1,55 @@
use crate::dbs;
use crate::dbs::Notification;
use crate::sql;
use crate::sql::Value;
use revision::revisioned;
use serde::Serialize;
/// The data returned by the database
// The variants here should be in exactly the same order as `crate::engine::remote::ws::Data`
// In future, they will possibly be merged to avoid having to keep them in sync.
#[derive(Debug, Serialize)]
#[revisioned(revision = 1)]
pub enum Data {
/// Generally methods return a `sql::Value`
Other(Value),
/// The query methods, `query` and `query_with` return a `Vec` of responses
Query(Vec<dbs::Response>),
/// Live queries return a notification
Live(Notification),
// Add new variants here
}
impl From<Value> for Data {
fn from(v: Value) -> Self {
Data::Other(v)
}
}
impl From<String> for Data {
fn from(v: String) -> Self {
Data::Other(Value::from(v))
}
}
impl From<Notification> for Data {
fn from(n: Notification) -> Self {
Data::Live(n)
}
}
impl From<Vec<dbs::Response>> for Data {
fn from(v: Vec<dbs::Response>) -> Self {
Data::Query(v)
}
}
impl From<Data> for Value {
fn from(val: Data) -> Self {
match val {
Data::Query(v) => sql::to_value(v).unwrap(),
Data::Live(v) => sql::to_value(v).unwrap(),
Data::Other(v) => v,
}
}
}

536
core/src/rpc/rpc_context.rs Normal file
View file

@ -0,0 +1,536 @@
use std::collections::BTreeMap;
use crate::{
dbs::{QueryType, Response, Session},
kvs::Datastore,
rpc::args::Take,
sql::{Array, Value},
};
use uuid::Uuid;
use super::{method::Method, response::Data, rpc_error::RpcError};
macro_rules! mrg {
($($m:expr, $x:expr)+) => {{
$($m.extend($x.iter().map(|(k, v)| (k.clone(), v.clone())));)+
$($m)+
}};
}
#[allow(async_fn_in_trait)]
pub trait RpcContext {
fn kvs(&self) -> &Datastore;
fn session(&self) -> &Session;
fn session_mut(&mut self) -> &mut Session;
fn vars(&self) -> &BTreeMap<String, Value>;
fn vars_mut(&mut self) -> &mut BTreeMap<String, Value>;
fn version_data(&self) -> impl Into<Data>;
const LQ_SUPPORT: bool = false;
fn handle_live(&self, _lqid: &Uuid) -> impl std::future::Future<Output = ()> + Send {
async { unreachable!() }
}
fn handle_kill(&self, _lqid: &Uuid) -> impl std::future::Future<Output = ()> + Send {
async { unreachable!() }
}
async fn execute(&mut self, method: Method, params: Array) -> Result<Data, RpcError> {
match method {
Method::Ping => Ok(Value::None.into()),
Method::Info => self.info().await.map(Into::into).map_err(Into::into),
Method::Use => self.yuse(params).await.map(Into::into).map_err(Into::into),
Method::Signup => self.signup(params).await.map(Into::into).map_err(Into::into),
Method::Signin => self.signin(params).await.map(Into::into).map_err(Into::into),
Method::Invalidate => self.invalidate().await.map(Into::into).map_err(Into::into),
Method::Authenticate => {
self.authenticate(params).await.map(Into::into).map_err(Into::into)
}
Method::Kill => self.kill(params).await.map(Into::into).map_err(Into::into),
Method::Live => self.live(params).await.map(Into::into).map_err(Into::into),
Method::Set => self.set(params).await.map(Into::into).map_err(Into::into),
Method::Unset => self.unset(params).await.map(Into::into).map_err(Into::into),
Method::Select => self.select(params).await.map(Into::into).map_err(Into::into),
Method::Insert => self.insert(params).await.map(Into::into).map_err(Into::into),
Method::Create => self.create(params).await.map(Into::into).map_err(Into::into),
Method::Update => self.update(params).await.map(Into::into).map_err(Into::into),
Method::Merge => self.merge(params).await.map(Into::into).map_err(Into::into),
Method::Patch => self.patch(params).await.map(Into::into).map_err(Into::into),
Method::Delete => self.delete(params).await.map(Into::into).map_err(Into::into),
Method::Version => self.version(params).await.map(Into::into).map_err(Into::into),
Method::Query => self.query(params).await.map(Into::into).map_err(Into::into),
Method::Relate => self.relate(params).await.map(Into::into).map_err(Into::into),
Method::Unknown => Err(RpcError::MethodNotFound),
}
}
async fn execute_immut(&self, method: Method, params: Array) -> Result<Data, RpcError> {
match method {
Method::Ping => Ok(Value::None.into()),
Method::Info => self.info().await.map(Into::into).map_err(Into::into),
Method::Select => self.select(params).await.map(Into::into).map_err(Into::into),
Method::Insert => self.insert(params).await.map(Into::into).map_err(Into::into),
Method::Create => self.create(params).await.map(Into::into).map_err(Into::into),
Method::Update => self.update(params).await.map(Into::into).map_err(Into::into),
Method::Merge => self.merge(params).await.map(Into::into).map_err(Into::into),
Method::Patch => self.patch(params).await.map(Into::into).map_err(Into::into),
Method::Delete => self.delete(params).await.map(Into::into).map_err(Into::into),
Method::Version => self.version(params).await.map(Into::into).map_err(Into::into),
Method::Query => self.query(params).await.map(Into::into).map_err(Into::into),
Method::Relate => self.relate(params).await.map(Into::into).map_err(Into::into),
Method::Unknown => Err(RpcError::MethodNotFound),
_ => Err(RpcError::MethodNotFound),
}
}
// ------------------------------
// Methods for authentication
// ------------------------------
async fn yuse(&mut self, params: Array) -> Result<impl Into<Data>, RpcError> {
let (ns, db) = params.needs_two()?;
if let Value::Strand(ns) = ns {
self.session_mut().ns = Some(ns.0);
}
if let Value::Strand(db) = db {
self.session_mut().db = Some(db.0);
}
Ok(Value::None)
}
async fn signup(&mut self, params: Array) -> Result<impl Into<Data>, RpcError> {
let Ok(Value::Object(v)) = params.needs_one() else {
return Err(RpcError::InvalidParams);
};
let mut tmp_session = self.session().clone();
let out: Result<Value, RpcError> =
crate::iam::signup::signup(self.kvs(), &mut tmp_session, v)
.await
.map(Into::into)
.map_err(Into::into);
*self.session_mut() = tmp_session;
out
}
async fn signin(&mut self, params: Array) -> Result<impl Into<Data>, RpcError> {
let Ok(Value::Object(v)) = params.needs_one() else {
return Err(RpcError::InvalidParams);
};
let mut tmp_session = self.session().clone();
let out: Result<Value, RpcError> =
crate::iam::signin::signin(self.kvs(), &mut tmp_session, v)
.await
.map(Into::into)
.map_err(Into::into);
*self.session_mut() = tmp_session;
out
}
async fn invalidate(&mut self) -> Result<impl Into<Data>, RpcError> {
crate::iam::clear::clear(self.session_mut())?;
Ok(Value::None)
}
async fn authenticate(&mut self, params: Array) -> Result<impl Into<Data>, RpcError> {
let Ok(Value::Strand(token)) = params.needs_one() else {
return Err(RpcError::InvalidParams);
};
let mut tmp_session = self.session().clone();
crate::iam::verify::token(self.kvs(), &mut tmp_session, &token.0).await?;
*self.session_mut() = tmp_session;
Ok(Value::None)
}
// ------------------------------
// Methods for identification
// ------------------------------
async fn info(&self) -> Result<impl Into<Data>, RpcError> {
// Specify the SQL query string
let sql = "SELECT * FROM $auth";
// Execute the query on the database
let mut res = self.kvs().execute(sql, self.session(), None).await?;
// Extract the first value from the result
let res = res.remove(0).result?.first();
// Return the result to the client
Ok(res)
}
// ------------------------------
// Methods for setting variables
// ------------------------------
async fn set(&mut self, params: Array) -> Result<impl Into<Data>, RpcError> {
let Ok((Value::Strand(key), val)) = params.needs_one_or_two() else {
return Err(RpcError::InvalidParams);
};
// Specify the query parameters
let var = Some(map! {
key.0.clone() => Value::None,
=> &self.vars()
});
// Compute the specified parameter
match self.kvs().compute(val, self.session(), var).await? {
// Remove the variable if undefined
Value::None => self.vars_mut().remove(&key.0),
// Store the variable if defined
v => self.vars_mut().insert(key.0, v),
};
Ok(Value::Null)
}
async fn unset(&mut self, params: Array) -> Result<impl Into<Data>, RpcError> {
let Ok(Value::Strand(key)) = params.needs_one() else {
return Err(RpcError::InvalidParams);
};
self.vars_mut().remove(&key.0);
Ok(Value::Null)
}
// ------------------------------
// Methods for live queries
// ------------------------------
async fn kill(&mut self, params: Array) -> Result<impl Into<Data>, RpcError> {
let id = params.needs_one()?;
// Specify the SQL query string
let sql = "KILL $id";
// Specify the query parameters
let var = map! {
String::from("id") => id,
=> &self.vars()
};
// Execute the query on the database
// let mut res = self.query_with(Value::from(sql), Object::from(var)).await?;
let mut res = self.query_inner(Value::from(sql), Some(var)).await?;
// Extract the first query result
let response = res.remove(0);
response.result.map_err(Into::into)
}
async fn live(&mut self, params: Array) -> Result<impl Into<Data>, RpcError> {
let (tb, diff) = params.needs_one_or_two()?;
// Specify the SQL query string
let sql = match diff.is_true() {
true => "LIVE SELECT DIFF FROM $tb",
false => "LIVE SELECT * FROM $tb",
};
// Specify the query parameters
let var = map! {
String::from("tb") => tb.could_be_table(),
=> &self.vars()
};
// Execute the query on the database
let mut res = self.query_inner(Value::from(sql), Some(var)).await?;
// Extract the first query result
let response = res.remove(0);
response.result.map_err(Into::into)
}
// ------------------------------
// Methods for selecting
// ------------------------------
async fn select(&self, params: Array) -> Result<impl Into<Data>, RpcError> {
let Ok(what) = params.needs_one() else {
return Err(RpcError::InvalidParams);
};
// Return a single result?
let one = what.is_thing();
// Specify the SQL query string
let sql = "SELECT * FROM $what";
// Specify the query parameters
let var = Some(map! {
String::from("what") => what.could_be_table(),
=> &self.vars()
});
// Execute the query on the database
let mut res = self.kvs().execute(sql, self.session(), var).await?;
// Extract the first query result
let res = match one {
true => res.remove(0).result?.first(),
false => res.remove(0).result?,
};
// Return the result to the client
Ok(res)
}
// ------------------------------
// Methods for inserting
// ------------------------------
async fn insert(&self, params: Array) -> Result<impl Into<Data>, RpcError> {
let Ok((what, data)) = params.needs_two() else {
return Err(RpcError::InvalidParams);
};
// Return a single result?
let one = what.is_thing();
// Specify the SQL query string
let sql = "INSERT INTO $what $data RETURN AFTER";
// Specify the query parameters
let var = Some(map! {
String::from("what") => what.could_be_table(),
String::from("data") => data,
=> &self.vars()
});
// Execute the query on the database
let mut res = self.kvs().execute(sql, self.session(), var).await?;
// Extract the first query result
let res = match one {
true => res.remove(0).result?.first(),
false => res.remove(0).result?,
};
// Return the result to the client
Ok(res)
}
// ------------------------------
// Methods for creating
// ------------------------------
async fn create(&self, params: Array) -> Result<impl Into<Data>, RpcError> {
let Ok((what, data)) = params.needs_one_or_two() else {
return Err(RpcError::InvalidParams);
};
// Return a single result?
let one = what.is_thing();
// Specify the SQL query string
let sql = if data.is_none_or_null() {
"CREATE $what RETURN AFTER"
} else {
"CREATE $what CONTENT $data RETURN AFTER"
};
// Specify the query parameters
let var = Some(map! {
String::from("what") => what.could_be_table(),
String::from("data") => data,
=> &self.vars()
});
// Execute the query on the database
let mut res = self.kvs().execute(sql, self.session(), var).await?;
// Extract the first query result
let res = match one {
true => res.remove(0).result?.first(),
false => res.remove(0).result?,
};
// Return the result to the client
Ok(res)
}
// ------------------------------
// Methods for updating
// ------------------------------
async fn update(&self, params: Array) -> Result<impl Into<Data>, RpcError> {
let Ok((what, data)) = params.needs_one_or_two() else {
return Err(RpcError::InvalidParams);
};
// Return a single result?
let one = what.is_thing();
// Specify the SQL query string
let sql = if data.is_none_or_null() {
"UPDATE $what RETURN AFTER"
} else {
"UPDATE $what CONTENT $data RETURN AFTER"
};
// Specify the query parameters
let var = Some(map! {
String::from("what") => what.could_be_table(),
String::from("data") => data,
=> &self.vars()
});
// Execute the query on the database
let mut res = self.kvs().execute(sql, self.session(), var).await?;
// Extract the first query result
let res = match one {
true => res.remove(0).result?.first(),
false => res.remove(0).result?,
};
// Return the result to the client
Ok(res)
}
// ------------------------------
// Methods for merging
// ------------------------------
async fn merge(&self, params: Array) -> Result<impl Into<Data>, RpcError> {
let Ok((what, data)) = params.needs_one_or_two() else {
return Err(RpcError::InvalidParams);
};
// Return a single result?
let one = what.is_thing();
// Specify the SQL query string
let sql = if data.is_none_or_null() {
"UPDATE $what RETURN AFTER"
} else {
"UPDATE $what MERGE $data RETURN AFTER"
};
// Specify the query parameters
let var = Some(map! {
String::from("what") => what.could_be_table(),
String::from("data") => data,
=> &self.vars()
});
// Execute the query on the database
let mut res = self.kvs().execute(sql, self.session(), var).await?;
// Extract the first query result
let res = match one {
true => res.remove(0).result?.first(),
false => res.remove(0).result?,
};
// Return the result to the client
Ok(res)
}
// ------------------------------
// Methods for patching
// ------------------------------
async fn patch(&self, params: Array) -> Result<impl Into<Data>, RpcError> {
let Ok((what, data, diff)) = params.needs_one_two_or_three() else {
return Err(RpcError::InvalidParams);
};
// Return a single result?
let one = what.is_thing();
// Specify the SQL query string
let sql = match diff.is_true() {
true => "UPDATE $what PATCH $data RETURN DIFF",
false => "UPDATE $what PATCH $data RETURN AFTER",
};
// Specify the query parameters
let var = Some(map! {
String::from("what") => what.could_be_table(),
String::from("data") => data,
=> &self.vars()
});
// Execute the query on the database
let mut res = self.kvs().execute(sql, self.session(), var).await?;
// Extract the first query result
let res = match one {
true => res.remove(0).result?.first(),
false => res.remove(0).result?,
};
// Return the result to the client
Ok(res)
}
// ------------------------------
// Methods for deleting
// ------------------------------
async fn delete(&self, params: Array) -> Result<impl Into<Data>, RpcError> {
let Ok(what) = params.needs_one() else {
return Err(RpcError::InvalidParams);
};
// Return a single result?
let one = what.is_thing();
// Specify the SQL query string
let sql = "DELETE $what RETURN BEFORE";
// Specify the query parameters
let var = Some(map! {
String::from("what") => what.could_be_table(),
=> &self.vars()
});
// Execute the query on the database
let mut res = self.kvs().execute(sql, self.session(), var).await?;
// Extract the first query result
let res = match one {
true => res.remove(0).result?.first(),
false => res.remove(0).result?,
};
// Return the result to the client
Ok(res)
}
// ------------------------------
// Methods for getting info
// ------------------------------
async fn version(&self, params: Array) -> Result<impl Into<Data>, RpcError> {
match params.len() {
0 => Ok(self.version_data()),
_ => Err(RpcError::InvalidParams),
}
}
// ------------------------------
// Methods for querying
// ------------------------------
async fn query(&self, params: Array) -> Result<impl Into<Data>, RpcError> {
let Ok((query, o)) = params.needs_one_or_two() else {
return Err(RpcError::InvalidParams);
};
if !(query.is_query() || query.is_strand()) {
return Err(RpcError::InvalidParams);
}
let o = match o {
Value::Object(v) => Some(v),
Value::None | Value::Null => None,
_ => return Err(RpcError::InvalidParams),
};
// Specify the query parameters
let vars = match o {
Some(mut v) => Some(mrg! {v.0, &self.vars()}),
None => Some(self.vars().clone()),
};
self.query_inner(query, vars).await
}
// ------------------------------
// Methods for querying
// ------------------------------
async fn relate(&self, _params: Array) -> Result<impl Into<Data>, RpcError> {
let out: Result<Value, RpcError> = Err(RpcError::MethodNotFound);
out
}
// ------------------------------
// Private methods
// ------------------------------
async fn query_inner(
&self,
query: Value,
vars: Option<BTreeMap<String, Value>>,
) -> Result<Vec<Response>, RpcError> {
// If no live query handler force realtime off
if !Self::LQ_SUPPORT && self.session().rt {
return Err(RpcError::BadLQConfig);
}
// Execute the query on the database
let res = match query {
Value::Query(sql) => self.kvs().process(sql, self.session(), vars).await?,
Value::Strand(sql) => self.kvs().execute(&sql, self.session(), vars).await?,
_ => unreachable!(),
};
// Post-process hooks for web layer
for response in &res {
// This error should be unreachable because we shouldn't proceed if there's no handler
self.handle_live_query_results(response).await;
}
// Return the result to the client
Ok(res)
}
async fn handle_live_query_results(&self, res: &Response) {
match &res.query_type {
QueryType::Live => {
if let Ok(Value::Uuid(lqid)) = &res.result {
self.handle_live(&lqid.0).await;
}
}
QueryType::Kill => {
if let Ok(Value::Uuid(lqid)) = &res.result {
self.handle_kill(&lqid.0).await;
}
}
_ => {}
}
}
}

52
core/src/rpc/rpc_error.rs Normal file
View file

@ -0,0 +1,52 @@
use thiserror::Error;
use crate::err;
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum RpcError {
#[error("Parse error")]
ParseError,
#[error("Invalid request")]
InvalidRequest,
#[error("Method not found")]
MethodNotFound,
#[error("Invalid params")]
InvalidParams,
#[error("There was a problem with the database: {0}")]
InternalError(err::Error),
#[error("Live Query was made, but is not supported")]
LqNotSuported,
#[error("RT is enabled for the session, but LQ is not supported with the context")]
BadLQConfig,
#[error("Error: {0}")]
Thrown(String),
}
impl From<err::Error> for RpcError {
fn from(e: err::Error) -> Self {
use err::Error;
match e {
Error::RealtimeDisabled => RpcError::LqNotSuported,
_ => RpcError::InternalError(e),
}
}
}
impl From<&str> for RpcError {
fn from(e: &str) -> Self {
RpcError::Thrown(e.to_string())
}
}
impl From<RpcError> for err::Error {
fn from(value: RpcError) -> Self {
use err::Error;
match value {
RpcError::InternalError(e) => e,
RpcError::Thrown(e) => Error::Thrown(e),
_ => Error::Thrown(value.to_string()),
}
}
}

View file

@ -135,6 +135,17 @@ impl From<surrealdb::error::Db> for Error {
} }
} }
impl From<surrealdb::rpc::RpcError> for Error {
fn from(value: surrealdb::rpc::RpcError) -> Self {
use surrealdb::rpc::RpcError;
match value {
RpcError::InternalError(e) => Error::Db(surrealdb::Error::Db(e)),
RpcError::Thrown(e) => Error::Other(e),
_ => Error::Other(value.to_string()),
}
}
}
impl Serialize for Error { impl Serialize for Error {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where where

View file

@ -6,10 +6,3 @@ macro_rules! map {
m m
}}; }};
} }
macro_rules! mrg {
($($m:expr, $x:expr)+) => {{
$($m.extend($x.iter().map(|(k, v)| (k.clone(), v.clone())));)+
$($m)+
}};
}

View file

@ -0,0 +1,70 @@
use axum::headers;
use axum::headers::Header;
use http::HeaderName;
use http::HeaderValue;
/// Typed header implementation for the `ContentType` header.
pub enum ContentType {
TextPlain,
ApplicationJson,
ApplicationCbor,
ApplicationPack,
ApplicationOctetStream,
Surrealdb,
}
impl std::fmt::Display for ContentType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ContentType::TextPlain => write!(f, "text/plain"),
ContentType::ApplicationJson => write!(f, "application/json"),
ContentType::ApplicationCbor => write!(f, "application/cbor"),
ContentType::ApplicationPack => write!(f, "application/pack"),
ContentType::ApplicationOctetStream => write!(f, "application/octet-stream"),
ContentType::Surrealdb => write!(f, "application/surrealdb"),
}
}
}
impl Header for ContentType {
fn name() -> &'static HeaderName {
&http::header::CONTENT_TYPE
}
fn decode<'i, I>(values: &mut I) -> Result<Self, headers::Error>
where
I: Iterator<Item = &'i HeaderValue>,
{
let value = values.next().ok_or_else(headers::Error::invalid)?;
match value.to_str().map_err(|_| headers::Error::invalid())? {
"text/plain" => Ok(ContentType::TextPlain),
"application/json" => Ok(ContentType::ApplicationJson),
"application/cbor" => Ok(ContentType::ApplicationCbor),
"application/pack" => Ok(ContentType::ApplicationPack),
"application/octet-stream" => Ok(ContentType::ApplicationOctetStream),
"application/surrealdb" => Ok(ContentType::Surrealdb),
// TODO: Support more (all?) mime-types
_ => Err(headers::Error::invalid()),
}
}
fn encode<E>(&self, values: &mut E)
where
E: Extend<HeaderValue>,
{
values.extend(std::iter::once(self.into()));
}
}
impl From<ContentType> for HeaderValue {
fn from(value: ContentType) -> Self {
HeaderValue::from(&value)
}
}
impl From<&ContentType> for HeaderValue {
fn from(value: &ContentType) -> Self {
HeaderValue::from_str(value.to_string().as_str()).unwrap()
}
}

View file

@ -14,6 +14,7 @@ use tower_http::set_header::SetResponseHeaderLayer;
mod accept; mod accept;
mod auth_db; mod auth_db;
mod auth_ns; mod auth_ns;
mod content_type;
mod db; mod db;
mod id; mod id;
mod ns; mod ns;
@ -21,6 +22,7 @@ mod ns;
pub use accept::Accept; pub use accept::Accept;
pub use auth_db::SurrealAuthDatabase; pub use auth_db::SurrealAuthDatabase;
pub use auth_ns::SurrealAuthNamespace; pub use auth_ns::SurrealAuthNamespace;
pub use content_type::ContentType;
pub use db::{SurrealDatabase, SurrealDatabaseLegacy}; pub use db::{SurrealDatabase, SurrealDatabaseLegacy};
pub use id::{SurrealId, SurrealIdLegacy}; pub use id::{SurrealId, SurrealIdLegacy};
pub use ns::{SurrealNamespace, SurrealNamespaceLegacy}; pub use ns::{SurrealNamespace, SurrealNamespaceLegacy};

View file

@ -97,6 +97,7 @@ async fn select_all(
Some(Accept::ApplicationCbor) => Ok(output::cbor(&output::simplify(res))), Some(Accept::ApplicationCbor) => Ok(output::cbor(&output::simplify(res))),
Some(Accept::ApplicationPack) => Ok(output::pack(&output::simplify(res))), Some(Accept::ApplicationPack) => Ok(output::pack(&output::simplify(res))),
// Internal serialization // Internal serialization
// TODO: remove format in 2.0.0
Some(Accept::Surrealdb) => Ok(output::full(&res)), Some(Accept::Surrealdb) => Ok(output::full(&res)),
// An incorrect content-type was requested // An incorrect content-type was requested
_ => Err(Error::InvalidType), _ => Err(Error::InvalidType),

View file

@ -1,12 +1,12 @@
mod auth; mod auth;
pub mod client_ip; pub mod client_ip;
mod export; mod export;
mod headers; pub(crate) mod headers;
mod health; mod health;
mod import; mod import;
mod input; mod input;
mod key; mod key;
mod output; pub(crate) mod output;
mod params; mod params;
mod rpc; mod rpc;
mod signals; mod signals;

View file

@ -1,30 +1,47 @@
use std::collections::BTreeMap;
use std::ops::Deref;
use crate::cnf; use crate::cnf;
use crate::dbs::DB;
use crate::err::Error; use crate::err::Error;
use crate::rpc::connection::Connection; use crate::rpc::connection::Connection;
use crate::rpc::format::Format; use crate::rpc::format::Format;
use crate::rpc::format::PROTOCOLS; use crate::rpc::format::PROTOCOLS;
use crate::rpc::post_context::PostRpcContext;
use crate::rpc::response::IntoRpcResponse;
use crate::rpc::WEBSOCKETS; use crate::rpc::WEBSOCKETS;
use axum::routing::get; use axum::routing::get;
use axum::routing::post;
use axum::TypedHeader;
use axum::{ use axum::{
extract::ws::{WebSocket, WebSocketUpgrade}, extract::ws::{WebSocket, WebSocketUpgrade},
response::IntoResponse, response::IntoResponse,
Extension, Router, Extension, Router,
}; };
use bytes::Bytes;
use http::HeaderValue; use http::HeaderValue;
use http_body::Body as HttpBody; use http_body::Body as HttpBody;
use surrealdb::dbs::Session; use surrealdb::dbs::Session;
use surrealdb::rpc::method::Method;
use tower_http::request_id::RequestId; use tower_http::request_id::RequestId;
use uuid::Uuid; use uuid::Uuid;
use super::headers::Accept;
use super::headers::ContentType;
use surrealdb::rpc::rpc_context::RpcContext;
pub(super) fn router<S, B>() -> Router<S, B> pub(super) fn router<S, B>() -> Router<S, B>
where where
B: HttpBody + Send + 'static, B: HttpBody + Send + 'static,
B::Data: Send,
B::Error: std::error::Error + Send + Sync + 'static,
S: Clone + Send + Sync + 'static, S: Clone + Send + Sync + 'static,
{ {
Router::new().route("/rpc", get(handler)) Router::new().route("/rpc", get(get_handler)).route("/rpc", post(post_handler))
} }
async fn handler( async fn get_handler(
ws: WebSocketUpgrade, ws: WebSocketUpgrade,
Extension(id): Extension<RequestId>, Extension(id): Extension<RequestId>,
Extension(sess): Extension<Session>, Extension(sess): Extension<Session>,
@ -70,9 +87,37 @@ async fn handle_socket(ws: WebSocket, sess: Session, id: Uuid) {
// No protocol format was specified // No protocol format was specified
_ => Format::None, _ => Format::None,
}; };
// // Format::Unsupported is not in the PROTOCOLS list so cannot be the value of format here
// Create a new connection instance // Create a new connection instance
let rpc = Connection::new(id, sess, format); let rpc = Connection::new(id, sess, format);
// Serve the socket connection requests // Serve the socket connection requests
Connection::serve(rpc, ws).await; Connection::serve(rpc, ws).await;
} }
async fn post_handler(
Extension(session): Extension<Session>,
output: Option<TypedHeader<Accept>>,
content_type: TypedHeader<ContentType>,
body: Bytes,
) -> Result<impl IntoResponse, impl IntoResponse> {
let fmt: Format = content_type.deref().into();
let out_fmt: Option<Format> = output.as_deref().map(Into::into);
if let Some(out_fmt) = out_fmt {
if fmt != out_fmt {
return Err(Error::InvalidType);
}
}
if fmt == Format::Unsupported || fmt == Format::None {
return Err(Error::InvalidType);
}
let mut rpc_ctx = PostRpcContext::new(DB.get().unwrap(), session, BTreeMap::new());
match fmt.req_http(body) {
Ok(req) => {
let res = rpc_ctx.execute(Method::parse(req.method), req.params).await;
fmt.res_http(res.into_response(None)).map_err(Error::from)
}
Err(err) => Err(Error::from(err)),
}
}

View file

@ -1,12 +1,10 @@
use crate::cnf::PKG_NAME; use crate::cnf::{
use crate::cnf::PKG_VERSION; PKG_NAME, PKG_VERSION, WEBSOCKET_MAX_CONCURRENT_REQUESTS, WEBSOCKET_PING_FREQUENCY,
use crate::cnf::{WEBSOCKET_MAX_CONCURRENT_REQUESTS, WEBSOCKET_PING_FREQUENCY}; };
use crate::dbs::DB; use crate::dbs::DB;
use crate::err::Error;
use crate::rpc::args::Take;
use crate::rpc::failure::Failure; use crate::rpc::failure::Failure;
use crate::rpc::format::Format; use crate::rpc::format::Format;
use crate::rpc::response::{failure, Data, IntoRpcResponse}; use crate::rpc::response::{failure, IntoRpcResponse};
use crate::rpc::{CONN_CLOSED_ERR, LIVE_QUERIES, WEBSOCKETS}; use crate::rpc::{CONN_CLOSED_ERR, LIVE_QUERIES, WEBSOCKETS};
use crate::telemetry; use crate::telemetry;
use crate::telemetry::metrics::ws::RequestContext; use crate::telemetry::metrics::ws::RequestContext;
@ -19,12 +17,13 @@ use opentelemetry::Context as TelemetryContext;
use std::collections::BTreeMap; use std::collections::BTreeMap;
use std::sync::Arc; use std::sync::Arc;
use surrealdb::channel::{self, Receiver, Sender}; use surrealdb::channel::{self, Receiver, Sender};
use surrealdb::dbs::QueryType;
use surrealdb::dbs::Response;
use surrealdb::dbs::Session; use surrealdb::dbs::Session;
use surrealdb::kvs::Datastore;
use surrealdb::rpc::args::Take;
use surrealdb::rpc::method::Method;
use surrealdb::rpc::RpcContext;
use surrealdb::rpc::{Data, RpcError};
use surrealdb::sql::Array; use surrealdb::sql::Array;
use surrealdb::sql::Object;
use surrealdb::sql::Strand;
use surrealdb::sql::Value; use surrealdb::sql::Value;
use tokio::sync::{RwLock, Semaphore}; use tokio::sync::{RwLock, Semaphore};
use tokio::task::JoinSet; use tokio::task::JoinSet;
@ -295,7 +294,7 @@ impl Connection {
let req_cx = RequestContext::default(); let req_cx = RequestContext::default();
let otel_cx = Arc::new(TelemetryContext::new().with_value(req_cx.clone())); let otel_cx = Arc::new(TelemetryContext::new().with_value(req_cx.clone()));
// Parse the RPC request structure // Parse the RPC request structure
match fmt.req(msg) { match fmt.req_ws(msg) {
Ok(req) => { Ok(req) => {
// Now that we know the method, we can update the span and create otel context // Now that we know the method, we can update the span and create otel context
span.record("rpc.method", &req.method); span.record("rpc.method", &req.method);
@ -337,547 +336,90 @@ impl Connection {
params: Array, params: Array,
) -> Result<Data, Failure> { ) -> Result<Data, Failure> {
debug!("Process RPC request"); debug!("Process RPC request");
// Match the method to a function let method = Method::parse(method);
match method { if !method.is_valid() {
// Handle a surrealdb ping message return Err(Failure::METHOD_NOT_FOUND);
//
// This is used to keep the WebSocket connection alive in environments where the WebSocket protocol is not enough.
// For example, some browsers will wait for the TCP protocol to timeout before triggering an on_close event. This may take several seconds or even minutes in certain scenarios.
// By sending a ping message every few seconds from the client, we can force a connection check and trigger an on_close event if the ping can't be sent.
//
"ping" => Ok(Value::None.into()),
// Retrieve the current auth record
"info" => match params.len() {
0 => rpc.read().await.info().await.map(Into::into).map_err(Into::into),
_ => Err(Failure::INVALID_PARAMS),
},
// Switch to a specific namespace and database
"use" => match params.needs_two() {
Ok((ns, db)) => {
rpc.write().await.yuse(ns, db).await.map(Into::into).map_err(Into::into)
}
_ => Err(Failure::INVALID_PARAMS),
},
// Signup to a specific authentication scope
"signup" => match params.needs_one() {
Ok(Value::Object(v)) => {
rpc.write().await.signup(v).await.map(Into::into).map_err(Into::into)
}
_ => Err(Failure::INVALID_PARAMS),
},
// Signin as a root, namespace, database or scope user
"signin" => match params.needs_one() {
Ok(Value::Object(v)) => {
rpc.write().await.signin(v).await.map(Into::into).map_err(Into::into)
}
_ => Err(Failure::INVALID_PARAMS),
},
// Invalidate the current authentication session
"invalidate" => match params.len() {
0 => rpc.write().await.invalidate().await.map(Into::into).map_err(Into::into),
_ => Err(Failure::INVALID_PARAMS),
},
// Authenticate using an authentication token
"authenticate" => match params.needs_one() {
Ok(Value::Strand(v)) => {
rpc.write().await.authenticate(v).await.map(Into::into).map_err(Into::into)
}
_ => Err(Failure::INVALID_PARAMS),
},
// Kill a live query using a query id
"kill" => match params.needs_one() {
Ok(v) => rpc.read().await.kill(v).await.map(Into::into).map_err(Into::into),
_ => Err(Failure::INVALID_PARAMS),
},
// Setup a live query on a specific table
"live" => match params.needs_one_or_two() {
Ok((v, d)) if v.is_table() => {
rpc.read().await.live(v, d).await.map(Into::into).map_err(Into::into)
}
Ok((v, d)) if v.is_strand() => {
rpc.read().await.live(v, d).await.map(Into::into).map_err(Into::into)
}
_ => Err(Failure::INVALID_PARAMS),
},
// Specify a connection-wide parameter
"let" | "set" => match params.needs_one_or_two() {
Ok((Value::Strand(s), v)) => {
rpc.write().await.set(s, v).await.map(Into::into).map_err(Into::into)
}
_ => Err(Failure::INVALID_PARAMS),
},
// Unset and clear a connection-wide parameter
"unset" => match params.needs_one() {
Ok(Value::Strand(s)) => {
rpc.write().await.unset(s).await.map(Into::into).map_err(Into::into)
}
_ => Err(Failure::INVALID_PARAMS),
},
// Select a value or values from the database
"select" => match params.needs_one() {
Ok(v) => rpc.read().await.select(v).await.map(Into::into).map_err(Into::into),
_ => Err(Failure::INVALID_PARAMS),
},
// Insert a value or values in the database
"insert" => match params.needs_one_or_two() {
Ok((v, o)) => {
rpc.read().await.insert(v, o).await.map(Into::into).map_err(Into::into)
}
_ => Err(Failure::INVALID_PARAMS),
},
// Create a value or values in the database
"create" => match params.needs_one_or_two() {
Ok((v, o)) => {
rpc.read().await.create(v, o).await.map(Into::into).map_err(Into::into)
}
_ => Err(Failure::INVALID_PARAMS),
},
// Update a value or values in the database using `CONTENT`
"update" => match params.needs_one_or_two() {
Ok((v, o)) => {
rpc.read().await.update(v, o).await.map(Into::into).map_err(Into::into)
}
_ => Err(Failure::INVALID_PARAMS),
},
// Update a value or values in the database using `MERGE`
"merge" => match params.needs_one_or_two() {
Ok((v, o)) => {
rpc.read().await.merge(v, o).await.map(Into::into).map_err(Into::into)
}
_ => Err(Failure::INVALID_PARAMS),
},
// Update a value or values in the database using `PATCH`
"patch" => match params.needs_one_two_or_three() {
Ok((v, o, d)) => {
rpc.read().await.patch(v, o, d).await.map(Into::into).map_err(Into::into)
}
_ => Err(Failure::INVALID_PARAMS),
},
// Delete a value or values from the database
"delete" => match params.needs_one() {
Ok(v) => rpc.read().await.delete(v).await.map(Into::into).map_err(Into::into),
_ => Err(Failure::INVALID_PARAMS),
},
// Get the current server version
"version" => match params.len() {
0 => Ok(format!("{PKG_NAME}-{}", *PKG_VERSION).into()),
_ => Err(Failure::INVALID_PARAMS),
},
// Run a full SurrealQL query against the database
"query" => match params.needs_one_or_two() {
Ok((v, o)) if (v.is_strand() || v.is_query()) && o.is_none_or_null() => {
rpc.read().await.query(v).await.map(Into::into).map_err(Into::into)
}
Ok((v, Value::Object(o))) if v.is_strand() || v.is_query() => {
rpc.read().await.query_with(v, o).await.map(Into::into).map_err(Into::into)
}
_ => Err(Failure::INVALID_PARAMS),
},
_ => Err(Failure::METHOD_NOT_FOUND),
}
} }
// ------------------------------ // if the write lock is a bottleneck then execute could be refactored into execute_mut and execute
// Methods for authentication // rpc.write().await.execute(method, params).await.map_err(Into::into)
// ------------------------------ match method.needs_mut() {
true => rpc.write().await.execute(method, params).await.map_err(Into::into),
async fn yuse(&mut self, ns: Value, db: Value) -> Result<Value, Error> { false => rpc.read().await.execute_immut(method, params).await.map_err(Into::into),
if let Value::Strand(ns) = ns {
self.session.ns = Some(ns.0);
}
if let Value::Strand(db) = db {
self.session.db = Some(db.0);
}
Ok(Value::None)
}
async fn signup(&mut self, vars: Object) -> Result<Value, Error> {
let kvs = DB.get().unwrap();
surrealdb::iam::signup::signup(kvs, &mut self.session, vars)
.await
.map(Into::into)
.map_err(Into::into)
}
async fn signin(&mut self, vars: Object) -> Result<Value, Error> {
let kvs = DB.get().unwrap();
surrealdb::iam::signin::signin(kvs, &mut self.session, vars)
.await
.map(Into::into)
.map_err(Into::into)
}
async fn invalidate(&mut self) -> Result<Value, Error> {
surrealdb::iam::clear::clear(&mut self.session)?;
Ok(Value::None)
}
async fn authenticate(&mut self, token: Strand) -> Result<Value, Error> {
let kvs = DB.get().unwrap();
surrealdb::iam::verify::token(kvs, &mut self.session, &token.0).await?;
Ok(Value::None)
}
// ------------------------------
// Methods for identification
// ------------------------------
async fn info(&self) -> Result<Value, Error> {
// Get a database reference
let kvs = DB.get().unwrap();
// Specify the SQL query string
let sql = "SELECT * FROM $auth";
// Execute the query on the database
let mut res = kvs.execute(sql, &self.session, None).await?;
// Extract the first value from the result
let res = res.remove(0).result?.first();
// Return the result to the client
Ok(res)
}
// ------------------------------
// Methods for setting variables
// ------------------------------
async fn set(&mut self, key: Strand, val: Value) -> Result<Value, Error> {
// Get a database reference
let kvs = DB.get().unwrap();
// Specify the query parameters
let var = Some(map! {
key.0.clone() => Value::None,
=> &self.vars
});
// Compute the specified parameter
match kvs.compute(val, &self.session, var).await? {
// Remove the variable if undefined
Value::None => self.vars.remove(&key.0),
// Store the variable if defined
v => self.vars.insert(key.0, v),
};
Ok(Value::Null)
}
async fn unset(&mut self, key: Strand) -> Result<Value, Error> {
self.vars.remove(&key.0);
Ok(Value::Null)
}
// ------------------------------
// Methods for live queries
// ------------------------------
async fn kill(&self, id: Value) -> Result<Value, Error> {
// Specify the SQL query string
let sql = "KILL $id";
// Specify the query parameters
let var = map! {
String::from("id") => id,
=> &self.vars
};
// Execute the query on the database
let mut res = self.query_with(Value::from(sql), Object::from(var)).await?;
// Extract the first query result
let response = res.remove(0);
match response.result {
Ok(v) => Ok(v),
Err(e) => Err(Error::from(e)),
}
}
async fn live(&self, tb: Value, diff: Value) -> Result<Value, Error> {
// Specify the SQL query string
let sql = match diff.is_true() {
true => "LIVE SELECT DIFF FROM $tb",
false => "LIVE SELECT * FROM $tb",
};
// Specify the query parameters
let var = map! {
String::from("tb") => tb.could_be_table(),
=> &self.vars
};
// Execute the query on the database
let mut res = self.query_with(Value::from(sql), Object::from(var)).await?;
// Extract the first query result
let response = res.remove(0);
match response.result {
Ok(v) => Ok(v),
Err(e) => Err(Error::from(e)),
}
}
// ------------------------------
// Methods for selecting
// ------------------------------
async fn select(&self, what: Value) -> Result<Value, Error> {
// Return a single result?
let one = what.is_thing();
// Get a database reference
let kvs = DB.get().unwrap();
// Specify the SQL query string
let sql = "SELECT * FROM $what";
// Specify the query parameters
let var = Some(map! {
String::from("what") => what.could_be_table(),
=> &self.vars
});
// Execute the query on the database
let mut res = kvs.execute(sql, &self.session, var).await?;
// Extract the first query result
let res = match one {
true => res.remove(0).result?.first(),
false => res.remove(0).result?,
};
// Return the result to the client
Ok(res)
}
// ------------------------------
// Methods for inserting
// ------------------------------
async fn insert(&self, what: Value, data: Value) -> Result<Value, Error> {
// Return a single result?
let one = what.is_thing();
// Get a database reference
let kvs = DB.get().unwrap();
// Specify the SQL query string
let sql = "INSERT INTO $what $data RETURN AFTER";
// Specify the query parameters
let var = Some(map! {
String::from("what") => what.could_be_table(),
String::from("data") => data,
=> &self.vars
});
// Execute the query on the database
let mut res = kvs.execute(sql, &self.session, var).await?;
// Extract the first query result
let res = match one {
true => res.remove(0).result?.first(),
false => res.remove(0).result?,
};
// Return the result to the client
Ok(res)
}
// ------------------------------
// Methods for creating
// ------------------------------
async fn create(&self, what: Value, data: Value) -> Result<Value, Error> {
// Return a single result?
let one = what.is_thing();
// Get a database reference
let kvs = DB.get().unwrap();
// Specify the SQL query string
let sql = if data.is_none_or_null() {
"CREATE $what RETURN AFTER"
} else {
"CREATE $what CONTENT $data RETURN AFTER"
};
// Specify the query parameters
let var = Some(map! {
String::from("what") => what.could_be_table(),
String::from("data") => data,
=> &self.vars
});
// Execute the query on the database
let mut res = kvs.execute(sql, &self.session, var).await?;
// Extract the first query result
let res = match one {
true => res.remove(0).result?.first(),
false => res.remove(0).result?,
};
// Return the result to the client
Ok(res)
}
// ------------------------------
// Methods for updating
// ------------------------------
async fn update(&self, what: Value, data: Value) -> Result<Value, Error> {
// Return a single result?
let one = what.is_thing();
// Get a database reference
let kvs = DB.get().unwrap();
// Specify the SQL query string
let sql = if data.is_none_or_null() {
"UPDATE $what RETURN AFTER"
} else {
"UPDATE $what CONTENT $data RETURN AFTER"
};
// Specify the query parameters
let var = Some(map! {
String::from("what") => what.could_be_table(),
String::from("data") => data,
=> &self.vars
});
// Execute the query on the database
let mut res = kvs.execute(sql, &self.session, var).await?;
// Extract the first query result
let res = match one {
true => res.remove(0).result?.first(),
false => res.remove(0).result?,
};
// Return the result to the client
Ok(res)
}
// ------------------------------
// Methods for merging
// ------------------------------
async fn merge(&self, what: Value, data: Value) -> Result<Value, Error> {
// Return a single result?
let one = what.is_thing();
// Get a database reference
let kvs = DB.get().unwrap();
// Specify the SQL query string
let sql = if data.is_none_or_null() {
"UPDATE $what RETURN AFTER"
} else {
"UPDATE $what MERGE $data RETURN AFTER"
};
// Specify the query parameters
let var = Some(map! {
String::from("what") => what.could_be_table(),
String::from("data") => data,
=> &self.vars
});
// Execute the query on the database
let mut res = kvs.execute(sql, &self.session, var).await?;
// Extract the first query result
let res = match one {
true => res.remove(0).result?.first(),
false => res.remove(0).result?,
};
// Return the result to the client
Ok(res)
}
// ------------------------------
// Methods for patching
// ------------------------------
async fn patch(&self, what: Value, data: Value, diff: Value) -> Result<Value, Error> {
// Return a single result?
let one = what.is_thing();
// Get a database reference
let kvs = DB.get().unwrap();
// Specify the SQL query string
let sql = match diff.is_true() {
true => "UPDATE $what PATCH $data RETURN DIFF",
false => "UPDATE $what PATCH $data RETURN AFTER",
};
// Specify the query parameters
let var = Some(map! {
String::from("what") => what.could_be_table(),
String::from("data") => data,
=> &self.vars
});
// Execute the query on the database
let mut res = kvs.execute(sql, &self.session, var).await?;
// Extract the first query result
let res = match one {
true => res.remove(0).result?.first(),
false => res.remove(0).result?,
};
// Return the result to the client
Ok(res)
}
// ------------------------------
// Methods for deleting
// ------------------------------
async fn delete(&self, what: Value) -> Result<Value, Error> {
// Return a single result?
let one = what.is_thing();
// Get a database reference
let kvs = DB.get().unwrap();
// Specify the SQL query string
let sql = "DELETE $what RETURN BEFORE";
// Specify the query parameters
let var = Some(map! {
String::from("what") => what.could_be_table(),
=> &self.vars
});
// Execute the query on the database
let mut res = kvs.execute(sql, &self.session, var).await?;
// Extract the first query result
let res = match one {
true => res.remove(0).result?.first(),
false => res.remove(0).result?,
};
// Return the result to the client
Ok(res)
}
// ------------------------------
// Methods for querying
// ------------------------------
async fn query(&self, sql: Value) -> Result<Vec<Response>, Error> {
// Get a database reference
let kvs = DB.get().unwrap();
// Specify the query parameters
let var = Some(self.vars.clone());
// Execute the query on the database
let res = match sql {
Value::Query(sql) => kvs.process(sql, &self.session, var).await?,
Value::Strand(sql) => kvs.execute(&sql, &self.session, var).await?,
_ => unreachable!(),
};
// Post-process hooks for web layer
for response in &res {
self.handle_live_query_results(response).await;
}
// Return the result to the client
Ok(res)
}
async fn query_with(&self, sql: Value, mut vars: Object) -> Result<Vec<Response>, Error> {
// Get a database reference
let kvs = DB.get().unwrap();
// Specify the query parameters
let var = Some(mrg! { vars.0, &self.vars });
// Execute the query on the database
let res = match sql {
Value::Query(sql) => kvs.process(sql, &self.session, var).await?,
Value::Strand(sql) => kvs.execute(&sql, &self.session, var).await?,
_ => unreachable!(),
};
// Post-process hooks for web layer
for response in &res {
self.handle_live_query_results(response).await;
}
// Return the result to the client
Ok(res)
}
// ------------------------------
// Private methods
// ------------------------------
async fn handle_live_query_results(&self, res: &Response) {
match &res.query_type {
QueryType::Live => {
if let Ok(Value::Uuid(lqid)) = &res.result {
// Match on Uuid type
LIVE_QUERIES.write().await.insert(lqid.0, self.id);
trace!("Registered live query {} on websocket {}", lqid, self.id);
}
}
QueryType::Kill => {
if let Ok(Value::Uuid(lqid)) = &res.result {
if let Some(id) = LIVE_QUERIES.write().await.remove(&lqid.0) {
trace!("Unregistered live query {} on websocket {}", lqid, id);
}
}
}
_ => {}
} }
} }
} }
impl RpcContext for Connection {
fn kvs(&self) -> &Datastore {
DB.get().unwrap()
}
fn session(&self) -> &Session {
&self.session
}
fn session_mut(&mut self) -> &mut Session {
&mut self.session
}
fn vars(&self) -> &BTreeMap<String, Value> {
&self.vars
}
fn vars_mut(&mut self) -> &mut BTreeMap<String, Value> {
&mut self.vars
}
fn version_data(&self) -> impl Into<Data> {
format!("{PKG_NAME}-{}", *PKG_VERSION)
}
const LQ_SUPPORT: bool = true;
async fn handle_live(&self, lqid: &Uuid) {
LIVE_QUERIES.write().await.insert(*lqid, self.id);
trace!("Registered live query {} on websocket {}", lqid, self.id);
}
async fn handle_kill(&self, lqid: &Uuid) {
if let Some(id) = LIVE_QUERIES.write().await.remove(lqid) {
trace!("Unregistered live query {} on websocket {}", lqid, id);
}
}
// reimplimentaions
async fn signup(&mut self, params: Array) -> Result<impl Into<Data>, RpcError> {
let Ok(Value::Object(v)) = params.needs_one() else {
return Err(RpcError::InvalidParams);
};
let out: Result<Value, RpcError> =
surrealdb::iam::signup::signup(DB.get().unwrap(), &mut self.session, v)
.await
.map(Into::into)
.map_err(Into::into);
out
}
async fn signin(&mut self, params: Array) -> Result<impl Into<Data>, RpcError> {
let Ok(Value::Object(v)) = params.needs_one() else {
return Err(RpcError::InvalidParams);
};
let out: Result<Value, RpcError> =
surrealdb::iam::signin::signin(DB.get().unwrap(), &mut self.session, v)
.await
.map(Into::into)
.map_err(Into::into);
out
}
async fn authenticate(&mut self, params: Array) -> Result<impl Into<Data>, RpcError> {
let Ok(Value::Strand(token)) = params.needs_one() else {
return Err(RpcError::InvalidParams);
};
surrealdb::iam::verify::token(DB.get().unwrap(), &mut self.session, &token.0).await?;
Ok(Value::None)
}
}

View file

@ -3,6 +3,7 @@ use revision::revisioned;
use revision::Revisioned; use revision::Revisioned;
use serde::Serialize; use serde::Serialize;
use std::borrow::Cow; use std::borrow::Cow;
use surrealdb::rpc::RpcError;
use surrealdb::sql::Value; use surrealdb::sql::Value;
#[derive(Clone, Debug, Serialize)] #[derive(Clone, Debug, Serialize)]
@ -51,6 +52,20 @@ impl From<Error> for Failure {
} }
} }
impl From<RpcError> for Failure {
fn from(err: RpcError) -> Self {
match err {
RpcError::ParseError => Failure::PARSE_ERROR,
RpcError::InvalidRequest => Failure::INVALID_REQUEST,
RpcError::MethodNotFound => Failure::METHOD_NOT_FOUND,
RpcError::InvalidParams => Failure::INVALID_PARAMS,
RpcError::InternalError(_) => Failure::custom(err.to_string()),
RpcError::Thrown(_) => Failure::custom(err.to_string()),
_ => Failure::custom(err.to_string()),
}
}
}
impl From<Failure> for Value { impl From<Failure> for Value {
fn from(err: Failure) -> Self { fn from(err: Failure) -> Self {
map! { map! {

View file

@ -1,22 +1,40 @@
use crate::rpc::failure::Failure; use crate::net::headers::ContentType;
use crate::rpc::request::Request; use crate::rpc::request::Request;
use crate::rpc::response::Response; use crate::rpc::response::Response;
use axum::extract::ws::Message; use axum::extract::ws::Message;
use axum::response::IntoResponse;
use axum::response::Response as AxumResponse;
use bytes::Bytes;
use http::header::CONTENT_TYPE;
use http::HeaderValue;
use surrealdb::rpc::RpcError;
use surrealdb::sql::serde::deserialize; use surrealdb::sql::serde::deserialize;
use surrealdb::sql::Value; use surrealdb::sql::Value;
pub fn req(msg: Message) -> Result<Request, Failure> { pub fn req_ws(msg: Message) -> Result<Request, RpcError> {
match msg { match msg {
Message::Binary(val) => { Message::Binary(val) => {
deserialize::<Value>(&val).map_err(|_| Failure::PARSE_ERROR)?.try_into() deserialize::<Value>(&val).map_err(|_| RpcError::ParseError)?.try_into()
} }
_ => Err(Failure::INVALID_REQUEST), _ => Err(RpcError::InvalidRequest),
} }
} }
pub fn res(res: Response) -> Result<(usize, Message), Failure> { pub fn res_ws(res: Response) -> Result<(usize, Message), RpcError> {
// Serialize the response with full internal type information // Serialize the response with full internal type information
let res = surrealdb::sql::serde::serialize(&res).unwrap(); let res = surrealdb::sql::serde::serialize(&res).unwrap();
// Return the message length, and message as binary // Return the message length, and message as binary
Ok((res.len(), Message::Binary(res))) Ok((res.len(), Message::Binary(res)))
} }
pub fn req_http(val: &Bytes) -> Result<Request, RpcError> {
deserialize::<Value>(val).map_err(|_| RpcError::ParseError)?.try_into()
}
pub fn res_http(res: Response) -> Result<AxumResponse, RpcError> {
// Serialize the response with full internal type information
let res = surrealdb::sql::serde::serialize(&res).unwrap();
// Return the message length, and message as binary
// TODO: Check what this header should be, I'm being consistent with /sql
Ok(([(CONTENT_TYPE, HeaderValue::from(ContentType::Surrealdb))], res).into_response())
}

View file

@ -1,27 +1,32 @@
mod convert; mod convert;
use bytes::Bytes;
pub use convert::Cbor; pub use convert::Cbor;
use http::header::CONTENT_TYPE;
use http::HeaderValue;
use surrealdb::rpc::RpcError;
use crate::rpc::failure::Failure; use crate::net::headers::ContentType;
use crate::rpc::request::Request; use crate::rpc::request::Request;
use crate::rpc::response::Response; use crate::rpc::response::Response;
use axum::extract::ws::Message; use axum::extract::ws::Message;
use axum::response::{IntoResponse, Response as AxumResponse};
use ciborium::Value as Data; use ciborium::Value as Data;
pub fn req(msg: Message) -> Result<Request, Failure> { pub fn req_ws(msg: Message) -> Result<Request, RpcError> {
match msg { match msg {
Message::Text(val) => { Message::Text(val) => {
surrealdb::sql::value(&val).map_err(|_| Failure::PARSE_ERROR)?.try_into() surrealdb::sql::value(&val).map_err(|_| RpcError::ParseError)?.try_into()
} }
Message::Binary(val) => ciborium::from_reader::<Data, _>(&mut val.as_slice()) Message::Binary(val) => ciborium::from_reader::<Data, _>(&mut val.as_slice())
.map_err(|_| Failure::PARSE_ERROR) .map_err(|_| RpcError::ParseError)
.map(Cbor)? .map(Cbor)?
.try_into(), .try_into(),
_ => Err(Failure::INVALID_REQUEST), _ => Err(RpcError::InvalidRequest),
} }
} }
pub fn res(res: Response) -> Result<(usize, Message), Failure> { pub fn res_ws(res: Response) -> Result<(usize, Message), RpcError> {
// Convert the response into a value // Convert the response into a value
let val: Cbor = res.into_value().try_into()?; let val: Cbor = res.into_value().try_into()?;
// Create a new vector for encoding output // Create a new vector for encoding output
@ -31,3 +36,22 @@ pub fn res(res: Response) -> Result<(usize, Message), Failure> {
// Return the message length, and message as binary // Return the message length, and message as binary
Ok((res.len(), Message::Binary(res))) Ok((res.len(), Message::Binary(res)))
} }
pub fn req_http(body: Bytes) -> Result<Request, RpcError> {
let val: Vec<u8> = body.into();
ciborium::from_reader::<Data, _>(&mut val.as_slice())
.map_err(|_| RpcError::ParseError)
.map(Cbor)?
.try_into()
}
pub fn res_http(res: Response) -> Result<AxumResponse, RpcError> {
// Convert the response into a value
let val: Cbor = res.into_value().try_into()?;
// Create a new vector for encoding output
let mut res = Vec::new();
// Serialize the value into CBOR binary data
ciborium::into_writer(&val.0, &mut res).unwrap();
// Return the message length, and message as binary
Ok(([(CONTENT_TYPE, HeaderValue::from(ContentType::Surrealdb))], res).into_response())
}

View file

@ -1,18 +1,23 @@
use crate::rpc::failure::Failure;
use crate::rpc::request::Request; use crate::rpc::request::Request;
use crate::rpc::response::Response; use crate::rpc::response::Response;
use axum::extract::ws::Message; use axum::extract::ws::Message;
use axum::response::IntoResponse;
use axum::response::Response as AxumResponse;
use bytes::Bytes;
use http::StatusCode;
use surrealdb::rpc::RpcError;
use surrealdb::sql;
pub fn req(msg: Message) -> Result<Request, Failure> { pub fn req_ws(msg: Message) -> Result<Request, RpcError> {
match msg { match msg {
Message::Text(val) => { Message::Text(val) => {
surrealdb::sql::value(&val).map_err(|_| Failure::PARSE_ERROR)?.try_into() surrealdb::sql::value(&val).map_err(|_| RpcError::ParseError)?.try_into()
} }
_ => Err(Failure::INVALID_REQUEST), _ => Err(RpcError::InvalidRequest),
} }
} }
pub fn res(res: Response) -> Result<(usize, Message), Failure> { pub fn res_ws(res: Response) -> Result<(usize, Message), RpcError> {
// Convert the response into simplified JSON // Convert the response into simplified JSON
let val = res.into_json(); let val = res.into_json();
// Serialize the response with simplified type information // Serialize the response with simplified type information
@ -20,3 +25,18 @@ pub fn res(res: Response) -> Result<(usize, Message), Failure> {
// Return the message length, and message as binary // Return the message length, and message as binary
Ok((res.len(), Message::Text(res))) Ok((res.len(), Message::Text(res)))
} }
pub fn req_http(val: &Bytes) -> Result<Request, RpcError> {
sql::value(std::str::from_utf8(val).or(Err(RpcError::ParseError))?)
.or(Err(RpcError::ParseError))?
.try_into()
}
pub fn res_http(res: Response) -> Result<AxumResponse, RpcError> {
// Convert the response into simplified JSON
let val = res.into_json();
// Serialize the response with simplified type information
let res = serde_json::to_string(&val).unwrap();
// Return the message length, and message as binary
Ok((StatusCode::OK, res).into_response())
}

View file

@ -4,10 +4,14 @@ mod json;
pub mod msgpack; pub mod msgpack;
mod revision; mod revision;
use crate::net::headers::{Accept, ContentType};
use crate::rpc::failure::Failure; use crate::rpc::failure::Failure;
use crate::rpc::request::Request; use crate::rpc::request::Request;
use crate::rpc::response::Response; use crate::rpc::response::Response;
use axum::extract::ws::Message; use axum::extract::ws::Message;
use axum::response::Response as AxumResponse;
use bytes::Bytes;
use surrealdb::rpc::RpcError;
pub const PROTOCOLS: [&str; 5] = [ pub const PROTOCOLS: [&str; 5] = [
"json", // For basic JSON serialisation "json", // For basic JSON serialisation
@ -25,6 +29,33 @@ pub enum Format {
Msgpack, // For basic Msgpack serialisation Msgpack, // For basic Msgpack serialisation
Bincode, // For full internal serialisation Bincode, // For full internal serialisation
Revision, // For full versioned serialisation Revision, // For full versioned serialisation
Unsupported, // Unsupported format
}
impl From<&Accept> for Format {
fn from(value: &Accept) -> Self {
match value {
Accept::TextPlain => Format::None,
Accept::ApplicationJson => Format::Json,
Accept::ApplicationCbor => Format::Cbor,
Accept::ApplicationPack => Format::Msgpack,
Accept::ApplicationOctetStream => Format::Unsupported,
Accept::Surrealdb => Format::Bincode,
}
}
}
impl From<&ContentType> for Format {
fn from(value: &ContentType) -> Self {
match value {
ContentType::TextPlain => Format::None,
ContentType::ApplicationJson => Format::Json,
ContentType::ApplicationCbor => Format::Cbor,
ContentType::ApplicationPack => Format::Msgpack,
ContentType::ApplicationOctetStream => Format::Unsupported,
ContentType::Surrealdb => Format::Bincode,
}
}
} }
impl From<&str> for Format { impl From<&str> for Format {
@ -46,25 +77,53 @@ impl Format {
matches!(self, Format::None) matches!(self, Format::None)
} }
/// Process a request using the specified format /// Process a request using the specified format
pub fn req(&self, msg: Message) -> Result<Request, Failure> { pub fn req_ws(&self, msg: Message) -> Result<Request, Failure> {
match self { match self {
Self::None => unreachable!(), // We should never arrive at this code Self::None => unreachable!(), // We should never arrive at this code
Self::Json => json::req(msg), Self::Unsupported => unreachable!(), // We should never arrive at this code
Self::Cbor => cbor::req(msg), Self::Json => json::req_ws(msg),
Self::Msgpack => msgpack::req(msg), Self::Cbor => cbor::req_ws(msg),
Self::Bincode => bincode::req(msg), Self::Msgpack => msgpack::req_ws(msg),
Self::Revision => revision::req(msg), Self::Bincode => bincode::req_ws(msg),
Self::Revision => revision::req_ws(msg),
}
.map_err(Into::into)
}
/// Process a response using the specified format
pub fn res_ws(&self, res: Response) -> Result<(usize, Message), Failure> {
match self {
Self::None => unreachable!(), // We should never arrive at this code
Self::Unsupported => unreachable!(), // We should never arrive at this code
Self::Json => json::res_ws(res),
Self::Cbor => cbor::res_ws(res),
Self::Msgpack => msgpack::res_ws(res),
Self::Bincode => bincode::res_ws(res),
Self::Revision => revision::res_ws(res),
}
.map_err(Into::into)
}
/// Process a request using the specified format
pub fn req_http(&self, body: Bytes) -> Result<Request, RpcError> {
match self {
Self::None => unreachable!(), // We should never arrive at this code
Self::Unsupported => unreachable!(), // We should never arrive at this code
Self::Json => json::req_http(&body),
Self::Cbor => cbor::req_http(body),
Self::Msgpack => msgpack::req_http(body),
Self::Bincode => bincode::req_http(&body),
Self::Revision => revision::req_http(body),
} }
} }
/// Process a response using the specified format /// Process a response using the specified format
pub fn res(&self, res: Response) -> Result<(usize, Message), Failure> { pub fn res_http(&self, res: Response) -> Result<AxumResponse, RpcError> {
match self { match self {
Self::None => unreachable!(), // We should never arrive at this code Self::None => unreachable!(), // We should never arrive at this code
Self::Json => json::res(res), Self::Unsupported => unreachable!(), // We should never arrive at this code
Self::Cbor => cbor::res(res), Self::Json => json::res_http(res),
Self::Msgpack => msgpack::res(res), Self::Cbor => cbor::res_http(res),
Self::Bincode => bincode::res(res), Self::Msgpack => msgpack::res_http(res),
Self::Revision => revision::res(res), Self::Bincode => bincode::res_http(res),
Self::Revision => revision::res_http(res),
} }
} }
} }

View file

@ -1,26 +1,31 @@
mod convert; mod convert;
use bytes::Bytes;
pub use convert::Pack; pub use convert::Pack;
use http::header::CONTENT_TYPE;
use http::HeaderValue;
use surrealdb::rpc::RpcError;
use crate::rpc::failure::Failure; use crate::net::headers::ContentType;
use crate::rpc::request::Request; use crate::rpc::request::Request;
use crate::rpc::response::Response; use crate::rpc::response::Response;
use axum::extract::ws::Message; use axum::extract::ws::Message;
use axum::response::{IntoResponse, Response as AxumResponse};
pub fn req(msg: Message) -> Result<Request, Failure> { pub fn req_ws(msg: Message) -> Result<Request, RpcError> {
match msg { match msg {
Message::Text(val) => { Message::Text(val) => {
surrealdb::sql::value(&val).map_err(|_| Failure::PARSE_ERROR)?.try_into() surrealdb::sql::value(&val).map_err(|_| RpcError::ParseError)?.try_into()
} }
Message::Binary(val) => rmpv::decode::read_value(&mut val.as_slice()) Message::Binary(val) => rmpv::decode::read_value(&mut val.as_slice())
.map_err(|_| Failure::PARSE_ERROR) .map_err(|_| RpcError::ParseError)
.map(Pack)? .map(Pack)?
.try_into(), .try_into(),
_ => Err(Failure::INVALID_REQUEST), _ => Err(RpcError::InvalidRequest),
} }
} }
pub fn res(res: Response) -> Result<(usize, Message), Failure> { pub fn res_ws(res: Response) -> Result<(usize, Message), RpcError> {
// Convert the response into a value // Convert the response into a value
let val: Pack = res.into_value().try_into()?; let val: Pack = res.into_value().try_into()?;
// Create a new vector for encoding output // Create a new vector for encoding output
@ -30,3 +35,21 @@ pub fn res(res: Response) -> Result<(usize, Message), Failure> {
// Return the message length, and message as binary // Return the message length, and message as binary
Ok((res.len(), Message::Binary(res))) Ok((res.len(), Message::Binary(res)))
} }
pub fn req_http(body: Bytes) -> Result<Request, RpcError> {
let val: Vec<u8> = body.into();
rmpv::decode::read_value(&mut val.as_slice())
.map_err(|_| RpcError::ParseError)
.map(Pack)?
.try_into()
}
pub fn res_http(res: Response) -> Result<AxumResponse, RpcError> {
// Convert the response into a value
let val: Pack = res.into_value().try_into()?;
// Create a new vector for encoding output
let mut res = Vec::new();
// Serialize the value into MsgPack binary data
rmpv::encode::write_value(&mut res, &val.0).unwrap();
// Return the message length, and message as binary
Ok(([(CONTENT_TYPE, HeaderValue::from(ContentType::ApplicationPack))], res).into_response())
}

View file

@ -1,23 +1,42 @@
use crate::rpc::failure::Failure; use crate::net::headers::ContentType;
use crate::rpc::request::Request; use crate::rpc::request::Request;
use crate::rpc::response::Response; use crate::rpc::response::Response;
use axum::extract::ws::Message; use axum::extract::ws::Message;
use axum::response::{IntoResponse, Response as AxumResponse};
use bytes::Bytes;
use http::header::CONTENT_TYPE;
use http::HeaderValue;
use revision::Revisioned; use revision::Revisioned;
use surrealdb::rpc::RpcError;
use surrealdb::sql::Value; use surrealdb::sql::Value;
pub fn req(msg: Message) -> Result<Request, Failure> { pub fn req_ws(msg: Message) -> Result<Request, RpcError> {
match msg { match msg {
Message::Binary(val) => Value::deserialize_revisioned(&mut val.as_slice()) Message::Binary(val) => Value::deserialize_revisioned(&mut val.as_slice())
.map_err(|_| Failure::PARSE_ERROR)? .map_err(|_| RpcError::ParseError)?
.try_into(), .try_into(),
_ => Err(Failure::INVALID_REQUEST), _ => Err(RpcError::InvalidRequest),
} }
} }
pub fn res(res: Response) -> Result<(usize, Message), Failure> { pub fn res_ws(res: Response) -> Result<(usize, Message), RpcError> {
// Serialize the response with full internal type information // Serialize the response with full internal type information
let mut buf = Vec::new(); let mut buf = Vec::new();
res.serialize_revisioned(&mut buf).unwrap(); res.serialize_revisioned(&mut buf).unwrap();
// Return the message length, and message as binary // Return the message length, and message as binary
Ok((buf.len(), Message::Binary(buf))) Ok((buf.len(), Message::Binary(buf)))
} }
pub fn req_http(body: Bytes) -> Result<Request, RpcError> {
let val: Vec<u8> = body.into();
Value::deserialize_revisioned(&mut val.as_slice()).map_err(|_| RpcError::ParseError)?.try_into()
}
pub fn res_http(res: Response) -> Result<AxumResponse, RpcError> {
// Serialize the response with full internal type information
let mut buf = Vec::new();
res.serialize_revisioned(&mut buf).unwrap();
// Return the message length, and message as binary
// TODO: Check what this header should be, new header needed for revisioned
Ok(([(CONTENT_TYPE, HeaderValue::from(ContentType::Surrealdb))], buf).into_response())
}

View file

@ -2,6 +2,7 @@ pub mod args;
pub mod connection; pub mod connection;
pub mod failure; pub mod failure;
pub mod format; pub mod format;
pub mod post_context;
pub mod request; pub mod request;
pub mod response; pub mod response;

103
src/rpc/post_context.rs Normal file
View file

@ -0,0 +1,103 @@
use std::collections::BTreeMap;
use crate::cnf::{PKG_NAME, PKG_VERSION};
use surrealdb::dbs::Session;
use surrealdb::kvs::Datastore;
use surrealdb::rpc::args::Take;
use surrealdb::rpc::Data;
use surrealdb::rpc::RpcContext;
use surrealdb::rpc::RpcError;
use surrealdb::sql::Array;
use surrealdb::sql::Value;
pub struct PostRpcContext<'a> {
pub kvs: &'a Datastore,
pub session: Session,
pub vars: BTreeMap<String, Value>,
}
impl<'a> PostRpcContext<'a> {
pub fn new(kvs: &'a Datastore, session: Session, vars: BTreeMap<String, Value>) -> Self {
Self {
kvs,
session,
vars,
}
}
}
impl RpcContext for PostRpcContext<'_> {
fn kvs(&self) -> &Datastore {
self.kvs
}
fn session(&self) -> &Session {
&self.session
}
fn session_mut(&mut self) -> &mut Session {
&mut self.session
}
fn vars(&self) -> &BTreeMap<String, Value> {
&self.vars
}
fn vars_mut(&mut self) -> &mut BTreeMap<String, Value> {
&mut self.vars
}
fn version_data(&self) -> impl Into<Data> {
let val: Value = format!("{PKG_NAME}-{}", *PKG_VERSION).into();
val
}
// disable:
// doesn't do anything so shouldn't be supported
async fn set(&mut self, _params: Array) -> Result<impl Into<Data>, RpcError> {
let out: Result<Value, RpcError> = Err(RpcError::MethodNotFound);
out
}
// doesn't do anything so shouldn't be supported
async fn unset(&mut self, _params: Array) -> Result<impl Into<Data>, RpcError> {
let out: Result<Value, RpcError> = Err(RpcError::MethodNotFound);
out
}
// reimplimentaions:
async fn signup(&mut self, params: Array) -> Result<impl Into<Data>, RpcError> {
let Ok(Value::Object(v)) = params.needs_one() else {
return Err(RpcError::InvalidParams);
};
let out: Result<Value, RpcError> =
surrealdb::iam::signup::signup(self.kvs, &mut self.session, v)
.await
.map(Into::into)
.map_err(Into::into);
out
}
async fn signin(&mut self, params: Array) -> Result<impl Into<Data>, RpcError> {
let Ok(Value::Object(v)) = params.needs_one() else {
return Err(RpcError::InvalidParams);
};
let out: Result<Value, RpcError> =
surrealdb::iam::signin::signin(self.kvs, &mut self.session, v)
.await
.map(Into::into)
.map_err(Into::into);
out
}
async fn authenticate(&mut self, params: Array) -> Result<impl Into<Data>, RpcError> {
let Ok(Value::Strand(token)) = params.needs_one() else {
return Err(RpcError::InvalidParams);
};
surrealdb::iam::verify::token(self.kvs, &mut self.session, &token.0).await?;
Ok(Value::None)
}
}

View file

@ -1,7 +1,7 @@
use crate::rpc::failure::Failure;
use crate::rpc::format::cbor::Cbor; use crate::rpc::format::cbor::Cbor;
use crate::rpc::format::msgpack::Pack; use crate::rpc::format::msgpack::Pack;
use once_cell::sync::Lazy; use once_cell::sync::Lazy;
use surrealdb::rpc::RpcError;
use surrealdb::sql::Part; use surrealdb::sql::Part;
use surrealdb::sql::{Array, Value}; use surrealdb::sql::{Array, Value};
@ -16,22 +16,22 @@ pub struct Request {
} }
impl TryFrom<Cbor> for Request { impl TryFrom<Cbor> for Request {
type Error = Failure; type Error = RpcError;
fn try_from(val: Cbor) -> Result<Self, Failure> { fn try_from(val: Cbor) -> Result<Self, RpcError> {
<Cbor as TryInto<Value>>::try_into(val).map_err(|_| Failure::INVALID_REQUEST)?.try_into() <Cbor as TryInto<Value>>::try_into(val).map_err(|_| RpcError::InvalidRequest)?.try_into()
} }
} }
impl TryFrom<Pack> for Request { impl TryFrom<Pack> for Request {
type Error = Failure; type Error = RpcError;
fn try_from(val: Pack) -> Result<Self, Failure> { fn try_from(val: Pack) -> Result<Self, RpcError> {
<Pack as TryInto<Value>>::try_into(val).map_err(|_| Failure::INVALID_REQUEST)?.try_into() <Pack as TryInto<Value>>::try_into(val).map_err(|_| RpcError::InvalidRequest)?.try_into()
} }
} }
impl TryFrom<Value> for Request { impl TryFrom<Value> for Request {
type Error = Failure; type Error = RpcError;
fn try_from(val: Value) -> Result<Self, Failure> { fn try_from(val: Value) -> Result<Self, RpcError> {
// Fetch the 'id' argument // Fetch the 'id' argument
let id = match val.pick(&*ID) { let id = match val.pick(&*ID) {
v if v.is_none() => None, v if v.is_none() => None,
@ -40,12 +40,12 @@ impl TryFrom<Value> for Request {
v if v.is_number() => Some(v), v if v.is_number() => Some(v),
v if v.is_strand() => Some(v), v if v.is_strand() => Some(v),
v if v.is_datetime() => Some(v), v if v.is_datetime() => Some(v),
_ => return Err(Failure::INVALID_REQUEST), _ => return Err(RpcError::InvalidRequest),
}; };
// Fetch the 'method' argument // Fetch the 'method' argument
let method = match val.pick(&*METHOD) { let method = match val.pick(&*METHOD) {
Value::Strand(v) => v.to_raw(), Value::Strand(v) => v.to_raw(),
_ => return Err(Failure::INVALID_REQUEST), _ => return Err(RpcError::InvalidRequest),
}; };
// Fetch the 'params' argument // Fetch the 'params' argument
let params = match val.pick(&*PARAMS) { let params = match val.pick(&*PARAMS) {

View file

@ -8,61 +8,10 @@ use serde::Serialize;
use serde_json::Value as Json; use serde_json::Value as Json;
use std::sync::Arc; use std::sync::Arc;
use surrealdb::channel::Sender; use surrealdb::channel::Sender;
use surrealdb::dbs; use surrealdb::rpc::Data;
use surrealdb::dbs::Notification;
use surrealdb::sql;
use surrealdb::sql::Value; use surrealdb::sql::Value;
use tracing::Span; use tracing::Span;
/// The data returned by the database
// The variants here should be in exactly the same order as `surrealdb::engine::remote::ws::Data`
// In future, they will possibly be merged to avoid having to keep them in sync.
#[derive(Debug, Serialize)]
#[revisioned(revision = 1)]
pub enum Data {
/// Generally methods return a `sql::Value`
Other(Value),
/// The query methods, `query` and `query_with` return a `Vec` of responses
Query(Vec<dbs::Response>),
/// Live queries return a notification
Live(Notification),
// Add new variants here
}
impl From<Value> for Data {
fn from(v: Value) -> Self {
Data::Other(v)
}
}
impl From<String> for Data {
fn from(v: String) -> Self {
Data::Other(Value::from(v))
}
}
impl From<Notification> for Data {
fn from(n: Notification) -> Self {
Data::Live(n)
}
}
impl From<Vec<dbs::Response>> for Data {
fn from(v: Vec<dbs::Response>) -> Self {
Data::Query(v)
}
}
impl From<Data> for Value {
fn from(val: Data) -> Self {
match val {
Data::Query(v) => sql::to_value(v).unwrap(),
Data::Live(v) => sql::to_value(v).unwrap(),
Data::Other(v) => v,
}
}
}
#[derive(Debug, Serialize)] #[derive(Debug, Serialize)]
#[revisioned(revision = 1)] #[revisioned(revision = 1)]
pub struct Response { pub struct Response {
@ -111,7 +60,7 @@ impl Response {
span.record("rpc.error_message", err.message.as_ref()); span.record("rpc.error_message", err.message.as_ref());
} }
// Process the response for the format // Process the response for the format
let (len, msg) = fmt.res(self).unwrap(); let (len, msg) = fmt.res_ws(self).unwrap();
// Send the message to the write channel // Send the message to the write channel
if chn.send(msg).await.is_ok() { if chn.send(msg).await.is_ok() {
record_rpc(cx.as_ref(), len, is_error); record_rpc(cx.as_ref(), len, is_error);