diff --git a/lib/src/api/engine/any/mod.rs b/lib/src/api/engine/any/mod.rs index 0245a3eb..66d6c5b9 100644 --- a/lib/src/api/engine/any/mod.rs +++ b/lib/src/api/engine/any/mod.rs @@ -95,10 +95,8 @@ use crate::api::opt::Endpoint; use crate::api::Connect; use crate::api::Result; use crate::api::Surreal; -use crate::opt::replace_tilde; -use path_clean::PathClean; +use crate::opt::path_to_string; use std::marker::PhantomData; -use std::path::Path; use std::sync::Arc; use std::sync::OnceLock; use url::Url; @@ -109,20 +107,30 @@ pub trait IntoEndpoint { fn into_endpoint(self) -> Result; } +fn split_url(url: &str) -> (&str, &str) { + match url.split_once("://") { + Some(parts) => parts, + None => match url.split_once(':') { + Some(parts) => parts, + None => (url, ""), + }, + } +} + impl IntoEndpoint for &str { fn into_endpoint(self) -> Result { let (url, path) = match self { "memory" | "mem://" => (Url::parse("mem://").unwrap(), "memory".to_owned()), - url if url.starts_with("ws") | url.starts_with("http") => { + url if url.starts_with("ws") | url.starts_with("http") | url.starts_with("tikv") => { (Url::parse(url).map_err(|_| Error::InvalidUrl(self.to_owned()))?, String::new()) } + _ => { - let (scheme, _) = self.split_once(':').unwrap_or((self, "")); - let path = replace_tilde(self); + let (scheme, path) = split_url(self); + let protocol = format!("{scheme}://"); ( - Url::parse(&format!("{scheme}://")) - .map_err(|_| Error::InvalidUrl(self.to_owned()))?, - Path::new(&path).clean().display().to_string(), + Url::parse(&protocol).map_err(|_| Error::InvalidUrl(self.to_owned()))?, + path_to_string(&protocol, path), ) } }; diff --git a/lib/src/api/engine/local/native.rs b/lib/src/api/engine/local/native.rs index f56df676..a44b9447 100644 --- a/lib/src/api/engine/local/native.rs +++ b/lib/src/api/engine/local/native.rs @@ -102,7 +102,12 @@ pub(crate) fn router( _ => None, }; - let kvs = match Datastore::new(&address.path).await { + let endpoint = match address.url.scheme() { + "tikv" => address.url.as_str(), + _ => &address.path, + }; + + let kvs = match Datastore::new(endpoint).await { Ok(kvs) => { // If a root user is specified, setup the initial datastore credentials if let Some(root) = configured_root { diff --git a/lib/src/api/opt/config.rs b/lib/src/api/opt/config.rs index d67c8e18..967c77ef 100644 --- a/lib/src/api/opt/config.rs +++ b/lib/src/api/opt/config.rs @@ -2,7 +2,7 @@ use crate::{dbs::Capabilities, iam::Level}; use std::time::Duration; /// Configuration for server connection, including: strictness, notifications, query_timeout, transaction_timeout -#[derive(Debug, Default)] +#[derive(Debug, Clone, Default)] pub struct Config { pub(crate) strict: bool, pub(crate) notifications: bool, diff --git a/lib/src/api/opt/endpoint/mod.rs b/lib/src/api/opt/endpoint/mod.rs index 6099d621..727c82e1 100644 --- a/lib/src/api/opt/endpoint/mod.rs +++ b/lib/src/api/opt/endpoint/mod.rs @@ -39,17 +39,54 @@ pub trait IntoEndpoint { fn into_endpoint(self) -> Result; } -pub(crate) fn replace_tilde(path: &str) -> String { - let home = std::env::var("HOME").unwrap_or_else(|_| ".".to_owned()); - path.replacen("://~", &format!("://{home}"), 1).replacen(":~", &format!(":{home}"), 1) +fn replace_tilde(path: &str) -> String { + if path.starts_with("~/") { + let home = std::env::var("HOME").unwrap_or_else(|_| ".".to_owned()); + path.replacen("~/", &format!("{home}/"), 1) + } else if path.starts_with("~\\") { + let home = std::env::var("HOMEPATH").unwrap_or_else(|_| ".".to_owned()); + path.replacen("~\\", &format!("{home}\\"), 1) + } else { + path.to_owned() + } } #[allow(dead_code)] -fn path_to_string(protocol: &str, path: impl AsRef) -> String { +pub(crate) fn path_to_string(protocol: &str, path: impl AsRef) -> String { use path_clean::PathClean; use std::path::Path; - let path = format!("{protocol}{}", path.as_ref().display()); + let path = path.as_ref().display().to_string(); let expanded = replace_tilde(&path); - Path::new(&expanded).clean().display().to_string() + let cleaned = Path::new(&expanded).clean(); + format!("{protocol}{}", cleaned.display()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_path_to_string() { + let paths = [ + // Unix-like paths + "path/to/db", + "/path/to/db", + // Windows paths + "path\\to\\db", + "\\path\\to\\db", + "c:path\\to\\db", + "c:\\path\\to\\db", + ]; + + let scheme = "scheme://"; + + for path in paths { + let expanded = replace_tilde(path); + assert_eq!(expanded, path, "failed to replace `{path}`"); + + let converted = path_to_string(scheme, path); + assert_eq!(converted, format!("{scheme}{path}"), "failed to convert `{path}`"); + } + } } diff --git a/lib/src/api/opt/tls.rs b/lib/src/api/opt/tls.rs index fde1f1a7..ca3f4a9e 100644 --- a/lib/src/api/opt/tls.rs +++ b/lib/src/api/opt/tls.rs @@ -1,7 +1,7 @@ /// TLS Configuration #[cfg(any(feature = "native-tls", feature = "rustls"))] #[cfg_attr(docsrs, doc(cfg(any(feature = "native-tls", feature = "rustls"))))] -#[derive(Debug)] +#[derive(Debug, Clone)] pub enum Tls { /// Native TLS configuration #[cfg(feature = "native-tls")] diff --git a/lib/tests/api.rs b/lib/tests/api.rs index 3e0f8f10..51322eaf 100644 --- a/lib/tests/api.rs +++ b/lib/tests/api.rs @@ -96,6 +96,12 @@ mod api_integration { db } + #[tokio::test] + async fn any_engine_can_connect() { + init_logger(); + surrealdb::engine::any::connect("ws://127.0.0.1:8000").await.unwrap(); + } + include!("api/mod.rs"); } @@ -119,6 +125,12 @@ mod api_integration { db } + #[tokio::test] + async fn any_engine_can_connect() { + init_logger(); + surrealdb::engine::any::connect("http://127.0.0.1:8000").await.unwrap(); + } + include!("api/mod.rs"); include!("api/backup.rs"); } @@ -149,7 +161,14 @@ mod api_integration { #[tokio::test] async fn memory_allowed_as_address() { init_logger(); - any::connect("memory").await.unwrap(); + surrealdb::engine::any::connect("memory").await.unwrap(); + } + + #[tokio::test] + async fn any_engine_can_connect() { + init_logger(); + surrealdb::engine::any::connect("mem://").await.unwrap(); + surrealdb::engine::any::connect("memory").await.unwrap(); } #[tokio::test] @@ -240,6 +259,14 @@ mod api_integration { db } + #[tokio::test] + async fn any_engine_can_connect() { + init_logger(); + let path = Ulid::new(); + surrealdb::engine::any::connect(format!("file://{path}.db")).await.unwrap(); + surrealdb::engine::any::connect(format!("file:///tmp/{path}.db")).await.unwrap(); + } + include!("api/mod.rs"); include!("api/backup.rs"); } @@ -268,6 +295,14 @@ mod api_integration { db } + #[tokio::test] + async fn any_engine_can_connect() { + init_logger(); + let path = Ulid::new(); + surrealdb::engine::any::connect(format!("rocksdb://{path}.db")).await.unwrap(); + surrealdb::engine::any::connect(format!("rocksdb:///tmp/{path}.db")).await.unwrap(); + } + include!("api/mod.rs"); include!("api/backup.rs"); } @@ -296,6 +331,14 @@ mod api_integration { db } + #[tokio::test] + async fn any_engine_can_connect() { + init_logger(); + let path = Ulid::new(); + surrealdb::engine::any::connect(format!("speedb://{path}.db")).await.unwrap(); + surrealdb::engine::any::connect(format!("speedb:///tmp/{path}.db")).await.unwrap(); + } + include!("api/mod.rs"); include!("api/backup.rs"); } @@ -323,6 +366,12 @@ mod api_integration { db } + #[tokio::test] + async fn any_engine_can_connect() { + init_logger(); + surrealdb::engine::any::connect("tikv://127.0.0.1:2379").await.unwrap(); + } + include!("api/mod.rs"); include!("api/backup.rs"); } @@ -345,7 +394,11 @@ mod api_integration { .user(root) .tick_interval(TICK_INTERVAL) .capabilities(Capabilities::all()); - let db = Surreal::new::(("/etc/foundationdb/fdb.cluster", config)).await.unwrap(); + let path = "/etc/foundationdb/fdb.cluster"; + surrealdb::engine::any::connect((format!("fdb://{path}"), config.clone())) + .await + .unwrap(); + let db = Surreal::new::((path, config)).await.unwrap(); db.signin(root).await.unwrap(); db }