summaryrefslogtreecommitdiffstats
path: root/atuin-server/src/handlers
diff options
context:
space:
mode:
authorErwin Kroon <123574+ekroon@users.noreply.github.com>2023-02-15 09:54:09 +0100
committerGitHub <noreply@github.com>2023-02-15 08:54:09 +0000
commitdcfad9a90d4b18465c437811698c879c4c5a9fcc (patch)
tree327990925add8c97cc02c61d5607628921e81b96 /atuin-server/src/handlers
parent7e7dd63966cd8eae7ec61f7e419e0b72a45ff580 (diff)
Add support for generic database in AppState (#711)
Diffstat (limited to 'atuin-server/src/handlers')
-rw-r--r--atuin-server/src/handlers/history.rs24
-rw-r--r--atuin-server/src/handlers/user.rs18
2 files changed, 21 insertions, 21 deletions
diff --git a/atuin-server/src/handlers/history.rs b/atuin-server/src/handlers/history.rs
index 9ee13e16..7cf18323 100644
--- a/atuin-server/src/handlers/history.rs
+++ b/atuin-server/src/handlers/history.rs
@@ -18,11 +18,11 @@ use crate::{
use atuin_common::api::*;
#[instrument(skip_all, fields(user.id = user.id))]
-pub async fn count(
+pub async fn count<DB: Database>(
user: User,
- state: State<AppState>,
+ state: State<AppState<DB>>,
) -> Result<Json<CountResponse>, ErrorResponseStatus<'static>> {
- let db = &state.0.postgres;
+ let db = &state.0.database;
match db.count_history_cached(&user).await {
// By default read out the cached value
Ok(count) => Ok(Json(CountResponse { count })),
@@ -38,12 +38,12 @@ pub async fn count(
}
#[instrument(skip_all, fields(user.id = user.id))]
-pub async fn list(
+pub async fn list<DB: Database>(
req: Query<SyncHistoryRequest>,
user: User,
- state: State<AppState>,
+ state: State<AppState<DB>>,
) -> Result<Json<SyncHistoryResponse>, ErrorResponseStatus<'static>> {
- let db = &state.0.postgres;
+ let db = &state.0.database;
let history = db
.list_history(
&user,
@@ -75,9 +75,9 @@ pub async fn list(
}
#[instrument(skip_all, fields(user.id = user.id))]
-pub async fn add(
+pub async fn add<DB: Database>(
user: User,
- state: State<AppState>,
+ state: State<AppState<DB>>,
Json(req): Json<Vec<AddHistoryRequest>>,
) -> Result<(), ErrorResponseStatus<'static>> {
debug!("request to add {} history items", req.len());
@@ -93,7 +93,7 @@ pub async fn add(
})
.collect();
- let db = &state.0.postgres;
+ let db = &state.0.database;
if let Err(e) = db.add_history(&history).await {
error!("failed to add history: {}", e);
@@ -105,18 +105,18 @@ pub async fn add(
}
#[instrument(skip_all, fields(user.id = user.id))]
-pub async fn calendar(
+pub async fn calendar<DB: Database>(
Path(focus): Path<String>,
Query(params): Query<HashMap<String, u64>>,
user: User,
- state: State<AppState>,
+ state: State<AppState<DB>>,
) -> Result<Json<HashMap<u64, TimePeriodInfo>>, ErrorResponseStatus<'static>> {
let focus = focus.as_str();
let year = params.get("year").unwrap_or(&0);
let month = params.get("month").unwrap_or(&1);
- let db = &state.0.postgres;
+ let db = &state.0.database;
let focus = match focus {
"year" => db
.calendar(&user, TimePeriod::YEAR, *year, *month)
diff --git a/atuin-server/src/handlers/user.rs b/atuin-server/src/handlers/user.rs
index 761724c5..677e7c65 100644
--- a/atuin-server/src/handlers/user.rs
+++ b/atuin-server/src/handlers/user.rs
@@ -34,11 +34,11 @@ pub fn verify_str(secret: &str, verify: &str) -> bool {
}
#[instrument(skip_all, fields(user.username = username.as_str()))]
-pub async fn get(
+pub async fn get<DB: Database>(
Path(username): Path<String>,
- state: State<AppState>,
+ state: State<AppState<DB>>,
) -> Result<Json<UserResponse>, ErrorResponseStatus<'static>> {
- let db = &state.0.postgres;
+ let db = &state.0.database;
let user = match db.get_user(username.as_ref()).await {
Ok(user) => user,
Err(sqlx::Error::RowNotFound) => {
@@ -58,9 +58,9 @@ pub async fn get(
}
#[instrument(skip_all)]
-pub async fn register(
+pub async fn register<DB: Database>(
settings: Extension<Settings>,
- state: State<AppState>,
+ state: State<AppState<DB>>,
Json(register): Json<RegisterRequest>,
) -> Result<Json<RegisterResponse>, ErrorResponseStatus<'static>> {
if !settings.open_registration {
@@ -78,7 +78,7 @@ pub async fn register(
password: hashed,
};
- let db = &state.0.postgres;
+ let db = &state.0.database;
let user_id = match db.add_user(&new_user).await {
Ok(id) => id,
Err(e) => {
@@ -107,11 +107,11 @@ pub async fn register(
}
#[instrument(skip_all, fields(user.username = login.username.as_str()))]
-pub async fn login(
- state: State<AppState>,
+pub async fn login<DB: Database>(
+ state: State<AppState<DB>>,
login: Json<LoginRequest>,
) -> Result<Json<LoginResponse>, ErrorResponseStatus<'static>> {
- let db = &state.0.postgres;
+ let db = &state.0.database;
let user = match db.get_user(login.username.borrow()).await {
Ok(u) => u,
Err(sqlx::Error::RowNotFound) => {