diff options
author | cyqsimon <28627918+cyqsimon@users.noreply.github.com> | 2023-10-21 22:14:46 +0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-21 22:14:46 +0800 |
commit | 6fa77d29917452f682fefa5df4db72c3d697661e (patch) | |
tree | 4f25976c53821c0de0f0fa9e5e93aecf17eb5482 /src/os/shared.rs | |
parent | 89e1140bea689c5790029bd02aea74b6b5ac0bea (diff) |
Refactor `OsInputOutput` (combine interfaces & frames into single Vec) (#310)
* Refactor `OsInputOutput` (combine interfaces & frames into single Vec)
* Add note on handling a separate failure case
* Reduce code duplication
Diffstat (limited to 'src/os/shared.rs')
-rw-r--r-- | src/os/shared.rs | 230 |
1 files changed, 103 insertions, 127 deletions
diff --git a/src/os/shared.rs b/src/os/shared.rs index 5383db3..ed97878 100644 --- a/src/os/shared.rs +++ b/src/os/shared.rs @@ -4,11 +4,13 @@ use std::{ time, }; +use anyhow::{anyhow, bail}; use crossterm::event::{read, Event}; +use itertools::Itertools; use pnet::datalink::{self, Channel::Ethernet, Config, DataLinkReceiver, NetworkInterface}; use tokio::runtime::Runtime; -use crate::{network::dns, os::errors::GetInterfaceError, OsInputOutput}; +use crate::{mt_log, network::dns, os::errors::GetInterfaceError, OsInputOutput}; #[cfg(target_os = "linux")] use crate::os::linux::get_open_sockets; @@ -63,160 +65,134 @@ fn get_interface(interface_name: &str) -> Option<NetworkInterface> { } fn create_write_to_stdout() -> Box<dyn FnMut(String) + Send> { + let mut stdout = io::stdout(); Box::new({ - let mut stdout = io::stdout(); move |output: String| { writeln!(stdout, "{}", output).unwrap(); } }) } -#[derive(Debug)] -pub struct UserErrors { - permission: Option<String>, - other: Option<String>, -} - -pub fn collect_errors<'a, I>(network_frames: I) -> String -where - I: Iterator< - Item = ( - &'a NetworkInterface, - Result<Box<dyn DataLinkReceiver>, GetInterfaceError>, - ), - >, -{ - let errors = network_frames.fold( - UserErrors { - permission: None, - other: None, - }, - |acc, (_, elem)| { - if let Some(iface_error) = elem.err() { - match iface_error { - GetInterfaceError::PermissionError(interface_name) => { - if let Some(prev_interface) = acc.permission { - return UserErrors { - permission: Some(format!("{prev_interface}, {interface_name}")), - ..acc - }; - } else { - return UserErrors { - permission: Some(interface_name), - ..acc - }; - } - } - error => { - if let Some(prev_errors) = acc.other { - return UserErrors { - other: Some(format!("{prev_errors} \n {error}")), - ..acc - }; - } else { - return UserErrors { - other: Some(format!("{error}")), - ..acc - }; - } - } - }; - } - acc - }, - ); - if let Some(interface_name) = errors.permission { - if let Some(other_errors) = errors.other { - format!( - "\n\n{interface_name}: {} \nAdditional Errors: \n {other_errors}", - eperm_message(), - ) - } else { - format!("\n\n{interface_name}: {}", eperm_message()) - } - } else { - let other_errors = errors - .other - .expect("asked to collect errors but found no errors"); - format!("\n\n {other_errors}") - } -} - pub fn get_input( interface_name: Option<&str>, resolve: bool, dns_server: Option<Ipv4Addr>, ) -> anyhow::Result<OsInputOutput> { - let network_interfaces = if let Some(name) = interface_name { - match get_interface(name) { - Some(interface) => vec![interface], - None => { - anyhow::bail!("Cannot find interface {name}"); - // the homebrew formula relies on this wording, please be careful when changing - } - } - } else { - datalink::interfaces() - }; - - #[cfg(target_os = "windows")] - let network_frames = network_interfaces - .iter() - .filter(|iface| !iface.ips.is_empty()) - .map(|iface| (iface, get_datalink_channel(iface))); - #[cfg(not(target_os = "windows"))] - let network_frames = network_interfaces - .iter() - .filter(|iface| iface.is_up() && !iface.ips.is_empty()) - .map(|iface| (iface, get_datalink_channel(iface))); - - let (available_network_frames, network_interfaces) = { - let network_frames = network_frames.clone(); - let mut available_network_frames = Vec::new(); - let mut available_interfaces: Vec<NetworkInterface> = Vec::new(); - for (iface, rx) in network_frames.filter_map(|(iface, channel)| { - if let Ok(rx) = channel { - Some((iface, rx)) + // get the user's requested interface, if any + // IDEA: allow requesting multiple interfaces + let requested_interfaces = interface_name + .map(|name| get_interface(name).ok_or_else(|| anyhow!("Cannot find interface {name}"))) + .transpose()? + .map(|interface| vec![interface]); + + // take the user's requested interfaces (or all interfaces), and filter for up ones + let available_interfaces = requested_interfaces + .unwrap_or_else(datalink::interfaces) + .into_iter() + .filter(|interface| { + // see https://github.com/libpnet/libpnet/issues/564 + let keep = if cfg!(target_os = "windows") { + !interface.ips.is_empty() } else { - None + interface.is_up() && !interface.ips.is_empty() + }; + if !keep { + mt_log!(debug, "{} is down. Skipping it.", interface.name); } - }) { - available_interfaces.push(iface.clone()); - available_network_frames.push(rx); - } - (available_network_frames, available_interfaces) - }; + keep + }) + .collect_vec(); - if available_network_frames.is_empty() { - let all_errors = collect_errors(network_frames.clone()); - if !all_errors.is_empty() { - anyhow::bail!(all_errors); - } + // bail if no interfaces are up + if available_interfaces.is_empty() { + bail!("Failed to find any network interface to listen on."); + } - anyhow::bail!("Failed to find any network interface to listen on."); + // try to get a frame receiver for each interface + let interfaces_with_frames_res = available_interfaces + .into_iter() + .map(|interface| { + let frames_res = get_datalink_channel(&interface); + (interface, frames_res) + }) + .collect_vec(); + + // warn for all frame receivers we failed to acquire + interfaces_with_frames_res + .iter() + .filter_map(|(interface, frames_res)| frames_res.as_ref().err().map(|err| (interface, err))) + .for_each(|(interface, err)| { + mt_log!( + warn, + "Failed to acquire a frame receiver for {}: {err}", + interface.name + ) + }); + + // bail if all of them fail + // note that `Iterator::all` returns `true` for an empty iterator, so it is important to handle + // that failure mode separately, which we already have + if interfaces_with_frames_res + .iter() + .all(|(_, frames)| frames.is_err()) + { + let (permission_err_interfaces, other_errs) = interfaces_with_frames_res.iter().fold( + (vec![], vec![]), + |(mut perms, mut others), (_, res)| { + match res { + Ok(_) => (), + Err(GetInterfaceError::PermissionError(interface)) => { + perms.push(interface.as_str()) + } + Err(GetInterfaceError::OtherError(err)) => others.push(err.as_str()), + } + (perms, others) + }, + ); + + let err_msg = match (permission_err_interfaces.is_empty(), other_errs.is_empty()) { + (false, false) => format!( + "\n\n{}: {}\nAdditional errors:\n{}", + permission_err_interfaces.join(", "), + eperm_message(), + other_errs.join("\n") + ), + (false, true) => format!( + "\n\n{}: {}", + permission_err_interfaces.join(", "), + eperm_message() + ), + (true, false) => format!("\n\n{}", other_errs.join("\n")), + (true, true) => unreachable!("Found no errors in error handling code path."), + }; + bail!(err_msg); } - let keyboard_events = Box::new(TerminalEvents); - let write_to_stdout = create_write_to_stdout(); + // filter out interfaces for which we failed to acquire a frame receiver + let interfaces_with_frames = interfaces_with_frames_res + .into_iter() + .filter_map(|(interface, res)| res.ok().map(|frames| (interface, frames))) + .collect(); + let dns_client = if resolve { let runtime = Runtime::new()?; - let resolver = match runtime.block_on(dns::Resolver::new(dns_server)) { - Ok(resolver) => resolver, - Err(err) => anyhow::bail!( - "Could not initialize the DNS resolver. Are you offline?\n\nReason: {err:?}" - ), - }; + let resolver = runtime + .block_on(dns::Resolver::new(dns_server)) + .map_err(|err| { + anyhow!("Could not initialize the DNS resolver. Are you offline?\n\nReason: {err}") + })?; let dns_client = dns::Client::new(resolver, runtime)?; Some(dns_client) } else { None }; + let write_to_stdout = create_write_to_stdout(); + Ok(OsInputOutput { - network_interfaces, - network_frames: available_network_frames, + interfaces_with_frames, get_open_sockets, - terminal_events: keyboard_events, + terminal_events: Box::new(TerminalEvents), dns_client, write_to_stdout, }) |