From 875f92415b9da5306f6db4348e260d3535a6ef4d Mon Sep 17 00:00:00 2001 From: Finn Bear Date: Fri, 12 May 2023 12:47:41 -0700 Subject: [PATCH] Feature 1827 - query CLI line continuation (#1968) --- Cargo.lock | 12 +++++ Cargo.toml | 2 +- Makefile | 2 +- src/cli/mod.rs | 7 +++ src/cli/sql.rs | 134 ++++++++++++++++++++++++++++++------------------- 5 files changed, 102 insertions(+), 55 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index fc1d9dbf..f1163cbf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3565,6 +3565,7 @@ dependencies = [ "memchr", "nix", "radix_trie", + "rustyline-derive", "scopeguard", "unicode-segmentation", "unicode-width", @@ -3572,6 +3573,17 @@ dependencies = [ "winapi", ] +[[package]] +name = "rustyline-derive" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8218eaf5d960e3c478a1b0f129fa888dd3d8d22eb3de097e9af14c1ab4438024" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "ryu" version = "1.0.13" diff --git a/Cargo.toml b/Cargo.toml index b7fd358c..d3178ed2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -43,7 +43,7 @@ opentelemetry = { version = "0.18", features = ["rt-tokio"] } opentelemetry-otlp = "0.11.0" rand = "0.8.5" reqwest = { version = "0.11.16", features = ["blocking"] } -rustyline = "11.0.0" +rustyline = { version = "11.0.0", features = ["derive"] } serde = { version = "1.0.160", features = ["derive"] } serde_cbor = "0.11.2" serde_pack = { version = "1.1.1", package = "rmp-serde" } diff --git a/Makefile b/Makefile index 04030fc5..fb9a69c8 100644 --- a/Makefile +++ b/Makefile @@ -34,7 +34,7 @@ serve: .PHONY: sql sql: - cargo run $(DEV_FEATURES) -- sql --conn ws://0.0.0.0:8000 --user root --pass root --ns test --db test --pretty + cargo run $(DEV_FEATURES) -- sql --conn ws://0.0.0.0:8000 --user root --pass root --ns test --db test --multi --pretty .PHONY: quick quick: diff --git a/src/cli/mod.rs b/src/cli/mod.rs index e3775301..9054233b 100644 --- a/src/cli/mod.rs +++ b/src/cli/mod.rs @@ -472,6 +472,13 @@ pub fn init() -> ExitCode { .default_value("root") .help("Database authentication username to use when connecting"), ) + .arg( + Arg::new("multi") + .long("multi") + .required(false) + .takes_value(false) + .help("Whether omitting semicolon causes a newline"), + ) .arg( Arg::new("pass") .short('p') diff --git a/src/cli/sql.rs b/src/cli/sql.rs index 126e0726..cae2042d 100644 --- a/src/cli/sql.rs +++ b/src/cli/sql.rs @@ -1,14 +1,12 @@ use crate::err::Error; use rustyline::error::ReadlineError; -use rustyline::DefaultEditor; +use rustyline::validate::{ValidationContext, ValidationResult, Validator}; +use rustyline::{Completer, Editor, Helper, Highlighter, Hinter}; use surrealdb::engine::any::connect; use surrealdb::error::Api as ApiError; use surrealdb::opt::auth::Root; -use surrealdb::sql; -use surrealdb::sql::Statement; -use surrealdb::sql::Value; -use surrealdb::Error as SurrealError; -use surrealdb::Response; +use surrealdb::sql::{self, Statement, Value}; +use surrealdb::{Error as SurrealError, Response}; #[tokio::main] pub async fn init(matches: &clap::ArgMatches) -> Result<(), Error> { @@ -39,7 +37,11 @@ pub async fn init(matches: &clap::ArgMatches) -> Result<(), Error> { } } // Create a new terminal REPL - let mut rl = DefaultEditor::new().unwrap(); + let mut rl = Editor::new().unwrap(); + // Set custom input validation + rl.set_helper(Some(InputValidator { + multi: matches.is_present("multi"), + })); // Load the command-line history let _ = rl.load_history("history.txt"); // Configure the prompt @@ -70,12 +72,12 @@ pub async fn init(matches: &clap::ArgMatches) -> Result<(), Error> { }, (None, None) => {} } - // Prompt the user to input SQL - let readline = rl.readline(&prompt); - // Check the user input - match readline { + + // Prompt the user to input SQL and check the input. + let line = match rl.readline(&prompt) { // The user typed a query Ok(line) => { + let line = filter_line_continuations(&line); // Ignore all empty lines if line.is_empty() { continue; @@ -84,49 +86,10 @@ pub async fn init(matches: &clap::ArgMatches) -> Result<(), Error> { if let Err(e) = rl.add_history_entry(line.as_str()) { eprintln!("{e}"); } - // Complete the request - match sql::parse(&line) { - Ok(query) => { - for statement in query.iter() { - match statement { - Statement::Use(stmt) => { - if let Some(namespace) = &stmt.ns { - ns = Some(namespace.clone()); - } - if let Some(database) = &stmt.db { - db = Some(database.clone()); - } - } - Statement::Set(stmt) => { - if let Err(e) = client.set(&stmt.name, &stmt.what).await { - eprintln!("{e}\n"); - } - } - _ => {} - } - } - let res = client.query(query).await; - // Get the request response - match process(pretty, res) { - Ok(v) => { - println!("{v}\n"); - } - Err(e) => { - eprintln!("{e}\n"); - } - } - } - Err(e) => { - eprintln!("{e}\n"); - } - } + line } - // The user types CTRL-C - Err(ReadlineError::Interrupted) => { - break; - } - // The user typed CTRL-D - Err(ReadlineError::Eof) => { + // The user typed CTRL-C or CTRL-D + Err(ReadlineError::Interrupted | ReadlineError::Eof) => { break; } // There was en error @@ -134,6 +97,43 @@ pub async fn init(matches: &clap::ArgMatches) -> Result<(), Error> { eprintln!("Error: {e:?}"); break; } + }; + + // Complete the request + match sql::parse(&line) { + Ok(query) => { + for statement in query.iter() { + match statement { + Statement::Use(stmt) => { + if let Some(namespace) = &stmt.ns { + ns = Some(namespace.clone()); + } + if let Some(database) = &stmt.db { + db = Some(database.clone()); + } + } + Statement::Set(stmt) => { + if let Err(e) = client.set(&stmt.name, &stmt.what).await { + eprintln!("{e}\n"); + } + } + _ => {} + } + } + let res = client.query(query).await; + // Get the request response + match process(pretty, res) { + Ok(v) => { + println!("{v}\n"); + } + Err(e) => { + eprintln!("{e}\n"); + } + } + } + Err(e) => { + eprintln!("{e}\n"); + } } } // Save the inputs to the history @@ -165,3 +165,31 @@ fn process(pretty: bool, res: surrealdb::Result) -> Result format!("{value:#}"), }) } + +#[derive(Completer, Helper, Highlighter, Hinter)] +struct InputValidator { + /// If omitting semicolon causes newline. + multi: bool, +} + +impl Validator for InputValidator { + fn validate(&self, ctx: &mut ValidationContext) -> rustyline::Result { + use ValidationResult::{Incomplete, Invalid, Valid}; + let input = filter_line_continuations(ctx.input()); + let result = if (self.multi && !input.trim().ends_with(';')) + || input.ends_with('\\') + || input.is_empty() + { + Incomplete + } else if let Err(e) = sql::parse(&input) { + Invalid(Some(format!(" --< {e}"))) + } else { + Valid(None) + }; + Ok(result) + } +} + +fn filter_line_continuations(line: &str) -> String { + line.replace("\\\n", "").replace("\\\r\n", "") +}