summaryrefslogtreecommitdiffstats
path: root/dns-transport/src/tls.rs
blob: 959dbc9ec43ed86f8450a36d26e3efa9b50f6035 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
#![cfg_attr(not(feature = "tls"), allow(unused))]

use std::net::TcpStream;
use std::io::Write;

use log::*;

use dns::{Request, Response};
use super::{Transport, Error, TcpTransport};
use super::tls_stream::TlsStream;


/// The **TLS transport**, which sends DNS wire data using TCP through an
/// encrypted TLS connection.
pub struct TlsTransport {
    addr: String,
}

impl TlsTransport {

    /// Creates a new TLS transport that connects to the given host.
    pub fn new(addr: String) -> Self {
        Self { addr }
    }
}



impl Transport for TlsTransport {

    #[cfg(feature = "with_tls")]
    fn send(&self, request: &Request) -> Result<Response, Error> {
        info!("Opening TLS socket");

        let domain = self.sni_domain();
        info!("Connecting using domain {:?}", domain);
        let mut stream =
            if self.addr.contains(':') {
                let mut parts = self.addr.split(":");
                let domain = parts.nth(0).unwrap();
                let port = parts.last().unwrap().parse::<u16>().expect("Invalid port number");

                Self::stream(domain, port)?
            }
            else {
                Self::stream(&*self.addr, 853)?
            };


        debug!("Connected");

        // The message is prepended with the length when sent over TCP,
        // so the server knows how long it is (RFC 1035 §4.2.2)
        let mut bytes_to_send = request.to_bytes().expect("failed to serialise request");
        TcpTransport::prefix_with_length(&mut bytes_to_send);

        info!("Sending {} bytes of data to {} over TLS", bytes_to_send.len(), self.addr);
        stream.write_all(&bytes_to_send)?;
        debug!("Wrote all bytes");

        let read_bytes = TcpTransport::length_prefixed_read(&mut stream)?;
        let response = Response::from_bytes(&read_bytes)?;
        Ok(response)
    }

    #[cfg(not(feature = "with_tls"))]
    fn send(&self, request: &Request) -> Result<Response, Error> {
        unreachable!("TLS feature disabled")
    }
}

impl TlsTransport {
    fn sni_domain(&self) -> &str {
        if let Some(colon_index) = self.addr.find(':') {
            &self.addr[.. colon_index]
        }
        else {
            &self.addr[..]
        }
    }
}