summaryrefslogtreecommitdiffstats
path: root/src/remote/auth.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/remote/auth.rs')
-rw-r--r--src/remote/auth.rs200
1 files changed, 200 insertions, 0 deletions
diff --git a/src/remote/auth.rs b/src/remote/auth.rs
new file mode 100644
index 00000000..8f9e9b46
--- /dev/null
+++ b/src/remote/auth.rs
@@ -0,0 +1,200 @@
+use self::diesel::prelude::*;
+use rocket::http::Status;
+use rocket::request::{self, FromRequest, Outcome, Request};
+use rocket_contrib::databases::diesel;
+use sodiumoxide::crypto::pwhash::argon2id13;
+
+use rocket_contrib::json::Json;
+use uuid::Uuid;
+
+use super::models::{NewSession, NewUser, Session, User};
+use super::views::ApiResponse;
+use crate::schema::{sessions, users};
+
+use super::database::AtuinDbConn;
+
+#[derive(Debug)]
+pub enum KeyError {
+ Missing,
+ Invalid,
+}
+
+pub fn hash_str(secret: &str) -> String {
+ sodiumoxide::init().unwrap();
+ let hash = argon2id13::pwhash(
+ secret.as_bytes(),
+ argon2id13::OPSLIMIT_INTERACTIVE,
+ argon2id13::MEMLIMIT_INTERACTIVE,
+ )
+ .unwrap();
+ let texthash = std::str::from_utf8(&hash.0).unwrap().to_string();
+
+ // postgres hates null chars. don't do that to postgres
+ texthash.trim_end_matches('\u{0}').to_string()
+}
+
+pub fn verify_str(secret: &str, verify: &str) -> bool {
+ sodiumoxide::init().unwrap();
+
+ let mut padded = [0_u8; 128];
+ secret.as_bytes().iter().enumerate().for_each(|(i, val)| {
+ padded[i] = *val;
+ });
+
+ match argon2id13::HashedPassword::from_slice(&padded) {
+ Some(hp) => argon2id13::pwhash_verify(&hp, verify.as_bytes()),
+ None => false,
+ }
+}
+
+impl<'a, 'r> FromRequest<'a, 'r> for User {
+ type Error = KeyError;
+
+ fn from_request(request: &'a Request<'r>) -> request::Outcome<User, Self::Error> {
+ let session: Vec<_> = request.headers().get("authorization").collect();
+
+ if session.is_empty() {
+ return Outcome::Failure((Status::BadRequest, KeyError::Missing));
+ } else if session.len() > 1 {
+ return Outcome::Failure((Status::BadRequest, KeyError::Invalid));
+ }
+
+ let session: Vec<_> = session[0].split(' ').collect();
+
+ if session.len() != 2 {
+ return Outcome::Failure((Status::BadRequest, KeyError::Invalid));
+ }
+
+ if session[0] != "Token" {
+ return Outcome::Failure((Status::BadRequest, KeyError::Invalid));
+ }
+
+ let session = session[1];
+
+ let db = request
+ .guard::<AtuinDbConn>()
+ .succeeded()
+ .expect("failed to load database");
+
+ let session = sessions::table
+ .filter(sessions::token.eq(session))
+ .first::<Session>(&*db);
+
+ if session.is_err() {
+ return Outcome::Failure((Status::Unauthorized, KeyError::Invalid));
+ }
+
+ let session = session.unwrap();
+
+ let user = users::table.find(session.user_id).first(&*db);
+
+ match user {
+ Ok(user) => Outcome::Success(user),
+ Err(_) => Outcome::Failure((Status::Unauthorized, KeyError::Invalid)),
+ }
+ }
+}
+
+#[derive(Deserialize)]
+pub struct Register {
+ email: String,
+ password: String,
+}
+
+#[post("/register", data = "<register>")]
+#[allow(clippy::clippy::needless_pass_by_value)]
+pub fn register(conn: AtuinDbConn, register: Json<Register>) -> ApiResponse {
+ let hashed = hash_str(register.password.as_str());
+
+ let new_user = NewUser {
+ email: register.email.as_str(),
+ password: hashed.as_str(),
+ };
+
+ let user = diesel::insert_into(users::table)
+ .values(&new_user)
+ .get_result(&*conn);
+
+ if user.is_err() {
+ return ApiResponse {
+ status: Status::BadRequest,
+ json: json!({
+ "status": "error",
+ "message": "failed to create user - is the email already in use?",
+ }),
+ };
+ }
+
+ let user: User = user.unwrap();
+ let token = Uuid::new_v4().to_simple().to_string();
+
+ let new_session = NewSession {
+ user_id: user.id,
+ token: token.as_str(),
+ };
+
+ match diesel::insert_into(sessions::table)
+ .values(&new_session)
+ .execute(&*conn)
+ {
+ Ok(_) => ApiResponse {
+ status: Status::Ok,
+ json: json!({"status": "ok", "message": "user created!", "session": token}),
+ },
+ Err(_) => ApiResponse {
+ status: Status::BadRequest,
+ json: json!({"status": "error", "message": "failed to create user"}),
+ },
+ }
+}
+
+#[derive(Deserialize)]
+pub struct Login {
+ email: String,
+ password: String,
+}
+
+#[post("/login", data = "<login>")]
+#[allow(clippy::clippy::needless_pass_by_value)]
+pub fn login(conn: AtuinDbConn, login: Json<Login>) -> ApiResponse {
+ let user = users::table
+ .filter(users::email.eq(login.email.as_str()))
+ .first(&*conn);
+
+ if user.is_err() {
+ return ApiResponse {
+ status: Status::NotFound,
+ json: json!({"status": "error", "message": "user not found"}),
+ };
+ }
+
+ let user: User = user.unwrap();
+
+ let session = sessions::table
+ .filter(sessions::user_id.eq(user.id))
+ .first(&*conn);
+
+ // a session should exist...
+ if session.is_err() {
+ return ApiResponse {
+ status: Status::InternalServerError,
+ json: json!({"status": "error", "message": "something went wrong"}),
+ };
+ }
+
+ let verified = verify_str(user.password.as_str(), login.password.as_str());
+
+ if !verified {
+ return ApiResponse {
+ status: Status::NotFound,
+ json: json!({"status": "error", "message": "user not found"}),
+ };
+ }
+
+ let session: Session = session.unwrap();
+
+ ApiResponse {
+ status: Status::Ok,
+ json: json!({"status": "ok", "token": session.token}),
+ }
+}