[security] Introduce the Datastore capabilities (#2489)
This commit is contained in:
parent
5945146459
commit
b5b6f6f1d4
25 changed files with 4522 additions and 3188 deletions
2
Cargo.lock
generated
2
Cargo.lock
generated
|
@ -5109,6 +5109,7 @@ dependencies = [
|
|||
"tracing-subscriber",
|
||||
"urlencoding",
|
||||
"uuid",
|
||||
"wiremock",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -5141,6 +5142,7 @@ dependencies = [
|
|||
"geo 0.25.1",
|
||||
"indexmap 1.9.3",
|
||||
"indxdb",
|
||||
"ipnet",
|
||||
"jsonwebtoken",
|
||||
"lexicmp",
|
||||
"lru",
|
||||
|
|
|
@ -87,6 +87,7 @@ test-log = { version = "0.2.12", features = ["trace"] }
|
|||
tokio-stream = { version = "0.1", features = ["net"] }
|
||||
tokio-tungstenite = { version = "0.18.0" }
|
||||
tonic = "0.8.3"
|
||||
wiremock = "0.5.19"
|
||||
|
||||
[package.metadata.deb]
|
||||
maintainer-scripts = "pkg/deb/"
|
||||
|
|
|
@ -24,20 +24,20 @@ args = ["clippy", "--all-targets", "--all-features", "--", "-D", "warnings"]
|
|||
[tasks.ci-cli-integration]
|
||||
category = "CI - INTEGRATION TESTS"
|
||||
command = "cargo"
|
||||
env = { RUST_LOG = "cli_integration=debug" }
|
||||
args = ["test", "--locked", "--no-default-features", "--features", "storage-mem", "--workspace", "--test", "cli_integration", "--", "--nocapture"]
|
||||
env = { RUST_LOG={ value = "cli_integration=debug", condition = { env_not_set = ["RUST_LOG"] } } }
|
||||
args = ["test", "--locked", "--no-default-features", "--features", "storage-mem,http,scripting", "--workspace", "--test", "cli_integration", "--", "cli_integration", "--nocapture"]
|
||||
|
||||
[tasks.ci-http-integration]
|
||||
category = "CI - INTEGRATION TESTS"
|
||||
command = "cargo"
|
||||
env = { RUST_LOG = "http_integration=debug" }
|
||||
args = ["test", "--locked", "--no-default-features", "--features", "storage-mem", "--workspace", "--test", "http_integration", "--", "--nocapture"]
|
||||
env = { RUST_LOG={ value = "http_integration=debug", condition = { env_not_set = ["RUST_LOG"] } } }
|
||||
args = ["test", "--locked", "--no-default-features", "--features", "storage-mem", "--workspace", "--test", "http_integration", "--", "http_integration", "--nocapture"]
|
||||
|
||||
[tasks.ci-ws-integration]
|
||||
category = "CI - INTEGRATION TESTS"
|
||||
command = "cargo"
|
||||
env = { RUST_LOG = "ws_integration=debug" }
|
||||
args = ["test", "--locked", "--no-default-features", "--features", "storage-mem", "--workspace", "--test", "ws_integration", "--", "--nocapture"]
|
||||
env = { RUST_LOG={ value = "ws_integration=debug", condition = { env_not_set = ["RUST_LOG"] } } }
|
||||
args = ["test", "--locked", "--no-default-features", "--features", "storage-mem", "--workspace", "--test", "ws_integration", "--", "ws_integration", "--nocapture"]
|
||||
|
||||
[tasks.ci-workspace-coverage]
|
||||
category = "CI - INTEGRATION TESTS"
|
||||
|
|
|
@ -69,25 +69,25 @@ args = ["bench", "--package", "surrealdb", "--no-default-features", "--features"
|
|||
[tasks.serve]
|
||||
category = "LOCAL USAGE"
|
||||
command = "cargo"
|
||||
args = ["run", "--no-default-features", "--features", "${DEV_FEATURES}", "--", "start"]
|
||||
args = ["run", "--no-default-features", "--features", "${DEV_FEATURES}", "--", "start", "${@}"]
|
||||
|
||||
# SQL
|
||||
[tasks.sql]
|
||||
category = "LOCAL USAGE"
|
||||
command = "cargo"
|
||||
args = ["run", "--no-default-features", "--features", "${DEV_FEATURES}", "--", "sql", "--conn", "ws://0.0.0.0:8000", "--multi", "--pretty"]
|
||||
args = ["run", "--no-default-features", "--features", "${DEV_FEATURES}", "--", "sql", "--conn", "ws://0.0.0.0:8000", "--multi", "--pretty", "${@}"]
|
||||
|
||||
# Quick
|
||||
[tasks.quick]
|
||||
category = "LOCAL USAGE"
|
||||
command = "cargo"
|
||||
args = ["build"]
|
||||
args = ["build", "${@}"]
|
||||
|
||||
# Build
|
||||
[tasks.build]
|
||||
category = "LOCAL USAGE"
|
||||
command = "cargo"
|
||||
args = ["build", "--release"]
|
||||
args = ["build", "--release", "${@}"]
|
||||
|
||||
# Default
|
||||
[tasks.default]
|
||||
|
|
|
@ -76,6 +76,7 @@ fuzzy-matcher = "0.3.7"
|
|||
geo = { version = "0.25.1", features = ["use-serde"] }
|
||||
indexmap = { version = "1.9.3", features = ["serde"] }
|
||||
indxdb = { version = "0.3.0", optional = true }
|
||||
ipnet = "2.8.0"
|
||||
js = { version = "0.4.0-beta.3", package = "rquickjs", features = ["array-buffer", "bindgen", "classes", "futures", "loader", "macro", "parallel", "properties","rust-alloc"], optional = true }
|
||||
jsonwebtoken = "8.3.0"
|
||||
lexicmp = "0.1.0"
|
||||
|
|
|
@ -1,16 +1,20 @@
|
|||
use crate::ctx::canceller::Canceller;
|
||||
use crate::ctx::reason::Reason;
|
||||
use crate::dbs::Notification;
|
||||
use crate::dbs::capabilities::{FuncTarget, NetTarget};
|
||||
use crate::dbs::{Capabilities, Notification};
|
||||
use crate::err::Error;
|
||||
use crate::idx::planner::QueryPlanner;
|
||||
use crate::sql::value::Value;
|
||||
use channel::Sender;
|
||||
use std::borrow::Cow;
|
||||
use std::collections::HashMap;
|
||||
use std::fmt::{self, Debug};
|
||||
use std::str::FromStr;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use trice::Instant;
|
||||
use url::Url;
|
||||
|
||||
impl<'a> From<Value> for Cow<'a, Value> {
|
||||
fn from(v: Value) -> Cow<'a, Value> {
|
||||
|
@ -36,6 +40,8 @@ pub struct Context<'a> {
|
|||
notifications: Option<Sender<Notification>>,
|
||||
// An optional query planner
|
||||
query_planner: Option<&'a QueryPlanner<'a>>,
|
||||
// Capabilities
|
||||
capabilities: Arc<Capabilities>,
|
||||
}
|
||||
|
||||
impl<'a> Default for Context<'a> {
|
||||
|
@ -65,6 +71,7 @@ impl<'a> Context<'a> {
|
|||
cancelled: Arc::new(AtomicBool::new(false)),
|
||||
notifications: None,
|
||||
query_planner: None,
|
||||
capabilities: Arc::new(Capabilities::default()),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -77,6 +84,7 @@ impl<'a> Context<'a> {
|
|||
cancelled: Arc::new(AtomicBool::new(false)),
|
||||
notifications: parent.notifications.clone(),
|
||||
query_planner: parent.query_planner,
|
||||
capabilities: parent.capabilities.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -190,4 +198,54 @@ impl<'a> Context<'a> {
|
|||
.collect(),
|
||||
)
|
||||
}
|
||||
|
||||
//
|
||||
// Capabilities
|
||||
//
|
||||
|
||||
/// Set the capabilities for this context
|
||||
pub fn add_capabilities(&mut self, caps: Capabilities) {
|
||||
self.capabilities = Arc::new(caps);
|
||||
}
|
||||
|
||||
/// Get the capabilities for this context
|
||||
pub fn get_capabilities(&self) -> Arc<Capabilities> {
|
||||
self.capabilities.clone()
|
||||
}
|
||||
|
||||
/// Check if scripting is allowed
|
||||
pub fn check_allowed_scripting(&self) -> Result<(), Error> {
|
||||
if !self.capabilities.is_allowed_scripting() {
|
||||
return Err(Error::ScriptingNotAllowed);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check if a function is allowed
|
||||
pub fn check_allowed_function(&self, target: &str) -> Result<(), Error> {
|
||||
let func_target = FuncTarget::from_str(target).map_err(|_| Error::InvalidFunction {
|
||||
name: target.to_string(),
|
||||
message: "Invalid function name".to_string(),
|
||||
})?;
|
||||
|
||||
if !self.capabilities.is_allowed_func(&func_target) {
|
||||
return Err(Error::FunctionNotAllowed(target.to_string()));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check if a network target is allowed
|
||||
pub fn check_allowed_net(&self, target: &Url) -> Result<(), Error> {
|
||||
match target.host() {
|
||||
Some(host)
|
||||
if self.capabilities.is_allowed_net(&NetTarget::Host(
|
||||
host.to_owned(),
|
||||
target.port_or_known_default(),
|
||||
)) =>
|
||||
{
|
||||
Ok(())
|
||||
}
|
||||
_ => Err(Error::NetTargetNotAllowed(target.to_string())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
479
lib/src/dbs/capabilities.rs
Normal file
479
lib/src/dbs/capabilities.rs
Normal file
|
@ -0,0 +1,479 @@
|
|||
use std::hash::Hash;
|
||||
use std::net::IpAddr;
|
||||
use std::{collections::HashSet, sync::Arc};
|
||||
|
||||
use ipnet::IpNet;
|
||||
use url::Url;
|
||||
|
||||
pub trait Target {
|
||||
fn matches(&self, elem: &Self) -> bool;
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Hash, Eq, PartialEq)]
|
||||
pub struct FuncTarget(pub String, pub Option<String>);
|
||||
|
||||
impl Target for FuncTarget {
|
||||
fn matches(&self, elem: &Self) -> bool {
|
||||
match self {
|
||||
Self(family, Some(name)) => {
|
||||
family == &elem.0 && (elem.1.as_ref().is_some_and(|n| n == name))
|
||||
}
|
||||
Self(family, None) => family == &elem.0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::str::FromStr for FuncTarget {
|
||||
type Err = String;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
// 'family::*' is treated as 'family'. They both match all functions in the family.
|
||||
let s = s.replace("::*", "");
|
||||
|
||||
let target = match s.split_once("::") {
|
||||
Some((family, name)) => Self(family.to_string(), Some(name.to_string())),
|
||||
_ => Self(s.to_string(), None),
|
||||
};
|
||||
Ok(target)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Hash, Eq, PartialEq)]
|
||||
pub enum NetTarget {
|
||||
Host(url::Host<String>, Option<u16>),
|
||||
IPNet(ipnet::IpNet),
|
||||
}
|
||||
|
||||
impl Target for NetTarget {
|
||||
fn matches(&self, elem: &Self) -> bool {
|
||||
match self {
|
||||
// If self contains a host and port, the elem must match both the host and port
|
||||
Self::Host(host, Some(port)) => match elem {
|
||||
Self::Host(_host, Some(_port)) => host == _host && port == _port,
|
||||
_ => false,
|
||||
},
|
||||
// If self contains a host but no port, the elem must match the host only
|
||||
Self::Host(host, None) => match elem {
|
||||
Self::Host(_host, _) => host == _host,
|
||||
_ => false,
|
||||
},
|
||||
// If self is an IPNet, it can match both an IPNet or a Host elem that contains an IPAddr
|
||||
Self::IPNet(ipnet) => match elem {
|
||||
Self::IPNet(_ipnet) => ipnet.contains(_ipnet),
|
||||
Self::Host(host, _) => match host {
|
||||
url::Host::Ipv4(ip) => ipnet.contains(&IpAddr::from(ip.to_owned())),
|
||||
url::Host::Ipv6(ip) => ipnet.contains(&IpAddr::from(ip.to_owned())),
|
||||
_ => false,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::str::FromStr for NetTarget {
|
||||
type Err = String;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
// If it's a valid IPNet, return it
|
||||
if let Ok(ipnet) = s.parse::<IpNet>() {
|
||||
return Ok(NetTarget::IPNet(ipnet));
|
||||
}
|
||||
|
||||
// If it's a valid IPAddr, return it as an IPNet
|
||||
if let Ok(ipnet) = s.parse::<IpAddr>() {
|
||||
return Ok(NetTarget::IPNet(IpNet::from(ipnet)));
|
||||
}
|
||||
|
||||
// Parse the host and port parts from a string in the form of 'host' or 'host:port'
|
||||
if let Ok(url) = Url::parse(format!("http://{s}").as_str()) {
|
||||
if let Some(host) = url.host() {
|
||||
// Url::parse will return port=None if the provided port was 80 (given we are using the http scheme). Get the original port from the string.
|
||||
if let Some(Ok(port)) = s.split(':').last().map(|p| p.parse::<u16>()) {
|
||||
return Ok(NetTarget::Host(host.to_owned(), Some(port)));
|
||||
} else {
|
||||
return Ok(NetTarget::Host(host.to_owned(), None));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Err(format!(
|
||||
"The provided network target `{s}` is not a valid host, ip address or ip network"
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Eq, PartialEq)]
|
||||
pub enum Targets<T: Target + Hash + Eq + PartialEq> {
|
||||
None,
|
||||
Some(HashSet<T>),
|
||||
All,
|
||||
}
|
||||
|
||||
impl<T: Target + Hash + Eq + PartialEq + std::fmt::Debug> Targets<T> {
|
||||
fn matches(&self, elem: &T) -> bool {
|
||||
match self {
|
||||
Self::None => false,
|
||||
Self::All => true,
|
||||
Self::Some(targets) => targets.iter().any(|t| t.matches(elem)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Capabilities are used to limit what a user can do to the system.
|
||||
///
|
||||
/// Capabilities are split into 3 categories:
|
||||
/// - Scripting: Whether or not the user can execute scripts
|
||||
/// - Functions: Whether or not the user can execute certain functions
|
||||
/// - Network: Whether or not the user can access certain network addresses
|
||||
///
|
||||
/// Capabilities are configured globally. By default, capabilities are configured as:
|
||||
/// - Scripting: true
|
||||
/// - Functions: All functions are allowed
|
||||
/// - Network: All network addresses are allowed
|
||||
///
|
||||
/// The capabilities are defined using allow/deny lists for fine-grained control.
|
||||
///
|
||||
/// Examples:
|
||||
/// - Allow all functions: `--allow-funcs`
|
||||
/// - Allow all functions except `http.*`: `--allow-funcs --deny-funcs 'http.*'`
|
||||
/// - Allow all network addresses except AWS metadata endpoint: `--allow-net --deny-net='169.254.169.254'`
|
||||
///
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Capabilities {
|
||||
scripting: bool,
|
||||
|
||||
allow_funcs: Arc<Targets<FuncTarget>>,
|
||||
deny_funcs: Arc<Targets<FuncTarget>>,
|
||||
allow_net: Arc<Targets<NetTarget>>,
|
||||
deny_net: Arc<Targets<NetTarget>>,
|
||||
}
|
||||
|
||||
impl Default for Capabilities {
|
||||
// By default, enable all capabilities
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
scripting: true,
|
||||
|
||||
allow_funcs: Arc::new(Targets::All),
|
||||
deny_funcs: Arc::new(Targets::None),
|
||||
allow_net: Arc::new(Targets::All),
|
||||
deny_net: Arc::new(Targets::None),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Capabilities {
|
||||
pub fn with_scripting(mut self, scripting: bool) -> Self {
|
||||
self.scripting = scripting;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_allow_funcs(mut self, allow_funcs: Targets<FuncTarget>) -> Self {
|
||||
self.allow_funcs = Arc::new(allow_funcs);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_deny_funcs(mut self, deny_funcs: Targets<FuncTarget>) -> Self {
|
||||
self.deny_funcs = Arc::new(deny_funcs);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_allow_net(mut self, allow_net: Targets<NetTarget>) -> Self {
|
||||
self.allow_net = Arc::new(allow_net);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_deny_net(mut self, deny_net: Targets<NetTarget>) -> Self {
|
||||
self.deny_net = Arc::new(deny_net);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn is_allowed_scripting(&self) -> bool {
|
||||
self.scripting
|
||||
}
|
||||
|
||||
pub fn is_allowed_func(&self, target: &FuncTarget) -> bool {
|
||||
self.allow_funcs.matches(target) && !self.deny_funcs.matches(target)
|
||||
}
|
||||
|
||||
pub fn is_allowed_net(&self, target: &NetTarget) -> bool {
|
||||
self.allow_net.matches(target) && !self.deny_net.matches(target)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::str::FromStr;
|
||||
use test_log::test;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_func_target() {
|
||||
assert!(FuncTarget::from_str("test")
|
||||
.unwrap()
|
||||
.matches(&FuncTarget::from_str("test").unwrap()));
|
||||
assert!(!FuncTarget::from_str("test")
|
||||
.unwrap()
|
||||
.matches(&FuncTarget::from_str("test2").unwrap()));
|
||||
|
||||
assert!(!FuncTarget::from_str("test::")
|
||||
.unwrap()
|
||||
.matches(&FuncTarget::from_str("test").unwrap()));
|
||||
|
||||
assert!(FuncTarget::from_str("test::*")
|
||||
.unwrap()
|
||||
.matches(&FuncTarget::from_str("test::name").unwrap()));
|
||||
assert!(!FuncTarget::from_str("test::*")
|
||||
.unwrap()
|
||||
.matches(&FuncTarget::from_str("test2::name").unwrap()));
|
||||
|
||||
assert!(FuncTarget::from_str("test::name")
|
||||
.unwrap()
|
||||
.matches(&FuncTarget::from_str("test::name").unwrap()));
|
||||
assert!(!FuncTarget::from_str("test::name")
|
||||
.unwrap()
|
||||
.matches(&FuncTarget::from_str("test::name2").unwrap()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_net_target() {
|
||||
// IPNet IPv4
|
||||
assert!(NetTarget::from_str("10.0.0.0/8")
|
||||
.unwrap()
|
||||
.matches(&NetTarget::from_str("10.0.1.0/24").unwrap()));
|
||||
assert!(NetTarget::from_str("10.0.0.0/8")
|
||||
.unwrap()
|
||||
.matches(&NetTarget::from_str("10.0.1.2").unwrap()));
|
||||
assert!(!NetTarget::from_str("10.0.0.0/8")
|
||||
.unwrap()
|
||||
.matches(&NetTarget::from_str("20.0.1.0/24").unwrap()));
|
||||
assert!(!NetTarget::from_str("10.0.0.0/8")
|
||||
.unwrap()
|
||||
.matches(&NetTarget::from_str("20.0.1.0").unwrap()));
|
||||
|
||||
// IPNet IPv6
|
||||
assert!(NetTarget::from_str("2001:db8::1")
|
||||
.unwrap()
|
||||
.matches(&NetTarget::from_str("2001:db8::1").unwrap()));
|
||||
assert!(NetTarget::from_str("2001:db8::/32")
|
||||
.unwrap()
|
||||
.matches(&NetTarget::from_str("2001:db8::1").unwrap()));
|
||||
assert!(NetTarget::from_str("2001:db8::/32")
|
||||
.unwrap()
|
||||
.matches(&NetTarget::from_str("2001:db8:abcd:12::/64").unwrap()));
|
||||
assert!(!NetTarget::from_str("2001:db8::/32")
|
||||
.unwrap()
|
||||
.matches(&NetTarget::from_str("2001:db9::1").unwrap()));
|
||||
assert!(!NetTarget::from_str("2001:db8::/32")
|
||||
.unwrap()
|
||||
.matches(&NetTarget::from_str("2001:db9:abcd:12::1/64").unwrap()));
|
||||
|
||||
// Host domain with and without port
|
||||
assert!(NetTarget::from_str("example.com")
|
||||
.unwrap()
|
||||
.matches(&NetTarget::from_str("example.com").unwrap()));
|
||||
assert!(NetTarget::from_str("example.com")
|
||||
.unwrap()
|
||||
.matches(&NetTarget::from_str("example.com:80").unwrap()));
|
||||
assert!(!NetTarget::from_str("example.com")
|
||||
.unwrap()
|
||||
.matches(&NetTarget::from_str("www.example.com").unwrap()));
|
||||
assert!(!NetTarget::from_str("example.com")
|
||||
.unwrap()
|
||||
.matches(&NetTarget::from_str("www.example.com:80").unwrap()));
|
||||
assert!(NetTarget::from_str("example.com:80")
|
||||
.unwrap()
|
||||
.matches(&NetTarget::from_str("example.com:80").unwrap()));
|
||||
assert!(!NetTarget::from_str("example.com:80")
|
||||
.unwrap()
|
||||
.matches(&NetTarget::from_str("example.com:443").unwrap()));
|
||||
assert!(!NetTarget::from_str("example.com:80")
|
||||
.unwrap()
|
||||
.matches(&NetTarget::from_str("example.com").unwrap()));
|
||||
|
||||
// Host IPv4 with and without port
|
||||
assert!(
|
||||
NetTarget::from_str("127.0.0.1")
|
||||
.unwrap()
|
||||
.matches(&NetTarget::from_str("127.0.0.1").unwrap()),
|
||||
"Host IPv4 without port matches itself"
|
||||
);
|
||||
assert!(
|
||||
NetTarget::from_str("127.0.0.1")
|
||||
.unwrap()
|
||||
.matches(&NetTarget::from_str("127.0.0.1:80").unwrap()),
|
||||
"Host IPv4 without port matches Host IPv4 with port"
|
||||
);
|
||||
assert!(
|
||||
NetTarget::from_str("10.0.0.0/8")
|
||||
.unwrap()
|
||||
.matches(&NetTarget::from_str("10.0.0.1:80").unwrap()),
|
||||
"IPv4 network matches Host IPv4 with port"
|
||||
);
|
||||
assert!(
|
||||
NetTarget::from_str("127.0.0.1:80")
|
||||
.unwrap()
|
||||
.matches(&NetTarget::from_str("127.0.0.1:80").unwrap()),
|
||||
"Host IPv4 with port matches itself"
|
||||
);
|
||||
assert!(
|
||||
!NetTarget::from_str("127.0.0.1:80")
|
||||
.unwrap()
|
||||
.matches(&NetTarget::from_str("127.0.0.1").unwrap()),
|
||||
"Host IPv4 with port does not match Host IPv4 without port"
|
||||
);
|
||||
assert!(
|
||||
!NetTarget::from_str("127.0.0.1:80")
|
||||
.unwrap()
|
||||
.matches(&NetTarget::from_str("127.0.0.1:443").unwrap()),
|
||||
"Host IPv4 with port does not match Host IPv4 with different port"
|
||||
);
|
||||
|
||||
// Host IPv6 with and without port
|
||||
assert!(
|
||||
NetTarget::from_str("[2001:db8::1]")
|
||||
.unwrap()
|
||||
.matches(&NetTarget::from_str("[2001:db8::1]").unwrap()),
|
||||
"Host IPv6 without port matches itself"
|
||||
);
|
||||
assert!(
|
||||
NetTarget::from_str("[2001:db8::1]")
|
||||
.unwrap()
|
||||
.matches(&NetTarget::from_str("[2001:db8::1]:80").unwrap()),
|
||||
"Host IPv6 without port matches Host IPv6 with port"
|
||||
);
|
||||
assert!(
|
||||
NetTarget::from_str("2001:db8::1")
|
||||
.unwrap()
|
||||
.matches(&NetTarget::from_str("[2001:db8::1]:80").unwrap()),
|
||||
"IPv6 addr matches Host IPv6 with port"
|
||||
);
|
||||
assert!(
|
||||
NetTarget::from_str("2001:db8::/64")
|
||||
.unwrap()
|
||||
.matches(&NetTarget::from_str("[2001:db8::1]:80").unwrap()),
|
||||
"IPv6 network matches Host IPv6 with port"
|
||||
);
|
||||
assert!(
|
||||
NetTarget::from_str("[2001:db8::1]:80")
|
||||
.unwrap()
|
||||
.matches(&NetTarget::from_str("[2001:db8::1]:80").unwrap()),
|
||||
"Host IPv6 with port matches itself"
|
||||
);
|
||||
assert!(
|
||||
!NetTarget::from_str("[2001:db8::1]:80")
|
||||
.unwrap()
|
||||
.matches(&NetTarget::from_str("[2001:db8::1]").unwrap()),
|
||||
"Host IPv6 with port does not match Host IPv6 without port"
|
||||
);
|
||||
assert!(
|
||||
!NetTarget::from_str("[2001:db8::1]:80")
|
||||
.unwrap()
|
||||
.matches(&NetTarget::from_str("[2001:db8::1]:443").unwrap()),
|
||||
"Host IPv6 with port does not match Host IPv6 with different port"
|
||||
);
|
||||
|
||||
// Test invalid targets
|
||||
assert!(NetTarget::from_str("exam^ple.com").is_err());
|
||||
assert!(NetTarget::from_str("example.com:80:80").is_err());
|
||||
assert!(NetTarget::from_str("11111.3.4.5").is_err());
|
||||
assert!(NetTarget::from_str("2001:db8::1/129").is_err());
|
||||
assert!(NetTarget::from_str("[2001:db8::1").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_targets() {
|
||||
assert!(Targets::<NetTarget>::All.matches(&NetTarget::from_str("example.com").unwrap()));
|
||||
assert!(Targets::<FuncTarget>::All.matches(&FuncTarget::from_str("http::get").unwrap()));
|
||||
assert!(!Targets::<NetTarget>::None.matches(&NetTarget::from_str("example.com").unwrap()));
|
||||
assert!(!Targets::<FuncTarget>::None.matches(&FuncTarget::from_str("http::get").unwrap()));
|
||||
assert!(Targets::<NetTarget>::Some([NetTarget::from_str("example.com").unwrap()].into())
|
||||
.matches(&NetTarget::from_str("example.com").unwrap()));
|
||||
assert!(!Targets::<NetTarget>::Some([NetTarget::from_str("example.com").unwrap()].into())
|
||||
.matches(&NetTarget::from_str("www.example.com").unwrap()));
|
||||
assert!(Targets::<FuncTarget>::Some([FuncTarget::from_str("http::get").unwrap()].into())
|
||||
.matches(&FuncTarget::from_str("http::get").unwrap()));
|
||||
assert!(!Targets::<FuncTarget>::Some([FuncTarget::from_str("http::get").unwrap()].into())
|
||||
.matches(&FuncTarget::from_str("http::post").unwrap()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_capabilities() {
|
||||
// When scripting is allowed
|
||||
{
|
||||
let caps = Capabilities::default().with_scripting(true);
|
||||
assert!(caps.is_allowed_scripting());
|
||||
}
|
||||
|
||||
// When scripting is denied
|
||||
{
|
||||
let caps = Capabilities::default().with_scripting(false);
|
||||
assert!(!caps.is_allowed_scripting());
|
||||
}
|
||||
|
||||
// When all nets are allowed
|
||||
{
|
||||
let caps = Capabilities::default()
|
||||
.with_allow_net(Targets::<NetTarget>::All)
|
||||
.with_deny_net(Targets::<NetTarget>::None);
|
||||
assert!(caps.is_allowed_net(&NetTarget::from_str("example.com").unwrap()));
|
||||
assert!(caps.is_allowed_net(&NetTarget::from_str("example.com:80").unwrap()));
|
||||
}
|
||||
|
||||
// When all nets are allowed and denied at the same time
|
||||
{
|
||||
let caps = Capabilities::default()
|
||||
.with_allow_net(Targets::<NetTarget>::All)
|
||||
.with_deny_net(Targets::<NetTarget>::All);
|
||||
assert!(!caps.is_allowed_net(&NetTarget::from_str("example.com").unwrap()));
|
||||
assert!(!caps.is_allowed_net(&NetTarget::from_str("example.com:80").unwrap()));
|
||||
}
|
||||
|
||||
// When some nets are allowed and some are denied, deny overrides the allow rules
|
||||
{
|
||||
let caps = Capabilities::default()
|
||||
.with_allow_net(Targets::<NetTarget>::Some(
|
||||
[NetTarget::from_str("example.com").unwrap()].into(),
|
||||
))
|
||||
.with_deny_net(Targets::<NetTarget>::Some(
|
||||
[NetTarget::from_str("example.com:80").unwrap()].into(),
|
||||
));
|
||||
assert!(caps.is_allowed_net(&NetTarget::from_str("example.com").unwrap()));
|
||||
assert!(caps.is_allowed_net(&NetTarget::from_str("example.com:443").unwrap()));
|
||||
assert!(!caps.is_allowed_net(&NetTarget::from_str("example.com:80").unwrap()));
|
||||
}
|
||||
|
||||
// When all funcs are allowed
|
||||
{
|
||||
let caps = Capabilities::default()
|
||||
.with_allow_funcs(Targets::<FuncTarget>::All)
|
||||
.with_deny_funcs(Targets::<FuncTarget>::None);
|
||||
assert!(caps.is_allowed_func(&FuncTarget::from_str("http::get").unwrap()));
|
||||
assert!(caps.is_allowed_func(&FuncTarget::from_str("http::post").unwrap()));
|
||||
}
|
||||
|
||||
// When all funcs are allowed and denied at the same time
|
||||
{
|
||||
let caps = Capabilities::default()
|
||||
.with_allow_funcs(Targets::<FuncTarget>::All)
|
||||
.with_deny_funcs(Targets::<FuncTarget>::All);
|
||||
assert!(!caps.is_allowed_func(&FuncTarget::from_str("http::get").unwrap()));
|
||||
assert!(!caps.is_allowed_func(&FuncTarget::from_str("http::post").unwrap()));
|
||||
}
|
||||
|
||||
// When some funcs are allowed and some are denied, deny overrides the allow rules
|
||||
{
|
||||
let caps = Capabilities::default()
|
||||
.with_allow_funcs(Targets::<FuncTarget>::Some(
|
||||
[FuncTarget::from_str("http::*").unwrap()].into(),
|
||||
))
|
||||
.with_deny_funcs(Targets::<FuncTarget>::Some(
|
||||
[FuncTarget::from_str("http::post").unwrap()].into(),
|
||||
));
|
||||
assert!(caps.is_allowed_func(&FuncTarget::from_str("http::get").unwrap()));
|
||||
assert!(caps.is_allowed_func(&FuncTarget::from_str("http::put").unwrap()));
|
||||
assert!(!caps.is_allowed_func(&FuncTarget::from_str("http::post").unwrap()));
|
||||
}
|
||||
}
|
||||
}
|
|
@ -19,12 +19,14 @@ pub use self::options::*;
|
|||
pub use self::response::*;
|
||||
pub use self::session::*;
|
||||
|
||||
pub(crate) use self::capabilities::Capabilities;
|
||||
pub(crate) use self::executor::*;
|
||||
pub(crate) use self::iterator::*;
|
||||
pub(crate) use self::statement::*;
|
||||
pub(crate) use self::transaction::*;
|
||||
pub(crate) use self::variables::*;
|
||||
|
||||
pub mod capabilities;
|
||||
pub mod node;
|
||||
|
||||
mod processor;
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
use super::capabilities::Capabilities;
|
||||
use crate::cnf;
|
||||
use crate::dbs::Notification;
|
||||
use crate::err::Error;
|
||||
|
@ -46,6 +47,8 @@ pub struct Options {
|
|||
pub futures: bool,
|
||||
/// The channel over which we send notifications
|
||||
pub sender: Option<Sender<Notification>>,
|
||||
/// Datastore capabilities
|
||||
pub capabilities: Arc<Capabilities>,
|
||||
}
|
||||
|
||||
impl Default for Options {
|
||||
|
@ -74,6 +77,7 @@ impl Options {
|
|||
auth_enabled: true,
|
||||
sender: None,
|
||||
auth: Arc::new(Auth::default()),
|
||||
capabilities: Arc::new(Capabilities::default()),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -208,6 +212,12 @@ impl Options {
|
|||
self
|
||||
}
|
||||
|
||||
/// Create a new Options object with the given Capabilities
|
||||
pub fn with_capabilities(mut self, capabilities: Arc<Capabilities>) -> Self {
|
||||
self.capabilities = capabilities;
|
||||
self
|
||||
}
|
||||
|
||||
// --------------------------------------------------
|
||||
|
||||
/// Create a new Options object for a subquery
|
||||
|
@ -215,6 +225,7 @@ impl Options {
|
|||
Self {
|
||||
sender: self.sender.clone(),
|
||||
auth: self.auth.clone(),
|
||||
capabilities: self.capabilities.clone(),
|
||||
ns: self.ns.clone(),
|
||||
db: self.db.clone(),
|
||||
perms,
|
||||
|
@ -227,6 +238,7 @@ impl Options {
|
|||
Self {
|
||||
sender: self.sender.clone(),
|
||||
auth: self.auth.clone(),
|
||||
capabilities: self.capabilities.clone(),
|
||||
ns: self.ns.clone(),
|
||||
db: self.db.clone(),
|
||||
force,
|
||||
|
@ -239,6 +251,7 @@ impl Options {
|
|||
Self {
|
||||
sender: self.sender.clone(),
|
||||
auth: self.auth.clone(),
|
||||
capabilities: self.capabilities.clone(),
|
||||
ns: self.ns.clone(),
|
||||
db: self.db.clone(),
|
||||
strict,
|
||||
|
@ -251,6 +264,7 @@ impl Options {
|
|||
Self {
|
||||
sender: self.sender.clone(),
|
||||
auth: self.auth.clone(),
|
||||
capabilities: self.capabilities.clone(),
|
||||
ns: self.ns.clone(),
|
||||
db: self.db.clone(),
|
||||
fields,
|
||||
|
@ -263,6 +277,7 @@ impl Options {
|
|||
Self {
|
||||
sender: self.sender.clone(),
|
||||
auth: self.auth.clone(),
|
||||
capabilities: self.capabilities.clone(),
|
||||
ns: self.ns.clone(),
|
||||
db: self.db.clone(),
|
||||
events,
|
||||
|
@ -275,6 +290,7 @@ impl Options {
|
|||
Self {
|
||||
sender: self.sender.clone(),
|
||||
auth: self.auth.clone(),
|
||||
capabilities: self.capabilities.clone(),
|
||||
ns: self.ns.clone(),
|
||||
db: self.db.clone(),
|
||||
tables,
|
||||
|
@ -287,6 +303,7 @@ impl Options {
|
|||
Self {
|
||||
sender: self.sender.clone(),
|
||||
auth: self.auth.clone(),
|
||||
capabilities: self.capabilities.clone(),
|
||||
ns: self.ns.clone(),
|
||||
db: self.db.clone(),
|
||||
indexes,
|
||||
|
@ -299,6 +316,7 @@ impl Options {
|
|||
Self {
|
||||
sender: self.sender.clone(),
|
||||
auth: self.auth.clone(),
|
||||
capabilities: self.capabilities.clone(),
|
||||
ns: self.ns.clone(),
|
||||
db: self.db.clone(),
|
||||
futures,
|
||||
|
@ -311,6 +329,7 @@ impl Options {
|
|||
Self {
|
||||
sender: self.sender.clone(),
|
||||
auth: self.auth.clone(),
|
||||
capabilities: self.capabilities.clone(),
|
||||
ns: self.ns.clone(),
|
||||
db: self.db.clone(),
|
||||
fields: !import,
|
||||
|
@ -324,6 +343,7 @@ impl Options {
|
|||
pub fn new_with_sender(&self, sender: Sender<Notification>) -> Self {
|
||||
Self {
|
||||
auth: self.auth.clone(),
|
||||
capabilities: self.capabilities.clone(),
|
||||
ns: self.ns.clone(),
|
||||
db: self.db.clone(),
|
||||
sender: Some(sender),
|
||||
|
@ -351,6 +371,7 @@ impl Options {
|
|||
Ok(Self {
|
||||
sender: self.sender.clone(),
|
||||
auth: self.auth.clone(),
|
||||
capabilities: self.capabilities.clone(),
|
||||
ns: self.ns.clone(),
|
||||
db: self.db.clone(),
|
||||
dive,
|
||||
|
|
|
@ -193,6 +193,10 @@ pub enum Error {
|
|||
message: String,
|
||||
},
|
||||
|
||||
/// The URL is invalid
|
||||
#[error("The URL `{0}` is invalid")]
|
||||
InvalidUrl(String),
|
||||
|
||||
/// The query timedout
|
||||
#[error("The query was not executed because it exceeded the timeout")]
|
||||
QueryTimedout,
|
||||
|
@ -597,6 +601,21 @@ pub enum Error {
|
|||
/// Represents an underlying IAM error
|
||||
#[error("IAM error: {0}")]
|
||||
IamError(#[from] IamError),
|
||||
|
||||
//
|
||||
// Capabilities
|
||||
//
|
||||
/// Scripting is not allowed
|
||||
#[error("Scripting functions are not allowed")]
|
||||
ScriptingNotAllowed,
|
||||
|
||||
/// Function is not allowed
|
||||
#[error("Function '{0}' is not allowed to be executed")]
|
||||
FunctionNotAllowed(String),
|
||||
|
||||
/// Network target is not allowed
|
||||
#[error("Access to network target '{0}' is not allowed")]
|
||||
NetTargetNotAllowed(String),
|
||||
}
|
||||
|
||||
impl From<Error> for String {
|
||||
|
|
|
@ -395,6 +395,7 @@ impl<'js> Request<'js> {
|
|||
let url_str = url.to_string()?;
|
||||
let url = Url::parse(&url_str)
|
||||
.map_err(|e| Exception::throw_type(&ctx, &format!("failed to parse url: {e}")))?;
|
||||
|
||||
if !url.username().is_empty() || !url.password().map(str::is_empty).unwrap_or(true) {
|
||||
// url cannot contain non empty username and passwords
|
||||
return Err(Exception::throw_type(&ctx, "Url contained credentials."));
|
||||
|
|
|
@ -1,12 +1,15 @@
|
|||
//! Contains the actual fetch function.
|
||||
|
||||
use crate::fnc::script::fetch::{
|
||||
body::{Body, BodyData, BodyKind},
|
||||
classes::{self, Request, RequestInit, Response, ResponseInit, ResponseType},
|
||||
RequestError,
|
||||
use crate::fnc::script::{
|
||||
fetch::{
|
||||
body::{Body, BodyData, BodyKind},
|
||||
classes::{self, Request, RequestInit, Response, ResponseInit, ResponseType},
|
||||
RequestError,
|
||||
},
|
||||
modules::surrealdb::query::{QueryContext, QUERY_DATA_PROP_NAME},
|
||||
};
|
||||
use futures::TryStreamExt;
|
||||
use js::{function::Opt, Class, Ctx, Exception, Result, Value};
|
||||
use js::{class::OwnedBorrow, function::Opt, Class, Ctx, Exception, Result, Value};
|
||||
use reqwest::{
|
||||
header::{HeaderValue, CONTENT_TYPE},
|
||||
redirect, Body as ReqBody,
|
||||
|
@ -27,6 +30,19 @@ pub async fn fetch<'js>(
|
|||
|
||||
let url = js_req.url;
|
||||
|
||||
// Check if the url is allowed to be fetched.
|
||||
if ctx.globals().contains_key(QUERY_DATA_PROP_NAME)? {
|
||||
let query_ctx =
|
||||
ctx.globals().get::<_, OwnedBorrow<'js, QueryContext<'js>>>(QUERY_DATA_PROP_NAME)?;
|
||||
query_ctx
|
||||
.context
|
||||
.check_allowed_net(&url)
|
||||
.map_err(|e| Exception::throw_message(&ctx, &e.to_string()))?;
|
||||
} else {
|
||||
#[cfg(debug_assertions)]
|
||||
panic!("Trying to fetch a URL but no QueryContext is present. QueryContext is required for checking if the URL is allowed to be fetched.")
|
||||
}
|
||||
|
||||
let req = reqwest::Request::new(js_req.init.method, url.clone());
|
||||
|
||||
// SurrealDB Implementation keeps all javascript parts inside the context::with scope so this
|
||||
|
@ -118,146 +134,3 @@ pub async fn fetch<'js>(
|
|||
};
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use crate::fnc::script::fetch::test::create_test_context;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_fetch_get() {
|
||||
use js::{promise::Promise, CatchResultExt};
|
||||
use wiremock::{
|
||||
matchers::{header, method, path},
|
||||
Mock, MockServer, ResponseTemplate,
|
||||
};
|
||||
|
||||
let server = MockServer::start().await;
|
||||
|
||||
Mock::given(method("GET"))
|
||||
.and(path("/hello"))
|
||||
.and(header("some-header", "some-value"))
|
||||
.respond_with(ResponseTemplate::new(200).set_body_string("some body once told me"))
|
||||
.expect(1)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let server_ref = &server;
|
||||
|
||||
create_test_context!(ctx => {
|
||||
ctx.globals().set("SERVER_URL",server_ref.uri()).unwrap();
|
||||
|
||||
ctx.eval::<Promise<()>,_>(r#"
|
||||
(async () => {
|
||||
let res = await fetch(SERVER_URL + '/hello',{
|
||||
headers: {
|
||||
"some-header": "some-value",
|
||||
}
|
||||
});
|
||||
assert.seq(res.status,200);
|
||||
let body = await res.text();
|
||||
assert.seq(body,'some body once told me');
|
||||
})()
|
||||
"#).catch(&ctx).unwrap().await.catch(&ctx).unwrap()
|
||||
})
|
||||
.await;
|
||||
|
||||
server.verify().await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_fetch_put() {
|
||||
use js::{promise::Promise, CatchResultExt};
|
||||
use wiremock::{
|
||||
matchers::{body_string, header, method, path},
|
||||
Mock, MockServer, ResponseTemplate,
|
||||
};
|
||||
|
||||
let server = MockServer::start().await;
|
||||
|
||||
Mock::given(method("PUT"))
|
||||
.and(path("/hello"))
|
||||
.and(header("some-header", "some-value"))
|
||||
.and(body_string("some text"))
|
||||
.respond_with(ResponseTemplate::new(201).set_body_string("some body once told me"))
|
||||
.expect(1)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let server_ref = &server;
|
||||
|
||||
create_test_context!(ctx => {
|
||||
ctx.globals().set("SERVER_URL",server_ref.uri()).unwrap();
|
||||
|
||||
ctx.eval::<Promise<()>,_>(r#"
|
||||
(async () => {
|
||||
let res = await fetch(SERVER_URL + '/hello',{
|
||||
method: "PuT",
|
||||
headers: {
|
||||
"some-header": "some-value",
|
||||
},
|
||||
body: "some text",
|
||||
});
|
||||
assert.seq(res.status,201);
|
||||
assert(res.ok);
|
||||
let body = await res.text();
|
||||
assert.seq(body,'some body once told me');
|
||||
})()
|
||||
"#).catch(&ctx).unwrap().await.catch(&ctx).unwrap()
|
||||
})
|
||||
.await;
|
||||
|
||||
server.verify().await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_fetch_error() {
|
||||
use js::{promise::Promise, CatchResultExt};
|
||||
use wiremock::{
|
||||
matchers::{body_string, header, method, path},
|
||||
Mock, MockServer, ResponseTemplate,
|
||||
};
|
||||
|
||||
let server = MockServer::start().await;
|
||||
|
||||
Mock::given(method("PROPPATCH"))
|
||||
.and(path("/hello"))
|
||||
.and(header("some-header", "some-value"))
|
||||
.and(body_string("some text"))
|
||||
.respond_with(ResponseTemplate::new(500).set_body_json(serde_json::json!({
|
||||
"foo": "bar",
|
||||
"baz": 2,
|
||||
})))
|
||||
.expect(1)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let server_ref = &server;
|
||||
|
||||
create_test_context!(ctx => {
|
||||
ctx.globals().set("SERVER_URL",server_ref.uri()).unwrap();
|
||||
|
||||
ctx.eval::<Promise<()>,_>(r#"
|
||||
(async () => {
|
||||
let req = new Request(SERVER_URL + '/hello',{
|
||||
method: "PROPPATCH",
|
||||
headers: {
|
||||
"some-header": "some-value",
|
||||
},
|
||||
body: "some text",
|
||||
})
|
||||
let res = await fetch(req);
|
||||
assert.seq(res.status,500);
|
||||
assert(!res.ok);
|
||||
let body = await res.json();
|
||||
assert(body.foo !== undefined, "body.foo not defined");
|
||||
assert(body.baz !== undefined, "body.foo not defined");
|
||||
assert.seq(body.foo, "bar");
|
||||
assert.seq(body.baz, 2);
|
||||
})()
|
||||
"#).catch(&ctx).unwrap().await.catch(&ctx).unwrap()
|
||||
})
|
||||
.await;
|
||||
|
||||
server.verify().await;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -16,3 +16,6 @@ mod fetch;
|
|||
mod fetch_stub;
|
||||
#[cfg(not(feature = "http"))]
|
||||
use self::fetch_stub as fetch;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
|
203
lib/src/fnc/script/tests/fetch.rs
Normal file
203
lib/src/fnc/script/tests/fetch.rs
Normal file
|
@ -0,0 +1,203 @@
|
|||
use std::str::FromStr;
|
||||
|
||||
use wiremock::{
|
||||
matchers::{body_string, header, method, path},
|
||||
Mock, MockServer, ResponseTemplate,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
dbs::{
|
||||
capabilities::{NetTarget, Targets},
|
||||
Capabilities, Session,
|
||||
},
|
||||
kvs::Datastore,
|
||||
};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_fetch_get() {
|
||||
// Prepare mock server
|
||||
let server = MockServer::start().await;
|
||||
Mock::given(method("GET"))
|
||||
.and(path("/hello"))
|
||||
.and(header("some-header", "some-value"))
|
||||
.respond_with(ResponseTemplate::new(200).set_body_string("some body once told me"))
|
||||
.expect(1)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
// Execute test
|
||||
let ds = Datastore::new("memory").await.unwrap();
|
||||
let sess = Session::owner();
|
||||
let sql = format!(
|
||||
r#"
|
||||
RETURN function() {{
|
||||
let res = await fetch('{}/hello',{{
|
||||
headers: {{
|
||||
"some-header": "some-value",
|
||||
}}
|
||||
}});
|
||||
let body = await res.text();
|
||||
|
||||
return {{ status: res.status, body: body }};
|
||||
}}
|
||||
"#,
|
||||
server.uri()
|
||||
);
|
||||
let res = ds.execute(&sql, &sess, None).await;
|
||||
|
||||
let res = res.unwrap().remove(0).output().unwrap();
|
||||
|
||||
server.verify().await;
|
||||
|
||||
assert_eq!(
|
||||
res.to_string(),
|
||||
"{ body: 'some body once told me', status: 200f }",
|
||||
"Unexpected result: {:?}",
|
||||
res
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_fetch_put() {
|
||||
// Prepare mock server
|
||||
let server = MockServer::start().await;
|
||||
Mock::given(method("PUT"))
|
||||
.and(path("/hello"))
|
||||
.and(header("some-header", "some-value"))
|
||||
.and(body_string("some text"))
|
||||
.respond_with(ResponseTemplate::new(201).set_body_string("some body once told me"))
|
||||
.expect(1)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
// Execute test
|
||||
let ds = Datastore::new("memory").await.unwrap();
|
||||
let sess = Session::owner();
|
||||
let sql = format!(
|
||||
r#"
|
||||
RETURN function() {{
|
||||
let res = await fetch('{}/hello',{{
|
||||
method: "PuT",
|
||||
headers: {{
|
||||
"some-header": "some-value",
|
||||
}},
|
||||
body: "some text",
|
||||
}});
|
||||
let body = await res.text();
|
||||
|
||||
return {{ status: res.status, body: body }};
|
||||
}}
|
||||
"#,
|
||||
server.uri()
|
||||
);
|
||||
let res = ds.execute(&sql, &sess, None).await;
|
||||
|
||||
let res = res.unwrap().remove(0).output().unwrap();
|
||||
|
||||
server.verify().await;
|
||||
|
||||
assert_eq!(
|
||||
res.to_string(),
|
||||
"{ body: 'some body once told me', status: 201f }",
|
||||
"Unexpected result: {:?}",
|
||||
res
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_fetch_error() {
|
||||
// Prepare mock server
|
||||
let server = MockServer::start().await;
|
||||
Mock::given(method("PROPPATCH"))
|
||||
.and(path("/hello"))
|
||||
.and(header("some-header", "some-value"))
|
||||
.and(body_string("some text"))
|
||||
.respond_with(ResponseTemplate::new(500).set_body_json(serde_json::json!({
|
||||
"foo": "bar",
|
||||
"baz": 2,
|
||||
})))
|
||||
.expect(1)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
// Execute test
|
||||
let ds = Datastore::new("memory").await.unwrap();
|
||||
let sess = Session::owner();
|
||||
let sql = format!(
|
||||
r#"
|
||||
RETURN function() {{
|
||||
let res = await fetch('{}/hello',{{
|
||||
method: "PROPPATCH",
|
||||
headers: {{
|
||||
"some-header": "some-value",
|
||||
}},
|
||||
body: "some text",
|
||||
}});
|
||||
let body = await res.text();
|
||||
|
||||
return {{ status: res.status, body: body }};
|
||||
}}
|
||||
"#,
|
||||
server.uri()
|
||||
);
|
||||
let res = ds.execute(&sql, &sess, None).await;
|
||||
|
||||
let res = res.unwrap().remove(0).output().unwrap();
|
||||
|
||||
server.verify().await;
|
||||
|
||||
assert_eq!(
|
||||
res.to_string(),
|
||||
"{ body: '{\"foo\":\"bar\",\"baz\":2}', status: 500f }",
|
||||
"Unexpected result: {:?}",
|
||||
res
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_fetch_denied() {
|
||||
// Prepare mock server
|
||||
let server = MockServer::start().await;
|
||||
Mock::given(method("GET"))
|
||||
.and(path("/hello"))
|
||||
.and(header("some-header", "some-value"))
|
||||
.respond_with(ResponseTemplate::new(200).set_body_string("some body once told me"))
|
||||
.expect(0)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
// Execute test
|
||||
let ds = Datastore::new("memory").await.unwrap().with_capabilities(
|
||||
Capabilities::default().with_deny_net(Targets::Some(
|
||||
[NetTarget::from_str(&server.address().to_string()).unwrap()].into(),
|
||||
)),
|
||||
);
|
||||
let sess = Session::owner();
|
||||
let sql = format!(
|
||||
r#"
|
||||
RETURN function() {{
|
||||
let res = await fetch('{}/hello',{{
|
||||
headers: {{
|
||||
"some-header": "some-value",
|
||||
}}
|
||||
}});
|
||||
let body = await res.text();
|
||||
|
||||
return {{ status: res.status, body: body }};
|
||||
}}
|
||||
"#,
|
||||
server.uri()
|
||||
);
|
||||
let res = ds.execute(&sql, &sess, None).await;
|
||||
|
||||
let res = res.unwrap().remove(0).output().unwrap_err();
|
||||
|
||||
server.verify().await;
|
||||
|
||||
assert!(
|
||||
res.to_string()
|
||||
.contains(&format!("Access to network target '{}/hello' is not allowed", server.uri())),
|
||||
"Unexpected result: {:?}",
|
||||
res
|
||||
);
|
||||
}
|
1
lib/src/fnc/script/tests/mod.rs
Normal file
1
lib/src/fnc/script/tests/mod.rs
Normal file
|
@ -0,0 +1 @@
|
|||
mod fetch;
|
|
@ -6,6 +6,7 @@ use crate::sql::value::Value;
|
|||
use crate::sql::{json, Bytes};
|
||||
use reqwest::header::CONTENT_TYPE;
|
||||
use reqwest::{Client, RequestBuilder, Response};
|
||||
use url::Url;
|
||||
|
||||
pub(crate) fn uri_is_valid(uri: &str) -> bool {
|
||||
reqwest::Url::parse(uri).is_ok()
|
||||
|
@ -46,10 +47,13 @@ async fn decode_response(res: Response) -> Result<Value, Error> {
|
|||
}
|
||||
|
||||
pub async fn head(ctx: &Context<'_>, uri: Strand, opts: impl Into<Object>) -> Result<Value, Error> {
|
||||
// Check if the URI is valid and allowed
|
||||
let url = Url::parse(&uri).map_err(|_| Error::InvalidUrl(uri.to_string()))?;
|
||||
ctx.check_allowed_net(&url)?;
|
||||
// Set a default client with no timeout
|
||||
let cli = Client::builder().build()?;
|
||||
// Start a new HEAD request
|
||||
let mut req = cli.head(uri.as_str());
|
||||
let mut req = cli.head(url);
|
||||
// Add the User-Agent header
|
||||
if cfg!(not(target_arch = "wasm32")) {
|
||||
req = req.header("User-Agent", "SurrealDB");
|
||||
|
@ -72,10 +76,13 @@ pub async fn head(ctx: &Context<'_>, uri: Strand, opts: impl Into<Object>) -> Re
|
|||
}
|
||||
|
||||
pub async fn get(ctx: &Context<'_>, uri: Strand, opts: impl Into<Object>) -> Result<Value, Error> {
|
||||
// Check if the URI is valid and allowed
|
||||
let url = Url::parse(&uri).map_err(|_| Error::InvalidUrl(uri.to_string()))?;
|
||||
ctx.check_allowed_net(&url)?;
|
||||
// Set a default client with no timeout
|
||||
let cli = Client::builder().build()?;
|
||||
// Start a new GET request
|
||||
let mut req = cli.get(uri.as_str());
|
||||
let mut req = cli.get(url);
|
||||
// Add the User-Agent header
|
||||
if cfg!(not(target_arch = "wasm32")) {
|
||||
req = req.header("User-Agent", "SurrealDB");
|
||||
|
@ -100,10 +107,13 @@ pub async fn put(
|
|||
body: Value,
|
||||
opts: impl Into<Object>,
|
||||
) -> Result<Value, Error> {
|
||||
// Check if the URI is valid and allowed
|
||||
let url = Url::parse(&uri).map_err(|_| Error::InvalidUrl(uri.to_string()))?;
|
||||
ctx.check_allowed_net(&url)?;
|
||||
// Set a default client with no timeout
|
||||
let cli = Client::builder().build()?;
|
||||
// Start a new GET request
|
||||
let mut req = cli.put(uri.as_str());
|
||||
let mut req = cli.put(url);
|
||||
// Add the User-Agent header
|
||||
if cfg!(not(target_arch = "wasm32")) {
|
||||
req = req.header("User-Agent", "SurrealDB");
|
||||
|
@ -130,10 +140,13 @@ pub async fn post(
|
|||
body: Value,
|
||||
opts: impl Into<Object>,
|
||||
) -> Result<Value, Error> {
|
||||
// Check if the URI is valid and allowed
|
||||
let url = Url::parse(&uri).map_err(|_| Error::InvalidUrl(uri.to_string()))?;
|
||||
ctx.check_allowed_net(&url)?;
|
||||
// Set a default client with no timeout
|
||||
let cli = Client::builder().build()?;
|
||||
// Start a new GET request
|
||||
let mut req = cli.post(uri.as_str());
|
||||
let mut req = cli.post(url);
|
||||
// Add the User-Agent header
|
||||
if cfg!(not(target_arch = "wasm32")) {
|
||||
req = req.header("User-Agent", "SurrealDB");
|
||||
|
@ -160,10 +173,13 @@ pub async fn patch(
|
|||
body: Value,
|
||||
opts: impl Into<Object>,
|
||||
) -> Result<Value, Error> {
|
||||
// Check if the URI is valid and allowed
|
||||
let url = Url::parse(&uri).map_err(|_| Error::InvalidUrl(uri.to_string()))?;
|
||||
ctx.check_allowed_net(&url)?;
|
||||
// Set a default client with no timeout
|
||||
let cli = Client::builder().build()?;
|
||||
// Start a new GET request
|
||||
let mut req = cli.patch(uri.as_str());
|
||||
let mut req = cli.patch(url);
|
||||
// Add the User-Agent header
|
||||
if cfg!(not(target_arch = "wasm32")) {
|
||||
req = req.header("User-Agent", "SurrealDB");
|
||||
|
@ -189,10 +205,13 @@ pub async fn delete(
|
|||
uri: Strand,
|
||||
opts: impl Into<Object>,
|
||||
) -> Result<Value, Error> {
|
||||
// Check if the URI is valid and allowed
|
||||
let url = Url::parse(&uri).map_err(|_| Error::InvalidUrl(uri.to_string()))?;
|
||||
ctx.check_allowed_net(&url)?;
|
||||
// Set a default client with no timeout
|
||||
let cli = Client::builder().build()?;
|
||||
// Start a new GET request
|
||||
let mut req = cli.delete(uri.as_str());
|
||||
let mut req = cli.delete(url);
|
||||
// Add the User-Agent header
|
||||
if cfg!(not(target_arch = "wasm32")) {
|
||||
req = req.header("User-Agent", "SurrealDB");
|
||||
|
|
|
@ -3,6 +3,7 @@ use crate::cf;
|
|||
use crate::ctx::Context;
|
||||
use crate::dbs::node::Timestamp;
|
||||
use crate::dbs::Attach;
|
||||
use crate::dbs::Capabilities;
|
||||
use crate::dbs::Executor;
|
||||
use crate::dbs::Notification;
|
||||
use crate::dbs::Options;
|
||||
|
@ -62,6 +63,8 @@ pub struct Datastore {
|
|||
notification_channel: Option<(Sender<Notification>, Receiver<Notification>)>,
|
||||
// Whether this datastore authentication is enabled. When disabled, anonymous actors have owner-level access.
|
||||
auth_enabled: bool,
|
||||
// Capabilities for this datastore
|
||||
capabilities: Capabilities,
|
||||
}
|
||||
|
||||
#[allow(clippy::large_enum_variant)]
|
||||
|
@ -260,6 +263,7 @@ impl Datastore {
|
|||
vso: Arc::new(Mutex::new(vs::Oracle::systime_counter())),
|
||||
notification_channel: None,
|
||||
auth_enabled: false,
|
||||
capabilities: Capabilities::default(),
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -326,6 +330,12 @@ impl Datastore {
|
|||
self
|
||||
}
|
||||
|
||||
/// Configure Datastore capabilities
|
||||
pub fn with_capabilities(mut self, caps: Capabilities) -> Self {
|
||||
self.capabilities = caps;
|
||||
self
|
||||
}
|
||||
|
||||
/// Creates a new datastore instance
|
||||
///
|
||||
/// Use this for clustered environments.
|
||||
|
@ -752,6 +762,7 @@ impl Datastore {
|
|||
let mut exe = Executor::new(self);
|
||||
// Create a default context
|
||||
let mut ctx = Context::default();
|
||||
ctx.add_capabilities(self.capabilities.clone());
|
||||
// Set the global query timeout
|
||||
if let Some(timeout) = self.query_timeout {
|
||||
ctx.add_timeout(timeout);
|
||||
|
@ -808,6 +819,8 @@ impl Datastore {
|
|||
let txn = Arc::new(Mutex::new(txn));
|
||||
// Create a default context
|
||||
let mut ctx = Context::default();
|
||||
// Set context capabilities
|
||||
ctx.add_capabilities(self.capabilities.clone());
|
||||
// Set the global query timeout
|
||||
if let Some(timeout) = self.query_timeout {
|
||||
ctx.add_timeout(timeout);
|
||||
|
|
|
@ -160,12 +160,16 @@ impl Function {
|
|||
// Process the function type
|
||||
match self {
|
||||
Self::Normal(s, x) => {
|
||||
// Check this function is allowed
|
||||
ctx.check_allowed_function(s)?;
|
||||
// Compute the function arguments
|
||||
let a = try_join_all(x.iter().map(|v| v.compute(ctx, opt, txn, doc))).await?;
|
||||
// Run the normal function
|
||||
fnc::run(ctx, txn, doc, s, a).await
|
||||
}
|
||||
Self::Custom(s, x) => {
|
||||
// Check this function is allowed
|
||||
ctx.check_allowed_function(format!("fn::{s}").as_str())?;
|
||||
// Get the function definition
|
||||
let val = {
|
||||
// Claim transaction
|
||||
|
@ -198,6 +202,8 @@ impl Function {
|
|||
Self::Script(s, x) => {
|
||||
#[cfg(feature = "scripting")]
|
||||
{
|
||||
// Check if scripting is allowed
|
||||
ctx.check_allowed_scripting()?;
|
||||
// Compute the function arguments
|
||||
let a = try_join_all(x.iter().map(|v| v.compute(ctx, opt, txn, doc))).await?;
|
||||
// Run the script function
|
||||
|
|
|
@ -10,7 +10,6 @@ use crate::err::Error;
|
|||
use crate::net::{self, client_ip::ClientIp};
|
||||
use crate::node;
|
||||
use clap::Args;
|
||||
use ipnet::IpNet;
|
||||
use opentelemetry::Context as TelemetryContext;
|
||||
use std::net::SocketAddr;
|
||||
use std::path::PathBuf;
|
||||
|
@ -24,8 +23,35 @@ pub struct StartCommandArguments {
|
|||
#[arg(default_value = "memory")]
|
||||
#[arg(value_parser = super::validator::path_valid)]
|
||||
path: String,
|
||||
#[arg(help = "The logging level for the database server")]
|
||||
#[arg(env = "SURREAL_LOG", short = 'l', long = "log")]
|
||||
#[arg(default_value = "info")]
|
||||
#[arg(value_parser = CustomEnvFilterParser::new())]
|
||||
log: CustomEnvFilter,
|
||||
#[arg(help = "Whether to hide the startup banner")]
|
||||
#[arg(env = "SURREAL_NO_BANNER", long)]
|
||||
#[arg(default_value_t = false)]
|
||||
no_banner: bool,
|
||||
#[arg(help = "Encryption key to use for on-disk encryption")]
|
||||
#[arg(env = "SURREAL_KEY", short = 'k', long = "key")]
|
||||
#[arg(value_parser = super::validator::key_valid)]
|
||||
#[arg(hide = true)] // Not currently in use
|
||||
key: Option<String>,
|
||||
|
||||
#[arg(
|
||||
help = "The username for the initial database root user. Only if no other root user exists"
|
||||
help = "The interval at which to run node agent tick (including garbage collection)",
|
||||
help_heading = "Database"
|
||||
)]
|
||||
#[arg(env = "SURREAL_TICK_INTERVAL", long = "tick-interval", value_parser = super::validator::duration)]
|
||||
#[arg(default_value = "10s")]
|
||||
tick_interval: Duration,
|
||||
|
||||
//
|
||||
// Authentication
|
||||
//
|
||||
#[arg(
|
||||
help = "The username for the initial database root user. Only if no other root user exists",
|
||||
help_heading = "Authentication"
|
||||
)]
|
||||
#[arg(
|
||||
env = "SURREAL_USER",
|
||||
|
@ -36,7 +62,8 @@ pub struct StartCommandArguments {
|
|||
)]
|
||||
username: Option<String>,
|
||||
#[arg(
|
||||
help = "The password for the initial database root user. Only if no other root user exists"
|
||||
help = "The password for the initial database root user. Only if no other root user exists",
|
||||
help_heading = "Authentication"
|
||||
)]
|
||||
#[arg(
|
||||
env = "SURREAL_PASS",
|
||||
|
@ -46,10 +73,20 @@ pub struct StartCommandArguments {
|
|||
requires = "username"
|
||||
)]
|
||||
password: Option<String>,
|
||||
#[arg(help = "The allowed networks for master authentication")]
|
||||
#[arg(env = "SURREAL_ADDR", long = "addr")]
|
||||
#[arg(default_value = "127.0.0.1/32")]
|
||||
allowed_networks: Vec<IpNet>,
|
||||
|
||||
//
|
||||
// Datastore connection
|
||||
//
|
||||
#[command(next_help_heading = "Datastore connection")]
|
||||
#[command(flatten)]
|
||||
kvs: Option<StartCommandRemoteTlsOptions>,
|
||||
|
||||
//
|
||||
// HTTP Server
|
||||
//
|
||||
#[command(next_help_heading = "HTTP server")]
|
||||
#[command(flatten)]
|
||||
web: Option<StartCommandWebTlsOptions>,
|
||||
#[arg(help = "The method of detecting the client's IP address")]
|
||||
#[arg(env = "SURREAL_CLIENT_IP", long)]
|
||||
#[arg(default_value = "socket", value_enum)]
|
||||
|
@ -58,29 +95,13 @@ pub struct StartCommandArguments {
|
|||
#[arg(env = "SURREAL_BIND", short = 'b', long = "bind")]
|
||||
#[arg(default_value = "0.0.0.0:8000")]
|
||||
listen_addresses: Vec<SocketAddr>,
|
||||
#[arg(help = "The interval at which to run node agent tick (including garbage collection)")]
|
||||
#[arg(env = "SURREAL_TICK_INTERVAL", long = "tick-interval", value_parser = super::validator::duration)]
|
||||
#[arg(default_value = "10s")]
|
||||
tick_interval: Duration,
|
||||
|
||||
//
|
||||
// Database options
|
||||
//
|
||||
#[command(flatten)]
|
||||
#[command(next_help_heading = "Database")]
|
||||
dbs: StartCommandDbsOptions,
|
||||
#[arg(help = "Encryption key to use for on-disk encryption")]
|
||||
#[arg(env = "SURREAL_KEY", short = 'k', long = "key")]
|
||||
#[arg(value_parser = super::validator::key_valid)]
|
||||
key: Option<String>,
|
||||
#[command(flatten)]
|
||||
kvs: Option<StartCommandRemoteTlsOptions>,
|
||||
#[command(flatten)]
|
||||
web: Option<StartCommandWebTlsOptions>,
|
||||
#[arg(help = "The logging level for the database server")]
|
||||
#[arg(env = "SURREAL_LOG", short = 'l', long = "log")]
|
||||
#[arg(default_value = "info")]
|
||||
#[arg(value_parser = CustomEnvFilterParser::new())]
|
||||
log: CustomEnvFilter,
|
||||
#[arg(help = "Whether to hide the startup banner")]
|
||||
#[arg(env = "SURREAL_NO_BANNER", long)]
|
||||
#[arg(default_value_t = false)]
|
||||
no_banner: bool,
|
||||
}
|
||||
|
||||
#[derive(Args, Debug)]
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
use std::collections::HashSet;
|
||||
#[cfg(feature = "has-storage")]
|
||||
use std::{
|
||||
path::{Path, PathBuf},
|
||||
|
@ -5,6 +6,8 @@ use std::{
|
|||
time::Duration,
|
||||
};
|
||||
|
||||
use surrealdb::dbs::capabilities::{FuncTarget, NetTarget, Targets};
|
||||
|
||||
pub(crate) mod parser;
|
||||
|
||||
#[cfg(feature = "has-storage")]
|
||||
|
@ -76,3 +79,77 @@ pub(crate) fn key_valid(v: &str) -> Result<String, String> {
|
|||
pub(crate) fn duration(v: &str) -> Result<Duration, String> {
|
||||
surrealdb::sql::Duration::from_str(v).map(|d| d.0).map_err(|_| String::from("invalid duration"))
|
||||
}
|
||||
|
||||
pub(crate) fn net_targets(value: &str) -> Result<Targets<NetTarget>, String> {
|
||||
if ["*", ""].contains(&value) {
|
||||
return Ok(Targets::All);
|
||||
}
|
||||
|
||||
let mut result = HashSet::new();
|
||||
|
||||
for target in value.split(',').filter(|s| !s.is_empty()) {
|
||||
result.insert(NetTarget::from_str(target)?);
|
||||
}
|
||||
|
||||
Ok(Targets::Some(result))
|
||||
}
|
||||
|
||||
pub(crate) fn func_targets(value: &str) -> Result<Targets<FuncTarget>, String> {
|
||||
if ["*", ""].contains(&value) {
|
||||
return Ok(Targets::All);
|
||||
}
|
||||
|
||||
let mut result = HashSet::new();
|
||||
|
||||
for target in value.split(',').filter(|s| !s.is_empty()) {
|
||||
result.insert(FuncTarget::from_str(target)?);
|
||||
}
|
||||
|
||||
Ok(Targets::Some(result))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_func_targets() {
|
||||
assert_eq!(func_targets("*").unwrap(), Targets::<FuncTarget>::All);
|
||||
assert_eq!(func_targets("").unwrap(), Targets::<FuncTarget>::All);
|
||||
assert_eq!(
|
||||
func_targets("foo").unwrap(),
|
||||
Targets::<FuncTarget>::Some(vec!["foo".parse().unwrap()].into_iter().collect())
|
||||
);
|
||||
assert_eq!(
|
||||
func_targets("foo,bar").unwrap(),
|
||||
Targets::<FuncTarget>::Some(
|
||||
vec!["foo".parse().unwrap(), "bar".parse().unwrap()].into_iter().collect()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_net_targets() {
|
||||
assert_eq!(net_targets("*").unwrap(), Targets::<NetTarget>::All);
|
||||
assert_eq!(net_targets("").unwrap(), Targets::<NetTarget>::All);
|
||||
assert_eq!(
|
||||
net_targets("example.com").unwrap(),
|
||||
Targets::<NetTarget>::Some(vec!["example.com".parse().unwrap()].into_iter().collect())
|
||||
);
|
||||
assert_eq!(
|
||||
net_targets("127.0.0.1:80,[2001:db8::1]:443,2001:db8::1").unwrap(),
|
||||
Targets::<NetTarget>::Some(
|
||||
vec![
|
||||
"127.0.0.1:80".parse().unwrap(),
|
||||
"[2001:db8::1]:443".parse().unwrap(),
|
||||
"2001:db8::1".parse().unwrap()
|
||||
]
|
||||
.into_iter()
|
||||
.collect()
|
||||
)
|
||||
);
|
||||
|
||||
assert!(net_targets("127777.0.0.1").is_err());
|
||||
assert!(net_targets("127.0.0.1,127777.0.0.1").is_err());
|
||||
}
|
||||
}
|
||||
|
|
360
src/dbs/mod.rs
360
src/dbs/mod.rs
|
@ -1,8 +1,9 @@
|
|||
use crate::cli::CF;
|
||||
use crate::err::Error;
|
||||
use clap::Args;
|
||||
use clap::{ArgAction, Args};
|
||||
use std::sync::OnceLock;
|
||||
use std::time::Duration;
|
||||
use surrealdb::dbs::capabilities::{Capabilities, FuncTarget, NetTarget, Targets};
|
||||
use surrealdb::kvs::Datastore;
|
||||
use surrealdb::opt::auth::Root;
|
||||
|
||||
|
@ -22,10 +23,175 @@ pub struct StartCommandDbsOptions {
|
|||
#[arg(env = "SURREAL_TRANSACTION_TIMEOUT", long)]
|
||||
#[arg(value_parser = super::cli::validator::duration)]
|
||||
transaction_timeout: Option<Duration>,
|
||||
#[arg(help = "Whether to enable authentication")]
|
||||
#[arg(help = "Whether to enable authentication", help_heading = "Authentication")]
|
||||
#[arg(env = "SURREAL_AUTH", long = "auth")]
|
||||
#[arg(default_value_t = false)]
|
||||
auth_enabled: bool,
|
||||
#[command(flatten)]
|
||||
#[command(next_help_heading = "Capabilities")]
|
||||
caps: DbsCapabilities,
|
||||
}
|
||||
|
||||
#[derive(Args, Debug)]
|
||||
struct DbsCapabilities {
|
||||
//
|
||||
// Allow
|
||||
//
|
||||
#[arg(help = "Allow all capabilities")]
|
||||
#[arg(env = "SURREAL_CAPS_ALLOW_ALL", short = 'A', long, conflicts_with = "deny_all")]
|
||||
#[arg(default_missing_value_os = "true", action = ArgAction::Set, num_args = 0..)]
|
||||
#[arg(default_value_t = false, hide_default_value = true)]
|
||||
allow_all: bool,
|
||||
|
||||
#[cfg(feature = "scripting")]
|
||||
#[arg(help = "Allow execution of scripting functions")]
|
||||
#[arg(env = "SURREAL_CAPS_ALLOW_SCRIPT", long, conflicts_with = "allow_all")]
|
||||
#[arg(default_missing_value_os = "true", action = ArgAction::Set, num_args = 0..)]
|
||||
#[arg(default_value_t = true, hide_default_value = true)]
|
||||
allow_scripting: bool,
|
||||
|
||||
#[arg(
|
||||
help = "Allow execution of functions. Optionally, you can provide a comma-separated list of function names to allow",
|
||||
long_help = r#"Allow execution of functions. Optionally, you can provide a comma-separated list of function names to allow.
|
||||
Function names must be in the form <family>[::<name>]. For example:
|
||||
- 'http' or 'http::*' -> Include all functions in the 'http' family
|
||||
- 'http::get' -> Include only the 'get' function in the 'http' family
|
||||
"#
|
||||
)]
|
||||
#[arg(env = "SURREAL_CAPS_ALLOW_FUNC", long, conflicts_with = "allow_all")]
|
||||
// If the arg is provided without value, then assume it's "", which gets parsed into Targets::All
|
||||
#[arg(default_value_os = "", default_missing_value_os = "", num_args = 0..)]
|
||||
#[arg(value_parser = super::cli::validator::func_targets)]
|
||||
allow_funcs: Option<Targets<FuncTarget>>,
|
||||
|
||||
#[arg(
|
||||
help = "Allow all outbound network access. Optionally, you can provide a comma-separated list of targets to allow",
|
||||
long_help = r#"Allow all outbound network access. Optionally, you can provide a comma-separated list of targets to allow.
|
||||
Targets must be in the form of <host>[:<port>], <ipv4|ipv6>[/<mask>]. For example:
|
||||
- 'surrealdb.com', '127.0.0.1' or 'fd00::1' -> Match outbound connections to these hosts on any port
|
||||
- 'surrealdb.com:80', '127.0.0.1:80' or 'fd00::1:80' -> Match outbound connections to these hosts on port 80
|
||||
- '10.0.0.0/8' or 'fd00::/8' -> Match outbound connections to any host in these networks
|
||||
"#
|
||||
)]
|
||||
#[arg(env = "SURREAL_CAPS_ALLOW_NET", long, conflicts_with = "allow_all")]
|
||||
// If the arg is provided without value, then assume it's "", which gets parsed into Targets::All
|
||||
#[arg(default_value_os = "", default_missing_value_os = "", num_args = 0..)]
|
||||
#[arg(value_parser = super::cli::validator::net_targets)]
|
||||
allow_net: Option<Targets<NetTarget>>,
|
||||
|
||||
//
|
||||
// Deny
|
||||
//
|
||||
#[arg(help = "Deny all capabilities")]
|
||||
#[arg(env = "SURREAL_CAPS_DENY_ALL", short = 'D', long, conflicts_with = "allow_all")]
|
||||
#[arg(default_missing_value_os = "true", action = ArgAction::Set, num_args = 0..)]
|
||||
#[arg(default_value_t = false, hide_default_value = true)]
|
||||
deny_all: bool,
|
||||
|
||||
#[cfg(feature = "scripting")]
|
||||
#[arg(help = "Deny execution of scripting functions")]
|
||||
#[arg(env = "SURREAL_CAPS_DENY_SCRIPT", long, conflicts_with = "deny_all")]
|
||||
#[arg(default_missing_value_os = "true", action = ArgAction::Set, num_args = 0..)]
|
||||
#[arg(default_value_t = false, hide_default_value = true)]
|
||||
deny_scripting: bool,
|
||||
|
||||
#[arg(
|
||||
help = "Deny execution of functions. Optionally, you can provide a comma-separated list of function names to deny",
|
||||
long_help = r#"Deny execution of functions. Optionally, you can provide a comma-separated list of function names to deny.
|
||||
Function names must be in the form <family>[::<name>]. For example:
|
||||
- 'http' or 'http::*' -> Include all functions in the 'http' family
|
||||
- 'http::get' -> Include only the 'get' function in the 'http' family
|
||||
"#
|
||||
)]
|
||||
#[arg(env = "SURREAL_CAPS_DENY_FUNC", long, conflicts_with = "deny_all")]
|
||||
// If the arg is provided without value, then assume it's "", which gets parsed into Targets::All
|
||||
#[arg(default_missing_value_os = "", num_args = 0..)]
|
||||
#[arg(value_parser = super::cli::validator::func_targets)]
|
||||
deny_funcs: Option<Targets<FuncTarget>>,
|
||||
|
||||
#[arg(
|
||||
help = "Deny all outbound network access. Optionally, you can provide a comma-separated list of targets to deny",
|
||||
long_help = r#"Deny all outbound network access. Optionally, you can provide a comma-separated list of targets to deny.
|
||||
Targets must be in the form of <host>[:<port>], <ipv4|ipv6>[/<mask>]. For example:
|
||||
- 'surrealdb.com', '127.0.0.1' or 'fd00::1' -> Match outbound connections to these hosts on any port
|
||||
- 'surrealdb.com:80', '127.0.0.1:80' or 'fd00::1:80' -> Match outbound connections to these hosts on port 80
|
||||
- '10.0.0.0/8' or 'fd00::/8' -> Match outbound connections to any host in these networks
|
||||
"#
|
||||
)]
|
||||
#[arg(env = "SURREAL_CAPS_DENY_NET", long, conflicts_with = "deny_all")]
|
||||
// If the arg is provided without value, then assume it's "", which gets parsed into Targets::All
|
||||
#[arg(default_missing_value_os = "", num_args = 0..)]
|
||||
// If deny_all is true, disable this arg and assume a default of Targets::All
|
||||
#[arg(conflicts_with = "deny_all", default_value_if("deny_all", "true", ""))]
|
||||
#[arg(value_parser = super::cli::validator::net_targets)]
|
||||
deny_net: Option<Targets<NetTarget>>,
|
||||
}
|
||||
|
||||
impl DbsCapabilities {
|
||||
#[cfg(feature = "scripting")]
|
||||
fn get_scripting(&self) -> bool {
|
||||
(self.allow_all || self.allow_scripting) && !(self.deny_all || self.deny_scripting)
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "scripting"))]
|
||||
fn get_scripting(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn get_allow_funcs(&self) -> Targets<FuncTarget> {
|
||||
if self.deny_all || matches!(self.deny_funcs, Some(Targets::All)) {
|
||||
return Targets::None;
|
||||
}
|
||||
|
||||
if self.allow_all {
|
||||
return Targets::All;
|
||||
}
|
||||
|
||||
// If allow_funcs was not provided and allow_all is false, then don't allow anything (Targets::None)
|
||||
self.allow_funcs.clone().unwrap_or(Targets::None)
|
||||
}
|
||||
|
||||
fn get_allow_net(&self) -> Targets<NetTarget> {
|
||||
if self.deny_all || matches!(self.deny_net, Some(Targets::All)) {
|
||||
return Targets::None;
|
||||
}
|
||||
|
||||
if self.allow_all {
|
||||
return Targets::All;
|
||||
}
|
||||
|
||||
// If allow_net was not provided and allow_all is false, then don't allow anything (Targets::None)
|
||||
self.allow_net.clone().unwrap_or(Targets::None)
|
||||
}
|
||||
|
||||
fn get_deny_funcs(&self) -> Targets<FuncTarget> {
|
||||
if self.deny_all {
|
||||
return Targets::All;
|
||||
}
|
||||
|
||||
// If deny_funcs was not provided and deny_all is false, then don't deny anything (Targets::None)
|
||||
self.deny_funcs.clone().unwrap_or(Targets::None)
|
||||
}
|
||||
|
||||
fn get_deny_net(&self) -> Targets<NetTarget> {
|
||||
if self.deny_all {
|
||||
return Targets::All;
|
||||
}
|
||||
|
||||
// If deny_net was not provided and deny_all is false, then don't deny anything (Targets::None)
|
||||
self.deny_net.clone().unwrap_or(Targets::None)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<DbsCapabilities> for Capabilities {
|
||||
fn from(caps: DbsCapabilities) -> Self {
|
||||
Capabilities::default()
|
||||
.with_scripting(caps.get_scripting())
|
||||
.with_allow_funcs(caps.get_allow_funcs())
|
||||
.with_deny_funcs(caps.get_deny_funcs())
|
||||
.with_allow_net(caps.get_allow_net())
|
||||
.with_deny_net(caps.get_deny_net())
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn init(
|
||||
|
@ -34,6 +200,7 @@ pub async fn init(
|
|||
query_timeout,
|
||||
transaction_timeout,
|
||||
auth_enabled,
|
||||
caps,
|
||||
}: StartCommandDbsOptions,
|
||||
) -> Result<(), Error> {
|
||||
// Get local copy of options
|
||||
|
@ -61,7 +228,9 @@ pub async fn init(
|
|||
.with_strict_mode(strict_mode)
|
||||
.with_query_timeout(query_timeout)
|
||||
.with_transaction_timeout(transaction_timeout)
|
||||
.with_auth_enabled(auth_enabled);
|
||||
.with_auth_enabled(auth_enabled)
|
||||
.with_capabilities(caps.into());
|
||||
|
||||
dbs.bootstrap().await?;
|
||||
|
||||
if let Some(user) = opt.user.as_ref() {
|
||||
|
@ -81,13 +250,17 @@ pub async fn init(
|
|||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::str::FromStr;
|
||||
|
||||
use surrealdb::dbs::Session;
|
||||
use surrealdb::iam::verify::verify_creds;
|
||||
use surrealdb::kvs::Datastore;
|
||||
use test_log::test;
|
||||
use wiremock::{matchers::method, Mock, MockServer, ResponseTemplate};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
#[test(tokio::test)]
|
||||
async fn test_setup_superuser() {
|
||||
let ds = Datastore::new("memory").await.unwrap();
|
||||
let creds = Root {
|
||||
|
@ -133,4 +306,183 @@ mod tests {
|
|||
.hash
|
||||
)
|
||||
}
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn test_capabilities() {
|
||||
let server1 = {
|
||||
let s = MockServer::start().await;
|
||||
let get = Mock::given(method("GET"))
|
||||
.respond_with(ResponseTemplate::new(200).set_body_string("SUCCESS"))
|
||||
.expect(1);
|
||||
let head =
|
||||
Mock::given(method("HEAD")).respond_with(ResponseTemplate::new(200)).expect(1);
|
||||
|
||||
s.register(get).await;
|
||||
s.register(head).await;
|
||||
|
||||
s
|
||||
};
|
||||
|
||||
let server2 = {
|
||||
let s = MockServer::start().await;
|
||||
let get = Mock::given(method("GET"))
|
||||
.respond_with(ResponseTemplate::new(200).set_body_string("SUCCESS"))
|
||||
.expect(1);
|
||||
let head =
|
||||
Mock::given(method("HEAD")).respond_with(ResponseTemplate::new(200)).expect(0);
|
||||
|
||||
s.register(get).await;
|
||||
s.register(head).await;
|
||||
|
||||
s
|
||||
};
|
||||
|
||||
// (Capabilities, Query, Succeeds, Response Contains)
|
||||
let cases = vec![
|
||||
//
|
||||
// Functions and Networking are allowed
|
||||
//
|
||||
(
|
||||
Capabilities::default(),
|
||||
format!("RETURN http::get('{}')", server1.uri()),
|
||||
true,
|
||||
"SUCCESS".to_string(),
|
||||
),
|
||||
//
|
||||
// Scripting is allowed
|
||||
//
|
||||
(
|
||||
Capabilities::default(),
|
||||
"RETURN function() { return '1' }".to_string(),
|
||||
true,
|
||||
"1".to_string(),
|
||||
),
|
||||
//
|
||||
// Scripting is not allowed
|
||||
//
|
||||
(
|
||||
Capabilities::default().with_scripting(false),
|
||||
"RETURN function() { return '1' }".to_string(),
|
||||
false,
|
||||
"Scripting functions are not allowed".to_string(),
|
||||
),
|
||||
//
|
||||
// Some functions are not allowed
|
||||
//
|
||||
(
|
||||
Capabilities::default()
|
||||
.with_allow_funcs(Targets::<FuncTarget>::Some(
|
||||
[FuncTarget::from_str("http::*").unwrap()].into(),
|
||||
))
|
||||
.with_deny_funcs(Targets::<FuncTarget>::Some(
|
||||
[FuncTarget::from_str("http::get").unwrap()].into(),
|
||||
)),
|
||||
format!("RETURN http::get('{}')", server1.uri()),
|
||||
false,
|
||||
"Function 'http::get' is not allowed".to_string(),
|
||||
),
|
||||
(
|
||||
Capabilities::default()
|
||||
.with_allow_funcs(Targets::<FuncTarget>::Some(
|
||||
[FuncTarget::from_str("http::*").unwrap()].into(),
|
||||
))
|
||||
.with_deny_funcs(Targets::<FuncTarget>::Some(
|
||||
[FuncTarget::from_str("http::get").unwrap()].into(),
|
||||
)),
|
||||
format!("RETURN http::head('{}')", server1.uri()),
|
||||
true,
|
||||
"NONE".to_string(),
|
||||
),
|
||||
(
|
||||
Capabilities::default()
|
||||
.with_allow_funcs(Targets::<FuncTarget>::Some(
|
||||
[FuncTarget::from_str("http::*").unwrap()].into(),
|
||||
))
|
||||
.with_deny_funcs(Targets::<FuncTarget>::Some(
|
||||
[FuncTarget::from_str("http::get").unwrap()].into(),
|
||||
)),
|
||||
"RETURN string::len('a')".to_string(),
|
||||
false,
|
||||
"Function 'string::len' is not allowed".to_string(),
|
||||
),
|
||||
//
|
||||
// Some net targets are not allowed
|
||||
//
|
||||
(
|
||||
Capabilities::default()
|
||||
.with_allow_net(Targets::<NetTarget>::Some(
|
||||
[
|
||||
NetTarget::from_str(&server1.address().to_string()).unwrap(),
|
||||
NetTarget::from_str(&server2.address().to_string()).unwrap(),
|
||||
]
|
||||
.into(),
|
||||
))
|
||||
.with_deny_net(Targets::<NetTarget>::Some(
|
||||
[NetTarget::from_str(&server1.address().to_string()).unwrap()].into(),
|
||||
)),
|
||||
format!("RETURN http::get('{}')", server1.uri()),
|
||||
false,
|
||||
format!("Access to network target '{}/' is not allowed", server1.uri()),
|
||||
),
|
||||
(
|
||||
Capabilities::default()
|
||||
.with_allow_net(Targets::<NetTarget>::Some(
|
||||
[
|
||||
NetTarget::from_str(&server1.address().to_string()).unwrap(),
|
||||
NetTarget::from_str(&server2.address().to_string()).unwrap(),
|
||||
]
|
||||
.into(),
|
||||
))
|
||||
.with_deny_net(Targets::<NetTarget>::Some(
|
||||
[NetTarget::from_str(&server1.address().to_string()).unwrap()].into(),
|
||||
)),
|
||||
"RETURN http::get('http://1.1.1.1')".to_string(),
|
||||
false,
|
||||
"Access to network target 'http://1.1.1.1/' is not allowed".to_string(),
|
||||
),
|
||||
(
|
||||
Capabilities::default()
|
||||
.with_allow_net(Targets::<NetTarget>::Some(
|
||||
[
|
||||
NetTarget::from_str(&server1.address().to_string()).unwrap(),
|
||||
NetTarget::from_str(&server2.address().to_string()).unwrap(),
|
||||
]
|
||||
.into(),
|
||||
))
|
||||
.with_deny_net(Targets::<NetTarget>::Some(
|
||||
[NetTarget::from_str(&server1.address().to_string()).unwrap()].into(),
|
||||
)),
|
||||
format!("RETURN http::get('{}')", server2.uri()),
|
||||
true,
|
||||
"SUCCESS".to_string(),
|
||||
),
|
||||
];
|
||||
|
||||
for (idx, (caps, query, succeeds, contains)) in cases.into_iter().enumerate() {
|
||||
let ds = Datastore::new("memory").await.unwrap().with_capabilities(caps);
|
||||
|
||||
let sess = Session::owner();
|
||||
let res = ds.execute(&query, &sess, None).await;
|
||||
|
||||
let res = res.unwrap().remove(0).output();
|
||||
let res = if succeeds {
|
||||
assert!(res.is_ok(), "Unexpected error for test case {}: {:?}", idx, res);
|
||||
res.unwrap().to_string()
|
||||
} else {
|
||||
assert!(res.is_err(), "Unexpected success for test case {}: {:?}", idx, res);
|
||||
res.unwrap_err().to_string()
|
||||
};
|
||||
|
||||
assert!(
|
||||
res.contains(&contains),
|
||||
"Unexpected result for test case {}: expected to contain = `{}`, got `{}`",
|
||||
idx,
|
||||
contains,
|
||||
res
|
||||
);
|
||||
}
|
||||
|
||||
server1.verify().await;
|
||||
server2.verify().await;
|
||||
}
|
||||
}
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -54,16 +54,21 @@ impl Child {
|
|||
self.inner.as_mut().unwrap().try_wait()
|
||||
}
|
||||
|
||||
pub fn stdout(&self) -> String {
|
||||
std::fs::read_to_string(&self.stdout_path).expect("Failed to read the stdout file")
|
||||
}
|
||||
|
||||
pub fn stderr(&self) -> String {
|
||||
std::fs::read_to_string(&self.stderr_path).expect("Failed to read the stderr file")
|
||||
}
|
||||
|
||||
/// Read the child's stdout concatenated with its stderr. Returns Ok if the child
|
||||
/// returns successfully, Err otherwise.
|
||||
pub fn output(mut self) -> Result<String, String> {
|
||||
let status = self.inner.take().unwrap().wait().unwrap();
|
||||
|
||||
let mut buf =
|
||||
std::fs::read_to_string(&self.stdout_path).expect("Failed to read the stdout file");
|
||||
buf.push_str(
|
||||
&std::fs::read_to_string(&self.stderr_path).expect("Failed to read the stderr file"),
|
||||
);
|
||||
let mut buf = self.stdout();
|
||||
buf.push_str(&self.stderr());
|
||||
|
||||
// Cleanup files after reading them
|
||||
std::fs::remove_file(self.stdout_path.as_str()).unwrap();
|
||||
|
@ -101,8 +106,8 @@ pub fn run_internal<P: AsRef<Path>>(args: &str, current_dir: Option<P>) -> Child
|
|||
}
|
||||
|
||||
// Use local files instead of pipes to avoid deadlocks. See https://github.com/rust-lang/rust/issues/45572
|
||||
let stdout_path = tmp_file(format!("server-stdout-{}.log", rand::random::<u32>()).as_str());
|
||||
let stderr_path = tmp_file(format!("server-stderr-{}.log", rand::random::<u32>()).as_str());
|
||||
let stdout_path = tmp_file("server-stdout.log");
|
||||
let stderr_path = tmp_file("server-stderr.log");
|
||||
debug!("Redirecting output. args=`{args}` stdout={stdout_path} stderr={stderr_path})");
|
||||
let stdout = Stdio::from(File::create(&stdout_path).unwrap());
|
||||
let stderr = Stdio::from(File::create(&stderr_path).unwrap());
|
||||
|
@ -131,7 +136,7 @@ pub fn run_in_dir<P: AsRef<Path>>(args: &str, current_dir: P) -> Child {
|
|||
}
|
||||
|
||||
pub fn tmp_file(name: &str) -> String {
|
||||
let path = Path::new(env!("OUT_DIR")).join(name);
|
||||
let path = Path::new(env!("OUT_DIR")).join(format!("{}-{}", rand::random::<u32>(), name));
|
||||
path.to_string_lossy().into_owned()
|
||||
}
|
||||
|
||||
|
@ -140,6 +145,7 @@ pub struct StartServerArguments {
|
|||
pub tls: bool,
|
||||
pub wait_is_ready: bool,
|
||||
pub tick_interval: time::Duration,
|
||||
pub args: String,
|
||||
}
|
||||
|
||||
impl Default for StartServerArguments {
|
||||
|
@ -149,6 +155,7 @@ impl Default for StartServerArguments {
|
|||
tls: false,
|
||||
wait_is_ready: true,
|
||||
tick_interval: time::Duration::new(1, 0),
|
||||
args: String::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -171,6 +178,7 @@ pub async fn start_server(
|
|||
tls,
|
||||
wait_is_ready,
|
||||
tick_interval,
|
||||
args,
|
||||
}: StartServerArguments,
|
||||
) -> Result<(String, Child), Box<dyn Error>> {
|
||||
let mut rng = thread_rng();
|
||||
|
@ -178,7 +186,7 @@ pub async fn start_server(
|
|||
let port: u16 = rng.gen_range(13000..14000);
|
||||
let addr = format!("127.0.0.1:{port}");
|
||||
|
||||
let mut extra_args = String::default();
|
||||
let mut extra_args = args.clone();
|
||||
if tls {
|
||||
// Test the crt/key args but the keys are self signed so don't actually connect.
|
||||
let crt_path = tmp_file("crt.crt");
|
||||
|
@ -212,7 +220,7 @@ pub async fn start_server(
|
|||
}
|
||||
|
||||
// Wait 5 seconds for the server to start
|
||||
let mut interval = time::interval(time::Duration::from_millis(500));
|
||||
let mut interval = time::interval(time::Duration::from_millis(1000));
|
||||
info!("Waiting for server to start...");
|
||||
for _i in 0..10 {
|
||||
interval.tick().await;
|
||||
|
|
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
Loading…
Reference in a new issue