Add version check for web with short timeout (#3599)

This commit is contained in:
Przemyslaw Hugh Kaznowski 2024-03-04 11:12:59 +00:00 committed by GitHub
parent e06cd111cf
commit 4471433a78
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 144 additions and 22 deletions

View file

@ -7,12 +7,16 @@ mod isready;
mod ml;
mod sql;
mod start;
#[cfg(test)]
mod test;
mod upgrade;
mod validate;
pub(crate) mod validator;
mod version;
mod version_client;
use crate::cnf::LOGO;
use crate::cli::version_client::VersionClient;
use crate::cnf::{LOGO, PKG_VERSION};
use crate::env::RELEASE;
use backup::BackupCommandArguments;
use clap::{Parser, Subcommand};
@ -21,9 +25,12 @@ use export::ExportCommandArguments;
use import::ImportCommandArguments;
use isready::IsReadyCommandArguments;
use ml::MlCommand;
use semver::Version;
use sql::SqlCommandArguments;
use start::StartCommandArguments;
use std::ops::Deref;
use std::process::ExitCode;
use std::time::Duration;
use upgrade::UpgradeCommandArguments;
use validate::ValidateCommandArguments;
use version::VersionCommandArguments;
@ -48,6 +55,10 @@ We would love it if you could star the repository (https://github.com/surrealdb/
struct Cli {
#[command(subcommand)]
command: Commands,
#[arg(help = "Whether to allow web check for client version upgrades at start")]
#[arg(env = "SURREAL_ONLINE_VERSION_CHECK", long)]
#[arg(default_value_t = true)]
online_version_check: bool,
}
#[allow(clippy::large_enum_variant)]
@ -88,7 +99,23 @@ pub async fn init() -> ExitCode {
.unwrap();
// Parse the CLI arguments
let args = Cli::parse();
// Run the respective command
// After parsing arguments, we check the version online
if args.online_version_check {
let client = version_client::new(Some(Duration::from_millis(500))).unwrap();
if let Err(opt_version) = check_upgrade(&client, PKG_VERSION.deref()).await {
match opt_version {
None => {
warn!("A new version of SurrealDB may be available.");
}
Some(new_version) => {
warn!("A new version of SurrealDB is available: {}", new_version);
}
}
// TODO ansi_term crate?
warn!("You can upgrade using the {} command", "surreal upgrade");
}
}
// After version warning we can run the respective command
let output = match args.command {
Commands::Start(args) => start::init(args).await,
Commands::Backup(args) => backup::init(args).await,
@ -125,3 +152,25 @@ pub async fn init() -> ExitCode {
ExitCode::SUCCESS
}
}
/// Check if there is a newer version
/// Ok = No upgrade needed
/// Err = Upgrade needed, returns the new version if it is available
async fn check_upgrade<C: VersionClient>(
client: &C,
pkg_version: &str,
) -> Result<(), Option<Version>> {
if let Ok(version) = client.fetch("latest").await {
// Request was successful, compare against current
let old_version = upgrade::parse_version(pkg_version).unwrap();
let new_version = upgrade::parse_version(&version).unwrap();
if old_version < new_version {
return Err(Some(new_version));
}
} else {
// Request failed, check against date
// TODO: We don't have an "expiry" set per-version, so this is a todo
// It would return Err(None) if the version is too old
}
Ok(())
}

23
src/cli/test.rs Normal file
View file

@ -0,0 +1,23 @@
use crate::cli::check_upgrade;
use crate::cli::version_client::MapVersionClient;
use crate::err::Error;
use std::collections::BTreeMap;
#[test_log::test(tokio::test)]
pub async fn test_version_upgrade() {
let mut client = MapVersionClient {
fetch_mock: BTreeMap::new(),
};
client
.fetch_mock
.insert("latest".to_string(), || -> Result<String, Error> { Ok("1.0.0".to_string()) });
check_upgrade(&client, "1.0.0")
.await
.expect("Expected the versions to be the same and not require an upgrade");
check_upgrade(&client, "0.9.0")
.await
.expect_err("Expected the versions to be different and require an upgrade");
check_upgrade(&client, "1.1.0")
.await
.expect("Expected the versions to be illogical, and not require and upgrade");
}

View file

@ -1,3 +1,5 @@
use crate::cli::version_client;
use crate::cli::version_client::VersionClient;
use crate::cnf::PKG_VERSION;
use crate::err::Error;
use clap::Args;
@ -10,7 +12,10 @@ use std::path::Path;
use std::process::Command;
use surrealdb::env::{arch, os};
const ROOT: &str = "https://download.surrealdb.com";
pub(crate) const ROOT: &str = "https://download.surrealdb.com";
const BETA: &str = "beta";
const LATEST: &str = "latest";
const NIGHTLY: &str = "nightly";
#[derive(Args, Debug)]
pub struct UpgradeCommandArguments {
@ -31,27 +36,26 @@ pub struct UpgradeCommandArguments {
impl UpgradeCommandArguments {
/// Get the version string to download based on the user preference
async fn version(&self) -> Result<Cow<'_, str>, Error> {
let nightly = "nightly";
let beta = "beta";
// Convert the version to lowercase, if supplied
let version = self.version.as_deref().map(str::to_ascii_lowercase);
let client = version_client::new(None)?;
if self.nightly || version.as_deref() == Some(nightly) {
Ok(Cow::Borrowed(nightly))
} else if self.beta || version.as_deref() == Some(beta) {
fetch(beta).await
if self.nightly || version.as_deref() == Some(NIGHTLY) {
Ok(Cow::Borrowed(NIGHTLY))
} else if self.beta || version.as_deref() == Some(BETA) {
client.fetch(BETA).await
} else if let Some(version) = version {
// Parse the version string to make sure it's valid, return an error if not
let version = parse_version(&version)?;
// Return the version, ensuring it's prefixed by `v`
Ok(Cow::Owned(format!("v{version}")))
} else {
fetch("latest").await
client.fetch(LATEST).await
}
}
}
fn parse_version(input: &str) -> Result<Version, Error> {
pub(crate) fn parse_version(input: &str) -> Result<Version, Error> {
// Remove the `v` prefix, if supplied
let version = input.strip_prefix('v').unwrap_or(input);
// Parse the version
@ -76,17 +80,6 @@ fn parse_version(input: &str) -> Result<Version, Error> {
}
}
async fn fetch(version: &str) -> Result<Cow<'_, str>, Error> {
let response = reqwest::get(format!("{ROOT}/{version}.txt")).await?;
if !response.status().is_success() {
return Err(Error::Io(IoError::new(
ErrorKind::Other,
format!("received status {} when fetching version", response.status()),
)));
}
Ok(Cow::Owned(response.text().await?.trim().to_owned()))
}
pub async fn init(args: UpgradeCommandArguments) -> Result<(), Error> {
// Initialize opentelemetry and logging
crate::telemetry::builder().with_log_level("error").init();

57
src/cli/version_client.rs Normal file
View file

@ -0,0 +1,57 @@
#[allow(unused_imports)]
// This is used in format! macro
use crate::cli::upgrade::ROOT;
use crate::err::Error;
use reqwest::Client;
use std::borrow::Cow;
#[cfg(test)]
use std::collections::BTreeMap;
use std::io::Error as IoError;
use std::io::ErrorKind;
use std::time::Duration;
pub(crate) trait VersionClient {
async fn fetch(&self, version: &str) -> Result<Cow<'static, str>, Error>;
}
pub(crate) struct ReqwestVersionClient {
client: Client,
}
pub(crate) fn new(timeout: Option<Duration>) -> Result<ReqwestVersionClient, Error> {
let mut client = Client::builder();
if let Some(timeout) = timeout {
client = client.timeout(timeout);
}
let client = client.build()?;
Ok(ReqwestVersionClient {
client,
})
}
impl VersionClient for ReqwestVersionClient {
async fn fetch(&self, version: &str) -> Result<Cow<'static, str>, Error> {
let request = self.client.get(format!("{ROOT}/{version}.txt")).build().unwrap();
let response = self.client.execute(request).await?;
if !response.status().is_success() {
return Err(Error::Io(IoError::new(
ErrorKind::Other,
format!("received status {} when fetching version", response.status()),
)));
}
Ok(Cow::Owned(response.text().await?.trim().to_owned()))
}
}
#[cfg(test)]
pub(crate) struct MapVersionClient {
pub(crate) fetch_mock: BTreeMap<String, fn() -> Result<String, Error>>,
}
#[cfg(test)]
impl VersionClient for MapVersionClient {
async fn fetch(&self, version: &str) -> Result<Cow<'static, str>, Error> {
let found = self.fetch_mock.get(version).unwrap();
found().map(Cow::Owned)
}
}