summaryrefslogtreecommitdiffstats
path: root/server/src/rate_limit/rate_limiter.rs
blob: 20a617c2fe3097d4617750ed8cd0460df3c87f9c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
use super::IPAddr;
use crate::{api::APIError, LemmyError};
use log::debug;
use std::{collections::HashMap, time::SystemTime};
use strum::IntoEnumIterator;

#[derive(Debug, Clone)]
pub struct RateLimitBucket {
  last_checked: SystemTime,
  allowance: f64,
}

#[derive(Eq, PartialEq, Hash, Debug, EnumIter, Copy, Clone, AsRefStr)]
pub enum RateLimitType {
  Message,
  Register,
  Post,
}

/// Rate limiting based on rate type and IP addr
#[derive(Debug, Clone)]
pub struct RateLimiter {
  pub buckets: HashMap<RateLimitType, HashMap<IPAddr, RateLimitBucket>>,
}

impl Default for RateLimiter {
  fn default() -> Self {
    Self {
      buckets: HashMap::new(),
    }
  }
}

impl RateLimiter {
  fn insert_ip(&mut self, ip: &str) {
    for rate_limit_type in RateLimitType::iter() {
      if self.buckets.get(&rate_limit_type).is_none() {
        self.buckets.insert(rate_limit_type, HashMap::new());
      }

      if let Some(bucket) = self.buckets.get_mut(&rate_limit_type) {
        if bucket.get(ip).is_none() {
          bucket.insert(
            ip.to_string(),
            RateLimitBucket {
              last_checked: SystemTime::now(),
              allowance: -2f64,
            },
          );
        }
      }
    }
  }

  #[allow(clippy::float_cmp)]
  pub(super) fn check_rate_limit_full(
    &mut self,
    type_: RateLimitType,
    ip: &str,
    rate: i32,
    per: i32,
    check_only: bool,
  ) -> Result<(), LemmyError> {
    self.insert_ip(ip);
    if let Some(bucket) = self.buckets.get_mut(&type_) {
      if let Some(rate_limit) = bucket.get_mut(ip) {
        let current = SystemTime::now();
        let time_passed = current.duration_since(rate_limit.last_checked)?.as_secs() as f64;

        // The initial value
        if rate_limit.allowance == -2f64 {
          rate_limit.allowance = rate as f64;
        };

        rate_limit.last_checked = current;
        rate_limit.allowance += time_passed * (rate as f64 / per as f64);
        if !check_only && rate_limit.allowance > rate as f64 {
          rate_limit.allowance = rate as f64;
        }

        if rate_limit.allowance < 1.0 {
          debug!(
            "Rate limited type: {}, IP: {}, time_passed: {}, allowance: {}",
            type_.as_ref(),
            ip,
            time_passed,
            rate_limit.allowance
          );
          Err(
            APIError {
              message: format!(
                "Too many requests. type: {}, IP: {}, {} per {} seconds",
                type_.as_ref(),
                ip,
                rate,
                per
              ),
            }
            .into(),
          )
        } else {
          if !check_only {
            rate_limit.allowance -= 1.0;
          }
          Ok(())
        }
      } else {
        Ok(())
      }
    } else {
      Ok(())
    }
  }
}