fix nonblocking I/O interference

This commit is contained in:
xenia 2021-05-20 01:16:21 -04:00
parent fb64863bbe
commit 9884973248
1 changed files with 112 additions and 18 deletions

View File

@ -8,9 +8,11 @@ import re
import signal import signal
import sys import sys
import termios import termios
import threading
import tty import tty
import queue
from types import TracebackType from types import TracebackType
from typing import Any, BinaryIO, List, Optional, Tuple, Type from typing import Any, BinaryIO, Iterable, List, Optional, Tuple, Type
from typing_extensions import Literal from typing_extensions import Literal
import serial import serial
@ -33,17 +35,21 @@ MODE_LOOKUP = {
class TtyRaw: class TtyRaw:
__slots__ = ["isatty", "infd", "outfd", "settings"] __slots__ = ["isatty", "infd", "outfd", "settings", "read_thread", "write_thread"]
isatty: bool isatty: bool
infd: int infd: int
outfd: int outfd: int
settings: List[Any] settings: List[Any]
read_thread: Optional[threading.Thread]
write_thread: Optional[threading.Thread]
def __init__(self) -> None: def __init__(self) -> None:
self.isatty = False self.isatty = False
self.infd = 0 self.infd = 0
self.outfd = 0 self.outfd = 0
self.settings = [] self.settings = []
self.read_thread = None
self.write_thread = None
def __enter__(self) -> 'TtyRaw': def __enter__(self) -> 'TtyRaw':
if sys.stdin.isatty(): if sys.stdin.isatty():
@ -59,6 +65,7 @@ class TtyRaw:
exc_traceback: Optional[TracebackType]) -> Literal[False]: exc_traceback: Optional[TracebackType]) -> Literal[False]:
if self.isatty: if self.isatty:
termios.tcsetattr(self.infd, termios.TCSADRAIN, self.settings) termios.tcsetattr(self.infd, termios.TCSADRAIN, self.settings)
# unset nonblocking modes # unset nonblocking modes
flags = fcntl.fcntl(self.infd, fcntl.F_GETFL) flags = fcntl.fcntl(self.infd, fcntl.F_GETFL)
flags = flags & (~os.O_NONBLOCK) flags = flags & (~os.O_NONBLOCK)
@ -68,16 +75,102 @@ class TtyRaw:
fcntl.fcntl(self.outfd, fcntl.F_SETFL, flags) fcntl.fcntl(self.outfd, fcntl.F_SETFL, flags)
return False return False
async def setup_async(self) -> Tuple[asyncio.StreamReader, asyncio.StreamWriter]:
# this is some UNIX bullshit honestly
# if a stdio stream is a tty we assume we have full control over it and thus can set it
# into nonblocking mode (optimal) and do real asyncio
# if it's _not_ a tty, it could be a pipe, and in the UNIX piping scheme, programs really
# don't like it when you set a pipe to nonblocking and they're not expecting that
# therefore in that case we have this really ugly workaround that offloads stdio to thread
# executors which sucks a lot. but, it means megacom now works as expected with regular
# UNIX shell commands
# annoyingly, shell piping is when you'd actually want the higher performance allowed by
# using asyncio, so if you have some very spammy output you'd like to capture the best thing
# is to keep it on the tty and use -l to also save it to a log file
async def setup_async() -> Tuple[asyncio.StreamReader, asyncio.StreamWriter]:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
if sys.stdin.isatty():
reader = asyncio.StreamReader() reader = asyncio.StreamReader()
reader_protocol = asyncio.StreamReaderProtocol(reader) reader_protocol = asyncio.StreamReaderProtocol(reader)
await loop.connect_read_pipe(lambda: reader_protocol, sys.stdin.buffer)
else:
reader = asyncio.StreamReader()
def reader_thread():
while True:
# XXX: if this is higher than 1, python treats it like it's a
# "read up to N bytes" thing and will buffer input indefinitely. idk how to
# turn that behavior off
buf = sys.stdin.buffer.read(1)
if len(buf) == 0:
break
loop.call_soon_threadsafe(lambda buf: reader.feed_data(buf), buf)
loop.call_soon_threadsafe(lambda: reader.feed_eof())
self.read_thread = threading.Thread(target=reader_thread, daemon=True)
self.read_thread.start()
if sys.stdout.isatty():
writer_transport, writer_protocol = await loop.connect_write_pipe( writer_transport, writer_protocol = await loop.connect_write_pipe(
asyncio.streams.FlowControlMixin, sys.stdout.buffer) asyncio.streams.FlowControlMixin, sys.stdout.buffer)
writer = asyncio.StreamWriter(writer_transport, writer_protocol, None, loop) writer = asyncio.StreamWriter(writer_transport, writer_protocol, None, loop)
await loop.connect_read_pipe(lambda: reader_protocol, sys.stdin.buffer) else:
class ThreadWriter(asyncio.StreamWriter):
__slots__ = ["transport", "queue", "_is_closing", "close_evt"]
transport: asyncio.BaseTransport
queue: queue.Queue[Optional[bytes]]
_is_closing: bool
close_evt: threading.Event
def __init__(self):
self.transport = None
self.queue = queue.Queue()
self._is_closing = False
self.close_evt = threading.Event()
def write(self, data: bytes) -> None:
self.queue.put(data)
def writelines(self, data: Iterable[bytes]) -> None:
for d in data:
self.queue.put(d)
def close(self) -> None:
if not self._is_closing:
self._is_closing = True
self.queue.put(None)
def can_write_eof(self) -> bool:
return True
def write_eof(self) -> None:
self.close()
def get_extra_info(self, name: str, default: Any = None) -> Any:
return default
async def drain(self) -> None:
await asyncio.to_thread(lambda: self.queue.join())
await asyncio.to_thread(lambda: sys.stdout.buffer.flush())
def is_closing(self) -> bool:
return self._is_closing
async def wait_closed(self) -> None:
await asyncio.to_thread(lambda: self.close_evt.wait())
def _writing_thread(self):
while True:
data = self.queue.get()
if data is None:
break
sys.stdout.buffer.write(data)
self.queue.task_done()
sys.stdout.buffer.flush()
self.close_evt.set()
writer = ThreadWriter()
self.write_thread = threading.Thread(target=writer._writing_thread, daemon=True)
self.write_thread.start()
return (reader, writer) return (reader, writer)
@ -123,7 +216,7 @@ class KeycodeHandler:
async def megacom(ttyraw: TtyRaw, tty: str, baud: int, mode: str, logfile: Optional[str]) -> None: async def megacom(ttyraw: TtyRaw, tty: str, baud: int, mode: str, logfile: Optional[str]) -> None:
(stdin, stdout) = await setup_async() (stdin, stdout) = await ttyraw.setup_async()
m = MODE_RE.match(mode) m = MODE_RE.match(mode)
if m is None: if m is None:
@ -198,8 +291,11 @@ async def megacom_main(stdin: asyncio.StreamReader, stdout: asyncio.StreamWriter
await stdout.drain() await stdout.drain()
return return
stdout.write(f"megacom connected to {tty}\r\n".encode()) def print_conn():
await stdout.drain() sys.stderr.write(f"megacom connected to {tty}\r\n")
sys.stderr.flush()
status_task = asyncio.to_thread(print_conn)
async def connect_pipe(pin: asyncio.StreamReader, pout: asyncio.StreamWriter, async def connect_pipe(pin: asyncio.StreamReader, pout: asyncio.StreamWriter,
ctrl: bool = False) -> None: ctrl: bool = False) -> None:
@ -209,14 +305,10 @@ async def megacom_main(stdin: asyncio.StreamReader, stdout: asyncio.StreamWriter
continue continue
if ctrl: if ctrl:
# stdout.write(f"\r\nin char: {c}\r\n".encode())
# await stdout.drain()
c = keycodes.process(c) c = keycodes.process(c)
if len(c) == 0: if len(c) == 0:
continue continue
else: else:
# stdout.write(f"\r\nout char: {c}\r\n".encode())
# await stdout.drain()
if log is not None: if log is not None:
log.write(c) log.write(c)
@ -227,6 +319,8 @@ async def megacom_main(stdin: asyncio.StreamReader, stdout: asyncio.StreamWriter
serial_to_stdout: asyncio.Task = asyncio.create_task(connect_pipe(serialin, stdout)) serial_to_stdout: asyncio.Task = asyncio.create_task(connect_pipe(serialin, stdout))
time_to_exit: asyncio.Task = asyncio.create_task(keycodes.exit_flag.wait()) time_to_exit: asyncio.Task = asyncio.create_task(keycodes.exit_flag.wait())
await status_task
do_retry = False do_retry = False
def handle_done(task): def handle_done(task):