diff --git a/src/cli/mod.rs b/src/cli/mod.rs index e783a27b..564421c1 100644 --- a/src/cli/mod.rs +++ b/src/cli/mod.rs @@ -7,12 +7,16 @@ mod isready; mod ml; mod sql; mod start; +#[cfg(test)] +mod test; mod upgrade; mod validate; pub(crate) mod validator; mod version; +mod version_client; -use crate::cnf::LOGO; +use crate::cli::version_client::VersionClient; +use crate::cnf::{LOGO, PKG_VERSION}; use crate::env::RELEASE; use backup::BackupCommandArguments; use clap::{Parser, Subcommand}; @@ -21,9 +25,12 @@ use export::ExportCommandArguments; use import::ImportCommandArguments; use isready::IsReadyCommandArguments; use ml::MlCommand; +use semver::Version; use sql::SqlCommandArguments; use start::StartCommandArguments; +use std::ops::Deref; use std::process::ExitCode; +use std::time::Duration; use upgrade::UpgradeCommandArguments; use validate::ValidateCommandArguments; use version::VersionCommandArguments; @@ -48,6 +55,10 @@ We would love it if you could star the repository (https://github.com/surrealdb/ struct Cli { #[command(subcommand)] command: Commands, + #[arg(help = "Whether to allow web check for client version upgrades at start")] + #[arg(env = "SURREAL_ONLINE_VERSION_CHECK", long)] + #[arg(default_value_t = true)] + online_version_check: bool, } #[allow(clippy::large_enum_variant)] @@ -88,7 +99,23 @@ pub async fn init() -> ExitCode { .unwrap(); // Parse the CLI arguments let args = Cli::parse(); - // Run the respective command + // After parsing arguments, we check the version online + if args.online_version_check { + let client = version_client::new(Some(Duration::from_millis(500))).unwrap(); + if let Err(opt_version) = check_upgrade(&client, PKG_VERSION.deref()).await { + match opt_version { + None => { + warn!("A new version of SurrealDB may be available."); + } + Some(new_version) => { + warn!("A new version of SurrealDB is available: {}", new_version); + } + } + // TODO ansi_term crate? + warn!("You can upgrade using the {} command", "surreal upgrade"); + } + } + // After version warning we can run the respective command let output = match args.command { Commands::Start(args) => start::init(args).await, Commands::Backup(args) => backup::init(args).await, @@ -125,3 +152,25 @@ pub async fn init() -> ExitCode { ExitCode::SUCCESS } } + +/// Check if there is a newer version +/// Ok = No upgrade needed +/// Err = Upgrade needed, returns the new version if it is available +async fn check_upgrade( + client: &C, + pkg_version: &str, +) -> Result<(), Option> { + if let Ok(version) = client.fetch("latest").await { + // Request was successful, compare against current + let old_version = upgrade::parse_version(pkg_version).unwrap(); + let new_version = upgrade::parse_version(&version).unwrap(); + if old_version < new_version { + return Err(Some(new_version)); + } + } else { + // Request failed, check against date + // TODO: We don't have an "expiry" set per-version, so this is a todo + // It would return Err(None) if the version is too old + } + Ok(()) +} diff --git a/src/cli/test.rs b/src/cli/test.rs new file mode 100644 index 00000000..4a3d380f --- /dev/null +++ b/src/cli/test.rs @@ -0,0 +1,23 @@ +use crate::cli::check_upgrade; +use crate::cli::version_client::MapVersionClient; +use crate::err::Error; +use std::collections::BTreeMap; + +#[test_log::test(tokio::test)] +pub async fn test_version_upgrade() { + let mut client = MapVersionClient { + fetch_mock: BTreeMap::new(), + }; + client + .fetch_mock + .insert("latest".to_string(), || -> Result { Ok("1.0.0".to_string()) }); + check_upgrade(&client, "1.0.0") + .await + .expect("Expected the versions to be the same and not require an upgrade"); + check_upgrade(&client, "0.9.0") + .await + .expect_err("Expected the versions to be different and require an upgrade"); + check_upgrade(&client, "1.1.0") + .await + .expect("Expected the versions to be illogical, and not require and upgrade"); +} diff --git a/src/cli/upgrade.rs b/src/cli/upgrade.rs index 2f89e719..474bc838 100644 --- a/src/cli/upgrade.rs +++ b/src/cli/upgrade.rs @@ -1,3 +1,5 @@ +use crate::cli::version_client; +use crate::cli::version_client::VersionClient; use crate::cnf::PKG_VERSION; use crate::err::Error; use clap::Args; @@ -10,7 +12,10 @@ use std::path::Path; use std::process::Command; use surrealdb::env::{arch, os}; -const ROOT: &str = "https://download.surrealdb.com"; +pub(crate) const ROOT: &str = "https://download.surrealdb.com"; +const BETA: &str = "beta"; +const LATEST: &str = "latest"; +const NIGHTLY: &str = "nightly"; #[derive(Args, Debug)] pub struct UpgradeCommandArguments { @@ -31,27 +36,26 @@ pub struct UpgradeCommandArguments { impl UpgradeCommandArguments { /// Get the version string to download based on the user preference async fn version(&self) -> Result, Error> { - let nightly = "nightly"; - let beta = "beta"; // Convert the version to lowercase, if supplied let version = self.version.as_deref().map(str::to_ascii_lowercase); + let client = version_client::new(None)?; - if self.nightly || version.as_deref() == Some(nightly) { - Ok(Cow::Borrowed(nightly)) - } else if self.beta || version.as_deref() == Some(beta) { - fetch(beta).await + if self.nightly || version.as_deref() == Some(NIGHTLY) { + Ok(Cow::Borrowed(NIGHTLY)) + } else if self.beta || version.as_deref() == Some(BETA) { + client.fetch(BETA).await } else if let Some(version) = version { // Parse the version string to make sure it's valid, return an error if not let version = parse_version(&version)?; // Return the version, ensuring it's prefixed by `v` Ok(Cow::Owned(format!("v{version}"))) } else { - fetch("latest").await + client.fetch(LATEST).await } } } -fn parse_version(input: &str) -> Result { +pub(crate) fn parse_version(input: &str) -> Result { // Remove the `v` prefix, if supplied let version = input.strip_prefix('v').unwrap_or(input); // Parse the version @@ -76,17 +80,6 @@ fn parse_version(input: &str) -> Result { } } -async fn fetch(version: &str) -> Result, Error> { - let response = reqwest::get(format!("{ROOT}/{version}.txt")).await?; - if !response.status().is_success() { - return Err(Error::Io(IoError::new( - ErrorKind::Other, - format!("received status {} when fetching version", response.status()), - ))); - } - Ok(Cow::Owned(response.text().await?.trim().to_owned())) -} - pub async fn init(args: UpgradeCommandArguments) -> Result<(), Error> { // Initialize opentelemetry and logging crate::telemetry::builder().with_log_level("error").init(); diff --git a/src/cli/version_client.rs b/src/cli/version_client.rs new file mode 100644 index 00000000..25c55f51 --- /dev/null +++ b/src/cli/version_client.rs @@ -0,0 +1,57 @@ +#[allow(unused_imports)] +// This is used in format! macro +use crate::cli::upgrade::ROOT; +use crate::err::Error; +use reqwest::Client; +use std::borrow::Cow; +#[cfg(test)] +use std::collections::BTreeMap; +use std::io::Error as IoError; +use std::io::ErrorKind; +use std::time::Duration; + +pub(crate) trait VersionClient { + async fn fetch(&self, version: &str) -> Result, Error>; +} + +pub(crate) struct ReqwestVersionClient { + client: Client, +} + +pub(crate) fn new(timeout: Option) -> Result { + let mut client = Client::builder(); + if let Some(timeout) = timeout { + client = client.timeout(timeout); + } + let client = client.build()?; + Ok(ReqwestVersionClient { + client, + }) +} + +impl VersionClient for ReqwestVersionClient { + async fn fetch(&self, version: &str) -> Result, Error> { + let request = self.client.get(format!("{ROOT}/{version}.txt")).build().unwrap(); + let response = self.client.execute(request).await?; + if !response.status().is_success() { + return Err(Error::Io(IoError::new( + ErrorKind::Other, + format!("received status {} when fetching version", response.status()), + ))); + } + Ok(Cow::Owned(response.text().await?.trim().to_owned())) + } +} + +#[cfg(test)] +pub(crate) struct MapVersionClient { + pub(crate) fetch_mock: BTreeMap Result>, +} + +#[cfg(test)] +impl VersionClient for MapVersionClient { + async fn fetch(&self, version: &str) -> Result, Error> { + let found = self.fetch_mock.get(version).unwrap(); + found().map(Cow::Owned) + } +}