diff options
author | Avery Pennarun <apenwarr@gmail.com> | 2010-05-02 00:52:06 -0400 |
---|---|---|
committer | Avery Pennarun <apenwarr@gmail.com> | 2010-05-02 00:52:06 -0400 |
commit | 915a96b0ec1f2862431d296e116d687bc79f0ba3 (patch) | |
tree | 30ca301bba0814f91287fa4d4581ff721982acdd | |
parent | d435c41bdbf2ffd889cefe6710b8fc126304ef38 (diff) |
We now have a server that works... some of the time.
There still seem to be some weird timing and/or closing-related bugs, since
I can't load the eqldata project correctly unless I use --noserver.
-rw-r--r-- | client.py | 26 | ||||
-rw-r--r-- | helpers.py | 4 | ||||
-rwxr-xr-x | main.py | 7 | ||||
-rw-r--r-- | server.py | 43 | ||||
-rw-r--r-- | ssnet.py | 104 |
5 files changed, 146 insertions, 38 deletions
@@ -1,5 +1,5 @@ import struct, socket, select, subprocess, errno -import ssnet, ssh +import ssnet, ssh, helpers from ssnet import SockWrapper, Handler, Proxy, Mux, MuxWrapper from helpers import * @@ -24,10 +24,21 @@ def iptables_setup(port, subnets): def _main(listener, listenport, use_server, remotename, subnets): handlers = [] if use_server: + helpers.logprefix = 'c : ' (serverproc, serversock) = ssh.connect(remotename) mux = Mux(serversock, serversock) handlers.append(mux) + expected = 'SSHUTTLE0001' + initstring = serversock.recv(len(expected)) + if initstring != expected: + raise Exception('expected server init string %r; got %r' + % (expected, initstring)) + + rv = serverproc.poll() + if rv: + raise Exception('server died with error code %d' % rv) + # we definitely want to do this *after* starting ssh, or we might end # up intercepting the ssh connection! iptables_setup(listenport, subnets) @@ -45,21 +56,24 @@ def _main(listener, listenport, use_server, remotename, subnets): mux.send(chan, ssnet.CMD_CONNECT, '%s,%s' % dstip) outwrap = MuxWrapper(mux, chan) else: - outsock = socket.socket() - outsock.setsockopt(socket.SOL_IP, socket.IP_TTL, 42) - outsock.connect(dstip) - outwrap = SockWrapper(outsock, outsock) + outwrap = ssnet.connect_dst(dstip[0], dstip[1]) handlers.append(Proxy(SockWrapper(sock, sock), outwrap)) handlers.append(Handler([listener], onaccept)) while 1: + if use_server: + rv = serverproc.poll() + if rv: + raise Exception('server died with error code %d' % rv) + r = set() w = set() x = set() handlers = filter(lambda s: s.ok, handlers) for s in handlers: s.pre_select(r,w,x) - log('\nWaiting: %d[%d,%d,%d]...\n' + log('\n') + log('Waiting: %d[%d,%d,%d]...\n' % (len(handlers), len(r), len(w), len(x))) (r,w,x) = select.select(r,w,x) log('r=%r w=%r x=%r\n' % (r,w,x)) @@ -1,6 +1,8 @@ import sys, os +logprefix = '' + def log(s): sys.stdout.flush() - sys.stderr.write(s) + sys.stderr.write(logprefix + s) sys.stderr.flush() @@ -1,6 +1,6 @@ #!/usr/bin/env python import sys, os, re -import options, client, iptables +import options, client, iptables, server # list of: @@ -58,10 +58,7 @@ o = options.Options('sshuttle', optspec) (opt, flags, extra) = o.parse(sys.argv[1:]) if opt.server: - #o.fatal('server mode not implemented yet') - os.dup2(2,1) - os.execvp('hd', ['hd']) - sys.exit(1) + sys.exit(server.main()) elif opt.iptables: if len(extra) < 1: o.fatal('at least one argument expected') diff --git a/server.py b/server.py new file mode 100644 index 0000000..499aa4c --- /dev/null +++ b/server.py @@ -0,0 +1,43 @@ +import struct, socket, select +import ssnet, helpers +from ssnet import SockWrapper, Handler, Proxy, Mux, MuxWrapper +from helpers import * + + +def main(): + # synchronization header + sys.stdout.write('SSHUTTLE0001') + sys.stdout.flush() + + helpers.logprefix = ' s: ' + handlers = [] + mux = Mux(socket.fromfd(sys.stdin.fileno(), + socket.AF_INET, socket.SOCK_STREAM), + socket.fromfd(sys.stdout.fileno(), + socket.AF_INET, socket.SOCK_STREAM)) + handlers.append(mux) + + def new_channel(channel, data): + (dstip,dstport) = data.split(',', 1) + dstport = int(dstport) + outwrap = ssnet.connect_dst(dstip,dstport) + handlers.append(Proxy(MuxWrapper(mux, channel), outwrap)) + + mux.new_channel = new_channel + + while mux.ok: + r = set() + w = set() + x = set() + handlers = filter(lambda s: s.ok, handlers) + for s in handlers: + s.pre_select(r,w,x) + log('\n') + log('Waiting: %d[%d,%d,%d]...\n' + % (len(handlers), len(r), len(w), len(x))) + (r,w,x) = select.select(r,w,x) + log('r=%r w=%r x=%r\n' % (r,w,x)) + ready = set(r) | set(w) | set(x) + for s in handlers: + if s.socks & ready: + s.callback() @@ -16,16 +16,27 @@ def _nb_clean(func, *args): try: return func(*args) except socket.error, e: - if e.args[0] in (errno.EWOULDBLOCK, errno.EAGAIN): + if e.args[0] not in (errno.EWOULDBLOCK, errno.EAGAIN): + raise + else: return None - raise + + +def _try_peername(sock): + try: + return sock.getpeername() + except socket.error, e: + if e.args[0] not in (errno.ENOTCONN,): + raise + else: + return ('0.0.0.0',0) class SockWrapper: def __init__(self, rsock, wsock): self.rsock = rsock self.wsock = wsock - self.peername = self.rsock.getpeername() + self.peername = _try_peername(self.rsock) self.shut_read = self.shut_write = False self.buf = [] @@ -45,11 +56,20 @@ class SockWrapper: if not self.shut_write: log('%r: done writing\n' % self) self.shut_write = True - self.wsock.shutdown(socket.SHUT_WR) + try: + self.wsock.shutdown(socket.SHUT_WR) + except socket.error: + pass def uwrite(self, buf): self.wsock.setblocking(False) - return _nb_clean(self.wsock.send, buf) + try: + return _nb_clean(self.wsock.send, buf) + except socket.error: + # unexpected error... stream is dead + self.nowrite() + self.noread() + return 0 def write(self, buf): assert(buf) @@ -59,7 +79,10 @@ class SockWrapper: if self.shut_read: return self.rsock.setblocking(False) - return _nb_clean(self.rsock.recv, 65536) + try: + return _nb_clean(self.rsock.recv, 65536) + except socket.error: + return '' # unexpected error... we'll call it EOF def fill(self): if self.buf: @@ -133,6 +156,7 @@ class Mux(Handler): Handler.__init__(self, [rsock, wsock]) self.rsock = rsock self.wsock = wsock + self.new_channel = None self.channels = {} self.chani = 0 self.want = 0 @@ -160,12 +184,18 @@ class Mux(Handler): def got_packet(self, channel, cmd, data): log('--got-packet--\n') if cmd == CMD_PING: - self.mux.send(0, CMD_PONG, data) + self.send(0, CMD_PONG, data) + elif cmd == CMD_PONG: + log('received PING response\n') elif cmd == CMD_EXIT: self.ok = False + elif cmd == CMD_CONNECT: + assert(not self.channels.get(channel)) + if self.new_channel: + self.new_channel(channel, data) else: - c = self.channels[channel] - c.got_packet(cmd, data) + callback = self.channels[channel] + callback(cmd, data) def flush(self): self.wsock.setblocking(False) @@ -180,28 +210,30 @@ class Mux(Handler): self.rsock.setblocking(False) b = _nb_clean(self.rsock.recv, 32768) if b == '': # EOF - ok = False + self.ok = False if b: self.inbuf += b def handle(self): - log('inbuf is: %r\n' % self.inbuf) - if len(self.inbuf) >= (self.want or HDR_LEN): - (s1,s2,channel,cmd,datalen) = struct.unpack('!ccHHH', - self.inbuf[:HDR_LEN]) - assert(s1 == 'S') - assert(s2 == 'S') - self.want = datalen + HDR_LEN - if self.want and len(self.inbuf) >= self.want: - data = self.inbuf[HDR_LEN:self.want] - self.inbuf = self.inbuf[self.want:] - self.got_packet(channel, cmd, data) - else: - self.fill() + self.fill() + log('inbuf is: (%d,%d) %r\n' % (self.want, len(self.inbuf), self.inbuf)) + while 1: + if len(self.inbuf) >= (self.want or HDR_LEN): + (s1,s2,channel,cmd,datalen) = \ + struct.unpack('!ccHHH', self.inbuf[:HDR_LEN]) + assert(s1 == 'S') + assert(s2 == 'S') + self.want = datalen + HDR_LEN + if self.want and len(self.inbuf) >= self.want: + data = self.inbuf[HDR_LEN:self.want] + self.inbuf = self.inbuf[self.want:] + self.want = 0 + self.got_packet(channel, cmd, data) + else: + break def pre_select(self, r, w, x): - if self.inbuf < (self.want or HDR_LEN): - r.add(self.rsock) + r.add(self.rsock) if self.outbuf: w.add(self.wsock) @@ -218,9 +250,16 @@ class MuxWrapper(SockWrapper): SockWrapper.__init__(self, mux.rsock, mux.wsock) self.mux = mux self.channel = channel - self.mux.channels[channel] = self + self.mux.channels[channel] = self.got_packet log('Created MuxWrapper on channel %d\n' % channel) + def __del__(self): + self.nowrite() + SockWrapper.__del__(self) + + def __repr__(self): + return 'SW%r:Mux#%d' % (self.peername,self.channel) + def noread(self): if not self.shut_read: self.shut_read = True @@ -231,6 +270,8 @@ class MuxWrapper(SockWrapper): self.mux.send(self.channel, CMD_EOF, '') def uwrite(self, buf): + if len(buf) > 65535: + buf = buf[:32768] self.mux.send(self.channel, CMD_DATA, buf) return len(buf) @@ -251,3 +292,14 @@ class MuxWrapper(SockWrapper): else: raise Exception('unknown command %d (%d bytes)' % (cmd, len(data))) + + +def connect_dst(ip, port): + outsock = socket.socket() + outsock.setsockopt(socket.SOL_IP, socket.IP_TTL, 42) + try: + outsock.connect((ip,port)) + except socket.error, e: + if e.args[0] not in [errno.ECONNREFUSED]: + raise + return SockWrapper(outsock,outsock) |