Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions smite/src/bolt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,13 @@ pub enum BoltError {
/// Unknown even TLV type (must reject per BOLT 1)
#[error("TLV_UNKNOWN_EVEN_TYPE {0}")]
TlvUnknownEvenType(u64),
/// TLV value longer than the known encoding for its type
#[error("TLV_TRAILING_BYTES type {tlv_type} expected {expected} got {actual}")]
TlvTrailingBytes {
tlv_type: u64,
expected: usize,
actual: usize,
},
}

/// BOLT message type constants.
Expand Down
33 changes: 24 additions & 9 deletions smite/src/bolt/channel_ready.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,16 +80,9 @@ impl ChannelReadyTlvs {
///
/// # Errors
///
/// Returns `Truncated` if the short channel ID TLV has invalid length.
/// Returns a `BoltError` if the short channel ID TLV has invalid length.
fn from_stream(stream: &TlvStream) -> Result<Self, BoltError> {
let short_channel_id = if let Some(data) = stream.get(TLV_SHORT_CHANNEL_ID) {
let mut cursor = data;
let scid = u64::read(&mut cursor)?;
Some(scid)
} else {
None
};

let short_channel_id = stream.get_as::<u64>(TLV_SHORT_CHANNEL_ID)?;
Ok(Self { short_channel_id })
}
}
Expand Down Expand Up @@ -223,6 +216,28 @@ mod tests {
);
}

#[test]
// Test constants are known to fit in u8
#[allow(clippy::cast_possible_truncation)]
fn decode_short_channel_id_reject_trailing_bytes() {
let msg = sample_channel_ready(None);
let mut encoded = msg.encode();

// short_channel_id TLV should be 8 bytes, but we push 9 bytes
encoded.push(TLV_SHORT_CHANNEL_ID as u8);
encoded.push(0x09);
encoded.extend_from_slice(&[0xbb; 9]);

assert_eq!(
ChannelReady::decode(&encoded),
Err(BoltError::TlvTrailingBytes {
tlv_type: TLV_SHORT_CHANNEL_ID,
expected: 8,
actual: 9,
})
);
}

#[test]
fn default_tlvs_are_none() {
let tlvs = ChannelReadyTlvs::default();
Expand Down
21 changes: 5 additions & 16 deletions smite/src/bolt/init.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,21 +125,9 @@ impl InitTlvs {
///
/// # Errors
///
/// Returns `Truncated` if the networks TLV has invalid length.
/// Returns a `BoltError` if the networks TLV has invalid length.
fn from_stream(stream: &TlvStream) -> Result<Self, BoltError> {
let networks = if let Some(data) = stream.get(TLV_NETWORKS) {
let (chunks, remainder) = data.as_chunks::<CHAIN_HASH_SIZE>();
if !remainder.is_empty() {
return Err(BoltError::Truncated {
expected: (chunks.len() + 1) * CHAIN_HASH_SIZE,
actual: data.len(),
});
}
Some(chunks.to_vec())
} else {
None
};

let networks = stream.get_as_many::<[u8; 32]>(TLV_NETWORKS)?;
let remote_addr = stream.get(TLV_REMOTE_ADDR).map(Vec::from);

Ok(Self {
Expand Down Expand Up @@ -389,8 +377,9 @@ mod tests {

assert_eq!(
Init::decode(&data),
Err(BoltError::Truncated {
expected: CHAIN_HASH_SIZE * 2, // Next multiple of 32
Err(BoltError::TlvTrailingBytes {
tlv_type: 1,
expected: 32,
actual: 33
})
);
Expand Down
154 changes: 154 additions & 0 deletions smite/src/bolt/tlv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,85 @@ impl TlvStream {
.map(|r| r.value.as_slice())
}

/// Gets a record by type and decodes it as a fixed-size `WireFormat` value.
///
/// Returns `None` if the record is absent. Rejects TLV values that are
/// longer than the type's known wire encoding.
///
/// # Errors
///
/// Returns a `BoltError` if the value is truncated or contains trailing
/// bytes after decoding.
pub fn get_as<T: WireFormat>(&self, tlv_type: u64) -> Result<Option<T>, BoltError> {
self.get(tlv_type)
.map(|data| {
let mut cursor = data;
let value = T::read(&mut cursor)?;
// we must fail to parse the stream
// "if length is not exactly equal to that required for the known encoding for type"
// [BOLT 1]: https://github.com/lightning/bolts/blob/master/01-messaging.md#type-length-value-format
if !cursor.is_empty() {
let bytes_read = data.len() - cursor.len();
return Err(BoltError::TlvTrailingBytes {
tlv_type,
expected: bytes_read,
actual: data.len(),
});
}
Ok(value)
})
.transpose()
}

/// Gets all records by type and decodes them as fixed-size `WireFormat`
/// values.
///
/// Returns `None` if no records are found, or `Some(vec)` if present.
///
/// # Errors
///
/// Returns a `BoltError` if decoding a TLV value fails or the TLV values
/// cannot be divided into fixed-size chunks.
pub fn get_as_many<T: WireFormat>(&self, tlv_type: u64) -> Result<Option<Vec<T>>, BoltError> {
match self.get(tlv_type) {
Some(data) => {
if data.is_empty() {
return Ok(Some(Vec::new()));
}

let total_bytes = data.len();
let mut cursor = data;

// read first element to determine chunk size
let first = T::read(&mut cursor)?;

let chunk_size = total_bytes - cursor.len();
if chunk_size == 0 {
return Err(BoltError::Truncated {
expected: 1,
actual: 0,
});
}
if total_bytes % chunk_size != 0 {
return Err(BoltError::TlvTrailingBytes {
tlv_type,
expected: (total_bytes / chunk_size) * chunk_size,
actual: total_bytes,
});
}

let mut values = Vec::with_capacity(total_bytes / chunk_size);
values.push(first);
for chunk in cursor.chunks(chunk_size) {
let mut chunk_cursor = chunk;
values.push(T::read(&mut chunk_cursor)?);
}
Ok(Some(values))
}
None => Ok(None),
}
}

/// Returns an iterator over all records.
pub fn iter(&self) -> impl Iterator<Item = &TlvRecord> {
self.records.iter()
Expand Down Expand Up @@ -210,6 +289,81 @@ mod tests {
assert_eq!(decoded.get(255), Some(&[0xff; 100][..]));
}

#[test]
fn get_as_missing_returns_none() {
let stream = TlvStream::new();
assert_eq!(stream.get_as::<u64>(1).unwrap(), None);
}

#[test]
fn get_as_exact_length() {
let mut stream = TlvStream::new();
let mut value = Vec::new();
42u64.write(&mut value);
stream.add(1, value);

assert_eq!(stream.get_as::<u64>(1).unwrap(), Some(42));
}

#[test]
fn get_as_overlength_rejected() {
let mut stream = TlvStream::new();
let mut value = Vec::new();
42u64.write(&mut value);
value.push(0xff);
stream.add(1, value);

assert_eq!(
stream.get_as::<u64>(1),
Err(BoltError::TlvTrailingBytes {
tlv_type: 1,
expected: 8,
actual: 9
})
);
}

#[test]
fn get_as_underlength_truncated() {
let mut stream = TlvStream::new();
stream.add(1, vec![0xaa; 4]);

assert_eq!(
stream.get_as::<u64>(1),
Err(BoltError::Truncated {
expected: 8,
actual: 4
})
);
}

#[test]
fn get_as_many() {
let mut stream = TlvStream::new();
let value = [[0u8; 32], [1u8; 32]].as_flattened().to_vec();
stream.add(1, value);

assert_eq!(
stream.get_as_many::<[u8; 32]>(1).unwrap(),
Some(vec![[0u8; 32], [1u8; 32]])
);
}

#[test]
fn get_as_many_reject_trailing_bytes() {
let mut stream = TlvStream::new();
stream.add(1, vec![0u8; 33]);

assert_eq!(
stream.get_as_many::<[u8; 32]>(1),
Err(BoltError::TlvTrailingBytes {
tlv_type: 1,
expected: 32,
actual: 33
})
);
}

#[test]
fn known_even_accepted() {
// type=2 is even but known
Expand Down
12 changes: 3 additions & 9 deletions smite/src/bolt/tx_ack_rbf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,16 +90,10 @@ impl TxAckRbfTlvs {
///
/// # Errors
///
/// Returns `Truncated` if `funding_output_contribution` has invalid length.
/// Returns a `BoltError` if `funding_output_contribution` has invalid
/// length.
fn from_stream(stream: &TlvStream) -> Result<Self, BoltError> {
let funding_output_contribution =
if let Some(data) = stream.get(TLV_FUNDING_OUTPUT_CONTRIBUTION) {
let mut cursor = data;
Some(i64::read(&mut cursor)?)
} else {
None
};

let funding_output_contribution = stream.get_as::<i64>(TLV_FUNDING_OUTPUT_CONTRIBUTION)?;
let require_confirmed_inputs = stream.get(TLV_REQUIRE_CONFIRMED_INPUTS).is_some();

Ok(Self {
Expand Down
10 changes: 2 additions & 8 deletions smite/src/bolt/tx_add_input.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,9 @@ impl TxAddInputTlvs {
///
/// # Errors
///
/// Returns `Truncated` if `shared_input_txid` has invalid length.
/// Returns a `BoltError` if `shared_input_txid` has invalid length.
fn from_stream(stream: &TlvStream) -> Result<Self, BoltError> {
let shared_input_txid = stream
.get(TLV_SHARED_INPUT_TXID)
.map(|v| {
let mut cursor = v;
Txid::read(&mut cursor)
})
.transpose()?;
let shared_input_txid = stream.get_as::<Txid>(TLV_SHARED_INPUT_TXID)?;
Ok(Self { shared_input_txid })
}
}
Expand Down
12 changes: 3 additions & 9 deletions smite/src/bolt/tx_init_rbf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,16 +103,10 @@ impl TxInitRbfTlvs {
///
/// # Errors
///
/// Returns `Truncated` if `funding_output_contribution` has invalid length.
/// Returns a `BoltError` if `funding_output_contribution` has invalid
/// length.
fn from_stream(stream: &TlvStream) -> Result<Self, BoltError> {
let funding_output_contribution =
if let Some(data) = stream.get(TLV_FUNDING_OUTPUT_CONTRIBUTION) {
let mut cursor = data;
Some(i64::read(&mut cursor)?)
} else {
None
};

let funding_output_contribution = stream.get_as::<i64>(TLV_FUNDING_OUTPUT_CONTRIBUTION)?;
let require_confirmed_inputs = stream.get(TLV_REQUIRE_CONFIRMED_INPUTS).is_some();

Ok(Self {
Expand Down
7 changes: 2 additions & 5 deletions smite/src/bolt/update_fail_htlc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,9 @@ impl UpdateFailHtlcTlvs {
///
/// # Errors
///
/// Returns `Truncated` if `attribution_data` has invalid length.
/// Returns a `BoltError` if `attribution_data` has invalid length.
fn from_stream(stream: &TlvStream) -> Result<Self, BoltError> {
let attribution_data = stream
.get(TLV_ATTRIBUTION_DATA)
.map(AttributionData::decode)
.transpose()?;
let attribution_data = stream.get_as::<AttributionData>(TLV_ATTRIBUTION_DATA)?;
Ok(Self { attribution_data })
}
}
Expand Down
7 changes: 2 additions & 5 deletions smite/src/bolt/update_fulfill_htlc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,9 @@ impl UpdateFulfillHtlcTlvs {
///
/// # Errors
///
/// Returns `Truncated` if `attribution_data` has invalid length.
/// Returns a `BoltError` if `attribution_data` has invalid length.
fn from_stream(stream: &TlvStream) -> Result<Self, BoltError> {
let attribution_data = stream
.get(TLV_ATTRIBUTION_DATA)
.map(AttributionData::decode)
.transpose()?;
let attribution_data = stream.get_as::<AttributionData>(TLV_ATTRIBUTION_DATA)?;
Ok(Self { attribution_data })
}
}
Expand Down
Loading