From ae8af71886edd8f5cdcd812a839267ee973c3864 Mon Sep 17 00:00:00 2001 From: Scott Kuhl Date: Fri, 7 Jan 2022 11:35:37 -0500 Subject: Gracefully exit if firewall process receives Ctrl+C/SIGINT. Typically sshuttle exits by having the main sshuttle client process terminated. This closes file descriptors which the firewall process then sees and uses as a cue to cleanup the firewall rules. The firewall process ignored SIGINT/SIGTERM signals and used setsid() to prevent Ctrl+C from sending signals to the firewall process. This patch makes the firewall process accept SIGINT/SIGTERM signals and then in turn sends a SIGINT signal to the main sshuttle client process which then triggers a regular shutdown as described above. This allows a user to manually send a SIGINT/SIGTERM to either sshuttle process and have it exit gracefully. It also is needed if setsid() fails (known to occur if sudo's use_pty option is used) and then the Ctrl+C SIGINT signal goes to the firewall process. The PID of the sshuttle client process is sent to the firewall process. Using os.getppid() in the firewall process doesn't correctly return the sshuttle client PID. --- sshuttle/client.py | 4 ++-- sshuttle/firewall.py | 31 +++++++++++++++++++++++++------ 2 files changed, 27 insertions(+), 8 deletions(-) (limited to 'sshuttle') diff --git a/sshuttle/client.py b/sshuttle/client.py index f9bf04b..abc3836 100644 --- a/sshuttle/client.py +++ b/sshuttle/client.py @@ -298,8 +298,8 @@ class FirewallClient: else: user = b'%d' % self.user - self.pfile.write(b'GO %d %s %s\n' % - (udp, user, bytes(self.tmark, 'ascii'))) + self.pfile.write(b'GO %d %s %s %d\n' % + (udp, user, bytes(self.tmark, 'ascii'), os.getpid())) self.pfile.flush() line = self.pfile.readline() diff --git a/sshuttle/firewall.py b/sshuttle/firewall.py index d3806cd..fb9471d 100644 --- a/sshuttle/firewall.py +++ b/sshuttle/firewall.py @@ -13,7 +13,7 @@ from sshuttle.helpers import debug1, debug2, Fatal from sshuttle.methods import get_auto_method, get_method HOSTSFILE = '/etc/hosts' - +sshuttle_pid = None def rewrite_etc_hosts(hostmap, port): BAKFILE = '%s.sbak' % HOSTSFILE @@ -55,6 +55,23 @@ def restore_etc_hosts(hostmap, port): debug2('undoing /etc/hosts changes.') rewrite_etc_hosts({}, port) +def firewall_exit(signum, frame): + # The typical sshuttle exit is that the main sshuttle process + # exits, closes file descriptors it uses, and the firewall process + # notices that it can't read from stdin anymore and exits + # (cleaning up firewall rules). + # + # However, in some cases, Ctrl+C might get sent to the firewall + # process. This might caused if someone manually tries to kill the + # firewall process, or if sshuttle was started using sudo's use_pty option + # and they try to exit by pressing Ctrl+C. Here, we forward the + # Ctrl+C/SIGINT to the main sshuttle process which should trigger + # the typical exit process as described above. + global sshuttle_pid + if sshuttle_pid: + debug1("Relaying SIGINT to sshuttle process %d\n" % sshuttle_pid) + os.kill(sshuttle_pid, signal.SIGINT) + # Isolate function that needs to be replaced for tests def setup_daemon(): @@ -65,8 +82,8 @@ def setup_daemon(): # disappears; we still have to clean up. signal.signal(signal.SIGHUP, signal.SIG_IGN) signal.signal(signal.SIGPIPE, signal.SIG_IGN) - signal.signal(signal.SIGTERM, signal.SIG_IGN) - signal.signal(signal.SIGINT, signal.SIG_IGN) + signal.signal(signal.SIGTERM, firewall_exit) + signal.signal(signal.SIGINT, firewall_exit) # ctrl-c shouldn't be passed along to me. When the main sshuttle dies, # I'll die automatically. @@ -230,12 +247,14 @@ def main(method_name, syslog): raise Fatal('expected GO but got %r' % line) _, _, args = line.partition(" ") - udp, user, tmark = args.strip().split(" ", 2) + global sshuttle_pid + udp, user, tmark, sshuttle_pid = args.strip().split(" ", 3) udp = bool(int(udp)) + sshuttle_pid = int(sshuttle_pid) if user == '-': user = None - debug2('Got udp: %r, user: %r, tmark: %s' % - (udp, user, tmark)) + debug2('Got udp: %r, user: %r, tmark: %s, sshuttle_pid: %d' % + (udp, user, tmark, sshuttle_pid)) subnets_v6 = [i for i in subnets if i[0] == socket.AF_INET6] nslist_v6 = [i for i in nslist if i[0] == socket.AF_INET6] -- cgit v1.2.3