summaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
authorEllie Huxtable <e@elm.sh>2021-03-21 20:04:39 +0000
committerGitHub <noreply@github.com>2021-03-21 20:04:39 +0000
commitc9579cb9ca2a6a165d10f128e0af1dfd372e0c03 (patch)
tree1d4feecb422aae3cde1cc7cad54ccc73b2dae410 /src
parent716c7722cda29bf612508bb96f51822a86e0f69e (diff)
Implement server (#23)
* Add initial database and server setup * Set up all routes, auth, etc * Implement sessions, password auth, hashing with argon2, and history storage
Diffstat (limited to 'src')
-rw-r--r--src/command/mod.rs2
-rw-r--r--src/command/server.rs5
-rw-r--r--src/main.rs24
-rw-r--r--src/remote/auth.rs200
-rw-r--r--src/remote/database.rs14
-rw-r--r--src/remote/mod.rs4
-rw-r--r--src/remote/models.rs56
-rw-r--r--src/remote/server.rs46
-rw-r--r--src/remote/views.rs89
-rw-r--r--src/schema.rs28
-rw-r--r--src/settings.rs13
11 files changed, 471 insertions, 10 deletions
diff --git a/src/command/mod.rs b/src/command/mod.rs
index 3ebb92e0..a5ea0228 100644
--- a/src/command/mod.rs
+++ b/src/command/mod.rs
@@ -49,7 +49,7 @@ impl AtuinCmd {
match self {
Self::History(history) => history.run(db),
Self::Import(import) => import.run(db),
- Self::Server(server) => server.run(),
+ Self::Server(server) => server.run(settings),
Self::Stats(stats) => stats.run(db, settings),
Self::Init => init::init(),
Self::Search { query } => search::run(&query, db),
diff --git a/src/command/server.rs b/src/command/server.rs
index 1ddc73e7..9d9bcb3a 100644
--- a/src/command/server.rs
+++ b/src/command/server.rs
@@ -2,6 +2,7 @@ use eyre::Result;
use structopt::StructOpt;
use crate::remote::server;
+use crate::settings::Settings;
#[derive(StructOpt)]
pub enum Cmd {
@@ -10,8 +11,8 @@ pub enum Cmd {
#[allow(clippy::unused_self)] // I'll use it later
impl Cmd {
- pub fn run(&self) -> Result<()> {
- server::launch();
+ pub fn run(&self, settings: &Settings) -> Result<()> {
+ server::launch(settings);
Ok(())
}
}
diff --git a/src/main.rs b/src/main.rs
index d47866f4..3c4a05e4 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -17,6 +17,15 @@ extern crate rocket;
#[macro_use]
extern crate serde_derive;
+#[macro_use]
+extern crate diesel;
+
+#[macro_use]
+extern crate diesel_migrations;
+
+#[macro_use]
+extern crate rocket_contrib;
+
use command::AtuinCmd;
use local::database::Sqlite;
use settings::Settings;
@@ -26,6 +35,8 @@ mod local;
mod remote;
mod settings;
+pub mod schema;
+
#[derive(StructOpt)]
#[structopt(
author = "Ellie Huxtable <e@elm.sh>",
@@ -61,7 +72,18 @@ impl Atuin {
}
fn main() -> Result<()> {
- pretty_env_logger::init();
+ fern::Dispatch::new()
+ .format(|out, message, record| {
+ out.finish(format_args!(
+ "{} [{}] {}",
+ chrono::Local::now().format("[%Y-%m-%d][%H:%M:%S]"),
+ record.level(),
+ message
+ ))
+ })
+ .level(log::LevelFilter::Info)
+ .chain(std::io::stdout())
+ .apply()?;
Atuin::from_args().run()
}
diff --git a/src/remote/auth.rs b/src/remote/auth.rs
new file mode 100644
index 00000000..8f9e9b46
--- /dev/null
+++ b/src/remote/auth.rs
@@ -0,0 +1,200 @@
+use self::diesel::prelude::*;
+use rocket::http::Status;
+use rocket::request::{self, FromRequest, Outcome, Request};
+use rocket_contrib::databases::diesel;
+use sodiumoxide::crypto::pwhash::argon2id13;
+
+use rocket_contrib::json::Json;
+use uuid::Uuid;
+
+use super::models::{NewSession, NewUser, Session, User};
+use super::views::ApiResponse;
+use crate::schema::{sessions, users};
+
+use super::database::AtuinDbConn;
+
+#[derive(Debug)]
+pub enum KeyError {
+ Missing,
+ Invalid,
+}
+
+pub fn hash_str(secret: &str) -> String {
+ sodiumoxide::init().unwrap();
+ let hash = argon2id13::pwhash(
+ secret.as_bytes(),
+ argon2id13::OPSLIMIT_INTERACTIVE,
+ argon2id13::MEMLIMIT_INTERACTIVE,
+ )
+ .unwrap();
+ let texthash = std::str::from_utf8(&hash.0).unwrap().to_string();
+
+ // postgres hates null chars. don't do that to postgres
+ texthash.trim_end_matches('\u{0}').to_string()
+}
+
+pub fn verify_str(secret: &str, verify: &str) -> bool {
+ sodiumoxide::init().unwrap();
+
+ let mut padded = [0_u8; 128];
+ secret.as_bytes().iter().enumerate().for_each(|(i, val)| {
+ padded[i] = *val;
+ });
+
+ match argon2id13::HashedPassword::from_slice(&padded) {
+ Some(hp) => argon2id13::pwhash_verify(&hp, verify.as_bytes()),
+ None => false,
+ }
+}
+
+impl<'a, 'r> FromRequest<'a, 'r> for User {
+ type Error = KeyError;
+
+ fn from_request(request: &'a Request<'r>) -> request::Outcome<User, Self::Error> {
+ let session: Vec<_> = request.headers().get("authorization").collect();
+
+ if session.is_empty() {
+ return Outcome::Failure((Status::BadRequest, KeyError::Missing));
+ } else if session.len() > 1 {
+ return Outcome::Failure((Status::BadRequest, KeyError::Invalid));
+ }
+
+ let session: Vec<_> = session[0].split(' ').collect();
+
+ if session.len() != 2 {
+ return Outcome::Failure((Status::BadRequest, KeyError::Invalid));
+ }
+
+ if session[0] != "Token" {
+ return Outcome::Failure((Status::BadRequest, KeyError::Invalid));
+ }
+
+ let session = session[1];
+
+ let db = request
+ .guard::<AtuinDbConn>()
+ .succeeded()
+ .expect("failed to load database");
+
+ let session = sessions::table
+ .filter(sessions::token.eq(session))
+ .first::<Session>(&*db);
+
+ if session.is_err() {
+ return Outcome::Failure((Status::Unauthorized, KeyError::Invalid));
+ }
+
+ let session = session.unwrap();
+
+ let user = users::table.find(session.user_id).first(&*db);
+
+ match user {
+ Ok(user) => Outcome::Success(user),
+ Err(_) => Outcome::Failure((Status::Unauthorized, KeyError::Invalid)),
+ }
+ }
+}
+
+#[derive(Deserialize)]
+pub struct Register {
+ email: String,
+ password: String,
+}
+
+#[post("/register", data = "<register>")]
+#[allow(clippy::clippy::needless_pass_by_value)]
+pub fn register(conn: AtuinDbConn, register: Json<Register>) -> ApiResponse {
+ let hashed = hash_str(register.password.as_str());
+
+ let new_user = NewUser {
+ email: register.email.as_str(),
+ password: hashed.as_str(),
+ };
+
+ let user = diesel::insert_into(users::table)
+ .values(&new_user)
+ .get_result(&*conn);
+
+ if user.is_err() {
+ return ApiResponse {
+ status: Status::BadRequest,
+ json: json!({
+ "status": "error",
+ "message": "failed to create user - is the email already in use?",
+ }),
+ };
+ }
+
+ let user: User = user.unwrap();
+ let token = Uuid::new_v4().to_simple().to_string();
+
+ let new_session = NewSession {
+ user_id: user.id,
+ token: token.as_str(),
+ };
+
+ match diesel::insert_into(sessions::table)
+ .values(&new_session)
+ .execute(&*conn)
+ {
+ Ok(_) => ApiResponse {
+ status: Status::Ok,
+ json: json!({"status": "ok", "message": "user created!", "session": token}),
+ },
+ Err(_) => ApiResponse {
+ status: Status::BadRequest,
+ json: json!({"status": "error", "message": "failed to create user"}),
+ },
+ }
+}
+
+#[derive(Deserialize)]
+pub struct Login {
+ email: String,
+ password: String,
+}
+
+#[post("/login", data = "<login>")]
+#[allow(clippy::clippy::needless_pass_by_value)]
+pub fn login(conn: AtuinDbConn, login: Json<Login>) -> ApiResponse {
+ let user = users::table
+ .filter(users::email.eq(login.email.as_str()))
+ .first(&*conn);
+
+ if user.is_err() {
+ return ApiResponse {
+ status: Status::NotFound,
+ json: json!({"status": "error", "message": "user not found"}),
+ };
+ }
+
+ let user: User = user.unwrap();
+
+ let session = sessions::table
+ .filter(sessions::user_id.eq(user.id))
+ .first(&*conn);
+
+ // a session should exist...
+ if session.is_err() {
+ return ApiResponse {
+ status: Status::InternalServerError,
+ json: json!({"status": "error", "message": "something went wrong"}),
+ };
+ }
+
+ let verified = verify_str(user.password.as_str(), login.password.as_str());
+
+ if !verified {
+ return ApiResponse {
+ status: Status::NotFound,
+ json: json!({"status": "error", "message": "user not found"}),
+ };
+ }
+
+ let session: Session = session.unwrap();
+
+ ApiResponse {
+ status: Status::Ok,
+ json: json!({"status": "ok", "token": session.token}),
+ }
+}
diff --git a/src/remote/database.rs b/src/remote/database.rs
new file mode 100644
index 00000000..4f386def
--- /dev/null
+++ b/src/remote/database.rs
@@ -0,0 +1,14 @@
+use diesel::pg::PgConnection;
+use diesel::prelude::*;
+
+use crate::settings::Settings;
+
+#[database("atuin")]
+pub struct AtuinDbConn(diesel::PgConnection);
+
+// TODO: connection pooling
+pub fn establish_connection(settings: &Settings) -> PgConnection {
+ let database_url = &settings.remote.db.url;
+ PgConnection::establish(database_url)
+ .unwrap_or_else(|_| panic!("Error connecting to {}", database_url))
+}
diff --git a/src/remote/mod.rs b/src/remote/mod.rs
index 74f47ad3..7147b88e 100644
--- a/src/remote/mod.rs
+++ b/src/remote/mod.rs
@@ -1 +1,5 @@
+pub mod auth;
+pub mod database;
+pub mod models;
pub mod server;
+pub mod views;
diff --git a/src/remote/models.rs b/src/remote/models.rs
new file mode 100644
index 00000000..058b2f0b
--- /dev/null
+++ b/src/remote/models.rs
@@ -0,0 +1,56 @@
+use chrono::naive::NaiveDateTime;
+
+use crate::schema::{history, sessions, users};
+
+#[derive(Identifiable, Queryable, Associations)]
+#[table_name = "history"]
+#[belongs_to(User)]
+pub struct History {
+ pub id: i64,
+ pub client_id: String,
+ pub user_id: i64,
+ pub mac: String,
+ pub timestamp: NaiveDateTime,
+
+ pub data: String,
+}
+
+#[derive(Identifiable, Queryable, Associations)]
+pub struct User {
+ pub id: i64,
+ pub email: String,
+ pub password: String,
+}
+
+#[derive(Queryable, Identifiable, Associations)]
+#[belongs_to(User)]
+pub struct Session {
+ pub id: i64,
+ pub user_id: i64,
+ pub token: String,
+}
+
+#[derive(Insertable)]
+#[table_name = "history"]
+pub struct NewHistory<'a> {
+ pub client_id: &'a str,
+ pub user_id: i64,
+ pub mac: &'a str,
+ pub timestamp: NaiveDateTime,
+
+ pub data: &'a str,
+}
+
+#[derive(Insertable)]
+#[table_name = "users"]
+pub struct NewUser<'a> {
+ pub email: &'a str,
+ pub password: &'a str,
+}
+
+#[derive(Insertable)]
+#[table_name = "sessions"]
+pub struct NewSession<'a> {
+ pub user_id: i64,
+ pub token: &'a str,
+}
diff --git a/src/remote/server.rs b/src/remote/server.rs
index bc1dc2bd..4409f646 100644
--- a/src/remote/server.rs
+++ b/src/remote/server.rs
@@ -1,8 +1,42 @@
-#[get("/")]
-const fn index() -> &'static str {
- "Hello, world!"
-}
+use rocket::config::{Config, Environment, LoggingLevel, Value};
+
+use std::collections::HashMap;
+
+use crate::remote::database::establish_connection;
+use crate::settings::Settings;
+
+use super::database::AtuinDbConn;
+
+// a bunch of these imports are generated by macros, it's easier to wildcard
+#[allow(clippy::clippy::wildcard_imports)]
+use super::views::*;
+
+#[allow(clippy::clippy::wildcard_imports)]
+use super::auth::*;
+
+embed_migrations!("migrations");
+
+pub fn launch(settings: &Settings) {
+ let mut database_config = HashMap::new();
+ let mut databases = HashMap::new();
+
+ database_config.insert("url", Value::from(settings.remote.db.url.clone()));
+ databases.insert("atuin", Value::from(database_config));
+
+ let connection = establish_connection(settings);
+ embedded_migrations::run(&connection).expect("failed to run migrations");
+
+ let config = Config::build(Environment::Production)
+ .address("0.0.0.0")
+ .log_level(LoggingLevel::Normal)
+ .port(8080)
+ .extra("databases", databases)
+ .finalize()
+ .unwrap();
-pub fn launch() {
- rocket::ignite().mount("/", routes![index]).launch();
+ let app = rocket::custom(config);
+ app.mount("/", routes![index, register, add_history, login])
+ .attach(AtuinDbConn::fairing())
+ .register(catchers![internal_error, bad_request])
+ .launch();
}
diff --git a/src/remote/views.rs b/src/remote/views.rs
new file mode 100644
index 00000000..2af3f369
--- /dev/null
+++ b/src/remote/views.rs
@@ -0,0 +1,89 @@
+use self::diesel::prelude::*;
+use rocket::http::{ContentType, Status};
+use rocket::request::Request;
+use rocket::response;
+use rocket::response::{Responder, Response};
+use rocket_contrib::databases::diesel;
+use rocket_contrib::json::{Json, JsonValue};
+
+use super::database::AtuinDbConn;
+use super::models::{NewHistory, User};
+use crate::schema::history;
+
+#[derive(Debug)]
+pub struct ApiResponse {
+ pub json: JsonValue,
+ pub status: Status,
+}
+
+impl<'r> Responder<'r> for ApiResponse {
+ fn respond_to(self, req: &Request) -> response::Result<'r> {
+ Response::build_from(self.json.respond_to(req).unwrap())
+ .status(self.status)
+ .header(ContentType::JSON)
+ .ok()
+ }
+}
+
+#[get("/")]
+pub const fn index() -> &'static str {
+ "\"Through the fathomless deeps of space swims the star turtle Great A\u{2019}Tuin, bearing on its back the four giant elephants who carry on their shoulders the mass of the Discworld.\"\n\t-- Sir Terry Pratchett"
+}
+
+#[catch(500)]
+pub fn internal_error(_req: &Request) -> ApiResponse {
+ ApiResponse {
+ status: Status::InternalServerError,
+ json: json!({"status": "error", "message": "an internal server error has occured"}),
+ }
+}
+
+#[catch(400)]
+pub fn bad_request(_req: &Request) -> ApiResponse {
+ ApiResponse {
+ status: Status::InternalServerError,
+ json: json!({"status": "error", "message": "bad request. don't do that."}),
+ }
+}
+
+#[derive(Deserialize)]
+pub struct AddHistory {
+ id: String,
+ timestamp: i64,
+ data: String,
+ mac: String,
+}
+
+#[post("/history", data = "<add_history>")]
+#[allow(
+ clippy::clippy::cast_sign_loss,
+ clippy::cast_possible_truncation,
+ clippy::clippy::needless_pass_by_value
+)]
+pub fn add_history(conn: AtuinDbConn, user: User, add_history: Json<AddHistory>) -> ApiResponse {
+ let secs: i64 = add_history.timestamp / 1_000_000_000;
+ let nanosecs: u32 = (add_history.timestamp - (secs * 1_000_000_000)) as u32;
+ let datetime = chrono::NaiveDateTime::from_timestamp(secs, nanosecs);
+
+ let new_history = NewHistory {
+ client_id: add_history.id.as_str(),
+ user_id: user.id,
+ mac: add_history.mac.as_str(),
+ timestamp: datetime,
+ data: add_history.data.as_str(),
+ };
+
+ match diesel::insert_into(history::table)
+ .values(&new_history)
+ .execute(&*conn)
+ {
+ Ok(_) => ApiResponse {
+ status: Status::Ok,
+ json: json!({"status": "ok", "message": "history added", "id": new_history.client_id}),
+ },
+ Err(_) => ApiResponse {
+ status: Status::BadRequest,
+ json: json!({"status": "error", "message": "failed to add history"}),
+ },
+ }
+}
diff --git a/src/schema.rs b/src/schema.rs
new file mode 100644
index 00000000..efa9ddcc
--- /dev/null
+++ b/src/schema.rs
@@ -0,0 +1,28 @@
+table! {
+ history (id) {
+ id -> Int8,
+ client_id -> Text,
+ user_id -> Int8,
+ mac -> Varchar,
+ timestamp -> Timestamp,
+ data -> Varchar,
+ }
+}
+
+table! {
+ sessions (id) {
+ id -> Int8,
+ user_id -> Int8,
+ token -> Varchar,
+ }
+}
+
+table! {
+ users (id) {
+ id -> Int8,
+ email -> Varchar,
+ password -> Varchar,
+ }
+}
+
+allow_tables_to_appear_in_same_query!(history, sessions, users,);
diff --git a/src/settings.rs b/src/settings.rs
index a4c9f8da..6f29afd2 100644
--- a/src/settings.rs
+++ b/src/settings.rs
@@ -11,6 +11,11 @@ pub struct LocalDatabase {
}
#[derive(Debug, Deserialize)]
+pub struct RemoteDatabase {
+ pub url: String,
+}
+
+#[derive(Debug, Deserialize)]
pub struct Local {
pub server_address: String,
pub dialect: String,
@@ -18,8 +23,14 @@ pub struct Local {
}
#[derive(Debug, Deserialize)]
+pub struct Remote {
+ pub db: RemoteDatabase,
+}
+
+#[derive(Debug, Deserialize)]
pub struct Settings {
pub local: Local,
+ pub remote: Remote,
}
impl Settings {
@@ -49,6 +60,8 @@ impl Settings {
s.set_default("local.dialect", "us")?;
s.set_default("local.db.path", db_path.to_str())?;
+ s.set_default("remote.db.url", "please set a postgres url")?;
+
if config_file.exists() {
s.merge(File::with_name(config_file.to_str().unwrap()))?;
}