diff options
Diffstat (limited to 'server/src/rate_limit/rate_limiter.rs')
-rw-r--r-- | server/src/rate_limit/rate_limiter.rs | 131 |
1 files changed, 131 insertions, 0 deletions
diff --git a/server/src/rate_limit/rate_limiter.rs b/server/src/rate_limit/rate_limiter.rs new file mode 100644 index 00000000..6b01a75b --- /dev/null +++ b/server/src/rate_limit/rate_limiter.rs @@ -0,0 +1,131 @@ +use super::*; + +#[derive(Debug, Clone)] +pub struct RateLimitBucket { + last_checked: SystemTime, + allowance: f64, +} + +#[derive(Eq, PartialEq, Hash, Debug, EnumIter, Copy, Clone)] +pub enum RateLimitType { + Message, + Register, + Post, +} + +/// Rate limiting based on rate type and IP addr +#[derive(Debug, Clone)] +pub struct RateLimiter { + pub buckets: HashMap<RateLimitType, HashMap<IPAddr, RateLimitBucket>>, +} + +impl Default for RateLimiter { + fn default() -> Self { + Self { + buckets: HashMap::new(), + } + } +} + +impl RateLimiter { + fn insert_ip(&mut self, ip: &str) { + for rate_limit_type in RateLimitType::iter() { + if self.buckets.get(&rate_limit_type).is_none() { + self.buckets.insert(rate_limit_type, HashMap::new()); + } + + if let Some(bucket) = self.buckets.get_mut(&rate_limit_type) { + if bucket.get(ip).is_none() { + bucket.insert( + ip.to_string(), + RateLimitBucket { + last_checked: SystemTime::now(), + allowance: -2f64, + }, + ); + } + } + } + } + + pub fn check_rate_limit_register(&mut self, ip: &str, check_only: bool) -> Result<(), Error> { + self.check_rate_limit_full( + RateLimitType::Register, + ip, + Settings::get().rate_limit.register, + Settings::get().rate_limit.register_per_second, + check_only, + ) + } + + pub fn check_rate_limit_post(&mut self, ip: &str, check_only: bool) -> Result<(), Error> { + self.check_rate_limit_full( + RateLimitType::Post, + ip, + Settings::get().rate_limit.post, + Settings::get().rate_limit.post_per_second, + check_only, + ) + } + + pub fn check_rate_limit_message(&mut self, ip: &str, check_only: bool) -> Result<(), Error> { + self.check_rate_limit_full( + RateLimitType::Message, + ip, + Settings::get().rate_limit.message, + Settings::get().rate_limit.message_per_second, + check_only, + ) + } + + #[allow(clippy::float_cmp)] + fn check_rate_limit_full( + &mut self, + type_: RateLimitType, + ip: &str, + rate: i32, + per: i32, + check_only: bool, + ) -> Result<(), Error> { + self.insert_ip(ip); + if let Some(bucket) = self.buckets.get_mut(&type_) { + if let Some(rate_limit) = bucket.get_mut(ip) { + let current = SystemTime::now(); + let time_passed = current.duration_since(rate_limit.last_checked)?.as_secs() as f64; + + // The initial value + if rate_limit.allowance == -2f64 { + rate_limit.allowance = rate as f64; + }; + + rate_limit.last_checked = current; + rate_limit.allowance += time_passed * (rate as f64 / per as f64); + if !check_only && rate_limit.allowance > rate as f64 { + rate_limit.allowance = rate as f64; + } + + if rate_limit.allowance < 1.0 { + warn!( + "Rate limited IP: {}, time_passed: {}, allowance: {}", + ip, time_passed, rate_limit.allowance + ); + Err( + APIError { + message: format!("Too many requests. {} per {} seconds", rate, per), + } + .into(), + ) + } else { + if !check_only { + rate_limit.allowance -= 1.0; + } + Ok(()) + } + } else { + Ok(()) + } + } else { + Ok(()) + } + } +} |