summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorSam Tay <sam.chong.tay@gmail.com>2020-06-18 14:50:02 -0700
committerSam Tay <sam.chong.tay@gmail.com>2020-06-18 14:50:02 -0700
commit412676f8c99a93ee879c9c127a58b32dae50cdfa (patch)
treeccf7316192626c3c7929973d9d64cbef042657d9
parent5f88657a75c4443ba93936e0f14bb3be0435fd41 (diff)
Allow searching multiple SE sites at once
-rw-r--r--src/cli.rs15
-rw-r--r--src/config.rs6
-rw-r--r--src/error.rs2
-rw-r--r--src/main.rs4
-rw-r--r--src/stackexchange.rs53
5 files changed, 58 insertions, 22 deletions
diff --git a/src/cli.rs b/src/cli.rs
index 4ca00d0..7d946d0 100644
--- a/src/cli.rs
+++ b/src/cli.rs
@@ -17,6 +17,7 @@ pub struct Opts {
pub fn get_opts() -> Result<Opts> {
let config = config::user_config()?;
let limit = &config.limit.to_string();
+ let sites = &config.sites.join(";");
let matches = App::new("so")
.setting(AppSettings::ColoredHelp)
.version(clap::crate_version!())
@@ -43,11 +44,11 @@ pub fn get_opts() -> Result<Opts> {
Arg::with_name("site")
.long("site")
.short("s")
- .multiple(false) // TODO sites plural
+ .multiple(true)
.number_of_values(1)
.takes_value(true)
- .default_value(&config.site)
- .help("StackExchange site code to search"), // TODO sites plural
+ .default_value(sites)
+ .help("StackExchange site code to search"),
)
.arg(
Arg::with_name("limit")
@@ -92,7 +93,13 @@ pub fn get_opts() -> Result<Opts> {
config: Config {
// these unwraps are safe via clap default values & validators
limit: matches.value_of("limit").unwrap().parse::<u16>().unwrap(),
- site: matches.value_of("site").unwrap().to_string(), // TODO values_of
+ sites: matches
+ .values_of("site")
+ .unwrap()
+ .map(|s| s.split(';'))
+ .flatten()
+ .map(String::from)
+ .collect(),
api_key: matches
.value_of("set-api-key")
.map(String::from)
diff --git a/src/config.rs b/src/config.rs
index af82885..c86e0ad 100644
--- a/src/config.rs
+++ b/src/config.rs
@@ -7,12 +7,12 @@ use std::path::PathBuf;
use crate::error::{Error, Result};
use crate::utils;
-#[derive(Deserialize, Serialize, Debug)]
+#[derive(Deserialize, Serialize, Debug, Clone)]
pub struct Config {
pub api_key: Option<String>,
pub limit: u16,
pub lucky: bool,
- pub site: String,
+ pub sites: Vec<String>,
}
// TODO make a friender config file, like the colors.toml below
@@ -22,7 +22,7 @@ impl Default for Config {
api_key: None,
limit: 20,
lucky: true,
- site: String::from("stackoverflow"),
+ sites: vec![String::from("stackoverflow")],
}
}
}
diff --git a/src/error.rs b/src/error.rs
index 86fb55f..d104594 100644
--- a/src/error.rs
+++ b/src/error.rs
@@ -16,6 +16,8 @@ pub enum Error {
SerdeYaml(#[from] serde_yaml::Error),
#[error("IO error: {0}")]
IO(#[from] std::io::Error),
+ #[error("Futures Join error : {0}")]
+ JoinError(#[from] tokio::task::JoinError),
#[error("File `{}` is malformed; try removing it", .0.display())]
MalformedFile(PathBuf),
#[error("Lacking {0:?} permissions on `{}`", .1.display())]
diff --git a/src/main.rs b/src/main.rs
index e7c0414..c900765 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -36,7 +36,7 @@ async fn main() -> Result<(), Error> {
async fn run(skin: &mut MadSkin) -> Result<(), Error> {
let opts = cli::get_opts()?;
let config = opts.config;
- let site = &config.site;
+ let sites = &config.sites;
let lucky = config.lucky;
let mut ls = LocalStorage::new()?;
@@ -62,7 +62,7 @@ async fn run(skin: &mut MadSkin) -> Result<(), Error> {
return Ok(());
}
- if !ls.validate_site(site).await? {
+ if let Some(site) = ls.find_invalid_site(sites).await? {
print_error!(skin, "$0 is not a valid StackExchange site.\n\n", site)?;
// TODO should only use inline for single lines; use termimad::text stuff
print_notice!(
diff --git a/src/stackexchange.rs b/src/stackexchange.rs
index f86c8c0..77e8647 100644
--- a/src/stackexchange.rs
+++ b/src/stackexchange.rs
@@ -1,3 +1,4 @@
+use futures::stream::StreamExt;
use reqwest::Client;
use reqwest::Url;
use serde::{Deserialize, Serialize};
@@ -21,9 +22,13 @@ const SE_FILTER: &str = ".DND5X2VHHUH8HyJzpjo)5NvdHI3w6auG";
/// Pagesize when fetching all SE sites. Should be good for many years...
const SE_SITES_PAGESIZE: u16 = 10000;
+/// Limit on concurrent requests (gets passed to `buffer_unordered`)
+const CONCURRENT_REQUESTS_LIMIT: usize = 8;
+
/// This structure allows interacting with parts of the StackExchange
/// API, using the `Config` struct to determine certain API settings and options.
// TODO should my se structs have &str instead of String?
+#[derive(Clone)]
pub struct StackExchange {
client: Client,
config: Config,
@@ -32,7 +37,7 @@ pub struct StackExchange {
/// This structure allows interacting with locally cached StackExchange metadata.
pub struct LocalStorage {
- sites: Option<Vec<Site>>, // TODO this should be a hashmap!
+ sites: Option<Vec<Site>>,
filename: PathBuf,
}
@@ -85,7 +90,6 @@ impl StackExchange {
}
}
- // TODO also return a future with the rest of the questions
/// Search query at stack exchange and get the top answer body
pub async fn search_lucky(&self) -> Result<String> {
Ok(self
@@ -106,11 +110,29 @@ impl StackExchange {
self.search_advanced(self.config.limit).await
}
- /// Search against the search/advanced endpoint with a given query.
- /// Only fetches questions that have at least one answer.
- /// TODO async
- /// TODO parallel requests over multiple sites
+ /// Parallel searches against the search/advanced endpoint across all configured sites
async fn search_advanced(&self, limit: u16) -> Result<Vec<Question>> {
+ let results = futures::stream::iter(self.config.sites.clone())
+ .map(|site| {
+ let clone = self.clone();
+ tokio::spawn(async move {
+ let clone = &clone;
+ clone.search_advanced_site(&site, limit).await
+ })
+ })
+ .buffer_unordered(CONCURRENT_REQUESTS_LIMIT)
+ .collect::<Vec<_>>()
+ .await;
+ results
+ .into_iter()
+ .map(|r| r.map_err(Error::from).and_then(|x| x))
+ .collect::<Result<Vec<Vec<_>>>>()
+ .map(|v| v.into_iter().flatten().collect())
+ }
+
+ /// Search against the site's search/advanced endpoint with a given query.
+ /// Only fetches questions that have at least one answer.
+ async fn search_advanced_site(&self, site: &str, limit: u16) -> Result<Vec<Question>> {
Ok(self
.client
.get(stackexchange_url("search/advanced"))
@@ -119,6 +141,7 @@ impl StackExchange {
.query(&[
("q", self.query.as_str()),
("pagesize", &limit.to_string()),
+ ("site", site),
("page", "1"),
("answers", "1"),
("order", "desc"),
@@ -140,10 +163,9 @@ impl StackExchange {
fn get_default_opts(&self) -> HashMap<&str, &str> {
let mut params = HashMap::new();
- params.insert("site", self.config.site.as_str());
- params.insert("filter", &SE_FILTER);
+ params.insert("filter", SE_FILTER);
if let Some(key) = &self.config.api_key {
- params.insert("key", key.as_str());
+ params.insert("key", &key);
}
params
}
@@ -162,7 +184,6 @@ impl LocalStorage {
// TODO inform user if we are downloading
pub async fn sites(&mut self) -> Result<&Vec<Site>> {
- // Stop once Option ~ Some or Result ~ Err
if self.sites.is_none() && !self.fetch_local_sites()? {
self.fetch_remote_sites().await?;
}
@@ -177,12 +198,18 @@ impl LocalStorage {
self.fetch_remote_sites().await
}
- pub async fn validate_site(&mut self, site_code: &str) -> Result<bool> {
- Ok(self
+ // TODO is this HM worth it? Probably only will ever have < 10 site codes to search...
+ pub async fn find_invalid_site<'a, 'b>(
+ &'b mut self,
+ site_codes: &'a [String],
+ ) -> Result<Option<&'a String>> {
+ let hm: HashMap<&str, ()> = self
.sites()
.await?
.iter()
- .any(|site| site.api_site_parameter == *site_code))
+ .map(|site| (site.api_site_parameter.as_str(), ()))
+ .collect();
+ Ok(site_codes.iter().find(|s| !hm.contains_key(&s.as_str())))
}
fn fetch_local_sites(&mut self) -> Result<bool> {