diff --git a/hptp-recv/src/main.rs b/hptp-recv/src/main.rs index b29a5a0..d45dbda 100644 --- a/hptp-recv/src/main.rs +++ b/hptp-recv/src/main.rs @@ -7,6 +7,7 @@ extern crate thiserror; use hptp::logger::Logger; use hptp::msg::{DownMsg, UpMsg}; use hptp::peer::{self, DownPeer, Peer}; +use hptp::seg::SegIdx; use std::collections::HashMap; use tokio::io::AsyncWrite; @@ -44,11 +45,55 @@ async fn start(log: &mut Logger) -> Result<(), Error> { download(log, &mut peer, &mut out).await } -async fn send(peer: &mut DownPeer, m: DownMsg) -> Result<(), Error> { - match peer.send(m).await { - Ok(()) => Ok(()), - Err(peer::SendError::Io { source }) => Err(source.into()), - Err(peer::SendError::NoTarget) => unreachable!(), +struct SegmentSink<'o, OUT> { + out: &'o mut OUT, + segs: HashMap>>, + n_flushed: u32, +} + +#[derive(Copy, Clone, Eq, PartialEq, Debug)] +enum Put { + Duplicate, + Fresh, +} + +impl<'o, OUT> SegmentSink<'o, OUT> +where + OUT: AsyncWrite + Unpin, +{ + fn new(out: &'o mut OUT) -> Self { + SegmentSink { + out, + segs: HashMap::new(), + n_flushed: 0, + } + } + + async fn flush(&mut self) { + use tokio::io::AsyncWriteExt; + while let Some(cache) = self.segs.get_mut(&self.n_flushed) { + if let Some(data) = cache.take() { + self.out.write_all(&data).await.expect("god help us"); + self.out.flush().await.expect("god help us"); + } + self.n_flushed += 1; + } + } + + async fn put(&mut self, seg_idx: SegIdx, data: Vec) -> Put { + if seg_idx < self.n_flushed || self.segs.contains_key(&seg_idx) { + Put::Duplicate + } else { + self.segs.insert(seg_idx, Some(data)); + self.flush().await; + Put::Fresh + } + } + + fn ack_msg(&self) -> DownMsg { + DownMsg::Ack { + idxs: self.segs.keys().cloned().collect(), + } } } @@ -56,45 +101,50 @@ async fn download(log: &mut Logger, peer: &mut DownPeer, out: &mut OUT) -> where OUT: AsyncWrite + Unpin, { - let mut segs = HashMap::new(); - let mut flush_seg_idx = 0; + let mut sink = SegmentSink::new(out); + let mut to_send = vec![]; + loop { - match peer.recv().await { - Ok(UpMsg::Data { payload, seg_idx }) => { - if segs.contains_key(&seg_idx) { - log.recv_data_ignored(seg_idx as usize, payload.len()).await; - } else { - log.recv_data_accepted( - seg_idx as usize, - payload.len(), - hptp::logger::OutOfOrder, - ) - .await; - segs.insert(seg_idx, Some(payload)); - let ack = DownMsg::Ack { - idxs: segs.keys().cloned().collect(), - }; - log.debug_msg(format!("sent ack: {:?}", { - let mut idxs = segs.keys().collect::>(); - idxs.sort_unstable(); - idxs - })) - .await; - send(peer, ack).await?; + let msg = match peer.recv().await { + Ok(m) => m, + Err(peer::RecvError::InvalidMessage { .. }) => { + log.recv_corrupt().await; + continue; + } + Err(peer::RecvError::Io { source }) => { + return Err(source.into()); + } + }; + + match msg { + UpMsg::Data { payload, seg_idx } => { + let len = payload.len(); + match sink.put(seg_idx, payload).await { + Put::Duplicate => { + log.recv_data_ignored(seg_idx as usize, len).await; + } + + Put::Fresh => { + log.recv_data_accepted(seg_idx as usize, len, hptp::logger::OutOfOrder) + .await; + log.debug_msg(format!("sending acks: {:?}", { + let mut idxs = sink.segs.keys().cloned().collect::>(); + idxs.sort(); + idxs + })) + .await; + to_send.push(sink.ack_msg()); + } } } - - Err(peer::RecvError::InvalidMessage { .. }) => log.recv_corrupt().await, - Err(peer::RecvError::Io { source }) => return Err(source.into()), } - while let Some(v) = segs.get_mut(&flush_seg_idx) { - if let Some(payload) = v.take() { - use tokio::io::AsyncWriteExt; - out.write_all(&payload).await?; - out.flush().await?; + for m in to_send.drain(..) { + match peer.send(m).await { + Ok(()) => (), + Err(hptp::peer::SendError::NoTarget) => log.debug_msg("no target").await, + Err(hptp::peer::SendError::Io { source }) => return Err(source.into()), } - flush_seg_idx += 1; } } }