diff --git a/hptp-recv/src/main.rs b/hptp-recv/src/main.rs index 7956ae5..7eba27b 100644 --- a/hptp-recv/src/main.rs +++ b/hptp-recv/src/main.rs @@ -6,7 +6,7 @@ extern crate thiserror; use hptp::log::Logger; use hptp::msg::Msg; -use hptp::peer::DownloadPeer; +use hptp::peer::Peer; use tokio::io::AsyncWrite; #[derive(Error, Debug)] @@ -42,20 +42,23 @@ async fn start(log: &mut Logger) -> Result<(), Error> { .await .map_err(|_| Error::NoPortAvail)?; log.bound(sock.local_addr()?.port()).await; - let mut peer = DownloadPeer::new(tokio::io::stdout(), sock); - download(log, &mut peer).await + let mut out = tokio::io::stdout(); + let mut peer = Peer::new(sock); + download(log, &mut peer, &mut out).await } -async fn download(log: &mut Logger, peer: &mut DownloadPeer) -> Result<(), Error> +async fn download(log: &mut Logger, peer: &mut Peer, out: &mut OUT) -> Result<(), Error> where OUT: AsyncWrite + Unpin, { + use tokio::io::AsyncWriteExt; let mut pos = 0; loop { match peer.recv().await? { Msg::Data(data) => { let len = data.len(); - peer.write_output(&data).await; + out.write_all(&data).await?; + out.flush().await?; log.recv_data_accepted(pos, len, hptp::log::InOrder).await; pos += len; } diff --git a/hptp-send/src/main.rs b/hptp-send/src/main.rs index d429eec..eb41959 100644 --- a/hptp-send/src/main.rs +++ b/hptp-send/src/main.rs @@ -6,8 +6,8 @@ extern crate thiserror; use hptp::log::Logger; use hptp::msg::Msg; -use hptp::peer::UploadPeer; -use std::net::SocketAddrV4; +use hptp::peer::Peer; +use std::net::SocketAddr; use tokio::io::AsyncRead; // use tokio::net::UdpSocket; @@ -56,24 +56,40 @@ async fn start(log: &mut Logger) -> Result<(), Error> { .map_err(|_| Error::NoAvailPort)?; log.debug_msg(format!("bound on {}", sock.local_addr()?)) .await; - let mut peer = UploadPeer::new(tokio::io::stdin(), sock, targ_addr); - upload(log, &mut peer).await + let mut out = tokio::io::stdin(); + let mut peer = Peer::new(sock); + peer.set_known_target(targ_addr); + upload(log, &mut peer, &mut out).await } -fn parse_args(mut args: impl Iterator) -> Result { +fn parse_args(mut args: impl Iterator) -> Result { args.nth(1) .ok_or(Error::InvalidArgs)? .parse() .map_err(|_| Error::InvalidArgs) } -async fn upload(log: &mut Logger, peer: &mut UploadPeer) -> Result<(), Error> +async fn read_some(inp: &mut IN) -> Result>, Error> +where + IN: AsyncRead + Unpin, +{ + use tokio::io::AsyncReadExt; + let mut buf = [0u8; 512]; + let len = inp.read(&mut buf).await?; + Ok(if len > 0 { + Some(buf[..len].into()) + } else { + None + }) +} + +async fn upload(log: &mut Logger, peer: &mut Peer, inp: &mut IN) -> Result<(), Error> where IN: AsyncRead + Unpin, { let mut pos = 0; loop { - match peer.read_input().await { + match read_some(inp).await? { Some(data) => { let len = data.len(); peer.send(Msg::Data(data)).await?; diff --git a/hptp/src/peer.rs b/hptp/src/peer.rs index 7435087..cde4a83 100644 --- a/hptp/src/peer.rs +++ b/hptp/src/peer.rs @@ -1,83 +1,39 @@ -use std::net::SocketAddrV4; -use tokio::io::{AsyncRead, AsyncWrite}; +use std::net::SocketAddr; use tokio::net::UdpSocket; use super::msg::Msg; -pub struct UploadPeer { - input: IN, - sock: UdpSocket, - targ: SocketAddrV4, -} - -pub struct DownloadPeer { - output: OUT, +pub struct Peer { sock: UdpSocket, + targ: Option, } const BUFFER_SIZE: usize = 1000; -impl UploadPeer { - pub fn new(input: IN, sock: UdpSocket, targ: SocketAddrV4) -> Self { - UploadPeer { input, sock, targ } +impl Peer { + pub fn new(sock: UdpSocket) -> Self { + Peer { sock, targ: None } + } + + pub fn set_known_target(&mut self, addr: SocketAddr) { + self.targ = Some(addr); } pub async fn send(&mut self, msg: Msg) -> Result<(), tokio::io::Error> { + let targ = self.targ.expect("no target to send to"); let bs = msg.ser(); let mut i = 0; while i < bs.len() { - let n = self.sock.send_to(&bs[i..], self.targ).await?; + let n = self.sock.send_to(&bs[i..], targ).await?; i += n } Ok(()) } -} - -impl UploadPeer -where - IN: AsyncRead + Unpin, -{ - pub async fn read_input(&mut self) -> Option> { - use tokio::io::AsyncReadExt; - let mut buf = [0u8; BUFFER_SIZE]; - let len = self - .input - .read(&mut buf) - .await - .expect("failed to read from stdin"); - if len == 0 { - None - } else { - Some(buf[..len].into()) - } - } -} - -impl DownloadPeer { - pub fn new(output: OUT, sock: UdpSocket) -> Self { - DownloadPeer { output, sock } - } pub async fn recv(&mut self) -> Result { let mut buf = [0u8; BUFFER_SIZE]; - let (len, _who) = self.sock.recv_from(&mut buf).await?; + let (len, who) = self.sock.recv_from(&mut buf).await?; + self.set_known_target(who); Ok(Msg::des(&buf[..len])) } } - -impl DownloadPeer -where - OUT: AsyncWrite + Unpin, -{ - pub async fn write_output(&mut self, data: &[u8]) { - use tokio::io::AsyncWriteExt; - self.output - .write_all(data) - .await - .expect("failed to write to stdout"); - self.output - .flush() - .await - .expect("failed to write to stdout") - } -}