diff options
author | Avery Pennarun <apenwarr@gmail.com> | 2010-05-01 23:14:42 -0400 |
---|---|---|
committer | Avery Pennarun <apenwarr@gmail.com> | 2010-05-01 23:14:42 -0400 |
commit | 5f0bfb5d9e4135332ab7c398e40037545f63ae18 (patch) | |
tree | 942b412e84e57e9385198ebfff94852aaf10bc1b | |
parent | 9f514d7a1507018b337b56f376fd2e9022d75641 (diff) |
Basic implementation of a multiplex protocol - client side only.
Currently the 'server' is just a pipe to run 'hd' (hexdump) for looking at
the client-side results. Lame, but true.
-rw-r--r-- | client.py | 37 | ||||
-rwxr-xr-x | main.py | 10 | ||||
-rw-r--r-- | ssh.py | 24 | ||||
-rw-r--r-- | ssnet.py | 161 |
4 files changed, 196 insertions, 36 deletions
@@ -1,13 +1,13 @@ import struct, socket, select, subprocess, errno -from ssnet import SockWrapper, Handler, Proxy +import ssnet, ssh +from ssnet import SockWrapper, Handler, Proxy, Mux, MuxWrapper from helpers import * - def original_dst(sock): SO_ORIGINAL_DST = 80 SOCKADDR_MIN = 16 sockaddr_in = sock.getsockopt(socket.SOL_IP, SO_ORIGINAL_DST, SOCKADDR_MIN) - (proto, port, a,b,c,d) = struct.unpack('!hhBBBB', sockaddr_in[:8]) + (proto, port, a,b,c,d) = struct.unpack('!HHBBBB', sockaddr_in[:8]) assert(socket.htons(proto) == socket.AF_INET) ip = '%d.%d.%d.%d' % (a,b,c,d) return (ip,port) @@ -21,8 +21,17 @@ def iptables_setup(port, subnets): raise Exception('%r returned %d' % (argv, rv)) -def _main(listener, remotename, subnets): +def _main(listener, listenport, use_server, remotename, subnets): handlers = [] + if use_server: + (serverproc, serversock) = ssh.connect(remotename) + mux = Mux(serversock) + handlers.append(mux) + + # we definitely want to do this *after* starting ssh, or we might end + # up intercepting the ssh connection! + iptables_setup(listenport, subnets) + def onaccept(): sock,srcip = listener.accept() dstip = original_dst(sock) @@ -31,10 +40,16 @@ def _main(listener, remotename, subnets): log("-- ignored: that's my address!\n") sock.close() return - outsock = socket.socket() - outsock.setsockopt(socket.SOL_IP, socket.IP_TTL, 42) - outsock.connect(dstip) - handlers.append(Proxy(SockWrapper(sock), SockWrapper(outsock))) + if use_server: + chan = mux.next_channel() + mux.send(chan, ssnet.CMD_CONNECT, '%s,%s' % dstip) + outwrap = MuxWrapper(mux, chan) + else: + outsock = socket.socket() + outsock.setsockopt(socket.SOL_IP, socket.IP_TTL, 42) + outsock.connect(dstip) + outwrap = SockWrapper(outsock) + handlers.append(Proxy(SockWrapper(sock), outwrap)) handlers.append(Handler([listener], onaccept)) while 1: @@ -54,7 +69,7 @@ def _main(listener, remotename, subnets): s.callback() -def main(listenip, remotename, subnets): +def main(listenip, use_server, remotename, subnets): log('Starting sshuttle proxy.\n') listener = socket.socket() listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) @@ -81,9 +96,7 @@ def main(listenip, remotename, subnets): listenip = listener.getsockname() log('Listening on %r.\n' % (listenip,)) - iptables_setup(listenip[1], subnets) - try: - return _main(listener, remotename, subnets) + return _main(listener, listenip[1], use_server, remotename, subnets) finally: iptables_setup(listenip[1], []) @@ -1,5 +1,5 @@ #!/usr/bin/env python -import sys, re +import sys, os, re import options, client, iptables @@ -50,6 +50,7 @@ sshuttle --server -- l,listen= transproxy to this ip address and port number [default=0] r,remote= ssh hostname (and optional username) of remote sshuttle server +noserver don't use a separate server process (mostly for debugging) server [internal use only] iptables [internal use only] """ @@ -57,7 +58,9 @@ o = options.Options('sshuttle', optspec) (opt, flags, extra) = o.parse(sys.argv[1:]) if opt.server: - o.fatal('server mode not implemented yet') + #o.fatal('server mode not implemented yet') + os.dup2(2,1) + os.execvp('hd', ['hd']) sys.exit(1) elif opt.iptables: if len(extra) < 1: @@ -67,9 +70,10 @@ elif opt.iptables: else: if len(extra) < 1: o.fatal('at least one subnet expected') - remotename = extra[0] + remotename = opt.remote if remotename == '' or remotename == '-': remotename = None sys.exit(client.main(parse_ipport(opt.listen or '0.0.0.0:0'), + not opt.noserver, remotename, parse_subnets(extra))) @@ -1,14 +1,13 @@ -import os, re, subprocess +import sys, os, re, subprocess, socket -def connect(rhost, subcmd): - assert(not re.search(r'[^\w-]', subcmd)) +def connect(rhost): main_exe = sys.argv[0] nicedir = os.path.split(os.path.abspath(main_exe))[0] nicedir = re.sub(r':', "_", nicedir) if rhost == '-': rhost = None if not rhost: - argv = ['sshuttle', subcmd] + argv = ['sshuttle', '--server'] else: # WARNING: shell quoting security holes are possible here, so we # have to be super careful. We have to use 'sh -c' because @@ -19,14 +18,21 @@ def connect(rhost, subcmd): # stuff here. escapedir = re.sub(r'([^\w/])', r'\\\\\\\1', nicedir) cmd = r""" - sh -c PATH=%s:'$PATH sshuttle %s' - """ % (escapedir, subcmd) - argv = ['ssh', rhost, '--', cmd.strip()] + sh -c PATH=%s:'$PATH sshuttle --server' + """ % (escapedir,) + argv = ['ssh', '-v', rhost, '--', cmd.strip()] + print repr(argv) + (s1,s2) = socket.socketpair() def setup(): # runs in the child process + s2.close() if not rhost: os.environ['PATH'] = ':'.join([nicedir, os.environ.get('PATH', '')]) os.setsid() - return subprocess.Popen(argv, stdin=subprocess.PIPE, stdout=subprocess.PIPE, - preexec_fn=setup) + s1a,s1b = os.dup(s1.fileno()), os.dup(s1.fileno()) + s1.close() + p = subprocess.Popen(argv, stdin=s1a, stdout=s1b, preexec_fn=setup) + os.close(s1a) + os.close(s1b) + return p, s2 @@ -1,6 +1,17 @@ -import socket, errno, select +import struct, socket, errno, select from helpers import * +HDR_LEN = 8 + +CMD_EXIT = 0x4200 +CMD_PING = 0x4201 +CMD_PONG = 0x4202 +CMD_CONNECT = 0x4203 +CMD_CLOSE = 0x4204 +CMD_EOF = 0x4205 +CMD_DATA = 0x4206 + + def _nb_clean(func, *args): try: return func(*args) @@ -34,29 +45,33 @@ class SockWrapper: log('%r: done writing\n' % self) self.shut_write = True self.sock.shutdown(socket.SHUT_WR) + + def uwrite(self, buf): + self.sock.setblocking(False) + return _nb_clean(self.sock.send, buf) def write(self, buf): assert(buf) - self.sock.setblocking(False) - return _nb_clean(self.sock.send, buf) + return self.uwrite(buf) - def fill(self): + def uread(self): if self.shut_read: return self.sock.setblocking(False) - rb = _nb_clean(self.sock.recv, 65536) + return _nb_clean(self.sock.recv, 65536) + + def fill(self): + if self.buf: + return + rb = self.uread() if rb: self.buf.append(rb) if rb == '': # empty string means EOF; None means temporarily empty self.noread() - def maybe_fill(self): - if not self.buf: - self.fill() - def copy_to(self, outwrap): if self.buf and self.buf[0]: - wrote = outwrap.sock.send(self.buf[0]) + wrote = outwrap.write(self.buf[0]) self.buf[0] = self.buf[0][wrote:] while self.buf and not self.buf[0]: self.buf.pop(0) @@ -102,8 +117,8 @@ class Proxy(Handler): r.add(self.wrap2.sock) def callback(self): - self.wrap1.maybe_fill() - self.wrap2.maybe_fill() + self.wrap1.fill() + self.wrap2.fill() self.wrap1.copy_to(self.wrap2) self.wrap2.copy_to(self.wrap1) if (self.wrap1.shut_read and self.wrap2.shut_read and @@ -111,3 +126,125 @@ class Proxy(Handler): self.ok = False +class Mux(Handler): + def __init__(self, sock): + Handler.__init__(self, [sock]) + self.sock = sock + self.channels = {} + self.chani = 0 + self.want = 0 + self.inbuf = '' + self.outbuf = [] + self.send(0, CMD_PING, 'chicken') + + def next_channel(self): + # channel 0 is special, so we never allocate it + for timeout in xrange(1024): + self.chani += 1 + if self.chani > 65535: + self.chani = 1 + if not self.channels.get(self.chani): + return self.chani + + def send(self, channel, cmd, data): + data = str(data) + assert(len(data) <= 65535) + p = struct.pack('!ccHHH', 'S', 'S', channel, cmd, len(data)) + data + self.outbuf.append(p) + log('Mux: send queue is %d/%d\n' + % (len(self.outbuf), sum(len(b) for b in self.outbuf))) + + def got_packet(self, channel, cmd, data): + log('--got-packet--\n') + if cmd == CMD_PING: + self.mux.send(0, CMD_PONG, data) + elif cmd == CMD_EXIT: + self.ok = False + else: + c = self.channels[channel] + c.got_packet(cmd, data) + + def flush(self): + self.sock.setblocking(False) + if self.outbuf and self.outbuf[0]: + wrote = _nb_clean(self.sock.send, self.outbuf[0]) + if wrote: + self.outbuf[0] = self.outbuf[0][wrote:] + while self.outbuf and not self.outbuf[0]: + self.outbuf.pop() + + def fill(self): + self.sock.setblocking(False) + b = _nb_clean(self.sock.recv, 32768) + if b == '': # EOF + ok = False + if b: + self.inbuf += b + + def handle(self): + log('inbuf is: %r\n' % self.inbuf) + if len(self.inbuf) >= (self.want or HDR_LEN): + (s1,s2,channel,cmd,datalen) = struct.unpack('!ccHHH', + self.inbuf[:HDR_LEN]) + assert(s1 == 'S') + assert(s2 == 'S') + self.want = datalen + HDR_LEN + if self.want and len(self.inbuf) >= self.want: + data = self.inbuf[HDR_LEN:self.want] + self.inbuf = self.inbuf[self.want:] + self.got_packet(channel, cmd, data) + else: + self.fill() + + def pre_select(self, r, w, x): + if self.inbuf < (self.want or HDR_LEN): + r.add(self.sock) + if self.outbuf: + w.add(self.sock) + + def callback(self): + (r,w,x) = select.select([self.sock], [self.sock], [], 0) + if self.sock in r: + self.handle() + if self.outbuf and self.sock in w: + self.flush() + + +class MuxWrapper(SockWrapper): + def __init__(self, mux, channel): + SockWrapper.__init__(self, mux.sock) + self.mux = mux + self.channel = channel + self.mux.channels[channel] = self + log('Created MuxWrapper on channel %d\n' % channel) + + def noread(self): + if not self.shut_read: + self.shut_read = True + + def nowrite(self): + if not self.shut_write: + self.shut_write = True + self.mux.send(self.channel, CMD_EOF, '') + + def uwrite(self, buf): + self.mux.send(self.channel, CMD_DATA, buf) + return len(buf) + + def uread(self): + if self.shut_read: + return '' # EOF + else: + return None # no data available right now + + def got_packet(self, cmd, data): + if cmd == CMD_CLOSE: + self.noread() + self.nowrite() + elif cmd == CMD_EOF: + self.noread() + elif cmd == CMD_DATA: + self.buf.append(data) + else: + raise Exception('unknown command %d (%d bytes)' + % (cmd, len(data))) |