summaryrefslogtreecommitdiffstats
path: root/slb/src/main.rs
blob: 7955b1ae8079e01aa4065f3b8914588969eede1f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
//! `slb` main executable

use std::collections::hash_map::DefaultHasher;
use std::convert::TryInto;
use std::hash::{Hash, Hasher};
use std::io::{self, BufRead, BufReader, Write};
use std::iter;
use std::process::{Command, Stdio};
use std::sync::mpsc::sync_channel;
use std::sync::{Arc, Mutex};
use std::thread;

use bstr::io::BufReadExt;
use memchr::memchr;
use structopt::StructOpt;

/// Performs streaming load balancing on stdin, handing off input
/// to child processes based on a hash of the first word on each line.
///
/// E.g., suppose we have a file with contents like
///
/// ```
/// key1 a b c d
/// key2 e f g h
/// key1 a b
/// ```
///
/// The key is all bytes leading up to the first space, or all bytes
/// on a line if there are no spaces. Suppose `hash(key1) == 1` and
/// `hash(key2) == 2`. For a machine with 2 cores, `slb` will have
/// two processes, and the zeroth one will receive as stdin
///
/// ```
/// key2 e f g h
/// ```
///
/// since `2 % 2 == 0` and all `key1` lines will be fed into the
/// process at index 1.
///
/// These processes are expected to perform some kind of aggregation
/// and print at the end of their execution. For instance, suppose we invoke
/// `slb 'awk -f catter.awk'` where `catter.awk` is just defined to be
/// `{key = $1; $1 = ""; a[key] += $0}END{for (k in a) print k,a[k]}`,
/// which just concatenates the keyed values. Then the output might be
///
/// ```
/// key1  a b c d a b
/// key2  e f g h
/// ```
#[derive(Debug, StructOpt)]
#[structopt(name = "slb", about = "Performs streaming load balancing.")]
struct Opt {
    /// The first positional argument determines the child processes
    /// that get launched. It is required.
    ///
    /// Multiple instances of this same process are created with the same
    /// command-line string. Text lines from the stdin of `slb` are fed
    /// into these processes, and stdout is shared between this parent
    /// process and its children.
    cmd: String,

    /// Queue size for mpsc queues used for load balancing to inputs.
    #[structopt(short, long)]
    queuesize: Option<usize>,

    /// Buffer size for reading input.
    #[structopt(short, long)]
    bufsize: Option<usize>,
    // arguments to consider:
    // #[structopt(short,long)]
    // sort-like KEYDEF -k --key
    // nproc -j --jobs
    // buffer input size (buffer full stdin reads, do line parsing
    // ourselves)
    // queue buffer size for mpsc queues
}

fn main() {
    let opt = Opt::from_args();

    let children: Vec<_> = (0..rayon::current_num_threads())
        .map(|i| {
            Command::new("/bin/bash")
                .arg("-c")
                .arg(&opt.cmd)
                .stdin(Stdio::piped())
                .stdout(Stdio::piped())
                .spawn()
                .unwrap_or_else(|err| panic!("error spawn child {}: {}", i, err))
        })
        .collect();

    let (txs, rxs): (Vec<_>, Vec<_>) = (0..children.len())
        .map(|_| sync_channel(opt.queuesize.unwrap_or(16 * 1024)))
        .unzip();

    rayon::spawn(move || {
        // txs captured by value here
        let mut done = false;
        let txs_ref = &txs;
        // use bstr here --> collect multiple slices into hashes
        // line by line
        // but also try chunking (for_byte_line, collect buffers
        iter::from_fn(move || {
            if done {
                return None;
            }
            let bufsize = opt.bufsize.unwrap_or(16 * 1024);
            let mut buf = Vec::with_capacity(bufsize);
            let mut lines = Vec::with_capacity(bufsize / 8);
            // keep reading up until 5x the average line size so far
            // for the most recent block
            // to minimize reallocations
            // buf.len() < bufsize - 5 * avg
            // avg == buf.len() / lines.len()
            // <=> lines.len() * buf.len() <= bufsize * lines.len() - 5 * buf.len()
            while buf.len() * lines.len() <= bufsize * lines.len() - 5 * buf.len() {
                let bytes = io::stdin()
                    .lock()
                    .read_until(b'\n', &mut buf)
                    .expect("read");
                if bytes == 0 {
                    done = true;
                    break;
                }
                buf.pop();
                lines.push(buf.len());
            }
            Some((buf, lines))
        })
        .flat_map(|(buf, lines)| {
            let mut start = 0;
            lines.into_iter().map(move |end| {
                let line = &buf[start..end];
                let key = hash_key(line);
                let send_ix = key % txs_ref.len();
                start = end;
                (send_ix, line.to_vec())
            })
        })
        .for_each(|(send_ix, line)| {
            txs_ref[send_ix].send(line).expect("send");
        });
        // txs dropped here, hanging up the send channel
        drop(txs)
    });
    // rxs valid until txs dropped, since they loop until err

    // instead of a lock here we could just use a channel to pipe stdout lines
    let lock = Arc::new(Mutex::new(()));
    let handles: Vec<_> = children
        .into_iter()
        .zip(rxs.into_iter())
        .map(|(mut child, rx)| {
            let lock = Arc::clone(&lock);
            thread::spawn(move || {
                let mut child_stdin = child.stdin.take().expect("child stdin");
                while let Ok(line) = rx.recv() {
                    child_stdin.write_all(&line).expect("write line");
                    child_stdin.write(b"\n").expect("write newline");
                }
                let _ = lock.lock().expect("lock");
                drop(child_stdin);
                let stdout = child.stdout.take().expect("child_stdout");
                let stdout = BufReader::new(stdout);
                stdout
                    .for_byte_line_with_terminator(|line: &[u8]| {
                        let stdout = io::stdout();
                        let mut handle = stdout.lock();
                        handle.write_all(line).map(|_| true)
                    })
                    .expect("write");
            })
        })
        .collect();
    handles
        .into_iter()
        .for_each(|handle| handle.join().expect("join"));
}

fn hash_key(bytes: &[u8]) -> usize {
    let end = memchr(b' ', bytes).unwrap_or(bytes.len());
    // consider faster hasher?
    let mut hasher = DefaultHasher::default();
    bytes[..end].hash(&mut hasher);
    hasher.finish().try_into().expect("u64 to usize")
}