Fix local engine endpoints (#2813)

Co-authored-by: Rushmore Mushambi <rushmore@surrealdb.com>
This commit is contained in:
Djole 2023-10-10 08:01:21 +02:00 committed by GitHub
parent 9594683129
commit 9d9fde2db8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 123 additions and 20 deletions

View file

@ -95,10 +95,8 @@ use crate::api::opt::Endpoint;
use crate::api::Connect;
use crate::api::Result;
use crate::api::Surreal;
use crate::opt::replace_tilde;
use path_clean::PathClean;
use crate::opt::path_to_string;
use std::marker::PhantomData;
use std::path::Path;
use std::sync::Arc;
use std::sync::OnceLock;
use url::Url;
@ -109,20 +107,30 @@ pub trait IntoEndpoint {
fn into_endpoint(self) -> Result<Endpoint>;
}
fn split_url(url: &str) -> (&str, &str) {
match url.split_once("://") {
Some(parts) => parts,
None => match url.split_once(':') {
Some(parts) => parts,
None => (url, ""),
},
}
}
impl IntoEndpoint for &str {
fn into_endpoint(self) -> Result<Endpoint> {
let (url, path) = match self {
"memory" | "mem://" => (Url::parse("mem://").unwrap(), "memory".to_owned()),
url if url.starts_with("ws") | url.starts_with("http") => {
url if url.starts_with("ws") | url.starts_with("http") | url.starts_with("tikv") => {
(Url::parse(url).map_err(|_| Error::InvalidUrl(self.to_owned()))?, String::new())
}
_ => {
let (scheme, _) = self.split_once(':').unwrap_or((self, ""));
let path = replace_tilde(self);
let (scheme, path) = split_url(self);
let protocol = format!("{scheme}://");
(
Url::parse(&format!("{scheme}://"))
.map_err(|_| Error::InvalidUrl(self.to_owned()))?,
Path::new(&path).clean().display().to_string(),
Url::parse(&protocol).map_err(|_| Error::InvalidUrl(self.to_owned()))?,
path_to_string(&protocol, path),
)
}
};

View file

@ -102,7 +102,12 @@ pub(crate) fn router(
_ => None,
};
let kvs = match Datastore::new(&address.path).await {
let endpoint = match address.url.scheme() {
"tikv" => address.url.as_str(),
_ => &address.path,
};
let kvs = match Datastore::new(endpoint).await {
Ok(kvs) => {
// If a root user is specified, setup the initial datastore credentials
if let Some(root) = configured_root {

View file

@ -2,7 +2,7 @@ use crate::{dbs::Capabilities, iam::Level};
use std::time::Duration;
/// Configuration for server connection, including: strictness, notifications, query_timeout, transaction_timeout
#[derive(Debug, Default)]
#[derive(Debug, Clone, Default)]
pub struct Config {
pub(crate) strict: bool,
pub(crate) notifications: bool,

View file

@ -39,17 +39,54 @@ pub trait IntoEndpoint<Scheme> {
fn into_endpoint(self) -> Result<Endpoint>;
}
pub(crate) fn replace_tilde(path: &str) -> String {
let home = std::env::var("HOME").unwrap_or_else(|_| ".".to_owned());
path.replacen("://~", &format!("://{home}"), 1).replacen(":~", &format!(":{home}"), 1)
fn replace_tilde(path: &str) -> String {
if path.starts_with("~/") {
let home = std::env::var("HOME").unwrap_or_else(|_| ".".to_owned());
path.replacen("~/", &format!("{home}/"), 1)
} else if path.starts_with("~\\") {
let home = std::env::var("HOMEPATH").unwrap_or_else(|_| ".".to_owned());
path.replacen("~\\", &format!("{home}\\"), 1)
} else {
path.to_owned()
}
}
#[allow(dead_code)]
fn path_to_string(protocol: &str, path: impl AsRef<std::path::Path>) -> String {
pub(crate) fn path_to_string(protocol: &str, path: impl AsRef<std::path::Path>) -> String {
use path_clean::PathClean;
use std::path::Path;
let path = format!("{protocol}{}", path.as_ref().display());
let path = path.as_ref().display().to_string();
let expanded = replace_tilde(&path);
Path::new(&expanded).clean().display().to_string()
let cleaned = Path::new(&expanded).clean();
format!("{protocol}{}", cleaned.display())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_path_to_string() {
let paths = [
// Unix-like paths
"path/to/db",
"/path/to/db",
// Windows paths
"path\\to\\db",
"\\path\\to\\db",
"c:path\\to\\db",
"c:\\path\\to\\db",
];
let scheme = "scheme://";
for path in paths {
let expanded = replace_tilde(path);
assert_eq!(expanded, path, "failed to replace `{path}`");
let converted = path_to_string(scheme, path);
assert_eq!(converted, format!("{scheme}{path}"), "failed to convert `{path}`");
}
}
}

View file

@ -1,7 +1,7 @@
/// TLS Configuration
#[cfg(any(feature = "native-tls", feature = "rustls"))]
#[cfg_attr(docsrs, doc(cfg(any(feature = "native-tls", feature = "rustls"))))]
#[derive(Debug)]
#[derive(Debug, Clone)]
pub enum Tls {
/// Native TLS configuration
#[cfg(feature = "native-tls")]

View file

@ -96,6 +96,12 @@ mod api_integration {
db
}
#[tokio::test]
async fn any_engine_can_connect() {
init_logger();
surrealdb::engine::any::connect("ws://127.0.0.1:8000").await.unwrap();
}
include!("api/mod.rs");
}
@ -119,6 +125,12 @@ mod api_integration {
db
}
#[tokio::test]
async fn any_engine_can_connect() {
init_logger();
surrealdb::engine::any::connect("http://127.0.0.1:8000").await.unwrap();
}
include!("api/mod.rs");
include!("api/backup.rs");
}
@ -149,7 +161,14 @@ mod api_integration {
#[tokio::test]
async fn memory_allowed_as_address() {
init_logger();
any::connect("memory").await.unwrap();
surrealdb::engine::any::connect("memory").await.unwrap();
}
#[tokio::test]
async fn any_engine_can_connect() {
init_logger();
surrealdb::engine::any::connect("mem://").await.unwrap();
surrealdb::engine::any::connect("memory").await.unwrap();
}
#[tokio::test]
@ -240,6 +259,14 @@ mod api_integration {
db
}
#[tokio::test]
async fn any_engine_can_connect() {
init_logger();
let path = Ulid::new();
surrealdb::engine::any::connect(format!("file://{path}.db")).await.unwrap();
surrealdb::engine::any::connect(format!("file:///tmp/{path}.db")).await.unwrap();
}
include!("api/mod.rs");
include!("api/backup.rs");
}
@ -268,6 +295,14 @@ mod api_integration {
db
}
#[tokio::test]
async fn any_engine_can_connect() {
init_logger();
let path = Ulid::new();
surrealdb::engine::any::connect(format!("rocksdb://{path}.db")).await.unwrap();
surrealdb::engine::any::connect(format!("rocksdb:///tmp/{path}.db")).await.unwrap();
}
include!("api/mod.rs");
include!("api/backup.rs");
}
@ -296,6 +331,14 @@ mod api_integration {
db
}
#[tokio::test]
async fn any_engine_can_connect() {
init_logger();
let path = Ulid::new();
surrealdb::engine::any::connect(format!("speedb://{path}.db")).await.unwrap();
surrealdb::engine::any::connect(format!("speedb:///tmp/{path}.db")).await.unwrap();
}
include!("api/mod.rs");
include!("api/backup.rs");
}
@ -323,6 +366,12 @@ mod api_integration {
db
}
#[tokio::test]
async fn any_engine_can_connect() {
init_logger();
surrealdb::engine::any::connect("tikv://127.0.0.1:2379").await.unwrap();
}
include!("api/mod.rs");
include!("api/backup.rs");
}
@ -345,7 +394,11 @@ mod api_integration {
.user(root)
.tick_interval(TICK_INTERVAL)
.capabilities(Capabilities::all());
let db = Surreal::new::<FDb>(("/etc/foundationdb/fdb.cluster", config)).await.unwrap();
let path = "/etc/foundationdb/fdb.cluster";
surrealdb::engine::any::connect((format!("fdb://{path}"), config.clone()))
.await
.unwrap();
let db = Surreal::new::<FDb>((path, config)).await.unwrap();
db.signin(root).await.unwrap();
db
}