summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMatthias Beyer <mail@beyermatthias.de>2021-12-02 19:16:16 +0100
committerMatthias Beyer <mail@beyermatthias.de>2021-12-02 19:16:16 +0100
commite833100d29b774ab30daf18dfb791f6aabf3085c (patch)
tree34be7fdb58d84d0406430905ac8ae8e796d4a8b7
parentab252275a75882ad3bc1d17c4c64b96062174962 (diff)
parent262793dbb92de361b743e9d268457ea1f63e822f (diff)
Merge branch 'optimize-mass-download'HEADmaster
-rw-r--r--src/commands/source/download.rs262
-rw-r--r--src/commands/source/mod.rs (renamed from src/commands/source.rs)161
-rw-r--r--src/config/not_validated.rs1
-rw-r--r--src/main.rs1
-rw-r--r--src/util/progress.rs14
5 files changed, 269 insertions, 170 deletions
diff --git a/src/commands/source/download.rs b/src/commands/source/download.rs
new file mode 100644
index 0000000..11843fa
--- /dev/null
+++ b/src/commands/source/download.rs
@@ -0,0 +1,262 @@
+//
+// Copyright (c) 2020-2021 science+computing ag and other contributors
+//
+// This program and the accompanying materials are made
+// available under the terms of the Eclipse Public License 2.0
+// which is available at https://www.eclipse.org/legal/epl-2.0/
+//
+// SPDX-License-Identifier: EPL-2.0
+//
+
+use std::convert::TryFrom;
+use std::path::PathBuf;
+use std::str::FromStr;
+use std::sync::Arc;
+
+use anyhow::Context;
+use anyhow::Error;
+use anyhow::Result;
+use anyhow::anyhow;
+use clap::ArgMatches;
+use log::{debug, trace};
+use tokio::io::AsyncWriteExt;
+use tokio::sync::Mutex;
+use tokio_stream::StreamExt;
+
+use crate::config::*;
+use crate::package::PackageName;
+use crate::package::PackageVersionConstraint;
+use crate::repository::Repository;
+use crate::source::*;
+use crate::util::progress::ProgressBars;
+
+const NUMBER_OF_MAX_CONCURRENT_DOWNLOADS: usize = 100;
+
+/// A wrapper around the indicatif::ProgressBar
+///
+/// A wrapper around the indicatif::ProgressBar that is used to synchronize status information from
+/// the individual download jobs to the progress bar that is used to display download progress to
+/// the user.
+///
+/// The problem this helper solves is that we only have one status bar for all downloads, and all
+/// download tasks must be able to increase the number of bytes received, for example, (that is
+/// displayed in the status message) but in a sync way.
+#[derive(Clone)]
+struct ProgressWrapper {
+ download_count: u64,
+ finished_downloads: u64,
+ current_bytes: usize,
+ sum_bytes: u64,
+ bar: Arc<Mutex<indicatif::ProgressBar>>,
+}
+
+impl ProgressWrapper {
+ fn new(bar: indicatif::ProgressBar) -> Self {
+ Self {
+ download_count: 0,
+ finished_downloads: 0,
+ current_bytes: 0,
+ sum_bytes: 0,
+ bar: Arc::new(Mutex::new(bar))
+ }
+ }
+
+ async fn inc_download_count(&mut self) {
+ self.download_count += 1;
+ self.set_message().await;
+ let bar = self.bar.lock().await;
+ bar.set_length(bar.length() + 1);
+ }
+
+ async fn inc_download_bytes(&mut self, bytes: u64) {
+ self.sum_bytes += bytes;
+ self.set_message().await;
+ }
+
+ async fn finish_one_download(&mut self) {
+ self.finished_downloads += 1;
+ self.bar.lock().await.inc(1);
+ self.set_message().await;
+ }
+
+ async fn add_bytes(&mut self, len: usize) {
+ self.current_bytes += len;
+ self.set_message().await;
+ }
+
+ async fn set_message(&self) {
+ let bar = self.bar.lock().await;
+ bar.set_message(format!("Downloading ({current_bytes}/{sum_bytes} bytes, {dlfinished}/{dlsum} downloads finished)",
+ current_bytes = self.current_bytes,
+ sum_bytes = self.sum_bytes,
+ dlfinished = self.finished_downloads,
+ dlsum = self.download_count));
+ }
+
+ async fn success(&self) {
+ let bar = self.bar.lock().await;
+ bar.finish_with_message(format!("Succeeded {}/{} downloads", self.finished_downloads, self.download_count));
+ }
+
+ async fn error(&self) {
+ let bar = self.bar.lock().await;
+ bar.finish_with_message(format!("At least one download of {} failed", self.download_count));
+ }
+}
+
+async fn perform_download(source: &SourceEntry, progress: Arc<Mutex<ProgressWrapper>>, timeout: Option<u64>) -> Result<()> {
+ trace!("Creating: {:?}", source);
+ let file = source.create().await.with_context(|| {
+ anyhow!(
+ "Creating source file destination: {}",
+ source.path().display()
+ )
+ })?;
+
+ let mut file = tokio::io::BufWriter::new(file);
+ let client_builder = reqwest::Client::builder()
+ .redirect(reqwest::redirect::Policy::limited(10));
+
+ let client_builder = if let Some(to) = timeout {
+ client_builder.timeout(std::time::Duration::from_secs(to))
+ } else {
+ client_builder
+ };
+
+ let client = client_builder.build().context("Building HTTP client failed")?;
+
+ let request = client.get(source.url().as_ref())
+ .build()
+ .with_context(|| anyhow!("Building request for {} failed", source.url().as_ref()))?;
+
+ let response = match client.execute(request).await {
+ Ok(resp) => resp,
+ Err(e) => {
+ return Err(e).with_context(|| anyhow!("Downloading '{}'", source.url()))
+ }
+ };
+
+ progress.lock()
+ .await
+ .inc_download_bytes(response.content_length().unwrap_or(0))
+ .await;
+
+ let mut stream = response.bytes_stream();
+ while let Some(bytes) = stream.next().await {
+ let bytes = bytes?;
+ tokio::try_join!(
+ file.write_all(bytes.as_ref()),
+ async {
+ progress.lock()
+ .await
+ .add_bytes(bytes.len())
+ .await;
+ Ok(())
+ }
+ )?;
+ }
+
+ file.flush()
+ .await
+ .map_err(Error::from)
+ .map(|_| ())
+}
+
+
+// Implementation of the 'source download' subcommand
+pub async fn download(
+ matches: &ArgMatches,
+ config: &Configuration,
+ repo: Repository,
+ progressbars: ProgressBars,
+) -> Result<()> {
+ let force = matches.is_present("force");
+ let timeout = matches.value_of("timeout")
+ .map(u64::from_str)
+ .transpose()
+ .context("Parsing timeout argument to integer")?;
+ let cache = PathBuf::from(config.source_cache_root());
+ let sc = SourceCache::new(cache);
+ let pname = matches
+ .value_of("package_name")
+ .map(String::from)
+ .map(PackageName::from);
+ let pvers = matches
+ .value_of("package_version")
+ .map(PackageVersionConstraint::try_from)
+ .transpose()?;
+
+ let matching_regexp = matches.value_of("matching")
+ .map(crate::commands::util::mk_package_name_regex)
+ .transpose()?;
+
+ let progressbar = Arc::new(Mutex::new(ProgressWrapper::new(progressbars.bar())));
+
+ let download_sema = Arc::new(tokio::sync::Semaphore::new(NUMBER_OF_MAX_CONCURRENT_DOWNLOADS));
+
+ let r = repo.packages()
+ .filter(|p| {
+ match (pname.as_ref(), pvers.as_ref(), matching_regexp.as_ref()) {
+ (None, None, None) => true,
+ (Some(pname), None, None) => p.name() == pname,
+ (Some(pname), Some(vers), None) => p.name() == pname && vers.matches(p.version()),
+ (None, None, Some(regex)) => regex.is_match(p.name()),
+
+ (_, _, _) => {
+ panic!("This should not be possible, either we select packages by name and (optionally) version, or by regex.")
+ },
+ }
+ })
+ .map(|p| {
+ sc.sources_for(p).into_iter().map(|source| {
+ let download_sema = download_sema.clone();
+ let progressbar = progressbar.clone();
+ async move {
+ let source_path_exists = source.path().exists();
+ if !source_path_exists && source.download_manually() {
+ return Err(anyhow!(
+ "Cannot download source that is marked for manual download"
+ ))
+ .context(anyhow!("Creating source: {}", source.path().display()))
+ .context(anyhow!("Downloading source: {}", source.url()))
+ .map_err(Error::from);
+ }
+
+ if source_path_exists && !force {
+ Err(anyhow!("Source exists: {}", source.path().display()))
+ } else {
+ if source_path_exists /* && force is implied by 'if' above*/ {
+ if let Err(e) = source.remove_file().await {
+ return Err(e)
+ }
+ }
+
+ progressbar.lock().await.inc_download_count().await;
+ {
+ let permit = download_sema.acquire_owned().await?;
+ perform_download(&source, progressbar.clone(), timeout).await?;
+ drop(permit);
+ }
+ progressbar.lock().await.finish_one_download().await;
+ Ok(())
+ }
+ }
+ })
+ })
+ .flatten()
+ .collect::<futures::stream::FuturesUnordered<_>>()
+ .collect::<Vec<Result<()>>>()
+ .await
+ .into_iter()
+ .collect::<Result<()>>();
+
+ if r.is_err() {
+ progressbar.lock().await.error().await;
+ } else {
+ progressbar.lock().await.success().await;
+ }
+
+ debug!("r = {:?}", r);
+ r
+}
+
diff --git a/src/commands/source.rs b/src/commands/source/mod.rs
index 6f52b2b..8d11099 100644
--- a/src/commands/source.rs
+++ b/src/commands/source/mod.rs
@@ -10,19 +10,17 @@
//! Implementation of the 'source' subcommand
+use std::convert::TryFrom;
use std::io::Write;
use std::path::PathBuf;
-use std::convert::TryFrom;
-use std::str::FromStr;
-use anyhow::anyhow;
use anyhow::Context;
use anyhow::Error;
use anyhow::Result;
+use anyhow::anyhow;
use clap::ArgMatches;
use colored::Colorize;
use log::{info, trace};
-use tokio::io::AsyncWriteExt;
use tokio_stream::StreamExt;
use crate::config::*;
@@ -33,6 +31,8 @@ use crate::repository::Repository;
use crate::source::*;
use crate::util::progress::ProgressBars;
+mod download;
+
/// Implementation of the "source" subcommand
pub async fn source(
matches: &ArgMatches,
@@ -44,7 +44,7 @@ pub async fn source(
Some(("verify", matches)) => verify(matches, config, repo, progressbars).await,
Some(("list-missing", matches)) => list_missing(matches, config, repo).await,
Some(("url", matches)) => url(matches, repo).await,
- Some(("download", matches)) => download(matches, config, repo, progressbars).await,
+ Some(("download", matches)) => crate::commands::source::download::download(matches, config, repo, progressbars).await,
Some(("of", matches)) => of(matches, config, repo).await,
Some((other, _)) => return Err(anyhow!("Unknown subcommand: {}", other)),
None => Err(anyhow!("No subcommand")),
@@ -218,157 +218,6 @@ pub async fn url(matches: &ArgMatches, repo: Repository) -> Result<()> {
})
}
-pub async fn download(
- matches: &ArgMatches,
- config: &Configuration,
- repo: Repository,
- progressbars: ProgressBars,
-) -> Result<()> {
- async fn perform_download(source: &SourceEntry, bar: &indicatif::ProgressBar, timeout: Option<u64>) -> Result<()> {
- trace!("Creating: {:?}", source);
- let file = source.create().await.with_context(|| {
- anyhow!(
- "Creating source file destination: {}",
- source.path().display()
- )
- })?;
-
- let mut file = tokio::io::BufWriter::new(file);
- let client_builder = reqwest::Client::builder()
- .redirect(reqwest::redirect::Policy::limited(10));
-
- let client_builder = if let Some(to) = timeout {
- client_builder.timeout(std::time::Duration::from_secs(to))
- } else {
- client_builder
- };
-
- let client = client_builder.build().context("Building HTTP client failed")?;
-
- let request = client.get(source.url().as_ref())
- .build()
- .with_context(|| anyhow!("Building request for {} failed", source.url().as_ref()))?;
-
- let response = match client.execute(request).await {
- Ok(resp) => resp,
- Err(e) => {
- bar.finish_with_message(format!("Failed: {}", source.url()));
- return Err(e).with_context(|| anyhow!("Downloading '{}'", source.url()))
- }
- };
-
- if let Some(len) = response.content_length() {
- bar.set_length(len);
- }
-
- let mut stream = reqwest::get(source.url().as_ref()).await?.bytes_stream();
- let mut bytes_written = 0;
- while let Some(bytes) = stream.next().await {
- let bytes = bytes?;
- file.write_all(bytes.as_ref()).await?;
- bytes_written += bytes.len();
-
- bar.inc(bytes.len() as u64);
- if let Some(len) = response.content_length() {
- bar.set_message(format!("Downloading {} ({}/{} bytes)", source.url(), bytes_written, len));
- } else {
- bar.set_message(format!("Downloading {} ({} bytes)", source.url(), bytes_written));
- }
- }
-
- file.flush()
- .await
- .map_err(Error::from)
- .map(|_| ())
- }
-
- let force = matches.is_present("force");
- let timeout = matches.value_of("timeout")
- .map(u64::from_str)
- .transpose()
- .context("Parsing timeout argument to integer")?;
- let cache = PathBuf::from(config.source_cache_root());
- let sc = SourceCache::new(cache);
- let pname = matches
- .value_of("package_name")
- .map(String::from)
- .map(PackageName::from);
- let pvers = matches
- .value_of("package_version")
- .map(PackageVersionConstraint::try_from)
- .transpose()?;
- let multi = {
- let mp = indicatif::MultiProgress::new();
- if progressbars.hide() {
- mp.set_draw_target(indicatif::ProgressDrawTarget::hidden());
- }
- mp
- };
-
- let matching_regexp = matches.value_of("matching")
- .map(crate::commands::util::mk_package_name_regex)
- .transpose()?;
-
- let r = repo
- .packages()
- .filter(|p| {
- match (pname.as_ref(), pvers.as_ref(), matching_regexp.as_ref()) {
- (None, None, None) => true,
- (Some(pname), None, None) => p.name() == pname,
- (Some(pname), Some(vers), None) => p.name() == pname && vers.matches(p.version()),
- (None, None, Some(regex)) => regex.is_match(p.name()),
-
- (_, _, _) => {
- panic!("This should not be possible, either we select packages by name and (optionally) version, or by regex.")
- },
- }
- })
- .map(|p| {
- sc.sources_for(p).into_iter().map(|source| {
- let bar = multi.add(progressbars.spinner());
- bar.set_message(format!("Downloading {}", source.url()));
- async move {
- let source_path_exists = source.path().exists();
- if !source_path_exists && source.download_manually() {
- return Err(anyhow!(
- "Cannot download source that is marked for manual download"
- ))
- .context(anyhow!("Creating source: {}", source.path().display()))
- .context(anyhow!("Downloading source: {}", source.url()))
- .map_err(Error::from);
- }
-
- if source_path_exists && !force {
- Err(anyhow!("Source exists: {}", source.path().display()))
- } else {
- if source_path_exists /* && force is implied by 'if' above*/ {
- if let Err(e) = source.remove_file().await {
- bar.finish_with_message(format!("Failed to remove existing file: {}", source.path().display()));
- return Err(e)
- }
- }
-
-
- if let Err(e) = perform_download(&source, &bar, timeout).await {
- bar.finish_with_message(format!("Failed: {}", source.url()));
- Err(e)
- } else {
- bar.finish_with_message(format!("Finished: {}", source.url()));
- Ok(())
- }
- }
- }
- })
- })
- .flatten()
- .collect::<futures::stream::FuturesUnordered<_>>()
- .collect::<Vec<Result<()>>>();
-
- let multibar_block = tokio::task::spawn_blocking(move || multi.join());
- let (r, _) = tokio::join!(r, multibar_block);
- r.into_iter().collect()
-}
-
async fn of(
matches: &ArgMatches,
config: &Configuration,
diff --git a/src/config/not_validated.rs b/src/config/not_validated.rs
index 4a81fbb..d41acd7 100644
--- a/src/config/not_validated.rs
+++ b/src/config/not_validated.rs
@@ -51,6 +51,7 @@ pub struct NotValidatedConfiguration {
/// The format of the spinners in the CLI
#[serde(default = "default_spinner_format")]
#[getset(get = "pub")]
+ #[allow(unused)]
spinner_format: String,
/// The format used to print a package
diff --git a/src/main.rs b/src/main.rs
index 2dbfa5d..a57eef5 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -139,7 +139,6 @@ async fn main() -> Result<()> {
let hide_bars = cli.is_present("hide_bars") || crate::util::stdout_is_pipe();
let progressbars = ProgressBars::setup(
config.progress_format().clone(),
- config.spinner_format().clone(),
hide_bars,
);
diff --git a/src/util/progress.rs b/src/util/progress.rs
index e403989..57f9d67 100644
--- a/src/util/progress.rs
+++ b/src/util/progress.rs
@@ -14,17 +14,15 @@ use getset::CopyGetters;
#[derive(Clone, Debug, CopyGetters)]
pub struct ProgressBars {
bar_template: String,
- spinner_template: String,
#[getset(get_copy = "pub")]
hide: bool,
}
impl ProgressBars {
- pub fn setup(bar_template: String, spinner_template: String, hide: bool) -> Self {
+ pub fn setup(bar_template: String, hide: bool) -> Self {
ProgressBars {
bar_template,
- spinner_template,
hide,
}
}
@@ -38,14 +36,4 @@ impl ProgressBars {
b
}
}
-
- pub fn spinner(&self) -> ProgressBar {
- if self.hide {
- ProgressBar::hidden()
- } else {
- let bar = ProgressBar::new_spinner();
- bar.set_style(ProgressStyle::default_spinner().template(&self.spinner_template));
- bar
- }
- }
}