#![allow(dead_code)] pub mod error; use futures_util::{SinkExt, StreamExt, TryStreamExt}; use rand::{thread_rng, Rng}; use serde::{Deserialize, Serialize}; use serde_json::json; use std::error::Error; use std::fs::File; use std::path::Path; use std::process::{Command, Stdio}; use std::{env, fs}; use tokio::net::TcpStream; use tokio::time; use tokio_tungstenite::tungstenite::Message; use tokio_tungstenite::{connect_async, MaybeTlsStream, WebSocketStream}; use tracing::{error, info}; use crate::common::error::TestError; pub const USER: &str = "root"; pub const PASS: &str = "root"; /// Child is a (maybe running) CLI process. It can be killed by dropping it pub struct Child { inner: Option, } impl Child { /// Send some thing to the child's stdin pub fn input(mut self, input: &str) -> Self { let stdin = self.inner.as_mut().unwrap().stdin.as_mut().unwrap(); use std::io::Write; stdin.write_all(input.as_bytes()).unwrap(); self } pub fn kill(mut self) -> Self { self.inner.as_mut().unwrap().kill().unwrap(); self } /// Read the child's stdout concatenated with its stderr. Returns Ok if the child /// returns successfully, Err otherwise. pub fn output(mut self) -> Result { let output = self.inner.take().unwrap().wait_with_output().unwrap(); let mut buf = String::from_utf8(output.stdout).unwrap(); buf.push_str(&String::from_utf8(output.stderr).unwrap()); if output.status.success() { Ok(buf) } else { Err(buf) } } } impl Drop for Child { fn drop(&mut self) { if let Some(inner) = self.inner.as_mut() { let _ = inner.kill(); } } } pub fn run_internal>( args: &str, current_dir: Option

, stdout: Stdio, stderr: Stdio, ) -> Child { let mut path = std::env::current_exe().unwrap(); assert!(path.pop()); if path.ends_with("deps") { assert!(path.pop()); } // Note: Cargo automatically builds this binary for integration tests. path.push(format!("{}{}", env!("CARGO_PKG_NAME"), std::env::consts::EXE_SUFFIX)); let mut cmd = Command::new(path); if let Some(dir) = current_dir { cmd.current_dir(&dir); } cmd.env_clear(); cmd.stdin(Stdio::piped()); cmd.stdout(stdout); cmd.stderr(stderr); cmd.args(args.split_ascii_whitespace()); Child { inner: Some(cmd.spawn().unwrap()), } } /// Run the CLI with the given args pub fn run(args: &str) -> Child { run_internal::(args, None, Stdio::piped(), Stdio::piped()) } /// Run the CLI with the given args inside a temporary directory pub fn run_in_dir>(args: &str, current_dir: P) -> Child { run_internal(args, Some(current_dir), Stdio::piped(), Stdio::piped()) } pub fn tmp_file(name: &str) -> String { let path = Path::new(env!("OUT_DIR")).join(name); path.to_string_lossy().into_owned() } fn parse_server_stdio_from_var(var: &str) -> Result> { match env::var(var).as_deref() { Ok("inherit") => Ok(Stdio::inherit()), Ok("null") => Ok(Stdio::null()), Ok("piped") => Ok(Stdio::piped()), Ok(val) if val.starts_with("file://") => { Ok(Stdio::from(File::create(val.trim_start_matches("file://"))?)) } Ok(val) => Err(format!("Unsupported stdio value: {val:?}").into()), _ => Ok(Stdio::null()), } } pub struct StartServerArguments { pub auth: bool, pub tls: bool, pub wait_is_ready: bool, pub tick_interval: time::Duration, } impl Default for StartServerArguments { fn default() -> Self { Self { auth: true, tls: false, wait_is_ready: true, tick_interval: time::Duration::new(1, 0), } } } pub async fn start_server_without_auth() -> Result<(String, Child), Box> { start_server(StartServerArguments { auth: false, ..Default::default() }) .await } pub async fn start_server_with_defaults() -> Result<(String, Child), Box> { start_server(StartServerArguments::default()).await } pub async fn start_server( StartServerArguments { auth, tls, wait_is_ready, tick_interval, }: StartServerArguments, ) -> Result<(String, Child), Box> { let mut rng = thread_rng(); let port: u16 = rng.gen_range(13000..14000); let addr = format!("127.0.0.1:{port}"); let mut extra_args = String::default(); if tls { // Test the crt/key args but the keys are self signed so don't actually connect. let crt_path = tmp_file("crt.crt"); let key_path = tmp_file("key.pem"); let cert = rcgen::generate_simple_self_signed(Vec::new()).unwrap(); fs::write(&crt_path, cert.serialize_pem().unwrap()).unwrap(); fs::write(&key_path, cert.serialize_private_key_pem().into_bytes()).unwrap(); extra_args.push_str(format!(" --web-crt {crt_path} --web-key {key_path}").as_str()); } if auth { extra_args.push_str(" --auth"); } if !tick_interval.is_zero() { let sec = tick_interval.as_secs(); extra_args.push_str(format!(" --tick-interval {sec}s").as_str()); } let start_args = format!("start --bind {addr} memory --no-banner --log trace --user {USER} --pass {PASS} {extra_args}"); info!("starting server with args: {start_args}"); // Configure where the logs go when running the test let stdout = parse_server_stdio_from_var("SURREAL_TEST_SERVER_STDOUT")?; let stderr = parse_server_stdio_from_var("SURREAL_TEST_SERVER_STDERR")?; let server = run_internal::(&start_args, None, stdout, stderr); if !wait_is_ready { return Ok((addr, server)); } // Wait 5 seconds for the server to start let mut interval = time::interval(time::Duration::from_millis(500)); info!("Waiting for server to start..."); for _i in 0..10 { interval.tick().await; if run(&format!("isready --conn http://{addr}")).output().is_ok() { info!("Server ready!"); return Ok((addr, server)); } } let server_out = server.kill().output().err().unwrap(); error!("server output: {server_out}"); Err("server failed to start".into()) } type WsStream = WebSocketStream>; pub async fn connect_ws(addr: &str) -> Result> { let url = format!("ws://{}/rpc", addr); let (ws_stream, _) = connect_async(url).await?; Ok(ws_stream) } pub async fn ws_send_msg( socket: &mut WsStream, msg_req: String, ) -> Result> { // Use JSON format by default ws_send_msg_with_fmt(socket, msg_req, Format::Json).await } pub async fn ws_recv_msg(socket: &mut WsStream) -> Result> { ws_recv_msg_with_fmt(socket, Format::Json).await } pub enum Format { Json, Cbor, Pack, } pub async fn ws_recv_msg_with_fmt( socket: &mut WsStream, format: Format, ) -> Result> { // Parse and return response let mut f = socket.try_filter(|msg| match format { Format::Json => futures_util::future::ready(msg.is_text()), Format::Pack | Format::Cbor => futures_util::future::ready(msg.is_binary()), }); let msg: serde_json::Value = tokio::select! { _ = time::sleep(time::Duration::from_millis(2000)) => { return Err(TestError::NetworkError{message: "timeout waiting for the response".to_string()}.into()); } msg = f.select_next_some() => { serde_json::from_str(&msg?.to_string())? } }; Ok(serde_json::from_str(&msg.to_string())?) } pub async fn ws_send_msg_with_fmt( socket: &mut WsStream, msg_req: String, response_format: Format, ) -> Result> { tokio::select! { _ = time::sleep(time::Duration::from_millis(500)) => { return Err("timeout waiting for the request to be sent".into()); } res = socket.send(Message::Text(msg_req)) => { if let Err(err) = res { return Err(format!("Error sending the message: {}", err).into()); } } } let mut f = socket.try_filter(|msg| match response_format { Format::Json => futures_util::future::ready(msg.is_text()), Format::Pack | Format::Cbor => futures_util::future::ready(msg.is_binary()), }); tokio::select! { _ = time::sleep(time::Duration::from_millis(2000)) => { Err("timeout waiting for the response".into()) } res = f.select_next_some() => { match response_format { Format::Json => Ok(serde_json::from_str(&res?.to_string())?), Format::Cbor => Ok(serde_cbor::from_slice(&res?.into_data())?), Format::Pack => Ok(serde_pack::from_slice(&res?.into_data())?), } } } } #[derive(Serialize, Deserialize)] struct SigninParams<'a> { user: &'a str, pass: &'a str, #[serde(skip_serializing_if = "Option::is_none")] ns: Option<&'a str>, #[serde(skip_serializing_if = "Option::is_none")] db: Option<&'a str>, #[serde(skip_serializing_if = "Option::is_none")] sc: Option<&'a str>, } #[derive(Serialize, Deserialize)] struct UseParams<'a> { #[serde(skip_serializing_if = "Option::is_none")] ns: Option<&'a str>, #[serde(skip_serializing_if = "Option::is_none")] db: Option<&'a str>, } pub async fn ws_signin( socket: &mut WsStream, user: &str, pass: &str, ns: Option<&str>, db: Option<&str>, sc: Option<&str>, ) -> Result> { let json = json!({ "id": "1", "method": "signin", "params": [ SigninParams { user, pass, ns, db, sc } ], }); let msg = ws_send_msg(socket, serde_json::to_string(&json).unwrap()).await?; match msg.as_object() { Some(obj) if obj.keys().all(|k| ["id", "error"].contains(&k.as_str())) => { Err(format!("unexpected error from query request: {:?}", obj.get("error")).into()) } Some(obj) if obj.keys().all(|k| ["id", "result"].contains(&k.as_str())) => { Ok(obj.get("result").unwrap().as_str().unwrap_or_default().to_owned()) } _ => { error!("{:?}", msg.as_object().unwrap().keys().collect::>()); Err(format!("unexpected response: {:?}", msg).into()) } } } pub async fn ws_query( socket: &mut WsStream, query: &str, ) -> Result, Box> { let json = json!({ "id": "1", "method": "query", "params": [query], }); let msg = ws_send_msg(socket, serde_json::to_string(&json).unwrap()).await?; match msg.as_object() { Some(obj) if obj.keys().all(|k| ["id", "error"].contains(&k.as_str())) => { Err(format!("unexpected error from query request: {:?}", obj.get("error")).into()) } Some(obj) if obj.keys().all(|k| ["id", "result"].contains(&k.as_str())) => { Ok(obj.get("result").unwrap().as_array().unwrap().to_owned()) } _ => { error!("{:?}", msg.as_object().unwrap().keys().collect::>()); Err(format!("unexpected response: {:?}", msg).into()) } } } pub async fn ws_use( socket: &mut WsStream, ns: Option<&str>, db: Option<&str>, ) -> Result> { let json = json!({ "id": "1", "method": "use", "params": [ ns, db ], }); let msg = ws_send_msg(socket, serde_json::to_string(&json).unwrap()).await?; match msg.as_object() { Some(obj) if obj.keys().all(|k| ["id", "error"].contains(&k.as_str())) => { Err(format!("unexpected error from query request: {:?}", obj.get("error")).into()) } Some(obj) if obj.keys().all(|k| ["id", "result"].contains(&k.as_str())) => { Ok(obj.get("result").unwrap().to_owned()) } _ => { error!("{:?}", msg.as_object().unwrap().keys().collect::>()); Err(format!("unexpected response: {:?}", msg).into()) } } }