From 6ed05e7d2357c01e2545f36316fc9262409483b6 Mon Sep 17 00:00:00 2001 From: Paul Woolcock Date: Tue, 29 Sep 2020 23:08:36 -0400 Subject: use the async reqwest client but present the same blocking api --- Cargo.toml | 2 ++ src/lib.rs | 86 ++++++++++++++++++++++++++++---------------------- src/macros.rs | 14 ++++---- src/mastodon_client.rs | 1 + src/page.rs | 10 +++--- src/registration.rs | 14 +++++--- 6 files changed, 72 insertions(+), 55 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 48bfacd..ad8b9c1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,6 +25,8 @@ url = "2.1.1" tap-reader = "1" toml = { version = "0.5.0", optional = true } tungstenite = "0.11.0" +async-trait = "0.1.40" +tokio = "0.2.22" [dependencies.chrono] version = "0.4" diff --git a/src/lib.rs b/src/lib.rs index 165c775..cdc0a3c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -73,12 +73,13 @@ use std::{borrow::Cow, io::BufRead, ops}; -use reqwest::blocking::{Client, RequestBuilder, Response}; -use tap_reader::Tap; +use reqwest::{Client, RequestBuilder, Response}; use tungstenite::client::AutoStream; use crate::{entities::prelude::*, page::Page}; +pub use isolang::Language; + pub use crate::{ data::Data, errors::{ApiError, Error, Result}, @@ -94,7 +95,6 @@ pub use crate::{ }, status_builder::{NewStatus, StatusBuilder}, }; -pub use isolang::Language; /// Registering your App pub mod apps; @@ -150,9 +150,10 @@ impl Mastodon { format!("{}{}", self.base, url) } - pub(crate) fn send(&self, req: RequestBuilder) -> Result { + pub(crate) fn send_blocking(&self, req: RequestBuilder) -> Result { let request = req.bearer_auth(&self.token).build()?; - Ok(self.client.execute(request)?) + let handle = tokio::runtime::Handle::current(); + Ok(handle.block_on(self.client.execute(request))?) } } @@ -167,6 +168,7 @@ impl From for Mastodon { } } +#[async_trait::async_trait] impl MastodonClient for Mastodon { type Stream = EventReader; @@ -241,7 +243,7 @@ impl MastodonClient for Mastodon { fn add_filter(&self, request: &mut AddFilterRequest) -> Result { let url = self.route("/api/v1/filters"); - let response = self.send(self.client.post(&url).json(&request))?; + let response = self.send_blocking(self.client.post(&url).json(&request))?; let status = response.status(); @@ -251,13 +253,13 @@ impl MastodonClient for Mastodon { return Err(Error::Server(status)); } - deserialise(response) + deserialise_blocking(response) } /// PUT /api/v1/filters/:id fn update_filter(&self, id: &str, request: &mut AddFilterRequest) -> Result { let url = self.route(&format!("/api/v1/filters/{}", id)); - let response = self.send(self.client.put(&url).json(&request))?; + let response = self.send_blocking(self.client.put(&url).json(&request))?; let status = response.status(); @@ -267,13 +269,13 @@ impl MastodonClient for Mastodon { return Err(Error::Server(status)); } - deserialise(response) + deserialise_blocking(response) } fn update_credentials(&self, builder: &mut UpdateCredsRequest) -> Result { let changes = builder.build()?; let url = self.route("/api/v1/accounts/update_credentials"); - let response = self.send(self.client.patch(&url).json(&changes))?; + let response = self.send_blocking(self.client.patch(&url).json(&changes))?; let status = response.status(); @@ -283,18 +285,18 @@ impl MastodonClient for Mastodon { return Err(Error::Server(status)); } - deserialise(response) + deserialise_blocking(response) } /// Post a new status to the account. fn new_status(&self, status: NewStatus) -> Result { - let response = self.send( + let response = self.send_blocking( self.client .post(&self.route("/api/v1/statuses")) .json(&status), )?; - deserialise(response) + deserialise_blocking(response) } /// Get timeline filtered by a hashtag(eg. `#coffee`) either locally or @@ -307,7 +309,7 @@ impl MastodonClient for Mastodon { self.route(&format!("{}{}", base, hashtag)) }; - Page::new(self, self.send(self.client.get(&url))?) + Page::new(self, self.send_blocking(self.client.get(&url))?) } /// Get statuses of a single account by id. Optionally only with pictures @@ -362,7 +364,7 @@ impl MastodonClient for Mastodon { url = format!("{}{}", url, request.to_querystring()?); } - let response = self.send(self.client.get(&url))?; + let response = self.send_blocking(self.client.get(&url))?; Page::new(self, response) } @@ -384,7 +386,7 @@ impl MastodonClient for Mastodon { url.pop(); } - let response = self.send(self.client.get(&url))?; + let response = self.send_blocking(self.client.get(&url))?; Page::new(self, response) } @@ -392,26 +394,26 @@ impl MastodonClient for Mastodon { /// Add a push notifications subscription fn add_push_subscription(&self, request: &AddPushRequest) -> Result { let request = request.build()?; - let response = self.send( + let response = self.send_blocking( self.client .post(&self.route("/api/v1/push/subscription")) .json(&request), )?; - deserialise(response) + deserialise_blocking(response) } /// Update the `data` portion of the push subscription associated with this /// access token fn update_push_data(&self, request: &UpdatePushRequest) -> Result { let request = request.build(); - let response = self.send( + let response = self.send_blocking( self.client .put(&self.route("/api/v1/push/subscription")) .json(&request), )?; - deserialise(response) + deserialise_blocking(response) } /// Get all accounts that follow the authenticated user @@ -621,9 +623,14 @@ impl MastodonClient for Mastodon { /// Equivalent to /api/v1/media fn media(&self, media_builder: MediaBuilder) -> Result { - use reqwest::blocking::multipart::Form; + use reqwest::multipart::{Form, Part}; + use std::{fs::File, io::Read}; - let mut form_data = Form::new().file("file", media_builder.file.as_ref())?; + let mut f = File::open(media_builder.file.as_ref())?; + let mut bytes = Vec::new(); + f.read_to_end(&mut bytes)?; + let part = Part::stream(bytes); + let mut form_data = Form::new().part("file", part); if let Some(description) = media_builder.description { form_data = form_data.text("description", description); @@ -634,7 +641,7 @@ impl MastodonClient for Mastodon { form_data = form_data.text("focus", string); } - let response = self.send( + let response = self.send_blocking( self.client .post(&self.route("/api/v1/media")) .multipart(form_data), @@ -648,7 +655,7 @@ impl MastodonClient for Mastodon { return Err(Error::Server(status)); } - deserialise(response) + deserialise_blocking(response) } } @@ -820,9 +827,10 @@ impl MastodonUnauth { Ok(self.base.join(url)?) } - fn send(&self, req: RequestBuilder) -> Result { + fn send_blocking(&self, req: RequestBuilder) -> Result { let req = req.build()?; - Ok(self.client.execute(req)?) + let handle = tokio::runtime::Handle::current(); + Ok(handle.block_on(self.client.execute(req))?) } /// Get a stream of the public timeline @@ -852,8 +860,8 @@ impl MastodonUnauthenticated for MastodonUnauth { fn get_status(&self, id: &str) -> Result { let route = self.route("/api/v1/statuses")?; let route = route.join(id)?; - let response = self.send(self.client.get(route))?; - deserialise(response) + let response = self.send_blocking(self.client.get(route))?; + deserialise_blocking(response) } /// GET /api/v1/statuses/:id/context @@ -861,8 +869,8 @@ impl MastodonUnauthenticated for MastodonUnauth { let route = self.route("/api/v1/statuses")?; let route = route.join(id)?; let route = route.join("context")?; - let response = self.send(self.client.get(route))?; - deserialise(response) + let response = self.send_blocking(self.client.get(route))?; + deserialise_blocking(response) } /// GET /api/v1/statuses/:id/card @@ -870,26 +878,28 @@ impl MastodonUnauthenticated for MastodonUnauth { let route = self.route("/api/v1/statuses")?; let route = route.join(id)?; let route = route.join("card")?; - let response = self.send(self.client.get(route))?; - deserialise(response) + let response = self.send_blocking(self.client.get(route))?; + deserialise_blocking(response) } } // Convert the HTTP response body from JSON. Pass up deserialization errors // transparently. -fn deserialise serde::Deserialize<'de>>(response: Response) -> Result { - let mut reader = Tap::new(response); +fn deserialise_blocking serde::Deserialize<'de>>(response: Response) -> Result { + let handle = tokio::runtime::Handle::current(); + + let bytes = handle.block_on(response.bytes())?; - match serde_json::from_reader(&mut reader) { + match serde_json::from_slice(&bytes) { Ok(t) => { - log::debug!("{}", String::from_utf8_lossy(&reader.bytes)); + log::debug!("{}", String::from_utf8_lossy(&bytes)); Ok(t) }, // If deserializing into the desired type fails try again to // see if this is an error response. Err(e) => { - log::error!("{}", String::from_utf8_lossy(&reader.bytes)); - if let Ok(error) = serde_json::from_slice(&reader.bytes) { + log::error!("{}", String::from_utf8_lossy(&bytes)); + if let Ok(error) = serde_json::from_slice(&bytes) { return Err(Error::Api(error)); } Err(e.into()) diff --git a/src/macros.rs b/src/macros.rs index b866339..ac46131 100644 --- a/src/macros.rs +++ b/src/macros.rs @@ -4,11 +4,11 @@ macro_rules! methods { fn $method serde::Deserialize<'de>>(&self, url: String) -> Result { - let response = self.send( + let response = self.send_blocking( self.client.$method(&url) )?; - deserialise(response) + deserialise_blocking(response) } )+ }; @@ -42,7 +42,7 @@ macro_rules! paged_routes { ), fn $name(&self) -> Result> { let url = self.route(concat!("/api/v1/", $url)); - let response = self.send( + let response = self.send_blocking( self.client.$method(&url) )?; @@ -88,7 +88,7 @@ macro_rules! paged_routes { let url = format!(concat!("/api/v1/", $url, "?{}"), &qs); - let response = self.send( + let response = self.send_blocking( self.client.get(&url) )?; @@ -199,7 +199,7 @@ macro_rules! route { )* }); - let response = self.send( + let response = self.send_blocking( self.client.$method(&self.route(concat!("/api/v1/", $url))) .json(&form_data) )?; @@ -212,7 +212,7 @@ macro_rules! route { return Err(Error::Server(status)); } - deserialise(response) + deserialise_blocking(response) } } @@ -317,7 +317,7 @@ macro_rules! paged_routes_with_id { ), fn $name(&self, id: &str) -> Result> { let url = self.route(&format!(concat!("/api/v1/", $url), id)); - let response = self.send( + let response = self.send_blocking( self.client.$method(&url) )?; diff --git a/src/mastodon_client.rs b/src/mastodon_client.rs index 3fc0739..d5e935b 100644 --- a/src/mastodon_client.rs +++ b/src/mastodon_client.rs @@ -18,6 +18,7 @@ use crate::{ /// Represents the set of methods that a Mastodon Client can do, so that /// implementations might be swapped out for testing #[allow(unused)] +#[async_trait::async_trait] pub trait MastodonClient { /// Type that wraps streaming API streams type Stream: Iterator; diff --git a/src/page.rs b/src/page.rs index a5efd56..c36b56a 100644 --- a/src/page.rs +++ b/src/page.rs @@ -1,7 +1,7 @@ -use super::{deserialise, Mastodon, Result}; +use super::{deserialise_blocking, Mastodon, Result}; use crate::entities::itemsiter::ItemsIter; use hyper_old_types::header::{parsing, Link, RelationType}; -use reqwest::{blocking::Response, header::LINK}; +use reqwest::{header::LINK, Response}; use serde::Deserialize; use url::Url; @@ -17,7 +17,7 @@ macro_rules! pages { None => return Ok(None), }; - let response = self.mastodon.send( + let response = self.mastodon.send_blocking( self.mastodon.client.get(url) )?; @@ -25,7 +25,7 @@ macro_rules! pages { self.next = next; self.prev = prev; - deserialise(response) + deserialise_blocking(response) }); )* } @@ -110,7 +110,7 @@ impl<'a, T: for<'de> Deserialize<'de>> Page<'a, T> { pub(crate) fn new(mastodon: &'a Mastodon, response: Response) -> Result { let (prev, next) = get_links(&response)?; Ok(Page { - initial_items: deserialise(response)?, + initial_items: deserialise_blocking(response)?, next, prev, mastodon, diff --git a/src/registration.rs b/src/registration.rs index 22228ad..dcc5306 100644 --- a/src/registration.rs +++ b/src/registration.rs @@ -1,6 +1,6 @@ use std::borrow::Cow; -use reqwest::blocking::{Client, RequestBuilder, Response}; +use reqwest::{Client, RequestBuilder, Response}; use serde::Deserialize; use std::convert::TryInto; @@ -99,7 +99,8 @@ impl<'a> Registration<'a> { fn send(&self, req: RequestBuilder) -> Result { let req = req.build()?; - Ok(self.client.execute(req)?) + let handle = tokio::runtime::Handle::current(); + Ok(handle.block_on(self.client.execute(req))?) } /// Register the given application @@ -178,7 +179,8 @@ impl<'a> Registration<'a> { fn send_app(&self, app: &App) -> Result { let url = format!("{}/api/v1/apps", self.base); - Ok(self.send(self.client.post(&url).json(&app))?.json()?) + let handle = tokio::runtime::Handle::current(); + Ok(handle.block_on(self.send(self.client.post(&url).json(&app))?.json())?) } } @@ -234,7 +236,8 @@ impl Registered { impl Registered { fn send(&self, req: RequestBuilder) -> Result { let req = req.build()?; - Ok(self.client.execute(req)?) + let handle = tokio::runtime::Handle::current(); + Ok(handle.block_on(self.client.execute(req))?) } /// Returns the parts of the `Registered` struct that can be used to @@ -309,7 +312,8 @@ impl Registered { self.base, self.client_id, self.client_secret, code, self.redirect ); - let token: AccessToken = self.send(self.client.post(&url))?.json()?; + let handle = tokio::runtime::Handle::current(); + let token: AccessToken = handle.block_on(self.send(self.client.post(&url))?.json())?; let data = Data { base: self.base.clone().into(), -- cgit v1.2.3