use std::collections::HashMap; use async_trait::async_trait; use chrono::{Datelike, TimeZone}; use chronoutil::RelativeDuration; use sqlx::{postgres::PgPoolOptions, Result}; use sqlx::Row; use tracing::{debug, instrument, warn}; use super::{ calendar::{TimePeriod, TimePeriodInfo}, models::{History, NewHistory, NewSession, NewUser, Session, User}, }; use crate::settings::Settings; use crate::settings::HISTORY_PAGE_SIZE; use atuin_common::utils::get_days_from_month; #[async_trait] pub trait Database { async fn get_session(&self, token: &str) -> Result; async fn get_session_user(&self, token: &str) -> Result; async fn add_session(&self, session: &NewSession) -> Result<()>; async fn get_user(&self, username: &str) -> Result; async fn get_user_session(&self, u: &User) -> Result; async fn add_user(&self, user: &NewUser) -> Result; async fn count_history(&self, user: &User) -> Result; async fn count_history_cached(&self, user: &User) -> Result; async fn delete_history(&self, user: &User, id: String) -> Result<()>; async fn deleted_history(&self, user: &User) -> Result>; async fn count_history_range( &self, user: &User, start: chrono::NaiveDateTime, end: chrono::NaiveDateTime, ) -> Result; async fn count_history_day(&self, user: &User, date: chrono::NaiveDate) -> Result; async fn count_history_month(&self, user: &User, date: chrono::NaiveDate) -> Result; async fn count_history_year(&self, user: &User, year: i32) -> Result; async fn list_history( &self, user: &User, created_after: chrono::NaiveDateTime, since: chrono::NaiveDateTime, host: &str, ) -> Result>; async fn add_history(&self, history: &[NewHistory]) -> Result<()>; async fn oldest_history(&self, user: &User) -> Result; async fn calendar( &self, user: &User, period: TimePeriod, year: u64, month: u64, ) -> Result>; } #[derive(Clone)] pub struct Postgres { pool: sqlx::Pool, settings: Settings, } impl Postgres { pub async fn new(settings: Settings) -> Result { let pool = PgPoolOptions::new() .max_connections(100) .connect(settings.db_uri.as_str()) .await?; sqlx::migrate!("./migrations").run(&pool).await?; Ok(Self { pool, settings }) } } #[async_trait] impl Database for Postgres { #[instrument(skip_all)] async fn get_session(&self, token: &str) -> Result { sqlx::query_as::<_, Session>("select id, user_id, token from sessions where token = $1") .bind(token) .fetch_one(&self.pool) .await } #[instrument(skip_all)] async fn get_user(&self, username: &str) -> Result { sqlx::query_as::<_, User>( "select id, username, email, password from users where username = $1", ) .bind(username) .fetch_one(&self.pool) .await } #[instrument(skip_all)] async fn get_session_user(&self, token: &str) -> Result { sqlx::query_as::<_, User>( "select users.id, users.username, users.email, users.password from users inner join sessions on users.id = sessions.user_id and sessions.token = $1", ) .bind(token) .fetch_one(&self.pool) .await } #[instrument(skip_all)] async fn count_history(&self, user: &User) -> Result { // The cache is new, and the user might not yet have a cache value. // They will have one as soon as they post up some new history, but handle that // edge case. let res: (i64,) = sqlx::query_as( "select count(1) from history where user_id = $1", ) .bind(user.id) .fetch_one(&self.pool) .await?; Ok(res.0) } #[instrument(skip_all)] async fn count_history_cached(&self, user: &User) -> Result { let res: (i32,) = sqlx::query_as( "select total from total_history_count_user where user_id = $1", ) .bind(user.id) .fetch_one(&self.pool) .await?; Ok(res.0 as i64) } async fn delete_history(&self, user: &User, id: String) -> Result<()> { sqlx::query( "update history set deleted_at = $3 where user_id = $1 and client_id = $2 and deleted_at is null", // don't just keep setting it ) .bind(user.id) .bind(id) .bind(chrono::Utc::now().naive_utc()) .fetch_all(&self.pool) .await?; Ok(()) } #[instrument(skip_all)] async fn deleted_history(&self, user: &User) -> Result> { // The cache is new, and the user might not yet have a cache value. // They will have one as soon as they post up some new history, but handle that // edge case. let res = sqlx::query( "select client_id from history where user_id = $1 and deleted_at is not null", ) .bind(user.id) .fetch_all(&self.pool) .await?; let res = res .iter() .map(|row| row.get::("client_id")) .collect(); Ok(res) } #[instrument(skip_all)] async fn count_history_range( &self, user: &User, start: chrono::NaiveDateTime, end: chrono::NaiveDateTime, ) -> Result { let res: (i64,) = sqlx::query_as( "select count(1) from history where user_id = $1 and timestamp >= $2::date and timestamp < $3::date", ) .bind(user.id) .bind(start) .bind(end) .fetch_one(&self.pool) .await?; Ok(res.0) } // Count the history for a given year #[instrument(skip_all)] async fn count_history_year(&self, user: &User, year: i32) -> Result { let start = chrono::Utc.ymd(year, 1, 1).and_hms_nano(0, 0, 0, 0); let end = start + RelativeDuration::years(1); let res = self .count_history_range(user, start.naive_utc(), end.naive_utc()) .await?; Ok(res) } // Count the history for a given month #[instrument(skip_all)] async fn count_history_month(&self, user: &User, month: chrono::NaiveDate) -> Result { let start = chrono::Utc .ymd(month.year(), month.month(), 1) .and_hms_nano(0, 0, 0, 0); // ofc... let end = if month.month() < 12 { chrono::Utc .ymd(month.year(), month.month() + 1, 1) .and_hms_nano(0, 0, 0, 0) } else { chrono::Utc .ymd(month.year() + 1, 1, 1) .and_hms_nano(0, 0, 0, 0) }; debug!("start: {}, end: {}", start, end); let res = self .count_history_range(user, start.naive_utc(), end.naive_utc()) .await?; Ok(res) } // Count the history for a given day #[instrument(skip_all)] async fn count_history_day(&self, user: &User, day: chrono::NaiveDate) -> Result { let start = chrono::Utc .ymd(day.year(), day.month(), day.day()) .and_hms_nano(0, 0, 0, 0); let end = chrono::Utc .ymd(day.year(), day.month(), day.day() + 1) .and_hms_nano(0, 0, 0, 0); let res = self .count_history_range(user, start.naive_utc(), end.naive_utc()) .await?; Ok(res) } #[instrument(skip_all)] async fn list_history( &self, user: &User, created_after: chrono::NaiveDateTime, since: chrono::NaiveDateTime, host: &str, ) -> Result> { let res = sqlx::query_as::<_, History>( "select id, client_id, user_id, hostname, timestamp, data, created_at from history where user_id = $1 and hostname != $2 and created_at >= $3 and timestamp >= $4 order by timestamp asc limit $5", ) .bind(user.id) .bind(host) .bind(created_after) .bind(since) .bind(HISTORY_PAGE_SIZE) .fetch_all(&self.pool) .await?; Ok(res) } #[instrument(skip_all)] async fn add_history(&self, history: &[NewHistory]) -> Result<()> { let mut tx = self.pool.begin().await?; for i in history { let client_id: &str = &i.client_id; let hostname: &str = &i.hostname; let data: &str = &i.data; if data.len() > self.settings.max_history_length && self.settings.max_history_length != 0 { // Don't return an error here. We want to insert as much of the // history list as we can, so log the error and continue going. warn!( "history too long, got length {}, max {}", data.len(), self.settings.max_history_length ); continue; } sqlx::query( "insert into history (client_id, user_id, hostname, timestamp, data) values ($1, $2, $3, $4, $5) on conflict do nothing ", ) .bind(client_id) .bind(i.user_id) .bind(hostname) .bind(i.timestamp) .bind(data) .execute(&mut tx) .await?; } tx.commit().await?; Ok(()) } #[instrument(skip_all)] async fn add_user(&self, user: &NewUser) -> Result { let email: &str = &user.email; let username: &str = &user.username; let password: &str = &user.password; let res: (i64,) = sqlx::query_as( "insert into users (username, email, password) values($1, $2, $3) returning id", ) .bind(username) .bind(email) .bind(password) .fetch_one(&self.pool) .await?; Ok(res.0) } #[instrument(skip_all)] async fn add_session(&self, session: &NewSession) -> Result<()> { let token: &str = &session.token; sqlx::query( "insert into sessions (user_id, token) values($1, $2)", ) .bind(session.user_id) .bind(token) .execute(&self.pool) .await?; Ok(()) } #[instrument(skip_all)] async fn get_user_session(&self, u: &User) -> Result { sqlx::query_as::<_, Session>("select id, user_id, token from sessions where user_id = $1") .bind(u.id) .fetch_one(&self.pool) .await } #[instrument(skip_all)] async fn oldest_history(&self, user: &User) -> Result { let res = sqlx::query_as::<_, History>( "select id, client_id, user_id, hostname, timestamp, data, created_at from history where user_id = $1 order by timestamp asc limit 1", ) .bind(user.id) .fetch_one(&self.pool) .await?; Ok(res) } #[instrument(skip_all)] async fn calendar( &self, user: &User, period: TimePeriod, year: u64, month: u64, ) -> Result> { // TODO: Support different timezones. Right now we assume UTC and // everything is stored as such. But it _should_ be possible to // interpret the stored date with a different TZ match period { TimePeriod::YEAR => { let mut ret = HashMap::new(); // First we need to work out how far back to calculate. Get the // oldest history item let oldest = self.oldest_history(user).await?.timestamp.year(); let current_year = chrono::Utc::now().year(); // All the years we need to get data for // The upper bound is exclusive, so include current +1 let years = oldest..current_year + 1; for year in years { let count = self.count_history_year(user, year).await?; ret.insert( year as u64, TimePeriodInfo { count: count as u64, hash: "".to_string(), }, ); } Ok(ret) } TimePeriod::MONTH => { let mut ret = HashMap::new(); for month in 1..13 { let count = self .count_history_month( user, chrono::Utc.ymd(year as i32, month, 1).naive_utc(), ) .await?; ret.insert( month as u64, TimePeriodInfo { count: count as u64, hash: "".to_string(), }, ); } Ok(ret) } TimePeriod::DAY => { let mut ret = HashMap::new(); for day in 1..get_days_from_month(year as i32, month as u32) { let count = self .count_history_day( user, chrono::Utc .ymd(year as i32, month as u32, day as u32) .naive_utc(), ) .await?; ret.insert( day as u64, TimePeriodInfo { count: count as u64, hash: "".to_string(), }, ); } Ok(ret) } } } }