diff --git a/megacom/__init__.py b/megacom/__init__.py index 8636cdb..f48a476 100644 --- a/megacom/__init__.py +++ b/megacom/__init__.py @@ -8,9 +8,11 @@ import re import signal import sys import termios +import threading import tty +import queue from types import TracebackType -from typing import Any, BinaryIO, List, Optional, Tuple, Type +from typing import Any, BinaryIO, Iterable, List, Optional, Tuple, Type from typing_extensions import Literal import serial @@ -33,17 +35,21 @@ MODE_LOOKUP = { class TtyRaw: - __slots__ = ["isatty", "infd", "outfd", "settings"] + __slots__ = ["isatty", "infd", "outfd", "settings", "read_thread", "write_thread"] isatty: bool infd: int outfd: int settings: List[Any] + read_thread: Optional[threading.Thread] + write_thread: Optional[threading.Thread] def __init__(self) -> None: self.isatty = False self.infd = 0 self.outfd = 0 self.settings = [] + self.read_thread = None + self.write_thread = None def __enter__(self) -> 'TtyRaw': if sys.stdin.isatty(): @@ -59,6 +65,7 @@ class TtyRaw: exc_traceback: Optional[TracebackType]) -> Literal[False]: if self.isatty: termios.tcsetattr(self.infd, termios.TCSADRAIN, self.settings) + # unset nonblocking modes flags = fcntl.fcntl(self.infd, fcntl.F_GETFL) flags = flags & (~os.O_NONBLOCK) @@ -68,17 +75,103 @@ class TtyRaw: fcntl.fcntl(self.outfd, fcntl.F_SETFL, flags) return False + async def setup_async(self) -> Tuple[asyncio.StreamReader, asyncio.StreamWriter]: + # this is some UNIX bullshit honestly + # if a stdio stream is a tty we assume we have full control over it and thus can set it + # into nonblocking mode (optimal) and do real asyncio + # if it's _not_ a tty, it could be a pipe, and in the UNIX piping scheme, programs really + # don't like it when you set a pipe to nonblocking and they're not expecting that + # therefore in that case we have this really ugly workaround that offloads stdio to thread + # executors which sucks a lot. but, it means megacom now works as expected with regular + # UNIX shell commands + # annoyingly, shell piping is when you'd actually want the higher performance allowed by + # using asyncio, so if you have some very spammy output you'd like to capture the best thing + # is to keep it on the tty and use -l to also save it to a log file -async def setup_async() -> Tuple[asyncio.StreamReader, asyncio.StreamWriter]: - loop = asyncio.get_event_loop() - reader = asyncio.StreamReader() - reader_protocol = asyncio.StreamReaderProtocol(reader) + loop = asyncio.get_event_loop() + if sys.stdin.isatty(): + reader = asyncio.StreamReader() + reader_protocol = asyncio.StreamReaderProtocol(reader) + await loop.connect_read_pipe(lambda: reader_protocol, sys.stdin.buffer) + else: + reader = asyncio.StreamReader() + def reader_thread(): + while True: + # XXX: if this is higher than 1, python treats it like it's a + # "read up to N bytes" thing and will buffer input indefinitely. idk how to + # turn that behavior off + buf = sys.stdin.buffer.read(1) + if len(buf) == 0: + break + loop.call_soon_threadsafe(lambda buf: reader.feed_data(buf), buf) + loop.call_soon_threadsafe(lambda: reader.feed_eof()) + self.read_thread = threading.Thread(target=reader_thread, daemon=True) + self.read_thread.start() - writer_transport, writer_protocol = await loop.connect_write_pipe( - asyncio.streams.FlowControlMixin, sys.stdout.buffer) - writer = asyncio.StreamWriter(writer_transport, writer_protocol, None, loop) - await loop.connect_read_pipe(lambda: reader_protocol, sys.stdin.buffer) - return (reader, writer) + if sys.stdout.isatty(): + writer_transport, writer_protocol = await loop.connect_write_pipe( + asyncio.streams.FlowControlMixin, sys.stdout.buffer) + writer = asyncio.StreamWriter(writer_transport, writer_protocol, None, loop) + else: + class ThreadWriter(asyncio.StreamWriter): + __slots__ = ["transport", "queue", "_is_closing", "close_evt"] + transport: asyncio.BaseTransport + queue: queue.Queue[Optional[bytes]] + _is_closing: bool + close_evt: threading.Event + + def __init__(self): + self.transport = None + self.queue = queue.Queue() + self._is_closing = False + self.close_evt = threading.Event() + + def write(self, data: bytes) -> None: + self.queue.put(data) + + def writelines(self, data: Iterable[bytes]) -> None: + for d in data: + self.queue.put(d) + + def close(self) -> None: + if not self._is_closing: + self._is_closing = True + self.queue.put(None) + + def can_write_eof(self) -> bool: + return True + + def write_eof(self) -> None: + self.close() + + def get_extra_info(self, name: str, default: Any = None) -> Any: + return default + + async def drain(self) -> None: + await asyncio.to_thread(lambda: self.queue.join()) + await asyncio.to_thread(lambda: sys.stdout.buffer.flush()) + + def is_closing(self) -> bool: + return self._is_closing + + async def wait_closed(self) -> None: + await asyncio.to_thread(lambda: self.close_evt.wait()) + + def _writing_thread(self): + while True: + data = self.queue.get() + if data is None: + break + sys.stdout.buffer.write(data) + self.queue.task_done() + sys.stdout.buffer.flush() + self.close_evt.set() + + writer = ThreadWriter() + self.write_thread = threading.Thread(target=writer._writing_thread, daemon=True) + self.write_thread.start() + + return (reader, writer) # CTRL-A @@ -123,7 +216,7 @@ class KeycodeHandler: async def megacom(ttyraw: TtyRaw, tty: str, baud: int, mode: str, logfile: Optional[str]) -> None: - (stdin, stdout) = await setup_async() + (stdin, stdout) = await ttyraw.setup_async() m = MODE_RE.match(mode) if m is None: @@ -198,8 +291,11 @@ async def megacom_main(stdin: asyncio.StreamReader, stdout: asyncio.StreamWriter await stdout.drain() return - stdout.write(f"megacom connected to {tty}\r\n".encode()) - await stdout.drain() + def print_conn(): + sys.stderr.write(f"megacom connected to {tty}\r\n") + sys.stderr.flush() + + status_task = asyncio.to_thread(print_conn) async def connect_pipe(pin: asyncio.StreamReader, pout: asyncio.StreamWriter, ctrl: bool = False) -> None: @@ -209,14 +305,10 @@ async def megacom_main(stdin: asyncio.StreamReader, stdout: asyncio.StreamWriter continue if ctrl: - # stdout.write(f"\r\nin char: {c}\r\n".encode()) - # await stdout.drain() c = keycodes.process(c) if len(c) == 0: continue else: - # stdout.write(f"\r\nout char: {c}\r\n".encode()) - # await stdout.drain() if log is not None: log.write(c) @@ -227,6 +319,8 @@ async def megacom_main(stdin: asyncio.StreamReader, stdout: asyncio.StreamWriter serial_to_stdout: asyncio.Task = asyncio.create_task(connect_pipe(serialin, stdout)) time_to_exit: asyncio.Task = asyncio.create_task(keycodes.exit_flag.wait()) + await status_task + do_retry = False def handle_done(task):