diff --git a/Cargo.lock b/Cargo.lock index 5059838..e238591 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -59,6 +59,7 @@ name = "hptp" version = "0.1.0" dependencies = [ "rand", + "thiserror", "tokio", ] diff --git a/hptp-recv/src/main.rs b/hptp-recv/src/main.rs index 7eba27b..6679126 100644 --- a/hptp-recv/src/main.rs +++ b/hptp-recv/src/main.rs @@ -6,7 +6,7 @@ extern crate thiserror; use hptp::log::Logger; use hptp::msg::Msg; -use hptp::peer::Peer; +use hptp::peer::{self, Peer}; use tokio::io::AsyncWrite; #[derive(Error, Debug)] @@ -47,6 +47,14 @@ async fn start(log: &mut Logger) -> Result<(), Error> { download(log, &mut peer, &mut out).await } +async fn send(peer: &mut Peer, m: Msg) -> Result<(), Error> { + match peer.send(m).await { + Ok(()) => Ok(()), + Err(peer::SendError::Io { source }) => Err(source.into()), + Err(peer::SendError::NoTarget) => unreachable!(), + } +} + async fn download(log: &mut Logger, peer: &mut Peer, out: &mut OUT) -> Result<(), Error> where OUT: AsyncWrite + Unpin, @@ -54,14 +62,24 @@ where use tokio::io::AsyncWriteExt; let mut pos = 0; loop { - match peer.recv().await? { - Msg::Data(data) => { + match peer.recv().await { + Ok(Msg::Ack) => log.debug_msg("not expected an ack...").await, + + Ok(Msg::Data(data)) => { let len = data.len(); out.write_all(&data).await?; out.flush().await?; log.recv_data_accepted(pos, len, hptp::log::InOrder).await; + send(peer, Msg::Ack).await?; pos += len; } + + Err(peer::RecvError::InvalidMessage { .. }) => { + log.debug_msg(format!("got an invalid message; discarding")) + .await + } + + Err(peer::RecvError::Io { source }) => return Err(source.into()), } } } diff --git a/hptp-send/src/main.rs b/hptp-send/src/main.rs index 11e5eb5..ce5e959 100644 --- a/hptp-send/src/main.rs +++ b/hptp-send/src/main.rs @@ -6,7 +6,7 @@ extern crate thiserror; use hptp::log::Logger; use hptp::msg::{self, Msg}; -use hptp::peer::Peer; +use hptp::peer::{self, Peer}; use std::net::SocketAddr; use tokio::io::AsyncRead; // use tokio::net::UdpSocket; @@ -83,6 +83,25 @@ where }) } +async fn send(peer: &mut Peer, m: Msg) -> Result<(), Error> { + match peer.send(m).await { + Ok(()) => Ok(()), + Err(peer::SendError::Io { source }) => Err(source.into()), + Err(peer::SendError::NoTarget) => panic!("tried to send w/ no target!"), + } +} + +async fn recv(log: &mut Logger, peer: &mut Peer) -> Result, Error> { + match peer.recv().await { + Ok(m) => Ok(Some(m)), + Err(peer::RecvError::Io { source }) => Err(source.into()), + Err(peer::RecvError::InvalidMessage { .. }) => { + log.debug_msg("invalid message; discarding").await; + Ok(None) + } + } +} + async fn upload(log: &mut Logger, peer: &mut Peer, inp: &mut IN) -> Result<(), Error> where IN: AsyncRead + Unpin, @@ -92,8 +111,13 @@ where match read_data(inp).await? { Some(data) => { let len = data.len(); - peer.send(Msg::Data(data)).await?; + send(peer, Msg::Data(data)).await?; log.send_data(pos, len).await; + if let Some(Msg::Ack) = recv(log, peer).await? { + log.debug_msg("got ack").await; + } else { + log.debug_msg("didn't get ack??").await; + } pos += len; } None => return Ok(()), diff --git a/hptp/Cargo.toml b/hptp/Cargo.toml index 4f06334..c9b2d59 100644 --- a/hptp/Cargo.toml +++ b/hptp/Cargo.toml @@ -9,4 +9,5 @@ edition = "2018" [dependencies] tokio = {version = "0.2.*", features = ["io-std", "io-util", "udp"]} -rand = "0.7.*" \ No newline at end of file +rand = "0.7.*" +thiserror = "*" \ No newline at end of file diff --git a/hptp/src/lib.rs b/hptp/src/lib.rs index 0851005..d3d937d 100644 --- a/hptp/src/lib.rs +++ b/hptp/src/lib.rs @@ -1,4 +1,6 @@ extern crate rand; +#[macro_use] +extern crate thiserror; pub mod log; pub mod msg; diff --git a/hptp/src/msg.rs b/hptp/src/msg.rs index e6839cc..fadf36e 100644 --- a/hptp/src/msg.rs +++ b/hptp/src/msg.rs @@ -1,19 +1,34 @@ #[derive(Clone)] pub enum Msg { Data(Vec), + Ack, } pub const MAX_DATA_SIZE: usize = 999; -pub const MAX_SERIALIZED_SIZE: usize = MAX_DATA_SIZE; +pub const MAX_SERIALIZED_SIZE: usize = 1 + MAX_DATA_SIZE; + +#[derive(Error, Debug)] +#[error("deserialization failed; malformed packet")] +pub struct DesError; impl Msg { - pub fn des(data: &[u8]) -> Msg { - Msg::Data(data.into()) + pub fn des(data: &[u8]) -> Result { + match data.first() { + Some(0) => Ok(Msg::Data(data[1..].into())), + Some(1) => Ok(Msg::Ack), + _ => Err(DesError), + } } pub fn ser(&self) -> Vec { + let mut buf = Vec::new(); match self { - Msg::Data(data) => data.clone(), + Msg::Data(data) => { + buf.push(0); + buf.extend_from_slice(data); + } + Msg::Ack => buf.push(1), } + buf } } diff --git a/hptp/src/peer.rs b/hptp/src/peer.rs index 082d16b..70cd96f 100644 --- a/hptp/src/peer.rs +++ b/hptp/src/peer.rs @@ -8,6 +8,31 @@ pub struct Peer { targ: Option, } +#[derive(Error, Debug)] +pub enum RecvError { + #[error("io error: {source}")] + Io { + #[from] + source: tokio::io::Error, + }, + #[error("{source}")] + InvalidMessage { + #[from] + source: msg::DesError, + }, +} + +#[derive(Error, Debug)] +pub enum SendError { + #[error("io error: {source}")] + Io { + #[from] + source: tokio::io::Error, + }, + #[error("no target to send to")] + NoTarget, +} + impl Peer { pub fn new(sock: UdpSocket) -> Self { Peer { sock, targ: None } @@ -17,17 +42,17 @@ impl Peer { self.targ = Some(addr); } - pub async fn send(&mut self, msg: Msg) -> Result<(), tokio::io::Error> { - let targ = self.targ.expect("no target to send to"); + pub async fn send(&mut self, msg: Msg) -> Result<(), SendError> { + let targ = self.targ.ok_or(SendError::NoTarget)?; let bs = msg.ser(); let _n_sent = self.sock.send_to(&bs, targ).await?; Ok(()) } - pub async fn recv(&mut self) -> Result { + pub async fn recv(&mut self) -> Result { let mut buf = [0u8; msg::MAX_SERIALIZED_SIZE]; let (len, who) = self.sock.recv_from(&mut buf).await?; self.set_known_target(who); - Ok(Msg::des(&buf[..len])) + Ok(Msg::des(&buf[..len])?) } }