summaryrefslogtreecommitdiffstats
path: root/atuin-server
diff options
context:
space:
mode:
authorConrad Ludgate <conrad.ludgate@truelayer.com>2022-04-12 23:06:19 +0100
committerGitHub <noreply@github.com>2022-04-12 23:06:19 +0100
commita95018cc9039851e707973bc19faf907132ae4f3 (patch)
treee135f1da64c5d020f336d437f83a333298861ca0 /atuin-server
parent3b7ed7caffdbedfd30b022b8e2b3f93a2b6a494a (diff)
goodbye warp, hello axum (#296)
Diffstat (limited to 'atuin-server')
-rw-r--r--atuin-server/Cargo.toml3
-rw-r--r--atuin-server/src/handlers/history.rs53
-rw-r--r--atuin-server/src/handlers/mod.rs2
-rw-r--r--atuin-server/src/handlers/user.rs72
-rw-r--r--atuin-server/src/lib.rs20
-rw-r--r--atuin-server/src/models.rs22
-rw-r--r--atuin-server/src/router.rs171
7 files changed, 138 insertions, 205 deletions
diff --git a/atuin-server/Cargo.toml b/atuin-server/Cargo.toml
index e1acc97b..16a9fa0b 100644
--- a/atuin-server/Cargo.toml
+++ b/atuin-server/Cargo.toml
@@ -25,6 +25,7 @@ base64 = "0.13.0"
rand = "0.8.4"
rust-crypto = "^0.2"
tokio = { version = "1", features = ["full"] }
-warp = "0.3"
sqlx = { version = "0.5", features = [ "runtime-tokio-rustls", "uuid", "chrono", "postgres" ] }
async-trait = "0.1.49"
+axum = "0.5"
+http = "0.2"
diff --git a/atuin-server/src/handlers/history.rs b/atuin-server/src/handlers/history.rs
index 06715381..546e5a29 100644
--- a/atuin-server/src/handlers/history.rs
+++ b/atuin-server/src/handlers/history.rs
@@ -1,26 +1,27 @@
-use warp::{http::StatusCode, Reply};
+use axum::extract::Query;
+use axum::{Extension, Json};
+use http::StatusCode;
-use crate::database::Database;
+use crate::database::{Database, Postgres};
use crate::models::{NewHistory, User};
use atuin_common::api::*;
+
pub async fn count(
user: User,
- db: impl Database + Clone + Send + Sync,
-) -> JSONResult<ErrorResponseStatus<'static>> {
- db.count_history(&user).await.map_or(
- reply_error(
- ErrorResponse::reply("failed to query history count")
- .with_status(StatusCode::INTERNAL_SERVER_ERROR),
- ),
- |count| reply_json(CountResponse { count }),
- )
+ db: Extension<Postgres>,
+) -> Result<Json<CountResponse>, ErrorResponseStatus<'static>> {
+ match db.count_history(&user).await {
+ Ok(count) => Ok(Json(CountResponse { count })),
+ Err(_) => Err(ErrorResponse::reply("failed to query history count")
+ .with_status(StatusCode::INTERNAL_SERVER_ERROR)),
+ }
}
pub async fn list(
- req: SyncHistoryRequest<'_>,
+ req: Query<SyncHistoryRequest>,
user: User,
- db: impl Database + Clone + Send + Sync,
-) -> JSONResult<ErrorResponseStatus<'static>> {
+ db: Extension<Postgres>,
+) -> Result<Json<SyncHistoryResponse>, ErrorResponseStatus<'static>> {
let history = db
.list_history(
&user,
@@ -32,10 +33,8 @@ pub async fn list(
if let Err(e) = history {
error!("failed to load history: {}", e);
- return reply_error(
- ErrorResponse::reply("failed to load history")
- .with_status(StatusCode::INTERNAL_SERVER_ERROR),
- );
+ return Err(ErrorResponse::reply("failed to load history")
+ .with_status(StatusCode::INTERNAL_SERVER_ERROR));
}
let history: Vec<String> = history
@@ -50,14 +49,14 @@ pub async fn list(
user.id
);
- reply_json(SyncHistoryResponse { history })
+ Ok(Json(SyncHistoryResponse { history }))
}
pub async fn add(
- req: Vec<AddHistoryRequest<'_, String>>,
+ Json(req): Json<Vec<AddHistoryRequest>>,
user: User,
- db: impl Database + Clone + Send + Sync,
-) -> ReplyResult<impl Reply, ErrorResponseStatus<'_>> {
+ db: Extension<Postgres>,
+) -> Result<(), ErrorResponseStatus<'static>> {
debug!("request to add {} history items", req.len());
let history: Vec<NewHistory> = req
@@ -67,18 +66,16 @@ pub async fn add(
user_id: user.id,
hostname: h.hostname,
timestamp: h.timestamp.naive_utc(),
- data: h.data.into(),
+ data: h.data,
})
.collect();
if let Err(e) = db.add_history(&history).await {
error!("failed to add history: {}", e);
- return reply_error(
- ErrorResponse::reply("failed to add history")
- .with_status(StatusCode::INTERNAL_SERVER_ERROR),
- );
+ return Err(ErrorResponse::reply("failed to add history")
+ .with_status(StatusCode::INTERNAL_SERVER_ERROR));
};
- reply(warp::reply())
+ Ok(())
}
diff --git a/atuin-server/src/handlers/mod.rs b/atuin-server/src/handlers/mod.rs
index 3c20538c..83c2d0c3 100644
--- a/atuin-server/src/handlers/mod.rs
+++ b/atuin-server/src/handlers/mod.rs
@@ -1,6 +1,6 @@
pub mod history;
pub mod user;
-pub const fn index() -> &'static str {
+pub async fn index() -> &'static str {
"\"Through the fathomless deeps of space swims the star turtle Great A\u{2019}Tuin, bearing on its back the four giant elephants who carry on their shoulders the mass of the Discworld.\"\n\t-- Sir Terry Pratchett"
}
diff --git a/atuin-server/src/handlers/user.rs b/atuin-server/src/handlers/user.rs
index 8144adab..1bcfce2f 100644
--- a/atuin-server/src/handlers/user.rs
+++ b/atuin-server/src/handlers/user.rs
@@ -2,11 +2,13 @@ use std::borrow::Borrow;
use atuin_common::api::*;
use atuin_common::utils::hash_secret;
+use axum::extract::Path;
+use axum::{Extension, Json};
+use http::StatusCode;
use sodiumoxide::crypto::pwhash::argon2id13;
use uuid::Uuid;
-use warp::http::StatusCode;
-use crate::database::Database;
+use crate::database::{Database, Postgres};
use crate::models::{NewSession, NewUser};
use crate::settings::Settings;
@@ -25,31 +27,29 @@ pub fn verify_str(secret: &str, verify: &str) -> bool {
}
pub async fn get(
- username: impl AsRef<str>,
- db: impl Database + Clone + Send + Sync,
-) -> JSONResult<ErrorResponseStatus<'static>> {
+ Path(username): Path<String>,
+ db: Extension<Postgres>,
+) -> Result<Json<UserResponse>, ErrorResponseStatus<'static>> {
let user = match db.get_user(username.as_ref()).await {
Ok(user) => user,
Err(e) => {
debug!("user not found: {}", e);
- return reply_error(
- ErrorResponse::reply("user not found").with_status(StatusCode::NOT_FOUND),
- );
+ return Err(ErrorResponse::reply("user not found").with_status(StatusCode::NOT_FOUND));
}
};
- reply_json(UserResponse {
- username: user.username.into(),
- })
+ Ok(Json(UserResponse {
+ username: user.username,
+ }))
}
pub async fn register(
- register: RegisterRequest<'_>,
- settings: Settings,
- db: impl Database + Clone + Send + Sync,
-) -> JSONResult<ErrorResponseStatus<'static>> {
+ Json(register): Json<RegisterRequest>,
+ settings: Extension<Settings>,
+ db: Extension<Postgres>,
+) -> Result<Json<RegisterResponse>, ErrorResponseStatus<'static>> {
if !settings.open_registration {
- return reply_error(
+ return Err(
ErrorResponse::reply("this server is not open for registrations")
.with_status(StatusCode::BAD_REQUEST),
);
@@ -60,15 +60,15 @@ pub async fn register(
let new_user = NewUser {
email: register.email,
username: register.username,
- password: hashed.into(),
+ password: hashed,
};
let user_id = match db.add_user(&new_user).await {
Ok(id) => id,
Err(e) => {
error!("failed to add user: {}", e);
- return reply_error(
- ErrorResponse::reply("failed to add user").with_status(StatusCode::BAD_REQUEST),
+ return Err(
+ ErrorResponse::reply("failed to add user").with_status(StatusCode::BAD_REQUEST)
);
}
};
@@ -81,31 +81,25 @@ pub async fn register(
};
match db.add_session(&new_session).await {
- Ok(_) => reply_json(RegisterResponse {
- session: token.into(),
- }),
+ Ok(_) => Ok(Json(RegisterResponse { session: token })),
Err(e) => {
error!("failed to add session: {}", e);
- reply_error(
- ErrorResponse::reply("failed to register user")
- .with_status(StatusCode::BAD_REQUEST),
- )
+ Err(ErrorResponse::reply("failed to register user")
+ .with_status(StatusCode::BAD_REQUEST))
}
}
}
pub async fn login(
- login: LoginRequest<'_>,
- db: impl Database + Clone + Send + Sync,
-) -> JSONResult<ErrorResponseStatus<'_>> {
+ login: Json<LoginRequest>,
+ db: Extension<Postgres>,
+) -> Result<Json<LoginResponse>, ErrorResponseStatus<'static>> {
let user = match db.get_user(login.username.borrow()).await {
Ok(u) => u,
Err(e) => {
error!("failed to get user {}: {}", login.username.clone(), e);
- return reply_error(
- ErrorResponse::reply("user not found").with_status(StatusCode::NOT_FOUND),
- );
+ return Err(ErrorResponse::reply("user not found").with_status(StatusCode::NOT_FOUND));
}
};
@@ -114,21 +108,17 @@ pub async fn login(
Err(e) => {
error!("failed to get session for {}: {}", login.username, e);
- return reply_error(
- ErrorResponse::reply("user not found").with_status(StatusCode::NOT_FOUND),
- );
+ return Err(ErrorResponse::reply("user not found").with_status(StatusCode::NOT_FOUND));
}
};
let verified = verify_str(user.password.as_str(), login.password.borrow());
if !verified {
- return reply_error(
- ErrorResponse::reply("user not found").with_status(StatusCode::NOT_FOUND),
- );
+ return Err(ErrorResponse::reply("user not found").with_status(StatusCode::NOT_FOUND));
}
- reply_json(LoginResponse {
- session: session.token.into(),
- })
+ Ok(Json(LoginResponse {
+ session: session.token,
+ }))
}
diff --git a/atuin-server/src/lib.rs b/atuin-server/src/lib.rs
index e4858811..ca0aa11c 100644
--- a/atuin-server/src/lib.rs
+++ b/atuin-server/src/lib.rs
@@ -1,8 +1,10 @@
#![forbid(unsafe_code)]
-use std::net::IpAddr;
+use std::net::{IpAddr, SocketAddr};
-use eyre::Result;
+use axum::Server;
+use database::Postgres;
+use eyre::{Context, Result};
use crate::settings::Settings;
@@ -19,14 +21,18 @@ pub mod models;
pub mod router;
pub mod settings;
-pub async fn launch(settings: &Settings, host: String, port: u16) -> Result<()> {
- // routes to run:
- // index, register, add_history, login, get_user, sync_count, sync_list
+pub async fn launch(settings: Settings, host: String, port: u16) -> Result<()> {
let host = host.parse::<IpAddr>()?;
- let r = router::router(settings).await?;
+ let postgres = Postgres::new(settings.db_uri.as_str())
+ .await
+ .wrap_err_with(|| format!("failed to connect to db: {}", settings.db_uri))?;
- warp::serve(r).run((host, port)).await;
+ let r = router::router(postgres, settings);
+
+ Server::bind(&SocketAddr::new(host, port))
+ .serve(r.into_make_service())
+ .await?;
Ok(())
}
diff --git a/atuin-server/src/models.rs b/atuin-server/src/models.rs
index d493153a..ee84f58a 100644
--- a/atuin-server/src/models.rs
+++ b/atuin-server/src/models.rs
@@ -1,5 +1,3 @@
-use std::borrow::Cow;
-
use chrono::prelude::*;
#[derive(sqlx::FromRow)]
@@ -15,13 +13,13 @@ pub struct History {
pub created_at: NaiveDateTime,
}
-pub struct NewHistory<'a> {
- pub client_id: Cow<'a, str>,
+pub struct NewHistory {
+ pub client_id: String,
pub user_id: i64,
- pub hostname: Cow<'a, str>,
+ pub hostname: String,
pub timestamp: chrono::NaiveDateTime,
- pub data: Cow<'a, str>,
+ pub data: String,
}
#[derive(sqlx::FromRow)]
@@ -39,13 +37,13 @@ pub struct Session {
pub token: String,
}
-pub struct NewUser<'a> {
- pub username: Cow<'a, str>,
- pub email: Cow<'a, str>,
- pub password: Cow<'a, str>,
+pub struct NewUser {
+ pub username: String,
+ pub email: String,
+ pub password: String,
}
-pub struct NewSession<'a> {
+pub struct NewSession {
pub user_id: i64,
- pub token: Cow<'a, str>,
+ pub token: String,
}
diff --git a/atuin-server/src/router.rs b/atuin-server/src/router.rs
index f7e142a0..6ca47229 100644
--- a/atuin-server/src/router.rs
+++ b/atuin-server/src/router.rs
@@ -1,9 +1,12 @@
-use std::convert::Infallible;
-
+use async_trait::async_trait;
+use axum::{
+ extract::{FromRequest, RequestParts},
+ handler::Handler,
+ response::IntoResponse,
+ routing::{get, post},
+ Extension, Router,
+};
use eyre::Result;
-use warp::{hyper::StatusCode, Filter};
-
-use atuin_common::api::SyncHistoryRequest;
use super::{
database::{Database, Postgres},
@@ -11,119 +14,57 @@ use super::{
};
use crate::{models::User, settings::Settings};
-fn with_settings(
- settings: Settings,
-) -> impl Filter<Extract = (Settings,), Error = Infallible> + Clone {
- warp::any().map(move || settings.clone())
-}
-
-fn with_db(
- db: impl Database + Clone + Send + Sync,
-) -> impl Filter<Extract = (impl Database + Clone,), Error = Infallible> + Clone {
- warp::any().map(move || db.clone())
-}
-
-fn with_user(
- postgres: Postgres,
-) -> impl Filter<Extract = (User,), Error = warp::Rejection> + Clone {
- warp::header::<String>("authorization").and_then(move |header: String| {
- // async closures are still buggy :(
- let postgres = postgres.clone();
-
- async move {
- let header: Vec<&str> = header.split(' ').collect();
-
- let token = if header.len() == 2 {
- if header[0] != "Token" {
- return Err(warp::reject());
- }
-
- header[1]
- } else {
- return Err(warp::reject());
- };
-
- let user = postgres
- .get_session_user(token)
- .await
- .map_err(|_| warp::reject())?;
-
- Ok(user)
+#[async_trait]
+impl<B> FromRequest<B> for User
+where
+ B: Send,
+{
+ type Rejection = http::StatusCode;
+
+ async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
+ let postgres = req
+ .extensions()
+ .get::<Postgres>()
+ .ok_or(http::StatusCode::INTERNAL_SERVER_ERROR)?;
+
+ let auth_header = req
+ .headers()
+ .get(http::header::AUTHORIZATION)
+ .ok_or(http::StatusCode::FORBIDDEN)?;
+ let auth_header = auth_header
+ .to_str()
+ .map_err(|_| http::StatusCode::FORBIDDEN)?;
+ let (typ, token) = auth_header
+ .split_once(' ')
+ .ok_or(http::StatusCode::FORBIDDEN)?;
+
+ if typ != "Token" {
+ return Err(http::StatusCode::FORBIDDEN);
}
- })
-}
-
-pub async fn router(
- settings: &Settings,
-) -> Result<impl Filter<Extract = impl warp::Reply, Error = Infallible> + Clone> {
- let postgres = Postgres::new(settings.db_uri.as_str()).await?;
- let index = warp::get().and(warp::path::end()).map(handlers::index);
-
- let count = warp::get()
- .and(warp::path("sync"))
- .and(warp::path("count"))
- .and(warp::path::end())
- .and(with_user(postgres.clone()))
- .and(with_db(postgres.clone()))
- .and_then(handlers::history::count)
- .boxed();
- let sync = warp::get()
- .and(warp::path("sync"))
- .and(warp::path("history"))
- .and(warp::query::<SyncHistoryRequest>())
- .and(warp::path::end())
- .and(with_user(postgres.clone()))
- .and(with_db(postgres.clone()))
- .and_then(handlers::history::list)
- .boxed();
+ let user = postgres
+ .get_session_user(token)
+ .await
+ .map_err(|_| http::StatusCode::FORBIDDEN)?;
- let add_history = warp::post()
- .and(warp::path("history"))
- .and(warp::path::end())
- .and(warp::body::json())
- .and(with_user(postgres.clone()))
- .and(with_db(postgres.clone()))
- .and_then(handlers::history::add)
- .boxed();
-
- let user = warp::get()
- .and(warp::path("user"))
- .and(warp::path::param::<String>())
- .and(warp::path::end())
- .and(with_db(postgres.clone()))
- .and_then(handlers::user::get)
- .boxed();
-
- let register = warp::post()
- .and(warp::path("register"))
- .and(warp::path::end())
- .and(warp::body::json())
- .and(with_settings(settings.clone()))
- .and(with_db(postgres.clone()))
- .and_then(handlers::user::register)
- .boxed();
-
- let login = warp::post()
- .and(warp::path("login"))
- .and(warp::path::end())
- .and(warp::body::json())
- .and(with_db(postgres))
- .and_then(handlers::user::login)
- .boxed();
+ Ok(user)
+ }
+}
- let r = warp::any()
- .and(
- index
- .or(count)
- .or(sync)
- .or(add_history)
- .or(user)
- .or(register)
- .or(login)
- .or(warp::any().map(|| warp::reply::with_status("☕", StatusCode::IM_A_TEAPOT))),
- )
- .with(warp::filters::log::log("atuin::api"));
+async fn teapot() -> impl IntoResponse {
+ (http::StatusCode::IM_A_TEAPOT, "☕")
+}
- Ok(r)
+pub fn router(postgres: Postgres, settings: Settings) -> Router {
+ Router::new()
+ .route("/", get(handlers::index))
+ .route("/sync/count", get(handlers::history::count))
+ .route("/sync/history", get(handlers::history::list))
+ .route("/history", post(handlers::history::add))
+ .route("/user/:username", get(handlers::user::get))
+ .route("/register", post(handlers::user::register))
+ .route("/login", post(handlers::user::login))
+ .fallback(teapot.into_service())
+ .layer(Extension(postgres))
+ .layer(Extension(settings))
}