diff options
Diffstat (limited to 'sshuttle/options.py')
-rw-r--r-- | sshuttle/options.py | 122 |
1 files changed, 54 insertions, 68 deletions
diff --git a/sshuttle/options.py b/sshuttle/options.py index d97d7ae..659c014 100644 --- a/sshuttle/options.py +++ b/sshuttle/options.py @@ -4,41 +4,8 @@ from argparse import ArgumentParser, Action, ArgumentTypeError as Fatal from sshuttle import __version__ -# 1.2.3.4/5 or just 1.2.3.4 -def parse_subnet4(s): - m = re.match(r'(\d+)(?:\.(\d+)\.(\d+)\.(\d+))?(?:/(\d+))?$', s) - if not m: - raise Fatal('%r is not a valid IP subnet format' % s) - (a, b, c, d, width) = m.groups() - (a, b, c, d) = (int(a or 0), int(b or 0), int(c or 0), int(d or 0)) - if width is None: - width = 32 - else: - width = int(width) - if a > 255 or b > 255 or c > 255 or d > 255: - raise Fatal('%d.%d.%d.%d has numbers > 255' % (a, b, c, d)) - if width > 32: - raise Fatal('*/%d is greater than the maximum of 32' % width) - return(socket.AF_INET, '%d.%d.%d.%d' % (a, b, c, d), width) - - -# 1:2::3/64 or just 1:2::3 -def parse_subnet6(s): - m = re.match(r'(?:([a-fA-F\d:]+))?(?:/(\d+))?$', s) - if not m: - raise Fatal('%r is not a valid IP subnet format' % s) - (net, width) = m.groups() - if width is None: - width = 128 - else: - width = int(width) - if width > 128: - raise Fatal('*/%d is greater than the maximum of 128' % width) - return(socket.AF_INET6, net, width) - - # Subnet file, supporting empty lines and hash-started comment lines -def parse_subnet_file(s): +def parse_subnetport_file(s): try: handle = open(s, 'r') except OSError: @@ -52,47 +19,66 @@ def parse_subnet_file(s): continue if line[0] == '#': continue - subnets.append(parse_subnet(line)) + subnets.append(parse_subnetport(line)) return subnets -# 1.2.3.4/5 or just 1.2.3.4 -# 1:2::3/64 or just 1:2::3 -def parse_subnet(subnet_str): - if ':' in subnet_str: - return parse_subnet6(subnet_str) +# 1.2.3.4/5:678, 1.2.3.4:567, 1.2.3.4/16 or just 1.2.3.4 +# [1:2::3/64]:456, [1:2::3]:456, 1:2::3/64 or just 1:2::3 +# example.com:123 or just example.com +def parse_subnetport(s): + if s.count(':') > 1: + rx = r'(?:\[?([\w\:]+)(?:/(\d+))?]?)(?::(\d+)(?:-(\d+))?)?$' else: - return parse_subnet4(subnet_str) + rx = r'([\w\.]+)(?:/(\d+))?(?::(\d+)(?:-(\d+))?)?$' + + m = re.match(rx, s) + if not m: + raise Fatal('%r is not a valid address/mask:port format' % s) + + addr, width, fport, lport = m.groups() + try: + addrinfo = socket.getaddrinfo(addr, 0, 0, socket.SOCK_STREAM) + except socket.gaierror: + raise Fatal('Unable to resolve address: %s' % addr) + + family, _, _, _, addr = min(addrinfo) + max_width = 32 if family == socket.AF_INET else 128 + width = int(width or max_width) + if not 0 <= width <= max_width: + raise Fatal('width %d is not between 0 and %d' % (width, max_width)) + + return (family, addr[0], width, int(fport or 0), int(lport or fport or 0)) # 1.2.3.4:567 or just 1.2.3.4 or just 567 -def parse_ipport4(s): +# [1:2::3]:456 or [1:2::3] or just [::]:567 +# example.com:123 or just example.com +def parse_ipport(s): s = str(s) - m = re.match(r'(?:(\d+)\.(\d+)\.(\d+)\.(\d+))?(?::)?(?:(\d+))?$', s) + if s.isdigit(): + rx = r'()(\d+)$' + elif ']' in s: + rx = r'(?:\[([^]]+)])(?::(\d+))?$' + else: + rx = r'([\w\.]+)(?::(\d+))?$' + + m = re.match(rx, s) if not m: raise Fatal('%r is not a valid IP:port format' % s) - (a, b, c, d, port) = m.groups() - (a, b, c, d, port) = (int(a or 0), int(b or 0), int(c or 0), int(d or 0), - int(port or 0)) - if a > 255 or b > 255 or c > 255 or d > 255: - raise Fatal('%d.%d.%d.%d has numbers > 255' % (a, b, c, d)) - if port > 65535: - raise Fatal('*:%d is greater than the maximum of 65535' % port) - if a is None: - a = b = c = d = 0 - return ('%d.%d.%d.%d' % (a, b, c, d), port) + ip, port = m.groups() + ip = ip or '0.0.0.0' + port = int(port or 0) -# [1:2::3]:456 or [1:2::3] or 456 -def parse_ipport6(s): - s = str(s) - m = re.match(r'(?:\[([^]]*)])?(?::)?(?:(\d+))?$', s) - if not m: - raise Fatal('%s is not a valid IP:port format' % s) - (ip, port) = m.groups() - (ip, port) = (ip or '::', int(port or 0)) - return (ip, port) + try: + addrinfo = socket.getaddrinfo(ip, port, 0, socket.SOCK_STREAM) + except socket.gaierror: + raise Fatal('%r is not a valid IP:port format' % s) + + family, _, _, _, addr = min(addrinfo) + return (family,) + addr[:2] def parse_list(list): @@ -116,9 +102,9 @@ parser = ArgumentParser( ) parser.add_argument( "subnets", - metavar="IP/MASK [IP/MASK...]", + metavar="IP/MASK[:PORT[-PORT]]...", nargs="*", - type=parse_subnet, + type=parse_subnetport, help=""" capture and forward traffic to these subnets (whitespace separated) """ @@ -185,10 +171,10 @@ parser.add_argument( ) parser.add_argument( "-x", "--exclude", - metavar="IP/MASK", + metavar="IP/MASK[:PORT[-PORT]]", action="append", default=[], - type=parse_subnet, + type=parse_subnetport, help=""" exclude this subnet (can be used more than once) """ @@ -198,7 +184,7 @@ parser.add_argument( metavar="PATH", action=Concat, dest="exclude", - type=parse_subnet_file, + type=parse_subnetport_file, help=""" exclude the subnets in a file (whitespace separated) """ @@ -271,7 +257,7 @@ parser.add_argument( action=Concat, dest="subnets_file", default=[], - type=parse_subnet_file, + type=parse_subnetport_file, help=""" file where the subnets are stored, instead of on the command line """ |