summaryrefslogtreecommitdiffstats
path: root/slb/src/main.rs
blob: a7d8f2203499e2b7908c6c2b356f35a13df923de (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
//! `slb` main executable

use std::collections::hash_map::DefaultHasher;
use std::convert::TryInto;
use std::hash::{Hash, Hasher};
use std::io::{self, BufRead, BufReader, Write};
use std::process::{Command, Stdio};
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::mpsc::sync_channel;
use std::sync::Arc;
use std::thread;

use bstr::io::BufReadExt;
use memchr::memchr;
use structopt::StructOpt;

/// Performs streaming load balancing on stdin, handing off input
/// to child processes based on a hash of the first word on each line.
///
/// E.g., suppose we have a file with contents like
///
/// ```
/// key1 a b c d
/// key2 e f g h
/// key1 a b
/// ```
///
/// The key is all bytes leading up to the first space, or all bytes
/// on a line if there are no spaces. Suppose `hash(key1) == 1` and
/// `hash(key2) == 2`. For a machine with 2 cores, `slb` will have
/// two processes, and the zeroth one will receive as stdin
///
/// ```
/// key2 e f g h
/// ```
///
/// since `2 % 2 == 0` and all `key1` lines will be fed into the
/// process at index 1.
///
/// These processes are expected to perform some kind of aggregation
/// and print at the end of their execution. For instance, suppose we invoke
/// `slb 'awk -f catter.awk'` where `catter.awk` is just defined to be
/// `{key = $1; $1 = ""; a[key] += $0}END{for (k in a) print k,a[k]}`,
/// which just concatenates the keyed values. Then the output might be
///
/// ```
/// key1  a b c d a b
/// key2  e f g h
/// ```
#[derive(Debug, StructOpt)]
#[structopt(name = "slb", about = "Performs streaming load balancing.")]
struct Opt {
    /// The first positional argument determines the child processes
    /// that get launched. It is required.
    ///
    /// Multiple instances of this same process are created with the same
    /// command-line string. Text lines from the stdin of `slb` are fed
    /// into these processes, and stdout is shared between this parent
    /// process and its children.
    cmd: String,

    /// Queue size for mpsc queues used for load balancing to inputs.
    #[structopt(short, long)]
    queuesize: Option<usize>,

    /// Buffer size for reading input.
    #[structopt(short, long)]
    bufsize: Option<usize>,
    // arguments to consider:
    // #[structopt(short,long)]
    // sort-like KEYDEF -k --key
    // nproc -j --jobs
    // buffer input size (buffer full stdin reads, do line parsing
    // ourselves)
    // queue buffer size for mpsc queues
    /// Print debug information to stderr.
    #[structopt(short, long)]
    verbose: Option<bool>,
}

fn main() {
    let opt = Opt::from_args();
    let verbose = opt.verbose.unwrap_or(false);

    let children: Vec<_> = (0..rayon::current_num_threads())
        .map(|i| {
            Command::new("/bin/bash")
                .arg("-c")
                .arg(&opt.cmd)
                .stdin(Stdio::piped())
                .stdout(Stdio::piped())
                .spawn()
                .unwrap_or_else(|err| panic!("error spawn child {}: {}", i, err))
        })
        .collect();

    let (txs, rxs): (Vec<_>, Vec<_>) = (0..children.len())
        .map(|_| sync_channel(opt.queuesize.unwrap_or(16 * 1024)))
        .unzip();

    let reader_queue_max_size = 16;
    let (read_tx, read_rx) = sync_channel(reader_queue_max_size);

    let reader_queue_size = Arc::new(AtomicU32::new(0));
    let reader_rqs = Arc::clone(&reader_queue_size);
    let reader = thread::spawn(move || {
        let mut done = false;
        let mut total_queue_len = 0;
        let mut num_enqueues = 0;
        while !done {
            let bufsize = opt.bufsize.unwrap_or(16 * 1024);
            let mut buf = Vec::with_capacity(bufsize);
            let mut lines = Vec::with_capacity(bufsize / 8);
            // keep reading up until 5x the average line size so far
            // for the most recent block
            // to minimize reallocations
            // buf.len() < bufsize - 5 * avg
            // avg == buf.len() / lines.len()
            // <=> lines.len() * buf.len() <= bufsize * lines.len() - 5 * buf.len()
            while buf.len() * lines.len() <= bufsize * lines.len() - 5 * buf.len() {
                let bytes = io::stdin()
                    .lock()
                    .read_until(b'\n', &mut buf)
                    .expect("read");
                if bytes == 0 {
                    done = true;
                    break;
                }
                buf.pop();
                lines.push(buf.len());
            }
            read_tx.send((buf, lines)).expect("send");
            total_queue_len += reader_rqs.fetch_add(1, Ordering::Relaxed);
            num_enqueues += 1;
        }
        // read tx dropped here, hanging up send
        drop(read_tx);
        if verbose {
            eprintln!(
                "avg queue len, rounding up {} (max {})",
                (total_queue_len + num_enqueues - 1) / num_enqueues,
                reader_queue_max_size
            );
        }
    });

    let writer_rqs = Arc::clone(&reader_queue_size);
    let writer = thread::spawn(move || {
        while let Ok((buf, lines)) = read_rx.recv() {
            writer_rqs.fetch_sub(1, Ordering::Relaxed);
            let mut start = 0;
            for end in lines {
                let line = &buf[start..end];
                let key = hash_key(line);
                let send_ix = key % txs.len();
                start = end;
                txs[send_ix].send(line.to_vec()).expect("send");
            }
        }
        // txs dropped here, hanging up the send channel
        drop(txs)
    });
    // rxs valid until txs dropped, since they loop until err

    let handles: Vec<_> = children
        .into_iter()
        .zip(rxs.into_iter())
        .map(|(mut child, rx)| {
            thread::spawn(move || {
                let mut child_stdin = child.stdin.take().expect("child stdin");
                while let Ok(line) = rx.recv() {
                    child_stdin.write_all(&line).expect("write line");
                    child_stdin.write(b"\n").expect("write newline");
                }
                drop(child_stdin);
                let child_stdout = child.stdout.take().expect("child_stdout");
                let child_stdout = BufReader::new(child_stdout);
                let stdout = io::stdout();
                let mut handle = stdout.lock();
                child_stdout
                    .for_byte_line_with_terminator(move |line: &[u8]| {
                        handle.write_all(line).map(|_| true)
                    })
                    .expect("write");
            })
        })
        .collect();

    reader.join().expect("reader join");
    writer.join().expect("writer join");
    handles
        .into_iter()
        .for_each(|handle| handle.join().expect("join"));
}

fn hash_key(bytes: &[u8]) -> usize {
    let end = memchr(b' ', bytes).unwrap_or(bytes.len());
    // consider faster hasher?
    let mut hasher = DefaultHasher::default();
    bytes[..end].hash(&mut hasher);
    hasher.finish().try_into().expect("u64 to usize")
}