Rename all crate directories, package names, binary names, proto package/module paths, ALPN strings, env var prefixes, config filenames, mDNS service names, and plugin ABI symbols from quicproquo/qpq to quicprochat/qpc.
304 lines
10 KiB
Python
304 lines
10 KiB
Python
"""Minimal protobuf encode/decode for qpq v1 messages.
|
||
|
||
Uses the ``protobuf`` library's descriptor-less encoding for simplicity.
|
||
Each message is represented as a plain dict and encoded/decoded via
|
||
google.protobuf helpers.
|
||
|
||
This avoids requiring protoc at build time while still producing
|
||
wire-compatible protobuf bytes.
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
from google.protobuf import descriptor_pb2 as _ # noqa: F401 – ensure protobuf is importable
|
||
from google.protobuf.internal.encoder import _VarintBytes # type: ignore[attr-defined]
|
||
from google.protobuf.internal.decoder import _DecodeVarint # type: ignore[attr-defined]
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Low-level protobuf helpers (wire types 0=varint, 2=length-delimited)
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def _encode_varint_field(field_number: int, value: int) -> bytes:
|
||
"""Encode a varint (wire type 0) field."""
|
||
if value == 0:
|
||
return b""
|
||
tag = (field_number << 3) | 0
|
||
return _VarintBytes(tag) + _VarintBytes(value)
|
||
|
||
|
||
def _encode_bytes_field(field_number: int, value: bytes) -> bytes:
|
||
"""Encode a length-delimited (wire type 2) field."""
|
||
if not value:
|
||
return b""
|
||
tag = (field_number << 3) | 2
|
||
return _VarintBytes(tag) + _VarintBytes(len(value)) + value
|
||
|
||
|
||
def _encode_string_field(field_number: int, value: str) -> bytes:
|
||
"""Encode a string (wire type 2) field."""
|
||
return _encode_bytes_field(field_number, value.encode("utf-8"))
|
||
|
||
|
||
def _decode_fields(data: bytes) -> dict[int, list[tuple[int, bytes | int]]]:
|
||
"""Decode a protobuf message into {field_number: [(wire_type, value), ...]}."""
|
||
fields: dict[int, list[tuple[int, bytes | int]]] = {}
|
||
pos = 0
|
||
while pos < len(data):
|
||
tag, pos = _DecodeVarint(data, pos)
|
||
field_number = tag >> 3
|
||
wire_type = tag & 0x07
|
||
if wire_type == 0: # varint
|
||
value, pos = _DecodeVarint(data, pos)
|
||
fields.setdefault(field_number, []).append((wire_type, value))
|
||
elif wire_type == 2: # length-delimited
|
||
length, pos = _DecodeVarint(data, pos)
|
||
fields.setdefault(field_number, []).append((wire_type, data[pos : pos + length]))
|
||
pos += length
|
||
elif wire_type == 5: # 32-bit fixed
|
||
fields.setdefault(field_number, []).append((wire_type, data[pos : pos + 4]))
|
||
pos += 4
|
||
elif wire_type == 1: # 64-bit fixed
|
||
fields.setdefault(field_number, []).append((wire_type, data[pos : pos + 8]))
|
||
pos += 8
|
||
else:
|
||
raise ValueError(f"unsupported wire type {wire_type}")
|
||
return fields
|
||
|
||
|
||
def _get_bytes(fields: dict[int, list[tuple[int, bytes | int]]], fn: int) -> bytes:
|
||
entries = fields.get(fn, [])
|
||
if not entries:
|
||
return b""
|
||
_, val = entries[0]
|
||
return val if isinstance(val, bytes) else b""
|
||
|
||
|
||
def _get_string(fields: dict[int, list[tuple[int, bytes | int]]], fn: int) -> str:
|
||
return _get_bytes(fields, fn).decode("utf-8", errors="replace")
|
||
|
||
|
||
def _get_varint(fields: dict[int, list[tuple[int, bytes | int]]], fn: int) -> int:
|
||
entries = fields.get(fn, [])
|
||
if not entries:
|
||
return 0
|
||
_, val = entries[0]
|
||
return val if isinstance(val, int) else 0
|
||
|
||
|
||
def _get_bool(fields: dict[int, list[tuple[int, bytes | int]]], fn: int) -> bool:
|
||
return _get_varint(fields, fn) != 0
|
||
|
||
|
||
def _get_repeated_bytes(fields: dict[int, list[tuple[int, bytes | int]]], fn: int) -> list[bytes]:
|
||
result: list[bytes] = []
|
||
for _, val in fields.get(fn, []):
|
||
if isinstance(val, bytes):
|
||
result.append(val)
|
||
return result
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Auth
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def encode_opaque_register_start(username: str, request: bytes) -> bytes:
|
||
return _encode_string_field(1, username) + _encode_bytes_field(2, request)
|
||
|
||
def decode_opaque_register_start_response(data: bytes) -> bytes:
|
||
return _get_bytes(_decode_fields(data), 1)
|
||
|
||
def encode_opaque_register_finish(username: str, upload: bytes, identity_key: bytes) -> bytes:
|
||
return (
|
||
_encode_string_field(1, username)
|
||
+ _encode_bytes_field(2, upload)
|
||
+ _encode_bytes_field(3, identity_key)
|
||
)
|
||
|
||
def decode_opaque_register_finish_response(data: bytes) -> bool:
|
||
return _get_bool(_decode_fields(data), 1)
|
||
|
||
def encode_opaque_login_start(username: str, request: bytes) -> bytes:
|
||
return _encode_string_field(1, username) + _encode_bytes_field(2, request)
|
||
|
||
def decode_opaque_login_start_response(data: bytes) -> bytes:
|
||
return _get_bytes(_decode_fields(data), 1)
|
||
|
||
def encode_opaque_login_finish(username: str, finalization: bytes, identity_key: bytes) -> bytes:
|
||
return (
|
||
_encode_string_field(1, username)
|
||
+ _encode_bytes_field(2, finalization)
|
||
+ _encode_bytes_field(3, identity_key)
|
||
)
|
||
|
||
def decode_opaque_login_finish_response(data: bytes) -> bytes:
|
||
"""Returns session_token."""
|
||
return _get_bytes(_decode_fields(data), 1)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Delivery
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def encode_enqueue(
|
||
recipient_key: bytes,
|
||
payload: bytes,
|
||
channel_id: bytes = b"",
|
||
ttl_secs: int = 0,
|
||
message_id: bytes = b"",
|
||
) -> bytes:
|
||
return (
|
||
_encode_bytes_field(1, recipient_key)
|
||
+ _encode_bytes_field(2, payload)
|
||
+ _encode_bytes_field(3, channel_id)
|
||
+ _encode_varint_field(4, ttl_secs)
|
||
+ _encode_bytes_field(5, message_id)
|
||
)
|
||
|
||
def decode_enqueue_response(data: bytes) -> tuple[int, bytes, bool]:
|
||
"""Returns (seq, delivery_proof, duplicate)."""
|
||
fields = _decode_fields(data)
|
||
return _get_varint(fields, 1), _get_bytes(fields, 2), _get_bool(fields, 3)
|
||
|
||
def encode_fetch(
|
||
recipient_key: bytes,
|
||
channel_id: bytes = b"",
|
||
limit: int = 0,
|
||
device_id: bytes = b"",
|
||
) -> bytes:
|
||
return (
|
||
_encode_bytes_field(1, recipient_key)
|
||
+ _encode_bytes_field(2, channel_id)
|
||
+ _encode_varint_field(3, limit)
|
||
+ _encode_bytes_field(4, device_id)
|
||
)
|
||
|
||
def decode_fetch_response(data: bytes) -> list[tuple[int, bytes]]:
|
||
"""Returns list of (seq, data) envelopes."""
|
||
fields = _decode_fields(data)
|
||
envelopes: list[tuple[int, bytes]] = []
|
||
for _, val in fields.get(1, []):
|
||
if isinstance(val, bytes):
|
||
env_fields = _decode_fields(val)
|
||
envelopes.append((_get_varint(env_fields, 1), _get_bytes(env_fields, 2)))
|
||
return envelopes
|
||
|
||
def encode_fetch_wait(
|
||
recipient_key: bytes,
|
||
channel_id: bytes = b"",
|
||
timeout_ms: int = 5000,
|
||
limit: int = 0,
|
||
device_id: bytes = b"",
|
||
) -> bytes:
|
||
return (
|
||
_encode_bytes_field(1, recipient_key)
|
||
+ _encode_bytes_field(2, channel_id)
|
||
+ _encode_varint_field(3, timeout_ms)
|
||
+ _encode_varint_field(4, limit)
|
||
+ _encode_bytes_field(5, device_id)
|
||
)
|
||
|
||
# decode_fetch_wait_response = decode_fetch_response (same message shape)
|
||
decode_fetch_wait_response = decode_fetch_response
|
||
|
||
def encode_ack(
|
||
recipient_key: bytes,
|
||
seq_up_to: int,
|
||
channel_id: bytes = b"",
|
||
device_id: bytes = b"",
|
||
) -> bytes:
|
||
return (
|
||
_encode_bytes_field(1, recipient_key)
|
||
+ _encode_bytes_field(2, channel_id)
|
||
+ _encode_varint_field(3, seq_up_to)
|
||
+ _encode_bytes_field(4, device_id)
|
||
)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Channel
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def encode_create_channel(peer_key: bytes) -> bytes:
|
||
return _encode_bytes_field(1, peer_key)
|
||
|
||
def decode_create_channel_response(data: bytes) -> tuple[bytes, bool]:
|
||
"""Returns (channel_id, was_new)."""
|
||
fields = _decode_fields(data)
|
||
return _get_bytes(fields, 1), _get_bool(fields, 2)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# User
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def encode_resolve_user(username: str) -> bytes:
|
||
return _encode_string_field(1, username)
|
||
|
||
def decode_resolve_user_response(data: bytes) -> tuple[bytes, bytes]:
|
||
"""Returns (identity_key, inclusion_proof)."""
|
||
fields = _decode_fields(data)
|
||
return _get_bytes(fields, 1), _get_bytes(fields, 2)
|
||
|
||
def encode_resolve_identity(identity_key: bytes) -> bytes:
|
||
return _encode_bytes_field(1, identity_key)
|
||
|
||
def decode_resolve_identity_response(data: bytes) -> str:
|
||
"""Returns username."""
|
||
return _get_string(_decode_fields(data), 1)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Keys
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def encode_upload_key_package(identity_key: bytes, package: bytes) -> bytes:
|
||
return _encode_bytes_field(1, identity_key) + _encode_bytes_field(2, package)
|
||
|
||
def decode_upload_key_package_response(data: bytes) -> bytes:
|
||
return _get_bytes(_decode_fields(data), 1)
|
||
|
||
def encode_fetch_key_package(identity_key: bytes) -> bytes:
|
||
return _encode_bytes_field(1, identity_key)
|
||
|
||
def decode_fetch_key_package_response(data: bytes) -> bytes:
|
||
return _get_bytes(_decode_fields(data), 1)
|
||
|
||
def encode_upload_hybrid_key(identity_key: bytes, hybrid_public_key: bytes) -> bytes:
|
||
return _encode_bytes_field(1, identity_key) + _encode_bytes_field(2, hybrid_public_key)
|
||
|
||
def encode_fetch_hybrid_key(identity_key: bytes) -> bytes:
|
||
return _encode_bytes_field(1, identity_key)
|
||
|
||
def decode_fetch_hybrid_key_response(data: bytes) -> bytes:
|
||
return _get_bytes(_decode_fields(data), 1)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Health
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def encode_health() -> bytes:
|
||
return b""
|
||
|
||
def decode_health_response(data: bytes) -> dict[str, str | int]:
|
||
fields = _decode_fields(data)
|
||
return {
|
||
"status": _get_string(fields, 1),
|
||
"node_id": _get_string(fields, 2),
|
||
"version": _get_string(fields, 3),
|
||
"uptime_secs": _get_varint(fields, 4),
|
||
"storage_backend": _get_string(fields, 5),
|
||
}
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Delete Account
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def encode_delete_account() -> bytes:
|
||
return b""
|
||
|
||
def decode_delete_account_response(data: bytes) -> bool:
|
||
return _get_bool(_decode_fields(data), 1)
|