diff options
Diffstat (limited to 'sshuttle/hostwatch.py')
-rw-r--r-- | sshuttle/hostwatch.py | 79 |
1 files changed, 61 insertions, 18 deletions
diff --git a/sshuttle/hostwatch.py b/sshuttle/hostwatch.py index 7e4d3c5..a016f4f 100644 --- a/sshuttle/hostwatch.py +++ b/sshuttle/hostwatch.py @@ -15,6 +15,8 @@ POLL_TIME = 60 * 15 NETSTAT_POLL_TIME = 30 CACHEFILE = os.path.expanduser('~/.sshuttle.hosts') +# Have we already failed to write CACHEFILE? +CACHE_WRITE_FAILED = False hostnames = {} queue = {} @@ -31,7 +33,10 @@ def _is_ip(s): def write_host_cache(): + """If possible, write our hosts file to disk so future connections + can reuse the hosts that we already found.""" tmpname = '%s.%d.tmp' % (CACHEFILE, os.getpid()) + global CACHE_WRITE_FAILED try: f = open(tmpname, 'wb') for name, ip in sorted(hostnames.items()): @@ -39,7 +44,15 @@ def write_host_cache(): f.close() os.chmod(tmpname, 384) # 600 in octal, 'rw-------' os.rename(tmpname, CACHEFILE) - finally: + CACHE_WRITE_FAILED = False + except (OSError, IOError): + # Write message if we haven't yet or if we get a failure after + # a previous success. + if not CACHE_WRITE_FAILED: + log("Failed to write host cache to temporary file " + "%s and rename it to %s" % (tmpname, CACHEFILE)) + CACHE_WRITE_FAILED = True + try: os.unlink(tmpname) except BaseException: @@ -47,25 +60,34 @@ def write_host_cache(): def read_host_cache(): + """If possible, read the cache file from disk to populate hosts that + were found in a previous sshuttle run.""" try: f = open(CACHEFILE) - except IOError: + except (OSError, IOError): _, e = sys.exc_info()[:2] if e.errno == errno.ENOENT: return else: - raise + log("Failed to read existing host cache file %s on remote host" + % CACHEFILE) + return for line in f: words = line.strip().split(',') if len(words) == 2: (name, ip) = words name = re.sub(r'[^-\w\.]', '-', name).strip() + # Remove characters that shouldn't be in IP ip = re.sub(r'[^0-9.]', '', ip).strip() if name and ip: found_host(name, ip) def found_host(name, ip): + """The provided name maps to the given IP. Add the host to the + hostnames list, send the host to the sshuttle client via + stdout, and write the host to the cache file. + """ hostname = re.sub(r'\..*', '', name) hostname = re.sub(r'[^-\w\.]', '_', hostname) if (ip.startswith('127.') or ip.startswith('255.') or @@ -84,29 +106,37 @@ def found_host(name, ip): def _check_etc_hosts(): - debug2(' > hosts') - for line in open('/etc/hosts'): - line = re.sub(r'#.*', '', line) - words = line.strip().split() - if not words: - continue - ip = words[0] - names = words[1:] - if _is_ip(ip): - debug3('< %s %r' % (ip, names)) - for n in names: - check_host(n) - found_host(n, ip) + """If possible, read /etc/hosts to find hosts.""" + filename = '/etc/hosts' + debug2(' > Reading %s on remote host' % filename) + try: + for line in open(filename): + line = re.sub(r'#.*', '', line) # remove comments + words = line.strip().split() + if not words: + continue + ip = words[0] + if _is_ip(ip): + names = words[1:] + debug3('< %s %r' % (ip, names)) + for n in names: + check_host(n) + found_host(n, ip) + except (OSError, IOError): + debug1("Failed to read %s on remote host" % filename) def _check_revdns(ip): + """Use reverse DNS to try to get hostnames from an IP addresses.""" debug2(' > rev: %s' % ip) try: r = socket.gethostbyaddr(ip) debug3('< %s' % r[0]) check_host(r[0]) found_host(r[0], ip) - except (socket.herror, UnicodeError): + except (OSError, socket.error, UnicodeError): + # This case is expected to occur regularly. + # debug3('< %s gethostbyaddr failed on remote host' % ip) pass @@ -134,7 +164,14 @@ def _check_netstat(): log('%r failed: %r' % (argv, e)) return + # The same IPs may appear multiple times. Consolidate them so the + # debug message doesn't print the same IP repeatedly. + ip_list = [] for ip in re.findall(r'\d+\.\d+\.\d+\.\d+', content): + if ip not in ip_list: + ip_list.append(ip) + + for ip in sorted(ip_list): debug3('< %s' % ip) check_host(ip) @@ -179,13 +216,19 @@ def hw_main(seed_hosts, auto_hosts): while 1: now = time.time() + # For each item in the queue for t, last_polled in list(queue.items()): (op, args) = t if not _stdin_still_ok(0): break + + # Determine if we need to run. maxtime = POLL_TIME + # netstat runs more often than other jobs if op == _check_netstat: maxtime = NETSTAT_POLL_TIME + + # Check if this jobs needs to run. if now - last_polled > maxtime: queue[t] = time.time() op(*args) @@ -195,5 +238,5 @@ def hw_main(seed_hosts, auto_hosts): break # FIXME: use a smarter timeout based on oldest last_polled - if not _stdin_still_ok(1): + if not _stdin_still_ok(1): # sleeps for up to 1 second break |