summaryrefslogtreecommitdiffstats
path: root/src/postings/block_search.rs
blob: 04da35b931020211770737808cfd9f62e56ba016 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
use postings::compression::AlignedBuffer;

/// This modules define the logic used to search for a doc in a given
/// block. (at most 128 docs)
///
/// Searching within a block is a hotspot when running intersection.
/// so it was worth defining it in its own module.

#[cfg(target_arch = "x86_64")]
mod sse2 {
    use postings::compression::{AlignedBuffer, COMPRESSION_BLOCK_SIZE};
    use std::arch::x86_64::__m128i as DataType;
    use std::arch::x86_64::_mm_add_epi32 as op_add;
    use std::arch::x86_64::_mm_cmplt_epi32 as op_lt;
    use std::arch::x86_64::_mm_load_si128 as op_load; // requires 128-bits alignment
    use std::arch::x86_64::_mm_set1_epi32 as set1;
    use std::arch::x86_64::_mm_setzero_si128 as set0;
    use std::arch::x86_64::_mm_sub_epi32 as op_sub;
    use std::arch::x86_64::{_mm_cvtsi128_si32, _mm_shuffle_epi32};

    const MASK1: i32 = 78;
    const MASK2: i32 = 177;

    /// Performs an exhaustive linear search over the
    ///
    /// There is no early exit here. We simply count the
    /// number of elements that are `< target`.
    pub(crate) fn linear_search_sse2_128(arr: &AlignedBuffer, target: u32) -> usize {
        unsafe {
            let ptr = arr as *const AlignedBuffer as *const DataType;
            let vkey = set1(target as i32);
            let mut cnt = set0();
            // We work over 4 `__m128i` at a time.
            // A single `__m128i` actual contains 4 `u32`.
            for i in 0..(COMPRESSION_BLOCK_SIZE as isize) / (4 * 4) {
                let cmp1 = op_lt(op_load(ptr.offset(i * 4)), vkey);
                let cmp2 = op_lt(op_load(ptr.offset(i * 4 + 1)), vkey);
                let cmp3 = op_lt(op_load(ptr.offset(i * 4 + 2)), vkey);
                let cmp4 = op_lt(op_load(ptr.offset(i * 4 + 3)), vkey);
                let sum = op_add(op_add(cmp1, cmp2), op_add(cmp3, cmp4));
                cnt = op_sub(cnt, sum);
            }
            cnt = op_add(cnt, _mm_shuffle_epi32(cnt, MASK1));
            cnt = op_add(cnt, _mm_shuffle_epi32(cnt, MASK2));
            _mm_cvtsi128_si32(cnt) as usize
        }
    }

    #[cfg(test)]
    mod test {
        use super::linear_search_sse2_128;
        use postings::compression::{AlignedBuffer, COMPRESSION_BLOCK_SIZE};

        #[test]
        fn test_linear_search_sse2_128_u32() {
            let mut block = [0u32; COMPRESSION_BLOCK_SIZE];
            for el in 0u32..128u32 {
                block[el as usize] = el * 2 + 1 << 18;
            }
            let target = block[64] + 1;
            assert_eq!(linear_search_sse2_128(&AlignedBuffer(block), target), 65);
        }
    }
}

/// This `linear search` browser exhaustively through the array.
/// but the early exit is very difficult to predict.
///
/// Coupled with `exponential search` this function is likely
/// to be called with the same `len`
fn linear_search(arr: &[u32], target: u32) -> usize {
    arr.iter().map(|&el| if el < target { 1 } else { 0 }).sum()
}

fn exponential_search(arr: &[u32], target: u32) -> (usize, usize) {
    let end = arr.len();
    let mut begin = 0;
    for &pivot in &[1, 3, 7, 15, 31, 63] {
        if pivot >= end {
            break;
        }
        if arr[pivot] > target {
            return (begin, pivot);
        }
        begin = pivot;
    }
    (begin, end)
}

fn galloping(block_docs: &[u32], target: u32) -> usize {
    let (start, end) = exponential_search(&block_docs, target);
    start + linear_search(&block_docs[start..end], target)
}

/// Tantivy may rely on SIMD instructions to search for a specific document within
/// a given block.
#[derive(Clone, Copy, PartialEq)]
pub enum BlockSearcher {
    #[cfg(target_arch = "x86_64")]
    SSE2,
    Scalar,
}

impl BlockSearcher {
    /// Search the first index containing an element greater or equal to
    /// the target.
    ///
    /// The results should be equivalent to
    /// ```ignore
    /// block[..]
    //       .iter()
    //       .take_while(|&&val| val < target)
    //       .count()
    /// ```
    ///
    /// The `start` argument is just used to hint that the response is
    /// greater than beyond `start`. The implementation may or may not use
    /// it for optimization.
    ///
    /// # Assumption
    ///
    /// The array len is > start.
    /// The block is sorted
    /// The target is assumed greater or equal to the `arr[start]`.
    /// The target is assumed smaller or equal to the last element of the block.
    ///
    /// Currently the scalar implementation starts by an exponential search, and
    /// then operates a linear search in the result subarray.
    ///
    /// If SSE2 instructions are available in the `(platform, running CPU)`,
    /// then we use a different implementation that does an exhaustive linear search over
    /// the full block whenever the block is full (`len == 128`). It is surprisingly faster, most likely because of the lack
    /// of branch.
    pub(crate) fn search_in_block(
        self,
        block_docs: &AlignedBuffer,
        len: usize,
        start: usize,
        target: u32,
    ) -> usize {
        #[cfg(target_arch = "x86_64")]
        {
            use postings::compression::COMPRESSION_BLOCK_SIZE;
            if self == BlockSearcher::SSE2 && len == COMPRESSION_BLOCK_SIZE {
                return sse2::linear_search_sse2_128(block_docs, target);
            }
        }
        start + galloping(&block_docs.0[start..len], target)
    }
}

impl Default for BlockSearcher {
    fn default() -> BlockSearcher {
        #[cfg(target_arch = "x86_64")]
        {
            if is_x86_feature_detected!("sse2") {
                return BlockSearcher::SSE2;
            }
        }
        BlockSearcher::Scalar
    }
}

#[cfg(test)]
mod tests {
    use super::exponential_search;
    use super::linear_search;
    use super::BlockSearcher;
    use postings::compression::{AlignedBuffer, COMPRESSION_BLOCK_SIZE};

    #[test]
    fn test_linear_search() {
        let len: usize = 50;
        let arr: Vec<u32> = (0..len).map(|el| 1u32 + (el as u32) * 2).collect();
        for target in 1..*arr.last().unwrap() {
            let res = linear_search(&arr[..], target);
            if res > 0 {
                assert!(arr[res - 1] < target);
            }
            if res < len {
                assert!(arr[res] >= target);
            }
        }
    }

    #[test]
    fn test_exponentiel_search() {
        assert_eq!(exponential_search(&[1, 2], 0), (0, 1));
        assert_eq!(exponential_search(&[1, 2], 1), (0, 1));
        assert_eq!(
            exponential_search(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], 7),
            (3, 7)
        );
    }

    fn util_test_search_in_block(block_searcher: BlockSearcher, block: &[u32], target: u32) {
        let cursor = search_in_block_trivial_but_slow(block, target);
        assert!(block.len() < COMPRESSION_BLOCK_SIZE);
        let mut output_buffer = [u32::max_value(); COMPRESSION_BLOCK_SIZE];
        output_buffer[..block.len()].copy_from_slice(block);
        for i in 0..cursor {
            assert_eq!(
                block_searcher.search_in_block(
                    &AlignedBuffer(output_buffer),
                    block.len(),
                    i,
                    target
                ),
                cursor
            );
        }
    }

    fn util_test_search_in_block_all(block_searcher: BlockSearcher, block: &[u32]) {
        use std::collections::HashSet;
        let mut targets = HashSet::new();
        for (i, val) in block.iter().cloned().enumerate() {
            if i > 0 {
                targets.insert(val - 1);
            }
            targets.insert(val);
        }
        for target in targets {
            util_test_search_in_block(block_searcher, block, target);
        }
    }

    fn search_in_block_trivial_but_slow(block: &[u32], target: u32) -> usize {
        block.iter().take_while(|&&val| val < target).count()
    }

    fn test_search_in_block_util(block_searcher: BlockSearcher) {
        for len in 1u32..128u32 {
            let v: Vec<u32> = (0..len).map(|i| i * 2).collect();
            util_test_search_in_block_all(block_searcher, &v[..]);
        }
    }

    #[test]
    fn test_search_in_block_scalar() {
        test_search_in_block_util(BlockSearcher::Scalar);
    }

    #[cfg(target_arch = "x86_64")]
    #[test]
    fn test_search_in_block_sse2() {
        test_search_in_block_util(BlockSearcher::SSE2);
    }
}