"""QUIC transport using aioquic for the v2 wire format. Opens a QUIC connection to the qpq server and provides ``rpc()`` to send protobuf-encoded requests over individual QUIC streams, reading back the framed response on the same stream. aioquic is imported lazily so that the module can be loaded even when aioquic is not installed (e.g. for tests that only exercise wire/proto). """ from __future__ import annotations import asyncio import ssl from typing import Any from quicprochat.types import ConnectionError, TimeoutError from quicprochat.wire import HEADER_SIZE, encode_frame, decode_header def _make_protocol_class() -> type: """Build the protocol class at call time so aioquic is imported lazily.""" from aioquic.asyncio.protocol import QuicConnectionProtocol from aioquic.quic.events import StreamDataReceived, QuicEvent class _QpqQuicProtocol(QuicConnectionProtocol): """QUIC protocol handler that dispatches stream data to waiting futures.""" def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self._stream_buffers: dict[int, bytearray] = {} self._stream_waiters: dict[int, asyncio.Future[bytes]] = {} def quic_event_received(self, event: QuicEvent) -> None: if isinstance(event, StreamDataReceived): sid = event.stream_id buf = self._stream_buffers.setdefault(sid, bytearray()) buf.extend(event.data) if len(buf) >= HEADER_SIZE: _, _, payload_len = decode_header(bytes(buf[:HEADER_SIZE])) total = HEADER_SIZE + payload_len if len(buf) >= total: frame = bytes(buf[:total]) del buf[:total] waiter = self._stream_waiters.pop(sid, None) if waiter and not waiter.done(): waiter.set_result(frame) def wait_for_stream(self, stream_id: int) -> asyncio.Future[bytes]: loop = asyncio.get_event_loop() fut: asyncio.Future[bytes] = loop.create_future() self._stream_waiters[stream_id] = fut buf = self._stream_buffers.get(stream_id, bytearray()) if len(buf) >= HEADER_SIZE: _, _, payload_len = decode_header(bytes(buf[:HEADER_SIZE])) total = HEADER_SIZE + payload_len if len(buf) >= total: frame = bytes(buf[:total]) del buf[:total] if not fut.done(): fut.set_result(frame) return fut return _QpqQuicProtocol class QuicTransport: """Async QUIC transport for the qpq v2 wire format. Usage:: transport = await QuicTransport.connect("127.0.0.1:5001") response_bytes = await transport.rpc(method_id, request_payload) transport.close() """ def __init__( self, protocol: Any, connection: Any, request_timeout_ms: int, ) -> None: self._protocol = protocol self._connection = connection self._req_id = 0 self._request_timeout = request_timeout_ms / 1000.0 self._closed = False @staticmethod async def connect( addr: str, *, ca_cert_path: str = "", server_name: str = "", insecure_skip_verify: bool = False, connect_timeout_ms: int = 5_000, request_timeout_ms: int = 10_000, ) -> "QuicTransport": """Open a QUIC connection to the server.""" from aioquic.asyncio import connect as quic_connect from aioquic.quic.configuration import QuicConfiguration host, _, port_str = addr.rpartition(":") if not host: host = addr port_str = "5001" port = int(port_str) configuration = QuicConfiguration( is_client=True, alpn_protocols=["qpq"], ) if insecure_skip_verify: configuration.verify_mode = ssl.CERT_NONE elif ca_cert_path: configuration.load_verify_locations(ca_cert_path) if not server_name: server_name = host proto_cls = _make_protocol_class() try: async with asyncio.timeout(connect_timeout_ms / 1000.0): connection = await quic_connect( host, port, configuration=configuration, create_protocol=proto_cls, server_name=server_name, ) except (OSError, asyncio.TimeoutError) as exc: raise ConnectionError(f"failed to connect to {addr}: {exc}") from exc protocol = connection._protocol # type: ignore[attr-defined] return QuicTransport(protocol, connection, request_timeout_ms) async def rpc(self, method_id: int, payload: bytes) -> bytes: """Send an RPC request and return the response payload (protobuf bytes). Opens a new QUIC stream for each request. """ if self._closed: raise ConnectionError("transport is closed") self._req_id += 1 req_id = self._req_id frame = encode_frame(method_id, req_id, payload) stream_id = self._protocol._quic.get_next_available_stream_id() waiter = self._protocol.wait_for_stream(stream_id) self._protocol._quic.send_stream_data(stream_id, frame, end_stream=True) self._protocol.transmit() try: async with asyncio.timeout(self._request_timeout): response_frame = await waiter except asyncio.TimeoutError as exc: raise TimeoutError( f"RPC timeout for method {method_id} (req_id={req_id})" ) from exc _, _, resp_len = decode_header(response_frame) return response_frame[HEADER_SIZE : HEADER_SIZE + resp_len] @property def closed(self) -> bool: return self._closed def close(self) -> None: """Close the QUIC connection.""" if not self._closed: self._closed = True self._protocol._quic.close() self._protocol.transmit()