diff options
Diffstat (limited to 'crates/common/batcher/src/driver.rs')
-rw-r--r-- | crates/common/batcher/src/driver.rs | 260 |
1 files changed, 260 insertions, 0 deletions
diff --git a/crates/common/batcher/src/driver.rs b/crates/common/batcher/src/driver.rs new file mode 100644 index 00000000..0f7c0ed5 --- /dev/null +++ b/crates/common/batcher/src/driver.rs @@ -0,0 +1,260 @@ +use crate::batchable::Batchable; +use crate::batcher::Batcher; +use crate::batcher::BatcherOutput; +use chrono::{DateTime, Utc}; +use std::collections::BTreeSet; +use std::time::Duration; +use tokio::sync::mpsc::error::SendError; +use tokio::sync::mpsc::{Receiver, Sender}; + +/// Input message to the BatchDriver's input channel. +#[derive(Debug)] +pub enum BatchDriverInput<B: Batchable> { + /// Message representing a new item to batch. + Event(B), + /// Message representing that the batching should finish and that + /// any remaining batches should be immediately closed and sent to the output. + Flush, +} + +/// Output message from the BatchDriver's output channel. +#[derive(Debug)] +pub enum BatchDriverOutput<B: Batchable> { + /// Message representing a batch of items. + Batch(Vec<B>), + /// Message representing that batching has finished. + Flush, +} + +/// The central API for using the batching algorithm. +/// Send items in, get batches out. +#[derive(Debug)] +pub struct BatchDriver<B: Batchable> { + batcher: Batcher<B>, + input: Receiver<BatchDriverInput<B>>, + output: Sender<BatchDriverOutput<B>>, + timers: BTreeSet<DateTime<Utc>>, +} + +enum TimeTo { + Unbounded, + Future(std::time::Duration), + Past(DateTime<Utc>), +} + +impl<B: Batchable> BatchDriver<B> { + /// Define the batching process and channels to interact with it. + pub fn new( + batcher: Batcher<B>, + input: Receiver<BatchDriverInput<B>>, + output: Sender<BatchDriverOutput<B>>, + ) -> BatchDriver<B> { + BatchDriver { + batcher, + input, + output, + timers: BTreeSet::new(), + } + } + + /// Start the batching - runs until receiving a Flush message + pub async fn run(mut self) -> Result<(), SendError<BatchDriverOutput<B>>> { + loop { + let message = match self.time_to_next_timer() { + TimeTo::Unbounded => self.recv(None), + TimeTo::Future(timeout) => self.recv(Some(timeout)), + TimeTo::Past(timer) => { + self.timers.remove(&timer); + self.time(Utc::now()).await?; + continue; + } + }; + + match message.await { + Err(_) => continue, // timer timeout expired + Ok(None) => break, // input channel closed + Ok(Some(BatchDriverInput::Flush)) => break, // we've been told to stop + Ok(Some(BatchDriverInput::Event(event))) => self.event(event).await?, + }; + } + + self.flush().await + } + + async fn recv( + &mut self, + timeout: Option<Duration>, + ) -> Result<Option<BatchDriverInput<B>>, tokio::time::error::Elapsed> { + match timeout { + None => Ok(self.input.recv().await), + Some(timeout) => tokio::time::timeout(timeout, self.input.recv()).await, + } + } + + fn time_to_next_timer(&self) -> TimeTo { + match self.timers.iter().next() { + None => TimeTo::Unbounded, + Some(timer) => { + let signed_duration = timer.signed_duration_since(Utc::now()); + match signed_duration.to_std() { + Ok(d) => TimeTo::Future(d), + Err(_) => TimeTo::Past(*timer), + } + } + } + } + + async fn event(&mut self, event: B) -> Result<(), SendError<BatchDriverOutput<B>>> { + for action in self.batcher.event(Utc::now(), event) { + match action { + BatcherOutput::Batch(batch) => { + self.output.send(BatchDriverOutput::Batch(batch)).await?; + } + BatcherOutput::Timer(t) => { + self.timers.insert(t); + } + }; + } + + Ok(()) + } + + async fn time(&mut self, timer: DateTime<Utc>) -> Result<(), SendError<BatchDriverOutput<B>>> { + for batch in self.batcher.time(timer) { + self.output.send(BatchDriverOutput::Batch(batch)).await?; + } + + Ok(()) + } + + async fn flush(self) -> Result<(), SendError<BatchDriverOutput<B>>> { + for batch in self.batcher.flush() { + self.output.send(BatchDriverOutput::Batch(batch)).await?; + } + + self.output.send(BatchDriverOutput::Flush).await + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::batchable::Batchable; + use crate::batcher::Batcher; + use crate::config::BatchConfigBuilder; + use crate::driver::BatchDriver; + use chrono::{DateTime, Utc}; + use std::time::Duration; + use tokio::sync::mpsc::error::SendError; + use tokio::sync::mpsc::{channel, Receiver, Sender}; + use tokio::time::timeout; + + #[tokio::test] + async fn flush_empty() -> Result<(), SendError<BatchDriverInput<TestBatchEvent>>> { + let (input_send, mut output_recv) = spawn_driver(); + input_send.send(BatchDriverInput::Flush).await?; + assert_recv_flush(&mut output_recv).await; + Ok(()) + } + + #[tokio::test] + async fn flush_one_batch() -> Result<(), SendError<BatchDriverInput<TestBatchEvent>>> { + let (input_send, mut output_recv) = spawn_driver(); + + let event1 = TestBatchEvent::new(1, Utc::now()); + input_send.send(BatchDriverInput::Event(event1)).await?; + input_send.send(BatchDriverInput::Flush).await?; + + assert_recv_batch(&mut output_recv, vec![event1]).await; + assert_recv_flush(&mut output_recv).await; + + Ok(()) + } + + #[tokio::test] + async fn two_batches_with_timer() -> Result<(), SendError<BatchDriverInput<TestBatchEvent>>> { + let (input_send, mut output_recv) = spawn_driver(); + + let event1 = TestBatchEvent::new(1, Utc::now()); + input_send.send(BatchDriverInput::Event(event1)).await?; + + assert_recv_batch(&mut output_recv, vec![event1]).await; + + let event2 = TestBatchEvent::new(2, Utc::now()); + input_send.send(BatchDriverInput::Event(event2)).await?; + + assert_recv_batch(&mut output_recv, vec![event2]).await; + + Ok(()) + } + + async fn assert_recv_batch( + output_recv: &mut Receiver<BatchDriverOutput<TestBatchEvent>>, + expected: Vec<TestBatchEvent>, + ) { + match timeout(Duration::from_secs(10), output_recv.recv()).await { + Ok(Some(BatchDriverOutput::Batch(batch))) => assert_batch(batch, expected), + other => panic!("Failed to receive batch: {:?}", other), + } + } + + fn assert_batch(batch: Vec<TestBatchEvent>, expected: Vec<TestBatchEvent>) { + assert_eq!(batch.len(), expected.len()); + + for event in &batch { + if !expected.contains(event) { + panic!("Failed to find: {:?}", event); + } + } + } + + async fn assert_recv_flush(output_recv: &mut Receiver<BatchDriverOutput<TestBatchEvent>>) { + match timeout(Duration::from_secs(10), output_recv.recv()).await { + Ok(Some(BatchDriverOutput::Flush)) => {} + other => panic!("Failed to receive flush: {:?}", other), + } + } + + fn spawn_driver() -> ( + Sender<BatchDriverInput<TestBatchEvent>>, + Receiver<BatchDriverOutput<TestBatchEvent>>, + ) { + let (input_send, input_recv) = channel(1); + let (output_send, output_recv) = channel(1); + let config = BatchConfigBuilder::new() + .event_jitter(50) + .delivery_jitter(20) + .message_leap_limit(0) + .build(); + let batcher = Batcher::new(config); + + let driver = BatchDriver::new(batcher, input_recv, output_send); + tokio::spawn(driver.run()); + + (input_send, output_recv) + } + + #[derive(Debug, Copy, Clone, Eq, PartialEq)] + struct TestBatchEvent { + key: u64, + event_time: DateTime<Utc>, + } + + impl TestBatchEvent { + fn new(key: u64, event_time: DateTime<Utc>) -> TestBatchEvent { + TestBatchEvent { key, event_time } + } + } + + impl Batchable for TestBatchEvent { + type Key = u64; + + fn key(&self) -> Self::Key { + self.key + } + + fn event_time(&self) -> DateTime<Utc> { + self.event_time + } + } +} |