separate types for upload/download messages

This commit is contained in:
Milo Turner 2020-03-09 14:11:12 -04:00
parent 9553c25b26
commit 6652ddb3d1
4 changed files with 66 additions and 42 deletions

View File

@ -5,8 +5,8 @@ extern crate tokio;
extern crate thiserror; extern crate thiserror;
use hptp::logger::Logger; use hptp::logger::Logger;
use hptp::msg::Msg; use hptp::msg::{DownMsg, UpMsg};
use hptp::peer::{self, Peer}; use hptp::peer::{self, DownPeer, Peer};
use tokio::io::AsyncWrite; use tokio::io::AsyncWrite;
#[derive(Error, Debug)] #[derive(Error, Debug)]
@ -47,7 +47,7 @@ async fn start(log: &mut Logger) -> Result<(), Error> {
download(log, &mut peer, &mut out).await download(log, &mut peer, &mut out).await
} }
async fn send(peer: &mut Peer, m: Msg) -> Result<(), Error> { async fn send(peer: &mut DownPeer, m: DownMsg) -> Result<(), Error> {
match peer.send(m).await { match peer.send(m).await {
Ok(()) => Ok(()), Ok(()) => Ok(()),
Err(peer::SendError::Io { source }) => Err(source.into()), Err(peer::SendError::Io { source }) => Err(source.into()),
@ -55,7 +55,7 @@ async fn send(peer: &mut Peer, m: Msg) -> Result<(), Error> {
} }
} }
async fn download<OUT>(log: &mut Logger, peer: &mut Peer, out: &mut OUT) -> Result<(), Error> async fn download<OUT>(log: &mut Logger, peer: &mut DownPeer, out: &mut OUT) -> Result<(), Error>
where where
OUT: AsyncWrite + Unpin, OUT: AsyncWrite + Unpin,
{ {
@ -63,20 +63,17 @@ where
let mut pos = 0; let mut pos = 0;
loop { loop {
match peer.recv().await { match peer.recv().await {
Ok(Msg::Ack) => log.debug_msg("not expecting an ack...").await, Ok(UpMsg::Data(data)) => {
Ok(Msg::Data(data)) => {
let len = data.len(); let len = data.len();
out.write_all(&data).await?; out.write_all(&data).await?;
out.flush().await?; out.flush().await?;
log.recv_data_accepted(pos, len, hptp::logger::InOrder) log.recv_data_accepted(pos, len, hptp::logger::InOrder)
.await; .await;
send(peer, Msg::Ack).await?; send(peer, DownMsg::Ack).await?;
pos += len; pos += len;
} }
Err(peer::RecvError::InvalidMessage { .. }) => log.recv_corrupt().await, Err(peer::RecvError::InvalidMessage { .. }) => log.recv_corrupt().await,
Err(peer::RecvError::Io { source }) => return Err(source.into()), Err(peer::RecvError::Io { source }) => return Err(source.into()),
} }
} }

View File

@ -5,8 +5,8 @@ extern crate tokio;
extern crate thiserror; extern crate thiserror;
use hptp::logger::Logger; use hptp::logger::Logger;
use hptp::msg::{self, Msg}; use hptp::msg::{self, DownMsg, UpMsg};
use hptp::peer::{self, Peer}; use hptp::peer::{self, Peer, UpPeer};
use std::net::SocketAddr; use std::net::SocketAddr;
use tokio::io::AsyncRead; use tokio::io::AsyncRead;
// use tokio::net::UdpSocket; // use tokio::net::UdpSocket;
@ -83,7 +83,7 @@ where
}) })
} }
async fn send(peer: &mut Peer, m: Msg) -> Result<(), Error> { async fn send(peer: &mut UpPeer, m: UpMsg) -> Result<(), Error> {
match peer.send(m).await { match peer.send(m).await {
Ok(()) => Ok(()), Ok(()) => Ok(()),
Err(peer::SendError::Io { source }) => Err(source.into()), Err(peer::SendError::Io { source }) => Err(source.into()),
@ -91,7 +91,7 @@ async fn send(peer: &mut Peer, m: Msg) -> Result<(), Error> {
} }
} }
async fn upload<IN>(log: &mut Logger, peer: &mut Peer, inp: &mut IN) -> Result<(), Error> async fn upload<IN>(log: &mut Logger, peer: &mut UpPeer, inp: &mut IN) -> Result<(), Error>
where where
IN: AsyncRead + Unpin, IN: AsyncRead + Unpin,
{ {
@ -100,12 +100,11 @@ where
loop { loop {
if let Some(ack_len) = next_ack_len { if let Some(ack_len) = next_ack_len {
match peer.recv().await { match peer.recv().await {
Ok(Msg::Ack) => { Ok(DownMsg::Ack) => {
log.recv_ack(pos).await; log.recv_ack(pos).await;
pos += ack_len; pos += ack_len;
next_ack_len = None; next_ack_len = None;
} }
Ok(_) => log.debug_msg("got some other packet??").await,
Err(peer::RecvError::InvalidMessage { .. }) => log.recv_corrupt().await, Err(peer::RecvError::InvalidMessage { .. }) => log.recv_corrupt().await,
Err(peer::RecvError::Io { source }) => return Err(source.into()), Err(peer::RecvError::Io { source }) => return Err(source.into()),
@ -115,7 +114,7 @@ where
Some(data) => { Some(data) => {
next_ack_len = Some(data.len()); next_ack_len = Some(data.len());
log.send_data(pos, data.len()).await; log.send_data(pos, data.len()).await;
send(peer, Msg::Data(data)).await?; send(peer, UpMsg::Data(data)).await?;
} }
None => break, None => break,
} }

View File

@ -1,6 +1,10 @@
#[derive(Clone)] #[derive(Clone)]
pub enum Msg { pub enum UpMsg {
Data(Vec<u8>), Data(Vec<u8>),
}
#[derive(Clone)]
pub enum DownMsg {
Ack, Ack,
} }
@ -11,24 +15,33 @@ pub const MAX_SERIALIZED_SIZE: usize = 1 + MAX_DATA_SIZE;
#[error("deserialization failed; malformed packet")] #[error("deserialization failed; malformed packet")]
pub struct DesError; pub struct DesError;
impl Msg { pub trait SerDes: Sized {
pub fn des(data: &[u8]) -> Result<Msg, DesError> { fn des(data: &[u8]) -> Result<Self, DesError>;
match data.first() { fn ser_into(self) -> Vec<u8>;
Some(0) => Ok(Msg::Data(data[1..].into())), }
Some(1) => Ok(Msg::Ack),
_ => Err(DesError), impl SerDes for UpMsg {
fn des(data: &[u8]) -> Result<Self, DesError> {
Ok(UpMsg::Data(data.into()))
}
fn ser_into(self) -> Vec<u8> {
match self {
UpMsg::Data(data) => data,
}
}
}
impl SerDes for DownMsg {
fn des(data: &[u8]) -> Result<Self, DesError> {
if data == [0] {
Ok(DownMsg::Ack)
} else {
Err(DesError)
} }
} }
pub fn ser(&self) -> Vec<u8> { fn ser_into(self) -> Vec<u8> {
let mut buf = Vec::new(); vec![0]
match self {
Msg::Data(data) => {
buf.push(0);
buf.extend_from_slice(data);
}
Msg::Ack => buf.push(1),
}
buf
} }
} }

View File

@ -1,13 +1,18 @@
use std::marker::PhantomData;
use std::net::SocketAddr; use std::net::SocketAddr;
use tokio::net::UdpSocket; use tokio::net::UdpSocket;
use super::msg::{self, Msg}; use super::msg::{self, SerDes};
pub struct Peer { pub struct Peer<F, T> {
sock: UdpSocket, sock: UdpSocket,
targ: Option<SocketAddr>, targ: Option<SocketAddr>,
_phantom: PhantomData<fn(F) -> T>,
} }
pub type UpPeer = Peer<msg::UpMsg, msg::DownMsg>;
pub type DownPeer = Peer<msg::DownMsg, msg::UpMsg>;
#[derive(Error, Debug)] #[derive(Error, Debug)]
pub enum RecvError { pub enum RecvError {
#[error("io error: {source}")] #[error("io error: {source}")]
@ -33,26 +38,36 @@ pub enum SendError {
NoTarget, NoTarget,
} }
impl Peer { impl<F, T> Peer<F, T> {
pub fn new(sock: UdpSocket) -> Self { pub fn new(sock: UdpSocket) -> Self {
Peer { sock, targ: None } Peer {
sock,
targ: None,
_phantom: PhantomData,
}
} }
pub fn set_known_target(&mut self, addr: SocketAddr) { pub fn set_known_target(&mut self, addr: SocketAddr) {
self.targ = Some(addr); self.targ = Some(addr);
} }
pub async fn send(&mut self, msg: Msg) -> Result<(), SendError> { pub async fn send(&mut self, msg: F) -> Result<(), SendError>
where
F: SerDes,
{
let targ = self.targ.ok_or(SendError::NoTarget)?; let targ = self.targ.ok_or(SendError::NoTarget)?;
let bs = msg.ser(); let data = msg.ser_into();
let _n_sent = self.sock.send_to(&bs, targ).await?; let _n_sent = self.sock.send_to(&data, targ).await?;
Ok(()) Ok(())
} }
pub async fn recv(&mut self) -> Result<Msg, RecvError> { pub async fn recv(&mut self) -> Result<T, RecvError>
where
T: SerDes,
{
let mut buf = [0u8; msg::MAX_SERIALIZED_SIZE]; let mut buf = [0u8; msg::MAX_SERIALIZED_SIZE];
let (len, who) = self.sock.recv_from(&mut buf).await?; let (len, who) = self.sock.recv_from(&mut buf).await?;
self.set_known_target(who); self.set_known_target(who);
Ok(Msg::des(&buf[..len])?) Ok(T::des(&buf[..len])?)
} }
} }