summaryrefslogtreecommitdiffstats
path: root/sshuttle
diff options
context:
space:
mode:
authorScott Kuhl <kuhl@mtu.edu>2022-01-07 11:35:37 -0500
committerScott Kuhl <kuhl@mtu.edu>2022-01-07 11:52:39 -0500
commitae8af71886edd8f5cdcd812a839267ee973c3864 (patch)
treef5150dd49cf2be40115b01bd0ee268f5e954b45f /sshuttle
parentae1faa7fa1ea5c35117ab3a275ba7cf8e96d7a12 (diff)
Gracefully exit if firewall process receives Ctrl+C/SIGINT.
Typically sshuttle exits by having the main sshuttle client process terminated. This closes file descriptors which the firewall process then sees and uses as a cue to cleanup the firewall rules. The firewall process ignored SIGINT/SIGTERM signals and used setsid() to prevent Ctrl+C from sending signals to the firewall process. This patch makes the firewall process accept SIGINT/SIGTERM signals and then in turn sends a SIGINT signal to the main sshuttle client process which then triggers a regular shutdown as described above. This allows a user to manually send a SIGINT/SIGTERM to either sshuttle process and have it exit gracefully. It also is needed if setsid() fails (known to occur if sudo's use_pty option is used) and then the Ctrl+C SIGINT signal goes to the firewall process. The PID of the sshuttle client process is sent to the firewall process. Using os.getppid() in the firewall process doesn't correctly return the sshuttle client PID.
Diffstat (limited to 'sshuttle')
-rw-r--r--sshuttle/client.py4
-rw-r--r--sshuttle/firewall.py31
2 files changed, 27 insertions, 8 deletions
diff --git a/sshuttle/client.py b/sshuttle/client.py
index f9bf04b..abc3836 100644
--- a/sshuttle/client.py
+++ b/sshuttle/client.py
@@ -298,8 +298,8 @@ class FirewallClient:
else:
user = b'%d' % self.user
- self.pfile.write(b'GO %d %s %s\n' %
- (udp, user, bytes(self.tmark, 'ascii')))
+ self.pfile.write(b'GO %d %s %s %d\n' %
+ (udp, user, bytes(self.tmark, 'ascii'), os.getpid()))
self.pfile.flush()
line = self.pfile.readline()
diff --git a/sshuttle/firewall.py b/sshuttle/firewall.py
index d3806cd..fb9471d 100644
--- a/sshuttle/firewall.py
+++ b/sshuttle/firewall.py
@@ -13,7 +13,7 @@ from sshuttle.helpers import debug1, debug2, Fatal
from sshuttle.methods import get_auto_method, get_method
HOSTSFILE = '/etc/hosts'
-
+sshuttle_pid = None
def rewrite_etc_hosts(hostmap, port):
BAKFILE = '%s.sbak' % HOSTSFILE
@@ -55,6 +55,23 @@ def restore_etc_hosts(hostmap, port):
debug2('undoing /etc/hosts changes.')
rewrite_etc_hosts({}, port)
+def firewall_exit(signum, frame):
+ # The typical sshuttle exit is that the main sshuttle process
+ # exits, closes file descriptors it uses, and the firewall process
+ # notices that it can't read from stdin anymore and exits
+ # (cleaning up firewall rules).
+ #
+ # However, in some cases, Ctrl+C might get sent to the firewall
+ # process. This might caused if someone manually tries to kill the
+ # firewall process, or if sshuttle was started using sudo's use_pty option
+ # and they try to exit by pressing Ctrl+C. Here, we forward the
+ # Ctrl+C/SIGINT to the main sshuttle process which should trigger
+ # the typical exit process as described above.
+ global sshuttle_pid
+ if sshuttle_pid:
+ debug1("Relaying SIGINT to sshuttle process %d\n" % sshuttle_pid)
+ os.kill(sshuttle_pid, signal.SIGINT)
+
# Isolate function that needs to be replaced for tests
def setup_daemon():
@@ -65,8 +82,8 @@ def setup_daemon():
# disappears; we still have to clean up.
signal.signal(signal.SIGHUP, signal.SIG_IGN)
signal.signal(signal.SIGPIPE, signal.SIG_IGN)
- signal.signal(signal.SIGTERM, signal.SIG_IGN)
- signal.signal(signal.SIGINT, signal.SIG_IGN)
+ signal.signal(signal.SIGTERM, firewall_exit)
+ signal.signal(signal.SIGINT, firewall_exit)
# ctrl-c shouldn't be passed along to me. When the main sshuttle dies,
# I'll die automatically.
@@ -230,12 +247,14 @@ def main(method_name, syslog):
raise Fatal('expected GO but got %r' % line)
_, _, args = line.partition(" ")
- udp, user, tmark = args.strip().split(" ", 2)
+ global sshuttle_pid
+ udp, user, tmark, sshuttle_pid = args.strip().split(" ", 3)
udp = bool(int(udp))
+ sshuttle_pid = int(sshuttle_pid)
if user == '-':
user = None
- debug2('Got udp: %r, user: %r, tmark: %s' %
- (udp, user, tmark))
+ debug2('Got udp: %r, user: %r, tmark: %s, sshuttle_pid: %d' %
+ (udp, user, tmark, sshuttle_pid))
subnets_v6 = [i for i in subnets if i[0] == socket.AF_INET6]
nslist_v6 = [i for i in nslist if i[0] == socket.AF_INET6]