From 6652ddb3d132a3f5e5ae7efc0dca1f37bd9fdb30 Mon Sep 17 00:00:00 2001 From: Milo Turner Date: Mon, 9 Mar 2020 14:11:12 -0400 Subject: [PATCH] separate types for upload/download messages --- hptp-recv/src/main.rs | 15 ++++++-------- hptp-send/src/main.rs | 13 ++++++------ hptp/src/msg.rs | 47 +++++++++++++++++++++++++++---------------- hptp/src/peer.rs | 33 +++++++++++++++++++++--------- 4 files changed, 66 insertions(+), 42 deletions(-) diff --git a/hptp-recv/src/main.rs b/hptp-recv/src/main.rs index b1b0ee6..d9b0184 100644 --- a/hptp-recv/src/main.rs +++ b/hptp-recv/src/main.rs @@ -5,8 +5,8 @@ extern crate tokio; extern crate thiserror; use hptp::logger::Logger; -use hptp::msg::Msg; -use hptp::peer::{self, Peer}; +use hptp::msg::{DownMsg, UpMsg}; +use hptp::peer::{self, DownPeer, Peer}; use tokio::io::AsyncWrite; #[derive(Error, Debug)] @@ -47,7 +47,7 @@ async fn start(log: &mut Logger) -> Result<(), Error> { download(log, &mut peer, &mut out).await } -async fn send(peer: &mut Peer, m: Msg) -> Result<(), Error> { +async fn send(peer: &mut DownPeer, m: DownMsg) -> Result<(), Error> { match peer.send(m).await { Ok(()) => Ok(()), Err(peer::SendError::Io { source }) => Err(source.into()), @@ -55,7 +55,7 @@ async fn send(peer: &mut Peer, m: Msg) -> Result<(), Error> { } } -async fn download(log: &mut Logger, peer: &mut Peer, out: &mut OUT) -> Result<(), Error> +async fn download(log: &mut Logger, peer: &mut DownPeer, out: &mut OUT) -> Result<(), Error> where OUT: AsyncWrite + Unpin, { @@ -63,20 +63,17 @@ where let mut pos = 0; loop { match peer.recv().await { - Ok(Msg::Ack) => log.debug_msg("not expecting an ack...").await, - - Ok(Msg::Data(data)) => { + Ok(UpMsg::Data(data)) => { let len = data.len(); out.write_all(&data).await?; out.flush().await?; log.recv_data_accepted(pos, len, hptp::logger::InOrder) .await; - send(peer, Msg::Ack).await?; + send(peer, DownMsg::Ack).await?; pos += len; } Err(peer::RecvError::InvalidMessage { .. }) => log.recv_corrupt().await, - Err(peer::RecvError::Io { source }) => return Err(source.into()), } } diff --git a/hptp-send/src/main.rs b/hptp-send/src/main.rs index e9ffbc8..9f8b9d6 100644 --- a/hptp-send/src/main.rs +++ b/hptp-send/src/main.rs @@ -5,8 +5,8 @@ extern crate tokio; extern crate thiserror; use hptp::logger::Logger; -use hptp::msg::{self, Msg}; -use hptp::peer::{self, Peer}; +use hptp::msg::{self, DownMsg, UpMsg}; +use hptp::peer::{self, Peer, UpPeer}; use std::net::SocketAddr; use tokio::io::AsyncRead; // use tokio::net::UdpSocket; @@ -83,7 +83,7 @@ where }) } -async fn send(peer: &mut Peer, m: Msg) -> Result<(), Error> { +async fn send(peer: &mut UpPeer, m: UpMsg) -> Result<(), Error> { match peer.send(m).await { Ok(()) => Ok(()), Err(peer::SendError::Io { source }) => Err(source.into()), @@ -91,7 +91,7 @@ async fn send(peer: &mut Peer, m: Msg) -> Result<(), Error> { } } -async fn upload(log: &mut Logger, peer: &mut Peer, inp: &mut IN) -> Result<(), Error> +async fn upload(log: &mut Logger, peer: &mut UpPeer, inp: &mut IN) -> Result<(), Error> where IN: AsyncRead + Unpin, { @@ -100,12 +100,11 @@ where loop { if let Some(ack_len) = next_ack_len { match peer.recv().await { - Ok(Msg::Ack) => { + Ok(DownMsg::Ack) => { log.recv_ack(pos).await; pos += ack_len; next_ack_len = None; } - Ok(_) => log.debug_msg("got some other packet??").await, Err(peer::RecvError::InvalidMessage { .. }) => log.recv_corrupt().await, Err(peer::RecvError::Io { source }) => return Err(source.into()), @@ -115,7 +114,7 @@ where Some(data) => { next_ack_len = Some(data.len()); log.send_data(pos, data.len()).await; - send(peer, Msg::Data(data)).await?; + send(peer, UpMsg::Data(data)).await?; } None => break, } diff --git a/hptp/src/msg.rs b/hptp/src/msg.rs index fadf36e..f6f2951 100644 --- a/hptp/src/msg.rs +++ b/hptp/src/msg.rs @@ -1,6 +1,10 @@ #[derive(Clone)] -pub enum Msg { +pub enum UpMsg { Data(Vec), +} + +#[derive(Clone)] +pub enum DownMsg { Ack, } @@ -11,24 +15,33 @@ pub const MAX_SERIALIZED_SIZE: usize = 1 + MAX_DATA_SIZE; #[error("deserialization failed; malformed packet")] pub struct DesError; -impl Msg { - pub fn des(data: &[u8]) -> Result { - match data.first() { - Some(0) => Ok(Msg::Data(data[1..].into())), - Some(1) => Ok(Msg::Ack), - _ => Err(DesError), +pub trait SerDes: Sized { + fn des(data: &[u8]) -> Result; + fn ser_into(self) -> Vec; +} + +impl SerDes for UpMsg { + fn des(data: &[u8]) -> Result { + Ok(UpMsg::Data(data.into())) + } + + fn ser_into(self) -> Vec { + match self { + UpMsg::Data(data) => data, + } + } +} + +impl SerDes for DownMsg { + fn des(data: &[u8]) -> Result { + if data == [0] { + Ok(DownMsg::Ack) + } else { + Err(DesError) } } - pub fn ser(&self) -> Vec { - let mut buf = Vec::new(); - match self { - Msg::Data(data) => { - buf.push(0); - buf.extend_from_slice(data); - } - Msg::Ack => buf.push(1), - } - buf + fn ser_into(self) -> Vec { + vec![0] } } diff --git a/hptp/src/peer.rs b/hptp/src/peer.rs index 70cd96f..47a5353 100644 --- a/hptp/src/peer.rs +++ b/hptp/src/peer.rs @@ -1,13 +1,18 @@ +use std::marker::PhantomData; use std::net::SocketAddr; use tokio::net::UdpSocket; -use super::msg::{self, Msg}; +use super::msg::{self, SerDes}; -pub struct Peer { +pub struct Peer { sock: UdpSocket, targ: Option, + _phantom: PhantomData T>, } +pub type UpPeer = Peer; +pub type DownPeer = Peer; + #[derive(Error, Debug)] pub enum RecvError { #[error("io error: {source}")] @@ -33,26 +38,36 @@ pub enum SendError { NoTarget, } -impl Peer { +impl Peer { pub fn new(sock: UdpSocket) -> Self { - Peer { sock, targ: None } + Peer { + sock, + targ: None, + _phantom: PhantomData, + } } pub fn set_known_target(&mut self, addr: SocketAddr) { self.targ = Some(addr); } - pub async fn send(&mut self, msg: Msg) -> Result<(), SendError> { + pub async fn send(&mut self, msg: F) -> Result<(), SendError> + where + F: SerDes, + { let targ = self.targ.ok_or(SendError::NoTarget)?; - let bs = msg.ser(); - let _n_sent = self.sock.send_to(&bs, targ).await?; + let data = msg.ser_into(); + let _n_sent = self.sock.send_to(&data, targ).await?; Ok(()) } - pub async fn recv(&mut self) -> Result { + pub async fn recv(&mut self) -> Result + where + T: SerDes, + { let mut buf = [0u8; msg::MAX_SERIALIZED_SIZE]; let (len, who) = self.sock.recv_from(&mut buf).await?; self.set_known_target(who); - Ok(Msg::des(&buf[..len])?) + Ok(T::des(&buf[..len])?) } }