summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorAvery Pennarun <apenwarr@gmail.com>2010-05-02 00:52:06 -0400
committerAvery Pennarun <apenwarr@gmail.com>2010-05-02 00:52:06 -0400
commit915a96b0ec1f2862431d296e116d687bc79f0ba3 (patch)
tree30ca301bba0814f91287fa4d4581ff721982acdd
parentd435c41bdbf2ffd889cefe6710b8fc126304ef38 (diff)
We now have a server that works... some of the time.
There still seem to be some weird timing and/or closing-related bugs, since I can't load the eqldata project correctly unless I use --noserver.
-rw-r--r--client.py26
-rw-r--r--helpers.py4
-rwxr-xr-xmain.py7
-rw-r--r--server.py43
-rw-r--r--ssnet.py104
5 files changed, 146 insertions, 38 deletions
diff --git a/client.py b/client.py
index 2ca75ba..854005d 100644
--- a/client.py
+++ b/client.py
@@ -1,5 +1,5 @@
import struct, socket, select, subprocess, errno
-import ssnet, ssh
+import ssnet, ssh, helpers
from ssnet import SockWrapper, Handler, Proxy, Mux, MuxWrapper
from helpers import *
@@ -24,10 +24,21 @@ def iptables_setup(port, subnets):
def _main(listener, listenport, use_server, remotename, subnets):
handlers = []
if use_server:
+ helpers.logprefix = 'c : '
(serverproc, serversock) = ssh.connect(remotename)
mux = Mux(serversock, serversock)
handlers.append(mux)
+ expected = 'SSHUTTLE0001'
+ initstring = serversock.recv(len(expected))
+ if initstring != expected:
+ raise Exception('expected server init string %r; got %r'
+ % (expected, initstring))
+
+ rv = serverproc.poll()
+ if rv:
+ raise Exception('server died with error code %d' % rv)
+
# we definitely want to do this *after* starting ssh, or we might end
# up intercepting the ssh connection!
iptables_setup(listenport, subnets)
@@ -45,21 +56,24 @@ def _main(listener, listenport, use_server, remotename, subnets):
mux.send(chan, ssnet.CMD_CONNECT, '%s,%s' % dstip)
outwrap = MuxWrapper(mux, chan)
else:
- outsock = socket.socket()
- outsock.setsockopt(socket.SOL_IP, socket.IP_TTL, 42)
- outsock.connect(dstip)
- outwrap = SockWrapper(outsock, outsock)
+ outwrap = ssnet.connect_dst(dstip[0], dstip[1])
handlers.append(Proxy(SockWrapper(sock, sock), outwrap))
handlers.append(Handler([listener], onaccept))
while 1:
+ if use_server:
+ rv = serverproc.poll()
+ if rv:
+ raise Exception('server died with error code %d' % rv)
+
r = set()
w = set()
x = set()
handlers = filter(lambda s: s.ok, handlers)
for s in handlers:
s.pre_select(r,w,x)
- log('\nWaiting: %d[%d,%d,%d]...\n'
+ log('\n')
+ log('Waiting: %d[%d,%d,%d]...\n'
% (len(handlers), len(r), len(w), len(x)))
(r,w,x) = select.select(r,w,x)
log('r=%r w=%r x=%r\n' % (r,w,x))
diff --git a/helpers.py b/helpers.py
index 4b46dc3..d397cd8 100644
--- a/helpers.py
+++ b/helpers.py
@@ -1,6 +1,8 @@
import sys, os
+logprefix = ''
+
def log(s):
sys.stdout.flush()
- sys.stderr.write(s)
+ sys.stderr.write(logprefix + s)
sys.stderr.flush()
diff --git a/main.py b/main.py
index 02fa00c..9115e63 100755
--- a/main.py
+++ b/main.py
@@ -1,6 +1,6 @@
#!/usr/bin/env python
import sys, os, re
-import options, client, iptables
+import options, client, iptables, server
# list of:
@@ -58,10 +58,7 @@ o = options.Options('sshuttle', optspec)
(opt, flags, extra) = o.parse(sys.argv[1:])
if opt.server:
- #o.fatal('server mode not implemented yet')
- os.dup2(2,1)
- os.execvp('hd', ['hd'])
- sys.exit(1)
+ sys.exit(server.main())
elif opt.iptables:
if len(extra) < 1:
o.fatal('at least one argument expected')
diff --git a/server.py b/server.py
new file mode 100644
index 0000000..499aa4c
--- /dev/null
+++ b/server.py
@@ -0,0 +1,43 @@
+import struct, socket, select
+import ssnet, helpers
+from ssnet import SockWrapper, Handler, Proxy, Mux, MuxWrapper
+from helpers import *
+
+
+def main():
+ # synchronization header
+ sys.stdout.write('SSHUTTLE0001')
+ sys.stdout.flush()
+
+ helpers.logprefix = ' s: '
+ handlers = []
+ mux = Mux(socket.fromfd(sys.stdin.fileno(),
+ socket.AF_INET, socket.SOCK_STREAM),
+ socket.fromfd(sys.stdout.fileno(),
+ socket.AF_INET, socket.SOCK_STREAM))
+ handlers.append(mux)
+
+ def new_channel(channel, data):
+ (dstip,dstport) = data.split(',', 1)
+ dstport = int(dstport)
+ outwrap = ssnet.connect_dst(dstip,dstport)
+ handlers.append(Proxy(MuxWrapper(mux, channel), outwrap))
+
+ mux.new_channel = new_channel
+
+ while mux.ok:
+ r = set()
+ w = set()
+ x = set()
+ handlers = filter(lambda s: s.ok, handlers)
+ for s in handlers:
+ s.pre_select(r,w,x)
+ log('\n')
+ log('Waiting: %d[%d,%d,%d]...\n'
+ % (len(handlers), len(r), len(w), len(x)))
+ (r,w,x) = select.select(r,w,x)
+ log('r=%r w=%r x=%r\n' % (r,w,x))
+ ready = set(r) | set(w) | set(x)
+ for s in handlers:
+ if s.socks & ready:
+ s.callback()
diff --git a/ssnet.py b/ssnet.py
index 0966cec..0f01d0a 100644
--- a/ssnet.py
+++ b/ssnet.py
@@ -16,16 +16,27 @@ def _nb_clean(func, *args):
try:
return func(*args)
except socket.error, e:
- if e.args[0] in (errno.EWOULDBLOCK, errno.EAGAIN):
+ if e.args[0] not in (errno.EWOULDBLOCK, errno.EAGAIN):
+ raise
+ else:
return None
- raise
+
+
+def _try_peername(sock):
+ try:
+ return sock.getpeername()
+ except socket.error, e:
+ if e.args[0] not in (errno.ENOTCONN,):
+ raise
+ else:
+ return ('0.0.0.0',0)
class SockWrapper:
def __init__(self, rsock, wsock):
self.rsock = rsock
self.wsock = wsock
- self.peername = self.rsock.getpeername()
+ self.peername = _try_peername(self.rsock)
self.shut_read = self.shut_write = False
self.buf = []
@@ -45,11 +56,20 @@ class SockWrapper:
if not self.shut_write:
log('%r: done writing\n' % self)
self.shut_write = True
- self.wsock.shutdown(socket.SHUT_WR)
+ try:
+ self.wsock.shutdown(socket.SHUT_WR)
+ except socket.error:
+ pass
def uwrite(self, buf):
self.wsock.setblocking(False)
- return _nb_clean(self.wsock.send, buf)
+ try:
+ return _nb_clean(self.wsock.send, buf)
+ except socket.error:
+ # unexpected error... stream is dead
+ self.nowrite()
+ self.noread()
+ return 0
def write(self, buf):
assert(buf)
@@ -59,7 +79,10 @@ class SockWrapper:
if self.shut_read:
return
self.rsock.setblocking(False)
- return _nb_clean(self.rsock.recv, 65536)
+ try:
+ return _nb_clean(self.rsock.recv, 65536)
+ except socket.error:
+ return '' # unexpected error... we'll call it EOF
def fill(self):
if self.buf:
@@ -133,6 +156,7 @@ class Mux(Handler):
Handler.__init__(self, [rsock, wsock])
self.rsock = rsock
self.wsock = wsock
+ self.new_channel = None
self.channels = {}
self.chani = 0
self.want = 0
@@ -160,12 +184,18 @@ class Mux(Handler):
def got_packet(self, channel, cmd, data):
log('--got-packet--\n')
if cmd == CMD_PING:
- self.mux.send(0, CMD_PONG, data)
+ self.send(0, CMD_PONG, data)
+ elif cmd == CMD_PONG:
+ log('received PING response\n')
elif cmd == CMD_EXIT:
self.ok = False
+ elif cmd == CMD_CONNECT:
+ assert(not self.channels.get(channel))
+ if self.new_channel:
+ self.new_channel(channel, data)
else:
- c = self.channels[channel]
- c.got_packet(cmd, data)
+ callback = self.channels[channel]
+ callback(cmd, data)
def flush(self):
self.wsock.setblocking(False)
@@ -180,28 +210,30 @@ class Mux(Handler):
self.rsock.setblocking(False)
b = _nb_clean(self.rsock.recv, 32768)
if b == '': # EOF
- ok = False
+ self.ok = False
if b:
self.inbuf += b
def handle(self):
- log('inbuf is: %r\n' % self.inbuf)
- if len(self.inbuf) >= (self.want or HDR_LEN):
- (s1,s2,channel,cmd,datalen) = struct.unpack('!ccHHH',
- self.inbuf[:HDR_LEN])
- assert(s1 == 'S')
- assert(s2 == 'S')
- self.want = datalen + HDR_LEN
- if self.want and len(self.inbuf) >= self.want:
- data = self.inbuf[HDR_LEN:self.want]
- self.inbuf = self.inbuf[self.want:]
- self.got_packet(channel, cmd, data)
- else:
- self.fill()
+ self.fill()
+ log('inbuf is: (%d,%d) %r\n' % (self.want, len(self.inbuf), self.inbuf))
+ while 1:
+ if len(self.inbuf) >= (self.want or HDR_LEN):
+ (s1,s2,channel,cmd,datalen) = \
+ struct.unpack('!ccHHH', self.inbuf[:HDR_LEN])
+ assert(s1 == 'S')
+ assert(s2 == 'S')
+ self.want = datalen + HDR_LEN
+ if self.want and len(self.inbuf) >= self.want:
+ data = self.inbuf[HDR_LEN:self.want]
+ self.inbuf = self.inbuf[self.want:]
+ self.want = 0
+ self.got_packet(channel, cmd, data)
+ else:
+ break
def pre_select(self, r, w, x):
- if self.inbuf < (self.want or HDR_LEN):
- r.add(self.rsock)
+ r.add(self.rsock)
if self.outbuf:
w.add(self.wsock)
@@ -218,9 +250,16 @@ class MuxWrapper(SockWrapper):
SockWrapper.__init__(self, mux.rsock, mux.wsock)
self.mux = mux
self.channel = channel
- self.mux.channels[channel] = self
+ self.mux.channels[channel] = self.got_packet
log('Created MuxWrapper on channel %d\n' % channel)
+ def __del__(self):
+ self.nowrite()
+ SockWrapper.__del__(self)
+
+ def __repr__(self):
+ return 'SW%r:Mux#%d' % (self.peername,self.channel)
+
def noread(self):
if not self.shut_read:
self.shut_read = True
@@ -231,6 +270,8 @@ class MuxWrapper(SockWrapper):
self.mux.send(self.channel, CMD_EOF, '')
def uwrite(self, buf):
+ if len(buf) > 65535:
+ buf = buf[:32768]
self.mux.send(self.channel, CMD_DATA, buf)
return len(buf)
@@ -251,3 +292,14 @@ class MuxWrapper(SockWrapper):
else:
raise Exception('unknown command %d (%d bytes)'
% (cmd, len(data)))
+
+
+def connect_dst(ip, port):
+ outsock = socket.socket()
+ outsock.setsockopt(socket.SOL_IP, socket.IP_TTL, 42)
+ try:
+ outsock.connect((ip,port))
+ except socket.error, e:
+ if e.args[0] not in [errno.ECONNREFUSED]:
+ raise
+ return SockWrapper(outsock,outsock)