Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 43 additions & 18 deletions src/lean_spec/subspecs/networking/transport/quic/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
from ..peer_id import PeerId
from .stream import QuicStream, QuicTransportError
from .stream_adapter import QuicStreamAdapter
from .tls import generate_libp2p_certificate
from .tls import PeerVerificationError, generate_libp2p_certificate, verify_libp2p_certificate

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -212,6 +212,15 @@ def __init__(self, *args, **kwargs) -> None:
self.handshake_complete = asyncio.Event()
self._buffered_events: list[QuicEvent] = []

# Why: libp2p QUIC mandates mutual authentication.
# The server must request the client certificate so the libp2p
# extension is available on both sides of the handshake.
# aioquic does not expose mTLS as a public option; this private
# flag is the established libp2p-on-aioquic mechanism.
quic = getattr(self, "_quic", None)
if quic is not None and not quic.configuration.is_client:
quic.tls._request_client_certificate = True

def quic_event_received(self, event: QuicEvent) -> None:
"""Handle QUIC events."""
if isinstance(event, HandshakeCompleted):
Expand Down Expand Up @@ -424,22 +433,23 @@ async def connect(self, multiaddr: str) -> QuicConnection:
# Wait for handshake to complete.
await protocol.handshake_complete.wait()

# For now, we don't verify peer certificate (requires deeper aioquic integration).
# In production, we would extract and verify the libp2p certificate extension.
# Derive and verify the peer identity from the server's certificate.
#
# Without peer ID verification, we trust the connection based on:
# - QUIC encryption (TLS 1.3)
# - The peer being at the expected address

# Create a placeholder peer_id if we couldn't verify.
# In a real implementation, we'd extract this from the certificate.
if expected_peer_id:
peer_id = expected_peer_id
else:
# Generate a random peer ID for now.
# This is NOT correct for production but allows testing.
temp_key = IdentityKeypair.generate()
peer_id = temp_key.to_peer_id()
# Why: TLS 1.3 always delivers the server's certificate to the
# client. The libp2p extension binds the identity key to the TLS
# key; verifying the signature proves the server controls both.
peer_cert = protocol._quic.tls._peer_certificate
if peer_cert is None:
raise QuicTransportError("Peer did not present a TLS certificate")
peer_id = verify_libp2p_certificate(peer_cert)

# Invariant: the dialed multiaddr must name the same peer we ended
# up authenticated with.
# Mismatch indicates a man-in-the-middle or a stale ENR; reject it.
if expected_peer_id is not None and peer_id != expected_peer_id:
raise QuicTransportError(
f"Peer identity mismatch: expected {expected_peer_id}, got {peer_id}"
)

conn = QuicConnection(
_protocol=protocol,
Expand Down Expand Up @@ -499,8 +509,23 @@ async def listen(
# Callback to set up connection when handshake completes.
# Captures this manager's state (self, on_connection, host, port).
def handle_handshake(protocol_instance: LibP2PQuicProtocol) -> None:
temp_key = IdentityKeypair.generate()
remote_peer_id = temp_key.to_peer_id()
# Derive and verify the client's peer identity from the cert it
# sent during the mutual-TLS handshake.
# Drop the connection if the cert is missing or the libp2p
# extension fails to verify.
peer_cert = protocol_instance._quic.tls._peer_certificate
if peer_cert is None:
logger.warning("Inbound connection without client certificate, closing")
protocol_instance._quic.close()
protocol_instance.transmit()
return
try:
remote_peer_id = verify_libp2p_certificate(peer_cert)
except PeerVerificationError as exc:
logger.warning("Rejecting inbound connection: %s", exc)
protocol_instance._quic.close()
protocol_instance.transmit()
return

remote_addr = f"/ip4/{host}/udp/{port}/quic-v1/p2p/{remote_peer_id}"
conn = QuicConnection(
Expand Down
197 changes: 195 additions & 2 deletions src/lean_spec/subspecs/networking/transport/quic/tls.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,10 @@
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import ec

from ..identity import IdentityKeypair
from ..peer_id import KeyType
from lean_spec.subspecs.networking import varint

from ..identity import IdentityKeypair, Secp256k1PublicKey
from ..peer_id import KeyType, PeerId, PublicKeyProto

LIBP2P_EXTENSION_OID: Final = x509.ObjectIdentifier("1.3.6.1.4.1.53594.1.1")
"""libp2p TLS extension OID (Protocol Labs assigned)."""
Expand Down Expand Up @@ -222,3 +224,194 @@ def _encode_asn1_length(length: int) -> bytes:
return bytes([0x81, length])
else:
return bytes([0x82, length >> 8, length & 0xFF])


class PeerVerificationError(Exception):
"""Raised when the libp2p TLS extension fails to validate."""


def verify_libp2p_certificate(cert: x509.Certificate) -> PeerId:
"""
Extract and verify the peer identity from a libp2p TLS certificate.

The certificate carries the libp2p extension (OID 1.3.6.1.4.1.53594.1.1).
The extension contains a SignedKey envelope binding the peer's identity
public key to the TLS public key:

SignedKey ::= SEQUENCE {
publicKey OCTET STRING, -- protobuf-encoded PublicKey
signature OCTET STRING -- ECDSA over prefix || tls_pub_der
}

The signature proves the identity-key holder controls the TLS key.

Args:
cert: Peer's leaf X.509 certificate from the QUIC handshake.

Returns:
Canonical PeerId derived from the verified identity public key.

Raises:
PeerVerificationError: if the extension is missing, malformed, the
signature does not verify, or the key type is unsupported.
"""
# Locate the libp2p extension.
try:
ext = cert.extensions.get_extension_for_oid(LIBP2P_EXTENSION_OID)
except x509.ExtensionNotFound as exc:
raise PeerVerificationError("libp2p extension missing from certificate") from exc

if not isinstance(ext.value, x509.UnrecognizedExtension):
raise PeerVerificationError("libp2p extension has unexpected type")

# Parse the ASN.1 SignedKey envelope.
public_key_proto, signature = _parse_asn1_signed_key(ext.value.value)

# Decode the protobuf PublicKey.
key_type, key_data = _decode_protobuf_public_key(public_key_proto)

# The spec allows multiple key types, but only secp256k1 is used today.
if key_type != KeyType.SECP256K1:
raise PeerVerificationError(f"unsupported libp2p key type: {key_type}")

# Reconstruct the identity public key from the compressed point.
try:
identity_key = Secp256k1PublicKey(
_key=ec.EllipticCurvePublicKey.from_encoded_point(ec.SECP256K1(), key_data)
)
except ValueError as exc:
raise PeerVerificationError(f"invalid secp256k1 public key: {exc}") from exc

# Extract the TLS public key as SubjectPublicKeyInfo DER.
#
# The signature was computed over (prefix || SubjectPublicKeyInfo DER)
# by the certificate generator, so we must reproduce the same bytes here.
tls_public_bytes = cert.public_key().public_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)

if not identity_key.verify(SIGNATURE_PREFIX + tls_public_bytes, signature):
raise PeerVerificationError("libp2p extension signature failed verification")

# Derive the canonical PeerId via the protobuf-encoded identity key.
return PeerId.from_public_key(PublicKeyProto(key_type=key_type, key_data=key_data))


def _parse_asn1_signed_key(payload: bytes) -> tuple[bytes, bytes]:
"""
Parse the SignedKey envelope.

Mirrors _encode_asn1_signed_key. Accepts the three DER length forms used
by the encoder (short, 0x81, 0x82) and rejects any other form.

Returns:
(public_key_proto_bytes, signature_bytes) tuple.

Raises:
PeerVerificationError: if the envelope is malformed.
"""
# Outer SEQUENCE.
body, rest = _parse_asn1_tlv(payload, expected_tag=0x30)
if rest:
raise PeerVerificationError("trailing bytes after SignedKey SEQUENCE")

# Two OCTET STRINGs inside the SEQUENCE.
public_key_proto, after_first = _parse_asn1_tlv(body, expected_tag=0x04)
signature, after_second = _parse_asn1_tlv(after_first, expected_tag=0x04)
if after_second:
raise PeerVerificationError("trailing bytes after SignedKey contents")

return public_key_proto, signature


def _parse_asn1_tlv(data: bytes, *, expected_tag: int) -> tuple[bytes, bytes]:
"""
Parse a single ASN.1 TLV with the given tag and return (value, rest).

Raises:
PeerVerificationError: if the data is truncated, the tag mismatches,
or the length uses an unsupported encoding form.
"""
if not data:
raise PeerVerificationError("ASN.1 TLV truncated at tag")
if data[0] != expected_tag:
raise PeerVerificationError(
f"ASN.1 tag mismatch: expected 0x{expected_tag:02x}, got 0x{data[0]:02x}"
)

length, length_size = _decode_asn1_length(data[1:])
start = 1 + length_size
end = start + length
if end > len(data):
raise PeerVerificationError("ASN.1 TLV truncated at value")
return data[start:end], data[end:]


def _decode_asn1_length(data: bytes) -> tuple[int, int]:
"""
Decode an ASN.1 DER length and return (length, bytes_consumed).

Only the three forms emitted by _encode_asn1_length are accepted:
- short form (length < 128, single byte)
- long-1 form (0x81 + 1 byte)
- long-2 form (0x82 + 2 bytes)
"""
if not data:
raise PeerVerificationError("ASN.1 length truncated")

first = data[0]
if first < 0x80:
return first, 1
if first == 0x81:
if len(data) < 2:
raise PeerVerificationError("ASN.1 long-1 length truncated")
return data[1], 2
if first == 0x82:
if len(data) < 3:
raise PeerVerificationError("ASN.1 long-2 length truncated")
return (data[1] << 8) | data[2], 3
raise PeerVerificationError(f"unsupported ASN.1 length form: 0x{first:02x}")


def _decode_protobuf_public_key(payload: bytes) -> tuple[KeyType, bytes]:
"""
Decode the protobuf PublicKey message.

Wire format (deterministic encoding, fields in tag order):

[0x08][type_varint][0x12][length_varint][key_bytes]

Returns:
(key_type, key_data) tuple.

Raises:
PeerVerificationError: on truncation, missing field, or unknown tag.
"""
if len(payload) < 2 or payload[0] != 0x08:
raise PeerVerificationError("protobuf PublicKey: missing Type tag")
try:
type_value, type_size = varint.decode_varint(payload, offset=1)
except varint.VarintError as exc:
raise PeerVerificationError(f"protobuf PublicKey Type: {exc}") from exc

try:
key_type = KeyType(type_value)
except ValueError as exc:
raise PeerVerificationError(f"protobuf PublicKey: unknown KeyType {type_value}") from exc

data_start = 1 + type_size
if data_start >= len(payload) or payload[data_start] != 0x12:
raise PeerVerificationError("protobuf PublicKey: missing Data tag")
try:
data_length, length_size = varint.decode_varint(payload, offset=data_start + 1)
except varint.VarintError as exc:
raise PeerVerificationError(f"protobuf PublicKey Data length: {exc}") from exc

key_start = data_start + 1 + length_size
key_end = key_start + data_length
if key_end > len(payload):
raise PeerVerificationError("protobuf PublicKey: Data truncated")
if key_end != len(payload):
raise PeerVerificationError("protobuf PublicKey: trailing bytes after Data")
return key_type, payload[key_start:key_end]
Loading
Loading