Refactor function target and net target parsing and matching. (#4150)

This commit is contained in:
Mees Delzenne 2024-06-12 12:00:51 +02:00 committed by GitHub
parent e1123ae6d6
commit 2184e80f45
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 125 additions and 93 deletions

View file

@ -42,7 +42,6 @@ pub static INSECURE_FORWARD_RECORD_ACCESS_ERRORS: Lazy<bool> =
#[cfg(any(
feature = "kv-mem",
feature = "kv-surrealkv",
feature = "kv-file",
feature = "kv-rocksdb",
feature = "kv-fdb",
feature = "kv-tikv",

View file

@ -1,6 +1,5 @@
use crate::ctx::canceller::Canceller;
use crate::ctx::reason::Reason;
use crate::dbs::capabilities::FuncTarget;
#[cfg(feature = "http")]
use crate::dbs::capabilities::NetTarget;
use crate::dbs::{Capabilities, Notification, Transaction};
@ -18,13 +17,11 @@ use std::fmt::{self, Debug};
#[cfg(any(
feature = "kv-mem",
feature = "kv-surrealkv",
feature = "kv-file",
feature = "kv-rocksdb",
feature = "kv-fdb",
feature = "kv-tikv",
))]
use std::path::PathBuf;
use std::str::FromStr;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Duration;
@ -68,7 +65,6 @@ pub struct Context<'a> {
#[cfg(any(
feature = "kv-mem",
feature = "kv-surrealkv",
feature = "kv-file",
feature = "kv-rocksdb",
feature = "kv-fdb",
feature = "kv-tikv",
@ -104,7 +100,6 @@ impl<'a> Context<'a> {
#[cfg(any(
feature = "kv-mem",
feature = "kv-surrealkv",
feature = "kv-file",
feature = "kv-rocksdb",
feature = "kv-fdb",
feature = "kv-tikv",
@ -125,7 +120,6 @@ impl<'a> Context<'a> {
#[cfg(any(
feature = "kv-mem",
feature = "kv-surrealkv",
feature = "kv-file",
feature = "kv-rocksdb",
feature = "kv-fdb",
feature = "kv-tikv",
@ -154,7 +148,6 @@ impl<'a> Context<'a> {
#[cfg(any(
feature = "kv-mem",
feature = "kv-surrealkv",
feature = "kv-file",
feature = "kv-rocksdb",
feature = "kv-fdb",
feature = "kv-tikv",
@ -180,7 +173,6 @@ impl<'a> Context<'a> {
#[cfg(any(
feature = "kv-mem",
feature = "kv-surrealkv",
feature = "kv-file",
feature = "kv-rocksdb",
feature = "kv-fdb",
feature = "kv-tikv",
@ -318,7 +310,6 @@ impl<'a> Context<'a> {
#[cfg(any(
feature = "kv-mem",
feature = "kv-surrealkv",
feature = "kv-file",
feature = "kv-rocksdb",
feature = "kv-fdb",
feature = "kv-tikv",
@ -380,12 +371,7 @@ impl<'a> Context<'a> {
/// Check if a function is allowed
pub fn check_allowed_function(&self, target: &str) -> Result<(), Error> {
let func_target = FuncTarget::from_str(target).map_err(|_| Error::InvalidFunction {
name: target.to_string(),
message: "Invalid function name".to_string(),
})?;
if !self.capabilities.allows_function(&func_target) {
if !self.capabilities.allows_function_name(target) {
return Err(Error::FunctionNotAllowed(target.to_string()));
}
Ok(())

View file

@ -1,3 +1,4 @@
use std::fmt;
use std::hash::Hash;
use std::net::IpAddr;
use std::{collections::HashSet, sync::Arc};
@ -5,16 +6,16 @@ use std::{collections::HashSet, sync::Arc};
use ipnet::IpNet;
use url::Url;
pub trait Target {
fn matches(&self, elem: &Self) -> bool;
pub trait Target<Item: ?Sized = Self> {
fn matches(&self, elem: &Item) -> bool;
}
#[derive(Debug, Clone, Hash, Eq, PartialEq)]
#[non_exhaustive]
pub struct FuncTarget(pub String, pub Option<String>);
impl std::fmt::Display for FuncTarget {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
impl fmt::Display for FuncTarget {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match &self.1 {
Some(name) => write!(f, "{}:{}", self.0, name),
None => write!(f, "{}::*", self.0),
@ -23,7 +24,7 @@ impl std::fmt::Display for FuncTarget {
}
impl Target for FuncTarget {
fn matches(&self, elem: &Self) -> bool {
fn matches(&self, elem: &FuncTarget) -> bool {
match self {
Self(family, Some(name)) => {
family == &elem.0 && (elem.1.as_ref().is_some_and(|n| n == name))
@ -33,18 +34,75 @@ impl Target for FuncTarget {
}
}
impl Target<str> for FuncTarget {
fn matches(&self, elem: &str) -> bool {
if let Some(x) = self.1.as_ref() {
let Some((f, r)) = elem.split_once("::") else {
return false;
};
f == self.0 && r == x
} else {
let f = elem.split_once("::").map(|(f, _)| f).unwrap_or(elem);
f == self.0
}
}
}
#[derive(Debug, Clone)]
pub enum ParseFuncTargetError {
InvalidWildcardFamily,
InvalidName,
}
impl std::error::Error for ParseFuncTargetError {}
impl fmt::Display for ParseFuncTargetError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match *self {
ParseFuncTargetError::InvalidName => {
write!(f, "invalid function target name")
}
ParseFuncTargetError::InvalidWildcardFamily => {
write!(
f,
"invalid function target wildcard family, only first part of function can be wildcarded"
)
}
}
}
}
impl std::str::FromStr for FuncTarget {
type Err = String;
type Err = ParseFuncTargetError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
// 'family::*' is treated as 'family'. They both match all functions in the family.
let s = s.replace("::*", "");
let s = s.trim();
let target = match s.split_once("::") {
Some((family, name)) => Self(family.to_string(), Some(name.to_string())),
_ => Self(s.to_string(), None),
};
Ok(target)
if s.is_empty() {
return Err(ParseFuncTargetError::InvalidName);
}
if let Some(family) = s.strip_suffix("::*") {
if family.contains("::") {
return Err(ParseFuncTargetError::InvalidWildcardFamily);
}
if !family.bytes().all(|x| x.is_ascii_alphanumeric()) {
return Err(ParseFuncTargetError::InvalidName);
}
return Ok(FuncTarget(family.to_string(), None));
}
if !s.bytes().all(|x| x.is_ascii_alphanumeric() || x == b':') {
return Err(ParseFuncTargetError::InvalidName);
}
if let Some((first, rest)) = s.split_once("::") {
Ok(FuncTarget(first.to_string(), Some(rest.to_string())))
} else {
Ok(FuncTarget(s.to_string(), None))
}
}
}
@ -56,8 +114,8 @@ pub enum NetTarget {
}
// impl display
impl std::fmt::Display for NetTarget {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
impl fmt::Display for NetTarget {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Host(host, Some(port)) => write!(f, "{}:{}", host, port),
Self::Host(host, None) => write!(f, "{}", host),
@ -92,8 +150,18 @@ impl Target for NetTarget {
}
}
#[derive(Debug)]
pub struct ParseNetTargetError;
impl std::error::Error for ParseNetTargetError {}
impl fmt::Display for ParseNetTargetError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "The provided network target is not a valid host, ip address or ip network")
}
}
impl std::str::FromStr for NetTarget {
type Err = String;
type Err = ParseNetTargetError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
// If it's a valid IPNet, return it
@ -118,22 +186,24 @@ impl std::str::FromStr for NetTarget {
}
}
Err(format!(
"The provided network target `{s}` is not a valid host, ip address or ip network"
))
Err(ParseNetTargetError)
}
}
#[derive(Debug, Clone, Eq, PartialEq)]
#[non_exhaustive]
pub enum Targets<T: Target + Hash + Eq + PartialEq> {
pub enum Targets<T: Hash + Eq + PartialEq> {
None,
Some(HashSet<T>),
All,
}
impl<T: Target + Hash + Eq + PartialEq + std::fmt::Debug + std::fmt::Display> Targets<T> {
fn matches(&self, elem: &T) -> bool {
impl<T: Hash + Eq + PartialEq + fmt::Debug + fmt::Display> Targets<T> {
fn matches<S>(&self, elem: &S) -> bool
where
S: ?Sized,
T: Target<S>,
{
match self {
Self::None => false,
Self::All => true,
@ -142,8 +212,8 @@ impl<T: Target + Hash + Eq + PartialEq + std::fmt::Debug + std::fmt::Display> Ta
}
}
impl<T: Target + Hash + Eq + PartialEq + std::fmt::Display> std::fmt::Display for Targets<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
impl<T: Target + Hash + Eq + PartialEq + fmt::Display> fmt::Display for Targets<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::None => write!(f, "none"),
Self::All => write!(f, "all"),
@ -169,8 +239,8 @@ pub struct Capabilities {
deny_net: Arc<Targets<NetTarget>>,
}
impl std::fmt::Display for Capabilities {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
impl fmt::Display for Capabilities {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"scripting={}, guest_access={}, live_query_notifications={}, allow_funcs={}, deny_funcs={}, allow_net={}, deny_net={}",
@ -255,10 +325,17 @@ impl Capabilities {
self.live_query_notifications
}
// function is public API so we can't remove it, but you should prefer allows_function_name
pub fn allows_function(&self, target: &FuncTarget) -> bool {
self.allow_funcs.matches(target) && !self.deny_funcs.matches(target)
}
// doc hidden so we don't extend api in the library.
#[doc(hidden)]
pub fn allows_function_name(&self, target: &str) -> bool {
self.allow_funcs.matches(target) && !self.deny_funcs.matches(target)
}
pub fn allows_network_target(&self, target: &NetTarget) -> bool {
self.allow_net.matches(target) && !self.deny_net.matches(target)
}
@ -271,32 +348,26 @@ mod tests {
use super::*;
#[test]
fn test_invalid_func_target() {
FuncTarget::from_str("te::*st").unwrap_err();
FuncTarget::from_str("\0::st").unwrap_err();
FuncTarget::from_str("").unwrap_err();
FuncTarget::from_str("❤️").unwrap_err();
}
#[test]
fn test_func_target() {
assert!(FuncTarget::from_str("test")
.unwrap()
.matches(&FuncTarget::from_str("test").unwrap()));
assert!(!FuncTarget::from_str("test")
.unwrap()
.matches(&FuncTarget::from_str("test2").unwrap()));
assert!(FuncTarget::from_str("test").unwrap().matches("test"));
assert!(!FuncTarget::from_str("test").unwrap().matches("test2"));
assert!(!FuncTarget::from_str("test::")
.unwrap()
.matches(&FuncTarget::from_str("test").unwrap()));
assert!(!FuncTarget::from_str("test::").unwrap().matches("test"));
assert!(FuncTarget::from_str("test::*")
.unwrap()
.matches(&FuncTarget::from_str("test::name").unwrap()));
assert!(!FuncTarget::from_str("test::*")
.unwrap()
.matches(&FuncTarget::from_str("test2::name").unwrap()));
assert!(FuncTarget::from_str("test::*").unwrap().matches("test::name"));
assert!(!FuncTarget::from_str("test::*").unwrap().matches("test2::name"));
assert!(FuncTarget::from_str("test::name")
.unwrap()
.matches(&FuncTarget::from_str("test::name").unwrap()));
assert!(!FuncTarget::from_str("test::name")
.unwrap()
.matches(&FuncTarget::from_str("test::name2").unwrap()));
assert!(FuncTarget::from_str("test::name").unwrap().matches("test::name"));
assert!(!FuncTarget::from_str("test::name").unwrap().matches("test::name2"));
}
#[test]
@ -448,17 +519,17 @@ mod tests {
#[test]
fn test_targets() {
assert!(Targets::<NetTarget>::All.matches(&NetTarget::from_str("example.com").unwrap()));
assert!(Targets::<FuncTarget>::All.matches(&FuncTarget::from_str("http::get").unwrap()));
assert!(Targets::<FuncTarget>::All.matches("http::get"));
assert!(!Targets::<NetTarget>::None.matches(&NetTarget::from_str("example.com").unwrap()));
assert!(!Targets::<FuncTarget>::None.matches(&FuncTarget::from_str("http::get").unwrap()));
assert!(!Targets::<FuncTarget>::None.matches("http::get"));
assert!(Targets::<NetTarget>::Some([NetTarget::from_str("example.com").unwrap()].into())
.matches(&NetTarget::from_str("example.com").unwrap()));
assert!(!Targets::<NetTarget>::Some([NetTarget::from_str("example.com").unwrap()].into())
.matches(&NetTarget::from_str("www.example.com").unwrap()));
assert!(Targets::<FuncTarget>::Some([FuncTarget::from_str("http::get").unwrap()].into())
.matches(&FuncTarget::from_str("http::get").unwrap()));
.matches("http::get"));
assert!(!Targets::<FuncTarget>::Some([FuncTarget::from_str("http::get").unwrap()].into())
.matches(&FuncTarget::from_str("http::post").unwrap()));
.matches("http::post"));
}
#[test]

View file

@ -297,7 +297,6 @@ impl Iterator {
#[cfg(any(
feature = "kv-mem",
feature = "kv-surrealkv",
feature = "kv-file",
feature = "kv-rocksdb",
feature = "kv-fdb",
feature = "kv-tikv",

View file

@ -4,7 +4,6 @@ use crate::dbs::plan::Explanation;
#[cfg(any(
feature = "kv-mem",
feature = "kv-surrealkv",
feature = "kv-file",
feature = "kv-rocksdb",
feature = "kv-fdb",
feature = "kv-tikv",
@ -22,7 +21,6 @@ pub(super) enum Results {
#[cfg(any(
feature = "kv-mem",
feature = "kv-surrealkv",
feature = "kv-file",
feature = "kv-rocksdb",
feature = "kv-fdb",
feature = "kv-tikv",
@ -37,7 +35,6 @@ impl Results {
#[cfg(any(
feature = "kv-mem",
feature = "kv-surrealkv",
feature = "kv-file",
feature = "kv-rocksdb",
feature = "kv-fdb",
feature = "kv-tikv",
@ -51,7 +48,6 @@ impl Results {
#[cfg(any(
feature = "kv-mem",
feature = "kv-surrealkv",
feature = "kv-file",
feature = "kv-rocksdb",
feature = "kv-fdb",
feature = "kv-tikv",
@ -80,7 +76,6 @@ impl Results {
#[cfg(any(
feature = "kv-mem",
feature = "kv-surrealkv",
feature = "kv-file",
feature = "kv-rocksdb",
feature = "kv-fdb",
feature = "kv-tikv",
@ -101,7 +96,6 @@ impl Results {
#[cfg(any(
feature = "kv-mem",
feature = "kv-surrealkv",
feature = "kv-file",
feature = "kv-rocksdb",
feature = "kv-fdb",
feature = "kv-tikv",
@ -118,7 +112,6 @@ impl Results {
#[cfg(any(
feature = "kv-mem",
feature = "kv-surrealkv",
feature = "kv-file",
feature = "kv-rocksdb",
feature = "kv-fdb",
feature = "kv-tikv",
@ -135,7 +128,6 @@ impl Results {
#[cfg(any(
feature = "kv-mem",
feature = "kv-surrealkv",
feature = "kv-file",
feature = "kv-rocksdb",
feature = "kv-fdb",
feature = "kv-tikv",
@ -151,7 +143,6 @@ impl Results {
#[cfg(any(
feature = "kv-mem",
feature = "kv-surrealkv",
feature = "kv-file",
feature = "kv-rocksdb",
feature = "kv-fdb",
feature = "kv-tikv",
@ -170,7 +161,6 @@ impl Results {
#[cfg(any(
feature = "kv-mem",
feature = "kv-surrealkv",
feature = "kv-file",
feature = "kv-rocksdb",
feature = "kv-fdb",
feature = "kv-tikv",

View file

@ -238,7 +238,6 @@ impl<'a> Statement<'a> {
#[cfg(any(
feature = "kv-mem",
feature = "kv-surrealkv",
feature = "kv-file",
feature = "kv-rocksdb",
feature = "kv-fdb",
feature = "kv-tikv",

View file

@ -52,7 +52,6 @@ impl From<Vec<Value>> for MemoryCollector {
#[cfg(any(
feature = "kv-mem",
feature = "kv-surrealkv",
feature = "kv-file",
feature = "kv-rocksdb",
feature = "kv-fdb",
feature = "kv-tikv",

View file

@ -14,7 +14,6 @@ use bincode::Error as BincodeError;
#[cfg(any(
feature = "kv-mem",
feature = "kv-surrealkv",
feature = "kv-file",
feature = "kv-rocksdb",
feature = "kv-fdb",
feature = "kv-tikv",
@ -1113,7 +1112,6 @@ impl From<reqwest::Error> for Error {
#[cfg(any(
feature = "kv-mem",
feature = "kv-surrealkv",
feature = "kv-file",
feature = "kv-rocksdb",
feature = "kv-fdb",
feature = "kv-tikv",

View file

@ -3,7 +3,6 @@ use std::fmt;
#[cfg(any(
feature = "kv-mem",
feature = "kv-surrealkv",
feature = "kv-file",
feature = "kv-rocksdb",
feature = "kv-fdb",
feature = "kv-tikv",
@ -95,7 +94,6 @@ pub struct Datastore {
#[cfg(any(
feature = "kv-mem",
feature = "kv-surrealkv",
feature = "kv-file",
feature = "kv-rocksdb",
feature = "kv-fdb",
feature = "kv-tikv",
@ -355,7 +353,6 @@ impl Datastore {
#[cfg(any(
feature = "kv-mem",
feature = "kv-surrealkv",
feature = "kv-file",
feature = "kv-rocksdb",
feature = "kv-fdb",
feature = "kv-tikv",
@ -410,7 +407,6 @@ impl Datastore {
#[cfg(any(
feature = "kv-mem",
feature = "kv-surrealkv",
feature = "kv-file",
feature = "kv-rocksdb",
feature = "kv-fdb",
feature = "kv-tikv",
@ -1136,7 +1132,6 @@ impl Datastore {
#[cfg(any(
feature = "kv-mem",
feature = "kv-surrealkv",
feature = "kv-file",
feature = "kv-rocksdb",
feature = "kv-fdb",
feature = "kv-tikv",

View file

@ -147,7 +147,6 @@ pub(crate) fn router(
#[cfg(any(
feature = "kv-mem",
feature = "kv-surrealkv",
feature = "kv-file",
feature = "kv-rocksdb",
feature = "kv-fdb",
feature = "kv-tikv",

View file

@ -2,7 +2,6 @@ use crate::{dbs::Capabilities, iam::Level};
#[cfg(any(
feature = "kv-mem",
feature = "kv-surrealkv",
feature = "kv-file",
feature = "kv-rocksdb",
feature = "kv-fdb",
feature = "kv-tikv",
@ -29,7 +28,6 @@ pub struct Config {
#[cfg(any(
feature = "kv-mem",
feature = "kv-surrealkv",
feature = "kv-file",
feature = "kv-rocksdb",
feature = "kv-fdb",
feature = "kv-tikv",
@ -126,7 +124,6 @@ impl Config {
#[cfg(any(
feature = "kv-mem",
feature = "kv-surrealkv",
feature = "kv-file",
feature = "kv-rocksdb",
feature = "kv-fdb",
feature = "kv-tikv",

View file

@ -85,7 +85,7 @@ pub(crate) fn net_targets(value: &str) -> Result<Targets<NetTarget>, String> {
let mut result = HashSet::new();
for target in value.split(',').filter(|s| !s.is_empty()) {
result.insert(NetTarget::from_str(target)?);
result.insert(NetTarget::from_str(target).map_err(|e| e.to_string())?);
}
Ok(Targets::Some(result))
@ -99,7 +99,7 @@ pub(crate) fn func_targets(value: &str) -> Result<Targets<FuncTarget>, String> {
let mut result = HashSet::new();
for target in value.split(',').filter(|s| !s.is_empty()) {
result.insert(FuncTarget::from_str(target)?);
result.insert(FuncTarget::from_str(target).map_err(|e| e.to_string())?);
}
Ok(Targets::Some(result))