diff options
author | Paul Woolcock <paul@woolcock.us> | 2020-10-07 05:47:39 -0400 |
---|---|---|
committer | Paul Woolcock <paul@woolcock.us> | 2020-10-07 09:06:13 -0400 |
commit | 02ca0a89515413ac9fb3b655de2f21f6a711e0f2 (patch) | |
tree | 004bcd9f88eca168e10e1ac85c5987fdd6769fcf /src/async | |
parent | 04b5b54212629f058bdab1ba55c89a3d417e0454 (diff) |
Add basic async client
This adds a module, accessible by compiling with `--features async`,
that provides an `elefren::async::Client`. The client is
runtime-agnostic, and currently only provides unauthenticated access,
see the docs for the full list of methods that can be performed* with
this client.
* note that some API calls are publicly available by default, but can be
changed via instance settings to not be publicly accessible
Diffstat (limited to 'src/async')
-rw-r--r-- | src/async/auth.rs | 46 | ||||
-rw-r--r-- | src/async/client.rs | 52 | ||||
-rw-r--r-- | src/async/mod.rs | 263 | ||||
-rw-r--r-- | src/async/page.rs | 95 |
4 files changed, 456 insertions, 0 deletions
diff --git a/src/async/auth.rs b/src/async/auth.rs new file mode 100644 index 0000000..6f593c4 --- /dev/null +++ b/src/async/auth.rs @@ -0,0 +1,46 @@ +//! Authentication mechanisms for async client +use async_mutex::Mutex; +use std::cell::RefCell; + +use crate::{ + entities::{account::Account, card::Card, context::Context, status::Status}, + errors::{Error, Result}, + requests::StatusesRequest, +}; +use http_types::{Method, Request, Response}; +use hyper_old_types::header::{parsing, Link, RelationType}; +use serde::Serialize; +use smol::{prelude::*, Async}; +use std::net::{TcpStream, ToSocketAddrs}; +use url::Url; + +/// strategies for authenticating mastodon requests need to implement this trait +#[async_trait::async_trait] +pub trait Authenticate { + async fn authenticate(&self, request: &mut Request) -> Result<()>; +} + +/// The null-strategy, will only allow the client to call public API endpoints +#[derive(Debug, Clone, Copy, PartialEq)] +pub struct Unauthenticated; +#[async_trait::async_trait] +impl Authenticate for Unauthenticated { + async fn authenticate(&self, _: &mut Request) -> Result<()> { + Ok(()) + } +} + +/// Authenticates to the server via oauth +#[derive(Debug, Clone, PartialEq)] +pub struct OAuth { + client_id: String, + client_secret: String, + redirect: String, + token: String, +} +#[async_trait::async_trait] +impl Authenticate for Mutex<RefCell<Option<OAuth>>> { + async fn authenticate(&self, _: &mut Request) -> Result<()> { + unimplemented!() + } +} diff --git a/src/async/client.rs b/src/async/client.rs new file mode 100644 index 0000000..5081c6f --- /dev/null +++ b/src/async/client.rs @@ -0,0 +1,52 @@ +use crate::{ + entities::{account::Account, card::Card, context::Context, status::Status}, + errors::{Error, Result}, +}; +use http_types::{Method, Request, Response}; +use hyper_old_types::header::{parsing, Link, RelationType}; +use smol::{prelude::*, Async}; +use std::net::{TcpStream, ToSocketAddrs}; +use url::Url; + +// taken pretty much verbatim from `smol`s example + +/// Sends a request and fetches the response. +pub(super) async fn fetch(req: Request) -> Result<Response> { + // Figure out the host and the port. + let host = req + .url() + .host() + .ok_or_else(|| String::from("No host found"))? + .to_string(); + let port = req + .url() + .port_or_known_default() + .ok_or_else(|| Error::Other(String::from("No port found")))?; + + // Connect to the host. + let socket_addr = { + let host = host.clone(); + smol::unblock(move || (host.as_str(), port).to_socket_addrs()) + .await? + .next() + .ok_or_else(|| Error::Other(String::from("No socket addr")))? + }; + let stream = Async::<TcpStream>::connect(socket_addr).await?; + + // Send the request and wait for the response. + let resp = match req.url().scheme() { + "http" => async_h1::connect(stream, req).await?, + "https" => { + // In case of HTTPS, establish a secure TLS connection first. + let stream = async_native_tls::connect(&host, stream).await?; + async_h1::connect(stream, req).await? + }, + scheme => return Err(Error::Other(format!("unsupported scheme '{}'", scheme))), + }; + Ok(resp) +} + +pub(super) async fn get(url: Url) -> Result<Response> { + let req = Request::new(Method::Get, url); + Ok(fetch(req).await?) +} diff --git a/src/async/mod.rs b/src/async/mod.rs new file mode 100644 index 0000000..d75e131 --- /dev/null +++ b/src/async/mod.rs @@ -0,0 +1,263 @@ +//! Async Mastodon Client +//! +//! # Example +//! +//! ```rust,no_run +//! use elefren::r#async::Client; +//! use url::Url; +//! +//! # fn main() -> Result<(), Box<dyn std::error::Error>> { +//! # smol::block_on(async { +//! let client = Client::new("https://mastodon.social")?; +//! +//! // iterate page-by-page +//! // this API isn't ideal, but one day we'll get better +//! // syntax support for iterating over streams and we can +//! // do better +//! let mut pages = client.public_timeline(None).await?; +//! while let Some(statuses) = pages.next_page().await? { +//! for status in statuses { +//! println!("{:?}", status); +//! } +//! } +//! # Ok(()) +//! # }) +//! } +//! ``` +#![allow(warnings)] +#![allow(missing_docs)] +use crate::{ + entities::{ + account::Account, + activity::Activity, + card::Card, + context::Context, + instance::Instance, + poll::Poll, + status::{Emoji, Status, Tag}, + }, + errors::{Error, Result}, + requests::{DirectoryRequest, StatusesRequest}, +}; +use http_types::{Method, Request, Response}; +use std::fmt::Debug; +use url::Url; + +pub use auth::Authenticate; +use auth::{OAuth, Unauthenticated}; +pub use page::Page; + +mod auth; +mod client; +mod page; + +/// Async unauthenticated client +#[derive(Debug)] +pub struct Client<A: Debug + Authenticate> { + base_url: Url, + auth: A, +} +impl Client<Unauthenticated> { + pub fn new<S: AsRef<str>>(base_url: S) -> Result<Client<Unauthenticated>> { + let base_url = Url::parse(base_url.as_ref())?; + Ok(Client { + base_url, + auth: Unauthenticated, + }) + } +} +impl<A: Debug + Authenticate> Client<A> { + async fn send(&self, mut req: Request) -> Result<Response> { + self.auth.authenticate(&mut req).await?; + Ok(client::fetch(req).await?) + } + + /// GET /api/v1/timelines/public + pub async fn public_timeline<'a, 'client: 'a, I: Into<Option<StatusesRequest<'a>>>>( + &'client self, + opts: I, + ) -> Result<Page<'client, Status, A>> { + let mut url = self.base_url.join("api/v1/timelines/public")?; + if let Some(opts) = opts.into() { + let qs = opts.to_querystring()?; + url.set_query(Some(&qs[..])); + }; + Ok(Page::new(Request::new(Method::Get, url), &self.auth)) + } + + /// GET /api/v1/timelines/tag/:tag + pub async fn hashtag_timeline<'a, 'client: 'a, I: Into<Option<StatusesRequest<'a>>>>( + &'client self, + tag: &str, + opts: I, + ) -> Result<Page<'client, Status, A>> { + let mut url = self + .base_url + .join(&format!("api/v1/timelines/tag/{}", tag))?; + if let Some(opts) = opts.into() { + let qs = opts.to_querystring()?; + url.set_query(Some(&qs[..])); + } + Ok(Page::new(Request::new(Method::Get, url), &self.auth)) + } + + /// GET /api/v1/statuses/:id + pub async fn status(&self, id: &str) -> Result<Status> { + let url = self.base_url.join(&format!("api/v1/statuses/{}", id))?; + let response = self.send(Request::new(Method::Get, url)).await?; + Ok(deserialize(response).await?) + } + + /// GET /api/v1/statuses/:id/context + pub async fn context(&self, id: &str) -> Result<Context> { + let url = self + .base_url + .join(&format!("api/v1/statuses/{}/context", id))?; + let response = self.send(Request::new(Method::Get, url)).await?; + Ok(deserialize(response).await?) + } + + /// GET /api/v1/statuses/:id/card + pub async fn card(&self, id: &str) -> Result<Card> { + let url = self + .base_url + .join(&format!("api/v1/statuses/{}/card", id))?; + let response = self.send(Request::new(Method::Get, url)).await?; + Ok(deserialize(response).await?) + } + + /// GET /api/v1/statuses/:id/reblogged_by + pub async fn reblogged_by<'client>( + &'client self, + id: &str, + ) -> Result<Page<'client, Account, A>> { + let url = self + .base_url + .join(&format!("api/v1/statuses/{}/reblogged_by", id))?; + Ok(Page::new(Request::new(Method::Get, url), &self.auth)) + } + + /// GET /api/v1/statuses/:id/favourited_by + pub async fn favourited_by<'client>( + &'client self, + id: &str, + ) -> Result<Page<'client, Account, A>> { + let url = self + .base_url + .join(&format!("api/v1/statuses/{}/favourited_by", id))?; + Ok(Page::new(Request::new(Method::Get, url), &self.auth)) + } + + /// GET /api/v1/accounts/:id + pub async fn account(&self, id: &str) -> Result<Account> { + let url = self.base_url.join(&format!("api/v1/accounts/{}", id))?; + let response = self.send(Request::new(Method::Get, url)).await?; + Ok(deserialize(response).await?) + } + + /// GET /api/v1/accounts/:id/statuses + pub async fn account_statuses<'a, 'client: 'a, I: Into<Option<StatusesRequest<'a>>>>( + &'client self, + id: &str, + request: I, + ) -> Result<Page<'client, Status, A>> { + let mut url = self + .base_url + .join(&format!("api/v1/accounts/{}/statuses", id))?; + if let Some(request) = request.into() { + let qs = request.to_querystring()?; + url.set_query(Some(&qs[..])); + } + Ok(Page::new(Request::new(Method::Get, url), &self.auth)) + } + + /// GET /api/v1/polls/:id + pub async fn poll(&self, id: &str) -> Result<Poll> { + let url = self.base_url.join(&format!("api/v1/polls/{}", id))?; + let response = self.send(Request::new(Method::Get, url)).await?; + Ok(deserialize(response).await?) + } + + /// GET /api/v1/instance + pub async fn instance(&self) -> Result<Instance> { + let url = self.base_url.join("api/v1/instance")?; + let response = self.send(Request::new(Method::Get, url)).await?; + Ok(deserialize(response).await?) + } + + /// GET /api/v1/instance/peers + pub async fn peers(&self) -> Result<Vec<String>> { + let url = self.base_url.join("api/v1/instance/peers")?; + let response = self.send(Request::new(Method::Get, url)).await?; + Ok(deserialize(response).await?) + } + + /// GET /api/v1/instance/activity + pub async fn activity(&self) -> Result<Option<Vec<Activity>>> { + let url = self.base_url.join("api/v1/instance/activity")?; + let response = self.send(Request::new(Method::Get, url)).await?; + Ok(deserialize(response).await?) + } + + /// GET /api/v1/custom_emojis + pub async fn custom_emojis(&self) -> Result<Vec<Emoji>> { + let url = self.base_url.join("api/v1/custom_emojis")?; + let response = self.send(Request::new(Method::Get, url)).await?; + Ok(deserialize(response).await?) + } + + /// GET /api/v1/directory + pub async fn directory<'a, I: Into<Option<DirectoryRequest<'a>>>>( + &self, + opts: I, + ) -> Result<Vec<Account>> { + let mut url = self.base_url.join("api/v1/directory")?; + if let Some(opts) = opts.into() { + let qs = opts.to_querystring()?; + url.set_query(Some(&qs[..])); + } + let response = self.send(Request::new(Method::Get, url)).await?; + Ok(deserialize(response).await?) + } + + /// GET /api/v1/trends + pub async fn trends<I: Into<Option<usize>>>(&self, limit: I) -> Result<Vec<Tag>> { + let mut url = self.base_url.join("api/v1/trends")?; + if let Some(limit) = limit.into() { + url.set_query(Some(&format!("?limit={}", limit))); + } + let response = self.send(Request::new(Method::Get, url)).await?; + Ok(deserialize(response).await?) + } +} + +async fn deserialize<T: serde::de::DeserializeOwned>(mut response: Response) -> Result<T> { + let status = response.status(); + if status.is_client_error() { + // TODO + // return Err(Error::Client(status)); + return Err(Error::Other(String::from("4xx status code"))); + } else if status.is_server_error() { + // TODO + // return Err(Error::Server(status)) // TODO + return Err(Error::Other(String::from("5xx status code"))); + } else if status.is_redirection() || status.is_informational() { + return Err(Error::Other(String::from("3xx or 1xx status code"))); + } + let bytes = response.body_bytes().await?; + Ok(match serde_json::from_slice::<T>(&bytes) { + Ok(t) => { + log::debug!("{}", String::from_utf8_lossy(&bytes)); + t + }, + Err(e) => { + log::error!("{}", String::from_utf8_lossy(&bytes)); + let err = if let Ok(error) = serde_json::from_slice(&bytes) { + Error::Api(error) + } else { + e.into() + }; + return Err(err); + }, + }) +} diff --git a/src/async/page.rs b/src/async/page.rs new file mode 100644 index 0000000..4ef4161 --- /dev/null +++ b/src/async/page.rs @@ -0,0 +1,95 @@ +use super::{client, deserialize, Authenticate}; +use crate::{ + entities::{account::Account, card::Card, context::Context, status::Status}, + errors::{Error, Result}, +}; +use http_types::{Method, Request, Response}; +use hyper_old_types::header::{parsing, Link, RelationType}; +use smol::{prelude::*, Async}; +use std::{ + fmt::Debug, + net::{TcpStream, ToSocketAddrs}, +}; +use url::Url; + +// link header name +const LINK: &str = "link"; + +#[derive(Debug)] +pub struct Page<'client, T, A: Authenticate + Debug + 'client> { + next: Option<Request>, + prev: Option<Request>, + auth: &'client A, + _marker: std::marker::PhantomData<T>, +} +impl<'client, T: serde::de::DeserializeOwned, A: Authenticate + Debug + 'client> + Page<'client, T, A> +{ + pub fn new(next: Request, auth: &'client A) -> Page<'client, T, A> { + Page { + next: Some(next), + prev: None, + auth, + _marker: std::marker::PhantomData, + } + } + + pub async fn next_page(&mut self) -> Result<Option<Vec<T>>> { + let mut req = if let Some(next) = self.next.take() { + next + } else { + return Ok(None); + }; + Ok(self.send(req).await?) + } + + pub async fn prev_page(&mut self) -> Result<Option<Vec<T>>> { + let req = if let Some(prev) = self.prev.take() { + prev + } else { + return Ok(None); + }; + Ok(self.send(req).await?) + } + + async fn send(&mut self, mut req: Request) -> Result<Option<Vec<T>>> { + self.auth.authenticate(&mut req).await?; + log::trace!("Request: {:?}", req); + let response = client::fetch(req).await?; + log::trace!("Response: {:?}", response); + self.fill_links_from_resp(&response)?; + let items = deserialize(response).await?; + Ok(items) + } + + fn fill_links_from_resp(&mut self, response: &Response) -> Result<()> { + let (prev, next) = get_links(&response)?; + self.prev = prev.map(|url| Request::new(Method::Get, url)); + self.next = next.map(|url| Request::new(Method::Get, url)); + Ok(()) + } +} + +fn get_links(response: &Response) -> Result<(Option<Url>, Option<Url>)> { + let mut prev = None; + let mut next = None; + + if let Some(link_header) = response.header(LINK) { + let link_header = link_header.as_str(); + let link_header = link_header.as_bytes(); + let link_header: Link = parsing::from_raw_str(&link_header)?; + for value in link_header.values() { + if let Some(relations) = value.rel() { + if relations.contains(&RelationType::Next) { + next = Some(Url::parse(value.link())?); + } + + if relations.contains(&RelationType::Prev) { + prev = Some(Url::parse(value.link())?); + } + } + } + } + + Ok((prev, next)) +} |