summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorBrian May <brian@linuxpenguins.xyz>2015-12-15 11:40:55 +1100
committerBrian May <brian@linuxpenguins.xyz>2015-12-15 11:40:55 +1100
commit90654b4fb95f608ae955cee5e3befd0041ac1c4a (patch)
tree9afee6683beafe746ab6ffd36b4d4a252a23e9d2
parent6b4e36c5280f6886f413b43d76d15a5c74b897bc (diff)
Simplify selection of features
-rw-r--r--sshuttle/client.py72
-rw-r--r--sshuttle/firewall.py15
-rw-r--r--sshuttle/methods/__init__.py10
-rw-r--r--sshuttle/methods/nat.py13
-rw-r--r--sshuttle/methods/pf.py43
-rw-r--r--sshuttle/methods/tproxy.py33
-rw-r--r--sshuttle/tests/test_methods_nat.py16
-rw-r--r--sshuttle/tests/test_methods_pf.py16
-rw-r--r--sshuttle/tests/test_methods_tproxy.py19
9 files changed, 151 insertions, 86 deletions
diff --git a/sshuttle/client.py b/sshuttle/client.py
index 0dfd404..f906512 100644
--- a/sshuttle/client.py
+++ b/sshuttle/client.py
@@ -14,7 +14,7 @@ import platform
from sshuttle.ssnet import SockWrapper, Handler, Proxy, Mux, MuxWrapper
from sshuttle.helpers import log, debug1, debug2, debug3, Fatal, islocal, \
resolvconf_nameservers
-from sshuttle.methods import get_method
+from sshuttle.methods import get_method, Features
_extra_fd = os.open('/dev/null', os.O_RDONLY)
@@ -505,19 +505,44 @@ def main(listenip_v6, listenip_v4,
fw = FirewallClient(method_name)
- features = fw.method.get_supported_features()
+ # Get family specific subnet lists
+ if dns:
+ nslist += resolvconf_nameservers()
+
+ subnets = subnets_include + subnets_exclude # we don't care here
+ subnets_v6 = [i for i in subnets if i[0] == socket.AF_INET6]
+ nslist_v6 = [i for i in nslist if i[0] == socket.AF_INET6]
+ subnets_v4 = [i for i in subnets if i[0] == socket.AF_INET]
+ nslist_v4 = [i for i in nslist if i[0] == socket.AF_INET]
+
+ # Check features available
+ avail = fw.method.get_supported_features()
+ required = Features()
+
if listenip_v6 == "auto":
- if features.ipv6:
+ if avail.ipv6:
listenip_v6 = ('::1', 0)
else:
listenip_v6 = None
+ required.ipv6 = len(subnets_v6) > 0 or len(nslist_v6) > 0
+ required.udp = avail.udp
+ required.dns = len(nslist) > 0
+
+ fw.method.assert_features(required)
+
+ if required.ipv6 and listenip_v6 is None:
+ raise Fatal("IPv6 required but not listening.")
+
+ # display features enabled
+ debug1("IPv6 enabled: %r\n" % required.ipv6)
+ debug1("UDP enabled: %r\n" % required.udp)
+ debug1("DNS enabled: %r\n" % required.dns)
+
+ # bind to required ports
if listenip_v4 == "auto":
listenip_v4 = ('127.0.0.1', 0)
- udp = features.udp
- debug1("UDP enabled: %r\n" % udp)
-
if listenip_v6 and listenip_v6[1] and listenip_v4 and listenip_v4[1]:
# if both ports given, no need to search for a spare port
ports = [0, ]
@@ -536,7 +561,7 @@ def main(listenip_v6, listenip_v4,
tcp_listener = MultiListener()
tcp_listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
- if udp:
+ if required.udp:
udp_listener = MultiListener(socket.SOCK_DGRAM)
udp_listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
else:
@@ -584,10 +609,7 @@ def main(listenip_v6, listenip_v4,
udp_listener.print_listening("UDP redirector")
bound = False
- if dns or nslist:
- if dns:
- nslist += resolvconf_nameservers()
- dns = True
+ if required.dns:
# search for spare port for DNS
debug2('Binding DNS:')
ports = range(12300, 9000, -1)
@@ -628,17 +650,41 @@ def main(listenip_v6, listenip_v4,
dnsport_v4 = 0
dns_listener = None
- fw.method.check_settings(udp, dns)
+ # Last minute sanity checks.
+ # These should never fail.
+ # If these do fail, something is broken above.
+ if len(subnets_v6) > 0:
+ assert required.ipv6
+ if redirectport_v6 == 0:
+ raise Fatal("IPv6 subnets defined but not listening")
+
+ if len(nslist_v6) > 0:
+ assert required.dns
+ assert required.ipv6
+ if dnsport_v6 == 0:
+ raise Fatal("IPv6 ns servers defined but not listening")
+
+ if len(subnets_v4) > 0:
+ if redirectport_v4 == 0:
+ raise Fatal("IPv4 subnets defined but not listening")
+
+ if len(nslist_v4) > 0:
+ if dnsport_v4 == 0:
+ raise Fatal("IPv4 ns servers defined but not listening")
+
+ # setup method specific stuff on listeners
fw.method.setup_tcp_listener(tcp_listener)
if udp_listener:
fw.method.setup_udp_listener(udp_listener)
if dns_listener:
fw.method.setup_udp_listener(dns_listener)
+ # start the firewall
fw.setup(subnets_include, subnets_exclude, nslist,
redirectport_v6, redirectport_v4, dnsport_v6, dnsport_v4,
- udp)
+ required.udp)
+ # start the client process
try:
return _main(tcp_listener, udp_listener, fw, ssh_cmd, remotename,
python, latency_control, dns_listener,
diff --git a/sshuttle/firewall.py b/sshuttle/firewall.py
index 7d5ece0..8d7c011 100644
--- a/sshuttle/firewall.py
+++ b/sshuttle/firewall.py
@@ -178,26 +178,23 @@ def main(method_name, syslog):
try:
debug1('firewall manager: setting up.\n')
- nslist_v6 = [i for i in nslist if i[0] == socket.AF_INET6]
subnets_v6 = [i for i in subnets if i[0] == socket.AF_INET6]
- if port_v6 > 0:
+ nslist_v6 = [i for i in nslist if i[0] == socket.AF_INET6]
+
+ if len(subnets_v6) > 0 or len(nslist_v6) > 0:
debug2('firewall manager: setting up IPv6.\n')
method.setup_firewall(
port_v6, dnsport_v6, nslist_v6,
socket.AF_INET6, subnets_v6, udp)
- elif len(subnets_v6) > 0:
- debug1("IPv6 subnets defined but IPv6 disabled\n")
- nslist_v4 = [i for i in nslist if i[0] == socket.AF_INET]
subnets_v4 = [i for i in subnets if i[0] == socket.AF_INET]
- if port_v4 > 0:
+ nslist_v4 = [i for i in nslist if i[0] == socket.AF_INET]
+
+ if len(subnets_v4) > 0 or len(nslist_v4) > 0:
debug2('firewall manager: setting up IPv4.\n')
method.setup_firewall(
port_v4, dnsport_v4, nslist_v4,
socket.AF_INET, subnets_v4, udp)
- elif len(subnets_v4) > 0:
- debug1('firewall manager: '
- 'IPv4 subnets defined but IPv4 disabled\n')
stdout.write('STARTED\n')
diff --git a/sshuttle/methods/__init__.py b/sshuttle/methods/__init__.py
index 999e6b9..34699da 100644
--- a/sshuttle/methods/__init__.py
+++ b/sshuttle/methods/__init__.py
@@ -62,9 +62,13 @@ class BaseMethod(object):
def setup_udp_listener(self, udp_listener):
pass
- def check_settings(self, udp, dns):
- if udp:
- Fatal("UDP support not supported with method %s.\n" % self.name)
+ def assert_features(self, features):
+ avail = self.get_supported_features()
+ for key in ["udp", "dns", "ipv6"]:
+ if getattr(features, key) and not getattr(avail, key):
+ raise Fatal(
+ "Feature %s not supported with method %s.\n" %
+ (key, self.name))
def setup_firewall(self, port, dnsport, nslist, family, subnets, udp):
raise NotImplementedError()
diff --git a/sshuttle/methods/nat.py b/sshuttle/methods/nat.py
index 9936ce8..c5afc03 100644
--- a/sshuttle/methods/nat.py
+++ b/sshuttle/methods/nat.py
@@ -55,13 +55,12 @@ class Method(BaseMethod):
'-p', 'tcp',
'--to-ports', str(port))
- if dnsport:
- for f, ip in [i for i in nslist if i[0] == family]:
- _ipt_ttl('-A', chain, '-j', 'REDIRECT',
- '--dest', '%s/32' % ip,
- '-p', 'udp',
- '--dport', '53',
- '--to-ports', str(dnsport))
+ for f, ip in [i for i in nslist if i[0] == family]:
+ _ipt_ttl('-A', chain, '-j', 'REDIRECT',
+ '--dest', '%s/32' % ip,
+ '-p', 'udp',
+ '--dport', '53',
+ '--to-ports', str(dnsport))
def restore_firewall(self, port, family, udp):
# only ipv4 supported with NAT
diff --git a/sshuttle/methods/pf.py b/sshuttle/methods/pf.py
index 276571a..07d014c 100644
--- a/sshuttle/methods/pf.py
+++ b/sshuttle/methods/pf.py
@@ -181,27 +181,28 @@ class Method(BaseMethod):
if udp:
raise Exception("UDP not supported by pf method_name")
- includes = []
- # If a given subnet is both included and excluded, list the
- # exclusion first; the table will ignore the second, opposite
- # definition
- for f, swidth, sexclude, snet in sorted(
- subnets, key=lambda s: (s[1], s[2]), reverse=True):
- includes.append(b"%s%s/%d" %
- (b"!" if sexclude else b"",
- snet.encode("ASCII"),
- swidth))
-
- tables.append(
- b'table <forward_subnets> {%s}' % b','.join(includes))
- translating_rules.append(
- b'rdr pass on lo0 proto tcp '
- b'to <forward_subnets> -> 127.0.0.1 port %r' % port)
- filtering_rules.append(
- b'pass out route-to lo0 inet proto tcp '
- b'to <forward_subnets> keep state')
-
- if dnsport:
+ if len(subnets) > 0:
+ includes = []
+ # If a given subnet is both included and excluded, list the
+ # exclusion first; the table will ignore the second, opposite
+ # definition
+ for f, swidth, sexclude, snet in sorted(
+ subnets, key=lambda s: (s[1], s[2]), reverse=True):
+ includes.append(b"%s%s/%d" %
+ (b"!" if sexclude else b"",
+ snet.encode("ASCII"),
+ swidth))
+
+ tables.append(
+ b'table <forward_subnets> {%s}' % b','.join(includes))
+ translating_rules.append(
+ b'rdr pass on lo0 proto tcp '
+ b'to <forward_subnets> -> 127.0.0.1 port %r' % port)
+ filtering_rules.append(
+ b'pass out route-to lo0 inet proto tcp '
+ b'to <forward_subnets> keep state')
+
+ if len(nslist) > 0:
tables.append(
b'table <dns_servers> {%s}' %
b','.join([ns[1].encode("ASCII") for ns in nslist]))
diff --git a/sshuttle/methods/tproxy.py b/sshuttle/methods/tproxy.py
index 367b481..03353b8 100644
--- a/sshuttle/methods/tproxy.py
+++ b/sshuttle/methods/tproxy.py
@@ -59,6 +59,7 @@ if recvmsg == "python":
ip = socket.inet_ntop(family, cmsg_data[start:start + length])
dstip = (ip, port)
break
+ print("xxxxx", srcip, dstip)
return (srcip, dstip, data)
elif recvmsg == "socket_ext":
def recv_udp(listener, bufsize):
@@ -187,16 +188,15 @@ class Method(BaseMethod):
_ipt('-A', tproxy_chain, '-m', 'socket', '-j', divert_chain,
'-m', 'udp', '-p', 'udp')
- if dnsport:
- for f, ip in [i for i in nslist if i[0] == family]:
- _ipt('-A', mark_chain, '-j', 'MARK', '--set-mark', '1',
- '--dest', '%s/32' % ip,
- '-m', 'udp', '-p', 'udp', '--dport', '53')
- _ipt('-A', tproxy_chain, '-j', 'TPROXY',
- '--tproxy-mark', '0x1/0x1',
- '--dest', '%s/32' % ip,
- '-m', 'udp', '-p', 'udp', '--dport', '53',
- '--on-port', str(dnsport))
+ for f, ip in [i for i in nslist if i[0] == family]:
+ _ipt('-A', mark_chain, '-j', 'MARK', '--set-mark', '1',
+ '--dest', '%s/32' % ip,
+ '-m', 'udp', '-p', 'udp', '--dport', '53')
+ _ipt('-A', tproxy_chain, '-j', 'TPROXY',
+ '--tproxy-mark', '0x1/0x1',
+ '--dest', '%s/32' % ip,
+ '-m', 'udp', '-p', 'udp', '--dport', '53',
+ '--on-port', str(dnsport))
for f, swidth, sexclude, snet \
in sorted(subnets, key=lambda s: s[1], reverse=True):
@@ -267,16 +267,3 @@ class Method(BaseMethod):
if ipt_chain_exists(family, table, divert_chain):
_ipt('-F', divert_chain)
_ipt('-X', divert_chain)
-
- def check_settings(self, udp, dns):
- if udp and recvmsg is None:
- raise Fatal("tproxy UDP support requires recvmsg function.\n")
-
- if dns and recvmsg is None:
- raise Fatal("tproxy DNS support requires recvmsg function.\n")
-
- if udp:
- debug1("tproxy UDP support enabled.\n")
-
- if dns:
- debug1("tproxy DNS support enabled.\n")
diff --git a/sshuttle/tests/test_methods_nat.py b/sshuttle/tests/test_methods_nat.py
index 0f93675..2144e25 100644
--- a/sshuttle/tests/test_methods_nat.py
+++ b/sshuttle/tests/test_methods_nat.py
@@ -3,6 +3,7 @@ from mock import Mock, patch, call
import socket
import struct
+from sshuttle.helpers import Fatal
from sshuttle.methods import get_method
@@ -11,6 +12,7 @@ def test_get_supported_features():
features = method.get_supported_features()
assert not features.ipv6
assert not features.udp
+ assert features.dns
def test_get_tcp_dstip():
@@ -52,10 +54,18 @@ def test_setup_udp_listener():
assert listener.mock_calls == []
-def test_check_settings():
+def test_assert_features():
method = get_method('nat')
- method.check_settings(True, True)
- method.check_settings(False, True)
+ features = method.get_supported_features()
+ method.assert_features(features)
+
+ features.udp = True
+ with pytest.raises(Fatal):
+ method.assert_features(features)
+
+ features.ipv6 = True
+ with pytest.raises(Fatal):
+ method.assert_features(features)
def test_firewall_command():
diff --git a/sshuttle/tests/test_methods_pf.py b/sshuttle/tests/test_methods_pf.py
index 94e38d0..5d67396 100644
--- a/sshuttle/tests/test_methods_pf.py
+++ b/sshuttle/tests/test_methods_pf.py
@@ -3,6 +3,7 @@ from mock import Mock, patch, call, ANY
import socket
from sshuttle.methods import get_method
+from sshuttle.helpers import Fatal
def test_get_supported_features():
@@ -10,6 +11,7 @@ def test_get_supported_features():
features = method.get_supported_features()
assert not features.ipv6
assert not features.udp
+ assert features.dns
@patch('sshuttle.helpers.verbose', new=3)
@@ -68,10 +70,18 @@ def test_setup_udp_listener():
assert listener.mock_calls == []
-def test_check_settings():
+def test_assert_features():
method = get_method('pf')
- method.check_settings(True, True)
- method.check_settings(False, True)
+ features = method.get_supported_features()
+ method.assert_features(features)
+
+ features.udp = True
+ with pytest.raises(Fatal):
+ method.assert_features(features)
+
+ features.ipv6 = True
+ with pytest.raises(Fatal):
+ method.assert_features(features)
@patch('sshuttle.methods.pf.sys.stdout')
diff --git a/sshuttle/tests/test_methods_tproxy.py b/sshuttle/tests/test_methods_tproxy.py
index 344f5d8..3401958 100644
--- a/sshuttle/tests/test_methods_tproxy.py
+++ b/sshuttle/tests/test_methods_tproxy.py
@@ -3,11 +3,22 @@ from mock import Mock, patch, call
from sshuttle.methods import get_method
-def test_get_supported_features():
+@patch("sshuttle.methods.tproxy.recvmsg")
+def test_get_supported_features_recvmsg(mock_recvmsg):
method = get_method('tproxy')
features = method.get_supported_features()
assert features.ipv6
assert features.udp
+ assert features.dns
+
+
+@patch("sshuttle.methods.tproxy.recvmsg", None)
+def test_get_supported_features_norecvmsg():
+ method = get_method('tproxy')
+ features = method.get_supported_features()
+ assert features.ipv6
+ assert not features.udp
+ assert not features.dns
def test_get_tcp_dstip():
@@ -66,10 +77,10 @@ def test_setup_udp_listener():
]
-def test_check_settings():
+def test_assert_features():
method = get_method('tproxy')
- method.check_settings(True, True)
- method.check_settings(False, True)
+ features = method.get_supported_features()
+ method.assert_features(features)
def test_firewall_command():