Fix local engine endpoints (#2813)

Co-authored-by: Rushmore Mushambi <rushmore@surrealdb.com>
This commit is contained in:
Djole 2023-10-10 08:01:21 +02:00 committed by GitHub
parent 9594683129
commit 9d9fde2db8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 123 additions and 20 deletions

View file

@ -95,10 +95,8 @@ use crate::api::opt::Endpoint;
use crate::api::Connect; use crate::api::Connect;
use crate::api::Result; use crate::api::Result;
use crate::api::Surreal; use crate::api::Surreal;
use crate::opt::replace_tilde; use crate::opt::path_to_string;
use path_clean::PathClean;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::path::Path;
use std::sync::Arc; use std::sync::Arc;
use std::sync::OnceLock; use std::sync::OnceLock;
use url::Url; use url::Url;
@ -109,20 +107,30 @@ pub trait IntoEndpoint {
fn into_endpoint(self) -> Result<Endpoint>; fn into_endpoint(self) -> Result<Endpoint>;
} }
fn split_url(url: &str) -> (&str, &str) {
match url.split_once("://") {
Some(parts) => parts,
None => match url.split_once(':') {
Some(parts) => parts,
None => (url, ""),
},
}
}
impl IntoEndpoint for &str { impl IntoEndpoint for &str {
fn into_endpoint(self) -> Result<Endpoint> { fn into_endpoint(self) -> Result<Endpoint> {
let (url, path) = match self { let (url, path) = match self {
"memory" | "mem://" => (Url::parse("mem://").unwrap(), "memory".to_owned()), "memory" | "mem://" => (Url::parse("mem://").unwrap(), "memory".to_owned()),
url if url.starts_with("ws") | url.starts_with("http") => { url if url.starts_with("ws") | url.starts_with("http") | url.starts_with("tikv") => {
(Url::parse(url).map_err(|_| Error::InvalidUrl(self.to_owned()))?, String::new()) (Url::parse(url).map_err(|_| Error::InvalidUrl(self.to_owned()))?, String::new())
} }
_ => { _ => {
let (scheme, _) = self.split_once(':').unwrap_or((self, "")); let (scheme, path) = split_url(self);
let path = replace_tilde(self); let protocol = format!("{scheme}://");
( (
Url::parse(&format!("{scheme}://")) Url::parse(&protocol).map_err(|_| Error::InvalidUrl(self.to_owned()))?,
.map_err(|_| Error::InvalidUrl(self.to_owned()))?, path_to_string(&protocol, path),
Path::new(&path).clean().display().to_string(),
) )
} }
}; };

View file

@ -102,7 +102,12 @@ pub(crate) fn router(
_ => None, _ => None,
}; };
let kvs = match Datastore::new(&address.path).await { let endpoint = match address.url.scheme() {
"tikv" => address.url.as_str(),
_ => &address.path,
};
let kvs = match Datastore::new(endpoint).await {
Ok(kvs) => { Ok(kvs) => {
// If a root user is specified, setup the initial datastore credentials // If a root user is specified, setup the initial datastore credentials
if let Some(root) = configured_root { if let Some(root) = configured_root {

View file

@ -2,7 +2,7 @@ use crate::{dbs::Capabilities, iam::Level};
use std::time::Duration; use std::time::Duration;
/// Configuration for server connection, including: strictness, notifications, query_timeout, transaction_timeout /// Configuration for server connection, including: strictness, notifications, query_timeout, transaction_timeout
#[derive(Debug, Default)] #[derive(Debug, Clone, Default)]
pub struct Config { pub struct Config {
pub(crate) strict: bool, pub(crate) strict: bool,
pub(crate) notifications: bool, pub(crate) notifications: bool,

View file

@ -39,17 +39,54 @@ pub trait IntoEndpoint<Scheme> {
fn into_endpoint(self) -> Result<Endpoint>; fn into_endpoint(self) -> Result<Endpoint>;
} }
pub(crate) fn replace_tilde(path: &str) -> String { fn replace_tilde(path: &str) -> String {
let home = std::env::var("HOME").unwrap_or_else(|_| ".".to_owned()); if path.starts_with("~/") {
path.replacen("://~", &format!("://{home}"), 1).replacen(":~", &format!(":{home}"), 1) let home = std::env::var("HOME").unwrap_or_else(|_| ".".to_owned());
path.replacen("~/", &format!("{home}/"), 1)
} else if path.starts_with("~\\") {
let home = std::env::var("HOMEPATH").unwrap_or_else(|_| ".".to_owned());
path.replacen("~\\", &format!("{home}\\"), 1)
} else {
path.to_owned()
}
} }
#[allow(dead_code)] #[allow(dead_code)]
fn path_to_string(protocol: &str, path: impl AsRef<std::path::Path>) -> String { pub(crate) fn path_to_string(protocol: &str, path: impl AsRef<std::path::Path>) -> String {
use path_clean::PathClean; use path_clean::PathClean;
use std::path::Path; use std::path::Path;
let path = format!("{protocol}{}", path.as_ref().display()); let path = path.as_ref().display().to_string();
let expanded = replace_tilde(&path); let expanded = replace_tilde(&path);
Path::new(&expanded).clean().display().to_string() let cleaned = Path::new(&expanded).clean();
format!("{protocol}{}", cleaned.display())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_path_to_string() {
let paths = [
// Unix-like paths
"path/to/db",
"/path/to/db",
// Windows paths
"path\\to\\db",
"\\path\\to\\db",
"c:path\\to\\db",
"c:\\path\\to\\db",
];
let scheme = "scheme://";
for path in paths {
let expanded = replace_tilde(path);
assert_eq!(expanded, path, "failed to replace `{path}`");
let converted = path_to_string(scheme, path);
assert_eq!(converted, format!("{scheme}{path}"), "failed to convert `{path}`");
}
}
} }

View file

@ -1,7 +1,7 @@
/// TLS Configuration /// TLS Configuration
#[cfg(any(feature = "native-tls", feature = "rustls"))] #[cfg(any(feature = "native-tls", feature = "rustls"))]
#[cfg_attr(docsrs, doc(cfg(any(feature = "native-tls", feature = "rustls"))))] #[cfg_attr(docsrs, doc(cfg(any(feature = "native-tls", feature = "rustls"))))]
#[derive(Debug)] #[derive(Debug, Clone)]
pub enum Tls { pub enum Tls {
/// Native TLS configuration /// Native TLS configuration
#[cfg(feature = "native-tls")] #[cfg(feature = "native-tls")]

View file

@ -96,6 +96,12 @@ mod api_integration {
db db
} }
#[tokio::test]
async fn any_engine_can_connect() {
init_logger();
surrealdb::engine::any::connect("ws://127.0.0.1:8000").await.unwrap();
}
include!("api/mod.rs"); include!("api/mod.rs");
} }
@ -119,6 +125,12 @@ mod api_integration {
db db
} }
#[tokio::test]
async fn any_engine_can_connect() {
init_logger();
surrealdb::engine::any::connect("http://127.0.0.1:8000").await.unwrap();
}
include!("api/mod.rs"); include!("api/mod.rs");
include!("api/backup.rs"); include!("api/backup.rs");
} }
@ -149,7 +161,14 @@ mod api_integration {
#[tokio::test] #[tokio::test]
async fn memory_allowed_as_address() { async fn memory_allowed_as_address() {
init_logger(); init_logger();
any::connect("memory").await.unwrap(); surrealdb::engine::any::connect("memory").await.unwrap();
}
#[tokio::test]
async fn any_engine_can_connect() {
init_logger();
surrealdb::engine::any::connect("mem://").await.unwrap();
surrealdb::engine::any::connect("memory").await.unwrap();
} }
#[tokio::test] #[tokio::test]
@ -240,6 +259,14 @@ mod api_integration {
db db
} }
#[tokio::test]
async fn any_engine_can_connect() {
init_logger();
let path = Ulid::new();
surrealdb::engine::any::connect(format!("file://{path}.db")).await.unwrap();
surrealdb::engine::any::connect(format!("file:///tmp/{path}.db")).await.unwrap();
}
include!("api/mod.rs"); include!("api/mod.rs");
include!("api/backup.rs"); include!("api/backup.rs");
} }
@ -268,6 +295,14 @@ mod api_integration {
db db
} }
#[tokio::test]
async fn any_engine_can_connect() {
init_logger();
let path = Ulid::new();
surrealdb::engine::any::connect(format!("rocksdb://{path}.db")).await.unwrap();
surrealdb::engine::any::connect(format!("rocksdb:///tmp/{path}.db")).await.unwrap();
}
include!("api/mod.rs"); include!("api/mod.rs");
include!("api/backup.rs"); include!("api/backup.rs");
} }
@ -296,6 +331,14 @@ mod api_integration {
db db
} }
#[tokio::test]
async fn any_engine_can_connect() {
init_logger();
let path = Ulid::new();
surrealdb::engine::any::connect(format!("speedb://{path}.db")).await.unwrap();
surrealdb::engine::any::connect(format!("speedb:///tmp/{path}.db")).await.unwrap();
}
include!("api/mod.rs"); include!("api/mod.rs");
include!("api/backup.rs"); include!("api/backup.rs");
} }
@ -323,6 +366,12 @@ mod api_integration {
db db
} }
#[tokio::test]
async fn any_engine_can_connect() {
init_logger();
surrealdb::engine::any::connect("tikv://127.0.0.1:2379").await.unwrap();
}
include!("api/mod.rs"); include!("api/mod.rs");
include!("api/backup.rs"); include!("api/backup.rs");
} }
@ -345,7 +394,11 @@ mod api_integration {
.user(root) .user(root)
.tick_interval(TICK_INTERVAL) .tick_interval(TICK_INTERVAL)
.capabilities(Capabilities::all()); .capabilities(Capabilities::all());
let db = Surreal::new::<FDb>(("/etc/foundationdb/fdb.cluster", config)).await.unwrap(); let path = "/etc/foundationdb/fdb.cluster";
surrealdb::engine::any::connect((format!("fdb://{path}"), config.clone()))
.await
.unwrap();
let db = Surreal::new::<FDb>((path, config)).await.unwrap();
db.signin(root).await.unwrap(); db.signin(root).await.unwrap();
db db
} }