CS3700-project3/hptp-recv/src/main.rs

101 lines
3.0 KiB
Rust

#![feature(backtrace)]
extern crate hptp;
extern crate tokio;
#[macro_use]
extern crate thiserror;
use hptp::logger::Logger;
use hptp::msg::{DownMsg, UpMsg};
use hptp::peer::{self, DownPeer, Peer};
use std::collections::HashMap;
use tokio::io::AsyncWrite;
#[derive(Error, Debug)]
enum Error {
#[error("io error: {source}")]
Io {
#[from]
source: tokio::io::Error,
backtrace: std::backtrace::Backtrace,
},
}
fn entry() -> Result<(), Error> {
let mut rt = tokio::runtime::Runtime::new().unwrap();
let mut log = Logger::new();
rt.block_on(start(&mut log))
}
fn main() {
if let Err(e) = entry() {
use std::error::Error;
println!("Error: {}", e);
for bt in e.backtrace() {
println!("{}", bt);
}
}
}
async fn start(log: &mut Logger) -> Result<(), Error> {
let sock = tokio::net::UdpSocket::bind((std::net::Ipv4Addr::LOCALHOST, 0)).await?;
log.bound(sock.local_addr()?.port()).await;
let mut out = tokio::io::stdout();
let mut peer = Peer::new(sock);
download(log, &mut peer, &mut out).await
}
async fn send(peer: &mut DownPeer, m: DownMsg) -> Result<(), Error> {
match peer.send(m).await {
Ok(()) => Ok(()),
Err(peer::SendError::Io { source }) => Err(source.into()),
Err(peer::SendError::NoTarget) => unreachable!(),
}
}
async fn download<OUT>(log: &mut Logger, peer: &mut DownPeer, out: &mut OUT) -> Result<(), Error>
where
OUT: AsyncWrite + Unpin,
{
let mut segs = HashMap::new();
let mut flush_seg_idx = 0;
loop {
match peer.recv().await {
Ok(UpMsg::Data { payload, seg_idx }) => {
if segs.contains_key(&seg_idx) {
log.recv_data_ignored(seg_idx as usize, payload.len()).await;
} else {
log.recv_data_accepted(
seg_idx as usize,
payload.len(),
hptp::logger::OutOfOrder,
)
.await;
segs.insert(seg_idx, Some(payload));
let ack = DownMsg::Ack {
idxs: segs.keys().cloned().collect(),
};
log.debug_msg(format!("sent ack: {:?}", {
let mut idxs = segs.keys().collect::<Vec<_>>();
idxs.sort_unstable();
idxs
}))
.await;
send(peer, ack).await?;
}
}
Err(peer::RecvError::InvalidMessage { .. }) => log.recv_corrupt().await,
Err(peer::RecvError::Io { source }) => return Err(source.into()),
}
while let Some(v) = segs.get_mut(&flush_seg_idx) {
if let Some(payload) = v.take() {
use tokio::io::AsyncWriteExt;
out.write_all(&payload).await?;
out.flush().await?;
}
flush_seg_idx += 1;
}
}
}