diff options
author | Sylvestre Ledru <sylvestre@debian.org> | 2024-03-10 22:44:12 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-03-10 22:44:12 +0100 |
commit | 80702d5391d1b49a682994c67912c36312502bde (patch) | |
tree | be94ef73a53f7352f024745c83199ce1c20b165f | |
parent | fe0c814bd594ee3480d6556cc69fb765117167de (diff) | |
parent | b233569b9ce1f47598cb231720df3400f5e77185 (diff) |
Merge pull request #6014 from BenWiederhake/dev-shuf-range-off-by-one
shuf: Fix off-by-one errors in range handling
-rw-r--r-- | src/uu/shuf/src/shuf.rs | 51 | ||||
-rw-r--r-- | tests/by-util/test_shuf.rs | 147 |
2 files changed, 174 insertions, 24 deletions
diff --git a/src/uu/shuf/src/shuf.rs b/src/uu/shuf/src/shuf.rs index 9ee04826b..40028c2fb 100644 --- a/src/uu/shuf/src/shuf.rs +++ b/src/uu/shuf/src/shuf.rs @@ -12,6 +12,7 @@ use rand::{Rng, RngCore}; use std::collections::HashSet; use std::fs::File; use std::io::{stdin, stdout, BufReader, BufWriter, Error, Read, Write}; +use std::ops::RangeInclusive; use uucore::display::Quotable; use uucore::error::{FromIo, UResult, USimpleError, UUsageError}; use uucore::{format_usage, help_about, help_usage}; @@ -21,7 +22,7 @@ mod rand_read_adapter; enum Mode { Default(String), Echo(Vec<String>), - InputRange((usize, usize)), + InputRange(RangeInclusive<usize>), } static USAGE: &str = help_usage!("shuf.md"); @@ -119,8 +120,8 @@ pub fn uumain(args: impl uucore::Args) -> UResult<()> { find_seps(&mut evec, options.sep); shuf_exec(&mut evec, options)?; } - Mode::InputRange((b, e)) => { - shuf_exec(&mut (b, e), options)?; + Mode::InputRange(mut range) => { + shuf_exec(&mut range, options)?; } Mode::Default(filename) => { let fdata = read_input_file(&filename)?; @@ -289,14 +290,13 @@ impl<'a> Shufable for Vec<&'a [u8]> { } } -impl Shufable for (usize, usize) { +impl Shufable for RangeInclusive<usize> { type Item = usize; fn is_empty(&self) -> bool { - // Note: This is an inclusive range, so equality means there is 1 element. - self.0 > self.1 + self.is_empty() } fn choose(&self, rng: &mut WrappedRng) -> usize { - rng.gen_range(self.0..self.1) + rng.gen_range(self.clone()) } type PartialShuffleIterator<'b> = NonrepeatingIterator<'b> where Self: 'b; fn partial_shuffle<'b>( @@ -304,7 +304,7 @@ impl Shufable for (usize, usize) { rng: &'b mut WrappedRng, amount: usize, ) -> Self::PartialShuffleIterator<'b> { - NonrepeatingIterator::new(self.0, self.1, rng, amount) + NonrepeatingIterator::new(self.clone(), rng, amount) } } @@ -314,8 +314,7 @@ enum NumberSet { } struct NonrepeatingIterator<'a> { - begin: usize, - end: usize, // exclusive + range: RangeInclusive<usize>, rng: &'a mut WrappedRng, remaining_count: usize, buf: NumberSet, @@ -323,19 +322,19 @@ struct NonrepeatingIterator<'a> { impl<'a> NonrepeatingIterator<'a> { fn new( - begin: usize, - end: usize, + range: RangeInclusive<usize>, rng: &'a mut WrappedRng, amount: usize, ) -> NonrepeatingIterator { - let capped_amount = if begin > end { + let capped_amount = if range.start() > range.end() { 0 + } else if *range.start() == 0 && *range.end() == std::usize::MAX { + amount } else { - amount.min(end - begin) + amount.min(range.end() - range.start() + 1) }; NonrepeatingIterator { - begin, - end, + range, rng, remaining_count: capped_amount, buf: NumberSet::AlreadyListed(HashSet::default()), @@ -343,11 +342,11 @@ impl<'a> NonrepeatingIterator<'a> { } fn produce(&mut self) -> usize { - debug_assert!(self.begin <= self.end); + debug_assert!(self.range.start() <= self.range.end()); match &mut self.buf { NumberSet::AlreadyListed(already_listed) => { let chosen = loop { - let guess = self.rng.gen_range(self.begin..self.end); + let guess = self.rng.gen_range(self.range.clone()); let newly_inserted = already_listed.insert(guess); if newly_inserted { break guess; @@ -356,9 +355,11 @@ impl<'a> NonrepeatingIterator<'a> { // Once a significant fraction of the interval has already been enumerated, // the number of attempts to find a number that hasn't been chosen yet increases. // Therefore, we need to switch at some point from "set of already returned values" to "list of remaining values". - let range_size = self.end - self.begin; + let range_size = (self.range.end() - self.range.start()).saturating_add(1); if number_set_should_list_remaining(already_listed.len(), range_size) { - let mut remaining = (self.begin..self.end) + let mut remaining = self + .range + .clone() .filter(|n| !already_listed.contains(n)) .collect::<Vec<_>>(); assert!(remaining.len() >= self.remaining_count); @@ -381,7 +382,7 @@ impl<'a> Iterator for NonrepeatingIterator<'a> { type Item = usize; fn next(&mut self) -> Option<usize> { - if self.begin > self.end || self.remaining_count == 0 { + if self.range.is_empty() || self.remaining_count == 0 { return None; } self.remaining_count -= 1; @@ -462,7 +463,7 @@ fn shuf_exec(input: &mut impl Shufable, opts: Options) -> UResult<()> { Ok(()) } -fn parse_range(input_range: &str) -> Result<(usize, usize), String> { +fn parse_range(input_range: &str) -> Result<RangeInclusive<usize>, String> { if let Some((from, to)) = input_range.split_once('-') { let begin = from .parse::<usize>() @@ -470,7 +471,11 @@ fn parse_range(input_range: &str) -> Result<(usize, usize), String> { let end = to .parse::<usize>() .map_err(|_| format!("invalid input range: {}", to.quote()))?; - Ok((begin, end + 1)) + if begin <= end || begin == end + 1 { + Ok(begin..=end) + } else { + Err(format!("invalid input range: {}", input_range.quote())) + } } else { Err(format!("invalid input range: {}", input_range.quote())) } diff --git a/tests/by-util/test_shuf.rs b/tests/by-util/test_shuf.rs index 7b0af7c94..8a991e435 100644 --- a/tests/by-util/test_shuf.rs +++ b/tests/by-util/test_shuf.rs @@ -139,6 +139,99 @@ fn test_very_large_range_offset() { } #[test] +fn test_range_repeat_no_overflow_1_max() { + let upper_bound = std::usize::MAX; + let result = new_ucmd!() + .arg("-rn1") + .arg(&format!("-i1-{upper_bound}")) + .succeeds(); + result.no_stderr(); + + let result_seq: Vec<usize> = result + .stdout_str() + .split('\n') + .filter(|x| !x.is_empty()) + .map(|x| x.parse().unwrap()) + .collect(); + assert_eq!(result_seq.len(), 1, "Miscounted output length!"); +} + +#[test] +fn test_range_repeat_no_overflow_0_max_minus_1() { + let upper_bound = std::usize::MAX - 1; + let result = new_ucmd!() + .arg("-rn1") + .arg(&format!("-i0-{upper_bound}")) + .succeeds(); + result.no_stderr(); + + let result_seq: Vec<usize> = result + .stdout_str() + .split('\n') + .filter(|x| !x.is_empty()) + .map(|x| x.parse().unwrap()) + .collect(); + assert_eq!(result_seq.len(), 1, "Miscounted output length!"); +} + +#[test] +fn test_range_permute_no_overflow_1_max() { + let upper_bound = std::usize::MAX; + let result = new_ucmd!() + .arg("-n1") + .arg(&format!("-i1-{upper_bound}")) + .succeeds(); + result.no_stderr(); + + let result_seq: Vec<usize> = result + .stdout_str() + .split('\n') + .filter(|x| !x.is_empty()) + .map(|x| x.parse().unwrap()) + .collect(); + assert_eq!(result_seq.len(), 1, "Miscounted output length!"); +} + +#[test] +fn test_range_permute_no_overflow_0_max_minus_1() { + let upper_bound = std::usize::MAX - 1; + let result = new_ucmd!() + .arg("-n1") + .arg(&format!("-i0-{upper_bound}")) + .succeeds(); + result.no_stderr(); + + let result_seq: Vec<usize> = result + .stdout_str() + .split('\n') + .filter(|x| !x.is_empty()) + .map(|x| x.parse().unwrap()) + .collect(); + assert_eq!(result_seq.len(), 1, "Miscounted output length!"); +} + +#[test] +fn test_range_permute_no_overflow_0_max() { + // NOTE: This is different from GNU shuf! + // GNU shuf accepts -i0-MAX-1 and -i1-MAX, but not -i0-MAX. + // This feels like a bug in GNU shuf. + let upper_bound = std::usize::MAX; + let result = new_ucmd!() + .arg("-n1") + .arg(&format!("-i0-{upper_bound}")) + .succeeds(); + result.no_stderr(); + + let result_seq: Vec<usize> = result + .stdout_str() + .split('\n') + .filter(|x| !x.is_empty()) + .map(|x| x.parse().unwrap()) + .collect(); + assert_eq!(result_seq.len(), 1, "Miscounted output length!"); +} + +#[test] fn test_very_high_range_full() { let input_seq = vec![ 2147483641, 2147483642, 2147483643, 2147483644, 2147483645, 2147483646, 2147483647, @@ -626,7 +719,6 @@ fn test_shuf_multiple_input_line_count() { } #[test] -#[ignore = "known issue"] fn test_shuf_repeat_empty_range() { new_ucmd!() .arg("-ri4-3") @@ -653,3 +745,56 @@ fn test_shuf_repeat_empty_input() { .no_stdout() .stderr_only("shuf: no lines to repeat\n"); } + +#[test] +fn test_range_one_elem() { + new_ucmd!() + .arg("-i5-5") + .succeeds() + .no_stderr() + .stdout_only("5\n"); +} + +#[test] +fn test_range_empty() { + new_ucmd!().arg("-i5-4").succeeds().no_output(); +} + +#[test] +fn test_range_empty_minus_one() { + new_ucmd!() + .arg("-i5-3") + .fails() + .no_stdout() + .stderr_only("shuf: invalid input range: '5-3'\n"); +} + +#[test] +fn test_range_repeat_one_elem() { + new_ucmd!() + .arg("-n1") + .arg("-ri5-5") + .succeeds() + .no_stderr() + .stdout_only("5\n"); +} + +#[test] +fn test_range_repeat_empty() { + new_ucmd!() + .arg("-n1") + .arg("-ri5-4") + .fails() + .no_stdout() + .stderr_only("shuf: no lines to repeat\n"); +} + +#[test] +fn test_range_repeat_empty_minus_one() { + new_ucmd!() + .arg("-n1") + .arg("-ri5-3") + .fails() + .no_stdout() + .stderr_only("shuf: invalid input range: '5-3'\n"); +} |