summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorSylvestre Ledru <sylvestre@debian.org>2024-03-10 22:44:12 +0100
committerGitHub <noreply@github.com>2024-03-10 22:44:12 +0100
commit80702d5391d1b49a682994c67912c36312502bde (patch)
treebe94ef73a53f7352f024745c83199ce1c20b165f
parentfe0c814bd594ee3480d6556cc69fb765117167de (diff)
parentb233569b9ce1f47598cb231720df3400f5e77185 (diff)
Merge pull request #6014 from BenWiederhake/dev-shuf-range-off-by-one
shuf: Fix off-by-one errors in range handling
-rw-r--r--src/uu/shuf/src/shuf.rs51
-rw-r--r--tests/by-util/test_shuf.rs147
2 files changed, 174 insertions, 24 deletions
diff --git a/src/uu/shuf/src/shuf.rs b/src/uu/shuf/src/shuf.rs
index 9ee04826b..40028c2fb 100644
--- a/src/uu/shuf/src/shuf.rs
+++ b/src/uu/shuf/src/shuf.rs
@@ -12,6 +12,7 @@ use rand::{Rng, RngCore};
use std::collections::HashSet;
use std::fs::File;
use std::io::{stdin, stdout, BufReader, BufWriter, Error, Read, Write};
+use std::ops::RangeInclusive;
use uucore::display::Quotable;
use uucore::error::{FromIo, UResult, USimpleError, UUsageError};
use uucore::{format_usage, help_about, help_usage};
@@ -21,7 +22,7 @@ mod rand_read_adapter;
enum Mode {
Default(String),
Echo(Vec<String>),
- InputRange((usize, usize)),
+ InputRange(RangeInclusive<usize>),
}
static USAGE: &str = help_usage!("shuf.md");
@@ -119,8 +120,8 @@ pub fn uumain(args: impl uucore::Args) -> UResult<()> {
find_seps(&mut evec, options.sep);
shuf_exec(&mut evec, options)?;
}
- Mode::InputRange((b, e)) => {
- shuf_exec(&mut (b, e), options)?;
+ Mode::InputRange(mut range) => {
+ shuf_exec(&mut range, options)?;
}
Mode::Default(filename) => {
let fdata = read_input_file(&filename)?;
@@ -289,14 +290,13 @@ impl<'a> Shufable for Vec<&'a [u8]> {
}
}
-impl Shufable for (usize, usize) {
+impl Shufable for RangeInclusive<usize> {
type Item = usize;
fn is_empty(&self) -> bool {
- // Note: This is an inclusive range, so equality means there is 1 element.
- self.0 > self.1
+ self.is_empty()
}
fn choose(&self, rng: &mut WrappedRng) -> usize {
- rng.gen_range(self.0..self.1)
+ rng.gen_range(self.clone())
}
type PartialShuffleIterator<'b> = NonrepeatingIterator<'b> where Self: 'b;
fn partial_shuffle<'b>(
@@ -304,7 +304,7 @@ impl Shufable for (usize, usize) {
rng: &'b mut WrappedRng,
amount: usize,
) -> Self::PartialShuffleIterator<'b> {
- NonrepeatingIterator::new(self.0, self.1, rng, amount)
+ NonrepeatingIterator::new(self.clone(), rng, amount)
}
}
@@ -314,8 +314,7 @@ enum NumberSet {
}
struct NonrepeatingIterator<'a> {
- begin: usize,
- end: usize, // exclusive
+ range: RangeInclusive<usize>,
rng: &'a mut WrappedRng,
remaining_count: usize,
buf: NumberSet,
@@ -323,19 +322,19 @@ struct NonrepeatingIterator<'a> {
impl<'a> NonrepeatingIterator<'a> {
fn new(
- begin: usize,
- end: usize,
+ range: RangeInclusive<usize>,
rng: &'a mut WrappedRng,
amount: usize,
) -> NonrepeatingIterator {
- let capped_amount = if begin > end {
+ let capped_amount = if range.start() > range.end() {
0
+ } else if *range.start() == 0 && *range.end() == std::usize::MAX {
+ amount
} else {
- amount.min(end - begin)
+ amount.min(range.end() - range.start() + 1)
};
NonrepeatingIterator {
- begin,
- end,
+ range,
rng,
remaining_count: capped_amount,
buf: NumberSet::AlreadyListed(HashSet::default()),
@@ -343,11 +342,11 @@ impl<'a> NonrepeatingIterator<'a> {
}
fn produce(&mut self) -> usize {
- debug_assert!(self.begin <= self.end);
+ debug_assert!(self.range.start() <= self.range.end());
match &mut self.buf {
NumberSet::AlreadyListed(already_listed) => {
let chosen = loop {
- let guess = self.rng.gen_range(self.begin..self.end);
+ let guess = self.rng.gen_range(self.range.clone());
let newly_inserted = already_listed.insert(guess);
if newly_inserted {
break guess;
@@ -356,9 +355,11 @@ impl<'a> NonrepeatingIterator<'a> {
// Once a significant fraction of the interval has already been enumerated,
// the number of attempts to find a number that hasn't been chosen yet increases.
// Therefore, we need to switch at some point from "set of already returned values" to "list of remaining values".
- let range_size = self.end - self.begin;
+ let range_size = (self.range.end() - self.range.start()).saturating_add(1);
if number_set_should_list_remaining(already_listed.len(), range_size) {
- let mut remaining = (self.begin..self.end)
+ let mut remaining = self
+ .range
+ .clone()
.filter(|n| !already_listed.contains(n))
.collect::<Vec<_>>();
assert!(remaining.len() >= self.remaining_count);
@@ -381,7 +382,7 @@ impl<'a> Iterator for NonrepeatingIterator<'a> {
type Item = usize;
fn next(&mut self) -> Option<usize> {
- if self.begin > self.end || self.remaining_count == 0 {
+ if self.range.is_empty() || self.remaining_count == 0 {
return None;
}
self.remaining_count -= 1;
@@ -462,7 +463,7 @@ fn shuf_exec(input: &mut impl Shufable, opts: Options) -> UResult<()> {
Ok(())
}
-fn parse_range(input_range: &str) -> Result<(usize, usize), String> {
+fn parse_range(input_range: &str) -> Result<RangeInclusive<usize>, String> {
if let Some((from, to)) = input_range.split_once('-') {
let begin = from
.parse::<usize>()
@@ -470,7 +471,11 @@ fn parse_range(input_range: &str) -> Result<(usize, usize), String> {
let end = to
.parse::<usize>()
.map_err(|_| format!("invalid input range: {}", to.quote()))?;
- Ok((begin, end + 1))
+ if begin <= end || begin == end + 1 {
+ Ok(begin..=end)
+ } else {
+ Err(format!("invalid input range: {}", input_range.quote()))
+ }
} else {
Err(format!("invalid input range: {}", input_range.quote()))
}
diff --git a/tests/by-util/test_shuf.rs b/tests/by-util/test_shuf.rs
index 7b0af7c94..8a991e435 100644
--- a/tests/by-util/test_shuf.rs
+++ b/tests/by-util/test_shuf.rs
@@ -139,6 +139,99 @@ fn test_very_large_range_offset() {
}
#[test]
+fn test_range_repeat_no_overflow_1_max() {
+ let upper_bound = std::usize::MAX;
+ let result = new_ucmd!()
+ .arg("-rn1")
+ .arg(&format!("-i1-{upper_bound}"))
+ .succeeds();
+ result.no_stderr();
+
+ let result_seq: Vec<usize> = result
+ .stdout_str()
+ .split('\n')
+ .filter(|x| !x.is_empty())
+ .map(|x| x.parse().unwrap())
+ .collect();
+ assert_eq!(result_seq.len(), 1, "Miscounted output length!");
+}
+
+#[test]
+fn test_range_repeat_no_overflow_0_max_minus_1() {
+ let upper_bound = std::usize::MAX - 1;
+ let result = new_ucmd!()
+ .arg("-rn1")
+ .arg(&format!("-i0-{upper_bound}"))
+ .succeeds();
+ result.no_stderr();
+
+ let result_seq: Vec<usize> = result
+ .stdout_str()
+ .split('\n')
+ .filter(|x| !x.is_empty())
+ .map(|x| x.parse().unwrap())
+ .collect();
+ assert_eq!(result_seq.len(), 1, "Miscounted output length!");
+}
+
+#[test]
+fn test_range_permute_no_overflow_1_max() {
+ let upper_bound = std::usize::MAX;
+ let result = new_ucmd!()
+ .arg("-n1")
+ .arg(&format!("-i1-{upper_bound}"))
+ .succeeds();
+ result.no_stderr();
+
+ let result_seq: Vec<usize> = result
+ .stdout_str()
+ .split('\n')
+ .filter(|x| !x.is_empty())
+ .map(|x| x.parse().unwrap())
+ .collect();
+ assert_eq!(result_seq.len(), 1, "Miscounted output length!");
+}
+
+#[test]
+fn test_range_permute_no_overflow_0_max_minus_1() {
+ let upper_bound = std::usize::MAX - 1;
+ let result = new_ucmd!()
+ .arg("-n1")
+ .arg(&format!("-i0-{upper_bound}"))
+ .succeeds();
+ result.no_stderr();
+
+ let result_seq: Vec<usize> = result
+ .stdout_str()
+ .split('\n')
+ .filter(|x| !x.is_empty())
+ .map(|x| x.parse().unwrap())
+ .collect();
+ assert_eq!(result_seq.len(), 1, "Miscounted output length!");
+}
+
+#[test]
+fn test_range_permute_no_overflow_0_max() {
+ // NOTE: This is different from GNU shuf!
+ // GNU shuf accepts -i0-MAX-1 and -i1-MAX, but not -i0-MAX.
+ // This feels like a bug in GNU shuf.
+ let upper_bound = std::usize::MAX;
+ let result = new_ucmd!()
+ .arg("-n1")
+ .arg(&format!("-i0-{upper_bound}"))
+ .succeeds();
+ result.no_stderr();
+
+ let result_seq: Vec<usize> = result
+ .stdout_str()
+ .split('\n')
+ .filter(|x| !x.is_empty())
+ .map(|x| x.parse().unwrap())
+ .collect();
+ assert_eq!(result_seq.len(), 1, "Miscounted output length!");
+}
+
+#[test]
fn test_very_high_range_full() {
let input_seq = vec![
2147483641, 2147483642, 2147483643, 2147483644, 2147483645, 2147483646, 2147483647,
@@ -626,7 +719,6 @@ fn test_shuf_multiple_input_line_count() {
}
#[test]
-#[ignore = "known issue"]
fn test_shuf_repeat_empty_range() {
new_ucmd!()
.arg("-ri4-3")
@@ -653,3 +745,56 @@ fn test_shuf_repeat_empty_input() {
.no_stdout()
.stderr_only("shuf: no lines to repeat\n");
}
+
+#[test]
+fn test_range_one_elem() {
+ new_ucmd!()
+ .arg("-i5-5")
+ .succeeds()
+ .no_stderr()
+ .stdout_only("5\n");
+}
+
+#[test]
+fn test_range_empty() {
+ new_ucmd!().arg("-i5-4").succeeds().no_output();
+}
+
+#[test]
+fn test_range_empty_minus_one() {
+ new_ucmd!()
+ .arg("-i5-3")
+ .fails()
+ .no_stdout()
+ .stderr_only("shuf: invalid input range: '5-3'\n");
+}
+
+#[test]
+fn test_range_repeat_one_elem() {
+ new_ucmd!()
+ .arg("-n1")
+ .arg("-ri5-5")
+ .succeeds()
+ .no_stderr()
+ .stdout_only("5\n");
+}
+
+#[test]
+fn test_range_repeat_empty() {
+ new_ucmd!()
+ .arg("-n1")
+ .arg("-ri5-4")
+ .fails()
+ .no_stdout()
+ .stderr_only("shuf: no lines to repeat\n");
+}
+
+#[test]
+fn test_range_repeat_empty_minus_one() {
+ new_ucmd!()
+ .arg("-n1")
+ .arg("-ri5-3")
+ .fails()
+ .no_stdout()
+ .stderr_only("shuf: invalid input range: '5-3'\n");
+}