diff options
author | Dessalines <tyhou13@gmx.com> | 2020-04-19 18:08:25 -0400 |
---|---|---|
committer | Dessalines <tyhou13@gmx.com> | 2020-04-19 18:08:25 -0400 |
commit | f300c67a4d9674eef05d180a787cc8352092903d (patch) | |
tree | 49d076d128d065403f5690f92900bdd0679f2d66 /server/src/api/user.rs | |
parent | be6a7876b49e8f963506f0b05e12495f119afc10 (diff) |
Adding websocket notification system.
- HTTP and APUB clients can now send live updating messages to websocket
clients
- Rate limiting now affects both HTTP and websockets
- Rate limiting / Websocket logic is now moved into the API Perform
functions.
- TODO This broke getting current online users, but that will have to
wait for the perform trait to be made async.
- Fixes #446
Diffstat (limited to 'server/src/api/user.rs')
-rw-r--r-- | server/src/api/user.rs | 325 |
1 files changed, 297 insertions, 28 deletions
diff --git a/server/src/api/user.rs b/server/src/api/user.rs index 40e09969..31a0a4e7 100644 --- a/server/src/api/user.rs +++ b/server/src/api/user.rs @@ -1,10 +1,5 @@ use super::*; -use crate::settings::Settings; -use crate::{generate_random_string, send_email}; use bcrypt::verify; -use diesel::PgConnection; -use log::error; -use std::str::FromStr; #[derive(Serialize, Deserialize, Debug)] pub struct Login { @@ -89,7 +84,7 @@ pub struct AddAdmin { auth: String, } -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Clone)] pub struct AddAdminResponse { admins: Vec<UserView>, } @@ -103,7 +98,7 @@ pub struct BanUser { auth: String, } -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Clone)] pub struct BanUserResponse { user: UserView, banned: bool, @@ -205,9 +200,23 @@ pub struct UserJoinResponse { } impl Perform<LoginResponse> for Oper<Login> { - fn perform(&self, conn: &PgConnection) -> Result<LoginResponse, Error> { + fn perform( + &self, + pool: Pool<ConnectionManager<PgConnection>>, + _websocket_info: Option<WebsocketInfo>, + rate_limit_info: Option<RateLimitInfo>, + ) -> Result<LoginResponse, Error> { let data: &Login = &self.data; + if let Some(rl) = rate_limit_info { + rl.rate_limiter + .lock() + .unwrap() + .check_rate_limit_message(&rl.ip, false)?; + } + + let conn = pool.get()?; + // Fetch that username / email let user: User_ = match User_::find_by_email_or_username(&conn, &data.username_or_email) { Ok(user) => user, @@ -226,9 +235,23 @@ impl Perform<LoginResponse> for Oper<Login> { } impl Perform<LoginResponse> for Oper<Register> { - fn perform(&self, conn: &PgConnection) -> Result<LoginResponse, Error> { + fn perform( + &self, + pool: Pool<ConnectionManager<PgConnection>>, + _websocket_info: Option<WebsocketInfo>, + rate_limit_info: Option<RateLimitInfo>, + ) -> Result<LoginResponse, Error> { let data: &Register = &self.data; + if let Some(rl) = &rate_limit_info { + rl.rate_limiter + .lock() + .unwrap() + .check_rate_limit_register(&rl.ip, true)?; + } + + let conn = pool.get()?; + // Make sure site has open registration if let Ok(site) = SiteView::read(&conn) { if !site.open_registration { @@ -332,6 +355,13 @@ impl Perform<LoginResponse> for Oper<Register> { }; } + if let Some(rl) = rate_limit_info { + rl.rate_limiter + .lock() + .unwrap() + .check_rate_limit_register(&rl.ip, false)?; + } + // Return the jwt Ok(LoginResponse { jwt: inserted_user.jwt(), @@ -340,7 +370,12 @@ impl Perform<LoginResponse> for Oper<Register> { } impl Perform<LoginResponse> for Oper<SaveUserSettings> { - fn perform(&self, conn: &PgConnection) -> Result<LoginResponse, Error> { + fn perform( + &self, + pool: Pool<ConnectionManager<PgConnection>>, + _websocket_info: Option<WebsocketInfo>, + rate_limit_info: Option<RateLimitInfo>, + ) -> Result<LoginResponse, Error> { let data: &SaveUserSettings = &self.data; let claims = match Claims::decode(&data.auth) { @@ -350,6 +385,15 @@ impl Perform<LoginResponse> for Oper<SaveUserSettings> { let user_id = claims.id; + if let Some(rl) = rate_limit_info { + rl.rate_limiter + .lock() + .unwrap() + .check_rate_limit_message(&rl.ip, false)?; + } + + let conn = pool.get()?; + let read_user = User_::read(&conn, user_id)?; let email = match &data.email { @@ -428,9 +472,23 @@ impl Perform<LoginResponse> for Oper<SaveUserSettings> { } impl Perform<GetUserDetailsResponse> for Oper<GetUserDetails> { - fn perform(&self, conn: &PgConnection) -> Result<GetUserDetailsResponse, Error> { + fn perform( + &self, + pool: Pool<ConnectionManager<PgConnection>>, + _websocket_info: Option<WebsocketInfo>, + rate_limit_info: Option<RateLimitInfo>, + ) -> Result<GetUserDetailsResponse, Error> { let data: &GetUserDetails = &self.data; + if let Some(rl) = rate_limit_info { + rl.rate_limiter + .lock() + .unwrap() + .check_rate_limit_message(&rl.ip, false)?; + } + + let conn = pool.get()?; + let user_claims: Option<Claims> = match &data.auth { Some(auth) => match Claims::decode(&auth) { Ok(claims) => Some(claims.claims), @@ -525,7 +583,12 @@ impl Perform<GetUserDetailsResponse> for Oper<GetUserDetails> { } impl Perform<AddAdminResponse> for Oper<AddAdmin> { - fn perform(&self, conn: &PgConnection) -> Result<AddAdminResponse, Error> { + fn perform( + &self, + pool: Pool<ConnectionManager<PgConnection>>, + websocket_info: Option<WebsocketInfo>, + rate_limit_info: Option<RateLimitInfo>, + ) -> Result<AddAdminResponse, Error> { let data: &AddAdmin = &self.data; let claims = match Claims::decode(&data.auth) { @@ -535,6 +598,15 @@ impl Perform<AddAdminResponse> for Oper<AddAdmin> { let user_id = claims.id; + if let Some(rl) = rate_limit_info { + rl.rate_limiter + .lock() + .unwrap() + .check_rate_limit_message(&rl.ip, false)?; + } + + let conn = pool.get()?; + // Make sure user is an admin if !UserView::read(&conn, user_id)?.admin { return Err(APIError::err("not_an_admin").into()); @@ -583,12 +655,27 @@ impl Perform<AddAdminResponse> for Oper<AddAdmin> { let creator_user = admins.remove(creator_index); admins.insert(0, creator_user); - Ok(AddAdminResponse { admins }) + let res = AddAdminResponse { admins }; + + if let Some(ws) = websocket_info { + ws.chatserver.do_send(SendAllMessage { + op: UserOperation::AddAdmin, + response: res.clone(), + my_id: ws.id, + }); + } + + Ok(res) } } impl Perform<BanUserResponse> for Oper<BanUser> { - fn perform(&self, conn: &PgConnection) -> Result<BanUserResponse, Error> { + fn perform( + &self, + pool: Pool<ConnectionManager<PgConnection>>, + websocket_info: Option<WebsocketInfo>, + rate_limit_info: Option<RateLimitInfo>, + ) -> Result<BanUserResponse, Error> { let data: &BanUser = &self.data; let claims = match Claims::decode(&data.auth) { @@ -598,6 +685,15 @@ impl Perform<BanUserResponse> for Oper<BanUser> { let user_id = claims.id; + if let Some(rl) = rate_limit_info { + rl.rate_limiter + .lock() + .unwrap() + .check_rate_limit_message(&rl.ip, false)?; + } + + let conn = pool.get()?; + // Make sure user is an admin if !UserView::read(&conn, user_id)?.admin { return Err(APIError::err("not_an_admin").into()); @@ -649,15 +745,30 @@ impl Perform<BanUserResponse> for Oper<BanUser> { let user_view = UserView::read(&conn, data.user_id)?; - Ok(BanUserResponse { + let res = BanUserResponse { user: user_view, banned: data.ban, - }) + }; + + if let Some(ws) = websocket_info { + ws.chatserver.do_send(SendAllMessage { + op: UserOperation::BanUser, + response: res.clone(), + my_id: ws.id, + }); + } + + Ok(res) } } impl Perform<GetRepliesResponse> for Oper<GetReplies> { - fn perform(&self, conn: &PgConnection) -> Result<GetRepliesResponse, Error> { + fn perform( + &self, + pool: Pool<ConnectionManager<PgConnection>>, + _websocket_info: Option<WebsocketInfo>, + rate_limit_info: Option<RateLimitInfo>, + ) -> Result<GetRepliesResponse, Error> { let data: &GetReplies = &self.data; let claims = match Claims::decode(&data.auth) { @@ -669,6 +780,15 @@ impl Perform<GetRepliesResponse> for Oper<GetReplies> { let sort = SortType::from_str(&data.sort)?; + if let Some(rl) = rate_limit_info { + rl.rate_limiter + .lock() + .unwrap() + .check_rate_limit_message(&rl.ip, false)?; + } + + let conn = pool.get()?; + let replies = ReplyQueryBuilder::create(&conn, user_id) .sort(&sort) .unread_only(data.unread_only) @@ -681,7 +801,12 @@ impl Perform<GetRepliesResponse> for Oper<GetReplies> { } impl Perform<GetUserMentionsResponse> for Oper<GetUserMentions> { - fn perform(&self, conn: &PgConnection) -> Result<GetUserMentionsResponse, Error> { + fn perform( + &self, + pool: Pool<ConnectionManager<PgConnection>>, + _websocket_info: Option<WebsocketInfo>, + rate_limit_info: Option<RateLimitInfo>, + ) -> Result<GetUserMentionsResponse, Error> { let data: &GetUserMentions = &self.data; let claims = match Claims::decode(&data.auth) { @@ -693,6 +818,15 @@ impl Perform<GetUserMentionsResponse> for Oper<GetUserMentions> { let sort = SortType::from_str(&data.sort)?; + if let Some(rl) = rate_limit_info { + rl.rate_limiter + .lock() + .unwrap() + .check_rate_limit_message(&rl.ip, false)?; + } + + let conn = pool.get()?; + let mentions = UserMentionQueryBuilder::create(&conn, user_id) .sort(&sort) .unread_only(data.unread_only) @@ -705,7 +839,12 @@ impl Perform<GetUserMentionsResponse> for Oper<GetUserMentions> { } impl Perform<UserMentionResponse> for Oper<EditUserMention> { - fn perform(&self, conn: &PgConnection) -> Result<UserMentionResponse, Error> { + fn perform( + &self, + pool: Pool<ConnectionManager<PgConnection>>, + _websocket_info: Option<WebsocketInfo>, + rate_limit_info: Option<RateLimitInfo>, + ) -> Result<UserMentionResponse, Error> { let data: &EditUserMention = &self.data; let claims = match Claims::decode(&data.auth) { @@ -715,6 +854,15 @@ impl Perform<UserMentionResponse> for Oper<EditUserMention> { let user_id = claims.id; + if let Some(rl) = rate_limit_info { + rl.rate_limiter + .lock() + .unwrap() + .check_rate_limit_message(&rl.ip, false)?; + } + + let conn = pool.get()?; + let user_mention = UserMention::read(&conn, data.user_mention_id)?; let user_mention_form = UserMentionForm { @@ -738,7 +886,12 @@ impl Perform<UserMentionResponse> for Oper<EditUserMention> { } impl Perform<GetRepliesResponse> for Oper<MarkAllAsRead> { - fn perform(&self, conn: &PgConnection) -> Result<GetRepliesResponse, Error> { + fn perform( + &self, + pool: Pool<ConnectionManager<PgConnection>>, + _websocket_info: Option<WebsocketInfo>, + rate_limit_info: Option<RateLimitInfo>, + ) -> Result<GetRepliesResponse, Error> { let data: &MarkAllAsRead = &self.data; let claims = match Claims::decode(&data.auth) { @@ -748,6 +901,15 @@ impl Perform<GetRepliesResponse> for Oper<MarkAllAsRead> { let user_id = claims.id; + if let Some(rl) = rate_limit_info { + rl.rate_limiter + .lock() + .unwrap() + .check_rate_limit_message(&rl.ip, false)?; + } + + let conn = pool.get()?; + let replies = ReplyQueryBuilder::create(&conn, user_id) .unread_only(true) .page(1) @@ -822,7 +984,12 @@ impl Perform<GetRepliesResponse> for Oper<MarkAllAsRead> { } impl Perform<LoginResponse> for Oper<DeleteAccount> { - fn perform(&self, conn: &PgConnection) -> Result<LoginResponse, Error> { + fn perform( + &self, + pool: Pool<ConnectionManager<PgConnection>>, + _websocket_info: Option<WebsocketInfo>, + rate_limit_info: Option<RateLimitInfo>, + ) -> Result<LoginResponse, Error> { let data: &DeleteAccount = &self.data; let claims = match Claims::decode(&data.auth) { @@ -832,6 +999,15 @@ impl Perform<LoginResponse> for Oper<DeleteAccount> { let user_id = claims.id; + if let Some(rl) = rate_limit_info { + rl.rate_limiter + .lock() + .unwrap() + .check_rate_limit_message(&rl.ip, false)?; + } + + let conn = pool.get()?; + let user: User_ = User_::read(&conn, user_id)?; // Verify the password @@ -903,9 +1079,23 @@ impl Perform<LoginResponse> for Oper<DeleteAccount> { } impl Perform<PasswordResetResponse> for Oper<PasswordReset> { - fn perform(&self, conn: &PgConnection) -> Result<PasswordResetResponse, Error> { + fn perform( + &self, + pool: Pool<ConnectionManager<PgConnection>>, + _websocket_info: Option<WebsocketInfo>, + rate_limit_info: Option<RateLimitInfo>, + ) -> Result<PasswordResetResponse, Error> { let data: &PasswordReset = &self.data; + if let Some(rl) = rate_limit_info { + rl.rate_limiter + .lock() + .unwrap() + .check_rate_limit_message(&rl.ip, false)?; + } + + let conn = pool.get()?; + // Fetch that email let user: User_ = match User_::find_by_email(&conn, &data.email) { Ok(user) => user, @@ -934,9 +1124,23 @@ impl Perform<PasswordResetResponse> for Oper<PasswordReset> { } impl Perform<LoginResponse> for Oper<PasswordChange> { - fn perform(&self, conn: &PgConnection) -> Result<LoginResponse, Error> { + fn perform( + &self, + pool: Pool<ConnectionManager<PgConnection>>, + _websocket_info: Option<WebsocketInfo>, + rate_limit_info: Option<RateLimitInfo>, + ) -> Result<LoginResponse, Error> { let data: &PasswordChange = &self.data; + if let Some(rl) = rate_limit_info { + rl.rate_limiter + .lock() + .unwrap() + .check_rate_limit_message(&rl.ip, false)?; + } + + let conn = pool.get()?; + // Fetch the user_id from the token let user_id = PasswordResetRequest::read_from_token(&conn, &data.token)?.user_id; @@ -959,7 +1163,12 @@ impl Perform<LoginResponse> for Oper<PasswordChange> { } impl Perform<PrivateMessageResponse> for Oper<CreatePrivateMessage> { - fn perform(&self, conn: &PgConnection) -> Result<PrivateMessageResponse, Error> { + fn perform( + &self, + pool: Pool<ConnectionManager<PgConnection>>, + websocket_info: Option<WebsocketInfo>, + rate_limit_info: Option<RateLimitInfo>, + ) -> Result<PrivateMessageResponse, Error> { let data: &CreatePrivateMessage = &self.data; let claims = match Claims::decode(&data.auth) { @@ -971,6 +1180,15 @@ impl Perform<PrivateMessageResponse> for Oper<CreatePrivateMessage> { let hostname = &format!("https://{}", Settings::get().hostname); + if let Some(rl) = rate_limit_info { + rl.rate_limiter + .lock() + .unwrap() + .check_rate_limit_message(&rl.ip, false)?; + } + + let conn = pool.get()?; + // Check for a site ban if UserView::read(&conn, user_id)?.banned { return Err(APIError::err("site_ban").into()); @@ -1016,12 +1234,28 @@ impl Perform<PrivateMessageResponse> for Oper<CreatePrivateMessage> { let message = PrivateMessageView::read(&conn, inserted_private_message.id)?; - Ok(PrivateMessageResponse { message }) + let res = PrivateMessageResponse { message }; + + if let Some(ws) = websocket_info { + ws.chatserver.do_send(SendUserRoomMessage { + op: UserOperation::CreatePrivateMessage, + response: res.clone(), + recipient_id: recipient_user.id, + my_id: ws.id, + }); + } + + Ok(res) } } impl Perform<PrivateMessageResponse> for Oper<EditPrivateMessage> { - fn perform(&self, conn: &PgConnection) -> Result<PrivateMessageResponse, Error> { + fn perform( + &self, + pool: Pool<ConnectionManager<PgConnection>>, + _websocket_info: Option<WebsocketInfo>, + rate_limit_info: Option<RateLimitInfo>, + ) -> Result<PrivateMessageResponse, Error> { let data: &EditPrivateMessage = &self.data; let claims = match Claims::decode(&data.auth) { @@ -1031,6 +1265,15 @@ impl Perform<PrivateMessageResponse> for Oper<EditPrivateMessage> { let user_id = claims.id; + if let Some(rl) = rate_limit_info { + rl.rate_limiter + .lock() + .unwrap() + .check_rate_limit_message(&rl.ip, false)?; + } + + let conn = pool.get()?; + let orig_private_message = PrivateMessage::read(&conn, data.edit_id)?; // Check for a site ban @@ -1076,7 +1319,12 @@ impl Perform<PrivateMessageResponse> for Oper<EditPrivateMessage> { } impl Perform<PrivateMessagesResponse> for Oper<GetPrivateMessages> { - fn perform(&self, conn: &PgConnection) -> Result<PrivateMessagesResponse, Error> { + fn perform( + &self, + pool: Pool<ConnectionManager<PgConnection>>, + _websocket_info: Option<WebsocketInfo>, + rate_limit_info: Option<RateLimitInfo>, + ) -> Result<PrivateMessagesResponse, Error> { let data: &GetPrivateMessages = &self.data; let claims = match Claims::decode(&data.auth) { @@ -1086,6 +1334,15 @@ impl Perform<PrivateMessagesResponse> for Oper<GetPrivateMessages> { let user_id = claims.id; + if let Some(rl) = rate_limit_info { + rl.rate_limiter + .lock() + .unwrap() + .check_rate_limit_message(&rl.ip, false)?; + } + + let conn = pool.get()?; + let messages = PrivateMessageQueryBuilder::create(&conn, user_id) .page(data.page) .limit(data.limit) @@ -1097,7 +1354,12 @@ impl Perform<PrivateMessagesResponse> for Oper<GetPrivateMessages> { } impl Perform<UserJoinResponse> for Oper<UserJoin> { - fn perform(&self, _conn: &PgConnection) -> Result<UserJoinResponse, Error> { + fn perform( + &self, + _pool: Pool<ConnectionManager<PgConnection>>, + websocket_info: Option<WebsocketInfo>, + _rate_limit_info: Option<RateLimitInfo>, + ) -> Result<UserJoinResponse, Error> { let data: &UserJoin = &self.data; let claims = match Claims::decode(&data.auth) { @@ -1106,6 +1368,13 @@ impl Perform<UserJoinResponse> for Oper<UserJoin> { }; let user_id = claims.id; + + if let Some(ws) = websocket_info { + if let Some(id) = ws.id { + ws.chatserver.do_send(JoinUserRoom { user_id, id }); + } + } + Ok(UserJoinResponse { user_id }) } } |