diff options
author | Carl Lerche <me@carllerche.com> | 2019-10-22 10:13:49 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2019-10-22 10:13:49 -0700 |
commit | cfc15617a5247ea780c32c85b7134b88b6de5845 (patch) | |
tree | ef0a46c61c51505a60f386c9760acac9d1f9b7b1 /examples/chat.rs | |
parent | b8cee1a60ad99ef28ec494ae4230e2ef4399fcf9 (diff) |
codec: move into tokio-util (#1675)
Related to #1318, Tokio APIs that are "less stable" are moved into a new
`tokio-util` crate. This crate will mirror `tokio` and provide
additional APIs that may require a greater rate of breaking changes.
As examples require `tokio-util`, they are moved into a separate
crate (`examples`). This has the added advantage of being able to avoid
example only dependencies in the `tokio` crate.
Diffstat (limited to 'examples/chat.rs')
-rw-r--r-- | examples/chat.rs | 261 |
1 files changed, 261 insertions, 0 deletions
diff --git a/examples/chat.rs b/examples/chat.rs new file mode 100644 index 00000000..0a3976d5 --- /dev/null +++ b/examples/chat.rs @@ -0,0 +1,261 @@ +//! A chat server that broadcasts a message to all connections. +//! +//! This example is explicitly more verbose than it has to be. This is to +//! illustrate more concepts. +//! +//! A chat server for telnet clients. After a telnet client connects, the first +//! line should contain the client's name. After that, all lines sent by a +//! client are broadcasted to all other connected clients. +//! +//! Because the client is telnet, lines are delimited by "\r\n". +//! +//! You can test this out by running: +//! +//! cargo run --example chat +//! +//! And then in another terminal run: +//! +//! telnet localhost 6142 +//! +//! You can run the `telnet` command in any number of additional windows. +//! +//! You can run the second command in multiple windows and then chat between the +//! two, seeing the messages from the other client as they're received. For all +//! connected clients they'll all join the same room and see everyone else's +//! messages. + +#![warn(rust_2018_idioms)] + +use tokio::net::{TcpListener, TcpStream}; +use tokio::sync::{mpsc, Mutex}; +use tokio_util::codec::{Framed, LinesCodec, LinesCodecError}; + +use futures::{Poll, SinkExt, Stream, StreamExt}; +use std::collections::HashMap; +use std::env; +use std::error::Error; +use std::io; +use std::net::SocketAddr; +use std::pin::Pin; +use std::sync::Arc; +use std::task::Context; + +#[tokio::main] +async fn main() -> Result<(), Box<dyn Error>> { + // Create the shared state. This is how all the peers communicate. + // + // The server task will hold a handle to this. For every new client, the + // `state` handle is cloned and passed into the task that processes the + // client connection. + let state = Arc::new(Mutex::new(Shared::new())); + + let addr = env::args().nth(1).unwrap_or("127.0.0.1:6142".to_string()); + + // Bind a TCP listener to the socket address. + // + // Note that this is the Tokio TcpListener, which is fully async. + let mut listener = TcpListener::bind(&addr).await?; + + println!("server running on {}", addr); + + loop { + // Asynchronously wait for an inbound TcpStream. + let (stream, addr) = listener.accept().await?; + + // Clone a handle to the `Shared` state for the new connection. + let state = Arc::clone(&state); + + // Spawn our handler to be run asynchronously. + tokio::spawn(async move { + if let Err(e) = process(state, stream, addr).await { + println!("an error occured; error = {:?}", e); + } + }); + } +} + +/// Shorthand for the transmit half of the message channel. +type Tx = mpsc::UnboundedSender<String>; + +/// Shorthand for the receive half of the message channel. +type Rx = mpsc::UnboundedReceiver<String>; + +/// Data that is shared between all peers in the chat server. +/// +/// This is the set of `Tx` handles for all connected clients. Whenever a +/// message is received from a client, it is broadcasted to all peers by +/// iterating over the `peers` entries and sending a copy of the message on each +/// `Tx`. +struct Shared { + peers: HashMap<SocketAddr, Tx>, +} + +/// The state for each connected client. +struct Peer { + /// The TCP socket wrapped with the `Lines` codec, defined below. + /// + /// This handles sending and receiving data on the socket. When using + /// `Lines`, we can work at the line level instead of having to manage the + /// raw byte operations. + lines: Framed<TcpStream, LinesCodec>, + + /// Receive half of the message channel. + /// + /// This is used to receive messages from peers. When a message is received + /// off of this `Rx`, it will be written to the socket. + rx: Rx, +} + +impl Shared { + /// Create a new, empty, instance of `Shared`. + fn new() -> Self { + Shared { + peers: HashMap::new(), + } + } + + /// Send a `LineCodec` encoded message to every peer, except + /// for the sender. + async fn broadcast( + &mut self, + sender: SocketAddr, + message: &str, + ) -> Result<(), mpsc::error::UnboundedSendError> { + for peer in self.peers.iter_mut() { + if *peer.0 != sender { + peer.1.send(message.into()).await?; + } + } + + Ok(()) + } +} + +impl Peer { + /// Create a new instance of `Peer`. + async fn new( + state: Arc<Mutex<Shared>>, + lines: Framed<TcpStream, LinesCodec>, + ) -> io::Result<Peer> { + // Get the client socket address + let addr = lines.get_ref().peer_addr()?; + + // Create a channel for this peer + let (tx, rx) = mpsc::unbounded_channel(); + + // Add an entry for this `Peer` in the shared state map. + state.lock().await.peers.insert(addr, tx); + + Ok(Peer { lines, rx }) + } +} + +#[derive(Debug)] +enum Message { + /// A message that should be broadcasted to others. + Broadcast(String), + + /// A message that should be received by a client + Received(String), +} + +// Peer implements `Stream` in a way that polls both the `Rx`, and `Framed` types. +// A message is produced whenever an event is ready until the `Framed` stream returns `None`. +impl Stream for Peer { + type Item = Result<Message, LinesCodecError>; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { + // First poll the `UnboundedReceiver`. + + if let Poll::Ready(Some(v)) = self.rx.poll_next_unpin(cx) { + return Poll::Ready(Some(Ok(Message::Received(v)))); + } + + // Secondly poll the `Framed` stream. + let result: Option<_> = futures::ready!(self.lines.poll_next_unpin(cx)); + + Poll::Ready(match result { + // We've received a message we should broadcast to others. + Some(Ok(message)) => Some(Ok(Message::Broadcast(message))), + + // An error occured. + Some(Err(e)) => Some(Err(e)), + + // The stream has been exhausted. + None => None, + }) + } +} + +/// Process an individual chat client +async fn process( + state: Arc<Mutex<Shared>>, + stream: TcpStream, + addr: SocketAddr, +) -> Result<(), Box<dyn Error>> { + let mut lines = Framed::new(stream, LinesCodec::new()); + + // Send a prompt to the client to enter their username. + lines + .send(String::from("Please enter your username:")) + .await?; + + // Read the first line from the `LineCodec` stream to get the username. + let username = match lines.next().await { + Some(Ok(line)) => line, + // We didn't get a line so we return early here. + _ => { + println!("Failed to get username from {}. Client disconnected.", addr); + return Ok(()); + } + }; + + // Register our peer with state which internally sets up some channels. + let mut peer = Peer::new(state.clone(), lines).await?; + + // A client has connected, let's let everyone know. + { + let mut state = state.lock().await; + let msg = format!("{} has joined the chat", username); + println!("{}", msg); + state.broadcast(addr, &msg).await?; + } + + // Process incoming messages until our stream is exhausted by a disconnect. + while let Some(result) = peer.next().await { + match result { + // A message was received from the current user, we should + // broadcast this message to the other users. + Ok(Message::Broadcast(msg)) => { + let mut state = state.lock().await; + let msg = format!("{}: {}", username, msg); + + state.broadcast(addr, &msg).await?; + } + // A message was received from a peer. Send it to the + // current user. + Ok(Message::Received(msg)) => { + peer.lines.send(msg).await?; + } + Err(e) => { + println!( + "an error occured while processing messages for {}; error = {:?}", + username, e + ); + } + } + } + + // If this section is reached it means that the client was disconnected! + // Let's let everyone still connected know about it. + { + let mut state = state.lock().await; + state.peers.remove(&addr); + + let msg = format!("{} has left the chat", username); + println!("{}", msg); + state.broadcast(addr, &msg).await?; + } + + Ok(()) +} |