diff options
author | Haw Loeung <haw.loeung@canonical.com> | 2020-02-03 10:37:02 +1100 |
---|---|---|
committer | Brian May <brian@linuxpenguins.xyz> | 2020-02-04 07:41:29 +1100 |
commit | 13db89916afdcea30fc3e275ec2654d8037f5934 (patch) | |
tree | 32719afc2b227d2a6ee7e7e5087f662b3261db02 | |
parent | 84076f29fac33d06aaa4298a0ffb7b5468a995a4 (diff) |
Added nft_chain_exists() and fixed nft to use that
-rw-r--r-- | sshuttle/linux.py | 26 | ||||
-rw-r--r-- | sshuttle/methods/nft.py | 6 |
2 files changed, 28 insertions, 4 deletions
diff --git a/sshuttle/linux.py b/sshuttle/linux.py index 2ff59c4..76ae7e0 100644 --- a/sshuttle/linux.py +++ b/sshuttle/linux.py @@ -68,6 +68,32 @@ def nft(family, table, action, *args): raise Fatal('%r returned %d' % (argv, rv)) +def nft_chain_exists(family, table, name): + if family == socket.AF_INET: + fam = 'ip' + elif family == socket.AF_INET6: + fam = 'ip6' + else: + raise Exception('Unsupported family "%s"' % family_to_string(family)) + argv = ['nft', 'list', 'chain', fam, table, name] + debug1('>> %s\n' % ' '.join(argv)) + env = { + 'PATH': os.environ['PATH'], + 'LC_ALL': "C", + } + try: + table_exists = False + output = ssubprocess.check_output(argv, env=env, + stderr=ssubprocess.STDOUT) + for line in output.decode('ASCII').split('\n'): + if line.startswith('table %s %s ' % (fam, table)): + table_exists = True + if table_exists and ('chain %s {' % name) in line: + return True + except ssubprocess.CalledProcessError: + return False + + def nft_get_handle(expression, chain): cmd = 'nft' argv = [cmd, 'list', expression, '-a'] diff --git a/sshuttle/methods/nft.py b/sshuttle/methods/nft.py index 59266af..40f6f3b 100644 --- a/sshuttle/methods/nft.py +++ b/sshuttle/methods/nft.py @@ -1,7 +1,7 @@ import socket from sshuttle.firewall import subnet_weight from sshuttle.helpers import Fatal, log -from sshuttle.linux import nft, nft_get_handle, nonfatal +from sshuttle.linux import nft, nft_get_handle, nft_chain_exists, nonfatal from sshuttle.methods import BaseMethod @@ -28,10 +28,8 @@ class Method(BaseMethod): for chain in ['prerouting', 'postrouting', 'output']: rules = '{{ type nat hook {} priority -100; policy accept; }}' \ .format(chain) - try: + if not nft_chain_exists(family, table, chain): _nft('add chain', chain, rules) - except Fatal: - log('Chain {} already exists, ignoring\n'.format(chain)) chain = 'sshuttle-%s' % port |