diff options
Diffstat (limited to 'src/transport.rs')
-rw-r--r-- | src/transport.rs | 280 |
1 files changed, 179 insertions, 101 deletions
diff --git a/src/transport.rs b/src/transport.rs index 7cc1fb1..112d548 100644 --- a/src/transport.rs +++ b/src/transport.rs @@ -1,13 +1,15 @@ //! Transports for communicating with the docker daemon use crate::{Error, Result}; -use futures::{ - future::{self, Either}, - Future, IntoFuture, Stream, +use futures_util::{ + io::{AsyncRead, AsyncWrite}, + stream::Stream, + StreamExt, TryFutureExt, }; use hyper::{ + body::Bytes, client::{Client, HttpConnector}, - header, Body, Chunk, Method, Request, StatusCode, + header, Body, Method, Request, StatusCode, }; #[cfg(feature = "tls")] use hyper_openssl::HttpsConnector; @@ -15,12 +17,14 @@ use hyper_openssl::HttpsConnector; use hyperlocal::UnixConnector; #[cfg(feature = "unix-socket")] use hyperlocal::Uri as DomainUri; -use log::debug; use mime::Mime; +use pin_project::pin_project; use serde::{Deserialize, Serialize}; -use serde_json; -use std::{fmt, iter}; -use tokio_io::{AsyncRead, AsyncWrite}; +use std::{ + fmt, io, iter, + pin::Pin, + task::{Context, Poll}, +}; pub fn tar() -> Mime { "application/tar".parse().unwrap() @@ -66,116 +70,133 @@ impl fmt::Debug for Transport { impl Transport { /// Make a request and return the whole response in a `String` - pub fn request<B>( + pub async fn request<B>( &self, method: Method, - endpoint: &str, + endpoint: impl AsRef<str>, body: Option<(B, Mime)>, - ) -> impl Future<Item = String, Error = Error> + ) -> Result<String> where B: Into<Body>, { - let endpoint = endpoint.to_string(); - self.stream_chunks(method, &endpoint, body, None::<iter::Empty<_>>) - .concat2() - .and_then(|v| { - String::from_utf8(v.to_vec()) - .map_err(Error::Encoding) - .into_future() - }) - .inspect(move |body| debug!("{} raw response: {}", endpoint, body)) + let body = self + .get_body(method, endpoint, body, None::<iter::Empty<_>>) + .await?; + let bytes = hyper::body::to_bytes(body).await?; + let string = String::from_utf8(bytes.to_vec())?; + + Ok(string) } - /// Make a request and return a `Stream` of `Chunks` as they are returned. - pub fn stream_chunks<B, H>( + async fn get_body<B, H>( &self, method: Method, - endpoint: &str, + endpoint: impl AsRef<str>, body: Option<(B, Mime)>, headers: Option<H>, - ) -> impl Stream<Item = Chunk, Error = Error> + ) -> Result<Body> where B: Into<Body>, H: IntoIterator<Item = (&'static str, String)>, { let req = self - .build_request(method, endpoint, body, headers, |_| ()) + .build_request(method, endpoint, body, headers, Request::builder()) .expect("Failed to build request!"); - self.send_request(req) - .and_then(|res| { - let status = res.status(); - match status { - // Success case: pass on the response - StatusCode::OK - | StatusCode::CREATED - | StatusCode::SWITCHING_PROTOCOLS - | StatusCode::NO_CONTENT => Either::A(future::ok(res)), - // Error case: parse the body to try to extract the error message - _ => Either::B( - res.into_body() - .concat2() - .map_err(Error::Hyper) - .and_then(|v| { - String::from_utf8(v.into_iter().collect::<Vec<u8>>()) - .map_err(Error::Encoding) - }) - .and_then(move |body| { - future::err(Error::Fault { - code: status, - message: Self::get_error_message(&body).unwrap_or_else(|| { - status - .canonical_reason() - .unwrap_or_else(|| "unknown error code") - .to_owned() - }), - }) - }), - ), - } - }) - .map(|r| { - // Convert the response body into a stream of chunks - r.into_body().map_err(Error::Hyper) - }) - .flatten_stream() + let response = self.send_request(req).await?; + + let status = response.status(); + + match status { + // Success case: pass on the response + StatusCode::OK + | StatusCode::CREATED + | StatusCode::SWITCHING_PROTOCOLS + | StatusCode::NO_CONTENT => Ok(response.into_body()), + _ => { + let bytes = hyper::body::to_bytes(response.into_body()).await?; + let message_body = String::from_utf8(bytes.to_vec())?; + + Err(Error::Fault { + code: status, + message: Self::get_error_message(&message_body).unwrap_or_else(|| { + status + .canonical_reason() + .unwrap_or_else(|| "unknown error code") + .to_owned() + }), + }) + } + } + } + + async fn get_chunk_stream<B, H>( + &self, + method: Method, + endpoint: impl AsRef<str>, + body: Option<(B, Mime)>, + headers: Option<H>, + ) -> Result<impl Stream<Item = Result<Bytes>>> + where + B: Into<Body>, + H: IntoIterator<Item = (&'static str, String)>, + { + let body = self.get_body(method, endpoint, body, headers).await?; + + Ok(stream_body(body)) + } + + pub fn stream_chunks<'a, H, B>( + &'a self, + method: Method, + endpoint: impl AsRef<str> + 'a, + body: Option<(B, Mime)>, + headers: Option<H>, + ) -> impl Stream<Item = Result<Bytes>> + 'a + where + H: IntoIterator<Item = (&'static str, String)> + 'a, + B: Into<Body> + 'a, + { + self.get_chunk_stream(method, endpoint, body, headers) + .try_flatten_stream() } /// Builds an HTTP request. fn build_request<B, H>( &self, method: Method, - endpoint: &str, + endpoint: impl AsRef<str>, body: Option<(B, Mime)>, headers: Option<H>, - f: impl FnOnce(&mut ::hyper::http::request::Builder), + builder: hyper::http::request::Builder, ) -> Result<Request<Body>> where B: Into<Body>, H: IntoIterator<Item = (&'static str, String)>, { - let mut builder = Request::builder(); - f(&mut builder); - let req = match *self { Transport::Tcp { ref host, .. } => { - builder.method(method).uri(&format!("{}{}", host, endpoint)) + builder + .method(method) + .uri(&format!("{}{}", host, endpoint.as_ref())) } #[cfg(feature = "tls")] Transport::EncryptedTcp { ref host, .. } => { - builder.method(method).uri(&format!("{}{}", host, endpoint)) + builder + .method(method) + .uri(&format!("{}{}", host, endpoint.as_ref())) } #[cfg(feature = "unix-socket")] Transport::Unix { ref path, .. } => { - let uri: hyper::Uri = DomainUri::new(&path, endpoint).into(); - builder.method(method).uri(&uri.to_string()) + let uri = DomainUri::new(&path, endpoint.as_ref()); + builder.method(method).uri(uri) } }; - let req = req.header(header::HOST, ""); + let mut req = req.header(header::HOST, ""); if let Some(h) = headers { for (k, v) in h.into_iter() { - req.header(k, v); + req = req.header(k, v); } } @@ -188,19 +209,17 @@ impl Transport { } /// Send the given request to the docker daemon and return a Future of the response. - fn send_request( + async fn send_request( &self, req: Request<hyper::Body>, - ) -> impl Future<Item = hyper::Response<Body>, Error = Error> { - let req = match self { - Transport::Tcp { ref client, .. } => client.request(req), + ) -> Result<hyper::Response<Body>> { + match self { + Transport::Tcp { ref client, .. } => Ok(client.request(req).await?), #[cfg(feature = "tls")] - Transport::EncryptedTcp { ref client, .. } => client.request(req), + Transport::EncryptedTcp { ref client, .. } => Ok(client.request(req).await?), #[cfg(feature = "unix-socket")] - Transport::Unix { ref client, .. } => client.request(req), - }; - - req.map_err(Error::Hyper) + Transport::Unix { ref client, .. } => Ok(client.request(req).await?), + } } /// Makes an HTTP request, upgrading the connection to a TCP @@ -208,12 +227,12 @@ impl Transport { /// /// This method can be used for operations such as viewing /// docker container logs interactively. - pub fn stream_upgrade<B>( + async fn stream_upgrade_tokio<B>( &self, method: Method, - endpoint: &str, + endpoint: impl AsRef<str>, body: Option<(B, Mime)>, - ) -> impl Future<Item = impl AsyncRead + AsyncWrite, Error = Error> + ) -> Result<hyper::upgrade::Upgraded> where B: Into<Body>, { @@ -226,32 +245,37 @@ impl Transport { }; let req = self - .build_request(method, endpoint, body, None::<iter::Empty<_>>, |builder| { - builder + .build_request( + method, + endpoint, + body, + None::<iter::Empty<_>>, + Request::builder() .header(header::CONNECTION, "Upgrade") - .header(header::UPGRADE, "tcp"); - }) + .header(header::UPGRADE, "tcp"), + ) .expect("Failed to build request!"); - self.send_request(req) - .and_then(|res| match res.status() { - StatusCode::SWITCHING_PROTOCOLS => Ok(res), - _ => Err(Error::ConnectionNotUpgraded), - }) - .and_then(|res| res.into_body().on_upgrade().from_err()) + let response = self.send_request(req).await?; + + match response.status() { + StatusCode::SWITCHING_PROTOCOLS => Ok(response.into_body().on_upgrade().await?), + _ => Err(Error::ConnectionNotUpgraded), + } } - pub fn stream_upgrade_multiplexed<B>( + pub async fn stream_upgrade<B>( &self, method: Method, - endpoint: &str, + endpoint: impl AsRef<str>, body: Option<(B, Mime)>, - ) -> impl Future<Item = crate::tty::Multiplexed, Error = Error> + ) -> Result<impl AsyncRead + AsyncWrite> where - B: Into<Body> + 'static, + B: Into<Body>, { - self.stream_upgrade(method, endpoint, body) - .map(crate::tty::Multiplexed::new) + let tokio_multiplexer = self.stream_upgrade_tokio(method, endpoint, body).await?; + + Ok(Compat { tokio_multiplexer }) } /// Extract the error message content from an HTTP response that @@ -263,7 +287,61 @@ impl Transport { } } +#[pin_project] +struct Compat<S> { + #[pin] + tokio_multiplexer: S, +} + +impl<S> AsyncRead for Compat<S> +where + S: tokio::io::AsyncRead, +{ + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll<io::Result<usize>> { + self.project().tokio_multiplexer.poll_read(cx, buf) + } +} + +impl<S> AsyncWrite for Compat<S> +where + S: tokio::io::AsyncWrite, +{ + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll<io::Result<usize>> { + self.project().tokio_multiplexer.poll_write(cx, buf) + } + fn poll_flush( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll<io::Result<()>> { + self.project().tokio_multiplexer.poll_flush(cx) + } + fn poll_close( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll<io::Result<()>> { + self.project().tokio_multiplexer.poll_shutdown(cx) + } +} + #[derive(Serialize, Deserialize)] struct ErrorResponse { message: String, } + +fn stream_body(body: Body) -> impl Stream<Item = Result<Bytes>> { + async fn unfold(mut body: Body) -> Option<(Result<Bytes>, Body)> { + let chunk_result = body.next().await?.map_err(Error::from); + + Some((chunk_result, body)) + } + + futures_util::stream::unfold(body, unfold) +} |