summaryrefslogtreecommitdiffstats
path: root/src/os/shared.rs
diff options
context:
space:
mode:
authorcyqsimon <28627918+cyqsimon@users.noreply.github.com>2023-10-21 22:14:46 +0800
committerGitHub <noreply@github.com>2023-10-21 22:14:46 +0800
commit6fa77d29917452f682fefa5df4db72c3d697661e (patch)
tree4f25976c53821c0de0f0fa9e5e93aecf17eb5482 /src/os/shared.rs
parent89e1140bea689c5790029bd02aea74b6b5ac0bea (diff)
Refactor `OsInputOutput` (combine interfaces & frames into single Vec) (#310)
* Refactor `OsInputOutput` (combine interfaces & frames into single Vec) * Add note on handling a separate failure case * Reduce code duplication
Diffstat (limited to 'src/os/shared.rs')
-rw-r--r--src/os/shared.rs230
1 files changed, 103 insertions, 127 deletions
diff --git a/src/os/shared.rs b/src/os/shared.rs
index 5383db3..ed97878 100644
--- a/src/os/shared.rs
+++ b/src/os/shared.rs
@@ -4,11 +4,13 @@ use std::{
time,
};
+use anyhow::{anyhow, bail};
use crossterm::event::{read, Event};
+use itertools::Itertools;
use pnet::datalink::{self, Channel::Ethernet, Config, DataLinkReceiver, NetworkInterface};
use tokio::runtime::Runtime;
-use crate::{network::dns, os::errors::GetInterfaceError, OsInputOutput};
+use crate::{mt_log, network::dns, os::errors::GetInterfaceError, OsInputOutput};
#[cfg(target_os = "linux")]
use crate::os::linux::get_open_sockets;
@@ -63,160 +65,134 @@ fn get_interface(interface_name: &str) -> Option<NetworkInterface> {
}
fn create_write_to_stdout() -> Box<dyn FnMut(String) + Send> {
+ let mut stdout = io::stdout();
Box::new({
- let mut stdout = io::stdout();
move |output: String| {
writeln!(stdout, "{}", output).unwrap();
}
})
}
-#[derive(Debug)]
-pub struct UserErrors {
- permission: Option<String>,
- other: Option<String>,
-}
-
-pub fn collect_errors<'a, I>(network_frames: I) -> String
-where
- I: Iterator<
- Item = (
- &'a NetworkInterface,
- Result<Box<dyn DataLinkReceiver>, GetInterfaceError>,
- ),
- >,
-{
- let errors = network_frames.fold(
- UserErrors {
- permission: None,
- other: None,
- },
- |acc, (_, elem)| {
- if let Some(iface_error) = elem.err() {
- match iface_error {
- GetInterfaceError::PermissionError(interface_name) => {
- if let Some(prev_interface) = acc.permission {
- return UserErrors {
- permission: Some(format!("{prev_interface}, {interface_name}")),
- ..acc
- };
- } else {
- return UserErrors {
- permission: Some(interface_name),
- ..acc
- };
- }
- }
- error => {
- if let Some(prev_errors) = acc.other {
- return UserErrors {
- other: Some(format!("{prev_errors} \n {error}")),
- ..acc
- };
- } else {
- return UserErrors {
- other: Some(format!("{error}")),
- ..acc
- };
- }
- }
- };
- }
- acc
- },
- );
- if let Some(interface_name) = errors.permission {
- if let Some(other_errors) = errors.other {
- format!(
- "\n\n{interface_name}: {} \nAdditional Errors: \n {other_errors}",
- eperm_message(),
- )
- } else {
- format!("\n\n{interface_name}: {}", eperm_message())
- }
- } else {
- let other_errors = errors
- .other
- .expect("asked to collect errors but found no errors");
- format!("\n\n {other_errors}")
- }
-}
-
pub fn get_input(
interface_name: Option<&str>,
resolve: bool,
dns_server: Option<Ipv4Addr>,
) -> anyhow::Result<OsInputOutput> {
- let network_interfaces = if let Some(name) = interface_name {
- match get_interface(name) {
- Some(interface) => vec![interface],
- None => {
- anyhow::bail!("Cannot find interface {name}");
- // the homebrew formula relies on this wording, please be careful when changing
- }
- }
- } else {
- datalink::interfaces()
- };
-
- #[cfg(target_os = "windows")]
- let network_frames = network_interfaces
- .iter()
- .filter(|iface| !iface.ips.is_empty())
- .map(|iface| (iface, get_datalink_channel(iface)));
- #[cfg(not(target_os = "windows"))]
- let network_frames = network_interfaces
- .iter()
- .filter(|iface| iface.is_up() && !iface.ips.is_empty())
- .map(|iface| (iface, get_datalink_channel(iface)));
-
- let (available_network_frames, network_interfaces) = {
- let network_frames = network_frames.clone();
- let mut available_network_frames = Vec::new();
- let mut available_interfaces: Vec<NetworkInterface> = Vec::new();
- for (iface, rx) in network_frames.filter_map(|(iface, channel)| {
- if let Ok(rx) = channel {
- Some((iface, rx))
+ // get the user's requested interface, if any
+ // IDEA: allow requesting multiple interfaces
+ let requested_interfaces = interface_name
+ .map(|name| get_interface(name).ok_or_else(|| anyhow!("Cannot find interface {name}")))
+ .transpose()?
+ .map(|interface| vec![interface]);
+
+ // take the user's requested interfaces (or all interfaces), and filter for up ones
+ let available_interfaces = requested_interfaces
+ .unwrap_or_else(datalink::interfaces)
+ .into_iter()
+ .filter(|interface| {
+ // see https://github.com/libpnet/libpnet/issues/564
+ let keep = if cfg!(target_os = "windows") {
+ !interface.ips.is_empty()
} else {
- None
+ interface.is_up() && !interface.ips.is_empty()
+ };
+ if !keep {
+ mt_log!(debug, "{} is down. Skipping it.", interface.name);
}
- }) {
- available_interfaces.push(iface.clone());
- available_network_frames.push(rx);
- }
- (available_network_frames, available_interfaces)
- };
+ keep
+ })
+ .collect_vec();
- if available_network_frames.is_empty() {
- let all_errors = collect_errors(network_frames.clone());
- if !all_errors.is_empty() {
- anyhow::bail!(all_errors);
- }
+ // bail if no interfaces are up
+ if available_interfaces.is_empty() {
+ bail!("Failed to find any network interface to listen on.");
+ }
- anyhow::bail!("Failed to find any network interface to listen on.");
+ // try to get a frame receiver for each interface
+ let interfaces_with_frames_res = available_interfaces
+ .into_iter()
+ .map(|interface| {
+ let frames_res = get_datalink_channel(&interface);
+ (interface, frames_res)
+ })
+ .collect_vec();
+
+ // warn for all frame receivers we failed to acquire
+ interfaces_with_frames_res
+ .iter()
+ .filter_map(|(interface, frames_res)| frames_res.as_ref().err().map(|err| (interface, err)))
+ .for_each(|(interface, err)| {
+ mt_log!(
+ warn,
+ "Failed to acquire a frame receiver for {}: {err}",
+ interface.name
+ )
+ });
+
+ // bail if all of them fail
+ // note that `Iterator::all` returns `true` for an empty iterator, so it is important to handle
+ // that failure mode separately, which we already have
+ if interfaces_with_frames_res
+ .iter()
+ .all(|(_, frames)| frames.is_err())
+ {
+ let (permission_err_interfaces, other_errs) = interfaces_with_frames_res.iter().fold(
+ (vec![], vec![]),
+ |(mut perms, mut others), (_, res)| {
+ match res {
+ Ok(_) => (),
+ Err(GetInterfaceError::PermissionError(interface)) => {
+ perms.push(interface.as_str())
+ }
+ Err(GetInterfaceError::OtherError(err)) => others.push(err.as_str()),
+ }
+ (perms, others)
+ },
+ );
+
+ let err_msg = match (permission_err_interfaces.is_empty(), other_errs.is_empty()) {
+ (false, false) => format!(
+ "\n\n{}: {}\nAdditional errors:\n{}",
+ permission_err_interfaces.join(", "),
+ eperm_message(),
+ other_errs.join("\n")
+ ),
+ (false, true) => format!(
+ "\n\n{}: {}",
+ permission_err_interfaces.join(", "),
+ eperm_message()
+ ),
+ (true, false) => format!("\n\n{}", other_errs.join("\n")),
+ (true, true) => unreachable!("Found no errors in error handling code path."),
+ };
+ bail!(err_msg);
}
- let keyboard_events = Box::new(TerminalEvents);
- let write_to_stdout = create_write_to_stdout();
+ // filter out interfaces for which we failed to acquire a frame receiver
+ let interfaces_with_frames = interfaces_with_frames_res
+ .into_iter()
+ .filter_map(|(interface, res)| res.ok().map(|frames| (interface, frames)))
+ .collect();
+
let dns_client = if resolve {
let runtime = Runtime::new()?;
- let resolver = match runtime.block_on(dns::Resolver::new(dns_server)) {
- Ok(resolver) => resolver,
- Err(err) => anyhow::bail!(
- "Could not initialize the DNS resolver. Are you offline?\n\nReason: {err:?}"
- ),
- };
+ let resolver = runtime
+ .block_on(dns::Resolver::new(dns_server))
+ .map_err(|err| {
+ anyhow!("Could not initialize the DNS resolver. Are you offline?\n\nReason: {err}")
+ })?;
let dns_client = dns::Client::new(resolver, runtime)?;
Some(dns_client)
} else {
None
};
+ let write_to_stdout = create_write_to_stdout();
+
Ok(OsInputOutput {
- network_interfaces,
- network_frames: available_network_frames,
+ interfaces_with_frames,
get_open_sockets,
- terminal_events: keyboard_events,
+ terminal_events: Box::new(TerminalEvents),
dns_client,
write_to_stdout,
})