diff --git a/src/lean_spec/subspecs/networking/transport/quic/connection.py b/src/lean_spec/subspecs/networking/transport/quic/connection.py index 51573f5b..5d6e29f1 100644 --- a/src/lean_spec/subspecs/networking/transport/quic/connection.py +++ b/src/lean_spec/subspecs/networking/transport/quic/connection.py @@ -46,7 +46,7 @@ from ..peer_id import PeerId from .stream import QuicStream, QuicTransportError from .stream_adapter import QuicStreamAdapter -from .tls import generate_libp2p_certificate +from .tls import PeerVerificationError, generate_libp2p_certificate, verify_libp2p_certificate logger = logging.getLogger(__name__) @@ -212,6 +212,15 @@ def __init__(self, *args, **kwargs) -> None: self.handshake_complete = asyncio.Event() self._buffered_events: list[QuicEvent] = [] + # Why: libp2p QUIC mandates mutual authentication. + # The server must request the client certificate so the libp2p + # extension is available on both sides of the handshake. + # aioquic does not expose mTLS as a public option; this private + # flag is the established libp2p-on-aioquic mechanism. + quic = getattr(self, "_quic", None) + if quic is not None and not quic.configuration.is_client: + quic.tls._request_client_certificate = True + def quic_event_received(self, event: QuicEvent) -> None: """Handle QUIC events.""" if isinstance(event, HandshakeCompleted): @@ -424,22 +433,23 @@ async def connect(self, multiaddr: str) -> QuicConnection: # Wait for handshake to complete. await protocol.handshake_complete.wait() - # For now, we don't verify peer certificate (requires deeper aioquic integration). - # In production, we would extract and verify the libp2p certificate extension. + # Derive and verify the peer identity from the server's certificate. # - # Without peer ID verification, we trust the connection based on: - # - QUIC encryption (TLS 1.3) - # - The peer being at the expected address - - # Create a placeholder peer_id if we couldn't verify. - # In a real implementation, we'd extract this from the certificate. - if expected_peer_id: - peer_id = expected_peer_id - else: - # Generate a random peer ID for now. - # This is NOT correct for production but allows testing. - temp_key = IdentityKeypair.generate() - peer_id = temp_key.to_peer_id() + # Why: TLS 1.3 always delivers the server's certificate to the + # client. The libp2p extension binds the identity key to the TLS + # key; verifying the signature proves the server controls both. + peer_cert = protocol._quic.tls._peer_certificate + if peer_cert is None: + raise QuicTransportError("Peer did not present a TLS certificate") + peer_id = verify_libp2p_certificate(peer_cert) + + # Invariant: the dialed multiaddr must name the same peer we ended + # up authenticated with. + # Mismatch indicates a man-in-the-middle or a stale ENR; reject it. + if expected_peer_id is not None and peer_id != expected_peer_id: + raise QuicTransportError( + f"Peer identity mismatch: expected {expected_peer_id}, got {peer_id}" + ) conn = QuicConnection( _protocol=protocol, @@ -499,8 +509,23 @@ async def listen( # Callback to set up connection when handshake completes. # Captures this manager's state (self, on_connection, host, port). def handle_handshake(protocol_instance: LibP2PQuicProtocol) -> None: - temp_key = IdentityKeypair.generate() - remote_peer_id = temp_key.to_peer_id() + # Derive and verify the client's peer identity from the cert it + # sent during the mutual-TLS handshake. + # Drop the connection if the cert is missing or the libp2p + # extension fails to verify. + peer_cert = protocol_instance._quic.tls._peer_certificate + if peer_cert is None: + logger.warning("Inbound connection without client certificate, closing") + protocol_instance._quic.close() + protocol_instance.transmit() + return + try: + remote_peer_id = verify_libp2p_certificate(peer_cert) + except PeerVerificationError as exc: + logger.warning("Rejecting inbound connection: %s", exc) + protocol_instance._quic.close() + protocol_instance.transmit() + return remote_addr = f"/ip4/{host}/udp/{port}/quic-v1/p2p/{remote_peer_id}" conn = QuicConnection( diff --git a/src/lean_spec/subspecs/networking/transport/quic/tls.py b/src/lean_spec/subspecs/networking/transport/quic/tls.py index 76783406..64b5182e 100644 --- a/src/lean_spec/subspecs/networking/transport/quic/tls.py +++ b/src/lean_spec/subspecs/networking/transport/quic/tls.py @@ -30,8 +30,10 @@ from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.primitives.asymmetric import ec -from ..identity import IdentityKeypair -from ..peer_id import KeyType +from lean_spec.subspecs.networking import varint + +from ..identity import IdentityKeypair, Secp256k1PublicKey +from ..peer_id import KeyType, PeerId, PublicKeyProto LIBP2P_EXTENSION_OID: Final = x509.ObjectIdentifier("1.3.6.1.4.1.53594.1.1") """libp2p TLS extension OID (Protocol Labs assigned).""" @@ -222,3 +224,194 @@ def _encode_asn1_length(length: int) -> bytes: return bytes([0x81, length]) else: return bytes([0x82, length >> 8, length & 0xFF]) + + +class PeerVerificationError(Exception): + """Raised when the libp2p TLS extension fails to validate.""" + + +def verify_libp2p_certificate(cert: x509.Certificate) -> PeerId: + """ + Extract and verify the peer identity from a libp2p TLS certificate. + + The certificate carries the libp2p extension (OID 1.3.6.1.4.1.53594.1.1). + The extension contains a SignedKey envelope binding the peer's identity + public key to the TLS public key: + + SignedKey ::= SEQUENCE { + publicKey OCTET STRING, -- protobuf-encoded PublicKey + signature OCTET STRING -- ECDSA over prefix || tls_pub_der + } + + The signature proves the identity-key holder controls the TLS key. + + Args: + cert: Peer's leaf X.509 certificate from the QUIC handshake. + + Returns: + Canonical PeerId derived from the verified identity public key. + + Raises: + PeerVerificationError: if the extension is missing, malformed, the + signature does not verify, or the key type is unsupported. + """ + # Locate the libp2p extension. + try: + ext = cert.extensions.get_extension_for_oid(LIBP2P_EXTENSION_OID) + except x509.ExtensionNotFound as exc: + raise PeerVerificationError("libp2p extension missing from certificate") from exc + + if not isinstance(ext.value, x509.UnrecognizedExtension): + raise PeerVerificationError("libp2p extension has unexpected type") + + # Parse the ASN.1 SignedKey envelope. + public_key_proto, signature = _parse_asn1_signed_key(ext.value.value) + + # Decode the protobuf PublicKey. + key_type, key_data = _decode_protobuf_public_key(public_key_proto) + + # The spec allows multiple key types, but only secp256k1 is used today. + if key_type != KeyType.SECP256K1: + raise PeerVerificationError(f"unsupported libp2p key type: {key_type}") + + # Reconstruct the identity public key from the compressed point. + try: + identity_key = Secp256k1PublicKey( + _key=ec.EllipticCurvePublicKey.from_encoded_point(ec.SECP256K1(), key_data) + ) + except ValueError as exc: + raise PeerVerificationError(f"invalid secp256k1 public key: {exc}") from exc + + # Extract the TLS public key as SubjectPublicKeyInfo DER. + # + # The signature was computed over (prefix || SubjectPublicKeyInfo DER) + # by the certificate generator, so we must reproduce the same bytes here. + tls_public_bytes = cert.public_key().public_bytes( + encoding=serialization.Encoding.DER, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + + if not identity_key.verify(SIGNATURE_PREFIX + tls_public_bytes, signature): + raise PeerVerificationError("libp2p extension signature failed verification") + + # Derive the canonical PeerId via the protobuf-encoded identity key. + return PeerId.from_public_key(PublicKeyProto(key_type=key_type, key_data=key_data)) + + +def _parse_asn1_signed_key(payload: bytes) -> tuple[bytes, bytes]: + """ + Parse the SignedKey envelope. + + Mirrors _encode_asn1_signed_key. Accepts the three DER length forms used + by the encoder (short, 0x81, 0x82) and rejects any other form. + + Returns: + (public_key_proto_bytes, signature_bytes) tuple. + + Raises: + PeerVerificationError: if the envelope is malformed. + """ + # Outer SEQUENCE. + body, rest = _parse_asn1_tlv(payload, expected_tag=0x30) + if rest: + raise PeerVerificationError("trailing bytes after SignedKey SEQUENCE") + + # Two OCTET STRINGs inside the SEQUENCE. + public_key_proto, after_first = _parse_asn1_tlv(body, expected_tag=0x04) + signature, after_second = _parse_asn1_tlv(after_first, expected_tag=0x04) + if after_second: + raise PeerVerificationError("trailing bytes after SignedKey contents") + + return public_key_proto, signature + + +def _parse_asn1_tlv(data: bytes, *, expected_tag: int) -> tuple[bytes, bytes]: + """ + Parse a single ASN.1 TLV with the given tag and return (value, rest). + + Raises: + PeerVerificationError: if the data is truncated, the tag mismatches, + or the length uses an unsupported encoding form. + """ + if not data: + raise PeerVerificationError("ASN.1 TLV truncated at tag") + if data[0] != expected_tag: + raise PeerVerificationError( + f"ASN.1 tag mismatch: expected 0x{expected_tag:02x}, got 0x{data[0]:02x}" + ) + + length, length_size = _decode_asn1_length(data[1:]) + start = 1 + length_size + end = start + length + if end > len(data): + raise PeerVerificationError("ASN.1 TLV truncated at value") + return data[start:end], data[end:] + + +def _decode_asn1_length(data: bytes) -> tuple[int, int]: + """ + Decode an ASN.1 DER length and return (length, bytes_consumed). + + Only the three forms emitted by _encode_asn1_length are accepted: + - short form (length < 128, single byte) + - long-1 form (0x81 + 1 byte) + - long-2 form (0x82 + 2 bytes) + """ + if not data: + raise PeerVerificationError("ASN.1 length truncated") + + first = data[0] + if first < 0x80: + return first, 1 + if first == 0x81: + if len(data) < 2: + raise PeerVerificationError("ASN.1 long-1 length truncated") + return data[1], 2 + if first == 0x82: + if len(data) < 3: + raise PeerVerificationError("ASN.1 long-2 length truncated") + return (data[1] << 8) | data[2], 3 + raise PeerVerificationError(f"unsupported ASN.1 length form: 0x{first:02x}") + + +def _decode_protobuf_public_key(payload: bytes) -> tuple[KeyType, bytes]: + """ + Decode the protobuf PublicKey message. + + Wire format (deterministic encoding, fields in tag order): + + [0x08][type_varint][0x12][length_varint][key_bytes] + + Returns: + (key_type, key_data) tuple. + + Raises: + PeerVerificationError: on truncation, missing field, or unknown tag. + """ + if len(payload) < 2 or payload[0] != 0x08: + raise PeerVerificationError("protobuf PublicKey: missing Type tag") + try: + type_value, type_size = varint.decode_varint(payload, offset=1) + except varint.VarintError as exc: + raise PeerVerificationError(f"protobuf PublicKey Type: {exc}") from exc + + try: + key_type = KeyType(type_value) + except ValueError as exc: + raise PeerVerificationError(f"protobuf PublicKey: unknown KeyType {type_value}") from exc + + data_start = 1 + type_size + if data_start >= len(payload) or payload[data_start] != 0x12: + raise PeerVerificationError("protobuf PublicKey: missing Data tag") + try: + data_length, length_size = varint.decode_varint(payload, offset=data_start + 1) + except varint.VarintError as exc: + raise PeerVerificationError(f"protobuf PublicKey Data length: {exc}") from exc + + key_start = data_start + 1 + length_size + key_end = key_start + data_length + if key_end > len(payload): + raise PeerVerificationError("protobuf PublicKey: Data truncated") + if key_end != len(payload): + raise PeerVerificationError("protobuf PublicKey: trailing bytes after Data") + return key_type, payload[key_start:key_end] diff --git a/tests/lean_spec/subspecs/networking/transport/quic/test_connection.py b/tests/lean_spec/subspecs/networking/transport/quic/test_connection.py index f019b4c5..dc7d61df 100644 --- a/tests/lean_spec/subspecs/networking/transport/quic/test_connection.py +++ b/tests/lean_spec/subspecs/networking/transport/quic/test_connection.py @@ -14,6 +14,7 @@ import pytest from lean_spec.subspecs.networking.config import LIBP2P_ALPN_PROTOCOL +from lean_spec.subspecs.networking.transport.identity.keypair import IdentityKeypair from lean_spec.subspecs.networking.transport.peer_id import PeerId from lean_spec.subspecs.networking.transport.quic.connection import ( ConnectionTerminated, @@ -31,6 +32,7 @@ QuicStreamResetError, QuicTransportError, ) +from lean_spec.subspecs.networking.transport.quic.tls import generate_libp2p_certificate from lean_spec.subspecs.networking.types import ProtocolId # --------------------------------------------------------------------------- @@ -995,10 +997,10 @@ class TestQuicConnectionManagerConnect: """Tests for outbound QUIC connection establishment. Connecting parses the multiaddr, creates a QUIC session via aioquic, - waits for the TLS handshake, and wraps the result in a QuicConnection. - If the multiaddr includes a p2p component, the expected peer ID is used - directly. Otherwise a temporary peer ID is generated (full certificate - verification is not yet implemented). + waits for the TLS handshake, verifies the libp2p extension in the + server's certificate, and wraps the result in a QuicConnection. + A peer ID in the multiaddr is matched against the verified identity + to defeat man-in-the-middle attacks. """ @pytest.fixture @@ -1019,77 +1021,87 @@ async def test_connect_non_quic_raises(self, manager: QuicConnectionManager) -> with pytest.raises(QuicTransportError, match=r"Not a QUIC multiaddr"): await manager.connect("/ip4/127.0.0.1/udp/9000") + @staticmethod + def _mock_protocol_with_cert(cert: object) -> MagicMock: + """Build a protocol mock whose TLS layer presents the given peer cert.""" + protocol = MagicMock(spec=LibP2PQuicProtocol) + protocol.handshake_complete = asyncio.Event() + protocol.handshake_complete.set() + protocol.connection = None + protocol._buffered_events = [] + protocol._replay_buffered_events = MagicMock() + protocol._quic = MagicMock() + protocol._quic.tls = MagicMock() + protocol._quic.tls._peer_certificate = cert + return protocol + @patch("lean_spec.subspecs.networking.transport.quic.connection.quic_connect") - async def test_connect_happy_path_with_peer_id( + async def test_connect_verifies_peer_and_matches_multiaddr( self, mock_quic_connect: MagicMock, manager: QuicConnectionManager ) -> None: - """When the multiaddr includes a p2p component, that peer ID is used. - - This is the normal case — the caller knows who they're connecting to. - """ + """The verified peer ID must match the one named in the multiaddr.""" + peer_key = IdentityKeypair.generate() + _, _, peer_cert = generate_libp2p_certificate(peer_key) + expected_peer_id = peer_key.to_peer_id() - # Simulate a protocol whose TLS handshake already completed. - mock_protocol = MagicMock(spec=LibP2PQuicProtocol) - mock_protocol.handshake_complete = asyncio.Event() - mock_protocol.handshake_complete.set() - mock_protocol.connection = None - mock_protocol._buffered_events = [] - mock_protocol._replay_buffered_events = MagicMock() - - # Wire quic_connect to return our pre-configured protocol. + mock_protocol = self._mock_protocol_with_cert(peer_cert) mock_cm = AsyncMock() mock_cm.__aenter__ = AsyncMock(return_value=mock_protocol) mock_quic_connect.return_value = mock_cm - # Connect to a multiaddr that includes a peer ID. - multiaddr = "/ip4/127.0.0.1/udp/9000/quic-v1/p2p/peerB" + multiaddr = f"/ip4/127.0.0.1/udp/9000/quic-v1/p2p/{expected_peer_id}" conn = await manager.connect(multiaddr) - # The peer ID from the multiaddr is used, not a generated one. - assert conn.peer_id == PeerId.from_base58("peerB") + assert conn.peer_id == expected_peer_id assert conn.remote_addr == multiaddr - - # Buffered events from the handshake window are replayed. mock_protocol._replay_buffered_events.assert_called_once() - @patch("lean_spec.subspecs.networking.transport.quic.connection.IdentityKeypair") @patch("lean_spec.subspecs.networking.transport.quic.connection.quic_connect") - async def test_connect_happy_path_without_peer_id( - self, - mock_quic_connect: MagicMock, - mock_identity_cls: MagicMock, - manager: QuicConnectionManager, + async def test_connect_derives_peer_from_cert_when_multiaddr_omits_it( + self, mock_quic_connect: MagicMock, manager: QuicConnectionManager ) -> None: - """Without a p2p component, a temporary peer ID is generated. + """Without a p2p component, the peer ID is taken from the verified cert.""" + peer_key = IdentityKeypair.generate() + _, _, peer_cert = generate_libp2p_certificate(peer_key) - Full peer certificate verification is not yet implemented. - This fallback allows connections to proceed during development. - """ + mock_protocol = self._mock_protocol_with_cert(peer_cert) + mock_cm = AsyncMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_protocol) + mock_quic_connect.return_value = mock_cm + + conn = await manager.connect("/ip4/127.0.0.1/udp/9000/quic-v1") + assert conn.peer_id == peer_key.to_peer_id() - # Simulate a protocol whose TLS handshake already completed. - mock_protocol = MagicMock(spec=LibP2PQuicProtocol) - mock_protocol.handshake_complete = asyncio.Event() - mock_protocol.handshake_complete.set() - mock_protocol.connection = None - mock_protocol._buffered_events = [] - mock_protocol._replay_buffered_events = MagicMock() + @patch("lean_spec.subspecs.networking.transport.quic.connection.quic_connect") + async def test_connect_peer_id_mismatch_raises( + self, mock_quic_connect: MagicMock, manager: QuicConnectionManager + ) -> None: + """Mismatch between multiaddr p2p component and verified cert is fatal.""" + actual_key = IdentityKeypair.generate() + _, _, peer_cert = generate_libp2p_certificate(actual_key) + other_key = IdentityKeypair.generate() - # Wire quic_connect to return our pre-configured protocol. + mock_protocol = self._mock_protocol_with_cert(peer_cert) mock_cm = AsyncMock() mock_cm.__aenter__ = AsyncMock(return_value=mock_protocol) mock_quic_connect.return_value = mock_cm - # Simulate generation of a temporary identity keypair. - temp_peer = PeerId.from_base58("tempPeer") - mock_temp_key = MagicMock() - mock_temp_key.to_peer_id.return_value = temp_peer - mock_identity_cls.generate.return_value = mock_temp_key + multiaddr = f"/ip4/127.0.0.1/udp/9000/quic-v1/p2p/{other_key.to_peer_id()}" + with pytest.raises(QuicTransportError, match=r"Peer identity mismatch"): + await manager.connect(multiaddr) - # Connect without a peer ID in the multiaddr. - conn = await manager.connect("/ip4/127.0.0.1/udp/9000/quic-v1") + @patch("lean_spec.subspecs.networking.transport.quic.connection.quic_connect") + async def test_connect_missing_peer_cert_raises( + self, mock_quic_connect: MagicMock, manager: QuicConnectionManager + ) -> None: + """A handshake that left no peer cert is rejected up front.""" + mock_protocol = self._mock_protocol_with_cert(None) + mock_cm = AsyncMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_protocol) + mock_quic_connect.return_value = mock_cm - # A temporary peer ID was generated and used. - assert conn.peer_id == temp_peer + with pytest.raises(QuicTransportError, match=r"did not present a TLS certificate"): + await manager.connect("/ip4/127.0.0.1/udp/9000/quic-v1") @patch("lean_spec.subspecs.networking.transport.quic.connection.quic_connect") async def test_connect_wraps_exception( @@ -1191,27 +1203,24 @@ async def test_listen_configures_server_and_serves( @patch("lean_spec.subspecs.networking.transport.quic.connection.QuicConfiguration") @patch("lean_spec.subspecs.networking.transport.quic.connection.quic_serve") - @patch("lean_spec.subspecs.networking.transport.quic.connection.IdentityKeypair") async def test_listen_handle_handshake_creates_connection( self, - mock_identity_cls: MagicMock, mock_quic_serve: MagicMock, mock_config_cls: MagicMock, manager_with_temp_dir: QuicConnectionManager, ) -> None: - """The handshake callback creates and registers a QuicConnection. + """The handshake callback verifies the client cert and registers the connection. - This test captures the protocol factory that listen() passes to - quic_serve, then invokes the handshake callback to verify that - a connection is correctly wired up and registered. + Captures the protocol factory that listen() passes to quic_serve, + then invokes the handshake callback with a real libp2p-TLS certificate + and asserts that the derived peer ID is used. """ mock_config_cls.return_value = MagicMock() - # Simulate generation of a remote peer identity. - temp_peer = PeerId.from_base58("remotePeer") - mock_temp_key = MagicMock() - mock_temp_key.to_peer_id.return_value = temp_peer - mock_identity_cls.generate.return_value = mock_temp_key + # Real keypair and cert for the remote (client) side of the handshake. + remote_key = IdentityKeypair.generate() + _, _, remote_cert = generate_libp2p_certificate(remote_key) + expected_peer_id = remote_key.to_peer_id() # Capture the protocol factory that listen() passes to quic_serve. captured_create_protocol = None @@ -1250,18 +1259,72 @@ async def capture_serve(*args: object, **kwargs: object) -> MagicMock: # The factory must have attached a handshake callback. assert proto_instance._on_handshake is not None - # Simulate the handshake callback being invoked by aioquic. + # Simulate the handshake callback being invoked by aioquic, with the + # client's verified certificate already attached by the TLS layer. proto_instance._quic = MagicMock() + proto_instance._quic.tls = MagicMock() + proto_instance._quic.tls._peer_certificate = remote_cert proto_instance.transmit = MagicMock() proto_instance.connection = None proto_instance._buffered_events = [] proto_instance._on_handshake(proto_instance) - # The callback creates a connection and registers it in the manager. + # The callback derives the peer ID from the cert and registers the connection. assert proto_instance.connection is not None - assert proto_instance.connection.peer_id == temp_peer - assert temp_peer in manager_with_temp_dir._connections + assert proto_instance.connection.peer_id == expected_peer_id + assert expected_peer_id in manager_with_temp_dir._connections + + @patch("lean_spec.subspecs.networking.transport.quic.connection.QuicConfiguration") + @patch("lean_spec.subspecs.networking.transport.quic.connection.quic_serve") + async def test_listen_handle_handshake_rejects_missing_cert( + self, + mock_quic_serve: MagicMock, + mock_config_cls: MagicMock, + manager_with_temp_dir: QuicConnectionManager, + ) -> None: + """Inbound connection without a client certificate is closed and not registered.""" + mock_config_cls.return_value = MagicMock() + + captured_create_protocol = None + + async def capture_serve(*args: object, **kwargs: object) -> MagicMock: + nonlocal captured_create_protocol + captured_create_protocol = kwargs.get("create_protocol") + return MagicMock() + + mock_quic_serve.side_effect = capture_serve + + mock_event = MagicMock() + mock_event.wait = AsyncMock(side_effect=asyncio.CancelledError) + + with ( + patch( + "lean_spec.subspecs.networking.transport.quic.connection.asyncio.Event" + ) as mock_event_cls, + pytest.raises(asyncio.CancelledError), + ): + mock_event_cls.return_value = mock_event + await manager_with_temp_dir.listen("/ip4/0.0.0.0/udp/9000/quic-v1", AsyncMock()) + + assert captured_create_protocol is not None + with patch.object(LibP2PQuicProtocol.__bases__[0], "__init__", return_value=None): + proto_instance = captured_create_protocol() + + proto_instance._quic = MagicMock() + proto_instance._quic.tls = MagicMock() + proto_instance._quic.tls._peer_certificate = None + proto_instance._quic.close = MagicMock() + proto_instance.transmit = MagicMock() + proto_instance.connection = None + proto_instance._buffered_events = [] + + proto_instance._on_handshake(proto_instance) + + # No connection was registered and the QUIC session was torn down. + assert proto_instance.connection is None + assert manager_with_temp_dir._connections == {} + proto_instance._quic.close.assert_called_once() @patch("lean_spec.subspecs.networking.transport.quic.connection.QuicConfiguration") @patch("lean_spec.subspecs.networking.transport.quic.connection.quic_serve") diff --git a/tests/lean_spec/subspecs/networking/transport/quic/test_tls.py b/tests/lean_spec/subspecs/networking/transport/quic/test_tls.py index 15c6ad03..5a2153a5 100644 --- a/tests/lean_spec/subspecs/networking/transport/quic/test_tls.py +++ b/tests/lean_spec/subspecs/networking/transport/quic/test_tls.py @@ -11,10 +11,11 @@ from __future__ import annotations import hashlib +from datetime import datetime, timedelta, timezone import pytest from cryptography import x509 -from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.primitives.asymmetric import ec from lean_spec.subspecs.networking.transport.identity.keypair import IdentityKeypair @@ -22,12 +23,14 @@ from lean_spec.subspecs.networking.transport.quic.tls import ( LIBP2P_EXTENSION_OID, SIGNATURE_PREFIX, + PeerVerificationError, _create_extension_payload, _encode_asn1_length, _encode_asn1_octet_string, _encode_asn1_sequence, _encode_asn1_signed_key, generate_libp2p_certificate, + verify_libp2p_certificate, ) # --------------------------------------------------------------------------- @@ -399,3 +402,128 @@ def _parse_der_tlv_with_rest(data: bytes) -> tuple[int, bytes, bytes]: value_start = 1 + length_size value_end = value_start + length return tag, data[value_start:value_end], data[value_end:] + + +# --------------------------------------------------------------------------- +# verify_libp2p_certificate — happy path + rejection cases +# --------------------------------------------------------------------------- + + +def _cert_without_libp2p_extension() -> x509.Certificate: + """Build a self-signed cert that has no libp2p extension.""" + key = ec.generate_private_key(ec.SECP256R1()) + now = datetime.now(timezone.utc) + return ( + x509.CertificateBuilder() + .subject_name(x509.Name([])) + .issuer_name(x509.Name([])) + .public_key(key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(now - timedelta(days=1)) + .not_valid_after(now + timedelta(days=1)) + .sign(key, hashes.SHA256()) + ) + + +def _replace_libp2p_extension(cert: x509.Certificate, payload: bytes) -> x509.Certificate: + """Re-sign the cert with the libp2p extension replaced by the given bytes.""" + # Use the same TLS public key but a fresh private key (test-only). + # We can't re-sign with the original key (not exposed), so generate + # a new TLS keypair and rebuild the cert around the new payload. + # The verifier only inspects the libp2p extension and the cert's + # SubjectPublicKeyInfo; it does not check the cert's outer signature. + tls_key = ec.generate_private_key(ec.SECP256R1()) + now = datetime.now(timezone.utc) + return ( + x509.CertificateBuilder() + .subject_name(x509.Name([])) + .issuer_name(x509.Name([])) + .public_key(cert.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(now - timedelta(days=1)) + .not_valid_after(now + timedelta(days=1)) + .add_extension( + x509.UnrecognizedExtension(LIBP2P_EXTENSION_OID, payload), + critical=False, + ) + .sign(tls_key, hashes.SHA256()) + ) + + +class TestVerifyLibp2pCertificate: + """Validate end-to-end extraction of peer identity from a libp2p TLS cert.""" + + def test_roundtrip_recovers_peer_id(self, identity_key: IdentityKeypair) -> None: + """A freshly generated cert verifies and yields the original PeerId.""" + _, _, cert = generate_libp2p_certificate(identity_key) + assert verify_libp2p_certificate(cert) == identity_key.to_peer_id() + + def test_distinct_keys_yield_distinct_peer_ids(self) -> None: + """Two independent identity keys produce different verified PeerIds.""" + key_a = IdentityKeypair.generate() + key_b = IdentityKeypair.generate() + _, _, cert_a = generate_libp2p_certificate(key_a) + _, _, cert_b = generate_libp2p_certificate(key_b) + + peer_a = verify_libp2p_certificate(cert_a) + peer_b = verify_libp2p_certificate(cert_b) + + assert peer_a == key_a.to_peer_id() + assert peer_b == key_b.to_peer_id() + assert peer_a != peer_b + + def test_missing_extension_is_rejected(self) -> None: + """A cert without the libp2p extension fails verification.""" + cert = _cert_without_libp2p_extension() + with pytest.raises(PeerVerificationError, match="missing"): + verify_libp2p_certificate(cert) + + def test_tampered_signature_is_rejected(self, identity_key: IdentityKeypair) -> None: + """Flipping a byte inside the SignedKey signature breaks verification.""" + _, _, cert = generate_libp2p_certificate(identity_key) + ext = cert.extensions.get_extension_for_oid(LIBP2P_EXTENSION_OID) + assert isinstance(ext.value, x509.UnrecognizedExtension) + payload = bytearray(ext.value.value) + + # The signature is the second OCTET STRING; flip its last byte. + payload[-1] ^= 0x01 + tampered = _replace_libp2p_extension(cert, bytes(payload)) + + with pytest.raises(PeerVerificationError, match="signature"): + verify_libp2p_certificate(tampered) + + def test_malformed_outer_sequence_is_rejected(self, identity_key: IdentityKeypair) -> None: + """Replacing the SEQUENCE tag with garbage fails ASN.1 parsing.""" + _, _, cert = generate_libp2p_certificate(identity_key) + ext = cert.extensions.get_extension_for_oid(LIBP2P_EXTENSION_OID) + assert isinstance(ext.value, x509.UnrecognizedExtension) + payload = bytearray(ext.value.value) + payload[0] = 0x31 # SET tag instead of SEQUENCE (0x30) + broken = _replace_libp2p_extension(cert, bytes(payload)) + + with pytest.raises(PeerVerificationError, match="tag mismatch"): + verify_libp2p_certificate(broken) + + def test_unknown_key_type_is_rejected(self, identity_key: IdentityKeypair) -> None: + """A SignedKey carrying an unknown KeyType is rejected up front.""" + # Build a SignedKey by hand whose protobuf Type field is 99 (no such key type). + public_key_compressed = identity_key.public_key.to_bytes() + bogus_proto = bytes([0x08, 99, 0x12, len(public_key_compressed)]) + public_key_compressed + signature = identity_key.sign(b"") # signature value is irrelevant here + payload = _encode_asn1_signed_key(bogus_proto, signature) + + _, _, cert = generate_libp2p_certificate(identity_key) + broken = _replace_libp2p_extension(cert, payload) + + with pytest.raises(PeerVerificationError, match="KeyType"): + verify_libp2p_certificate(broken) + + def test_truncated_signed_key_is_rejected(self, identity_key: IdentityKeypair) -> None: + """A SignedKey body shorter than its declared length is rejected.""" + _, _, cert = generate_libp2p_certificate(identity_key) + ext = cert.extensions.get_extension_for_oid(LIBP2P_EXTENSION_OID) + assert isinstance(ext.value, x509.UnrecognizedExtension) + truncated = _replace_libp2p_extension(cert, ext.value.value[:-5]) + + with pytest.raises(PeerVerificationError): + verify_libp2p_certificate(truncated)