use crate::error::Error; use crate::SimpleCache; use flate2::read::DeflateDecoder; use flate2::write::DeflateEncoder; use flate2::Compression; use parking_lot::RwLock; use parking_lot::RwLockReadGuard; use parking_lot::RwLockWriteGuard; use rmp_serde; use serde::de::DeserializeOwned; use serde::*; use serde_json; use std::borrow::Borrow; use std::collections::hash_map::Entry; use std::fs::File; use std::hash::Hash; use std::io::BufReader; use std::io::BufWriter; use std::marker::PhantomData; use std::path::PathBuf; use tempfile::NamedTempFile; type FxHashMap = std::collections::HashMap; struct Inner { data: Option>>, writes: usize, next_autosave: usize, } pub struct TempCache> { path: PathBuf, data: RwLock>, _ty: PhantomData, pub cache_only: bool, sem: tokio::sync::Semaphore, } impl TempCache { pub fn new(path: impl Into) -> Result { let path = path.into().with_extension("rmpz"); let data = if path.exists() { None } else { Some(FxHashMap::default()) }; Ok(Self { path, data: RwLock::new(Inner { data, writes: 0, next_autosave: 10, }), _ty: PhantomData, cache_only: false, sem: tokio::sync::Semaphore::new(32), }) } #[inline] pub fn set(&self, key: impl Into, value: impl Borrow) -> Result<(), Error> { self.set_(key.into(), value.borrow()) } fn lock_for_write(&self) -> Result>, Error> { let mut inner = self.data.write(); if inner.data.is_none() { inner.data = Some(self.load_data()?); } Ok(inner) } fn lock_for_read(&self) -> Result>, Error> { loop { let inner = self.data.read(); if inner.data.is_some() { return Ok(inner); } drop(inner); let _ = self.lock_for_write()?; } } fn load_data(&self) -> Result>, Error> { let mut f = BufReader::new(File::open(&self.path)?); Ok(rmp_serde::from_read(&mut f).map_err(|e| { eprintln!("File {} is broken: {}", self.path.display(), e); e })?) } pub fn set_(&self, key: K, value: &T) -> Result<(), Error> { let mut e = DeflateEncoder::new(Vec::new(), Compression::best()); rmp_serde::encode::write_named(&mut e, value)?; let compr = e.finish()?; debug_assert!(Self::ungz(&compr).is_ok()); // sanity check let mut w = self.lock_for_write()?; let compr = compr.into_boxed_slice(); match w.data.as_mut().unwrap().entry(key) { Entry::Vacant(e) => { e.insert(compr); }, Entry::Occupied(mut e) => { if e.get() == &compr { return Ok(()); } e.insert(compr); }, } w.writes += 1; if w.writes >= w.next_autosave { w.writes = 0; w.next_autosave *= 2; drop(w); // unlock writes let d = self.lock_for_read()?; self.save_unlocked(&d)?; } Ok(()) } pub fn delete(&self, key: &Q) -> Result<(), Error> where K: Borrow, Q: Eq + Hash + ?Sized { let mut d = self.lock_for_write()?; if d.data.as_mut().unwrap().remove(key).is_some() { d.writes += 1; } Ok(()) } pub fn get(&self, key: &Q) -> Result, Error> where K: Borrow, Q: Eq + Hash + std::fmt::Display + ?Sized { let kw = self.lock_for_read()?; Ok(match kw.data.as_ref().unwrap().get(key) { Some(gz) => Some(Self::ungz(gz).map_err(|e| { eprintln!("ungz of {} failed in {}", key, self.path.display()); drop(kw); let _ = self.delete(key); e })?), None => None, }) } fn ungz(data: &[u8]) -> Result { let ungz = DeflateDecoder::new(data); Ok(rmp_serde::decode::from_read(ungz)?) } pub fn save(&self) -> Result<(), Error> { let mut data = self.data.write(); if data.writes > 0 { self.save_unlocked(&data)?; data.data = None; // Flush mem } Ok(()) } fn save_unlocked(&self, d: &Inner) -> Result<(), Error> { if let Some(data) = d.data.as_ref() { let tmp_path = NamedTempFile::new_in(self.path.parent().expect("tmp"))?; let mut file = BufWriter::new(File::create(&tmp_path)?); rmp_serde::encode::write(&mut file, data)?; tmp_path.persist(&self.path).map_err(|e| e.error)?; } Ok(()) } #[inline] pub async fn get_json(&self, key: &Q, url: impl AsRef, on_miss: impl FnOnce(B) -> Option) -> Result, Error> where B: for<'a> Deserialize<'a>, K: Borrow + for<'a> From<&'a Q>, Q: Eq + Hash + std::fmt::Display + ?Sized { if let Some(res) = self.get(key)? { return Ok(Some(res)); } if self.cache_only { return Ok(None); } let _s = self.sem.acquire().await; let data = Box::pin(SimpleCache::fetch(url.as_ref())).await?; match serde_json::from_slice(&data) { Ok(res) => { let res = on_miss(res); if let Some(ref res) = res { self.set(key, res)? } Ok(res) }, Err(parse) => Err(Error::Parse(parse, data)), } } } impl Drop for TempCache { fn drop(&mut self) { let d = self.data.read(); if d.writes > 0 { if let Err(err) = self.save_unlocked(&d) { eprintln!("Temp db save failed: {}", err); } } } } #[test] fn kvtest() { let tmp: TempCache<(String, String)> = TempCache::new("/tmp/rmptest.bin").unwrap(); tmp.set("hello", &("world".to_string(), "etc".to_string())).unwrap(); let res = tmp.get("hello").unwrap().unwrap(); drop(tmp); assert_eq!(res, ("world".to_string(), "etc".to_string())); let tmp2: TempCache<(String, String)> = TempCache::new("/tmp/rmptest.bin").unwrap(); let res2 = tmp2.get("hello").unwrap().unwrap(); assert_eq!(res2, ("world".to_string(), "etc".to_string())); }