use crate::cli::abstraction::{ AuthArguments, DatabaseConnectionArguments, DatabaseSelectionOptionalArguments, }; use crate::cnf::PKG_VERSION; use crate::err::Error; use clap::Args; use rustyline::error::ReadlineError; use rustyline::validate::{ValidationContext, ValidationResult, Validator}; use rustyline::{Completer, Editor, Helper, Highlighter, Hinter}; use serde::Serialize; use serde_json::ser::PrettyFormatter; use surrealdb::dbs::Capabilities; use surrealdb::engine::any::connect; use surrealdb::opt::auth::Root; use surrealdb::opt::Config; use surrealdb::sql::{self, Statement, Value}; use surrealdb::Response; #[derive(Args, Debug)] pub struct SqlCommandArguments { #[command(flatten)] conn: DatabaseConnectionArguments, #[command(flatten)] auth: AuthArguments, #[command(flatten)] sel: Option, /// Whether database responses should be pretty printed #[arg(long)] pretty: bool, /// Whether to emit results in JSON #[arg(long)] json: bool, /// Whether omitting semicolon causes a newline #[arg(long)] multi: bool, /// Whether to show welcome message #[arg(long, env = "SURREAL_HIDE_WELCOME")] hide_welcome: bool, } pub async fn init( SqlCommandArguments { auth: AuthArguments { username, password, }, conn: DatabaseConnectionArguments { endpoint, }, sel, pretty, json, multi, hide_welcome, .. }: SqlCommandArguments, ) -> Result<(), Error> { // Initialize opentelemetry and logging crate::telemetry::builder().with_log_level("warn").init(); // Default datastore configuration for local engines let config = Config::new().capabilities(Capabilities::all()); let client = if let Some((username, password)) = username.zip(password) { let root = Root { username: &username, password: &password, }; // Connect to the database engine with authentication // // * For local engines, here we enable authentication and in the signin below we actually authenticate. // * For remote engines, we connect to the endpoint and then signin. #[cfg(feature = "has-storage")] let address = (endpoint, config.user(root)); #[cfg(not(feature = "has-storage"))] let address = endpoint; let client = connect(address).await?; // Sign in to the server client.signin(root).await?; client } else { connect((endpoint, config)).await? }; // Create a new terminal REPL let mut rl = Editor::new().unwrap(); // Set custom input validation rl.set_helper(Some(InputValidator { multi, })); // Load the command-line history let _ = rl.load_history("history.txt"); // Configure the prompt let mut prompt = "> ".to_owned(); // Keep track of current namespace/database. if let Some(DatabaseSelectionOptionalArguments { namespace, database, }) = sel { let is_not_empty = |s: &&str| !s.is_empty(); let namespace = namespace.as_deref().map(str::trim).filter(is_not_empty); let database = database.as_deref().map(str::trim).filter(is_not_empty); match (namespace, database) { (Some(namespace), Some(database)) => { client.use_ns(namespace).use_db(database).await?; prompt = format!("{namespace}/{database}> "); } (Some(namespace), None) => { client.use_ns(namespace).await?; prompt = format!("{namespace}> "); } _ => {} } }; if !hide_welcome { let hints = vec![ (true, "Different statements within a query should be separated by a (;) semicolon."), (!multi, "To create a multi-line query, end your lines with a (\\) backslash, and press enter."), (true, "To exit, send a SIGTERM or press CTRL+C") ] .iter() .filter(|(show, _)| *show) .map(|(_, hint)| format!("# - {hint}")) .collect::>() .join("\n"); eprintln!( " # # Welcome to the SurrealDB SQL shell # # How to use this shell: {hints} # # Consult https://surrealdb.com/docs/cli/sql for further instructions # # SurrealDB version: {} # ", *PKG_VERSION ); } // Loop over each command-line input loop { // Prompt the user to input SQL and check the input. let line = match rl.readline(&prompt) { // The user typed a query Ok(line) => { // Filter out all new lines let line = filter_line_continuations(&line); // Add the entry to the history if let Err(e) = rl.add_history_entry(line.as_str()) { eprintln!("{e}"); } line } // The user typed CTRL-C or CTRL-D Err(ReadlineError::Interrupted | ReadlineError::Eof) => { break; } // There was en error Err(e) => { eprintln!("Error: {e:?}"); break; } }; // Complete the request match sql::parse(&line) { Ok(query) => { let mut namespace = None; let mut database = None; let mut vars = Vec::new(); // Capture `use` and `set/let` statements from the query for statement in query.iter() { match statement { Statement::Use(stmt) => { if let Some(ns) = &stmt.ns { namespace = Some(ns.clone()); } if let Some(db) = &stmt.db { database = Some(db.clone()); } } Statement::Set(stmt) => { vars.push((stmt.name.clone(), stmt.what.clone())); } _ => {} } } // Extract the namespace and database from the current prompt let (prompt_ns, prompt_db) = split_prompt(&prompt); // The namespace should be set before the database can be set if namespace.is_none() && prompt_ns.is_empty() && database.is_some() { eprintln!( "There was a problem with the database: Specify a namespace to use\n" ); continue; } // Run the query provided let res = client.query(query).await; match process(pretty, json, res) { Ok(v) => { println!("{v}\n"); } Err(e) => { eprintln!("{e}\n"); continue; } } // Persist the variables extracted from the query for (key, value) in vars { let _ = client.set(key, value).await; } // Process the last `use` statements, if any if namespace.is_some() || database.is_some() { // Use the namespace provided in the query if any, otherwise use the one in the prompt let namespace = namespace.as_deref().unwrap_or(prompt_ns); // Use the database provided in the query if any, otherwise use the one in the prompt let database = database.as_deref().unwrap_or(prompt_db); // If the database is empty we should only use the namespace if database.is_empty() { if client.use_ns(namespace).await.is_ok() { prompt = format!("{namespace}> "); } } // Otherwise we should use both the namespace and database else if client.use_ns(namespace).use_db(database).await.is_ok() { prompt = format!("{namespace}/{database}> "); } } } Err(e) => { eprintln!("{e}\n"); } } } // Save the inputs to the history let _ = rl.save_history("history.txt"); // Everything OK Ok(()) } fn process(pretty: bool, json: bool, res: surrealdb::Result) -> Result { // Check query response for an error let mut response = res?; // Get the number of statements the query contained let num_statements = response.num_statements(); // Prepare a single value from the query response let mut output = Vec::::with_capacity(num_statements); for index in 0..num_statements { let result = response.take(index).unwrap_or_else(|e| e.to_string().into()); output.push(result); } // Check if we should emit JSON and/or prettify Ok(match (json, pretty) { // Don't prettify the SurrealQL response (false, false) => Value::from(output).to_string(), // Yes prettify the SurrealQL response (false, true) => output .iter() .enumerate() .map(|(i, v)| format!("-- Query {:?}\n{v:#}", i + 1)) .collect::>() .join("\n"), // Don't pretty print the JSON response (true, false) => serde_json::to_string(&Value::from(output).into_json()).unwrap(), // Yes prettify the JSON response (true, true) => output .iter() .enumerate() .map(|(i, v)| { let mut buf = Vec::new(); let mut serializer = serde_json::Serializer::with_formatter( &mut buf, PrettyFormatter::with_indent(b"\t"), ); v.clone().into_json().serialize(&mut serializer).unwrap(); let v = String::from_utf8(buf).unwrap(); format!("-- Query {:?}\n{v:#}", i + 1) }) .collect::>() .join("\n"), }) } #[derive(Completer, Helper, Highlighter, Hinter)] struct InputValidator { /// If omitting semicolon causes newline. multi: bool, } #[allow(clippy::if_same_then_else)] impl Validator for InputValidator { fn validate(&self, ctx: &mut ValidationContext) -> rustyline::Result { use ValidationResult::{Incomplete, Invalid, Valid}; // Filter out all new line characters let input = filter_line_continuations(ctx.input()); // Trim all whitespace from the user input let input = input.trim(); // Process the input to check if we can send the query let result = if self.multi && !input.ends_with(';') { Incomplete // The line doesn't end with a ; and we are in multi mode } else if self.multi && input.is_empty() { Incomplete // The line was empty and we are in multi mode } else if input.ends_with('\\') { Incomplete // The line ends with a backslash } else if let Err(e) = sql::parse(input) { Invalid(Some(format!(" --< {e}"))) } else { Valid(None) }; // Validation complete Ok(result) } } fn filter_line_continuations(line: &str) -> String { line.replace("\\\n", "").replace("\\\r\n", "") } fn split_prompt(prompt: &str) -> (&str, &str) { let selection = prompt.split_once('>').unwrap().0; selection.split_once('/').unwrap_or((selection, "")) }