summaryrefslogtreecommitdiffstats
path: root/src/command/import.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/command/import.rs')
-rw-r--r--src/command/import.rs273
1 files changed, 86 insertions, 187 deletions
diff --git a/src/command/import.rs b/src/command/import.rs
index 9a0364da..53940abb 100644
--- a/src/command/import.rs
+++ b/src/command/import.rs
@@ -1,15 +1,11 @@
-use std::env;
-use std::path::PathBuf;
+use std::{env, path::PathBuf};
-use atuin_common::utils::uuid_v4;
-use chrono::{TimeZone, Utc};
-use directories::UserDirs;
use eyre::{eyre, Result};
use structopt::StructOpt;
-use atuin_client::history::History;
use atuin_client::import::{bash::Bash, zsh::Zsh};
-use atuin_client::{database::Database, import::resh::ReshEntry};
+use atuin_client::{database::Database, import::Importer};
+use atuin_client::{history::History, import::resh::Resh};
use indicatif::ProgressBar;
#[derive(StructOpt)]
@@ -39,6 +35,8 @@ pub enum Cmd {
Resh,
}
+const BATCH_SIZE: usize = 100;
+
impl Cmd {
pub async fn run(&self, db: &mut (impl Database + Send + Sync)) -> Result<()> {
println!(" Atuin ");
@@ -55,216 +53,117 @@ impl Cmd {
if shell.ends_with("/zsh") {
println!("Detected ZSH");
- import_zsh(db).await
+ import::<Zsh<_>, _>(db, BATCH_SIZE).await
} else {
println!("cannot import {} history", shell);
Ok(())
}
}
- Self::Zsh => import_zsh(db).await,
- Self::Bash => import_bash(db).await,
- Self::Resh => import_resh(db).await,
- }
- }
-}
-
-async fn import_resh(db: &mut (impl Database + Send + Sync)) -> Result<()> {
- let histpath = std::path::Path::new(std::env::var("HOME")?.as_str()).join(".resh_history.json");
-
- println!("Parsing .resh_history.json...");
- #[allow(clippy::filter_map)]
- let history = std::fs::read_to_string(histpath)?
- .split('\n')
- .map(str::trim)
- .map(|x| serde_json::from_str::<ReshEntry>(x))
- .filter_map(|x| match x {
- Ok(x) => Some(x),
- Err(e) => {
- if e.is_eof() {
- None
- } else {
- warn!("Invalid entry found in resh_history file: {}", e);
- None
- }
- }
- })
- .map(|x| {
- #[allow(clippy::cast_possible_truncation)]
- #[allow(clippy::cast_sign_loss)]
- let timestamp = {
- let secs = x.realtime_before.floor() as i64;
- let nanosecs = (x.realtime_before.fract() * 1_000_000_000_f64).round() as u32;
- Utc.timestamp(secs, nanosecs)
- };
- #[allow(clippy::cast_possible_truncation)]
- #[allow(clippy::cast_sign_loss)]
- let duration = {
- let secs = x.realtime_after.floor() as i64;
- let nanosecs = (x.realtime_after.fract() * 1_000_000_000_f64).round() as u32;
- let difference = Utc.timestamp(secs, nanosecs) - timestamp;
- difference.num_nanoseconds().unwrap_or(0)
- };
-
- History {
- id: uuid_v4(),
- timestamp,
- duration,
- exit: x.exit_code,
- command: x.cmd_line,
- cwd: x.pwd,
- session: uuid_v4(),
- hostname: x.host,
- }
- })
- .collect::<Vec<_>>();
- println!("Updating database...");
-
- let progress = ProgressBar::new(history.len() as u64);
-
- let buf_size = 100;
- let mut buf = Vec::<_>::with_capacity(buf_size);
-
- for i in history {
- buf.push(i);
-
- if buf.len() == buf_size {
- db.save_bulk(&buf).await?;
- progress.inc(buf.len() as u64);
-
- buf.clear();
+ Self::Zsh => import::<Zsh<_>, _>(db, BATCH_SIZE).await,
+ Self::Bash => import::<Bash<_>, _>(db, BATCH_SIZE).await,
+ Self::Resh => import::<Resh, _>(db, BATCH_SIZE).await,
}
}
-
- if !buf.is_empty() {
- db.save_bulk(&buf).await?;
- progress.inc(buf.len() as u64);
- }
- Ok(())
}
-async fn import_zsh(db: &mut (impl Database + Send + Sync)) -> Result<()> {
- // oh-my-zsh sets HISTFILE=~/.zhistory
- // zsh has no default value for this var, but uses ~/.zhistory.
- // we could maybe be smarter about this in the future :)
-
- let histpath = env::var("HISTFILE");
-
- let histpath = if let Ok(p) = histpath {
- let histpath = PathBuf::from(p);
-
- if !histpath.exists() {
- return Err(eyre!(
- "Could not find history file {:?}. try updating $HISTFILE",
- histpath
- ));
- }
-
- histpath
+async fn import<I: Importer + Send, DB: Database + Send + Sync>(
+ db: &mut DB,
+ buf_size: usize,
+) -> Result<()>
+where
+ I::IntoIter: Send,
+{
+ println!("Importing history from {}", I::NAME);
+
+ let histpath = get_histpath::<I>()?;
+ let contents = I::parse(histpath)?;
+
+ let iter = contents.into_iter();
+ let progress = if let (_, Some(upper_bound)) = iter.size_hint() {
+ ProgressBar::new(upper_bound as u64)
} else {
- let user_dirs = UserDirs::new().unwrap();
- let home_dir = user_dirs.home_dir();
-
- let mut candidates = [".zhistory", ".zsh_history"].iter();
- loop {
- match candidates.next() {
- Some(candidate) => {
- let histpath = home_dir.join(candidate);
- if histpath.exists() {
- break histpath;
- }
- }
- None => return Err(eyre!("Could not find history file. try setting $HISTFILE")),
- }
- }
+ ProgressBar::new_spinner()
};
- let zsh = Zsh::new(histpath)?;
-
- let progress = ProgressBar::new(zsh.loc);
-
- let buf_size = 100;
let mut buf = Vec::<History>::with_capacity(buf_size);
+ let mut iter = progress.wrap_iter(iter);
+ loop {
+ // fill until either no more entries
+ // or until the buffer is full
+ let done = fill_buf(&mut buf, &mut iter);
- for i in zsh
- .filter_map(Result::ok)
- .filter(|x| !x.command.trim().is_empty())
- {
- buf.push(i);
-
- if buf.len() == buf_size {
- db.save_bulk(&buf).await?;
- progress.inc(buf.len() as u64);
+ // flush
+ db.save_bulk(&buf).await?;
- buf.clear();
+ if done {
+ break;
}
}
- if !buf.is_empty() {
- db.save_bulk(&buf).await?;
- progress.inc(buf.len() as u64);
- }
-
- progress.finish();
println!("Import complete!");
Ok(())
}
-// TODO: don't just copy paste this lol
-async fn import_bash(db: &mut (impl Database + Send + Sync)) -> Result<()> {
- // oh-my-zsh sets HISTFILE=~/.zhistory
- // zsh has no default value for this var, but uses ~/.zhistory.
- // we could maybe be smarter about this in the future :)
-
- let histpath = env::var("HISTFILE");
-
- let histpath = if let Ok(p) = histpath {
- let histpath = PathBuf::from(p);
-
- if !histpath.exists() {
- return Err(eyre!(
- "Could not find history file {:?}. try updating $HISTFILE",
- histpath
- ));
- }
-
- histpath
+fn get_histpath<I: Importer>() -> Result<PathBuf> {
+ if let Ok(p) = env::var("HISTFILE") {
+ is_file(PathBuf::from(p))
} else {
- let user_dirs = UserDirs::new().unwrap();
- let home_dir = user_dirs.home_dir();
-
- home_dir.join(".bash_history")
- };
-
- let bash = Bash::new(histpath)?;
-
- let progress = ProgressBar::new(bash.loc);
-
- let buf_size = 100;
- let mut buf = Vec::<History>::with_capacity(buf_size);
+ is_file(I::histpath()?)
+ }
+}
- for i in bash
- .filter_map(Result::ok)
- .filter(|x| !x.command.trim().is_empty())
- {
- buf.push(i);
+fn is_file(p: PathBuf) -> Result<PathBuf> {
+ if p.is_file() {
+ Ok(p)
+ } else {
+ Err(eyre!(
+ "Could not find history file {:?}. Try setting $HISTFILE",
+ p
+ ))
+ }
+}
- if buf.len() == buf_size {
- db.save_bulk(&buf).await?;
- progress.inc(buf.len() as u64);
+fn fill_buf<T, E>(buf: &mut Vec<T>, iter: &mut impl Iterator<Item = Result<T, E>>) -> bool {
+ buf.clear();
+ loop {
+ match iter.next() {
+ Some(Ok(t)) => buf.push(t),
+ Some(Err(_)) => (),
+ None => break true,
+ }
- buf.clear();
+ if buf.len() == buf.capacity() {
+ break false;
}
}
+}
- if !buf.is_empty() {
- db.save_bulk(&buf).await?;
- progress.inc(buf.len() as u64);
+#[cfg(test)]
+mod tests {
+ use super::fill_buf;
+
+ #[test]
+ fn test_fill_buf() {
+ let mut buf = Vec::with_capacity(4);
+ let mut iter = vec![
+ Ok(1),
+ Err(2),
+ Ok(3),
+ Ok(4),
+ Err(5),
+ Ok(6),
+ Ok(7),
+ Err(8),
+ Ok(9),
+ ]
+ .into_iter();
+
+ assert!(!fill_buf(&mut buf, &mut iter));
+ assert_eq!(buf, vec![1, 3, 4, 6]);
+
+ assert!(fill_buf(&mut buf, &mut iter));
+ assert_eq!(buf, vec![7, 9]);
}
-
- progress.finish();
- println!("Import complete!");
-
- Ok(())
}