"""Minimal protobuf encode/decode for qpq v1 messages. Uses the ``protobuf`` library's descriptor-less encoding for simplicity. Each message is represented as a plain dict and encoded/decoded via google.protobuf helpers. This avoids requiring protoc at build time while still producing wire-compatible protobuf bytes. """ from __future__ import annotations from google.protobuf import descriptor_pb2 as _ # noqa: F401 – ensure protobuf is importable from google.protobuf.internal.encoder import _VarintBytes # type: ignore[attr-defined] from google.protobuf.internal.decoder import _DecodeVarint # type: ignore[attr-defined] # --------------------------------------------------------------------------- # Low-level protobuf helpers (wire types 0=varint, 2=length-delimited) # --------------------------------------------------------------------------- def _encode_varint_field(field_number: int, value: int) -> bytes: """Encode a varint (wire type 0) field.""" if value == 0: return b"" tag = (field_number << 3) | 0 return _VarintBytes(tag) + _VarintBytes(value) def _encode_bytes_field(field_number: int, value: bytes) -> bytes: """Encode a length-delimited (wire type 2) field.""" if not value: return b"" tag = (field_number << 3) | 2 return _VarintBytes(tag) + _VarintBytes(len(value)) + value def _encode_string_field(field_number: int, value: str) -> bytes: """Encode a string (wire type 2) field.""" return _encode_bytes_field(field_number, value.encode("utf-8")) def _decode_fields(data: bytes) -> dict[int, list[tuple[int, bytes | int]]]: """Decode a protobuf message into {field_number: [(wire_type, value), ...]}.""" fields: dict[int, list[tuple[int, bytes | int]]] = {} pos = 0 while pos < len(data): tag, pos = _DecodeVarint(data, pos) field_number = tag >> 3 wire_type = tag & 0x07 if wire_type == 0: # varint value, pos = _DecodeVarint(data, pos) fields.setdefault(field_number, []).append((wire_type, value)) elif wire_type == 2: # length-delimited length, pos = _DecodeVarint(data, pos) fields.setdefault(field_number, []).append((wire_type, data[pos : pos + length])) pos += length elif wire_type == 5: # 32-bit fixed fields.setdefault(field_number, []).append((wire_type, data[pos : pos + 4])) pos += 4 elif wire_type == 1: # 64-bit fixed fields.setdefault(field_number, []).append((wire_type, data[pos : pos + 8])) pos += 8 else: raise ValueError(f"unsupported wire type {wire_type}") return fields def _get_bytes(fields: dict[int, list[tuple[int, bytes | int]]], fn: int) -> bytes: entries = fields.get(fn, []) if not entries: return b"" _, val = entries[0] return val if isinstance(val, bytes) else b"" def _get_string(fields: dict[int, list[tuple[int, bytes | int]]], fn: int) -> str: return _get_bytes(fields, fn).decode("utf-8", errors="replace") def _get_varint(fields: dict[int, list[tuple[int, bytes | int]]], fn: int) -> int: entries = fields.get(fn, []) if not entries: return 0 _, val = entries[0] return val if isinstance(val, int) else 0 def _get_bool(fields: dict[int, list[tuple[int, bytes | int]]], fn: int) -> bool: return _get_varint(fields, fn) != 0 def _get_repeated_bytes(fields: dict[int, list[tuple[int, bytes | int]]], fn: int) -> list[bytes]: result: list[bytes] = [] for _, val in fields.get(fn, []): if isinstance(val, bytes): result.append(val) return result # --------------------------------------------------------------------------- # Auth # --------------------------------------------------------------------------- def encode_opaque_register_start(username: str, request: bytes) -> bytes: return _encode_string_field(1, username) + _encode_bytes_field(2, request) def decode_opaque_register_start_response(data: bytes) -> bytes: return _get_bytes(_decode_fields(data), 1) def encode_opaque_register_finish(username: str, upload: bytes, identity_key: bytes) -> bytes: return ( _encode_string_field(1, username) + _encode_bytes_field(2, upload) + _encode_bytes_field(3, identity_key) ) def decode_opaque_register_finish_response(data: bytes) -> bool: return _get_bool(_decode_fields(data), 1) def encode_opaque_login_start(username: str, request: bytes) -> bytes: return _encode_string_field(1, username) + _encode_bytes_field(2, request) def decode_opaque_login_start_response(data: bytes) -> bytes: return _get_bytes(_decode_fields(data), 1) def encode_opaque_login_finish(username: str, finalization: bytes, identity_key: bytes) -> bytes: return ( _encode_string_field(1, username) + _encode_bytes_field(2, finalization) + _encode_bytes_field(3, identity_key) ) def decode_opaque_login_finish_response(data: bytes) -> bytes: """Returns session_token.""" return _get_bytes(_decode_fields(data), 1) # --------------------------------------------------------------------------- # Delivery # --------------------------------------------------------------------------- def encode_enqueue( recipient_key: bytes, payload: bytes, channel_id: bytes = b"", ttl_secs: int = 0, message_id: bytes = b"", ) -> bytes: return ( _encode_bytes_field(1, recipient_key) + _encode_bytes_field(2, payload) + _encode_bytes_field(3, channel_id) + _encode_varint_field(4, ttl_secs) + _encode_bytes_field(5, message_id) ) def decode_enqueue_response(data: bytes) -> tuple[int, bytes, bool]: """Returns (seq, delivery_proof, duplicate).""" fields = _decode_fields(data) return _get_varint(fields, 1), _get_bytes(fields, 2), _get_bool(fields, 3) def encode_fetch( recipient_key: bytes, channel_id: bytes = b"", limit: int = 0, device_id: bytes = b"", ) -> bytes: return ( _encode_bytes_field(1, recipient_key) + _encode_bytes_field(2, channel_id) + _encode_varint_field(3, limit) + _encode_bytes_field(4, device_id) ) def decode_fetch_response(data: bytes) -> list[tuple[int, bytes]]: """Returns list of (seq, data) envelopes.""" fields = _decode_fields(data) envelopes: list[tuple[int, bytes]] = [] for _, val in fields.get(1, []): if isinstance(val, bytes): env_fields = _decode_fields(val) envelopes.append((_get_varint(env_fields, 1), _get_bytes(env_fields, 2))) return envelopes def encode_fetch_wait( recipient_key: bytes, channel_id: bytes = b"", timeout_ms: int = 5000, limit: int = 0, device_id: bytes = b"", ) -> bytes: return ( _encode_bytes_field(1, recipient_key) + _encode_bytes_field(2, channel_id) + _encode_varint_field(3, timeout_ms) + _encode_varint_field(4, limit) + _encode_bytes_field(5, device_id) ) # decode_fetch_wait_response = decode_fetch_response (same message shape) decode_fetch_wait_response = decode_fetch_response def encode_ack( recipient_key: bytes, seq_up_to: int, channel_id: bytes = b"", device_id: bytes = b"", ) -> bytes: return ( _encode_bytes_field(1, recipient_key) + _encode_bytes_field(2, channel_id) + _encode_varint_field(3, seq_up_to) + _encode_bytes_field(4, device_id) ) # --------------------------------------------------------------------------- # Channel # --------------------------------------------------------------------------- def encode_create_channel(peer_key: bytes) -> bytes: return _encode_bytes_field(1, peer_key) def decode_create_channel_response(data: bytes) -> tuple[bytes, bool]: """Returns (channel_id, was_new).""" fields = _decode_fields(data) return _get_bytes(fields, 1), _get_bool(fields, 2) # --------------------------------------------------------------------------- # User # --------------------------------------------------------------------------- def encode_resolve_user(username: str) -> bytes: return _encode_string_field(1, username) def decode_resolve_user_response(data: bytes) -> tuple[bytes, bytes]: """Returns (identity_key, inclusion_proof).""" fields = _decode_fields(data) return _get_bytes(fields, 1), _get_bytes(fields, 2) def encode_resolve_identity(identity_key: bytes) -> bytes: return _encode_bytes_field(1, identity_key) def decode_resolve_identity_response(data: bytes) -> str: """Returns username.""" return _get_string(_decode_fields(data), 1) # --------------------------------------------------------------------------- # Keys # --------------------------------------------------------------------------- def encode_upload_key_package(identity_key: bytes, package: bytes) -> bytes: return _encode_bytes_field(1, identity_key) + _encode_bytes_field(2, package) def decode_upload_key_package_response(data: bytes) -> bytes: return _get_bytes(_decode_fields(data), 1) def encode_fetch_key_package(identity_key: bytes) -> bytes: return _encode_bytes_field(1, identity_key) def decode_fetch_key_package_response(data: bytes) -> bytes: return _get_bytes(_decode_fields(data), 1) def encode_upload_hybrid_key(identity_key: bytes, hybrid_public_key: bytes) -> bytes: return _encode_bytes_field(1, identity_key) + _encode_bytes_field(2, hybrid_public_key) def encode_fetch_hybrid_key(identity_key: bytes) -> bytes: return _encode_bytes_field(1, identity_key) def decode_fetch_hybrid_key_response(data: bytes) -> bytes: return _get_bytes(_decode_fields(data), 1) # --------------------------------------------------------------------------- # Health # --------------------------------------------------------------------------- def encode_health() -> bytes: return b"" def decode_health_response(data: bytes) -> dict[str, str | int]: fields = _decode_fields(data) return { "status": _get_string(fields, 1), "node_id": _get_string(fields, 2), "version": _get_string(fields, 3), "uptime_secs": _get_varint(fields, 4), "storage_backend": _get_string(fields, 5), } # --------------------------------------------------------------------------- # Delete Account # --------------------------------------------------------------------------- def encode_delete_account() -> bytes: return b"" def decode_delete_account_response(data: bytes) -> bool: return _get_bool(_decode_fields(data), 1)