summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorcyqsimon <28627918+cyqsimon@users.noreply.github.com>2023-10-21 22:14:46 +0800
committerGitHub <noreply@github.com>2023-10-21 22:14:46 +0800
commit6fa77d29917452f682fefa5df4db72c3d697661e (patch)
tree4f25976c53821c0de0f0fa9e5e93aecf17eb5482
parent89e1140bea689c5790029bd02aea74b6b5ac0bea (diff)
Refactor `OsInputOutput` (combine interfaces & frames into single Vec) (#310)
* Refactor `OsInputOutput` (combine interfaces & frames into single Vec) * Add note on handling a separate failure case * Reduce code duplication
-rw-r--r--src/main.rs6
-rw-r--r--src/os/shared.rs230
-rw-r--r--src/tests/cases/test_utils.rs11
-rw-r--r--src/tests/cases/ui.rs22
-rw-r--r--src/tests/fakes/fake_input.rs7
5 files changed, 131 insertions, 145 deletions
diff --git a/src/main.rs b/src/main.rs
index 87c5cf6..ade4b9e 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -93,8 +93,7 @@ pub struct OpenSockets {
}
pub struct OsInputOutput {
- pub network_interfaces: Vec<NetworkInterface>,
- pub network_frames: Vec<Box<dyn DataLinkReceiver>>,
+ pub interfaces_with_frames: Vec<(NetworkInterface, Box<dyn DataLinkReceiver>)>,
pub get_open_sockets: fn() -> OpenSockets,
pub terminal_events: Box<dyn Iterator<Item = Event> + Send>,
pub dns_client: Option<dns::Client>,
@@ -281,9 +280,8 @@ where
active_threads.push(terminal_event_handler);
let sniffer_threads = os_input
- .network_interfaces
+ .interfaces_with_frames
.into_iter()
- .zip(os_input.network_frames)
.map(|(iface, frames)| {
let name = format!("sniffing_handler_{}", iface.name);
let running = running.clone();
diff --git a/src/os/shared.rs b/src/os/shared.rs
index 5383db3..ed97878 100644
--- a/src/os/shared.rs
+++ b/src/os/shared.rs
@@ -4,11 +4,13 @@ use std::{
time,
};
+use anyhow::{anyhow, bail};
use crossterm::event::{read, Event};
+use itertools::Itertools;
use pnet::datalink::{self, Channel::Ethernet, Config, DataLinkReceiver, NetworkInterface};
use tokio::runtime::Runtime;
-use crate::{network::dns, os::errors::GetInterfaceError, OsInputOutput};
+use crate::{mt_log, network::dns, os::errors::GetInterfaceError, OsInputOutput};
#[cfg(target_os = "linux")]
use crate::os::linux::get_open_sockets;
@@ -63,160 +65,134 @@ fn get_interface(interface_name: &str) -> Option<NetworkInterface> {
}
fn create_write_to_stdout() -> Box<dyn FnMut(String) + Send> {
+ let mut stdout = io::stdout();
Box::new({
- let mut stdout = io::stdout();
move |output: String| {
writeln!(stdout, "{}", output).unwrap();
}
})
}
-#[derive(Debug)]
-pub struct UserErrors {
- permission: Option<String>,
- other: Option<String>,
-}
-
-pub fn collect_errors<'a, I>(network_frames: I) -> String
-where
- I: Iterator<
- Item = (
- &'a NetworkInterface,
- Result<Box<dyn DataLinkReceiver>, GetInterfaceError>,
- ),
- >,
-{
- let errors = network_frames.fold(
- UserErrors {
- permission: None,
- other: None,
- },
- |acc, (_, elem)| {
- if let Some(iface_error) = elem.err() {
- match iface_error {
- GetInterfaceError::PermissionError(interface_name) => {
- if let Some(prev_interface) = acc.permission {
- return UserErrors {
- permission: Some(format!("{prev_interface}, {interface_name}")),
- ..acc
- };
- } else {
- return UserErrors {
- permission: Some(interface_name),
- ..acc
- };
- }
- }
- error => {
- if let Some(prev_errors) = acc.other {
- return UserErrors {
- other: Some(format!("{prev_errors} \n {error}")),
- ..acc
- };
- } else {
- return UserErrors {
- other: Some(format!("{error}")),
- ..acc
- };
- }
- }
- };
- }
- acc
- },
- );
- if let Some(interface_name) = errors.permission {
- if let Some(other_errors) = errors.other {
- format!(
- "\n\n{interface_name}: {} \nAdditional Errors: \n {other_errors}",
- eperm_message(),
- )
- } else {
- format!("\n\n{interface_name}: {}", eperm_message())
- }
- } else {
- let other_errors = errors
- .other
- .expect("asked to collect errors but found no errors");
- format!("\n\n {other_errors}")
- }
-}
-
pub fn get_input(
interface_name: Option<&str>,
resolve: bool,
dns_server: Option<Ipv4Addr>,
) -> anyhow::Result<OsInputOutput> {
- let network_interfaces = if let Some(name) = interface_name {
- match get_interface(name) {
- Some(interface) => vec![interface],
- None => {
- anyhow::bail!("Cannot find interface {name}");
- // the homebrew formula relies on this wording, please be careful when changing
- }
- }
- } else {
- datalink::interfaces()
- };
-
- #[cfg(target_os = "windows")]
- let network_frames = network_interfaces
- .iter()
- .filter(|iface| !iface.ips.is_empty())
- .map(|iface| (iface, get_datalink_channel(iface)));
- #[cfg(not(target_os = "windows"))]
- let network_frames = network_interfaces
- .iter()
- .filter(|iface| iface.is_up() && !iface.ips.is_empty())
- .map(|iface| (iface, get_datalink_channel(iface)));
-
- let (available_network_frames, network_interfaces) = {
- let network_frames = network_frames.clone();
- let mut available_network_frames = Vec::new();
- let mut available_interfaces: Vec<NetworkInterface> = Vec::new();
- for (iface, rx) in network_frames.filter_map(|(iface, channel)| {
- if let Ok(rx) = channel {
- Some((iface, rx))
+ // get the user's requested interface, if any
+ // IDEA: allow requesting multiple interfaces
+ let requested_interfaces = interface_name
+ .map(|name| get_interface(name).ok_or_else(|| anyhow!("Cannot find interface {name}")))
+ .transpose()?
+ .map(|interface| vec![interface]);
+
+ // take the user's requested interfaces (or all interfaces), and filter for up ones
+ let available_interfaces = requested_interfaces
+ .unwrap_or_else(datalink::interfaces)
+ .into_iter()
+ .filter(|interface| {
+ // see https://github.com/libpnet/libpnet/issues/564
+ let keep = if cfg!(target_os = "windows") {
+ !interface.ips.is_empty()
} else {
- None
+ interface.is_up() && !interface.ips.is_empty()
+ };
+ if !keep {
+ mt_log!(debug, "{} is down. Skipping it.", interface.name);
}
- }) {
- available_interfaces.push(iface.clone());
- available_network_frames.push(rx);
- }
- (available_network_frames, available_interfaces)
- };
+ keep
+ })
+ .collect_vec();
- if available_network_frames.is_empty() {
- let all_errors = collect_errors(network_frames.clone());
- if !all_errors.is_empty() {
- anyhow::bail!(all_errors);
- }
+ // bail if no interfaces are up
+ if available_interfaces.is_empty() {
+ bail!("Failed to find any network interface to listen on.");
+ }
- anyhow::bail!("Failed to find any network interface to listen on.");
+ // try to get a frame receiver for each interface
+ let interfaces_with_frames_res = available_interfaces
+ .into_iter()
+ .map(|interface| {
+ let frames_res = get_datalink_channel(&interface);
+ (interface, frames_res)
+ })
+ .collect_vec();
+
+ // warn for all frame receivers we failed to acquire
+ interfaces_with_frames_res
+ .iter()
+ .filter_map(|(interface, frames_res)| frames_res.as_ref().err().map(|err| (interface, err)))
+ .for_each(|(interface, err)| {
+ mt_log!(
+ warn,
+ "Failed to acquire a frame receiver for {}: {err}",
+ interface.name
+ )
+ });
+
+ // bail if all of them fail
+ // note that `Iterator::all` returns `true` for an empty iterator, so it is important to handle
+ // that failure mode separately, which we already have
+ if interfaces_with_frames_res
+ .iter()
+ .all(|(_, frames)| frames.is_err())
+ {
+ let (permission_err_interfaces, other_errs) = interfaces_with_frames_res.iter().fold(
+ (vec![], vec![]),
+ |(mut perms, mut others), (_, res)| {
+ match res {
+ Ok(_) => (),
+ Err(GetInterfaceError::PermissionError(interface)) => {
+ perms.push(interface.as_str())
+ }
+ Err(GetInterfaceError::OtherError(err)) => others.push(err.as_str()),
+ }
+ (perms, others)
+ },
+ );
+
+ let err_msg = match (permission_err_interfaces.is_empty(), other_errs.is_empty()) {
+ (false, false) => format!(
+ "\n\n{}: {}\nAdditional errors:\n{}",
+ permission_err_interfaces.join(", "),
+ eperm_message(),
+ other_errs.join("\n")
+ ),
+ (false, true) => format!(
+ "\n\n{}: {}",
+ permission_err_interfaces.join(", "),
+ eperm_message()
+ ),
+ (true, false) => format!("\n\n{}", other_errs.join("\n")),
+ (true, true) => unreachable!("Found no errors in error handling code path."),
+ };
+ bail!(err_msg);
}
- let keyboard_events = Box::new(TerminalEvents);
- let write_to_stdout = create_write_to_stdout();
+ // filter out interfaces for which we failed to acquire a frame receiver
+ let interfaces_with_frames = interfaces_with_frames_res
+ .into_iter()
+ .filter_map(|(interface, res)| res.ok().map(|frames| (interface, frames)))
+ .collect();
+
let dns_client = if resolve {
let runtime = Runtime::new()?;
- let resolver = match runtime.block_on(dns::Resolver::new(dns_server)) {
- Ok(resolver) => resolver,
- Err(err) => anyhow::bail!(
- "Could not initialize the DNS resolver. Are you offline?\n\nReason: {err:?}"
- ),
- };
+ let resolver = runtime
+ .block_on(dns::Resolver::new(dns_server))
+ .map_err(|err| {
+ anyhow!("Could not initialize the DNS resolver. Are you offline?\n\nReason: {err}")
+ })?;
let dns_client = dns::Client::new(resolver, runtime)?;
Some(dns_client)
} else {
None
};
+ let write_to_stdout = create_write_to_stdout();
+
Ok(OsInputOutput {
- network_interfaces,
- network_frames: available_network_frames,
+ interfaces_with_frames,
get_open_sockets,
- terminal_events: keyboard_events,
+ terminal_events: Box::new(TerminalEvents),
dns_client,
write_to_stdout,
})
diff --git a/src/tests/cases/test_utils.rs b/src/tests/cases/test_utils.rs
index bed4642..58ef211 100644
--- a/src/tests/cases/test_utils.rs
+++ b/src/tests/cases/test_utils.rs
@@ -14,8 +14,8 @@ use rstest::fixture;
use crate::{
network::dns::Client,
tests::fakes::{
- create_fake_dns_client, get_interfaces, get_open_sockets, NetworkFrames, TerminalEvent,
- TerminalEvents, TestBackend,
+ create_fake_dns_client, get_interfaces_with_frames, get_open_sockets, NetworkFrames,
+ TerminalEvent, TerminalEvents, TestBackend,
},
Opt, OsInputOutput,
};
@@ -248,11 +248,13 @@ pub fn os_input_output_dns(
}
pub fn os_input_output_factory(
- network_frames: Vec<Box<dyn DataLinkReceiver>>,
+ network_frames: impl IntoIterator<Item = Box<dyn DataLinkReceiver>>,
stdout: Option<Arc<Mutex<Vec<u8>>>>,
dns_client: Option<Client>,
keyboard_events: Box<dyn Iterator<Item = Event> + Send>,
) -> OsInputOutput {
+ let interfaces_with_frames = get_interfaces_with_frames(network_frames);
+
let write_to_stdout: Box<dyn FnMut(String) + Send> = match stdout {
Some(stdout) => Box::new({
move |output: String| {
@@ -264,8 +266,7 @@ pub fn os_input_output_factory(
};
OsInputOutput {
- network_interfaces: get_interfaces(),
- network_frames,
+ interfaces_with_frames,
get_open_sockets,
terminal_events: keyboard_events,
dns_client,
diff --git a/src/tests/cases/ui.rs b/src/tests/cases/ui.rs
index 2209284..70d7e78 100644
--- a/src/tests/cases/ui.rs
+++ b/src/tests/cases/ui.rs
@@ -17,7 +17,8 @@ use crate::{
sleep_and_quit_events, sleep_resize_and_quit_events, test_backend_factory,
},
fakes::{
- create_fake_dns_client, get_interfaces, get_open_sockets, NetworkFrames, TerminalEvents,
+ create_fake_dns_client, get_interfaces_with_frames, get_open_sockets, NetworkFrames,
+ TerminalEvents,
},
},
Opt, OsInputOutput,
@@ -640,6 +641,8 @@ fn sustained_traffic_from_multiple_processes_bi_directional_total(
fn traffic_with_host_names(network_frames: Vec<Box<dyn DataLinkReceiver>>) {
let (terminal_events, terminal_draw_events, backend) = test_backend_factory(190, 50);
+ let interfaces_with_frames = get_interfaces_with_frames(network_frames);
+
let mut ips_to_hostnames = HashMap::new();
ips_to_hostnames.insert(
IpAddr::V4("1.1.1.1".parse().unwrap()),
@@ -657,8 +660,7 @@ fn traffic_with_host_names(network_frames: Vec<Box<dyn DataLinkReceiver>>) {
let write_to_stdout = Box::new(move |_output: String| {});
let os_input = OsInputOutput {
- network_interfaces: get_interfaces(),
- network_frames,
+ interfaces_with_frames,
get_open_sockets,
terminal_events: sleep_and_quit_events(3),
dns_client,
@@ -678,6 +680,8 @@ fn traffic_with_host_names(network_frames: Vec<Box<dyn DataLinkReceiver>>) {
fn truncate_long_hostnames(network_frames: Vec<Box<dyn DataLinkReceiver>>) {
let (terminal_events, terminal_draw_events, backend) = test_backend_factory(190, 50);
+ let interfaces_with_frames = get_interfaces_with_frames(network_frames);
+
let mut ips_to_hostnames = HashMap::new();
ips_to_hostnames.insert(
IpAddr::V4("1.1.1.1".parse().unwrap()),
@@ -695,8 +699,7 @@ fn truncate_long_hostnames(network_frames: Vec<Box<dyn DataLinkReceiver>>) {
let write_to_stdout = Box::new(move |_output: String| {});
let os_input = OsInputOutput {
- network_interfaces: get_interfaces(),
- network_frames,
+ interfaces_with_frames,
get_open_sockets,
terminal_events: sleep_and_quit_events(3),
dns_client,
@@ -716,6 +719,8 @@ fn truncate_long_hostnames(network_frames: Vec<Box<dyn DataLinkReceiver>>) {
fn no_resolve_mode(network_frames: Vec<Box<dyn DataLinkReceiver>>) {
let (terminal_events, terminal_draw_events, backend) = test_backend_factory(190, 50);
+ let interfaces_with_frames = get_interfaces_with_frames(network_frames);
+
let mut ips_to_hostnames = HashMap::new();
ips_to_hostnames.insert(
IpAddr::V4("1.1.1.1".parse().unwrap()),
@@ -733,8 +738,7 @@ fn no_resolve_mode(network_frames: Vec<Box<dyn DataLinkReceiver>>) {
let write_to_stdout = Box::new(move |_output: String| {});
let os_input = OsInputOutput {
- network_interfaces: get_interfaces(),
- network_frames,
+ interfaces_with_frames,
get_open_sockets,
terminal_events: sleep_and_quit_events(3),
dns_client,
@@ -759,6 +763,7 @@ fn traffic_with_winch_event() {
12345,
b"I am a fake tcp packet",
))]) as Box<dyn DataLinkReceiver>];
+ let interfaces_with_frames = get_interfaces_with_frames(network_frames);
let (terminal_events, terminal_draw_events, backend) = test_backend_factory(190, 50);
@@ -766,8 +771,7 @@ fn traffic_with_winch_event() {
let write_to_stdout = Box::new(move |_output: String| {});
let os_input = OsInputOutput {
- network_interfaces: get_interfaces(),
- network_frames,
+ interfaces_with_frames,
get_open_sockets,
terminal_events: sleep_resize_and_quit_events(2),
dns_client,
diff --git a/src/tests/fakes/fake_input.rs b/src/tests/fakes/fake_input.rs
index f267b48..146a85e 100644
--- a/src/tests/fakes/fake_input.rs
+++ b/src/tests/fakes/fake_input.rs
@@ -7,6 +7,7 @@ use std::{
use async_trait::async_trait;
use crossterm::event::Event;
use ipnetwork::IpNetwork;
+use itertools::Itertools;
use pnet::datalink::{DataLinkReceiver, NetworkInterface};
use tokio::runtime::Runtime;
@@ -159,6 +160,12 @@ pub fn get_interfaces() -> Vec<NetworkInterface> {
}]
}
+pub fn get_interfaces_with_frames(
+ frames: impl IntoIterator<Item = Box<dyn DataLinkReceiver>>,
+) -> Vec<(NetworkInterface, Box<dyn DataLinkReceiver>)> {
+ get_interfaces().into_iter().zip_eq(frames).collect()
+}
+
pub fn create_fake_dns_client(ips_to_hosts: HashMap<IpAddr, String>) -> Option<dns::Client> {
let runtime = Runtime::new().unwrap();
let dns_client = dns::Client::new(FakeResolver(ips_to_hosts), runtime).unwrap();