Rename all crate directories, package names, binary names, proto package/module paths, ALPN strings, env var prefixes, config filenames, mDNS service names, and plugin ABI symbols from quicproquo/qpq to quicprochat/qpc.
182 lines
6.1 KiB
Python
182 lines
6.1 KiB
Python
"""QUIC transport using aioquic for the v2 wire format.
|
|
|
|
Opens a QUIC connection to the qpq server and provides ``rpc()`` to send
|
|
protobuf-encoded requests over individual QUIC streams, reading back the
|
|
framed response on the same stream.
|
|
|
|
aioquic is imported lazily so that the module can be loaded even when
|
|
aioquic is not installed (e.g. for tests that only exercise wire/proto).
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import ssl
|
|
from typing import Any
|
|
|
|
from quicproquo.types import ConnectionError, TimeoutError
|
|
from quicproquo.wire import HEADER_SIZE, encode_frame, decode_header
|
|
|
|
|
|
def _make_protocol_class() -> type:
|
|
"""Build the protocol class at call time so aioquic is imported lazily."""
|
|
from aioquic.asyncio.protocol import QuicConnectionProtocol
|
|
from aioquic.quic.events import StreamDataReceived, QuicEvent
|
|
|
|
class _QpqQuicProtocol(QuicConnectionProtocol):
|
|
"""QUIC protocol handler that dispatches stream data to waiting futures."""
|
|
|
|
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
|
super().__init__(*args, **kwargs)
|
|
self._stream_buffers: dict[int, bytearray] = {}
|
|
self._stream_waiters: dict[int, asyncio.Future[bytes]] = {}
|
|
|
|
def quic_event_received(self, event: QuicEvent) -> None:
|
|
if isinstance(event, StreamDataReceived):
|
|
sid = event.stream_id
|
|
buf = self._stream_buffers.setdefault(sid, bytearray())
|
|
buf.extend(event.data)
|
|
|
|
if len(buf) >= HEADER_SIZE:
|
|
_, _, payload_len = decode_header(bytes(buf[:HEADER_SIZE]))
|
|
total = HEADER_SIZE + payload_len
|
|
if len(buf) >= total:
|
|
frame = bytes(buf[:total])
|
|
del buf[:total]
|
|
waiter = self._stream_waiters.pop(sid, None)
|
|
if waiter and not waiter.done():
|
|
waiter.set_result(frame)
|
|
|
|
def wait_for_stream(self, stream_id: int) -> asyncio.Future[bytes]:
|
|
loop = asyncio.get_event_loop()
|
|
fut: asyncio.Future[bytes] = loop.create_future()
|
|
self._stream_waiters[stream_id] = fut
|
|
|
|
buf = self._stream_buffers.get(stream_id, bytearray())
|
|
if len(buf) >= HEADER_SIZE:
|
|
_, _, payload_len = decode_header(bytes(buf[:HEADER_SIZE]))
|
|
total = HEADER_SIZE + payload_len
|
|
if len(buf) >= total:
|
|
frame = bytes(buf[:total])
|
|
del buf[:total]
|
|
if not fut.done():
|
|
fut.set_result(frame)
|
|
|
|
return fut
|
|
|
|
return _QpqQuicProtocol
|
|
|
|
|
|
class QuicTransport:
|
|
"""Async QUIC transport for the qpq v2 wire format.
|
|
|
|
Usage::
|
|
|
|
transport = await QuicTransport.connect("127.0.0.1:5001")
|
|
response_bytes = await transport.rpc(method_id, request_payload)
|
|
transport.close()
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
protocol: Any,
|
|
connection: Any,
|
|
request_timeout_ms: int,
|
|
) -> None:
|
|
self._protocol = protocol
|
|
self._connection = connection
|
|
self._req_id = 0
|
|
self._request_timeout = request_timeout_ms / 1000.0
|
|
self._closed = False
|
|
|
|
@staticmethod
|
|
async def connect(
|
|
addr: str,
|
|
*,
|
|
ca_cert_path: str = "",
|
|
server_name: str = "",
|
|
insecure_skip_verify: bool = False,
|
|
connect_timeout_ms: int = 5_000,
|
|
request_timeout_ms: int = 10_000,
|
|
) -> "QuicTransport":
|
|
"""Open a QUIC connection to the server."""
|
|
from aioquic.asyncio import connect as quic_connect
|
|
from aioquic.quic.configuration import QuicConfiguration
|
|
|
|
host, _, port_str = addr.rpartition(":")
|
|
if not host:
|
|
host = addr
|
|
port_str = "5001"
|
|
port = int(port_str)
|
|
|
|
configuration = QuicConfiguration(
|
|
is_client=True,
|
|
alpn_protocols=["qpq"],
|
|
)
|
|
|
|
if insecure_skip_verify:
|
|
configuration.verify_mode = ssl.CERT_NONE
|
|
elif ca_cert_path:
|
|
configuration.load_verify_locations(ca_cert_path)
|
|
|
|
if not server_name:
|
|
server_name = host
|
|
|
|
proto_cls = _make_protocol_class()
|
|
|
|
try:
|
|
async with asyncio.timeout(connect_timeout_ms / 1000.0):
|
|
connection = await quic_connect(
|
|
host,
|
|
port,
|
|
configuration=configuration,
|
|
create_protocol=proto_cls,
|
|
server_name=server_name,
|
|
)
|
|
except (OSError, asyncio.TimeoutError) as exc:
|
|
raise ConnectionError(f"failed to connect to {addr}: {exc}") from exc
|
|
|
|
protocol = connection._protocol # type: ignore[attr-defined]
|
|
return QuicTransport(protocol, connection, request_timeout_ms)
|
|
|
|
async def rpc(self, method_id: int, payload: bytes) -> bytes:
|
|
"""Send an RPC request and return the response payload (protobuf bytes).
|
|
|
|
Opens a new QUIC stream for each request.
|
|
"""
|
|
if self._closed:
|
|
raise ConnectionError("transport is closed")
|
|
|
|
self._req_id += 1
|
|
req_id = self._req_id
|
|
|
|
frame = encode_frame(method_id, req_id, payload)
|
|
|
|
stream_id = self._protocol._quic.get_next_available_stream_id()
|
|
waiter = self._protocol.wait_for_stream(stream_id)
|
|
|
|
self._protocol._quic.send_stream_data(stream_id, frame, end_stream=True)
|
|
self._protocol.transmit()
|
|
|
|
try:
|
|
async with asyncio.timeout(self._request_timeout):
|
|
response_frame = await waiter
|
|
except asyncio.TimeoutError as exc:
|
|
raise TimeoutError(
|
|
f"RPC timeout for method {method_id} (req_id={req_id})"
|
|
) from exc
|
|
|
|
_, _, resp_len = decode_header(response_frame)
|
|
return response_frame[HEADER_SIZE : HEADER_SIZE + resp_len]
|
|
|
|
@property
|
|
def closed(self) -> bool:
|
|
return self._closed
|
|
|
|
def close(self) -> None:
|
|
"""Close the QUIC connection."""
|
|
if not self._closed:
|
|
self._closed = True
|
|
self._protocol._quic.close()
|
|
self._protocol.transmit()
|