diff --git a/hptp-recv/src/main.rs b/hptp-recv/src/main.rs index d45dbda..28b4f35 100644 --- a/hptp-recv/src/main.rs +++ b/hptp-recv/src/main.rs @@ -7,7 +7,7 @@ extern crate thiserror; use hptp::logger::Logger; use hptp::msg::{DownMsg, UpMsg}; use hptp::peer::{self, DownPeer, Peer}; -use hptp::seg::SegIdx; +use hptp::seg::{SegData, SegIdx}; use std::collections::HashMap; use tokio::io::AsyncWrite; @@ -47,8 +47,9 @@ async fn start(log: &mut Logger) -> Result<(), Error> { struct SegmentSink<'o, OUT> { out: &'o mut OUT, - segs: HashMap>>, + segs: HashMap>, n_flushed: u32, + complete: bool, } #[derive(Copy, Clone, Eq, PartialEq, Debug)] @@ -66,6 +67,7 @@ where out, segs: HashMap::new(), n_flushed: 0, + complete: false, } } @@ -73,14 +75,17 @@ where use tokio::io::AsyncWriteExt; while let Some(cache) = self.segs.get_mut(&self.n_flushed) { if let Some(data) = cache.take() { - self.out.write_all(&data).await.expect("god help us"); - self.out.flush().await.expect("god help us"); + data.write(self.out).await.expect("god help us"); + if data.is_last_segment { + self.complete = true; + } } self.n_flushed += 1; } + self.out.flush().await.expect("god help us"); } - async fn put(&mut self, seg_idx: SegIdx, data: Vec) -> Put { + async fn put(&mut self, seg_idx: SegIdx, data: SegData) -> Put { if seg_idx < self.n_flushed || self.segs.contains_key(&seg_idx) { Put::Duplicate } else { @@ -95,6 +100,10 @@ where idxs: self.segs.keys().cloned().collect(), } } + + fn is_file_complete(&self) -> bool { + self.complete + } } async fn download(log: &mut Logger, peer: &mut DownPeer, out: &mut OUT) -> Result<(), Error> @@ -104,6 +113,11 @@ where let mut sink = SegmentSink::new(out); let mut to_send = vec![]; + enum Action { + Continue, + Quit, + } + loop { let msg = match peer.recv().await { Ok(m) => m, @@ -116,6 +130,7 @@ where } }; + let act; match msg { UpMsg::Data { payload, seg_idx } => { let len = payload.len(); @@ -136,6 +151,11 @@ where to_send.push(sink.ack_msg()); } } + act = if sink.is_file_complete() { + Action::Quit + } else { + Action::Continue + } } } @@ -146,5 +166,13 @@ where Err(hptp::peer::SendError::Io { source }) => return Err(source.into()), } } + + match act { + Action::Continue => (), + Action::Quit => break, + } } + + log.completed().await; + Ok(()) } diff --git a/hptp-send/src/main.rs b/hptp-send/src/main.rs index 28f9170..f063028 100644 --- a/hptp-send/src/main.rs +++ b/hptp-send/src/main.rs @@ -7,7 +7,7 @@ extern crate thiserror; use hptp::logger::Logger; use hptp::msg::{DownMsg, UpMsg}; use hptp::peer::{self, Peer, UpPeer}; -use hptp::seg::SegIdx; +use hptp::seg::{SegData, SegIdx}; use std::collections::HashMap; use std::net::SocketAddr; use tokio::io::AsyncRead; @@ -71,10 +71,11 @@ struct SegmentSource<'i, IN> { inp: &'i mut IN, unacked_segs: HashMap, unacked_upper_bound: u32, + eof: bool, } struct UnAcked { - payload: Vec, + payload: SegData, nacks: usize, } @@ -87,35 +88,32 @@ where inp, unacked_segs: HashMap::new(), unacked_upper_bound: 0, + eof: false, } } - async fn read_segment(&mut self) -> Option<(SegIdx, &[u8])> { - use tokio::io::AsyncReadExt; - let mut buf = [0u8; hptp::seg::MAX_SEG_SIZE]; - let len = self.inp.read(&mut buf).await.unwrap_or(0); - if len > 0 { + async fn read_segment(&mut self) -> Option<(SegIdx, &SegData)> { + if self.eof { + None + } else { let seg_idx = self.unacked_upper_bound; + let payload = SegData::read(self.inp).await.expect("god help us"); + self.eof = payload.is_last_segment; self.unacked_upper_bound += 1; let unack = self.unacked_segs.entry(seg_idx); - let unack = unack.or_insert(UnAcked { - payload: buf[..len].into(), - nacks: 0, - }); - Some((seg_idx, unack.payload.as_ref())) - } else { - None + let unack = unack.or_insert(UnAcked { payload, nacks: 0 }); + Some((seg_idx, &unack.payload)) } } - async fn get_segment(&mut self) -> Option<(SegIdx, &[u8])> { + async fn get_segment(&mut self) -> Option<(SegIdx, &SegData)> { if self.unacked_segs.is_empty() { self.read_segment().await } else { self.unacked_segs .iter() .next() // get any entry, no ordering guarunteed or needed - .map(|(seg_idx, unack)| (*seg_idx, unack.payload.as_ref())) + .map(|(seg_idx, unack)| (*seg_idx, &unack.payload)) } } @@ -179,7 +177,7 @@ where log.send_data(seg_idx as usize, payload.len()).await; to_send.push(UpMsg::Data { seg_idx, - payload: Vec::from(payload), + payload: payload.clone(), }); Action::Continue } diff --git a/hptp/src/msg.rs b/hptp/src/msg.rs index 4eab760..9f7dc4e 100644 --- a/hptp/src/msg.rs +++ b/hptp/src/msg.rs @@ -1,17 +1,17 @@ -use super::seg::SegIdx; +use super::seg::{SegData, SegIdx}; pub use super::seg::{DOWN_HEADER_SIZE, MAX_TOTAL_PACKET_SIZE, UP_HEADER_SIZE}; use byteorder::ByteOrder; -#[derive(Clone, Debug)] +#[derive(Clone)] pub enum UpMsg { Data { - payload: Vec, + payload: SegData, seg_idx: SegIdx, // is_final_packet: bool, }, } -#[derive(Clone, Debug)] +#[derive(Clone)] pub enum DownMsg { /// `idxs` must be distinct and in increasing order. Ack { idxs: Vec }, @@ -30,24 +30,38 @@ pub trait SerDes: Sized { type BO = byteorder::LE; +const LAST_SEG_MASK: u32 = 1 << 31; + impl SerDes for UpMsg { fn des(buf: &[u8]) -> Result { if buf.len() < UP_HEADER_SIZE { Err(DesError) } else { + let hdr = BO::read_u32(&buf[0..4]); Ok(UpMsg::Data { - seg_idx: BO::read_u32(&buf[0..4]), - payload: buf[4..].into(), + seg_idx: hdr & !LAST_SEG_MASK, + payload: SegData { + bytes: buf[4..].into(), + is_last_segment: (hdr & LAST_SEG_MASK) != 0, + }, }) } } fn ser_to(&self, buf: &mut [u8]) -> usize { match self { - UpMsg::Data { payload, seg_idx } => { - let len = payload.len(); - BO::write_u32(&mut buf[0..4], *seg_idx); - buf[4..4 + len].copy_from_slice(&payload[..]); + UpMsg::Data { + payload: + SegData { + bytes, + is_last_segment, + }, + seg_idx, + } => { + let hdr = *seg_idx | if *is_last_segment { LAST_SEG_MASK } else { 0 }; + BO::write_u32(&mut buf[0..4], hdr); + let len = bytes.len(); + buf[4..4 + len].copy_from_slice(&bytes[..]); 4 + len } } diff --git a/hptp/src/seg.rs b/hptp/src/seg.rs index 040a4ed..b77e4f3 100644 --- a/hptp/src/seg.rs +++ b/hptp/src/seg.rs @@ -1,3 +1,5 @@ +use tokio::io::{AsyncRead, AsyncWrite}; + /// Per the assignment spec, `1472` is the maximum size packet we're allowed to send. pub const MAX_TOTAL_PACKET_SIZE: usize = 1472; @@ -45,3 +47,44 @@ pub struct SegmentSet { pub other_segs: std::collections::HashSet, } */ + +#[derive(Clone)] +pub struct SegData { + // TODO: encoding + pub(crate) bytes: Vec, + pub is_last_segment: bool, +} + +impl SegData { + pub fn len(&self) -> usize { + self.bytes.len() + } + + pub async fn read(inp: &mut IN) -> Result + where + IN: AsyncRead + Unpin, + { + use tokio::io::AsyncReadExt; + let mut buf = [0u8; MAX_SEG_SIZE]; + let len = inp.read(&mut buf).await.unwrap_or(0); + if len > 0 { + Ok(SegData { + bytes: Vec::from(&buf[..len]), + is_last_segment: false, + }) + } else { + Ok(SegData { + bytes: vec![], + is_last_segment: true, + }) + } + } + + pub async fn write(&self, out: &mut OUT) -> Result<(), tokio::io::Error> + where + OUT: AsyncWrite + Unpin, + { + use tokio::io::AsyncWriteExt; + out.write_all(&self.bytes).await + } +}