diff options
author | cyqsimon <28627918+cyqsimon@users.noreply.github.com> | 2023-10-21 22:14:46 +0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-21 22:14:46 +0800 |
commit | 6fa77d29917452f682fefa5df4db72c3d697661e (patch) | |
tree | 4f25976c53821c0de0f0fa9e5e93aecf17eb5482 | |
parent | 89e1140bea689c5790029bd02aea74b6b5ac0bea (diff) |
Refactor `OsInputOutput` (combine interfaces & frames into single Vec) (#310)
* Refactor `OsInputOutput` (combine interfaces & frames into single Vec)
* Add note on handling a separate failure case
* Reduce code duplication
-rw-r--r-- | src/main.rs | 6 | ||||
-rw-r--r-- | src/os/shared.rs | 230 | ||||
-rw-r--r-- | src/tests/cases/test_utils.rs | 11 | ||||
-rw-r--r-- | src/tests/cases/ui.rs | 22 | ||||
-rw-r--r-- | src/tests/fakes/fake_input.rs | 7 |
5 files changed, 131 insertions, 145 deletions
diff --git a/src/main.rs b/src/main.rs index 87c5cf6..ade4b9e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -93,8 +93,7 @@ pub struct OpenSockets { } pub struct OsInputOutput { - pub network_interfaces: Vec<NetworkInterface>, - pub network_frames: Vec<Box<dyn DataLinkReceiver>>, + pub interfaces_with_frames: Vec<(NetworkInterface, Box<dyn DataLinkReceiver>)>, pub get_open_sockets: fn() -> OpenSockets, pub terminal_events: Box<dyn Iterator<Item = Event> + Send>, pub dns_client: Option<dns::Client>, @@ -281,9 +280,8 @@ where active_threads.push(terminal_event_handler); let sniffer_threads = os_input - .network_interfaces + .interfaces_with_frames .into_iter() - .zip(os_input.network_frames) .map(|(iface, frames)| { let name = format!("sniffing_handler_{}", iface.name); let running = running.clone(); diff --git a/src/os/shared.rs b/src/os/shared.rs index 5383db3..ed97878 100644 --- a/src/os/shared.rs +++ b/src/os/shared.rs @@ -4,11 +4,13 @@ use std::{ time, }; +use anyhow::{anyhow, bail}; use crossterm::event::{read, Event}; +use itertools::Itertools; use pnet::datalink::{self, Channel::Ethernet, Config, DataLinkReceiver, NetworkInterface}; use tokio::runtime::Runtime; -use crate::{network::dns, os::errors::GetInterfaceError, OsInputOutput}; +use crate::{mt_log, network::dns, os::errors::GetInterfaceError, OsInputOutput}; #[cfg(target_os = "linux")] use crate::os::linux::get_open_sockets; @@ -63,160 +65,134 @@ fn get_interface(interface_name: &str) -> Option<NetworkInterface> { } fn create_write_to_stdout() -> Box<dyn FnMut(String) + Send> { + let mut stdout = io::stdout(); Box::new({ - let mut stdout = io::stdout(); move |output: String| { writeln!(stdout, "{}", output).unwrap(); } }) } -#[derive(Debug)] -pub struct UserErrors { - permission: Option<String>, - other: Option<String>, -} - -pub fn collect_errors<'a, I>(network_frames: I) -> String -where - I: Iterator< - Item = ( - &'a NetworkInterface, - Result<Box<dyn DataLinkReceiver>, GetInterfaceError>, - ), - >, -{ - let errors = network_frames.fold( - UserErrors { - permission: None, - other: None, - }, - |acc, (_, elem)| { - if let Some(iface_error) = elem.err() { - match iface_error { - GetInterfaceError::PermissionError(interface_name) => { - if let Some(prev_interface) = acc.permission { - return UserErrors { - permission: Some(format!("{prev_interface}, {interface_name}")), - ..acc - }; - } else { - return UserErrors { - permission: Some(interface_name), - ..acc - }; - } - } - error => { - if let Some(prev_errors) = acc.other { - return UserErrors { - other: Some(format!("{prev_errors} \n {error}")), - ..acc - }; - } else { - return UserErrors { - other: Some(format!("{error}")), - ..acc - }; - } - } - }; - } - acc - }, - ); - if let Some(interface_name) = errors.permission { - if let Some(other_errors) = errors.other { - format!( - "\n\n{interface_name}: {} \nAdditional Errors: \n {other_errors}", - eperm_message(), - ) - } else { - format!("\n\n{interface_name}: {}", eperm_message()) - } - } else { - let other_errors = errors - .other - .expect("asked to collect errors but found no errors"); - format!("\n\n {other_errors}") - } -} - pub fn get_input( interface_name: Option<&str>, resolve: bool, dns_server: Option<Ipv4Addr>, ) -> anyhow::Result<OsInputOutput> { - let network_interfaces = if let Some(name) = interface_name { - match get_interface(name) { - Some(interface) => vec![interface], - None => { - anyhow::bail!("Cannot find interface {name}"); - // the homebrew formula relies on this wording, please be careful when changing - } - } - } else { - datalink::interfaces() - }; - - #[cfg(target_os = "windows")] - let network_frames = network_interfaces - .iter() - .filter(|iface| !iface.ips.is_empty()) - .map(|iface| (iface, get_datalink_channel(iface))); - #[cfg(not(target_os = "windows"))] - let network_frames = network_interfaces - .iter() - .filter(|iface| iface.is_up() && !iface.ips.is_empty()) - .map(|iface| (iface, get_datalink_channel(iface))); - - let (available_network_frames, network_interfaces) = { - let network_frames = network_frames.clone(); - let mut available_network_frames = Vec::new(); - let mut available_interfaces: Vec<NetworkInterface> = Vec::new(); - for (iface, rx) in network_frames.filter_map(|(iface, channel)| { - if let Ok(rx) = channel { - Some((iface, rx)) + // get the user's requested interface, if any + // IDEA: allow requesting multiple interfaces + let requested_interfaces = interface_name + .map(|name| get_interface(name).ok_or_else(|| anyhow!("Cannot find interface {name}"))) + .transpose()? + .map(|interface| vec![interface]); + + // take the user's requested interfaces (or all interfaces), and filter for up ones + let available_interfaces = requested_interfaces + .unwrap_or_else(datalink::interfaces) + .into_iter() + .filter(|interface| { + // see https://github.com/libpnet/libpnet/issues/564 + let keep = if cfg!(target_os = "windows") { + !interface.ips.is_empty() } else { - None + interface.is_up() && !interface.ips.is_empty() + }; + if !keep { + mt_log!(debug, "{} is down. Skipping it.", interface.name); } - }) { - available_interfaces.push(iface.clone()); - available_network_frames.push(rx); - } - (available_network_frames, available_interfaces) - }; + keep + }) + .collect_vec(); - if available_network_frames.is_empty() { - let all_errors = collect_errors(network_frames.clone()); - if !all_errors.is_empty() { - anyhow::bail!(all_errors); - } + // bail if no interfaces are up + if available_interfaces.is_empty() { + bail!("Failed to find any network interface to listen on."); + } - anyhow::bail!("Failed to find any network interface to listen on."); + // try to get a frame receiver for each interface + let interfaces_with_frames_res = available_interfaces + .into_iter() + .map(|interface| { + let frames_res = get_datalink_channel(&interface); + (interface, frames_res) + }) + .collect_vec(); + + // warn for all frame receivers we failed to acquire + interfaces_with_frames_res + .iter() + .filter_map(|(interface, frames_res)| frames_res.as_ref().err().map(|err| (interface, err))) + .for_each(|(interface, err)| { + mt_log!( + warn, + "Failed to acquire a frame receiver for {}: {err}", + interface.name + ) + }); + + // bail if all of them fail + // note that `Iterator::all` returns `true` for an empty iterator, so it is important to handle + // that failure mode separately, which we already have + if interfaces_with_frames_res + .iter() + .all(|(_, frames)| frames.is_err()) + { + let (permission_err_interfaces, other_errs) = interfaces_with_frames_res.iter().fold( + (vec![], vec![]), + |(mut perms, mut others), (_, res)| { + match res { + Ok(_) => (), + Err(GetInterfaceError::PermissionError(interface)) => { + perms.push(interface.as_str()) + } + Err(GetInterfaceError::OtherError(err)) => others.push(err.as_str()), + } + (perms, others) + }, + ); + + let err_msg = match (permission_err_interfaces.is_empty(), other_errs.is_empty()) { + (false, false) => format!( + "\n\n{}: {}\nAdditional errors:\n{}", + permission_err_interfaces.join(", "), + eperm_message(), + other_errs.join("\n") + ), + (false, true) => format!( + "\n\n{}: {}", + permission_err_interfaces.join(", "), + eperm_message() + ), + (true, false) => format!("\n\n{}", other_errs.join("\n")), + (true, true) => unreachable!("Found no errors in error handling code path."), + }; + bail!(err_msg); } - let keyboard_events = Box::new(TerminalEvents); - let write_to_stdout = create_write_to_stdout(); + // filter out interfaces for which we failed to acquire a frame receiver + let interfaces_with_frames = interfaces_with_frames_res + .into_iter() + .filter_map(|(interface, res)| res.ok().map(|frames| (interface, frames))) + .collect(); + let dns_client = if resolve { let runtime = Runtime::new()?; - let resolver = match runtime.block_on(dns::Resolver::new(dns_server)) { - Ok(resolver) => resolver, - Err(err) => anyhow::bail!( - "Could not initialize the DNS resolver. Are you offline?\n\nReason: {err:?}" - ), - }; + let resolver = runtime + .block_on(dns::Resolver::new(dns_server)) + .map_err(|err| { + anyhow!("Could not initialize the DNS resolver. Are you offline?\n\nReason: {err}") + })?; let dns_client = dns::Client::new(resolver, runtime)?; Some(dns_client) } else { None }; + let write_to_stdout = create_write_to_stdout(); + Ok(OsInputOutput { - network_interfaces, - network_frames: available_network_frames, + interfaces_with_frames, get_open_sockets, - terminal_events: keyboard_events, + terminal_events: Box::new(TerminalEvents), dns_client, write_to_stdout, }) diff --git a/src/tests/cases/test_utils.rs b/src/tests/cases/test_utils.rs index bed4642..58ef211 100644 --- a/src/tests/cases/test_utils.rs +++ b/src/tests/cases/test_utils.rs @@ -14,8 +14,8 @@ use rstest::fixture; use crate::{ network::dns::Client, tests::fakes::{ - create_fake_dns_client, get_interfaces, get_open_sockets, NetworkFrames, TerminalEvent, - TerminalEvents, TestBackend, + create_fake_dns_client, get_interfaces_with_frames, get_open_sockets, NetworkFrames, + TerminalEvent, TerminalEvents, TestBackend, }, Opt, OsInputOutput, }; @@ -248,11 +248,13 @@ pub fn os_input_output_dns( } pub fn os_input_output_factory( - network_frames: Vec<Box<dyn DataLinkReceiver>>, + network_frames: impl IntoIterator<Item = Box<dyn DataLinkReceiver>>, stdout: Option<Arc<Mutex<Vec<u8>>>>, dns_client: Option<Client>, keyboard_events: Box<dyn Iterator<Item = Event> + Send>, ) -> OsInputOutput { + let interfaces_with_frames = get_interfaces_with_frames(network_frames); + let write_to_stdout: Box<dyn FnMut(String) + Send> = match stdout { Some(stdout) => Box::new({ move |output: String| { @@ -264,8 +266,7 @@ pub fn os_input_output_factory( }; OsInputOutput { - network_interfaces: get_interfaces(), - network_frames, + interfaces_with_frames, get_open_sockets, terminal_events: keyboard_events, dns_client, diff --git a/src/tests/cases/ui.rs b/src/tests/cases/ui.rs index 2209284..70d7e78 100644 --- a/src/tests/cases/ui.rs +++ b/src/tests/cases/ui.rs @@ -17,7 +17,8 @@ use crate::{ sleep_and_quit_events, sleep_resize_and_quit_events, test_backend_factory, }, fakes::{ - create_fake_dns_client, get_interfaces, get_open_sockets, NetworkFrames, TerminalEvents, + create_fake_dns_client, get_interfaces_with_frames, get_open_sockets, NetworkFrames, + TerminalEvents, }, }, Opt, OsInputOutput, @@ -640,6 +641,8 @@ fn sustained_traffic_from_multiple_processes_bi_directional_total( fn traffic_with_host_names(network_frames: Vec<Box<dyn DataLinkReceiver>>) { let (terminal_events, terminal_draw_events, backend) = test_backend_factory(190, 50); + let interfaces_with_frames = get_interfaces_with_frames(network_frames); + let mut ips_to_hostnames = HashMap::new(); ips_to_hostnames.insert( IpAddr::V4("1.1.1.1".parse().unwrap()), @@ -657,8 +660,7 @@ fn traffic_with_host_names(network_frames: Vec<Box<dyn DataLinkReceiver>>) { let write_to_stdout = Box::new(move |_output: String| {}); let os_input = OsInputOutput { - network_interfaces: get_interfaces(), - network_frames, + interfaces_with_frames, get_open_sockets, terminal_events: sleep_and_quit_events(3), dns_client, @@ -678,6 +680,8 @@ fn traffic_with_host_names(network_frames: Vec<Box<dyn DataLinkReceiver>>) { fn truncate_long_hostnames(network_frames: Vec<Box<dyn DataLinkReceiver>>) { let (terminal_events, terminal_draw_events, backend) = test_backend_factory(190, 50); + let interfaces_with_frames = get_interfaces_with_frames(network_frames); + let mut ips_to_hostnames = HashMap::new(); ips_to_hostnames.insert( IpAddr::V4("1.1.1.1".parse().unwrap()), @@ -695,8 +699,7 @@ fn truncate_long_hostnames(network_frames: Vec<Box<dyn DataLinkReceiver>>) { let write_to_stdout = Box::new(move |_output: String| {}); let os_input = OsInputOutput { - network_interfaces: get_interfaces(), - network_frames, + interfaces_with_frames, get_open_sockets, terminal_events: sleep_and_quit_events(3), dns_client, @@ -716,6 +719,8 @@ fn truncate_long_hostnames(network_frames: Vec<Box<dyn DataLinkReceiver>>) { fn no_resolve_mode(network_frames: Vec<Box<dyn DataLinkReceiver>>) { let (terminal_events, terminal_draw_events, backend) = test_backend_factory(190, 50); + let interfaces_with_frames = get_interfaces_with_frames(network_frames); + let mut ips_to_hostnames = HashMap::new(); ips_to_hostnames.insert( IpAddr::V4("1.1.1.1".parse().unwrap()), @@ -733,8 +738,7 @@ fn no_resolve_mode(network_frames: Vec<Box<dyn DataLinkReceiver>>) { let write_to_stdout = Box::new(move |_output: String| {}); let os_input = OsInputOutput { - network_interfaces: get_interfaces(), - network_frames, + interfaces_with_frames, get_open_sockets, terminal_events: sleep_and_quit_events(3), dns_client, @@ -759,6 +763,7 @@ fn traffic_with_winch_event() { 12345, b"I am a fake tcp packet", ))]) as Box<dyn DataLinkReceiver>]; + let interfaces_with_frames = get_interfaces_with_frames(network_frames); let (terminal_events, terminal_draw_events, backend) = test_backend_factory(190, 50); @@ -766,8 +771,7 @@ fn traffic_with_winch_event() { let write_to_stdout = Box::new(move |_output: String| {}); let os_input = OsInputOutput { - network_interfaces: get_interfaces(), - network_frames, + interfaces_with_frames, get_open_sockets, terminal_events: sleep_resize_and_quit_events(2), dns_client, diff --git a/src/tests/fakes/fake_input.rs b/src/tests/fakes/fake_input.rs index f267b48..146a85e 100644 --- a/src/tests/fakes/fake_input.rs +++ b/src/tests/fakes/fake_input.rs @@ -7,6 +7,7 @@ use std::{ use async_trait::async_trait; use crossterm::event::Event; use ipnetwork::IpNetwork; +use itertools::Itertools; use pnet::datalink::{DataLinkReceiver, NetworkInterface}; use tokio::runtime::Runtime; @@ -159,6 +160,12 @@ pub fn get_interfaces() -> Vec<NetworkInterface> { }] } +pub fn get_interfaces_with_frames( + frames: impl IntoIterator<Item = Box<dyn DataLinkReceiver>>, +) -> Vec<(NetworkInterface, Box<dyn DataLinkReceiver>)> { + get_interfaces().into_iter().zip_eq(frames).collect() +} + pub fn create_fake_dns_client(ips_to_hosts: HashMap<IpAddr, String>) -> Option<dns::Client> { let runtime = Runtime::new().unwrap(); let dns_client = dns::Client::new(FakeResolver(ips_to_hosts), runtime).unwrap(); |