diff options
Diffstat (limited to 'atuin-server-postgres/src/lib.rs')
-rw-r--r-- | atuin-server-postgres/src/lib.rs | 98 |
1 files changed, 53 insertions, 45 deletions
diff --git a/atuin-server-postgres/src/lib.rs b/atuin-server-postgres/src/lib.rs index f22e6bee3..c1de4d509 100644 --- a/atuin-server-postgres/src/lib.rs +++ b/atuin-server-postgres/src/lib.rs @@ -1,7 +1,7 @@ use std::ops::Range; use async_trait::async_trait; -use atuin_common::record::{EncryptedData, HostId, Record, RecordId, RecordIndex}; +use atuin_common::record::{EncryptedData, HostId, Record, RecordIdx, RecordStatus}; use atuin_server_database::models::{History, NewHistory, NewSession, NewUser, Session, User}; use atuin_server_database::{Database, DbError, DbResult}; use futures_util::TryStreamExt; @@ -11,6 +11,7 @@ use sqlx::Row; use time::{OffsetDateTime, PrimitiveDateTime, UtcOffset}; use tracing::instrument; +use uuid::Uuid; use wrappers::{DbHistory, DbRecord, DbSession, DbUser}; mod wrappers; @@ -361,16 +362,16 @@ impl Database for Postgres { let id = atuin_common::utils::uuid_v7(); sqlx::query( - "insert into records - (id, client_id, host, parent, timestamp, version, tag, data, cek, user_id) + "insert into store + (id, client_id, host, idx, timestamp, version, tag, data, cek, user_id) values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) on conflict do nothing ", ) .bind(id) .bind(i.id) - .bind(i.host) - .bind(i.parent) + .bind(i.host.id) + .bind(i.idx as i64) .bind(i.timestamp as i64) // throwing away some data, but i64 is still big in terms of time .bind(&i.version) .bind(&i.tag) @@ -393,62 +394,69 @@ impl Database for Postgres { user: &User, host: HostId, tag: String, - start: Option<RecordId>, + start: Option<RecordIdx>, count: u64, ) -> DbResult<Vec<Record<EncryptedData>>> { tracing::debug!("{:?} - {:?} - {:?}", host, tag, start); - let mut ret = Vec::with_capacity(count as usize); - let mut parent = start; - - // yeah let's do something better - for _ in 0..count { - // a very much not ideal query. but it's simple at least? - // we are basically using postgres as a kv store here, so... maybe consider using an actual - // kv store? - let record: Result<DbRecord, DbError> = sqlx::query_as( - "select client_id, host, parent, timestamp, version, tag, data, cek from records + let start = start.unwrap_or(0); + + let records: Result<Vec<DbRecord>, DbError> = sqlx::query_as( + "select client_id, host, idx, timestamp, version, tag, data, cek from store where user_id = $1 and tag = $2 and host = $3 - and parent is not distinct from $4", - ) - .bind(user.id) - .bind(tag.clone()) - .bind(host) - .bind(parent) - .fetch_one(&self.pool) - .await - .map_err(fix_error); - - match record { - Ok(record) => { - let record: Record<EncryptedData> = record.into(); - ret.push(record.clone()); - - parent = Some(record.id); - } - Err(DbError::NotFound) => { - tracing::debug!("hit tail of store: {:?}/{}", host, tag); - return Ok(ret); - } - Err(e) => return Err(e), + and idx >= $4 + order by idx asc + limit $5", + ) + .bind(user.id) + .bind(tag.clone()) + .bind(host) + .bind(start as i64) + .bind(count as i64) + .fetch_all(&self.pool) + .await + .map_err(fix_error); + + let ret = match records { + Ok(records) => { + let records: Vec<Record<EncryptedData>> = records + .into_iter() + .map(|f| { + let record: Record<EncryptedData> = f.into(); + record + }) + .collect(); + + records } - } + Err(DbError::NotFound) => { + tracing::debug!("no records found in store: {:?}/{}", host, tag); + return Ok(vec![]); + } + Err(e) => return Err(e), + }; Ok(ret) } - async fn tail_records(&self, user: &User) -> DbResult<RecordIndex> { - const TAIL_RECORDS_SQL: &str = "select host, tag, client_id from records rp where (select count(1) from records where parent=rp.client_id and user_id = $1) = 0 and user_id = $1;"; + async fn status(&self, user: &User) -> DbResult<RecordStatus> { + const STATUS_SQL: &str = + "select host, tag, max(idx) from store where user_id = $1 group by host, tag"; - let res = sqlx::query_as(TAIL_RECORDS_SQL) + let res: Vec<(Uuid, String, i64)> = sqlx::query_as(STATUS_SQL) .bind(user.id) - .fetch(&self.pool) - .try_collect() + .fetch_all(&self.pool) .await .map_err(fix_error)?; - Ok(res) + let mut status = RecordStatus::new(); + + for i in res { + status.set_raw(HostId(i.0), i.1, i.2 as u64); + } + + Ok(status) } } |