Support switching namespaces and databases separately ()

This commit is contained in:
Rushmore Mushambi 2023-05-05 20:12:19 +02:00 committed by GitHub
parent 0c752b43e9
commit 107e5b5dba
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 120 additions and 64 deletions
lib
src/api
engine
local
remote/http
method
tests/api
src/cli

View file

@ -378,14 +378,19 @@ async fn router(
match method {
Method::Use => {
let (ns, db) = match &mut params[..] {
match &mut params[..] {
[Value::Strand(Strand(ns)), Value::Strand(Strand(db))] => {
(mem::take(ns), mem::take(db))
session.ns = Some(mem::take(ns));
session.db = Some(mem::take(db));
}
[Value::Strand(Strand(ns)), Value::None] => {
session.ns = Some(mem::take(ns));
}
[Value::None, Value::Strand(Strand(db))] => {
session.db = Some(mem::take(db));
}
_ => unreachable!(),
};
session.ns = Some(ns);
session.db = Some(db);
}
Ok(DbResponse::Other(Value::None))
}
Method::Signin | Method::Signup | Method::Authenticate | Method::Invalidate => {

View file

@ -354,34 +354,47 @@ async fn router(
match method {
Method::Use => {
let path = base_url.join(SQL_PATH)?;
let mut request = client.post(path).headers(headers.clone());
let (ns, db) = match &mut params[..] {
[Value::Strand(Strand(ns)), Value::Strand(Strand(db))] => {
(mem::take(ns), mem::take(db))
(Some(mem::take(ns)), Some(mem::take(db)))
}
[Value::Strand(Strand(ns)), Value::None] => (Some(mem::take(ns)), None),
[Value::None, Value::Strand(Strand(db))] => (None, Some(mem::take(db))),
_ => unreachable!(),
};
let ns = match HeaderValue::try_from(&ns) {
Ok(ns) => ns,
Err(_) => {
return Err(Error::InvalidNsName(ns).into());
}
let ns = match ns {
Some(ns) => match HeaderValue::try_from(&ns) {
Ok(ns) => {
request = request.header("NS", &ns);
Some(ns)
}
Err(_) => {
return Err(Error::InvalidNsName(ns).into());
}
},
None => None,
};
let db = match HeaderValue::try_from(&db) {
Ok(db) => db,
Err(_) => {
return Err(Error::InvalidDbName(db).into());
}
let db = match db {
Some(db) => match HeaderValue::try_from(&db) {
Ok(db) => {
request = request.header("DB", &db);
Some(db)
}
Err(_) => {
return Err(Error::InvalidDbName(db).into());
}
},
None => None,
};
let request = client
.post(path)
.headers(headers.clone())
.header("NS", &ns)
.header("DB", &db)
.auth(auth)
.body("RETURN true");
request = request.auth(auth).body("RETURN true");
take(true, request).await?;
headers.insert("NS", ns);
headers.insert("DB", db);
if let Some(ns) = ns {
headers.insert("NS", ns);
}
if let Some(db) = db {
headers.insert("DB", db);
}
Ok(DbResponse::Other(Value::None))
}
Method::Signin => {

View file

@ -23,6 +23,7 @@ mod signin;
mod signup;
mod unset;
mod update;
mod use_db;
mod use_ns;
mod version;
@ -58,8 +59,8 @@ pub use signin::Signin;
pub use signup::Signup;
pub use unset::Unset;
pub use update::Update;
pub use use_db::UseDb;
pub use use_ns::UseNs;
pub use use_ns::UseNsDb;
pub use version::Version;
use crate::api::conn::Method;
@ -74,6 +75,7 @@ use crate::api::ExtractRouter;
use crate::api::Surreal;
use crate::sql::to_value;
use crate::sql::Uuid;
use crate::sql::Value;
use once_cell::sync::OnceCell;
use serde::Serialize;
use std::marker::PhantomData;
@ -252,7 +254,7 @@ where
/// # #[tokio::main]
/// # async fn main() -> surrealdb::Result<()> {
/// # let db = surrealdb::engine::any::connect("mem://").await?;
/// db.use_ns("namespace").use_db("database").await?;
/// db.use_ns("namespace").await?;
/// # Ok(())
/// # }
/// ```
@ -263,6 +265,26 @@ where
}
}
/// Switch to a specific database
///
/// # Examples
///
/// ```no_run
/// # #[tokio::main]
/// # async fn main() -> surrealdb::Result<()> {
/// # let db = surrealdb::engine::any::connect("mem://").await?;
/// db.use_db("database").await?;
/// # Ok(())
/// # }
/// ```
pub fn use_db(&self, db: impl Into<String>) -> UseDb<C> {
UseDb {
router: self.router.extract(),
ns: Value::None,
db: db.into(),
}
}
/// Assigns a value as a parameter for this connection
///
/// # Examples

View file

@ -1,32 +1,32 @@
use crate::api::conn::Method;
use std::future::Future;
use std::pin::Pin;
use crate::api::conn::Param;
use crate::api::conn::Router;
use crate::api::Connection;
use crate::api::Result;
use crate::api::conn::Router;
use std::future::IntoFuture;
use crate::sql::Value;
use std::future::Future;
use std::future::IntoFuture;
use std::pin::Pin;
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct UseDb<'r, C: Connection> {
pub(super) router: Result<&'r Router<C>>,
pub(super) db: String,
pub(super) router: Result<&'r Router<C>>,
pub(super) ns: Value,
pub(super) db: String,
}
impl<'r, Client> IntoFuture for UseDb<'r, Client>
where
Client: Connection,
Client: Connection,
{
type Output = Result<()>;
type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send + Sync + 'r>>;
type Output = Result<()>;
type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send + Sync + 'r>>;
fn into_future(self) -> Self::IntoFuture {
Box::pin(async move {
let mut conn = Client::new(Method::Use);
conn.execute_unit(self.router?, Param::new(vec![Value::None, self.db.into()]))
.await
})
}
fn into_future(self) -> Self::IntoFuture {
Box::pin(async move {
let mut conn = Client::new(Method::Use);
conn.execute_unit(self.router?, Param::new(vec![self.ns, self.db.into()])).await
})
}
}

View file

@ -1,8 +1,10 @@
use crate::api::conn::Method;
use crate::api::conn::Param;
use crate::api::conn::Router;
use crate::api::method::UseDb;
use crate::api::Connection;
use crate::api::Result;
use crate::sql::Value;
use std::future::Future;
use std::future::IntoFuture;
use std::pin::Pin;
@ -15,30 +17,21 @@ pub struct UseNs<'r, C: Connection> {
pub(super) ns: String,
}
/// A use NS and DB future
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct UseNsDb<'r, C: Connection> {
pub(super) router: Result<&'r Router<C>>,
pub(super) ns: String,
pub(super) db: String,
}
impl<'r, C> UseNs<'r, C>
where
C: Connection,
{
/// Switch to a specific database
pub fn use_db(self, db: impl Into<String>) -> UseNsDb<'r, C> {
UseNsDb {
pub fn use_db(self, db: impl Into<String>) -> UseDb<'r, C> {
UseDb {
ns: self.ns.into(),
db: db.into(),
ns: self.ns,
router: self.router,
}
}
}
impl<'r, Client> IntoFuture for UseNsDb<'r, Client>
impl<'r, Client> IntoFuture for UseNs<'r, Client>
where
Client: Connection,
{
@ -48,7 +41,7 @@ where
fn into_future(self) -> Self::IntoFuture {
Box::pin(async move {
let mut conn = Client::new(Method::Use);
conn.execute_unit(self.router?, Param::new(vec![self.ns.into(), self.db.into()])).await
conn.execute_unit(self.router?, Param::new(vec![self.ns.into(), Value::None])).await
})
}
}

View file

@ -10,15 +10,23 @@ async fn connect() {
async fn yuse() {
let db = new_db().await;
let item = Ulid::new().to_string();
let error = db.create::<Vec<()>>(item.as_str()).await.unwrap_err();
match error {
match db.create(Resource::from(item.as_str())).await.unwrap_err() {
// Local engines return this error
Error::Db(DbError::NsEmpty) => {}
// Remote engines return this error
Error::Api(ApiError::Query(error)) if error.contains("Specify a namespace to use") => {}
error => panic!("{:?}", error),
}
db.use_ns(NS).use_db(item).await.unwrap();
db.use_ns(NS).await.unwrap();
match db.create(Resource::from(item.as_str())).await.unwrap_err() {
// Local engines return this error
Error::Db(DbError::DbEmpty) => {}
// Remote engines return this error
Error::Api(ApiError::Query(error)) if error.contains("Specify a database to use") => {}
error => panic!("{:?}", error),
}
db.use_db(item.as_str()).await.unwrap();
db.create(Resource::from(item)).await.unwrap();
}
#[tokio::test]

View file

@ -47,13 +47,28 @@ pub async fn init(matches: &clap::ArgMatches) -> Result<(), Error> {
// Loop over each command-line input
loop {
// Use namespace / database if specified
if let (Some(namespace), Some(database)) = (&ns, &db) {
match client.use_ns(namespace).use_db(database).await {
match (&ns, &db) {
(Some(namespace), Some(database)) => {
match client.use_ns(namespace).use_db(database).await {
Ok(()) => {
prompt = format!("{namespace}/{database}> ");
}
Err(error) => eprintln!("{error}"),
}
}
(Some(namespace), None) => match client.use_ns(namespace).await {
Ok(()) => {
prompt = format!("{namespace}/{database}> ");
prompt = format!("{namespace}> ");
}
Err(error) => eprintln!("{error}"),
}
},
(None, Some(database)) => match client.use_db(database).await {
Ok(()) => {
prompt = format!("/{database}> ");
}
Err(error) => eprintln!("{error}"),
},
(None, None) => {}
}
// Prompt the user to input SQL
let readline = rl.readline(&prompt);