diff options
Diffstat (limited to 'server/src/api/site.rs')
-rw-r--r-- | server/src/api/site.rs | 424 |
1 files changed, 227 insertions, 197 deletions
diff --git a/server/src/api/site.rs b/server/src/api/site.rs index faee30cb..f45561a8 100644 --- a/server/src/api/site.rs +++ b/server/src/api/site.rs @@ -2,6 +2,7 @@ use super::user::Register; use crate::{ api::{APIError, Oper, Perform}, apub::fetcher::search_by_apub_id, + blocking, db::{ category::*, comment_view::*, @@ -22,12 +23,9 @@ use crate::{ slur_check, slurs_vec_to_str, websocket::{server::SendAllMessage, UserOperation, WebsocketInfo}, + DbPool, + LemmyError, }; -use diesel::{ - r2d2::{ConnectionManager, Pool}, - PgConnection, -}; -use failure::Error; use log::{debug, info}; use serde::{Deserialize, Serialize}; use std::str::FromStr; @@ -139,87 +137,79 @@ pub struct SaveSiteConfig { auth: String, } +#[async_trait::async_trait(?Send)] impl Perform for Oper<ListCategories> { type Response = ListCategoriesResponse; - fn perform( + async fn perform( &self, - pool: Pool<ConnectionManager<PgConnection>>, + pool: &DbPool, _websocket_info: Option<WebsocketInfo>, - ) -> Result<ListCategoriesResponse, Error> { + ) -> Result<ListCategoriesResponse, LemmyError> { let _data: &ListCategories = &self.data; - let conn = pool.get()?; - - let categories: Vec<Category> = Category::list_all(&conn)?; + let categories = blocking(pool, move |conn| Category::list_all(conn)).await??; // Return the jwt Ok(ListCategoriesResponse { categories }) } } +#[async_trait::async_trait(?Send)] impl Perform for Oper<GetModlog> { type Response = GetModlogResponse; - fn perform( + async fn perform( &self, - pool: Pool<ConnectionManager<PgConnection>>, + pool: &DbPool, _websocket_info: Option<WebsocketInfo>, - ) -> Result<GetModlogResponse, Error> { + ) -> Result<GetModlogResponse, LemmyError> { let data: &GetModlog = &self.data; - let conn = pool.get()?; - - let removed_posts = ModRemovePostView::list( - &conn, - data.community_id, - data.mod_user_id, - data.page, - data.limit, - )?; - let locked_posts = ModLockPostView::list( - &conn, - data.community_id, - data.mod_user_id, - data.page, - data.limit, - )?; - let stickied_posts = ModStickyPostView::list( - &conn, - data.community_id, - data.mod_user_id, - data.page, - data.limit, - )?; - let removed_comments = ModRemoveCommentView::list( - &conn, - data.community_id, - data.mod_user_id, - data.page, - data.limit, - )?; - let banned_from_community = ModBanFromCommunityView::list( - &conn, - data.community_id, - data.mod_user_id, - data.page, - data.limit, - )?; - let added_to_community = ModAddCommunityView::list( - &conn, - data.community_id, - data.mod_user_id, - data.page, - data.limit, - )?; + let community_id = data.community_id; + let mod_user_id = data.mod_user_id; + let page = data.page; + let limit = data.limit; + let removed_posts = blocking(pool, move |conn| { + ModRemovePostView::list(conn, community_id, mod_user_id, page, limit) + }) + .await??; + + let locked_posts = blocking(pool, move |conn| { + ModLockPostView::list(conn, community_id, mod_user_id, page, limit) + }) + .await??; + + let stickied_posts = blocking(pool, move |conn| { + ModStickyPostView::list(conn, community_id, mod_user_id, page, limit) + }) + .await??; + + let removed_comments = blocking(pool, move |conn| { + ModRemoveCommentView::list(conn, community_id, mod_user_id, page, limit) + }) + .await??; + + let banned_from_community = blocking(pool, move |conn| { + ModBanFromCommunityView::list(conn, community_id, mod_user_id, page, limit) + }) + .await??; + + let added_to_community = blocking(pool, move |conn| { + ModAddCommunityView::list(conn, community_id, mod_user_id, page, limit) + }) + .await??; // These arrays are only for the full modlog, when a community isn't given let (removed_communities, banned, added) = if data.community_id.is_none() { - ( - ModRemoveCommunityView::list(&conn, data.mod_user_id, data.page, data.limit)?, - ModBanView::list(&conn, data.mod_user_id, data.page, data.limit)?, - ModAddView::list(&conn, data.mod_user_id, data.page, data.limit)?, - ) + blocking(pool, move |conn| { + Ok(( + ModRemoveCommunityView::list(conn, mod_user_id, page, limit)?, + ModBanView::list(conn, mod_user_id, page, limit)?, + ModAddView::list(conn, mod_user_id, page, limit)?, + )) as Result<_, LemmyError> + }) + .await?? } else { (Vec::new(), Vec::new(), Vec::new()) }; @@ -239,14 +229,15 @@ impl Perform for Oper<GetModlog> { } } +#[async_trait::async_trait(?Send)] impl Perform for Oper<CreateSite> { type Response = SiteResponse; - fn perform( + async fn perform( &self, - pool: Pool<ConnectionManager<PgConnection>>, + pool: &DbPool, _websocket_info: Option<WebsocketInfo>, - ) -> Result<SiteResponse, Error> { + ) -> Result<SiteResponse, LemmyError> { let data: &CreateSite = &self.data; let claims = match Claims::decode(&data.auth) { @@ -266,10 +257,9 @@ impl Perform for Oper<CreateSite> { let user_id = claims.id; - let conn = pool.get()?; - // Make sure user is an admin - if !UserView::read(&conn, user_id)?.admin { + let user = blocking(pool, move |conn| UserView::read(conn, user_id)).await??; + if !user.admin { return Err(APIError::err("not_an_admin").into()); } @@ -283,24 +273,25 @@ impl Perform for Oper<CreateSite> { updated: None, }; - match Site::create(&conn, &site_form) { - Ok(site) => site, - Err(_e) => return Err(APIError::err("site_already_exists").into()), - }; + let create_site = move |conn: &'_ _| Site::create(conn, &site_form); + if blocking(pool, create_site).await?.is_err() { + return Err(APIError::err("site_already_exists").into()); + } - let site_view = SiteView::read(&conn)?; + let site_view = blocking(pool, move |conn| SiteView::read(conn)).await??; Ok(SiteResponse { site: site_view }) } } +#[async_trait::async_trait(?Send)] impl Perform for Oper<EditSite> { type Response = SiteResponse; - fn perform( + async fn perform( &self, - pool: Pool<ConnectionManager<PgConnection>>, + pool: &DbPool, websocket_info: Option<WebsocketInfo>, - ) -> Result<SiteResponse, Error> { + ) -> Result<SiteResponse, LemmyError> { let data: &EditSite = &self.data; let claims = match Claims::decode(&data.auth) { @@ -320,14 +311,13 @@ impl Perform for Oper<EditSite> { let user_id = claims.id; - let conn = pool.get()?; - // Make sure user is an admin - if !UserView::read(&conn, user_id)?.admin { + let user = blocking(pool, move |conn| UserView::read(conn, user_id)).await??; + if !user.admin { return Err(APIError::err("not_an_admin").into()); } - let found_site = Site::read(&conn, 1)?; + let found_site = blocking(pool, move |conn| Site::read(conn, 1)).await??; let site_form = SiteForm { name: data.name.to_owned(), @@ -339,12 +329,12 @@ impl Perform for Oper<EditSite> { enable_nsfw: data.enable_nsfw, }; - match Site::update(&conn, 1, &site_form) { - Ok(site) => site, - Err(_e) => return Err(APIError::err("couldnt_update_site").into()), - }; + let update_site = move |conn: &'_ _| Site::update(conn, 1, &site_form); + if blocking(pool, update_site).await?.is_err() { + return Err(APIError::err("couldnt_update_site").into()); + } - let site_view = SiteView::read(&conn)?; + let site_view = blocking(pool, move |conn| SiteView::read(conn)).await??; let res = SiteResponse { site: site_view }; @@ -360,21 +350,21 @@ impl Perform for Oper<EditSite> { } } +#[async_trait::async_trait(?Send)] impl Perform for Oper<GetSite> { type Response = GetSiteResponse; - fn perform( + async fn perform( &self, - pool: Pool<ConnectionManager<PgConnection>>, + pool: &DbPool, websocket_info: Option<WebsocketInfo>, - ) -> Result<GetSiteResponse, Error> { + ) -> Result<GetSiteResponse, LemmyError> { let _data: &GetSite = &self.data; - let conn = pool.get()?; - // TODO refactor this a little - let site_view = if let Ok(_site) = Site::read(&conn, 1) { - Some(SiteView::read(&conn)?) + let res = blocking(pool, move |conn| Site::read(conn, 1)).await?; + let site_view = if res.is_ok() { + Some(blocking(pool, move |conn| SiteView::read(conn)).await??) } else if let Some(setup) = Settings::get().setup.as_ref() { let register = Register { username: setup.admin_username.to_owned(), @@ -384,7 +374,9 @@ impl Perform for Oper<GetSite> { admin: true, show_nsfw: true, }; - let login_response = Oper::new(register).perform(pool.clone(), websocket_info.clone())?; + let login_response = Oper::new(register, self.client.clone()) + .perform(pool, websocket_info.clone()) + .await?; info!("Admin {} created", setup.admin_username); let create_site = CreateSite { @@ -395,14 +387,16 @@ impl Perform for Oper<GetSite> { enable_nsfw: true, auth: login_response.jwt, }; - Oper::new(create_site).perform(pool, websocket_info.clone())?; + Oper::new(create_site, self.client.clone()) + .perform(pool, websocket_info.clone()) + .await?; info!("Site {} created", setup.site_name); - Some(SiteView::read(&conn)?) + Some(blocking(pool, move |conn| SiteView::read(conn)).await??) } else { None }; - let mut admins = UserView::admins(&conn)?; + let mut admins = blocking(pool, move |conn| UserView::admins(conn)).await??; // Make sure the site creator is the top admin if let Some(site_view) = site_view.to_owned() { @@ -415,7 +409,7 @@ impl Perform for Oper<GetSite> { } } - let banned = UserView::banned(&conn)?; + let banned = blocking(pool, move |conn| UserView::banned(conn)).await??; let online = if let Some(_ws) = websocket_info { // TODO @@ -437,21 +431,20 @@ impl Perform for Oper<GetSite> { } } +#[async_trait::async_trait(?Send)] impl Perform for Oper<Search> { type Response = SearchResponse; - fn perform( + async fn perform( &self, - pool: Pool<ConnectionManager<PgConnection>>, + pool: &DbPool, _websocket_info: Option<WebsocketInfo>, - ) -> Result<SearchResponse, Error> { + ) -> Result<SearchResponse, LemmyError> { let data: &Search = &self.data; dbg!(&data); - let conn = pool.get()?; - - match search_by_apub_id(&data.q, &conn) { + match search_by_apub_id(&data.q, &self.client, pool).await { Ok(r) => return Ok(r), Err(e) => debug!("Failed to resolve search query as activitypub ID: {}", e), } @@ -467,7 +460,6 @@ impl Perform for Oper<Search> { None => None, }; - let sort = SortType::from_str(&data.sort)?; let type_ = SearchType::from_str(&data.type_)?; let mut posts = Vec::new(); @@ -477,85 +469,126 @@ impl Perform for Oper<Search> { // TODO no clean / non-nsfw searching rn + let q = data.q.to_owned(); + let page = data.page; + let limit = data.limit; + let sort = SortType::from_str(&data.sort)?; + let community_id = data.community_id; match type_ { SearchType::Posts => { - posts = PostQueryBuilder::create(&conn) - .sort(&sort) - .show_nsfw(true) - .for_community_id(data.community_id) - .search_term(data.q.to_owned()) - .my_user_id(user_id) - .page(data.page) - .limit(data.limit) - .list()?; + posts = blocking(pool, move |conn| { + PostQueryBuilder::create(conn) + .sort(&sort) + .show_nsfw(true) + .for_community_id(community_id) + .search_term(q) + .my_user_id(user_id) + .page(page) + .limit(limit) + .list() + }) + .await??; } SearchType::Comments => { - comments = CommentQueryBuilder::create(&conn) - .sort(&sort) - .search_term(data.q.to_owned()) - .my_user_id(user_id) - .page(data.page) - .limit(data.limit) - .list()?; + comments = blocking(pool, move |conn| { + CommentQueryBuilder::create(&conn) + .sort(&sort) + .search_term(q) + .my_user_id(user_id) + .page(page) + .limit(limit) + .list() + }) + .await??; } SearchType::Communities => { - communities = CommunityQueryBuilder::create(&conn) - .sort(&sort) - .search_term(data.q.to_owned()) - .page(data.page) - .limit(data.limit) - .list()?; + communities = blocking(pool, move |conn| { + CommunityQueryBuilder::create(conn) + .sort(&sort) + .search_term(q) + .page(page) + .limit(limit) + .list() + }) + .await??; } SearchType::Users => { - users = UserQueryBuilder::create(&conn) - .sort(&sort) - .search_term(data.q.to_owned()) - .page(data.page) - .limit(data.limit) - .list()?; + users = blocking(pool, move |conn| { + UserQueryBuilder::create(conn) + .sort(&sort) + .search_term(q) + .page(page) + .limit(limit) + .list() + }) + .await??; } SearchType::All => { - posts = PostQueryBuilder::create(&conn) - .sort(&sort) - .show_nsfw(true) - .for_community_id(data.community_id) - .search_term(data.q.to_owned()) - .my_user_id(user_id) - .page(data.page) - .limit(data.limit) - .list()?; - - comments = CommentQueryBuilder::create(&conn) - .sort(&sort) - .search_term(data.q.to_owned()) - .my_user_id(user_id) - .page(data.page) - .limit(data.limit) - .list()?; - - communities = CommunityQueryBuilder::create(&conn) - .sort(&sort) - .search_term(data.q.to_owned()) - .page(data.page) - .limit(data.limit) - .list()?; - - users = UserQueryBuilder::create(&conn) - .sort(&sort) - .search_term(data.q.to_owned()) - .page(data.page) - .limit(data.limit) - .list()?; + posts = blocking(pool, move |conn| { + PostQueryBuilder::create(conn) + .sort(&sort) + .show_nsfw(true) + .for_community_id(community_id) + .search_term(q) + .my_user_id(user_id) + .page(page) + .limit(limit) + .list() + }) + .await??; + + let q = data.q.to_owned(); + let sort = SortType::from_str(&data.sort)?; + + comments = blocking(pool, move |conn| { + CommentQueryBuilder::create(conn) + .sort(&sort) + .search_term(q) + .my_user_id(user_id) + .page(page) + .limit(limit) + .list() + }) + .await??; + + let q = data.q.to_owned(); + let sort = SortType::from_str(&data.sort)?; + + communities = blocking(pool, move |conn| { + CommunityQueryBuilder::create(conn) + .sort(&sort) + .search_term(q) + .page(page) + .limit(limit) + .list() + }) + .await??; + + let q = data.q.to_owned(); + let sort = SortType::from_str(&data.sort)?; + + users = blocking(pool, move |conn| { + UserQueryBuilder::create(conn) + .sort(&sort) + .search_term(q) + .page(page) + .limit(limit) + .list() + }) + .await??; } SearchType::Url => { - posts = PostQueryBuilder::create(&conn) - .sort(&sort) - .show_nsfw(true) - .for_community_id(data.community_id) - .url_search(data.q.to_owned()) - .page(data.page) - .limit(data.limit) - .list()?; + posts = blocking(pool, move |conn| { + PostQueryBuilder::create(conn) + .sort(&sort) + .show_nsfw(true) + .for_community_id(community_id) + .url_search(q) + .page(page) + .limit(limit) + .list() + }) + .await??; } }; @@ -570,14 +603,15 @@ impl Perform for Oper<Search> { } } +#[async_trait::async_trait(?Send)] impl Perform for Oper<TransferSite> { type Response = GetSiteResponse; - fn perform( + async fn perform( &self, - pool: Pool<ConnectionManager<PgConnection>>, + pool: &DbPool, _websocket_info: Option<WebsocketInfo>, - ) -> Result<GetSiteResponse, Error> { + ) -> Result<GetSiteResponse, LemmyError> { let data: &TransferSite = &self.data; let claims = match Claims::decode(&data.auth) { @@ -587,9 +621,7 @@ impl Perform for Oper<TransferSite> { let user_id = claims.id; - let conn = pool.get()?; - - let read_site = Site::read(&conn, 1)?; + let read_site = blocking(pool, move |conn| Site::read(conn, 1)).await??; // Make sure user is the creator if read_site.creator_id != user_id { @@ -606,9 +638,9 @@ impl Perform for Oper<TransferSite> { enable_nsfw: read_site.enable_nsfw, }; - match Site::update(&conn, 1, &site_form) { - Ok(site) => site, - Err(_e) => return Err(APIError::err("couldnt_update_site").into()), + let update_site = move |conn: &'_ _| Site::update(conn, 1, &site_form); + if blocking(pool, update_site).await?.is_err() { + return Err(APIError::err("couldnt_update_site").into()); }; // Mod tables @@ -618,11 +650,11 @@ impl Perform for Oper<TransferSite> { removed: Some(false), }; - ModAdd::create(&conn, &form)?; + blocking(pool, move |conn| ModAdd::create(conn, &form)).await??; - let site_view = SiteView::read(&conn)?; + let site_view = blocking(pool, move |conn| SiteView::read(conn)).await??; - let mut admins = UserView::admins(&conn)?; + let mut admins = blocking(pool, move |conn| UserView::admins(conn)).await??; let creator_index = admins .iter() .position(|r| r.id == site_view.creator_id) @@ -630,7 +662,7 @@ impl Perform for Oper<TransferSite> { let creator_user = admins.remove(creator_index); admins.insert(0, creator_user); - let banned = UserView::banned(&conn)?; + let banned = blocking(pool, move |conn| UserView::banned(conn)).await??; Ok(GetSiteResponse { site: Some(site_view), @@ -641,14 +673,15 @@ impl Perform for Oper<TransferSite> { } } +#[async_trait::async_trait(?Send)] impl Perform for Oper<GetSiteConfig> { type Response = GetSiteConfigResponse; - fn perform( + async fn perform( &self, - pool: Pool<ConnectionManager<PgConnection>>, + pool: &DbPool, _websocket_info: Option<WebsocketInfo>, - ) -> Result<GetSiteConfigResponse, Error> { + ) -> Result<GetSiteConfigResponse, LemmyError> { let data: &GetSiteConfig = &self.data; let claims = match Claims::decode(&data.auth) { @@ -658,10 +691,8 @@ impl Perform for Oper<GetSiteConfig> { let user_id = claims.id; - let conn = pool.get()?; - // Only let admins read this - let admins = UserView::admins(&conn)?; + let admins = blocking(pool, move |conn| UserView::admins(conn)).await??; let admin_ids: Vec<i32> = admins.into_iter().map(|m| m.id).collect(); if !admin_ids.contains(&user_id) { @@ -674,14 +705,15 @@ impl Perform for Oper<GetSiteConfig> { } } +#[async_trait::async_trait(?Send)] impl Perform for Oper<SaveSiteConfig> { type Response = GetSiteConfigResponse; - fn perform( + async fn perform( &self, - pool: Pool<ConnectionManager<PgConnection>>, + pool: &DbPool, _websocket_info: Option<WebsocketInfo>, - ) -> Result<GetSiteConfigResponse, Error> { + ) -> Result<GetSiteConfigResponse, LemmyError> { let data: &SaveSiteConfig = &self.data; let claims = match Claims::decode(&data.auth) { @@ -691,10 +723,8 @@ impl Perform for Oper<SaveSiteConfig> { let user_id = claims.id; - let conn = pool.get()?; - // Only let admins read this - let admins = UserView::admins(&conn)?; + let admins = blocking(pool, move |conn| UserView::admins(conn)).await??; let admin_ids: Vec<i32> = admins.into_iter().map(|m| m.id).collect(); if !admin_ids.contains(&user_id) { |