summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorHaw Loeung <haw.loeung@canonical.com>2020-02-03 10:37:02 +1100
committerBrian May <brian@linuxpenguins.xyz>2020-02-04 07:41:29 +1100
commit13db89916afdcea30fc3e275ec2654d8037f5934 (patch)
tree32719afc2b227d2a6ee7e7e5087f662b3261db02
parent84076f29fac33d06aaa4298a0ffb7b5468a995a4 (diff)
Added nft_chain_exists() and fixed nft to use that
-rw-r--r--sshuttle/linux.py26
-rw-r--r--sshuttle/methods/nft.py6
2 files changed, 28 insertions, 4 deletions
diff --git a/sshuttle/linux.py b/sshuttle/linux.py
index 2ff59c4..76ae7e0 100644
--- a/sshuttle/linux.py
+++ b/sshuttle/linux.py
@@ -68,6 +68,32 @@ def nft(family, table, action, *args):
raise Fatal('%r returned %d' % (argv, rv))
+def nft_chain_exists(family, table, name):
+ if family == socket.AF_INET:
+ fam = 'ip'
+ elif family == socket.AF_INET6:
+ fam = 'ip6'
+ else:
+ raise Exception('Unsupported family "%s"' % family_to_string(family))
+ argv = ['nft', 'list', 'chain', fam, table, name]
+ debug1('>> %s\n' % ' '.join(argv))
+ env = {
+ 'PATH': os.environ['PATH'],
+ 'LC_ALL': "C",
+ }
+ try:
+ table_exists = False
+ output = ssubprocess.check_output(argv, env=env,
+ stderr=ssubprocess.STDOUT)
+ for line in output.decode('ASCII').split('\n'):
+ if line.startswith('table %s %s ' % (fam, table)):
+ table_exists = True
+ if table_exists and ('chain %s {' % name) in line:
+ return True
+ except ssubprocess.CalledProcessError:
+ return False
+
+
def nft_get_handle(expression, chain):
cmd = 'nft'
argv = [cmd, 'list', expression, '-a']
diff --git a/sshuttle/methods/nft.py b/sshuttle/methods/nft.py
index 59266af..40f6f3b 100644
--- a/sshuttle/methods/nft.py
+++ b/sshuttle/methods/nft.py
@@ -1,7 +1,7 @@
import socket
from sshuttle.firewall import subnet_weight
from sshuttle.helpers import Fatal, log
-from sshuttle.linux import nft, nft_get_handle, nonfatal
+from sshuttle.linux import nft, nft_get_handle, nft_chain_exists, nonfatal
from sshuttle.methods import BaseMethod
@@ -28,10 +28,8 @@ class Method(BaseMethod):
for chain in ['prerouting', 'postrouting', 'output']:
rules = '{{ type nat hook {} priority -100; policy accept; }}' \
.format(chain)
- try:
+ if not nft_chain_exists(family, table, chain):
_nft('add chain', chain, rules)
- except Fatal:
- log('Chain {} already exists, ignoring\n'.format(chain))
chain = 'sshuttle-%s' % port