segment data in its own type called SegData; handles EOF (badly)

This commit is contained in:
Milo Turner 2020-03-11 13:13:07 -04:00
parent aefc4d73ab
commit f858cbc398
4 changed files with 115 additions and 32 deletions

View File

@ -7,7 +7,7 @@ extern crate thiserror;
use hptp::logger::Logger; use hptp::logger::Logger;
use hptp::msg::{DownMsg, UpMsg}; use hptp::msg::{DownMsg, UpMsg};
use hptp::peer::{self, DownPeer, Peer}; use hptp::peer::{self, DownPeer, Peer};
use hptp::seg::SegIdx; use hptp::seg::{SegData, SegIdx};
use std::collections::HashMap; use std::collections::HashMap;
use tokio::io::AsyncWrite; use tokio::io::AsyncWrite;
@ -47,8 +47,9 @@ async fn start(log: &mut Logger) -> Result<(), Error> {
struct SegmentSink<'o, OUT> { struct SegmentSink<'o, OUT> {
out: &'o mut OUT, out: &'o mut OUT,
segs: HashMap<SegIdx, Option<Vec<u8>>>, segs: HashMap<SegIdx, Option<SegData>>,
n_flushed: u32, n_flushed: u32,
complete: bool,
} }
#[derive(Copy, Clone, Eq, PartialEq, Debug)] #[derive(Copy, Clone, Eq, PartialEq, Debug)]
@ -66,6 +67,7 @@ where
out, out,
segs: HashMap::new(), segs: HashMap::new(),
n_flushed: 0, n_flushed: 0,
complete: false,
} }
} }
@ -73,14 +75,17 @@ where
use tokio::io::AsyncWriteExt; use tokio::io::AsyncWriteExt;
while let Some(cache) = self.segs.get_mut(&self.n_flushed) { while let Some(cache) = self.segs.get_mut(&self.n_flushed) {
if let Some(data) = cache.take() { if let Some(data) = cache.take() {
self.out.write_all(&data).await.expect("god help us"); data.write(self.out).await.expect("god help us");
self.out.flush().await.expect("god help us"); if data.is_last_segment {
self.complete = true;
}
} }
self.n_flushed += 1; self.n_flushed += 1;
} }
self.out.flush().await.expect("god help us");
} }
async fn put(&mut self, seg_idx: SegIdx, data: Vec<u8>) -> Put { async fn put(&mut self, seg_idx: SegIdx, data: SegData) -> Put {
if seg_idx < self.n_flushed || self.segs.contains_key(&seg_idx) { if seg_idx < self.n_flushed || self.segs.contains_key(&seg_idx) {
Put::Duplicate Put::Duplicate
} else { } else {
@ -95,6 +100,10 @@ where
idxs: self.segs.keys().cloned().collect(), idxs: self.segs.keys().cloned().collect(),
} }
} }
fn is_file_complete(&self) -> bool {
self.complete
}
} }
async fn download<OUT>(log: &mut Logger, peer: &mut DownPeer, out: &mut OUT) -> Result<(), Error> async fn download<OUT>(log: &mut Logger, peer: &mut DownPeer, out: &mut OUT) -> Result<(), Error>
@ -104,6 +113,11 @@ where
let mut sink = SegmentSink::new(out); let mut sink = SegmentSink::new(out);
let mut to_send = vec![]; let mut to_send = vec![];
enum Action {
Continue,
Quit,
}
loop { loop {
let msg = match peer.recv().await { let msg = match peer.recv().await {
Ok(m) => m, Ok(m) => m,
@ -116,6 +130,7 @@ where
} }
}; };
let act;
match msg { match msg {
UpMsg::Data { payload, seg_idx } => { UpMsg::Data { payload, seg_idx } => {
let len = payload.len(); let len = payload.len();
@ -136,6 +151,11 @@ where
to_send.push(sink.ack_msg()); to_send.push(sink.ack_msg());
} }
} }
act = if sink.is_file_complete() {
Action::Quit
} else {
Action::Continue
}
} }
} }
@ -146,5 +166,13 @@ where
Err(hptp::peer::SendError::Io { source }) => return Err(source.into()), Err(hptp::peer::SendError::Io { source }) => return Err(source.into()),
} }
} }
match act {
Action::Continue => (),
Action::Quit => break,
}
} }
log.completed().await;
Ok(())
} }

View File

@ -7,7 +7,7 @@ extern crate thiserror;
use hptp::logger::Logger; use hptp::logger::Logger;
use hptp::msg::{DownMsg, UpMsg}; use hptp::msg::{DownMsg, UpMsg};
use hptp::peer::{self, Peer, UpPeer}; use hptp::peer::{self, Peer, UpPeer};
use hptp::seg::SegIdx; use hptp::seg::{SegData, SegIdx};
use std::collections::HashMap; use std::collections::HashMap;
use std::net::SocketAddr; use std::net::SocketAddr;
use tokio::io::AsyncRead; use tokio::io::AsyncRead;
@ -71,10 +71,11 @@ struct SegmentSource<'i, IN> {
inp: &'i mut IN, inp: &'i mut IN,
unacked_segs: HashMap<SegIdx, UnAcked>, unacked_segs: HashMap<SegIdx, UnAcked>,
unacked_upper_bound: u32, unacked_upper_bound: u32,
eof: bool,
} }
struct UnAcked { struct UnAcked {
payload: Vec<u8>, payload: SegData,
nacks: usize, nacks: usize,
} }
@ -87,35 +88,32 @@ where
inp, inp,
unacked_segs: HashMap::new(), unacked_segs: HashMap::new(),
unacked_upper_bound: 0, unacked_upper_bound: 0,
eof: false,
} }
} }
async fn read_segment(&mut self) -> Option<(SegIdx, &[u8])> { async fn read_segment(&mut self) -> Option<(SegIdx, &SegData)> {
use tokio::io::AsyncReadExt; if self.eof {
let mut buf = [0u8; hptp::seg::MAX_SEG_SIZE]; None
let len = self.inp.read(&mut buf).await.unwrap_or(0); } else {
if len > 0 {
let seg_idx = self.unacked_upper_bound; let seg_idx = self.unacked_upper_bound;
let payload = SegData::read(self.inp).await.expect("god help us");
self.eof = payload.is_last_segment;
self.unacked_upper_bound += 1; self.unacked_upper_bound += 1;
let unack = self.unacked_segs.entry(seg_idx); let unack = self.unacked_segs.entry(seg_idx);
let unack = unack.or_insert(UnAcked { let unack = unack.or_insert(UnAcked { payload, nacks: 0 });
payload: buf[..len].into(), Some((seg_idx, &unack.payload))
nacks: 0,
});
Some((seg_idx, unack.payload.as_ref()))
} else {
None
} }
} }
async fn get_segment(&mut self) -> Option<(SegIdx, &[u8])> { async fn get_segment(&mut self) -> Option<(SegIdx, &SegData)> {
if self.unacked_segs.is_empty() { if self.unacked_segs.is_empty() {
self.read_segment().await self.read_segment().await
} else { } else {
self.unacked_segs self.unacked_segs
.iter() .iter()
.next() // get any entry, no ordering guarunteed or needed .next() // get any entry, no ordering guarunteed or needed
.map(|(seg_idx, unack)| (*seg_idx, unack.payload.as_ref())) .map(|(seg_idx, unack)| (*seg_idx, &unack.payload))
} }
} }
@ -179,7 +177,7 @@ where
log.send_data(seg_idx as usize, payload.len()).await; log.send_data(seg_idx as usize, payload.len()).await;
to_send.push(UpMsg::Data { to_send.push(UpMsg::Data {
seg_idx, seg_idx,
payload: Vec::from(payload), payload: payload.clone(),
}); });
Action::Continue Action::Continue
} }

View File

@ -1,17 +1,17 @@
use super::seg::SegIdx; use super::seg::{SegData, SegIdx};
pub use super::seg::{DOWN_HEADER_SIZE, MAX_TOTAL_PACKET_SIZE, UP_HEADER_SIZE}; pub use super::seg::{DOWN_HEADER_SIZE, MAX_TOTAL_PACKET_SIZE, UP_HEADER_SIZE};
use byteorder::ByteOrder; use byteorder::ByteOrder;
#[derive(Clone, Debug)] #[derive(Clone)]
pub enum UpMsg { pub enum UpMsg {
Data { Data {
payload: Vec<u8>, payload: SegData,
seg_idx: SegIdx, seg_idx: SegIdx,
// is_final_packet: bool, // is_final_packet: bool,
}, },
} }
#[derive(Clone, Debug)] #[derive(Clone)]
pub enum DownMsg { pub enum DownMsg {
/// `idxs` must be distinct and in increasing order. /// `idxs` must be distinct and in increasing order.
Ack { idxs: Vec<SegIdx> }, Ack { idxs: Vec<SegIdx> },
@ -30,24 +30,38 @@ pub trait SerDes: Sized {
type BO = byteorder::LE; type BO = byteorder::LE;
const LAST_SEG_MASK: u32 = 1 << 31;
impl SerDes for UpMsg { impl SerDes for UpMsg {
fn des(buf: &[u8]) -> Result<Self, DesError> { fn des(buf: &[u8]) -> Result<Self, DesError> {
if buf.len() < UP_HEADER_SIZE { if buf.len() < UP_HEADER_SIZE {
Err(DesError) Err(DesError)
} else { } else {
let hdr = BO::read_u32(&buf[0..4]);
Ok(UpMsg::Data { Ok(UpMsg::Data {
seg_idx: BO::read_u32(&buf[0..4]), seg_idx: hdr & !LAST_SEG_MASK,
payload: buf[4..].into(), payload: SegData {
bytes: buf[4..].into(),
is_last_segment: (hdr & LAST_SEG_MASK) != 0,
},
}) })
} }
} }
fn ser_to(&self, buf: &mut [u8]) -> usize { fn ser_to(&self, buf: &mut [u8]) -> usize {
match self { match self {
UpMsg::Data { payload, seg_idx } => { UpMsg::Data {
let len = payload.len(); payload:
BO::write_u32(&mut buf[0..4], *seg_idx); SegData {
buf[4..4 + len].copy_from_slice(&payload[..]); bytes,
is_last_segment,
},
seg_idx,
} => {
let hdr = *seg_idx | if *is_last_segment { LAST_SEG_MASK } else { 0 };
BO::write_u32(&mut buf[0..4], hdr);
let len = bytes.len();
buf[4..4 + len].copy_from_slice(&bytes[..]);
4 + len 4 + len
} }
} }

View File

@ -1,3 +1,5 @@
use tokio::io::{AsyncRead, AsyncWrite};
/// Per the assignment spec, `1472` is the maximum size packet we're allowed to send. /// Per the assignment spec, `1472` is the maximum size packet we're allowed to send.
pub const MAX_TOTAL_PACKET_SIZE: usize = 1472; pub const MAX_TOTAL_PACKET_SIZE: usize = 1472;
@ -45,3 +47,44 @@ pub struct SegmentSet {
pub other_segs: std::collections::HashSet<SegIdx>, pub other_segs: std::collections::HashSet<SegIdx>,
} }
*/ */
#[derive(Clone)]
pub struct SegData {
// TODO: encoding
pub(crate) bytes: Vec<u8>,
pub is_last_segment: bool,
}
impl SegData {
pub fn len(&self) -> usize {
self.bytes.len()
}
pub async fn read<IN>(inp: &mut IN) -> Result<SegData, tokio::io::Error>
where
IN: AsyncRead + Unpin,
{
use tokio::io::AsyncReadExt;
let mut buf = [0u8; MAX_SEG_SIZE];
let len = inp.read(&mut buf).await.unwrap_or(0);
if len > 0 {
Ok(SegData {
bytes: Vec::from(&buf[..len]),
is_last_segment: false,
})
} else {
Ok(SegData {
bytes: vec![],
is_last_segment: true,
})
}
}
pub async fn write<OUT>(&self, out: &mut OUT) -> Result<(), tokio::io::Error>
where
OUT: AsyncWrite + Unpin,
{
use tokio::io::AsyncWriteExt;
out.write_all(&self.bytes).await
}
}