diff options
author | Erwin Kroon <123574+ekroon@users.noreply.github.com> | 2023-02-15 09:54:09 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-02-15 08:54:09 +0000 |
commit | dcfad9a90d4b18465c437811698c879c4c5a9fcc (patch) | |
tree | 327990925add8c97cc02c61d5607628921e81b96 /atuin-server/src/handlers | |
parent | 7e7dd63966cd8eae7ec61f7e419e0b72a45ff580 (diff) |
Add support for generic database in AppState (#711)
Diffstat (limited to 'atuin-server/src/handlers')
-rw-r--r-- | atuin-server/src/handlers/history.rs | 24 | ||||
-rw-r--r-- | atuin-server/src/handlers/user.rs | 18 |
2 files changed, 21 insertions, 21 deletions
diff --git a/atuin-server/src/handlers/history.rs b/atuin-server/src/handlers/history.rs index 9ee13e16..7cf18323 100644 --- a/atuin-server/src/handlers/history.rs +++ b/atuin-server/src/handlers/history.rs @@ -18,11 +18,11 @@ use crate::{ use atuin_common::api::*; #[instrument(skip_all, fields(user.id = user.id))] -pub async fn count( +pub async fn count<DB: Database>( user: User, - state: State<AppState>, + state: State<AppState<DB>>, ) -> Result<Json<CountResponse>, ErrorResponseStatus<'static>> { - let db = &state.0.postgres; + let db = &state.0.database; match db.count_history_cached(&user).await { // By default read out the cached value Ok(count) => Ok(Json(CountResponse { count })), @@ -38,12 +38,12 @@ pub async fn count( } #[instrument(skip_all, fields(user.id = user.id))] -pub async fn list( +pub async fn list<DB: Database>( req: Query<SyncHistoryRequest>, user: User, - state: State<AppState>, + state: State<AppState<DB>>, ) -> Result<Json<SyncHistoryResponse>, ErrorResponseStatus<'static>> { - let db = &state.0.postgres; + let db = &state.0.database; let history = db .list_history( &user, @@ -75,9 +75,9 @@ pub async fn list( } #[instrument(skip_all, fields(user.id = user.id))] -pub async fn add( +pub async fn add<DB: Database>( user: User, - state: State<AppState>, + state: State<AppState<DB>>, Json(req): Json<Vec<AddHistoryRequest>>, ) -> Result<(), ErrorResponseStatus<'static>> { debug!("request to add {} history items", req.len()); @@ -93,7 +93,7 @@ pub async fn add( }) .collect(); - let db = &state.0.postgres; + let db = &state.0.database; if let Err(e) = db.add_history(&history).await { error!("failed to add history: {}", e); @@ -105,18 +105,18 @@ pub async fn add( } #[instrument(skip_all, fields(user.id = user.id))] -pub async fn calendar( +pub async fn calendar<DB: Database>( Path(focus): Path<String>, Query(params): Query<HashMap<String, u64>>, user: User, - state: State<AppState>, + state: State<AppState<DB>>, ) -> Result<Json<HashMap<u64, TimePeriodInfo>>, ErrorResponseStatus<'static>> { let focus = focus.as_str(); let year = params.get("year").unwrap_or(&0); let month = params.get("month").unwrap_or(&1); - let db = &state.0.postgres; + let db = &state.0.database; let focus = match focus { "year" => db .calendar(&user, TimePeriod::YEAR, *year, *month) diff --git a/atuin-server/src/handlers/user.rs b/atuin-server/src/handlers/user.rs index 761724c5..677e7c65 100644 --- a/atuin-server/src/handlers/user.rs +++ b/atuin-server/src/handlers/user.rs @@ -34,11 +34,11 @@ pub fn verify_str(secret: &str, verify: &str) -> bool { } #[instrument(skip_all, fields(user.username = username.as_str()))] -pub async fn get( +pub async fn get<DB: Database>( Path(username): Path<String>, - state: State<AppState>, + state: State<AppState<DB>>, ) -> Result<Json<UserResponse>, ErrorResponseStatus<'static>> { - let db = &state.0.postgres; + let db = &state.0.database; let user = match db.get_user(username.as_ref()).await { Ok(user) => user, Err(sqlx::Error::RowNotFound) => { @@ -58,9 +58,9 @@ pub async fn get( } #[instrument(skip_all)] -pub async fn register( +pub async fn register<DB: Database>( settings: Extension<Settings>, - state: State<AppState>, + state: State<AppState<DB>>, Json(register): Json<RegisterRequest>, ) -> Result<Json<RegisterResponse>, ErrorResponseStatus<'static>> { if !settings.open_registration { @@ -78,7 +78,7 @@ pub async fn register( password: hashed, }; - let db = &state.0.postgres; + let db = &state.0.database; let user_id = match db.add_user(&new_user).await { Ok(id) => id, Err(e) => { @@ -107,11 +107,11 @@ pub async fn register( } #[instrument(skip_all, fields(user.username = login.username.as_str()))] -pub async fn login( - state: State<AppState>, +pub async fn login<DB: Database>( + state: State<AppState<DB>>, login: Json<LoginRequest>, ) -> Result<Json<LoginResponse>, ErrorResponseStatus<'static>> { - let db = &state.0.postgres; + let db = &state.0.database; let user = match db.get_user(login.username.borrow()).await { Ok(u) => u, Err(sqlx::Error::RowNotFound) => { |