Feature 1827 - query CLI line continuation (#1968)

This commit is contained in:
Finn Bear 2023-05-12 12:47:41 -07:00 committed by GitHub
parent db345a2ce7
commit 875f92415b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 102 additions and 55 deletions

12
Cargo.lock generated
View file

@ -3565,6 +3565,7 @@ dependencies = [
"memchr", "memchr",
"nix", "nix",
"radix_trie", "radix_trie",
"rustyline-derive",
"scopeguard", "scopeguard",
"unicode-segmentation", "unicode-segmentation",
"unicode-width", "unicode-width",
@ -3572,6 +3573,17 @@ dependencies = [
"winapi", "winapi",
] ]
[[package]]
name = "rustyline-derive"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8218eaf5d960e3c478a1b0f129fa888dd3d8d22eb3de097e9af14c1ab4438024"
dependencies = [
"proc-macro2",
"quote",
"syn 1.0.109",
]
[[package]] [[package]]
name = "ryu" name = "ryu"
version = "1.0.13" version = "1.0.13"

View file

@ -43,7 +43,7 @@ opentelemetry = { version = "0.18", features = ["rt-tokio"] }
opentelemetry-otlp = "0.11.0" opentelemetry-otlp = "0.11.0"
rand = "0.8.5" rand = "0.8.5"
reqwest = { version = "0.11.16", features = ["blocking"] } reqwest = { version = "0.11.16", features = ["blocking"] }
rustyline = "11.0.0" rustyline = { version = "11.0.0", features = ["derive"] }
serde = { version = "1.0.160", features = ["derive"] } serde = { version = "1.0.160", features = ["derive"] }
serde_cbor = "0.11.2" serde_cbor = "0.11.2"
serde_pack = { version = "1.1.1", package = "rmp-serde" } serde_pack = { version = "1.1.1", package = "rmp-serde" }

View file

@ -34,7 +34,7 @@ serve:
.PHONY: sql .PHONY: sql
sql: sql:
cargo run $(DEV_FEATURES) -- sql --conn ws://0.0.0.0:8000 --user root --pass root --ns test --db test --pretty cargo run $(DEV_FEATURES) -- sql --conn ws://0.0.0.0:8000 --user root --pass root --ns test --db test --multi --pretty
.PHONY: quick .PHONY: quick
quick: quick:

View file

@ -472,6 +472,13 @@ pub fn init() -> ExitCode {
.default_value("root") .default_value("root")
.help("Database authentication username to use when connecting"), .help("Database authentication username to use when connecting"),
) )
.arg(
Arg::new("multi")
.long("multi")
.required(false)
.takes_value(false)
.help("Whether omitting semicolon causes a newline"),
)
.arg( .arg(
Arg::new("pass") Arg::new("pass")
.short('p') .short('p')

View file

@ -1,14 +1,12 @@
use crate::err::Error; use crate::err::Error;
use rustyline::error::ReadlineError; use rustyline::error::ReadlineError;
use rustyline::DefaultEditor; use rustyline::validate::{ValidationContext, ValidationResult, Validator};
use rustyline::{Completer, Editor, Helper, Highlighter, Hinter};
use surrealdb::engine::any::connect; use surrealdb::engine::any::connect;
use surrealdb::error::Api as ApiError; use surrealdb::error::Api as ApiError;
use surrealdb::opt::auth::Root; use surrealdb::opt::auth::Root;
use surrealdb::sql; use surrealdb::sql::{self, Statement, Value};
use surrealdb::sql::Statement; use surrealdb::{Error as SurrealError, Response};
use surrealdb::sql::Value;
use surrealdb::Error as SurrealError;
use surrealdb::Response;
#[tokio::main] #[tokio::main]
pub async fn init(matches: &clap::ArgMatches) -> Result<(), Error> { pub async fn init(matches: &clap::ArgMatches) -> Result<(), Error> {
@ -39,7 +37,11 @@ pub async fn init(matches: &clap::ArgMatches) -> Result<(), Error> {
} }
} }
// Create a new terminal REPL // Create a new terminal REPL
let mut rl = DefaultEditor::new().unwrap(); let mut rl = Editor::new().unwrap();
// Set custom input validation
rl.set_helper(Some(InputValidator {
multi: matches.is_present("multi"),
}));
// Load the command-line history // Load the command-line history
let _ = rl.load_history("history.txt"); let _ = rl.load_history("history.txt");
// Configure the prompt // Configure the prompt
@ -70,12 +72,12 @@ pub async fn init(matches: &clap::ArgMatches) -> Result<(), Error> {
}, },
(None, None) => {} (None, None) => {}
} }
// Prompt the user to input SQL
let readline = rl.readline(&prompt); // Prompt the user to input SQL and check the input.
// Check the user input let line = match rl.readline(&prompt) {
match readline {
// The user typed a query // The user typed a query
Ok(line) => { Ok(line) => {
let line = filter_line_continuations(&line);
// Ignore all empty lines // Ignore all empty lines
if line.is_empty() { if line.is_empty() {
continue; continue;
@ -84,49 +86,10 @@ pub async fn init(matches: &clap::ArgMatches) -> Result<(), Error> {
if let Err(e) = rl.add_history_entry(line.as_str()) { if let Err(e) = rl.add_history_entry(line.as_str()) {
eprintln!("{e}"); eprintln!("{e}");
} }
// Complete the request line
match sql::parse(&line) {
Ok(query) => {
for statement in query.iter() {
match statement {
Statement::Use(stmt) => {
if let Some(namespace) = &stmt.ns {
ns = Some(namespace.clone());
}
if let Some(database) = &stmt.db {
db = Some(database.clone());
}
}
Statement::Set(stmt) => {
if let Err(e) = client.set(&stmt.name, &stmt.what).await {
eprintln!("{e}\n");
}
}
_ => {}
}
}
let res = client.query(query).await;
// Get the request response
match process(pretty, res) {
Ok(v) => {
println!("{v}\n");
}
Err(e) => {
eprintln!("{e}\n");
}
}
}
Err(e) => {
eprintln!("{e}\n");
}
}
} }
// The user types CTRL-C // The user typed CTRL-C or CTRL-D
Err(ReadlineError::Interrupted) => { Err(ReadlineError::Interrupted | ReadlineError::Eof) => {
break;
}
// The user typed CTRL-D
Err(ReadlineError::Eof) => {
break; break;
} }
// There was en error // There was en error
@ -134,6 +97,43 @@ pub async fn init(matches: &clap::ArgMatches) -> Result<(), Error> {
eprintln!("Error: {e:?}"); eprintln!("Error: {e:?}");
break; break;
} }
};
// Complete the request
match sql::parse(&line) {
Ok(query) => {
for statement in query.iter() {
match statement {
Statement::Use(stmt) => {
if let Some(namespace) = &stmt.ns {
ns = Some(namespace.clone());
}
if let Some(database) = &stmt.db {
db = Some(database.clone());
}
}
Statement::Set(stmt) => {
if let Err(e) = client.set(&stmt.name, &stmt.what).await {
eprintln!("{e}\n");
}
}
_ => {}
}
}
let res = client.query(query).await;
// Get the request response
match process(pretty, res) {
Ok(v) => {
println!("{v}\n");
}
Err(e) => {
eprintln!("{e}\n");
}
}
}
Err(e) => {
eprintln!("{e}\n");
}
} }
} }
// Save the inputs to the history // Save the inputs to the history
@ -165,3 +165,31 @@ fn process(pretty: bool, res: surrealdb::Result<Response>) -> Result<String, Err
true => format!("{value:#}"), true => format!("{value:#}"),
}) })
} }
#[derive(Completer, Helper, Highlighter, Hinter)]
struct InputValidator {
/// If omitting semicolon causes newline.
multi: bool,
}
impl Validator for InputValidator {
fn validate(&self, ctx: &mut ValidationContext) -> rustyline::Result<ValidationResult> {
use ValidationResult::{Incomplete, Invalid, Valid};
let input = filter_line_continuations(ctx.input());
let result = if (self.multi && !input.trim().ends_with(';'))
|| input.ends_with('\\')
|| input.is_empty()
{
Incomplete
} else if let Err(e) = sql::parse(&input) {
Invalid(Some(format!(" --< {e}")))
} else {
Valid(None)
};
Ok(result)
}
}
fn filter_line_continuations(line: &str) -> String {
line.replace("\\\n", "").replace("\\\r\n", "")
}