diff options
author | Ellie Huxtable <e@elm.sh> | 2021-03-21 20:04:39 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-03-21 20:04:39 +0000 |
commit | c9579cb9ca2a6a165d10f128e0af1dfd372e0c03 (patch) | |
tree | 1d4feecb422aae3cde1cc7cad54ccc73b2dae410 /src | |
parent | 716c7722cda29bf612508bb96f51822a86e0f69e (diff) |
Implement server (#23)
* Add initial database and server setup
* Set up all routes, auth, etc
* Implement sessions, password auth, hashing with argon2, and history storage
Diffstat (limited to 'src')
-rw-r--r-- | src/command/mod.rs | 2 | ||||
-rw-r--r-- | src/command/server.rs | 5 | ||||
-rw-r--r-- | src/main.rs | 24 | ||||
-rw-r--r-- | src/remote/auth.rs | 200 | ||||
-rw-r--r-- | src/remote/database.rs | 14 | ||||
-rw-r--r-- | src/remote/mod.rs | 4 | ||||
-rw-r--r-- | src/remote/models.rs | 56 | ||||
-rw-r--r-- | src/remote/server.rs | 46 | ||||
-rw-r--r-- | src/remote/views.rs | 89 | ||||
-rw-r--r-- | src/schema.rs | 28 | ||||
-rw-r--r-- | src/settings.rs | 13 |
11 files changed, 471 insertions, 10 deletions
diff --git a/src/command/mod.rs b/src/command/mod.rs index 3ebb92e0..a5ea0228 100644 --- a/src/command/mod.rs +++ b/src/command/mod.rs @@ -49,7 +49,7 @@ impl AtuinCmd { match self { Self::History(history) => history.run(db), Self::Import(import) => import.run(db), - Self::Server(server) => server.run(), + Self::Server(server) => server.run(settings), Self::Stats(stats) => stats.run(db, settings), Self::Init => init::init(), Self::Search { query } => search::run(&query, db), diff --git a/src/command/server.rs b/src/command/server.rs index 1ddc73e7..9d9bcb3a 100644 --- a/src/command/server.rs +++ b/src/command/server.rs @@ -2,6 +2,7 @@ use eyre::Result; use structopt::StructOpt; use crate::remote::server; +use crate::settings::Settings; #[derive(StructOpt)] pub enum Cmd { @@ -10,8 +11,8 @@ pub enum Cmd { #[allow(clippy::unused_self)] // I'll use it later impl Cmd { - pub fn run(&self) -> Result<()> { - server::launch(); + pub fn run(&self, settings: &Settings) -> Result<()> { + server::launch(settings); Ok(()) } } diff --git a/src/main.rs b/src/main.rs index d47866f4..3c4a05e4 100644 --- a/src/main.rs +++ b/src/main.rs @@ -17,6 +17,15 @@ extern crate rocket; #[macro_use] extern crate serde_derive; +#[macro_use] +extern crate diesel; + +#[macro_use] +extern crate diesel_migrations; + +#[macro_use] +extern crate rocket_contrib; + use command::AtuinCmd; use local::database::Sqlite; use settings::Settings; @@ -26,6 +35,8 @@ mod local; mod remote; mod settings; +pub mod schema; + #[derive(StructOpt)] #[structopt( author = "Ellie Huxtable <e@elm.sh>", @@ -61,7 +72,18 @@ impl Atuin { } fn main() -> Result<()> { - pretty_env_logger::init(); + fern::Dispatch::new() + .format(|out, message, record| { + out.finish(format_args!( + "{} [{}] {}", + chrono::Local::now().format("[%Y-%m-%d][%H:%M:%S]"), + record.level(), + message + )) + }) + .level(log::LevelFilter::Info) + .chain(std::io::stdout()) + .apply()?; Atuin::from_args().run() } diff --git a/src/remote/auth.rs b/src/remote/auth.rs new file mode 100644 index 00000000..8f9e9b46 --- /dev/null +++ b/src/remote/auth.rs @@ -0,0 +1,200 @@ +use self::diesel::prelude::*; +use rocket::http::Status; +use rocket::request::{self, FromRequest, Outcome, Request}; +use rocket_contrib::databases::diesel; +use sodiumoxide::crypto::pwhash::argon2id13; + +use rocket_contrib::json::Json; +use uuid::Uuid; + +use super::models::{NewSession, NewUser, Session, User}; +use super::views::ApiResponse; +use crate::schema::{sessions, users}; + +use super::database::AtuinDbConn; + +#[derive(Debug)] +pub enum KeyError { + Missing, + Invalid, +} + +pub fn hash_str(secret: &str) -> String { + sodiumoxide::init().unwrap(); + let hash = argon2id13::pwhash( + secret.as_bytes(), + argon2id13::OPSLIMIT_INTERACTIVE, + argon2id13::MEMLIMIT_INTERACTIVE, + ) + .unwrap(); + let texthash = std::str::from_utf8(&hash.0).unwrap().to_string(); + + // postgres hates null chars. don't do that to postgres + texthash.trim_end_matches('\u{0}').to_string() +} + +pub fn verify_str(secret: &str, verify: &str) -> bool { + sodiumoxide::init().unwrap(); + + let mut padded = [0_u8; 128]; + secret.as_bytes().iter().enumerate().for_each(|(i, val)| { + padded[i] = *val; + }); + + match argon2id13::HashedPassword::from_slice(&padded) { + Some(hp) => argon2id13::pwhash_verify(&hp, verify.as_bytes()), + None => false, + } +} + +impl<'a, 'r> FromRequest<'a, 'r> for User { + type Error = KeyError; + + fn from_request(request: &'a Request<'r>) -> request::Outcome<User, Self::Error> { + let session: Vec<_> = request.headers().get("authorization").collect(); + + if session.is_empty() { + return Outcome::Failure((Status::BadRequest, KeyError::Missing)); + } else if session.len() > 1 { + return Outcome::Failure((Status::BadRequest, KeyError::Invalid)); + } + + let session: Vec<_> = session[0].split(' ').collect(); + + if session.len() != 2 { + return Outcome::Failure((Status::BadRequest, KeyError::Invalid)); + } + + if session[0] != "Token" { + return Outcome::Failure((Status::BadRequest, KeyError::Invalid)); + } + + let session = session[1]; + + let db = request + .guard::<AtuinDbConn>() + .succeeded() + .expect("failed to load database"); + + let session = sessions::table + .filter(sessions::token.eq(session)) + .first::<Session>(&*db); + + if session.is_err() { + return Outcome::Failure((Status::Unauthorized, KeyError::Invalid)); + } + + let session = session.unwrap(); + + let user = users::table.find(session.user_id).first(&*db); + + match user { + Ok(user) => Outcome::Success(user), + Err(_) => Outcome::Failure((Status::Unauthorized, KeyError::Invalid)), + } + } +} + +#[derive(Deserialize)] +pub struct Register { + email: String, + password: String, +} + +#[post("/register", data = "<register>")] +#[allow(clippy::clippy::needless_pass_by_value)] +pub fn register(conn: AtuinDbConn, register: Json<Register>) -> ApiResponse { + let hashed = hash_str(register.password.as_str()); + + let new_user = NewUser { + email: register.email.as_str(), + password: hashed.as_str(), + }; + + let user = diesel::insert_into(users::table) + .values(&new_user) + .get_result(&*conn); + + if user.is_err() { + return ApiResponse { + status: Status::BadRequest, + json: json!({ + "status": "error", + "message": "failed to create user - is the email already in use?", + }), + }; + } + + let user: User = user.unwrap(); + let token = Uuid::new_v4().to_simple().to_string(); + + let new_session = NewSession { + user_id: user.id, + token: token.as_str(), + }; + + match diesel::insert_into(sessions::table) + .values(&new_session) + .execute(&*conn) + { + Ok(_) => ApiResponse { + status: Status::Ok, + json: json!({"status": "ok", "message": "user created!", "session": token}), + }, + Err(_) => ApiResponse { + status: Status::BadRequest, + json: json!({"status": "error", "message": "failed to create user"}), + }, + } +} + +#[derive(Deserialize)] +pub struct Login { + email: String, + password: String, +} + +#[post("/login", data = "<login>")] +#[allow(clippy::clippy::needless_pass_by_value)] +pub fn login(conn: AtuinDbConn, login: Json<Login>) -> ApiResponse { + let user = users::table + .filter(users::email.eq(login.email.as_str())) + .first(&*conn); + + if user.is_err() { + return ApiResponse { + status: Status::NotFound, + json: json!({"status": "error", "message": "user not found"}), + }; + } + + let user: User = user.unwrap(); + + let session = sessions::table + .filter(sessions::user_id.eq(user.id)) + .first(&*conn); + + // a session should exist... + if session.is_err() { + return ApiResponse { + status: Status::InternalServerError, + json: json!({"status": "error", "message": "something went wrong"}), + }; + } + + let verified = verify_str(user.password.as_str(), login.password.as_str()); + + if !verified { + return ApiResponse { + status: Status::NotFound, + json: json!({"status": "error", "message": "user not found"}), + }; + } + + let session: Session = session.unwrap(); + + ApiResponse { + status: Status::Ok, + json: json!({"status": "ok", "token": session.token}), + } +} diff --git a/src/remote/database.rs b/src/remote/database.rs new file mode 100644 index 00000000..4f386def --- /dev/null +++ b/src/remote/database.rs @@ -0,0 +1,14 @@ +use diesel::pg::PgConnection; +use diesel::prelude::*; + +use crate::settings::Settings; + +#[database("atuin")] +pub struct AtuinDbConn(diesel::PgConnection); + +// TODO: connection pooling +pub fn establish_connection(settings: &Settings) -> PgConnection { + let database_url = &settings.remote.db.url; + PgConnection::establish(database_url) + .unwrap_or_else(|_| panic!("Error connecting to {}", database_url)) +} diff --git a/src/remote/mod.rs b/src/remote/mod.rs index 74f47ad3..7147b88e 100644 --- a/src/remote/mod.rs +++ b/src/remote/mod.rs @@ -1 +1,5 @@ +pub mod auth; +pub mod database; +pub mod models; pub mod server; +pub mod views; diff --git a/src/remote/models.rs b/src/remote/models.rs new file mode 100644 index 00000000..058b2f0b --- /dev/null +++ b/src/remote/models.rs @@ -0,0 +1,56 @@ +use chrono::naive::NaiveDateTime; + +use crate::schema::{history, sessions, users}; + +#[derive(Identifiable, Queryable, Associations)] +#[table_name = "history"] +#[belongs_to(User)] +pub struct History { + pub id: i64, + pub client_id: String, + pub user_id: i64, + pub mac: String, + pub timestamp: NaiveDateTime, + + pub data: String, +} + +#[derive(Identifiable, Queryable, Associations)] +pub struct User { + pub id: i64, + pub email: String, + pub password: String, +} + +#[derive(Queryable, Identifiable, Associations)] +#[belongs_to(User)] +pub struct Session { + pub id: i64, + pub user_id: i64, + pub token: String, +} + +#[derive(Insertable)] +#[table_name = "history"] +pub struct NewHistory<'a> { + pub client_id: &'a str, + pub user_id: i64, + pub mac: &'a str, + pub timestamp: NaiveDateTime, + + pub data: &'a str, +} + +#[derive(Insertable)] +#[table_name = "users"] +pub struct NewUser<'a> { + pub email: &'a str, + pub password: &'a str, +} + +#[derive(Insertable)] +#[table_name = "sessions"] +pub struct NewSession<'a> { + pub user_id: i64, + pub token: &'a str, +} diff --git a/src/remote/server.rs b/src/remote/server.rs index bc1dc2bd..4409f646 100644 --- a/src/remote/server.rs +++ b/src/remote/server.rs @@ -1,8 +1,42 @@ -#[get("/")] -const fn index() -> &'static str { - "Hello, world!" -} +use rocket::config::{Config, Environment, LoggingLevel, Value}; + +use std::collections::HashMap; + +use crate::remote::database::establish_connection; +use crate::settings::Settings; + +use super::database::AtuinDbConn; + +// a bunch of these imports are generated by macros, it's easier to wildcard +#[allow(clippy::clippy::wildcard_imports)] +use super::views::*; + +#[allow(clippy::clippy::wildcard_imports)] +use super::auth::*; + +embed_migrations!("migrations"); + +pub fn launch(settings: &Settings) { + let mut database_config = HashMap::new(); + let mut databases = HashMap::new(); + + database_config.insert("url", Value::from(settings.remote.db.url.clone())); + databases.insert("atuin", Value::from(database_config)); + + let connection = establish_connection(settings); + embedded_migrations::run(&connection).expect("failed to run migrations"); + + let config = Config::build(Environment::Production) + .address("0.0.0.0") + .log_level(LoggingLevel::Normal) + .port(8080) + .extra("databases", databases) + .finalize() + .unwrap(); -pub fn launch() { - rocket::ignite().mount("/", routes![index]).launch(); + let app = rocket::custom(config); + app.mount("/", routes![index, register, add_history, login]) + .attach(AtuinDbConn::fairing()) + .register(catchers![internal_error, bad_request]) + .launch(); } diff --git a/src/remote/views.rs b/src/remote/views.rs new file mode 100644 index 00000000..2af3f369 --- /dev/null +++ b/src/remote/views.rs @@ -0,0 +1,89 @@ +use self::diesel::prelude::*; +use rocket::http::{ContentType, Status}; +use rocket::request::Request; +use rocket::response; +use rocket::response::{Responder, Response}; +use rocket_contrib::databases::diesel; +use rocket_contrib::json::{Json, JsonValue}; + +use super::database::AtuinDbConn; +use super::models::{NewHistory, User}; +use crate::schema::history; + +#[derive(Debug)] +pub struct ApiResponse { + pub json: JsonValue, + pub status: Status, +} + +impl<'r> Responder<'r> for ApiResponse { + fn respond_to(self, req: &Request) -> response::Result<'r> { + Response::build_from(self.json.respond_to(req).unwrap()) + .status(self.status) + .header(ContentType::JSON) + .ok() + } +} + +#[get("/")] +pub const fn index() -> &'static str { + "\"Through the fathomless deeps of space swims the star turtle Great A\u{2019}Tuin, bearing on its back the four giant elephants who carry on their shoulders the mass of the Discworld.\"\n\t-- Sir Terry Pratchett" +} + +#[catch(500)] +pub fn internal_error(_req: &Request) -> ApiResponse { + ApiResponse { + status: Status::InternalServerError, + json: json!({"status": "error", "message": "an internal server error has occured"}), + } +} + +#[catch(400)] +pub fn bad_request(_req: &Request) -> ApiResponse { + ApiResponse { + status: Status::InternalServerError, + json: json!({"status": "error", "message": "bad request. don't do that."}), + } +} + +#[derive(Deserialize)] +pub struct AddHistory { + id: String, + timestamp: i64, + data: String, + mac: String, +} + +#[post("/history", data = "<add_history>")] +#[allow( + clippy::clippy::cast_sign_loss, + clippy::cast_possible_truncation, + clippy::clippy::needless_pass_by_value +)] +pub fn add_history(conn: AtuinDbConn, user: User, add_history: Json<AddHistory>) -> ApiResponse { + let secs: i64 = add_history.timestamp / 1_000_000_000; + let nanosecs: u32 = (add_history.timestamp - (secs * 1_000_000_000)) as u32; + let datetime = chrono::NaiveDateTime::from_timestamp(secs, nanosecs); + + let new_history = NewHistory { + client_id: add_history.id.as_str(), + user_id: user.id, + mac: add_history.mac.as_str(), + timestamp: datetime, + data: add_history.data.as_str(), + }; + + match diesel::insert_into(history::table) + .values(&new_history) + .execute(&*conn) + { + Ok(_) => ApiResponse { + status: Status::Ok, + json: json!({"status": "ok", "message": "history added", "id": new_history.client_id}), + }, + Err(_) => ApiResponse { + status: Status::BadRequest, + json: json!({"status": "error", "message": "failed to add history"}), + }, + } +} diff --git a/src/schema.rs b/src/schema.rs new file mode 100644 index 00000000..efa9ddcc --- /dev/null +++ b/src/schema.rs @@ -0,0 +1,28 @@ +table! { + history (id) { + id -> Int8, + client_id -> Text, + user_id -> Int8, + mac -> Varchar, + timestamp -> Timestamp, + data -> Varchar, + } +} + +table! { + sessions (id) { + id -> Int8, + user_id -> Int8, + token -> Varchar, + } +} + +table! { + users (id) { + id -> Int8, + email -> Varchar, + password -> Varchar, + } +} + +allow_tables_to_appear_in_same_query!(history, sessions, users,); diff --git a/src/settings.rs b/src/settings.rs index a4c9f8da..6f29afd2 100644 --- a/src/settings.rs +++ b/src/settings.rs @@ -11,6 +11,11 @@ pub struct LocalDatabase { } #[derive(Debug, Deserialize)] +pub struct RemoteDatabase { + pub url: String, +} + +#[derive(Debug, Deserialize)] pub struct Local { pub server_address: String, pub dialect: String, @@ -18,8 +23,14 @@ pub struct Local { } #[derive(Debug, Deserialize)] +pub struct Remote { + pub db: RemoteDatabase, +} + +#[derive(Debug, Deserialize)] pub struct Settings { pub local: Local, + pub remote: Remote, } impl Settings { @@ -49,6 +60,8 @@ impl Settings { s.set_default("local.dialect", "us")?; s.set_default("local.db.path", db_path.to_str())?; + s.set_default("remote.db.url", "please set a postgres url")?; + if config_file.exists() { s.merge(File::with_name(config_file.to_str().unwrap()))?; } |