summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorBrian May <brian@vpac.org>2011-05-12 14:37:19 +1000
committerBrian May <brian@vpac.org>2011-05-12 14:37:19 +1000
commit9915d736fea145f9f6eccd8f6ab153d590ec14c8 (patch)
tree3b8cd6d64d6ff4cb3cf8cc57da95f1da65b2d6f5
parent783d33cada4bbb7282877330abd7b16c7e2ab1fe (diff)
Enhanced DNS support. Initial version.dns
-rw-r--r--client.py112
-rw-r--r--main.py18
2 files changed, 122 insertions, 8 deletions
diff --git a/client.py b/client.py
index 0ff5f2b..21ecd22 100644
--- a/client.py
+++ b/client.py
@@ -175,8 +175,60 @@ class FirewallClient:
raise Fatal('cleanup: %r returned %d' % (self.argv, rv))
+def unpack_dns_name(buf, off):
+ name = ''
+ while True:
+ # get the next octet from buffer
+ n = ord(buf[off])
+
+ # zero octet terminates name
+ if n == 0:
+ off += 1
+ break
+
+ # top two bits on
+ # => a 2 octect pointer to another part of the buffer
+ elif (n & 0xc0) == 0xc0:
+ ptr = struct.unpack('>H', buf[off:off+2])[0] & 0x3fff
+ off = ptr
+
+ # an octet representing the number of bytes to process.
+ else:
+ off += 1
+ name = name + buf[off:off+n] + '.'
+ off += n
+
+ return name.strip('.'), off
+
+class dnspkt:
+ def unpack(self, buf, off):
+ l = len(buf)
+
+ (self.id, self.op, self.qdcount, self.ancount, self.nscount, self.arcount) = struct.unpack("!HHHHHH",buf[off:off+12])
+ off += 12
+
+ self.q = []
+ for i in range(self.qdcount):
+ qname, off = unpack_dns_name(buf, off)
+ qtype, qclass = struct.unpack('!HH', buf[off:off+4])
+ off += 4
+ self.q.append( (qname,qtype,qclass) )
+
+ return off
+
+ def match_q_domain(self, domain):
+ l = len(domain)
+ for qname,qtype,qclass in self.q:
+ if qname[-l:] == domain:
+ if l==len(qname):
+ return True
+ elif qname[-l-1] == '.':
+ return True
+ return False
+
def _main(listener, fw, ssh_cmd, remotename, python, latency_control,
- dnslistener, seed_hosts, auto_nets,
+ dnslistener, dnsforwarder, dns_domains, dns_to,
+ seed_hosts, auto_nets,
syslog, daemon):
handlers = []
if helpers.verbose >= 1:
@@ -283,6 +335,7 @@ def _main(listener, fw, ssh_cmd, remotename, python, latency_control,
handlers.append(Handler([listener], onaccept))
dnsreqs = {}
+ dnsforwards = {}
def dns_done(chan, data):
peer,timeout = dnsreqs.get(chan) or (None,None)
debug3('dns_done: channel=%r peer=%r\n' % (chan, peer))
@@ -295,16 +348,54 @@ def _main(listener, fw, ssh_cmd, remotename, python, latency_control,
now = time.time()
if pkt:
debug1('DNS request from %r: %d bytes\n' % (peer, len(pkt)))
- chan = mux.next_channel()
- dnsreqs[chan] = peer,now+30
- mux.send(chan, ssnet.CMD_DNS_REQ, pkt)
- mux.channels[chan] = lambda cmd,data: dns_done(chan,data)
+ dns = dnspkt()
+ dns.unpack(pkt, 0)
+
+ match=False
+ if dns_domains is not None:
+ for domain in dns_domains:
+ if dns.match_q_domain(domain):
+ match=True
+ break
+
+ if match:
+ debug3("We need to redirect this request remotely\n")
+ chan = mux.next_channel()
+ dnsreqs[chan] = peer,now+30
+ mux.send(chan, ssnet.CMD_DNS_REQ, pkt)
+ mux.channels[chan] = lambda cmd,data: dns_done(chan,data)
+ else:
+ debug3("We need to forward this request locally\n")
+ dnsforwarder.sendto(pkt, dns_to)
+ dnsforwards[dns.id] = peer,now+30
for chan,(peer,timeout) in dnsreqs.items():
if timeout < now:
del dnsreqs[chan]
+ for chan,(peer,timeout) in dnsforwards.items():
+ if timeout < now:
+ del dnsforwards[chan]
debug3('Remaining DNS requests: %d\n' % len(dnsreqs))
+ debug3('Remaining DNS forwards: %d\n' % len(dnsforwards))
if dnslistener:
handlers.append(Handler([dnslistener], ondns))
+ def ondnsforward():
+ debug1("We got a response.\n")
+ pkt,server = dnsforwarder.recvfrom(4096)
+ now = time.time()
+ if server[0] != dns_to[0] or server[1] != dns_to[1]:
+ debug1("Ooops. The response came from the wrong server. Ignoring\n")
+ else:
+ dns = dnspkt()
+ dns.unpack(pkt, 0)
+ chan=dns.id
+ peer,timeout = dnsforwards.get(chan) or (None,None)
+ debug3('dns_done: channel=%r peer=%r\n' % (chan, peer))
+ if peer:
+ del dnsforwards[chan]
+ debug3('doing sendto %r\n' % (peer,))
+ dnslistener.sendto(pkt, peer)
+ if dnsforwarder:
+ handlers.append(Handler([dnsforwarder], ondnsforward))
if seed_hosts != None:
debug1('seed_hosts: %r\n' % seed_hosts)
@@ -321,7 +412,8 @@ def _main(listener, fw, ssh_cmd, remotename, python, latency_control,
mux.callback()
-def main(listenip, ssh_cmd, remotename, python, latency_control, dns,
+def main(listenip, ssh_cmd, remotename, python, latency_control,
+ dns, dns_domains, dns_to,
seed_hosts, auto_nets,
subnets_include, subnets_exclude, syslog, daemon, pidfile):
if syslog:
@@ -366,15 +458,21 @@ def main(listenip, ssh_cmd, remotename, python, latency_control, dns,
dnsip = dnslistener.getsockname()
debug1('DNS listening on %r.\n' % (dnsip,))
dnsport = dnsip[1]
+
+ dnsforwarder = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
+ dnsforwarder.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ dnsforwarder.setsockopt(socket.SOL_IP, socket.IP_TTL, 42)
else:
dnsport = 0
dnslistener = None
+ dnsforwarder = None
fw = FirewallClient(listenip[1], subnets_include, subnets_exclude, dnsport)
try:
return _main(listener, fw, ssh_cmd, remotename,
- python, latency_control, dnslistener,
+ python, latency_control,
+ dnslistener, dnsforwarder, dns_domains, dns_to,
seed_hosts, auto_nets, syslog, daemon)
finally:
try:
diff --git a/main.py b/main.py
index 1cf00af..6afeefe 100644
--- a/main.py
+++ b/main.py
@@ -54,6 +54,8 @@ l,listen= transproxy to this ip address and port number [127.0.0.1:0]
H,auto-hosts scan for remote hostnames and update local /etc/hosts
N,auto-nets automatically determine subnets to route
dns capture local DNS requests and forward to the remote DNS server
+dns-domains= comma seperated list of DNS domains for DNS forwarding
+dns-to= forward any DNS requests that don't match domains to this address
python= path to python interpreter on the remote server
r,remote= ssh hostname (and optional username) of remote sshuttle server
x,exclude= exclude this subnet (can be used more than once)
@@ -110,12 +112,26 @@ try:
sh = []
else:
sh = None
+ if opt.dns and opt.dns_domains:
+ dns_domains = opt.dns_domains.split(",")
+ if opt.dns_to:
+ addr,colon,port = opt.dns_to.rpartition(":")
+ if colon == ":":
+ dns_to = ( addr, int(port) )
+ else:
+ dns_to = ( port, 53 )
+ else:
+ o.fatal('--dns-to=ip is required with --dns-domains=list')
+ else:
+ dns_domains = None
+ dns_to = None
+
sys.exit(client.main(parse_ipport(opt.listen or '0.0.0.0:0'),
opt.ssh_cmd,
remotename,
opt.python,
opt.latency_control,
- opt.dns,
+ opt.dns, dns_domains, dns_to,
sh,
opt.auto_nets,
parse_subnets(includes),