import sys import struct import socket import errno import select import os import fcntl from sshuttle.helpers import b, log, debug1, debug2, debug3, Fatal MAX_CHANNEL = 65535 LATENCY_BUFFER_SIZE = 32768 SHUT_RD = 0 SHUT_WR = 1 SHUT_RDWR = 2 HDR_LEN = 8 CMD_EXIT = 0x4200 CMD_PING = 0x4201 CMD_PONG = 0x4202 CMD_TCP_CONNECT = 0x4203 CMD_TCP_STOP_SENDING = 0x4204 CMD_TCP_EOF = 0x4205 CMD_TCP_DATA = 0x4206 CMD_ROUTES = 0x4207 CMD_HOST_REQ = 0x4208 CMD_HOST_LIST = 0x4209 CMD_DNS_REQ = 0x420a CMD_DNS_RESPONSE = 0x420b CMD_UDP_OPEN = 0x420c CMD_UDP_DATA = 0x420d CMD_UDP_CLOSE = 0x420e cmd_to_name = { CMD_EXIT: 'EXIT', CMD_PING: 'PING', CMD_PONG: 'PONG', CMD_TCP_CONNECT: 'TCP_CONNECT', CMD_TCP_STOP_SENDING: 'TCP_STOP_SENDING', CMD_TCP_EOF: 'TCP_EOF', CMD_TCP_DATA: 'TCP_DATA', CMD_ROUTES: 'ROUTES', CMD_HOST_REQ: 'HOST_REQ', CMD_HOST_LIST: 'HOST_LIST', CMD_DNS_REQ: 'DNS_REQ', CMD_DNS_RESPONSE: 'DNS_RESPONSE', CMD_UDP_OPEN: 'UDP_OPEN', CMD_UDP_DATA: 'UDP_DATA', CMD_UDP_CLOSE: 'UDP_CLOSE', } NET_ERRS = [errno.ECONNREFUSED, errno.ETIMEDOUT, errno.EHOSTUNREACH, errno.ENETUNREACH, errno.EHOSTDOWN, errno.ENETDOWN, errno.ENETUNREACH, errno.ECONNABORTED, errno.ECONNRESET] def _add(socks, elem): if elem not in socks: socks.append(elem) def _fds(socks): out = [] for i in socks: try: out.append(i.fileno()) except AttributeError: out.append(i) out.sort() return out def _nb_clean(func, *args): try: return func(*args) except OSError: _, e = sys.exc_info()[:2] if e.errno not in (errno.EWOULDBLOCK, errno.EAGAIN): raise else: debug3('%s: err was: %s' % (func.__name__, e)) return None def _try_peername(sock): try: pn = sock.getpeername() if pn: return '%s:%s' % (pn[0], pn[1]) except socket.error: _, e = sys.exc_info()[:2] if e.args[0] == errno.EINVAL: pass elif e.args[0] not in (errno.ENOTCONN, errno.ENOTSOCK): raise except AttributeError: pass return 'unknown' _swcount = 0 class SockWrapper: def __init__(self, rsock, wsock, connect_to=None, peername=None): global _swcount _swcount += 1 debug3('creating new SockWrapper (%d now exist)' % _swcount) self.exc = None self.rsock = rsock self.wsock = wsock self.shut_read = self.shut_write = False self.buf = [] self.connect_to = connect_to self.peername = peername or _try_peername(self.rsock) self.try_connect() def __del__(self): global _swcount _swcount -= 1 debug1('%r: deleting (%d remain)' % (self, _swcount)) if self.exc: debug1('%r: error was: %s' % (self, self.exc)) def __repr__(self): if self.rsock == self.wsock: fds = '#%d' % self.rsock.fileno() else: fds = '#%d,%d' % (self.rsock.fileno(), self.wsock.fileno()) return 'SW%s:%s' % (fds, self.peername) def seterr(self, e): if not self.exc: self.exc = e self.nowrite() self.noread() def try_connect(self): if self.connect_to and self.shut_write: self.noread() self.connect_to = None if not self.connect_to: return # already connected self.rsock.setblocking(False) debug3('%r: trying connect to %r' % (self, self.connect_to)) try: self.rsock.connect(self.connect_to) # connected successfully (Linux) self.connect_to = None except socket.error: _, e = sys.exc_info()[:2] debug3('%r: connect result: %s' % (self, e)) if e.args[0] == errno.EINVAL: # this is what happens when you call connect() on a socket # that is now connected but returned EINPROGRESS last time, # on BSD, on python pre-2.5.1. We need to use getsockopt() # to get the "real" error. Later pythons do this # automatically, so this code won't run. realerr = self.rsock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) e = socket.error(realerr, os.strerror(realerr)) debug3('%r: fixed connect result: %s' % (self, e)) if e.args[0] in [errno.EINPROGRESS, errno.EALREADY]: pass # not connected yet elif e.args[0] == 0: # connected successfully (weird Linux bug?) # Sometimes Linux seems to return EINVAL when it isn't # invalid. This *may* be caused by a race condition # between connect() and getsockopt(SO_ERROR) (ie. it # finishes connecting in between the two, so there is no # longer an error). However, I'm not sure of that. # # I did get at least one report that the problem went away # when we added this, however. self.connect_to = None elif e.args[0] == errno.EISCONN: # connected successfully (BSD) self.connect_to = None elif e.args[0] in NET_ERRS + [errno.EACCES, errno.EPERM]: # a "normal" kind of error self.connect_to = None self.seterr(e) else: raise # error we've never heard of?! barf completely. def noread(self): if not self.shut_read: debug2('%r: done reading' % self) self.shut_read = True def nowrite(self): if not self.shut_write: debug2('%r: done writing' % self) self.shut_write = True try: self.wsock.shutdown(SHUT_WR) except socket.error: _, e = sys.exc_info()[:2] self.seterr('nowrite: %s' % e) @staticmethod def too_full(): return False # fullness is determined by the socket's select() state def uwrite(self, buf): if self.connect_to: return 0 # still connecting self.wsock.setblocking(False) try: return _nb_clean(os.write, self.wsock.fileno(), buf) except OSError: _, e = sys.exc_info()[:2] if e.errno == errno.EPIPE: debug1('%r: uwrite: got EPIPE' % self) self.nowrite() return 0 else: # unexpected error... stream is dead self.seterr('uwrite: %s' % e) return 0 def write(self, buf): assert buf return self.uwrite(buf) def uread(self): if self.connect_to: return None # still connecting if self.shut_read: return self.rsock.setblocking(False) try: return _nb_clean(os.read, self.rsock.fileno(), 65536) except OSError: _, e = sys.exc_info()[:2] self.seterr('uread: %s' % e) return b('') # unexpected error... we'll call it EOF def fill(self): if self.buf: return rb = self.uread() if rb: self.buf.append(rb) if rb == b(''): # empty string means EOF; None means temporarily empty self.noread() def copy_to(self, outwrap): if self.buf and self.buf[0]: wrote = outwrap.write(self.buf[0]) self.buf[0] = self.buf[0][wrote:] while self.buf and not self.buf[0]: self.buf.pop(0) if not self.buf and self.shut_read: outwrap.nowrite() class Handler: def __init__(self, socks=None, callback=None): self.ok = True self.socks = socks or [] if callback: self.callback = callback def pre_select(self, r, w, x): for i in self.socks: _add(r, i) def callback(self, sock): log('--no callback defined-- %r' % self) (r, _, _) = select.select(self.socks, [], [], 0) for s in r: v = s.recv(4096) if not v: log('--closed-- %r' % self) self.socks = [] self.ok = False class Proxy(Handler): def __init__(self, wrap1, wrap2): Handler.__init__(self, [wrap1.rsock, wrap1.wsock, wrap2.rsock, wrap2.wsock]) self.wrap1 = wrap1 self.wrap2 = wrap2 def pre_select(self, r, w, x): if self.wrap1.shut_write: self.wrap2.noread() if self.wrap2.shut_write: self.wrap1.noread() if self.wrap1.connect_to: _add(w, self.wrap1.rsock) elif self.wrap1.buf: if not self.wrap2.too_full(): _add(w, self.wrap2.wsock) elif not self.wrap1.shut_read: _add(r, self.wrap1.rsock) if self.wrap2.connect_to: _add(w, self.wrap2.rsock) elif self.wrap2.buf: if not self.wrap1.too_full(): _add(w, self.wrap1.wsock) elif not self.wrap2.shut_read: _add(r, self.wrap2.rsock) def callback(self, sock): self.wrap1.try_connect() self.wrap2.try_connect() self.wrap1.fill() self.wrap2.fill() self.wrap1.copy_to(self.wrap2) self.wrap2.copy_to(self.wrap1) if self.wrap1.buf and self.wrap2.shut_write: self.wrap1.buf = [] self.wrap1.noread() if self.wrap2.buf and self.wrap1.shut_write: self.wrap2.buf = [] self.wrap2.noread() if (self.wrap1.shut_read and self.wrap2.shut_read and not self.wrap1.buf and not self.wrap2.buf): self.ok = False self.wrap1.nowrite() self.wrap2.nowrite() class Mux(Handler): def __init__(self, rfile, wfile): Handler.__init__(self, [rfile, wfile]) self.rfile = rfile self.wfile = wfile self.new_channel = self.got_dns_req = self.got_routes = None self.got_udp_open = self.got_udp_data = self.got_udp_close = None self.got_host_req = self.got_host_list = None self.channels = {} self.chani = 0 self.want = 0 self.inbuf = b('') self.outbuf = [] self.fullness = 0 self.too_full = False self.send(0, CMD_PING, b('chicken')) def next_channel(self): # channel 0 is special, so we never allocate it for _ in range(1024): self.chani += 1 if self.chani > MAX_CHANNEL: self.chani = 1 if not self.channels.get(self.chani): return self.chani def amount_queued(self): total = 0 for byte in self.outbuf: total += len(byte) return total def check_fullness(self): if self.fullness > LATENCY_BUFFER_SIZE: if not self.too_full: self.send(0, CMD_PING, b('rttest')) self.too_full = True def send(self, channel, cmd, data): assert isinstance(data, bytes) assert len(data) <= 65535 p = struct.pack('!ccHHH', b('S'), b('S'), channel, cmd, len(data)) \ + data self.outbuf.append(p) debug2(' > channel=%d cmd=%s len=%d (fullness=%d)' % (channel, cmd_to_name.get(cmd, hex(cmd)), len(data), self.fullness)) self.fullness += len(data) def got_packet(self, channel, cmd, data): debug2('< channel=%d cmd=%s len=%d' % (channel, cmd_to_name.get(cmd, hex(cmd)), len(data))) if cmd == CMD_PING: self.send(0, CMD_PONG, data) elif cmd == CMD_PONG: debug2('received PING response') self.too_full = False self.fullness = 0 elif cmd == CMD_EXIT: self.ok = False elif cmd == CMD_TCP_CONNECT: assert not self.channels.get(channel) if self.new_channel: self.new_channel(channel, data) elif cmd == CMD_DNS_REQ: assert not self.channels.get(channel) if self.got_dns_req: self.got_dns_req(channel, data) elif cmd == CMD_UDP_OPEN: assert not self.channels.get(channel) if self.got_udp_open: self.got_udp_open(channel, data) elif cmd == CMD_ROUTES: if self.got_routes: self.got_routes(data) else: raise Exception('got CMD_ROUTES without got_routes?') elif cmd == CMD_HOST_REQ: if self.got_host_req: self.got_host_req(data) else: raise Exception('got CMD_HOST_REQ without got_host_req?') elif cmd == CMD_HOST_LIST: if self.got_host_list: self.got_host_list(data) else: raise Exception('got CMD_HOST_LIST without got_host_list?') else: callback = self.channels.get(channel) if not callback: log('warning: closed channel %d got cmd=%s len=%d' % (channel, cmd_to_name.get(cmd, hex(cmd)), len(data))) else: callback(cmd, data) def flush(self): try: os.set_blocking(self.wfile.fileno(), False) except AttributeError: # python < 3.5 flags = fcntl.fcntl(self.wfile.fileno(), fcntl.F_GETFL) flags |= os.O_NONBLOCK fcntl.fcntl(self.wfile.fileno(), fcntl.F_SETFL, flags) if self.outbuf and self.outbuf[0]: wrote = _nb_clean(os.write, self.wfile.fileno(), self.outbuf[0]) debug2('mux wrote: %r/%d' % (wrote, len(self.outbuf[0]))) if wrote: self.outbuf[0] = self.outbuf[0][wrote:] while self.outbuf and not self.outbuf[0]: self.outbuf[0:1] = [] def fill(self): try: os.set_blocking(self.rfile.fileno(), False) except AttributeError: # python < 3.5 flags = fcntl.fcntl(self.rfile.fileno(), fcntl.F_GETFL) flags |= os.O_NONBLOCK fcntl.fcntl(self.rfile.fileno(), fcntl.F_SETFL, flags) try: # If LATENCY_BUFFER_SIZE is inappropriately large, we will # get a MemoryError here. Read no more than 1MiB. read = _nb_clean(os.read, self.rfile.fileno(), min(1048576, LATENCY_BUFFER_SIZE)) except OSError: _, e = sys.exc_info()[:2] raise Fatal('other end: %r' % e) # log('<<< %r' % b) if read == b(''): # EOF self.ok = False if read: self.inbuf += read def handle(self): self.fill() while 1: if len(self.inbuf) >= (self.want or HDR_LEN): (s1, s2, channel, cmd, datalen) = \ struct.unpack('!ccHHH', self.inbuf[:HDR_LEN]) assert s1 == b('S') assert s2 == b('S') self.want = datalen + HDR_LEN if self.want and len(self.inbuf) >= self.want: data = self.inbuf[HDR_LEN:self.want] self.inbuf = self.inbuf[self.want:] self.want = 0 self.got_packet(channel, cmd, data) else: break def pre_select(self, r, w, x): _add(r, self.rfile) if self.outbuf: _add(w, self.wfile) def callback(self, sock): (r, w, _) = select.select([self.rfile], [self.wfile], [], 0) if self.rfile in r: self.handle() if self.outbuf and self.wfile in w: self.flush() class MuxWrapper(SockWrapper): def __init__(self, mux, channel): SockWrapper.__init__(self, mux.rfile, mux.wfile) self.mux = mux self.channel = channel self.mux.channels[channel] = self.got_packet self.socks = [] debug2('new channel: %d' % channel) def __del__(self): self.nowrite() SockWrapper.__del__(self) def __repr__(self): return 'SW%r:Mux#%d' % (self.peername, self.channel) def noread(self): if not self.shut_read: self.mux.send(self.channel, CMD_TCP_STOP_SENDING, b('')) self.setnoread() def setnoread(self): if not self.shut_read: debug2('%r: done reading' % self) self.shut_read = True self.maybe_close() def nowrite(self): if not self.shut_write: self.mux.send(self.channel, CMD_TCP_EOF, b('')) self.setnowrite() def setnowrite(self): if not self.shut_write: debug2('%r: done writing' % self) self.shut_write = True self.maybe_close() def maybe_close(self): if self.shut_read and self.shut_write: debug2('%r: closing connection' % self) # remove the mux's reference to us. The python garbage collector # will then be able to reap our object. self.mux.channels[self.channel] = None def too_full(self): return self.mux.too_full def uwrite(self, buf): if self.mux.too_full: return 0 # too much already enqueued if len(buf) > 2048: buf = buf[:2048] self.mux.send(self.channel, CMD_TCP_DATA, buf) return len(buf) def uread(self): if self.shut_read: return b('') # EOF else: return None # no data available right now def got_packet(self, cmd, data): if cmd == CMD_TCP_EOF: # Remote side already knows the status - set flag but don't notify self.setnoread() elif cmd == CMD_TCP_STOP_SENDING: # Remote side already knows the status - set flag but don't notify self.setnowrite() elif cmd == CMD_TCP_DATA: self.buf.append(data) else: raise Exception('unknown command %d (%d bytes)' % (cmd, len(data))) def connect_dst(family, ip, port): debug2('Connecting to %s:%d' % (ip, port)) outsock = socket.socket(family) return SockWrapper(outsock, outsock, connect_to=(ip, port), peername='%s:%d' % (ip, port)) def runonce(handlers, mux): r = [] w = [] x = [] to_remove = [s for s in handlers if not s.ok] for h in to_remove: handlers.remove(h) for s in handlers: s.pre_select(r, w, x) debug2('Waiting: %d r=%r w=%r x=%r (fullness=%d/%d)' % (len(handlers), _fds(r), _fds(w), _fds(x), mux.fullness, mux.too_full)) (r, w, x) = select.select(r, w, x) debug2(' Ready: %d r=%r w=%r x=%r' % (len(handlers), _fds(r), _fds(w), _fds(x))) ready = r + w + x did = {} for h in handlers: for s in h.socks: if s in ready: h.callback(s) did[s] = 1 for s in ready: if s not in did: raise Fatal('socket %r was not used by any handler' % s)