From fd6e774fab0f984f5ae437967e2c5ad6f096c026 Mon Sep 17 00:00:00 2001 From: Matthew James Briggs Date: Thu, 2 Nov 2023 10:57:50 -0700 Subject: [PATCH 1/3] migrator: refactor test code This refactor of test code will make it easier to switch to async since the tokio::fs::read_dir API does not match the std::fs::read_dir API. --- sources/api/migration/migrator/src/test.rs | 115 ++++++++++----------- 1 file changed, 55 insertions(+), 60 deletions(-) diff --git a/sources/api/migration/migrator/src/test.rs b/sources/api/migration/migrator/src/test.rs index 0433bf924f5..366c4130bc7 100644 --- a/sources/api/migration/migrator/src/test.rs +++ b/sources/api/migration/migrator/src/test.rs @@ -4,8 +4,7 @@ use crate::args::Args; use crate::run; use chrono::{DateTime, Utc}; use semver::Version; -use std::fs; -use std::fs::{DirEntry, File}; +use std::fs::{self, File}; use std::io::Write; use std::path::{Path, PathBuf}; use tempfile::TempDir; @@ -181,23 +180,14 @@ fn create_test_repo(test_type: TestType) -> TestRepo { .timestamp_version(one) .timestamp_expires(long_ago); - fs::read_dir(tuf_indir) - .unwrap() - .filter(|dir_entry_result| { - if let Ok(dir_entry) = dir_entry_result { - return dir_entry.path().is_file(); - } - false - }) - .for_each(|dir_entry_result| { - let dir_entry = dir_entry_result.unwrap(); - editor - .add_target( - dir_entry.file_name().to_str().unwrap(), - tough::schema::Target::from_path(dir_entry.path()).unwrap(), - ) - .unwrap(); - }); + list_dir_files(tuf_indir).into_iter().for_each(|path| { + editor + .add_target( + path.file_name().unwrap().to_str().unwrap(), + tough::schema::Target::from_path(&path).unwrap(), + ) + .unwrap(); + }); let signed_repo = editor .sign(&[Box::new(tough::key_source::LocalKeySource { path: pem() })]) .unwrap(); @@ -224,31 +214,26 @@ fn assert_directory_structure_with_failed_migration( from: &Version, to: &Version, ) -> PathBuf { - let dir_entries: Vec = fs::read_dir(dir) - .unwrap() - .map(|item| item.unwrap()) - .collect(); - + let paths = list_dir_entries(dir); let from_ver = format!("v{}", from); let from_ver_unique_prefix = format!("v{}_", from); let to_ver_unique_prefix = format!("v{}_", to); - assert_eq!(dir_entries.len(), 8); - assert_dir_entry_exists(&dir_entries, "current"); - assert_dir_entry_exists(&dir_entries, "result.txt"); - assert_dir_entry_exists(&dir_entries, "v0"); - assert_dir_entry_exists(&dir_entries, "v0.99"); - assert_dir_entry_exists(&dir_entries, &from_ver); - assert_dir_starting_with_exists(&dir_entries, &from_ver_unique_prefix); + assert_eq!(paths.len(), 8); + assert_dir_entry_exists(&paths, "current"); + assert_dir_entry_exists(&paths, "result.txt"); + assert_dir_entry_exists(&paths, "v0"); + assert_dir_entry_exists(&paths, "v0.99"); + assert_dir_entry_exists(&paths, &from_ver); + assert_dir_starting_with_exists(&paths, &from_ver_unique_prefix); // There are two datastores that start with the target version followed by an underscore. This // is because the datastore we intended to promote (target_datastore) and one intermediate // datastore are expected to be left behind for debugging after a migration failure. - let left_behind_count = dir_entries + let left_behind_count = paths .iter() .filter_map(|entry| { entry - .path() .file_name() .unwrap() .to_str() @@ -265,54 +250,47 @@ fn assert_directory_structure_with_failed_migration( left_behind_count ); - let symlink = dir_entries + let symlink = paths .iter() - .find(|entry| entry.path().file_name().unwrap().to_str().unwrap() == "current") - .unwrap() - .path(); + .find(|entry| entry.file_name().unwrap().to_str().unwrap() == "current") + .unwrap(); symlink.canonicalize().unwrap() } /// Asserts that the expected directories and files are in the datastore directory after a /// successful migration. Returns the absolute path that the `current` symlink is pointing to. fn assert_directory_structure(dir: &Path) -> PathBuf { - let dir_entries: Vec = fs::read_dir(dir) - .unwrap() - .map(|item| item.unwrap()) - .collect(); - - assert_eq!(dir_entries.len(), 8); - assert_dir_entry_exists(&dir_entries, "current"); - assert_dir_entry_exists(&dir_entries, "result.txt"); - assert_dir_entry_exists(&dir_entries, "v0"); - assert_dir_entry_exists(&dir_entries, "v0.99"); - assert_dir_entry_exists(&dir_entries, "v0.99.0"); - assert_dir_entry_exists(&dir_entries, "v0.99.1"); - assert_dir_starting_with_exists(&dir_entries, "v0.99.0_"); - assert_dir_starting_with_exists(&dir_entries, "v0.99.1_"); - - let symlink = dir_entries + let paths = list_dir_entries(dir); + assert_eq!(paths.len(), 8); + assert_dir_entry_exists(&paths, "current"); + assert_dir_entry_exists(&paths, "result.txt"); + assert_dir_entry_exists(&paths, "v0"); + assert_dir_entry_exists(&paths, "v0.99"); + assert_dir_entry_exists(&paths, "v0.99.0"); + assert_dir_entry_exists(&paths, "v0.99.1"); + assert_dir_starting_with_exists(&paths, "v0.99.0_"); + assert_dir_starting_with_exists(&paths, "v0.99.1_"); + + let symlink = paths .iter() - .find(|entry| entry.path().file_name().unwrap().to_str().unwrap() == "current") - .unwrap() - .path(); + .find(|entry| entry.file_name().unwrap().to_str().unwrap() == "current") + .unwrap(); symlink.canonicalize().unwrap() } -fn assert_dir_entry_exists(dir_entries: &[DirEntry], exact_name: &str) { +fn assert_dir_entry_exists(dir_entries: &[PathBuf], exact_name: &str) { assert!( dir_entries .iter() - .any(|entry| entry.path().file_name().unwrap().to_str().unwrap() == exact_name), + .any(|entry| entry.file_name().unwrap().to_str().unwrap() == exact_name), "'{}' not found", exact_name ); } -fn assert_dir_starting_with_exists(dir_entries: &[DirEntry], starts_with: &str) { +fn assert_dir_starting_with_exists(dir_entries: &[PathBuf], starts_with: &str) { assert!( dir_entries.iter().any(|entry| entry - .path() .file_name() .unwrap() .to_str() @@ -323,6 +301,23 @@ fn assert_dir_starting_with_exists(dir_entries: &[DirEntry], starts_with: &str) ); } +fn list_dir_entries(dir: impl AsRef) -> Vec { + fs::read_dir(dir) + .unwrap() + .map(|dir_entry_result| { + let dir_entry = dir_entry_result.unwrap(); + dir_entry.path() + }) + .collect() +} + +fn list_dir_files(dir: impl AsRef) -> Vec { + list_dir_entries(dir) + .into_iter() + .filter(|path| path.is_file()) + .collect() +} + /// Tests the migrator program end-to-end using the `run` function. Creates a TUF repo in a /// tempdir which includes a `manifest.json` with a couple of migrations: /// ``` From 1d9ca9d58a961374507cf36307656fcefc36cc2b Mon Sep 17 00:00:00 2001 From: Matthew James Briggs Date: Thu, 2 Nov 2023 10:15:14 -0700 Subject: [PATCH 2/3] migrator: update tough to async version --- sources/Cargo.lock | 100 +++++++++++- sources/api/migration/migrator/Cargo.toml | 7 +- sources/api/migration/migrator/src/main.rs | 176 ++++++++++++++------- sources/api/migration/migrator/src/test.rs | 87 +++++----- 4 files changed, 264 insertions(+), 106 deletions(-) diff --git a/sources/Cargo.lock b/sources/Cargo.lock index dd948d80593..135dde924ea 100644 --- a/sources/Cargo.lock +++ b/sources/Cargo.lock @@ -491,6 +491,17 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "async-recursion" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5fd55a5ba1179988837d24ab4c7cc8ed6efdeff578ede0416b4225a5fca35bd0" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.38", +] + [[package]] name = "async-stream" version = "0.3.5" @@ -1862,6 +1873,7 @@ checksum = "23342abe12aba583913b2e62f22225ff9c950774065e4bfb61a19cd9770fec40" dependencies = [ "futures-channel", "futures-core", + "futures-executor", "futures-io", "futures-sink", "futures-task", @@ -1884,6 +1896,17 @@ version = "0.3.28" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4bca583b7e26f571124fe5b7561d49cb2868d79116cfa0eefce955557c6fee8c" +[[package]] +name = "futures-executor" +version = "0.3.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ccecee823288125bd88b4d7f565c9e58e41858e47ab72e8ea2d64e93624386e0" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + [[package]] name = "futures-io" version = "0.3.28" @@ -2626,7 +2649,10 @@ name = "migrator" version = "0.1.0" dependencies = [ "bottlerocket-release", + "bytes", "chrono", + "futures", + "futures-core", "generate-readme", "log", "lz4", @@ -2638,7 +2664,9 @@ dependencies = [ "snafu", "storewolf", "tempfile", - "tough", + "tokio", + "tokio-util", + "tough 0.15.0", "update_metadata", "url", ] @@ -3044,6 +3072,16 @@ dependencies = [ "base64 0.13.1", ] +[[package]] +name = "pem" +version = "3.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3163d2912b7c3b52d651a055f2c7eec9ba5cd22d26ef75b8dd3a59980b185923" +dependencies = [ + "base64 0.21.4", + "serde", +] + [[package]] name = "pentacle" version = "1.0.0" @@ -3415,10 +3453,12 @@ dependencies = [ "serde_urlencoded", "tokio", "tokio-rustls", + "tokio-util", "tower-service", "url", "wasm-bindgen", "wasm-bindgen-futures", + "wasm-streams", "web-sys", "winreg", ] @@ -4318,6 +4358,7 @@ checksum = "1d68074620f57a0b21594d9735eb2e98ab38b17f80d3fcb189fca266771ca60d" dependencies = [ "bytes", "futures-core", + "futures-io", "futures-sink", "pin-project-lite", "tokio", @@ -4381,7 +4422,7 @@ dependencies = [ "log", "olpc-cjson", "path-absolutize", - "pem", + "pem 1.1.1", "percent-encoding", "reqwest", "ring", @@ -4395,6 +4436,40 @@ dependencies = [ "walkdir", ] +[[package]] +name = "tough" +version = "0.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d16dc5f42fc7ce7cb51eebc7a6ef91f4d69a6d41bb13f34a09674ec47e454d9b" +dependencies = [ + "async-recursion", + "async-trait", + "bytes", + "chrono", + "dyn-clone", + "futures", + "futures-core", + "globset", + "hex", + "log", + "olpc-cjson", + "pem 3.0.2", + "percent-encoding", + "reqwest", + "ring", + "serde", + "serde_json", + "serde_plain", + "snafu", + "tempfile", + "tokio", + "tokio-util", + "typed-path", + "untrusted", + "url", + "walkdir", +] + [[package]] name = "tower" version = "0.4.13" @@ -4480,6 +4555,12 @@ dependencies = [ "utf-8", ] +[[package]] +name = "typed-path" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbb9d13b8242894ff21f9990082b90a6410a43dcc6029ac4227a1467853ba781" + [[package]] name = "typenum" version = "1.17.0" @@ -4572,7 +4653,7 @@ dependencies = [ "snafu", "tempfile", "toml 0.5.11", - "tough", + "tough 0.14.0", "update_metadata", "url", ] @@ -4730,6 +4811,19 @@ version = "0.2.87" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ca6ad05a4870b2bf5fe995117d3728437bd27d7cd5f06f13c17443ef369775a1" +[[package]] +name = "wasm-streams" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6bbae3363c08332cadccd13b67db371814cd214c2524020932f0804b8cf7c078" +dependencies = [ + "futures-util", + "js-sys", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", +] + [[package]] name = "web-sys" version = "0.3.64" diff --git a/sources/api/migration/migrator/Cargo.toml b/sources/api/migration/migrator/Cargo.toml index 5363587d861..32a5295689e 100644 --- a/sources/api/migration/migrator/Cargo.toml +++ b/sources/api/migration/migrator/Cargo.toml @@ -11,6 +11,9 @@ exclude = ["README.md"] [dependencies] bottlerocket-release = { path = "../../../bottlerocket-release", version = "0.1" } +bytes = "1" +futures = "0.3" +futures-core = "0.3" log = "0.4" lz4 = "1" nix = "0.26" @@ -19,7 +22,9 @@ rand = { version = "0.8", default-features = false, features = ["std", "std_rng" semver = "1" simplelog = "0.12" snafu = "0.7" -tough = "0.14" +tokio = { version = "~1.32", default-features = false, features = ["fs", "macros", "rt-multi-thread"] } # LTS +tokio-util = { version = "0.7", features = ["compat", "io-util"] } +tough = { version = "0.15", features = ["http"] } update_metadata = { path = "../../../updater/update_metadata", version = "0.1" } url = "2" diff --git a/sources/api/migration/migrator/src/main.rs b/sources/api/migration/migrator/src/main.rs index b6f685d6745..a3f09c8ad63 100644 --- a/sources/api/migration/migrator/src/main.rs +++ b/sources/api/migration/migrator/src/main.rs @@ -24,6 +24,7 @@ extern crate log; use args::Args; use direction::Direction; use error::Result; +use futures::{StreamExt, TryStreamExt}; use nix::{dir::Dir, fcntl::OFlag, sys::stat::Mode, unistd::fsync}; use rand::{distributions::Alphanumeric, thread_rng, Rng}; use semver::Version; @@ -31,11 +32,15 @@ use simplelog::{Config as LogConfig, SimpleLogger}; use snafu::{ensure, OptionExt, ResultExt}; use std::convert::TryInto; use std::env; -use std::fs::{self, File}; +use std::io::ErrorKind; use std::os::unix::fs::symlink; use std::os::unix::io::AsRawFd; use std::path::{Path, PathBuf}; use std::process; +use tokio::fs; +use tokio::runtime::Handle; +use tokio_util::compat::FuturesAsyncReadCompatExt; +use tokio_util::io::SyncIoBridge; use tough::{ExpirationEnforcement, FilesystemTransport, RepositoryLoader}; use update_metadata::Manifest; use url::Url; @@ -49,20 +54,21 @@ mod test; // Returning a Result from main makes it print a Debug representation of the error, but with Snafu // we have nice Display representations of the error, so we wrap "main" (run) and print any error. // https://github.com/shepmaster/snafu/issues/110 -fn main() { +#[tokio::main] +async fn main() { let args = Args::from_env(env::args()); // SimpleLogger will send errors to stderr and anything less to stdout. if let Err(e) = SimpleLogger::init(args.log_level, LogConfig::default()) { eprintln!("{}", e); process::exit(1); } - if let Err(e) = run(&args) { + if let Err(e) = run(&args).await { eprintln!("{}", e); process::exit(1); } } -fn get_current_version

(datastore_dir: P) -> Result +async fn get_current_version

(datastore_dir: P) -> Result where P: AsRef, { @@ -70,12 +76,21 @@ where // Find the current patch version link, which contains our full version number let current = datastore_dir.join("current"); - let major = datastore_dir - .join(fs::read_link(¤t).context(error::LinkReadSnafu { link: current })?); - let minor = - datastore_dir.join(fs::read_link(&major).context(error::LinkReadSnafu { link: major })?); - let patch = - datastore_dir.join(fs::read_link(&minor).context(error::LinkReadSnafu { link: minor })?); + let major = datastore_dir.join( + fs::read_link(¤t) + .await + .context(error::LinkReadSnafu { link: current })?, + ); + let minor = datastore_dir.join( + fs::read_link(&major) + .await + .context(error::LinkReadSnafu { link: major })?, + ); + let patch = datastore_dir.join( + fs::read_link(&minor) + .await + .context(error::LinkReadSnafu { link: minor })?, + ); // Pull out the basename of the path, which contains the version let version_os_str = patch @@ -93,7 +108,7 @@ where Version::parse(version_str).context(error::InvalidDataStoreVersionSnafu { path: &patch }) } -pub(crate) fn run(args: &Args) -> Result<()> { +pub(crate) async fn run(args: &Args) -> Result<()> { // Get the directory we're working in. let datastore_dir = args .datastore_path @@ -102,7 +117,7 @@ pub(crate) fn run(args: &Args) -> Result<()> { path: &args.datastore_path, })?; - let current_version = get_current_version(datastore_dir)?; + let current_version = get_current_version(datastore_dir).await?; let direction = Direction::from_versions(¤t_version, &args.migrate_to_version) .unwrap_or_else(|| { info!( @@ -127,9 +142,11 @@ pub(crate) fn run(args: &Args) -> Result<()> { })?; // open a reader to the root.json file - let root_file = File::open(&args.root_path).with_context(|_| error::OpenRootSnafu { - path: args.root_path.clone(), - })?; + let root_bytes = fs::read(&args.root_path) + .await + .with_context(|_| error::OpenRootSnafu { + path: args.root_path.clone(), + })?; // We will load the locally cached TUF repository to obtain the manifest. The Repository is // loaded using a `TempDir` for its internal Datastore (this is the default). Part of using a @@ -143,7 +160,7 @@ pub(crate) fn run(args: &Args) -> Result<()> { // Failure to load the TUF repo at the expected location is a serious issue because updog should // always create a TUF repo that contains at least the manifest, even if there are no migrations. - let repo = RepositoryLoader::new(root_file, metadata_base_url, targets_base_url) + let repo = RepositoryLoader::new(&root_bytes, metadata_base_url, targets_base_url) .transport(FilesystemTransport) // The threats TUF mitigates are more than the threats we are attempting to mitigate // here by caching signatures for migrations locally and using them after a reboot but @@ -154,8 +171,9 @@ pub(crate) fn run(args: &Args) -> Result<()> { // if the targets expired between updog downloading them and now. .expiration_enforcement(ExpirationEnforcement::Unsafe) .load() + .await .context(error::RepoLoadSnafu)?; - let manifest = load_manifest(&repo)?; + let manifest = load_manifest(repo.clone()).await?; let migrations = update_metadata::find_migrations(¤t_version, &args.migrate_to_version, &manifest) .context(error::FindMigrationsSnafu)?; @@ -165,7 +183,7 @@ pub(crate) fn run(args: &Args) -> Result<()> { // change, we can just link to the last version rather than making a copy. // (Note: we link to the fully resolved directory, args.datastore_path, so we don't // have a chain of symlinks that could go past the maximum depth.) - flip_to_new_version(&args.migrate_to_version, &args.datastore_path)?; + flip_to_new_version(&args.migrate_to_version, &args.datastore_path).await?; } else { let copy_path = run_migrations( &repo, @@ -173,8 +191,9 @@ pub(crate) fn run(args: &Args) -> Result<()> { &migrations, &args.datastore_path, &args.migrate_to_version, - )?; - flip_to_new_version(&args.migrate_to_version, copy_path)?; + ) + .await?; + flip_to_new_version(&args.migrate_to_version, copy_path).await?; } Ok(()) } @@ -220,7 +239,7 @@ where /// /// The given data store is used as a starting point; each migration is given the output of the /// previous migration, and the final output becomes the new data store. -fn run_migrations( +async fn run_migrations( repository: &tough::Repository, direction: Direction, migrations: &[S], @@ -249,44 +268,58 @@ where .context(error::TargetNameSnafu { target: migration })?; // get the migration from the repo - let lz4_bytes = repository + let lz4_byte_stream = repository .read_target(&migration) + .await .context(error::LoadMigrationSnafu { migration: migration.raw(), })? .context(error::MigrationNotFoundSnafu { migration: migration.raw(), - })?; + })? + .map(|entry| { + let annotated: std::result::Result = entry; + annotated.map_err(|tough_error| std::io::Error::new(ErrorKind::Other, tough_error)) + }); + + // Convert the stream to a blocking Read object. + let lz4_async_read = lz4_byte_stream.into_async_read().compat(); + let lz4_bytes = SyncIoBridge::new(lz4_async_read); // Add an LZ4 decoder so the bytes will be deflated on read let mut reader = lz4::Decoder::new(lz4_bytes).context(error::Lz4DecodeSnafu { migration: migration.raw(), })?; - // Create a sealed command with pentacle, so we can run the verified bytes from memory - let mut command = - pentacle::SealedCommand::new(&mut reader).context(error::SealMigrationSnafu)?; - - // Point each migration in the right direction, and at the given data store. - command.arg(direction.to_string()); - command.args(&[ + let mut command_args = vec![ + direction.to_string(), "--source-datastore".to_string(), source_datastore.display().to_string(), - ]); + ]; // Create a new output location for this migration. target_datastore = new_datastore_location(source_datastore, new_version)?; - command.args(&[ - "--target-datastore".to_string(), - target_datastore.display().to_string(), - ]); + command_args.push("--target-datastore".to_string()); + command_args.push(target_datastore.display().to_string()); info!("Running migration '{}'", migration.raw()); - debug!("Migration command: {:?}", command); - let output = command.output().context(error::StartMigrationSnafu)?; + // Run this blocking IO in a thread so it doesn't block the scheduler. + let rt = Handle::current(); + let task = rt.spawn_blocking(move || { + // Create a sealed command with pentacle, so we can run the verified bytes from memory + let mut command = + pentacle::SealedCommand::new(&mut reader).context(error::SealMigrationSnafu)?; + command.args(command_args); + + debug!("Migration command: {:?}", command); + let output = command.output().context(error::StartMigrationSnafu)?; + Ok(output) + }); + + let output = task.await.expect("TODO - snafu error for this")?; if !output.stdout.is_empty() { debug!( "Migration stdout: {}", @@ -310,7 +343,7 @@ where // If an intermediate datastore exists from a previous loop, delete it. if let Some(path) = &intermediate_datastore { - delete_intermediate_datastore(path); + delete_intermediate_datastore(path).await; } // Remember the location of the target_datastore to delete it in the next loop iteration @@ -323,11 +356,11 @@ where } // Try to delete an intermediate datastore if it exists. If it fails to delete, print an error. -fn delete_intermediate_datastore(path: &PathBuf) { +async fn delete_intermediate_datastore(path: &PathBuf) { // Even if we fail to remove an intermediate data store, we don't want to fail the upgrade - // just let someone know for later cleanup. trace!("Removing intermediate data store at {}", path.display()); - if let Err(e) = fs::remove_dir_all(path) { + if let Err(e) = fs::remove_dir_all(path).await { error!( "Failed to remove intermediate data store at '{}': {}", path.display(), @@ -344,7 +377,7 @@ fn delete_intermediate_datastore(path: &PathBuf) { /// * pointing the major version to the minor version /// * pointing the 'current' link to the major version /// * fsyncing the directory to disk -fn flip_to_new_version

(version: &Version, to_datastore: P) -> Result<()> +async fn flip_to_new_version

(version: &Version, to_datastore: P) -> Result<()> where P: AsRef, { @@ -426,9 +459,11 @@ where symlink(to_target, &temp_link).context(error::LinkCreateSnafu { path: &temp_link })?; // Atomically swap the link into place, so that the patch version link points to the new data // store copy. - fs::rename(&temp_link, &patch_version_link).context(error::LinkSwapSnafu { - link: &patch_version_link, - })?; + fs::rename(&temp_link, &patch_version_link) + .await + .context(error::LinkSwapSnafu { + link: &patch_version_link, + })?; // =^..^= =^..^= =^..^= =^..^= @@ -443,9 +478,11 @@ where symlink(patch_target, &temp_link).context(error::LinkCreateSnafu { path: &temp_link })?; // Atomically swap the link into place, so that the minor version link points to the new patch // version. - fs::rename(&temp_link, &minor_version_link).context(error::LinkSwapSnafu { - link: &minor_version_link, - })?; + fs::rename(&temp_link, &minor_version_link) + .await + .context(error::LinkSwapSnafu { + link: &minor_version_link, + })?; // =^..^= =^..^= =^..^= =^..^= @@ -460,9 +497,11 @@ where symlink(minor_target, &temp_link).context(error::LinkCreateSnafu { path: &temp_link })?; // Atomically swap the link into place, so that the major version link points to the new minor // version. - fs::rename(&temp_link, &major_version_link).context(error::LinkSwapSnafu { - link: &major_version_link, - })?; + fs::rename(&temp_link, &major_version_link) + .await + .context(error::LinkSwapSnafu { + link: &major_version_link, + })?; // =^..^= =^..^= =^..^= =^..^= @@ -476,9 +515,11 @@ where // This will point at, for example, /path/to/datastore/v1 symlink(major_target, &temp_link).context(error::LinkCreateSnafu { path: &temp_link })?; // Atomically swap the link into place, so that 'current' points to the new major version. - fs::rename(&temp_link, ¤t_version_link).context(error::LinkSwapSnafu { - link: ¤t_version_link, - })?; + fs::rename(&temp_link, ¤t_version_link) + .await + .context(error::LinkSwapSnafu { + link: ¤t_version_link, + })?; // =^..^= =^..^= =^..^= =^..^= @@ -496,16 +537,29 @@ where Ok(()) } -fn load_manifest(repository: &tough::Repository) -> Result { +async fn load_manifest(repository: tough::Repository) -> Result { let target = "manifest.json"; let target = target .try_into() .context(error::TargetNameSnafu { target })?; - Manifest::from_json( - repository - .read_target(&target) - .context(error::ManifestLoadSnafu)? - .context(error::ManifestNotFoundSnafu)?, - ) - .context(error::ManifestParseSnafu) + + let stream = repository + .read_target(&target) + .await + .context(error::ManifestLoadSnafu)? + .context(error::ManifestNotFoundSnafu)? + .map(|entry| { + let annotated: std::result::Result = entry; + annotated.map_err(|tough_error| std::io::Error::new(ErrorKind::Other, tough_error)) + }); + + // Convert the stream to a blocking Read object. + let async_read = stream.into_async_read().compat(); + let reader = SyncIoBridge::new(async_read); + + // Run this blocking Read object in a thread so it doesn't block the scheduler. + let rt = Handle::current(); + let task = + rt.spawn_blocking(move || Manifest::from_json(reader).context(error::ManifestParseSnafu)); + task.await.expect("TODO - create snafu join handle error") } diff --git a/sources/api/migration/migrator/src/test.rs b/sources/api/migration/migrator/src/test.rs index 366c4130bc7..1fceae28b96 100644 --- a/sources/api/migration/migrator/src/test.rs +++ b/sources/api/migration/migrator/src/test.rs @@ -4,10 +4,10 @@ use crate::args::Args; use crate::run; use chrono::{DateTime, Utc}; use semver::Version; -use std::fs::{self, File}; use std::io::Write; use std::path::{Path, PathBuf}; use tempfile::TempDir; +use tokio::fs; /// Provides the path to a folder where test data files reside. fn test_data() -> PathBuf { @@ -115,7 +115,8 @@ struct TestRepo { /// LZ4 compresses `source` bytes to a new file at `destination`. fn compress(source: &[u8], destination: &Path) { - let output_file = File::create(destination).unwrap(); + // It is easier to use blocking IO here and in test code is fine as long as it works. + let output_file = std::fs::File::create(destination).unwrap(); let mut encoder = lz4::EncoderBuilder::new() .level(4) .build(output_file) @@ -127,7 +128,7 @@ fn compress(source: &[u8], destination: &Path) { /// Creates a test repository with a couple of versions defined in the manifest and a couple of /// migrations. See the test description for for more info. -fn create_test_repo(test_type: TestType) -> TestRepo { +async fn create_test_repo(test_type: TestType) -> TestRepo { // This is where the signed TUF repo will exist when we are done. It is the // root directory of the `TestRepo` we will return when we are done. let test_repo_dir = TempDir::new().unwrap(); @@ -165,7 +166,7 @@ fn create_test_repo(test_type: TestType) -> TestRepo { } // Create and sign the TUF repository. - let mut editor = tough::editor::RepositoryEditor::new(root()).unwrap(); + let mut editor = tough::editor::RepositoryEditor::new(root()).await.unwrap(); let long_ago: DateTime = DateTime::parse_from_rfc3339("1970-01-01T00:00:00Z") .unwrap() .into(); @@ -180,16 +181,17 @@ fn create_test_repo(test_type: TestType) -> TestRepo { .timestamp_version(one) .timestamp_expires(long_ago); - list_dir_files(tuf_indir).into_iter().for_each(|path| { + for path in list_dir_files(tuf_indir).await { editor .add_target( path.file_name().unwrap().to_str().unwrap(), - tough::schema::Target::from_path(&path).unwrap(), + tough::schema::Target::from_path(&path).await.unwrap(), ) .unwrap(); - }); + } let signed_repo = editor .sign(&[Box::new(tough::key_source::LocalKeySource { path: pem() })]) + .await .unwrap(); signed_repo .link_targets( @@ -197,8 +199,9 @@ fn create_test_repo(test_type: TestType) -> TestRepo { &targets_path, tough::editor::signed::PathExists::Fail, ) + .await .unwrap(); - signed_repo.write(&metadata_path).unwrap(); + signed_repo.write(&metadata_path).await.unwrap(); TestRepo { _tuf_dir: test_repo_dir, @@ -209,12 +212,12 @@ fn create_test_repo(test_type: TestType) -> TestRepo { /// Asserts that the expected directories and files are in the datastore directory after a /// failed migration. Returns the absolute path that the `current` symlink is pointing to. -fn assert_directory_structure_with_failed_migration( +async fn assert_directory_structure_with_failed_migration( dir: &Path, from: &Version, to: &Version, ) -> PathBuf { - let paths = list_dir_entries(dir); + let paths = list_dir_entries(dir).await; let from_ver = format!("v{}", from); let from_ver_unique_prefix = format!("v{}_", from); let to_ver_unique_prefix = format!("v{}_", to); @@ -259,8 +262,8 @@ fn assert_directory_structure_with_failed_migration( /// Asserts that the expected directories and files are in the datastore directory after a /// successful migration. Returns the absolute path that the `current` symlink is pointing to. -fn assert_directory_structure(dir: &Path) -> PathBuf { - let paths = list_dir_entries(dir); +async fn assert_directory_structure(dir: &Path) -> PathBuf { + let paths = list_dir_entries(dir).await; assert_eq!(paths.len(), 8); assert_dir_entry_exists(&paths, "current"); assert_dir_entry_exists(&paths, "result.txt"); @@ -301,18 +304,18 @@ fn assert_dir_starting_with_exists(dir_entries: &[PathBuf], starts_with: &str) { ); } -fn list_dir_entries(dir: impl AsRef) -> Vec { - fs::read_dir(dir) - .unwrap() - .map(|dir_entry_result| { - let dir_entry = dir_entry_result.unwrap(); - dir_entry.path() - }) - .collect() +async fn list_dir_entries(dir: impl AsRef) -> Vec { + let mut paths = Vec::new(); + let mut read_dir = fs::read_dir(dir).await.unwrap(); + while let Some(entry) = read_dir.next_entry().await.unwrap() { + paths.push(entry.path()) + } + paths } -fn list_dir_files(dir: impl AsRef) -> Vec { +async fn list_dir_files(dir: impl AsRef) -> Vec { list_dir_entries(dir) + .await .into_iter() .filter(|path| path.is_file()) .collect() @@ -332,12 +335,12 @@ fn list_dir_files(dir: impl AsRef) -> Vec { /// (i.e. since migrations run in the context of the datastore directory, `result.txt` is /// written one directory above the datastore.) We can then inspect the contents of `result.txt` /// to see that the expected migrations ran in the correct order. -#[test] -fn migrate_forward() { +#[tokio::test] +async fn migrate_forward() { let from_version = Version::parse("0.99.0").unwrap(); let to_version = Version::parse("0.99.1").unwrap(); let test_datastore = TestDatastore::new(from_version); - let test_repo = create_test_repo(TestType::Success); + let test_repo = create_test_repo(TestType::Success).await; let args = Args { datastore_path: test_datastore.datastore.clone(), log_level: log::LevelFilter::Info, @@ -346,7 +349,7 @@ fn migrate_forward() { root_path: root(), metadata_directory: test_repo.metadata_path.clone(), }; - run(&args).unwrap(); + run(&args).await.unwrap(); // the migrations should write to a file named result.txt. let output_file = test_datastore.tmp.path().join("result.txt"); let contents = std::fs::read_to_string(&output_file).unwrap(); @@ -366,7 +369,7 @@ fn migrate_forward() { assert_eq!(got, want); // Check the directory. - let current = assert_directory_structure(test_datastore.tmp.path()); + let current = assert_directory_structure(test_datastore.tmp.path()).await; // We have successfully migrated so current should be pointing to a directory that starts with // v0.99.1. @@ -380,12 +383,12 @@ fn migrate_forward() { /// This test ensures that migrations run when migrating from a newer to an older version. /// See `migrate_forward` for a description of how these tests work. -#[test] -fn migrate_backward() { +#[tokio::test] +async fn migrate_backward() { let from_version = Version::parse("0.99.1").unwrap(); let to_version = Version::parse("0.99.0").unwrap(); let test_datastore = TestDatastore::new(from_version); - let test_repo = create_test_repo(TestType::Success); + let test_repo = create_test_repo(TestType::Success).await; let args = Args { datastore_path: test_datastore.datastore.clone(), log_level: log::LevelFilter::Info, @@ -394,7 +397,7 @@ fn migrate_backward() { root_path: root(), metadata_directory: test_repo.metadata_path.clone(), }; - run(&args).unwrap(); + run(&args).await.unwrap(); let output_file = test_datastore.tmp.path().join("result.txt"); let contents = std::fs::read_to_string(&output_file).unwrap(); let lines: Vec<&str> = contents.split('\n').collect(); @@ -413,7 +416,7 @@ fn migrate_backward() { assert_eq!(got, want); // Check the directory. - let current = assert_directory_structure(test_datastore.tmp.path()); + let current = assert_directory_structure(test_datastore.tmp.path()).await; // We have successfully migrated so current should be pointing to a directory that starts with // v0.99.0. @@ -425,12 +428,12 @@ fn migrate_backward() { .starts_with("v0.99.0")); } -#[test] -fn migrate_forward_with_failed_migration() { +#[tokio::test] +async fn migrate_forward_with_failed_migration() { let from_version = Version::parse("0.99.0").unwrap(); let to_version = Version::parse("0.99.1").unwrap(); let test_datastore = TestDatastore::new(from_version.clone()); - let test_repo = create_test_repo(TestType::ForwardFailure); + let test_repo = create_test_repo(TestType::ForwardFailure).await; let args = Args { datastore_path: test_datastore.datastore.clone(), log_level: log::LevelFilter::Info, @@ -439,7 +442,7 @@ fn migrate_forward_with_failed_migration() { root_path: root(), metadata_directory: test_repo.metadata_path.clone(), }; - let result = run(&args); + let result = run(&args).await; assert!(result.is_err()); // the migrations should write to a file named result.txt. @@ -465,7 +468,8 @@ fn migrate_forward_with_failed_migration() { test_datastore.tmp.path(), &from_version, &to_version, - ); + ) + .await; // We have not successfully migrated to v0.99.1 so we should still be pointing at the "from" // version. @@ -477,12 +481,12 @@ fn migrate_forward_with_failed_migration() { .starts_with("v0.99.0")); } -#[test] -fn migrate_backward_with_failed_migration() { +#[tokio::test] +async fn migrate_backward_with_failed_migration() { let from_version = Version::parse("0.99.1").unwrap(); let to_version = Version::parse("0.99.0").unwrap(); let test_datastore = TestDatastore::new(from_version.clone()); - let test_repo = create_test_repo(TestType::BackwardFailure); + let test_repo = create_test_repo(TestType::BackwardFailure).await; let args = Args { datastore_path: test_datastore.datastore.clone(), log_level: log::LevelFilter::Info, @@ -491,7 +495,7 @@ fn migrate_backward_with_failed_migration() { root_path: root(), metadata_directory: test_repo.metadata_path.clone(), }; - let result = run(&args); + let result = run(&args).await; assert!(result.is_err()); let output_file = test_datastore.tmp.path().join("result.txt"); @@ -516,7 +520,8 @@ fn migrate_backward_with_failed_migration() { test_datastore.tmp.path(), &from_version, &to_version, - ); + ) + .await; // We have not successfully migrated to v0.99.0 so we should still be pointing at the "from" // version. From 3fae5483acefa97c3fa89e5b510853966f63324c Mon Sep 17 00:00:00 2001 From: Matthew James Briggs Date: Thu, 2 Nov 2023 14:21:51 -0700 Subject: [PATCH 3/3] updog: update to tough async --- sources/Cargo.lock | 66 ++-------- sources/updater/updog/Cargo.toml | 8 +- sources/updater/updog/src/main.rs | 161 +++++++++++++++---------- sources/updater/updog/src/transport.rs | 35 ++++-- 4 files changed, 141 insertions(+), 129 deletions(-) diff --git a/sources/Cargo.lock b/sources/Cargo.lock index 135dde924ea..4138810feba 100644 --- a/sources/Cargo.lock +++ b/sources/Cargo.lock @@ -2666,7 +2666,7 @@ dependencies = [ "tempfile", "tokio", "tokio-util", - "tough 0.15.0", + "tough", "update_metadata", "url", ] @@ -3045,33 +3045,6 @@ version = "1.0.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "de3145af08024dea9fa9914f381a17b8fc6034dfb00f3a84013f7ff43f29ed4c" -[[package]] -name = "path-absolutize" -version = "3.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e4af381fe79fa195b4909485d99f73a80792331df0625188e707854f0b3383f5" -dependencies = [ - "path-dedot", -] - -[[package]] -name = "path-dedot" -version = "3.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07ba0ad7e047712414213ff67533e6dd477af0a4e1d14fb52343e53d30ea9397" -dependencies = [ - "once_cell", -] - -[[package]] -name = "pem" -version = "1.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a8835c273a76a90455d7344889b0964598e3316e2a79ede8e36f16bdcf2228b8" -dependencies = [ - "base64 0.13.1", -] - [[package]] name = "pem" version = "3.0.2" @@ -4409,33 +4382,6 @@ dependencies = [ "winnow", ] -[[package]] -name = "tough" -version = "0.14.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eda3efa9005cf9c1966984c3b9a44c3f37b7ed2c95ba338d6ad51bba70e989a0" -dependencies = [ - "chrono", - "dyn-clone", - "globset", - "hex", - "log", - "olpc-cjson", - "path-absolutize", - "pem 1.1.1", - "percent-encoding", - "reqwest", - "ring", - "serde", - "serde_json", - "serde_plain", - "snafu", - "tempfile", - "untrusted", - "url", - "walkdir", -] - [[package]] name = "tough" version = "0.15.0" @@ -4453,7 +4399,7 @@ dependencies = [ "hex", "log", "olpc-cjson", - "pem 3.0.2", + "pem", "percent-encoding", "reqwest", "ring", @@ -4638,8 +4584,12 @@ name = "updog" version = "0.1.0" dependencies = [ "argh", + "async-trait", "bottlerocket-release", + "bytes", "chrono", + "futures", + "futures-core", "log", "lz4", "models", @@ -4652,8 +4602,10 @@ dependencies = [ "simplelog", "snafu", "tempfile", + "tokio", + "tokio-util", "toml 0.5.11", - "tough 0.14.0", + "tough", "update_metadata", "url", ] diff --git a/sources/updater/updog/Cargo.toml b/sources/updater/updog/Cargo.toml index 9b9a966dc6a..98b938cc45e 100644 --- a/sources/updater/updog/Cargo.toml +++ b/sources/updater/updog/Cargo.toml @@ -9,9 +9,13 @@ publish = false exclude = ["README.md"] [dependencies] +async-trait = "0.1" argh = "0.1" bottlerocket-release = { path = "../../bottlerocket-release", version = "0.1" } +bytes = "1" chrono = { version = "0.4", default-features = false, features = ["std", "clock"] } +futures = "0.3" +futures-core = "0.3" log = "0.4" lz4 = "1" semver = "1" @@ -21,8 +25,10 @@ serde_plain = "1" signpost = { path = "../signpost", version = "0.1" } simplelog = "0.12" snafu = "0.7" +tokio = { version = "~1.32", default-features = false, features = ["fs", "macros", "process", "rt-multi-thread"] } # LTS +tokio-util = { version = "0.7", features = ["compat", "io-util"] } toml = "0.5" -tough = { version = "0.14", features = ["http"] } +tough = { version = "0.15", features = ["http"] } update_metadata = { path = "../update_metadata", version = "0.1" } url = "2" signal-hook = "0.3" diff --git a/sources/updater/updog/src/main.rs b/sources/updater/updog/src/main.rs index f76f63dd6e9..dbbc291ceae 100644 --- a/sources/updater/updog/src/main.rs +++ b/sources/updater/updog/src/main.rs @@ -4,7 +4,7 @@ mod error; mod transport; use crate::error::Result; -use crate::transport::{HttpQueryTransport, QueryParams}; +use crate::transport::{reader_from_stream, HttpQueryTransport, QueryParams}; use bottlerocket_release::BottlerocketRelease; use chrono::Utc; use log::debug; @@ -17,12 +17,12 @@ use signpost::State; use simplelog::{Config as LogConfig, LevelFilter, SimpleLogger}; use snafu::{ErrorCompat, OptionExt, ResultExt}; use std::convert::{TryFrom, TryInto}; -use std::fs::{self, File, OpenOptions}; -use std::io; use std::path::Path; -use std::process; +use std::process::ExitCode; use std::str::FromStr; use std::thread; +use tokio::runtime::Handle; +use tokio::{fs, process}; use tough::{Repository, RepositoryLoader}; use update_metadata::{find_migrations, Manifest, Update}; use url::Url; @@ -108,21 +108,29 @@ GLOBAL OPTIONS: std::process::exit(1) } -fn load_config() -> Result { +async fn load_config() -> Result { let path = "/etc/updog.toml"; - let s = fs::read_to_string(path).context(error::ConfigReadSnafu { path })?; + let s = fs::read_to_string(path) + .await + .context(error::ConfigReadSnafu { path })?; let config: Config = toml::from_str(&s).context(error::ConfigParseSnafu { path })?; Ok(config) } -fn load_repository(transport: HttpQueryTransport, config: &Config) -> Result { - fs::create_dir_all(METADATA_PATH).context(error::CreateMetadataCacheSnafu { - path: METADATA_PATH, - })?; - RepositoryLoader::new( - File::open(TRUSTED_ROOT_PATH).context(error::OpenRootSnafu { +async fn load_repository(transport: HttpQueryTransport, config: &Config) -> Result { + fs::create_dir_all(METADATA_PATH) + .await + .context(error::CreateMetadataCacheSnafu { + path: METADATA_PATH, + })?; + let root_bytes = fs::read(TRUSTED_ROOT_PATH) + .await + .context(error::OpenRootSnafu { path: TRUSTED_ROOT_PATH, - })?, + })?; + + RepositoryLoader::new( + &root_bytes, Url::parse(&config.metadata_base_url).context(error::UrlParseSnafu { url: &config.metadata_base_url, })?, @@ -132,6 +140,7 @@ fn load_repository(transport: HttpQueryTransport, config: &Config) -> Result( Ok(None) } -fn write_target_to_disk>( +async fn write_target_to_disk>( repository: &Repository, target: &str, disk_path: P, @@ -218,31 +227,38 @@ fn write_target_to_disk>( let target = target .try_into() .context(error::TargetNameSnafu { target })?; - let reader = repository + let stream = repository .read_target(&target) + .await .context(error::MetadataSnafu)? .context(error::TargetNotFoundSnafu { target: target.raw(), })?; - // Note: the file extension for the compression type we're using should be removed in - // retrieve_migrations below. - let mut reader = lz4::Decoder::new(reader).context(error::Lz4DecodeSnafu { - target: target.raw(), - })?; - let mut f = OpenOptions::new() - .write(true) - .create(true) - .open(disk_path.as_ref()) - .context(error::OpenPartitionSnafu { - path: disk_path.as_ref(), + let reader = reader_from_stream(stream); + + // Run blocking IO without blocking the scheduler. + let disk_path = disk_path.as_ref().to_path_buf(); + let rt = Handle::current(); + let task = rt.spawn_blocking(move || { + // Note: the file extension for the compression type we're using should be removed in + // retrieve_migrations below. + let mut reader = lz4::Decoder::new(reader).context(error::Lz4DecodeSnafu { + target: target.raw(), })?; - io::copy(&mut reader, &mut f).context(error::WriteUpdateSnafu)?; - Ok(()) + let mut f = std::fs::OpenOptions::new() + .write(true) + .create(true) + .open(&disk_path) + .context(error::OpenPartitionSnafu { path: disk_path })?; + std::io::copy(&mut reader, &mut f).context(error::WriteUpdateSnafu)?; + Ok(()) + }); + task.await.expect("TODO - snafu error for this") } /// Store required migrations for an update in persistent storage. All intermediate migrations /// between the current version and the target version must be retrieved. -fn retrieve_migrations( +async fn retrieve_migrations( repository: &Repository, query_params: &mut QueryParams, manifest: &Manifest, @@ -257,7 +273,9 @@ fn retrieve_migrations( let dir = Path::new(MIGRATION_PATH); if !dir.exists() { - fs::create_dir(dir).context(error::DirCreateSnafu { path: &dir })?; + fs::create_dir(dir) + .await + .context(error::DirCreateSnafu { path: &dir })?; } // find the list of migrations in the manifest based on our from and to versions. @@ -268,13 +286,14 @@ fn retrieve_migrations( targets.push("manifest.json".to_owned()); repository .cache(METADATA_PATH, MIGRATION_PATH, Some(&targets), true) + .await .context(error::RepoCacheMigrationsSnafu)?; // Set a query parameter listing the required migrations query_params.add("migrations", targets.join(",")); Ok(()) } -fn update_image(update: &Update, repository: &Repository) -> Result<()> { +async fn update_image(update: &Update, repository: &Repository) -> Result<()> { let mut gpt_state = State::load().context(error::PartitionTableReadSnafu)?; gpt_state.clear_inactive(); // Write out the clearing of the inactive partition immediately, because we're about to @@ -285,9 +304,9 @@ fn update_image(update: &Update, repository: &Repository) -> Result<()> { let inactive = gpt_state.inactive_set(); // TODO Do we want to recover the inactive side on an error? - write_target_to_disk(repository, &update.images.root, &inactive.root)?; - write_target_to_disk(repository, &update.images.boot, &inactive.boot)?; - write_target_to_disk(repository, &update.images.hash, &inactive.hash)?; + write_target_to_disk(repository, &update.images.root, &inactive.root).await?; + write_target_to_disk(repository, &update.images.boot, &inactive.boot).await?; + write_target_to_disk(repository, &update.images.hash, &inactive.hash).await?; gpt_state.mark_inactive_valid(); gpt_state.write().context(error::PartitionTableWriteSnafu)?; @@ -441,11 +460,11 @@ fn output(json: bool, object: T, string: &str) -> Result<()> { Ok(()) } -fn initiate_reboot() -> Result<()> { +async fn initiate_reboot() -> Result<()> { // Set up signal handler for termination signals let mut signals = Signals::new([SIGTERM]).context(error::SignalSnafu)?; let signals_handle = signals.handle(); - thread::spawn(move || { + let _ = thread::spawn(move || { for _sig in signals.forever() { // Ignore termination signals in case updog gets terminated // before getting to exit normally by itself after invoking @@ -455,6 +474,7 @@ fn initiate_reboot() -> Result<()> { if let Err(err) = process::Command::new("shutdown") .arg("-r") .status() + .await .context(error::RebootFailureSnafu) { // Kill the signal handling thread @@ -489,7 +509,7 @@ fn set_https_proxy_environment_variables( } #[allow(clippy::too_many_lines)] -fn main_inner() -> Result<()> { +async fn main_inner() -> Result<()> { // Parse and store the arguments passed to the program let arguments = parse_args(std::env::args()); @@ -499,7 +519,7 @@ fn main_inner() -> Result<()> { let command = serde_plain::from_str::(&arguments.subcommand).unwrap_or_else(|_| usage()); - let config = load_config()?; + let config = load_config().await?; set_https_proxy_environment_variables(&config.https_proxy, &config.no_proxy); let current_release = BottlerocketRelease::new().context(error::ReleaseVersionSnafu)?; let variant = arguments.variant.unwrap_or(current_release.variant_id); @@ -508,8 +528,8 @@ fn main_inner() -> Result<()> { // the transport's HTTP calls. let mut query_params = transport.query_params(); set_common_query_params(&mut query_params, ¤t_release.version_id, &config); - let repository = load_repository(transport, &config)?; - let manifest = load_manifest(&repository)?; + let repository = load_repository(transport, &config).await?; + let manifest = load_manifest(&repository).await?; let ignore_waves = arguments.ignore_waves || config.ignore_waves; match command { Command::CheckUpdate | Command::Whats => { @@ -554,12 +574,13 @@ fn main_inner() -> Result<()> { &manifest, u, ¤t_release.version_id, - )?; - update_image(u, &repository)?; + ) + .await?; + update_image(u, &repository).await?; if command == Command::Update { update_flags()?; if arguments.reboot { - initiate_reboot()?; + initiate_reboot().await?; } } output( @@ -574,7 +595,7 @@ fn main_inner() -> Result<()> { Command::UpdateApply => { update_flags()?; if arguments.reboot { - initiate_reboot()?; + initiate_reboot().await?; } } Command::UpdateRevert => { @@ -588,23 +609,28 @@ fn main_inner() -> Result<()> { Ok(()) } -fn load_manifest(repository: &tough::Repository) -> Result { +async fn load_manifest(repository: &tough::Repository) -> Result { let target = "manifest.json"; let target = target .try_into() .context(error::TargetNameSnafu { target })?; - Manifest::from_json( - repository - .read_target(&target) - .context(error::ManifestLoadSnafu)? - .context(error::ManifestNotFoundSnafu)?, - ) - .context(error::ManifestParseSnafu) + let stream = repository + .read_target(&target) + .await + .context(error::ManifestLoadSnafu)? + .context(error::ManifestNotFoundSnafu)?; + let reader = reader_from_stream(stream); + + // Run blocking IO on a thread. + let rt = Handle::current(); + let task = rt.spawn_blocking(|| Manifest::from_json(reader).context(error::ManifestParseSnafu)); + task.await.expect("TODO - snafu error for this") } -fn main() -> ! { - std::process::exit(match main_inner() { - Ok(()) => 0, +#[tokio::main] +async fn main() -> ExitCode { + match main_inner().await { + Ok(()) => ExitCode::SUCCESS, Err(err) => { eprintln!("{err}"); if let Some(var) = std::env::var_os("RUST_BACKTRACE") { @@ -614,9 +640,9 @@ fn main() -> ! { } } } - 1 + ExitCode::from(1) } - }) + } } #[cfg(test)] @@ -635,7 +661,8 @@ mod tests { // - the image:datastore mappings exist // - there is a mapping between 1.11.0 and 1.0 let path = "tests/data/example.json"; - let manifest: Manifest = serde_json::from_reader(File::open(path).unwrap()).unwrap(); + let manifest: Manifest = + serde_json::from_reader(std::fs::File::open(path).unwrap()).unwrap(); assert!( !manifest.updates.is_empty(), "Failed to parse update manifest" @@ -659,7 +686,8 @@ mod tests { // A basic manifest with a single update, no migrations, and two // image:datastore mappings let path = "tests/data/example_2.json"; - let manifest: Manifest = serde_json::from_reader(File::open(path).unwrap()).unwrap(); + let manifest: Manifest = + serde_json::from_reader(std::fs::File::open(path).unwrap()).unwrap(); assert!(!manifest.updates.is_empty()); } @@ -670,7 +698,8 @@ mod tests { // - version: 1.25.0 // - max_version: 1.20.0 let path = "tests/data/regret.json"; - let manifest: Manifest = serde_json::from_reader(File::open(path).unwrap()).unwrap(); + let manifest: Manifest = + serde_json::from_reader(std::fs::File::open(path).unwrap()).unwrap(); let config = Config { metadata_base_url: String::from("foo"), targets_base_url: String::from("bar"), @@ -704,7 +733,8 @@ mod tests { // A manifest with two updates, both less than 0.1.3. // Use a architecture specific JSON payload, otherwise updog will ignore the update let path = format!("tests/data/example_3_{TARGET_ARCH}.json"); - let manifest: Manifest = serde_json::from_reader(File::open(path).unwrap()).unwrap(); + let manifest: Manifest = + serde_json::from_reader(std::fs::File::open(path).unwrap()).unwrap(); let config = Config { metadata_base_url: String::from("foo"), targets_base_url: String::from("bar"), @@ -742,7 +772,8 @@ mod tests { // upgrading from the version 1.10.0 results in updating to 1.15.0 // instead of 1.13.0 (lower), 1.25.0 (too high), or 1.16.0 (wrong arch). let path = format!("tests/data/multiple_{TARGET_ARCH}.json"); - let manifest: Manifest = serde_json::from_reader(File::open(path).unwrap()).unwrap(); + let manifest: Manifest = + serde_json::from_reader(std::fs::File::open(path).unwrap()).unwrap(); let config = Config { metadata_base_url: String::from("foo"), targets_base_url: String::from("bar"), @@ -784,7 +815,8 @@ mod tests { // a downgrade to 1.13.0, instead of 1.15.0 like it would be in the // above test, test_multiple. let path = format!("tests/data/multiple_{TARGET_ARCH}.json"); - let manifest: Manifest = serde_json::from_reader(File::open(path).unwrap()).unwrap(); + let manifest: Manifest = + serde_json::from_reader(std::fs::File::open(path).unwrap()).unwrap(); let config = Config { metadata_base_url: String::from("foo"), targets_base_url: String::from("bar"), @@ -841,7 +873,8 @@ mod tests { fn serialize_metadata() { // A basic manifest with a single update let path = "tests/data/example_2.json"; - let manifest: Manifest = serde_json::from_reader(File::open(path).unwrap()).unwrap(); + let manifest: Manifest = + serde_json::from_reader(std::fs::File::open(path).unwrap()).unwrap(); assert!(serde_json::to_string_pretty(&manifest) .context(error::UpdateSerializeSnafu) .is_ok()); diff --git a/sources/updater/updog/src/transport.rs b/sources/updater/updog/src/transport.rs index 1edbfa99251..9c7c9857628 100644 --- a/sources/updater/updog/src/transport.rs +++ b/sources/updater/updog/src/transport.rs @@ -1,6 +1,14 @@ -use std::sync::{Arc, RwLock}; - +use async_trait::async_trait; +use bytes::Bytes; +use futures::StreamExt; +use futures::TryStreamExt; +use futures_core::Stream; use log::error; +use std::io::{ErrorKind, Read}; +use std::pin::Pin; +use std::sync::{Arc, RwLock}; +use tokio_util::compat::FuturesAsyncReadCompatExt; +use tokio_util::io::SyncIoBridge; use tough::{HttpTransport, Transport, TransportError}; use url::Url; @@ -67,11 +75,24 @@ impl HttpQueryTransport { } } +pub(crate) type TransportStream = Pin> + Send>>; + +#[async_trait] impl Transport for HttpQueryTransport { - fn fetch( - &self, - url: Url, - ) -> std::result::Result, TransportError> { - self.inner.fetch(self.parameters.add_params_to_url(url)) + /// Send a GET request to the URL. The returned `TransportStream` will retry as necessary per + /// the `ClientSettings`. + async fn fetch(&self, url: Url) -> Result { + self.inner + .fetch(self.parameters.add_params_to_url(url)) + .await } } + +pub(crate) fn reader_from_stream(stream: S) -> impl Read +where + S: Stream> + Send + Unpin, +{ + let mapped_err = stream.map(|next| next.map_err(|e| std::io::Error::new(ErrorKind::Other, e))); + let async_read = mapped_err.into_async_read().compat(); + SyncIoBridge::new(async_read) +}