summaryrefslogtreecommitdiffstats
path: root/server/src/websocket/server.rs
diff options
context:
space:
mode:
Diffstat (limited to 'server/src/websocket/server.rs')
-rw-r--r--server/src/websocket/server.rs143
1 files changed, 95 insertions, 48 deletions
diff --git a/server/src/websocket/server.rs b/server/src/websocket/server.rs
index a26c8144..fc838c1f 100644
--- a/server/src/websocket/server.rs
+++ b/server/src/websocket/server.rs
@@ -12,6 +12,7 @@ use serde_json::Value;
use std::collections::{HashMap, HashSet};
use std::str::FromStr;
use std::time::SystemTime;
+use strum::IntoEnumIterator;
use crate::api::comment::*;
use crate::api::community::*;
@@ -71,6 +72,13 @@ pub struct SessionInfo {
pub ip: IPAddr,
}
+#[derive(Eq, PartialEq, Hash, Debug, EnumIter, Copy, Clone)]
+pub enum RateLimitType {
+ Message,
+ Register,
+ Post,
+}
+
/// `ChatServer` manages chat rooms and responsible for coordinating chat
/// session.
pub struct ChatServer {
@@ -87,8 +95,8 @@ pub struct ChatServer {
/// sessions (IE clients)
user_rooms: HashMap<UserId, HashSet<ConnectionId>>,
- /// Rate limiting based on IP addr
- rate_limits: HashMap<IPAddr, RateLimitBucket>,
+ /// Rate limiting based on rate type and IP addr
+ rate_limit_buckets: HashMap<RateLimitType, HashMap<IPAddr, RateLimitBucket>>,
rng: ThreadRng,
db: Pool<ConnectionManager<PgConnection>>,
@@ -98,7 +106,7 @@ impl ChatServer {
pub fn startup(db: Pool<ConnectionManager<PgConnection>>) -> ChatServer {
ChatServer {
sessions: HashMap::new(),
- rate_limits: HashMap::new(),
+ rate_limit_buckets: HashMap::new(),
post_rooms: HashMap::new(),
community_rooms: HashMap::new(),
user_rooms: HashMap::new(),
@@ -259,60 +267,82 @@ impl ChatServer {
to_json_string(&user_operation, post)
}
- fn check_rate_limit_register(&mut self, id: usize) -> Result<(), Error> {
+ fn check_rate_limit_register(&mut self, id: usize, check_only: bool) -> Result<(), Error> {
self.check_rate_limit_full(
+ RateLimitType::Register,
id,
Settings::get().rate_limit.register,
Settings::get().rate_limit.register_per_second,
+ check_only,
)
}
- fn check_rate_limit_post(&mut self, id: usize) -> Result<(), Error> {
+ fn check_rate_limit_post(&mut self, id: usize, check_only: bool) -> Result<(), Error> {
self.check_rate_limit_full(
+ RateLimitType::Post,
id,
Settings::get().rate_limit.post,
Settings::get().rate_limit.post_per_second,
+ check_only,
)
}
- fn check_rate_limit_message(&mut self, id: usize) -> Result<(), Error> {
+ fn check_rate_limit_message(&mut self, id: usize, check_only: bool) -> Result<(), Error> {
self.check_rate_limit_full(
+ RateLimitType::Message,
id,
Settings::get().rate_limit.message,
Settings::get().rate_limit.message_per_second,
+ check_only,
)
}
#[allow(clippy::float_cmp)]
- fn check_rate_limit_full(&mut self, id: usize, rate: i32, per: i32) -> Result<(), Error> {
+ fn check_rate_limit_full(
+ &mut self,
+ type_: RateLimitType,
+ id: usize,
+ rate: i32,
+ per: i32,
+ check_only: bool,
+ ) -> Result<(), Error> {
if let Some(info) = self.sessions.get(&id) {
- if let Some(rate_limit) = self.rate_limits.get_mut(&info.ip) {
- // The initial value
- if rate_limit.allowance == -2f64 {
- rate_limit.allowance = rate as f64;
- };
-
- let current = SystemTime::now();
- let time_passed = current.duration_since(rate_limit.last_checked)?.as_secs() as f64;
- rate_limit.last_checked = current;
- rate_limit.allowance += time_passed * (rate as f64 / per as f64);
- if rate_limit.allowance > rate as f64 {
- rate_limit.allowance = rate as f64;
- }
+ if let Some(bucket) = self.rate_limit_buckets.get_mut(&type_) {
+ if let Some(rate_limit) = bucket.get_mut(&info.ip) {
+ let current = SystemTime::now();
+ let time_passed = current.duration_since(rate_limit.last_checked)?.as_secs() as f64;
+
+ // The initial value
+ if rate_limit.allowance == -2f64 {
+ rate_limit.allowance = rate as f64;
+ };
+
+ rate_limit.last_checked = current;
+ if !check_only {
+ rate_limit.allowance += time_passed * (rate as f64 / per as f64);
+ if rate_limit.allowance > rate as f64 {
+ rate_limit.allowance = rate as f64;
+ }
+ }
- if rate_limit.allowance < 1.0 {
- println!(
- "Rate limited IP: {}, time_passed: {}, allowance: {}",
- &info.ip, time_passed, rate_limit.allowance
- );
- Err(
- APIError {
- message: format!("Too many requests. {} per {} seconds", rate, per),
+ if rate_limit.allowance < 1.0 {
+ println!(
+ "Rate limited IP: {}, time_passed: {}, allowance: {}",
+ &info.ip, time_passed, rate_limit.allowance
+ );
+ Err(
+ APIError {
+ message: format!("Too many requests. {} per {} seconds", rate, per),
+ }
+ .into(),
+ )
+ } else {
+ if !check_only {
+ rate_limit.allowance -= 1.0;
}
- .into(),
- )
+ Ok(())
+ }
} else {
- rate_limit.allowance -= 1.0;
Ok(())
}
} else {
@@ -350,14 +380,24 @@ impl Handler<Connect> for ChatServer {
},
);
- if self.rate_limits.get(&msg.ip).is_none() {
- self.rate_limits.insert(
- msg.ip,
- RateLimitBucket {
- last_checked: SystemTime::now(),
- allowance: -2f64,
- },
- );
+ for rate_limit_type in RateLimitType::iter() {
+ if self.rate_limit_buckets.get(&rate_limit_type).is_none() {
+ self
+ .rate_limit_buckets
+ .insert(rate_limit_type, HashMap::new());
+ }
+
+ if let Some(bucket) = self.rate_limit_buckets.get_mut(&rate_limit_type) {
+ if bucket.get(&msg.ip).is_none() {
+ bucket.insert(
+ msg.ip.to_owned(),
+ RateLimitBucket {
+ last_checked: SystemTime::now(),
+ allowance: -2f64,
+ },
+ );
+ }
+ }
}
id
@@ -446,11 +486,18 @@ fn parse_json_message(chat: &mut ChatServer, msg: StandardMessage) -> Result<Str
// TODO: none of the chat messages are going to work if stuff is submitted via http api,
// need to move that handling elsewhere
+
+ // A DDOS check
+ chat.check_rate_limit_message(msg.id, false)?;
+
match user_operation {
UserOperation::Login => do_user_operation::<Login, LoginResponse>(user_operation, data, &conn),
UserOperation::Register => {
- chat.check_rate_limit_register(msg.id)?;
- do_user_operation::<Register, LoginResponse>(user_operation, data, &conn)
+ chat.check_rate_limit_register(msg.id, true)?;
+ let register: Register = serde_json::from_str(data)?;
+ let res = Oper::new(register).perform(&conn)?;
+ chat.check_rate_limit_register(msg.id, false)?;
+ to_json_string(&user_operation, &res)
}
UserOperation::GetUserDetails => {
do_user_operation::<GetUserDetails, GetUserDetailsResponse>(user_operation, data, &conn)
@@ -503,8 +550,11 @@ fn parse_json_message(chat: &mut ChatServer, msg: StandardMessage) -> Result<Str
do_user_operation::<ListCommunities, ListCommunitiesResponse>(user_operation, data, &conn)
}
UserOperation::CreateCommunity => {
- chat.check_rate_limit_register(msg.id)?;
- do_user_operation::<CreateCommunity, CommunityResponse>(user_operation, data, &conn)
+ chat.check_rate_limit_register(msg.id, true)?;
+ let create_community: CreateCommunity = serde_json::from_str(data)?;
+ let res = Oper::new(create_community).perform(&conn)?;
+ chat.check_rate_limit_register(msg.id, false)?;
+ to_json_string(&user_operation, &res)
}
UserOperation::EditCommunity => {
let edit_community: EditCommunity = serde_json::from_str(data)?;
@@ -566,14 +616,14 @@ fn parse_json_message(chat: &mut ChatServer, msg: StandardMessage) -> Result<Str
to_json_string(&user_operation, &res)
}
UserOperation::CreatePost => {
- chat.check_rate_limit_post(msg.id)?;
+ chat.check_rate_limit_post(msg.id, true)?;
let create_post: CreatePost = serde_json::from_str(data)?;
let res = Oper::new(create_post).perform(&conn)?;
+ chat.check_rate_limit_post(msg.id, false)?;
chat.post_sends(UserOperation::CreatePost, res, msg.id)
}
UserOperation::CreatePostLike => {
- chat.check_rate_limit_message(msg.id)?;
let create_post_like: CreatePostLike = serde_json::from_str(data)?;
let res = Oper::new(create_post_like).perform(&conn)?;
@@ -589,7 +639,6 @@ fn parse_json_message(chat: &mut ChatServer, msg: StandardMessage) -> Result<Str
do_user_operation::<SavePost, PostResponse>(user_operation, data, &conn)
}
UserOperation::CreateComment => {
- chat.check_rate_limit_message(msg.id)?;
let create_comment: CreateComment = serde_json::from_str(data)?;
let res = Oper::new(create_comment).perform(&conn)?;
@@ -605,7 +654,6 @@ fn parse_json_message(chat: &mut ChatServer, msg: StandardMessage) -> Result<Str
do_user_operation::<SaveComment, CommentResponse>(user_operation, data, &conn)
}
UserOperation::CreateCommentLike => {
- chat.check_rate_limit_message(msg.id)?;
let create_comment_like: CreateCommentLike = serde_json::from_str(data)?;
let res = Oper::new(create_comment_like).perform(&conn)?;
@@ -649,7 +697,6 @@ fn parse_json_message(chat: &mut ChatServer, msg: StandardMessage) -> Result<Str
do_user_operation::<PasswordChange, LoginResponse>(user_operation, data, &conn)
}
UserOperation::CreatePrivateMessage => {
- chat.check_rate_limit_message(msg.id)?;
let create_private_message: CreatePrivateMessage = serde_json::from_str(data)?;
let recipient_id = create_private_message.recipient_id;
let res = Oper::new(create_private_message).perform(&conn)?;