diff options
-rw-r--r-- | src/commands/source/download.rs | 262 | ||||
-rw-r--r-- | src/commands/source/mod.rs (renamed from src/commands/source.rs) | 161 | ||||
-rw-r--r-- | src/config/not_validated.rs | 1 | ||||
-rw-r--r-- | src/main.rs | 1 | ||||
-rw-r--r-- | src/util/progress.rs | 14 |
5 files changed, 269 insertions, 170 deletions
diff --git a/src/commands/source/download.rs b/src/commands/source/download.rs new file mode 100644 index 0000000..11843fa --- /dev/null +++ b/src/commands/source/download.rs @@ -0,0 +1,262 @@ +// +// Copyright (c) 2020-2021 science+computing ag and other contributors +// +// This program and the accompanying materials are made +// available under the terms of the Eclipse Public License 2.0 +// which is available at https://www.eclipse.org/legal/epl-2.0/ +// +// SPDX-License-Identifier: EPL-2.0 +// + +use std::convert::TryFrom; +use std::path::PathBuf; +use std::str::FromStr; +use std::sync::Arc; + +use anyhow::Context; +use anyhow::Error; +use anyhow::Result; +use anyhow::anyhow; +use clap::ArgMatches; +use log::{debug, trace}; +use tokio::io::AsyncWriteExt; +use tokio::sync::Mutex; +use tokio_stream::StreamExt; + +use crate::config::*; +use crate::package::PackageName; +use crate::package::PackageVersionConstraint; +use crate::repository::Repository; +use crate::source::*; +use crate::util::progress::ProgressBars; + +const NUMBER_OF_MAX_CONCURRENT_DOWNLOADS: usize = 100; + +/// A wrapper around the indicatif::ProgressBar +/// +/// A wrapper around the indicatif::ProgressBar that is used to synchronize status information from +/// the individual download jobs to the progress bar that is used to display download progress to +/// the user. +/// +/// The problem this helper solves is that we only have one status bar for all downloads, and all +/// download tasks must be able to increase the number of bytes received, for example, (that is +/// displayed in the status message) but in a sync way. +#[derive(Clone)] +struct ProgressWrapper { + download_count: u64, + finished_downloads: u64, + current_bytes: usize, + sum_bytes: u64, + bar: Arc<Mutex<indicatif::ProgressBar>>, +} + +impl ProgressWrapper { + fn new(bar: indicatif::ProgressBar) -> Self { + Self { + download_count: 0, + finished_downloads: 0, + current_bytes: 0, + sum_bytes: 0, + bar: Arc::new(Mutex::new(bar)) + } + } + + async fn inc_download_count(&mut self) { + self.download_count += 1; + self.set_message().await; + let bar = self.bar.lock().await; + bar.set_length(bar.length() + 1); + } + + async fn inc_download_bytes(&mut self, bytes: u64) { + self.sum_bytes += bytes; + self.set_message().await; + } + + async fn finish_one_download(&mut self) { + self.finished_downloads += 1; + self.bar.lock().await.inc(1); + self.set_message().await; + } + + async fn add_bytes(&mut self, len: usize) { + self.current_bytes += len; + self.set_message().await; + } + + async fn set_message(&self) { + let bar = self.bar.lock().await; + bar.set_message(format!("Downloading ({current_bytes}/{sum_bytes} bytes, {dlfinished}/{dlsum} downloads finished)", + current_bytes = self.current_bytes, + sum_bytes = self.sum_bytes, + dlfinished = self.finished_downloads, + dlsum = self.download_count)); + } + + async fn success(&self) { + let bar = self.bar.lock().await; + bar.finish_with_message(format!("Succeeded {}/{} downloads", self.finished_downloads, self.download_count)); + } + + async fn error(&self) { + let bar = self.bar.lock().await; + bar.finish_with_message(format!("At least one download of {} failed", self.download_count)); + } +} + +async fn perform_download(source: &SourceEntry, progress: Arc<Mutex<ProgressWrapper>>, timeout: Option<u64>) -> Result<()> { + trace!("Creating: {:?}", source); + let file = source.create().await.with_context(|| { + anyhow!( + "Creating source file destination: {}", + source.path().display() + ) + })?; + + let mut file = tokio::io::BufWriter::new(file); + let client_builder = reqwest::Client::builder() + .redirect(reqwest::redirect::Policy::limited(10)); + + let client_builder = if let Some(to) = timeout { + client_builder.timeout(std::time::Duration::from_secs(to)) + } else { + client_builder + }; + + let client = client_builder.build().context("Building HTTP client failed")?; + + let request = client.get(source.url().as_ref()) + .build() + .with_context(|| anyhow!("Building request for {} failed", source.url().as_ref()))?; + + let response = match client.execute(request).await { + Ok(resp) => resp, + Err(e) => { + return Err(e).with_context(|| anyhow!("Downloading '{}'", source.url())) + } + }; + + progress.lock() + .await + .inc_download_bytes(response.content_length().unwrap_or(0)) + .await; + + let mut stream = response.bytes_stream(); + while let Some(bytes) = stream.next().await { + let bytes = bytes?; + tokio::try_join!( + file.write_all(bytes.as_ref()), + async { + progress.lock() + .await + .add_bytes(bytes.len()) + .await; + Ok(()) + } + )?; + } + + file.flush() + .await + .map_err(Error::from) + .map(|_| ()) +} + + +// Implementation of the 'source download' subcommand +pub async fn download( + matches: &ArgMatches, + config: &Configuration, + repo: Repository, + progressbars: ProgressBars, +) -> Result<()> { + let force = matches.is_present("force"); + let timeout = matches.value_of("timeout") + .map(u64::from_str) + .transpose() + .context("Parsing timeout argument to integer")?; + let cache = PathBuf::from(config.source_cache_root()); + let sc = SourceCache::new(cache); + let pname = matches + .value_of("package_name") + .map(String::from) + .map(PackageName::from); + let pvers = matches + .value_of("package_version") + .map(PackageVersionConstraint::try_from) + .transpose()?; + + let matching_regexp = matches.value_of("matching") + .map(crate::commands::util::mk_package_name_regex) + .transpose()?; + + let progressbar = Arc::new(Mutex::new(ProgressWrapper::new(progressbars.bar()))); + + let download_sema = Arc::new(tokio::sync::Semaphore::new(NUMBER_OF_MAX_CONCURRENT_DOWNLOADS)); + + let r = repo.packages() + .filter(|p| { + match (pname.as_ref(), pvers.as_ref(), matching_regexp.as_ref()) { + (None, None, None) => true, + (Some(pname), None, None) => p.name() == pname, + (Some(pname), Some(vers), None) => p.name() == pname && vers.matches(p.version()), + (None, None, Some(regex)) => regex.is_match(p.name()), + + (_, _, _) => { + panic!("This should not be possible, either we select packages by name and (optionally) version, or by regex.") + }, + } + }) + .map(|p| { + sc.sources_for(p).into_iter().map(|source| { + let download_sema = download_sema.clone(); + let progressbar = progressbar.clone(); + async move { + let source_path_exists = source.path().exists(); + if !source_path_exists && source.download_manually() { + return Err(anyhow!( + "Cannot download source that is marked for manual download" + )) + .context(anyhow!("Creating source: {}", source.path().display())) + .context(anyhow!("Downloading source: {}", source.url())) + .map_err(Error::from); + } + + if source_path_exists && !force { + Err(anyhow!("Source exists: {}", source.path().display())) + } else { + if source_path_exists /* && force is implied by 'if' above*/ { + if let Err(e) = source.remove_file().await { + return Err(e) + } + } + + progressbar.lock().await.inc_download_count().await; + { + let permit = download_sema.acquire_owned().await?; + perform_download(&source, progressbar.clone(), timeout).await?; + drop(permit); + } + progressbar.lock().await.finish_one_download().await; + Ok(()) + } + } + }) + }) + .flatten() + .collect::<futures::stream::FuturesUnordered<_>>() + .collect::<Vec<Result<()>>>() + .await + .into_iter() + .collect::<Result<()>>(); + + if r.is_err() { + progressbar.lock().await.error().await; + } else { + progressbar.lock().await.success().await; + } + + debug!("r = {:?}", r); + r +} + diff --git a/src/commands/source.rs b/src/commands/source/mod.rs index 6f52b2b..8d11099 100644 --- a/src/commands/source.rs +++ b/src/commands/source/mod.rs @@ -10,19 +10,17 @@ //! Implementation of the 'source' subcommand +use std::convert::TryFrom; use std::io::Write; use std::path::PathBuf; -use std::convert::TryFrom; -use std::str::FromStr; -use anyhow::anyhow; use anyhow::Context; use anyhow::Error; use anyhow::Result; +use anyhow::anyhow; use clap::ArgMatches; use colored::Colorize; use log::{info, trace}; -use tokio::io::AsyncWriteExt; use tokio_stream::StreamExt; use crate::config::*; @@ -33,6 +31,8 @@ use crate::repository::Repository; use crate::source::*; use crate::util::progress::ProgressBars; +mod download; + /// Implementation of the "source" subcommand pub async fn source( matches: &ArgMatches, @@ -44,7 +44,7 @@ pub async fn source( Some(("verify", matches)) => verify(matches, config, repo, progressbars).await, Some(("list-missing", matches)) => list_missing(matches, config, repo).await, Some(("url", matches)) => url(matches, repo).await, - Some(("download", matches)) => download(matches, config, repo, progressbars).await, + Some(("download", matches)) => crate::commands::source::download::download(matches, config, repo, progressbars).await, Some(("of", matches)) => of(matches, config, repo).await, Some((other, _)) => return Err(anyhow!("Unknown subcommand: {}", other)), None => Err(anyhow!("No subcommand")), @@ -218,157 +218,6 @@ pub async fn url(matches: &ArgMatches, repo: Repository) -> Result<()> { }) } -pub async fn download( - matches: &ArgMatches, - config: &Configuration, - repo: Repository, - progressbars: ProgressBars, -) -> Result<()> { - async fn perform_download(source: &SourceEntry, bar: &indicatif::ProgressBar, timeout: Option<u64>) -> Result<()> { - trace!("Creating: {:?}", source); - let file = source.create().await.with_context(|| { - anyhow!( - "Creating source file destination: {}", - source.path().display() - ) - })?; - - let mut file = tokio::io::BufWriter::new(file); - let client_builder = reqwest::Client::builder() - .redirect(reqwest::redirect::Policy::limited(10)); - - let client_builder = if let Some(to) = timeout { - client_builder.timeout(std::time::Duration::from_secs(to)) - } else { - client_builder - }; - - let client = client_builder.build().context("Building HTTP client failed")?; - - let request = client.get(source.url().as_ref()) - .build() - .with_context(|| anyhow!("Building request for {} failed", source.url().as_ref()))?; - - let response = match client.execute(request).await { - Ok(resp) => resp, - Err(e) => { - bar.finish_with_message(format!("Failed: {}", source.url())); - return Err(e).with_context(|| anyhow!("Downloading '{}'", source.url())) - } - }; - - if let Some(len) = response.content_length() { - bar.set_length(len); - } - - let mut stream = reqwest::get(source.url().as_ref()).await?.bytes_stream(); - let mut bytes_written = 0; - while let Some(bytes) = stream.next().await { - let bytes = bytes?; - file.write_all(bytes.as_ref()).await?; - bytes_written += bytes.len(); - - bar.inc(bytes.len() as u64); - if let Some(len) = response.content_length() { - bar.set_message(format!("Downloading {} ({}/{} bytes)", source.url(), bytes_written, len)); - } else { - bar.set_message(format!("Downloading {} ({} bytes)", source.url(), bytes_written)); - } - } - - file.flush() - .await - .map_err(Error::from) - .map(|_| ()) - } - - let force = matches.is_present("force"); - let timeout = matches.value_of("timeout") - .map(u64::from_str) - .transpose() - .context("Parsing timeout argument to integer")?; - let cache = PathBuf::from(config.source_cache_root()); - let sc = SourceCache::new(cache); - let pname = matches - .value_of("package_name") - .map(String::from) - .map(PackageName::from); - let pvers = matches - .value_of("package_version") - .map(PackageVersionConstraint::try_from) - .transpose()?; - let multi = { - let mp = indicatif::MultiProgress::new(); - if progressbars.hide() { - mp.set_draw_target(indicatif::ProgressDrawTarget::hidden()); - } - mp - }; - - let matching_regexp = matches.value_of("matching") - .map(crate::commands::util::mk_package_name_regex) - .transpose()?; - - let r = repo - .packages() - .filter(|p| { - match (pname.as_ref(), pvers.as_ref(), matching_regexp.as_ref()) { - (None, None, None) => true, - (Some(pname), None, None) => p.name() == pname, - (Some(pname), Some(vers), None) => p.name() == pname && vers.matches(p.version()), - (None, None, Some(regex)) => regex.is_match(p.name()), - - (_, _, _) => { - panic!("This should not be possible, either we select packages by name and (optionally) version, or by regex.") - }, - } - }) - .map(|p| { - sc.sources_for(p).into_iter().map(|source| { - let bar = multi.add(progressbars.spinner()); - bar.set_message(format!("Downloading {}", source.url())); - async move { - let source_path_exists = source.path().exists(); - if !source_path_exists && source.download_manually() { - return Err(anyhow!( - "Cannot download source that is marked for manual download" - )) - .context(anyhow!("Creating source: {}", source.path().display())) - .context(anyhow!("Downloading source: {}", source.url())) - .map_err(Error::from); - } - - if source_path_exists && !force { - Err(anyhow!("Source exists: {}", source.path().display())) - } else { - if source_path_exists /* && force is implied by 'if' above*/ { - if let Err(e) = source.remove_file().await { - bar.finish_with_message(format!("Failed to remove existing file: {}", source.path().display())); - return Err(e) - } - } - - - if let Err(e) = perform_download(&source, &bar, timeout).await { - bar.finish_with_message(format!("Failed: {}", source.url())); - Err(e) - } else { - bar.finish_with_message(format!("Finished: {}", source.url())); - Ok(()) - } - } - } - }) - }) - .flatten() - .collect::<futures::stream::FuturesUnordered<_>>() - .collect::<Vec<Result<()>>>(); - - let multibar_block = tokio::task::spawn_blocking(move || multi.join()); - let (r, _) = tokio::join!(r, multibar_block); - r.into_iter().collect() -} - async fn of( matches: &ArgMatches, config: &Configuration, diff --git a/src/config/not_validated.rs b/src/config/not_validated.rs index 4a81fbb..d41acd7 100644 --- a/src/config/not_validated.rs +++ b/src/config/not_validated.rs @@ -51,6 +51,7 @@ pub struct NotValidatedConfiguration { /// The format of the spinners in the CLI #[serde(default = "default_spinner_format")] #[getset(get = "pub")] + #[allow(unused)] spinner_format: String, /// The format used to print a package diff --git a/src/main.rs b/src/main.rs index 2dbfa5d..a57eef5 100644 --- a/src/main.rs +++ b/src/main.rs @@ -139,7 +139,6 @@ async fn main() -> Result<()> { let hide_bars = cli.is_present("hide_bars") || crate::util::stdout_is_pipe(); let progressbars = ProgressBars::setup( config.progress_format().clone(), - config.spinner_format().clone(), hide_bars, ); diff --git a/src/util/progress.rs b/src/util/progress.rs index e403989..57f9d67 100644 --- a/src/util/progress.rs +++ b/src/util/progress.rs @@ -14,17 +14,15 @@ use getset::CopyGetters; #[derive(Clone, Debug, CopyGetters)] pub struct ProgressBars { bar_template: String, - spinner_template: String, #[getset(get_copy = "pub")] hide: bool, } impl ProgressBars { - pub fn setup(bar_template: String, spinner_template: String, hide: bool) -> Self { + pub fn setup(bar_template: String, hide: bool) -> Self { ProgressBars { bar_template, - spinner_template, hide, } } @@ -38,14 +36,4 @@ impl ProgressBars { b } } - - pub fn spinner(&self) -> ProgressBar { - if self.hide { - ProgressBar::hidden() - } else { - let bar = ProgressBar::new_spinner(); - bar.set_style(ProgressStyle::default_spinner().template(&self.spinner_template)); - bar - } - } } |