diff options
author | Conrad Ludgate <conrad.ludgate@truelayer.com> | 2022-04-12 23:06:19 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-04-12 23:06:19 +0100 |
commit | a95018cc9039851e707973bc19faf907132ae4f3 (patch) | |
tree | e135f1da64c5d020f336d437f83a333298861ca0 /atuin-server | |
parent | 3b7ed7caffdbedfd30b022b8e2b3f93a2b6a494a (diff) |
goodbye warp, hello axum (#296)
Diffstat (limited to 'atuin-server')
-rw-r--r-- | atuin-server/Cargo.toml | 3 | ||||
-rw-r--r-- | atuin-server/src/handlers/history.rs | 53 | ||||
-rw-r--r-- | atuin-server/src/handlers/mod.rs | 2 | ||||
-rw-r--r-- | atuin-server/src/handlers/user.rs | 72 | ||||
-rw-r--r-- | atuin-server/src/lib.rs | 20 | ||||
-rw-r--r-- | atuin-server/src/models.rs | 22 | ||||
-rw-r--r-- | atuin-server/src/router.rs | 171 |
7 files changed, 138 insertions, 205 deletions
diff --git a/atuin-server/Cargo.toml b/atuin-server/Cargo.toml index e1acc97b..16a9fa0b 100644 --- a/atuin-server/Cargo.toml +++ b/atuin-server/Cargo.toml @@ -25,6 +25,7 @@ base64 = "0.13.0" rand = "0.8.4" rust-crypto = "^0.2" tokio = { version = "1", features = ["full"] } -warp = "0.3" sqlx = { version = "0.5", features = [ "runtime-tokio-rustls", "uuid", "chrono", "postgres" ] } async-trait = "0.1.49" +axum = "0.5" +http = "0.2" diff --git a/atuin-server/src/handlers/history.rs b/atuin-server/src/handlers/history.rs index 06715381..546e5a29 100644 --- a/atuin-server/src/handlers/history.rs +++ b/atuin-server/src/handlers/history.rs @@ -1,26 +1,27 @@ -use warp::{http::StatusCode, Reply}; +use axum::extract::Query; +use axum::{Extension, Json}; +use http::StatusCode; -use crate::database::Database; +use crate::database::{Database, Postgres}; use crate::models::{NewHistory, User}; use atuin_common::api::*; + pub async fn count( user: User, - db: impl Database + Clone + Send + Sync, -) -> JSONResult<ErrorResponseStatus<'static>> { - db.count_history(&user).await.map_or( - reply_error( - ErrorResponse::reply("failed to query history count") - .with_status(StatusCode::INTERNAL_SERVER_ERROR), - ), - |count| reply_json(CountResponse { count }), - ) + db: Extension<Postgres>, +) -> Result<Json<CountResponse>, ErrorResponseStatus<'static>> { + match db.count_history(&user).await { + Ok(count) => Ok(Json(CountResponse { count })), + Err(_) => Err(ErrorResponse::reply("failed to query history count") + .with_status(StatusCode::INTERNAL_SERVER_ERROR)), + } } pub async fn list( - req: SyncHistoryRequest<'_>, + req: Query<SyncHistoryRequest>, user: User, - db: impl Database + Clone + Send + Sync, -) -> JSONResult<ErrorResponseStatus<'static>> { + db: Extension<Postgres>, +) -> Result<Json<SyncHistoryResponse>, ErrorResponseStatus<'static>> { let history = db .list_history( &user, @@ -32,10 +33,8 @@ pub async fn list( if let Err(e) = history { error!("failed to load history: {}", e); - return reply_error( - ErrorResponse::reply("failed to load history") - .with_status(StatusCode::INTERNAL_SERVER_ERROR), - ); + return Err(ErrorResponse::reply("failed to load history") + .with_status(StatusCode::INTERNAL_SERVER_ERROR)); } let history: Vec<String> = history @@ -50,14 +49,14 @@ pub async fn list( user.id ); - reply_json(SyncHistoryResponse { history }) + Ok(Json(SyncHistoryResponse { history })) } pub async fn add( - req: Vec<AddHistoryRequest<'_, String>>, + Json(req): Json<Vec<AddHistoryRequest>>, user: User, - db: impl Database + Clone + Send + Sync, -) -> ReplyResult<impl Reply, ErrorResponseStatus<'_>> { + db: Extension<Postgres>, +) -> Result<(), ErrorResponseStatus<'static>> { debug!("request to add {} history items", req.len()); let history: Vec<NewHistory> = req @@ -67,18 +66,16 @@ pub async fn add( user_id: user.id, hostname: h.hostname, timestamp: h.timestamp.naive_utc(), - data: h.data.into(), + data: h.data, }) .collect(); if let Err(e) = db.add_history(&history).await { error!("failed to add history: {}", e); - return reply_error( - ErrorResponse::reply("failed to add history") - .with_status(StatusCode::INTERNAL_SERVER_ERROR), - ); + return Err(ErrorResponse::reply("failed to add history") + .with_status(StatusCode::INTERNAL_SERVER_ERROR)); }; - reply(warp::reply()) + Ok(()) } diff --git a/atuin-server/src/handlers/mod.rs b/atuin-server/src/handlers/mod.rs index 3c20538c..83c2d0c3 100644 --- a/atuin-server/src/handlers/mod.rs +++ b/atuin-server/src/handlers/mod.rs @@ -1,6 +1,6 @@ pub mod history; pub mod user; -pub const fn index() -> &'static str { +pub async fn index() -> &'static str { "\"Through the fathomless deeps of space swims the star turtle Great A\u{2019}Tuin, bearing on its back the four giant elephants who carry on their shoulders the mass of the Discworld.\"\n\t-- Sir Terry Pratchett" } diff --git a/atuin-server/src/handlers/user.rs b/atuin-server/src/handlers/user.rs index 8144adab..1bcfce2f 100644 --- a/atuin-server/src/handlers/user.rs +++ b/atuin-server/src/handlers/user.rs @@ -2,11 +2,13 @@ use std::borrow::Borrow; use atuin_common::api::*; use atuin_common::utils::hash_secret; +use axum::extract::Path; +use axum::{Extension, Json}; +use http::StatusCode; use sodiumoxide::crypto::pwhash::argon2id13; use uuid::Uuid; -use warp::http::StatusCode; -use crate::database::Database; +use crate::database::{Database, Postgres}; use crate::models::{NewSession, NewUser}; use crate::settings::Settings; @@ -25,31 +27,29 @@ pub fn verify_str(secret: &str, verify: &str) -> bool { } pub async fn get( - username: impl AsRef<str>, - db: impl Database + Clone + Send + Sync, -) -> JSONResult<ErrorResponseStatus<'static>> { + Path(username): Path<String>, + db: Extension<Postgres>, +) -> Result<Json<UserResponse>, ErrorResponseStatus<'static>> { let user = match db.get_user(username.as_ref()).await { Ok(user) => user, Err(e) => { debug!("user not found: {}", e); - return reply_error( - ErrorResponse::reply("user not found").with_status(StatusCode::NOT_FOUND), - ); + return Err(ErrorResponse::reply("user not found").with_status(StatusCode::NOT_FOUND)); } }; - reply_json(UserResponse { - username: user.username.into(), - }) + Ok(Json(UserResponse { + username: user.username, + })) } pub async fn register( - register: RegisterRequest<'_>, - settings: Settings, - db: impl Database + Clone + Send + Sync, -) -> JSONResult<ErrorResponseStatus<'static>> { + Json(register): Json<RegisterRequest>, + settings: Extension<Settings>, + db: Extension<Postgres>, +) -> Result<Json<RegisterResponse>, ErrorResponseStatus<'static>> { if !settings.open_registration { - return reply_error( + return Err( ErrorResponse::reply("this server is not open for registrations") .with_status(StatusCode::BAD_REQUEST), ); @@ -60,15 +60,15 @@ pub async fn register( let new_user = NewUser { email: register.email, username: register.username, - password: hashed.into(), + password: hashed, }; let user_id = match db.add_user(&new_user).await { Ok(id) => id, Err(e) => { error!("failed to add user: {}", e); - return reply_error( - ErrorResponse::reply("failed to add user").with_status(StatusCode::BAD_REQUEST), + return Err( + ErrorResponse::reply("failed to add user").with_status(StatusCode::BAD_REQUEST) ); } }; @@ -81,31 +81,25 @@ pub async fn register( }; match db.add_session(&new_session).await { - Ok(_) => reply_json(RegisterResponse { - session: token.into(), - }), + Ok(_) => Ok(Json(RegisterResponse { session: token })), Err(e) => { error!("failed to add session: {}", e); - reply_error( - ErrorResponse::reply("failed to register user") - .with_status(StatusCode::BAD_REQUEST), - ) + Err(ErrorResponse::reply("failed to register user") + .with_status(StatusCode::BAD_REQUEST)) } } } pub async fn login( - login: LoginRequest<'_>, - db: impl Database + Clone + Send + Sync, -) -> JSONResult<ErrorResponseStatus<'_>> { + login: Json<LoginRequest>, + db: Extension<Postgres>, +) -> Result<Json<LoginResponse>, ErrorResponseStatus<'static>> { let user = match db.get_user(login.username.borrow()).await { Ok(u) => u, Err(e) => { error!("failed to get user {}: {}", login.username.clone(), e); - return reply_error( - ErrorResponse::reply("user not found").with_status(StatusCode::NOT_FOUND), - ); + return Err(ErrorResponse::reply("user not found").with_status(StatusCode::NOT_FOUND)); } }; @@ -114,21 +108,17 @@ pub async fn login( Err(e) => { error!("failed to get session for {}: {}", login.username, e); - return reply_error( - ErrorResponse::reply("user not found").with_status(StatusCode::NOT_FOUND), - ); + return Err(ErrorResponse::reply("user not found").with_status(StatusCode::NOT_FOUND)); } }; let verified = verify_str(user.password.as_str(), login.password.borrow()); if !verified { - return reply_error( - ErrorResponse::reply("user not found").with_status(StatusCode::NOT_FOUND), - ); + return Err(ErrorResponse::reply("user not found").with_status(StatusCode::NOT_FOUND)); } - reply_json(LoginResponse { - session: session.token.into(), - }) + Ok(Json(LoginResponse { + session: session.token, + })) } diff --git a/atuin-server/src/lib.rs b/atuin-server/src/lib.rs index e4858811..ca0aa11c 100644 --- a/atuin-server/src/lib.rs +++ b/atuin-server/src/lib.rs @@ -1,8 +1,10 @@ #![forbid(unsafe_code)] -use std::net::IpAddr; +use std::net::{IpAddr, SocketAddr}; -use eyre::Result; +use axum::Server; +use database::Postgres; +use eyre::{Context, Result}; use crate::settings::Settings; @@ -19,14 +21,18 @@ pub mod models; pub mod router; pub mod settings; -pub async fn launch(settings: &Settings, host: String, port: u16) -> Result<()> { - // routes to run: - // index, register, add_history, login, get_user, sync_count, sync_list +pub async fn launch(settings: Settings, host: String, port: u16) -> Result<()> { let host = host.parse::<IpAddr>()?; - let r = router::router(settings).await?; + let postgres = Postgres::new(settings.db_uri.as_str()) + .await + .wrap_err_with(|| format!("failed to connect to db: {}", settings.db_uri))?; - warp::serve(r).run((host, port)).await; + let r = router::router(postgres, settings); + + Server::bind(&SocketAddr::new(host, port)) + .serve(r.into_make_service()) + .await?; Ok(()) } diff --git a/atuin-server/src/models.rs b/atuin-server/src/models.rs index d493153a..ee84f58a 100644 --- a/atuin-server/src/models.rs +++ b/atuin-server/src/models.rs @@ -1,5 +1,3 @@ -use std::borrow::Cow; - use chrono::prelude::*; #[derive(sqlx::FromRow)] @@ -15,13 +13,13 @@ pub struct History { pub created_at: NaiveDateTime, } -pub struct NewHistory<'a> { - pub client_id: Cow<'a, str>, +pub struct NewHistory { + pub client_id: String, pub user_id: i64, - pub hostname: Cow<'a, str>, + pub hostname: String, pub timestamp: chrono::NaiveDateTime, - pub data: Cow<'a, str>, + pub data: String, } #[derive(sqlx::FromRow)] @@ -39,13 +37,13 @@ pub struct Session { pub token: String, } -pub struct NewUser<'a> { - pub username: Cow<'a, str>, - pub email: Cow<'a, str>, - pub password: Cow<'a, str>, +pub struct NewUser { + pub username: String, + pub email: String, + pub password: String, } -pub struct NewSession<'a> { +pub struct NewSession { pub user_id: i64, - pub token: Cow<'a, str>, + pub token: String, } diff --git a/atuin-server/src/router.rs b/atuin-server/src/router.rs index f7e142a0..6ca47229 100644 --- a/atuin-server/src/router.rs +++ b/atuin-server/src/router.rs @@ -1,9 +1,12 @@ -use std::convert::Infallible; - +use async_trait::async_trait; +use axum::{ + extract::{FromRequest, RequestParts}, + handler::Handler, + response::IntoResponse, + routing::{get, post}, + Extension, Router, +}; use eyre::Result; -use warp::{hyper::StatusCode, Filter}; - -use atuin_common::api::SyncHistoryRequest; use super::{ database::{Database, Postgres}, @@ -11,119 +14,57 @@ use super::{ }; use crate::{models::User, settings::Settings}; -fn with_settings( - settings: Settings, -) -> impl Filter<Extract = (Settings,), Error = Infallible> + Clone { - warp::any().map(move || settings.clone()) -} - -fn with_db( - db: impl Database + Clone + Send + Sync, -) -> impl Filter<Extract = (impl Database + Clone,), Error = Infallible> + Clone { - warp::any().map(move || db.clone()) -} - -fn with_user( - postgres: Postgres, -) -> impl Filter<Extract = (User,), Error = warp::Rejection> + Clone { - warp::header::<String>("authorization").and_then(move |header: String| { - // async closures are still buggy :( - let postgres = postgres.clone(); - - async move { - let header: Vec<&str> = header.split(' ').collect(); - - let token = if header.len() == 2 { - if header[0] != "Token" { - return Err(warp::reject()); - } - - header[1] - } else { - return Err(warp::reject()); - }; - - let user = postgres - .get_session_user(token) - .await - .map_err(|_| warp::reject())?; - - Ok(user) +#[async_trait] +impl<B> FromRequest<B> for User +where + B: Send, +{ + type Rejection = http::StatusCode; + + async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { + let postgres = req + .extensions() + .get::<Postgres>() + .ok_or(http::StatusCode::INTERNAL_SERVER_ERROR)?; + + let auth_header = req + .headers() + .get(http::header::AUTHORIZATION) + .ok_or(http::StatusCode::FORBIDDEN)?; + let auth_header = auth_header + .to_str() + .map_err(|_| http::StatusCode::FORBIDDEN)?; + let (typ, token) = auth_header + .split_once(' ') + .ok_or(http::StatusCode::FORBIDDEN)?; + + if typ != "Token" { + return Err(http::StatusCode::FORBIDDEN); } - }) -} - -pub async fn router( - settings: &Settings, -) -> Result<impl Filter<Extract = impl warp::Reply, Error = Infallible> + Clone> { - let postgres = Postgres::new(settings.db_uri.as_str()).await?; - let index = warp::get().and(warp::path::end()).map(handlers::index); - - let count = warp::get() - .and(warp::path("sync")) - .and(warp::path("count")) - .and(warp::path::end()) - .and(with_user(postgres.clone())) - .and(with_db(postgres.clone())) - .and_then(handlers::history::count) - .boxed(); - let sync = warp::get() - .and(warp::path("sync")) - .and(warp::path("history")) - .and(warp::query::<SyncHistoryRequest>()) - .and(warp::path::end()) - .and(with_user(postgres.clone())) - .and(with_db(postgres.clone())) - .and_then(handlers::history::list) - .boxed(); + let user = postgres + .get_session_user(token) + .await + .map_err(|_| http::StatusCode::FORBIDDEN)?; - let add_history = warp::post() - .and(warp::path("history")) - .and(warp::path::end()) - .and(warp::body::json()) - .and(with_user(postgres.clone())) - .and(with_db(postgres.clone())) - .and_then(handlers::history::add) - .boxed(); - - let user = warp::get() - .and(warp::path("user")) - .and(warp::path::param::<String>()) - .and(warp::path::end()) - .and(with_db(postgres.clone())) - .and_then(handlers::user::get) - .boxed(); - - let register = warp::post() - .and(warp::path("register")) - .and(warp::path::end()) - .and(warp::body::json()) - .and(with_settings(settings.clone())) - .and(with_db(postgres.clone())) - .and_then(handlers::user::register) - .boxed(); - - let login = warp::post() - .and(warp::path("login")) - .and(warp::path::end()) - .and(warp::body::json()) - .and(with_db(postgres)) - .and_then(handlers::user::login) - .boxed(); + Ok(user) + } +} - let r = warp::any() - .and( - index - .or(count) - .or(sync) - .or(add_history) - .or(user) - .or(register) - .or(login) - .or(warp::any().map(|| warp::reply::with_status("☕", StatusCode::IM_A_TEAPOT))), - ) - .with(warp::filters::log::log("atuin::api")); +async fn teapot() -> impl IntoResponse { + (http::StatusCode::IM_A_TEAPOT, "☕") +} - Ok(r) +pub fn router(postgres: Postgres, settings: Settings) -> Router { + Router::new() + .route("/", get(handlers::index)) + .route("/sync/count", get(handlers::history::count)) + .route("/sync/history", get(handlers::history::list)) + .route("/history", post(handlers::history::add)) + .route("/user/:username", get(handlers::user::get)) + .route("/register", post(handlers::user::register)) + .route("/login", post(handlers::user::login)) + .fallback(teapot.into_service()) + .layer(Extension(postgres)) + .layer(Extension(settings)) } |