[security] Introduce the Datastore capabilities (#2489)

This commit is contained in:
Salvador Girones Gil 2023-08-23 21:26:31 +02:00 committed by GitHub
parent 5945146459
commit b5b6f6f1d4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
25 changed files with 4522 additions and 3188 deletions

2
Cargo.lock generated
View file

@ -5109,6 +5109,7 @@ dependencies = [
"tracing-subscriber",
"urlencoding",
"uuid",
"wiremock",
]
[[package]]
@ -5141,6 +5142,7 @@ dependencies = [
"geo 0.25.1",
"indexmap 1.9.3",
"indxdb",
"ipnet",
"jsonwebtoken",
"lexicmp",
"lru",

View file

@ -87,6 +87,7 @@ test-log = { version = "0.2.12", features = ["trace"] }
tokio-stream = { version = "0.1", features = ["net"] }
tokio-tungstenite = { version = "0.18.0" }
tonic = "0.8.3"
wiremock = "0.5.19"
[package.metadata.deb]
maintainer-scripts = "pkg/deb/"

View file

@ -24,20 +24,20 @@ args = ["clippy", "--all-targets", "--all-features", "--", "-D", "warnings"]
[tasks.ci-cli-integration]
category = "CI - INTEGRATION TESTS"
command = "cargo"
env = { RUST_LOG = "cli_integration=debug" }
args = ["test", "--locked", "--no-default-features", "--features", "storage-mem", "--workspace", "--test", "cli_integration", "--", "--nocapture"]
env = { RUST_LOG={ value = "cli_integration=debug", condition = { env_not_set = ["RUST_LOG"] } } }
args = ["test", "--locked", "--no-default-features", "--features", "storage-mem,http,scripting", "--workspace", "--test", "cli_integration", "--", "cli_integration", "--nocapture"]
[tasks.ci-http-integration]
category = "CI - INTEGRATION TESTS"
command = "cargo"
env = { RUST_LOG = "http_integration=debug" }
args = ["test", "--locked", "--no-default-features", "--features", "storage-mem", "--workspace", "--test", "http_integration", "--", "--nocapture"]
env = { RUST_LOG={ value = "http_integration=debug", condition = { env_not_set = ["RUST_LOG"] } } }
args = ["test", "--locked", "--no-default-features", "--features", "storage-mem", "--workspace", "--test", "http_integration", "--", "http_integration", "--nocapture"]
[tasks.ci-ws-integration]
category = "CI - INTEGRATION TESTS"
command = "cargo"
env = { RUST_LOG = "ws_integration=debug" }
args = ["test", "--locked", "--no-default-features", "--features", "storage-mem", "--workspace", "--test", "ws_integration", "--", "--nocapture"]
env = { RUST_LOG={ value = "ws_integration=debug", condition = { env_not_set = ["RUST_LOG"] } } }
args = ["test", "--locked", "--no-default-features", "--features", "storage-mem", "--workspace", "--test", "ws_integration", "--", "ws_integration", "--nocapture"]
[tasks.ci-workspace-coverage]
category = "CI - INTEGRATION TESTS"

View file

@ -69,25 +69,25 @@ args = ["bench", "--package", "surrealdb", "--no-default-features", "--features"
[tasks.serve]
category = "LOCAL USAGE"
command = "cargo"
args = ["run", "--no-default-features", "--features", "${DEV_FEATURES}", "--", "start"]
args = ["run", "--no-default-features", "--features", "${DEV_FEATURES}", "--", "start", "${@}"]
# SQL
[tasks.sql]
category = "LOCAL USAGE"
command = "cargo"
args = ["run", "--no-default-features", "--features", "${DEV_FEATURES}", "--", "sql", "--conn", "ws://0.0.0.0:8000", "--multi", "--pretty"]
args = ["run", "--no-default-features", "--features", "${DEV_FEATURES}", "--", "sql", "--conn", "ws://0.0.0.0:8000", "--multi", "--pretty", "${@}"]
# Quick
[tasks.quick]
category = "LOCAL USAGE"
command = "cargo"
args = ["build"]
args = ["build", "${@}"]
# Build
[tasks.build]
category = "LOCAL USAGE"
command = "cargo"
args = ["build", "--release"]
args = ["build", "--release", "${@}"]
# Default
[tasks.default]

View file

@ -76,6 +76,7 @@ fuzzy-matcher = "0.3.7"
geo = { version = "0.25.1", features = ["use-serde"] }
indexmap = { version = "1.9.3", features = ["serde"] }
indxdb = { version = "0.3.0", optional = true }
ipnet = "2.8.0"
js = { version = "0.4.0-beta.3", package = "rquickjs", features = ["array-buffer", "bindgen", "classes", "futures", "loader", "macro", "parallel", "properties","rust-alloc"], optional = true }
jsonwebtoken = "8.3.0"
lexicmp = "0.1.0"

View file

@ -1,16 +1,20 @@
use crate::ctx::canceller::Canceller;
use crate::ctx::reason::Reason;
use crate::dbs::Notification;
use crate::dbs::capabilities::{FuncTarget, NetTarget};
use crate::dbs::{Capabilities, Notification};
use crate::err::Error;
use crate::idx::planner::QueryPlanner;
use crate::sql::value::Value;
use channel::Sender;
use std::borrow::Cow;
use std::collections::HashMap;
use std::fmt::{self, Debug};
use std::str::FromStr;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Duration;
use trice::Instant;
use url::Url;
impl<'a> From<Value> for Cow<'a, Value> {
fn from(v: Value) -> Cow<'a, Value> {
@ -36,6 +40,8 @@ pub struct Context<'a> {
notifications: Option<Sender<Notification>>,
// An optional query planner
query_planner: Option<&'a QueryPlanner<'a>>,
// Capabilities
capabilities: Arc<Capabilities>,
}
impl<'a> Default for Context<'a> {
@ -65,6 +71,7 @@ impl<'a> Context<'a> {
cancelled: Arc::new(AtomicBool::new(false)),
notifications: None,
query_planner: None,
capabilities: Arc::new(Capabilities::default()),
}
}
@ -77,6 +84,7 @@ impl<'a> Context<'a> {
cancelled: Arc::new(AtomicBool::new(false)),
notifications: parent.notifications.clone(),
query_planner: parent.query_planner,
capabilities: parent.capabilities.clone(),
}
}
@ -190,4 +198,54 @@ impl<'a> Context<'a> {
.collect(),
)
}
//
// Capabilities
//
/// Set the capabilities for this context
pub fn add_capabilities(&mut self, caps: Capabilities) {
self.capabilities = Arc::new(caps);
}
/// Get the capabilities for this context
pub fn get_capabilities(&self) -> Arc<Capabilities> {
self.capabilities.clone()
}
/// Check if scripting is allowed
pub fn check_allowed_scripting(&self) -> Result<(), Error> {
if !self.capabilities.is_allowed_scripting() {
return Err(Error::ScriptingNotAllowed);
}
Ok(())
}
/// Check if a function is allowed
pub fn check_allowed_function(&self, target: &str) -> Result<(), Error> {
let func_target = FuncTarget::from_str(target).map_err(|_| Error::InvalidFunction {
name: target.to_string(),
message: "Invalid function name".to_string(),
})?;
if !self.capabilities.is_allowed_func(&func_target) {
return Err(Error::FunctionNotAllowed(target.to_string()));
}
Ok(())
}
/// Check if a network target is allowed
pub fn check_allowed_net(&self, target: &Url) -> Result<(), Error> {
match target.host() {
Some(host)
if self.capabilities.is_allowed_net(&NetTarget::Host(
host.to_owned(),
target.port_or_known_default(),
)) =>
{
Ok(())
}
_ => Err(Error::NetTargetNotAllowed(target.to_string())),
}
}
}

479
lib/src/dbs/capabilities.rs Normal file
View file

@ -0,0 +1,479 @@
use std::hash::Hash;
use std::net::IpAddr;
use std::{collections::HashSet, sync::Arc};
use ipnet::IpNet;
use url::Url;
pub trait Target {
fn matches(&self, elem: &Self) -> bool;
}
#[derive(Debug, Clone, Hash, Eq, PartialEq)]
pub struct FuncTarget(pub String, pub Option<String>);
impl Target for FuncTarget {
fn matches(&self, elem: &Self) -> bool {
match self {
Self(family, Some(name)) => {
family == &elem.0 && (elem.1.as_ref().is_some_and(|n| n == name))
}
Self(family, None) => family == &elem.0,
}
}
}
impl std::str::FromStr for FuncTarget {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
// 'family::*' is treated as 'family'. They both match all functions in the family.
let s = s.replace("::*", "");
let target = match s.split_once("::") {
Some((family, name)) => Self(family.to_string(), Some(name.to_string())),
_ => Self(s.to_string(), None),
};
Ok(target)
}
}
#[derive(Debug, Clone, Hash, Eq, PartialEq)]
pub enum NetTarget {
Host(url::Host<String>, Option<u16>),
IPNet(ipnet::IpNet),
}
impl Target for NetTarget {
fn matches(&self, elem: &Self) -> bool {
match self {
// If self contains a host and port, the elem must match both the host and port
Self::Host(host, Some(port)) => match elem {
Self::Host(_host, Some(_port)) => host == _host && port == _port,
_ => false,
},
// If self contains a host but no port, the elem must match the host only
Self::Host(host, None) => match elem {
Self::Host(_host, _) => host == _host,
_ => false,
},
// If self is an IPNet, it can match both an IPNet or a Host elem that contains an IPAddr
Self::IPNet(ipnet) => match elem {
Self::IPNet(_ipnet) => ipnet.contains(_ipnet),
Self::Host(host, _) => match host {
url::Host::Ipv4(ip) => ipnet.contains(&IpAddr::from(ip.to_owned())),
url::Host::Ipv6(ip) => ipnet.contains(&IpAddr::from(ip.to_owned())),
_ => false,
},
},
}
}
}
impl std::str::FromStr for NetTarget {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
// If it's a valid IPNet, return it
if let Ok(ipnet) = s.parse::<IpNet>() {
return Ok(NetTarget::IPNet(ipnet));
}
// If it's a valid IPAddr, return it as an IPNet
if let Ok(ipnet) = s.parse::<IpAddr>() {
return Ok(NetTarget::IPNet(IpNet::from(ipnet)));
}
// Parse the host and port parts from a string in the form of 'host' or 'host:port'
if let Ok(url) = Url::parse(format!("http://{s}").as_str()) {
if let Some(host) = url.host() {
// Url::parse will return port=None if the provided port was 80 (given we are using the http scheme). Get the original port from the string.
if let Some(Ok(port)) = s.split(':').last().map(|p| p.parse::<u16>()) {
return Ok(NetTarget::Host(host.to_owned(), Some(port)));
} else {
return Ok(NetTarget::Host(host.to_owned(), None));
}
}
}
Err(format!(
"The provided network target `{s}` is not a valid host, ip address or ip network"
))
}
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub enum Targets<T: Target + Hash + Eq + PartialEq> {
None,
Some(HashSet<T>),
All,
}
impl<T: Target + Hash + Eq + PartialEq + std::fmt::Debug> Targets<T> {
fn matches(&self, elem: &T) -> bool {
match self {
Self::None => false,
Self::All => true,
Self::Some(targets) => targets.iter().any(|t| t.matches(elem)),
}
}
}
/// Capabilities are used to limit what a user can do to the system.
///
/// Capabilities are split into 3 categories:
/// - Scripting: Whether or not the user can execute scripts
/// - Functions: Whether or not the user can execute certain functions
/// - Network: Whether or not the user can access certain network addresses
///
/// Capabilities are configured globally. By default, capabilities are configured as:
/// - Scripting: true
/// - Functions: All functions are allowed
/// - Network: All network addresses are allowed
///
/// The capabilities are defined using allow/deny lists for fine-grained control.
///
/// Examples:
/// - Allow all functions: `--allow-funcs`
/// - Allow all functions except `http.*`: `--allow-funcs --deny-funcs 'http.*'`
/// - Allow all network addresses except AWS metadata endpoint: `--allow-net --deny-net='169.254.169.254'`
///
#[derive(Debug, Clone)]
pub struct Capabilities {
scripting: bool,
allow_funcs: Arc<Targets<FuncTarget>>,
deny_funcs: Arc<Targets<FuncTarget>>,
allow_net: Arc<Targets<NetTarget>>,
deny_net: Arc<Targets<NetTarget>>,
}
impl Default for Capabilities {
// By default, enable all capabilities
fn default() -> Self {
Self {
scripting: true,
allow_funcs: Arc::new(Targets::All),
deny_funcs: Arc::new(Targets::None),
allow_net: Arc::new(Targets::All),
deny_net: Arc::new(Targets::None),
}
}
}
impl Capabilities {
pub fn with_scripting(mut self, scripting: bool) -> Self {
self.scripting = scripting;
self
}
pub fn with_allow_funcs(mut self, allow_funcs: Targets<FuncTarget>) -> Self {
self.allow_funcs = Arc::new(allow_funcs);
self
}
pub fn with_deny_funcs(mut self, deny_funcs: Targets<FuncTarget>) -> Self {
self.deny_funcs = Arc::new(deny_funcs);
self
}
pub fn with_allow_net(mut self, allow_net: Targets<NetTarget>) -> Self {
self.allow_net = Arc::new(allow_net);
self
}
pub fn with_deny_net(mut self, deny_net: Targets<NetTarget>) -> Self {
self.deny_net = Arc::new(deny_net);
self
}
pub fn is_allowed_scripting(&self) -> bool {
self.scripting
}
pub fn is_allowed_func(&self, target: &FuncTarget) -> bool {
self.allow_funcs.matches(target) && !self.deny_funcs.matches(target)
}
pub fn is_allowed_net(&self, target: &NetTarget) -> bool {
self.allow_net.matches(target) && !self.deny_net.matches(target)
}
}
#[cfg(test)]
mod tests {
use std::str::FromStr;
use test_log::test;
use super::*;
#[test]
fn test_func_target() {
assert!(FuncTarget::from_str("test")
.unwrap()
.matches(&FuncTarget::from_str("test").unwrap()));
assert!(!FuncTarget::from_str("test")
.unwrap()
.matches(&FuncTarget::from_str("test2").unwrap()));
assert!(!FuncTarget::from_str("test::")
.unwrap()
.matches(&FuncTarget::from_str("test").unwrap()));
assert!(FuncTarget::from_str("test::*")
.unwrap()
.matches(&FuncTarget::from_str("test::name").unwrap()));
assert!(!FuncTarget::from_str("test::*")
.unwrap()
.matches(&FuncTarget::from_str("test2::name").unwrap()));
assert!(FuncTarget::from_str("test::name")
.unwrap()
.matches(&FuncTarget::from_str("test::name").unwrap()));
assert!(!FuncTarget::from_str("test::name")
.unwrap()
.matches(&FuncTarget::from_str("test::name2").unwrap()));
}
#[test]
fn test_net_target() {
// IPNet IPv4
assert!(NetTarget::from_str("10.0.0.0/8")
.unwrap()
.matches(&NetTarget::from_str("10.0.1.0/24").unwrap()));
assert!(NetTarget::from_str("10.0.0.0/8")
.unwrap()
.matches(&NetTarget::from_str("10.0.1.2").unwrap()));
assert!(!NetTarget::from_str("10.0.0.0/8")
.unwrap()
.matches(&NetTarget::from_str("20.0.1.0/24").unwrap()));
assert!(!NetTarget::from_str("10.0.0.0/8")
.unwrap()
.matches(&NetTarget::from_str("20.0.1.0").unwrap()));
// IPNet IPv6
assert!(NetTarget::from_str("2001:db8::1")
.unwrap()
.matches(&NetTarget::from_str("2001:db8::1").unwrap()));
assert!(NetTarget::from_str("2001:db8::/32")
.unwrap()
.matches(&NetTarget::from_str("2001:db8::1").unwrap()));
assert!(NetTarget::from_str("2001:db8::/32")
.unwrap()
.matches(&NetTarget::from_str("2001:db8:abcd:12::/64").unwrap()));
assert!(!NetTarget::from_str("2001:db8::/32")
.unwrap()
.matches(&NetTarget::from_str("2001:db9::1").unwrap()));
assert!(!NetTarget::from_str("2001:db8::/32")
.unwrap()
.matches(&NetTarget::from_str("2001:db9:abcd:12::1/64").unwrap()));
// Host domain with and without port
assert!(NetTarget::from_str("example.com")
.unwrap()
.matches(&NetTarget::from_str("example.com").unwrap()));
assert!(NetTarget::from_str("example.com")
.unwrap()
.matches(&NetTarget::from_str("example.com:80").unwrap()));
assert!(!NetTarget::from_str("example.com")
.unwrap()
.matches(&NetTarget::from_str("www.example.com").unwrap()));
assert!(!NetTarget::from_str("example.com")
.unwrap()
.matches(&NetTarget::from_str("www.example.com:80").unwrap()));
assert!(NetTarget::from_str("example.com:80")
.unwrap()
.matches(&NetTarget::from_str("example.com:80").unwrap()));
assert!(!NetTarget::from_str("example.com:80")
.unwrap()
.matches(&NetTarget::from_str("example.com:443").unwrap()));
assert!(!NetTarget::from_str("example.com:80")
.unwrap()
.matches(&NetTarget::from_str("example.com").unwrap()));
// Host IPv4 with and without port
assert!(
NetTarget::from_str("127.0.0.1")
.unwrap()
.matches(&NetTarget::from_str("127.0.0.1").unwrap()),
"Host IPv4 without port matches itself"
);
assert!(
NetTarget::from_str("127.0.0.1")
.unwrap()
.matches(&NetTarget::from_str("127.0.0.1:80").unwrap()),
"Host IPv4 without port matches Host IPv4 with port"
);
assert!(
NetTarget::from_str("10.0.0.0/8")
.unwrap()
.matches(&NetTarget::from_str("10.0.0.1:80").unwrap()),
"IPv4 network matches Host IPv4 with port"
);
assert!(
NetTarget::from_str("127.0.0.1:80")
.unwrap()
.matches(&NetTarget::from_str("127.0.0.1:80").unwrap()),
"Host IPv4 with port matches itself"
);
assert!(
!NetTarget::from_str("127.0.0.1:80")
.unwrap()
.matches(&NetTarget::from_str("127.0.0.1").unwrap()),
"Host IPv4 with port does not match Host IPv4 without port"
);
assert!(
!NetTarget::from_str("127.0.0.1:80")
.unwrap()
.matches(&NetTarget::from_str("127.0.0.1:443").unwrap()),
"Host IPv4 with port does not match Host IPv4 with different port"
);
// Host IPv6 with and without port
assert!(
NetTarget::from_str("[2001:db8::1]")
.unwrap()
.matches(&NetTarget::from_str("[2001:db8::1]").unwrap()),
"Host IPv6 without port matches itself"
);
assert!(
NetTarget::from_str("[2001:db8::1]")
.unwrap()
.matches(&NetTarget::from_str("[2001:db8::1]:80").unwrap()),
"Host IPv6 without port matches Host IPv6 with port"
);
assert!(
NetTarget::from_str("2001:db8::1")
.unwrap()
.matches(&NetTarget::from_str("[2001:db8::1]:80").unwrap()),
"IPv6 addr matches Host IPv6 with port"
);
assert!(
NetTarget::from_str("2001:db8::/64")
.unwrap()
.matches(&NetTarget::from_str("[2001:db8::1]:80").unwrap()),
"IPv6 network matches Host IPv6 with port"
);
assert!(
NetTarget::from_str("[2001:db8::1]:80")
.unwrap()
.matches(&NetTarget::from_str("[2001:db8::1]:80").unwrap()),
"Host IPv6 with port matches itself"
);
assert!(
!NetTarget::from_str("[2001:db8::1]:80")
.unwrap()
.matches(&NetTarget::from_str("[2001:db8::1]").unwrap()),
"Host IPv6 with port does not match Host IPv6 without port"
);
assert!(
!NetTarget::from_str("[2001:db8::1]:80")
.unwrap()
.matches(&NetTarget::from_str("[2001:db8::1]:443").unwrap()),
"Host IPv6 with port does not match Host IPv6 with different port"
);
// Test invalid targets
assert!(NetTarget::from_str("exam^ple.com").is_err());
assert!(NetTarget::from_str("example.com:80:80").is_err());
assert!(NetTarget::from_str("11111.3.4.5").is_err());
assert!(NetTarget::from_str("2001:db8::1/129").is_err());
assert!(NetTarget::from_str("[2001:db8::1").is_err());
}
#[test]
fn test_targets() {
assert!(Targets::<NetTarget>::All.matches(&NetTarget::from_str("example.com").unwrap()));
assert!(Targets::<FuncTarget>::All.matches(&FuncTarget::from_str("http::get").unwrap()));
assert!(!Targets::<NetTarget>::None.matches(&NetTarget::from_str("example.com").unwrap()));
assert!(!Targets::<FuncTarget>::None.matches(&FuncTarget::from_str("http::get").unwrap()));
assert!(Targets::<NetTarget>::Some([NetTarget::from_str("example.com").unwrap()].into())
.matches(&NetTarget::from_str("example.com").unwrap()));
assert!(!Targets::<NetTarget>::Some([NetTarget::from_str("example.com").unwrap()].into())
.matches(&NetTarget::from_str("www.example.com").unwrap()));
assert!(Targets::<FuncTarget>::Some([FuncTarget::from_str("http::get").unwrap()].into())
.matches(&FuncTarget::from_str("http::get").unwrap()));
assert!(!Targets::<FuncTarget>::Some([FuncTarget::from_str("http::get").unwrap()].into())
.matches(&FuncTarget::from_str("http::post").unwrap()));
}
#[test]
fn test_capabilities() {
// When scripting is allowed
{
let caps = Capabilities::default().with_scripting(true);
assert!(caps.is_allowed_scripting());
}
// When scripting is denied
{
let caps = Capabilities::default().with_scripting(false);
assert!(!caps.is_allowed_scripting());
}
// When all nets are allowed
{
let caps = Capabilities::default()
.with_allow_net(Targets::<NetTarget>::All)
.with_deny_net(Targets::<NetTarget>::None);
assert!(caps.is_allowed_net(&NetTarget::from_str("example.com").unwrap()));
assert!(caps.is_allowed_net(&NetTarget::from_str("example.com:80").unwrap()));
}
// When all nets are allowed and denied at the same time
{
let caps = Capabilities::default()
.with_allow_net(Targets::<NetTarget>::All)
.with_deny_net(Targets::<NetTarget>::All);
assert!(!caps.is_allowed_net(&NetTarget::from_str("example.com").unwrap()));
assert!(!caps.is_allowed_net(&NetTarget::from_str("example.com:80").unwrap()));
}
// When some nets are allowed and some are denied, deny overrides the allow rules
{
let caps = Capabilities::default()
.with_allow_net(Targets::<NetTarget>::Some(
[NetTarget::from_str("example.com").unwrap()].into(),
))
.with_deny_net(Targets::<NetTarget>::Some(
[NetTarget::from_str("example.com:80").unwrap()].into(),
));
assert!(caps.is_allowed_net(&NetTarget::from_str("example.com").unwrap()));
assert!(caps.is_allowed_net(&NetTarget::from_str("example.com:443").unwrap()));
assert!(!caps.is_allowed_net(&NetTarget::from_str("example.com:80").unwrap()));
}
// When all funcs are allowed
{
let caps = Capabilities::default()
.with_allow_funcs(Targets::<FuncTarget>::All)
.with_deny_funcs(Targets::<FuncTarget>::None);
assert!(caps.is_allowed_func(&FuncTarget::from_str("http::get").unwrap()));
assert!(caps.is_allowed_func(&FuncTarget::from_str("http::post").unwrap()));
}
// When all funcs are allowed and denied at the same time
{
let caps = Capabilities::default()
.with_allow_funcs(Targets::<FuncTarget>::All)
.with_deny_funcs(Targets::<FuncTarget>::All);
assert!(!caps.is_allowed_func(&FuncTarget::from_str("http::get").unwrap()));
assert!(!caps.is_allowed_func(&FuncTarget::from_str("http::post").unwrap()));
}
// When some funcs are allowed and some are denied, deny overrides the allow rules
{
let caps = Capabilities::default()
.with_allow_funcs(Targets::<FuncTarget>::Some(
[FuncTarget::from_str("http::*").unwrap()].into(),
))
.with_deny_funcs(Targets::<FuncTarget>::Some(
[FuncTarget::from_str("http::post").unwrap()].into(),
));
assert!(caps.is_allowed_func(&FuncTarget::from_str("http::get").unwrap()));
assert!(caps.is_allowed_func(&FuncTarget::from_str("http::put").unwrap()));
assert!(!caps.is_allowed_func(&FuncTarget::from_str("http::post").unwrap()));
}
}
}

View file

@ -19,12 +19,14 @@ pub use self::options::*;
pub use self::response::*;
pub use self::session::*;
pub(crate) use self::capabilities::Capabilities;
pub(crate) use self::executor::*;
pub(crate) use self::iterator::*;
pub(crate) use self::statement::*;
pub(crate) use self::transaction::*;
pub(crate) use self::variables::*;
pub mod capabilities;
pub mod node;
mod processor;

View file

@ -1,3 +1,4 @@
use super::capabilities::Capabilities;
use crate::cnf;
use crate::dbs::Notification;
use crate::err::Error;
@ -46,6 +47,8 @@ pub struct Options {
pub futures: bool,
/// The channel over which we send notifications
pub sender: Option<Sender<Notification>>,
/// Datastore capabilities
pub capabilities: Arc<Capabilities>,
}
impl Default for Options {
@ -74,6 +77,7 @@ impl Options {
auth_enabled: true,
sender: None,
auth: Arc::new(Auth::default()),
capabilities: Arc::new(Capabilities::default()),
}
}
@ -208,6 +212,12 @@ impl Options {
self
}
/// Create a new Options object with the given Capabilities
pub fn with_capabilities(mut self, capabilities: Arc<Capabilities>) -> Self {
self.capabilities = capabilities;
self
}
// --------------------------------------------------
/// Create a new Options object for a subquery
@ -215,6 +225,7 @@ impl Options {
Self {
sender: self.sender.clone(),
auth: self.auth.clone(),
capabilities: self.capabilities.clone(),
ns: self.ns.clone(),
db: self.db.clone(),
perms,
@ -227,6 +238,7 @@ impl Options {
Self {
sender: self.sender.clone(),
auth: self.auth.clone(),
capabilities: self.capabilities.clone(),
ns: self.ns.clone(),
db: self.db.clone(),
force,
@ -239,6 +251,7 @@ impl Options {
Self {
sender: self.sender.clone(),
auth: self.auth.clone(),
capabilities: self.capabilities.clone(),
ns: self.ns.clone(),
db: self.db.clone(),
strict,
@ -251,6 +264,7 @@ impl Options {
Self {
sender: self.sender.clone(),
auth: self.auth.clone(),
capabilities: self.capabilities.clone(),
ns: self.ns.clone(),
db: self.db.clone(),
fields,
@ -263,6 +277,7 @@ impl Options {
Self {
sender: self.sender.clone(),
auth: self.auth.clone(),
capabilities: self.capabilities.clone(),
ns: self.ns.clone(),
db: self.db.clone(),
events,
@ -275,6 +290,7 @@ impl Options {
Self {
sender: self.sender.clone(),
auth: self.auth.clone(),
capabilities: self.capabilities.clone(),
ns: self.ns.clone(),
db: self.db.clone(),
tables,
@ -287,6 +303,7 @@ impl Options {
Self {
sender: self.sender.clone(),
auth: self.auth.clone(),
capabilities: self.capabilities.clone(),
ns: self.ns.clone(),
db: self.db.clone(),
indexes,
@ -299,6 +316,7 @@ impl Options {
Self {
sender: self.sender.clone(),
auth: self.auth.clone(),
capabilities: self.capabilities.clone(),
ns: self.ns.clone(),
db: self.db.clone(),
futures,
@ -311,6 +329,7 @@ impl Options {
Self {
sender: self.sender.clone(),
auth: self.auth.clone(),
capabilities: self.capabilities.clone(),
ns: self.ns.clone(),
db: self.db.clone(),
fields: !import,
@ -324,6 +343,7 @@ impl Options {
pub fn new_with_sender(&self, sender: Sender<Notification>) -> Self {
Self {
auth: self.auth.clone(),
capabilities: self.capabilities.clone(),
ns: self.ns.clone(),
db: self.db.clone(),
sender: Some(sender),
@ -351,6 +371,7 @@ impl Options {
Ok(Self {
sender: self.sender.clone(),
auth: self.auth.clone(),
capabilities: self.capabilities.clone(),
ns: self.ns.clone(),
db: self.db.clone(),
dive,

View file

@ -193,6 +193,10 @@ pub enum Error {
message: String,
},
/// The URL is invalid
#[error("The URL `{0}` is invalid")]
InvalidUrl(String),
/// The query timedout
#[error("The query was not executed because it exceeded the timeout")]
QueryTimedout,
@ -597,6 +601,21 @@ pub enum Error {
/// Represents an underlying IAM error
#[error("IAM error: {0}")]
IamError(#[from] IamError),
//
// Capabilities
//
/// Scripting is not allowed
#[error("Scripting functions are not allowed")]
ScriptingNotAllowed,
/// Function is not allowed
#[error("Function '{0}' is not allowed to be executed")]
FunctionNotAllowed(String),
/// Network target is not allowed
#[error("Access to network target '{0}' is not allowed")]
NetTargetNotAllowed(String),
}
impl From<Error> for String {

View file

@ -395,6 +395,7 @@ impl<'js> Request<'js> {
let url_str = url.to_string()?;
let url = Url::parse(&url_str)
.map_err(|e| Exception::throw_type(&ctx, &format!("failed to parse url: {e}")))?;
if !url.username().is_empty() || !url.password().map(str::is_empty).unwrap_or(true) {
// url cannot contain non empty username and passwords
return Err(Exception::throw_type(&ctx, "Url contained credentials."));

View file

@ -1,12 +1,15 @@
//! Contains the actual fetch function.
use crate::fnc::script::fetch::{
body::{Body, BodyData, BodyKind},
classes::{self, Request, RequestInit, Response, ResponseInit, ResponseType},
RequestError,
use crate::fnc::script::{
fetch::{
body::{Body, BodyData, BodyKind},
classes::{self, Request, RequestInit, Response, ResponseInit, ResponseType},
RequestError,
},
modules::surrealdb::query::{QueryContext, QUERY_DATA_PROP_NAME},
};
use futures::TryStreamExt;
use js::{function::Opt, Class, Ctx, Exception, Result, Value};
use js::{class::OwnedBorrow, function::Opt, Class, Ctx, Exception, Result, Value};
use reqwest::{
header::{HeaderValue, CONTENT_TYPE},
redirect, Body as ReqBody,
@ -27,6 +30,19 @@ pub async fn fetch<'js>(
let url = js_req.url;
// Check if the url is allowed to be fetched.
if ctx.globals().contains_key(QUERY_DATA_PROP_NAME)? {
let query_ctx =
ctx.globals().get::<_, OwnedBorrow<'js, QueryContext<'js>>>(QUERY_DATA_PROP_NAME)?;
query_ctx
.context
.check_allowed_net(&url)
.map_err(|e| Exception::throw_message(&ctx, &e.to_string()))?;
} else {
#[cfg(debug_assertions)]
panic!("Trying to fetch a URL but no QueryContext is present. QueryContext is required for checking if the URL is allowed to be fetched.")
}
let req = reqwest::Request::new(js_req.init.method, url.clone());
// SurrealDB Implementation keeps all javascript parts inside the context::with scope so this
@ -118,146 +134,3 @@ pub async fn fetch<'js>(
};
Ok(response)
}
#[cfg(test)]
mod test {
use crate::fnc::script::fetch::test::create_test_context;
#[tokio::test]
async fn test_fetch_get() {
use js::{promise::Promise, CatchResultExt};
use wiremock::{
matchers::{header, method, path},
Mock, MockServer, ResponseTemplate,
};
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/hello"))
.and(header("some-header", "some-value"))
.respond_with(ResponseTemplate::new(200).set_body_string("some body once told me"))
.expect(1)
.mount(&server)
.await;
let server_ref = &server;
create_test_context!(ctx => {
ctx.globals().set("SERVER_URL",server_ref.uri()).unwrap();
ctx.eval::<Promise<()>,_>(r#"
(async () => {
let res = await fetch(SERVER_URL + '/hello',{
headers: {
"some-header": "some-value",
}
});
assert.seq(res.status,200);
let body = await res.text();
assert.seq(body,'some body once told me');
})()
"#).catch(&ctx).unwrap().await.catch(&ctx).unwrap()
})
.await;
server.verify().await;
}
#[tokio::test]
async fn test_fetch_put() {
use js::{promise::Promise, CatchResultExt};
use wiremock::{
matchers::{body_string, header, method, path},
Mock, MockServer, ResponseTemplate,
};
let server = MockServer::start().await;
Mock::given(method("PUT"))
.and(path("/hello"))
.and(header("some-header", "some-value"))
.and(body_string("some text"))
.respond_with(ResponseTemplate::new(201).set_body_string("some body once told me"))
.expect(1)
.mount(&server)
.await;
let server_ref = &server;
create_test_context!(ctx => {
ctx.globals().set("SERVER_URL",server_ref.uri()).unwrap();
ctx.eval::<Promise<()>,_>(r#"
(async () => {
let res = await fetch(SERVER_URL + '/hello',{
method: "PuT",
headers: {
"some-header": "some-value",
},
body: "some text",
});
assert.seq(res.status,201);
assert(res.ok);
let body = await res.text();
assert.seq(body,'some body once told me');
})()
"#).catch(&ctx).unwrap().await.catch(&ctx).unwrap()
})
.await;
server.verify().await;
}
#[tokio::test]
async fn test_fetch_error() {
use js::{promise::Promise, CatchResultExt};
use wiremock::{
matchers::{body_string, header, method, path},
Mock, MockServer, ResponseTemplate,
};
let server = MockServer::start().await;
Mock::given(method("PROPPATCH"))
.and(path("/hello"))
.and(header("some-header", "some-value"))
.and(body_string("some text"))
.respond_with(ResponseTemplate::new(500).set_body_json(serde_json::json!({
"foo": "bar",
"baz": 2,
})))
.expect(1)
.mount(&server)
.await;
let server_ref = &server;
create_test_context!(ctx => {
ctx.globals().set("SERVER_URL",server_ref.uri()).unwrap();
ctx.eval::<Promise<()>,_>(r#"
(async () => {
let req = new Request(SERVER_URL + '/hello',{
method: "PROPPATCH",
headers: {
"some-header": "some-value",
},
body: "some text",
})
let res = await fetch(req);
assert.seq(res.status,500);
assert(!res.ok);
let body = await res.json();
assert(body.foo !== undefined, "body.foo not defined");
assert(body.baz !== undefined, "body.foo not defined");
assert.seq(body.foo, "bar");
assert.seq(body.baz, 2);
})()
"#).catch(&ctx).unwrap().await.catch(&ctx).unwrap()
})
.await;
server.verify().await;
}
}

View file

@ -16,3 +16,6 @@ mod fetch;
mod fetch_stub;
#[cfg(not(feature = "http"))]
use self::fetch_stub as fetch;
#[cfg(test)]
mod tests;

View file

@ -0,0 +1,203 @@
use std::str::FromStr;
use wiremock::{
matchers::{body_string, header, method, path},
Mock, MockServer, ResponseTemplate,
};
use crate::{
dbs::{
capabilities::{NetTarget, Targets},
Capabilities, Session,
},
kvs::Datastore,
};
#[tokio::test]
async fn test_fetch_get() {
// Prepare mock server
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/hello"))
.and(header("some-header", "some-value"))
.respond_with(ResponseTemplate::new(200).set_body_string("some body once told me"))
.expect(1)
.mount(&server)
.await;
// Execute test
let ds = Datastore::new("memory").await.unwrap();
let sess = Session::owner();
let sql = format!(
r#"
RETURN function() {{
let res = await fetch('{}/hello',{{
headers: {{
"some-header": "some-value",
}}
}});
let body = await res.text();
return {{ status: res.status, body: body }};
}}
"#,
server.uri()
);
let res = ds.execute(&sql, &sess, None).await;
let res = res.unwrap().remove(0).output().unwrap();
server.verify().await;
assert_eq!(
res.to_string(),
"{ body: 'some body once told me', status: 200f }",
"Unexpected result: {:?}",
res
);
}
#[tokio::test]
async fn test_fetch_put() {
// Prepare mock server
let server = MockServer::start().await;
Mock::given(method("PUT"))
.and(path("/hello"))
.and(header("some-header", "some-value"))
.and(body_string("some text"))
.respond_with(ResponseTemplate::new(201).set_body_string("some body once told me"))
.expect(1)
.mount(&server)
.await;
// Execute test
let ds = Datastore::new("memory").await.unwrap();
let sess = Session::owner();
let sql = format!(
r#"
RETURN function() {{
let res = await fetch('{}/hello',{{
method: "PuT",
headers: {{
"some-header": "some-value",
}},
body: "some text",
}});
let body = await res.text();
return {{ status: res.status, body: body }};
}}
"#,
server.uri()
);
let res = ds.execute(&sql, &sess, None).await;
let res = res.unwrap().remove(0).output().unwrap();
server.verify().await;
assert_eq!(
res.to_string(),
"{ body: 'some body once told me', status: 201f }",
"Unexpected result: {:?}",
res
);
}
#[tokio::test]
async fn test_fetch_error() {
// Prepare mock server
let server = MockServer::start().await;
Mock::given(method("PROPPATCH"))
.and(path("/hello"))
.and(header("some-header", "some-value"))
.and(body_string("some text"))
.respond_with(ResponseTemplate::new(500).set_body_json(serde_json::json!({
"foo": "bar",
"baz": 2,
})))
.expect(1)
.mount(&server)
.await;
// Execute test
let ds = Datastore::new("memory").await.unwrap();
let sess = Session::owner();
let sql = format!(
r#"
RETURN function() {{
let res = await fetch('{}/hello',{{
method: "PROPPATCH",
headers: {{
"some-header": "some-value",
}},
body: "some text",
}});
let body = await res.text();
return {{ status: res.status, body: body }};
}}
"#,
server.uri()
);
let res = ds.execute(&sql, &sess, None).await;
let res = res.unwrap().remove(0).output().unwrap();
server.verify().await;
assert_eq!(
res.to_string(),
"{ body: '{\"foo\":\"bar\",\"baz\":2}', status: 500f }",
"Unexpected result: {:?}",
res
);
}
#[tokio::test]
async fn test_fetch_denied() {
// Prepare mock server
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/hello"))
.and(header("some-header", "some-value"))
.respond_with(ResponseTemplate::new(200).set_body_string("some body once told me"))
.expect(0)
.mount(&server)
.await;
// Execute test
let ds = Datastore::new("memory").await.unwrap().with_capabilities(
Capabilities::default().with_deny_net(Targets::Some(
[NetTarget::from_str(&server.address().to_string()).unwrap()].into(),
)),
);
let sess = Session::owner();
let sql = format!(
r#"
RETURN function() {{
let res = await fetch('{}/hello',{{
headers: {{
"some-header": "some-value",
}}
}});
let body = await res.text();
return {{ status: res.status, body: body }};
}}
"#,
server.uri()
);
let res = ds.execute(&sql, &sess, None).await;
let res = res.unwrap().remove(0).output().unwrap_err();
server.verify().await;
assert!(
res.to_string()
.contains(&format!("Access to network target '{}/hello' is not allowed", server.uri())),
"Unexpected result: {:?}",
res
);
}

View file

@ -0,0 +1 @@
mod fetch;

View file

@ -6,6 +6,7 @@ use crate::sql::value::Value;
use crate::sql::{json, Bytes};
use reqwest::header::CONTENT_TYPE;
use reqwest::{Client, RequestBuilder, Response};
use url::Url;
pub(crate) fn uri_is_valid(uri: &str) -> bool {
reqwest::Url::parse(uri).is_ok()
@ -46,10 +47,13 @@ async fn decode_response(res: Response) -> Result<Value, Error> {
}
pub async fn head(ctx: &Context<'_>, uri: Strand, opts: impl Into<Object>) -> Result<Value, Error> {
// Check if the URI is valid and allowed
let url = Url::parse(&uri).map_err(|_| Error::InvalidUrl(uri.to_string()))?;
ctx.check_allowed_net(&url)?;
// Set a default client with no timeout
let cli = Client::builder().build()?;
// Start a new HEAD request
let mut req = cli.head(uri.as_str());
let mut req = cli.head(url);
// Add the User-Agent header
if cfg!(not(target_arch = "wasm32")) {
req = req.header("User-Agent", "SurrealDB");
@ -72,10 +76,13 @@ pub async fn head(ctx: &Context<'_>, uri: Strand, opts: impl Into<Object>) -> Re
}
pub async fn get(ctx: &Context<'_>, uri: Strand, opts: impl Into<Object>) -> Result<Value, Error> {
// Check if the URI is valid and allowed
let url = Url::parse(&uri).map_err(|_| Error::InvalidUrl(uri.to_string()))?;
ctx.check_allowed_net(&url)?;
// Set a default client with no timeout
let cli = Client::builder().build()?;
// Start a new GET request
let mut req = cli.get(uri.as_str());
let mut req = cli.get(url);
// Add the User-Agent header
if cfg!(not(target_arch = "wasm32")) {
req = req.header("User-Agent", "SurrealDB");
@ -100,10 +107,13 @@ pub async fn put(
body: Value,
opts: impl Into<Object>,
) -> Result<Value, Error> {
// Check if the URI is valid and allowed
let url = Url::parse(&uri).map_err(|_| Error::InvalidUrl(uri.to_string()))?;
ctx.check_allowed_net(&url)?;
// Set a default client with no timeout
let cli = Client::builder().build()?;
// Start a new GET request
let mut req = cli.put(uri.as_str());
let mut req = cli.put(url);
// Add the User-Agent header
if cfg!(not(target_arch = "wasm32")) {
req = req.header("User-Agent", "SurrealDB");
@ -130,10 +140,13 @@ pub async fn post(
body: Value,
opts: impl Into<Object>,
) -> Result<Value, Error> {
// Check if the URI is valid and allowed
let url = Url::parse(&uri).map_err(|_| Error::InvalidUrl(uri.to_string()))?;
ctx.check_allowed_net(&url)?;
// Set a default client with no timeout
let cli = Client::builder().build()?;
// Start a new GET request
let mut req = cli.post(uri.as_str());
let mut req = cli.post(url);
// Add the User-Agent header
if cfg!(not(target_arch = "wasm32")) {
req = req.header("User-Agent", "SurrealDB");
@ -160,10 +173,13 @@ pub async fn patch(
body: Value,
opts: impl Into<Object>,
) -> Result<Value, Error> {
// Check if the URI is valid and allowed
let url = Url::parse(&uri).map_err(|_| Error::InvalidUrl(uri.to_string()))?;
ctx.check_allowed_net(&url)?;
// Set a default client with no timeout
let cli = Client::builder().build()?;
// Start a new GET request
let mut req = cli.patch(uri.as_str());
let mut req = cli.patch(url);
// Add the User-Agent header
if cfg!(not(target_arch = "wasm32")) {
req = req.header("User-Agent", "SurrealDB");
@ -189,10 +205,13 @@ pub async fn delete(
uri: Strand,
opts: impl Into<Object>,
) -> Result<Value, Error> {
// Check if the URI is valid and allowed
let url = Url::parse(&uri).map_err(|_| Error::InvalidUrl(uri.to_string()))?;
ctx.check_allowed_net(&url)?;
// Set a default client with no timeout
let cli = Client::builder().build()?;
// Start a new GET request
let mut req = cli.delete(uri.as_str());
let mut req = cli.delete(url);
// Add the User-Agent header
if cfg!(not(target_arch = "wasm32")) {
req = req.header("User-Agent", "SurrealDB");

View file

@ -3,6 +3,7 @@ use crate::cf;
use crate::ctx::Context;
use crate::dbs::node::Timestamp;
use crate::dbs::Attach;
use crate::dbs::Capabilities;
use crate::dbs::Executor;
use crate::dbs::Notification;
use crate::dbs::Options;
@ -62,6 +63,8 @@ pub struct Datastore {
notification_channel: Option<(Sender<Notification>, Receiver<Notification>)>,
// Whether this datastore authentication is enabled. When disabled, anonymous actors have owner-level access.
auth_enabled: bool,
// Capabilities for this datastore
capabilities: Capabilities,
}
#[allow(clippy::large_enum_variant)]
@ -260,6 +263,7 @@ impl Datastore {
vso: Arc::new(Mutex::new(vs::Oracle::systime_counter())),
notification_channel: None,
auth_enabled: false,
capabilities: Capabilities::default(),
})
}
@ -326,6 +330,12 @@ impl Datastore {
self
}
/// Configure Datastore capabilities
pub fn with_capabilities(mut self, caps: Capabilities) -> Self {
self.capabilities = caps;
self
}
/// Creates a new datastore instance
///
/// Use this for clustered environments.
@ -752,6 +762,7 @@ impl Datastore {
let mut exe = Executor::new(self);
// Create a default context
let mut ctx = Context::default();
ctx.add_capabilities(self.capabilities.clone());
// Set the global query timeout
if let Some(timeout) = self.query_timeout {
ctx.add_timeout(timeout);
@ -808,6 +819,8 @@ impl Datastore {
let txn = Arc::new(Mutex::new(txn));
// Create a default context
let mut ctx = Context::default();
// Set context capabilities
ctx.add_capabilities(self.capabilities.clone());
// Set the global query timeout
if let Some(timeout) = self.query_timeout {
ctx.add_timeout(timeout);

View file

@ -160,12 +160,16 @@ impl Function {
// Process the function type
match self {
Self::Normal(s, x) => {
// Check this function is allowed
ctx.check_allowed_function(s)?;
// Compute the function arguments
let a = try_join_all(x.iter().map(|v| v.compute(ctx, opt, txn, doc))).await?;
// Run the normal function
fnc::run(ctx, txn, doc, s, a).await
}
Self::Custom(s, x) => {
// Check this function is allowed
ctx.check_allowed_function(format!("fn::{s}").as_str())?;
// Get the function definition
let val = {
// Claim transaction
@ -198,6 +202,8 @@ impl Function {
Self::Script(s, x) => {
#[cfg(feature = "scripting")]
{
// Check if scripting is allowed
ctx.check_allowed_scripting()?;
// Compute the function arguments
let a = try_join_all(x.iter().map(|v| v.compute(ctx, opt, txn, doc))).await?;
// Run the script function

View file

@ -10,7 +10,6 @@ use crate::err::Error;
use crate::net::{self, client_ip::ClientIp};
use crate::node;
use clap::Args;
use ipnet::IpNet;
use opentelemetry::Context as TelemetryContext;
use std::net::SocketAddr;
use std::path::PathBuf;
@ -24,8 +23,35 @@ pub struct StartCommandArguments {
#[arg(default_value = "memory")]
#[arg(value_parser = super::validator::path_valid)]
path: String,
#[arg(help = "The logging level for the database server")]
#[arg(env = "SURREAL_LOG", short = 'l', long = "log")]
#[arg(default_value = "info")]
#[arg(value_parser = CustomEnvFilterParser::new())]
log: CustomEnvFilter,
#[arg(help = "Whether to hide the startup banner")]
#[arg(env = "SURREAL_NO_BANNER", long)]
#[arg(default_value_t = false)]
no_banner: bool,
#[arg(help = "Encryption key to use for on-disk encryption")]
#[arg(env = "SURREAL_KEY", short = 'k', long = "key")]
#[arg(value_parser = super::validator::key_valid)]
#[arg(hide = true)] // Not currently in use
key: Option<String>,
#[arg(
help = "The username for the initial database root user. Only if no other root user exists"
help = "The interval at which to run node agent tick (including garbage collection)",
help_heading = "Database"
)]
#[arg(env = "SURREAL_TICK_INTERVAL", long = "tick-interval", value_parser = super::validator::duration)]
#[arg(default_value = "10s")]
tick_interval: Duration,
//
// Authentication
//
#[arg(
help = "The username for the initial database root user. Only if no other root user exists",
help_heading = "Authentication"
)]
#[arg(
env = "SURREAL_USER",
@ -36,7 +62,8 @@ pub struct StartCommandArguments {
)]
username: Option<String>,
#[arg(
help = "The password for the initial database root user. Only if no other root user exists"
help = "The password for the initial database root user. Only if no other root user exists",
help_heading = "Authentication"
)]
#[arg(
env = "SURREAL_PASS",
@ -46,10 +73,20 @@ pub struct StartCommandArguments {
requires = "username"
)]
password: Option<String>,
#[arg(help = "The allowed networks for master authentication")]
#[arg(env = "SURREAL_ADDR", long = "addr")]
#[arg(default_value = "127.0.0.1/32")]
allowed_networks: Vec<IpNet>,
//
// Datastore connection
//
#[command(next_help_heading = "Datastore connection")]
#[command(flatten)]
kvs: Option<StartCommandRemoteTlsOptions>,
//
// HTTP Server
//
#[command(next_help_heading = "HTTP server")]
#[command(flatten)]
web: Option<StartCommandWebTlsOptions>,
#[arg(help = "The method of detecting the client's IP address")]
#[arg(env = "SURREAL_CLIENT_IP", long)]
#[arg(default_value = "socket", value_enum)]
@ -58,29 +95,13 @@ pub struct StartCommandArguments {
#[arg(env = "SURREAL_BIND", short = 'b', long = "bind")]
#[arg(default_value = "0.0.0.0:8000")]
listen_addresses: Vec<SocketAddr>,
#[arg(help = "The interval at which to run node agent tick (including garbage collection)")]
#[arg(env = "SURREAL_TICK_INTERVAL", long = "tick-interval", value_parser = super::validator::duration)]
#[arg(default_value = "10s")]
tick_interval: Duration,
//
// Database options
//
#[command(flatten)]
#[command(next_help_heading = "Database")]
dbs: StartCommandDbsOptions,
#[arg(help = "Encryption key to use for on-disk encryption")]
#[arg(env = "SURREAL_KEY", short = 'k', long = "key")]
#[arg(value_parser = super::validator::key_valid)]
key: Option<String>,
#[command(flatten)]
kvs: Option<StartCommandRemoteTlsOptions>,
#[command(flatten)]
web: Option<StartCommandWebTlsOptions>,
#[arg(help = "The logging level for the database server")]
#[arg(env = "SURREAL_LOG", short = 'l', long = "log")]
#[arg(default_value = "info")]
#[arg(value_parser = CustomEnvFilterParser::new())]
log: CustomEnvFilter,
#[arg(help = "Whether to hide the startup banner")]
#[arg(env = "SURREAL_NO_BANNER", long)]
#[arg(default_value_t = false)]
no_banner: bool,
}
#[derive(Args, Debug)]

View file

@ -1,3 +1,4 @@
use std::collections::HashSet;
#[cfg(feature = "has-storage")]
use std::{
path::{Path, PathBuf},
@ -5,6 +6,8 @@ use std::{
time::Duration,
};
use surrealdb::dbs::capabilities::{FuncTarget, NetTarget, Targets};
pub(crate) mod parser;
#[cfg(feature = "has-storage")]
@ -76,3 +79,77 @@ pub(crate) fn key_valid(v: &str) -> Result<String, String> {
pub(crate) fn duration(v: &str) -> Result<Duration, String> {
surrealdb::sql::Duration::from_str(v).map(|d| d.0).map_err(|_| String::from("invalid duration"))
}
pub(crate) fn net_targets(value: &str) -> Result<Targets<NetTarget>, String> {
if ["*", ""].contains(&value) {
return Ok(Targets::All);
}
let mut result = HashSet::new();
for target in value.split(',').filter(|s| !s.is_empty()) {
result.insert(NetTarget::from_str(target)?);
}
Ok(Targets::Some(result))
}
pub(crate) fn func_targets(value: &str) -> Result<Targets<FuncTarget>, String> {
if ["*", ""].contains(&value) {
return Ok(Targets::All);
}
let mut result = HashSet::new();
for target in value.split(',').filter(|s| !s.is_empty()) {
result.insert(FuncTarget::from_str(target)?);
}
Ok(Targets::Some(result))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_func_targets() {
assert_eq!(func_targets("*").unwrap(), Targets::<FuncTarget>::All);
assert_eq!(func_targets("").unwrap(), Targets::<FuncTarget>::All);
assert_eq!(
func_targets("foo").unwrap(),
Targets::<FuncTarget>::Some(vec!["foo".parse().unwrap()].into_iter().collect())
);
assert_eq!(
func_targets("foo,bar").unwrap(),
Targets::<FuncTarget>::Some(
vec!["foo".parse().unwrap(), "bar".parse().unwrap()].into_iter().collect()
)
);
}
#[test]
fn test_net_targets() {
assert_eq!(net_targets("*").unwrap(), Targets::<NetTarget>::All);
assert_eq!(net_targets("").unwrap(), Targets::<NetTarget>::All);
assert_eq!(
net_targets("example.com").unwrap(),
Targets::<NetTarget>::Some(vec!["example.com".parse().unwrap()].into_iter().collect())
);
assert_eq!(
net_targets("127.0.0.1:80,[2001:db8::1]:443,2001:db8::1").unwrap(),
Targets::<NetTarget>::Some(
vec![
"127.0.0.1:80".parse().unwrap(),
"[2001:db8::1]:443".parse().unwrap(),
"2001:db8::1".parse().unwrap()
]
.into_iter()
.collect()
)
);
assert!(net_targets("127777.0.0.1").is_err());
assert!(net_targets("127.0.0.1,127777.0.0.1").is_err());
}
}

View file

@ -1,8 +1,9 @@
use crate::cli::CF;
use crate::err::Error;
use clap::Args;
use clap::{ArgAction, Args};
use std::sync::OnceLock;
use std::time::Duration;
use surrealdb::dbs::capabilities::{Capabilities, FuncTarget, NetTarget, Targets};
use surrealdb::kvs::Datastore;
use surrealdb::opt::auth::Root;
@ -22,10 +23,175 @@ pub struct StartCommandDbsOptions {
#[arg(env = "SURREAL_TRANSACTION_TIMEOUT", long)]
#[arg(value_parser = super::cli::validator::duration)]
transaction_timeout: Option<Duration>,
#[arg(help = "Whether to enable authentication")]
#[arg(help = "Whether to enable authentication", help_heading = "Authentication")]
#[arg(env = "SURREAL_AUTH", long = "auth")]
#[arg(default_value_t = false)]
auth_enabled: bool,
#[command(flatten)]
#[command(next_help_heading = "Capabilities")]
caps: DbsCapabilities,
}
#[derive(Args, Debug)]
struct DbsCapabilities {
//
// Allow
//
#[arg(help = "Allow all capabilities")]
#[arg(env = "SURREAL_CAPS_ALLOW_ALL", short = 'A', long, conflicts_with = "deny_all")]
#[arg(default_missing_value_os = "true", action = ArgAction::Set, num_args = 0..)]
#[arg(default_value_t = false, hide_default_value = true)]
allow_all: bool,
#[cfg(feature = "scripting")]
#[arg(help = "Allow execution of scripting functions")]
#[arg(env = "SURREAL_CAPS_ALLOW_SCRIPT", long, conflicts_with = "allow_all")]
#[arg(default_missing_value_os = "true", action = ArgAction::Set, num_args = 0..)]
#[arg(default_value_t = true, hide_default_value = true)]
allow_scripting: bool,
#[arg(
help = "Allow execution of functions. Optionally, you can provide a comma-separated list of function names to allow",
long_help = r#"Allow execution of functions. Optionally, you can provide a comma-separated list of function names to allow.
Function names must be in the form <family>[::<name>]. For example:
- 'http' or 'http::*' -> Include all functions in the 'http' family
- 'http::get' -> Include only the 'get' function in the 'http' family
"#
)]
#[arg(env = "SURREAL_CAPS_ALLOW_FUNC", long, conflicts_with = "allow_all")]
// If the arg is provided without value, then assume it's "", which gets parsed into Targets::All
#[arg(default_value_os = "", default_missing_value_os = "", num_args = 0..)]
#[arg(value_parser = super::cli::validator::func_targets)]
allow_funcs: Option<Targets<FuncTarget>>,
#[arg(
help = "Allow all outbound network access. Optionally, you can provide a comma-separated list of targets to allow",
long_help = r#"Allow all outbound network access. Optionally, you can provide a comma-separated list of targets to allow.
Targets must be in the form of <host>[:<port>], <ipv4|ipv6>[/<mask>]. For example:
- 'surrealdb.com', '127.0.0.1' or 'fd00::1' -> Match outbound connections to these hosts on any port
- 'surrealdb.com:80', '127.0.0.1:80' or 'fd00::1:80' -> Match outbound connections to these hosts on port 80
- '10.0.0.0/8' or 'fd00::/8' -> Match outbound connections to any host in these networks
"#
)]
#[arg(env = "SURREAL_CAPS_ALLOW_NET", long, conflicts_with = "allow_all")]
// If the arg is provided without value, then assume it's "", which gets parsed into Targets::All
#[arg(default_value_os = "", default_missing_value_os = "", num_args = 0..)]
#[arg(value_parser = super::cli::validator::net_targets)]
allow_net: Option<Targets<NetTarget>>,
//
// Deny
//
#[arg(help = "Deny all capabilities")]
#[arg(env = "SURREAL_CAPS_DENY_ALL", short = 'D', long, conflicts_with = "allow_all")]
#[arg(default_missing_value_os = "true", action = ArgAction::Set, num_args = 0..)]
#[arg(default_value_t = false, hide_default_value = true)]
deny_all: bool,
#[cfg(feature = "scripting")]
#[arg(help = "Deny execution of scripting functions")]
#[arg(env = "SURREAL_CAPS_DENY_SCRIPT", long, conflicts_with = "deny_all")]
#[arg(default_missing_value_os = "true", action = ArgAction::Set, num_args = 0..)]
#[arg(default_value_t = false, hide_default_value = true)]
deny_scripting: bool,
#[arg(
help = "Deny execution of functions. Optionally, you can provide a comma-separated list of function names to deny",
long_help = r#"Deny execution of functions. Optionally, you can provide a comma-separated list of function names to deny.
Function names must be in the form <family>[::<name>]. For example:
- 'http' or 'http::*' -> Include all functions in the 'http' family
- 'http::get' -> Include only the 'get' function in the 'http' family
"#
)]
#[arg(env = "SURREAL_CAPS_DENY_FUNC", long, conflicts_with = "deny_all")]
// If the arg is provided without value, then assume it's "", which gets parsed into Targets::All
#[arg(default_missing_value_os = "", num_args = 0..)]
#[arg(value_parser = super::cli::validator::func_targets)]
deny_funcs: Option<Targets<FuncTarget>>,
#[arg(
help = "Deny all outbound network access. Optionally, you can provide a comma-separated list of targets to deny",
long_help = r#"Deny all outbound network access. Optionally, you can provide a comma-separated list of targets to deny.
Targets must be in the form of <host>[:<port>], <ipv4|ipv6>[/<mask>]. For example:
- 'surrealdb.com', '127.0.0.1' or 'fd00::1' -> Match outbound connections to these hosts on any port
- 'surrealdb.com:80', '127.0.0.1:80' or 'fd00::1:80' -> Match outbound connections to these hosts on port 80
- '10.0.0.0/8' or 'fd00::/8' -> Match outbound connections to any host in these networks
"#
)]
#[arg(env = "SURREAL_CAPS_DENY_NET", long, conflicts_with = "deny_all")]
// If the arg is provided without value, then assume it's "", which gets parsed into Targets::All
#[arg(default_missing_value_os = "", num_args = 0..)]
// If deny_all is true, disable this arg and assume a default of Targets::All
#[arg(conflicts_with = "deny_all", default_value_if("deny_all", "true", ""))]
#[arg(value_parser = super::cli::validator::net_targets)]
deny_net: Option<Targets<NetTarget>>,
}
impl DbsCapabilities {
#[cfg(feature = "scripting")]
fn get_scripting(&self) -> bool {
(self.allow_all || self.allow_scripting) && !(self.deny_all || self.deny_scripting)
}
#[cfg(not(feature = "scripting"))]
fn get_scripting(&self) -> bool {
false
}
fn get_allow_funcs(&self) -> Targets<FuncTarget> {
if self.deny_all || matches!(self.deny_funcs, Some(Targets::All)) {
return Targets::None;
}
if self.allow_all {
return Targets::All;
}
// If allow_funcs was not provided and allow_all is false, then don't allow anything (Targets::None)
self.allow_funcs.clone().unwrap_or(Targets::None)
}
fn get_allow_net(&self) -> Targets<NetTarget> {
if self.deny_all || matches!(self.deny_net, Some(Targets::All)) {
return Targets::None;
}
if self.allow_all {
return Targets::All;
}
// If allow_net was not provided and allow_all is false, then don't allow anything (Targets::None)
self.allow_net.clone().unwrap_or(Targets::None)
}
fn get_deny_funcs(&self) -> Targets<FuncTarget> {
if self.deny_all {
return Targets::All;
}
// If deny_funcs was not provided and deny_all is false, then don't deny anything (Targets::None)
self.deny_funcs.clone().unwrap_or(Targets::None)
}
fn get_deny_net(&self) -> Targets<NetTarget> {
if self.deny_all {
return Targets::All;
}
// If deny_net was not provided and deny_all is false, then don't deny anything (Targets::None)
self.deny_net.clone().unwrap_or(Targets::None)
}
}
impl From<DbsCapabilities> for Capabilities {
fn from(caps: DbsCapabilities) -> Self {
Capabilities::default()
.with_scripting(caps.get_scripting())
.with_allow_funcs(caps.get_allow_funcs())
.with_deny_funcs(caps.get_deny_funcs())
.with_allow_net(caps.get_allow_net())
.with_deny_net(caps.get_deny_net())
}
}
pub async fn init(
@ -34,6 +200,7 @@ pub async fn init(
query_timeout,
transaction_timeout,
auth_enabled,
caps,
}: StartCommandDbsOptions,
) -> Result<(), Error> {
// Get local copy of options
@ -61,7 +228,9 @@ pub async fn init(
.with_strict_mode(strict_mode)
.with_query_timeout(query_timeout)
.with_transaction_timeout(transaction_timeout)
.with_auth_enabled(auth_enabled);
.with_auth_enabled(auth_enabled)
.with_capabilities(caps.into());
dbs.bootstrap().await?;
if let Some(user) = opt.user.as_ref() {
@ -81,13 +250,17 @@ pub async fn init(
#[cfg(test)]
mod tests {
use std::str::FromStr;
use surrealdb::dbs::Session;
use surrealdb::iam::verify::verify_creds;
use surrealdb::kvs::Datastore;
use test_log::test;
use wiremock::{matchers::method, Mock, MockServer, ResponseTemplate};
use super::*;
#[tokio::test]
#[test(tokio::test)]
async fn test_setup_superuser() {
let ds = Datastore::new("memory").await.unwrap();
let creds = Root {
@ -133,4 +306,183 @@ mod tests {
.hash
)
}
#[test(tokio::test)]
async fn test_capabilities() {
let server1 = {
let s = MockServer::start().await;
let get = Mock::given(method("GET"))
.respond_with(ResponseTemplate::new(200).set_body_string("SUCCESS"))
.expect(1);
let head =
Mock::given(method("HEAD")).respond_with(ResponseTemplate::new(200)).expect(1);
s.register(get).await;
s.register(head).await;
s
};
let server2 = {
let s = MockServer::start().await;
let get = Mock::given(method("GET"))
.respond_with(ResponseTemplate::new(200).set_body_string("SUCCESS"))
.expect(1);
let head =
Mock::given(method("HEAD")).respond_with(ResponseTemplate::new(200)).expect(0);
s.register(get).await;
s.register(head).await;
s
};
// (Capabilities, Query, Succeeds, Response Contains)
let cases = vec![
//
// Functions and Networking are allowed
//
(
Capabilities::default(),
format!("RETURN http::get('{}')", server1.uri()),
true,
"SUCCESS".to_string(),
),
//
// Scripting is allowed
//
(
Capabilities::default(),
"RETURN function() { return '1' }".to_string(),
true,
"1".to_string(),
),
//
// Scripting is not allowed
//
(
Capabilities::default().with_scripting(false),
"RETURN function() { return '1' }".to_string(),
false,
"Scripting functions are not allowed".to_string(),
),
//
// Some functions are not allowed
//
(
Capabilities::default()
.with_allow_funcs(Targets::<FuncTarget>::Some(
[FuncTarget::from_str("http::*").unwrap()].into(),
))
.with_deny_funcs(Targets::<FuncTarget>::Some(
[FuncTarget::from_str("http::get").unwrap()].into(),
)),
format!("RETURN http::get('{}')", server1.uri()),
false,
"Function 'http::get' is not allowed".to_string(),
),
(
Capabilities::default()
.with_allow_funcs(Targets::<FuncTarget>::Some(
[FuncTarget::from_str("http::*").unwrap()].into(),
))
.with_deny_funcs(Targets::<FuncTarget>::Some(
[FuncTarget::from_str("http::get").unwrap()].into(),
)),
format!("RETURN http::head('{}')", server1.uri()),
true,
"NONE".to_string(),
),
(
Capabilities::default()
.with_allow_funcs(Targets::<FuncTarget>::Some(
[FuncTarget::from_str("http::*").unwrap()].into(),
))
.with_deny_funcs(Targets::<FuncTarget>::Some(
[FuncTarget::from_str("http::get").unwrap()].into(),
)),
"RETURN string::len('a')".to_string(),
false,
"Function 'string::len' is not allowed".to_string(),
),
//
// Some net targets are not allowed
//
(
Capabilities::default()
.with_allow_net(Targets::<NetTarget>::Some(
[
NetTarget::from_str(&server1.address().to_string()).unwrap(),
NetTarget::from_str(&server2.address().to_string()).unwrap(),
]
.into(),
))
.with_deny_net(Targets::<NetTarget>::Some(
[NetTarget::from_str(&server1.address().to_string()).unwrap()].into(),
)),
format!("RETURN http::get('{}')", server1.uri()),
false,
format!("Access to network target '{}/' is not allowed", server1.uri()),
),
(
Capabilities::default()
.with_allow_net(Targets::<NetTarget>::Some(
[
NetTarget::from_str(&server1.address().to_string()).unwrap(),
NetTarget::from_str(&server2.address().to_string()).unwrap(),
]
.into(),
))
.with_deny_net(Targets::<NetTarget>::Some(
[NetTarget::from_str(&server1.address().to_string()).unwrap()].into(),
)),
"RETURN http::get('http://1.1.1.1')".to_string(),
false,
"Access to network target 'http://1.1.1.1/' is not allowed".to_string(),
),
(
Capabilities::default()
.with_allow_net(Targets::<NetTarget>::Some(
[
NetTarget::from_str(&server1.address().to_string()).unwrap(),
NetTarget::from_str(&server2.address().to_string()).unwrap(),
]
.into(),
))
.with_deny_net(Targets::<NetTarget>::Some(
[NetTarget::from_str(&server1.address().to_string()).unwrap()].into(),
)),
format!("RETURN http::get('{}')", server2.uri()),
true,
"SUCCESS".to_string(),
),
];
for (idx, (caps, query, succeeds, contains)) in cases.into_iter().enumerate() {
let ds = Datastore::new("memory").await.unwrap().with_capabilities(caps);
let sess = Session::owner();
let res = ds.execute(&query, &sess, None).await;
let res = res.unwrap().remove(0).output();
let res = if succeeds {
assert!(res.is_ok(), "Unexpected error for test case {}: {:?}", idx, res);
res.unwrap().to_string()
} else {
assert!(res.is_err(), "Unexpected success for test case {}: {:?}", idx, res);
res.unwrap_err().to_string()
};
assert!(
res.contains(&contains),
"Unexpected result for test case {}: expected to contain = `{}`, got `{}`",
idx,
contains,
res
);
}
server1.verify().await;
server2.verify().await;
}
}

File diff suppressed because it is too large Load diff

View file

@ -54,16 +54,21 @@ impl Child {
self.inner.as_mut().unwrap().try_wait()
}
pub fn stdout(&self) -> String {
std::fs::read_to_string(&self.stdout_path).expect("Failed to read the stdout file")
}
pub fn stderr(&self) -> String {
std::fs::read_to_string(&self.stderr_path).expect("Failed to read the stderr file")
}
/// Read the child's stdout concatenated with its stderr. Returns Ok if the child
/// returns successfully, Err otherwise.
pub fn output(mut self) -> Result<String, String> {
let status = self.inner.take().unwrap().wait().unwrap();
let mut buf =
std::fs::read_to_string(&self.stdout_path).expect("Failed to read the stdout file");
buf.push_str(
&std::fs::read_to_string(&self.stderr_path).expect("Failed to read the stderr file"),
);
let mut buf = self.stdout();
buf.push_str(&self.stderr());
// Cleanup files after reading them
std::fs::remove_file(self.stdout_path.as_str()).unwrap();
@ -101,8 +106,8 @@ pub fn run_internal<P: AsRef<Path>>(args: &str, current_dir: Option<P>) -> Child
}
// Use local files instead of pipes to avoid deadlocks. See https://github.com/rust-lang/rust/issues/45572
let stdout_path = tmp_file(format!("server-stdout-{}.log", rand::random::<u32>()).as_str());
let stderr_path = tmp_file(format!("server-stderr-{}.log", rand::random::<u32>()).as_str());
let stdout_path = tmp_file("server-stdout.log");
let stderr_path = tmp_file("server-stderr.log");
debug!("Redirecting output. args=`{args}` stdout={stdout_path} stderr={stderr_path})");
let stdout = Stdio::from(File::create(&stdout_path).unwrap());
let stderr = Stdio::from(File::create(&stderr_path).unwrap());
@ -131,7 +136,7 @@ pub fn run_in_dir<P: AsRef<Path>>(args: &str, current_dir: P) -> Child {
}
pub fn tmp_file(name: &str) -> String {
let path = Path::new(env!("OUT_DIR")).join(name);
let path = Path::new(env!("OUT_DIR")).join(format!("{}-{}", rand::random::<u32>(), name));
path.to_string_lossy().into_owned()
}
@ -140,6 +145,7 @@ pub struct StartServerArguments {
pub tls: bool,
pub wait_is_ready: bool,
pub tick_interval: time::Duration,
pub args: String,
}
impl Default for StartServerArguments {
@ -149,6 +155,7 @@ impl Default for StartServerArguments {
tls: false,
wait_is_ready: true,
tick_interval: time::Duration::new(1, 0),
args: String::default(),
}
}
}
@ -171,6 +178,7 @@ pub async fn start_server(
tls,
wait_is_ready,
tick_interval,
args,
}: StartServerArguments,
) -> Result<(String, Child), Box<dyn Error>> {
let mut rng = thread_rng();
@ -178,7 +186,7 @@ pub async fn start_server(
let port: u16 = rng.gen_range(13000..14000);
let addr = format!("127.0.0.1:{port}");
let mut extra_args = String::default();
let mut extra_args = args.clone();
if tls {
// Test the crt/key args but the keys are self signed so don't actually connect.
let crt_path = tmp_file("crt.crt");
@ -212,7 +220,7 @@ pub async fn start_server(
}
// Wait 5 seconds for the server to start
let mut interval = time::interval(time::Duration::from_millis(500));
let mut interval = time::interval(time::Duration::from_millis(1000));
info!("Waiting for server to start...");
for _i in 0..10 {
interval.tick().await;

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff