From 7c199ff586ee1834036ccde3a40bd2426354be13 Mon Sep 17 00:00:00 2001 From: Rushmore Mushambi Date: Fri, 30 Dec 2022 23:27:19 +0200 Subject: [PATCH] Use new client library in CLI `sql` command (#1561) --- src/cli/mod.rs | 12 +++-- src/cli/sql.rs | 136 +++++++++++++++++++++++++++++-------------------- src/err/mod.rs | 10 +++- 3 files changed, 98 insertions(+), 60 deletions(-) diff --git a/src/cli/mod.rs b/src/cli/mod.rs index c6fa9294..f910b09e 100644 --- a/src/cli/mod.rs +++ b/src/cli/mod.rs @@ -27,6 +27,10 @@ We would love it if you could star the repository (https://github.com/surrealdb/ ---------- "; +fn split_endpoint(v: &str) -> (&str, &str) { + v.split_once("://").unwrap_or_default() +} + fn file_valid(v: &str) -> Result<(), String> { match v { v if !v.is_empty() => Ok(()), @@ -54,9 +58,9 @@ fn path_valid(v: &str) -> Result<(), String> { } fn conn_valid(v: &str) -> Result<(), String> { - match v { - v if v.starts_with("http://") => Ok(()), - v if v.starts_with("https://") => Ok(()), + let scheme = split_endpoint(v).0; + match scheme { + "http" | "https" | "ws" | "wss" | "fdb" | "mem" | "rocksdb" | "file" | "tikv" => Ok(()), _ => Err(String::from( "\ Provide a valid database connection string\ @@ -459,7 +463,7 @@ pub fn init() { let matches = setup.get_matches(); let output = match matches.subcommand() { - Some(("sql", m)) => sql::init(m), + Some(("sql", m)) => futures::executor::block_on(sql::init(m)), Some(("start", m)) => start::init(m), Some(("backup", m)) => backup::init(m), Some(("import", m)) => import::init(m), diff --git a/src/cli/sql.rs b/src/cli/sql.rs index 43015c52..f23a6707 100644 --- a/src/cli/sql.rs +++ b/src/cli/sql.rs @@ -1,50 +1,62 @@ -use crate::cnf::SERVER_AGENT; use crate::err::Error; -use reqwest::blocking::Client; -use reqwest::blocking::Response; -use reqwest::header::ACCEPT; -use reqwest::header::USER_AGENT; use rustyline::error::ReadlineError; use rustyline::Editor; use serde_json::Value; +use surrealdb::engines::any::connect; +use surrealdb::error::Api as ApiError; +use surrealdb::opt::auth::Root; +use surrealdb::sql; +use surrealdb::sql::statements::SetStatement; +use surrealdb::sql::Statement; +use surrealdb::Error as SurrealError; +use surrealdb::Response; -pub fn init(matches: &clap::ArgMatches) -> Result<(), Error> { +pub async fn init(matches: &clap::ArgMatches) -> Result<(), Error> { // Set the default logging level - crate::cli::log::init(3); + crate::cli::log::init(0); // Parse all other cli arguments - let user = matches.value_of("user").unwrap(); - let pass = matches.value_of("pass").unwrap(); - let conn = matches.value_of("conn").unwrap(); - let ns = matches.value_of("ns"); - let db = matches.value_of("db"); + let username = matches.value_of("user").unwrap(); + let password = matches.value_of("pass").unwrap(); + let endpoint = matches.value_of("conn").unwrap(); + let mut ns = matches.value_of("ns").map(str::to_string); + let mut db = matches.value_of("db").map(str::to_string); // If we should pretty-print responses let pretty = matches.is_present("pretty"); - // Set the correct import URL - let conn = format!("{conn}/sql"); // Make a new remote request - let res = Client::new() - .post(conn) - .header(USER_AGENT, SERVER_AGENT) - .header(ACCEPT, "application/json") - .basic_auth(user, Some(pass)); - // Add NS header if specified - let res = match ns { - Some(ns) => res.header("NS", ns), - None => res, - }; - // Add DB header if specified - let res = match db { - Some(db) => res.header("DB", db), - None => res, + let client = connect(endpoint).await?; + // Sign in to the server if the specified dabatabase engine supports it + let root = Root { + username, + password, }; + if let Err(error) = client.signin(root).await { + match error { + // Authentication not supported by this engine, we can safely continue + SurrealError::Api(ApiError::AuthNotSupported) => {} + error => { + return Err(error.into()); + } + } + } // Create a new terminal REPL let mut rl = Editor::<()>::new().unwrap(); // Load the command-line history let _ = rl.load_history("history.txt"); + // Configure the prompt + let mut prompt = "> ".to_owned(); // Loop over each command-line input loop { + // Use namespace / database if specified + if let (Some(namespace), Some(database)) = (&ns, &db) { + match client.use_ns(namespace).use_db(database).await { + Ok(()) => { + prompt = format!("{namespace}/{database}> "); + } + Err(error) => eprintln!("{error}"), + } + } // Prompt the user to input SQL - let readline = rl.readline("> "); + let readline = rl.readline(&prompt); // Check the user input match readline { // The user typed a query @@ -55,14 +67,38 @@ pub fn init(matches: &clap::ArgMatches) -> Result<(), Error> { } // Add the entry to the history rl.add_history_entry(line.as_str()); - // Clone the request infallibly - let res = res.try_clone().unwrap(); // Complete the request - let res = res.body(line).send(); - // Get the request response - match process(pretty, res) { - Ok(v) => println!("{v}"), - Err(e) => eprintln!("{e}"), + match sql::parse(&line) { + Ok(query) => { + for statement in query.iter() { + match statement { + Statement::Use(stmt) => { + if let Some(namespace) = &stmt.ns { + ns = Some(namespace.clone()); + } + if let Some(database) = &stmt.db { + db = Some(database.clone()); + } + } + Statement::Set(SetStatement { + name, + what, + }) => { + if let Err(error) = client.set(name, what).await { + eprintln!("{error}"); + } + } + _ => {} + } + } + let res = client.query(query).await; + // Get the request response + match process(pretty, res) { + Ok(v) => println!("{v}"), + Err(e) => eprintln!("{e}"), + } + } + Err(error) => eprintln!("{error}"), } } // The user types CTRL-C @@ -86,28 +122,20 @@ pub fn init(matches: &clap::ArgMatches) -> Result<(), Error> { Ok(()) } -fn process(pretty: bool, res: reqwest::Result) -> Result { +fn process(pretty: bool, res: surrealdb::Result) -> Result { // Catch any errors - let res = res?; - // Process the TEXT response - let res = res.text()?; + let values: Vec = res?.take(0)?; + let value = Value::Array(values); // Check if we should prettify match pretty { // Don't prettify the response - false => Ok(res), + false => Ok(value.to_string()), // Yes prettify the response - true => match res.is_empty() { - // This was an empty response - true => Ok(res), - // Let's make this response pretty - false => { - // Parse the JSON response - let res: Value = serde_json::from_str(&res)?; - // Pretty the JSON response - let res = serde_json::to_string_pretty(&res)?; - // Everything processed OK - Ok(res) - } - }, + true => { + // Pretty the JSON response + let res = serde_json::to_string_pretty(&value)?; + // Everything processed OK + Ok(res) + } } } diff --git a/src/err/mod.rs b/src/err/mod.rs index 71dfd55d..78ebddf1 100644 --- a/src/err/mod.rs +++ b/src/err/mod.rs @@ -6,7 +6,7 @@ use serde_json::error::Error as JsonError; use serde_pack::encode::Error as PackError; use std::io::Error as IoError; use std::string::FromUtf8Error as Utf8Error; -use surrealdb::err::Error as DbError; +use surrealdb::Error as SurrealError; use thiserror::Error; #[derive(Error, Debug)] @@ -30,7 +30,7 @@ pub enum Error { InvalidStorage, #[error("There was a problem with the database: {0}")] - Db(#[from] DbError), + Db(#[from] SurrealError), #[error("Couldn't open the specified file: {0}")] Io(#[from] IoError), @@ -73,3 +73,9 @@ impl From for Error { Error::InvalidAuth } } + +impl From for Error { + fn from(error: surrealdb::error::Db) -> Error { + Error::Db(error.into()) + } +}