From 34888827f8a06de835cbe5833a06914f28cce514 Mon Sep 17 00:00:00 2001 From: Ellie Huxtable Date: Tue, 20 Apr 2021 17:07:11 +0100 Subject: Switch to Warp + SQLx, use async, switch to Rust stable (#36) * Switch to warp + sql, use async and stable rust * Update CI to use stable --- src/api.rs | 42 +++++++- src/command/history.rs | 8 +- src/command/login.rs | 7 +- src/command/mod.rs | 8 +- src/command/search.rs | 110 ++++++++++++++------ src/command/server.rs | 6 +- src/command/sync.rs | 4 +- src/local/api_client.rs | 87 ++++++++-------- src/local/database.rs | 8 +- src/local/import.rs | 7 +- src/local/sync.rs | 36 ++++--- src/main.rs | 43 ++++---- src/remote/auth.rs | 220 ---------------------------------------- src/remote/database.rs | 22 ---- src/remote/mod.rs | 5 - src/remote/models.rs | 60 ----------- src/remote/server.rs | 61 ----------- src/remote/views.rs | 185 ---------------------------------- src/schema.rs | 30 ------ src/server/auth.rs | 222 +++++++++++++++++++++++++++++++++++++++++ src/server/database.rs | 202 +++++++++++++++++++++++++++++++++++++ src/server/handlers/history.rs | 89 +++++++++++++++++ src/server/handlers/mod.rs | 6 ++ src/server/handlers/user.rs | 140 ++++++++++++++++++++++++++ src/server/mod.rs | 23 +++++ src/server/models.rs | 49 +++++++++ src/server/router.rs | 121 ++++++++++++++++++++++ src/settings.rs | 2 +- src/shell/atuin.zsh | 1 + 29 files changed, 1085 insertions(+), 719 deletions(-) delete mode 100644 src/remote/auth.rs delete mode 100644 src/remote/database.rs delete mode 100644 src/remote/mod.rs delete mode 100644 src/remote/models.rs delete mode 100644 src/remote/server.rs delete mode 100644 src/remote/views.rs delete mode 100644 src/schema.rs create mode 100644 src/server/auth.rs create mode 100644 src/server/database.rs create mode 100644 src/server/handlers/history.rs create mode 100644 src/server/handlers/mod.rs create mode 100644 src/server/handlers/user.rs create mode 100644 src/server/mod.rs create mode 100644 src/server/models.rs create mode 100644 src/server/router.rs (limited to 'src') diff --git a/src/api.rs b/src/api.rs index 90977404..82ee6604 100644 --- a/src/api.rs +++ b/src/api.rs @@ -1,8 +1,9 @@ use chrono::Utc; -// This is shared between the client and the server, and has the data structures -// representing the requests/responses for each method. -// TODO: Properly define responses rather than using json! +#[derive(Debug, Serialize, Deserialize)] +pub struct UserResponse { + pub username: String, +} #[derive(Debug, Serialize, Deserialize)] pub struct RegisterRequest { @@ -11,12 +12,22 @@ pub struct RegisterRequest { pub password: String, } +#[derive(Debug, Serialize, Deserialize)] +pub struct RegisterResponse { + pub session: String, +} + #[derive(Debug, Serialize, Deserialize)] pub struct LoginRequest { pub username: String, pub password: String, } +#[derive(Debug, Serialize, Deserialize)] +pub struct LoginResponse { + pub session: String, +} + #[derive(Debug, Serialize, Deserialize)] pub struct AddHistoryRequest { pub id: String, @@ -31,6 +42,29 @@ pub struct CountResponse { } #[derive(Debug, Serialize, Deserialize)] -pub struct ListHistoryResponse { +pub struct SyncHistoryRequest { + pub sync_ts: chrono::DateTime, + pub history_ts: chrono::DateTime, + pub host: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct SyncHistoryResponse { pub history: Vec, } + +#[derive(Debug, Serialize, Deserialize)] +pub struct ErrorResponse { + pub reason: String, +} + +impl ErrorResponse { + pub fn reply(reason: &str, status: warp::http::StatusCode) -> impl warp::Reply { + warp::reply::with_status( + warp::reply::json(&ErrorResponse { + reason: String::from(reason), + }), + status, + ) + } +} diff --git a/src/command/history.rs b/src/command/history.rs index 3b4a717c..627efae4 100644 --- a/src/command/history.rs +++ b/src/command/history.rs @@ -53,7 +53,7 @@ fn print_list(h: &[History]) { } impl Cmd { - pub fn run(&self, settings: &Settings, db: &mut impl Database) -> Result<()> { + pub async fn run(&self, settings: &Settings, db: &mut (impl Database + Send)) -> Result<()> { match self { Self::Start { command: words } => { let command = words.join(" "); @@ -69,6 +69,10 @@ impl Cmd { } Self::End { id, exit } => { + if id.trim() == "" { + return Ok(()); + } + let mut h = db.load(id)?; h.exit = *exit; h.duration = chrono::Utc::now().timestamp_nanos() - h.timestamp.timestamp_nanos(); @@ -82,7 +86,7 @@ impl Cmd { } Ok(Fork::Child) => { debug!("running periodic background sync"); - sync::sync(settings, false, db)?; + sync::sync(settings, false, db).await?; } Err(_) => println!("Fork failed"), } diff --git a/src/command/login.rs b/src/command/login.rs index 4f58b77f..636ac0d3 100644 --- a/src/command/login.rs +++ b/src/command/login.rs @@ -2,7 +2,7 @@ use std::collections::HashMap; use std::fs::File; use std::io::prelude::*; -use eyre::Result; +use eyre::{eyre, Result}; use structopt::StructOpt; use crate::settings::Settings; @@ -28,8 +28,13 @@ impl Cmd { let url = format!("{}/login", settings.local.sync_address); let client = reqwest::blocking::Client::new(); + let resp = client.post(url).json(&map).send()?; + if resp.status() != reqwest::StatusCode::OK { + return Err(eyre!("invalid login details")); + } + let session = resp.json::>()?; let session = session["session"].clone(); diff --git a/src/command/mod.rs b/src/command/mod.rs index eeb11a87..cd857e9f 100644 --- a/src/command/mod.rs +++ b/src/command/mod.rs @@ -63,16 +63,16 @@ pub fn uuid_v4() -> String { } impl AtuinCmd { - pub fn run(self, db: &mut impl Database, settings: &Settings) -> Result<()> { + pub async fn run(self, db: &mut T, settings: &Settings) -> Result<()> { match self { - Self::History(history) => history.run(settings, db), + Self::History(history) => history.run(settings, db).await, Self::Import(import) => import.run(db), - Self::Server(server) => server.run(settings), + Self::Server(server) => server.run(settings).await, Self::Stats(stats) => stats.run(db, settings), Self::Init => init::init(), Self::Search { query } => search::run(&query, db), - Self::Sync { force } => sync::run(settings, force, db), + Self::Sync { force } => sync::run(settings, force, db).await, Self::Login(l) => l.run(settings), Self::Register(r) => register::run( settings, diff --git a/src/command/search.rs b/src/command/search.rs index b9f3987c..d7b477da 100644 --- a/src/command/search.rs +++ b/src/command/search.rs @@ -1,6 +1,8 @@ use eyre::Result; use itertools::Itertools; use std::io::stdout; +use std::time::Duration; + use termion::{event::Key, input::MouseTerminal, raw::IntoRawMode, screen::AlternateScreen}; use tui::{ backend::TermionBackend, @@ -26,6 +28,78 @@ struct State { results_state: ListState, } +#[allow(clippy::clippy::cast_sign_loss)] +impl State { + fn durations(&self) -> Vec { + self.results + .iter() + .map(|h| { + let duration = + Duration::from_millis(std::cmp::max(h.duration, 0) as u64 / 1_000_000); + let duration = humantime::format_duration(duration).to_string(); + let duration: Vec<&str> = duration.split(' ').collect(); + + duration[0].to_string() + }) + .collect() + } + + fn render_results( + &mut self, + f: &mut tui::Frame, + r: tui::layout::Rect, + ) { + let durations = self.durations(); + let max_length = durations + .iter() + .fold(0, |largest, i| std::cmp::max(largest, i.len())); + + let results: Vec = self + .results + .iter() + .enumerate() + .map(|(i, m)| { + let command = m.command.to_string().replace("\n", " ").replace("\t", " "); + + let mut command = Span::raw(command); + + let mut duration = durations[i].clone(); + + while duration.len() < max_length { + duration.push(' '); + } + + let duration = Span::styled( + duration, + Style::default().fg(if m.exit == 0 || m.duration == -1 { + Color::Green + } else { + Color::Red + }), + ); + + if let Some(selected) = self.results_state.selected() { + if selected == i { + command.style = + Style::default().fg(Color::Red).add_modifier(Modifier::BOLD); + } + } + + let spans = Spans::from(vec![duration, Span::raw(" "), command]); + + ListItem::new(spans) + }) + .collect(); + + let results = List::new(results) + .block(Block::default().borders(Borders::ALL).title("History")) + .start_corner(Corner::BottomLeft) + .highlight_symbol(">> "); + + f.render_stateful_widget(results, r, &mut self.results_state); + } +} + fn query_results(app: &mut State, db: &mut impl Database) { let results = match app.input.as_str() { "" => db.list(), @@ -48,7 +122,11 @@ fn key_handler(input: Key, db: &mut impl Database, app: &mut State) -> Option { let i = app.results_state.selected().unwrap_or(0); - return Some(app.results.get(i).unwrap().command.clone()); + return Some( + app.results + .get(i) + .map_or("".to_string(), |h| h.command.clone()), + ); } Key::Char(c) => { app.input.push(c); @@ -163,32 +241,8 @@ fn select_history(query: &[String], db: &mut impl Database) -> Result { let help = Text::from(Spans::from(help)); let help = Paragraph::new(help); - let input = Paragraph::new(app.input.as_ref()) - .block(Block::default().borders(Borders::ALL).title("Search")); - - let results: Vec = app - .results - .iter() - .enumerate() - .map(|(i, m)| { - let mut content = - Span::raw(m.command.to_string().replace("\n", " ").replace("\t", " ")); - - if let Some(selected) = app.results_state.selected() { - if selected == i { - content.style = - Style::default().fg(Color::Red).add_modifier(Modifier::BOLD); - } - } - - ListItem::new(content) - }) - .collect(); - - let results = List::new(results) - .block(Block::default().borders(Borders::ALL).title("History")) - .start_corner(Corner::BottomLeft) - .highlight_symbol(">> "); + let input = Paragraph::new(app.input.clone()) + .block(Block::default().borders(Borders::ALL).title("Query")); let stats = Paragraph::new(Text::from(Span::raw(format!( "history count: {}", @@ -199,8 +253,8 @@ fn select_history(query: &[String], db: &mut impl Database) -> Result { f.render_widget(title, top_left_chunks[0]); f.render_widget(help, top_left_chunks[1]); + app.render_results(f, chunks[1]); f.render_widget(stats, top_right_chunks[0]); - f.render_stateful_widget(results, chunks[1], &mut app.results_state); f.render_widget(input, chunks[2]); f.set_cursor( diff --git a/src/command/server.rs b/src/command/server.rs index bf757948..a7835092 100644 --- a/src/command/server.rs +++ b/src/command/server.rs @@ -1,7 +1,7 @@ use eyre::Result; use structopt::StructOpt; -use crate::remote::server; +use crate::server; use crate::settings::Settings; #[derive(StructOpt)] @@ -20,7 +20,7 @@ pub enum Cmd { } impl Cmd { - pub fn run(&self, settings: &Settings) -> Result<()> { + pub async fn run(&self, settings: &Settings) -> Result<()> { match self { Self::Start { host, port } => { let host = host.as_ref().map_or( @@ -29,7 +29,7 @@ impl Cmd { ); let port = port.map_or(settings.server.port, |p| p); - server::launch(settings, host, port) + server::launch(settings, host, port).await } } } diff --git a/src/command/sync.rs b/src/command/sync.rs index facbe578..88217b3c 100644 --- a/src/command/sync.rs +++ b/src/command/sync.rs @@ -4,8 +4,8 @@ use crate::local::database::Database; use crate::local::sync; use crate::settings::Settings; -pub fn run(settings: &Settings, force: bool, db: &mut impl Database) -> Result<()> { - sync::sync(settings, force, db)?; +pub async fn run(settings: &Settings, force: bool, db: &mut (impl Database + Send)) -> Result<()> { + sync::sync(settings, force, db).await?; println!( "Sync complete! {} items in database, force: {}", db.history_count()?, diff --git a/src/local/api_client.rs b/src/local/api_client.rs index 434c07ba..1b64a295 100644 --- a/src/local/api_client.rs +++ b/src/local/api_client.rs @@ -1,93 +1,94 @@ use chrono::Utc; use eyre::Result; -use reqwest::header::AUTHORIZATION; +use reqwest::header::{HeaderMap, AUTHORIZATION}; +use reqwest::Url; +use sodiumoxide::crypto::secretbox; -use crate::api::{AddHistoryRequest, CountResponse, ListHistoryResponse}; -use crate::local::encryption::{decrypt, load_key}; +use crate::api::{AddHistoryRequest, CountResponse, SyncHistoryResponse}; +use crate::local::encryption::decrypt; use crate::local::history::History; -use crate::settings::Settings; use crate::utils::hash_str; pub struct Client<'a> { - settings: &'a Settings, + sync_addr: &'a str, + token: &'a str, + key: secretbox::Key, + client: reqwest::Client, } impl<'a> Client<'a> { - pub const fn new(settings: &'a Settings) -> Self { - Client { settings } + pub fn new(sync_addr: &'a str, token: &'a str, key: secretbox::Key) -> Self { + Client { + sync_addr, + token, + key, + client: reqwest::Client::new(), + } } - pub fn count(&self) -> Result { - let url = format!("{}/sync/count", self.settings.local.sync_address); - let client = reqwest::blocking::Client::new(); + pub async fn count(&self) -> Result { + let url = format!("{}/sync/count", self.sync_addr); + let url = Url::parse(url.as_str())?; + let token = format!("Token {}", self.token); + let token = token.parse()?; - let resp = client - .get(url) - .header( - AUTHORIZATION, - format!("Token {}", self.settings.local.session_token), - ) - .send()?; + let mut headers = HeaderMap::new(); + headers.insert(AUTHORIZATION, token); + + let resp = self.client.get(url).headers(headers).send().await?; - let count = resp.json::()?; + let count = resp.json::().await?; Ok(count.count) } - pub fn get_history( + pub async fn get_history( &self, sync_ts: chrono::DateTime, history_ts: chrono::DateTime, host: Option, ) -> Result> { - let key = load_key(self.settings)?; - let host = match host { None => hash_str(&format!("{}:{}", whoami::hostname(), whoami::username())), Some(h) => h, }; - // this allows for syncing between users on the same machine let url = format!( "{}/sync/history?sync_ts={}&history_ts={}&host={}", - self.settings.local.sync_address, - sync_ts.to_rfc3339(), - history_ts.to_rfc3339(), + self.sync_addr, + urlencoding::encode(sync_ts.to_rfc3339().as_str()), + urlencoding::encode(history_ts.to_rfc3339().as_str()), host, ); - let client = reqwest::blocking::Client::new(); - let resp = client + let resp = self + .client .get(url) - .header( - AUTHORIZATION, - format!("Token {}", self.settings.local.session_token), - ) - .send()?; + .header(AUTHORIZATION, format!("Token {}", self.token)) + .send() + .await?; - let history = resp.json::()?; + let history = resp.json::().await?; let history = history .history .iter() .map(|h| serde_json::from_str(h).expect("invalid base64")) - .map(|h| decrypt(&h, &key).expect("failed to decrypt history! check your key")) + .map(|h| decrypt(&h, &self.key).expect("failed to decrypt history! check your key")) .collect(); Ok(history) } - pub fn post_history(&self, history: &[AddHistoryRequest]) -> Result<()> { - let client = reqwest::blocking::Client::new(); + pub async fn post_history(&self, history: &[AddHistoryRequest]) -> Result<()> { + let url = format!("{}/history", self.sync_addr); + let url = Url::parse(url.as_str())?; - let url = format!("{}/history", self.settings.local.sync_address); - client + self.client .post(url) .json(history) - .header( - AUTHORIZATION, - format!("Token {}", self.settings.local.session_token), - ) - .send()?; + .header(AUTHORIZATION, format!("Token {}", self.token)) + .send() + .await?; Ok(()) } diff --git a/src/local/database.rs b/src/local/database.rs index 977f11cc..abc22bb8 100644 --- a/src/local/database.rs +++ b/src/local/database.rs @@ -215,9 +215,9 @@ impl Database for Sqlite { } fn before(&self, timestamp: chrono::DateTime, count: i64) -> Result> { - let mut stmt = self.conn.prepare( - "SELECT * FROM history where timestamp <= ? order by timestamp desc limit ?", - )?; + let mut stmt = self + .conn + .prepare("SELECT * FROM history where timestamp < ? order by timestamp desc limit ?")?; let history_iter = stmt.query_map(params![timestamp.timestamp_nanos(), count], |row| { history_from_sqlite_row(None, row) @@ -236,7 +236,7 @@ impl Database for Sqlite { fn prefix_search(&self, query: &str) -> Result> { self.query( - "select * from history where command like ?1 || '%' order by timestamp asc", + "select * from history where command like ?1 || '%' order by timestamp asc limit 1000", &[query], ) } diff --git a/src/local/import.rs b/src/local/import.rs index d0f679c9..3b0b2a69 100644 --- a/src/local/import.rs +++ b/src/local/import.rs @@ -7,6 +7,7 @@ use std::{fs::File, path::Path}; use chrono::prelude::*; use chrono::Utc; use eyre::{eyre, Result}; +use itertools::Itertools; use super::history::History; @@ -42,8 +43,8 @@ impl Zsh { fn parse_extended(line: &str, counter: i64) -> History { let line = line.replacen(": ", "", 2); - let (time, duration) = line.split_once(':').unwrap(); - let (duration, command) = duration.split_once(';').unwrap(); + let (time, duration) = line.splitn(2, ':').collect_tuple().unwrap(); + let (duration, command) = duration.splitn(2, ';').collect_tuple().unwrap(); let time = time .parse::() @@ -60,7 +61,7 @@ fn parse_extended(line: &str, counter: i64) -> History { time, command.trim_end().to_string(), String::from("unknown"), - -1, + 0, // assume 0, we have no way of knowing :( duration, None, None, diff --git a/src/local/sync.rs b/src/local/sync.rs index c22d2f27..e0feb759 100644 --- a/src/local/sync.rs +++ b/src/local/sync.rs @@ -20,12 +20,12 @@ use crate::{api::AddHistoryRequest, utils::hash_str}; // Check if remote has things we don't, and if so, download them. // Returns (num downloaded, total local) -fn sync_download( +async fn sync_download( force: bool, - client: &api_client::Client, - db: &mut impl Database, + client: &api_client::Client<'_>, + db: &mut (impl Database + Send), ) -> Result<(i64, i64)> { - let remote_count = client.count()?; + let remote_count = client.count().await?; let initial_local = db.history_count()?; let mut local_count = initial_local; @@ -41,7 +41,9 @@ fn sync_download( let host = if force { Some(String::from("")) } else { None }; while remote_count > local_count { - let page = client.get_history(last_sync, last_timestamp, host.clone())?; + let page = client + .get_history(last_sync, last_timestamp, host.clone()) + .await?; if page.len() < HISTORY_PAGE_SIZE.try_into().unwrap() { break; @@ -71,13 +73,13 @@ fn sync_download( } // Check if we have things remote doesn't, and if so, upload them -fn sync_upload( +async fn sync_upload( settings: &Settings, _force: bool, - client: &api_client::Client, - db: &mut impl Database, + client: &api_client::Client<'_>, + db: &mut (impl Database + Send), ) -> Result<()> { - let initial_remote_count = client.count()?; + let initial_remote_count = client.count().await?; let mut remote_count = initial_remote_count; let local_count = db.history_count()?; @@ -111,21 +113,25 @@ fn sync_upload( } // anything left over outside of the 100 block size - client.post_history(&buffer)?; + client.post_history(&buffer).await?; cursor = buffer.last().unwrap().timestamp; - remote_count = client.count()?; + remote_count = client.count().await?; } Ok(()) } -pub fn sync(settings: &Settings, force: bool, db: &mut impl Database) -> Result<()> { - let client = api_client::Client::new(settings); +pub async fn sync(settings: &Settings, force: bool, db: &mut (impl Database + Send)) -> Result<()> { + let client = api_client::Client::new( + settings.local.sync_address.as_str(), + settings.local.session_token.as_str(), + load_key(settings)?, + ); - sync_upload(settings, force, &client, db)?; + sync_upload(settings, force, &client, db).await?; - let download = sync_download(force, &client, db)?; + let download = sync_download(force, &client, db).await?; debug!("sync downloaded {}", download.0); diff --git a/src/main.rs b/src/main.rs index 94c7366d..0045a943 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,32 +1,19 @@ -#![feature(proc_macro_hygiene)] -#![feature(decl_macro)] #![warn(clippy::pedantic, clippy::nursery)] #![allow(clippy::use_self)] // not 100% reliable use std::path::PathBuf; use eyre::{eyre, Result}; +use fern::colors::{Color, ColoredLevelConfig}; use human_panic::setup_panic; use structopt::{clap::AppSettings, StructOpt}; #[macro_use] extern crate log; -#[macro_use] -extern crate rocket; - #[macro_use] extern crate serde_derive; -#[macro_use] -extern crate diesel; - -#[macro_use] -extern crate diesel_migrations; - -#[macro_use] -extern crate rocket_contrib; - use command::AtuinCmd; use local::database::Sqlite; use settings::Settings; @@ -34,12 +21,10 @@ use settings::Settings; mod api; mod command; mod local; -mod remote; +mod server; mod settings; mod utils; -pub mod schema; - #[derive(StructOpt)] #[structopt( author = "Ellie Huxtable ", @@ -56,7 +41,7 @@ struct Atuin { } impl Atuin { - fn run(self, settings: &Settings) -> Result<()> { + async fn run(self, settings: &Settings) -> Result<()> { let db_path = if let Some(db_path) = self.db { let path = db_path .to_str() @@ -69,26 +54,32 @@ impl Atuin { let mut db = Sqlite::new(db_path)?; - self.atuin.run(&mut db, settings) + self.atuin.run(&mut db, settings).await } } -fn main() -> Result<()> { - setup_panic!(); - let settings = Settings::new()?; +#[tokio::main] +async fn main() -> Result<()> { + let colors = ColoredLevelConfig::new() + .warn(Color::Yellow) + .error(Color::Red); fern::Dispatch::new() - .format(|out, message, record| { + .format(move |out, message, record| { out.finish(format_args!( "{} [{}] {}", - chrono::Local::now().format("[%Y-%m-%d][%H:%M:%S]"), - record.level(), + chrono::Local::now().to_rfc3339(), + colors.color(record.level()), message )) }) .level(log::LevelFilter::Info) + .level_for("sqlx", log::LevelFilter::Warn) .chain(std::io::stdout()) .apply()?; - Atuin::from_args().run(&settings) + let settings = Settings::new()?; + setup_panic!(); + + Atuin::from_args().run(&settings).await } diff --git a/src/remote/auth.rs b/src/remote/auth.rs deleted file mode 100644 index cf61b077..00000000 --- a/src/remote/auth.rs +++ /dev/null @@ -1,220 +0,0 @@ -use self::diesel::prelude::*; -use eyre::Result; -use rocket::http::Status; -use rocket::request::{self, FromRequest, Outcome, Request}; -use rocket::State; -use rocket_contrib::databases::diesel; -use sodiumoxide::crypto::pwhash::argon2id13; - -use rocket_contrib::json::Json; -use uuid::Uuid; - -use super::models::{NewSession, NewUser, Session, User}; -use super::views::ApiResponse; - -use crate::api::{LoginRequest, RegisterRequest}; -use crate::schema::{sessions, users}; -use crate::settings::Settings; -use crate::utils::hash_secret; - -use super::database::AtuinDbConn; - -#[derive(Debug)] -pub enum KeyError { - Missing, - Invalid, -} - -pub fn verify_str(secret: &str, verify: &str) -> bool { - sodiumoxide::init().unwrap(); - - let mut padded = [0_u8; 128]; - secret.as_bytes().iter().enumerate().for_each(|(i, val)| { - padded[i] = *val; - }); - - match argon2id13::HashedPassword::from_slice(&padded) { - Some(hp) => argon2id13::pwhash_verify(&hp, verify.as_bytes()), - None => false, - } -} - -impl<'a, 'r> FromRequest<'a, 'r> for User { - type Error = KeyError; - - fn from_request(request: &'a Request<'r>) -> request::Outcome { - let session: Vec<_> = request.headers().get("authorization").collect(); - - if session.is_empty() { - return Outcome::Failure((Status::BadRequest, KeyError::Missing)); - } else if session.len() > 1 { - return Outcome::Failure((Status::BadRequest, KeyError::Invalid)); - } - - let session: Vec<_> = session[0].split(' ').collect(); - - if session.len() != 2 { - return Outcome::Failure((Status::BadRequest, KeyError::Invalid)); - } - - if session[0] != "Token" { - return Outcome::Failure((Status::BadRequest, KeyError::Invalid)); - } - - let session = session[1]; - - let db = request - .guard::() - .succeeded() - .expect("failed to load database"); - - let session = sessions::table - .filter(sessions::token.eq(session)) - .first::(&*db); - - if session.is_err() { - return Outcome::Failure((Status::Unauthorized, KeyError::Invalid)); - } - - let session = session.unwrap(); - - let user = users::table.find(session.user_id).first(&*db); - - match user { - Ok(user) => Outcome::Success(user), - Err(_) => Outcome::Failure((Status::Unauthorized, KeyError::Invalid)), - } - } -} - -#[get("/user/")] -#[allow(clippy::clippy::needless_pass_by_value)] -pub fn get_user(user: String, conn: AtuinDbConn) -> ApiResponse { - use crate::schema::users::dsl::{username, users}; - - let user: Result = users - .select(username) - .filter(username.eq(user)) - .first(&*conn); - - if user.is_err() { - return ApiResponse { - json: json!({ - "message": "could not find user", - }), - status: Status::NotFound, - }; - } - - let user = user.unwrap(); - - ApiResponse { - json: json!({ "username": user.as_str() }), - status: Status::Ok, - } -} - -#[post("/register", data = "")] -#[allow(clippy::clippy::needless_pass_by_value)] -pub fn register( - conn: AtuinDbConn, - register: Json, - settings: State, -) -> ApiResponse { - if !settings.server.open_registration { - return ApiResponse { - status: Status::BadRequest, - json: json!({ - "message": "registrations are not open" - }), - }; - } - - let hashed = hash_secret(register.password.as_str()); - - let new_user = NewUser { - email: register.email.as_str(), - username: register.username.as_str(), - password: hashed.as_str(), - }; - - let user = diesel::insert_into(users::table) - .values(&new_user) - .get_result(&*conn); - - if user.is_err() { - return ApiResponse { - status: Status::BadRequest, - json: json!({ - "message": "failed to create user - username or email in use?", - }), - }; - } - - let user: User = user.unwrap(); - let token = Uuid::new_v4().to_simple().to_string(); - - let new_session = NewSession { - user_id: user.id, - token: token.as_str(), - }; - - match diesel::insert_into(sessions::table) - .values(&new_session) - .execute(&*conn) - { - Ok(_) => ApiResponse { - status: Status::Ok, - json: json!({"message": "user created!", "session": token}), - }, - Err(_) => ApiResponse { - status: Status::BadRequest, - json: json!({ "message": "failed to create user"}), - }, - } -} - -#[post("/login", data = "")] -#[allow(clippy::clippy::needless_pass_by_value)] -pub fn login(conn: AtuinDbConn, login: Json) -> ApiResponse { - let user = users::table - .filter(users::username.eq(login.username.as_str())) - .first(&*conn); - - if user.is_err() { - return ApiResponse { - status: Status::NotFound, - json: json!({"message": "user not found"}), - }; - } - - let user: User = user.unwrap(); - - let session = sessions::table - .filter(sessions::user_id.eq(user.id)) - .first(&*conn); - - // a session should exist... - if session.is_err() { - return ApiResponse { - status: Status::InternalServerError, - json: json!({"message": "something went wrong"}), - }; - } - - let verified = verify_str(user.password.as_str(), login.password.as_str()); - - if !verified { - return ApiResponse { - status: Status::NotFound, - json: json!({"message": "user not found"}), - }; - } - - let session: Session = session.unwrap(); - - ApiResponse { - status: Status::Ok, - json: json!({"session": session.token}), - } -} diff --git a/src/remote/database.rs b/src/remote/database.rs deleted file mode 100644 index 03973ca1..00000000 --- a/src/remote/database.rs +++ /dev/null @@ -1,22 +0,0 @@ -use diesel::pg::PgConnection; -use diesel::prelude::*; -use eyre::{eyre, Result}; - -use crate::settings::Settings; - -#[database("atuin")] -pub struct AtuinDbConn(diesel::PgConnection); - -// TODO: connection pooling -pub fn establish_connection(settings: &Settings) -> Result { - if settings.server.db_uri == "default_uri" { - Err(eyre!( - "Please configure your database! Set db_uri in config.toml" - )) - } else { - let database_url = &settings.server.db_uri; - let conn = PgConnection::establish(database_url)?; - - Ok(conn) - } -} diff --git a/src/remote/mod.rs b/src/remote/mod.rs deleted file mode 100644 index 7147b88e..00000000 --- a/src/remote/mod.rs +++ /dev/null @@ -1,5 +0,0 @@ -pub mod auth; -pub mod database; -pub mod models; -pub mod server; -pub mod views; diff --git a/src/remote/models.rs b/src/remote/models.rs deleted file mode 100644 index 7f6f7766..00000000 --- a/src/remote/models.rs +++ /dev/null @@ -1,60 +0,0 @@ -use chrono::prelude::*; - -use crate::schema::{history, sessions, users}; - -#[derive(Deserialize, Serialize, Identifiable, Queryable, Associations)] -#[table_name = "history"] -#[belongs_to(User)] -pub struct History { - pub id: i64, - pub client_id: String, // a client generated ID - pub user_id: i64, - pub hostname: String, - pub timestamp: NaiveDateTime, - - pub data: String, - - pub created_at: NaiveDateTime, -} - -#[derive(Identifiable, Queryable, Associations)] -pub struct User { - pub id: i64, - pub username: String, - pub email: String, - pub password: String, -} - -#[derive(Queryable, Identifiable, Associations)] -#[belongs_to(User)] -pub struct Session { - pub id: i64, - pub user_id: i64, - pub token: String, -} - -#[derive(Insertable)] -#[table_name = "history"] -pub struct NewHistory<'a> { - pub client_id: &'a str, - pub user_id: i64, - pub hostname: String, - pub timestamp: chrono::NaiveDateTime, - - pub data: &'a str, -} - -#[derive(Insertable)] -#[table_name = "users"] -pub struct NewUser<'a> { - pub username: &'a str, - pub email: &'a str, - pub password: &'a str, -} - -#[derive(Insertable)] -#[table_name = "sessions"] -pub struct NewSession<'a> { - pub user_id: i64, - pub token: &'a str, -} diff --git a/src/remote/server.rs b/src/remote/server.rs deleted file mode 100644 index ee481ca4..00000000 --- a/src/remote/server.rs +++ /dev/null @@ -1,61 +0,0 @@ -use std::collections::HashMap; - -use crate::remote::database::establish_connection; -use crate::settings::Settings; - -use super::database::AtuinDbConn; - -use eyre::Result; -use rocket::config::{Config, Environment, LoggingLevel, Value}; - -// a bunch of these imports are generated by macros, it's easier to wildcard -#[allow(clippy::clippy::wildcard_imports)] -use super::views::*; - -#[allow(clippy::clippy::wildcard_imports)] -use super::auth::*; - -embed_migrations!("migrations"); - -pub fn launch(settings: &Settings, host: String, port: u16) -> Result<()> { - let settings: Settings = settings.clone(); // clone so rocket can manage it - - let mut database_config = HashMap::new(); - let mut databases = HashMap::new(); - - database_config.insert("url", Value::from(settings.server.db_uri.clone())); - databases.insert("atuin", Value::from(database_config)); - - let connection = establish_connection(&settings)?; - - embedded_migrations::run(&connection).expect("failed to run migrations"); - - let config = Config::build(Environment::Production) - .address(host) - .log_level(LoggingLevel::Normal) - .port(port) - .extra("databases", databases) - .finalize() - .unwrap(); - - let app = rocket::custom(config); - - app.mount( - "/", - routes![ - index, - register, - add_history, - login, - get_user, - sync_count, - sync_list - ], - ) - .manage(settings) - .attach(AtuinDbConn::fairing()) - .register(catchers![internal_error, bad_request]) - .launch(); - - Ok(()) -} diff --git a/src/remote/views.rs b/src/remote/views.rs deleted file mode 100644 index 08dff13e..00000000 --- a/src/remote/views.rs +++ /dev/null @@ -1,185 +0,0 @@ -use chrono::Utc; -use rocket::http::uri::Uri; -use rocket::http::RawStr; -use rocket::http::{ContentType, Status}; -use rocket::request::FromFormValue; -use rocket::request::Request; -use rocket::response; -use rocket::response::{Responder, Response}; -use rocket_contrib::databases::diesel; -use rocket_contrib::json::{Json, JsonValue}; - -use self::diesel::prelude::*; - -use crate::api::AddHistoryRequest; -use crate::schema::history; -use crate::settings::HISTORY_PAGE_SIZE; - -use super::database::AtuinDbConn; -use super::models::{History, NewHistory, User}; - -#[derive(Debug)] -pub struct ApiResponse { - pub json: JsonValue, - pub status: Status, -} - -impl<'r> Responder<'r> for ApiResponse { - fn respond_to(self, req: &Request) -> response::Result<'r> { - Response::build_from(self.json.respond_to(req).unwrap()) - .status(self.status) - .header(ContentType::JSON) - .ok() - } -} - -#[get("/")] -pub const fn index() -> &'static str { - "\"Through the fathomless deeps of space swims the star turtle Great A\u{2019}Tuin, bearing on its back the four giant elephants who carry on their shoulders the mass of the Discworld.\"\n\t-- Sir Terry Pratchett" -} - -#[catch(500)] -pub fn internal_error(_req: &Request) -> ApiResponse { - ApiResponse { - status: Status::InternalServerError, - json: json!({"status": "error", "message": "an internal server error has occured"}), - } -} - -#[catch(400)] -pub fn bad_request(_req: &Request) -> ApiResponse { - ApiResponse { - status: Status::InternalServerError, - json: json!({"status": "error", "message": "bad request. don't do that."}), - } -} - -#[post("/history", data = "")] -#[allow( - clippy::clippy::cast_sign_loss, - clippy::cast_possible_truncation, - clippy::clippy::needless_pass_by_value -)] -pub fn add_history( - conn: AtuinDbConn, - user: User, - add_history: Json>, -) -> ApiResponse { - let new_history: Vec = add_history - .iter() - .map(|h| NewHistory { - client_id: h.id.as_str(), - hostname: h.hostname.to_string(), - user_id: user.id, - timestamp: h.timestamp.naive_utc(), - data: h.data.as_str(), - }) - .collect(); - - match diesel::insert_into(history::table) - .values(&new_history) - .on_conflict_do_nothing() - .execute(&*conn) - { - Ok(_) => ApiResponse { - status: Status::Ok, - json: json!({"status": "ok", "message": "history added"}), - }, - Err(_) => ApiResponse { - status: Status::BadRequest, - json: json!({"status": "error", "message": "failed to add history"}), - }, - } -} - -#[get("/sync/count")] -#[allow(clippy::wildcard_imports, clippy::needless_pass_by_value)] -pub fn sync_count(conn: AtuinDbConn, user: User) -> ApiResponse { - use crate::schema::history::dsl::*; - - // we need to return the number of history items we have for this user - // in the future I'd like to use something like a merkel tree to calculate - // which day specifically needs syncing - let count = history - .filter(user_id.eq(user.id)) - .count() - .first::(&*conn); - - if count.is_err() { - error!("failed to count: {}", count.err().unwrap()); - - return ApiResponse { - json: json!({"message": "internal server error"}), - status: Status::InternalServerError, - }; - } - - ApiResponse { - status: Status::Ok, - json: json!({"count": count.ok()}), - } -} - -pub struct UtcDateTime(chrono::DateTime); - -impl<'v> FromFormValue<'v> for UtcDateTime { - type Error = &'v RawStr; - - fn from_form_value(form_value: &'v RawStr) -> Result { - let time = Uri::percent_decode(form_value.as_bytes()).map_err(|_| form_value)?; - let time = time.to_string(); - - match chrono::DateTime::parse_from_rfc3339(time.as_str()) { - Ok(t) => Ok(UtcDateTime(t.with_timezone(&Utc))), - Err(e) => { - error!("failed to parse time {}, got: {}", time, e); - Err(form_value) - } - } - } -} - -// Request a list of all history items added to the DB after a given timestamp. -// Provide the current hostname, so that we don't send the client data that -// originated from them -#[get("/sync/history?&&")] -#[allow(clippy::wildcard_imports, clippy::needless_pass_by_value)] -pub fn sync_list( - conn: AtuinDbConn, - user: User, - sync_ts: UtcDateTime, - history_ts: UtcDateTime, - host: String, -) -> ApiResponse { - use crate::schema::history::dsl::*; - - // we need to return the number of history items we have for this user - // in the future I'd like to use something like a merkel tree to calculate - // which day specifically needs syncing - // TODO: Allow for configuring the page size, both from params, and setting - // the max in config. 100 is fine for now. - let h = history - .filter(user_id.eq(user.id)) - .filter(hostname.ne(host)) - .filter(created_at.ge(sync_ts.0.naive_utc())) - .filter(timestamp.ge(history_ts.0.naive_utc())) - .order(timestamp.asc()) - .limit(HISTORY_PAGE_SIZE) - .load::(&*conn); - - if let Err(e) = h { - error!("failed to load history: {}", e); - - return ApiResponse { - json: json!({"message": "internal server error"}), - status: Status::InternalServerError, - }; - } - - let user_data: Vec = h.unwrap().iter().map(|i| i.data.to_string()).collect(); - - ApiResponse { - status: Status::Ok, - json: json!({ "history": user_data }), - } -} diff --git a/src/schema.rs b/src/schema.rs deleted file mode 100644 index 84bf5bab..00000000 --- a/src/schema.rs +++ /dev/null @@ -1,30 +0,0 @@ -table! { - history (id) { - id -> Int8, - client_id -> Text, - user_id -> Int8, - hostname -> Text, - timestamp -> Timestamp, - data -> Varchar, - created_at -> Timestamp, - } -} - -table! { - sessions (id) { - id -> Int8, - user_id -> Int8, - token -> Varchar, - } -} - -table! { - users (id) { - id -> Int8, - username -> Varchar, - email -> Varchar, - password -> Varchar, - } -} - -allow_tables_to_appear_in_same_query!(history, sessions, users,); diff --git a/src/server/auth.rs b/src/server/auth.rs new file mode 100644 index 00000000..52a73108 --- /dev/null +++ b/src/server/auth.rs @@ -0,0 +1,222 @@ +/* +use self::diesel::prelude::*; +use eyre::Result; +use rocket::http::Status; +use rocket::request::{self, FromRequest, Outcome, Request}; +use rocket::State; +use rocket_contrib::databases::diesel; +use sodiumoxide::crypto::pwhash::argon2id13; + +use rocket_contrib::json::Json; +use uuid::Uuid; + +use super::models::{NewSession, NewUser, Session, User}; +use super::views::ApiResponse; + +use crate::api::{LoginRequest, RegisterRequest}; +use crate::schema::{sessions, users}; +use crate::settings::Settings; +use crate::utils::hash_secret; + +use super::database::AtuinDbConn; + +#[derive(Debug)] +pub enum KeyError { + Missing, + Invalid, +} + +pub fn verify_str(secret: &str, verify: &str) -> bool { + sodiumoxide::init().unwrap(); + + let mut padded = [0_u8; 128]; + secret.as_bytes().iter().enumerate().for_each(|(i, val)| { + padded[i] = *val; + }); + + match argon2id13::HashedPassword::from_slice(&padded) { + Some(hp) => argon2id13::pwhash_verify(&hp, verify.as_bytes()), + None => false, + } +} + +impl<'a, 'r> FromRequest<'a, 'r> for User { + type Error = KeyError; + + fn from_request(request: &'a Request<'r>) -> request::Outcome { + let session: Vec<_> = request.headers().get("authorization").collect(); + + if session.is_empty() { + return Outcome::Failure((Status::BadRequest, KeyError::Missing)); + } else if session.len() > 1 { + return Outcome::Failure((Status::BadRequest, KeyError::Invalid)); + } + + let session: Vec<_> = session[0].split(' ').collect(); + + if session.len() != 2 { + return Outcome::Failure((Status::BadRequest, KeyError::Invalid)); + } + + if session[0] != "Token" { + return Outcome::Failure((Status::BadRequest, KeyError::Invalid)); + } + + let session = session[1]; + + let db = request + .guard::() + .succeeded() + .expect("failed to load database"); + + let session = sessions::table + .filter(sessions::token.eq(session)) + .first::(&*db); + + if session.is_err() { + return Outcome::Failure((Status::Unauthorized, KeyError::Invalid)); + } + + let session = session.unwrap(); + + let user = users::table.find(session.user_id).first(&*db); + + match user { + Ok(user) => Outcome::Success(user), + Err(_) => Outcome::Failure((Status::Unauthorized, KeyError::Invalid)), + } + } +} + +#[get("/user/")] +#[allow(clippy::clippy::needless_pass_by_value)] +pub fn get_user(user: String, conn: AtuinDbConn) -> ApiResponse { + use crate::schema::users::dsl::{username, users}; + + let user: Result = users + .select(username) + .filter(username.eq(user)) + .first(&*conn); + + if user.is_err() { + return ApiResponse { + json: json!({ + "message": "could not find user", + }), + status: Status::NotFound, + }; + } + + let user = user.unwrap(); + + ApiResponse { + json: json!({ "username": user.as_str() }), + status: Status::Ok, + } +} + +#[post("/register", data = "")] +#[allow(clippy::clippy::needless_pass_by_value)] +pub fn register( + conn: AtuinDbConn, + register: Json, + settings: State, +) -> ApiResponse { + if !settings.server.open_registration { + return ApiResponse { + status: Status::BadRequest, + json: json!({ + "message": "registrations are not open" + }), + }; + } + + let hashed = hash_secret(register.password.as_str()); + + let new_user = NewUser { + email: register.email.as_str(), + username: register.username.as_str(), + password: hashed.as_str(), + }; + + let user = diesel::insert_into(users::table) + .values(&new_user) + .get_result(&*conn); + + if user.is_err() { + return ApiResponse { + status: Status::BadRequest, + json: json!({ + "message": "failed to create user - username or email in use?", + }), + }; + } + + let user: User = user.unwrap(); + let token = Uuid::new_v4().to_simple().to_string(); + + let new_session = NewSession { + user_id: user.id, + token: token.as_str(), + }; + + match diesel::insert_into(sessions::table) + .values(&new_session) + .execute(&*conn) + { + Ok(_) => ApiResponse { + status: Status::Ok, + json: json!({"message": "user created!", "session": token}), + }, + Err(_) => ApiResponse { + status: Status::BadRequest, + json: json!({ "message": "failed to create user"}), + }, + } +} + +#[post("/login", data = "")] +#[allow(clippy::clippy::needless_pass_by_value)] +pub fn login(conn: AtuinDbConn, login: Json) -> ApiResponse { + let user = users::table + .filter(users::username.eq(login.username.as_str())) + .first(&*conn); + + if user.is_err() { + return ApiResponse { + status: Status::NotFound, + json: json!({"message": "user not found"}), + }; + } + + let user: User = user.unwrap(); + + let session = sessions::table + .filter(sessions::user_id.eq(user.id)) + .first(&*conn); + + // a session should exist... + if session.is_err() { + return ApiResponse { + status: Status::InternalServerError, + json: json!({"message": "something went wrong"}), + }; + } + + let verified = verify_str(user.password.as_str(), login.password.as_str()); + + if !verified { + return ApiResponse { + status: Status::NotFound, + json: json!({"message": "user not found"}), + }; + } + + let session: Session = session.unwrap(); + + ApiResponse { + status: Status::Ok, + json: json!({"session": session.token}), + } +} +*/ diff --git a/src/server/database.rs b/src/server/database.rs new file mode 100644 index 00000000..5945baaf --- /dev/null +++ b/src/server/database.rs @@ -0,0 +1,202 @@ +use async_trait::async_trait; + +use eyre::{eyre, Result}; +use sqlx::postgres::PgPoolOptions; + +use crate::settings::HISTORY_PAGE_SIZE; + +use super::models::{History, NewHistory, NewSession, NewUser, Session, User}; + +#[async_trait] +pub trait Database { + async fn get_session(&self, token: &str) -> Result; + async fn get_session_user(&self, token: &str) -> Result; + async fn add_session(&self, session: &NewSession) -> Result<()>; + + async fn get_user(&self, username: String) -> Result; + async fn get_user_session(&self, u: &User) -> Result; + async fn add_user(&self, user: NewUser) -> Result; + + async fn count_history(&self, user: &User) -> Result; + async fn list_history( + &self, + user: &User, + created_since: chrono::NaiveDateTime, + since: chrono::NaiveDateTime, + host: String, + ) -> Result>; + async fn add_history(&self, history: &[NewHistory]) -> Result<()>; +} + +#[derive(Clone)] +pub struct Postgres { + pool: sqlx::Pool, +} + +impl Postgres { + pub async fn new(uri: &str) -> Result { + let pool = PgPoolOptions::new() + .max_connections(100) + .connect(uri) + .await?; + + Ok(Self { pool }) + } +} + +#[async_trait] +impl Database for Postgres { + async fn get_session(&self, token: &str) -> Result { + let res: Option = + sqlx::query_as::<_, Session>("select * from sessions where token = $1") + .bind(token) + .fetch_optional(&self.pool) + .await?; + + if let Some(s) = res { + Ok(s) + } else { + Err(eyre!("could not find session")) + } + } + + async fn get_user(&self, username: String) -> Result { + let res: Option = + sqlx::query_as::<_, User>("select * from users where username = $1") + .bind(username) + .fetch_optional(&self.pool) + .await?; + + if let Some(u) = res { + Ok(u) + } else { + Err(eyre!("could not find user")) + } + } + + async fn get_session_user(&self, token: &str) -> Result { + let res: Option = sqlx::query_as::<_, User>( + "select * from users + inner join sessions + on users.id = sessions.user_id + and sessions.token = $1", + ) + .bind(token) + .fetch_optional(&self.pool) + .await?; + + if let Some(u) = res { + Ok(u) + } else { + Err(eyre!("could not find user")) + } + } + + async fn count_history(&self, user: &User) -> Result { + let res: (i64,) = sqlx::query_as( + "select count(1) from history + where user_id = $1", + ) + .bind(user.id) + .fetch_one(&self.pool) + .await?; + + Ok(res.0) + } + + async fn list_history( + &self, + user: &User, + created_since: chrono::NaiveDateTime, + since: chrono::NaiveDateTime, + host: String, + ) -> Result> { + let res = sqlx::query_as::<_, History>( + "select * from history + where user_id = $1 + and hostname != $2 + and created_at >= $3 + and timestamp >= $4 + order by timestamp asc + limit $5", + ) + .bind(user.id) + .bind(host) + .bind(created_since) + .bind(since) + .bind(HISTORY_PAGE_SIZE) + .fetch_all(&self.pool) + .await?; + + Ok(res) + } + + async fn add_history(&self, history: &[NewHistory]) -> Result<()> { + let mut tx = self.pool.begin().await?; + + for i in history { + sqlx::query( + "insert into history + (client_id, user_id, hostname, timestamp, data) + values ($1, $2, $3, $4, $5) + on conflict do nothing + ", + ) + .bind(i.client_id) + .bind(i.user_id) + .bind(i.hostname) + .bind(i.timestamp) + .bind(i.data) + .execute(&mut tx) + .await?; + } + + tx.commit().await?; + + Ok(()) + } + + async fn add_user(&self, user: NewUser) -> Result { + let res: (i64,) = sqlx::query_as( + "insert into users + (username, email, password) + values($1, $2, $3) + returning id", + ) + .bind(user.username.as_str()) + .bind(user.email.as_str()) + .bind(user.password) + .fetch_one(&self.pool) + .await?; + + Ok(res.0) + } + + async fn add_session(&self, session: &NewSession) -> Result<()> { + sqlx::query( + "insert into sessions + (user_id, token) + values($1, $2)", + ) + .bind(session.user_id) + .bind(session.token) + .execute(&self.pool) + .await?; + + Ok(()) + } + + async fn get_user_session(&self, u: &User) -> Result { + let res: Option = + sqlx::query_as::<_, Session>("select * from sessions where user_id = $1") + .bind(u.id) + .fetch_optional(&self.pool) + .await?; + + if let Some(s) = res { + Ok(s) + } else { + Err(eyre!("could not find session")) + } + } +} diff --git a/src/server/handlers/history.rs b/src/server/handlers/history.rs new file mode 100644 index 00000000..4fd6f03f --- /dev/null +++ b/src/server/handlers/history.rs @@ -0,0 +1,89 @@ +use std::convert::Infallible; + +use warp::{http::StatusCode, reply::json}; + +use crate::api::{ + AddHistoryRequest, CountResponse, ErrorResponse, SyncHistoryRequest, SyncHistoryResponse, +}; +use crate::server::database::Database; +use crate::server::models::{NewHistory, User}; + +pub async fn count( + user: User, + db: impl Database + Clone + Send + Sync, +) -> Result, Infallible> { + db.count_history(&user).await.map_or( + Ok(Box::new(ErrorResponse::reply( + "failed to query history count", + StatusCode::INTERNAL_SERVER_ERROR, + ))), + |count| Ok(Box::new(json(&CountResponse { count }))), + ) +} + +pub async fn list( + req: SyncHistoryRequest, + user: User, + db: impl Database + Clone + Send + Sync, +) -> Result, Infallible> { + let history = db + .list_history( + &user, + req.sync_ts.naive_utc(), + req.history_ts.naive_utc(), + req.host, + ) + .await; + + if let Err(e) = history { + error!("failed to load history: {}", e); + let resp = + ErrorResponse::reply("failed to load history", StatusCode::INTERNAL_SERVER_ERROR); + let resp = Box::new(resp); + return Ok(resp); + } + + let history: Vec = history + .unwrap() + .iter() + .map(|i| i.data.to_string()) + .collect(); + + debug!( + "loaded {} items of history for user {}", + history.len(), + user.id + ); + + Ok(Box::new(json(&SyncHistoryResponse { history }))) +} + +pub async fn add( + req: Vec, + user: User, + db: impl Database + Clone + Send + Sync, +) -> Result, Infallible> { + debug!("request to add {} history items", req.len()); + + let history: Vec = req + .iter() + .map(|h| NewHistory { + client_id: h.id.as_str(), + user_id: user.id, + hostname: h.hostname.as_str(), + timestamp: h.timestamp.naive_utc(), + data: h.data.as_str(), + }) + .collect(); + + if let Err(e) = db.add_history(&history).await { + error!("failed to add history: {}", e); + + return Ok(Box::new(ErrorResponse::reply( + "failed to add history", + StatusCode::INTERNAL_SERVER_ERROR, + ))); + }; + + Ok(Box::new(warp::reply())) +} diff --git a/src/server/handlers/mod.rs b/src/server/handlers/mod.rs new file mode 100644 index 00000000..3c20538c --- /dev/null +++ b/src/server/handlers/mod.rs @@ -0,0 +1,6 @@ +pub mod history; +pub mod user; + +pub const fn index() -> &'static str { + "\"Through the fathomless deeps of space swims the star turtle Great A\u{2019}Tuin, bearing on its back the four giant elephants who carry on their shoulders the mass of the Discworld.\"\n\t-- Sir Terry Pratchett" +} diff --git a/src/server/handlers/user.rs b/src/server/handlers/user.rs new file mode 100644 index 00000000..782d7dbd --- /dev/null +++ b/src/server/handlers/user.rs @@ -0,0 +1,140 @@ +use std::convert::Infallible; + +use sodiumoxide::crypto::pwhash::argon2id13; +use uuid::Uuid; +use warp::http::StatusCode; +use warp::reply::json; + +use crate::api::{ + ErrorResponse, LoginRequest, LoginResponse, RegisterRequest, RegisterResponse, UserResponse, +}; +use crate::server::database::Database; +use crate::server::models::{NewSession, NewUser}; +use crate::settings::Settings; +use crate::utils::hash_secret; + +pub fn verify_str(secret: &str, verify: &str) -> bool { + sodiumoxide::init().unwrap(); + + let mut padded = [0_u8; 128]; + secret.as_bytes().iter().enumerate().for_each(|(i, val)| { + padded[i] = *val; + }); + + match argon2id13::HashedPassword::from_slice(&padded) { + Some(hp) => argon2id13::pwhash_verify(&hp, verify.as_bytes()), + None => false, + } +} + +pub async fn get( + username: String, + db: impl Database + Clone + Send + Sync, +) -> Result, Infallible> { + let user = match db.get_user(username).await { + Ok(user) => user, + Err(e) => { + debug!("user not found: {}", e); + return Ok(Box::new(ErrorResponse::reply( + "user not found", + StatusCode::NOT_FOUND, + ))); + } + }; + + Ok(Box::new(warp::reply::json(&UserResponse { + username: user.username, + }))) +} + +pub async fn register( + register: RegisterRequest, + settings: Settings, + db: impl Database + Clone + Send + Sync, +) -> Result, Infallible> { + if !settings.server.open_registration { + return Ok(Box::new(ErrorResponse::reply( + "this server is not open for registrations", + StatusCode::BAD_REQUEST, + ))); + } + + let hashed = hash_secret(register.password.as_str()); + + let new_user = NewUser { + email: register.email, + username: register.username, + password: hashed, + }; + + let user_id = match db.add_user(new_user).await { + Ok(id) => id, + Err(e) => { + error!("failed to add user: {}", e); + return Ok(Box::new(ErrorResponse::reply( + "failed to add user", + StatusCode::BAD_REQUEST, + ))); + } + }; + + let token = Uuid::new_v4().to_simple().to_string(); + + let new_session = NewSession { + user_id, + token: token.as_str(), + }; + + match db.add_session(&new_session).await { + Ok(_) => Ok(Box::new(json(&RegisterResponse { session: token }))), + Err(e) => { + error!("failed to add session: {}", e); + Ok(Box::new(ErrorResponse::reply( + "failed to register user", + StatusCode::BAD_REQUEST, + ))) + } + } +} + +pub async fn login( + login: LoginRequest, + db: impl Database + Clone + Send + Sync, +) -> Result, Infallible> { + let user = match db.get_user(login.username.clone()).await { + Ok(u) => u, + Err(e) => { + error!("failed to get user {}: {}", login.username.clone(), e); + + return Ok(Box::new(ErrorResponse::reply( + "user not found", + StatusCode::NOT_FOUND, + ))); + } + }; + + let session = match db.get_user_session(&user).await { + Ok(u) => u, + Err(e) => { + error!("failed to get session for {}: {}", login.username, e); + + return Ok(Box::new(ErrorResponse::reply( + "user not found", + StatusCode::NOT_FOUND, + ))); + } + }; + + let verified = verify_str(user.password.as_str(), login.password.as_str()); + + if !verified { + return Ok(Box::new(ErrorResponse::reply( + "user not found", + StatusCode::NOT_FOUND, + ))); + } + + Ok(Box::new(warp::reply::json(&LoginResponse { + session: session.token, + }))) +} diff --git a/src/server/mod.rs b/src/server/mod.rs new file mode 100644 index 00000000..d5e083df --- /dev/null +++ b/src/server/mod.rs @@ -0,0 +1,23 @@ +use std::net::IpAddr; + +use eyre::Result; + +use crate::settings::Settings; + +pub mod auth; +pub mod database; +pub mod handlers; +pub mod models; +pub mod router; + +pub async fn launch(settings: &Settings, host: String, port: u16) -> Result<()> { + // routes to run: + // index, register, add_history, login, get_user, sync_count, sync_list + let host = host.parse::()?; + + let r = router::router(settings).await?; + + warp::serve(r).run((host, port)).await; + + Ok(()) +} diff --git a/src/server/models.rs b/src/server/models.rs new file mode 100644 index 00000000..fbf1897e --- /dev/null +++ b/src/server/models.rs @@ -0,0 +1,49 @@ +use chrono::prelude::*; + +#[derive(sqlx::FromRow)] +pub struct History { + pub id: i64, + pub client_id: String, // a client generated ID + pub user_id: i64, + pub hostname: String, + pub timestamp: NaiveDateTime, + + pub data: String, + + pub created_at: NaiveDateTime, +} + +pub struct NewHistory<'a> { + pub client_id: &'a str, + pub user_id: i64, + pub hostname: &'a str, + pub timestamp: chrono::NaiveDateTime, + + pub data: &'a str, +} + +#[derive(sqlx::FromRow)] +pub struct User { + pub id: i64, + pub username: String, + pub email: String, + pub password: String, +} + +#[derive(sqlx::FromRow)] +pub struct Session { +