summaryrefslogtreecommitdiffstats
path: root/src/common/bitpacker.rs
blob: 22a21bcd5a2da0e8705b7fe46ba59672997402a2 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
use byteorder::{ByteOrder, LittleEndian, WriteBytesExt};
use std::io;
use std::ops::Deref;

pub(crate) struct BitPacker {
    mini_buffer: u64,
    mini_buffer_written: usize,
}

impl BitPacker {
    pub fn new() -> BitPacker {
        BitPacker {
            mini_buffer: 0u64,
            mini_buffer_written: 0,
        }
    }

    pub fn write<TWrite: io::Write>(
        &mut self,
        val: u64,
        num_bits: u8,
        output: &mut TWrite,
    ) -> io::Result<()> {
        let val_u64 = val as u64;
        let num_bits = num_bits as usize;
        if self.mini_buffer_written + num_bits > 64 {
            self.mini_buffer |= val_u64.wrapping_shl(self.mini_buffer_written as u32);
            output.write_u64::<LittleEndian>(self.mini_buffer)?;
            self.mini_buffer = val_u64.wrapping_shr((64 - self.mini_buffer_written) as u32);
            self.mini_buffer_written = self.mini_buffer_written + num_bits - 64;
        } else {
            self.mini_buffer |= val_u64 << self.mini_buffer_written;
            self.mini_buffer_written += num_bits;
            if self.mini_buffer_written == 64 {
                output.write_u64::<LittleEndian>(self.mini_buffer)?;
                self.mini_buffer_written = 0;
                self.mini_buffer = 0u64;
            }
        }
        Ok(())
    }

    pub fn flush<TWrite: io::Write>(&mut self, output: &mut TWrite) -> io::Result<()> {
        if self.mini_buffer_written > 0 {
            let num_bytes = (self.mini_buffer_written + 7) / 8;
            let mut arr: [u8; 8] = [0u8; 8];
            LittleEndian::write_u64(&mut arr, self.mini_buffer);
            output.write_all(&arr[..num_bytes])?;
            self.mini_buffer_written = 0;
        }
        Ok(())
    }

    pub fn close<TWrite: io::Write>(&mut self, output: &mut TWrite) -> io::Result<()> {
        self.flush(output)?;
        // Padding the write file to simplify reads.
        output.write_all(&[0u8; 7])?;
        Ok(())
    }
}

#[derive(Clone)]
pub struct BitUnpacker<Data>
where
    Data: Deref<Target = [u8]>,
{
    num_bits: u64,
    mask: u64,
    data: Data,
}

impl<Data> BitUnpacker<Data>
where
    Data: Deref<Target = [u8]>,
{
    pub fn new(data: Data, num_bits: u8) -> BitUnpacker<Data> {
        let mask: u64 = if num_bits == 64 {
            !0u64
        } else {
            (1u64 << num_bits) - 1u64
        };
        BitUnpacker {
            num_bits: u64::from(num_bits),
            mask,
            data,
        }
    }

    pub fn get(&self, idx: u64) -> u64 {
        if self.num_bits == 0 {
            return 0u64;
        }
        let data: &[u8] = &*self.data;
        let num_bits = self.num_bits;
        let mask = self.mask;
        let addr_in_bits = idx * num_bits;
        let addr = addr_in_bits >> 3;
        let bit_shift = addr_in_bits & 7;
        debug_assert!(
            addr + 8 <= data.len() as u64,
            "The fast field field should have been padded with 7 bytes."
        );
        let val_unshifted_unmasked: u64 = LittleEndian::read_u64(&data[(addr as usize)..]);
        let val_shifted = (val_unshifted_unmasked >> bit_shift) as u64;
        val_shifted & mask
    }
}

#[cfg(test)]
mod test {
    use super::{BitPacker, BitUnpacker};

    fn create_fastfield_bitpacker(len: usize, num_bits: u8) -> (BitUnpacker<Vec<u8>>, Vec<u64>) {
        let mut data = Vec::new();
        let mut bitpacker = BitPacker::new();
        let max_val: u64 = (1u64 << num_bits as u64) - 1u64;
        let vals: Vec<u64> = (0u64..len as u64)
            .map(|i| if max_val == 0 { 0 } else { i % max_val })
            .collect();
        for &val in &vals {
            bitpacker.write(val, num_bits, &mut data).unwrap();
        }
        bitpacker.close(&mut data).unwrap();
        assert_eq!(data.len(), ((num_bits as usize) * len + 7) / 8 + 7);
        let bitunpacker = BitUnpacker::new(data, num_bits);
        (bitunpacker, vals)
    }

    fn test_bitpacker_util(len: usize, num_bits: u8) {
        let (bitunpacker, vals) = create_fastfield_bitpacker(len, num_bits);
        for (i, val) in vals.iter().enumerate() {
            assert_eq!(bitunpacker.get(i as u64), *val);
        }
    }

    #[test]
    fn test_bitpacker() {
        test_bitpacker_util(10, 3);
        test_bitpacker_util(10, 0);
        test_bitpacker_util(10, 1);
        test_bitpacker_util(6, 14);
        test_bitpacker_util(1000, 14);
    }
}