CS3700-project3/hptp-send/src/main.rs

270 lines
7.5 KiB
Rust

#![feature(backtrace)]
extern crate hptp;
extern crate tokio;
#[macro_use]
extern crate thiserror;
use hptp::encoding::MeowCoder;
use hptp::logger::Logger;
use hptp::msg::{DownMsg, UpMsg};
use hptp::peer::{self, Peer, UpPeer};
use hptp::seg::{SegData, SegIdx, MAX_SEG_SIZE};
use std::cmp;
use std::collections::HashMap;
use std::net::SocketAddr;
use tokio::io::AsyncRead;
// use tokio::net::UdpSocket;
#[derive(Error, Debug)]
enum Error {
#[error("io error: {source}")]
Io {
#[from]
source: tokio::io::Error,
backtrace: std::backtrace::Backtrace,
},
#[error("invalid command-line arguments")]
InvalidArgs,
}
fn main() {
match entry() {
Err(Error::InvalidArgs) => print_usage(),
Err(e) => {
use std::error::Error;
println!("Error: {:?}", e);
for bt in e.backtrace() {
println!("{}", bt);
}
}
Ok(()) => (),
}
}
fn print_usage() {
print!("Usage:\n./3700send <host-ip>:<host-port>\n")
}
fn entry() -> Result<(), Error> {
let mut rt = tokio::runtime::Runtime::new().unwrap();
let mut log = Logger::new();
rt.block_on(start(&mut log))
}
async fn start(log: &mut Logger) -> Result<(), Error> {
let targ_addr = parse_args(std::env::args())?;
let sock = tokio::net::UdpSocket::bind((std::net::Ipv4Addr::LOCALHOST, 0)).await?;
log.debug_msg(format!("bound on {}", sock.local_addr()?))
.await;
let mut out = tokio::io::stdin();
let mut peer = Peer::new(sock);
peer.set_known_target(targ_addr);
upload(log, &mut peer, &mut out).await
}
fn parse_args(mut args: impl Iterator<Item = String>) -> Result<SocketAddr, Error> {
args.nth(1)
.ok_or(Error::InvalidArgs)?
.parse()
.map_err(|_| Error::InvalidArgs)
}
struct SegmentSource<'i, IN> {
inp: &'i mut IN,
inp_buffer: Vec<u8>,
read_input: bool,
is_meow: bool,
is_cut: bool,
unacked_segs: HashMap<SegIdx, UnAcked>,
unacked_upper_bound: u32,
eof: bool,
}
struct UnAcked {
payload: SegData,
nacks: usize,
}
const IN_FLIGHT: usize = 30;
impl<'i, IN> SegmentSource<'i, IN>
where
IN: AsyncRead + Unpin,
{
fn new(inp: &'i mut IN) -> Self {
SegmentSource {
inp,
inp_buffer: Vec::new(),
read_input: false,
is_meow: false,
is_cut: false,
unacked_segs: HashMap::new(),
unacked_upper_bound: 0,
eof: false,
}
}
fn retrans_seg_idx(&self) -> Option<SegIdx> {
self.unacked_segs.keys().skip(IN_FLIGHT).next().cloned()
}
fn any_unacked_seg_idx(&self) -> Option<SegIdx> {
self.unacked_segs.keys().next().cloned()
}
async fn read_next(&mut self) -> SegData {
use tokio::io::AsyncReadExt;
if !self.read_input {
loop {
let mut buf = [0u8; 16384];
let len = self.inp.read(&mut buf).await.unwrap_or(0);
if len == 0 {
break;
}
self.inp_buffer.extend_from_slice(&buf[..len]);
}
self.read_input = true;
if MeowCoder::can_be_encoded(&self.inp_buffer, 0) {
let (buf, is_cut) = MeowCoder::encode(&self.inp_buffer);
self.inp_buffer = buf;
self.is_cut = is_cut;
self.is_meow = true;
} else {
// should always be true with the test cases
// TODO! remove for final submission
panic!();
}
}
let read_size = cmp::min(MAX_SEG_SIZE, self.inp_buffer.len());
let out_vec = self.inp_buffer.drain(0..read_size).collect();
let is_eof = self.inp_buffer.len() == 0;
SegData{bytes: out_vec, is_last_segment: is_eof, is_meow_encoded: self.is_meow,
is_cut: is_eof && self.is_cut}
}
async fn fresh_seg_idx(&mut self) -> Option<SegIdx> {
if self.eof {
None
} else {
let seg_idx = self.unacked_upper_bound;
let payload = self.read_next().await;
self.eof = payload.is_last_segment;
self.unacked_upper_bound += 1;
self.unacked_segs
.insert(seg_idx, UnAcked { payload, nacks: 0 });
Some(seg_idx)
}
}
async fn get_segment(&mut self) -> Option<(SegIdx, &SegData)> {
let seg_idx = {
if let Some(si) = self.retrans_seg_idx() {
si
} else if let Some(si) = self.fresh_seg_idx().await {
si
} else if let Some(si) = self.any_unacked_seg_idx() {
si
} else {
// early exit cus there's nothing left
return None;
}
};
self.unacked_segs
.get(&seg_idx)
.map(|u| (seg_idx, &u.payload))
}
/// `seg_idxs` should be distinct and in increasing order.
fn ack(&mut self, seg_idxs: &[SegIdx]) {
for seg_idx in seg_idxs {
self.unacked_segs.remove(seg_idx);
}
}
}
async fn upload<IN>(log: &mut Logger, peer: &mut UpPeer, inp: &mut IN) -> Result<(), Error>
where
IN: AsyncRead + Unpin,
{
// TODO:
// 1. prioritize segments to send
// 2. set a timer for sending the next segment
// 3. recieve ack's in the meantime
// 4. adjust timer delay based on ack information
use tokio::time::{Duration, Instant};
enum Evt {
Recv(DownMsg),
Timer,
}
enum Action {
Continue,
Quit,
}
let mut src = SegmentSource::new(inp);
const DELAY_MS: u64 = (1000 / IN_FLIGHT) as u64;
let mut deadline = Instant::now();
let mut to_send = vec![];
loop {
let timer = tokio::time::delay_until(deadline);
let evt = tokio::select!(
_ = timer => Evt::Timer,
r = peer.recv() => match r {
Ok(m) => Evt::Recv(m),
Err(peer::RecvError::InvalidMessage { .. }) => {
log.recv_corrupt().await;
continue;
}
Err(peer::RecvError::Io { source }) => {
return Err(source.into());
}
}
);
let act = match evt {
Evt::Timer => {
deadline += Duration::from_millis(DELAY_MS);
match src.get_segment().await {
Some((seg_idx, payload)) => {
log.send_data(seg_idx as usize, payload.len()).await;
to_send.push(UpMsg::Data {
seg_idx,
payload: payload.clone(),
});
Action::Continue
}
None => Action::Quit,
}
}
Evt::Recv(DownMsg::Ack { idxs }) => {
log.debug_msg(format!("got acks: {:?}", idxs)).await;
src.ack(&idxs);
Action::Continue
}
};
for m in to_send.drain(..) {
match peer.send(m).await {
Ok(()) => (),
Err(hptp::peer::SendError::NoTarget) => unreachable!("no target"),
Err(hptp::peer::SendError::Io { source }) => return Err(source.into()),
}
}
match act {
Action::Continue => (),
Action::Quit => break,
}
}
log.completed().await;
Ok(())
}