#![allow(dead_code)] pub mod error; use crate::common::error::TestError; use futures_util::{SinkExt, StreamExt, TryStreamExt}; use rand::{thread_rng, Rng}; use serde::{Deserialize, Serialize}; use serde_json::json; use std::error::Error; use std::fs::File; use std::path::Path; use std::process::{Command, Stdio}; use std::{env, fs}; use tokio::net::TcpStream; use tokio::time; use tokio_tungstenite::tungstenite::Message; use tokio_tungstenite::{connect_async, MaybeTlsStream, WebSocketStream}; use tracing::{debug, error, info}; pub const USER: &str = "root"; pub const PASS: &str = "root"; /// Child is a (maybe running) CLI process. It can be killed by dropping it pub struct Child { inner: Option, } impl Child { /// Send some thing to the child's stdin pub fn input(mut self, input: &str) -> Self { let stdin = self.inner.as_mut().unwrap().stdin.as_mut().unwrap(); use std::io::Write; stdin.write_all(input.as_bytes()).unwrap(); self } pub fn kill(mut self) -> Self { self.inner.as_mut().unwrap().kill().unwrap(); self } /// Read the child's stdout concatenated with its stderr. Returns Ok if the child /// returns successfully, Err otherwise. pub fn output(mut self) -> Result { let output = self.inner.take().unwrap().wait_with_output().unwrap(); let mut buf = String::from_utf8(output.stdout).unwrap(); buf.push_str(&String::from_utf8(output.stderr).unwrap()); if output.status.success() { Ok(buf) } else { Err(buf) } } } impl Drop for Child { fn drop(&mut self) { if let Some(inner) = self.inner.as_mut() { let _ = inner.kill(); } } } pub fn run_internal>( args: &str, current_dir: Option

, stdout: Stdio, stderr: Stdio, ) -> Child { let mut path = std::env::current_exe().unwrap(); assert!(path.pop()); if path.ends_with("deps") { assert!(path.pop()); } // Note: Cargo automatically builds this binary for integration tests. path.push(format!("{}{}", env!("CARGO_PKG_NAME"), std::env::consts::EXE_SUFFIX)); let mut cmd = Command::new(path); if let Some(dir) = current_dir { cmd.current_dir(&dir); } cmd.env_clear(); cmd.stdin(Stdio::piped()); cmd.stdout(stdout); cmd.stderr(stderr); cmd.args(args.split_ascii_whitespace()); Child { inner: Some(cmd.spawn().unwrap()), } } /// Run the CLI with the given args pub fn run(args: &str) -> Child { run_internal::(args, None, Stdio::piped(), Stdio::piped()) } /// Run the CLI with the given args inside a temporary directory pub fn run_in_dir>(args: &str, current_dir: P) -> Child { run_internal(args, Some(current_dir), Stdio::piped(), Stdio::piped()) } pub fn tmp_file(name: &str) -> String { let path = Path::new(env!("OUT_DIR")).join(name); path.to_string_lossy().into_owned() } fn parse_server_stdio_from_var(var: &str) -> Result> { match env::var(var).as_deref() { Ok("inherit") => Ok(Stdio::inherit()), Ok("null") => Ok(Stdio::null()), Ok("piped") => Ok(Stdio::piped()), Ok(val) if val.starts_with("file://") => { Ok(Stdio::from(File::create(val.trim_start_matches("file://"))?)) } Ok(val) => Err(format!("Unsupported stdio value: {val:?}").into()), _ => Ok(Stdio::null()), } } pub struct StartServerArguments { pub auth: bool, pub tls: bool, pub wait_is_ready: bool, pub tick_interval: time::Duration, } impl Default for StartServerArguments { fn default() -> Self { Self { auth: true, tls: false, wait_is_ready: true, tick_interval: time::Duration::new(1, 0), } } } pub async fn start_server_without_auth() -> Result<(String, Child), Box> { start_server(StartServerArguments { auth: false, ..Default::default() }) .await } pub async fn start_server_with_defaults() -> Result<(String, Child), Box> { start_server(StartServerArguments::default()).await } pub async fn start_server( StartServerArguments { auth, tls, wait_is_ready, tick_interval, }: StartServerArguments, ) -> Result<(String, Child), Box> { let mut rng = thread_rng(); let port: u16 = rng.gen_range(13000..14000); let addr = format!("127.0.0.1:{port}"); let mut extra_args = String::default(); if tls { // Test the crt/key args but the keys are self signed so don't actually connect. let crt_path = tmp_file("crt.crt"); let key_path = tmp_file("key.pem"); let cert = rcgen::generate_simple_self_signed(Vec::new()).unwrap(); fs::write(&crt_path, cert.serialize_pem().unwrap()).unwrap(); fs::write(&key_path, cert.serialize_private_key_pem().into_bytes()).unwrap(); extra_args.push_str(format!(" --web-crt {crt_path} --web-key {key_path}").as_str()); } if auth { extra_args.push_str(" --auth"); } if !tick_interval.is_zero() { let sec = tick_interval.as_secs(); extra_args.push_str(format!(" --tick-interval {sec}s").as_str()); } let start_args = format!("start --bind {addr} memory --no-banner --log trace --user {USER} --pass {PASS} {extra_args}"); info!("starting server with args: {start_args}"); // Configure where the logs go when running the test let stdout = parse_server_stdio_from_var("SURREAL_TEST_SERVER_STDOUT")?; let stderr = parse_server_stdio_from_var("SURREAL_TEST_SERVER_STDERR")?; let server = run_internal::(&start_args, None, stdout, stderr); if !wait_is_ready { return Ok((addr, server)); } // Wait 5 seconds for the server to start let mut interval = time::interval(time::Duration::from_millis(500)); info!("Waiting for server to start..."); for _i in 0..10 { interval.tick().await; if run(&format!("isready --conn http://{addr}")).output().is_ok() { info!("Server ready!"); return Ok((addr, server)); } } let server_out = server.kill().output().err().unwrap(); error!("server output: {server_out}"); Err("server failed to start".into()) } type WsStream = WebSocketStream>; pub async fn connect_ws(addr: &str) -> Result> { let url = format!("ws://{}/rpc", addr); let (ws_stream, _) = connect_async(url).await?; Ok(ws_stream) } pub async fn ws_send_msg(socket: &mut WsStream, msg_req: String) -> Result<(), Box> { let now = time::Instant::now(); debug!("Sending message: {msg_req}"); tokio::select! { _ = time::sleep(time::Duration::from_millis(500)) => { return Err("timeout after 500ms waiting for the request to be sent".into()); } res = socket.send(Message::Text(msg_req)) => { debug!("Message sent in {:?}", now.elapsed()); if let Err(err) = res { return Err(format!("Error sending the message: {}", err).into()); } } } Ok(()) } pub async fn ws_recv_msg(socket: &mut WsStream) -> Result> { ws_recv_msg_with_fmt(socket, Format::Json).await } pub async fn ws_send_msg_and_wait_response( socket: &mut WsStream, msg_req: String, ) -> Result> { ws_send_msg(socket, msg_req).await?; ws_recv_msg_with_fmt(socket, Format::Json).await } pub enum Format { Json, Cbor, Pack, } pub async fn ws_recv_msg_with_fmt( socket: &mut WsStream, format: Format, ) -> Result> { let now = time::Instant::now(); debug!("Waiting for response..."); // Parse and return response let mut f = socket.try_filter(|msg| match format { Format::Json => futures_util::future::ready(msg.is_text()), Format::Pack | Format::Cbor => futures_util::future::ready(msg.is_binary()), }); tokio::select! { _ = time::sleep(time::Duration::from_millis(5000)) => { Err("timeout after 5s waiting for the response".into()) } res = f.select_next_some() => { debug!("Response received in {:?}", now.elapsed()); match format { Format::Json => Ok(serde_json::from_str(&res?.to_string())?), Format::Cbor => Ok(serde_cbor::from_slice(&res?.into_data())?), Format::Pack => Ok(serde_pack::from_slice(&res?.into_data())?), } } } } #[derive(Serialize, Deserialize)] struct SigninParams<'a> { user: &'a str, pass: &'a str, #[serde(skip_serializing_if = "Option::is_none")] ns: Option<&'a str>, #[serde(skip_serializing_if = "Option::is_none")] db: Option<&'a str>, #[serde(skip_serializing_if = "Option::is_none")] sc: Option<&'a str>, } #[derive(Serialize, Deserialize)] struct UseParams<'a> { #[serde(skip_serializing_if = "Option::is_none")] ns: Option<&'a str>, #[serde(skip_serializing_if = "Option::is_none")] db: Option<&'a str>, } pub async fn ws_signin( socket: &mut WsStream, user: &str, pass: &str, ns: Option<&str>, db: Option<&str>, sc: Option<&str>, ) -> Result> { let json = json!({ "id": "1", "method": "signin", "params": [ SigninParams { user, pass, ns, db, sc } ], }); ws_send_msg(socket, serde_json::to_string(&json).unwrap()).await?; let msg = ws_recv_msg(socket).await?; match msg.as_object() { Some(obj) if obj.keys().all(|k| ["id", "error"].contains(&k.as_str())) => { Err(format!("unexpected error from query request: {:?}", obj.get("error")).into()) } Some(obj) if obj.keys().all(|k| ["id", "result"].contains(&k.as_str())) => { Ok(obj.get("result").unwrap().as_str().unwrap_or_default().to_owned()) } _ => { error!("{:?}", msg.as_object().unwrap().keys().collect::>()); Err(format!("unexpected response: {:?}", msg).into()) } } } pub async fn ws_query( socket: &mut WsStream, query: &str, ) -> Result, Box> { let json = json!({ "id": "1", "method": "query", "params": [query], }); ws_send_msg(socket, serde_json::to_string(&json).unwrap()).await?; let msg = ws_recv_msg(socket).await?; match msg.as_object() { Some(obj) if obj.keys().all(|k| ["id", "error"].contains(&k.as_str())) => { Err(format!("unexpected error from query request: {:?}", obj.get("error")).into()) } Some(obj) if obj.keys().all(|k| ["id", "result"].contains(&k.as_str())) => Ok(obj .get("result") .ok_or(TestError::AssertionError { message: "expected a result from the received object".to_string(), })? .as_array() .ok_or(TestError::AssertionError { message: "expected the result object to be an array for the received ws message" .to_string(), })? .to_owned()), _ => { error!("{:?}", msg.as_object().unwrap().keys().collect::>()); Err(format!("unexpected response: {:?}", msg).into()) } } } pub async fn ws_use( socket: &mut WsStream, ns: Option<&str>, db: Option<&str>, ) -> Result> { let json = json!({ "id": "1", "method": "use", "params": [ ns, db ], }); ws_send_msg(socket, serde_json::to_string(&json).unwrap()).await?; let msg = ws_recv_msg(socket).await?; match msg.as_object() { Some(obj) if obj.keys().all(|k| ["id", "error"].contains(&k.as_str())) => { Err(format!("unexpected error from query request: {:?}", obj.get("error")).into()) } Some(obj) if obj.keys().all(|k| ["id", "result"].contains(&k.as_str())) => { Ok(obj.get("result").unwrap().to_owned()) } _ => { error!("{:?}", msg.as_object().unwrap().keys().collect::>()); Err(format!("unexpected response: {:?}", msg).into()) } } }