diff --git a/chanbackup/backup.go b/chanbackup/backup.go index cf7217ae383..6edae255c81 100644 --- a/chanbackup/backup.go +++ b/chanbackup/backup.go @@ -6,6 +6,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/chanstate" "github.com/lightningnetwork/lnd/fn/v2" ) @@ -14,11 +15,11 @@ import ( // commitment transaction broadcast. type LiveChannelSource interface { // FetchAllChannels returns all known live channels. - FetchAllChannels() ([]*channeldb.OpenChannel, error) + FetchAllChannels() ([]*chanstate.OpenChannel, error) // FetchChannel attempts to locate a live channel identified by the // passed chanPoint. Optionally an existing db tx can be supplied. - FetchChannel(chanPoint wire.OutPoint) (*channeldb.OpenChannel, error) + FetchChannel(chanPoint wire.OutPoint) (*chanstate.OpenChannel, error) } // assembleChanBackup attempts to assemble a static channel backup for the @@ -26,7 +27,7 @@ type LiveChannelSource interface { // the channel, as well as addressing information so we can find the peer and // reconnect to them to initiate the protocol. func assembleChanBackup(ctx context.Context, addrSource channeldb.AddrSource, - openChan *channeldb.OpenChannel) (*Single, error) { + openChan *chanstate.OpenChannel) (*Single, error) { log.Debugf("Crafting backup for ChannelPoint(%v)", openChan.FundingOutpoint) @@ -55,7 +56,7 @@ func assembleChanBackup(ctx context.Context, addrSource channeldb.AddrSource, // in loss of funds! This may happen if an outdated channel backup is attempted // to be used to force close the channel. func buildCloseTxInputs( - targetChan *channeldb.OpenChannel) fn.Option[CloseTxInputs] { + targetChan *chanstate.OpenChannel) fn.Option[CloseTxInputs] { log.Debugf("Crafting CloseTxInputs for ChannelPoint(%v)", targetChan.FundingOutpoint) diff --git a/chanbackup/backup_test.go b/chanbackup/backup_test.go index 05a24090c0a..bf55deecb71 100644 --- a/chanbackup/backup_test.go +++ b/chanbackup/backup_test.go @@ -8,12 +8,12 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/wire" - "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/chanstate" "github.com/stretchr/testify/require" ) type mockChannelSource struct { - chans map[wire.OutPoint]*channeldb.OpenChannel + chans map[wire.OutPoint]*chanstate.OpenChannel failQuery bool @@ -22,17 +22,19 @@ type mockChannelSource struct { func newMockChannelSource() *mockChannelSource { return &mockChannelSource{ - chans: make(map[wire.OutPoint]*channeldb.OpenChannel), + chans: make(map[wire.OutPoint]*chanstate.OpenChannel), addrs: make(map[[33]byte][]net.Addr), } } -func (m *mockChannelSource) FetchAllChannels() ([]*channeldb.OpenChannel, error) { +func (m *mockChannelSource) FetchAllChannels() ( + []*chanstate.OpenChannel, error) { + if m.failQuery { return nil, fmt.Errorf("fail") } - chans := make([]*channeldb.OpenChannel, 0, len(m.chans)) + chans := make([]*chanstate.OpenChannel, 0, len(m.chans)) for _, channel := range m.chans { chans = append(chans, channel) } @@ -41,7 +43,7 @@ func (m *mockChannelSource) FetchAllChannels() ([]*channeldb.OpenChannel, error) } func (m *mockChannelSource) FetchChannel(chanPoint wire.OutPoint) ( - *channeldb.OpenChannel, error) { + *chanstate.OpenChannel, error) { if m.failQuery { return nil, fmt.Errorf("fail") diff --git a/chanbackup/pubsub.go b/chanbackup/pubsub.go index 6304f0cb376..9a50bfba0b1 100644 --- a/chanbackup/pubsub.go +++ b/chanbackup/pubsub.go @@ -10,7 +10,7 @@ import ( "sync/atomic" "github.com/btcsuite/btcd/wire" - "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/chanstate" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lnutils" ) @@ -31,7 +31,7 @@ type Swapper interface { // ChannelWithAddrs bundles an open channel along with all the addresses for // the channel peer. type ChannelWithAddrs struct { - *channeldb.OpenChannel + *chanstate.OpenChannel // Addrs is the set of addresses that we can use to reach the target // peer. diff --git a/chanbackup/single.go b/chanbackup/single.go index 0a38d6b561f..2e2c22682cd 100644 --- a/chanbackup/single.go +++ b/chanbackup/single.go @@ -11,7 +11,7 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" - "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/chanstate" "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lnencrypt" @@ -169,7 +169,7 @@ type Single struct { // // NOTE: Of the items in the ChannelConstraints, we only write the CSV // delay. - LocalChanCfg channeldb.ChannelConfig + LocalChanCfg chanstate.ChannelConfig // RemoteChanCfg is the remote channel confirmation. We store this as // well since we'll need some of their keys to re-derive things like @@ -178,7 +178,7 @@ type Single struct { // // NOTE: Of the items in the ChannelConstraints, we only write the CSV // delay. - RemoteChanCfg channeldb.ChannelConfig + RemoteChanCfg chanstate.ChannelConfig // ShaChainRootDesc describes how to derive the private key that was // used as the shachain root for this channel. @@ -234,7 +234,7 @@ type CloseTxInputs struct { // connect to the channel peer. If possible, we include the data needed to // produce a force close transaction from the most recent state using externally // provided private key. -func NewSingle(channel *channeldb.OpenChannel, +func NewSingle(channel *chanstate.OpenChannel, nodeAddrs []net.Addr) Single { var shaChainRootDesc keychain.KeyDescriptor diff --git a/chanbackup/single_test.go b/chanbackup/single_test.go index f1f805c1435..1bf763618a3 100644 --- a/chanbackup/single_test.go +++ b/chanbackup/single_test.go @@ -12,7 +12,7 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/davecgh/go-spew/spew" - "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/chanstate" "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lnencrypt" @@ -135,7 +135,7 @@ func assertSingleEqual(t *testing.T, a, b Single) { } } -func genRandomOpenChannelShell() (*channeldb.OpenChannel, error) { +func genRandomOpenChannelShell() (*chanstate.OpenChannel, error) { var testPriv [32]byte if _, err := rand.Read(testPriv[:]); err != nil { return nil, err @@ -162,11 +162,11 @@ func genRandomOpenChannelShell() (*channeldb.OpenChannel, error) { isInitiator = true } - chanType := channeldb.ChannelType(rand.Intn(1 << 12)) + chanType := chanstate.ChannelType(rand.Intn(1 << 12)) - localCfg := channeldb.ChannelConfig{ - ChannelStateBounds: channeldb.ChannelStateBounds{}, - CommitmentParams: channeldb.CommitmentParams{ + localCfg := chanstate.ChannelConfig{ + ChannelStateBounds: chanstate.ChannelStateBounds{}, + CommitmentParams: chanstate.CommitmentParams{ CsvDelay: uint16(rand.Int63()), }, MultiSigKey: keychain.KeyDescriptor{ @@ -201,8 +201,8 @@ func genRandomOpenChannelShell() (*channeldb.OpenChannel, error) { }, } - remoteCfg := channeldb.ChannelConfig{ - CommitmentParams: channeldb.CommitmentParams{ + remoteCfg := chanstate.ChannelConfig{ + CommitmentParams: chanstate.CommitmentParams{ CsvDelay: uint16(rand.Int63()), }, MultiSigKey: keychain.KeyDescriptor{ @@ -222,14 +222,14 @@ func genRandomOpenChannelShell() (*channeldb.OpenChannel, error) { }, } - var localCommit channeldb.ChannelCommitment + var localCommit chanstate.ChannelCommitment if chanType.IsTaproot() { var commitSig [64]byte if _, err := rand.Read(commitSig[:]); err != nil { return nil, err } - localCommit = channeldb.ChannelCommitment{ + localCommit = chanstate.ChannelCommitment{ CommitTx: sampleCommitTx, CommitSig: commitSig[:], CommitHeight: rand.Uint64(), @@ -245,7 +245,7 @@ func genRandomOpenChannelShell() (*channeldb.OpenChannel, error) { tapscriptRootOption = fn.Some(tapscriptRoot) } - return &channeldb.OpenChannel{ + return &chanstate.OpenChannel{ ChainHash: chainHash, ChanType: chanType, IsInitiator: isInitiator, diff --git a/chanfitness/chaneventstore.go b/chanfitness/chaneventstore.go index 881e9a35ea2..23dcbffc13e 100644 --- a/chanfitness/chaneventstore.go +++ b/chanfitness/chaneventstore.go @@ -20,6 +20,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channelnotifier" + "github.com/lightningnetwork/lnd/chanstate" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/peernotifier" "github.com/lightningnetwork/lnd/routing/route" @@ -84,7 +85,7 @@ type Config struct { // GetOpenChannels provides a list of existing open channels which is // used to populate the ChannelEventStore with a set of channels on // startup. - GetOpenChannels func() ([]*channeldb.OpenChannel, error) + GetOpenChannels func() ([]*chanstate.OpenChannel, error) // Clock is the time source that the subsystem uses, provided here // for ease of testing. diff --git a/chanfitness/chaneventstore_test.go b/chanfitness/chaneventstore_test.go index ecec3ea4717..35ecdbaa4e7 100644 --- a/chanfitness/chaneventstore_test.go +++ b/chanfitness/chaneventstore_test.go @@ -8,6 +8,7 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/chanstate" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/routing/route" "github.com/lightningnetwork/lnd/subscribe" @@ -35,7 +36,7 @@ func TestStartStoreError(t *testing.T) { name string ChannelEvents func() (subscribe.Subscription, error) PeerEvents func() (subscribe.Subscription, error) - GetChannels func() ([]*channeldb.OpenChannel, error) + GetChannels func() ([]*chanstate.OpenChannel, error) }{ { name: "Channel events fail", @@ -50,7 +51,7 @@ func TestStartStoreError(t *testing.T) { name: "Get open channels fails", ChannelEvents: okSubscribeFunc, PeerEvents: okSubscribeFunc, - GetChannels: func() ([]*channeldb.OpenChannel, error) { + GetChannels: func() ([]*chanstate.OpenChannel, error) { return nil, errors.New("intentional test err") }, }, diff --git a/chanfitness/chaneventstore_testctx_test.go b/chanfitness/chaneventstore_testctx_test.go index aff4c5fca5b..41043d2459b 100644 --- a/chanfitness/chaneventstore_testctx_test.go +++ b/chanfitness/chaneventstore_testctx_test.go @@ -9,6 +9,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channelnotifier" + "github.com/lightningnetwork/lnd/chanstate" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/peernotifier" "github.com/lightningnetwork/lnd/routing/route" @@ -72,7 +73,7 @@ func newChanEventStoreTestCtx(t *testing.T) *chanEventStoreTestCtx { SubscribePeerEvents: func() (subscribe.Subscription, error) { return testCtx.peerSubscription, nil }, - GetOpenChannels: func() ([]*channeldb.OpenChannel, error) { + GetOpenChannels: func() ([]*chanstate.OpenChannel, error) { return nil, nil }, WriteFlapCount: func(updates map[route.Vertex]*channeldb.FlapCount) error { @@ -181,7 +182,7 @@ func (c *chanEventStoreTestCtx) closeChannel(channel wire.OutPoint, peer *btcec.PublicKey) { update := channelnotifier.ClosedChannelEvent{ - CloseSummary: &channeldb.ChannelCloseSummary{ + CloseSummary: &chanstate.ChannelCloseSummary{ ChanPoint: channel, RemotePub: peer, }, @@ -221,7 +222,7 @@ func (c *chanEventStoreTestCtx) sendChannelOpenedUpdate(pubkey *btcec.PublicKey, channel wire.OutPoint) { update := channelnotifier.OpenChannelEvent{ - Channel: &channeldb.OpenChannel{ + Channel: &chanstate.OpenChannel{ FundingOutpoint: channel, IdentityPub: pubkey, }, diff --git a/channel_notifier.go b/channel_notifier.go index 8affd48f08d..0a53502dd96 100644 --- a/channel_notifier.go +++ b/channel_notifier.go @@ -8,6 +8,7 @@ import ( "github.com/lightningnetwork/lnd/chanbackup" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channelnotifier" + "github.com/lightningnetwork/lnd/chanstate" ) // channelNotifier is an implementation of the chanbackup.ChannelNotifier @@ -46,7 +47,7 @@ func (c *channelNotifier) SubscribeChans(ctx context.Context, // sendChanOpenUpdate is a closure that sends a ChannelEvent to the // chanUpdates channel to inform subscribers about new pending or // confirmed channels. - sendChanOpenUpdate := func(newOrPendingChan *channeldb.OpenChannel) { + sendChanOpenUpdate := func(newOrPendingChan *chanstate.OpenChannel) { _, nodeAddrs, err := c.addrs.AddrsForNode( ctx, newOrPendingChan.IdentityPub, ) diff --git a/channeldb/channel.go b/channeldb/channel.go index 127e0ac9c4d..f1b86c71e8e 100644 --- a/channeldb/channel.go +++ b/channeldb/channel.go @@ -1,30 +1,13 @@ package channeldb import ( - "bytes" - "crypto/hmac" - "crypto/sha256" - "encoding/binary" - "errors" - "fmt" "io" "net" - "strconv" - "strings" - "sync" "github.com/btcsuite/btcd/btcec/v2" - "github.com/btcsuite/btcd/btcec/v2/schnorr/musig2" - "github.com/btcsuite/btcd/btcutil" - "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" - "github.com/btcsuite/btcwallet/walletdb" + cstate "github.com/lightningnetwork/lnd/chanstate" "github.com/lightningnetwork/lnd/fn/v2" - graphdb "github.com/lightningnetwork/lnd/graph/db" - "github.com/lightningnetwork/lnd/graph/db/models" - "github.com/lightningnetwork/lnd/htlcswitch/hop" - "github.com/lightningnetwork/lnd/input" - "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" @@ -33,339 +16,101 @@ import ( ) const ( - // AbsoluteThawHeightThreshold is the threshold at which a thaw height - // begins to be interpreted as an absolute block height, rather than a - // relative one. - AbsoluteThawHeightThreshold uint32 = 500000 - // HTLCBlindingPointTLV is the tlv type used for storing blinding // points with HTLCs. HTLCBlindingPointTLV tlv.Type = 0 ) +const ( + // AbsoluteThawHeightThreshold is the threshold at which a thaw height + // begins to be interpreted as an absolute block height, rather than a + // relative one. + AbsoluteThawHeightThreshold = cstate.AbsoluteThawHeightThreshold +) + var ( - // closedChannelBucket stores summarization information concerning - // previously open, but now closed channels. - closedChannelBucket = []byte("closed-chan-bucket") - - // openChannelBucket stores all the currently open channels. This bucket - // has a second, nested bucket which is keyed by a node's ID. Within - // that node ID bucket, all attributes required to track, update, and - // close a channel are stored. - // - // openChan -> nodeID -> chanPoint - // - // TODO(roasbeef): flesh out comment - openChannelBucket = []byte("open-chan-bucket") - - // outpointBucket stores all of our channel outpoints and a tlv - // stream containing channel data. - // - // outpoint -> tlv stream. - // - outpointBucket = []byte("outpoint-bucket") - - // chanIDBucket stores all of the 32-byte channel ID's we know about. - // These could be derived from outpointBucket, but it is more - // convenient to have these in their own bucket. - // - // chanID -> tlv stream. - // - chanIDBucket = []byte("chan-id-bucket") - - // historicalChannelBucket stores all channels that have seen their - // commitment tx confirm. All information from their previous open state - // is retained. - historicalChannelBucket = []byte("historical-chan-bucket") - - // chanInfoKey can be accessed within the bucket for a channel - // (identified by its chanPoint). This key stores all the static - // information for a channel which is decided at the end of the - // funding flow. - chanInfoKey = []byte("chan-info-key") - - // localUpfrontShutdownKey can be accessed within the bucket for a channel - // (identified by its chanPoint). This key stores an optional upfront - // shutdown script for the local peer. - localUpfrontShutdownKey = []byte("local-upfront-shutdown-key") - - // remoteUpfrontShutdownKey can be accessed within the bucket for a channel - // (identified by its chanPoint). This key stores an optional upfront - // shutdown script for the remote peer. - remoteUpfrontShutdownKey = []byte("remote-upfront-shutdown-key") - - // chanCommitmentKey can be accessed within the sub-bucket for a - // particular channel. This key stores the up to date commitment state - // for a particular channel party. Appending a 0 to the end of this key - // indicates it's the commitment for the local party, and appending a 1 - // to the end of this key indicates it's the commitment for the remote - // party. - chanCommitmentKey = []byte("chan-commitment-key") - - // unsignedAckedUpdatesKey is an entry in the channel bucket that - // contains the remote updates that we have acked, but not yet signed - // for in one of our remote commits. - unsignedAckedUpdatesKey = []byte("unsigned-acked-updates-key") - - // remoteUnsignedLocalUpdatesKey is an entry in the channel bucket that - // contains the local updates that the remote party has acked, but - // has not yet signed for in one of their local commits. - remoteUnsignedLocalUpdatesKey = []byte("remote-unsigned-local-updates-key") - - // revocationStateKey stores their current revocation hash, our - // preimage producer and their preimage store. - revocationStateKey = []byte("revocation-state-key") - - // dataLossCommitPointKey stores the commitment point received from the - // remote peer during a channel sync in case we have lost channel state. - dataLossCommitPointKey = []byte("data-loss-commit-point-key") - - // forceCloseTxKey points to a the unilateral closing tx that we - // broadcasted when moving the channel to state CommitBroadcasted. - forceCloseTxKey = []byte("closing-tx-key") - - // coopCloseTxKey points to a the cooperative closing tx that we - // broadcasted when moving the channel to state CoopBroadcasted. - coopCloseTxKey = []byte("coop-closing-tx-key") - - // shutdownInfoKey points to the serialised shutdown info that has been - // persisted for a channel. The existence of this info means that we - // have sent the Shutdown message before and so should re-initiate the - // shutdown on re-establish. - shutdownInfoKey = []byte("shutdown-info-key") - - // commitDiffKey stores the current pending commitment state we've - // extended to the remote party (if any). Each time we propose a new - // state, we store the information necessary to reconstruct this state - // from the prior commitment. This allows us to resync the remote party - // to their expected state in the case of message loss. - // - // TODO(roasbeef): rename to commit chain? - commitDiffKey = []byte("commit-diff-key") - - // frozenChanKey is the key where we store the information for any - // active "frozen" channels. This key is present only in the leaf - // bucket for a given channel. - frozenChanKey = []byte("frozen-chans") - - // lastWasRevokeKey is a key that stores true when the last update we - // sent was a revocation and false when it was a commitment signature. - // This is nil in the case of new channels with no updates exchanged. - lastWasRevokeKey = []byte("last-was-revoke") - - // finalHtlcsBucket contains the htlcs that have been resolved - // definitively. Within this bucket, there is a sub-bucket for each - // channel. In each channel bucket, the htlc indices are stored along - // with final outcome. - // - // final-htlcs -> chanID -> htlcIndex -> outcome - // - // 'outcome' is a byte value that encodes: - // - // | true false - // ------+------------------ - // bit 0 | settled failed - // bit 1 | offchain onchain - // - // This bucket is positioned at the root level, because its contents - // will be kept independent of the channel lifecycle. This is to avoid - // the situation where a channel force-closes autonomously and the user - // not being able to query for htlc outcomes anymore. - finalHtlcsBucket = []byte("final-htlcs") + closedChannelBucket = cstate.ClosedChannelBucketKey() + openChannelBucket = cstate.OpenChannelBucketKey() + outpointBucket = cstate.OutpointBucketKey() + chanIDBucket = cstate.ChanIDBucketKey() + historicalChannelBucket = cstate.HistoricalChannelBucketKey() ) var ( // ErrNoCommitmentsFound is returned when a channel has not set // commitment states. - ErrNoCommitmentsFound = fmt.Errorf("no commitments found") + ErrNoCommitmentsFound = cstate.ErrNoCommitmentsFound // ErrNoChanInfoFound is returned when a particular channel does not // have any channels state. - ErrNoChanInfoFound = fmt.Errorf("no chan info found") + ErrNoChanInfoFound = cstate.ErrNoChanInfoFound // ErrNoRevocationsFound is returned when revocation state for a // particular channel cannot be found. - ErrNoRevocationsFound = fmt.Errorf("no revocations found") + ErrNoRevocationsFound = cstate.ErrNoRevocationsFound // ErrNoPendingCommit is returned when there is not a pending // commitment for a remote party. A new commitment is written to disk // each time we write a new state in order to be properly fault // tolerant. - ErrNoPendingCommit = fmt.Errorf("no pending commits found") + ErrNoPendingCommit = cstate.ErrNoPendingCommit // ErrNoCommitPoint is returned when no data loss commit point is found // in the database. - ErrNoCommitPoint = fmt.Errorf("no commit point found") + ErrNoCommitPoint = cstate.ErrNoCommitPoint // ErrNoCloseTx is returned when no closing tx is found for a channel // in the state CommitBroadcasted. - ErrNoCloseTx = fmt.Errorf("no closing tx found") + ErrNoCloseTx = cstate.ErrNoCloseTx // ErrNoShutdownInfo is returned when no shutdown info has been // persisted for a channel. - ErrNoShutdownInfo = errors.New("no shutdown info") + ErrNoShutdownInfo = cstate.ErrNoShutdownInfo // ErrNoRestoredChannelMutation is returned when a caller attempts to // mutate a channel that's been recovered. - ErrNoRestoredChannelMutation = fmt.Errorf("cannot mutate restored " + - "channel state") + ErrNoRestoredChannelMutation = cstate.ErrNoRestoredChannelMutation // ErrChanBorked is returned when a caller attempts to mutate a borked // channel. - ErrChanBorked = fmt.Errorf("cannot mutate borked channel") + ErrChanBorked = cstate.ErrChanBorked // ErrMissingIndexEntry is returned when a caller attempts to close a // channel and the outpoint is missing from the index. - ErrMissingIndexEntry = fmt.Errorf("missing outpoint from index") + ErrMissingIndexEntry = cstate.ErrMissingIndexEntry // ErrOnionBlobLength is returned is an onion blob with incorrect // length is read from disk. - ErrOnionBlobLength = errors.New("onion blob < 1366 bytes") -) - -const ( - // A tlv type definition used to serialize an outpoint's indexStatus - // for use in the outpoint index. - indexStatusType tlv.Type = 0 + ErrOnionBlobLength = cstate.ErrOnionBlobLength ) -// openChannelTlvData houses the new data fields that are stored for each -// channel in a TLV stream within the root bucket. This is stored as a TLV -// stream appended to the existing hard-coded fields in the channel's root -// bucket. New fields being added to the channel state should be added here. -// -// NOTE: This struct is used for serialization purposes only and its fields -// should be accessed via the OpenChannel struct while in memory. -type openChannelTlvData struct { - // revokeKeyLoc is the key locator for the revocation key. - revokeKeyLoc tlv.RecordT[tlv.TlvType1, keyLocRecord] - - // initialLocalBalance is the initial local balance of the channel. - initialLocalBalance tlv.RecordT[tlv.TlvType2, uint64] - - // initialRemoteBalance is the initial remote balance of the channel. - initialRemoteBalance tlv.RecordT[tlv.TlvType3, uint64] - - // realScid is the real short channel ID of the channel corresponding to - // the on-chain outpoint. - realScid tlv.RecordT[tlv.TlvType4, lnwire.ShortChannelID] - - // memo is an optional text field that gives context to the user about - // the channel. - memo tlv.OptionalRecordT[tlv.TlvType5, []byte] - - // tapscriptRoot is the optional Tapscript root the channel funding - // output commits to. - tapscriptRoot tlv.OptionalRecordT[tlv.TlvType6, [32]byte] - - // customBlob is an optional TLV encoded blob of data representing - // custom channel funding information. - customBlob tlv.OptionalRecordT[tlv.TlvType7, tlv.Blob] - - // confirmationHeight records the block height at which the funding - // transaction was first confirmed. - confirmationHeight tlv.RecordT[tlv.TlvType8, uint32] - - // closeConfirmationHeight records the block height at which the closing - // transaction was first confirmed. This is used to calculate the - // remaining confirmations until the channel is considered fully closed. - // Note: if not set, it means either the channel has not been - // closed yet, or it was closed before this field was introduced. - closeConfirmationHeight tlv.OptionalRecordT[tlv.TlvType9, uint32] -} - -// encode serializes the openChannelTlvData to the given io.Writer. -func (c *openChannelTlvData) encode(w io.Writer) error { - tlvRecords := []tlv.Record{ - c.revokeKeyLoc.Record(), - c.initialLocalBalance.Record(), - c.initialRemoteBalance.Record(), - c.realScid.Record(), - c.confirmationHeight.Record(), - } - c.memo.WhenSome(func(memo tlv.RecordT[tlv.TlvType5, []byte]) { - tlvRecords = append(tlvRecords, memo.Record()) - }) - c.tapscriptRoot.WhenSome( - func(root tlv.RecordT[tlv.TlvType6, [32]byte]) { - tlvRecords = append(tlvRecords, root.Record()) - }, - ) - c.customBlob.WhenSome(func(blob tlv.RecordT[tlv.TlvType7, tlv.Blob]) { - tlvRecords = append(tlvRecords, blob.Record()) - }) - c.closeConfirmationHeight.WhenSome( - func(h tlv.RecordT[tlv.TlvType9, uint32]) { - tlvRecords = append(tlvRecords, h.Record()) - }, - ) - - tlv.SortRecords(tlvRecords) - - // Create the tlv stream. - tlvStream, err := tlv.NewStream(tlvRecords...) - if err != nil { - return err - } +type ( + indexStatus = cstate.IndexStatus - return tlvStream.Encode(w) -} - -// decode deserializes the openChannelTlvData from the given io.Reader. -func (c *openChannelTlvData) decode(r io.Reader) error { - memo := c.memo.Zero() - tapscriptRoot := c.tapscriptRoot.Zero() - blob := c.customBlob.Zero() - closeConfHeight := c.closeConfirmationHeight.Zero() - - // Create the tlv stream. - tlvStream, err := tlv.NewStream( - c.revokeKeyLoc.Record(), - c.initialLocalBalance.Record(), - c.initialRemoteBalance.Record(), - c.realScid.Record(), - memo.Record(), - tapscriptRoot.Record(), - blob.Record(), - c.confirmationHeight.Record(), - closeConfHeight.Record(), - ) - if err != nil { - return err - } + // OpenChannel encapsulates the persistent and dynamic state of an open + // channel with a remote node. + OpenChannel = cstate.OpenChannel - tlvs, err := tlvStream.DecodeWithParsedTypes(r) - if err != nil { - return err - } + // ChannelCommitment is a snapshot of the commitment state at a + // particular point in the commitment chain. + ChannelCommitment = cstate.ChannelCommitment - if _, ok := tlvs[memo.TlvType()]; ok { - c.memo = tlv.SomeRecordT(memo) - } - if _, ok := tlvs[tapscriptRoot.TlvType()]; ok { - c.tapscriptRoot = tlv.SomeRecordT(tapscriptRoot) - } - if _, ok := tlvs[c.customBlob.TlvType()]; ok { - c.customBlob = tlv.SomeRecordT(blob) - } - if _, ok := tlvs[closeConfHeight.TlvType()]; ok { - c.closeConfirmationHeight = tlv.SomeRecordT(closeConfHeight) - } + // HTLC is the on-disk representation of a hash time-locked contract. + HTLC = cstate.HTLC - return nil -} + // LogUpdate represents a pending update to the remote commitment + // chain. + LogUpdate = cstate.LogUpdate -// indexStatus is an enum-like type that describes what state the -// outpoint is in. Currently only two possible values. -type indexStatus uint8 + // CommitDiff represents the delta needed to apply the state + // transition between two subsequent commitment states. + CommitDiff = cstate.CommitDiff +) const ( - // outpointOpen represents an outpoint that is open in the outpoint index. - outpointOpen indexStatus = 0 - - // outpointClosed represents an outpoint that is closed in the outpoint - // index. - outpointClosed indexStatus = 1 + indexStatusType = cstate.IndexStatusType + outpointOpen = cstate.OutpointOpen + outpointClosed = cstate.OutpointClosed ) // isOutpointClosed reports whether the supplied chanKey has been flipped to @@ -377,4705 +122,603 @@ const ( // fetch outpointBucket once and pass it in, which lets loop-style readers // hoist the bucket lookup out of the inner loop. func isOutpointClosed(opBucket kvdb.RBucket, chanKey []byte) (bool, error) { - if opBucket == nil { - return false, nil - } - raw := opBucket.Get(chanKey) - if raw == nil { - return false, nil - } - - var status uint8 - statusRecord := tlv.MakePrimitiveRecord(indexStatusType, &status) - stream, err := tlv.NewStream(statusRecord) - if err != nil { - return false, err - } - if err := stream.Decode(bytes.NewReader(raw)); err != nil { - return false, fmt.Errorf("decode outpoint status for "+ - "chan_key=%x: %w", chanKey, err) - } - - return indexStatus(status) == outpointClosed, nil + return cstate.IsOutpointClosed(opBucket, chanKey) } // ChannelType is an enum-like type that describes one of several possible -// channel types. Each open channel is associated with a particular type as the -// channel type may determine how higher level operations are conducted such as -// fee negotiation, channel closing, the format of HTLCs, etc. Structure-wise, -// a ChannelType is a bit field, with each bit denoting a modification from the -// base channel type of single funder. -type ChannelType uint64 +// channel types. +type ChannelType = cstate.ChannelType const ( - // NOTE: iota isn't used here for this enum needs to be stable - // long-term as it will be persisted to the database. - // SingleFunderBit represents a channel wherein one party solely funds // the entire capacity of the channel. - SingleFunderBit ChannelType = 0 + SingleFunderBit = cstate.SingleFunderBit // DualFunderBit represents a channel wherein both parties contribute - // funds towards the total capacity of the channel. The channel may be - // funded symmetrically or asymmetrically. - DualFunderBit ChannelType = 1 << 0 + // funds towards the total capacity of the channel. + DualFunderBit = cstate.DualFunderBit // SingleFunderTweaklessBit is similar to the basic SingleFunder channel - // type, but it omits the tweak for one's key in the commitment - // transaction of the remote party. - SingleFunderTweaklessBit ChannelType = 1 << 1 + // type, but it omits the tweak for one's key. + SingleFunderTweaklessBit = cstate.SingleFunderTweaklessBit // NoFundingTxBit denotes if we have the funding transaction locally on - // disk. This bit may be on if the funding transaction was crafted by a - // wallet external to the primary daemon. - NoFundingTxBit ChannelType = 1 << 2 + // disk. + NoFundingTxBit = cstate.NoFundingTxBit // AnchorOutputsBit indicates that the channel makes use of anchor - // outputs to bump the commitment transaction's effective feerate. This - // channel type also uses a delayed to_remote output script. - AnchorOutputsBit ChannelType = 1 << 3 + // outputs to bump the commitment transaction's effective feerate. + AnchorOutputsBit = cstate.AnchorOutputsBit // FrozenBit indicates that the channel is a frozen channel, meaning // that only the responder can decide to cooperatively close the // channel. - FrozenBit ChannelType = 1 << 4 + FrozenBit = cstate.FrozenBit // ZeroHtlcTxFeeBit indicates that the channel should use zero-fee // second-level HTLC transactions. - ZeroHtlcTxFeeBit ChannelType = 1 << 5 + ZeroHtlcTxFeeBit = cstate.ZeroHtlcTxFeeBit // LeaseExpirationBit indicates that the channel has been leased for a - // period of time, constraining every output that pays to the channel - // initiator with an additional CLTV of the lease maturity. - LeaseExpirationBit ChannelType = 1 << 6 + // period of time. + LeaseExpirationBit = cstate.LeaseExpirationBit // ZeroConfBit indicates that the channel is a zero-conf channel. - ZeroConfBit ChannelType = 1 << 7 + ZeroConfBit = cstate.ZeroConfBit // ScidAliasChanBit indicates that the channel has negotiated the // scid-alias channel type. - ScidAliasChanBit ChannelType = 1 << 8 + ScidAliasChanBit = cstate.ScidAliasChanBit // ScidAliasFeatureBit indicates that the scid-alias feature bit was // negotiated during the lifetime of this channel. - ScidAliasFeatureBit ChannelType = 1 << 9 + ScidAliasFeatureBit = cstate.ScidAliasFeatureBit // SimpleTaprootFeatureBit indicates that the simple-taproot-chans // feature bit was negotiated during the lifetime of the channel. - SimpleTaprootFeatureBit ChannelType = 1 << 10 + SimpleTaprootFeatureBit = cstate.SimpleTaprootFeatureBit // TapscriptRootBit indicates that this is a MuSig2 channel with a top - // level tapscript commitment. This MUST be set along with the - // SimpleTaprootFeatureBit. - TapscriptRootBit ChannelType = 1 << 11 + // level tapscript commitment. + TapscriptRootBit = cstate.TapscriptRootBit // TaprootFinalBit indicates that this is a MuSig2 channel using the - // final/production taproot scripts and feature bits 80/81. This MUST - // be set along with the SimpleTaprootFeatureBit. - TaprootFinalBit ChannelType = 1 << 12 + // final/production taproot scripts and feature bits 80/81. + TaprootFinalBit = cstate.TaprootFinalBit ) -// IsSingleFunder returns true if the channel type if one of the known single -// funder variants. -func (c ChannelType) IsSingleFunder() bool { - return c&DualFunderBit == 0 -} +// ChannelStateBounds are the parameters from OpenChannel and AcceptChannel +// that bound the abstract channel state. +type ChannelStateBounds = cstate.ChannelStateBounds -// IsDualFunder returns true if the ChannelType has the DualFunderBit set. -func (c ChannelType) IsDualFunder() bool { - return c&DualFunderBit == DualFunderBit -} +// CommitmentParams are the parameters from OpenChannel and AcceptChannel that +// are required to render an abstract channel state to a concrete commitment +// transaction. +type CommitmentParams = cstate.CommitmentParams -// IsTweakless returns true if the target channel uses a commitment that -// doesn't tweak the key for the remote party. -func (c ChannelType) IsTweakless() bool { - return c&SingleFunderTweaklessBit == SingleFunderTweaklessBit -} +// ChannelConfig houses the channel configuration for one side of a channel. +type ChannelConfig = cstate.ChannelConfig -// HasFundingTx returns true if this channel type is one that has a funding -// transaction stored locally. -func (c ChannelType) HasFundingTx() bool { - return c&NoFundingTxBit == 0 -} +// ChannelStatus is a bit vector used to indicate whether an OpenChannel is in +// the default usable state, or a state where it shouldn't be used. +type ChannelStatus = cstate.ChannelStatus -// HasAnchors returns true if this channel type has anchor outputs on its -// commitment. -func (c ChannelType) HasAnchors() bool { - return c&AnchorOutputsBit == AnchorOutputsBit -} +var ( + // ChanStatusDefault is the normal state of an open channel. + ChanStatusDefault = cstate.ChanStatusDefault -// ZeroHtlcTxFee returns true if this channel type uses second-level HTLC -// transactions signed with zero-fee. -func (c ChannelType) ZeroHtlcTxFee() bool { - return c&ZeroHtlcTxFeeBit == ZeroHtlcTxFeeBit -} + // ChanStatusBorked indicates that the channel has entered an + // irreconcilable state. + ChanStatusBorked = cstate.ChanStatusBorked -// IsFrozen returns true if the channel is considered to be "frozen". A frozen -// channel means that only the responder can initiate a cooperative channel -// closure. -func (c ChannelType) IsFrozen() bool { - return c&FrozenBit == FrozenBit -} + // ChanStatusCommitBroadcasted indicates that a commitment for this + // channel has been broadcasted. + ChanStatusCommitBroadcasted = cstate.ChanStatusCommitBroadcasted -// HasLeaseExpiration returns true if the channel originated from a lease. -func (c ChannelType) HasLeaseExpiration() bool { - return c&LeaseExpirationBit == LeaseExpirationBit -} + // ChanStatusLocalDataLoss indicates that we have lost channel state + // for this channel. + ChanStatusLocalDataLoss = cstate.ChanStatusLocalDataLoss -// HasZeroConf returns true if the channel is a zero-conf channel. -func (c ChannelType) HasZeroConf() bool { - return c&ZeroConfBit == ZeroConfBit -} + // ChanStatusRestored signals that the channel has been restored and + // doesn't have all fields a typical channel will have. + ChanStatusRestored = cstate.ChanStatusRestored -// HasScidAliasChan returns true if the scid-alias channel type was negotiated. -func (c ChannelType) HasScidAliasChan() bool { - return c&ScidAliasChanBit == ScidAliasChanBit -} + // ChanStatusCoopBroadcasted indicates that a cooperative close for this + // channel has been broadcasted. + ChanStatusCoopBroadcasted = cstate.ChanStatusCoopBroadcasted -// HasScidAliasFeature returns true if the scid-alias feature bit was -// negotiated during the lifetime of this channel. -func (c ChannelType) HasScidAliasFeature() bool { - return c&ScidAliasFeatureBit == ScidAliasFeatureBit -} + // ChanStatusLocalCloseInitiator indicates that we initiated closing the + // channel. + ChanStatusLocalCloseInitiator = cstate.ChanStatusLocalCloseInitiator -// IsTaproot returns true if the channel is using taproot features. -func (c ChannelType) IsTaproot() bool { - return c&SimpleTaprootFeatureBit == SimpleTaprootFeatureBit -} + // ChanStatusRemoteCloseInitiator indicates that the remote node + // initiated closing the channel. + ChanStatusRemoteCloseInitiator = cstate.ChanStatusRemoteCloseInitiator +) -// HasTapscriptRoot returns true if the channel is using a top level tapscript -// root commitment. -func (c ChannelType) HasTapscriptRoot() bool { - return c&TapscriptRootBit == TapscriptRootBit -} +// FinalHtlcByte is a type alias for a byte that encodes information about the +// final htlc resolution. +type FinalHtlcByte = cstate.FinalHtlcByte -// IsTaprootFinal returns true if the channel is using final/production taproot -// scripts and feature bits. -func (c ChannelType) IsTaprootFinal() bool { - return c&TaprootFinalBit == TaprootFinalBit -} +const ( + // FinalHtlcSettledBit is the bit that encodes whether the htlc was + // settled or failed. + FinalHtlcSettledBit = cstate.FinalHtlcSettledBit -// ChannelStateBounds are the parameters from OpenChannel and AcceptChannel -// that are responsible for providing bounds on the state space of the abstract -// channel state. These values must be remembered for normal channel operation -// but they do not impact how we compute the commitment transactions themselves. -type ChannelStateBounds struct { - // ChanReserve is an absolute reservation on the channel for the - // owner of this set of constraints. This means that the current - // settled balance for this node CANNOT dip below the reservation - // amount. This acts as a defense against costless attacks when - // either side no longer has any skin in the game. - ChanReserve btcutil.Amount - - // MaxPendingAmount is the maximum pending HTLC value that the - // owner of these constraints can offer the remote node at a - // particular time. - MaxPendingAmount lnwire.MilliSatoshi - - // MinHTLC is the minimum HTLC value that the owner of these - // constraints can offer the remote node. If any HTLCs below this - // amount are offered, then the HTLC will be rejected. This, in - // tandem with the dust limit allows a node to regulate the - // smallest HTLC that it deems economically relevant. - MinHTLC lnwire.MilliSatoshi - - // MaxAcceptedHtlcs is the maximum number of HTLCs that the owner of - // this set of constraints can offer the remote node. This allows each - // node to limit their over all exposure to HTLCs that may need to be - // acted upon in the case of a unilateral channel closure or a contract - // breach. - MaxAcceptedHtlcs uint16 -} + // FinalHtlcOffchainBit is the bit that encodes whether the htlc was + // resolved offchain or onchain. + FinalHtlcOffchainBit = cstate.FinalHtlcOffchainBit +) -// CommitmentParams are the parameters from OpenChannel and -// AcceptChannel that are required to render an abstract channel state to a -// concrete commitment transaction. These values are necessary to (re)compute -// the commitment transaction. We treat these differently than the state space -// bounds because their history needs to be stored in order to properly handle -// chain resolution. -type CommitmentParams struct { - // DustLimit is the threshold (in satoshis) below which any outputs - // should be trimmed. When an output is trimmed, it isn't materialized - // as an actual output, but is instead burned to miner's fees. - DustLimit btcutil.Amount - - // CsvDelay is the relative time lock delay expressed in blocks. Any - // settled outputs that pay to the owner of this channel configuration - // MUST ensure that the delay branch uses this value as the relative - // time lock. Similarly, any HTLC's offered by this node should use - // this value as well. - CsvDelay uint16 +// RefreshChannel updates the in-memory channel state using the latest state +// observed on disk. +func (c *ChannelStateDB) RefreshChannel(channel *OpenChannel) error { + return c.kvStore.RefreshChannel(channel) } -// ChannelConfig is a struct that houses the various configuration opens for -// channels. Each side maintains an instance of this configuration file as it -// governs: how the funding and commitment transaction to be created, the -// nature of HTLC's allotted, the keys to be used for delivery, and relative -// time lock parameters. -type ChannelConfig struct { - // ChannelStateBounds is the set of constraints that must be - // upheld for the duration of the channel for the owner of this channel - // configuration. Constraints govern a number of flow control related - // parameters, also including the smallest HTLC that will be accepted - // by a participant. - ChannelStateBounds - - // CommitmentParams is an embedding of the parameters - // required to render an abstract channel state into a concrete - // commitment transaction. - CommitmentParams - - // MultiSigKey is the key to be used within the 2-of-2 output script - // for the owner of this channel config. - MultiSigKey keychain.KeyDescriptor - - // RevocationBasePoint is the base public key to be used when deriving - // revocation keys for the remote node's commitment transaction. This - // will be combined along with a per commitment secret to derive a - // unique revocation key for each state. - RevocationBasePoint keychain.KeyDescriptor - - // PaymentBasePoint is the base public key to be used when deriving - // the key used within the non-delayed pay-to-self output on the - // commitment transaction for a node. This will be combined with a - // tweak derived from the per-commitment point to ensure unique keys - // for each commitment transaction. - PaymentBasePoint keychain.KeyDescriptor - - // DelayBasePoint is the base public key to be used when deriving the - // key used within the delayed pay-to-self output on the commitment - // transaction for a node. This will be combined with a tweak derived - // from the per-commitment point to ensure unique keys for each - // commitment transaction. - DelayBasePoint keychain.KeyDescriptor - - // HtlcBasePoint is the base public key to be used when deriving the - // local HTLC key. The derived key (combined with the tweak derived - // from the per-commitment point) is used within the "to self" clause - // within any HTLC output scripts. - HtlcBasePoint keychain.KeyDescriptor -} +func fetchFinalHtlcsBucketRw(tx kvdb.RwTx, + chanID lnwire.ShortChannelID) (kvdb.RwBucket, error) { -// commitTlvData stores all the optional data that may be stored as a TLV stream -// at the _end_ of the normal serialized commit on disk. -type commitTlvData struct { - // customBlob is a custom blob that may store extra data for custom - // channels. - customBlob tlv.OptionalRecordT[tlv.TlvType1, tlv.Blob] + return cstate.FetchFinalHtlcsBucketRw(tx, chanID) } -// encode encodes the aux data into the passed io.Writer. -func (c *commitTlvData) encode(w io.Writer) error { - var tlvRecords []tlv.Record - c.customBlob.WhenSome(func(blob tlv.RecordT[tlv.TlvType1, tlv.Blob]) { - tlvRecords = append(tlvRecords, blob.Record()) - }) - - // Create the tlv stream. - tlvStream, err := tlv.NewStream(tlvRecords...) - if err != nil { - return err - } +// MarkChannelConfirmationHeight updates the channel's confirmation height once +// the channel opening transaction receives one confirmation. +func (c *ChannelStateDB) MarkChannelConfirmationHeight(channel *OpenChannel, + height uint32) error { - return tlvStream.Encode(w) + return c.kvStore.MarkChannelConfirmationHeight(channel, height) } -// decode attempts to decode the aux data from the passed io.Reader. -func (c *commitTlvData) decode(r io.Reader) error { - blob := c.customBlob.Zero() - - tlvStream, err := tlv.NewStream( - blob.Record(), - ) - if err != nil { - return err - } +// MarkChannelCloseConfirmationHeight updates the channel's close confirmation +// height when the closing transaction is first detected in a block. +func (c *ChannelStateDB) MarkChannelCloseConfirmationHeight( + channel *OpenChannel, height fn.Option[uint32]) error { - tlvs, err := tlvStream.DecodeWithParsedTypes(r) - if err != nil { - return err - } + return c.kvStore.MarkChannelCloseConfirmationHeight(channel, height) +} - if _, ok := tlvs[c.customBlob.TlvType()]; ok { - c.customBlob = tlv.SomeRecordT(blob) - } +// MarkChannelOpen marks a channel as fully open given a locator that uniquely +// describes its location within the chain. +func (c *ChannelStateDB) MarkChannelOpen(channel *OpenChannel, + openLoc lnwire.ShortChannelID) error { - return nil + return c.kvStore.MarkChannelOpen(channel, openLoc) } -// ChannelCommitment is a snapshot of the commitment state at a particular -// point in the commitment chain. With each state transition, a snapshot of the -// current state along with all non-settled HTLCs are recorded. These snapshots -// detail the state of the _remote_ party's commitment at a particular state -// number. For ourselves (the local node) we ONLY store our most recent -// (unrevoked) state for safety purposes. -type ChannelCommitment struct { - // CommitHeight is the update number that this ChannelDelta represents - // the total number of commitment updates to this point. This can be - // viewed as sort of a "commitment height" as this number is - // monotonically increasing. - CommitHeight uint64 - - // LocalLogIndex is the cumulative log index index of the local node at - // this point in the commitment chain. This value will be incremented - // for each _update_ added to the local update log. - LocalLogIndex uint64 - - // LocalHtlcIndex is the current local running HTLC index. This value - // will be incremented for each outgoing HTLC the local node offers. - LocalHtlcIndex uint64 - - // RemoteLogIndex is the cumulative log index index of the remote node - // at this point in the commitment chain. This value will be - // incremented for each _update_ added to the remote update log. - RemoteLogIndex uint64 - - // RemoteHtlcIndex is the current remote running HTLC index. This value - // will be incremented for each outgoing HTLC the remote node offers. - RemoteHtlcIndex uint64 - - // LocalBalance is the current available settled balance within the - // channel directly spendable by us. - // - // NOTE: This is the balance *after* subtracting any commitment fee, - // AND anchor output values. - LocalBalance lnwire.MilliSatoshi - - // RemoteBalance is the current available settled balance within the - // channel directly spendable by the remote node. - // - // NOTE: This is the balance *after* subtracting any commitment fee, - // AND anchor output values. - RemoteBalance lnwire.MilliSatoshi - - // CommitFee is the amount calculated to be paid in fees for the - // current set of commitment transactions. The fee amount is persisted - // with the channel in order to allow the fee amount to be removed and - // recalculated with each channel state update, including updates that - // happen after a system restart. - CommitFee btcutil.Amount - - // FeePerKw is the min satoshis/kilo-weight that should be paid within - // the commitment transaction for the entire duration of the channel's - // lifetime. This field may be updated during normal operation of the - // channel as on-chain conditions change. - // - // TODO(halseth): make this SatPerKWeight. Cannot be done atm because - // this will cause the import cycle lnwallet<->channeldb. Fee - // estimation stuff should be in its own package. - FeePerKw btcutil.Amount - - // CommitTx is the latest version of the commitment state, broadcast - // able by us. - CommitTx *wire.MsgTx - - // CustomBlob is an optional blob that can be used to store information - // specific to a custom channel type. This may track some custom - // specific state for this given commitment. - CustomBlob fn.Option[tlv.Blob] - - // CommitSig is one half of the signature required to fully complete - // the script for the commitment transaction above. This is the - // signature signed by the remote party for our version of the - // commitment transactions. - CommitSig []byte - - // Htlcs is the set of HTLC's that are pending at this particular - // commitment height. - Htlcs []HTLC -} +// MarkChannelRealScid marks the zero-conf channel's confirmed ShortChannelID. +func (c *ChannelStateDB) MarkChannelRealScid(channel *OpenChannel, + realScid lnwire.ShortChannelID) error { -// amendTlvData updates the channel with the given auxiliary TLV data. -func (c *ChannelCommitment) amendTlvData(auxData commitTlvData) { - auxData.customBlob.WhenSomeV(func(blob tlv.Blob) { - c.CustomBlob = fn.Some(blob) - }) + return c.kvStore.MarkChannelRealScid(channel, realScid) } -// extractTlvData creates a new commitTlvData from the given commitment. -func (c *ChannelCommitment) extractTlvData() commitTlvData { - var auxData commitTlvData - - c.CustomBlob.WhenSome(func(blob tlv.Blob) { - auxData.customBlob = tlv.SomeRecordT( - tlv.NewPrimitiveRecord[tlv.TlvType1](blob), - ) - }) +// MarkChannelScidAliasNegotiated adds ScidAliasFeatureBit to ChanType in the +// database. +func (c *ChannelStateDB) MarkChannelScidAliasNegotiated( + channel *OpenChannel) error { - return auxData + return c.kvStore.MarkChannelScidAliasNegotiated(channel) } -// copy returns a deep copy of the channel commitment. -func (c *ChannelCommitment) copy() ChannelCommitment { - c2 := *c - if c.CommitTx != nil { - c2.CommitTx = c.CommitTx.Copy() - } - if len(c.CommitSig) > 0 { - c2.CommitSig = make([]byte, len(c.CommitSig)) - copy(c2.CommitSig, c.CommitSig) - } +// MarkChannelDataLoss marks the channel as local-data-loss and stores the +// commit point needed if the remote force closes. +func (c *ChannelStateDB) MarkChannelDataLoss(channel *OpenChannel, + commitPoint *btcec.PublicKey) error { - c.CustomBlob.WhenSome(func(blob tlv.Blob) { - blobCopy := make([]byte, len(blob)) - copy(blobCopy, blob) - c2.CustomBlob = fn.Some(blobCopy) - }) - - if len(c.Htlcs) > 0 { - c2.Htlcs = make([]HTLC, len(c.Htlcs)) - for i, h := range c.Htlcs { - c2.Htlcs[i] = h.Copy() - } - } + return c.kvStore.MarkChannelDataLoss(channel, commitPoint) +} + +// FetchChannelDataLossCommitPoint retrieves the commit point stored when the +// channel was marked as local-data-loss. +func (c *ChannelStateDB) FetchChannelDataLossCommitPoint( + channel *OpenChannel) (*btcec.PublicKey, error) { - return c2 + return c.kvStore.FetchChannelDataLossCommitPoint(channel) } -// ChannelStatus is a bit vector used to indicate whether an OpenChannel is in -// the default usable state, or a state where it shouldn't be used. -type ChannelStatus uint64 +// MarkChannelBorked marks the channel as irreconcilable. +func (c *ChannelStateDB) MarkChannelBorked(channel *OpenChannel) error { + return c.kvStore.MarkChannelBorked(channel) +} var ( - // ChanStatusDefault is the normal state of an open channel. - ChanStatusDefault ChannelStatus + // DeriveMusig2Shachain derives a shachain producer for the taproot + // channel from normal shachain revocation root. + DeriveMusig2Shachain = cstate.DeriveMusig2Shachain - // ChanStatusBorked indicates that the channel has entered an - // irreconcilable state, triggered by a state desynchronization or - // channel breach. Channels in this state should never be added to the - // htlc switch. - ChanStatusBorked ChannelStatus = 1 + // NewMusigVerificationNonce generates the local or verification nonce + // for another musig2 session. + NewMusigVerificationNonce = cstate.NewMusigVerificationNonce +) - // ChanStatusCommitBroadcasted indicates that a commitment for this - // channel has been broadcasted. - ChanStatusCommitBroadcasted ChannelStatus = 1 << 1 +// StoreChannelShutdownInfo persists the ShutdownInfo for the target channel. +func (c *ChannelStateDB) StoreChannelShutdownInfo(channel *OpenChannel, + info *ShutdownInfo) error { - // ChanStatusLocalDataLoss indicates that we have lost channel state - // for this channel, and broadcasting our latest commitment might be - // considered a breach. - // - // TODO(halseh): actually enforce that we are not force closing such a - // channel. - ChanStatusLocalDataLoss ChannelStatus = 1 << 2 - - // ChanStatusRestored is a status flag that signals that the channel - // has been restored, and doesn't have all the fields a typical channel - // will have. - ChanStatusRestored ChannelStatus = 1 << 3 - - // ChanStatusCoopBroadcasted indicates that a cooperative close for - // this channel has been broadcasted. Older cooperatively closed - // channels will only have this status set. Newer ones will also have - // close initiator information stored using the local/remote initiator - // status. This status is set in conjunction with the initiator status - // so that we do not need to check multiple channel statues for - // cooperative closes. - ChanStatusCoopBroadcasted ChannelStatus = 1 << 4 - - // ChanStatusLocalCloseInitiator indicates that we initiated closing - // the channel. - ChanStatusLocalCloseInitiator ChannelStatus = 1 << 5 + return c.kvStore.StoreChannelShutdownInfo(channel, info) +} - // ChanStatusRemoteCloseInitiator indicates that the remote node - // initiated closing the channel. - ChanStatusRemoteCloseInitiator ChannelStatus = 1 << 6 -) +// FetchChannelShutdownInfo fetches the persisted ShutdownInfo for the target +// channel. +func (c *ChannelStateDB) FetchChannelShutdownInfo( + channel *OpenChannel) (fn.Option[ShutdownInfo], error) { -// chanStatusStrings maps a ChannelStatus to a human friendly string that -// describes that status. -var chanStatusStrings = map[ChannelStatus]string{ - ChanStatusDefault: "ChanStatusDefault", - ChanStatusBorked: "ChanStatusBorked", - ChanStatusCommitBroadcasted: "ChanStatusCommitBroadcasted", - ChanStatusLocalDataLoss: "ChanStatusLocalDataLoss", - ChanStatusRestored: "ChanStatusRestored", - ChanStatusCoopBroadcasted: "ChanStatusCoopBroadcasted", - ChanStatusLocalCloseInitiator: "ChanStatusLocalCloseInitiator", - ChanStatusRemoteCloseInitiator: "ChanStatusRemoteCloseInitiator", + return c.kvStore.FetchChannelShutdownInfo(channel) } -// orderedChanStatusFlags is an in-order list of all that channel status flags. -var orderedChanStatusFlags = []ChannelStatus{ - ChanStatusBorked, - ChanStatusCommitBroadcasted, - ChanStatusLocalDataLoss, - ChanStatusRestored, - ChanStatusCoopBroadcasted, - ChanStatusLocalCloseInitiator, - ChanStatusRemoteCloseInitiator, +// MarkChannelCommitmentBroadcasted marks the channel as having a commitment +// transaction broadcast. +func (c *ChannelStateDB) MarkChannelCommitmentBroadcasted( + channel *OpenChannel, closeTx *wire.MsgTx, + closer lntypes.ChannelParty) error { + + return c.kvStore.MarkChannelCommitmentBroadcasted( + channel, closeTx, closer, + ) } -// String returns a human-readable representation of the ChannelStatus. -func (c ChannelStatus) String() string { - // If no flags are set, then this is the default case. - if c == ChanStatusDefault { - return chanStatusStrings[ChanStatusDefault] - } +// MarkChannelCoopBroadcasted marks the channel as having a cooperative close +// transaction broadcast. +func (c *ChannelStateDB) MarkChannelCoopBroadcasted(channel *OpenChannel, + closeTx *wire.MsgTx, closer lntypes.ChannelParty) error { - // Add individual bit flags. - statusStr := "" - for _, flag := range orderedChanStatusFlags { - if c&flag == flag { - statusStr += chanStatusStrings[flag] + "|" - c -= flag - } - } + return c.kvStore.MarkChannelCoopBroadcasted(channel, closeTx, closer) +} - // Remove anything to the right of the final bar, including it as well. - statusStr = strings.TrimRight(statusStr, "|") +// FetchChannelBroadcastedCommitment fetches the stored unilateral closing +// transaction. +func (c *ChannelStateDB) FetchChannelBroadcastedCommitment( + channel *OpenChannel) (*wire.MsgTx, error) { - // Add any remaining flags which aren't accounted for as hex. - if c != 0 { - statusStr += "|0x" + strconv.FormatUint(uint64(c), 16) - } + return c.kvStore.FetchChannelBroadcastedCommitment(channel) +} - // If this was purely an unknown flag, then remove the extra bar at the - // start of the string. - statusStr = strings.TrimLeft(statusStr, "|") +// FetchChannelBroadcastedCooperative fetches the stored cooperative closing +// transaction. +func (c *ChannelStateDB) FetchChannelBroadcastedCooperative( + channel *OpenChannel) (*wire.MsgTx, error) { - return statusStr + return c.kvStore.FetchChannelBroadcastedCooperative(channel) } -// FinalHtlcByte defines a byte type that encodes information about the final -// htlc resolution. -type FinalHtlcByte byte +// ApplyChannelStatus adds the target status to the channel's persisted status +// bit field. +func (c *ChannelStateDB) ApplyChannelStatus(channel *OpenChannel, + status ChannelStatus) error { -const ( - // FinalHtlcSettledBit is the bit that encodes whether the htlc was - // settled or failed. - FinalHtlcSettledBit FinalHtlcByte = 1 << 0 + return c.kvStore.ApplyChannelStatus(channel, status) +} - // FinalHtlcOffchainBit is the bit that encodes whether the htlc was - // resolved offchain or onchain. - FinalHtlcOffchainBit FinalHtlcByte = 1 << 1 -) +// ClearChannelStatus clears the target status from the channel's persisted +// status bit field. +func (c *ChannelStateDB) ClearChannelStatus(channel *OpenChannel, + status ChannelStatus) error { -// OpenChannel encapsulates the persistent and dynamic state of an open channel -// with a remote node. An open channel supports several options for on-disk -// serialization depending on the exact context. Full (upon channel creation) -// state commitments, and partial (due to a commitment update) writes are -// supported. Each partial write due to a state update appends the new update -// to an on-disk log, which can then subsequently be queried in order to -// "time-travel" to a prior state. -type OpenChannel struct { - // ChanType denotes which type of channel this is. - ChanType ChannelType - - // ChainHash is a hash which represents the blockchain that this - // channel will be opened within. This value is typically the genesis - // hash. In the case that the original chain went through a contentious - // hard-fork, then this value will be tweaked using the unique fork - // point on each branch. - ChainHash chainhash.Hash - - // FundingOutpoint is the outpoint of the final funding transaction. - // This value uniquely and globally identifies the channel within the - // target blockchain as specified by the chain hash parameter. - FundingOutpoint wire.OutPoint - - // ShortChannelID encodes the exact location in the chain in which the - // channel was initially confirmed. This includes: the block height, - // transaction index, and the output within the target transaction. - // - // If IsZeroConf(), then this will the "base" (very first) ALIAS scid - // and the confirmed SCID will be stored in ConfirmedScid. - ShortChannelID lnwire.ShortChannelID - - // IsPending indicates whether a channel's funding transaction has been - // confirmed. - IsPending bool - - // IsInitiator is a bool which indicates if we were the original - // initiator for the channel. This value may affect how higher levels - // negotiate fees, or close the channel. - IsInitiator bool - - // chanStatus is the current status of this channel. If it is not in - // the state Default, it should not be used for forwarding payments. - chanStatus ChannelStatus - - // FundingBroadcastHeight is the height in which the funding - // transaction was broadcast. This value can be used by higher level - // sub-systems to determine if a channel is stale and/or should have - // been confirmed before a certain height. - FundingBroadcastHeight uint32 - - // ConfirmationHeight records the block height at which the funding - // transaction was first confirmed. - ConfirmationHeight uint32 - - // CloseConfirmationHeight records the block height at which the closing - // transaction was first confirmed. This is used to track remaining - // confirmations until the channel is considered fully closed. It is - // None if the closing transaction has not yet been confirmed, or if - // this data was not available (e.g. channels closed before this - // field was introduced). - CloseConfirmationHeight fn.Option[uint32] - - // NumConfsRequired is the number of confirmations a channel's funding - // transaction must have received in order to be considered available - // for normal transactional use. - NumConfsRequired uint16 - - // ChannelFlags holds the flags that were sent as part of the - // open_channel message. - ChannelFlags lnwire.FundingFlag - - // IdentityPub is the identity public key of the remote node this - // channel has been established with. - IdentityPub *btcec.PublicKey - - // Capacity is the total capacity of this channel. - Capacity btcutil.Amount - - // TotalMSatSent is the total number of milli-satoshis we've sent - // within this channel. - TotalMSatSent lnwire.MilliSatoshi - - // TotalMSatReceived is the total number of milli-satoshis we've - // received within this channel. - TotalMSatReceived lnwire.MilliSatoshi - - // InitialLocalBalance is the balance we have during the channel - // opening. When we are not the initiator, this value represents the - // push amount. - InitialLocalBalance lnwire.MilliSatoshi - - // InitialRemoteBalance is the balance they have during the channel - // opening. - InitialRemoteBalance lnwire.MilliSatoshi - - // LocalChanCfg is the channel configuration for the local node. - LocalChanCfg ChannelConfig - - // RemoteChanCfg is the channel configuration for the remote node. - RemoteChanCfg ChannelConfig - - // LocalCommitment is the current local commitment state for the local - // party. This is stored distinct from the state of the remote party - // as there are certain asymmetric parameters which affect the - // structure of each commitment. - LocalCommitment ChannelCommitment - - // RemoteCommitment is the current remote commitment state for the - // remote party. This is stored distinct from the state of the local - // party as there are certain asymmetric parameters which affect the - // structure of each commitment. - RemoteCommitment ChannelCommitment - - // RemoteCurrentRevocation is the current revocation for their - // commitment transaction. However, since this the derived public key, - // we don't yet have the private key so we aren't yet able to verify - // that it's actually in the hash chain. - RemoteCurrentRevocation *btcec.PublicKey - - // RemoteNextRevocation is the revocation key to be used for the *next* - // commitment transaction we create for the local node. Within the - // specification, this value is referred to as the - // per-commitment-point. - RemoteNextRevocation *btcec.PublicKey - - // RevocationProducer is used to generate the revocation in such a way - // that remote side might store it efficiently and have the ability to - // restore the revocation by index if needed. Current implementation of - // secret producer is shachain producer. - RevocationProducer shachain.Producer - - // RevocationStore is used to efficiently store the revocations for - // previous channels states sent to us by remote side. Current - // implementation of secret store is shachain store. - RevocationStore shachain.Store - - // Packager is used to create and update forwarding packages for this - // channel, which encodes all necessary information to recover from - // failures and reforward HTLCs that were not fully processed. - Packager FwdPackager - - // FundingTxn is the transaction containing this channel's funding - // outpoint. Upon restarts, this txn will be rebroadcast if the channel - // is found to be pending. - // - // NOTE: This value will only be populated for single-funder channels - // for which we are the initiator, and that we also have the funding - // transaction for. One can check this by using the HasFundingTx() - // method on the ChanType field. - FundingTxn *wire.MsgTx - - // LocalShutdownScript is set to a pre-set script if the channel was opened - // by the local node with option_upfront_shutdown_script set. If the option - // was not set, the field is empty. - LocalShutdownScript lnwire.DeliveryAddress - - // RemoteShutdownScript is set to a pre-set script if the channel was opened - // by the remote node with option_upfront_shutdown_script set. If the option - // was not set, the field is empty. - RemoteShutdownScript lnwire.DeliveryAddress - - // ThawHeight is the height when a frozen channel once again becomes a - // normal channel. If this is zero, then there're no restrictions on - // this channel. If the value is lower than 500,000, then it's - // interpreted as a relative height, or an absolute height otherwise. - ThawHeight uint32 - - // LastWasRevoke is a boolean that determines if the last update we sent - // was a revocation (true) or a commitment signature (false). - LastWasRevoke bool - - // RevocationKeyLocator stores the KeyLocator information that we will - // need to derive the shachain root for this channel. This allows us to - // have private key isolation from lnd. - RevocationKeyLocator keychain.KeyLocator - - // confirmedScid is the confirmed ShortChannelID for a zero-conf - // channel. If the channel is unconfirmed, then this will be the - // default ShortChannelID. This is only set for zero-conf channels. - confirmedScid lnwire.ShortChannelID - - // Memo is any arbitrary information we wish to store locally about the - // channel that will be useful to our future selves. - Memo []byte - - // TapscriptRoot is an optional tapscript root used to derive the MuSig2 - // funding output. - TapscriptRoot fn.Option[chainhash.Hash] - - // CustomBlob is an optional blob that can be used to store information - // specific to a custom channel type. This information is only created - // at channel funding time, and after wards is to be considered - // immutable. - CustomBlob fn.Option[tlv.Blob] - - // TODO(roasbeef): eww - Db *ChannelStateDB - - // TODO(roasbeef): just need to store local and remote HTLC's? - - sync.RWMutex + return c.kvStore.ClearChannelStatus(channel, status) } -// String returns a string representation of the channel. -func (c *OpenChannel) String() string { - indexStr := "height=%v, local_htlc_index=%v, local_log_index=%v, " + - "remote_htlc_index=%v, remote_log_index=%v" +// SyncPendingChannel writes a pending channel to the store and records the +// funding broadcast height. +func (c *ChannelStateDB) SyncPendingChannel(channel *OpenChannel, + addr net.Addr, pendingHeight uint32) error { - commit := c.LocalCommitment - local := fmt.Sprintf(indexStr, commit.CommitHeight, - commit.LocalHtlcIndex, commit.LocalLogIndex, - commit.RemoteHtlcIndex, commit.RemoteLogIndex, - ) + return kvdb.Update(c.backend, func(tx kvdb.RwTx) error { + return syncNewChannel( + tx, channel, []net.Addr{addr}, c.backend, + pendingHeight, + ) + }, func() {}) +} - commit = c.RemoteCommitment - remote := fmt.Sprintf(indexStr, commit.CommitHeight, - commit.LocalHtlcIndex, commit.LocalLogIndex, - commit.RemoteHtlcIndex, commit.RemoteLogIndex, - ) +// syncNewChannel will write the passed channel to disk, and also create a +// LinkNode (if needed) for the channel peer. +func syncNewChannel(tx kvdb.RwTx, c *OpenChannel, addrs []net.Addr, + backend kvdb.Backend, pendingHeight uint32) error { - return fmt.Sprintf("SCID=%v, status=%v, initiator=%v, pending=%v, "+ - "local commitment has %s, remote commitment has %s", - c.ShortChannelID, c.chanStatus, c.IsInitiator, c.IsPending, - local, remote, - ) -} + // First, sync all the persistent channel state to disk. + err := cstate.SyncPendingOpenChannel(tx, c, pendingHeight) + if err != nil { + return err + } -// Initiator returns the ChannelParty that originally opened this channel. -func (c *OpenChannel) Initiator() lntypes.ChannelParty { - c.RLock() - defer c.RUnlock() + nodeInfoBucket, err := tx.CreateTopLevelBucket(nodeInfoBucket) + if err != nil { + return err + } - if c.IsInitiator { - return lntypes.Local + // If a LinkNode for this identity public key already exists, + // then we can exit early. + nodePub := c.IdentityPub.SerializeCompressed() + if nodeInfoBucket.Get(nodePub) != nil { + return nil } - return lntypes.Remote -} + // Next, we need to establish a (possibly) new LinkNode relationship + // for this channel. The LinkNode metadata contains reachability, + // up-time, and service bits related information. + linkNode := NewLinkNode( + &LinkNodeDB{backend: backend}, wire.MainNet, c.IdentityPub, + addrs..., + ) -// ShortChanID returns the current ShortChannelID of this channel. -func (c *OpenChannel) ShortChanID() lnwire.ShortChannelID { - c.RLock() - defer c.RUnlock() + // TODO(roasbeef): do away with link node all together? - return c.ShortChannelID + return putLinkNode(nodeInfoBucket, linkNode) } -// ZeroConfRealScid returns the zero-conf channel's confirmed scid. This should -// only be called if IsZeroConf returns true. -func (c *OpenChannel) ZeroConfRealScid() lnwire.ShortChannelID { - c.RLock() - defer c.RUnlock() +// UpdateChannelCommitment updates the local commitment state. +func (c *ChannelStateDB) UpdateChannelCommitment(channel *OpenChannel, + newCommitment *ChannelCommitment, + unsignedAckedUpdates []LogUpdate) (map[uint64]bool, error) { - return c.confirmedScid + return c.kvStore.UpdateChannelCommitment( + channel, newCommitment, unsignedAckedUpdates, + ) } -// ZeroConfConfirmed returns whether the zero-conf channel has confirmed. This -// should only be called if IsZeroConf returns true. -func (c *OpenChannel) ZeroConfConfirmed() bool { - c.RLock() - defer c.RUnlock() +// SerializeHtlcs writes out the passed set of HTLC's into the passed writer +// using the current default on-disk serialization format. +// +// This inline serialization has been extended to allow storage of extra data +// associated with a HTLC in the following way: +// - The known-length onion blob (1366 bytes) is serialized as var bytes in +// WriteElements (ie, the length 1366 was written, followed by the 1366 +// onion bytes). +// - To include extra data, we append any extra data present to this one +// variable length of data. Since we know that the onion is strictly 1366 +// bytes, any length after that should be considered to be extra data. +// +// NOTE: This API is NOT stable, the on-disk format will likely change in the +// future. +func SerializeHtlcs(b io.Writer, htlcs ...HTLC) error { + return cstate.SerializeHtlcs(b, htlcs...) +} - return c.confirmedScid != hop.Source +// DeserializeHtlcs attempts to read out a slice of HTLC's from the passed +// io.Reader. The bytes within the passed reader MUST have been previously +// written to using the SerializeHtlcs function. +// +// This inline deserialization has been extended to allow storage of extra data +// associated with a HTLC in the following way: +// - The known-length onion blob (1366 bytes) and any additional data present +// are read out as a single blob of variable byte data. +// - They are stored like this to take advantage of the variable space +// available for extension without migration (see SerializeHtlcs). +// - The first 1366 bytes are interpreted as the onion blob, and any remaining +// bytes as extra HTLC data. +// - This extra HTLC data is expected to be serialized as a TLV stream, and +// its parsing is left to higher layers. +// +// NOTE: This API is NOT stable, the on-disk format will likely change in the +// future. +func DeserializeHtlcs(r io.Reader) ([]HTLC, error) { + return cstate.DeserializeHtlcs(r) } -// IsZeroConf returns whether the option_zeroconf channel type was negotiated. -func (c *OpenChannel) IsZeroConf() bool { - c.RLock() - defer c.RUnlock() +// AppendRemoteCommitChain appends a new CommitDiff to the remote party's +// commitment chain. +func (c *ChannelStateDB) AppendRemoteCommitChain(channel *OpenChannel, + diff *CommitDiff) error { - return c.ChanType.HasZeroConf() + return c.kvStore.AppendRemoteCommitChain(channel, diff) } -// IsOptionScidAlias returns whether the option_scid_alias channel type was -// negotiated. -func (c *OpenChannel) IsOptionScidAlias() bool { - c.RLock() - defer c.RUnlock() +// RemoteCommitChainTip returns the "tip" of the current remote commitment +// chain. +func (c *ChannelStateDB) RemoteCommitChainTip(channel *OpenChannel) ( + *CommitDiff, error) { - return c.ChanType.HasScidAliasChan() + return c.kvStore.RemoteCommitChainTip(channel) } -// NegotiatedAliasFeature returns whether the option-scid-alias feature bit was -// negotiated. -func (c *OpenChannel) NegotiatedAliasFeature() bool { - c.RLock() - defer c.RUnlock() +// UnsignedAckedUpdates retrieves the persisted unsigned acked remote log +// updates that still need to be signed for. +func (c *ChannelStateDB) UnsignedAckedUpdates(channel *OpenChannel) ( + []LogUpdate, error) { - return c.ChanType.HasScidAliasFeature() + return c.kvStore.UnsignedAckedUpdates(channel) } -// ChanStatus returns the current ChannelStatus of this channel. -func (c *OpenChannel) ChanStatus() ChannelStatus { - c.RLock() - defer c.RUnlock() +// RemoteUnsignedLocalUpdates retrieves the persisted, unsigned local log +// updates that the remote still needs to sign for. +func (c *ChannelStateDB) RemoteUnsignedLocalUpdates(channel *OpenChannel) ( + []LogUpdate, error) { - return c.chanStatus + return c.kvStore.RemoteUnsignedLocalUpdates(channel) } -// ApplyChanStatus allows the caller to modify the internal channel state in a -// thead-safe manner. -func (c *OpenChannel) ApplyChanStatus(status ChannelStatus) error { - c.Lock() - defer c.Unlock() +// InsertNextRevocation inserts the next commitment point into the persisted +// channel state. +func (c *ChannelStateDB) InsertNextRevocation(channel *OpenChannel, + revKey *btcec.PublicKey) error { - return c.putChanStatus(status) + return c.kvStore.InsertNextRevocation(channel, revKey) } -// ClearChanStatus allows the caller to clear a particular channel status from -// the primary channel status bit field. After this method returns, a call to -// HasChanStatus(status) should return false. -func (c *OpenChannel) ClearChanStatus(status ChannelStatus) error { - c.Lock() - defer c.Unlock() +// AdvanceCommitChainTail records the new state transition within the +// revocation log and promotes the pending remote commitment to the current +// remote commitment. +func (c *ChannelStateDB) AdvanceCommitChainTail(channel *OpenChannel, + fwdPkg *FwdPkg, updates []LogUpdate, ourOutputIndex, + theirOutputIndex uint32) error { - return c.clearChanStatus(status) + return c.kvStore.AdvanceCommitChainTail( + channel, fwdPkg, updates, ourOutputIndex, theirOutputIndex, + ) } -// HasChanStatus returns true if the internal bitfield channel status of the -// target channel has the specified status bit set. -func (c *OpenChannel) HasChanStatus(status ChannelStatus) bool { - c.RLock() - defer c.RUnlock() +// FinalHtlcInfo contains information about the final outcome of an htlc. +type FinalHtlcInfo = cstate.FinalHtlcInfo + +// putFinalHtlc writes the final htlc outcome to the database. Additionally it +// records whether the htlc was resolved off-chain or on-chain. +func putFinalHtlc(finalHtlcsBucket kvdb.RwBucket, id uint64, + info FinalHtlcInfo) error { - return c.hasChanStatus(status) + return cstate.PutFinalHtlc(finalHtlcsBucket, id, info) } -func (c *OpenChannel) hasChanStatus(status ChannelStatus) bool { - // Special case ChanStatusDefualt since it isn't actually flag, but a - // particular combination (or lack-there-of) of flags. - if status == ChanStatusDefault { - return c.chanStatus == ChanStatusDefault - } +// LoadFwdPkgs scans the forwarding log for any packages that haven't been +// processed, and returns their deserialized log updates in map indexed by the +// remote commitment height at which the updates were locked in. +func (c *ChannelStateDB) LoadFwdPkgs(channel *OpenChannel) ([]*FwdPkg, + error) { - return c.chanStatus&status == status + return c.kvStore.LoadFwdPkgs(channel) } -// BroadcastHeight returns the height at which the funding tx was broadcast. -func (c *OpenChannel) BroadcastHeight() uint32 { - c.RLock() - defer c.RUnlock() +// AckAddHtlcs updates the AckAddFilter containing any of the provided AddRefs +// indicating that a response to this Add has been committed to the remote party. +// Doing so will prevent these Add HTLCs from being reforwarded internally. +// +//nolint:ll +func (c *ChannelStateDB) AckAddHtlcs(channel *OpenChannel, + addRefs ...AddRef) error { - return c.FundingBroadcastHeight + return c.kvStore.AckAddHtlcs(channel, addRefs...) } -// SetBroadcastHeight sets the FundingBroadcastHeight. -func (c *OpenChannel) SetBroadcastHeight(height uint32) { - c.Lock() - defer c.Unlock() +// AckSettleFails updates the SettleFailFilter containing any of the provided +// SettleFailRefs, indicating that the response has been delivered to the +// incoming link, corresponding to a particular AddRef. Doing so will prevent +// the responses from being retransmitted internally. +func (c *ChannelStateDB) AckSettleFails(channel *OpenChannel, + settleFailRefs ...SettleFailRef) error { - c.FundingBroadcastHeight = height + return c.kvStore.AckSettleFails(channel, settleFailRefs...) } -// amendTlvData updates the channel with the given auxiliary TLV data. -func (c *OpenChannel) amendTlvData(auxData openChannelTlvData) { - c.RevocationKeyLocator = auxData.revokeKeyLoc.Val.KeyLocator - c.InitialLocalBalance = lnwire.MilliSatoshi( - auxData.initialLocalBalance.Val, - ) - c.InitialRemoteBalance = lnwire.MilliSatoshi( - auxData.initialRemoteBalance.Val, - ) - c.confirmedScid = auxData.realScid.Val - c.ConfirmationHeight = auxData.confirmationHeight.Val - - auxData.memo.WhenSomeV(func(memo []byte) { - c.Memo = memo - }) - auxData.tapscriptRoot.WhenSomeV(func(h [32]byte) { - c.TapscriptRoot = fn.Some[chainhash.Hash](h) - }) - auxData.customBlob.WhenSomeV(func(blob tlv.Blob) { - c.CustomBlob = fn.Some(blob) - }) - auxData.closeConfirmationHeight.WhenSomeV(func(h uint32) { - c.CloseConfirmationHeight = fn.Some(h) - }) -} +// SetFwdFilter atomically sets the forwarding filter for the forwarding package +// identified by `height`. +func (c *ChannelStateDB) SetFwdFilter(channel *OpenChannel, height uint64, + fwdFilter *PkgFilter) error { -// extractTlvData creates a new openChannelTlvData from the given channel. -func (c *OpenChannel) extractTlvData() openChannelTlvData { - auxData := openChannelTlvData{ - revokeKeyLoc: tlv.NewRecordT[tlv.TlvType1]( - keyLocRecord{c.RevocationKeyLocator}, - ), - initialLocalBalance: tlv.NewPrimitiveRecord[tlv.TlvType2]( - uint64(c.InitialLocalBalance), - ), - initialRemoteBalance: tlv.NewPrimitiveRecord[tlv.TlvType3]( - uint64(c.InitialRemoteBalance), - ), - realScid: tlv.NewRecordT[tlv.TlvType4]( - c.confirmedScid, - ), - confirmationHeight: tlv.NewPrimitiveRecord[tlv.TlvType8]( - c.ConfirmationHeight, - ), - } + return c.kvStore.SetFwdFilter(channel, height, fwdFilter) +} - if len(c.Memo) != 0 { - auxData.memo = tlv.SomeRecordT( - tlv.NewPrimitiveRecord[tlv.TlvType5](c.Memo), - ) - } - c.TapscriptRoot.WhenSome(func(h chainhash.Hash) { - auxData.tapscriptRoot = tlv.SomeRecordT( - tlv.NewPrimitiveRecord[tlv.TlvType6, [32]byte](h), - ) - }) - c.CustomBlob.WhenSome(func(blob tlv.Blob) { - auxData.customBlob = tlv.SomeRecordT( - tlv.NewPrimitiveRecord[tlv.TlvType7](blob), - ) - }) - c.CloseConfirmationHeight.WhenSome(func(h uint32) { - auxData.closeConfirmationHeight = tlv.SomeRecordT( - tlv.NewPrimitiveRecord[tlv.TlvType9](h), - ) - }) +// RemoveFwdPkgs atomically removes forwarding packages specified by the remote +// commitment heights. If one of the intermediate RemovePkg calls fails, then the +// later packages won't be removed. +// +// NOTE: This method should only be called on packages marked FwdStateCompleted. +// +//nolint:ll +func (c *ChannelStateDB) RemoveFwdPkgs(channel *OpenChannel, + heights ...uint64) error { - return auxData + return c.kvStore.RemoveFwdPkgs(channel, heights...) } -// Refresh updates the in-memory channel state using the latest state observed -// on disk. -func (c *OpenChannel) Refresh() error { - c.Lock() - defer c.Unlock() +// revocationLogTailCommitHeight returns the commit height at the end of the +// revocation log. +func (c *ChannelStateDB) revocationLogTailCommitHeight( + channel *OpenChannel) (uint64, error) { - err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error { - chanBucket, err := fetchChanBucket( - tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, - ) - if err != nil { - return err - } - - // We'll re-populating the in-memory channel with the info - // fetched from disk. - if err := fetchChanInfo(chanBucket, c); err != nil { - return fmt.Errorf("unable to fetch chan info: %w", err) - } - - // Also populate the channel's commitment states for both sides - // of the channel. - if err := fetchChanCommitments(chanBucket, c); err != nil { - return fmt.Errorf("unable to fetch chan commitments: "+ - "%v", err) - } - - // Also retrieve the current revocation state. - if err := fetchChanRevocationState(chanBucket, c); err != nil { - return fmt.Errorf("unable to fetch chan revocations: "+ - "%v", err) - } + return c.kvStore.RevocationLogTailCommitHeight(channel) +} - return nil - }, func() {}) - if err != nil { - return err - } - - return nil -} - -// fetchChanBucket is a helper function that returns the bucket where a -// channel's data resides in given: the public key for the node, the outpoint, -// and the chainhash that the channel resides on. -func fetchChanBucket(tx kvdb.RTx, nodeKey *btcec.PublicKey, - outPoint *wire.OutPoint, chainHash chainhash.Hash) (kvdb.RBucket, error) { - - // First fetch the top level bucket which stores all data related to - // current, active channels. - openChanBucket := tx.ReadBucket(openChannelBucket) - if openChanBucket == nil { - return nil, ErrNoChanDBExists - } - - // TODO(roasbeef): CreateTopLevelBucket on the interface isn't like - // CreateIfNotExists, will return error - - // Within this top level bucket, fetch the bucket dedicated to storing - // open channel data specific to the remote node. - nodePub := nodeKey.SerializeCompressed() - nodeChanBucket := openChanBucket.NestedReadBucket(nodePub) - if nodeChanBucket == nil { - return nil, ErrNoActiveChannels - } - - // We'll then recurse down an additional layer in order to fetch the - // bucket for this particular chain. - chainBucket := nodeChanBucket.NestedReadBucket(chainHash[:]) - if chainBucket == nil { - return nil, ErrNoActiveChannels - } - - // With the bucket for the node and chain fetched, we can now go down - // another level, for this channel itself. - var chanPointBuf bytes.Buffer - if err := graphdb.WriteOutpoint(&chanPointBuf, outPoint); err != nil { - return nil, err - } - chanKey := chanPointBuf.Bytes() - - // Treat already-closed channels as gone. The chanBucket may still - // exist on tombstone-enabled backends; the outpoint flip is the - // source of truth. - closed, err := isOutpointClosed(tx.ReadBucket(outpointBucket), chanKey) - if err != nil { - return nil, err - } - if closed { - return nil, ErrChannelNotFound - } - - chanBucket := chainBucket.NestedReadBucket(chanKey) - if chanBucket == nil { - return nil, ErrChannelNotFound - } - - return chanBucket, nil -} - -// fetchChanBucketRw is a helper function that returns the bucket where a -// channel's data resides in given: the public key for the node, the outpoint, -// and the chainhash that the channel resides on. This differs from -// fetchChanBucket in that it returns a writeable bucket. -func fetchChanBucketRw(tx kvdb.RwTx, nodeKey *btcec.PublicKey, - outPoint *wire.OutPoint, chainHash chainhash.Hash) (kvdb.RwBucket, - error) { - - // First fetch the top level bucket which stores all data related to - // current, active channels. - openChanBucket := tx.ReadWriteBucket(openChannelBucket) - if openChanBucket == nil { - return nil, ErrNoChanDBExists - } - - // TODO(roasbeef): CreateTopLevelBucket on the interface isn't like - // CreateIfNotExists, will return error - - // Within this top level bucket, fetch the bucket dedicated to storing - // open channel data specific to the remote node. - nodePub := nodeKey.SerializeCompressed() - nodeChanBucket := openChanBucket.NestedReadWriteBucket(nodePub) - if nodeChanBucket == nil { - return nil, ErrNoActiveChannels - } - - // We'll then recurse down an additional layer in order to fetch the - // bucket for this particular chain. - chainBucket := nodeChanBucket.NestedReadWriteBucket(chainHash[:]) - if chainBucket == nil { - return nil, ErrNoActiveChannels - } - - // With the bucket for the node and chain fetched, we can now go down - // another level, for this channel itself. - var chanPointBuf bytes.Buffer - if err := graphdb.WriteOutpoint(&chanPointBuf, outPoint); err != nil { - return nil, err - } - chanKey := chanPointBuf.Bytes() - - // Treat already-closed channels as gone. The chanBucket may still - // exist on tombstone-enabled backends; the outpoint flip is the - // source of truth. - closed, err := isOutpointClosed(tx.ReadBucket(outpointBucket), chanKey) - if err != nil { - return nil, err - } - if closed { - return nil, ErrChannelNotFound - } - - chanBucket := chainBucket.NestedReadWriteBucket(chanKey) - if chanBucket == nil { - return nil, ErrChannelNotFound - } - - return chanBucket, nil -} - -func fetchFinalHtlcsBucketRw(tx kvdb.RwTx, - chanID lnwire.ShortChannelID) (kvdb.RwBucket, error) { - - finalHtlcsBucket, err := tx.CreateTopLevelBucket(finalHtlcsBucket) - if err != nil { - return nil, err - } - - var chanIDBytes [8]byte - byteOrder.PutUint64(chanIDBytes[:], chanID.ToUint64()) - chanBucket, err := finalHtlcsBucket.CreateBucketIfNotExists( - chanIDBytes[:], - ) - if err != nil { - return nil, err - } - - return chanBucket, nil -} - -// fullSync syncs the contents of an OpenChannel while re-using an existing -// database transaction. -func (c *OpenChannel) fullSync(tx kvdb.RwTx) error { - // Fetch the outpoint bucket and check if the outpoint already exists. - opBucket := tx.ReadWriteBucket(outpointBucket) - if opBucket == nil { - return ErrNoChanDBExists - } - cidBucket := tx.ReadWriteBucket(chanIDBucket) - if cidBucket == nil { - return ErrNoChanDBExists - } - - var chanPointBuf bytes.Buffer - err := graphdb.WriteOutpoint(&chanPointBuf, &c.FundingOutpoint) - if err != nil { - return err - } - - // Now, check if the outpoint exists in our index. - if opBucket.Get(chanPointBuf.Bytes()) != nil { - return ErrChanAlreadyExists - } - - cid := lnwire.NewChanIDFromOutPoint(c.FundingOutpoint) - if cidBucket.Get(cid[:]) != nil { - return ErrChanAlreadyExists - } - - status := uint8(outpointOpen) - - // Write the status of this outpoint as the first entry in a tlv - // stream. - statusRecord := tlv.MakePrimitiveRecord(indexStatusType, &status) - opStream, err := tlv.NewStream(statusRecord) - if err != nil { - return err - } - - var b bytes.Buffer - if err := opStream.Encode(&b); err != nil { - return err - } - - // Add the outpoint to our outpoint index with the tlv stream. - if err := opBucket.Put(chanPointBuf.Bytes(), b.Bytes()); err != nil { - return err - } - - if err := cidBucket.Put(cid[:], []byte{}); err != nil { - return err - } - - // First fetch the top level bucket which stores all data related to - // current, active channels. - openChanBucket, err := tx.CreateTopLevelBucket(openChannelBucket) - if err != nil { - return err - } - - // Within this top level bucket, fetch the bucket dedicated to storing - // open channel data specific to the remote node. - nodePub := c.IdentityPub.SerializeCompressed() - nodeChanBucket, err := openChanBucket.CreateBucketIfNotExists(nodePub) - if err != nil { - return err - } - - // We'll then recurse down an additional layer in order to fetch the - // bucket for this particular chain. - chainBucket, err := nodeChanBucket.CreateBucketIfNotExists(c.ChainHash[:]) - if err != nil { - return err - } - - // With the bucket for the node fetched, we can now go down another - // level, creating the bucket for this channel itself. - chanBucket, err := chainBucket.CreateBucket( - chanPointBuf.Bytes(), - ) - switch { - case err == kvdb.ErrBucketExists: - // If this channel already exists, then in order to avoid - // overriding it, we'll return an error back up to the caller. - return ErrChanAlreadyExists - case err != nil: - return err - } - - return putOpenChannel(chanBucket, c) -} - -// MarkConfirmationHeight updates the channel's confirmation height once the -// channel opening transaction receives one confirmation. -func (c *OpenChannel) MarkConfirmationHeight(height uint32) error { - c.Lock() - defer c.Unlock() - - if err := kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error { - chanBucket, err := fetchChanBucketRw( - tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, - ) - if err != nil { - return err - } - - channel, err := fetchOpenChannel(chanBucket, &c.FundingOutpoint) - if err != nil { - return err - } - - channel.ConfirmationHeight = height - - return putOpenChannel(chanBucket, channel) - }, func() {}); err != nil { - return err - } - - c.ConfirmationHeight = height - - return nil -} - -// ResetCloseConfirmationHeight clears the channel's close confirmation height -// when the spending transaction is reorged out. -func (c *OpenChannel) ResetCloseConfirmationHeight() error { - return c.MarkCloseConfirmationHeight(fn.None[uint32]()) -} - -// MarkCloseConfirmationHeight updates the channel's close confirmation height -// when the closing transaction is first detected in a block (spend height). -func (c *OpenChannel) MarkCloseConfirmationHeight( - height fn.Option[uint32]) error { - - c.Lock() - defer c.Unlock() - - if err := kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error { - chanBucket, err := fetchChanBucketRw( - tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, - ) - if err != nil { - return err - } - - channel, err := fetchOpenChannel(chanBucket, &c.FundingOutpoint) - if err != nil { - return err - } - - channel.CloseConfirmationHeight = height - - return putOpenChannel(chanBucket, channel) - }, func() {}); err != nil { - return err - } - - c.CloseConfirmationHeight = height - - return nil -} - -// MarkAsOpen marks a channel as fully open given a locator that uniquely -// describes its location within the chain. -func (c *OpenChannel) MarkAsOpen(openLoc lnwire.ShortChannelID) error { - c.Lock() - defer c.Unlock() - - if err := kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error { - chanBucket, err := fetchChanBucketRw( - tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, - ) - if err != nil { - return err - } - - channel, err := fetchOpenChannel(chanBucket, &c.FundingOutpoint) - if err != nil { - return err - } - - channel.IsPending = false - channel.ShortChannelID = openLoc - - return putOpenChannel(chanBucket, channel) - }, func() {}); err != nil { - return err - } - - c.IsPending = false - c.ShortChannelID = openLoc - c.Packager = NewChannelPackager(openLoc) - - return nil -} - -// MarkRealScid marks the zero-conf channel's confirmed ShortChannelID. This -// should only be done if IsZeroConf returns true. -func (c *OpenChannel) MarkRealScid(realScid lnwire.ShortChannelID) error { - c.Lock() - defer c.Unlock() - - if err := kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error { - chanBucket, err := fetchChanBucketRw( - tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, - ) - if err != nil { - return err - } - - channel, err := fetchOpenChannel( - chanBucket, &c.FundingOutpoint, - ) - if err != nil { - return err - } - - channel.confirmedScid = realScid - - return putOpenChannel(chanBucket, channel) - }, func() {}); err != nil { - return err - } - - c.confirmedScid = realScid - - return nil -} - -// MarkScidAliasNegotiated adds ScidAliasFeatureBit to ChanType in-memory and -// in the database. -func (c *OpenChannel) MarkScidAliasNegotiated() error { - c.Lock() - defer c.Unlock() - - if err := kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error { - chanBucket, err := fetchChanBucketRw( - tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, - ) - if err != nil { - return err - } - - channel, err := fetchOpenChannel( - chanBucket, &c.FundingOutpoint, - ) - if err != nil { - return err - } - - channel.ChanType |= ScidAliasFeatureBit - return putOpenChannel(chanBucket, channel) - }, func() {}); err != nil { - return err - } - - c.ChanType |= ScidAliasFeatureBit - - return nil -} - -// MarkDataLoss marks sets the channel status to LocalDataLoss and stores the -// passed commitPoint for use to retrieve funds in case the remote force closes -// the channel. -func (c *OpenChannel) MarkDataLoss(commitPoint *btcec.PublicKey) error { - c.Lock() - defer c.Unlock() - - var b bytes.Buffer - if err := WriteElement(&b, commitPoint); err != nil { - return err - } - - putCommitPoint := func(chanBucket kvdb.RwBucket) error { - return chanBucket.Put(dataLossCommitPointKey, b.Bytes()) - } - - return c.putChanStatus(ChanStatusLocalDataLoss, putCommitPoint) -} - -// DataLossCommitPoint retrieves the stored commit point set during -// MarkDataLoss. If not found ErrNoCommitPoint is returned. -func (c *OpenChannel) DataLossCommitPoint() (*btcec.PublicKey, error) { - var commitPoint *btcec.PublicKey - - err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error { - chanBucket, err := fetchChanBucket( - tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, - ) - switch err { - case nil: - case ErrNoChanDBExists, ErrNoActiveChannels, ErrChannelNotFound: - return ErrNoCommitPoint - default: - return err - } - - bs := chanBucket.Get(dataLossCommitPointKey) - if bs == nil { - return ErrNoCommitPoint - } - r := bytes.NewReader(bs) - if err := ReadElements(r, &commitPoint); err != nil { - return err - } - - return nil - }, func() { - commitPoint = nil - }) - if err != nil { - return nil, err - } - - return commitPoint, nil -} - -// MarkBorked marks the event when the channel as reached an irreconcilable -// state, such as a channel breach or state desynchronization. Borked channels -// should never be added to the switch. -func (c *OpenChannel) MarkBorked() error { - c.Lock() - defer c.Unlock() - - return c.putChanStatus(ChanStatusBorked) -} - -// SecondCommitmentPoint returns the second per-commitment-point for use in the -// channel_ready message. -func (c *OpenChannel) SecondCommitmentPoint() (*btcec.PublicKey, error) { - c.RLock() - defer c.RUnlock() - - // Since we start at commitment height = 0, the second per commitment - // point is actually at the 1st index. - revocation, err := c.RevocationProducer.AtIndex(1) - if err != nil { - return nil, err - } - - return input.ComputeCommitmentPoint(revocation[:]), nil -} - -var ( - // taprootRevRootKey is the key used to derive the revocation root for - // the taproot nonces. This is done via HMAC of the existing revocation - // root. - taprootRevRootKey = []byte("taproot-rev-root") -) - -// DeriveMusig2Shachain derives a shachain producer for the taproot channel -// from normal shachain revocation root. -func DeriveMusig2Shachain(revRoot shachain.Producer) (shachain.Producer, error) { //nolint:ll - // In order to obtain the revocation root hash to create the taproot - // revocation, we'll encode the producer into a buffer, then use that - // to derive the shachain root needed. - var rootHashBuf bytes.Buffer - if err := revRoot.Encode(&rootHashBuf); err != nil { - return nil, fmt.Errorf("unable to encode producer: %w", err) - } - - revRootHash := chainhash.HashH(rootHashBuf.Bytes()) - - // For taproot channel types, we'll also generate a distinct shachain - // root using the same seed information. We'll use this to generate - // verification nonces for the channel. We'll bind with this a simple - // hmac. - taprootRevHmac := hmac.New(sha256.New, taprootRevRootKey) - if _, err := taprootRevHmac.Write(revRootHash[:]); err != nil { - return nil, err - } - - taprootRevRoot := taprootRevHmac.Sum(nil) - - // Once we have the root, we can then generate our shachain producer - // and from that generate the per-commitment point. - return shachain.NewRevocationProducerFromBytes( - taprootRevRoot, - ) -} - -// NewMusigVerificationNonce generates the local or verification nonce for -// another musig2 session. In order to permit our implementation to not have to -// write any secret nonce state to disk, we'll use the _next_ shachain -// pre-image as our primary randomness source. When used to generate the nonce -// again to broadcast our commitment hte current height will be used. -func NewMusigVerificationNonce(pubKey *btcec.PublicKey, targetHeight uint64, - shaGen shachain.Producer) (*musig2.Nonces, error) { - - // Now that we know what height we need, we'll grab the shachain - // pre-image at the target destination. - nextPreimage, err := shaGen.AtIndex(targetHeight) - if err != nil { - return nil, err - } - - shaChainRand := musig2.WithCustomRand(bytes.NewBuffer(nextPreimage[:])) - pubKeyOpt := musig2.WithPublicKey(pubKey) - - return musig2.GenNonces(pubKeyOpt, shaChainRand) -} - -// ChanSyncMsg returns the ChannelReestablish message that should be sent upon -// reconnection with the remote peer that we're maintaining this channel with. -// The information contained within this message is necessary to re-sync our -// commitment chains in the case of a last or only partially processed message. -// When the remote party receives this message one of three things may happen: -// -// 1. We're fully synced and no messages need to be sent. -// 2. We didn't get the last CommitSig message they sent, so they'll re-send -// it. -// 3. We didn't get the last RevokeAndAck message they sent, so they'll -// re-send it. -// -// If this is a restored channel, having status ChanStatusRestored, then we'll -// modify our typical chan sync message to ensure they force close even if -// we're on the very first state. -func (c *OpenChannel) ChanSyncMsg() (*lnwire.ChannelReestablish, error) { - - c.Lock() - defer c.Unlock() - - // The remote commitment height that we'll send in the - // ChannelReestablish message is our current commitment height plus - // one. If the receiver thinks that our commitment height is actually - // *equal* to this value, then they'll re-send the last commitment that - // they sent but we never fully processed. - localHeight := c.LocalCommitment.CommitHeight - nextLocalCommitHeight := localHeight + 1 - - // The second value we'll send is the height of the remote commitment - // from our PoV. If the receiver thinks that their height is actually - // *one plus* this value, then they'll re-send their last revocation. - remoteChainTipHeight := c.RemoteCommitment.CommitHeight - - // If this channel has undergone a commitment update, then in order to - // prove to the remote party our knowledge of their prior commitment - // state, we'll also send over the last commitment secret that the - // remote party sent. - var lastCommitSecret [32]byte - if remoteChainTipHeight != 0 { - remoteSecret, err := c.RevocationStore.LookUp( - remoteChainTipHeight - 1, - ) - if err != nil { - return nil, err - } - lastCommitSecret = [32]byte(*remoteSecret) - } - - // Additionally, we'll send over the current unrevoked commitment on - // our local commitment transaction. - currentCommitSecret, err := c.RevocationProducer.AtIndex( - localHeight, - ) - if err != nil { - return nil, err - } - - // If we've restored this channel, then we'll purposefully give them an - // invalid LocalUnrevokedCommitPoint so they'll force close the channel - // allowing us to sweep our funds. - if c.hasChanStatus(ChanStatusRestored) { - currentCommitSecret[0] ^= 1 - - // If this is a tweakless channel, then we'll purposefully send - // a next local height taht's invalid to trigger a force close - // on their end. We do this as tweakless channels don't require - // that the commitment point is valid, only that it's present. - if c.ChanType.IsTweakless() { - nextLocalCommitHeight = 0 - } - } - - // If this is a taproot channel, then we'll need to generate our next - // verification nonce to send to the remote party. They'll use this to - // sign the next update to our commitment transaction. - var ( - nextTaprootNonce lnwire.OptMusig2NonceTLV - nextLocalNonces lnwire.OptLocalNonces - ) - if c.ChanType.IsTaproot() { - taprootRevProducer, err := DeriveMusig2Shachain( - c.RevocationProducer, - ) - if err != nil { - return nil, err - } - - nextNonce, err := NewMusigVerificationNonce( - c.LocalChanCfg.MultiSigKey.PubKey, - nextLocalCommitHeight, taprootRevProducer, - ) - if err != nil { - return nil, fmt.Errorf("unable to gen next "+ - "nonce: %w", err) - } - - fundingTxid := c.FundingOutpoint.Hash - nonce := nextNonce.PubNonce - - // Final taproot channels use the map-based LocalNonces - // field keyed by funding TXID. Staging channels use the - // legacy single LocalNonce field. - if c.ChanType.IsTaprootFinal() { - noncesMap := make(map[chainhash.Hash]lnwire.Musig2Nonce) - noncesMap[fundingTxid] = nonce - nextLocalNonces = lnwire.SomeLocalNonces( - lnwire.LocalNoncesData{NoncesMap: noncesMap}, - ) - } else { - nextTaprootNonce = lnwire.SomeMusig2Nonce(nonce) - } - } - - return &lnwire.ChannelReestablish{ - ChanID: lnwire.NewChanIDFromOutPoint( - c.FundingOutpoint, - ), - NextLocalCommitHeight: nextLocalCommitHeight, - RemoteCommitTailHeight: remoteChainTipHeight, - LastRemoteCommitSecret: lastCommitSecret, - LocalUnrevokedCommitPoint: input.ComputeCommitmentPoint( - currentCommitSecret[:], - ), - LocalNonce: nextTaprootNonce, - LocalNonces: nextLocalNonces, - }, nil -} - -// MarkShutdownSent serialises and persist the given ShutdownInfo for this -// channel. Persisting this info represents the fact that we have sent the -// Shutdown message to the remote side and hence that we should re-transmit the -// same Shutdown message on re-establish. -func (c *OpenChannel) MarkShutdownSent(info *ShutdownInfo) error { - c.Lock() - defer c.Unlock() - - return c.storeShutdownInfo(info) -} - -// storeShutdownInfo serialises the ShutdownInfo and persists it under the -// shutdownInfoKey. -func (c *OpenChannel) storeShutdownInfo(info *ShutdownInfo) error { - var b bytes.Buffer - err := info.encode(&b) - if err != nil { - return err - } - - return kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error { - chanBucket, err := fetchChanBucketRw( - tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, - ) - if err != nil { - return err - } - - return chanBucket.Put(shutdownInfoKey, b.Bytes()) - }, func() {}) -} - -// ShutdownInfo decodes the shutdown info stored for this channel and returns -// the result. If no shutdown info has been persisted for this channel then the -// ErrNoShutdownInfo error is returned. -func (c *OpenChannel) ShutdownInfo() (fn.Option[ShutdownInfo], error) { - c.RLock() - defer c.RUnlock() - - var shutdownInfo *ShutdownInfo - err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error { - chanBucket, err := fetchChanBucket( - tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, - ) - switch { - case err == nil: - case errors.Is(err, ErrNoChanDBExists), - errors.Is(err, ErrNoActiveChannels), - errors.Is(err, ErrChannelNotFound): - - return ErrNoShutdownInfo - default: - return err - } - - shutdownInfoBytes := chanBucket.Get(shutdownInfoKey) - if shutdownInfoBytes == nil { - return ErrNoShutdownInfo - } - - shutdownInfo, err = decodeShutdownInfo(shutdownInfoBytes) - - return err - }, func() { - shutdownInfo = nil - }) - if err != nil { - return fn.None[ShutdownInfo](), err - } - - return fn.Some[ShutdownInfo](*shutdownInfo), nil -} - -// isBorked returns true if the channel has been marked as borked in the -// database. This requires an existing database transaction to already be -// active. -// -// NOTE: The primary mutex should already be held before this method is called. -func (c *OpenChannel) isBorked(chanBucket kvdb.RBucket) (bool, error) { - channel, err := fetchOpenChannel(chanBucket, &c.FundingOutpoint) - if err != nil { - return false, err - } - - return channel.chanStatus != ChanStatusDefault, nil -} - -// MarkCommitmentBroadcasted marks the channel as a commitment transaction has -// been broadcast, either our own or the remote, and we should watch the chain -// for it to confirm before taking any further action. It takes as argument the -// closing tx _we believe_ will appear in the chain. This is only used to -// republish this tx at startup to ensure propagation, and we should still -// handle the case where a different tx actually hits the chain. -func (c *OpenChannel) MarkCommitmentBroadcasted(closeTx *wire.MsgTx, - closer lntypes.ChannelParty) error { - - return c.markBroadcasted( - ChanStatusCommitBroadcasted, forceCloseTxKey, closeTx, - closer, - ) -} - -// MarkCoopBroadcasted marks the channel to indicate that a cooperative close -// transaction has been broadcast, either our own or the remote, and that we -// should watch the chain for it to confirm before taking further action. It -// takes as argument a cooperative close tx that could appear on chain, and -// should be rebroadcast upon startup. This is only used to republish and -// ensure propagation, and we should still handle the case where a different tx -// actually hits the chain. -func (c *OpenChannel) MarkCoopBroadcasted(closeTx *wire.MsgTx, - closer lntypes.ChannelParty) error { - - return c.markBroadcasted( - ChanStatusCoopBroadcasted, coopCloseTxKey, closeTx, - closer, - ) -} - -// markBroadcasted is a helper function which modifies the channel status of the -// receiving channel and inserts a close transaction under the requested key, -// which should specify either a coop or force close. It adds a status which -// indicates the party that initiated the channel close. -func (c *OpenChannel) markBroadcasted(status ChannelStatus, key []byte, - closeTx *wire.MsgTx, closer lntypes.ChannelParty) error { - - c.Lock() - defer c.Unlock() - - // If a closing tx is provided, we'll generate a closure to write the - // transaction in the appropriate bucket under the given key. - var putClosingTx func(kvdb.RwBucket) error - if closeTx != nil { - var b bytes.Buffer - if err := WriteElement(&b, closeTx); err != nil { - return err - } - - putClosingTx = func(chanBucket kvdb.RwBucket) error { - return chanBucket.Put(key, b.Bytes()) - } - } - - // Add the initiator status to the status provided. These statuses are - // set in addition to the broadcast status so that we do not need to - // migrate the original logic which does not store initiator. - if closer.IsLocal() { - status |= ChanStatusLocalCloseInitiator - } else { - status |= ChanStatusRemoteCloseInitiator - } - - return c.putChanStatus(status, putClosingTx) -} - -// BroadcastedCommitment retrieves the stored unilateral closing tx set during -// MarkCommitmentBroadcasted. If not found ErrNoCloseTx is returned. -func (c *OpenChannel) BroadcastedCommitment() (*wire.MsgTx, error) { - return c.getClosingTx(forceCloseTxKey) -} - -// BroadcastedCooperative retrieves the stored cooperative closing tx set during -// MarkCoopBroadcasted. If not found ErrNoCloseTx is returned. -func (c *OpenChannel) BroadcastedCooperative() (*wire.MsgTx, error) { - return c.getClosingTx(coopCloseTxKey) -} - -// getClosingTx is a helper method which returns the stored closing transaction -// for key. The caller should use either the force or coop closing keys. -func (c *OpenChannel) getClosingTx(key []byte) (*wire.MsgTx, error) { - var closeTx *wire.MsgTx - - err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error { - chanBucket, err := fetchChanBucket( - tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, - ) - switch err { - case nil: - case ErrNoChanDBExists, ErrNoActiveChannels, ErrChannelNotFound: - return ErrNoCloseTx - default: - return err - } - - bs := chanBucket.Get(key) - if bs == nil { - return ErrNoCloseTx - } - r := bytes.NewReader(bs) - return ReadElement(r, &closeTx) - }, func() { - closeTx = nil - }) - if err != nil { - return nil, err - } - - return closeTx, nil -} - -// putChanStatus appends the given status to the channel. fs is an optional -// list of closures that are given the chanBucket in order to atomically add -// extra information together with the new status. -func (c *OpenChannel) putChanStatus(status ChannelStatus, - fs ...func(kvdb.RwBucket) error) error { - - if err := kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error { - chanBucket, err := fetchChanBucketRw( - tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, - ) - if err != nil { - return err - } - - channel, err := fetchOpenChannel(chanBucket, &c.FundingOutpoint) - if err != nil { - return err - } - - // Add this status to the existing bitvector found in the DB. - status = channel.chanStatus | status - channel.chanStatus = status - - if err := putOpenChannel(chanBucket, channel); err != nil { - return err - } - - for _, f := range fs { - // Skip execution of nil closures. - if f == nil { - continue - } - - if err := f(chanBucket); err != nil { - return err - } - } - - return nil - }, func() {}); err != nil { - return err - } - - // Update the in-memory representation to keep it in sync with the DB. - c.chanStatus = status - - return nil -} - -func (c *OpenChannel) clearChanStatus(status ChannelStatus) error { - if err := kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error { - chanBucket, err := fetchChanBucketRw( - tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, - ) - if err != nil { - return err - } - - channel, err := fetchOpenChannel(chanBucket, &c.FundingOutpoint) - if err != nil { - return err - } - - // Unset this bit in the bitvector on disk. - status = channel.chanStatus & ^status - channel.chanStatus = status - - return putOpenChannel(chanBucket, channel) - }, func() {}); err != nil { - return err - } - - // Update the in-memory representation to keep it in sync with the DB. - c.chanStatus = status - - return nil -} - -// putOpenChannel serializes, and stores the current state of the channel in its -// entirety. -func putOpenChannel(chanBucket kvdb.RwBucket, channel *OpenChannel) error { - // First, we'll write out all the relatively static fields, that are - // decided upon initial channel creation. - if err := putChanInfo(chanBucket, channel); err != nil { - return fmt.Errorf("unable to store chan info: %w", err) - } - - // With the static channel info written out, we'll now write out the - // current commitment state for both parties. - if err := putChanCommitments(chanBucket, channel); err != nil { - return fmt.Errorf("unable to store chan commitments: %w", err) - } - - // Next, if this is a frozen channel, we'll add in the axillary - // information we need to store. - if channel.ChanType.IsFrozen() || channel.ChanType.HasLeaseExpiration() { - err := storeThawHeight( - chanBucket, channel.ThawHeight, - ) - if err != nil { - return fmt.Errorf("unable to store thaw height: %w", - err) - } - } - - // Finally, we'll write out the revocation state for both parties - // within a distinct key space. - if err := putChanRevocationState(chanBucket, channel); err != nil { - return fmt.Errorf("unable to store chan revocations: %w", err) - } - - return nil -} - -// fetchOpenChannel retrieves, and deserializes (including decrypting -// sensitive) the complete channel currently active with the passed nodeID. -func fetchOpenChannel(chanBucket kvdb.RBucket, - chanPoint *wire.OutPoint) (*OpenChannel, error) { - - channel := &OpenChannel{ - FundingOutpoint: *chanPoint, - } - - // First, we'll read all the static information that changes less - // frequently from disk. - if err := fetchChanInfo(chanBucket, channel); err != nil { - return nil, fmt.Errorf("unable to fetch chan info: %w", err) - } - - // With the static information read, we'll now read the current - // commitment state for both sides of the channel. - if err := fetchChanCommitments(chanBucket, channel); err != nil { - return nil, fmt.Errorf("unable to fetch chan commitments: %w", - err) - } - - // Next, if this is a frozen channel, we'll add in the axillary - // information we need to store. - if channel.ChanType.IsFrozen() || channel.ChanType.HasLeaseExpiration() { - thawHeight, err := fetchThawHeight(chanBucket) - if err != nil { - return nil, fmt.Errorf("unable to store thaw "+ - "height: %v", err) - } - - channel.ThawHeight = thawHeight - } - - // Finally, we'll retrieve the current revocation state so we can - // properly - if err := fetchChanRevocationState(chanBucket, channel); err != nil { - return nil, fmt.Errorf("unable to fetch chan revocations: %w", - err) - } - - channel.Packager = NewChannelPackager(channel.ShortChannelID) - - return channel, nil -} - -// SyncPending writes the contents of the channel to the database while it's in -// the pending (waiting for funding confirmation) state. The IsPending flag -// will be set to true. When the channel's funding transaction is confirmed, -// the channel should be marked as "open" and the IsPending flag set to false. -// Note that this function also creates a LinkNode relationship between this -// newly created channel and a new LinkNode instance. This allows listing all -// channels in the database globally, or according to the LinkNode they were -// created with. -// -// TODO(roasbeef): addr param should eventually be an lnwire.NetAddress type -// that includes service bits. -func (c *OpenChannel) SyncPending(addr net.Addr, pendingHeight uint32) error { - c.Lock() - defer c.Unlock() - - c.FundingBroadcastHeight = pendingHeight - - return kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error { - return syncNewChannel(tx, c, []net.Addr{addr}) - }, func() {}) -} - -// syncNewChannel will write the passed channel to disk, and also create a -// LinkNode (if needed) for the channel peer. -func syncNewChannel(tx kvdb.RwTx, c *OpenChannel, addrs []net.Addr) error { - // First, sync all the persistent channel state to disk. - if err := c.fullSync(tx); err != nil { - return err - } - - nodeInfoBucket, err := tx.CreateTopLevelBucket(nodeInfoBucket) - if err != nil { - return err - } - - // If a LinkNode for this identity public key already exists, - // then we can exit early. - nodePub := c.IdentityPub.SerializeCompressed() - if nodeInfoBucket.Get(nodePub) != nil { - return nil - } - - // Next, we need to establish a (possibly) new LinkNode relationship - // for this channel. The LinkNode metadata contains reachability, - // up-time, and service bits related information. - linkNode := NewLinkNode( - &LinkNodeDB{backend: c.Db.backend}, - wire.MainNet, c.IdentityPub, addrs..., - ) - - // TODO(roasbeef): do away with link node all together? - - return putLinkNode(nodeInfoBucket, linkNode) -} - -// UpdateCommitment updates the local commitment state. It locks in the pending -// local updates that were received by us from the remote party. The commitment -// state completely describes the balance state at this point in the commitment -// chain. In addition to that, it persists all the remote log updates that we -// have acked, but not signed a remote commitment for yet. These need to be -// persisted to be able to produce a valid commit signature if a restart would -// occur. This method its to be called when we revoke our prior commitment -// state. -// -// A map is returned of all the htlc resolutions that were locked in this -// commitment. Keys correspond to htlc indices and values indicate whether the -// htlc was settled or failed. -func (c *OpenChannel) UpdateCommitment(newCommitment *ChannelCommitment, - unsignedAckedUpdates []LogUpdate) (map[uint64]bool, error) { - - c.Lock() - defer c.Unlock() - - // If this is a restored channel, then we want to avoid mutating the - // state as all, as it's impossible to do so in a protocol compliant - // manner. - if c.hasChanStatus(ChanStatusRestored) { - return nil, ErrNoRestoredChannelMutation - } - - var finalHtlcs = make(map[uint64]bool) - - err := kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error { - chanBucket, err := fetchChanBucketRw( - tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, - ) - if err != nil { - return err - } - - // If the channel is marked as borked, then for safety reasons, - // we shouldn't attempt any further updates. - isBorked, err := c.isBorked(chanBucket) - if err != nil { - return err - } - if isBorked { - return ErrChanBorked - } - - if err = putChanInfo(chanBucket, c); err != nil { - return fmt.Errorf("unable to store chan info: %w", err) - } - - // With the proper bucket fetched, we'll now write the latest - // commitment state to disk for the target party. - err = putChanCommitment( - chanBucket, newCommitment, true, - ) - if err != nil { - return fmt.Errorf("unable to store chan "+ - "revocations: %v", err) - } - - // Persist unsigned but acked remote updates that need to be - // restored after a restart. - var b bytes.Buffer - err = serializeLogUpdates(&b, unsignedAckedUpdates) - if err != nil { - return err - } - - err = chanBucket.Put(unsignedAckedUpdatesKey, b.Bytes()) - if err != nil { - return fmt.Errorf("unable to store dangline remote "+ - "updates: %v", err) - } - - // Since we have just sent the counterparty a revocation, store true - // under lastWasRevokeKey. - var b2 bytes.Buffer - if err := WriteElements(&b2, true); err != nil { - return err - } - - if err := chanBucket.Put(lastWasRevokeKey, b2.Bytes()); err != nil { - return err - } - - // Persist the remote unsigned local updates that are not included - // in our new commitment. - updateBytes := chanBucket.Get(remoteUnsignedLocalUpdatesKey) - if updateBytes == nil { - return nil - } - - r := bytes.NewReader(updateBytes) - updates, err := deserializeLogUpdates(r) - if err != nil { - return err - } - - // Get the bucket where settled htlcs are recorded if the user - // opted in to storing this information. - var finalHtlcsBucket kvdb.RwBucket - if c.Db.parent.storeFinalHtlcResolutions { - bucket, err := fetchFinalHtlcsBucketRw( - tx, c.ShortChannelID, - ) - if err != nil { - return err - } - - finalHtlcsBucket = bucket - } - - var unsignedUpdates []LogUpdate - for _, upd := range updates { - // Gather updates that are not on our local commitment. - if upd.LogIndex >= newCommitment.LocalLogIndex { - unsignedUpdates = append(unsignedUpdates, upd) - - continue - } - - // The update was locked in. If the update was a - // resolution, then store it in the database. - err := processFinalHtlc( - finalHtlcsBucket, upd, finalHtlcs, - ) - if err != nil { - return err - } - } - - var b3 bytes.Buffer - err = serializeLogUpdates(&b3, unsignedUpdates) - if err != nil { - return fmt.Errorf("unable to serialize log updates: %w", - err) - } - - err = chanBucket.Put(remoteUnsignedLocalUpdatesKey, b3.Bytes()) - if err != nil { - return fmt.Errorf("unable to restore chanbucket: %w", - err) - } - - return nil - }, func() { - finalHtlcs = make(map[uint64]bool) - }) - if err != nil { - return nil, err - } - - c.LocalCommitment = *newCommitment - - return finalHtlcs, nil -} - -// processFinalHtlc stores a final htlc outcome in the database if signaled via -// the supplied log update. An in-memory htlcs map is updated too. -func processFinalHtlc(finalHtlcsBucket walletdb.ReadWriteBucket, upd LogUpdate, - finalHtlcs map[uint64]bool) error { - - var ( - settled bool - id uint64 - ) - - switch msg := upd.UpdateMsg.(type) { - case *lnwire.UpdateFulfillHTLC: - settled = true - id = msg.ID - - case *lnwire.UpdateFailHTLC: - settled = false - id = msg.ID - - case *lnwire.UpdateFailMalformedHTLC: - settled = false - id = msg.ID - - default: - return nil - } - - // Store the final resolution in the database if a bucket is provided. - if finalHtlcsBucket != nil { - err := putFinalHtlc( - finalHtlcsBucket, id, - FinalHtlcInfo{ - Settled: settled, - Offchain: true, - }, - ) - if err != nil { - return err - } - } - - finalHtlcs[id] = settled - - return nil -} - -// ActiveHtlcs returns a slice of HTLC's which are currently active on *both* -// commitment transactions. -func (c *OpenChannel) ActiveHtlcs() []HTLC { - c.RLock() - defer c.RUnlock() - - // We'll only return HTLC's that are locked into *both* commitment - // transactions. So we'll iterate through their set of HTLC's to note - // which ones are present on their commitment. - remoteHtlcs := make(map[[32]byte]struct{}) - for _, htlc := range c.RemoteCommitment.Htlcs { - log.Tracef("RemoteCommitment has htlc: id=%v, update=%v "+ - "incoming=%v", htlc.HtlcIndex, htlc.LogIndex, - htlc.Incoming) - - onionHash := sha256.Sum256(htlc.OnionBlob[:]) - remoteHtlcs[onionHash] = struct{}{} - } - - // Now that we know which HTLC's they have, we'll only mark the HTLC's - // as active if *we* know them as well. - activeHtlcs := make([]HTLC, 0, len(remoteHtlcs)) - for _, htlc := range c.LocalCommitment.Htlcs { - log.Tracef("LocalCommitment has htlc: id=%v, update=%v "+ - "incoming=%v", htlc.HtlcIndex, htlc.LogIndex, - htlc.Incoming) - - onionHash := sha256.Sum256(htlc.OnionBlob[:]) - if _, ok := remoteHtlcs[onionHash]; !ok { - log.Tracef("Skipped htlc due to onion mismatched: "+ - "id=%v, update=%v incoming=%v", - htlc.HtlcIndex, htlc.LogIndex, htlc.Incoming) - - continue - } - - activeHtlcs = append(activeHtlcs, htlc) - } - - return activeHtlcs -} - -// HTLC is the on-disk representation of a hash time-locked contract. HTLCs are -// contained within ChannelDeltas which encode the current state of the -// commitment between state updates. -// -// TODO(roasbeef): save space by using smaller ints at tail end? -type HTLC struct { - // TODO(yy): can embed an HTLCEntry here. - - // Signature is the signature for the second level covenant transaction - // for this HTLC. The second level transaction is a timeout tx in the - // case that this is an outgoing HTLC, and a success tx in the case - // that this is an incoming HTLC. - // - // TODO(roasbeef): make [64]byte instead? - Signature []byte - - // RHash is the payment hash of the HTLC. - RHash [32]byte - - // Amt is the amount of milli-satoshis this HTLC escrows. - Amt lnwire.MilliSatoshi - - // RefundTimeout is the absolute timeout on the HTLC that the sender - // must wait before reclaiming the funds in limbo. - RefundTimeout uint32 - - // OutputIndex is the output index for this particular HTLC output - // within the commitment transaction. - OutputIndex int32 - - // Incoming denotes whether we're the receiver or the sender of this - // HTLC. - Incoming bool - - // OnionBlob is an opaque blob which is used to complete multi-hop - // routing. - OnionBlob [lnwire.OnionPacketSize]byte - - // HtlcIndex is the HTLC counter index of this active, outstanding - // HTLC. This differs from the LogIndex, as the HtlcIndex is only - // incremented for each offered HTLC, while they LogIndex is - // incremented for each update (includes settle+fail). - HtlcIndex uint64 - - // LogIndex is the cumulative log index of this HTLC. This differs - // from the HtlcIndex as this will be incremented for each new log - // update added. - LogIndex uint64 - - // ExtraData contains any additional information that was transmitted - // with the HTLC via TLVs. This data *must* already be encoded as a - // TLV stream, and may be empty. The length of this data is naturally - // limited by the space available to TLVs in update_add_htlc: - // = 65535 bytes (bolt 8 maximum message size): - // - 2 bytes (bolt 1 message_type) - // - 32 bytes (channel_id) - // - 8 bytes (id) - // - 8 bytes (amount_msat) - // - 32 bytes (payment_hash) - // - 4 bytes (cltv_expiry) - // - 1366 bytes (onion_routing_packet) - // = 64083 bytes maximum possible TLV stream - // - // Note that this extra data is stored inline with the OnionBlob for - // legacy reasons, see serialization/deserialization functions for - // detail. - ExtraData lnwire.ExtraOpaqueData - - // BlindingPoint is an optional blinding point included with the HTLC. - // - // Note: this field is not a part of on-disk representation of the - // HTLC. It is stored in the ExtraData field, which is used to store - // a TLV stream of additional information associated with the HTLC. - BlindingPoint lnwire.BlindingPointRecord - - // CustomRecords is a set of custom TLV records that are associated with - // this HTLC. These records are used to store additional information - // about the HTLC that is not part of the standard HTLC fields. This - // field is encoded within the ExtraData field. - CustomRecords lnwire.CustomRecords -} - -// serializeExtraData encodes a TLV stream of extra data to be stored with a -// HTLC. It uses the update_add_htlc TLV types, because this is where extra -// data is passed with a HTLC. At present blinding points are the only extra -// data that we will store, and the function is a no-op if a nil blinding -// point is provided. -// -// This function MUST be called to persist all HTLC values when they are -// serialized. -func (h *HTLC) serializeExtraData() error { - var records []tlv.RecordProducer - h.BlindingPoint.WhenSome(func(b tlv.RecordT[lnwire.BlindingPointTlvType, - *btcec.PublicKey]) { - - records = append(records, &b) - }) - - records, err := h.CustomRecords.ExtendRecordProducers(records) - if err != nil { - return err - } - - return h.ExtraData.PackRecords(records...) -} - -// deserializeExtraData extracts TLVs from the extra data persisted for the -// htlc and populates values in the struct accordingly. -// -// This function MUST be called to populate the struct properly when HTLCs -// are deserialized. -func (h *HTLC) deserializeExtraData() error { - if len(h.ExtraData) == 0 { - return nil - } - - blindingPoint := h.BlindingPoint.Zero() - tlvMap, err := h.ExtraData.ExtractRecords(&blindingPoint) - if err != nil { - return err - } - - if val, ok := tlvMap[h.BlindingPoint.TlvType()]; ok && val == nil { - h.BlindingPoint = tlv.SomeRecordT(blindingPoint) - - // Remove the entry from the TLV map. Anything left in the map - // will be included in the custom records field. - delete(tlvMap, h.BlindingPoint.TlvType()) - } - - // Set the custom records field to the remaining TLV records. - customRecords, err := lnwire.NewCustomRecords(tlvMap) - if err != nil { - return err - } - h.CustomRecords = customRecords - - return nil -} - -// SerializeHtlcs writes out the passed set of HTLC's into the passed writer -// using the current default on-disk serialization format. -// -// This inline serialization has been extended to allow storage of extra data -// associated with a HTLC in the following way: -// - The known-length onion blob (1366 bytes) is serialized as var bytes in -// WriteElements (ie, the length 1366 was written, followed by the 1366 -// onion bytes). -// - To include extra data, we append any extra data present to this one -// variable length of data. Since we know that the onion is strictly 1366 -// bytes, any length after that should be considered to be extra data. -// -// NOTE: This API is NOT stable, the on-disk format will likely change in the -// future. -func SerializeHtlcs(b io.Writer, htlcs ...HTLC) error { - numHtlcs := uint16(len(htlcs)) - if err := WriteElement(b, numHtlcs); err != nil { - return err - } - - for _, htlc := range htlcs { - // Populate TLV stream for any additional fields contained - // in the TLV. - if err := htlc.serializeExtraData(); err != nil { - return err - } - - // The onion blob and hltc data are stored as a single var - // bytes blob. - onionAndExtraData := make( - []byte, lnwire.OnionPacketSize+len(htlc.ExtraData), - ) - copy(onionAndExtraData, htlc.OnionBlob[:]) - copy(onionAndExtraData[lnwire.OnionPacketSize:], htlc.ExtraData) - - if err := WriteElements(b, - htlc.Signature, htlc.RHash, htlc.Amt, htlc.RefundTimeout, - htlc.OutputIndex, htlc.Incoming, onionAndExtraData, - htlc.HtlcIndex, htlc.LogIndex, - ); err != nil { - return err - } - } - - return nil -} - -// DeserializeHtlcs attempts to read out a slice of HTLC's from the passed -// io.Reader. The bytes within the passed reader MUST have been previously -// written to using the SerializeHtlcs function. -// -// This inline deserialization has been extended to allow storage of extra data -// associated with a HTLC in the following way: -// - The known-length onion blob (1366 bytes) and any additional data present -// are read out as a single blob of variable byte data. -// - They are stored like this to take advantage of the variable space -// available for extension without migration (see SerializeHtlcs). -// - The first 1366 bytes are interpreted as the onion blob, and any remaining -// bytes as extra HTLC data. -// - This extra HTLC data is expected to be serialized as a TLV stream, and -// its parsing is left to higher layers. -// -// NOTE: This API is NOT stable, the on-disk format will likely change in the -// future. -func DeserializeHtlcs(r io.Reader) ([]HTLC, error) { - var numHtlcs uint16 - if err := ReadElement(r, &numHtlcs); err != nil { - return nil, err - } - - var htlcs []HTLC - if numHtlcs == 0 { - return htlcs, nil - } - - htlcs = make([]HTLC, numHtlcs) - for i := uint16(0); i < numHtlcs; i++ { - var onionAndExtraData []byte - if err := ReadElements(r, - &htlcs[i].Signature, &htlcs[i].RHash, &htlcs[i].Amt, - &htlcs[i].RefundTimeout, &htlcs[i].OutputIndex, - &htlcs[i].Incoming, &onionAndExtraData, - &htlcs[i].HtlcIndex, &htlcs[i].LogIndex, - ); err != nil { - return htlcs, err - } - - // Sanity check that we have at least the onion blob size we - // expect. - if len(onionAndExtraData) < lnwire.OnionPacketSize { - return nil, ErrOnionBlobLength - } - - // First OnionPacketSize bytes are our fixed length onion - // packet. - copy( - htlcs[i].OnionBlob[:], - onionAndExtraData[0:lnwire.OnionPacketSize], - ) - - // Any additional bytes belong to extra data. ExtraDataLen - // will be >= 0, because we know that we always have a fixed - // length onion packet. - extraDataLen := len(onionAndExtraData) - lnwire.OnionPacketSize - if extraDataLen > 0 { - htlcs[i].ExtraData = make([]byte, extraDataLen) - - copy( - htlcs[i].ExtraData, - onionAndExtraData[lnwire.OnionPacketSize:], - ) - } - - // Finally, deserialize any TLVs contained in that extra data - // if they are present. - if err := htlcs[i].deserializeExtraData(); err != nil { - return nil, err - } - } - - return htlcs, nil -} - -// Copy returns a full copy of the target HTLC. -func (h *HTLC) Copy() HTLC { - clone := HTLC{ - Incoming: h.Incoming, - Amt: h.Amt, - RefundTimeout: h.RefundTimeout, - OutputIndex: h.OutputIndex, - } - copy(clone.Signature[:], h.Signature) - copy(clone.RHash[:], h.RHash[:]) - copy(clone.ExtraData, h.ExtraData) - clone.BlindingPoint = h.BlindingPoint - clone.CustomRecords = h.CustomRecords.Copy() - - return clone -} - -// LogUpdate represents a pending update to the remote commitment chain. The -// log update may be an add, fail, or settle entry. We maintain this data in -// order to be able to properly retransmit our proposed state if necessary. -type LogUpdate struct { - // LogIndex is the log index of this proposed commitment update entry. - LogIndex uint64 - - // UpdateMsg is the update message that was included within our - // local update log. The LogIndex value denotes the log index of this - // update which will be used when restoring our local update log if - // we're left with a dangling update on restart. - UpdateMsg lnwire.Message -} - -// serializeLogUpdate writes a log update to the provided io.Writer. -func serializeLogUpdate(w io.Writer, l *LogUpdate) error { - return WriteElements(w, l.LogIndex, l.UpdateMsg) -} - -// deserializeLogUpdate reads a log update from the provided io.Reader. -func deserializeLogUpdate(r io.Reader) (*LogUpdate, error) { - l := &LogUpdate{} - if err := ReadElements(r, &l.LogIndex, &l.UpdateMsg); err != nil { - return nil, err - } - - return l, nil -} - -// CommitDiff represents the delta needed to apply the state transition between -// two subsequent commitment states. Given state N and state N+1, one is able -// to apply the set of messages contained within the CommitDiff to N to arrive -// at state N+1. Each time a new commitment is extended, we'll write a new -// commitment (along with the full commitment state) to disk so we can -// re-transmit the state in the case of a connection loss or message drop. -type CommitDiff struct { - // ChannelCommitment is the full commitment state that one would arrive - // at by applying the set of messages contained in the UpdateDiff to - // the prior accepted commitment. - Commitment ChannelCommitment - - // LogUpdates is the set of messages sent prior to the commitment state - // transition in question. Upon reconnection, if we detect that they - // don't have the commitment, then we re-send this along with the - // proper signature. - LogUpdates []LogUpdate - - // CommitSig is the exact CommitSig message that should be sent after - // the set of LogUpdates above has been retransmitted. The signatures - // within this message should properly cover the new commitment state - // and also the HTLC's within the new commitment state. - CommitSig *lnwire.CommitSig - - // OpenedCircuitKeys is a set of unique identifiers for any downstream - // Add packets included in this commitment txn. After a restart, this - // set of htlcs is acked from the link's incoming mailbox to ensure - // there isn't an attempt to re-add them to this commitment txn. - OpenedCircuitKeys []models.CircuitKey - - // ClosedCircuitKeys records the unique identifiers for any settle/fail - // packets that were resolved by this commitment txn. After a restart, - // this is used to ensure those circuits are removed from the circuit - // map, and the downstream packets in the link's mailbox are removed. - ClosedCircuitKeys []models.CircuitKey - - // AddAcks specifies the locations (commit height, pkg index) of any - // Adds that were failed/settled in this commit diff. This will ack - // entries in *this* channel's forwarding packages. - // - // NOTE: This value is not serialized, it is used to atomically mark the - // resolution of adds, such that they will not be reprocessed after a - // restart. - AddAcks []AddRef - - // SettleFailAcks specifies the locations (chan id, commit height, pkg - // index) of any Settles or Fails that were locked into this commit - // diff, and originate from *another* channel, i.e. the outgoing link. - // - // NOTE: This value is not serialized, it is used to atomically acks - // settles and fails from the forwarding packages of other channels, - // such that they will not be reforwarded internally after a restart. - SettleFailAcks []SettleFailRef -} - -// serializeLogUpdates serializes provided list of updates to a stream. -func serializeLogUpdates(w io.Writer, logUpdates []LogUpdate) error { - numUpdates := uint16(len(logUpdates)) - if err := binary.Write(w, byteOrder, numUpdates); err != nil { - return err - } - - for _, diff := range logUpdates { - err := WriteElements(w, diff.LogIndex, diff.UpdateMsg) - if err != nil { - return err - } - } - - return nil -} - -// deserializeLogUpdates deserializes a list of updates from a stream. -func deserializeLogUpdates(r io.Reader) ([]LogUpdate, error) { - var numUpdates uint16 - if err := binary.Read(r, byteOrder, &numUpdates); err != nil { - return nil, err - } - - logUpdates := make([]LogUpdate, numUpdates) - for i := 0; i < int(numUpdates); i++ { - err := ReadElements(r, - &logUpdates[i].LogIndex, &logUpdates[i].UpdateMsg, - ) - if err != nil { - return nil, err - } - } - return logUpdates, nil -} - -func serializeCommitDiff(w io.Writer, diff *CommitDiff) error { // nolint: dupl - if err := serializeChanCommit(w, &diff.Commitment); err != nil { - return err - } - - if err := WriteElements(w, diff.CommitSig); err != nil { - return err - } - - if err := serializeLogUpdates(w, diff.LogUpdates); err != nil { - return err - } - - numOpenRefs := uint16(len(diff.OpenedCircuitKeys)) - if err := binary.Write(w, byteOrder, numOpenRefs); err != nil { - return err - } - - for _, openRef := range diff.OpenedCircuitKeys { - err := WriteElements(w, openRef.ChanID, openRef.HtlcID) - if err != nil { - return err - } - } - - numClosedRefs := uint16(len(diff.ClosedCircuitKeys)) - if err := binary.Write(w, byteOrder, numClosedRefs); err != nil { - return err - } - - for _, closedRef := range diff.ClosedCircuitKeys { - err := WriteElements(w, closedRef.ChanID, closedRef.HtlcID) - if err != nil { - return err - } - } - - // We'll also encode the commit aux data stream here. We do this here - // rather than above (at the call to serializeChanCommit), to ensure - // backwards compat for reads to existing non-custom channels. - auxData := diff.Commitment.extractTlvData() - if err := auxData.encode(w); err != nil { - return fmt.Errorf("unable to write aux data: %w", err) - } - - return nil -} - -func deserializeCommitDiff(r io.Reader) (*CommitDiff, error) { - var ( - d CommitDiff - err error - ) - - d.Commitment, err = deserializeChanCommit(r) - if err != nil { - return nil, err - } - - var msg lnwire.Message - if err := ReadElements(r, &msg); err != nil { - return nil, err - } - commitSig, ok := msg.(*lnwire.CommitSig) - if !ok { - return nil, fmt.Errorf("expected lnwire.CommitSig, instead "+ - "read: %T", msg) - } - d.CommitSig = commitSig - - d.LogUpdates, err = deserializeLogUpdates(r) - if err != nil { - return nil, err - } - - var numOpenRefs uint16 - if err := binary.Read(r, byteOrder, &numOpenRefs); err != nil { - return nil, err - } - - d.OpenedCircuitKeys = make([]models.CircuitKey, numOpenRefs) - for i := 0; i < int(numOpenRefs); i++ { - err := ReadElements(r, - &d.OpenedCircuitKeys[i].ChanID, - &d.OpenedCircuitKeys[i].HtlcID) - if err != nil { - return nil, err - } - } - - var numClosedRefs uint16 - if err := binary.Read(r, byteOrder, &numClosedRefs); err != nil { - return nil, err - } - - d.ClosedCircuitKeys = make([]models.CircuitKey, numClosedRefs) - for i := 0; i < int(numClosedRefs); i++ { - err := ReadElements(r, - &d.ClosedCircuitKeys[i].ChanID, - &d.ClosedCircuitKeys[i].HtlcID) - if err != nil { - return nil, err - } - } - - // As a final step, we'll read out any aux commit data that we have at - // the end of this byte stream. We do this here to ensure backward - // compatibility, as otherwise we risk erroneously reading into the - // wrong field. - var auxData commitTlvData - if err := auxData.decode(r); err != nil { - return nil, fmt.Errorf("unable to decode aux data: %w", err) - } - - d.Commitment.amendTlvData(auxData) - - return &d, nil -} - -// AppendRemoteCommitChain appends a new CommitDiff to the end of the -// commitment chain for the remote party. This method is to be used once we -// have prepared a new commitment state for the remote party, but before we -// transmit it to the remote party. The contents of the argument should be -// sufficient to retransmit the updates and signature needed to reconstruct the -// state in full, in the case that we need to retransmit. -func (c *OpenChannel) AppendRemoteCommitChain(diff *CommitDiff) error { - c.Lock() - defer c.Unlock() - - // If this is a restored channel, then we want to avoid mutating the - // state at all, as it's impossible to do so in a protocol compliant - // manner. - if c.hasChanStatus(ChanStatusRestored) { - return ErrNoRestoredChannelMutation - } - - return kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error { - // First, we'll grab the writable bucket where this channel's - // data resides. - chanBucket, err := fetchChanBucketRw( - tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, - ) - if err != nil { - return err - } - - // If the channel is marked as borked, then for safety reasons, - // we shouldn't attempt any further updates. - isBorked, err := c.isBorked(chanBucket) - if err != nil { - return err - } - if isBorked { - return ErrChanBorked - } - - // Any outgoing settles and fails necessarily have a - // corresponding adds in this channel's forwarding packages. - // Mark all of these as being fully processed in our forwarding - // package, which prevents us from reprocessing them after - // startup. - err = c.Packager.AckAddHtlcs(tx, diff.AddAcks...) - if err != nil { - return err - } - - // Additionally, we ack from any fails or settles that are - // persisted in another channel's forwarding package. This - // prevents the same fails and settles from being retransmitted - // after restarts. The actual fail or settle we need to - // propagate to the remote party is now in the commit diff. - err = c.Packager.AckSettleFails(tx, diff.SettleFailAcks...) - if err != nil { - return err - } - - // We are sending a commitment signature so lastWasRevokeKey should - // store false. - var b bytes.Buffer - if err := WriteElements(&b, false); err != nil { - return err - } - if err := chanBucket.Put(lastWasRevokeKey, b.Bytes()); err != nil { - return err - } - - // TODO(roasbeef): use seqno to derive key for later LCP - - // With the bucket retrieved, we'll now serialize the commit - // diff itself, and write it to disk. - var b2 bytes.Buffer - if err := serializeCommitDiff(&b2, diff); err != nil { - return err - } - return chanBucket.Put(commitDiffKey, b2.Bytes()) - }, func() {}) -} - -// RemoteCommitChainTip returns the "tip" of the current remote commitment -// chain. This value will be non-nil iff, we've created a new commitment for -// the remote party that they haven't yet ACK'd. In this case, their commitment -// chain will have a length of two: their current unrevoked commitment, and -// this new pending commitment. Once they revoked their prior state, we'll swap -// these pointers, causing the tip and the tail to point to the same entry. -func (c *OpenChannel) RemoteCommitChainTip() (*CommitDiff, error) { - var cd *CommitDiff - err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error { - chanBucket, err := fetchChanBucket( - tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, - ) - switch err { - case nil: - case ErrNoChanDBExists, ErrNoActiveChannels, ErrChannelNotFound: - return ErrNoPendingCommit - default: - return err - } - - tipBytes := chanBucket.Get(commitDiffKey) - if tipBytes == nil { - return ErrNoPendingCommit - } - - tipReader := bytes.NewReader(tipBytes) - dcd, err := deserializeCommitDiff(tipReader) - if err != nil { - return err - } - - cd = dcd - return nil - }, func() { - cd = nil - }) - if err != nil { - return nil, err - } - - return cd, nil -} - -// UnsignedAckedUpdates retrieves the persisted unsigned acked remote log -// updates that still need to be signed for. -func (c *OpenChannel) UnsignedAckedUpdates() ([]LogUpdate, error) { - var updates []LogUpdate - err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error { - chanBucket, err := fetchChanBucket( - tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, - ) - switch err { - case nil: - case ErrNoChanDBExists, ErrNoActiveChannels, ErrChannelNotFound: - return nil - default: - return err - } - - updateBytes := chanBucket.Get(unsignedAckedUpdatesKey) - if updateBytes == nil { - return nil - } - - r := bytes.NewReader(updateBytes) - updates, err = deserializeLogUpdates(r) - return err - }, func() { - updates = nil - }) - if err != nil { - return nil, err - } - - return updates, nil -} - -// RemoteUnsignedLocalUpdates retrieves the persisted, unsigned local log -// updates that the remote still needs to sign for. -func (c *OpenChannel) RemoteUnsignedLocalUpdates() ([]LogUpdate, error) { - var updates []LogUpdate - err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error { - chanBucket, err := fetchChanBucket( - tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, - ) - switch err { - case nil: - break - case ErrNoChanDBExists, ErrNoActiveChannels, ErrChannelNotFound: - return nil - default: - return err - } - - updateBytes := chanBucket.Get(remoteUnsignedLocalUpdatesKey) - if updateBytes == nil { - return nil - } - - r := bytes.NewReader(updateBytes) - updates, err = deserializeLogUpdates(r) - return err - }, func() { - updates = nil - }) - if err != nil { - return nil, err - } - - return updates, nil -} - -// InsertNextRevocation inserts the _next_ commitment point (revocation) into -// the database, and also modifies the internal RemoteNextRevocation attribute -// to point to the passed key. This method is to be using during final channel -// set up, _after_ the channel has been fully confirmed. -// -// NOTE: If this method isn't called, then the target channel won't be able to -// propose new states for the commitment state of the remote party. -func (c *OpenChannel) InsertNextRevocation(revKey *btcec.PublicKey) error { - c.Lock() - defer c.Unlock() - - c.RemoteNextRevocation = revKey - - err := kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error { - chanBucket, err := fetchChanBucketRw( - tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, - ) - if err != nil { - return err - } - - return putChanRevocationState(chanBucket, c) - }, func() {}) - if err != nil { - return err - } - - return nil -} - -// AdvanceCommitChainTail records the new state transition within an on-disk -// append-only log which records all state transitions by the remote peer. In -// the case of an uncooperative broadcast of a prior state by the remote peer, -// this log can be consulted in order to reconstruct the state needed to -// rectify the situation. This method will add the current commitment for the -// remote party to the revocation log, and promote the current pending -// commitment to the current remote commitment. The updates parameter is the -// set of local updates that the peer still needs to send us a signature for. -// We store this set of updates in case we go down. -func (c *OpenChannel) AdvanceCommitChainTail(fwdPkg *FwdPkg, - updates []LogUpdate, ourOutputIndex, theirOutputIndex uint32) error { - - c.Lock() - defer c.Unlock() - - // If this is a restored channel, then we want to avoid mutating the - // state at all, as it's impossible to do so in a protocol compliant - // manner. - if c.hasChanStatus(ChanStatusRestored) { - return ErrNoRestoredChannelMutation - } - - var newRemoteCommit *ChannelCommitment - - err := kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error { - chanBucket, err := fetchChanBucketRw( - tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, - ) - if err != nil { - return err - } - - // If the channel is marked as borked, then for safety reasons, - // we shouldn't attempt any further updates. - isBorked, err := c.isBorked(chanBucket) - if err != nil { - return err - } - if isBorked { - return ErrChanBorked - } - - // Persist the latest preimage state to disk as the remote peer - // has just added to our local preimage store, and given us a - // new pending revocation key. - if err := putChanRevocationState(chanBucket, c); err != nil { - return err - } - - // With the current preimage producer/store state updated, - // append a new log entry recording this the delta of this - // state transition. - // - // TODO(roasbeef): could make the deltas relative, would save - // space, but then tradeoff for more disk-seeks to recover the - // full state. - logKey := revocationLogBucket - logBucket, err := chanBucket.CreateBucketIfNotExists(logKey) - if err != nil { - return err - } - - // Before we append this revoked state to the revocation log, - // we'll swap out what's currently the tail of the commit tip, - // with the current locked-in commitment for the remote party. - tipBytes := chanBucket.Get(commitDiffKey) - tipReader := bytes.NewReader(tipBytes) - newCommit, err := deserializeCommitDiff(tipReader) - if err != nil { - return err - } - err = putChanCommitment( - chanBucket, &newCommit.Commitment, false, - ) - if err != nil { - return err - } - if err := chanBucket.Delete(commitDiffKey); err != nil { - return err - } - - // With the commitment pointer swapped, we can now add the - // revoked (prior) state to the revocation log. - err = putRevocationLog( - logBucket, &c.RemoteCommitment, ourOutputIndex, - theirOutputIndex, c.Db.parent.noRevLogAmtData, - ) - if err != nil { - return err - } - - // Lastly, we write the forwarding package to disk so that we - // can properly recover from failures and reforward HTLCs that - // have not received a corresponding settle/fail. - if err := c.Packager.AddFwdPkg(tx, fwdPkg); err != nil { - return err - } - - // Persist the unsigned acked updates that are not included - // in their new commitment. - updateBytes := chanBucket.Get(unsignedAckedUpdatesKey) - if updateBytes == nil { - // This shouldn't normally happen as we always store - // the number of updates, but could still be - // encountered by nodes that are upgrading. - newRemoteCommit = &newCommit.Commitment - return nil - } - - r := bytes.NewReader(updateBytes) - unsignedUpdates, err := deserializeLogUpdates(r) - if err != nil { - return err - } - - var validUpdates []LogUpdate - for _, upd := range unsignedUpdates { - lIdx := upd.LogIndex - - // Filter for updates that are not on the remote - // commitment. - if lIdx >= newCommit.Commitment.RemoteLogIndex { - validUpdates = append(validUpdates, upd) - } - } - - var b bytes.Buffer - err = serializeLogUpdates(&b, validUpdates) - if err != nil { - return fmt.Errorf("unable to serialize log updates: %w", - err) - } - - err = chanBucket.Put(unsignedAckedUpdatesKey, b.Bytes()) - if err != nil { - return fmt.Errorf("unable to store under "+ - "unsignedAckedUpdatesKey: %w", err) - } - - // Persist the local updates the peer hasn't yet signed so they - // can be restored after restart. - var b2 bytes.Buffer - err = serializeLogUpdates(&b2, updates) - if err != nil { - return err - } - - err = chanBucket.Put(remoteUnsignedLocalUpdatesKey, b2.Bytes()) - if err != nil { - return fmt.Errorf("unable to restore remote unsigned "+ - "local updates: %v", err) - } - - newRemoteCommit = &newCommit.Commitment - - return nil - }, func() { - newRemoteCommit = nil - }) - if err != nil { - return err - } - - // With the db transaction complete, we'll swap over the in-memory - // pointer of the new remote commitment, which was previously the tip - // of the commit chain. - c.RemoteCommitment = *newRemoteCommit - - return nil -} - -// FinalHtlcInfo contains information about the final outcome of an htlc. -type FinalHtlcInfo struct { - // Settled is true is the htlc was settled. If false, the htlc was - // failed. - Settled bool - - // Offchain indicates whether the htlc was resolved off-chain or - // on-chain. - Offchain bool -} - -// putFinalHtlc writes the final htlc outcome to the database. Additionally it -// records whether the htlc was resolved off-chain or on-chain. -func putFinalHtlc(finalHtlcsBucket kvdb.RwBucket, id uint64, - info FinalHtlcInfo) error { - - var key [8]byte - byteOrder.PutUint64(key[:], id) - - var finalHtlcByte FinalHtlcByte - if info.Settled { - finalHtlcByte |= FinalHtlcSettledBit - } - if info.Offchain { - finalHtlcByte |= FinalHtlcOffchainBit - } - - return finalHtlcsBucket.Put(key[:], []byte{byte(finalHtlcByte)}) -} - -// NextLocalHtlcIndex returns the next unallocated local htlc index. To ensure -// this always returns the next index that has been not been allocated, this -// will first try to examine any pending commitments, before falling back to the -// last locked-in remote commitment. -func (c *OpenChannel) NextLocalHtlcIndex() (uint64, error) { - // First, load the most recent commit diff that we initiated for the - // remote party. If no pending commit is found, this is not treated as - // a critical error, since we can always fall back. - pendingRemoteCommit, err := c.RemoteCommitChainTip() - if err != nil && err != ErrNoPendingCommit { - return 0, err - } - - // If a pending commit was found, its local htlc index will be at least - // as large as the one on our local commitment. - if pendingRemoteCommit != nil { - return pendingRemoteCommit.Commitment.LocalHtlcIndex, nil - } - - // Otherwise, fallback to using the local htlc index of their commitment. - return c.RemoteCommitment.LocalHtlcIndex, nil -} - -// LoadFwdPkgs scans the forwarding log for any packages that haven't been -// processed, and returns their deserialized log updates in map indexed by the -// remote commitment height at which the updates were locked in. -func (c *OpenChannel) LoadFwdPkgs() ([]*FwdPkg, error) { - c.RLock() - defer c.RUnlock() - - var fwdPkgs []*FwdPkg - if err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error { - var err error - fwdPkgs, err = c.Packager.LoadFwdPkgs(tx) - return err - }, func() { - fwdPkgs = nil - }); err != nil { - return nil, err - } - - return fwdPkgs, nil -} - -// AckAddHtlcs updates the AckAddFilter containing any of the provided AddRefs -// indicating that a response to this Add has been committed to the remote party. -// Doing so will prevent these Add HTLCs from being reforwarded internally. -func (c *OpenChannel) AckAddHtlcs(addRefs ...AddRef) error { - c.Lock() - defer c.Unlock() - - return kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error { - return c.Packager.AckAddHtlcs(tx, addRefs...) - }, func() {}) -} - -// AckSettleFails updates the SettleFailFilter containing any of the provided -// SettleFailRefs, indicating that the response has been delivered to the -// incoming link, corresponding to a particular AddRef. Doing so will prevent -// the responses from being retransmitted internally. -func (c *OpenChannel) AckSettleFails(settleFailRefs ...SettleFailRef) error { - c.Lock() - defer c.Unlock() - - return kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error { - return c.Packager.AckSettleFails(tx, settleFailRefs...) - }, func() {}) -} - -// SetFwdFilter atomically sets the forwarding filter for the forwarding package -// identified by `height`. -func (c *OpenChannel) SetFwdFilter(height uint64, fwdFilter *PkgFilter) error { - c.Lock() - defer c.Unlock() - - return kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error { - return c.Packager.SetFwdFilter(tx, height, fwdFilter) - }, func() {}) -} - -// RemoveFwdPkgs atomically removes forwarding packages specified by the remote -// commitment heights. If one of the intermediate RemovePkg calls fails, then the -// later packages won't be removed. -// -// NOTE: This method should only be called on packages marked FwdStateCompleted. -func (c *OpenChannel) RemoveFwdPkgs(heights ...uint64) error { - c.Lock() - defer c.Unlock() - - return kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error { - for _, height := range heights { - err := c.Packager.RemovePkg(tx, height) - if err != nil { - return err - } - } - - return nil - }, func() {}) -} - -// revocationLogTailCommitHeight returns the commit height at the end of the -// revocation log. This entry represents the last previous state for the remote -// node's commitment chain. The ChannelDelta returned by this method will -// always lag one state behind the most current (unrevoked) state of the remote -// node's commitment chain. -// NOTE: used in unit test only. -func (c *OpenChannel) revocationLogTailCommitHeight() (uint64, error) { - c.RLock() - defer c.RUnlock() - - var height uint64 - - // If we haven't created any state updates yet, then we'll exit early as - // there's nothing to be found on disk in the revocation bucket. - if c.RemoteCommitment.CommitHeight == 0 { - return height, nil - } - - if err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error { - chanBucket, err := fetchChanBucket( - tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, - ) - if err != nil { - return err - } - - logBucket, err := fetchLogBucket(chanBucket) - if err != nil { - return err - } - - // Once we have the bucket that stores the revocation log from - // this channel, we'll jump to the _last_ key in bucket. Since - // the key is the commit height, we'll decode the bytes and - // return it. - cursor := logBucket.ReadCursor() - rawHeight, _ := cursor.Last() - height = byteOrder.Uint64(rawHeight) - - return nil - }, func() {}); err != nil { - return height, err - } - - return height, nil -} - -// CommitmentHeight returns the current commitment height. The commitment -// height represents the number of updates to the commitment state to date. -// This value is always monotonically increasing. This method is provided in -// order to allow multiple instances of a particular open channel to obtain a -// consistent view of the number of channel updates to date. -func (c *OpenChannel) CommitmentHeight() (uint64, error) { - c.RLock() - defer c.RUnlock() - - var height uint64 - err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error { - // Get the bucket dedicated to storing the metadata for open - // channels. - chanBucket, err := fetchChanBucket( - tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, - ) - if err != nil { - return err - } - - commit, err := fetchChanCommitment(chanBucket, true) - if err != nil { - return err - } - - height = commit.CommitHeight - return nil - }, func() { - height = 0 - }) - if err != nil { - return 0, err - } - - return height, nil -} - -// FindPreviousState scans through the append-only log in an attempt to recover -// the previous channel state indicated by the update number. This method is -// intended to be used for obtaining the relevant data needed to claim all -// funds rightfully spendable in the case of an on-chain broadcast of the -// commitment transaction. -func (c *OpenChannel) FindPreviousState( - updateNum uint64) (*RevocationLog, *ChannelCommitment, error) { - - c.RLock() - defer c.RUnlock() - - commit := &ChannelCommitment{} - rl := &RevocationLog{} - - err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error { - chanBucket, err := fetchChanBucket( - tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, - ) - if err != nil { - return err - } - - // Find the revocation log from both the new and the old - // bucket. - r, c, err := fetchRevocationLogCompatible(chanBucket, updateNum) - if err != nil { - return err - } - - rl = r - commit = c - return nil - }, func() {}) - if err != nil { - return nil, nil, err - } - - // Either the `rl` or the `commit` is nil here. We return them as-is - // and leave it to the caller to decide its following action. - return rl, commit, nil -} - -// ClosureType is an enum like structure that details exactly _how_ a channel -// was closed. Three closure types are currently possible: none, cooperative, -// local force close, remote force close, and (remote) breach. -type ClosureType uint8 - -const ( - // CooperativeClose indicates that a channel has been closed - // cooperatively. This means that both channel peers were online and - // signed a new transaction paying out the settled balance of the - // contract. - CooperativeClose ClosureType = 0 - - // LocalForceClose indicates that we have unilaterally broadcast our - // current commitment state on-chain. - LocalForceClose ClosureType = 1 - - // RemoteForceClose indicates that the remote peer has unilaterally - // broadcast their current commitment state on-chain. - RemoteForceClose ClosureType = 4 - - // BreachClose indicates that the remote peer attempted to broadcast a - // prior _revoked_ channel state. - BreachClose ClosureType = 2 - - // FundingCanceled indicates that the channel never was fully opened - // before it was marked as closed in the database. This can happen if - // we or the remote fail at some point during the opening workflow, or - // we timeout waiting for the funding transaction to be confirmed. - FundingCanceled ClosureType = 3 - - // Abandoned indicates that the channel state was removed without - // any further actions. This is intended to clean up unusable - // channels during development. - Abandoned ClosureType = 5 -) - -// ChannelCloseSummary contains the final state of a channel at the point it -// was closed. Once a channel is closed, all the information pertaining to that -// channel within the openChannelBucket is deleted, and a compact summary is -// put in place instead. -type ChannelCloseSummary struct { - // ChanPoint is the outpoint for this channel's funding transaction, - // and is used as a unique identifier for the channel. - ChanPoint wire.OutPoint - - // ShortChanID encodes the exact location in the chain in which the - // channel was initially confirmed. This includes: the block height, - // transaction index, and the output within the target transaction. - ShortChanID lnwire.ShortChannelID - - // ChainHash is the hash of the genesis block that this channel resides - // within. - ChainHash chainhash.Hash - - // ClosingTXID is the txid of the transaction which ultimately closed - // this channel. - ClosingTXID chainhash.Hash - - // RemotePub is the public key of the remote peer that we formerly had - // a channel with. - RemotePub *btcec.PublicKey - - // Capacity was the total capacity of the channel. - Capacity btcutil.Amount - - // CloseHeight is the height at which the funding transaction was - // spent. - CloseHeight uint32 - - // SettledBalance is our total balance settled balance at the time of - // channel closure. This _does not_ include the sum of any outputs that - // have been time-locked as a result of the unilateral channel closure. - SettledBalance btcutil.Amount - - // TimeLockedBalance is the sum of all the time-locked outputs at the - // time of channel closure. If we triggered the force closure of this - // channel, then this value will be non-zero if our settled output is - // above the dust limit. If we were on the receiving side of a channel - // force closure, then this value will be non-zero if we had any - // outstanding outgoing HTLC's at the time of channel closure. - TimeLockedBalance btcutil.Amount - - // CloseType details exactly _how_ the channel was closed. Five closure - // types are possible: cooperative, local force, remote force, breach - // and funding canceled. - CloseType ClosureType - - // IsPending indicates whether this channel is in the 'pending close' - // state, which means the channel closing transaction has been - // confirmed, but not yet been fully resolved. In the case of a channel - // that has been cooperatively closed, it will go straight into the - // fully resolved state as soon as the closing transaction has been - // confirmed. However, for channels that have been force closed, they'll - // stay marked as "pending" until _all_ the pending funds have been - // swept. - IsPending bool - - // RemoteCurrentRevocation is the current revocation for their - // commitment transaction. However, since this is the derived public key, - // we don't yet have the private key so we aren't yet able to verify - // that it's actually in the hash chain. - RemoteCurrentRevocation *btcec.PublicKey - - // RemoteNextRevocation is the revocation key to be used for the *next* - // commitment transaction we create for the local node. Within the - // specification, this value is referred to as the - // per-commitment-point. - RemoteNextRevocation *btcec.PublicKey - - // LocalChanConfig is the channel configuration for the local node. - LocalChanConfig ChannelConfig - - // LastChanSyncMsg is the ChannelReestablish message for this channel - // for the state at the point where it was closed. - LastChanSyncMsg *lnwire.ChannelReestablish -} - -// CloseChannel closes a previously active Lightning channel. Closing a -// channel entails persisting a record of the close while either purging the -// nested per-channel state inline (synchronous backends like bbolt and etcd) -// or skipping the cascading delete on tombstone-enabled backends, where the -// outpoint-index flip to outpointClosed is the authoritative marker. The -// compact summary written to closedChannelBucket and the historical record -// under historicalChannelBucket are populated identically across both paths, -// so historical reads remain uniform regardless of backend. The optional set -// of channel statuses is OR'd into the chanStatus written to the historical -// bucket and is used to record close initiators. -func (c *OpenChannel) CloseChannel(summary *ChannelCloseSummary, - statuses ...ChannelStatus) error { - - c.Lock() - defer c.Unlock() - - return c.Db.CloseChannel(c, summary, statuses...) -} - -// CloseChannel closes the supplied channel via the strategy selected at DB -// construction. On synchronous backends the channel's nested state — the -// revocation log, the per-channel forwarding-package bucket, and the -// chanBucket itself — is deleted inline. On tombstone-enabled backends none -// of the bulk state is touched; the outpointBucket flip to outpointClosed -// signals that the channel is logically closed. -func (c *ChannelStateDB) CloseChannel(channel *OpenChannel, - summary *ChannelCloseSummary, statuses ...ChannelStatus) error { - - if c.tombstoneClosedChannels { - return c.closeChannelTombstone(channel, summary, statuses...) - } - - return c.closeChannelSync(channel, summary, statuses...) -} - -// locateOpenChannel performs the open-channel-bucket descent for a -// CloseChannel transaction: it returns the chain bucket, the channel bucket, -// and the serialized chanKey for the supplied OpenChannel. A chanKey already -// flipped to outpointClosed surfaces ErrChannelNotFound so a redundant -// CloseChannel does not re-archive or re-flip the index. -func locateOpenChannel(tx kvdb.RwTx, channel *OpenChannel) (kvdb.RwBucket, - kvdb.RwBucket, []byte, error) { - - openChanBucket := tx.ReadWriteBucket(openChannelBucket) - if openChanBucket == nil { - return nil, nil, nil, ErrNoChanDBExists - } - - nodePub := channel.IdentityPub.SerializeCompressed() - nodeChanBucket := openChanBucket.NestedReadWriteBucket(nodePub) - if nodeChanBucket == nil { - return nil, nil, nil, ErrNoActiveChannels - } - - chainBucket := nodeChanBucket.NestedReadWriteBucket( - channel.ChainHash[:], - ) - if chainBucket == nil { - return nil, nil, nil, ErrNoActiveChannels - } - - var chanPointBuf bytes.Buffer - if err := graphdb.WriteOutpoint( - &chanPointBuf, &channel.FundingOutpoint, - ); err != nil { - return nil, nil, nil, err - } - chanKey := chanPointBuf.Bytes() - - chanBucket := chainBucket.NestedReadWriteBucket(chanKey) - if chanBucket == nil { - return nil, nil, nil, ErrNoActiveChannels - } - - // A channel whose outpoint is already flipped to outpointClosed must - // not be re-closed: on tombstone backends the chanBucket survives a - // previous close, but the index flip is the authoritative record that - // the channel is gone from the open-channel view. - closed, err := isOutpointClosed(tx.ReadBucket(outpointBucket), chanKey) - if err != nil { - return nil, nil, nil, err - } - if closed { - return nil, nil, nil, ErrChannelNotFound - } - - return chainBucket, chanBucket, chanKey, nil -} - -// updateClosedOutpointIndex flips the outpoint index entry for chanKey from -// open to closed. The index entry must already exist; it was placed there -// when the channel was opened. -func updateClosedOutpointIndex(tx kvdb.RwTx, chanKey []byte) error { - opBucket := tx.ReadWriteBucket(outpointBucket) - if opBucket == nil { - return ErrNoChanDBExists - } - if opBucket.Get(chanKey) == nil { - return ErrMissingIndexEntry - } - - status := uint8(outpointClosed) - statusRecord := tlv.MakePrimitiveRecord(indexStatusType, &status) - opStream, err := tlv.NewStream(statusRecord) - if err != nil { - return err - } - - var b bytes.Buffer - if err := opStream.Encode(&b); err != nil { - return err - } - - return opBucket.Put(chanKey, b.Bytes()) -} - -// archiveClosedChannel writes the immutable close-time records of the -// channel: a copy of the open-channel state under historicalChannelBucket -// (with the supplied close statuses OR'd into chanStatus) and the close -// summary under closeSummaryBucket. -func archiveClosedChannel(tx kvdb.RwTx, chanKey []byte, - chanState *OpenChannel, summary *ChannelCloseSummary, - statuses ...ChannelStatus) error { - - historicalBucket, err := tx.CreateTopLevelBucket( - historicalChannelBucket, - ) - if err != nil { - return err - } - historicalChanBucket, err := historicalBucket.CreateBucketIfNotExists( - chanKey, - ) - if err != nil { - return err - } - - for _, s := range statuses { - chanState.chanStatus |= s - } - - if err := putOpenChannel(historicalChanBucket, chanState); err != nil { - return err - } - - return putChannelCloseSummary(tx, chanKey, summary, chanState) -} - -// closeChannelSync performs the historical synchronous close path: in a -// single write transaction it wipes the forwarding-package state, deletes -// the channel bucket and its nested revocation log entries, updates the -// outpoint index, and archives the close summary. It is used by backends -// where nested-bucket deletion is cheap (bbolt, etcd). -func (c *ChannelStateDB) closeChannelSync(channel *OpenChannel, - summary *ChannelCloseSummary, statuses ...ChannelStatus) error { - - return kvdb.Update(c.backend, func(tx kvdb.RwTx) error { - chainBucket, chanBucket, chanKey, err := locateOpenChannel( - tx, channel, - ) - if err != nil { - return err - } - - chanState, err := fetchOpenChannel( - chanBucket, &channel.FundingOutpoint, - ) - if err != nil { - return err - } - - if err = chanState.Packager.Wipe(tx); err != nil { - return err - } - - if err := deleteOpenChannel(chanBucket); err != nil { - return err - } - - if channel.ChanType.IsFrozen() || - channel.ChanType.HasLeaseExpiration() { - - if err := deleteThawHeight(chanBucket); err != nil { - return err - } - } - - if err := deleteLogBucket(chanBucket); err != nil { - return err - } - - if err := chainBucket.DeleteNestedBucket(chanKey); err != nil { - return err - } - - if err := updateClosedOutpointIndex(tx, chanKey); err != nil { - return err - } - - return archiveClosedChannel( - tx, chanKey, chanState, summary, statuses..., - ) - }, func() {}) -} - -// closeChannelTombstone performs the tombstone close path used by -// KV-over-SQL backends. The channel's per-channel state is left intact — -// touching it would trigger the cascading nested-bucket delete this path -// exists to avoid — and the outpointBucket flip from outpointOpen to -// outpointClosed serves as the authoritative closed-channel marker. The -// disk space is reclaimed wholesale by the upcoming native-SQL -// channel-state migration. -func (c *ChannelStateDB) closeChannelTombstone(channel *OpenChannel, - summary *ChannelCloseSummary, statuses ...ChannelStatus) error { - - return kvdb.Update(c.backend, func(tx kvdb.RwTx) error { - _, chanBucket, chanKey, err := locateOpenChannel(tx, channel) - if err != nil { - return err - } - - chanState, err := fetchOpenChannel( - chanBucket, &channel.FundingOutpoint, - ) - if err != nil { - return err - } - - if err := updateClosedOutpointIndex(tx, chanKey); err != nil { - return err - } - - return archiveClosedChannel( - tx, chanKey, chanState, summary, statuses..., - ) - }, func() {}) -} - -// ChannelSnapshot is a frozen snapshot of the current channel state. A -// snapshot is detached from the original channel that generated it, providing -// read-only access to the current or prior state of an active channel. -// -// TODO(roasbeef): remove all together? pretty much just commitment -type ChannelSnapshot struct { - // RemoteIdentity is the identity public key of the remote node that we - // are maintaining the open channel with. - RemoteIdentity btcec.PublicKey - - // ChanPoint is the outpoint that created the channel. This output is - // found within the funding transaction and uniquely identified the - // channel on the resident chain. - ChannelPoint wire.OutPoint - - // ChainHash is the genesis hash of the chain that the channel resides - // within. - ChainHash chainhash.Hash - - // Capacity is the total capacity of the channel. - Capacity btcutil.Amount - - // TotalMSatSent is the total number of milli-satoshis we've sent - // within this channel. - TotalMSatSent lnwire.MilliSatoshi - - // TotalMSatReceived is the total number of milli-satoshis we've - // received within this channel. - TotalMSatReceived lnwire.MilliSatoshi - - // ChannelCommitment is the current up-to-date commitment for the - // target channel. - ChannelCommitment -} - -// Snapshot returns a read-only snapshot of the current channel state. This -// snapshot includes information concerning the current settled balance within -// the channel, metadata detailing total flows, and any outstanding HTLCs. -func (c *OpenChannel) Snapshot() *ChannelSnapshot { - c.RLock() - defer c.RUnlock() - - localCommit := c.LocalCommitment - snapshot := &ChannelSnapshot{ - RemoteIdentity: *c.IdentityPub, - ChannelPoint: c.FundingOutpoint, - Capacity: c.Capacity, - TotalMSatSent: c.TotalMSatSent, - TotalMSatReceived: c.TotalMSatReceived, - ChainHash: c.ChainHash, - ChannelCommitment: ChannelCommitment{ - LocalBalance: localCommit.LocalBalance, - RemoteBalance: localCommit.RemoteBalance, - CommitHeight: localCommit.CommitHeight, - CommitFee: localCommit.CommitFee, - }, - } - - localCommit.CustomBlob.WhenSome(func(blob tlv.Blob) { - blobCopy := make([]byte, len(blob)) - copy(blobCopy, blob) - - snapshot.ChannelCommitment.CustomBlob = fn.Some(blobCopy) - }) - - // Copy over the current set of HTLCs to ensure the caller can't mutate - // our internal state. - snapshot.Htlcs = make([]HTLC, len(localCommit.Htlcs)) - for i, h := range localCommit.Htlcs { - snapshot.Htlcs[i] = h.Copy() - } - - return snapshot -} - -// Copy returns a deep copy of the channel state. -func (c *OpenChannel) Copy() *OpenChannel { - c.RLock() - defer c.RUnlock() - - clone := &OpenChannel{ - ChanType: c.ChanType, - ChainHash: c.ChainHash, - FundingOutpoint: c.FundingOutpoint, - ShortChannelID: c.ShortChannelID, - IsPending: c.IsPending, - IsInitiator: c.IsInitiator, - chanStatus: c.chanStatus, - FundingBroadcastHeight: c.FundingBroadcastHeight, - ConfirmationHeight: c.ConfirmationHeight, - NumConfsRequired: c.NumConfsRequired, - ChannelFlags: c.ChannelFlags, - IdentityPub: c.IdentityPub, - Capacity: c.Capacity, - TotalMSatSent: c.TotalMSatSent, - TotalMSatReceived: c.TotalMSatReceived, - InitialLocalBalance: c.InitialLocalBalance, - InitialRemoteBalance: c.InitialRemoteBalance, - LocalChanCfg: c.LocalChanCfg, - RemoteChanCfg: c.RemoteChanCfg, - LocalCommitment: c.LocalCommitment.copy(), - RemoteCommitment: c.RemoteCommitment.copy(), - RemoteCurrentRevocation: c.RemoteCurrentRevocation, - RemoteNextRevocation: c.RemoteNextRevocation, - RevocationProducer: c.RevocationProducer, - RevocationStore: c.RevocationStore, - Packager: c.Packager, - ThawHeight: c.ThawHeight, - LastWasRevoke: c.LastWasRevoke, - RevocationKeyLocator: c.RevocationKeyLocator, - confirmedScid: c.confirmedScid, - TapscriptRoot: c.TapscriptRoot, - } - - if c.FundingTxn != nil { - clone.FundingTxn = c.FundingTxn.Copy() - } - - if len(c.LocalShutdownScript) > 0 { - clone.LocalShutdownScript = make( - lnwire.DeliveryAddress, - len(c.LocalShutdownScript), - ) - copy(clone.LocalShutdownScript, c.LocalShutdownScript) - } - if len(c.RemoteShutdownScript) > 0 { - clone.RemoteShutdownScript = make( - lnwire.DeliveryAddress, - len(c.RemoteShutdownScript), - ) - copy(clone.RemoteShutdownScript, c.RemoteShutdownScript) - } - - if len(c.Memo) > 0 { - clone.Memo = make([]byte, len(c.Memo)) - copy(clone.Memo, c.Memo) - } - - c.CustomBlob.WhenSome(func(blob tlv.Blob) { - blobCopy := make([]byte, len(blob)) - copy(blobCopy, blob) - clone.CustomBlob = fn.Some(blobCopy) - }) - - return clone -} - -// LatestCommitments returns the two latest commitments for both the local and -// remote party. These commitments are read from disk to ensure that only the -// latest fully committed state is returned. The first commitment returned is -// the local commitment, and the second returned is the remote commitment. -func (c *OpenChannel) LatestCommitments() (*ChannelCommitment, *ChannelCommitment, error) { - err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error { - chanBucket, err := fetchChanBucket( - tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, - ) - if err != nil { - return err - } - - return fetchChanCommitments(chanBucket, c) - }, func() {}) - if err != nil { - return nil, nil, err - } - - return &c.LocalCommitment, &c.RemoteCommitment, nil -} - -// RemoteRevocationStore returns the most up to date commitment version of the -// revocation storage tree for the remote party. This method can be used when -// acting on a possible contract breach to ensure, that the caller has the most -// up to date information required to deliver justice. -func (c *OpenChannel) RemoteRevocationStore() (shachain.Store, error) { - err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error { - chanBucket, err := fetchChanBucket( - tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, - ) - if err != nil { - return err - } - - return fetchChanRevocationState(chanBucket, c) - }, func() {}) - if err != nil { - return nil, err - } - - return c.RevocationStore, nil -} - -// AbsoluteThawHeight determines a frozen channel's absolute thaw height. If the -// channel is not frozen, then 0 is returned. -func (c *OpenChannel) AbsoluteThawHeight() (uint32, error) { - // Only frozen channels have a thaw height. - if !c.ChanType.IsFrozen() && !c.ChanType.HasLeaseExpiration() { - return 0, nil - } - - // If the channel has the frozen bit set and it's thaw height is below - // the absolute threshold, then it's interpreted as a relative height to - // the chain's current height. - if c.ChanType.IsFrozen() && c.ThawHeight < AbsoluteThawHeightThreshold { - // We'll only known of the channel's short ID once it's - // confirmed. - if c.IsPending { - return 0, errors.New("cannot use relative thaw " + - "height for unconfirmed channel") - } - - // For non-zero-conf channels, this is the base height to use. - blockHeightBase := c.ShortChannelID.BlockHeight - - // If this is a zero-conf channel, the ShortChannelID will be - // an alias. - if c.IsZeroConf() { - if !c.ZeroConfConfirmed() { - return 0, errors.New("cannot use relative " + - "height for unconfirmed zero-conf " + - "channel") - } - - // Use the confirmed SCID's BlockHeight. - blockHeightBase = c.confirmedScid.BlockHeight - } - - return blockHeightBase + c.ThawHeight, nil - } - - return c.ThawHeight, nil -} - -// DeriveHeightHint derives the block height for the channel opening. -func (c *OpenChannel) DeriveHeightHint() uint32 { - // As a height hint, we'll try to use the opening height, but if the - // channel isn't yet open, then we'll use the height it was broadcast - // at. This may be an unconfirmed zero-conf channel. - heightHint := c.ShortChanID().BlockHeight - if heightHint == 0 { - heightHint = c.BroadcastHeight() - } - - // Since no zero-conf state is stored in a channel backup, the below - // logic will not be triggered for restored, zero-conf channels. Set - // the height hint for zero-conf channels. - if c.IsZeroConf() { - if c.ZeroConfConfirmed() { - // If the zero-conf channel is confirmed, we'll use the - // confirmed SCID's block height. - heightHint = c.ZeroConfRealScid().BlockHeight - } else { - // The zero-conf channel is unconfirmed. We'll need to - // use the FundingBroadcastHeight. - heightHint = c.BroadcastHeight() - } - } - - return heightHint -} - -func putChannelCloseSummary(tx kvdb.RwTx, chanID []byte, - summary *ChannelCloseSummary, lastChanState *OpenChannel) error { - - closedChanBucket, err := tx.CreateTopLevelBucket(closedChannelBucket) - if err != nil { - return err - } - - summary.RemoteCurrentRevocation = lastChanState.RemoteCurrentRevocation - summary.RemoteNextRevocation = lastChanState.RemoteNextRevocation - summary.LocalChanConfig = lastChanState.LocalChanCfg - - var b bytes.Buffer - if err := serializeChannelCloseSummary(&b, summary); err != nil { - return err - } - - return closedChanBucket.Put(chanID, b.Bytes()) -} - -func serializeChannelCloseSummary(w io.Writer, cs *ChannelCloseSummary) error { - err := WriteElements(w, - cs.ChanPoint, cs.ShortChanID, cs.ChainHash, cs.ClosingTXID, - cs.CloseHeight, cs.RemotePub, cs.Capacity, cs.SettledBalance, - cs.TimeLockedBalance, cs.CloseType, cs.IsPending, - ) - if err != nil { - return err - } - - // If this is a close channel summary created before the addition of - // the new fields, then we can exit here. - if cs.RemoteCurrentRevocation == nil { - return WriteElements(w, false) - } - - // If fields are present, write boolean to indicate this, and continue. - if err := WriteElements(w, true); err != nil { - return err - } - - if err := WriteElements(w, cs.RemoteCurrentRevocation); err != nil { - return err - } - - if err := writeChanConfig(w, &cs.LocalChanConfig); err != nil { - return err - } - - // The RemoteNextRevocation field is optional, as it's possible for a - // channel to be closed before we learn of the next unrevoked - // revocation point for the remote party. Write a boolean indicating - // whether this field is present or not. - if err := WriteElements(w, cs.RemoteNextRevocation != nil); err != nil { - return err - } - - // Write the field, if present. - if cs.RemoteNextRevocation != nil { - if err = WriteElements(w, cs.RemoteNextRevocation); err != nil { - return err - } - } - - // Write whether the channel sync message is present. - if err := WriteElements(w, cs.LastChanSyncMsg != nil); err != nil { - return err - } - - // Write the channel sync message, if present. - if cs.LastChanSyncMsg != nil { - if err := WriteElements(w, cs.LastChanSyncMsg); err != nil { - return err - } - } - - return nil -} - -func deserializeCloseChannelSummary(r io.Reader) (*ChannelCloseSummary, error) { - c := &ChannelCloseSummary{} - - err := ReadElements(r, - &c.ChanPoint, &c.ShortChanID, &c.ChainHash, &c.ClosingTXID, - &c.CloseHeight, &c.RemotePub, &c.Capacity, &c.SettledBalance, - &c.TimeLockedBalance, &c.CloseType, &c.IsPending, - ) - if err != nil { - return nil, err - } - - // We'll now check to see if the channel close summary was encoded with - // any of the additional optional fields. - var hasNewFields bool - err = ReadElements(r, &hasNewFields) - if err != nil { - return nil, err - } - - // If fields are not present, we can return. - if !hasNewFields { - return c, nil - } - - // Otherwise read the new fields. - if err := ReadElements(r, &c.RemoteCurrentRevocation); err != nil { - return nil, err - } - - if err := readChanConfig(r, &c.LocalChanConfig); err != nil { - return nil, err - } - - // Finally, we'll attempt to read the next unrevoked commitment point - // for the remote party. If we closed the channel before receiving a - // channel_ready message then this might not be present. A boolean - // indicating whether the field is present will come first. - var hasRemoteNextRevocation bool - err = ReadElements(r, &hasRemoteNextRevocation) - if err != nil { - return nil, err - } - - // If this field was written, read it. - if hasRemoteNextRevocation { - err = ReadElements(r, &c.RemoteNextRevocation) - if err != nil { - return nil, err - } - } - - // Check if we have a channel sync message to read. - var hasChanSyncMsg bool - err = ReadElements(r, &hasChanSyncMsg) - if err == io.EOF { - return c, nil - } else if err != nil { - return nil, err - } - - // If a chan sync message is present, read it. - if hasChanSyncMsg { - // We must pass in reference to a lnwire.Message for the codec - // to support it. - var msg lnwire.Message - if err := ReadElements(r, &msg); err != nil { - return nil, err - } - - chanSync, ok := msg.(*lnwire.ChannelReestablish) - if !ok { - return nil, errors.New("unable cast db Message to " + - "ChannelReestablish") - } - c.LastChanSyncMsg = chanSync - } - - return c, nil -} - -func writeChanConfig(b io.Writer, c *ChannelConfig) error { - return WriteElements(b, - c.DustLimit, c.MaxPendingAmount, c.ChanReserve, c.MinHTLC, - c.MaxAcceptedHtlcs, c.CsvDelay, c.MultiSigKey, - c.RevocationBasePoint, c.PaymentBasePoint, c.DelayBasePoint, - c.HtlcBasePoint, - ) -} - -// fundingTxPresent returns true if expect the funding transcation to be found -// on disk or already populated within the passed open channel struct. -func fundingTxPresent(channel *OpenChannel) bool { - chanType := channel.ChanType - - return chanType.IsSingleFunder() && chanType.HasFundingTx() && - channel.IsInitiator && - !channel.hasChanStatus(ChanStatusRestored) -} - -func putChanInfo(chanBucket kvdb.RwBucket, channel *OpenChannel) error { - var w bytes.Buffer - if err := WriteElements(&w, - channel.ChanType, channel.ChainHash, channel.FundingOutpoint, - channel.ShortChannelID, channel.IsPending, channel.IsInitiator, - channel.chanStatus, channel.FundingBroadcastHeight, - channel.NumConfsRequired, channel.ChannelFlags, - channel.IdentityPub, channel.Capacity, channel.TotalMSatSent, - channel.TotalMSatReceived, - ); err != nil { - return err - } - - // For single funder channels that we initiated, and we have the - // funding transaction, then write the funding txn. - if fundingTxPresent(channel) { - if err := WriteElement(&w, channel.FundingTxn); err != nil { - return err - } - } - - if err := writeChanConfig(&w, &channel.LocalChanCfg); err != nil { - return err - } - if err := writeChanConfig(&w, &channel.RemoteChanCfg); err != nil { - return err - } - - auxData := channel.extractTlvData() - if err := auxData.encode(&w); err != nil { - return fmt.Errorf("unable to encode aux data: %w", err) - } - - if err := chanBucket.Put(chanInfoKey, w.Bytes()); err != nil { - return err - } - - // Finally, add optional shutdown scripts for the local and remote peer if - // they are present. - if err := putOptionalUpfrontShutdownScript( - chanBucket, localUpfrontShutdownKey, channel.LocalShutdownScript, - ); err != nil { - return err - } - - return putOptionalUpfrontShutdownScript( - chanBucket, remoteUpfrontShutdownKey, channel.RemoteShutdownScript, - ) -} - -// putOptionalUpfrontShutdownScript adds a shutdown script under the key -// provided if it has a non-zero length. -func putOptionalUpfrontShutdownScript(chanBucket kvdb.RwBucket, key []byte, - script []byte) error { - // If the script is empty, we do not need to add anything. - if len(script) == 0 { - return nil - } - - var w bytes.Buffer - if err := WriteElement(&w, script); err != nil { - return err - } - - return chanBucket.Put(key, w.Bytes()) -} - -// getOptionalUpfrontShutdownScript reads the shutdown script stored under the -// key provided if it is present. Upfront shutdown scripts are optional, so the -// function returns with no error if the key is not present. -func getOptionalUpfrontShutdownScript(chanBucket kvdb.RBucket, key []byte, - script *lnwire.DeliveryAddress) error { - - // Return early if the bucket does not exit, a shutdown script was not set. - bs := chanBucket.Get(key) - if bs == nil { - return nil - } - - var tempScript []byte - r := bytes.NewReader(bs) - if err := ReadElement(r, &tempScript); err != nil { - return err - } - *script = tempScript - - return nil -} - -func serializeChanCommit(w io.Writer, c *ChannelCommitment) error { - if err := WriteElements(w, - c.CommitHeight, c.LocalLogIndex, c.LocalHtlcIndex, - c.RemoteLogIndex, c.RemoteHtlcIndex, c.LocalBalance, - c.RemoteBalance, c.CommitFee, c.FeePerKw, c.CommitTx, - c.CommitSig, - ); err != nil { - return err - } - - return SerializeHtlcs(w, c.Htlcs...) -} - -func putChanCommitment(chanBucket kvdb.RwBucket, c *ChannelCommitment, - local bool) error { - - var commitKey []byte - if local { - commitKey = append(chanCommitmentKey, byte(0x00)) - } else { - commitKey = append(chanCommitmentKey, byte(0x01)) - } - - var b bytes.Buffer - if err := serializeChanCommit(&b, c); err != nil { - return err - } - - // Before we write to disk, we'll also write our aux data as well. - auxData := c.extractTlvData() - if err := auxData.encode(&b); err != nil { - return fmt.Errorf("unable to write aux data: %w", err) - } - - return chanBucket.Put(commitKey, b.Bytes()) -} - -func putChanCommitments(chanBucket kvdb.RwBucket, channel *OpenChannel) error { - // If this is a restored channel, then we don't have any commitments to - // write. - if channel.hasChanStatus(ChanStatusRestored) { - return nil - } - - err := putChanCommitment( - chanBucket, &channel.LocalCommitment, true, - ) - if err != nil { - return err - } +// CommitmentHeight returns the current commitment height. The commitment +// height represents the number of updates to the commitment state to date. +// This value is always monotonically increasing. This method is provided in +// order to allow multiple instances of a particular open channel to obtain a +// consistent view of the number of channel updates to date. +func (c *ChannelStateDB) CommitmentHeight(channel *OpenChannel) ( + uint64, error) { - return putChanCommitment( - chanBucket, &channel.RemoteCommitment, false, - ) + return c.kvStore.CommitmentHeight(channel) } -func putChanRevocationState(chanBucket kvdb.RwBucket, channel *OpenChannel) error { - var b bytes.Buffer - err := WriteElements( - &b, channel.RemoteCurrentRevocation, channel.RevocationProducer, - channel.RevocationStore, - ) - if err != nil { - return err - } - - // If the next revocation is present, which is only the case after the - // ChannelReady message has been sent, then we'll write it to disk. - if channel.RemoteNextRevocation != nil { - err = WriteElements(&b, channel.RemoteNextRevocation) - if err != nil { - return err - } - } +// FindPreviousState scans through the append-only log in an attempt to recover +// the previous channel state indicated by the update number. This method is +// intended to be used for obtaining the relevant data needed to claim all +// funds rightfully spendable in the case of an on-chain broadcast of the +// commitment transaction. +func (c *ChannelStateDB) FindPreviousState(channel *OpenChannel, + updateNum uint64) (*RevocationLog, *ChannelCommitment, error) { - return chanBucket.Put(revocationStateKey, b.Bytes()) + return c.kvStore.FindPreviousState(channel, updateNum) } -func readChanConfig(b io.Reader, c *ChannelConfig) error { - return ReadElements(b, - &c.DustLimit, &c.MaxPendingAmount, &c.ChanReserve, - &c.MinHTLC, &c.MaxAcceptedHtlcs, &c.CsvDelay, - &c.MultiSigKey, &c.RevocationBasePoint, - &c.PaymentBasePoint, &c.DelayBasePoint, - &c.HtlcBasePoint, - ) -} +// ClosureType is an enum like structure that details exactly how a channel was +// closed. +type ClosureType = cstate.ClosureType -func fetchChanInfo(chanBucket kvdb.RBucket, channel *OpenChannel) error { - infoBytes := chanBucket.Get(chanInfoKey) - if infoBytes == nil { - return ErrNoChanInfoFound - } - r := bytes.NewReader(infoBytes) - - if err := ReadElements(r, - &channel.ChanType, &channel.ChainHash, &channel.FundingOutpoint, - &channel.ShortChannelID, &channel.IsPending, &channel.IsInitiator, - &channel.chanStatus, &channel.FundingBroadcastHeight, - &channel.NumConfsRequired, &channel.ChannelFlags, - &channel.IdentityPub, &channel.Capacity, &channel.TotalMSatSent, - &channel.TotalMSatReceived, - ); err != nil { - return err - } +const ( + // CooperativeClose indicates that a channel has been closed + // cooperatively. + CooperativeClose = cstate.CooperativeClose - // For single funder channels that we initiated and have the funding - // transaction to, read the funding txn. - if fundingTxPresent(channel) { - if err := ReadElement(r, &channel.FundingTxn); err != nil { - return err - } - } + // LocalForceClose indicates that we have unilaterally broadcast our + // current commitment state on-chain. + LocalForceClose = cstate.LocalForceClose - if err := readChanConfig(r, &channel.LocalChanCfg); err != nil { - return err - } - if err := readChanConfig(r, &channel.RemoteChanCfg); err != nil { - return err - } + // RemoteForceClose indicates that the remote peer has unilaterally + // broadcast their current commitment state on-chain. + RemoteForceClose = cstate.RemoteForceClose - // Retrieve the boolean stored under lastWasRevokeKey. - lastWasRevokeBytes := chanBucket.Get(lastWasRevokeKey) - if lastWasRevokeBytes == nil { - // If nothing has been stored under this key, we store false in the - // OpenChannel struct. - channel.LastWasRevoke = false - } else { - // Otherwise, read the value into the LastWasRevoke field. - revokeReader := bytes.NewReader(lastWasRevokeBytes) - err := ReadElements(revokeReader, &channel.LastWasRevoke) - if err != nil { - return err - } - } + // BreachClose indicates that the remote peer attempted to broadcast a + // prior revoked channel state. + BreachClose = cstate.BreachClose - var auxData openChannelTlvData - if err := auxData.decode(r); err != nil { - return fmt.Errorf("unable to decode aux data: %w", err) - } + // FundingCanceled indicates that the channel never was fully opened + // before it was marked as closed in the database. + FundingCanceled = cstate.FundingCanceled - // Assign all the relevant fields from the aux data into the actual - // open channel. - channel.amendTlvData(auxData) + // Abandoned indicates that the channel state was removed without any + // further actions. + Abandoned = cstate.Abandoned +) - channel.Packager = NewChannelPackager(channel.ShortChannelID) +// ChannelCloseSummary contains the final state of a channel at the point it +// was closed. +type ChannelCloseSummary = cstate.ChannelCloseSummary - // Finally, read the optional shutdown scripts. - if err := getOptionalUpfrontShutdownScript( - chanBucket, localUpfrontShutdownKey, &channel.LocalShutdownScript, - ); err != nil { - return err - } +// CloseChannel closes the supplied channel via the strategy selected at DB +// construction. On synchronous backends the channel's nested state — the +// revocation log, the per-channel forwarding-package bucket, and the +// chanBucket itself — is deleted inline. On tombstone-enabled backends none +// of the bulk state is touched; the outpointBucket flip to outpointClosed +// signals that the channel is logically closed. +func (c *ChannelStateDB) CloseChannel(channel *OpenChannel, + summary *ChannelCloseSummary, statuses ...ChannelStatus) error { - return getOptionalUpfrontShutdownScript( - chanBucket, remoteUpfrontShutdownKey, &channel.RemoteShutdownScript, - ) + return c.kvStore.CloseChannel(channel, summary, statuses...) } -func deserializeChanCommit(r io.Reader) (ChannelCommitment, error) { - var c ChannelCommitment - - err := ReadElements(r, - &c.CommitHeight, &c.LocalLogIndex, &c.LocalHtlcIndex, &c.RemoteLogIndex, - &c.RemoteHtlcIndex, &c.LocalBalance, &c.RemoteBalance, - &c.CommitFee, &c.FeePerKw, &c.CommitTx, &c.CommitSig, - ) - if err != nil { - return c, err - } +// ChannelSnapshot is a frozen snapshot of the current channel state. +type ChannelSnapshot = cstate.ChannelSnapshot - c.Htlcs, err = DeserializeHtlcs(r) - if err != nil { - return c, err - } +// LatestCommitments returns the two latest commitments for both the local and +// remote party. These commitments are read from disk to ensure that only the +// latest fully committed state is returned. The first commitment returned is +// the local commitment, and the second returned is the remote commitment. +func (c *ChannelStateDB) LatestCommitments(channel *OpenChannel) ( + *ChannelCommitment, *ChannelCommitment, error) { - return c, nil + return c.kvStore.LatestCommitments(channel) } -func fetchChanCommitment(chanBucket kvdb.RBucket, - local bool) (ChannelCommitment, error) { - - var commitKey []byte - if local { - commitKey = append(chanCommitmentKey, byte(0x00)) - } else { - commitKey = append(chanCommitmentKey, byte(0x01)) - } - - commitBytes := chanBucket.Get(commitKey) - if commitBytes == nil { - return ChannelCommitment{}, ErrNoCommitmentsFound - } - - r := bytes.NewReader(commitBytes) - chanCommit, err := deserializeChanCommit(r) - if err != nil { - return ChannelCommitment{}, fmt.Errorf("unable to decode "+ - "chan commit: %w", err) - } - - // We'll also check to see if we have any aux data stored as the end of - // the stream. - var auxData commitTlvData - if err := auxData.decode(r); err != nil { - return ChannelCommitment{}, fmt.Errorf("unable to decode "+ - "chan aux data: %w", err) - } - - chanCommit.amendTlvData(auxData) +// RemoteRevocationStore returns the most up to date commitment version of the +// revocation storage tree for the remote party. This method can be used when +// acting on a possible contract breach to ensure, that the caller has the most +// up to date information required to deliver justice. +func (c *ChannelStateDB) RemoteRevocationStore(channel *OpenChannel) ( + shachain.Store, error) { - return chanCommit, nil + return c.kvStore.RemoteRevocationStore(channel) } -func fetchChanCommitments(chanBucket kvdb.RBucket, channel *OpenChannel) error { - var err error - - // If this is a restored channel, then we don't have any commitments to - // read. - if channel.hasChanStatus(ChanStatusRestored) { - return nil - } - - channel.LocalCommitment, err = fetchChanCommitment(chanBucket, true) - if err != nil { - return err - } - channel.RemoteCommitment, err = fetchChanCommitment(chanBucket, false) - if err != nil { - return err - } - - return nil +func serializeChannelCloseSummary(w io.Writer, cs *ChannelCloseSummary) error { + return cstate.SerializeChannelCloseSummary(w, cs) } -func fetchChanRevocationState(chanBucket kvdb.RBucket, channel *OpenChannel) error { - revBytes := chanBucket.Get(revocationStateKey) - if revBytes == nil { - return ErrNoRevocationsFound - } - r := bytes.NewReader(revBytes) - - err := ReadElements( - r, &channel.RemoteCurrentRevocation, &channel.RevocationProducer, - &channel.RevocationStore, - ) - if err != nil { - return err - } - - // If there aren't any bytes left in the buffer, then we don't yet have - // the next remote revocation, so we can exit early here. - if r.Len() == 0 { - return nil - } - - // Otherwise we'll read the next revocation for the remote party which - // is always the last item within the buffer. - return ReadElements(r, &channel.RemoteNextRevocation) +func deserializeCloseChannelSummary(r io.Reader) (*ChannelCloseSummary, error) { + return cstate.DeserializeCloseChannelSummary(r) } -func deleteOpenChannel(chanBucket kvdb.RwBucket) error { - if err := chanBucket.Delete(chanInfoKey); err != nil { - return err - } - - err := chanBucket.Delete(append(chanCommitmentKey, byte(0x00))) - if err != nil { - return err - } - err = chanBucket.Delete(append(chanCommitmentKey, byte(0x01))) - if err != nil { - return err - } - - if err := chanBucket.Delete(revocationStateKey); err != nil { - return err - } - - if diff := chanBucket.Get(commitDiffKey); diff != nil { - return chanBucket.Delete(commitDiffKey) - } - - return nil +func serializeChanCommit(w io.Writer, c *ChannelCommitment) error { + return cstate.SerializeChanCommit(w, c) } // makeLogKey converts a uint64 into an 8 byte array. @@ -5086,139 +729,30 @@ func makeLogKey(updateNum uint64) [8]byte { } func fetchThawHeight(chanBucket kvdb.RBucket) (uint32, error) { - var height uint32 - - heightBytes := chanBucket.Get(frozenChanKey) - heightReader := bytes.NewReader(heightBytes) - - if err := ReadElements(heightReader, &height); err != nil { - return 0, err - } - - return height, nil + return cstate.FetchThawHeight(chanBucket) } func storeThawHeight(chanBucket kvdb.RwBucket, height uint32) error { - var heightBuf bytes.Buffer - if err := WriteElements(&heightBuf, height); err != nil { - return err - } - - return chanBucket.Put(frozenChanKey, heightBuf.Bytes()) -} - -func deleteThawHeight(chanBucket kvdb.RwBucket) error { - return chanBucket.Delete(frozenChanKey) -} - -// keyLocRecord is a wrapper struct around keychain.KeyLocator to implement the -// tlv.RecordProducer interface. -type keyLocRecord struct { - keychain.KeyLocator -} - -// Record creates a Record out of a KeyLocator using the passed Type and the -// EKeyLocator and DKeyLocator functions. The size will always be 8 as -// KeyFamily is uint32 and the Index is uint32. -// -// NOTE: This is part of the tlv.RecordProducer interface. -func (k *keyLocRecord) Record() tlv.Record { - // Note that we set the type here as zero, as when used with a - // tlv.RecordT, the type param will be used as the type. - return tlv.MakeStaticRecord( - 0, &k.KeyLocator, 8, EKeyLocator, DKeyLocator, - ) + return cstate.StoreThawHeight(chanBucket, height) } // EKeyLocator is an encoder for keychain.KeyLocator. func EKeyLocator(w io.Writer, val interface{}, buf *[8]byte) error { - if v, ok := val.(*keychain.KeyLocator); ok { - err := tlv.EUint32T(w, uint32(v.Family), buf) - if err != nil { - return err - } - - return tlv.EUint32T(w, v.Index, buf) - } - return tlv.NewTypeForEncodingErr(val, "keychain.KeyLocator") + return cstate.EKeyLocator(w, val, buf) } // DKeyLocator is a decoder for keychain.KeyLocator. func DKeyLocator(r io.Reader, val interface{}, buf *[8]byte, l uint64) error { - if v, ok := val.(*keychain.KeyLocator); ok { - var family uint32 - err := tlv.DUint32(r, &family, buf, 4) - if err != nil { - return err - } - v.Family = keychain.KeyFamily(family) - - return tlv.DUint32(r, &v.Index, buf, 4) - } - return tlv.NewTypeForDecodingErr(val, "keychain.KeyLocator", l, 8) + return cstate.DKeyLocator(r, val, buf, l) } // ShutdownInfo contains various info about the shutdown initiation of a // channel. -type ShutdownInfo struct { - // DeliveryScript is the address that we have included in any previous - // Shutdown message for a particular channel and so should include in - // any future re-sends of the Shutdown message. - DeliveryScript tlv.RecordT[tlv.TlvType0, lnwire.DeliveryAddress] - - // LocalInitiator is true if we sent a Shutdown message before ever - // receiving a Shutdown message from the remote peer. - LocalInitiator tlv.RecordT[tlv.TlvType1, bool] -} +type ShutdownInfo = cstate.ShutdownInfo // NewShutdownInfo constructs a new ShutdownInfo object. func NewShutdownInfo(deliveryScript lnwire.DeliveryAddress, locallyInitiated bool) *ShutdownInfo { - return &ShutdownInfo{ - DeliveryScript: tlv.NewRecordT[tlv.TlvType0](deliveryScript), - LocalInitiator: tlv.NewPrimitiveRecord[tlv.TlvType1]( - locallyInitiated, - ), - } -} - -// Closer identifies the ChannelParty that initiated the coop-closure process. -func (s ShutdownInfo) Closer() lntypes.ChannelParty { - if s.LocalInitiator.Val { - return lntypes.Local - } - - return lntypes.Remote -} - -// encode serialises the ShutdownInfo to the given io.Writer. -func (s *ShutdownInfo) encode(w io.Writer) error { - records := []tlv.Record{ - s.DeliveryScript.Record(), - s.LocalInitiator.Record(), - } - - stream, err := tlv.NewStream(records...) - if err != nil { - return err - } - - return stream.Encode(w) -} - -// decodeShutdownInfo constructs a ShutdownInfo struct by decoding the given -// byte slice. -func decodeShutdownInfo(b []byte) (*ShutdownInfo, error) { - tlvStream := lnwire.ExtraOpaqueData(b) - - var info ShutdownInfo - records := []tlv.RecordProducer{ - &info.DeliveryScript, - &info.LocalInitiator, - } - - _, err := tlvStream.ExtractRecords(records...) - - return &info, err + return cstate.NewShutdownInfo(deliveryScript, locallyInitiated) } diff --git a/channeldb/channel_test.go b/channeldb/channel_test.go index 47504067780..bf2e0b7227d 100644 --- a/channeldb/channel_test.go +++ b/channeldb/channel_test.go @@ -414,7 +414,6 @@ func createTestChannelState(t *testing.T, cdb *ChannelStateDB) *OpenChannel { RevocationProducer: producer, RevocationStore: store, Db: cdb, - Packager: NewChannelPackager(chanID), FundingTxn: channels.TestFundingTx, ThawHeight: uint32(defaultPendingHeight), InitialLocalBalance: lnwire.MilliSatoshi(9000), @@ -879,7 +878,7 @@ func TestChannelStateTransition(t *testing.T) { // The state number recovered from the tail of the revocation log // should be identical to this current state. - logTailHeight, err := channel.revocationLogTailCommitHeight() + logTailHeight, err := cdb.revocationLogTailCommitHeight(channel) require.NoError(t, err, "unable to retrieve log") if logTailHeight != oldRemoteCommit.CommitHeight { t.Fatal("update number doesn't match") @@ -922,7 +921,7 @@ func TestChannelStateTransition(t *testing.T) { // Once again, state number recovered from the tail of the revocation // log should be identical to this current state. - logTailHeight, err = channel.revocationLogTailCommitHeight() + logTailHeight, err = cdb.revocationLogTailCommitHeight(channel) require.NoError(t, err, "unable to retrieve log") if logTailHeight != oldRemoteCommit.CommitHeight { t.Fatal("update number doesn't match") @@ -939,7 +938,9 @@ func TestChannelStateTransition(t *testing.T) { } // At this point, we should have 2 forwarding packages added. - fwdPkgs := loadFwdPkgs(t, cdb.backend, channel.Packager) + fwdPkgs := loadFwdPkgs( + t, cdb.backend, NewChannelPackager(channel.ShortChanID()), + ) require.Len(t, fwdPkgs, 2, "wrong number of forwarding packages") // Now attempt to delete the channel from the database. @@ -974,7 +975,9 @@ func TestChannelStateTransition(t *testing.T) { } // All forwarding packages of this channel has been deleted too. - fwdPkgs = loadFwdPkgs(t, cdb.backend, channel.Packager) + fwdPkgs = loadFwdPkgs( + t, cdb.backend, NewChannelPackager(channel.ShortChanID()), + ) require.Empty(t, fwdPkgs, "no forwarding packages should exist") } @@ -1424,16 +1427,6 @@ func TestRefresh(t *testing.T) { "updated before refreshing short_chan_id") } - // Now that the receiver's short channel id has been updated, check to - // ensure that the channel packager's source has been updated as well. - // This ensures that the packager will read and write to buckets - // corresponding to the new short chan id, instead of the prior. - if state.Packager.(*ChannelPackager).source != chanOpenLoc { - t.Fatalf("channel packager source was not updated: want %v, "+ - "got %v", chanOpenLoc, - state.Packager.(*ChannelPackager).source) - } - // Now, refresh the state of the pending channel. err = pendingChannel.Refresh() require.NoError(t, err, "unable to refresh short_chan_id") @@ -1446,16 +1439,6 @@ func TestRefresh(t *testing.T) { pendingChannel.ShortChanID()) } - // Check to ensure that the _other_ OpenChannel channel packager's - // source has also been updated after the refresh. This ensures that the - // other packagers will read and write to buckets corresponding to the - // updated short chan id. - if pendingChannel.Packager.(*ChannelPackager).source != chanOpenLoc { - t.Fatalf("channel packager source was not updated: want %v, "+ - "got %v", chanOpenLoc, - pendingChannel.Packager.(*ChannelPackager).source) - } - // Check to ensure that this channel is no longer pending and this field // is up to date. if pendingChannel.IsPending { @@ -1542,9 +1525,7 @@ func TestCloseInitiator(t *testing.T) { } // Lookup open channels in the database. - dbChans, err := fetchChannels( - cdb, pendingChannelFilter(false), - ) + dbChans, err := cdb.FetchAllChannels() if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -1559,7 +1540,7 @@ func TestCloseInitiator(t *testing.T) { if !dbChans[0].HasChanStatus(status) { t.Fatalf("expected channel to have "+ "status: %v, has status: %v", - status, dbChans[0].chanStatus) + status, dbChans[0].ChanStatus()) } } }) @@ -1642,9 +1623,8 @@ func TestHasChanStatus(t *testing.T) { test := test t.Run(test.name, func(t *testing.T) { - c := &OpenChannel{ - chanStatus: test.status, - } + c := &OpenChannel{} + c.SetChannelStatusForStore(test.status) for status, expHas := range test.expHas { has := c.HasChanStatus(status) diff --git a/channeldb/chanstate_assertions.go b/channeldb/chanstate_assertions.go new file mode 100644 index 00000000000..cce703b3906 --- /dev/null +++ b/channeldb/chanstate_assertions.go @@ -0,0 +1,7 @@ +package channeldb + +import "github.com/lightningnetwork/lnd/chanstate" + +// Compile-time assertion that ChannelStateDB satisfies the channel-state store +// contract while the KV implementation still lives in channeldb. +var _ chanstate.Store = (*ChannelStateDB)(nil) diff --git a/channeldb/close_channel_test.go b/channeldb/close_channel_test.go index 3a940b78dec..43466a44b1f 100644 --- a/channeldb/close_channel_test.go +++ b/channeldb/close_channel_test.go @@ -15,10 +15,12 @@ import ( // revocationLogBucket of the given channel. The helper navigates the raw KV // tree so the test does not depend on the higher-level commit-chain // machinery. -func writeTestRevlogEntries(t *testing.T, ch *OpenChannel, n int) { +func writeTestRevlogEntries(t *testing.T, cdb *ChannelStateDB, + ch *OpenChannel, n int) { + t.Helper() - err := kvdb.Update(ch.Db.backend, func(tx kvdb.RwTx) error { + err := kvdb.Update(cdb.backend, func(tx kvdb.RwTx) error { openChanBkt := tx.ReadWriteBucket(openChannelBucket) require.NotNil(t, openChanBkt, "openChannelBucket missing") @@ -56,11 +58,13 @@ func writeTestRevlogEntries(t *testing.T, ch *OpenChannel, n int) { // writeTestForwardingPackages writes n empty forwarding packages for the // given channel using distinct remote commitment heights. -func writeTestForwardingPackages(t *testing.T, ch *OpenChannel, n int) { +func writeTestForwardingPackages(t *testing.T, cdb *ChannelStateDB, + ch *OpenChannel, n int) { + t.Helper() packager := NewChannelPackager(ch.ShortChanID()) - err := kvdb.Update(ch.Db.backend, func(tx kvdb.RwTx) error { + err := kvdb.Update(cdb.backend, func(tx kvdb.RwTx) error { for i := range n { pkg := NewFwdPkg( ch.ShortChanID(), uint64(i), nil, nil, @@ -78,11 +82,13 @@ func writeTestForwardingPackages(t *testing.T, ch *OpenChannel, n int) { // countRevlogEntries returns the number of entries in the revocationLogBucket // for the given channel, or -1 if the channel bucket no longer exists in // openChannelBucket. -func countRevlogEntries(t *testing.T, ch *OpenChannel) int { +func countRevlogEntries(t *testing.T, cdb *ChannelStateDB, + ch *OpenChannel) int { + t.Helper() count := -1 - err := kvdb.View(ch.Db.backend, func(tx kvdb.RTx) error { + err := kvdb.View(cdb.backend, func(tx kvdb.RTx) error { openChanBkt := tx.ReadBucket(openChannelBucket) if openChanBkt == nil { return nil @@ -202,8 +208,8 @@ func TestCloseChannelTombstoneWritePath(t *testing.T) { const numRevlogEntries = 5 const numFwdPkgs = 3 - writeTestRevlogEntries(t, ch, numRevlogEntries) - writeTestForwardingPackages(t, ch, numFwdPkgs) + writeTestRevlogEntries(t, cdb, ch, numRevlogEntries) + writeTestForwardingPackages(t, cdb, ch, numFwdPkgs) closeChannelForTest(t, cdb, ch) @@ -224,7 +230,7 @@ func TestCloseChannelTombstoneWritePath(t *testing.T) { require.Equal(t, ch.FundingOutpoint, closeSummary.ChanPoint) // Bulk state preserved on disk — tombstoning's whole point. - require.Equal(t, numRevlogEntries, countRevlogEntries(t, ch)) + require.Equal(t, numRevlogEntries, countRevlogEntries(t, cdb, ch)) packager := NewChannelPackager(ch.ShortChanID()) var fwdPkgs []*FwdPkg @@ -281,7 +287,7 @@ func TestCloseChannelTombstoneRemovesFromOpenScans(t *testing.T) { ch2 := createTestChannel(t, cdb, openChannelOption()) const numRevlogEntries = 5 - writeTestRevlogEntries(t, ch1, numRevlogEntries) + writeTestRevlogEntries(t, cdb, ch1, numRevlogEntries) openChans, err := cdb.FetchAllChannels() require.NoError(t, err) @@ -313,7 +319,7 @@ func TestCloseChannelTombstoneRemovesFromOpenScans(t *testing.T) { // The bulk historical state stays put — that is the whole point of // the tombstone path on these backends. - require.Equal(t, numRevlogEntries, countRevlogEntries(t, ch1)) + require.Equal(t, numRevlogEntries, countRevlogEntries(t, cdb, ch1)) // The outpoint index for ch1 must flip to closed; ch2's stays open. require.Equal(t, outpointClosed, readOutpointStatus( @@ -380,14 +386,14 @@ func TestCloseChannelSync(t *testing.T) { ch := createTestChannel(t, cdb, openChannelOption()) const numRevlogEntries = 4 - writeTestRevlogEntries(t, ch, numRevlogEntries) - writeTestForwardingPackages(t, ch, 3) + writeTestRevlogEntries(t, cdb, ch, numRevlogEntries) + writeTestForwardingPackages(t, cdb, ch, 3) closeChannelForTest(t, cdb, ch) // The synchronous path wipes the chanBucket inline, so // countRevlogEntries must report -1 (bucket is gone, not just empty). - require.Equal(t, -1, countRevlogEntries(t, ch), + require.Equal(t, -1, countRevlogEntries(t, cdb, ch), "channel bucket must be deleted after sync close") // Forwarding packages are wiped inline. diff --git a/channeldb/codec.go b/channeldb/codec.go index e5bab3d5f76..fe35c63da49 100644 --- a/channeldb/codec.go +++ b/channeldb/codec.go @@ -1,41 +1,20 @@ package channeldb import ( - "bytes" - "encoding/binary" - "fmt" "io" - "net" "time" - "github.com/btcsuite/btcd/btcec/v2" - "github.com/btcsuite/btcd/btcutil" - "github.com/btcsuite/btcd/chaincfg/chainhash" - "github.com/btcsuite/btcd/wire" - graphdb "github.com/lightningnetwork/lnd/graph/db" - "github.com/lightningnetwork/lnd/keychain" - "github.com/lightningnetwork/lnd/lnwire" - "github.com/lightningnetwork/lnd/shachain" - "github.com/lightningnetwork/lnd/tlv" + cstate "github.com/lightningnetwork/lnd/chanstate" ) // UnknownElementType is an error returned when the codec is unable to encode or // decode a particular type. -type UnknownElementType struct { - method string - element interface{} -} +type UnknownElementType = cstate.UnknownElementType // NewUnknownElementType creates a new UnknownElementType error from the passed // method name and element. func NewUnknownElementType(method string, el interface{}) UnknownElementType { - return UnknownElementType{method: method, element: el} -} - -// Error returns the name of the method that encountered the error, as well as -// the type that was unsupported. -func (e UnknownElementType) Error() string { - return fmt.Sprintf("Unknown type in %s: %T", e.method, e.element) + return cstate.NewUnknownElementType(method, el) } // WriteElement is a one-stop shop to write the big endian representation of @@ -43,419 +22,26 @@ func (e UnknownElementType) Error() string { // io.Writer should be backed by an appropriately sized byte slice, or be able // to dynamically expand to accommodate additional data. func WriteElement(w io.Writer, element interface{}) error { - switch e := element.(type) { - case keychain.KeyDescriptor: - if err := binary.Write(w, byteOrder, e.Family); err != nil { - return err - } - if err := binary.Write(w, byteOrder, e.Index); err != nil { - return err - } - - if e.PubKey != nil { - if err := binary.Write(w, byteOrder, true); err != nil { - return fmt.Errorf("error writing serialized "+ - "element: %w", err) - } - - return WriteElement(w, e.PubKey) - } - - return binary.Write(w, byteOrder, false) - case ChannelType: - var buf [8]byte - if err := tlv.WriteVarInt(w, uint64(e), &buf); err != nil { - return err - } - - case chainhash.Hash: - if _, err := w.Write(e[:]); err != nil { - return err - } - - case wire.OutPoint: - return graphdb.WriteOutpoint(w, &e) - - case lnwire.ShortChannelID: - if err := binary.Write(w, byteOrder, e.ToUint64()); err != nil { - return err - } - - case lnwire.ChannelID: - if _, err := w.Write(e[:]); err != nil { - return err - } - - case int64, uint64: - if err := binary.Write(w, byteOrder, e); err != nil { - return err - } - - case uint32: - if err := binary.Write(w, byteOrder, e); err != nil { - return err - } - - case int32: - if err := binary.Write(w, byteOrder, e); err != nil { - return err - } - - case uint16: - if err := binary.Write(w, byteOrder, e); err != nil { - return err - } - - case uint8: - if err := binary.Write(w, byteOrder, e); err != nil { - return err - } - - case bool: - if err := binary.Write(w, byteOrder, e); err != nil { - return err - } - - case btcutil.Amount: - if err := binary.Write(w, byteOrder, uint64(e)); err != nil { - return err - } - - case lnwire.MilliSatoshi: - if err := binary.Write(w, byteOrder, uint64(e)); err != nil { - return err - } - - case *btcec.PrivateKey: - b := e.Serialize() - if _, err := w.Write(b); err != nil { - return err - } - - case *btcec.PublicKey: - b := e.SerializeCompressed() - if _, err := w.Write(b); err != nil { - return err - } - - case shachain.Producer: - return e.Encode(w) - - case shachain.Store: - return e.Encode(w) - - case *wire.MsgTx: - return e.Serialize(w) - - case [32]byte: - if _, err := w.Write(e[:]); err != nil { - return err - } - - case []byte: - if err := wire.WriteVarBytes(w, 0, e); err != nil { - return err - } - - case lnwire.Message: - var msgBuf bytes.Buffer - if _, err := lnwire.WriteMessage(&msgBuf, e, 0); err != nil { - return err - } - - msgLen := uint16(len(msgBuf.Bytes())) - if err := WriteElements(w, msgLen); err != nil { - return err - } - - if _, err := w.Write(msgBuf.Bytes()); err != nil { - return err - } - - case ChannelStatus: - var buf [8]byte - if err := tlv.WriteVarInt(w, uint64(e), &buf); err != nil { - return err - } - - case ClosureType: - if err := binary.Write(w, byteOrder, e); err != nil { - return err - } - - case lnwire.FundingFlag: - if err := binary.Write(w, byteOrder, e); err != nil { - return err - } - - case net.Addr: - if err := graphdb.SerializeAddr(w, e); err != nil { - return err - } - - case []net.Addr: - if err := WriteElement(w, uint32(len(e))); err != nil { - return err - } - - for _, addr := range e { - if err := graphdb.SerializeAddr(w, addr); err != nil { - return err - } - } - - default: - return UnknownElementType{"WriteElement", e} - } - - return nil + return cstate.WriteElement(w, element) } // WriteElements is writes each element in the elements slice to the passed // io.Writer using WriteElement. func WriteElements(w io.Writer, elements ...interface{}) error { - for _, element := range elements { - err := WriteElement(w, element) - if err != nil { - return err - } - } - return nil + return cstate.WriteElements(w, elements...) } // ReadElement is a one-stop utility function to deserialize any datastructure // encoded using the serialization format of the database. func ReadElement(r io.Reader, element interface{}) error { - switch e := element.(type) { - case *keychain.KeyDescriptor: - if err := binary.Read(r, byteOrder, &e.Family); err != nil { - return err - } - if err := binary.Read(r, byteOrder, &e.Index); err != nil { - return err - } - - var hasPubKey bool - if err := binary.Read(r, byteOrder, &hasPubKey); err != nil { - return err - } - - if hasPubKey { - return ReadElement(r, &e.PubKey) - } - - case *ChannelType: - var buf [8]byte - ctype, err := tlv.ReadVarInt(r, &buf) - if err != nil { - return err - } - - *e = ChannelType(ctype) - - case *chainhash.Hash: - if _, err := io.ReadFull(r, e[:]); err != nil { - return err - } - - case *wire.OutPoint: - return graphdb.ReadOutpoint(r, e) - - case *lnwire.ShortChannelID: - var a uint64 - if err := binary.Read(r, byteOrder, &a); err != nil { - return err - } - *e = lnwire.NewShortChanIDFromInt(a) - - case *lnwire.ChannelID: - if _, err := io.ReadFull(r, e[:]); err != nil { - return err - } - - case *int64, *uint64: - if err := binary.Read(r, byteOrder, e); err != nil { - return err - } - - case *uint32: - if err := binary.Read(r, byteOrder, e); err != nil { - return err - } - - case *int32: - if err := binary.Read(r, byteOrder, e); err != nil { - return err - } - - case *uint16: - if err := binary.Read(r, byteOrder, e); err != nil { - return err - } - - case *uint8: - if err := binary.Read(r, byteOrder, e); err != nil { - return err - } - - case *bool: - if err := binary.Read(r, byteOrder, e); err != nil { - return err - } - - case *btcutil.Amount: - var a uint64 - if err := binary.Read(r, byteOrder, &a); err != nil { - return err - } - - *e = btcutil.Amount(a) - - case *lnwire.MilliSatoshi: - var a uint64 - if err := binary.Read(r, byteOrder, &a); err != nil { - return err - } - - *e = lnwire.MilliSatoshi(a) - - case **btcec.PrivateKey: - var b [btcec.PrivKeyBytesLen]byte - if _, err := io.ReadFull(r, b[:]); err != nil { - return err - } - - priv, _ := btcec.PrivKeyFromBytes(b[:]) - *e = priv - - case **btcec.PublicKey: - var b [btcec.PubKeyBytesLenCompressed]byte - if _, err := io.ReadFull(r, b[:]); err != nil { - return err - } - - pubKey, err := btcec.ParsePubKey(b[:]) - if err != nil { - return err - } - *e = pubKey - - case *shachain.Producer: - var root [32]byte - if _, err := io.ReadFull(r, root[:]); err != nil { - return err - } - - // TODO(roasbeef): remove - producer, err := shachain.NewRevocationProducerFromBytes(root[:]) - if err != nil { - return err - } - - *e = producer - - case *shachain.Store: - store, err := shachain.NewRevocationStoreFromBytes(r) - if err != nil { - return err - } - - *e = store - - case **wire.MsgTx: - tx := wire.NewMsgTx(2) - if err := tx.Deserialize(r); err != nil { - return err - } - - *e = tx - - case *[32]byte: - if _, err := io.ReadFull(r, e[:]); err != nil { - return err - } - - case *[]byte: - bytes, err := wire.ReadVarBytes(r, 0, 66000, "[]byte") - if err != nil { - return err - } - - *e = bytes - - case *lnwire.Message: - var msgLen uint16 - if err := ReadElement(r, &msgLen); err != nil { - return err - } - - msgReader := io.LimitReader(r, int64(msgLen)) - msg, err := lnwire.ReadMessage(msgReader, 0) - if err != nil { - return err - } - - *e = msg - - case *ChannelStatus: - var buf [8]byte - status, err := tlv.ReadVarInt(r, &buf) - if err != nil { - return err - } - - *e = ChannelStatus(status) - - case *ClosureType: - if err := binary.Read(r, byteOrder, e); err != nil { - return err - } - - case *lnwire.FundingFlag: - if err := binary.Read(r, byteOrder, e); err != nil { - return err - } - - case *net.Addr: - addr, err := graphdb.DeserializeAddr(r) - if err != nil { - return err - } - *e = addr - - case *[]net.Addr: - var numAddrs uint32 - if err := ReadElement(r, &numAddrs); err != nil { - return err - } - - *e = make([]net.Addr, numAddrs) - for i := uint32(0); i < numAddrs; i++ { - addr, err := graphdb.DeserializeAddr(r) - if err != nil { - return err - } - (*e)[i] = addr - } - - default: - return UnknownElementType{"ReadElement", e} - } - - return nil + return cstate.ReadElement(r, element) } // ReadElements deserializes a variable number of elements into the passed // io.Reader, with each element being deserialized according to the ReadElement // function. func ReadElements(r io.Reader, elements ...interface{}) error { - for _, element := range elements { - err := ReadElement(r, element) - if err != nil { - return err - } - } - return nil + return cstate.ReadElements(r, elements...) } // deserializeTime deserializes time as unix nanoseconds. diff --git a/channeldb/db.go b/channeldb/db.go index 3f0036b112c..5c935133891 100644 --- a/channeldb/db.go +++ b/channeldb/db.go @@ -12,7 +12,6 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/wire" - "github.com/btcsuite/btcwallet/walletdb" mig "github.com/lightningnetwork/lnd/channeldb/migration" "github.com/lightningnetwork/lnd/channeldb/migration12" "github.com/lightningnetwork/lnd/channeldb/migration13" @@ -32,6 +31,7 @@ import ( "github.com/lightningnetwork/lnd/channeldb/migration34" "github.com/lightningnetwork/lnd/channeldb/migration35" "github.com/lightningnetwork/lnd/channeldb/migration_01_to_11" + "github.com/lightningnetwork/lnd/chanstate" "github.com/lightningnetwork/lnd/clock" graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/invoices" @@ -51,13 +51,11 @@ var ( // ErrFinalHtlcsBucketNotFound signals that the top-level final htlcs // bucket does not exist. - ErrFinalHtlcsBucketNotFound = errors.New("final htlcs bucket not " + - "found") + ErrFinalHtlcsBucketNotFound = chanstate.ErrFinalHtlcsBucketNotFound // ErrFinalChannelBucketNotFound signals that the channel bucket for // final htlc outcomes does not exist. - ErrFinalChannelBucketNotFound = errors.New("final htlcs channel " + - "bucket not found") + ErrFinalChannelBucketNotFound = chanstate.ErrFinalChannelBucketNotFound ) // migration is a function which takes a prior outdated version of the database @@ -343,11 +341,6 @@ var ( // Big endian is the preferred byte order, due to cursor scans over // integer keys iterating in order. byteOrder = binary.BigEndian - - // channelOpeningStateBucket is the database bucket used to store the - // channelOpeningState for each channel that is currently in the process - // of being opened. - channelOpeningStateBucket = []byte("channelOpeningState") ) // DB is the primary datastore for the lnd daemon. The database stores @@ -418,6 +411,12 @@ func CreateWithBackend(backend kvdb.Backend, modifiers ...OptionModifier) (*DB, linkNodeDB: &LinkNodeDB{ backend: backend, }, + kvStore: chanstate.NewKVStore( + backend, + opts.storeFinalHtlcResolutions, + opts.NoRevLogAmtData, + opts.tombstoneClosedChannels, + ), backend: backend, tombstoneClosedChannels: opts.tombstoneClosedChannels, }, @@ -550,6 +549,10 @@ type ChannelStateDB struct { // database. This may be a real backend or a cache middleware. backend kvdb.Backend + // kvStore is the chanstate-owned KV implementation. ChannelStateDB + // keeps compatibility wrappers while callers still import channeldb. + kvStore *chanstate.KVStore + // tombstoneClosedChannels is set by OptionTombstoneClosedChannels. // When true, CloseChannel skips deleting nested per-channel state and // relies on the outpointBucket flip to outpointClosed as the @@ -576,16 +579,12 @@ func (c *ChannelStateDB) LinkNodeDB() *LinkNodeDB { func (c *ChannelStateDB) FetchOpenChannels(nodeID *btcec.PublicKey) ( []*OpenChannel, error) { - var channels []*OpenChannel - err := kvdb.View(c.backend, func(tx kvdb.RTx) error { - var err error - channels, err = c.fetchOpenChannels(tx, nodeID) - return err - }, func() { - channels = nil - }) + channels, err := c.kvStore.FetchOpenChannels(nodeID) + if err != nil { + return nil, err + } - return channels, err + return c.attachOpenChannelStores(channels), nil } // fetchOpenChannels uses and existing database transaction and returns all @@ -595,115 +594,32 @@ func (c *ChannelStateDB) FetchOpenChannels(nodeID *btcec.PublicKey) ( func (c *ChannelStateDB) fetchOpenChannels(tx kvdb.RTx, nodeID *btcec.PublicKey) ([]*OpenChannel, error) { - // Get the bucket dedicated to storing the metadata for open channels. - openChanBucket := tx.ReadBucket(openChannelBucket) - if openChanBucket == nil { - return nil, nil - } - - // Within this top level bucket, fetch the bucket dedicated to storing - // open channel data specific to the remote node. - pub := nodeID.SerializeCompressed() - nodeChanBucket := openChanBucket.NestedReadBucket(pub) - if nodeChanBucket == nil { - return nil, nil + channels, err := chanstate.FetchOpenChannelsTx(tx, nodeID) + if err != nil { + return nil, err } - // Next, we'll need to go down an additional layer in order to retrieve - // the channels for each chain the node knows of. - var channels []*OpenChannel - err := nodeChanBucket.ForEach(func(chainHash, v []byte) error { - // If there's a value, it's not a bucket so ignore it. - if v != nil { - return nil - } - - // If we've found a valid chainhash bucket, then we'll retrieve - // that so we can extract all the channels. - chainBucket := nodeChanBucket.NestedReadBucket(chainHash) - if chainBucket == nil { - return fmt.Errorf("unable to read bucket for chain=%x", - chainHash[:]) - } - - // Finally, we both of the necessary buckets retrieved, fetch - // all the active channels related to this node. - nodeChannels, err := c.fetchNodeChannels(tx, chainBucket) - if err != nil { - return fmt.Errorf("unable to read channel for "+ - "chain_hash=%x, node_key=%x: %v", - chainHash[:], pub, err) - } - - channels = append(channels, nodeChannels...) - return nil - }) - - return channels, err + return c.attachOpenChannelStores(channels), nil } -// fetchNodeChannels retrieves all active channels from the target chainBucket -// which is under a node's dedicated channel bucket. This function is typically -// used to fetch all the active channels related to a particular node. Channels -// already flipped to outpointClosed in the outpoint index are skipped silently -// — readers see only channels that are still considered open. -func (c *ChannelStateDB) fetchNodeChannels(tx kvdb.RTx, - chainBucket kvdb.RBucket) ([]*OpenChannel, error) { - - var channels []*OpenChannel - - // Hoist the outpoint-bucket lookup so the closed-channel check inside - // the loop is a per-iteration map probe rather than a tx-level bucket - // resolve. - opBucket := tx.ReadBucket(outpointBucket) - - // A node may have channels on several chains, so for each known chain, - // we'll extract all the channels. - err := chainBucket.ForEach(func(chanPoint, v []byte) error { - // If there's a value, it's not a bucket so ignore it. - if v != nil { - return nil - } +func (c *ChannelStateDB) attachOpenChannelStore( + channel *OpenChannel) *OpenChannel { - // Skip already-closed channels. The chanBucket still exists - // on disk on tombstone-enabled backends; the outpoint flip is - // the sole signal that the channel should be treated as - // closed. - isClosed, err := isOutpointClosed(opBucket, chanPoint) - if err != nil { - return err - } - if isClosed { - return nil - } - - // Once we've found a valid channel bucket, we'll extract it - // from the node's chain bucket. - chanBucket := chainBucket.NestedReadBucket(chanPoint) + if channel != nil { + channel.Db = c + } - var outPoint wire.OutPoint - err = graphdb.ReadOutpoint( - bytes.NewReader(chanPoint), &outPoint, - ) - if err != nil { - return err - } - oChannel, err := fetchOpenChannel(chanBucket, &outPoint) - if err != nil { - return fmt.Errorf("unable to read channel data for "+ - "chan_point=%v: %w", outPoint, err) - } - oChannel.Db = c + return channel +} - channels = append(channels, oChannel) +func (c *ChannelStateDB) attachOpenChannelStores( + channels []*OpenChannel) []*OpenChannel { - return nil - }) - if err != nil { - return nil, err + for _, channel := range channels { + c.attachOpenChannelStore(channel) } - return channels, nil + return channels } // FetchChannel attempts to locate a channel specified by the passed channel @@ -711,20 +627,12 @@ func (c *ChannelStateDB) fetchNodeChannels(tx kvdb.RTx, func (c *ChannelStateDB) FetchChannel(chanPoint wire.OutPoint) (*OpenChannel, error) { - var targetChanPoint bytes.Buffer - err := graphdb.WriteOutpoint(&targetChanPoint, &chanPoint) + channel, err := c.kvStore.FetchChannel(chanPoint) if err != nil { return nil, err } - targetChanPointBytes := targetChanPoint.Bytes() - selector := func(chainBkt walletdb.ReadBucket) ([]byte, *wire.OutPoint, - error) { - - return targetChanPointBytes, &chanPoint, nil - } - - return c.channelScanner(nil, selector) + return c.attachOpenChannelStore(channel), nil } // FetchChannelByID attempts to locate a channel specified by the passed channel @@ -732,55 +640,16 @@ func (c *ChannelStateDB) FetchChannel(chanPoint wire.OutPoint) (*OpenChannel, func (c *ChannelStateDB) FetchChannelByID(id lnwire.ChannelID) (*OpenChannel, error) { - selector := func(chainBkt walletdb.ReadBucket) ([]byte, *wire.OutPoint, - error) { - - var ( - targetChanPointBytes []byte - targetChanPoint *wire.OutPoint - - // errChanFound is used to signal that the channel has - // been found so that iteration through the DB buckets - // can stop. - errChanFound = errors.New("channel found") - ) - err := chainBkt.ForEach(func(k, _ []byte) error { - var outPoint wire.OutPoint - err := graphdb.ReadOutpoint( - bytes.NewReader(k), &outPoint, - ) - if err != nil { - return err - } - - chanID := lnwire.NewChanIDFromOutPoint(outPoint) - if chanID != id { - return nil - } - - targetChanPoint = &outPoint - targetChanPointBytes = k - - return errChanFound - }) - if err != nil && !errors.Is(err, errChanFound) { - return nil, nil, err - } - if targetChanPoint == nil { - return nil, nil, ErrChannelNotFound - } - - return targetChanPointBytes, targetChanPoint, nil + channel, err := c.kvStore.FetchChannelByID(id) + if err != nil { + return nil, err } - return c.channelScanner(nil, selector) + return c.attachOpenChannelStore(channel), nil } // ChanCount is used by the server in determining access control. -type ChanCount struct { - HasOpenOrClosedChan bool - PendingOpenCount uint64 -} +type ChanCount = chanstate.ChanCount // FetchPermAndTempPeers returns a map where the key is the remote node's // public key and the value is a struct that has a tally of the pending-open @@ -788,373 +657,43 @@ type ChanCount struct { func (c *ChannelStateDB) FetchPermAndTempPeers( chainHash []byte) (map[string]ChanCount, error) { - peerChanInfo := make(map[string]ChanCount) - - err := kvdb.View(c.backend, func(tx kvdb.RTx) error { - openChanBucket := tx.ReadBucket(openChannelBucket) - if openChanBucket == nil { - return ErrNoChanDBExists - } - - // Hoist the outpoint-bucket lookup so the closed-channel check - // inside the nested chainBucket.ForEach below is a per-channel - // map probe rather than a tx-level bucket resolve. - opBucket := tx.ReadBucket(outpointBucket) - - openChanErr := openChanBucket.ForEach(func(nodePub, - v []byte) error { - - // If there is a value, this is not a bucket. - if v != nil { - return nil - } - - nodeChanBucket := openChanBucket.NestedReadBucket( - nodePub, - ) - if nodeChanBucket == nil { - return nil - } - - chainBucket := nodeChanBucket.NestedReadBucket( - chainHash, - ) - if chainBucket == nil { - return fmt.Errorf("no chain bucket exists") - } - - var isPermPeer bool - var pendingOpenCount uint64 - - internalErr := chainBucket.ForEach(func(chanPoint, - val []byte) error { - - // If there is a value, this is not a bucket. - if val != nil { - return nil - } - - // Skip already-closed channels: they are - // logically closed even though their - // per-channel state still resides under - // chainBucket. The closed peer's protected - // status is established below via the - // historical-channel scan. - isClosed, err := isOutpointClosed( - opBucket, chanPoint, - ) - if err != nil { - return err - } - if isClosed { - return nil - } - - chanBucket := chainBucket.NestedReadBucket( - chanPoint, - ) - if chanBucket == nil { - return nil - } - - var op wire.OutPoint - readErr := graphdb.ReadOutpoint( - bytes.NewReader(chanPoint), &op, - ) - if readErr != nil { - return readErr - } - - // We need to go through each channel and look - // at the IsPending status. - openChan, err := fetchOpenChannel( - chanBucket, &op, - ) - if err != nil { - return err - } - - if openChan.IsPending { - // Add to the pending-open count since - // this is a temp peer. - pendingOpenCount++ - return nil - } - - // Since IsPending is false, this is a perm - // peer. - isPermPeer = true - - return nil - }) - if internalErr != nil { - return internalErr - } - - peerCount := ChanCount{ - HasOpenOrClosedChan: isPermPeer, - PendingOpenCount: pendingOpenCount, - } - peerChanInfo[string(nodePub)] = peerCount - - return nil - }) - if openChanErr != nil { - return openChanErr - } - - // Now check the closed channel bucket. - historicalChanBucket := tx.ReadBucket(historicalChannelBucket) - if historicalChanBucket == nil { - return ErrNoHistoricalBucket - } - - historicalErr := historicalChanBucket.ForEach(func(chanPoint, - v []byte) error { - // Parse each nested bucket and the chanInfoKey to get - // the IsPending bool. This determines whether the - // peer is protected or not. - if v != nil { - // This is not a bucket. This is currently not - // possible. - return nil - } - - chanBucket := historicalChanBucket.NestedReadBucket( - chanPoint, - ) - if chanBucket == nil { - // This is not possible. - return fmt.Errorf("no historical channel " + - "bucket exists") - } - - var op wire.OutPoint - readErr := graphdb.ReadOutpoint( - bytes.NewReader(chanPoint), &op, - ) - if readErr != nil { - return readErr - } - - // This channel is closed, but the structure of the - // historical bucket is the same. This is by design, - // which means we can call fetchOpenChannel. - channel, fetchErr := fetchOpenChannel(chanBucket, &op) - if fetchErr != nil { - return fetchErr - } - - // Only include this peer in the protected class if - // the closing transaction confirmed. Note that - // CloseChannel can be called in the funding manager - // while IsPending is true which is why we need this - // special-casing to not count premature funding - // manager calls to CloseChannel. - if !channel.IsPending { - // Fetch the public key of the remote node. We - // need to use the string-ified serialized, - // compressed bytes as the key. - remotePub := channel.IdentityPub - remoteSer := remotePub.SerializeCompressed() - remoteKey := string(remoteSer) - - count, exists := peerChanInfo[remoteKey] - if exists { - count.HasOpenOrClosedChan = true - peerChanInfo[remoteKey] = count - } else { - peerCount := ChanCount{ - HasOpenOrClosedChan: true, - } - peerChanInfo[remoteKey] = peerCount - } - } - - return nil - }) - if historicalErr != nil { - return historicalErr - } - - return nil - }, func() { - clear(peerChanInfo) - }) - - return peerChanInfo, err -} - -// channelSelector describes a function that takes a chain-hash bucket from -// within the open-channel DB and returns the wanted channel point bytes, and -// channel point. It must return the ErrChannelNotFound error if the wanted -// channel is not in the given bucket. -type channelSelector func(chainBkt walletdb.ReadBucket) ([]byte, *wire.OutPoint, - error) - -// channelScanner will traverse the DB to each chain-hash bucket of each node -// pub-key bucket in the open-channel-bucket. The chanSelector will then be used -// to fetch the wanted channel outpoint from the chain bucket. -func (c *ChannelStateDB) channelScanner(tx kvdb.RTx, - chanSelect channelSelector) (*OpenChannel, error) { - - var ( - targetChan *OpenChannel - - // errChanFound is used to signal that the channel has been - // found so that iteration through the DB buckets can stop. - errChanFound = errors.New("channel found") - ) - - // chanScan will traverse the following bucket structure: - // * nodePub => chainHash => chanPoint - // - // At each level we go one further, ensuring that we're traversing the - // proper key (that's actually a bucket). By only reading the bucket - // structure and skipping fully decoding each channel, we save a good - // bit of CPU as we don't need to do things like decompress public - // keys. - chanScan := func(tx kvdb.RTx) error { - // Get the bucket dedicated to storing the metadata for open - // channels. - openChanBucket := tx.ReadBucket(openChannelBucket) - if openChanBucket == nil { - return ErrNoActiveChannels - } - - // Hoist the outpoint-bucket lookup so the closed-channel - // check inside the per-chain ForEach below pays one tx-level - // bucket resolve total instead of one per visited chanKey. - opBucket := tx.ReadBucket(outpointBucket) - - // Within the node channel bucket, are the set of node pubkeys - // we have channels with, we don't know the entire set, so we'll - // check them all. - return openChanBucket.ForEach(func(nodePub, v []byte) error { - // Ensure that this is a key the same size as a pubkey, - // and also that it leads directly to a bucket. - if len(nodePub) != 33 || v != nil { - return nil - } - - nodeChanBucket := openChanBucket.NestedReadBucket( - nodePub, - ) - if nodeChanBucket == nil { - return nil - } - - // The next layer down is all the chains that this node - // has channels on with us. - return nodeChanBucket.ForEach(func(chainHash, - v []byte) error { - - // If there's a value, it's not a bucket so - // ignore it. - if v != nil { - return nil - } - - chainBucket := nodeChanBucket.NestedReadBucket( - chainHash, - ) - if chainBucket == nil { - return fmt.Errorf("unable to read "+ - "bucket for chain=%x", - chainHash) - } - - // Finally, we reach the leaf bucket that stores - // all the chanPoints for this node. - targetChanBytes, chanPoint, err := chanSelect( - chainBucket, - ) - if errors.Is(err, ErrChannelNotFound) { - return nil - } else if err != nil { - return err - } - - // An already-closed channel is logically gone - // and must not be surfaced by lookup-style - // scans. - isClosed, err := isOutpointClosed( - opBucket, targetChanBytes, - ) - if err != nil { - return err - } - if isClosed { - return nil - } - - chanBucket := chainBucket.NestedReadBucket( - targetChanBytes, - ) - if chanBucket == nil { - return nil - } - - channel, err := fetchOpenChannel( - chanBucket, chanPoint, - ) - if err != nil { - return err - } - - targetChan = channel - targetChan.Db = c - - return errChanFound - }) - }) - } - - var err error - if tx == nil { - err = kvdb.View(c.backend, chanScan, func() {}) - } else { - err = chanScan(tx) - } - if err != nil && !errors.Is(err, errChanFound) { - return nil, err - } - - if targetChan != nil { - return targetChan, nil - } - - // If we can't find the channel, then we return with an error, as we - // have nothing to back up. - return nil, ErrChannelNotFound + return c.kvStore.FetchPermAndTempPeers(chainHash) } // FetchAllChannels attempts to retrieve all open channels currently stored // within the database, including pending open, fully open and channels waiting // for a closing transaction to confirm. func (c *ChannelStateDB) FetchAllChannels() ([]*OpenChannel, error) { - return fetchChannels(c) + channels, err := c.kvStore.FetchAllChannels() + if err != nil { + return nil, err + } + + return c.attachOpenChannelStores(channels), nil } // FetchAllOpenChannels will return all channels that have the funding // transaction confirmed, and is not waiting for a closing transaction to be // confirmed. func (c *ChannelStateDB) FetchAllOpenChannels() ([]*OpenChannel, error) { - return fetchChannels( - c, - pendingChannelFilter(false), - waitingCloseFilter(false), - ) + channels, err := c.kvStore.FetchAllOpenChannels() + if err != nil { + return nil, err + } + + return c.attachOpenChannelStores(channels), nil } // FetchPendingChannels will return channels that have completed the process of // generating and broadcasting funding transactions, but whose funding // transactions have yet to be confirmed on the blockchain. func (c *ChannelStateDB) FetchPendingChannels() ([]*OpenChannel, error) { - return fetchChannels(c, - pendingChannelFilter(true), - waitingCloseFilter(false), - ) + channels, err := c.kvStore.FetchPendingChannels() + if err != nil { + return nil, err + } + + return c.attachOpenChannelStores(channels), nil } // FetchWaitingCloseChannels will return all channels that have been opened, @@ -1162,138 +701,12 @@ func (c *ChannelStateDB) FetchPendingChannels() ([]*OpenChannel, error) { // // NOTE: This includes channels that are also pending to be opened. func (c *ChannelStateDB) FetchWaitingCloseChannels() ([]*OpenChannel, error) { - return fetchChannels( - c, waitingCloseFilter(true), - ) -} - -// fetchChannelsFilter applies a filter to channels retrieved in fetchchannels. -// A set of filters can be combined to filter across multiple dimensions. -type fetchChannelsFilter func(channel *OpenChannel) bool - -// pendingChannelFilter returns a filter based on whether channels are pending -// (ie, their funding transaction still needs to confirm). If pending is false, -// channels with confirmed funding transactions are returned. -func pendingChannelFilter(pending bool) fetchChannelsFilter { - return func(channel *OpenChannel) bool { - return channel.IsPending == pending - } -} - -// waitingCloseFilter returns a filter which filters channels based on whether -// they are awaiting the confirmation of their closing transaction. If waiting -// close is true, channels that have had their closing tx broadcast are -// included. If it is false, channels that are not awaiting confirmation of -// their close transaction are returned. -func waitingCloseFilter(waitingClose bool) fetchChannelsFilter { - return func(channel *OpenChannel) bool { - // If the channel is in any other state than Default, - // then it means it is waiting to be closed. - channelWaitingClose := - channel.ChanStatus() != ChanStatusDefault - - // Include the channel if it matches the value for - // waiting close that we are filtering on. - return channelWaitingClose == waitingClose - } -} - -// fetchChannels attempts to retrieve channels currently stored in the -// database. It takes a set of filters which are applied to each channel to -// obtain a set of channels with the desired set of properties. Only channels -// which have a true value returned for *all* of the filters will be returned. -// If no filters are provided, every channel in the open channels bucket will -// be returned. -func fetchChannels(c *ChannelStateDB, filters ...fetchChannelsFilter) ( - []*OpenChannel, error) { - - var channels []*OpenChannel - - err := kvdb.View(c.backend, func(tx kvdb.RTx) error { - // Get the bucket dedicated to storing the metadata for open - // channels. - openChanBucket := tx.ReadBucket(openChannelBucket) - if openChanBucket == nil { - return ErrNoActiveChannels - } - - // Next, fetch the bucket dedicated to storing metadata related - // to all nodes. All keys within this bucket are the serialized - // public keys of all our direct counterparties. - nodeMetaBucket := tx.ReadBucket(nodeInfoBucket) - if nodeMetaBucket == nil { - return fmt.Errorf("node bucket not created") - } - - // Finally for each node public key in the bucket, fetch all - // the channels related to this particular node. - return nodeMetaBucket.ForEach(func(k, v []byte) error { - nodeChanBucket := openChanBucket.NestedReadBucket(k) - if nodeChanBucket == nil { - return nil - } - - return nodeChanBucket.ForEach(func(chainHash, v []byte) error { - // If there's a value, it's not a bucket so - // ignore it. - if v != nil { - return nil - } - - // If we've found a valid chainhash bucket, - // then we'll retrieve that so we can extract - // all the channels. - chainBucket := nodeChanBucket.NestedReadBucket( - chainHash, - ) - if chainBucket == nil { - return fmt.Errorf("unable to read "+ - "bucket for chain=%x", chainHash[:]) - } - - nodeChans, err := c.fetchNodeChannels( - tx, chainBucket, - ) - if err != nil { - return fmt.Errorf("unable to read "+ - "channel for chain_hash=%x, "+ - "node_key=%x: %v", chainHash[:], k, err) - } - for _, channel := range nodeChans { - // includeChannel indicates whether the channel - // meets the criteria specified by our filters. - includeChannel := true - - // Run through each filter and check whether the - // channel should be included. - for _, f := range filters { - // If the channel fails the filter, set - // includeChannel to false and don't bother - // checking the remaining filters. - if !f(channel) { - includeChannel = false - break - } - } - - // If the channel passed every filter, include it in - // our set of channels. - if includeChannel { - channels = append(channels, channel) - } - } - return nil - }) - - }) - }, func() { - channels = nil - }) + channels, err := c.kvStore.FetchWaitingCloseChannels() if err != nil { return nil, err } - return channels, nil + return c.attachOpenChannelStores(channels), nil } // FetchClosedChannels attempts to fetch all closed channels from the database. @@ -1305,78 +718,15 @@ func fetchChannels(c *ChannelStateDB, filters ...fetchChannelsFilter) ( func (c *ChannelStateDB) FetchClosedChannels(pendingOnly bool) ( []*ChannelCloseSummary, error) { - var chanSummaries []*ChannelCloseSummary - - if err := kvdb.View(c.backend, func(tx kvdb.RTx) error { - closeBucket := tx.ReadBucket(closedChannelBucket) - if closeBucket == nil { - return ErrNoClosedChannels - } - - return closeBucket.ForEach(func(chanID []byte, summaryBytes []byte) error { - summaryReader := bytes.NewReader(summaryBytes) - chanSummary, err := deserializeCloseChannelSummary(summaryReader) - if err != nil { - return err - } - - // If the query specified to only include pending - // channels, then we'll skip any channels which aren't - // currently pending. - if !chanSummary.IsPending && pendingOnly { - return nil - } - - chanSummaries = append(chanSummaries, chanSummary) - return nil - }) - }, func() { - chanSummaries = nil - }); err != nil { - return nil, err - } - - return chanSummaries, nil + return c.kvStore.FetchClosedChannels(pendingOnly) } -// ErrClosedChannelNotFound signals that a closed channel could not be found in -// the channeldb. -var ErrClosedChannelNotFound = errors.New("unable to find closed channel summary") - // FetchClosedChannel queries for a channel close summary using the channel // point of the channel in question. func (c *ChannelStateDB) FetchClosedChannel(chanID *wire.OutPoint) ( *ChannelCloseSummary, error) { - var chanSummary *ChannelCloseSummary - if err := kvdb.View(c.backend, func(tx kvdb.RTx) error { - closeBucket := tx.ReadBucket(closedChannelBucket) - if closeBucket == nil { - return ErrClosedChannelNotFound - } - - var b bytes.Buffer - var err error - if err = graphdb.WriteOutpoint(&b, chanID); err != nil { - return err - } - - summaryBytes := closeBucket.Get(b.Bytes()) - if summaryBytes == nil { - return ErrClosedChannelNotFound - } - - summaryReader := bytes.NewReader(summaryBytes) - chanSummary, err = deserializeCloseChannelSummary(summaryReader) - - return err - }, func() { - chanSummary = nil - }); err != nil { - return nil, err - } - - return chanSummary, nil + return c.kvStore.FetchClosedChannel(chanID) } // FetchClosedChannelForID queries for a channel close summary using the @@ -1384,51 +734,7 @@ func (c *ChannelStateDB) FetchClosedChannel(chanID *wire.OutPoint) ( func (c *ChannelStateDB) FetchClosedChannelForID(cid lnwire.ChannelID) ( *ChannelCloseSummary, error) { - var chanSummary *ChannelCloseSummary - if err := kvdb.View(c.backend, func(tx kvdb.RTx) error { - closeBucket := tx.ReadBucket(closedChannelBucket) - if closeBucket == nil { - return ErrClosedChannelNotFound - } - - // The first 30 bytes of the channel ID and outpoint will be - // equal. - cursor := closeBucket.ReadCursor() - op, c := cursor.Seek(cid[:30]) - - // We scan over all possible candidates for this channel ID. - for ; op != nil && bytes.Compare(cid[:30], op[:30]) <= 0; op, c = cursor.Next() { - var outPoint wire.OutPoint - err := graphdb.ReadOutpoint( - bytes.NewReader(op), &outPoint, - ) - if err != nil { - return err - } - - // If the found outpoint does not correspond to this - // channel ID, we continue. - if !cid.IsChanPoint(&outPoint) { - continue - } - - // Deserialize the close summary and return. - r := bytes.NewReader(c) - chanSummary, err = deserializeCloseChannelSummary(r) - if err != nil { - return err - } - - return nil - } - return ErrClosedChannelNotFound - }, func() { - chanSummary = nil - }); err != nil { - return nil, err - } - - return chanSummary, nil + return c.kvStore.FetchClosedChannelForID(cid) } // MarkChanFullyClosed marks a channel as fully closed within the database. A @@ -1678,17 +984,8 @@ func (c *ChannelStateDB) RepairLinkNodes(network wire.BitcoinNet) error { } // ChannelShell is a shell of a channel that is meant to be used for channel -// recovery purposes. It contains a minimal OpenChannel instance along with -// addresses for that target node. -type ChannelShell struct { - // NodeAddrs the set of addresses that this node has known to be - // reachable at in the past. - NodeAddrs []net.Addr - - // Chan is a shell of an OpenChannel, it contains only the items - // required to restore the channel on disk. - Chan *OpenChannel -} +// recovery purposes. +type ChannelShell = chanstate.ChannelShell // RestoreChannelShells is a method that allows the caller to reconstruct the // state of an OpenChannel from the ChannelShell. We'll attempt to write the @@ -1705,7 +1002,10 @@ func (c *ChannelStateDB) RestoreChannelShells(channelShells ...*ChannelShell) er // been restored, this will signal to other sub-systems // to not attempt to use the channel as if it was a // regular one. - channel.chanStatus |= ChanStatusRestored + channel.SetChannelStatusForStore( + channel.ChannelStatusForStore() | + ChanStatusRestored, + ) // First, we'll attempt to create a new open channel // and link node for this channel. If the channel @@ -1713,7 +1013,8 @@ func (c *ChannelStateDB) RestoreChannelShells(channelShells ...*ChannelShell) er // is idempotent, we'll continue to the next step. channel.Db = c err := syncNewChannel( - tx, channel, channelShell.NodeAddrs, + tx, channel, channelShell.NodeAddrs, c.backend, + channel.FundingBroadcastHeight, ) if err != nil { return err @@ -1758,48 +1059,7 @@ func (d *DB) AddrsForNode(_ context.Context, nodePub *btcec.PublicKey) (bool, func (c *ChannelStateDB) AbandonChannel(chanPoint *wire.OutPoint, bestHeight uint32) error { - // With the chanPoint constructed, we'll attempt to find the target - // channel in the database. If we can't find the channel, then we'll - // return the error back to the caller. - dbChan, err := c.FetchChannel(*chanPoint) - switch { - // If the channel wasn't found, then it's possible that it was already - // abandoned from the database. - case err == ErrChannelNotFound: - _, closedErr := c.FetchClosedChannel(chanPoint) - if closedErr != nil { - return closedErr - } - - // If the channel was already closed, then we don't return an - // error as we'd like this step to be repeatable. - return nil - case err != nil: - return err - } - - // Now that we've found the channel, we'll populate a close summary for - // the channel, so we can store as much information for this abounded - // channel as possible. We also ensure that we set Pending to false, to - // indicate that this channel has been "fully" closed. - summary := &ChannelCloseSummary{ - CloseType: Abandoned, - ChanPoint: *chanPoint, - ChainHash: dbChan.ChainHash, - CloseHeight: bestHeight, - RemotePub: dbChan.IdentityPub, - Capacity: dbChan.Capacity, - SettledBalance: dbChan.LocalCommitment.LocalBalance.ToSatoshis(), - ShortChanID: dbChan.ShortChanID(), - RemoteCurrentRevocation: dbChan.RemoteCurrentRevocation, - RemoteNextRevocation: dbChan.RemoteNextRevocation, - LocalChanConfig: dbChan.LocalChanCfg, - } - - // Finally, we'll close the channel in the DB, and return back to the - // caller. We set ourselves as the close initiator because we abandoned - // the channel. - return dbChan.CloseChannel(summary, ChanStatusLocalCloseInitiator) + return c.kvStore.AbandonChannel(chanPoint, bestHeight) } // SaveChannelOpeningState saves the serialized channel state for the provided @@ -1807,14 +1067,7 @@ func (c *ChannelStateDB) AbandonChannel(chanPoint *wire.OutPoint, func (c *ChannelStateDB) SaveChannelOpeningState(outPoint, serializedState []byte) error { - return kvdb.Update(c.backend, func(tx kvdb.RwTx) error { - bucket, err := tx.CreateTopLevelBucket(channelOpeningStateBucket) - if err != nil { - return err - } - - return bucket.Put(outPoint, serializedState) - }, func() {}) + return c.kvStore.SaveChannelOpeningState(outPoint, serializedState) } // GetChannelOpeningState fetches the serialized channel state for the provided @@ -1823,39 +1076,12 @@ func (c *ChannelStateDB) SaveChannelOpeningState(outPoint, func (c *ChannelStateDB) GetChannelOpeningState(outPoint []byte) ([]byte, error) { - var serializedState []byte - err := kvdb.View(c.backend, func(tx kvdb.RTx) error { - bucket := tx.ReadBucket(channelOpeningStateBucket) - if bucket == nil { - // If the bucket does not exist, it means we never added - // a channel to the db, so return ErrChannelNotFound. - return ErrChannelNotFound - } - - stateBytes := bucket.Get(outPoint) - if stateBytes == nil { - return ErrChannelNotFound - } - - serializedState = append(serializedState, stateBytes...) - - return nil - }, func() { - serializedState = nil - }) - return serializedState, err + return c.kvStore.GetChannelOpeningState(outPoint) } // DeleteChannelOpeningState removes any state for outPoint from the database. func (c *ChannelStateDB) DeleteChannelOpeningState(outPoint []byte) error { - return kvdb.Update(c.backend, func(tx kvdb.RwTx) error { - bucket := tx.ReadWriteBucket(channelOpeningStateBucket) - if bucket == nil { - return ErrChannelNotFound - } - - return bucket.Delete(outPoint) - }, func() {}) + return c.kvStore.DeleteChannelOpeningState(outPoint) } // syncVersions function is used for safe db version synchronization. It @@ -2054,136 +1280,29 @@ func getMigrationsToApply(versions []mandatoryVersion, return migrations, migrationVersions } -// fetchHistoricalChanBucket returns a the channel bucket for a given outpoint -// from the historical channel bucket. If the bucket does not exist, -// ErrNoHistoricalBucket is returned. -func fetchHistoricalChanBucket(tx kvdb.RTx, - outPoint *wire.OutPoint) (kvdb.RBucket, error) { - - // First fetch the top level bucket which stores all data related to - // historically stored channels. - historicalChanBucket := tx.ReadBucket(historicalChannelBucket) - if historicalChanBucket == nil { - return nil, ErrNoHistoricalBucket - } - - // With the bucket for the node and chain fetched, we can now go down - // another level, for the channel itself. - var chanPointBuf bytes.Buffer - if err := graphdb.WriteOutpoint(&chanPointBuf, outPoint); err != nil { - return nil, err - } - chanBucket := historicalChanBucket.NestedReadBucket( - chanPointBuf.Bytes(), - ) - if chanBucket == nil { - return nil, ErrChannelNotFound - } - - return chanBucket, nil -} - // FetchHistoricalChannel fetches open channel data from the historical channel // bucket. func (c *ChannelStateDB) FetchHistoricalChannel(outPoint *wire.OutPoint) ( *OpenChannel, error) { - var channel *OpenChannel - err := kvdb.View(c.backend, func(tx kvdb.RTx) error { - chanBucket, err := fetchHistoricalChanBucket(tx, outPoint) - if err != nil { - return err - } - - channel, err = fetchOpenChannel(chanBucket, outPoint) - if err != nil { - return err - } - - channel.Db = c - return nil - }, func() { - channel = nil - }) + channel, err := c.kvStore.FetchHistoricalChannel(outPoint) if err != nil { return nil, err } - return channel, nil -} - -func fetchFinalHtlcsBucket(tx kvdb.RTx, - chanID lnwire.ShortChannelID) (kvdb.RBucket, error) { + channel.Db = c - finalHtlcsBucket := tx.ReadBucket(finalHtlcsBucket) - if finalHtlcsBucket == nil { - return nil, ErrFinalHtlcsBucketNotFound - } - - var chanIDBytes [8]byte - byteOrder.PutUint64(chanIDBytes[:], chanID.ToUint64()) - - chanBucket := finalHtlcsBucket.NestedReadBucket(chanIDBytes[:]) - if chanBucket == nil { - return nil, ErrFinalChannelBucketNotFound - } - - return chanBucket, nil + return channel, nil } -var ErrHtlcUnknown = errors.New("htlc unknown") +var ErrHtlcUnknown = chanstate.ErrHtlcUnknown // LookupFinalHtlc retrieves a final htlc resolution from the database. If the // htlc has no final resolution yet, ErrHtlcUnknown is returned. func (c *ChannelStateDB) LookupFinalHtlc(chanID lnwire.ShortChannelID, htlcIndex uint64) (*FinalHtlcInfo, error) { - var idBytes [8]byte - byteOrder.PutUint64(idBytes[:], htlcIndex) - - var settledByte byte - - err := kvdb.View(c.backend, func(tx kvdb.RTx) error { - finalHtlcsBucket, err := fetchFinalHtlcsBucket( - tx, chanID, - ) - switch { - case errors.Is(err, ErrFinalHtlcsBucketNotFound): - fallthrough - - case errors.Is(err, ErrFinalChannelBucketNotFound): - return ErrHtlcUnknown - - case err != nil: - return fmt.Errorf("cannot fetch final htlcs bucket: %w", - err) - } - - value := finalHtlcsBucket.Get(idBytes[:]) - if value == nil { - return ErrHtlcUnknown - } - - if len(value) != 1 { - return errors.New("unexpected final htlc value length") - } - - settledByte = value[0] - - return nil - }, func() { - settledByte = 0 - }) - if err != nil { - return nil, err - } - - info := FinalHtlcInfo{ - Settled: settledByte&byte(FinalHtlcSettledBit) != 0, - Offchain: settledByte&byte(FinalHtlcOffchainBit) != 0, - } - - return &info, nil + return c.kvStore.LookupFinalHtlc(chanID, htlcIndex) } // PutOnchainFinalHtlcOutcome stores the final on-chain outcome of an htlc in @@ -2191,25 +1310,7 @@ func (c *ChannelStateDB) LookupFinalHtlc(chanID lnwire.ShortChannelID, func (c *ChannelStateDB) PutOnchainFinalHtlcOutcome( chanID lnwire.ShortChannelID, htlcID uint64, settled bool) error { - // Skip if the user did not opt in to storing final resolutions. - if !c.parent.storeFinalHtlcResolutions { - return nil - } - - return kvdb.Update(c.backend, func(tx kvdb.RwTx) error { - finalHtlcsBucket, err := fetchFinalHtlcsBucketRw(tx, chanID) - if err != nil { - return err - } - - return putFinalHtlc( - finalHtlcsBucket, htlcID, - FinalHtlcInfo{ - Settled: settled, - Offchain: false, - }, - ) - }, func() {}) + return c.kvStore.PutOnchainFinalHtlcOutcome(chanID, htlcID, settled) } // MakeTestInvoiceDB is used to create a test invoice database for testing diff --git a/channeldb/db_test.go b/channeldb/db_test.go index 277820b10cd..d66bafc2991 100644 --- a/channeldb/db_test.go +++ b/channeldb/db_test.go @@ -307,33 +307,37 @@ func genRandomChannelShell() (*ChannelShell, error) { CsvDelay: uint16(rand.Int63()), } + channel := &OpenChannel{ + ChainHash: rev, + FundingOutpoint: chanPoint, + ShortChannelID: lnwire.NewShortChanIDFromInt( + uint64(rand.Int63()), + ), + IdentityPub: pub, + LocalChanCfg: ChannelConfig{ + CommitmentParams: commitParams, + PaymentBasePoint: keychain.KeyDescriptor{ + KeyLocator: keychain.KeyLocator{ + Family: keychain.KeyFamily( + rand.Int63(), + ), + Index: uint32(rand.Int63()), + }, + }, + }, + RemoteCurrentRevocation: pub, + IsPending: false, + RevocationStore: shachain.NewRevocationStore(), + RevocationProducer: shaChainProducer, + } + channel.SetChannelStatusForStore(chanStatus) + return &ChannelShell{ NodeAddrs: []net.Addr{&net.TCPAddr{ IP: net.ParseIP("127.0.0.1"), Port: 18555, }}, - Chan: &OpenChannel{ - chanStatus: chanStatus, - ChainHash: rev, - FundingOutpoint: chanPoint, - ShortChannelID: lnwire.NewShortChanIDFromInt( - uint64(rand.Int63()), - ), - IdentityPub: pub, - LocalChanCfg: ChannelConfig{ - CommitmentParams: commitParams, - PaymentBasePoint: keychain.KeyDescriptor{ - KeyLocator: keychain.KeyLocator{ - Family: keychain.KeyFamily(rand.Int63()), - Index: uint32(rand.Int63()), - }, - }, - }, - RemoteCurrentRevocation: pub, - IsPending: false, - RevocationStore: shachain.NewRevocationStore(), - RevocationProducer: shaChainProducer, - }, + Chan: channel, }, nil } @@ -403,7 +407,7 @@ func TestRestoreChannelShells(t *testing.T) { } if !nodeChans[0].HasChanStatus(ChanStatusRestored) { t.Fatalf("node has wrong status flags: %v", - nodeChans[0].chanStatus) + nodeChans[0].ChanStatus()) } // We should also be able to find the channel if we query for it @@ -473,12 +477,12 @@ func TestAbandonChannel(t *testing.T) { require.NoError(t, err, "unable to abandon channel") } -// TestFetchChannels tests the filtering of open channels in fetchChannels. -// It tests the case where no filters are provided (which is equivalent to -// FetchAllOpenChannels) and every combination of pending and waiting close. +// TestFetchChannels tests the filtering of open channels exposed by the +// public fetch methods. func TestFetchChannels(t *testing.T) { // Create static channel IDs for each kind of channel retrieved by - // fetchChannels so that the expected channel IDs can be set in tests. + // the fetch methods so that the expected channel IDs can be set in + // tests. var ( // Pending is a channel that is pending open, and has not had // a close initiated. @@ -498,12 +502,12 @@ func TestFetchChannels(t *testing.T) { tests := []struct { name string - filters []fetchChannelsFilter + fetch func(*ChannelStateDB) ([]*OpenChannel, error) expectedChannels map[lnwire.ShortChannelID]bool }{ { - name: "get all channels", - filters: []fetchChannelsFilter{}, + name: "get all channels", + fetch: (*ChannelStateDB).FetchAllChannels, expectedChannels: map[lnwire.ShortChannelID]bool{ pendingChan: true, pendingWaitingChan: true, @@ -512,30 +516,22 @@ func TestFetchChannels(t *testing.T) { }, }, { - name: "pending channels", - filters: []fetchChannelsFilter{ - pendingChannelFilter(true), - }, + name: "pending channels", + fetch: (*ChannelStateDB).FetchPendingChannels, expectedChannels: map[lnwire.ShortChannelID]bool{ - pendingChan: true, - pendingWaitingChan: true, + pendingChan: true, }, }, { - name: "open channels", - filters: []fetchChannelsFilter{ - pendingChannelFilter(false), - }, + name: "open channels", + fetch: (*ChannelStateDB).FetchAllOpenChannels, expectedChannels: map[lnwire.ShortChannelID]bool{ - openChan: true, - openWaitingChan: true, + openChan: true, }, }, { - name: "waiting close channels", - filters: []fetchChannelsFilter{ - waitingCloseFilter(true), - }, + name: "waiting close channels", + fetch: (*ChannelStateDB).FetchWaitingCloseChannels, expectedChannels: map[lnwire.ShortChannelID]bool{ pendingWaitingChan: true, openWaitingChan: true, @@ -543,54 +539,30 @@ func TestFetchChannels(t *testing.T) { }, { name: "not waiting close channels", - filters: []fetchChannelsFilter{ - waitingCloseFilter(false), + fetch: func(cdb *ChannelStateDB) ( + []*OpenChannel, error) { + + pendingChans, err := cdb.FetchPendingChannels() + if err != nil { + return nil, err + } + + openChannels, err := cdb.FetchAllOpenChannels() + if err != nil { + return nil, err + } + + pendingChans = append( + pendingChans, openChannels..., + ) + + return pendingChans, nil }, expectedChannels: map[lnwire.ShortChannelID]bool{ pendingChan: true, openChan: true, }, }, - { - name: "pending waiting", - filters: []fetchChannelsFilter{ - pendingChannelFilter(true), - waitingCloseFilter(true), - }, - expectedChannels: map[lnwire.ShortChannelID]bool{ - pendingWaitingChan: true, - }, - }, - { - name: "pending, not waiting", - filters: []fetchChannelsFilter{ - pendingChannelFilter(true), - waitingCloseFilter(false), - }, - expectedChannels: map[lnwire.ShortChannelID]bool{ - pendingChan: true, - }, - }, - { - name: "open waiting", - filters: []fetchChannelsFilter{ - pendingChannelFilter(false), - waitingCloseFilter(true), - }, - expectedChannels: map[lnwire.ShortChannelID]bool{ - openWaitingChan: true, - }, - }, - { - name: "open, not waiting", - filters: []fetchChannelsFilter{ - pendingChannelFilter(false), - waitingCloseFilter(false), - }, - expectedChannels: map[lnwire.ShortChannelID]bool{ - openChan: true, - }, - }, } for _, test := range tests { @@ -649,7 +621,7 @@ func TestFetchChannels(t *testing.T) { t.Fatalf("unexpected error: %v", err) } - channels, err := fetchChannels(cdb, test.filters...) + channels, err := test.fetch(cdb) if err != nil { t.Fatalf("unexpected error: %v", err) } diff --git a/channeldb/error.go b/channeldb/error.go index c2b2dde0d73..3a7e5e24394 100644 --- a/channeldb/error.go +++ b/channeldb/error.go @@ -2,17 +2,18 @@ package channeldb import ( "fmt" + + cstate "github.com/lightningnetwork/lnd/chanstate" ) var ( // ErrNoChanDBExists is returned when a channel bucket hasn't been // created. - ErrNoChanDBExists = fmt.Errorf("channel db has not yet been created") + ErrNoChanDBExists = cstate.ErrNoChanDBExists // ErrNoHistoricalBucket is returned when the historical channel bucket // not been created yet. - ErrNoHistoricalBucket = fmt.Errorf("historical channel bucket has " + - "not yet been created") + ErrNoHistoricalBucket = cstate.ErrNoHistoricalBucket // ErrDBReversion is returned when detecting an attempt to revert to a // prior database version. @@ -24,11 +25,11 @@ var ( // ErrNoActiveChannels is returned when there is no active (open) // channels within the database. - ErrNoActiveChannels = fmt.Errorf("no active channels exist") + ErrNoActiveChannels = cstate.ErrNoActiveChannels // ErrNoPastDeltas is returned when the channel delta bucket hasn't been // created. - ErrNoPastDeltas = fmt.Errorf("channel has no recorded deltas") + ErrNoPastDeltas = cstate.ErrNoPastDeltas // ErrNodeNotFound is returned when node bucket exists, but node with // specific identity can't be found. @@ -36,7 +37,7 @@ var ( // ErrChannelNotFound is returned when we attempt to locate a channel // for a specific chain, but it is not found. - ErrChannelNotFound = fmt.Errorf("channel not found") + ErrChannelNotFound = cstate.ErrChannelNotFound // ErrMetaNotFound is returned when meta bucket hasn't been // created. @@ -44,7 +45,11 @@ var ( // ErrNoClosedChannels is returned when a node is queries for all the // channels it has closed, but it hasn't yet closed any channels. - ErrNoClosedChannels = fmt.Errorf("no channel have been closed yet") + ErrNoClosedChannels = cstate.ErrNoClosedChannels + + // ErrClosedChannelNotFound signals that a closed channel could not be + // found in the channeldb. + ErrClosedChannelNotFound = cstate.ErrClosedChannelNotFound // ErrNoForwardingEvents is returned in the case that a query fails due // to the log not having any recorded events. @@ -53,5 +58,5 @@ var ( // ErrChanAlreadyExists is return when the caller attempts to create a // channel with a channel point that is already present in the // database. - ErrChanAlreadyExists = fmt.Errorf("channel already exists") + ErrChanAlreadyExists = cstate.ErrChanAlreadyExists ) diff --git a/channeldb/forwarding_package.go b/channeldb/forwarding_package.go index c393a53b37f..08d85bf1fe0 100644 --- a/channeldb/forwarding_package.go +++ b/channeldb/forwarding_package.go @@ -1,1019 +1,85 @@ package channeldb import ( - "bytes" - "encoding/binary" - "errors" - "fmt" - "io" - - "github.com/lightningnetwork/lnd/kvdb" - "github.com/lightningnetwork/lnd/lnwire" + cstate "github.com/lightningnetwork/lnd/chanstate" ) -// ErrCorruptedFwdPkg signals that the on-disk structure of the forwarding -// package has potentially been mangled. -var ErrCorruptedFwdPkg = errors.New("fwding package db has been corrupted") +type ( + // AddRef is used to identify a particular Add in a FwdPkg. + AddRef = cstate.AddRef -// FwdState is an enum used to describe the lifecycle of a FwdPkg. -type FwdState byte + // SettleFailRef is used to locate a Settle/Fail in another channel's + // FwdPkg. + SettleFailRef = cstate.SettleFailRef -const ( - // FwdStateLockedIn is the starting state for all forwarding packages. - // Packages in this state have not yet committed to the exact set of - // Adds to forward to the switch. - FwdStateLockedIn FwdState = iota + // FwdState is an enum used to describe the lifecycle of a FwdPkg. + FwdState = cstate.FwdState - // FwdStateProcessed marks the state in which all Adds have been - // locally processed and the forwarding decision to the switch has been - // persisted. - FwdStateProcessed + // PkgFilter is used to compactly represent a particular subset of the + // Adds in a forwarding package. + PkgFilter = cstate.PkgFilter - // FwdStateCompleted signals that all Adds have been acked, and that all - // settles and fails have been delivered to their sources. Packages in - // this state can be removed permanently. - FwdStateCompleted -) + // FwdPkg records all adds, settles, and fails that were locked in as a + // result of the remote peer sending us a revocation. + FwdPkg = cstate.FwdPkg -var ( - // fwdPackagesKey is the root-level bucket that all forwarding packages - // are written. This bucket is further subdivided based on the short - // channel ID of each channel. - // - // Bucket hierarchy: - // - // fwdPackagesKey(root-bucket) - // | - // |-- - // | | - // | |-- - // | | |-- ackFilterKey: - // | | |-- settleFailFilterKey: - // | | |-- fwdFilterKey: - // | | | - // | | |-- addBucketKey - // | | | |-- : - // | | | |-- : - // | | | ... - // | | | - // | | |-- failSettleBucketKey - // | | |-- : - // | | |-- : - // | | ... - // | | - // | |-- - // | | | - // | ... ... - // | - // | - // |-- - // | | - // | ... - // ... - // - fwdPackagesKey = []byte("fwd-packages") + // SettleFailAcker is a generic interface providing the ability to + // acknowledge settle/fail HTLCs stored in forwarding packages. + SettleFailAcker = cstate.SettleFailAcker - // addBucketKey is the bucket to which all Add log updates are written. - addBucketKey = []byte("add-updates") + // GlobalFwdPkgReader is an interface used to retrieve the forwarding + // packages of any active channel. + GlobalFwdPkgReader = cstate.GlobalFwdPkgReader - // failSettleBucketKey is the bucket to which all Settle/Fail log - // updates are written. - failSettleBucketKey = []byte("fail-settle-updates") + // FwdOperator defines the interfaces for managing forwarding packages + // that are external to a particular channel. + FwdOperator = cstate.FwdOperator - // fwdFilterKey is a key used to write the set of Adds that passed - // validation and are to be forwarded to the switch. - // NOTE: The presence of this key within a forwarding package indicates - // that the package has reached FwdStateProcessed. - fwdFilterKey = []byte("fwd-filter-key") + // FwdPackager supports all operations required to modify fwd packages, + // such as creation, updates, reading, and removal. + FwdPackager = cstate.FwdPackager - // ackFilterKey is a key used to access the PkgFilter indicating which - // Adds have received a Settle/Fail. This response may come from a - // number of sources, including: exitHop settle/fails, switch failures, - // chain arbiter interjections, as well as settle/fails from the - // next hop in the route. - ackFilterKey = []byte("ack-filter-key") + // SwitchPackager is a concrete implementation of the FwdOperator + // interface. + SwitchPackager = cstate.SwitchPackager - // settleFailFilterKey is a key used to access the PkgFilter indicating - // which Settles/Fails in have been received and processed by the link - // that originally received the Add. - settleFailFilterKey = []byte("settle-fail-filter-key") + // ChannelPackager is used by a channel to manage the lifecycle of its + // forwarding packages. + ChannelPackager = cstate.ChannelPackager ) -// PkgFilter is used to compactly represent a particular subset of the Adds in a -// forwarding package. Each filter is represented as a simple, statically-sized -// bitvector, where the elements are intended to be the indices of the Adds as -// they are written in the FwdPkg. -type PkgFilter struct { - count uint16 - filter []byte -} - -// NewPkgFilter initializes an empty PkgFilter supporting `count` elements. -func NewPkgFilter(count uint16) *PkgFilter { - // We add 7 to ensure that the integer division yields properly rounded - // values. - filterLen := (count + 7) / 8 - - return &PkgFilter{ - count: count, - filter: make([]byte, filterLen), - } -} - -// Count returns the number of elements represented by this PkgFilter. -func (f *PkgFilter) Count() uint16 { - return f.count -} - -// Set marks the `i`-th element as included by this filter. -// NOTE: It is assumed that i is always less than count. -func (f *PkgFilter) Set(i uint16) { - byt := i / 8 - bit := i % 8 - - // Set the i-th bit in the filter. - // TODO(conner): ignore if > count to prevent panic? - f.filter[byt] |= byte(1 << (7 - bit)) -} - -// Contains queries the filter for membership of index `i`. -// NOTE: It is assumed that i is always less than count. -func (f *PkgFilter) Contains(i uint16) bool { - byt := i / 8 - bit := i % 8 - - // Read the i-th bit in the filter. - // TODO(conner): ignore if > count to prevent panic? - return f.filter[byt]&(1<<(7-bit)) != 0 -} - -// Equal checks two PkgFilters for equality. -func (f *PkgFilter) Equal(f2 *PkgFilter) bool { - if f == f2 { - return true - } - if f.count != f2.count { - return false - } - - return bytes.Equal(f.filter, f2.filter) -} - -// IsFull returns true if every element in the filter has been Set, and false -// otherwise. -func (f *PkgFilter) IsFull() bool { - // Batch validate bytes that are fully used. - for i := uint16(0); i < f.count/8; i++ { - if f.filter[i] != 0xFF { - return false - } - } - - // If the count is not a multiple of 8, check that the filter contains - // all remaining bits. - rem := f.count % 8 - for idx := f.count - rem; idx < f.count; idx++ { - if !f.Contains(idx) { - return false - } - } - - return true -} - -// Size returns number of bytes produced when the PkgFilter is serialized. -func (f *PkgFilter) Size() uint16 { - // 2 bytes for uint16 `count`, then round up number of bytes required to - // represent `count` bits. - return 2 + (f.count+7)/8 -} - -// Encode writes the filter to the provided io.Writer. -func (f *PkgFilter) Encode(w io.Writer) error { - if err := binary.Write(w, binary.BigEndian, f.count); err != nil { - return err - } - - _, err := w.Write(f.filter) - - return err -} - -// Decode reads the filter from the provided io.Reader. -func (f *PkgFilter) Decode(r io.Reader) error { - if err := binary.Read(r, binary.BigEndian, &f.count); err != nil { - return err - } - - f.filter = make([]byte, f.Size()-2) - _, err := io.ReadFull(r, f.filter) - - return err -} - -// String returns a human-readable string. -func (f *PkgFilter) String() string { - return fmt.Sprintf("count=%v, filter=%v", f.count, f.filter) -} - -// FwdPkg records all adds, settles, and fails that were locked in as a result -// of the remote peer sending us a revocation. Each package is identified by -// the short chanid and remote commitment height corresponding to the revocation -// that locked in the HTLCs. For everything except a locally initiated payment, -// settles and fails in a forwarding package must have a corresponding Add in -// another package, and can be removed individually once the source link has -// received the fail/settle. -// -// Adds cannot be removed, as we need to present the same batch of Adds to -// properly handle replay protection. Instead, we use a PkgFilter to mark that -// we have finished processing a particular Add. A FwdPkg should only be deleted -// after the AckFilter is full and all settles and fails have been persistently -// removed. -type FwdPkg struct { - // Source identifies the channel that wrote this forwarding package. - Source lnwire.ShortChannelID - - // Height is the height of the remote commitment chain that locked in - // this forwarding package. - Height uint64 - - // State signals the persistent condition of the package and directs how - // to reprocess the package in the event of failures. - State FwdState - - // Adds contains all add messages which need to be processed and - // forwarded to the switch. Adds does not change over the life of a - // forwarding package. - Adds []LogUpdate - - // FwdFilter is a filter containing the indices of all Adds that were - // forwarded to the switch. - // - // NOTE: This value signals when persisted to disk that the fwd package - // has been processed and garbage collection can happen. So it also - // has to be set for packages with no adds (empty packages or only - // settle/fail packages) so that they can be garbage collected as well. - FwdFilter *PkgFilter - - // AckFilter is a filter containing the indices of all Adds for which - // the source has received a settle or fail and is reflected in the next - // commitment txn. A package should not be removed until IsFull() - // returns true. - AckFilter *PkgFilter - - // SettleFails contains all settle and fail messages that should be - // forwarded to the switch. - SettleFails []LogUpdate - - // SettleFailFilter is a filter containing the indices of all Settle or - // Fails originating in this package that have been received and locked - // into the incoming link's commitment state. - SettleFailFilter *PkgFilter -} - -// NewFwdPkg initializes a new forwarding package in FwdStateLockedIn. This -// should be used to create a package at the time we receive a revocation. -func NewFwdPkg(source lnwire.ShortChannelID, height uint64, - addUpdates, settleFailUpdates []LogUpdate) *FwdPkg { - - nAddUpdates := uint16(len(addUpdates)) - nSettleFailUpdates := uint16(len(settleFailUpdates)) - - return &FwdPkg{ - Source: source, - Height: height, - State: FwdStateLockedIn, - Adds: addUpdates, - FwdFilter: NewPkgFilter(nAddUpdates), - AckFilter: NewPkgFilter(nAddUpdates), - SettleFails: settleFailUpdates, - SettleFailFilter: NewPkgFilter(nSettleFailUpdates), - } -} - -// SourceRef is a convenience method that returns an AddRef to this forwarding -// package for the index in the argument. It is the caller's responsibility -// to ensure that the index is in bounds. -func (f *FwdPkg) SourceRef(i uint16) AddRef { - return AddRef{ - Height: f.Height, - Index: i, - } -} - -// DestRef is a convenience method that returns a SettleFailRef to this -// forwarding package for the index in the argument. It is the caller's -// responsibility to ensure that the index is in bounds. -func (f *FwdPkg) DestRef(i uint16) SettleFailRef { - return SettleFailRef{ - Source: f.Source, - Height: f.Height, - Index: i, - } -} - -// ID returns an unique identifier for this package, used to ensure that sphinx -// replay processing of this batch is idempotent. -func (f *FwdPkg) ID() []byte { - var id = make([]byte, 16) - byteOrder.PutUint64(id[:8], f.Source.ToUint64()) - byteOrder.PutUint64(id[8:], f.Height) - return id -} - -// String returns a human-readable description of the forwarding package. -func (f *FwdPkg) String() string { - return fmt.Sprintf("%T(src=%v, height=%v, nadds=%v, nfailsettles=%v)", - f, f.Source, f.Height, len(f.Adds), len(f.SettleFails)) -} - -// AddRef is used to identify a particular Add in a FwdPkg. The short channel ID -// is assumed to be that of the packager. -type AddRef struct { - // Height is the remote commitment height that locked in the Add. - Height uint64 - - // Index is the index of the Add within the fwd pkg's Adds. - // - // NOTE: This index is static over the lifetime of a forwarding package. - Index uint16 -} - -// Encode serializes the AddRef to the given io.Writer. -func (a *AddRef) Encode(w io.Writer) error { - if err := binary.Write(w, binary.BigEndian, a.Height); err != nil { - return err - } - - return binary.Write(w, binary.BigEndian, a.Index) -} - -// Decode deserializes the AddRef from the given io.Reader. -func (a *AddRef) Decode(r io.Reader) error { - if err := binary.Read(r, binary.BigEndian, &a.Height); err != nil { - return err - } - - return binary.Read(r, binary.BigEndian, &a.Index) -} - -// SettleFailRef is used to locate a Settle/Fail in another channel's FwdPkg. A -// channel does not remove its own Settle/Fail htlcs, so the source is provided -// to locate a db bucket belonging to another channel. -type SettleFailRef struct { - // Source identifies the outgoing link that locked in the settle or - // fail. This is then used by the *incoming* link to find the settle - // fail in another link's forwarding packages. - Source lnwire.ShortChannelID - - // Height is the remote commitment height that locked in this - // Settle/Fail. - Height uint64 - - // Index is the index of the Add with the fwd pkg's SettleFails. - // - // NOTE: This index is static over the lifetime of a forwarding package. - Index uint16 -} - -// SettleFailAcker is a generic interface providing the ability to acknowledge -// settle/fail HTLCs stored in forwarding packages. -type SettleFailAcker interface { - // AckSettleFails atomically updates the settle-fail filters in *other* - // channels' forwarding packages. - AckSettleFails(tx kvdb.RwTx, settleFailRefs ...SettleFailRef) error -} - -// GlobalFwdPkgReader is an interface used to retrieve the forwarding packages -// of any active channel. -type GlobalFwdPkgReader interface { - // LoadChannelFwdPkgs loads all known forwarding packages for the given - // channel. - LoadChannelFwdPkgs(tx kvdb.RTx, - source lnwire.ShortChannelID) ([]*FwdPkg, error) -} - -// FwdOperator defines the interfaces for managing forwarding packages that are -// external to a particular channel. This interface is used by the switch to -// read forwarding packages from arbitrary channels, and acknowledge settles and -// fails for locally-sourced payments. -type FwdOperator interface { - // GlobalFwdPkgReader provides read access to all known forwarding - // packages - GlobalFwdPkgReader - - // SettleFailAcker grants the ability to acknowledge settles or fails - // residing in arbitrary forwarding packages. - SettleFailAcker -} - -// SwitchPackager is a concrete implementation of the FwdOperator interface. -// A SwitchPackager offers the ability to read any forwarding package, and ack -// arbitrary settle and fail HTLCs. -type SwitchPackager struct{} - -// NewSwitchPackager instantiates a new SwitchPackager. -func NewSwitchPackager() *SwitchPackager { - return &SwitchPackager{} -} - -// AckSettleFails atomically updates the settle-fail filters in *other* -// channels' forwarding packages, to mark that the switch has received a settle -// or fail residing in the forwarding package of a link. -func (*SwitchPackager) AckSettleFails(tx kvdb.RwTx, - settleFailRefs ...SettleFailRef) error { - - return ackSettleFails(tx, settleFailRefs) -} - -// LoadChannelFwdPkgs loads all forwarding packages for a particular channel. -func (*SwitchPackager) LoadChannelFwdPkgs(tx kvdb.RTx, - source lnwire.ShortChannelID) ([]*FwdPkg, error) { - - return loadChannelFwdPkgs(tx, source) -} - -// FwdPackager supports all operations required to modify fwd packages, such as -// creation, updates, reading, and removal. The interfaces are broken down in -// this way to support future delegation of the subinterfaces. -type FwdPackager interface { - // AddFwdPkg serializes and writes a FwdPkg for this channel at the - // remote commitment height included in the forwarding package. - AddFwdPkg(tx kvdb.RwTx, fwdPkg *FwdPkg) error - - // SetFwdFilter looks up the forwarding package at the remote `height` - // and sets the `fwdFilter`, marking the Adds for which: - // 1) We are not the exit node - // 2) Passed all validation - // 3) Should be forwarded to the switch immediately after a failure - SetFwdFilter(tx kvdb.RwTx, height uint64, fwdFilter *PkgFilter) error - - // AckAddHtlcs atomically updates the add filters in this channel's - // forwarding packages to mark the resolution of an Add that was - // received from the remote party. - AckAddHtlcs(tx kvdb.RwTx, addRefs ...AddRef) error - - // SettleFailAcker allows a link to acknowledge settle/fail HTLCs - // belonging to other channels. - SettleFailAcker - - // LoadFwdPkgs loads all known forwarding packages owned by this - // channel. - LoadFwdPkgs(tx kvdb.RTx) ([]*FwdPkg, error) - - // RemovePkg deletes a forwarding package owned by this channel at - // the provided remote `height`. - RemovePkg(tx kvdb.RwTx, height uint64) error - - // Wipe deletes all the forwarding packages owned by this channel. - Wipe(tx kvdb.RwTx) error -} - -// ChannelPackager is used by a channel to manage the lifecycle of its forwarding -// packages. The packager is tied to a particular source channel ID, allowing it -// to create and edit its own packages. Each packager also has the ability to -// remove fail/settle htlcs that correspond to an add contained in one of -// source's packages. -type ChannelPackager struct { - source lnwire.ShortChannelID -} - -// NewChannelPackager creates a new packager for a single channel. -func NewChannelPackager(source lnwire.ShortChannelID) *ChannelPackager { - return &ChannelPackager{ - source: source, - } -} - -// AddFwdPkg writes a newly locked in forwarding package to disk. -func (*ChannelPackager) AddFwdPkg(tx kvdb.RwTx, fwdPkg *FwdPkg) error { // nolint: dupl - fwdPkgBkt, err := tx.CreateTopLevelBucket(fwdPackagesKey) - if err != nil { - return err - } - - source := makeLogKey(fwdPkg.Source.ToUint64()) - sourceBkt, err := fwdPkgBkt.CreateBucketIfNotExists(source[:]) - if err != nil { - return err - } - - heightKey := makeLogKey(fwdPkg.Height) - heightBkt, err := sourceBkt.CreateBucketIfNotExists(heightKey[:]) - if err != nil { - return err - } - - // Write ADD updates we received at this commit height. - addBkt, err := heightBkt.CreateBucketIfNotExists(addBucketKey) - if err != nil { - return err - } - - // Write SETTLE/FAIL updates we received at this commit height. - failSettleBkt, err := heightBkt.CreateBucketIfNotExists(failSettleBucketKey) - if err != nil { - return err - } - - for i := range fwdPkg.Adds { - err = putLogUpdate(addBkt, uint16(i), &fwdPkg.Adds[i]) - if err != nil { - return err - } - } - - // Persist the initialized pkg filter, which will be used to determine - // when we can remove this forwarding package from disk. - var ackFilterBuf bytes.Buffer - if err := fwdPkg.AckFilter.Encode(&ackFilterBuf); err != nil { - return err - } - - if err := heightBkt.Put(ackFilterKey, ackFilterBuf.Bytes()); err != nil { - return err - } - - for i := range fwdPkg.SettleFails { - err = putLogUpdate(failSettleBkt, uint16(i), &fwdPkg.SettleFails[i]) - if err != nil { - return err - } - } - - var settleFailFilterBuf bytes.Buffer - err = fwdPkg.SettleFailFilter.Encode(&settleFailFilterBuf) - if err != nil { - return err - } - - return heightBkt.Put(settleFailFilterKey, settleFailFilterBuf.Bytes()) -} - -// putLogUpdate writes an htlc to the provided `bkt`, using `index` as the key. -func putLogUpdate(bkt kvdb.RwBucket, idx uint16, htlc *LogUpdate) error { - var b bytes.Buffer - if err := serializeLogUpdate(&b, htlc); err != nil { - return err - } - - return bkt.Put(uint16Key(idx), b.Bytes()) -} - -// LoadFwdPkgs scans the forwarding log for any packages that haven't been -// processed, and returns their deserialized log updates in a map indexed by the -// remote commitment height at which the updates were locked in. -func (p *ChannelPackager) LoadFwdPkgs(tx kvdb.RTx) ([]*FwdPkg, error) { - return loadChannelFwdPkgs(tx, p.source) -} - -// loadChannelFwdPkgs loads all forwarding packages owned by `source`. -func loadChannelFwdPkgs(tx kvdb.RTx, source lnwire.ShortChannelID) ([]*FwdPkg, error) { - fwdPkgBkt := tx.ReadBucket(fwdPackagesKey) - if fwdPkgBkt == nil { - return nil, nil - } - - sourceKey := makeLogKey(source.ToUint64()) - sourceBkt := fwdPkgBkt.NestedReadBucket(sourceKey[:]) - if sourceBkt == nil { - return nil, nil - } - - var heights []uint64 - if err := sourceBkt.ForEach(func(k, _ []byte) error { - if len(k) != 8 { - return ErrCorruptedFwdPkg - } - - heights = append(heights, byteOrder.Uint64(k)) - - return nil - }); err != nil { - return nil, err - } - - // Load the forwarding package for each retrieved height. - fwdPkgs := make([]*FwdPkg, 0, len(heights)) - for _, height := range heights { - fwdPkg, err := loadFwdPkg(fwdPkgBkt, source, height) - if err != nil { - return nil, err - } - - fwdPkgs = append(fwdPkgs, fwdPkg) - } - - return fwdPkgs, nil -} - -// loadFwdPkg reads the packager's fwd pkg at a given height, and determines the -// appropriate FwdState. -func loadFwdPkg(fwdPkgBkt kvdb.RBucket, source lnwire.ShortChannelID, - height uint64) (*FwdPkg, error) { - - sourceKey := makeLogKey(source.ToUint64()) - sourceBkt := fwdPkgBkt.NestedReadBucket(sourceKey[:]) - if sourceBkt == nil { - return nil, ErrCorruptedFwdPkg - } - - heightKey := makeLogKey(height) - heightBkt := sourceBkt.NestedReadBucket(heightKey[:]) - if heightBkt == nil { - return nil, ErrCorruptedFwdPkg - } - - // Load ADDs from disk. - addBkt := heightBkt.NestedReadBucket(addBucketKey) - if addBkt == nil { - return nil, ErrCorruptedFwdPkg - } - - adds, err := loadHtlcs(addBkt) - if err != nil { - return nil, err - } - - // Load ack filter from disk. - ackFilterBytes := heightBkt.Get(ackFilterKey) - if ackFilterBytes == nil { - return nil, ErrCorruptedFwdPkg - } - ackFilterReader := bytes.NewReader(ackFilterBytes) - - ackFilter := &PkgFilter{} - if err := ackFilter.Decode(ackFilterReader); err != nil { - return nil, err - } - - // Load SETTLE/FAILs from disk. - failSettleBkt := heightBkt.NestedReadBucket(failSettleBucketKey) - if failSettleBkt == nil { - return nil, ErrCorruptedFwdPkg - } - - failSettles, err := loadHtlcs(failSettleBkt) - if err != nil { - return nil, err - } - - // Load settle fail filter from disk. - settleFailFilterBytes := heightBkt.Get(settleFailFilterKey) - if settleFailFilterBytes == nil { - return nil, ErrCorruptedFwdPkg - } - settleFailFilterReader := bytes.NewReader(settleFailFilterBytes) - - settleFailFilter := &PkgFilter{} - if err := settleFailFilter.Decode(settleFailFilterReader); err != nil { - return nil, err - } - - // Initialize the fwding package, which always starts in the - // FwdStateLockedIn. We can determine what state the package was left in - // by examining constraints on the information loaded from disk. - fwdPkg := &FwdPkg{ - Source: source, - State: FwdStateLockedIn, - Height: height, - Adds: adds, - AckFilter: ackFilter, - SettleFails: failSettles, - SettleFailFilter: settleFailFilter, - } - - // Check if the forward filter has been persisted to disk. - // This indicates whether the Adds in this package have been processed. - // - // NOTE: We also expect packages with no Adds (settle/fail only packages - // or empty packages) to have the fwd filter set to signal that the - // packages have been processed. - fwdFilterBytes := heightBkt.Get(fwdFilterKey) - - // Handle packages with Adds that haven't been processed yet. - if fwdFilterBytes == nil { - // Create a new forward filter for the unprocessed Adds. - nAdds := uint16(len(adds)) - fwdPkg.FwdFilter = NewPkgFilter(nAdds) - - return fwdPkg, nil - } - - // Load the existing forward filter from disk. - fwdFilterReader := bytes.NewReader(fwdFilterBytes) - fwdPkg.FwdFilter = &PkgFilter{} - if err := fwdPkg.FwdFilter.Decode(fwdFilterReader); err != nil { - return nil, err - } - - // Mark the package as processed since the forward filter exists. - fwdPkg.State = FwdStateProcessed - - // If every add, settle, and fail has been fully acknowledged, we can - // safely set the package's state to FwdStateCompleted, signalling that - // it can be garbage collected. - if fwdPkg.AckFilter.IsFull() && fwdPkg.SettleFailFilter.IsFull() { - fwdPkg.State = FwdStateCompleted - } - - return fwdPkg, nil -} - -// loadHtlcs retrieves all serialized htlcs in a bucket, returning -// them in order of the indexes they were written under. -func loadHtlcs(bkt kvdb.RBucket) ([]LogUpdate, error) { - var htlcs []LogUpdate - if err := bkt.ForEach(func(_, v []byte) error { - htlc, err := deserializeLogUpdate(bytes.NewReader(v)) - if err != nil { - return err - } - - htlcs = append(htlcs, *htlc) - - return nil - }); err != nil { - return nil, err - } - - return htlcs, nil -} - -// SetFwdFilter writes the set of indexes corresponding to Adds at the -// `height` that are to be forwarded to the switch. Calling this method causes -// the forwarding package at `height` to be in FwdStateProcessed. We write this -// forwarding decision so that we always arrive at the same behavior for HTLCs -// leaving this channel. After a restart, we skip validation of these Adds, -// since they are assumed to have already been validated, and make the switch or -// outgoing link responsible for handling replays. -func (p *ChannelPackager) SetFwdFilter(tx kvdb.RwTx, height uint64, - fwdFilter *PkgFilter) error { - - fwdPkgBkt := tx.ReadWriteBucket(fwdPackagesKey) - if fwdPkgBkt == nil { - return ErrCorruptedFwdPkg - } - - source := makeLogKey(p.source.ToUint64()) - sourceBkt := fwdPkgBkt.NestedReadWriteBucket(source[:]) - if sourceBkt == nil { - return ErrCorruptedFwdPkg - } - - heightKey := makeLogKey(height) - heightBkt := sourceBkt.NestedReadWriteBucket(heightKey[:]) - if heightBkt == nil { - return ErrCorruptedFwdPkg - } - - // If the fwd filter has already been written, we return early to avoid - // modifying the persistent state. - forwardedAddsBytes := heightBkt.Get(fwdFilterKey) - if forwardedAddsBytes != nil { - return nil - } - - // Otherwise we serialize and write the provided fwd filter. - var b bytes.Buffer - if err := fwdFilter.Encode(&b); err != nil { - return err - } - - return heightBkt.Put(fwdFilterKey, b.Bytes()) -} - -// AckAddHtlcs accepts a list of references to add htlcs, and updates the -// AckAddFilter of those forwarding packages to indicate that a settle or fail -// has been received in response to the add. -func (p *ChannelPackager) AckAddHtlcs(tx kvdb.RwTx, addRefs ...AddRef) error { - if len(addRefs) == 0 { - return nil - } - - fwdPkgBkt := tx.ReadWriteBucket(fwdPackagesKey) - if fwdPkgBkt == nil { - return ErrCorruptedFwdPkg - } - - sourceKey := makeLogKey(p.source.ToUint64()) - sourceBkt := fwdPkgBkt.NestedReadWriteBucket(sourceKey[:]) - if sourceBkt == nil { - return ErrCorruptedFwdPkg - } - - // Organize the forward references such that we just get a single slice - // of indexes for each unique height. - heightDiffs := make(map[uint64][]uint16) - for _, addRef := range addRefs { - heightDiffs[addRef.Height] = append( - heightDiffs[addRef.Height], - addRef.Index, - ) - } - - // Load each height bucket once and remove all acked htlcs at that - // height. - for height, indexes := range heightDiffs { - err := ackAddHtlcsAtHeight(sourceBkt, height, indexes) - if err != nil { - return err - } - } - - return nil -} - -// ackAddHtlcsAtHeight updates the AddAckFilter of a single forwarding package -// with a list of indexes, writing the resulting filter back in its place. -func ackAddHtlcsAtHeight(sourceBkt kvdb.RwBucket, height uint64, - indexes []uint16) error { - - heightKey := makeLogKey(height) - heightBkt := sourceBkt.NestedReadWriteBucket(heightKey[:]) - if heightBkt == nil { - // If the height bucket isn't found, this could be because the - // forwarding package was already removed. We'll return nil to - // signal that the operation is successful, as there is nothing - // to ack. - return nil - } - - // Load ack filter from disk. - ackFilterBytes := heightBkt.Get(ackFilterKey) - if ackFilterBytes == nil { - return ErrCorruptedFwdPkg - } - - ackFilter := &PkgFilter{} - ackFilterReader := bytes.NewReader(ackFilterBytes) - if err := ackFilter.Decode(ackFilterReader); err != nil { - return err - } - - // Update the ack filter for this height. - for _, index := range indexes { - ackFilter.Set(index) - } - - // Write the resulting filter to disk. - var ackFilterBuf bytes.Buffer - if err := ackFilter.Encode(&ackFilterBuf); err != nil { - return err - } - - return heightBkt.Put(ackFilterKey, ackFilterBuf.Bytes()) -} - -// AckSettleFails persistently acknowledges settles or fails from a remote forwarding -// package. This should only be called after the source of the Add has locked in -// the settle/fail, or it becomes otherwise safe to forgo retransmitting the -// settle/fail after a restart. -func (p *ChannelPackager) AckSettleFails(tx kvdb.RwTx, settleFailRefs ...SettleFailRef) error { - return ackSettleFails(tx, settleFailRefs) -} - -// ackSettleFails persistently acknowledges a batch of settle fail references. -func ackSettleFails(tx kvdb.RwTx, settleFailRefs []SettleFailRef) error { - if len(settleFailRefs) == 0 { - return nil - } - - fwdPkgBkt := tx.ReadWriteBucket(fwdPackagesKey) - if fwdPkgBkt == nil { - return ErrCorruptedFwdPkg - } - - // Organize the forward references such that we just get a single slice - // of indexes for each unique destination-height pair. - destHeightDiffs := make(map[lnwire.ShortChannelID]map[uint64][]uint16) - for _, settleFailRef := range settleFailRefs { - destHeights, ok := destHeightDiffs[settleFailRef.Source] - if !ok { - destHeights = make(map[uint64][]uint16) - destHeightDiffs[settleFailRef.Source] = destHeights - } - - destHeights[settleFailRef.Height] = append( - destHeights[settleFailRef.Height], - settleFailRef.Index, - ) - } - - // With the references organized by destination and height, we now load - // each remote bucket, and update the settle fail filter for any - // settle/fail htlcs. - for dest, destHeights := range destHeightDiffs { - destKey := makeLogKey(dest.ToUint64()) - destBkt := fwdPkgBkt.NestedReadWriteBucket(destKey[:]) - if destBkt == nil { - // If the destination bucket is not found, this is - // likely the result of the destination channel being - // closed and having it's forwarding packages wiped. We - // won't treat this as an error, because the response - // will no longer be retransmitted internally. - continue - } - - for height, indexes := range destHeights { - err := ackSettleFailsAtHeight(destBkt, height, indexes) - if err != nil { - return err - } - } - } - - return nil -} - -// ackSettleFailsAtHeight given a destination bucket, acks the provided indexes -// at particular a height by updating the settle fail filter. -func ackSettleFailsAtHeight(destBkt kvdb.RwBucket, height uint64, - indexes []uint16) error { - - heightKey := makeLogKey(height) - heightBkt := destBkt.NestedReadWriteBucket(heightKey[:]) - if heightBkt == nil { - // If the height bucket isn't found, this could be because the - // forwarding package was already removed. We'll return nil to - // signal that the operation is as there is nothing to ack. - return nil - } - - // Load ack filter from disk. - settleFailFilterBytes := heightBkt.Get(settleFailFilterKey) - if settleFailFilterBytes == nil { - return ErrCorruptedFwdPkg - } - - settleFailFilter := &PkgFilter{} - settleFailFilterReader := bytes.NewReader(settleFailFilterBytes) - if err := settleFailFilter.Decode(settleFailFilterReader); err != nil { - return err - } - - // Update the ack filter for this height. - for _, index := range indexes { - settleFailFilter.Set(index) - } - - // Write the resulting filter to disk. - var settleFailFilterBuf bytes.Buffer - if err := settleFailFilter.Encode(&settleFailFilterBuf); err != nil { - return err - } - - return heightBkt.Put(settleFailFilterKey, settleFailFilterBuf.Bytes()) -} - -// RemovePkg deletes the forwarding package at the given height from the -// packager's source bucket. -func (p *ChannelPackager) RemovePkg(tx kvdb.RwTx, height uint64) error { - fwdPkgBkt := tx.ReadWriteBucket(fwdPackagesKey) - if fwdPkgBkt == nil { - return nil - } - - sourceBytes := makeLogKey(p.source.ToUint64()) - sourceBkt := fwdPkgBkt.NestedReadWriteBucket(sourceBytes[:]) - if sourceBkt == nil { - return ErrCorruptedFwdPkg - } - - heightKey := makeLogKey(height) +const ( + // FwdStateLockedIn is the starting state for all forwarding packages. + FwdStateLockedIn = cstate.FwdStateLockedIn - return sourceBkt.DeleteNestedBucket(heightKey[:]) -} + // FwdStateProcessed marks the state in which all Adds have been + // locally processed. + FwdStateProcessed = cstate.FwdStateProcessed -// Wipe deletes all the channel's forwarding packages, if any. -func (p *ChannelPackager) Wipe(tx kvdb.RwTx) error { - // If the root bucket doesn't exist, there's no need to delete. - fwdPkgBkt := tx.ReadWriteBucket(fwdPackagesKey) - if fwdPkgBkt == nil { - return nil - } + // FwdStateCompleted signals that all Adds have been acked, and that + // all settles and fails have been delivered to their sources. + FwdStateCompleted = cstate.FwdStateCompleted +) - sourceBytes := makeLogKey(p.source.ToUint64()) +var ( + // fwdPackagesKey is retained while the root channeldb bucket setup + // remains in this package. + fwdPackagesKey = cstate.FwdPackagesBucketKey() - // If the nested bucket doesn't exist, there's no need to delete. - if fwdPkgBkt.NestedReadWriteBucket(sourceBytes[:]) == nil { - return nil - } + // NewPkgFilter initializes an empty PkgFilter supporting `count` + // elements. + NewPkgFilter = cstate.NewPkgFilter - return fwdPkgBkt.DeleteNestedBucket(sourceBytes[:]) -} + // NewFwdPkg initializes a new forwarding package in FwdStateLockedIn. + NewFwdPkg = cstate.NewFwdPkg -// uint16Key writes the provided 16-bit unsigned integer to a 2-byte slice. -func uint16Key(i uint16) []byte { - key := make([]byte, 2) - byteOrder.PutUint16(key, i) - return key -} + // ErrCorruptedFwdPkg signals that the on-disk structure of the + // forwarding package has potentially been mangled. + ErrCorruptedFwdPkg = cstate.ErrCorruptedFwdPkg -// Compile-time constraint to ensure that ChannelPackager implements the public -// FwdPackager interface. -var _ FwdPackager = (*ChannelPackager)(nil) + // NewSwitchPackager instantiates a new SwitchPackager. + NewSwitchPackager = cstate.NewSwitchPackager -// Compile-time constraint to ensure that SwitchPackager implements the public -// FwdOperator interface. -var _ FwdOperator = (*SwitchPackager)(nil) + // NewChannelPackager creates a new packager for a single channel. + NewChannelPackager = cstate.NewChannelPackager +) diff --git a/channeldb/forwarding_policy.go b/channeldb/forwarding_policy.go index 2df2e308f8f..3bcfbda0028 100644 --- a/channeldb/forwarding_policy.go +++ b/channeldb/forwarding_policy.go @@ -2,44 +2,15 @@ package channeldb import ( "github.com/lightningnetwork/lnd/graph/db/models" - "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnwire" ) -var ( - // initialChannelForwardingPolicyBucket is the database bucket used to - // store the forwarding policy for each permanent channel that is - // currently in the process of being opened. - initialChannelForwardingPolicyBucket = []byte( - "initialChannelFwdingPolicy", - ) -) - // SaveInitialForwardingPolicy saves the serialized forwarding policy for the // provided permanent channel id to the initialChannelForwardingPolicyBucket. func (c *ChannelStateDB) SaveInitialForwardingPolicy(chanID lnwire.ChannelID, forwardingPolicy *models.ForwardingPolicy) error { - chanIDCopy := make([]byte, 32) - copy(chanIDCopy, chanID[:]) - - scratch := make([]byte, 36) - byteOrder.PutUint64(scratch[:8], uint64(forwardingPolicy.MinHTLCOut)) - byteOrder.PutUint64(scratch[8:16], uint64(forwardingPolicy.MaxHTLC)) - byteOrder.PutUint64(scratch[16:24], uint64(forwardingPolicy.BaseFee)) - byteOrder.PutUint64(scratch[24:32], uint64(forwardingPolicy.FeeRate)) - byteOrder.PutUint32(scratch[32:], forwardingPolicy.TimeLockDelta) - - return kvdb.Update(c.backend, func(tx kvdb.RwTx) error { - bucket, err := tx.CreateTopLevelBucket( - initialChannelForwardingPolicyBucket, - ) - if err != nil { - return err - } - - return bucket.Put(chanIDCopy, scratch) - }, func() {}) + return c.kvStore.SaveInitialForwardingPolicy(chanID, forwardingPolicy) } // GetInitialForwardingPolicy fetches the serialized forwarding policy for the @@ -48,46 +19,7 @@ func (c *ChannelStateDB) SaveInitialForwardingPolicy(chanID lnwire.ChannelID, func (c *ChannelStateDB) GetInitialForwardingPolicy( chanID lnwire.ChannelID) (*models.ForwardingPolicy, error) { - chanIDCopy := make([]byte, 32) - copy(chanIDCopy, chanID[:]) - - var forwardingPolicy *models.ForwardingPolicy - err := kvdb.View(c.backend, func(tx kvdb.RTx) error { - bucket := tx.ReadBucket(initialChannelForwardingPolicyBucket) - if bucket == nil { - // If the bucket does not exist, it means we - // never added a channel fees to the db, so - // return ErrChannelNotFound. - return ErrChannelNotFound - } - - stateBytes := bucket.Get(chanIDCopy) - if stateBytes == nil { - return ErrChannelNotFound - } - - forwardingPolicy = &models.ForwardingPolicy{ - MinHTLCOut: lnwire.MilliSatoshi( - byteOrder.Uint64(stateBytes[:8]), - ), - MaxHTLC: lnwire.MilliSatoshi( - byteOrder.Uint64(stateBytes[8:16]), - ), - BaseFee: lnwire.MilliSatoshi( - byteOrder.Uint64(stateBytes[16:24]), - ), - FeeRate: lnwire.MilliSatoshi( - byteOrder.Uint64(stateBytes[24:32]), - ), - TimeLockDelta: byteOrder.Uint32(stateBytes[32:36]), - } - - return nil - }, func() { - forwardingPolicy = nil - }) - - return forwardingPolicy, err + return c.kvStore.GetInitialForwardingPolicy(chanID) } // DeleteInitialForwardingPolicy removes the forwarding policy for a given @@ -95,17 +27,5 @@ func (c *ChannelStateDB) GetInitialForwardingPolicy( func (c *ChannelStateDB) DeleteInitialForwardingPolicy( chanID lnwire.ChannelID) error { - chanIDCopy := make([]byte, 32) - copy(chanIDCopy, chanID[:]) - - return kvdb.Update(c.backend, func(tx kvdb.RwTx) error { - bucket := tx.ReadWriteBucket( - initialChannelForwardingPolicyBucket, - ) - if bucket == nil { - return ErrChannelNotFound - } - - return bucket.Delete(chanIDCopy) - }, func() {}) + return c.kvStore.DeleteInitialForwardingPolicy(chanID) } diff --git a/channeldb/legacy_serialization.go b/channeldb/legacy_serialization.go index c2e636c5eaa..9bfa1180a71 100644 --- a/channeldb/legacy_serialization.go +++ b/channeldb/legacy_serialization.go @@ -2,6 +2,8 @@ package channeldb import ( "io" + + cstate "github.com/lightningnetwork/lnd/chanstate" ) // deserializeCloseChannelSummaryV6 reads the v6 database format for @@ -34,7 +36,7 @@ func deserializeCloseChannelSummaryV6(r io.Reader) (*ChannelCloseSummary, error) return nil, err } - if err := readChanConfig(r, &c.LocalChanConfig); err != nil { + if err := cstate.ReadChanConfig(r, &c.LocalChanConfig); err != nil { return nil, err } diff --git a/channeldb/revocation_log.go b/channeldb/revocation_log.go index 5a7f7a76bea..59fcf687349 100644 --- a/channeldb/revocation_log.go +++ b/channeldb/revocation_log.go @@ -1,35 +1,52 @@ package channeldb import ( - "bytes" - "encoding/binary" - "errors" "io" - "math" - "github.com/btcsuite/btcd/btcutil" - "github.com/lightningnetwork/lnd/fn/v2" + cstate "github.com/lightningnetwork/lnd/chanstate" "github.com/lightningnetwork/lnd/kvdb" - "github.com/lightningnetwork/lnd/lntypes" - "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/tlv" ) const ( // OutputIndexEmpty is used when the output index doesn't exist. - OutputIndexEmpty = math.MaxUint16 + OutputIndexEmpty = cstate.OutputIndexEmpty ) type ( // BigSizeAmount is a type alias for a TLV record of a btcutil.Amount. - BigSizeAmount = tlv.BigSizeT[btcutil.Amount] + BigSizeAmount = cstate.BigSizeAmount // BigSizeMilliSatoshi is a type alias for a TLV record of a // lnwire.MilliSatoshi. - BigSizeMilliSatoshi = tlv.BigSizeT[lnwire.MilliSatoshi] + BigSizeMilliSatoshi = cstate.BigSizeMilliSatoshi + + // SparsePayHash is a type alias for a 32 byte array, which when + // serialized is able to save some space by not including an empty + // payment hash on disk. + SparsePayHash = cstate.SparsePayHash + + // HTLCEntry specifies the minimal info needed to be stored on disk for + // ALL the historical HTLCs, which is useful for constructing + // RevocationLog when a breach is detected. + HTLCEntry = cstate.HTLCEntry + + // RevocationLog stores the info needed to construct a breach + // retribution. + RevocationLog = cstate.RevocationLog ) var ( + // NewSparsePayHash creates a new SparsePayHash from a 32 byte array. + NewSparsePayHash = cstate.NewSparsePayHash + + // NewHTLCEntryFromHTLC creates a new HTLCEntry from an HTLC. + NewHTLCEntryFromHTLC = cstate.NewHTLCEntryFromHTLC + + // NewRevocationLog creates a new RevocationLog from the given + // parameters. + NewRevocationLog = cstate.NewRevocationLog + // revocationLogBucketDeprecated is dedicated for storing the necessary // delta state between channel updates required to re-construct a past // state in order to punish a counterparty attempting a non-cooperative @@ -38,283 +55,24 @@ var ( // // Deprecated: This bucket is kept for read-only in case the user // choose not to migrate the old data. - revocationLogBucketDeprecated = []byte("revocation-log-key") + //nolint:ll + revocationLogBucketDeprecated = cstate.RevocationLogBucketDeprecatedKey() // revocationLogBucket is a sub-bucket under openChannelBucket. This // sub-bucket is dedicated for storing the minimal info required to // re-construct a past state in order to punish a counterparty // attempting a non-cooperative channel closure. - revocationLogBucket = []byte("revocation-log") + revocationLogBucket = cstate.RevocationLogBucketKey() // ErrLogEntryNotFound is returned when we cannot find a log entry at // the height requested in the revocation log. - ErrLogEntryNotFound = errors.New("log entry not found") + ErrLogEntryNotFound = cstate.ErrLogEntryNotFound // ErrOutputIndexTooBig is returned when the output index is greater // than uint16. - ErrOutputIndexTooBig = errors.New("output index is over uint16") + ErrOutputIndexTooBig = cstate.ErrOutputIndexTooBig ) -// SparsePayHash is a type alias for a 32 byte array, which when serialized is -// able to save some space by not including an empty payment hash on disk. -type SparsePayHash [32]byte - -// NewSparsePayHash creates a new SparsePayHash from a 32 byte array. -func NewSparsePayHash(rHash [32]byte) SparsePayHash { - return SparsePayHash(rHash) -} - -// Record returns a tlv record for the SparsePayHash. -func (s *SparsePayHash) Record() tlv.Record { - // We use a zero for the type here, as this'll be used along with the - // RecordT type. - return tlv.MakeDynamicRecord( - 0, s, s.hashLen, - sparseHashEncoder, sparseHashDecoder, - ) -} - -// hashLen is used by MakeDynamicRecord to return the size of the RHash. -// -// NOTE: for zero hash, we return a length 0. -func (s *SparsePayHash) hashLen() uint64 { - if bytes.Equal(s[:], lntypes.ZeroHash[:]) { - return 0 - } - - return 32 -} - -// sparseHashEncoder is the customized encoder which skips encoding the empty -// hash. -func sparseHashEncoder(w io.Writer, val interface{}, buf *[8]byte) error { - v, ok := val.(*SparsePayHash) - if !ok { - return tlv.NewTypeForEncodingErr(val, "SparsePayHash") - } - - // If the value is an empty hash, we will skip encoding it. - if bytes.Equal(v[:], lntypes.ZeroHash[:]) { - return nil - } - - vArray := (*[32]byte)(v) - - return tlv.EBytes32(w, vArray, buf) -} - -// sparseHashDecoder is the customized decoder which skips decoding the empty -// hash. -func sparseHashDecoder(r io.Reader, val interface{}, buf *[8]byte, - l uint64) error { - - v, ok := val.(*SparsePayHash) - if !ok { - return tlv.NewTypeForEncodingErr(val, "SparsePayHash") - } - - // If the length is zero, we will skip encoding the empty hash. - if l == 0 { - return nil - } - - vArray := (*[32]byte)(v) - - return tlv.DBytes32(r, vArray, buf, 32) -} - -// HTLCEntry specifies the minimal info needed to be stored on disk for ALL the -// historical HTLCs, which is useful for constructing RevocationLog when a -// breach is detected. -// The actual size of each HTLCEntry varies based on its RHash and Amt(sat), -// summarized as follows, -// -// | RHash empty | Amt<=252 | Amt<=65,535 | Amt<=4,294,967,295 | otherwise | -// |:-----------:|:--------:|:-----------:|:------------------:|:---------:| -// | true | 19 | 21 | 23 | 26 | -// | false | 51 | 53 | 55 | 58 | -// -// So the size varies from 19 bytes to 58 bytes, where most likely to be 23 or -// 55 bytes. -// -// NOTE: all the fields saved to disk use the primitive go types so they can be -// made into tlv records without further conversion. -type HTLCEntry struct { - // RHash is the payment hash of the HTLC. - RHash tlv.RecordT[tlv.TlvType0, SparsePayHash] - - // RefundTimeout is the absolute timeout on the HTLC that the sender - // must wait before reclaiming the funds in limbo. - RefundTimeout tlv.RecordT[tlv.TlvType1, uint32] - - // OutputIndex is the output index for this particular HTLC output - // within the commitment transaction. - // - // NOTE: we use uint16 instead of int32 here to save us 2 bytes, which - // gives us a max number of HTLCs of 65K. - OutputIndex tlv.RecordT[tlv.TlvType2, uint16] - - // Incoming denotes whether we're the receiver or the sender of this - // HTLC. - Incoming tlv.RecordT[tlv.TlvType3, bool] - - // Amt is the amount of satoshis this HTLC escrows. - Amt tlv.RecordT[tlv.TlvType4, tlv.BigSizeT[btcutil.Amount]] - - // CustomBlob is an optional blob that can be used to store information - // specific to revocation handling for a custom channel type. - CustomBlob tlv.OptionalRecordT[tlv.TlvType5, tlv.Blob] - - // HtlcIndex is the index of the HTLC in the channel. - HtlcIndex tlv.OptionalRecordT[tlv.TlvType6, tlv.BigSizeT[uint64]] -} - -// toTlvStream converts an HTLCEntry record into a tlv representation. -func (h *HTLCEntry) toTlvStream() (*tlv.Stream, error) { - records := []tlv.Record{ - h.RHash.Record(), - h.RefundTimeout.Record(), - h.OutputIndex.Record(), - h.Incoming.Record(), - h.Amt.Record(), - } - - h.CustomBlob.WhenSome(func(r tlv.RecordT[tlv.TlvType5, tlv.Blob]) { - records = append(records, r.Record()) - }) - - h.HtlcIndex.WhenSome(func(r tlv.RecordT[tlv.TlvType6, - tlv.BigSizeT[uint64]]) { - - records = append(records, r.Record()) - }) - - tlv.SortRecords(records) - - return tlv.NewStream(records...) -} - -// NewHTLCEntryFromHTLC creates a new HTLCEntry from an HTLC. -func NewHTLCEntryFromHTLC(htlc HTLC) (*HTLCEntry, error) { - h := &HTLCEntry{ - RHash: tlv.NewRecordT[tlv.TlvType0]( - NewSparsePayHash(htlc.RHash), - ), - RefundTimeout: tlv.NewPrimitiveRecord[tlv.TlvType1]( - htlc.RefundTimeout, - ), - OutputIndex: tlv.NewPrimitiveRecord[tlv.TlvType2]( - uint16(htlc.OutputIndex), - ), - Incoming: tlv.NewPrimitiveRecord[tlv.TlvType3](htlc.Incoming), - Amt: tlv.NewRecordT[tlv.TlvType4]( - tlv.NewBigSizeT(htlc.Amt.ToSatoshis()), - ), - HtlcIndex: tlv.SomeRecordT(tlv.NewRecordT[tlv.TlvType6]( - tlv.NewBigSizeT(htlc.HtlcIndex), - )), - } - - if len(htlc.CustomRecords) != 0 { - blob, err := htlc.CustomRecords.Serialize() - if err != nil { - return nil, err - } - - h.CustomBlob = tlv.SomeRecordT( - tlv.NewPrimitiveRecord[tlv.TlvType5, tlv.Blob](blob), - ) - } - - return h, nil -} - -// RevocationLog stores the info needed to construct a breach retribution. Its -// fields can be viewed as a subset of a ChannelCommitment's. In the database, -// all historical versions of the RevocationLog are saved using the -// CommitHeight as the key. -type RevocationLog struct { - // OurOutputIndex specifies our output index in this commitment. In a - // remote commitment transaction, this is the to remote output index. - OurOutputIndex tlv.RecordT[tlv.TlvType0, uint16] - - // TheirOutputIndex specifies their output index in this commitment. In - // a remote commitment transaction, this is the to local output index. - TheirOutputIndex tlv.RecordT[tlv.TlvType1, uint16] - - // CommitTxHash is the hash of the latest version of the commitment - // state, broadcast able by us. - CommitTxHash tlv.RecordT[tlv.TlvType2, [32]byte] - - // HTLCEntries is the set of HTLCEntry's that are pending at this - // particular commitment height. - HTLCEntries []*HTLCEntry - - // OurBalance is the current available balance within the channel - // directly spendable by us. In other words, it is the value of the - // to_remote output on the remote parties' commitment transaction. - // - // NOTE: this is an option so that it is clear if the value is zero or - // nil. Since migration 30 of the channeldb initially did not include - // this field, it could be the case that the field is not present for - // all revocation logs. - OurBalance tlv.OptionalRecordT[tlv.TlvType3, BigSizeMilliSatoshi] - - // TheirBalance is the current available balance within the channel - // directly spendable by the remote node. In other words, it is the - // value of the to_local output on the remote parties' commitment. - // - // NOTE: this is an option so that it is clear if the value is zero or - // nil. Since migration 30 of the channeldb initially did not include - // this field, it could be the case that the field is not present for - // all revocation logs. - TheirBalance tlv.OptionalRecordT[tlv.TlvType4, BigSizeMilliSatoshi] - - // CustomBlob is an optional blob that can be used to store information - // specific to a custom channel type. This information is only created - // at channel funding time, and after wards is to be considered - // immutable. - CustomBlob tlv.OptionalRecordT[tlv.TlvType5, tlv.Blob] -} - -// NewRevocationLog creates a new RevocationLog from the given parameters. -func NewRevocationLog(ourOutputIndex uint16, theirOutputIndex uint16, - commitHash [32]byte, ourBalance, - theirBalance fn.Option[lnwire.MilliSatoshi], htlcs []*HTLCEntry, - customBlob fn.Option[tlv.Blob]) RevocationLog { - - rl := RevocationLog{ - OurOutputIndex: tlv.NewPrimitiveRecord[tlv.TlvType0]( - ourOutputIndex, - ), - TheirOutputIndex: tlv.NewPrimitiveRecord[tlv.TlvType1]( - theirOutputIndex, - ), - CommitTxHash: tlv.NewPrimitiveRecord[tlv.TlvType2](commitHash), - HTLCEntries: htlcs, - } - - ourBalance.WhenSome(func(balance lnwire.MilliSatoshi) { - rl.OurBalance = tlv.SomeRecordT(tlv.NewRecordT[tlv.TlvType3]( - tlv.NewBigSizeT(balance), - )) - }) - - theirBalance.WhenSome(func(balance lnwire.MilliSatoshi) { - rl.TheirBalance = tlv.SomeRecordT(tlv.NewRecordT[tlv.TlvType4]( - tlv.NewBigSizeT(balance), - )) - }) - - customBlob.WhenSome(func(blob tlv.Blob) { - rl.CustomBlob = tlv.SomeRecordT( - tlv.NewPrimitiveRecord[tlv.TlvType5, tlv.Blob](blob), - ) - }) - - return rl -} - // putRevocationLog uses the fields `CommitTx` and `Htlcs` from a // ChannelCommitment to construct a revocation log entry and saves them to // disk. It also saves our output index and their output index, which are @@ -322,70 +80,9 @@ func NewRevocationLog(ourOutputIndex uint16, theirOutputIndex uint16, func putRevocationLog(bucket kvdb.RwBucket, commit *ChannelCommitment, ourOutputIndex, theirOutputIndex uint32, noAmtData bool) error { - // Sanity check that the output indexes can be safely converted. - if ourOutputIndex > math.MaxUint16 { - return ErrOutputIndexTooBig - } - if theirOutputIndex > math.MaxUint16 { - return ErrOutputIndexTooBig - } - - rl := &RevocationLog{ - OurOutputIndex: tlv.NewPrimitiveRecord[tlv.TlvType0]( - uint16(ourOutputIndex), - ), - TheirOutputIndex: tlv.NewPrimitiveRecord[tlv.TlvType1]( - uint16(theirOutputIndex), - ), - CommitTxHash: tlv.NewPrimitiveRecord[tlv.TlvType2, [32]byte]( - commit.CommitTx.TxHash(), - ), - HTLCEntries: make([]*HTLCEntry, 0, len(commit.Htlcs)), - } - - commit.CustomBlob.WhenSome(func(blob tlv.Blob) { - rl.CustomBlob = tlv.SomeRecordT( - tlv.NewPrimitiveRecord[tlv.TlvType5, tlv.Blob](blob), - ) - }) - - if !noAmtData { - rl.OurBalance = tlv.SomeRecordT(tlv.NewRecordT[tlv.TlvType3]( - tlv.NewBigSizeT(commit.LocalBalance), - )) - - rl.TheirBalance = tlv.SomeRecordT(tlv.NewRecordT[tlv.TlvType4]( - tlv.NewBigSizeT(commit.RemoteBalance), - )) - } - - for _, htlc := range commit.Htlcs { - // Skip dust HTLCs. - if htlc.OutputIndex < 0 { - continue - } - - // Sanity check that the output indexes can be safely - // converted. - if htlc.OutputIndex > math.MaxUint16 { - return ErrOutputIndexTooBig - } - - entry, err := NewHTLCEntryFromHTLC(htlc) - if err != nil { - return err - } - rl.HTLCEntries = append(rl.HTLCEntries, entry) - } - - var b bytes.Buffer - err := serializeRevocationLog(&b, rl) - if err != nil { - return err - } - - logEntrykey := makeLogKey(commit.CommitHeight) - return bucket.Put(logEntrykey[:], b.Bytes()) + return cstate.PutRevocationLog( + bucket, commit, ourOutputIndex, theirOutputIndex, noAmtData, + ) } // fetchRevocationLog queries the revocation log bucket to find an log entry. @@ -393,283 +90,42 @@ func putRevocationLog(bucket kvdb.RwBucket, commit *ChannelCommitment, func fetchRevocationLog(log kvdb.RBucket, updateNum uint64) (RevocationLog, error) { - logEntrykey := makeLogKey(updateNum) - commitBytes := log.Get(logEntrykey[:]) - if commitBytes == nil { - return RevocationLog{}, ErrLogEntryNotFound - } - - commitReader := bytes.NewReader(commitBytes) - - return deserializeRevocationLog(commitReader) + return cstate.FetchRevocationLog(log, updateNum) } // serializeRevocationLog serializes a RevocationLog record based on tlv // format. func serializeRevocationLog(w io.Writer, rl *RevocationLog) error { - // Add the tlv records for all non-optional fields. - records := []tlv.Record{ - rl.OurOutputIndex.Record(), - rl.TheirOutputIndex.Record(), - rl.CommitTxHash.Record(), - } - - // Now we add any optional fields that are non-nil. - rl.OurBalance.WhenSome( - func(r tlv.RecordT[tlv.TlvType3, BigSizeMilliSatoshi]) { - records = append(records, r.Record()) - }, - ) - - rl.TheirBalance.WhenSome( - func(r tlv.RecordT[tlv.TlvType4, BigSizeMilliSatoshi]) { - records = append(records, r.Record()) - }, - ) - - rl.CustomBlob.WhenSome(func(r tlv.RecordT[tlv.TlvType5, tlv.Blob]) { - records = append(records, r.Record()) - }) - - // Create the tlv stream. - tlvStream, err := tlv.NewStream(records...) - if err != nil { - return err - } - - // Write the tlv stream. - if err := writeTlvStream(w, tlvStream); err != nil { - return err - } - - // Write the HTLCs. - return serializeHTLCEntries(w, rl.HTLCEntries) + return cstate.SerializeRevocationLog(w, rl) } // serializeHTLCEntries serializes a list of HTLCEntry records based on tlv // format. func serializeHTLCEntries(w io.Writer, htlcs []*HTLCEntry) error { - for _, htlc := range htlcs { - // Create the tlv stream. - tlvStream, err := htlc.toTlvStream() - if err != nil { - return err - } - - // Write the tlv stream. - if err := writeTlvStream(w, tlvStream); err != nil { - return err - } - } - - return nil + return cstate.SerializeHTLCEntries(w, htlcs) } // deserializeRevocationLog deserializes a RevocationLog based on tlv format. func deserializeRevocationLog(r io.Reader) (RevocationLog, error) { - var rl RevocationLog - - ourBalance := rl.OurBalance.Zero() - theirBalance := rl.TheirBalance.Zero() - customBlob := rl.CustomBlob.Zero() - - // Create the tlv stream. - tlvStream, err := tlv.NewStream( - rl.OurOutputIndex.Record(), - rl.TheirOutputIndex.Record(), - rl.CommitTxHash.Record(), - ourBalance.Record(), - theirBalance.Record(), - customBlob.Record(), - ) - if err != nil { - return rl, err - } - - // Read the tlv stream. - parsedTypes, err := readTlvStream(r, tlvStream) - if err != nil { - return rl, err - } - - if t, ok := parsedTypes[ourBalance.TlvType()]; ok && t == nil { - rl.OurBalance = tlv.SomeRecordT(ourBalance) - } - - if t, ok := parsedTypes[theirBalance.TlvType()]; ok && t == nil { - rl.TheirBalance = tlv.SomeRecordT(theirBalance) - } - - if t, ok := parsedTypes[customBlob.TlvType()]; ok && t == nil { - rl.CustomBlob = tlv.SomeRecordT(customBlob) - } - - // Read the HTLC entries. - rl.HTLCEntries, err = deserializeHTLCEntries(r) - - return rl, err + return cstate.DeserializeRevocationLog(r) } // deserializeHTLCEntries deserializes a list of HTLC entries based on tlv // format. func deserializeHTLCEntries(r io.Reader) ([]*HTLCEntry, error) { - var ( - htlcs []*HTLCEntry - - // htlcIndexBlob defines the tlv record type to be used when - // decoding from the disk. We use it instead of the one defined - // in `HTLCEntry.HtlcIndex` as previously this field was encoded - // using `uint16`, thus we will read it as raw bytes and - // deserialize it further below. - htlcIndexBlob tlv.OptionalRecordT[tlv.TlvType6, tlv.Blob] - ) - - for { - var htlc HTLCEntry - - customBlob := htlc.CustomBlob.Zero() - htlcIndex := htlcIndexBlob.Zero() - - // Create the tlv stream. - records := []tlv.Record{ - htlc.RHash.Record(), - htlc.RefundTimeout.Record(), - htlc.OutputIndex.Record(), - htlc.Incoming.Record(), - htlc.Amt.Record(), - customBlob.Record(), - htlcIndex.Record(), - } - - tlvStream, err := tlv.NewStream(records...) - if err != nil { - return nil, err - } - - // Read the HTLC entry. - parsedTypes, err := readTlvStream(r, tlvStream) - if err != nil { - // We've reached the end when hitting an EOF. - if err == io.ErrUnexpectedEOF { - break - } - return nil, err - } - - if t, ok := parsedTypes[customBlob.TlvType()]; ok && t == nil { - htlc.CustomBlob = tlv.SomeRecordT(customBlob) - } - - if t, ok := parsedTypes[htlcIndex.TlvType()]; ok && t == nil { - record, err := deserializeHtlcIndexCompatible( - htlcIndex.Val, - ) - if err != nil { - return nil, err - } - - htlc.HtlcIndex = record - } - - // Append the entry. - htlcs = append(htlcs, &htlc) - } - - return htlcs, nil -} - -// deserializeHtlcIndexCompatible takes raw bytes and decodes it into an -// optional record that's assigned to the entry's HtlcIndex. -// -// NOTE: previously this `HtlcIndex` was a tlv record that used `uint16` to -// encode its value. Given now its value is encoded using BigSizeT, and for any -// BigSizeT, its possible length values are 1, 3, 5, and 8. This means if the -// tlv record has a length of 2, we know for sure it must be an old record -// whose value was encoded using uint16. -func deserializeHtlcIndexCompatible(rawBytes []byte) ( - tlv.OptionalRecordT[tlv.TlvType6, tlv.BigSizeT[uint64]], error) { - - var ( - // record defines the record that's used by the HtlcIndex in the - // entry. - record tlv.OptionalRecordT[ - tlv.TlvType6, tlv.BigSizeT[uint64], - ] - - // htlcIndexVal is the decoded uint64 value. - htlcIndexVal uint64 - ) - - // If the length of the tlv record is 2, it must be encoded using uint16 - // as the BigSizeT encoding cannot have this length. - if len(rawBytes) == 2 { - // Decode the raw bytes into uint16 and convert it into uint64. - htlcIndexVal = uint64(binary.BigEndian.Uint16(rawBytes)) - } else { - // This value is encoded using BigSizeT, we now use the decoder - // to deserialize the raw bytes. - r := bytes.NewBuffer(rawBytes) - - // Create a buffer to be used in the decoding process. - buf := [8]byte{} - - // Use the BigSizeT's decoder. - err := tlv.DBigSize(r, &htlcIndexVal, &buf, 8) - if err != nil { - return record, err - } - } - - record = tlv.SomeRecordT(tlv.NewRecordT[tlv.TlvType6]( - tlv.NewBigSizeT(htlcIndexVal), - )) - - return record, nil + return cstate.DeserializeHTLCEntries(r) } // writeTlvStream is a helper function that encodes the tlv stream into the // writer. func writeTlvStream(w io.Writer, s *tlv.Stream) error { - var b bytes.Buffer - if err := s.Encode(&b); err != nil { - return err - } - - // Write the stream's length as a varint. - err := tlv.WriteVarInt(w, uint64(b.Len()), &[8]byte{}) - if err != nil { - return err - } - - if _, err = w.Write(b.Bytes()); err != nil { - return err - } - - return nil + return cstate.WriteTlvStream(w, s) } // readTlvStream is a helper function that decodes the tlv stream from the // reader. func readTlvStream(r io.Reader, s *tlv.Stream) (tlv.TypeMap, error) { - var bodyLen uint64 - - // Read the stream's length. - bodyLen, err := tlv.ReadVarInt(r, &[8]byte{}) - switch { - // We'll convert any EOFs to ErrUnexpectedEOF, since this results in an - // invalid record. - case err == io.EOF: - return nil, io.ErrUnexpectedEOF - - // Other unexpected errors. - case err != nil: - return nil, err - } - - // TODO(yy): add overflow check. - lr := io.LimitReader(r, int64(bodyLen)) - - return s.DecodeWithParsedTypes(lr) + return cstate.ReadTlvStream(r, s) } // fetchOldRevocationLog finds the revocation log from the deprecated @@ -677,14 +133,7 @@ func readTlvStream(r io.Reader, s *tlv.Stream) (tlv.TypeMap, error) { func fetchOldRevocationLog(log kvdb.RBucket, updateNum uint64) (ChannelCommitment, error) { - logEntrykey := makeLogKey(updateNum) - commitBytes := log.Get(logEntrykey[:]) - if commitBytes == nil { - return ChannelCommitment{}, ErrLogEntryNotFound - } - - commitReader := bytes.NewReader(commitBytes) - return deserializeChanCommit(commitReader) + return cstate.FetchOldRevocationLog(log, updateNum) } // fetchRevocationLogCompatible finds the revocation log from both the @@ -698,86 +147,16 @@ func fetchOldRevocationLog(log kvdb.RBucket, func fetchRevocationLogCompatible(chanBucket kvdb.RBucket, updateNum uint64) (*RevocationLog, *ChannelCommitment, error) { - // Look into the new bucket first. - logBucket := chanBucket.NestedReadBucket(revocationLogBucket) - if logBucket != nil { - rl, err := fetchRevocationLog(logBucket, updateNum) - // We've found the record, no need to visit the old bucket. - if err == nil { - return &rl, nil, nil - } - - // Return the error if it doesn't say the log cannot be found. - if err != ErrLogEntryNotFound { - return nil, nil, err - } - } - - // Otherwise, look into the old bucket and try to find the log there. - oldBucket := chanBucket.NestedReadBucket(revocationLogBucketDeprecated) - if oldBucket != nil { - c, err := fetchOldRevocationLog(oldBucket, updateNum) - if err != nil { - return nil, nil, err - } - - // Found an old record and return it. - return nil, &c, nil - } - - // If both the buckets are nil, then the sub-buckets haven't been - // created yet. - if logBucket == nil && oldBucket == nil { - return nil, nil, ErrNoPastDeltas - } - - // Otherwise, we've tried to query the new bucket but the log cannot be - // found. - return nil, nil, ErrLogEntryNotFound + return cstate.FetchRevocationLogCompatible(chanBucket, updateNum) } // fetchLogBucket returns a read bucket by visiting both the new and the old // bucket. func fetchLogBucket(chanBucket kvdb.RBucket) (kvdb.RBucket, error) { - logBucket := chanBucket.NestedReadBucket(revocationLogBucket) - if logBucket == nil { - logBucket = chanBucket.NestedReadBucket( - revocationLogBucketDeprecated, - ) - if logBucket == nil { - return nil, ErrNoPastDeltas - } - } - - return logBucket, nil + return cstate.FetchLogBucket(chanBucket) } // deleteLogBucket deletes the both the new and old revocation log buckets. func deleteLogBucket(chanBucket kvdb.RwBucket) error { - // Check if the bucket exists and delete it. - logBucket := chanBucket.NestedReadWriteBucket( - revocationLogBucket, - ) - if logBucket != nil { - err := chanBucket.DeleteNestedBucket(revocationLogBucket) - if err != nil { - return err - } - } - - // We also check whether the old revocation log bucket exists - // and delete it if so. - oldLogBucket := chanBucket.NestedReadWriteBucket( - revocationLogBucketDeprecated, - ) - if oldLogBucket != nil { - err := chanBucket.DeleteNestedBucket( - revocationLogBucketDeprecated, - ) - if err != nil { - return err - } - } - - return nil + return cstate.DeleteLogBucket(chanBucket) } diff --git a/channelnotifier/channelnotifier.go b/channelnotifier/channelnotifier.go index 06f3e67c0c7..bf01b5a5dbc 100644 --- a/channelnotifier/channelnotifier.go +++ b/channelnotifier/channelnotifier.go @@ -4,7 +4,6 @@ import ( "sync" "github.com/btcsuite/btcd/wire" - "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/chanstate" "github.com/lightningnetwork/lnd/subscribe" ) @@ -31,14 +30,14 @@ type PendingOpenChannelEvent struct { // channel. This might not have been persisted to the channel DB yet // because we are still waiting for the final message from the remote // peer. - PendingChannel *channeldb.OpenChannel + PendingChannel *chanstate.OpenChannel } // OpenChannelEvent represents a new event where a channel goes from pending // open to open. type OpenChannelEvent struct { // Channel is the channel that has become open. - Channel *channeldb.OpenChannel + Channel *chanstate.OpenChannel } // ActiveLinkEvent represents a new event where the link becomes active in the @@ -70,13 +69,13 @@ type InactiveChannelEvent struct { // ClosedChannelEvent represents a new event where a channel becomes closed. type ClosedChannelEvent struct { // CloseSummary is the summary of the channel close that has occurred. - CloseSummary *channeldb.ChannelCloseSummary + CloseSummary *chanstate.ChannelCloseSummary } // ChannelUpdateEvent represents a new event where a channel's state is updated. type ChannelUpdateEvent struct { // Channel is the channel that has been updated. - Channel *channeldb.OpenChannel + Channel *chanstate.OpenChannel } // FullyResolvedChannelEvent represents a new event where a channel becomes @@ -143,7 +142,7 @@ func (c *ChannelNotifier) SubscribeChannelEvents() (*subscribe.Client, error) { // persisted to the DB because we still wait for the final message from the // remote peer. func (c *ChannelNotifier) NotifyPendingOpenChannelEvent(chanPoint wire.OutPoint, - pendingChan *channeldb.OpenChannel) { + pendingChan *chanstate.OpenChannel) { event := PendingOpenChannelEvent{ ChannelPoint: &chanPoint, @@ -249,7 +248,7 @@ func (c *ChannelNotifier) NotifyInactiveChannelEvent(chanPoint wire.OutPoint) { // NotifyChannelUpdateEvent notifies subscribers that a channel's state has been // updated. func (c *ChannelNotifier) NotifyChannelUpdateEvent( - channel *channeldb.OpenChannel) { + channel *chanstate.OpenChannel) { event := ChannelUpdateEvent{Channel: channel} if err := c.ntfnServer.SendUpdate(event); err != nil { diff --git a/channelnotifier/channelnotifier_test.go b/channelnotifier/channelnotifier_test.go index 5dbdb4a4579..7ecaf217281 100644 --- a/channelnotifier/channelnotifier_test.go +++ b/channelnotifier/channelnotifier_test.go @@ -4,7 +4,7 @@ import ( "testing" "time" - "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/chanstate" "github.com/stretchr/testify/require" ) @@ -25,7 +25,7 @@ func TestChannelUpdateEvent(t *testing.T) { defer sub.Cancel() // Create a mock channel state. - channel := &channeldb.OpenChannel{} + channel := &chanstate.OpenChannel{} // Notify the server of a channel update event. ntfnServer.NotifyChannelUpdateEvent(channel) diff --git a/chanrestore.go b/chanrestore.go index 407cdfbc7ad..c45a2700f2d 100644 --- a/chanrestore.go +++ b/chanrestore.go @@ -187,7 +187,7 @@ func (c *chanDBRestorer) openChannelShell(backup chanbackup.Single) ( chanShell := channeldb.ChannelShell{ NodeAddrs: backup.Addresses, - Chan: &channeldb.OpenChannel{ + Chan: &chanstate.OpenChannel{ ChanType: chanType, ChainHash: backup.ChainHash, IsInitiator: backup.IsInitiator, diff --git a/chanstate/channel.go b/chanstate/channel.go new file mode 100644 index 00000000000..0950f4c7294 --- /dev/null +++ b/chanstate/channel.go @@ -0,0 +1,18 @@ +package chanstate + +// ChanCount is used by the server in determining access control. +type ChanCount struct { + HasOpenOrClosedChan bool + PendingOpenCount uint64 +} + +// FinalHtlcInfo contains information about the final outcome of an htlc. +type FinalHtlcInfo struct { + // Settled is true is the htlc was settled. If false, the htlc was + // failed. + Settled bool + + // Offchain indicates whether the htlc was resolved off-chain or + // on-chain. + Offchain bool +} diff --git a/chanstate/channel_status.go b/chanstate/channel_status.go new file mode 100644 index 00000000000..b19fe3659ea --- /dev/null +++ b/chanstate/channel_status.go @@ -0,0 +1,110 @@ +package chanstate + +import ( + "strconv" + "strings" +) + +// ChannelStatus is a bit vector used to indicate whether an OpenChannel is in +// the default usable state, or a state where it shouldn't be used. +type ChannelStatus uint64 + +var ( + // ChanStatusDefault is the normal state of an open channel. + ChanStatusDefault ChannelStatus + + // ChanStatusBorked indicates that the channel has entered an + // irreconcilable state, triggered by a state desynchronization or + // channel breach. Channels in this state should never be added to the + // htlc switch. + ChanStatusBorked ChannelStatus = 1 + + // ChanStatusCommitBroadcasted indicates that a commitment for this + // channel has been broadcasted. + ChanStatusCommitBroadcasted ChannelStatus = 1 << 1 + + // ChanStatusLocalDataLoss indicates that we have lost channel state + // for this channel, and broadcasting our latest commitment might be + // considered a breach. + // + // TODO(halseh): actually enforce that we are not force closing such a + // channel. + ChanStatusLocalDataLoss ChannelStatus = 1 << 2 + + // ChanStatusRestored is a status flag that signals that the channel + // has been restored, and doesn't have all the fields a typical channel + // will have. + ChanStatusRestored ChannelStatus = 1 << 3 + + // ChanStatusCoopBroadcasted indicates that a cooperative close for + // this channel has been broadcasted. Older cooperatively closed + // channels will only have this status set. Newer ones will also have + // close initiator information stored using the local/remote initiator + // status. This status is set in conjunction with the initiator status + // so that we do not need to check multiple channel statues for + // cooperative closes. + ChanStatusCoopBroadcasted ChannelStatus = 1 << 4 + + // ChanStatusLocalCloseInitiator indicates that we initiated closing + // the channel. + ChanStatusLocalCloseInitiator ChannelStatus = 1 << 5 + + // ChanStatusRemoteCloseInitiator indicates that the remote node + // initiated closing the channel. + ChanStatusRemoteCloseInitiator ChannelStatus = 1 << 6 +) + +// chanStatusStrings maps a ChannelStatus to a human friendly string that +// describes that status. +var chanStatusStrings = map[ChannelStatus]string{ + ChanStatusDefault: "ChanStatusDefault", + ChanStatusBorked: "ChanStatusBorked", + ChanStatusCommitBroadcasted: "ChanStatusCommitBroadcasted", + ChanStatusLocalDataLoss: "ChanStatusLocalDataLoss", + ChanStatusRestored: "ChanStatusRestored", + ChanStatusCoopBroadcasted: "ChanStatusCoopBroadcasted", + ChanStatusLocalCloseInitiator: "ChanStatusLocalCloseInitiator", + ChanStatusRemoteCloseInitiator: "ChanStatusRemoteCloseInitiator", +} + +// orderedChanStatusFlags is an in-order list of all that channel status flags. +var orderedChanStatusFlags = []ChannelStatus{ + ChanStatusBorked, + ChanStatusCommitBroadcasted, + ChanStatusLocalDataLoss, + ChanStatusRestored, + ChanStatusCoopBroadcasted, + ChanStatusLocalCloseInitiator, + ChanStatusRemoteCloseInitiator, +} + +// String returns a human-readable representation of the ChannelStatus. +func (c ChannelStatus) String() string { + // If no flags are set, then this is the default case. + if c == ChanStatusDefault { + return chanStatusStrings[ChanStatusDefault] + } + + // Add individual bit flags. + statusStr := "" + for _, flag := range orderedChanStatusFlags { + if c&flag == flag { + statusStr += chanStatusStrings[flag] + "|" + c -= flag + } + } + + // Remove anything to the right of the final bar, including it as well. + statusStr = strings.TrimRight(statusStr, "|") + + // Add any remaining flags which aren't accounted for as hex. + if c != 0 { + statusStr += "|0x" + strconv.FormatUint(uint64(c), 16) + } + + // If this was purely an unknown flag, then remove the extra bar at the + // start of the string. + statusStr = strings.TrimLeft(statusStr, "|") + + return statusStr +} diff --git a/chanstate/channel_type.go b/chanstate/channel_type.go new file mode 100644 index 00000000000..9666307bef1 --- /dev/null +++ b/chanstate/channel_type.go @@ -0,0 +1,157 @@ +package chanstate + +// ChannelType is an enum-like type that describes one of several possible +// channel types. Each open channel is associated with a particular type as the +// channel type may determine how higher level operations are conducted such as +// fee negotiation, channel closing, the format of HTLCs, etc. Structure-wise, +// a ChannelType is a bit field, with each bit denoting a modification from the +// base channel type of single funder. +type ChannelType uint64 + +const ( + // NOTE: iota isn't used here for this enum needs to be stable + // long-term as it will be persisted to the database. + + // SingleFunderBit represents a channel wherein one party solely funds + // the entire capacity of the channel. + SingleFunderBit ChannelType = 0 + + // DualFunderBit represents a channel wherein both parties contribute + // funds towards the total capacity of the channel. The channel may be + // funded symmetrically or asymmetrically. + DualFunderBit ChannelType = 1 << 0 + + // SingleFunderTweaklessBit is similar to the basic SingleFunder channel + // type, but it omits the tweak for one's key in the commitment + // transaction of the remote party. + SingleFunderTweaklessBit ChannelType = 1 << 1 + + // NoFundingTxBit denotes if we have the funding transaction locally on + // disk. This bit may be on if the funding transaction was crafted by a + // wallet external to the primary daemon. + NoFundingTxBit ChannelType = 1 << 2 + + // AnchorOutputsBit indicates that the channel makes use of anchor + // outputs to bump the commitment transaction's effective feerate. This + // channel type also uses a delayed to_remote output script. + AnchorOutputsBit ChannelType = 1 << 3 + + // FrozenBit indicates that the channel is a frozen channel, meaning + // that only the responder can decide to cooperatively close the + // channel. + FrozenBit ChannelType = 1 << 4 + + // ZeroHtlcTxFeeBit indicates that the channel should use zero-fee + // second-level HTLC transactions. + ZeroHtlcTxFeeBit ChannelType = 1 << 5 + + // LeaseExpirationBit indicates that the channel has been leased for a + // period of time, constraining every output that pays to the channel + // initiator with an additional CLTV of the lease maturity. + LeaseExpirationBit ChannelType = 1 << 6 + + // ZeroConfBit indicates that the channel is a zero-conf channel. + ZeroConfBit ChannelType = 1 << 7 + + // ScidAliasChanBit indicates that the channel has negotiated the + // scid-alias channel type. + ScidAliasChanBit ChannelType = 1 << 8 + + // ScidAliasFeatureBit indicates that the scid-alias feature bit was + // negotiated during the lifetime of this channel. + ScidAliasFeatureBit ChannelType = 1 << 9 + + // SimpleTaprootFeatureBit indicates that the simple-taproot-chans + // feature bit was negotiated during the lifetime of the channel. + SimpleTaprootFeatureBit ChannelType = 1 << 10 + + // TapscriptRootBit indicates that this is a MuSig2 channel with a top + // level tapscript commitment. This MUST be set along with the + // SimpleTaprootFeatureBit. + TapscriptRootBit ChannelType = 1 << 11 + + // TaprootFinalBit indicates that this is a MuSig2 channel using the + // final/production taproot scripts and feature bits 80/81. This MUST + // be set along with the SimpleTaprootFeatureBit. + TaprootFinalBit ChannelType = 1 << 12 +) + +// IsSingleFunder returns true if the channel type if one of the known single +// funder variants. +func (c ChannelType) IsSingleFunder() bool { + return c&DualFunderBit == 0 +} + +// IsDualFunder returns true if the ChannelType has the DualFunderBit set. +func (c ChannelType) IsDualFunder() bool { + return c&DualFunderBit == DualFunderBit +} + +// IsTweakless returns true if the target channel uses a commitment that +// doesn't tweak the key for the remote party. +func (c ChannelType) IsTweakless() bool { + return c&SingleFunderTweaklessBit == SingleFunderTweaklessBit +} + +// HasFundingTx returns true if this channel type is one that has a funding +// transaction stored locally. +func (c ChannelType) HasFundingTx() bool { + return c&NoFundingTxBit == 0 +} + +// HasAnchors returns true if this channel type has anchor outputs on its +// commitment. +func (c ChannelType) HasAnchors() bool { + return c&AnchorOutputsBit == AnchorOutputsBit +} + +// ZeroHtlcTxFee returns true if this channel type uses second-level HTLC +// transactions signed with zero-fee. +func (c ChannelType) ZeroHtlcTxFee() bool { + return c&ZeroHtlcTxFeeBit == ZeroHtlcTxFeeBit +} + +// IsFrozen returns true if the channel is considered to be "frozen". A frozen +// channel means that only the responder can initiate a cooperative channel +// closure. +func (c ChannelType) IsFrozen() bool { + return c&FrozenBit == FrozenBit +} + +// HasLeaseExpiration returns true if the channel originated from a lease. +func (c ChannelType) HasLeaseExpiration() bool { + return c&LeaseExpirationBit == LeaseExpirationBit +} + +// HasZeroConf returns true if the channel is a zero-conf channel. +func (c ChannelType) HasZeroConf() bool { + return c&ZeroConfBit == ZeroConfBit +} + +// HasScidAliasChan returns true if the scid-alias channel type was negotiated. +func (c ChannelType) HasScidAliasChan() bool { + return c&ScidAliasChanBit == ScidAliasChanBit +} + +// HasScidAliasFeature returns true if the scid-alias feature bit was +// negotiated during the lifetime of this channel. +func (c ChannelType) HasScidAliasFeature() bool { + return c&ScidAliasFeatureBit == ScidAliasFeatureBit +} + +// IsTaproot returns true if the channel is using taproot features. +func (c ChannelType) IsTaproot() bool { + return c&SimpleTaprootFeatureBit == SimpleTaprootFeatureBit +} + +// HasTapscriptRoot returns true if the channel is using a top level tapscript +// root commitment. +func (c ChannelType) HasTapscriptRoot() bool { + return c&TapscriptRootBit == TapscriptRootBit +} + +// IsTaprootFinal returns true if the channel is using final/production taproot +// scripts and feature bits. +func (c ChannelType) IsTaprootFinal() bool { + return c&TaprootFinalBit == TaprootFinalBit +} diff --git a/chanstate/close_summary.go b/chanstate/close_summary.go new file mode 100644 index 00000000000..779a4c638f4 --- /dev/null +++ b/chanstate/close_summary.go @@ -0,0 +1,126 @@ +package chanstate + +import ( + "github.com/btcsuite/btcd/btcec/v2" + "github.com/btcsuite/btcd/btcutil" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/lnwire" +) + +// ClosureType is an enum like structure that details exactly _how_ a channel +// was closed. Three closure types are currently possible: none, cooperative, +// local force close, remote force close, and (remote) breach. +type ClosureType uint8 + +const ( + // CooperativeClose indicates that a channel has been closed + // cooperatively. This means that both channel peers were online and + // signed a new transaction paying out the settled balance of the + // contract. + CooperativeClose ClosureType = 0 + + // LocalForceClose indicates that we have unilaterally broadcast our + // current commitment state on-chain. + LocalForceClose ClosureType = 1 + + // RemoteForceClose indicates that the remote peer has unilaterally + // broadcast their current commitment state on-chain. + RemoteForceClose ClosureType = 4 + + // BreachClose indicates that the remote peer attempted to broadcast a + // prior _revoked_ channel state. + BreachClose ClosureType = 2 + + // FundingCanceled indicates that the channel never was fully opened + // before it was marked as closed in the database. This can happen if + // we or the remote fail at some point during the opening workflow, or + // we timeout waiting for the funding transaction to be confirmed. + FundingCanceled ClosureType = 3 + + // Abandoned indicates that the channel state was removed without + // any further actions. This is intended to clean up unusable + // channels during development. + Abandoned ClosureType = 5 +) + +// ChannelCloseSummary contains the final state of a channel at the point it +// was closed. Once a channel is closed, all the information pertaining to that +// channel within the openChannelBucket is deleted, and a compact summary is +// put in place instead. +type ChannelCloseSummary struct { + // ChanPoint is the outpoint for this channel's funding transaction, + // and is used as a unique identifier for the channel. + ChanPoint wire.OutPoint + + // ShortChanID encodes the exact location in the chain in which the + // channel was initially confirmed. This includes: the block height, + // transaction index, and the output within the target transaction. + ShortChanID lnwire.ShortChannelID + + // ChainHash is the hash of the genesis block that this channel resides + // within. + ChainHash chainhash.Hash + + // ClosingTXID is the txid of the transaction which ultimately closed + // this channel. + ClosingTXID chainhash.Hash + + // RemotePub is the public key of the remote peer that we formerly had + // a channel with. + RemotePub *btcec.PublicKey + + // Capacity was the total capacity of the channel. + Capacity btcutil.Amount + + // CloseHeight is the height at which the funding transaction was + // spent. + CloseHeight uint32 + + // SettledBalance is our total balance settled balance at the time of + // channel closure. This _does not_ include the sum of any outputs that + // have been time-locked as a result of the unilateral channel closure. + SettledBalance btcutil.Amount + + // TimeLockedBalance is the sum of all the time-locked outputs at the + // time of channel closure. If we triggered the force closure of this + // channel, then this value will be non-zero if our settled output is + // above the dust limit. If we were on the receiving side of a channel + // force closure, then this value will be non-zero if we had any + // outstanding outgoing HTLC's at the time of channel closure. + TimeLockedBalance btcutil.Amount + + // CloseType details exactly _how_ the channel was closed. Five closure + // types are possible: cooperative, local force, remote force, breach + // and funding canceled. + CloseType ClosureType + + // IsPending indicates whether this channel is in the 'pending close' + // state, which means the channel closing transaction has been + // confirmed, but not yet been fully resolved. In the case of a channel + // that has been cooperatively closed, it will go straight into the + // fully resolved state as soon as the closing transaction has been + // confirmed. However, for channels that have been force closed, they'll + // stay marked as "pending" until _all_ the pending funds have been + // swept. + IsPending bool + + // RemoteCurrentRevocation is the current revocation for their + // commitment transaction. However, since this is the derived public + // key, we don't yet have the private key so we aren't yet able to + // verify that it's actually in the hash chain. + RemoteCurrentRevocation *btcec.PublicKey + + // RemoteNextRevocation is the revocation key to be used for the *next* + // commitment transaction we create for the local node. Within the + // specification, this value is referred to as the + // per-commitment-point. + RemoteNextRevocation *btcec.PublicKey + + // LocalChanConfig is the channel configuration for the local node. + LocalChanConfig ChannelConfig + + // LastChanSyncMsg is the ChannelReestablish message for this channel + // for the state at the point where it was closed. + LastChanSyncMsg *lnwire.ChannelReestablish +} diff --git a/chanstate/codec.go b/chanstate/codec.go new file mode 100644 index 00000000000..12cf6bbef07 --- /dev/null +++ b/chanstate/codec.go @@ -0,0 +1,464 @@ +package chanstate + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + "net" + + "github.com/btcsuite/btcd/btcec/v2" + "github.com/btcsuite/btcd/btcutil" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/wire" + graphdb "github.com/lightningnetwork/lnd/graph/db" + "github.com/lightningnetwork/lnd/keychain" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/shachain" + "github.com/lightningnetwork/lnd/tlv" +) + +var byteOrder = binary.BigEndian + +// UnknownElementType is an error returned when the codec is unable to encode or +// decode a particular type. +type UnknownElementType struct { + method string + element interface{} +} + +// NewUnknownElementType creates a new UnknownElementType error from the passed +// method name and element. +func NewUnknownElementType(method string, el interface{}) UnknownElementType { + return UnknownElementType{method: method, element: el} +} + +// Error returns the name of the method that encountered the error, as well as +// the type that was unsupported. +func (e UnknownElementType) Error() string { + return fmt.Sprintf("Unknown type in %s: %T", e.method, e.element) +} + +// WriteElement is a one-stop shop to write the big endian representation of +// any element which is to be serialized for storage on disk. The passed +// io.Writer should be backed by an appropriately sized byte slice, or be able +// to dynamically expand to accommodate additional data. +func WriteElement(w io.Writer, element interface{}) error { //nolint:funlen + switch e := element.(type) { + case keychain.KeyDescriptor: + if err := binary.Write(w, byteOrder, e.Family); err != nil { + return err + } + if err := binary.Write(w, byteOrder, e.Index); err != nil { + return err + } + + if e.PubKey != nil { + if err := binary.Write(w, byteOrder, true); err != nil { + return fmt.Errorf("error writing serialized "+ + "element: %w", err) + } + + return WriteElement(w, e.PubKey) + } + + return binary.Write(w, byteOrder, false) + case ChannelType: + var buf [8]byte + if err := tlv.WriteVarInt(w, uint64(e), &buf); err != nil { + return err + } + + case chainhash.Hash: + if _, err := w.Write(e[:]); err != nil { + return err + } + + case wire.OutPoint: + return graphdb.WriteOutpoint(w, &e) + + case lnwire.ShortChannelID: + if err := binary.Write(w, byteOrder, e.ToUint64()); err != nil { + return err + } + + case lnwire.ChannelID: + if _, err := w.Write(e[:]); err != nil { + return err + } + + case int64, uint64: + if err := binary.Write(w, byteOrder, e); err != nil { + return err + } + + case uint32: + if err := binary.Write(w, byteOrder, e); err != nil { + return err + } + + case int32: + if err := binary.Write(w, byteOrder, e); err != nil { + return err + } + + case uint16: + if err := binary.Write(w, byteOrder, e); err != nil { + return err + } + + case uint8: + if err := binary.Write(w, byteOrder, e); err != nil { + return err + } + + case bool: + if err := binary.Write(w, byteOrder, e); err != nil { + return err + } + + case btcutil.Amount: + if err := binary.Write(w, byteOrder, uint64(e)); err != nil { + return err + } + + case lnwire.MilliSatoshi: + if err := binary.Write(w, byteOrder, uint64(e)); err != nil { + return err + } + + case *btcec.PrivateKey: + b := e.Serialize() + if _, err := w.Write(b); err != nil { + return err + } + + case *btcec.PublicKey: + b := e.SerializeCompressed() + if _, err := w.Write(b); err != nil { + return err + } + + case shachain.Producer: + return e.Encode(w) + + case shachain.Store: + return e.Encode(w) + + case *wire.MsgTx: + return e.Serialize(w) + + case [32]byte: + if _, err := w.Write(e[:]); err != nil { + return err + } + + case []byte: + if err := wire.WriteVarBytes(w, 0, e); err != nil { + return err + } + + case lnwire.Message: + var msgBuf bytes.Buffer + if _, err := lnwire.WriteMessage(&msgBuf, e, 0); err != nil { + return err + } + + msgLen := uint16(len(msgBuf.Bytes())) + if err := WriteElements(w, msgLen); err != nil { + return err + } + + if _, err := w.Write(msgBuf.Bytes()); err != nil { + return err + } + + case ChannelStatus: + var buf [8]byte + if err := tlv.WriteVarInt(w, uint64(e), &buf); err != nil { + return err + } + + case ClosureType: + if err := binary.Write(w, byteOrder, e); err != nil { + return err + } + + case lnwire.FundingFlag: + if err := binary.Write(w, byteOrder, e); err != nil { + return err + } + + case net.Addr: + if err := graphdb.SerializeAddr(w, e); err != nil { + return err + } + + case []net.Addr: + if err := WriteElement(w, uint32(len(e))); err != nil { + return err + } + + for _, addr := range e { + if err := graphdb.SerializeAddr(w, addr); err != nil { + return err + } + } + + default: + return UnknownElementType{"WriteElement", e} + } + + return nil +} + +// WriteElements is writes each element in the elements slice to the passed +// io.Writer using WriteElement. +func WriteElements(w io.Writer, elements ...interface{}) error { + for _, element := range elements { + err := WriteElement(w, element) + if err != nil { + return err + } + } + + return nil +} + +// ReadElement is a one-stop utility function to deserialize any datastructure +// encoded using the serialization format of the database. +func ReadElement(r io.Reader, element interface{}) error { //nolint:funlen + switch e := element.(type) { + case *keychain.KeyDescriptor: + if err := binary.Read(r, byteOrder, &e.Family); err != nil { + return err + } + if err := binary.Read(r, byteOrder, &e.Index); err != nil { + return err + } + + var hasPubKey bool + if err := binary.Read(r, byteOrder, &hasPubKey); err != nil { + return err + } + + if hasPubKey { + return ReadElement(r, &e.PubKey) + } + + case *ChannelType: + var buf [8]byte + ctype, err := tlv.ReadVarInt(r, &buf) + if err != nil { + return err + } + + *e = ChannelType(ctype) + + case *chainhash.Hash: + if _, err := io.ReadFull(r, e[:]); err != nil { + return err + } + + case *wire.OutPoint: + return graphdb.ReadOutpoint(r, e) + + case *lnwire.ShortChannelID: + var a uint64 + if err := binary.Read(r, byteOrder, &a); err != nil { + return err + } + *e = lnwire.NewShortChanIDFromInt(a) + + case *lnwire.ChannelID: + if _, err := io.ReadFull(r, e[:]); err != nil { + return err + } + + case *int64, *uint64: + if err := binary.Read(r, byteOrder, e); err != nil { + return err + } + + case *uint32: + if err := binary.Read(r, byteOrder, e); err != nil { + return err + } + + case *int32: + if err := binary.Read(r, byteOrder, e); err != nil { + return err + } + + case *uint16: + if err := binary.Read(r, byteOrder, e); err != nil { + return err + } + + case *uint8: + if err := binary.Read(r, byteOrder, e); err != nil { + return err + } + + case *bool: + if err := binary.Read(r, byteOrder, e); err != nil { + return err + } + + case *btcutil.Amount: + var a uint64 + if err := binary.Read(r, byteOrder, &a); err != nil { + return err + } + + *e = btcutil.Amount(a) + + case *lnwire.MilliSatoshi: + var a uint64 + if err := binary.Read(r, byteOrder, &a); err != nil { + return err + } + + *e = lnwire.MilliSatoshi(a) + + case **btcec.PrivateKey: + var b [btcec.PrivKeyBytesLen]byte + if _, err := io.ReadFull(r, b[:]); err != nil { + return err + } + + priv, _ := btcec.PrivKeyFromBytes(b[:]) + *e = priv + + case **btcec.PublicKey: + var b [btcec.PubKeyBytesLenCompressed]byte + if _, err := io.ReadFull(r, b[:]); err != nil { + return err + } + + pubKey, err := btcec.ParsePubKey(b[:]) + if err != nil { + return err + } + *e = pubKey + + case *shachain.Producer: + var root [32]byte + if _, err := io.ReadFull(r, root[:]); err != nil { + return err + } + + // TODO(roasbeef): remove + producer, err := shachain.NewRevocationProducerFromBytes( + root[:], + ) + if err != nil { + return err + } + + *e = producer + + case *shachain.Store: + store, err := shachain.NewRevocationStoreFromBytes(r) + if err != nil { + return err + } + + *e = store + + case **wire.MsgTx: + tx := wire.NewMsgTx(2) + if err := tx.Deserialize(r); err != nil { + return err + } + + *e = tx + + case *[32]byte: + if _, err := io.ReadFull(r, e[:]); err != nil { + return err + } + + case *[]byte: + bytes, err := wire.ReadVarBytes(r, 0, 66000, "[]byte") + if err != nil { + return err + } + + *e = bytes + + case *lnwire.Message: + var msgLen uint16 + if err := ReadElement(r, &msgLen); err != nil { + return err + } + + msgReader := io.LimitReader(r, int64(msgLen)) + msg, err := lnwire.ReadMessage(msgReader, 0) + if err != nil { + return err + } + + *e = msg + + case *ChannelStatus: + var buf [8]byte + status, err := tlv.ReadVarInt(r, &buf) + if err != nil { + return err + } + + *e = ChannelStatus(status) + + case *ClosureType: + if err := binary.Read(r, byteOrder, e); err != nil { + return err + } + + case *lnwire.FundingFlag: + if err := binary.Read(r, byteOrder, e); err != nil { + return err + } + + case *net.Addr: + addr, err := graphdb.DeserializeAddr(r) + if err != nil { + return err + } + *e = addr + + case *[]net.Addr: + var numAddrs uint32 + if err := ReadElement(r, &numAddrs); err != nil { + return err + } + + *e = make([]net.Addr, numAddrs) + for i := uint32(0); i < numAddrs; i++ { + addr, err := graphdb.DeserializeAddr(r) + if err != nil { + return err + } + (*e)[i] = addr + } + + default: + return UnknownElementType{"ReadElement", e} + } + + return nil +} + +// ReadElements deserializes a variable number of elements into the passed +// io.Reader, with each element being deserialized according to the ReadElement +// function. +func ReadElements(r io.Reader, elements ...interface{}) error { + for _, element := range elements { + err := ReadElement(r, element) + if err != nil { + return err + } + } + + return nil +} diff --git a/chanstate/commitment.go b/chanstate/commitment.go new file mode 100644 index 00000000000..0cb49e02313 --- /dev/null +++ b/chanstate/commitment.go @@ -0,0 +1,304 @@ +package chanstate + +import ( + "github.com/btcsuite/btcd/btcutil" + "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/fn/v2" + "github.com/lightningnetwork/lnd/graph/db/models" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/tlv" +) + +// ChannelCommitment is a snapshot of the commitment state at a particular +// point in the commitment chain. With each state transition, a snapshot of the +// current state along with all non-settled HTLCs are recorded. These snapshots +// detail the state of the _remote_ party's commitment at a particular state +// number. For ourselves (the local node) we ONLY store our most recent +// (unrevoked) state for safety purposes. +type ChannelCommitment struct { + // CommitHeight is the update number that this ChannelDelta represents + // the total number of commitment updates to this point. This can be + // viewed as sort of a "commitment height" as this number is + // monotonically increasing. + CommitHeight uint64 + + // LocalLogIndex is the cumulative log index of the local node at this + // point in the commitment chain. This value will be incremented for + // each _update_ added to the local update log. + LocalLogIndex uint64 + + // LocalHtlcIndex is the current local running HTLC index. This value + // will be incremented for each outgoing HTLC the local node offers. + LocalHtlcIndex uint64 + + // RemoteLogIndex is the cumulative log index of the remote node at + // this point in the commitment chain. This value will be incremented + // for each _update_ added to the remote update log. + RemoteLogIndex uint64 + + // RemoteHtlcIndex is the current remote running HTLC index. This value + // will be incremented for each outgoing HTLC the remote node offers. + RemoteHtlcIndex uint64 + + // LocalBalance is the current available settled balance within the + // channel directly spendable by us. + // + // NOTE: This is the balance *after* subtracting any commitment fee, + // AND anchor output values. + LocalBalance lnwire.MilliSatoshi + + // RemoteBalance is the current available settled balance within the + // channel directly spendable by the remote node. + // + // NOTE: This is the balance *after* subtracting any commitment fee, + // AND anchor output values. + RemoteBalance lnwire.MilliSatoshi + + // CommitFee is the amount calculated to be paid in fees for the + // current set of commitment transactions. The fee amount is persisted + // with the channel in order to allow the fee amount to be removed and + // recalculated with each channel state update, including updates that + // happen after a system restart. + CommitFee btcutil.Amount + + // FeePerKw is the min satoshis/kilo-weight that should be paid within + // the commitment transaction for the entire duration of the channel's + // lifetime. This field may be updated during normal operation of the + // channel as on-chain conditions change. + // + // TODO(halseth): make this SatPerKWeight. Cannot be done atm because + // this will cause the import cycle lnwallet<->channeldb. Fee + // estimation stuff should be in its own package. + FeePerKw btcutil.Amount + + // CommitTx is the latest version of the commitment state, broadcast + // able by us. + CommitTx *wire.MsgTx + + // CustomBlob is an optional blob that can be used to store information + // specific to a custom channel type. This may track some custom + // specific state for this given commitment. + CustomBlob fn.Option[tlv.Blob] + + // CommitSig is one half of the signature required to fully complete + // the script for the commitment transaction above. This is the + // signature signed by the remote party for our version of the + // commitment transactions. + CommitSig []byte + + // Htlcs is the set of HTLC's that are pending at this particular + // commitment height. + Htlcs []HTLC +} + +// Copy returns a deep copy of the channel commitment. +func (c *ChannelCommitment) Copy() ChannelCommitment { + c2 := *c + if c.CommitTx != nil { + c2.CommitTx = c.CommitTx.Copy() + } + if len(c.CommitSig) > 0 { + c2.CommitSig = make([]byte, len(c.CommitSig)) + copy(c2.CommitSig, c.CommitSig) + } + + c.CustomBlob.WhenSome(func(blob tlv.Blob) { + blobCopy := make([]byte, len(blob)) + copy(blobCopy, blob) + c2.CustomBlob = fn.Some(blobCopy) + }) + + if len(c.Htlcs) > 0 { + c2.Htlcs = make([]HTLC, len(c.Htlcs)) + for i, h := range c.Htlcs { + c2.Htlcs[i] = h.Copy() + } + } + + return c2 +} + +// HTLC is the on-disk representation of a hash time-locked contract. HTLCs are +// contained within ChannelDeltas which encode the current state of the +// commitment between state updates. +// +// TODO(roasbeef): save space by using smaller ints at tail end? +type HTLC struct { + // TODO(yy): can embed an HTLCEntry here. + + // Signature is the signature for the second level covenant transaction + // for this HTLC. The second level transaction is a timeout tx in the + // case that this is an outgoing HTLC, and a success tx in the case + // that this is an incoming HTLC. + // + // TODO(roasbeef): make [64]byte instead? + Signature []byte + + // RHash is the payment hash of the HTLC. + RHash [32]byte + + // Amt is the amount of milli-satoshis this HTLC escrows. + Amt lnwire.MilliSatoshi + + // RefundTimeout is the absolute timeout on the HTLC that the sender + // must wait before reclaiming the funds in limbo. + RefundTimeout uint32 + + // OutputIndex is the output index for this particular HTLC output + // within the commitment transaction. + OutputIndex int32 + + // Incoming denotes whether we're the receiver or the sender of this + // HTLC. + Incoming bool + + // OnionBlob is an opaque blob which is used to complete multi-hop + // routing. + OnionBlob [lnwire.OnionPacketSize]byte + + // HtlcIndex is the HTLC counter index of this active, outstanding + // HTLC. This differs from the LogIndex, as the HtlcIndex is only + // incremented for each offered HTLC, while they LogIndex is + // incremented for each update (includes settle+fail). + HtlcIndex uint64 + + // LogIndex is the cumulative log index of this HTLC. This differs + // from the HtlcIndex as this will be incremented for each new log + // update added. + LogIndex uint64 + + // ExtraData contains any additional information that was transmitted + // with the HTLC via TLVs. This data *must* already be encoded as a + // TLV stream, and may be empty. The length of this data is naturally + // limited by the space available to TLVs in update_add_htlc: + // = 65535 bytes (bolt 8 maximum message size): + // - 2 bytes (bolt 1 message_type) + // - 32 bytes (channel_id) + // - 8 bytes (id) + // - 8 bytes (amount_msat) + // - 32 bytes (payment_hash) + // - 4 bytes (cltv_expiry) + // - 1366 bytes (onion_routing_packet) + // = 64083 bytes maximum possible TLV stream + // + // Note that this extra data is stored inline with the OnionBlob for + // legacy reasons, see serialization/deserialization functions for + // detail. + ExtraData lnwire.ExtraOpaqueData + + // BlindingPoint is an optional blinding point included with the HTLC. + // + // Note: this field is not a part of on-disk representation of the + // HTLC. It is stored in the ExtraData field, which is used to store + // a TLV stream of additional information associated with the HTLC. + BlindingPoint lnwire.BlindingPointRecord + + // CustomRecords is a set of custom TLV records that are associated with + // this HTLC. These records are used to store additional information + // about the HTLC that is not part of the standard HTLC fields. This + // field is encoded within the ExtraData field. + CustomRecords lnwire.CustomRecords +} + +// Copy returns a full copy of the target HTLC. +func (h *HTLC) Copy() HTLC { + clone := HTLC{ + Incoming: h.Incoming, + Amt: h.Amt, + RefundTimeout: h.RefundTimeout, + OutputIndex: h.OutputIndex, + RHash: h.RHash, + OnionBlob: h.OnionBlob, + HtlcIndex: h.HtlcIndex, + LogIndex: h.LogIndex, + } + if len(h.Signature) > 0 { + clone.Signature = make([]byte, len(h.Signature)) + copy(clone.Signature, h.Signature) + } + if len(h.ExtraData) > 0 { + clone.ExtraData = make(lnwire.ExtraOpaqueData, len(h.ExtraData)) + copy(clone.ExtraData, h.ExtraData) + } + clone.BlindingPoint = h.BlindingPoint + if h.CustomRecords != nil { + clone.CustomRecords = make( + lnwire.CustomRecords, len(h.CustomRecords), + ) + for k, v := range h.CustomRecords { + clone.CustomRecords[k] = make([]byte, len(v)) + copy(clone.CustomRecords[k], v) + } + } + + return clone +} + +// LogUpdate represents a pending update to the remote commitment chain. The +// log update may be an add, fail, or settle entry. We maintain this data in +// order to be able to properly retransmit our proposed state if necessary. +type LogUpdate struct { + // LogIndex is the log index of this proposed commitment update entry. + LogIndex uint64 + + // UpdateMsg is the update message that was included within our + // local update log. The LogIndex value denotes the log index of this + // update which will be used when restoring our local update log if + // we're left with a dangling update on restart. + UpdateMsg lnwire.Message +} + +// CommitDiff represents the delta needed to apply the state transition between +// two subsequent commitment states. Given state N and state N+1, one is able +// to apply the set of messages contained within the CommitDiff to N to arrive +// at state N+1. Each time a new commitment is extended, we'll write a new +// commitment (along with the full commitment state) to disk so we can +// re-transmit the state in the case of a connection loss or message drop. +type CommitDiff struct { + // ChannelCommitment is the full commitment state that one would arrive + // at by applying the set of messages contained in the UpdateDiff to + // the prior accepted commitment. + Commitment ChannelCommitment + + // LogUpdates is the set of messages sent prior to the commitment state + // transition in question. Upon reconnection, if we detect that they + // don't have the commitment, then we re-send this along with the + // proper signature. + LogUpdates []LogUpdate + + // CommitSig is the exact CommitSig message that should be sent after + // the set of LogUpdates above has been retransmitted. The signatures + // within this message should properly cover the new commitment state + // and also the HTLC's within the new commitment state. + CommitSig *lnwire.CommitSig + + // OpenedCircuitKeys is a set of unique identifiers for any downstream + // Add packets included in this commitment txn. After a restart, this + // set of htlcs is acked from the link's incoming mailbox to ensure + // there isn't an attempt to re-add them to this commitment txn. + OpenedCircuitKeys []models.CircuitKey + + // ClosedCircuitKeys records the unique identifiers for any settle/fail + // packets that were resolved by this commitment txn. After a restart, + // this is used to ensure those circuits are removed from the circuit + // map, and the downstream packets in the link's mailbox are removed. + ClosedCircuitKeys []models.CircuitKey + + // AddAcks specifies the locations (commit height, pkg index) of any + // Adds that were failed/settled in this commit diff. This will ack + // entries in *this* channel's forwarding packages. + // + // NOTE: This value is not serialized, it is used to atomically mark the + // resolution of adds, such that they will not be reprocessed after a + // restart. + AddAcks []AddRef + + // SettleFailAcks specifies the locations (chan id, commit height, pkg + // index) of any Settles or Fails that were locked into this commit + // diff, and originate from *another* channel, i.e. the outgoing link. + // + // NOTE: This value is not serialized, it is used to atomically acks + // settles and fails from the forwarding packages of other channels, + // such that they will not be reforwarded internally after a restart. + SettleFailAcks []SettleFailRef +} diff --git a/chanstate/commitment_test.go b/chanstate/commitment_test.go new file mode 100644 index 00000000000..ab1794d7712 --- /dev/null +++ b/chanstate/commitment_test.go @@ -0,0 +1,66 @@ +package chanstate + +import ( + "bytes" + "testing" + + "github.com/btcsuite/btcd/btcec/v2" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/tlv" + "github.com/stretchr/testify/require" +) + +func TestHTLCCopy(t *testing.T) { + t.Parallel() + + _, blindingPoint := btcec.PrivKeyFromBytes(bytes.Repeat([]byte{1}, 32)) + + var rHash [32]byte + copy(rHash[:], bytes.Repeat([]byte{2}, len(rHash))) + + var onionBlob [lnwire.OnionPacketSize]byte + copy(onionBlob[:], bytes.Repeat([]byte{3}, len(onionBlob))) + + htlc := HTLC{ + Signature: []byte{4, 5, 6}, + RHash: rHash, + Amt: 1000, + RefundTimeout: 144, + OutputIndex: 3, + Incoming: true, + OnionBlob: onionBlob, + HtlcIndex: 42, + LogIndex: 43, + ExtraData: lnwire.ExtraOpaqueData{7, 8, 9}, + BlindingPoint: tlv.SomeRecordT( + tlv.NewPrimitiveRecord[lnwire.BlindingPointTlvType]( + blindingPoint, + ), + ), + CustomRecords: lnwire.CustomRecords{ + lnwire.MinCustomRecordsTlvType: []byte{10, 11, 12}, + }, + } + + clone := htlc.Copy() + require.Equal(t, htlc, clone) + + clone.Signature[0] = 0 + require.Equal(t, byte(4), htlc.Signature[0]) + + clone.ExtraData[0] = 0 + require.Equal(t, byte(7), htlc.ExtraData[0]) + + clone.CustomRecords[lnwire.MinCustomRecordsTlvType] = []byte{0} + require.Equal( + t, []byte{10, 11, 12}, + htlc.CustomRecords[lnwire.MinCustomRecordsTlvType], + ) + + clone = htlc.Copy() + clone.CustomRecords[lnwire.MinCustomRecordsTlvType][0] = 0 + require.Equal( + t, []byte{10, 11, 12}, + htlc.CustomRecords[lnwire.MinCustomRecordsTlvType], + ) +} diff --git a/chanstate/config.go b/chanstate/config.go new file mode 100644 index 00000000000..17e3e5e4fae --- /dev/null +++ b/chanstate/config.go @@ -0,0 +1,108 @@ +package chanstate + +import ( + "github.com/btcsuite/btcd/btcutil" + "github.com/lightningnetwork/lnd/keychain" + "github.com/lightningnetwork/lnd/lnwire" +) + +// ChannelStateBounds are the parameters from OpenChannel and AcceptChannel +// that are responsible for providing bounds on the state space of the abstract +// channel state. These values must be remembered for normal channel operation +// but they do not impact how we compute the commitment transactions themselves. +type ChannelStateBounds struct { + // ChanReserve is an absolute reservation on the channel for the + // owner of this set of constraints. This means that the current + // settled balance for this node CANNOT dip below the reservation + // amount. This acts as a defense against costless attacks when + // either side no longer has any skin in the game. + ChanReserve btcutil.Amount + + // MaxPendingAmount is the maximum pending HTLC value that the + // owner of these constraints can offer the remote node at a + // particular time. + MaxPendingAmount lnwire.MilliSatoshi + + // MinHTLC is the minimum HTLC value that the owner of these + // constraints can offer the remote node. If any HTLCs below this + // amount are offered, then the HTLC will be rejected. This, in + // tandem with the dust limit allows a node to regulate the + // smallest HTLC that it deems economically relevant. + MinHTLC lnwire.MilliSatoshi + + // MaxAcceptedHtlcs is the maximum number of HTLCs that the owner of + // this set of constraints can offer the remote node. This allows each + // node to limit their over all exposure to HTLCs that may need to be + // acted upon in the case of a unilateral channel closure or a contract + // breach. + MaxAcceptedHtlcs uint16 +} + +// CommitmentParams are the parameters from OpenChannel and +// AcceptChannel that are required to render an abstract channel state to a +// concrete commitment transaction. These values are necessary to (re)compute +// the commitment transaction. We treat these differently than the state space +// bounds because their history needs to be stored in order to properly handle +// chain resolution. +type CommitmentParams struct { + // DustLimit is the threshold (in satoshis) below which any outputs + // should be trimmed. When an output is trimmed, it isn't materialized + // as an actual output, but is instead burned to miner's fees. + DustLimit btcutil.Amount + + // CsvDelay is the relative time lock delay expressed in blocks. Any + // settled outputs that pay to the owner of this channel configuration + // MUST ensure that the delay branch uses this value as the relative + // time lock. Similarly, any HTLC's offered by this node should use + // this value as well. + CsvDelay uint16 +} + +// ChannelConfig is a struct that houses the various configuration opens for +// channels. Each side maintains an instance of this configuration file as it +// governs: how the funding and commitment transaction to be created, the +// nature of HTLC's allotted, the keys to be used for delivery, and relative +// time lock parameters. +type ChannelConfig struct { + // ChannelStateBounds is the set of constraints that must be + // upheld for the duration of the channel for the owner of this channel + // configuration. Constraints govern a number of flow control related + // parameters, also including the smallest HTLC that will be accepted + // by a participant. + ChannelStateBounds + + // CommitmentParams is an embedding of the parameters + // required to render an abstract channel state into a concrete + // commitment transaction. + CommitmentParams + + // MultiSigKey is the key to be used within the 2-of-2 output script + // for the owner of this channel config. + MultiSigKey keychain.KeyDescriptor + + // RevocationBasePoint is the base public key to be used when deriving + // revocation keys for the remote node's commitment transaction. This + // will be combined along with a per commitment secret to derive a + // unique revocation key for each state. + RevocationBasePoint keychain.KeyDescriptor + + // PaymentBasePoint is the base public key to be used when deriving + // the key used within the non-delayed pay-to-self output on the + // commitment transaction for a node. This will be combined with a + // tweak derived from the per-commitment point to ensure unique keys + // for each commitment transaction. + PaymentBasePoint keychain.KeyDescriptor + + // DelayBasePoint is the base public key to be used when deriving the + // key used within the delayed pay-to-self output on the commitment + // transaction for a node. This will be combined with a tweak derived + // from the per-commitment point to ensure unique keys for each + // commitment transaction. + DelayBasePoint keychain.KeyDescriptor + + // HtlcBasePoint is the base public key to be used when deriving the + // local HTLC key. The derived key (combined with the tweak derived + // from the per-commitment point) is used within the "to self" clause + // within any HTLC output scripts. + HtlcBasePoint keychain.KeyDescriptor +} diff --git a/chanstate/errors.go b/chanstate/errors.go new file mode 100644 index 00000000000..43cf787112c --- /dev/null +++ b/chanstate/errors.go @@ -0,0 +1,90 @@ +package chanstate + +import ( + "errors" + "fmt" +) + +var ( + // ErrNoChanDBExists is returned when a channel bucket hasn't been + // created. + ErrNoChanDBExists = fmt.Errorf("channel db has not yet been created") + + // ErrNoCommitmentsFound is returned when a channel has not set + // commitment states. + ErrNoCommitmentsFound = fmt.Errorf("no commitments found") + + // ErrNoChanInfoFound is returned when a particular channel does not + // have any channels state. + ErrNoChanInfoFound = fmt.Errorf("no chan info found") + + // ErrChannelNotFound is returned when we attempt to locate a channel + // for a specific chain, but it is not found. + ErrChannelNotFound = fmt.Errorf("channel not found") + + // ErrChanAlreadyExists is return when the caller attempts to create a + // channel with a channel point that is already present in the + // database. + ErrChanAlreadyExists = fmt.Errorf("channel already exists") + + // ErrNoRevocationsFound is returned when revocation state for a + // particular channel cannot be found. + ErrNoRevocationsFound = fmt.Errorf("no revocations found") + + // ErrNoPendingCommit is returned when there is not a pending + // commitment for a remote party. A new commitment is written to disk + // each time we write a new state in order to be properly fault + // tolerant. + ErrNoPendingCommit = fmt.Errorf("no pending commits found") + + // ErrNoActiveChannels is returned when there is no active (open) + // channels within the database. + ErrNoActiveChannels = fmt.Errorf("no active channels exist") + + // ErrNoHistoricalBucket is returned when the historical channel + // bucket not been created yet. + ErrNoHistoricalBucket = fmt.Errorf("historical channel bucket has " + + "not yet been created") + + // ErrNoClosedChannels is returned when a node is queries for all the + // channels it has closed, but it hasn't yet closed any channels. + ErrNoClosedChannels = fmt.Errorf("no channel have been closed yet") + + // ErrClosedChannelNotFound signals that a closed channel could not be + // found in the channel state store. + ErrClosedChannelNotFound = errors.New("unable to find closed " + + "channel summary") + + // ErrNoPastDeltas is returned when the channel delta bucket hasn't + // been created. + ErrNoPastDeltas = fmt.Errorf("channel has no recorded deltas") + + // ErrNoCommitPoint is returned when no data loss commit point is found + // in the database. + ErrNoCommitPoint = fmt.Errorf("no commit point found") + + // ErrNoCloseTx is returned when no closing tx is found for a channel + // in the state CommitBroadcasted. + ErrNoCloseTx = fmt.Errorf("no closing tx found") + + // ErrNoShutdownInfo is returned when no shutdown info has been + // persisted for a channel. + ErrNoShutdownInfo = errors.New("no shutdown info") + + // ErrNoRestoredChannelMutation is returned when a caller attempts to + // mutate a channel that's been recovered. + ErrNoRestoredChannelMutation = fmt.Errorf("cannot mutate restored " + + "channel state") + + // ErrChanBorked is returned when a caller attempts to mutate a borked + // channel. + ErrChanBorked = fmt.Errorf("cannot mutate borked channel") + + // ErrMissingIndexEntry is returned when a caller attempts to close a + // channel and the outpoint is missing from the index. + ErrMissingIndexEntry = fmt.Errorf("missing outpoint from index") + + // ErrOnionBlobLength is returned is an onion blob with incorrect + // length is read from disk. + ErrOnionBlobLength = errors.New("onion blob < 1366 bytes") +) diff --git a/chanstate/forwarding.go b/chanstate/forwarding.go new file mode 100644 index 00000000000..dc101ef76b1 --- /dev/null +++ b/chanstate/forwarding.go @@ -0,0 +1,259 @@ +package chanstate + +import ( + "bytes" + "encoding/binary" + "fmt" + + "github.com/lightningnetwork/lnd/lnwire" +) + +// AddRef is used to identify a particular Add in a FwdPkg. The short channel ID +// is assumed to be that of the packager. +type AddRef struct { + // Height is the remote commitment height that locked in the Add. + Height uint64 + + // Index is the index of the Add within the fwd pkg's Adds. + // + // NOTE: This index is static over the lifetime of a forwarding package. + Index uint16 +} + +// SettleFailRef is used to locate a Settle/Fail in another channel's FwdPkg. A +// channel does not remove its own Settle/Fail htlcs, so the source is provided +// to locate a db bucket belonging to another channel. +type SettleFailRef struct { + // Source identifies the outgoing link that locked in the settle or + // fail. This is then used by the *incoming* link to find the settle + // fail in another link's forwarding packages. + Source lnwire.ShortChannelID + + // Height is the remote commitment height that locked in this + // Settle/Fail. + Height uint64 + + // Index is the index of the Add with the fwd pkg's SettleFails. + // + // NOTE: This index is static over the lifetime of a forwarding package. + Index uint16 +} + +// FwdState is an enum used to describe the lifecycle of a FwdPkg. +type FwdState byte + +const ( + // FwdStateLockedIn is the starting state for all forwarding packages. + // Packages in this state have not yet committed to the exact set of + // Adds to forward to the switch. + FwdStateLockedIn FwdState = iota + + // FwdStateProcessed marks the state in which all Adds have been + // locally processed and the forwarding decision to the switch has been + // persisted. + FwdStateProcessed + + // FwdStateCompleted signals that all Adds have been acked, and that all + // settles and fails have been delivered to their sources. Packages in + // this state can be removed permanently. + FwdStateCompleted +) + +// PkgFilter is used to compactly represent a particular subset of the Adds in a +// forwarding package. Each filter is represented as a simple, statically-sized +// bitvector, where the elements are intended to be the indices of the Adds as +// they are written in the FwdPkg. +type PkgFilter struct { + count uint16 + filter []byte +} + +// NewPkgFilter initializes an empty PkgFilter supporting `count` elements. +func NewPkgFilter(count uint16) *PkgFilter { + // We add 7 to ensure that the integer division yields properly rounded + // values. + filterLen := (count + 7) / 8 + + return &PkgFilter{ + count: count, + filter: make([]byte, filterLen), + } +} + +// Count returns the number of elements represented by this PkgFilter. +func (f *PkgFilter) Count() uint16 { + return f.count +} + +// Set marks the `i`-th element as included by this filter. +// NOTE: It is assumed that i is always less than count. +func (f *PkgFilter) Set(i uint16) { + byt := i / 8 + bit := i % 8 + + // Set the i-th bit in the filter. + // TODO(conner): ignore if > count to prevent panic? + f.filter[byt] |= byte(1 << (7 - bit)) +} + +// Contains queries the filter for membership of index `i`. +// NOTE: It is assumed that i is always less than count. +func (f *PkgFilter) Contains(i uint16) bool { + byt := i / 8 + bit := i % 8 + + // Read the i-th bit in the filter. + // TODO(conner): ignore if > count to prevent panic? + return f.filter[byt]&(1<<(7-bit)) != 0 +} + +// Equal checks two PkgFilters for equality. +func (f *PkgFilter) Equal(f2 *PkgFilter) bool { + if f == f2 { + return true + } + if f.count != f2.count { + return false + } + + return bytes.Equal(f.filter, f2.filter) +} + +// IsFull returns true if every element in the filter has been Set, and false +// otherwise. +func (f *PkgFilter) IsFull() bool { + // Batch validate bytes that are fully used. + for i := uint16(0); i < f.count/8; i++ { + if f.filter[i] != 0xFF { + return false + } + } + + // If the count is not a multiple of 8, check that the filter contains + // all remaining bits. + rem := f.count % 8 + for idx := f.count - rem; idx < f.count; idx++ { + if !f.Contains(idx) { + return false + } + } + + return true +} + +// String returns a human-readable string. +func (f *PkgFilter) String() string { + return fmt.Sprintf("count=%v, filter=%v", f.count, f.filter) +} + +// FwdPkg records all adds, settles, and fails that were locked in as a result +// of the remote peer sending us a revocation. Each package is identified by +// the short chanid and remote commitment height corresponding to the revocation +// that locked in the HTLCs. For everything except a locally initiated payment, +// settles and fails in a forwarding package must have a corresponding Add in +// another package, and can be removed individually once the source link has +// received the fail/settle. +// +// Adds cannot be removed, as we need to present the same batch of Adds to +// properly handle replay protection. Instead, we use a PkgFilter to mark that +// we have finished processing a particular Add. A FwdPkg should only be deleted +// after the AckFilter is full and all settles and fails have been persistently +// removed. +type FwdPkg struct { + // Source identifies the channel that wrote this forwarding package. + Source lnwire.ShortChannelID + + // Height is the height of the remote commitment chain that locked in + // this forwarding package. + Height uint64 + + // State signals the persistent condition of the package and directs how + // to reprocess the package in the event of failures. + State FwdState + + // Adds contains all add messages which need to be processed and + // forwarded to the switch. Adds does not change over the life of a + // forwarding package. + Adds []LogUpdate + + // FwdFilter is a filter containing the indices of all Adds that were + // forwarded to the switch. + // + // NOTE: This value signals when persisted to disk that the fwd package + // has been processed and garbage collection can happen. So it also + // has to be set for packages with no adds (empty packages or only + // settle/fail packages) so that they can be garbage collected as well. + FwdFilter *PkgFilter + + // AckFilter is a filter containing the indices of all Adds for which + // the source has received a settle or fail and is reflected in the next + // commitment txn. A package should not be removed until IsFull() + // returns true. + AckFilter *PkgFilter + + // SettleFails contains all settle and fail messages that should be + // forwarded to the switch. + SettleFails []LogUpdate + + // SettleFailFilter is a filter containing the indices of all Settle or + // Fails originating in this package that have been received and locked + // into the incoming link's commitment state. + SettleFailFilter *PkgFilter +} + +// NewFwdPkg initializes a new forwarding package in FwdStateLockedIn. This +// should be used to create a package at the time we receive a revocation. +func NewFwdPkg(source lnwire.ShortChannelID, height uint64, + addUpdates, settleFailUpdates []LogUpdate) *FwdPkg { + + nAddUpdates := uint16(len(addUpdates)) + nSettleFailUpdates := uint16(len(settleFailUpdates)) + + return &FwdPkg{ + Source: source, + Height: height, + State: FwdStateLockedIn, + Adds: addUpdates, + FwdFilter: NewPkgFilter(nAddUpdates), + AckFilter: NewPkgFilter(nAddUpdates), + SettleFails: settleFailUpdates, + SettleFailFilter: NewPkgFilter(nSettleFailUpdates), + } +} + +// SourceRef is a convenience method that returns an AddRef to this forwarding +// package for the index in the argument. It is the caller's responsibility +// to ensure that the index is in bounds. +func (f *FwdPkg) SourceRef(i uint16) AddRef { + return AddRef{ + Height: f.Height, + Index: i, + } +} + +// DestRef is a convenience method that returns a SettleFailRef to this +// forwarding package for the index in the argument. It is the caller's +// responsibility to ensure that the index is in bounds. +func (f *FwdPkg) DestRef(i uint16) SettleFailRef { + return SettleFailRef{ + Source: f.Source, + Height: f.Height, + Index: i, + } +} + +// ID returns an unique identifier for this package, used to ensure that sphinx +// replay processing of this batch is idempotent. +func (f *FwdPkg) ID() []byte { + var id = make([]byte, 16) + binary.BigEndian.PutUint64(id[:8], f.Source.ToUint64()) + binary.BigEndian.PutUint64(id[8:], f.Height) + + return id +} + +// String returns a human-readable description of the forwarding package. +func (f *FwdPkg) String() string { + return fmt.Sprintf("%T(src=%v, height=%v, nadds=%v, nfailsettles=%v)", + f, f.Source, f.Height, len(f.Adds), len(f.SettleFails)) +} diff --git a/chanstate/interface.go b/chanstate/interface.go index 49344855128..a75a354fcae 100644 --- a/chanstate/interface.go +++ b/chanstate/interface.go @@ -1,11 +1,15 @@ package chanstate import ( + "net" + "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/wire" - "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph/db/models" + "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/shachain" ) // Store is the full persistence contract for the channel-state subsystem. @@ -24,6 +28,28 @@ type Store interface { // HistoricalChannelStore owns the post-close historical channel view. HistoricalChannelStore + // OpenChannelLifecycleStore owns persisted lifecycle state for open + // channel records. + OpenChannelLifecycleStore + + // OpenChannelStatusStore owns persisted status flags for open channel + // records. + OpenChannelStatusStore + + // OpenChannelShutdownStore owns persisted shutdown state. + OpenChannelShutdownStore + + // OpenChannelCloseTxStore owns persisted closing transaction state. + OpenChannelCloseTxStore + + // OpenChannelCommitmentStore owns persisted commitment state for open + // channel records. + OpenChannelCommitmentStore + + // OpenChannelFwdPkgStore owns forwarding packages tied to open + // channel records. + OpenChannelFwdPkgStore + // ClosedChannelStore owns closed-channel summaries and lifecycle // mutations. ClosedChannelStore @@ -47,48 +73,46 @@ type OpenChannelStore interface { // target nodeID. In the case that no active channels are known to // have been created with this node, then a zero-length slice is // returned. - FetchOpenChannels(nodeID *btcec.PublicKey) ( - []*channeldb.OpenChannel, error) + FetchOpenChannels(nodeID *btcec.PublicKey) ([]*OpenChannel, error) // FetchChannel attempts to locate a channel specified by the passed // channel point. If the channel cannot be found, then an error will // be returned. - FetchChannel(chanPoint wire.OutPoint) (*channeldb.OpenChannel, error) + FetchChannel(chanPoint wire.OutPoint) (*OpenChannel, error) // FetchChannelByID attempts to locate a channel specified by the // passed channel ID. If the channel cannot be found, then an error // will be returned. - FetchChannelByID(id lnwire.ChannelID) (*channeldb.OpenChannel, error) + FetchChannelByID(id lnwire.ChannelID) (*OpenChannel, error) // FetchAllChannels attempts to retrieve all open channels currently // stored within the database, including pending open, fully open and // channels waiting for a closing transaction to confirm. - FetchAllChannels() ([]*channeldb.OpenChannel, error) + FetchAllChannels() ([]*OpenChannel, error) // FetchAllOpenChannels will return all channels that have the // funding transaction confirmed, and is not waiting for a closing // transaction to be confirmed. - FetchAllOpenChannels() ([]*channeldb.OpenChannel, error) + FetchAllOpenChannels() ([]*OpenChannel, error) // FetchPendingChannels will return channels that have completed the // process of generating and broadcasting funding transactions, but // whose funding transactions have yet to be confirmed on the // blockchain. - FetchPendingChannels() ([]*channeldb.OpenChannel, error) + FetchPendingChannels() ([]*OpenChannel, error) // FetchWaitingCloseChannels will return all channels that have been // opened, but are now waiting for a closing transaction to be // confirmed. // // NOTE: This includes channels that are also pending to be opened. - FetchWaitingCloseChannels() ([]*channeldb.OpenChannel, error) + FetchWaitingCloseChannels() ([]*OpenChannel, error) // FetchPermAndTempPeers returns a map where the key is the remote // node's public key and the value is a struct that has a tally of // the pending-open channels and whether the peer has an open or // closed channel with us. - FetchPermAndTempPeers(chainHash []byte) ( - map[string]channeldb.ChanCount, error) + FetchPermAndTempPeers(chainHash []byte) (map[string]ChanCount, error) // RestoreChannelShells reconstructs the state of an OpenChannel from // the ChannelShell. We'll attempt to write the new channel to disk, @@ -96,15 +120,205 @@ type OpenChannelStore interface { // finally create an edge within the graph for the channel as well. // This method is idempotent, so repeated calls with the same set of // channel shells won't modify the database after the initial call. - RestoreChannelShells(channelShells ...*channeldb.ChannelShell) error + RestoreChannelShells(channelShells ...*ChannelShell) error } // HistoricalChannelStore owns the post-close historical channel view. type HistoricalChannelStore interface { // FetchHistoricalChannel fetches open channel data from the // historical channel bucket. - FetchHistoricalChannel(outPoint *wire.OutPoint) ( - *channeldb.OpenChannel, error) + FetchHistoricalChannel(outPoint *wire.OutPoint) (*OpenChannel, error) +} + +// OpenChannelLifecycleStore owns persisted lifecycle state for open channel +// records. +type OpenChannelLifecycleStore interface { + // SyncPendingChannel writes a pending channel to the store and records + // the funding broadcast height. + SyncPendingChannel(channel *OpenChannel, addr net.Addr, + pendingHeight uint32) error + + // RefreshChannel updates the in-memory channel state using the latest + // state observed on disk. + RefreshChannel(channel *OpenChannel) error + + // MarkChannelConfirmationHeight updates the channel's confirmation + // height once the channel opening transaction receives one + // confirmation. + MarkChannelConfirmationHeight(channel *OpenChannel, height uint32) error + + // MarkChannelCloseConfirmationHeight updates the channel's close + // confirmation height when the closing transaction is first detected + // in a block. + MarkChannelCloseConfirmationHeight(channel *OpenChannel, + height fn.Option[uint32]) error + + // MarkChannelOpen marks a channel as fully open given a locator that + // uniquely describes its location within the chain. + MarkChannelOpen(channel *OpenChannel, + openLoc lnwire.ShortChannelID) error + + // MarkChannelRealScid marks the zero-conf channel's confirmed + // ShortChannelID. + MarkChannelRealScid(channel *OpenChannel, + realScid lnwire.ShortChannelID) error + + // MarkChannelScidAliasNegotiated marks that the scid-alias feature + // bit was negotiated during the lifetime of the channel. + MarkChannelScidAliasNegotiated(channel *OpenChannel) error +} + +// OpenChannelStatusStore owns persisted status flags for open channel records. +type OpenChannelStatusStore interface { + // ApplyChannelStatus adds the target status to the channel's + // persisted status bit field. + ApplyChannelStatus(channel *OpenChannel, status ChannelStatus) error + + // ClearChannelStatus clears the target status from the channel's + // persisted status bit field. + ClearChannelStatus(channel *OpenChannel, status ChannelStatus) error + + // MarkChannelDataLoss marks the channel as local-data-loss and stores + // the commit point needed if the remote force closes. + MarkChannelDataLoss(channel *OpenChannel, + commitPoint *btcec.PublicKey) error + + // FetchChannelDataLossCommitPoint retrieves the commit point stored + // when the channel was marked as local-data-loss. + FetchChannelDataLossCommitPoint(channel *OpenChannel) ( + *btcec.PublicKey, error) + + // MarkChannelBorked marks the channel as irreconcilable. + MarkChannelBorked(channel *OpenChannel) error +} + +// OpenChannelShutdownStore owns persisted shutdown state. +type OpenChannelShutdownStore interface { + // StoreChannelShutdownInfo persists the ShutdownInfo for the target + // channel. + StoreChannelShutdownInfo(channel *OpenChannel, info *ShutdownInfo) error + + // FetchChannelShutdownInfo fetches the persisted ShutdownInfo for the + // target channel. + FetchChannelShutdownInfo(channel *OpenChannel) ( + fn.Option[ShutdownInfo], error) +} + +// OpenChannelCloseTxStore owns persisted closing transaction state. +type OpenChannelCloseTxStore interface { + // MarkChannelCommitmentBroadcasted marks the channel as having a + // commitment transaction broadcast. + MarkChannelCommitmentBroadcasted(channel *OpenChannel, + closeTx *wire.MsgTx, closer lntypes.ChannelParty) error + + // MarkChannelCoopBroadcasted marks the channel as having a + // cooperative close transaction broadcast. + MarkChannelCoopBroadcasted(channel *OpenChannel, closeTx *wire.MsgTx, + closer lntypes.ChannelParty) error + + // FetchChannelBroadcastedCommitment fetches the stored unilateral + // closing transaction. + FetchChannelBroadcastedCommitment(channel *OpenChannel) (*wire.MsgTx, + error) + + // FetchChannelBroadcastedCooperative fetches the stored cooperative + // closing transaction. + FetchChannelBroadcastedCooperative(channel *OpenChannel) (*wire.MsgTx, + error) +} + +// OpenChannelCommitmentStore owns persisted commitment state for open channel +// records. +type OpenChannelCommitmentStore interface { + OpenChannelCommitmentMutationStore + OpenChannelCommitmentQueryStore +} + +// OpenChannelCommitmentMutationStore owns persisted commitment mutations for +// open channel records. +type OpenChannelCommitmentMutationStore interface { + // UpdateChannelCommitment updates the local commitment state. It + // locks in pending local updates received from the remote party and + // persists remote log updates that have been acked, but not signed + // for yet. The returned map contains all HTLC resolutions locked into + // this commitment, keyed by HTLC index. + UpdateChannelCommitment(channel *OpenChannel, + newCommitment *ChannelCommitment, + unsignedAckedUpdates []LogUpdate) (map[uint64]bool, error) + + // AppendRemoteCommitChain appends a new CommitDiff to the remote + // party's commitment chain. This is used after preparing a new remote + // commitment state, before transmitting it to the remote party. + AppendRemoteCommitChain(channel *OpenChannel, diff *CommitDiff) error + + // RemoteCommitChainTip returns the "tip" of the current remote + // commitment chain. + RemoteCommitChainTip(channel *OpenChannel) (*CommitDiff, error) + + // UnsignedAckedUpdates retrieves the persisted unsigned acked remote + // log updates that still need to be signed for. + UnsignedAckedUpdates(channel *OpenChannel) ([]LogUpdate, error) + + // RemoteUnsignedLocalUpdates retrieves the persisted, unsigned local + // log updates that the remote still needs to sign for. + RemoteUnsignedLocalUpdates(channel *OpenChannel) ([]LogUpdate, error) + + // InsertNextRevocation inserts the next commitment point into the + // persisted channel state. + InsertNextRevocation(channel *OpenChannel, + revKey *btcec.PublicKey) error + + // AdvanceCommitChainTail records the new state transition within the + // revocation log and promotes the pending remote commitment to the + // current remote commitment. + AdvanceCommitChainTail(channel *OpenChannel, fwdPkg *FwdPkg, + updates []LogUpdate, ourOutputIndex, + theirOutputIndex uint32) error +} + +// OpenChannelCommitmentQueryStore owns persisted commitment queries for open +// channel records. +type OpenChannelCommitmentQueryStore interface { + // CommitmentHeight returns the current persisted commitment height. + CommitmentHeight(channel *OpenChannel) (uint64, error) + + // LatestCommitments returns the two latest commitments for both the + // local and remote party. + LatestCommitments(channel *OpenChannel) (*ChannelCommitment, + *ChannelCommitment, error) + + // RemoteRevocationStore returns the most up to date commitment version + // of the revocation storage tree for the remote party. + RemoteRevocationStore(channel *OpenChannel) (shachain.Store, error) + + // FindPreviousState scans through the append-only log in an attempt to + // recover the previous channel state indicated by the update number. + FindPreviousState(channel *OpenChannel, updateNum uint64) ( + *RevocationLog, *ChannelCommitment, error) +} + +// OpenChannelFwdPkgStore owns forwarding packages tied to open channel +// records. +type OpenChannelFwdPkgStore interface { + // LoadFwdPkgs loads forwarding packages that have not been processed. + LoadFwdPkgs(channel *OpenChannel) ([]*FwdPkg, error) + + // AckAddHtlcs marks add HTLCs in forwarding packages as resolved. + AckAddHtlcs(channel *OpenChannel, addRefs ...AddRef) error + + // AckSettleFails marks settles or fails as delivered to the incoming + // link. + AckSettleFails(channel *OpenChannel, + settleFailRefs ...SettleFailRef) error + + // SetFwdFilter writes the forwarding filter for the forwarding package + // identified by height. + SetFwdFilter(channel *OpenChannel, height uint64, + fwdFilter *PkgFilter) error + + // RemoveFwdPkgs removes forwarding packages by remote commitment + // height. + RemoveFwdPkgs(channel *OpenChannel, heights ...uint64) error } // ClosedChannelStore owns closed-channel summaries and lifecycle mutations. @@ -117,17 +331,17 @@ type ClosedChannelStore interface { // become fully closed after _all_ the pending funds (if any) have // been swept. FetchClosedChannels(pendingOnly bool) ( - []*channeldb.ChannelCloseSummary, error) + []*ChannelCloseSummary, error) // FetchClosedChannel queries for a channel close summary using the // channel point of the channel in question. FetchClosedChannel(chanID *wire.OutPoint) ( - *channeldb.ChannelCloseSummary, error) + *ChannelCloseSummary, error) // FetchClosedChannelForID queries for a channel close summary using // the channel ID of the channel in question. FetchClosedChannelForID(cid lnwire.ChannelID) ( - *channeldb.ChannelCloseSummary, error) + *ChannelCloseSummary, error) // MarkChanFullyClosed marks a channel as fully closed within the // database. A channel should be marked as fully closed if the @@ -142,9 +356,8 @@ type ClosedChannelStore interface { // FetchClosedChannel and FetchClosedChannelForID. Any ChannelStatus // values are merged into the archived summary. Returns // ErrChannelCloseSummaryNil if summary is nil. - CloseChannel(channel *channeldb.OpenChannel, - summary *channeldb.ChannelCloseSummary, - statuses ...channeldb.ChannelStatus) error + CloseChannel(channel *OpenChannel, summary *ChannelCloseSummary, + statuses ...ChannelStatus) error // AbandonChannel attempts to remove the target channel from the open // channel database. If the channel was already removed (has a closed @@ -159,7 +372,7 @@ type FinalHTLCStore interface { // database. If the htlc has no final resolution yet, ErrHtlcUnknown // is returned. LookupFinalHtlc(chanID lnwire.ShortChannelID, - htlcIndex uint64) (*channeldb.FinalHtlcInfo, error) + htlcIndex uint64) (*FinalHtlcInfo, error) // PutOnchainFinalHtlcOutcome stores the final on-chain outcome of an // htlc in the database. @@ -211,18 +424,3 @@ type LinkNodeMaintainer interface { // called on startup to ensure that our database is consistent. RepairLinkNodes(network wire.BitcoinNet) error } - -// Compile-time assertion that channeldb.ChannelStateDB satisfies the Store -// contract. If a method signature drifts on the concrete type, -// this assertion will fail to build before any consumer migration. -// -// NOTE: This assertion lives in the interface file as a temporary exception to -// the established pattern (see invoices/sql_store.go, payments/db/kv_store.go, -// graph/db/kv_store.go), where each implementation asserts itself in its own -// file. The implementation still lives in channeldb/, and channeldb must not -// import chanstate to avoid a cycle, so the assertion has no local -// implementation file to live in yet. When the KV implementation moves into -// this package (chanstate/kv_store.go), this assertion MUST be removed from -// here and re-stated next to the local implementation, matching the precedent -// packages. -var _ Store = (*channeldb.ChannelStateDB)(nil) diff --git a/chanstate/kv_channel_setup.go b/chanstate/kv_channel_setup.go new file mode 100644 index 00000000000..fc06232f9cc --- /dev/null +++ b/chanstate/kv_channel_setup.go @@ -0,0 +1,208 @@ +package chanstate + +import ( + graphmodels "github.com/lightningnetwork/lnd/graph/db/models" + "github.com/lightningnetwork/lnd/kvdb" + "github.com/lightningnetwork/lnd/lnwire" +) + +var ( + // channelOpeningStateBucket is the database bucket used to store the + // channelOpeningState for each channel that is currently in the process + // of being opened. + channelOpeningStateBucket = []byte("channelOpeningState") + + // initialChannelForwardingPolicyBucket is the database bucket used to + // store the forwarding policy for each permanent channel that is + // currently in the process of being opened. + initialChannelForwardingPolicyBucket = []byte( + "initialChannelFwdingPolicy", + ) +) + +// ChannelOpeningStateBucketKey returns the top-level bucket key used to store +// serialized channel opening state. +func ChannelOpeningStateBucketKey() []byte { + return channelOpeningStateBucket +} + +// InitialChannelForwardingPolicyBucketKey returns the top-level bucket key used +// to store initial channel forwarding policies. +func InitialChannelForwardingPolicyBucketKey() []byte { + return initialChannelForwardingPolicyBucket +} + +// SaveChannelOpeningState saves the serialized channel state for the provided +// chanPoint to the channelOpeningStateBucket. +func SaveChannelOpeningState(tx kvdb.RwTx, outPoint, + serializedState []byte) error { + + bucket, err := tx.CreateTopLevelBucket(channelOpeningStateBucket) + if err != nil { + return err + } + + return bucket.Put(outPoint, serializedState) +} + +// SaveChannelOpeningState saves the serialized channel state for the provided +// chanPoint to the channelOpeningStateBucket. +func (s *KVStore) SaveChannelOpeningState(outPoint, + serializedState []byte) error { + + return kvdb.Update(s.backend, func(tx kvdb.RwTx) error { + return SaveChannelOpeningState(tx, outPoint, serializedState) + }, func() {}) +} + +// GetChannelOpeningState fetches the serialized channel state for the provided +// outPoint from the database, or returns ErrChannelNotFound if the channel is +// not found. +func GetChannelOpeningState(tx kvdb.RTx, outPoint []byte) ([]byte, error) { + bucket := tx.ReadBucket(channelOpeningStateBucket) + if bucket == nil { + // If the bucket does not exist, it means we never added + // a channel to the db, so return ErrChannelNotFound. + return nil, ErrChannelNotFound + } + + stateBytes := bucket.Get(outPoint) + if stateBytes == nil { + return nil, ErrChannelNotFound + } + + var serializedState []byte + serializedState = append(serializedState, stateBytes...) + + return serializedState, nil +} + +// GetChannelOpeningState fetches the serialized channel state for the provided +// outPoint from the database, or returns ErrChannelNotFound if the channel is +// not found. +func (s *KVStore) GetChannelOpeningState(outPoint []byte) ([]byte, error) { + var serializedState []byte + + err := kvdb.View(s.backend, func(tx kvdb.RTx) error { + var err error + serializedState, err = GetChannelOpeningState(tx, outPoint) + + return err + }, func() { + serializedState = nil + }) + + return serializedState, err +} + +// DeleteChannelOpeningState removes any state for outPoint from the database. +func DeleteChannelOpeningState(tx kvdb.RwTx, outPoint []byte) error { + bucket := tx.ReadWriteBucket(channelOpeningStateBucket) + if bucket == nil { + return ErrChannelNotFound + } + + return bucket.Delete(outPoint) +} + +// DeleteChannelOpeningState removes any state for outPoint from the database. +func (s *KVStore) DeleteChannelOpeningState(outPoint []byte) error { + return kvdb.Update(s.backend, func(tx kvdb.RwTx) error { + return DeleteChannelOpeningState(tx, outPoint) + }, func() {}) +} + +// SaveInitialForwardingPolicy saves the serialized forwarding policy for the +// provided permanent channel id to the initialChannelForwardingPolicyBucket. +func (s *KVStore) SaveInitialForwardingPolicy(chanID lnwire.ChannelID, + forwardingPolicy *graphmodels.ForwardingPolicy) error { + + chanIDCopy := make([]byte, 32) + copy(chanIDCopy, chanID[:]) + + scratch := make([]byte, 36) + byteOrder.PutUint64(scratch[:8], uint64(forwardingPolicy.MinHTLCOut)) + byteOrder.PutUint64(scratch[8:16], uint64(forwardingPolicy.MaxHTLC)) + byteOrder.PutUint64(scratch[16:24], uint64(forwardingPolicy.BaseFee)) + byteOrder.PutUint64(scratch[24:32], uint64(forwardingPolicy.FeeRate)) + byteOrder.PutUint32(scratch[32:], forwardingPolicy.TimeLockDelta) + + return kvdb.Update(s.backend, func(tx kvdb.RwTx) error { + bucket, err := tx.CreateTopLevelBucket( + initialChannelForwardingPolicyBucket, + ) + if err != nil { + return err + } + + return bucket.Put(chanIDCopy, scratch) + }, func() {}) +} + +// GetInitialForwardingPolicy fetches the serialized forwarding policy for the +// provided channel id from the database, or returns ErrChannelNotFound if +// a forwarding policy for this channel id is not found. +func (s *KVStore) GetInitialForwardingPolicy( + chanID lnwire.ChannelID) (*graphmodels.ForwardingPolicy, error) { + + chanIDCopy := make([]byte, 32) + copy(chanIDCopy, chanID[:]) + + var forwardingPolicy *graphmodels.ForwardingPolicy + err := kvdb.View(s.backend, func(tx kvdb.RTx) error { + bucket := tx.ReadBucket(initialChannelForwardingPolicyBucket) + if bucket == nil { + // If the bucket does not exist, it means we + // never added a channel fees to the db, so + // return ErrChannelNotFound. + return ErrChannelNotFound + } + + stateBytes := bucket.Get(chanIDCopy) + if stateBytes == nil { + return ErrChannelNotFound + } + + forwardingPolicy = &graphmodels.ForwardingPolicy{ + MinHTLCOut: lnwire.MilliSatoshi( + byteOrder.Uint64(stateBytes[:8]), + ), + MaxHTLC: lnwire.MilliSatoshi( + byteOrder.Uint64(stateBytes[8:16]), + ), + BaseFee: lnwire.MilliSatoshi( + byteOrder.Uint64(stateBytes[16:24]), + ), + FeeRate: lnwire.MilliSatoshi( + byteOrder.Uint64(stateBytes[24:32]), + ), + TimeLockDelta: byteOrder.Uint32(stateBytes[32:36]), + } + + return nil + }, func() { + forwardingPolicy = nil + }) + + return forwardingPolicy, err +} + +// DeleteInitialForwardingPolicy removes the forwarding policy for a given +// channel from the database. +func (s *KVStore) DeleteInitialForwardingPolicy( + chanID lnwire.ChannelID) error { + + chanIDCopy := make([]byte, 32) + copy(chanIDCopy, chanID[:]) + + return kvdb.Update(s.backend, func(tx kvdb.RwTx) error { + bucket := tx.ReadWriteBucket( + initialChannelForwardingPolicyBucket, + ) + if bucket == nil { + return ErrChannelNotFound + } + + return bucket.Delete(chanIDCopy) + }, func() {}) +} diff --git a/chanstate/kv_close_summary.go b/chanstate/kv_close_summary.go new file mode 100644 index 00000000000..77489c79c96 --- /dev/null +++ b/chanstate/kv_close_summary.go @@ -0,0 +1,577 @@ +package chanstate + +import ( + "bytes" + "errors" + "io" + + "github.com/btcsuite/btcd/wire" + graphdb "github.com/lightningnetwork/lnd/graph/db" + "github.com/lightningnetwork/lnd/kvdb" + "github.com/lightningnetwork/lnd/lnwire" +) + +var ( + // closedChannelBucket stores summarization information concerning + // previously open, but now closed channels. + closedChannelBucket = []byte("closed-chan-bucket") +) + +// ClosedChannelBucketKey returns the top-level closed-channel summary bucket +// key. +func ClosedChannelBucketKey() []byte { + return closedChannelBucket +} + +// PutChannelCloseSummary writes the immutable close-time summary of a channel +// under the closed channel bucket. +func PutChannelCloseSummary(tx kvdb.RwTx, chanID []byte, + summary *ChannelCloseSummary, lastChanState *OpenChannel) error { + + closedChanBucket, err := tx.CreateTopLevelBucket(closedChannelBucket) + if err != nil { + return err + } + + summary.RemoteCurrentRevocation = lastChanState.RemoteCurrentRevocation + summary.RemoteNextRevocation = lastChanState.RemoteNextRevocation + summary.LocalChanConfig = lastChanState.LocalChanCfg + + var b bytes.Buffer + if err := SerializeChannelCloseSummary(&b, summary); err != nil { + return err + } + + return closedChanBucket.Put(chanID, b.Bytes()) +} + +// CloseChannel closes the supplied channel via the selected close strategy. On +// synchronous backends the channel's nested state — the revocation log, the +// per-channel forwarding-package bucket, and the chanBucket itself — is +// deleted inline. On tombstone-enabled backends none of the bulk state is +// touched; the outpointBucket flip to outpointClosed signals that the channel +// is logically closed. +func (s *KVStore) CloseChannel(channel *OpenChannel, + summary *ChannelCloseSummary, + statuses ...ChannelStatus) error { + + if s.tombstoneClosedChannels { + return s.closeChannelTombstone(channel, summary, statuses...) + } + + return s.closeChannelSync(channel, summary, statuses...) +} + +// LocateOpenChannel performs the open-channel-bucket descent for a CloseChannel +// transaction: it returns the chain bucket, the channel bucket, and the +// serialized chanKey for the supplied OpenChannel. A chanKey already flipped to +// outpointClosed surfaces ErrChannelNotFound so a redundant CloseChannel does +// not re-archive or re-flip the index. +func LocateOpenChannel(tx kvdb.RwTx, channel *OpenChannel) (kvdb.RwBucket, + kvdb.RwBucket, []byte, error) { + + openChanBucket := tx.ReadWriteBucket(openChannelBucket) + if openChanBucket == nil { + return nil, nil, nil, ErrNoChanDBExists + } + + nodePub := channel.IdentityPub.SerializeCompressed() + nodeChanBucket := openChanBucket.NestedReadWriteBucket(nodePub) + if nodeChanBucket == nil { + return nil, nil, nil, ErrNoActiveChannels + } + + chainBucket := nodeChanBucket.NestedReadWriteBucket( + channel.ChainHash[:], + ) + if chainBucket == nil { + return nil, nil, nil, ErrNoActiveChannels + } + + var chanPointBuf bytes.Buffer + if err := graphdb.WriteOutpoint( + &chanPointBuf, &channel.FundingOutpoint, + ); err != nil { + return nil, nil, nil, err + } + chanKey := chanPointBuf.Bytes() + + chanBucket := chainBucket.NestedReadWriteBucket(chanKey) + if chanBucket == nil { + return nil, nil, nil, ErrNoActiveChannels + } + + // A channel whose outpoint is already flipped to outpointClosed must + // not be re-closed: on tombstone backends the chanBucket survives a + // previous close, but the index flip is the authoritative record that + // the channel is gone from the open-channel view. + closed, err := IsOutpointClosed(tx.ReadBucket(outpointBucket), chanKey) + if err != nil { + return nil, nil, nil, err + } + if closed { + return nil, nil, nil, ErrChannelNotFound + } + + return chainBucket, chanBucket, chanKey, nil +} + +// ArchiveClosedChannel writes the immutable close-time records of the channel: +// a copy of the open-channel state under historicalChannelBucket (with the +// supplied close statuses OR'd into chanStatus) and the close summary under +// closeSummaryBucket. +func ArchiveClosedChannel(tx kvdb.RwTx, chanKey []byte, + chanState *OpenChannel, summary *ChannelCloseSummary, + statuses ...ChannelStatus) error { + + historicalBucket, err := tx.CreateTopLevelBucket( + historicalChannelBucket, + ) + if err != nil { + return err + } + historicalChanBucket, err := historicalBucket.CreateBucketIfNotExists( + chanKey, + ) + if err != nil { + return err + } + + for _, s := range statuses { + chanState.SetChannelStatusForStore( + chanState.ChannelStatusForStore() | s, + ) + } + + if err := PutOpenChannel(historicalChanBucket, chanState); err != nil { + return err + } + + return PutChannelCloseSummary(tx, chanKey, summary, chanState) +} + +// closeChannelSync performs the historical synchronous close path: in a single +// write transaction it wipes the forwarding-package state, deletes the channel +// bucket and its nested revocation log entries, updates the outpoint index, and +// archives the close summary. It is used by backends where nested-bucket +// deletion is cheap (bbolt, etcd). +func (s *KVStore) closeChannelSync(channel *OpenChannel, + summary *ChannelCloseSummary, statuses ...ChannelStatus) error { + + return kvdb.Update(s.backend, func(tx kvdb.RwTx) error { + chainBucket, chanBucket, chanKey, err := LocateOpenChannel( + tx, channel, + ) + if err != nil { + return err + } + + chanState, err := FetchOpenChannel( + chanBucket, &channel.FundingOutpoint, + ) + if err != nil { + return err + } + + packager := NewChannelPackager(chanState.ShortChannelID) + if err = packager.Wipe(tx); err != nil { + return err + } + + if err := DeleteOpenChannel(chanBucket); err != nil { + return err + } + + if channel.ChanType.IsFrozen() || + channel.ChanType.HasLeaseExpiration() { + + if err := DeleteThawHeight(chanBucket); err != nil { + return err + } + } + + if err := DeleteLogBucket(chanBucket); err != nil { + return err + } + + if err := chainBucket.DeleteNestedBucket(chanKey); err != nil { + return err + } + + if err := UpdateClosedOutpointIndex(tx, chanKey); err != nil { + return err + } + + return ArchiveClosedChannel( + tx, chanKey, chanState, summary, statuses..., + ) + }, func() {}) +} + +// closeChannelTombstone performs the tombstone close path used by KV-over-SQL +// backends. The channel's per-channel state is left intact — touching it +// would trigger the cascading nested-bucket delete this path exists to avoid +// — and the outpointBucket flip from outpointOpen to outpointClosed serves as +// the authoritative closed-channel marker. The disk space is reclaimed +// wholesale by the upcoming native-SQL channel-state migration. +func (s *KVStore) closeChannelTombstone(channel *OpenChannel, + summary *ChannelCloseSummary, statuses ...ChannelStatus) error { + + return kvdb.Update(s.backend, func(tx kvdb.RwTx) error { + _, chanBucket, chanKey, err := LocateOpenChannel(tx, channel) + if err != nil { + return err + } + + chanState, err := FetchOpenChannel( + chanBucket, &channel.FundingOutpoint, + ) + if err != nil { + return err + } + + if err := UpdateClosedOutpointIndex(tx, chanKey); err != nil { + return err + } + + return ArchiveClosedChannel( + tx, chanKey, chanState, summary, statuses..., + ) + }, func() {}) +} + +// FetchClosedChannels attempts to fetch all closed channels from the database. +// The pendingOnly bool toggles if channels that aren't yet fully closed should +// be returned in the response or not. When a channel was cooperatively closed, +// it becomes fully closed after a single confirmation. When a channel was +// forcibly closed, it will become fully closed after _all_ the pending funds +// (if any) have been swept. +func (s *KVStore) FetchClosedChannels(pendingOnly bool) ( + []*ChannelCloseSummary, error) { + + var chanSummaries []*ChannelCloseSummary + + if err := kvdb.View(s.backend, func(tx kvdb.RTx) error { + closeBucket := tx.ReadBucket(closedChannelBucket) + if closeBucket == nil { + return ErrNoClosedChannels + } + + return closeBucket.ForEach(func(chanID []byte, + summaryBytes []byte) error { + + summaryReader := bytes.NewReader(summaryBytes) + chanSummary, err := DeserializeCloseChannelSummary( + summaryReader, + ) + if err != nil { + return err + } + + // If the query specified to only include pending + // channels, then we'll skip any channels which aren't + // currently pending. + if !chanSummary.IsPending && pendingOnly { + return nil + } + + chanSummaries = append(chanSummaries, chanSummary) + + return nil + }) + }, func() { + chanSummaries = nil + }); err != nil { + return nil, err + } + + return chanSummaries, nil +} + +// FetchClosedChannel queries for a channel close summary using the channel +// point of the channel in question. +func (s *KVStore) FetchClosedChannel(chanID *wire.OutPoint) ( + *ChannelCloseSummary, error) { + + var chanSummary *ChannelCloseSummary + if err := kvdb.View(s.backend, func(tx kvdb.RTx) error { + closeBucket := tx.ReadBucket(closedChannelBucket) + if closeBucket == nil { + return ErrClosedChannelNotFound + } + + var b bytes.Buffer + var err error + if err = graphdb.WriteOutpoint(&b, chanID); err != nil { + return err + } + + summaryBytes := closeBucket.Get(b.Bytes()) + if summaryBytes == nil { + return ErrClosedChannelNotFound + } + + summaryReader := bytes.NewReader(summaryBytes) + chanSummary, err = DeserializeCloseChannelSummary( + summaryReader, + ) + + return err + }, func() { + chanSummary = nil + }); err != nil { + return nil, err + } + + return chanSummary, nil +} + +// FetchClosedChannelForID queries for a channel close summary using the +// channel ID of the channel in question. +func (s *KVStore) FetchClosedChannelForID(cid lnwire.ChannelID) ( + *ChannelCloseSummary, error) { + + var chanSummary *ChannelCloseSummary + if err := kvdb.View(s.backend, func(tx kvdb.RTx) error { + closeBucket := tx.ReadBucket(closedChannelBucket) + if closeBucket == nil { + return ErrClosedChannelNotFound + } + + // The first 30 bytes of the channel ID and outpoint will be + // equal. + cursor := closeBucket.ReadCursor() + op, c := cursor.Seek(cid[:30]) + + // We scan over all possible candidates for this channel ID. + for op != nil && bytes.Compare(cid[:30], op[:30]) <= 0 { + var outPoint wire.OutPoint + err := graphdb.ReadOutpoint( + bytes.NewReader(op), &outPoint, + ) + if err != nil { + return err + } + + // If the found outpoint corresponds to this channel ID, + // deserialize the close summary and return. + if cid.IsChanPoint(&outPoint) { + r := bytes.NewReader(c) + cs, err := DeserializeCloseChannelSummary(r) + if err != nil { + return err + } + + chanSummary = cs + + return nil + } + + op, c = cursor.Next() + } + + return ErrClosedChannelNotFound + }, func() { + chanSummary = nil + }); err != nil { + return nil, err + } + + return chanSummary, nil +} + +// AbandonChannel attempts to remove the target channel from the open channel +// database. If the channel was already removed (has a closed channel entry), +// then we'll return a nil error. Otherwise, we'll insert a new close summary +// into the database. +func (s *KVStore) AbandonChannel(chanPoint *wire.OutPoint, + bestHeight uint32) error { + + // With the chanPoint constructed, we'll attempt to find the target + // channel in the database. If we can't find the channel, then we'll + // return the error back to the caller. + dbChan, err := s.FetchChannel(*chanPoint) + switch { + // If the channel wasn't found, then it's possible that it was already + // abandoned from the database. + case errors.Is(err, ErrChannelNotFound): + _, closedErr := s.FetchClosedChannel(chanPoint) + if closedErr != nil { + return closedErr + } + + // If the channel was already closed, then we don't return an + // error as we'd like this step to be repeatable. + return nil + case err != nil: + return err + } + + // Now that we've found the channel, we'll populate a close summary for + // the channel, so we can store as much information for this abounded + // channel as possible. We also ensure that we set Pending to false, to + // indicate that this channel has been "fully" closed. + settledBalance := dbChan.LocalCommitment.LocalBalance.ToSatoshis() + summary := &ChannelCloseSummary{ + CloseType: Abandoned, + ChanPoint: *chanPoint, + ChainHash: dbChan.ChainHash, + CloseHeight: bestHeight, + RemotePub: dbChan.IdentityPub, + Capacity: dbChan.Capacity, + SettledBalance: settledBalance, + ShortChanID: dbChan.ShortChanID(), + RemoteCurrentRevocation: dbChan.RemoteCurrentRevocation, + RemoteNextRevocation: dbChan.RemoteNextRevocation, + LocalChanConfig: dbChan.LocalChanCfg, + } + + // Finally, we'll close the channel in the DB, and return back to the + // caller. We set ourselves as the close initiator because we abandoned + // the channel. + return s.CloseChannel(dbChan, summary, ChanStatusLocalCloseInitiator) +} + +// SerializeChannelCloseSummary serializes a channel close summary. +func SerializeChannelCloseSummary(w io.Writer, + cs *ChannelCloseSummary) error { + + err := WriteElements(w, + cs.ChanPoint, cs.ShortChanID, cs.ChainHash, cs.ClosingTXID, + cs.CloseHeight, cs.RemotePub, cs.Capacity, cs.SettledBalance, + cs.TimeLockedBalance, cs.CloseType, cs.IsPending, + ) + if err != nil { + return err + } + + // If this is a close channel summary created before the addition of + // the new fields, then we can exit here. + if cs.RemoteCurrentRevocation == nil { + return WriteElements(w, false) + } + + // If fields are present, write boolean to indicate this, and continue. + if err := WriteElements(w, true); err != nil { + return err + } + + if err := WriteElements(w, cs.RemoteCurrentRevocation); err != nil { + return err + } + + if err := WriteChanConfig(w, &cs.LocalChanConfig); err != nil { + return err + } + + // The RemoteNextRevocation field is optional, as it's possible for a + // channel to be closed before we learn of the next unrevoked + // revocation point for the remote party. Write a boolean indicating + // whether this field is present or not. + if err := WriteElements(w, cs.RemoteNextRevocation != nil); err != nil { + return err + } + + // Write the field, if present. + if cs.RemoteNextRevocation != nil { + if err = WriteElements(w, cs.RemoteNextRevocation); err != nil { + return err + } + } + + // Write whether the channel sync message is present. + if err := WriteElements(w, cs.LastChanSyncMsg != nil); err != nil { + return err + } + + // Write the channel sync message, if present. + if cs.LastChanSyncMsg != nil { + if err := WriteElements(w, cs.LastChanSyncMsg); err != nil { + return err + } + } + + return nil +} + +// DeserializeCloseChannelSummary deserializes a channel close summary. +func DeserializeCloseChannelSummary(r io.Reader) (*ChannelCloseSummary, error) { + c := &ChannelCloseSummary{} + + err := ReadElements(r, + &c.ChanPoint, &c.ShortChanID, &c.ChainHash, &c.ClosingTXID, + &c.CloseHeight, &c.RemotePub, &c.Capacity, &c.SettledBalance, + &c.TimeLockedBalance, &c.CloseType, &c.IsPending, + ) + if err != nil { + return nil, err + } + + // We'll now check to see if the channel close summary was encoded with + // any of the additional optional fields. + var hasNewFields bool + err = ReadElements(r, &hasNewFields) + if err != nil { + return nil, err + } + + // If fields are not present, we can return. + if !hasNewFields { + return c, nil + } + + // Otherwise read the new fields. + if err := ReadElements(r, &c.RemoteCurrentRevocation); err != nil { + return nil, err + } + + if err := ReadChanConfig(r, &c.LocalChanConfig); err != nil { + return nil, err + } + + // Finally, we'll attempt to read the next unrevoked commitment point + // for the remote party. If we closed the channel before receiving a + // channel_ready message then this might not be present. A boolean + // indicating whether the field is present will come first. + var hasRemoteNextRevocation bool + err = ReadElements(r, &hasRemoteNextRevocation) + if err != nil { + return nil, err + } + + // If this field was written, read it. + if hasRemoteNextRevocation { + err = ReadElements(r, &c.RemoteNextRevocation) + if err != nil { + return nil, err + } + } + + // Check if we have a channel sync message to read. + var hasChanSyncMsg bool + err = ReadElements(r, &hasChanSyncMsg) + if errors.Is(err, io.EOF) { + return c, nil + } else if err != nil { + return nil, err + } + + // If a chan sync message is present, read it. + if hasChanSyncMsg { + // We must pass in reference to a lnwire.Message for the codec + // to support it. + var msg lnwire.Message + if err := ReadElements(r, &msg); err != nil { + return nil, err + } + + chanSync, ok := msg.(*lnwire.ChannelReestablish) + if !ok { + return nil, errors.New("unable cast db Message to " + + "ChannelReestablish") + } + c.LastChanSyncMsg = chanSync + } + + return c, nil +} diff --git a/chanstate/kv_close_tx.go b/chanstate/kv_close_tx.go new file mode 100644 index 00000000000..5a7380d8055 --- /dev/null +++ b/chanstate/kv_close_tx.go @@ -0,0 +1,171 @@ +package chanstate + +import ( + "bytes" + "errors" + + "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/kvdb" + "github.com/lightningnetwork/lnd/lntypes" +) + +var ( + // forceCloseTxKey points to a the unilateral closing tx that we + // broadcasted when moving the channel to state CommitBroadcasted. + forceCloseTxKey = []byte("closing-tx-key") + + // coopCloseTxKey points to a the cooperative closing tx that we + // broadcasted when moving the channel to state CoopBroadcasted. + coopCloseTxKey = []byte("coop-closing-tx-key") +) + +// ForceCloseTxKey returns the key used to store the unilateral closing +// transaction in a channel bucket. +func ForceCloseTxKey() []byte { + return forceCloseTxKey +} + +// CoopCloseTxKey returns the key used to store the cooperative closing +// transaction in a channel bucket. +func CoopCloseTxKey() []byte { + return coopCloseTxKey +} + +// PutChannelCloseTx stores the closing transaction under the requested key in +// the target channel bucket. +func PutChannelCloseTx(chanBucket kvdb.RwBucket, key []byte, + closeTx *wire.MsgTx) error { + + var b bytes.Buffer + if err := closeTx.Serialize(&b); err != nil { + return err + } + + return chanBucket.Put(key, b.Bytes()) +} + +// FetchChannelCloseTx retrieves the closing transaction stored under the +// requested key in the target channel bucket. +func FetchChannelCloseTx(chanBucket kvdb.RBucket, + key []byte) (*wire.MsgTx, error) { + + bs := chanBucket.Get(key) + if bs == nil { + return nil, ErrNoCloseTx + } + + closeTx := wire.NewMsgTx(2) + r := bytes.NewReader(bs) + if err := closeTx.Deserialize(r); err != nil { + return nil, err + } + + return closeTx, nil +} + +// MarkChannelCommitmentBroadcasted marks the channel as having a commitment +// transaction broadcast. +func (s *KVStore) MarkChannelCommitmentBroadcasted( + channel *OpenChannel, closeTx *wire.MsgTx, + closer lntypes.ChannelParty) error { + + return s.markBroadcasted( + channel, ChanStatusCommitBroadcasted, forceCloseTxKey, + closeTx, closer, + ) +} + +// MarkChannelCoopBroadcasted marks the channel as having a cooperative close +// transaction broadcast. +func (s *KVStore) MarkChannelCoopBroadcasted( + channel *OpenChannel, closeTx *wire.MsgTx, + closer lntypes.ChannelParty) error { + + return s.markBroadcasted( + channel, ChanStatusCoopBroadcasted, coopCloseTxKey, + closeTx, closer, + ) +} + +// markBroadcasted modifies the channel status and inserts a close transaction +// under the requested key, which should specify either a coop or force close. +// It adds a status which indicates the party that initiated the channel close. +func (s *KVStore) markBroadcasted(channel *OpenChannel, + status ChannelStatus, key []byte, closeTx *wire.MsgTx, + closer lntypes.ChannelParty) error { + + channel.Lock() + defer channel.Unlock() + + // If a closing tx is provided, we'll generate a closure to write the + // transaction in the appropriate bucket under the given key. + var putClosingTx func(kvdb.RwBucket) error + if closeTx != nil { + putClosingTx = func(chanBucket kvdb.RwBucket) error { + return PutChannelCloseTx(chanBucket, key, closeTx) + } + } + + // Add the initiator status to the status provided. These statuses are + // set in addition to the broadcast status so that we do not need to + // migrate the original logic which does not store initiator. + if closer.IsLocal() { + status |= ChanStatusLocalCloseInitiator + } else { + status |= ChanStatusRemoteCloseInitiator + } + + return s.putChanStatus(channel, status, putClosingTx) +} + +// FetchChannelBroadcastedCommitment fetches the stored unilateral closing +// transaction. +func (s *KVStore) FetchChannelBroadcastedCommitment( + channel *OpenChannel) (*wire.MsgTx, error) { + + return s.fetchClosingTx(channel, forceCloseTxKey) +} + +// FetchChannelBroadcastedCooperative fetches the stored cooperative closing +// transaction. +func (s *KVStore) FetchChannelBroadcastedCooperative( + channel *OpenChannel) (*wire.MsgTx, error) { + + return s.fetchClosingTx(channel, coopCloseTxKey) +} + +// fetchClosingTx returns the stored closing transaction for key. The caller +// should use either the force or coop closing keys. +func (s *KVStore) fetchClosingTx(channel *OpenChannel, + key []byte) (*wire.MsgTx, error) { + + var closeTx *wire.MsgTx + + err := kvdb.View(s.backend, func(tx kvdb.RTx) error { + chanBucket, err := FetchChanBucket( + tx, channel.IdentityPub, &channel.FundingOutpoint, + channel.ChainHash, + ) + switch { + case err == nil: + case errors.Is(err, ErrNoChanDBExists), + errors.Is(err, ErrNoActiveChannels), + errors.Is(err, ErrChannelNotFound): + + return ErrNoCloseTx + default: + return err + } + + closeTx, err = FetchChannelCloseTx(chanBucket, key) + + return err + }, func() { + closeTx = nil + }) + if err != nil { + return nil, err + } + + return closeTx, nil +} diff --git a/chanstate/kv_commitment.go b/chanstate/kv_commitment.go new file mode 100644 index 00000000000..8fa94e62a4a --- /dev/null +++ b/chanstate/kv_commitment.go @@ -0,0 +1,1350 @@ +package chanstate + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "io" + + "github.com/btcsuite/btcd/btcec/v2" + "github.com/lightningnetwork/lnd/fn/v2" + "github.com/lightningnetwork/lnd/graph/db/models" + "github.com/lightningnetwork/lnd/kvdb" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/shachain" + "github.com/lightningnetwork/lnd/tlv" +) + +var ( + // chanCommitmentKey can be accessed within the sub-bucket for a + // particular channel. This key stores the up to date commitment state + // for a particular channel party. Appending a 0 to the end of this key + // indicates it's the commitment for the local party, and appending a 1 + // to the end of this key indicates it's the commitment for the remote + // party. + chanCommitmentKey = []byte("chan-commitment-key") + + // unsignedAckedUpdatesKey is an entry in the channel bucket that + // contains the remote updates that we have acked, but not yet signed + // for in one of our remote commits. + unsignedAckedUpdatesKey = []byte("unsigned-acked-updates-key") + + // remoteUnsignedLocalUpdatesKey is an entry in the channel bucket that + // contains the local updates that the remote party has acked, but + // has not yet signed for in one of their local commits. + remoteUnsignedLocalUpdatesKey = []byte( + "remote-unsigned-local-updates-key", + ) + + // revocationStateKey stores their current revocation hash, our + // preimage producer and their preimage store. + revocationStateKey = []byte("revocation-state-key") + + // commitDiffKey stores the current pending commitment state we've + // extended to the remote party (if any). Each time we propose a new + // state, we store the information necessary to reconstruct this state + // from the prior commitment. This allows us to resync the remote party + // to their expected state in the case of message loss. + // + // TODO(roasbeef): rename to commit chain? + commitDiffKey = []byte("commit-diff-key") + + // lastWasRevokeKey is a key that stores true when the last update we + // sent was a revocation and false when it was a commitment signature. + // This is nil in the case of new channels with no updates exchanged. + lastWasRevokeKey = []byte("last-was-revoke") +) + +// ChanCommitmentKey returns the channel-bucket key prefix for channel +// commitments. +func ChanCommitmentKey() []byte { + return chanCommitmentKey +} + +// UnsignedAckedUpdatesKey returns the channel-bucket key for unsigned acked +// remote updates. +func UnsignedAckedUpdatesKey() []byte { + return unsignedAckedUpdatesKey +} + +// RemoteUnsignedLocalUpdatesKey returns the channel-bucket key for remote +// unsigned local updates. +func RemoteUnsignedLocalUpdatesKey() []byte { + return remoteUnsignedLocalUpdatesKey +} + +// RevocationStateKey returns the channel-bucket key for revocation state. +func RevocationStateKey() []byte { + return revocationStateKey +} + +// CommitDiffKey returns the channel-bucket key for the current pending +// commitment diff. +func CommitDiffKey() []byte { + return commitDiffKey +} + +// LastWasRevokeKey returns the channel-bucket key for the last update type. +func LastWasRevokeKey() []byte { + return lastWasRevokeKey +} + +// serializeHtlcExtraData encodes a TLV stream of extra data to be stored with a +// HTLC. It uses the update_add_htlc TLV types, because this is where extra +// data is passed with a HTLC. At present blinding points are the only extra +// data that we will store, and the function is a no-op if a nil blinding +// point is provided. +// +// This function MUST be called to persist all HTLC values when they are +// serialized. +func serializeHtlcExtraData(h *HTLC) error { + var records []tlv.RecordProducer + h.BlindingPoint.WhenSome(func(b tlv.RecordT[lnwire.BlindingPointTlvType, + *btcec.PublicKey]) { + + records = append(records, &b) + }) + + records, err := h.CustomRecords.ExtendRecordProducers(records) + if err != nil { + return err + } + + return h.ExtraData.PackRecords(records...) +} + +// deserializeHtlcExtraData extracts TLVs from the extra data persisted for the +// HTLC and populates values in the struct accordingly. +// +// This function MUST be called to populate the struct properly when HTLCs +// are deserialized. +func deserializeHtlcExtraData(h *HTLC) error { + if len(h.ExtraData) == 0 { + return nil + } + + blindingPoint := h.BlindingPoint.Zero() + tlvMap, err := h.ExtraData.ExtractRecords(&blindingPoint) + if err != nil { + return err + } + + if val, ok := tlvMap[h.BlindingPoint.TlvType()]; ok && val == nil { + h.BlindingPoint = tlv.SomeRecordT(blindingPoint) + + // Remove the entry from the TLV map. Anything left in the map + // will be included in the custom records field. + delete(tlvMap, h.BlindingPoint.TlvType()) + } + + // Set the custom records field to the remaining TLV records. + customRecords, err := lnwire.NewCustomRecords(tlvMap) + if err != nil { + return err + } + h.CustomRecords = customRecords + + return nil +} + +// SerializeHtlcs writes out the passed set of HTLC's into the passed writer +// using the current default on-disk serialization format. +// +// This inline serialization has been extended to allow storage of extra data +// associated with a HTLC in the following way: +// - The known-length onion blob (1366 bytes) is serialized as var bytes in +// WriteElements (ie, the length 1366 was written, followed by the 1366 +// onion bytes). +// - To include extra data, we append any extra data present to this one +// variable length of data. Since we know that the onion is strictly 1366 +// bytes, any length after that should be considered to be extra data. +// +// NOTE: This API is NOT stable, the on-disk format will likely change in the +// future. +func SerializeHtlcs(b io.Writer, htlcs ...HTLC) error { + numHtlcs := uint16(len(htlcs)) + if err := WriteElement(b, numHtlcs); err != nil { + return err + } + + for _, htlc := range htlcs { + // Populate TLV stream for any additional fields contained + // in the TLV. + if err := serializeHtlcExtraData(&htlc); err != nil { + return err + } + + // The onion blob and hltc data are stored as a single var + // bytes blob. + onionAndExtraData := make( + []byte, lnwire.OnionPacketSize+len(htlc.ExtraData), + ) + copy(onionAndExtraData, htlc.OnionBlob[:]) + copy(onionAndExtraData[lnwire.OnionPacketSize:], htlc.ExtraData) + + if err := WriteElements(b, + //nolint:ll + htlc.Signature, htlc.RHash, htlc.Amt, htlc.RefundTimeout, + htlc.OutputIndex, htlc.Incoming, onionAndExtraData, + htlc.HtlcIndex, htlc.LogIndex, + ); err != nil { + return err + } + } + + return nil +} + +// DeserializeHtlcs attempts to read out a slice of HTLC's from the passed +// io.Reader. The bytes within the passed reader MUST have been previously +// written to using the SerializeHtlcs function. +// +// This inline deserialization has been extended to allow storage of extra data +// associated with a HTLC in the following way: +// - The known-length onion blob (1366 bytes) and any additional data present +// are read out as a single blob of variable byte data. +// - They are stored like this to take advantage of the variable space +// available for extension without migration (see SerializeHtlcs). +// - The first 1366 bytes are interpreted as the onion blob, and any remaining +// bytes as extra HTLC data. +// - This extra HTLC data is expected to be serialized as a TLV stream, and +// its parsing is left to higher layers. +// +// NOTE: This API is NOT stable, the on-disk format will likely change in the +// future. +func DeserializeHtlcs(r io.Reader) ([]HTLC, error) { + var numHtlcs uint16 + if err := ReadElement(r, &numHtlcs); err != nil { + return nil, err + } + + var htlcs []HTLC + if numHtlcs == 0 { + return htlcs, nil + } + + htlcs = make([]HTLC, numHtlcs) + for i := uint16(0); i < numHtlcs; i++ { + var onionAndExtraData []byte + if err := ReadElements(r, + &htlcs[i].Signature, &htlcs[i].RHash, &htlcs[i].Amt, + &htlcs[i].RefundTimeout, &htlcs[i].OutputIndex, + &htlcs[i].Incoming, &onionAndExtraData, + &htlcs[i].HtlcIndex, &htlcs[i].LogIndex, + ); err != nil { + return htlcs, err + } + + // Sanity check that we have at least the onion blob size we + // expect. + if len(onionAndExtraData) < lnwire.OnionPacketSize { + return nil, ErrOnionBlobLength + } + + // First OnionPacketSize bytes are our fixed length onion + // packet. + copy( + htlcs[i].OnionBlob[:], + onionAndExtraData[0:lnwire.OnionPacketSize], + ) + + // Any additional bytes belong to extra data. ExtraDataLen + // will be >= 0, because we know that we always have a fixed + // length onion packet. + extraDataLen := len(onionAndExtraData) - lnwire.OnionPacketSize + if extraDataLen > 0 { + htlcs[i].ExtraData = make([]byte, extraDataLen) + + copy( + htlcs[i].ExtraData, + onionAndExtraData[lnwire.OnionPacketSize:], + ) + } + + // Finally, deserialize any TLVs contained in that extra data + // if they are present. + if err := deserializeHtlcExtraData(&htlcs[i]); err != nil { + return nil, err + } + } + + return htlcs, nil +} + +// SerializeChanCommit serializes the channel commitment. +func SerializeChanCommit(w io.Writer, c *ChannelCommitment) error { + if err := WriteElements(w, + c.CommitHeight, c.LocalLogIndex, c.LocalHtlcIndex, + c.RemoteLogIndex, c.RemoteHtlcIndex, c.LocalBalance, + c.RemoteBalance, c.CommitFee, c.FeePerKw, c.CommitTx, + c.CommitSig, + ); err != nil { + return err + } + + return SerializeHtlcs(w, c.Htlcs...) +} + +// DeserializeChanCommit deserializes the channel commitment. +func DeserializeChanCommit(r io.Reader) (ChannelCommitment, error) { + var c ChannelCommitment + + err := ReadElements(r, + &c.CommitHeight, &c.LocalLogIndex, &c.LocalHtlcIndex, + &c.RemoteLogIndex, &c.RemoteHtlcIndex, &c.LocalBalance, + &c.RemoteBalance, &c.CommitFee, &c.FeePerKw, &c.CommitTx, + &c.CommitSig, + ) + if err != nil { + return c, err + } + + c.Htlcs, err = DeserializeHtlcs(r) + if err != nil { + return c, err + } + + return c, nil +} + +func chanCommitKey(local bool) []byte { + commitKey := make([]byte, 0, len(chanCommitmentKey)+1) + commitKey = append(commitKey, chanCommitmentKey...) + if local { + return append(commitKey, byte(0x00)) + } + + return append(commitKey, byte(0x01)) +} + +// PutChanCommitment writes a channel commitment to the channel bucket. +func PutChanCommitment(chanBucket kvdb.RwBucket, c *ChannelCommitment, + local bool) error { + + var b bytes.Buffer + if err := SerializeChanCommit(&b, c); err != nil { + return err + } + + // Before we write to disk, we'll also write our aux data as well. + if err := EncodeCommitTlvData(&b, c); err != nil { + return fmt.Errorf("unable to write aux data: %w", err) + } + + return chanBucket.Put(chanCommitKey(local), b.Bytes()) +} + +// PutChanCommitments writes the local and remote commitments to the channel +// bucket. +func PutChanCommitments(chanBucket kvdb.RwBucket, + channel *OpenChannel) error { + + // If this is a restored channel, then we don't have any commitments to + // write. + if channel.HasChanStatusForStore(ChanStatusRestored) { + return nil + } + + err := PutChanCommitment( + chanBucket, &channel.LocalCommitment, true, + ) + if err != nil { + return err + } + + return PutChanCommitment( + chanBucket, &channel.RemoteCommitment, false, + ) +} + +// PutChanRevocationState writes the remote revocation state to the channel +// bucket. +func PutChanRevocationState(chanBucket kvdb.RwBucket, + channel *OpenChannel) error { + + var b bytes.Buffer + err := WriteElements( + &b, channel.RemoteCurrentRevocation, channel.RevocationProducer, + channel.RevocationStore, + ) + if err != nil { + return err + } + + // If the next revocation is present, which is only the case after the + // ChannelReady message has been sent, then we'll write it to disk. + if channel.RemoteNextRevocation != nil { + err = WriteElements(&b, channel.RemoteNextRevocation) + if err != nil { + return err + } + } + + return chanBucket.Put(revocationStateKey, b.Bytes()) +} + +// FetchChanCommitment reads a channel commitment from the channel bucket. +func FetchChanCommitment(chanBucket kvdb.RBucket, + local bool) (ChannelCommitment, error) { + + commitBytes := chanBucket.Get(chanCommitKey(local)) + if commitBytes == nil { + return ChannelCommitment{}, ErrNoCommitmentsFound + } + + r := bytes.NewReader(commitBytes) + chanCommit, err := DeserializeChanCommit(r) + if err != nil { + return ChannelCommitment{}, fmt.Errorf("unable to decode "+ + "chan commit: %w", err) + } + + // We'll also check to see if we have any aux data stored as the end of + // the stream. + if err := DecodeCommitTlvData(r, &chanCommit); err != nil { + return ChannelCommitment{}, fmt.Errorf("unable to decode "+ + "chan aux data: %w", err) + } + + return chanCommit, nil +} + +// FetchChanCommitments reads the local and remote commitments from the channel +// bucket. +func FetchChanCommitments(chanBucket kvdb.RBucket, + channel *OpenChannel) error { + + var err error + + // If this is a restored channel, then we don't have any commitments to + // read. + if channel.HasChanStatusForStore(ChanStatusRestored) { + return nil + } + + channel.LocalCommitment, err = FetchChanCommitment(chanBucket, true) + if err != nil { + return err + } + channel.RemoteCommitment, err = FetchChanCommitment(chanBucket, false) + if err != nil { + return err + } + + return nil +} + +// FetchChanRevocationState reads the remote revocation state from the channel +// bucket. +func FetchChanRevocationState(chanBucket kvdb.RBucket, + channel *OpenChannel) error { + + revBytes := chanBucket.Get(revocationStateKey) + if revBytes == nil { + return ErrNoRevocationsFound + } + r := bytes.NewReader(revBytes) + + err := ReadElements( + r, + &channel.RemoteCurrentRevocation, &channel.RevocationProducer, + &channel.RevocationStore, + ) + if err != nil { + return err + } + + // If there aren't any bytes left in the buffer, then we don't yet have + // the next remote revocation, so we can exit early here. + if r.Len() == 0 { + return nil + } + + // Otherwise we'll read the next revocation for the remote party which + // is always the last item within the buffer. + return ReadElements(r, &channel.RemoteNextRevocation) +} + +// DeleteOpenChannel deletes the persisted open channel state from the channel +// bucket. +func DeleteOpenChannel(chanBucket kvdb.RwBucket) error { + if err := chanBucket.Delete(chanInfoKey); err != nil { + return err + } + + err := chanBucket.Delete(chanCommitKey(true)) + if err != nil { + return err + } + err = chanBucket.Delete(chanCommitKey(false)) + if err != nil { + return err + } + + if err := chanBucket.Delete(revocationStateKey); err != nil { + return err + } + + if diff := chanBucket.Get(commitDiffKey); diff != nil { + return chanBucket.Delete(commitDiffKey) + } + + return nil +} + +// RemoteCommitChainTip returns the "tip" of the current remote commitment +// chain. +func (s *KVStore) RemoteCommitChainTip( + channel *OpenChannel) (*CommitDiff, error) { + + var cd *CommitDiff + err := kvdb.View(s.backend, func(tx kvdb.RTx) error { + chanBucket, err := FetchChanBucket( + tx, channel.IdentityPub, &channel.FundingOutpoint, + channel.ChainHash, + ) + switch { + case err == nil: + case errors.Is(err, ErrNoChanDBExists), + errors.Is(err, ErrNoActiveChannels), + errors.Is(err, ErrChannelNotFound): + + return ErrNoPendingCommit + default: + return err + } + + tipBytes := chanBucket.Get(commitDiffKey) + if tipBytes == nil { + return ErrNoPendingCommit + } + + tipReader := bytes.NewReader(tipBytes) + dcd, err := DeserializeCommitDiff(tipReader) + if err != nil { + return err + } + + cd = dcd + + return nil + }, func() { + cd = nil + }) + if err != nil { + return nil, err + } + + return cd, nil +} + +// UnsignedAckedUpdates retrieves the persisted unsigned acked remote log +// updates that still need to be signed for. +func (s *KVStore) UnsignedAckedUpdates( + channel *OpenChannel) ([]LogUpdate, error) { + + var updates []LogUpdate + err := kvdb.View(s.backend, func(tx kvdb.RTx) error { + chanBucket, err := FetchChanBucket( + tx, channel.IdentityPub, &channel.FundingOutpoint, + channel.ChainHash, + ) + switch { + case err == nil: + case errors.Is(err, ErrNoChanDBExists), + errors.Is(err, ErrNoActiveChannels), + errors.Is(err, ErrChannelNotFound): + + return nil + default: + return err + } + + updateBytes := chanBucket.Get(unsignedAckedUpdatesKey) + if updateBytes == nil { + return nil + } + + r := bytes.NewReader(updateBytes) + updates, err = DeserializeLogUpdates(r) + + return err + }, func() { + updates = nil + }) + if err != nil { + return nil, err + } + + return updates, nil +} + +// RemoteUnsignedLocalUpdates retrieves the persisted, unsigned local log +// updates that the remote still needs to sign for. +func (s *KVStore) RemoteUnsignedLocalUpdates( + channel *OpenChannel) ([]LogUpdate, error) { + + var updates []LogUpdate + err := kvdb.View(s.backend, func(tx kvdb.RTx) error { + chanBucket, err := FetchChanBucket( + tx, channel.IdentityPub, &channel.FundingOutpoint, + channel.ChainHash, + ) + switch { + case err == nil: + case errors.Is(err, ErrNoChanDBExists), + errors.Is(err, ErrNoActiveChannels), + errors.Is(err, ErrChannelNotFound): + + return nil + default: + return err + } + + updateBytes := chanBucket.Get(remoteUnsignedLocalUpdatesKey) + if updateBytes == nil { + return nil + } + + r := bytes.NewReader(updateBytes) + updates, err = DeserializeLogUpdates(r) + + return err + }, func() { + updates = nil + }) + if err != nil { + return nil, err + } + + return updates, nil +} + +// InsertNextRevocation inserts the next commitment point into the persisted +// channel state. +func (s *KVStore) InsertNextRevocation(channel *OpenChannel, + revKey *btcec.PublicKey) error { + + channel.RemoteNextRevocation = revKey + + err := kvdb.Update(s.backend, func(tx kvdb.RwTx) error { + chanBucket, err := FetchChanBucketRw( + tx, channel.IdentityPub, &channel.FundingOutpoint, + channel.ChainHash, + ) + if err != nil { + return err + } + + return PutChanRevocationState(chanBucket, channel) + }, func() {}) + if err != nil { + return err + } + + return nil +} + +// UpdateChannelCommitment updates the local commitment state. +func (s *KVStore) UpdateChannelCommitment(channel *OpenChannel, + newCommitment *ChannelCommitment, + unsignedAckedUpdates []LogUpdate) ( + map[uint64]bool, error) { + + var finalHtlcs = make(map[uint64]bool) + + err := kvdb.Update(s.backend, func(tx kvdb.RwTx) error { + chanBucket, err := FetchChanBucketRw( + tx, channel.IdentityPub, &channel.FundingOutpoint, + channel.ChainHash, + ) + if err != nil { + return err + } + + // If the channel is marked as borked, then for safety reasons, + // we shouldn't attempt any further updates. + isBorked, err := IsChannelBorked(channel, chanBucket) + if err != nil { + return err + } + if isBorked { + return ErrChanBorked + } + + if err = PutChanInfo(chanBucket, channel); err != nil { + return fmt.Errorf("unable to store chan info: %w", err) + } + + // With the proper bucket fetched, we'll now write the latest + // commitment state to disk for the target party. + err = PutChanCommitment( + chanBucket, newCommitment, true, + ) + if err != nil { + return fmt.Errorf("unable to store chan "+ + "revocations: %v", err) + } + + // Persist unsigned but acked remote updates that need to be + // restored after a restart. + var b bytes.Buffer + err = SerializeLogUpdates(&b, unsignedAckedUpdates) + if err != nil { + return err + } + + err = chanBucket.Put(unsignedAckedUpdatesKey, b.Bytes()) + if err != nil { + return fmt.Errorf("unable to store dangline remote "+ + "updates: %v", err) + } + + //nolint:ll + // Since we have just sent the counterparty a revocation, store true + // under lastWasRevokeKey. + var b2 bytes.Buffer + if err := WriteElements(&b2, true); err != nil { + return err + } + + err = chanBucket.Put(lastWasRevokeKey, b2.Bytes()) + if err != nil { + return err + } + + //nolint:ll + // Persist the remote unsigned local updates that are not included + // in our new commitment. + updateBytes := chanBucket.Get(remoteUnsignedLocalUpdatesKey) + if updateBytes == nil { + return nil + } + + r := bytes.NewReader(updateBytes) + updates, err := DeserializeLogUpdates(r) + if err != nil { + return err + } + + // Get the bucket where settled htlcs are recorded if the user + // opted in to storing this information. + var finalHtlcsBucket kvdb.RwBucket + if s.storeFinalHtlcResolutions { + bucket, err := FetchFinalHtlcsBucketRw( + tx, channel.ShortChannelID, + ) + if err != nil { + return err + } + + finalHtlcsBucket = bucket + } + + var unsignedUpdates []LogUpdate + for _, upd := range updates { + // Gather updates that are not on our local commitment. + if upd.LogIndex >= newCommitment.LocalLogIndex { + unsignedUpdates = append(unsignedUpdates, upd) + + continue + } + + // The update was locked in. If the update was a + // resolution, then store it in the database. + err := ProcessFinalHtlc( + finalHtlcsBucket, upd, finalHtlcs, + ) + if err != nil { + return err + } + } + + var b3 bytes.Buffer + err = SerializeLogUpdates(&b3, unsignedUpdates) + if err != nil { + return fmt.Errorf("unable to serialize log updates: %w", + err) + } + + err = chanBucket.Put(remoteUnsignedLocalUpdatesKey, b3.Bytes()) + if err != nil { + return fmt.Errorf("unable to restore chanbucket: %w", + err) + } + + return nil + }, func() { + finalHtlcs = make(map[uint64]bool) + }) + if err != nil { + return nil, err + } + + return finalHtlcs, nil +} + +// AppendRemoteCommitChain appends a new CommitDiff to the remote party's +// commitment chain. +func (s *KVStore) AppendRemoteCommitChain(channel *OpenChannel, + diff *CommitDiff) error { + + return kvdb.Update(s.backend, func(tx kvdb.RwTx) error { + // First, we'll grab the writable bucket where this channel's + // data resides. + chanBucket, err := FetchChanBucketRw( + tx, channel.IdentityPub, &channel.FundingOutpoint, + channel.ChainHash, + ) + if err != nil { + return err + } + + // If the channel is marked as borked, then for safety reasons, + // we shouldn't attempt any further updates. + isBorked, err := IsChannelBorked(channel, chanBucket) + if err != nil { + return err + } + if isBorked { + return ErrChanBorked + } + + // Any outgoing settles and fails necessarily have a + // corresponding adds in this channel's forwarding packages. + // Mark all of these as being fully processed in our forwarding + // package, which prevents us from reprocessing them after + // startup. + packager := NewChannelPackager(channel.ShortChannelID) + + err = packager.AckAddHtlcs(tx, diff.AddAcks...) + if err != nil { + return err + } + + // Additionally, we ack from any fails or settles that are + // persisted in another channel's forwarding package. This + // prevents the same fails and settles from being retransmitted + // after restarts. The actual fail or settle we need to + // propagate to the remote party is now in the commit diff. + err = packager.AckSettleFails( + tx, diff.SettleFailAcks..., + ) + if err != nil { + return err + } + + //nolint:ll + // We are sending a commitment signature so lastWasRevokeKey should + // store false. + var b bytes.Buffer + if err := WriteElements(&b, false); err != nil { + return err + } + err = chanBucket.Put(lastWasRevokeKey, b.Bytes()) + if err != nil { + return err + } + + // TODO(roasbeef): use seqno to derive key for later LCP + + // With the bucket retrieved, we'll now serialize the commit + // diff itself, and write it to disk. + var b2 bytes.Buffer + if err := SerializeCommitDiff(&b2, diff); err != nil { + return err + } + + return chanBucket.Put(commitDiffKey, b2.Bytes()) + }, func() {}) +} + +// AdvanceCommitChainTail records the new state transition within the +// revocation log and promotes the pending remote commitment to the current +// remote commitment. +func (s *KVStore) AdvanceCommitChainTail(channel *OpenChannel, + fwdPkg *FwdPkg, updates []LogUpdate, ourOutputIndex, + theirOutputIndex uint32) error { + + var newRemoteCommit *ChannelCommitment + + err := kvdb.Update(s.backend, func(tx kvdb.RwTx) error { + chanBucket, err := FetchChanBucketRw( + tx, channel.IdentityPub, &channel.FundingOutpoint, + channel.ChainHash, + ) + if err != nil { + return err + } + + // If the channel is marked as borked, then for safety reasons, + // we shouldn't attempt any further updates. + isBorked, err := IsChannelBorked(channel, chanBucket) + if err != nil { + return err + } + if isBorked { + return ErrChanBorked + } + + // Persist the latest preimage state to disk as the remote peer + // has just added to our local preimage store, and given us a + // new pending revocation key. + err = PutChanRevocationState(chanBucket, channel) + if err != nil { + return err + } + + // With the current preimage producer/store state updated, + // append a new log entry recording this the delta of this + // state transition. + // + // TODO(roasbeef): could make the deltas relative, would save + // space, but then tradeoff for more disk-seeks to recover the + // full state. + logKey := revocationLogBucket + logBucket, err := chanBucket.CreateBucketIfNotExists(logKey) + if err != nil { + return err + } + + // Before we append this revoked state to the revocation log, + // we'll swap out what's currently the tail of the commit tip, + // with the current locked-in commitment for the remote party. + tipBytes := chanBucket.Get(commitDiffKey) + tipReader := bytes.NewReader(tipBytes) + newCommit, err := DeserializeCommitDiff(tipReader) + if err != nil { + return err + } + err = PutChanCommitment( + chanBucket, &newCommit.Commitment, false, + ) + if err != nil { + return err + } + if err := chanBucket.Delete(commitDiffKey); err != nil { + return err + } + + // With the commitment pointer swapped, we can now add the + // revoked (prior) state to the revocation log. + err = PutRevocationLog( + logBucket, &channel.RemoteCommitment, ourOutputIndex, + theirOutputIndex, s.noRevLogAmtData, + ) + if err != nil { + return err + } + + // Lastly, we write the forwarding package to disk so that we + // can properly recover from failures and reforward HTLCs that + // have not received a corresponding settle/fail. + err = NewChannelPackager(channel.ShortChannelID).AddFwdPkg( + tx, fwdPkg, + ) + if err != nil { + return err + } + + // Persist the unsigned acked updates that are not included + // in their new commitment. + updateBytes := chanBucket.Get(unsignedAckedUpdatesKey) + if updateBytes == nil { + // This shouldn't normally happen as we always store + // the number of updates, but could still be + // encountered by nodes that are upgrading. + newRemoteCommit = &newCommit.Commitment + return nil + } + + r := bytes.NewReader(updateBytes) + unsignedUpdates, err := DeserializeLogUpdates(r) + if err != nil { + return err + } + + var validUpdates []LogUpdate + for _, upd := range unsignedUpdates { + lIdx := upd.LogIndex + + // Filter for updates that are not on the remote + // commitment. + if lIdx >= newCommit.Commitment.RemoteLogIndex { + validUpdates = append(validUpdates, upd) + } + } + + var b bytes.Buffer + err = SerializeLogUpdates(&b, validUpdates) + if err != nil { + return fmt.Errorf("unable to serialize log updates: %w", + err) + } + + err = chanBucket.Put(unsignedAckedUpdatesKey, b.Bytes()) + if err != nil { + return fmt.Errorf("unable to store under "+ + "unsignedAckedUpdatesKey: %w", err) + } + + // Persist the local updates the peer hasn't yet signed so they + // can be restored after restart. + var b2 bytes.Buffer + err = SerializeLogUpdates(&b2, updates) + if err != nil { + return err + } + + err = chanBucket.Put(remoteUnsignedLocalUpdatesKey, b2.Bytes()) + if err != nil { + return fmt.Errorf("unable to restore remote unsigned "+ + "local updates: %v", err) + } + + newRemoteCommit = &newCommit.Commitment + + return nil + }, func() { + newRemoteCommit = nil + }) + if err != nil { + return err + } + + // With the db transaction complete, we'll swap over the in-memory + // pointer of the new remote commitment, which was previously the tip + // of the commit chain. + channel.RemoteCommitment = *newRemoteCommit + + return nil +} + +// CommitmentHeight returns the current commitment height. The commitment +// height represents the number of updates to the commitment state to date. +// This value is always monotonically increasing. This method is provided in +// order to allow multiple instances of a particular open channel to obtain a +// consistent view of the number of channel updates to date. +func (s *KVStore) CommitmentHeight(channel *OpenChannel) ( + uint64, error) { + + var height uint64 + err := kvdb.View(s.backend, func(tx kvdb.RTx) error { + // Get the bucket dedicated to storing the metadata for open + // channels. + chanBucket, err := FetchChanBucket( + tx, channel.IdentityPub, &channel.FundingOutpoint, + channel.ChainHash, + ) + if err != nil { + return err + } + + commit, err := FetchChanCommitment(chanBucket, true) + if err != nil { + return err + } + + height = commit.CommitHeight + + return nil + }, func() { + height = 0 + }) + if err != nil { + return 0, err + } + + return height, nil +} + +// LatestCommitments returns the two latest commitments for both the local and +// remote party. These commitments are read from disk to ensure that only the +// latest fully committed state is returned. The first commitment returned is +// the local commitment, and the second returned is the remote commitment. +func (s *KVStore) LatestCommitments(channel *OpenChannel) ( + *ChannelCommitment, *ChannelCommitment, error) { + + err := kvdb.View(s.backend, func(tx kvdb.RTx) error { + chanBucket, err := FetchChanBucket( + tx, channel.IdentityPub, &channel.FundingOutpoint, + channel.ChainHash, + ) + if err != nil { + return err + } + + return FetchChanCommitments(chanBucket, channel) + }, func() {}) + if err != nil { + return nil, nil, err + } + + return &channel.LocalCommitment, &channel.RemoteCommitment, nil +} + +// RemoteRevocationStore returns the most up to date commitment version of the +// revocation storage tree for the remote party. This method can be used when +// acting on a possible contract breach to ensure, that the caller has the most +// up to date information required to deliver justice. +func (s *KVStore) RemoteRevocationStore( + channel *OpenChannel) (shachain.Store, error) { + + err := kvdb.View(s.backend, func(tx kvdb.RTx) error { + chanBucket, err := FetchChanBucket( + tx, channel.IdentityPub, &channel.FundingOutpoint, + channel.ChainHash, + ) + if err != nil { + return err + } + + return FetchChanRevocationState(chanBucket, channel) + }, func() {}) + if err != nil { + return nil, err + } + + return channel.RevocationStore, nil +} + +// commitTlvData stores all the optional data that may be stored as a TLV stream +// at the _end_ of the normal serialized commit on disk. +type commitTlvData struct { + // customBlob is a custom blob that may store extra data for custom + // channels. + customBlob tlv.OptionalRecordT[tlv.TlvType1, tlv.Blob] +} + +// encode encodes the aux data into the passed io.Writer. +func (c *commitTlvData) encode(w io.Writer) error { + var tlvRecords []tlv.Record + c.customBlob.WhenSome(func(blob tlv.RecordT[tlv.TlvType1, tlv.Blob]) { + tlvRecords = append(tlvRecords, blob.Record()) + }) + + // Create the tlv stream. + tlvStream, err := tlv.NewStream(tlvRecords...) + if err != nil { + return err + } + + return tlvStream.Encode(w) +} + +// decode attempts to decode the aux data from the passed io.Reader. +func (c *commitTlvData) decode(r io.Reader) error { + blob := c.customBlob.Zero() + + tlvStream, err := tlv.NewStream( + blob.Record(), + ) + if err != nil { + return err + } + + tlvs, err := tlvStream.DecodeWithParsedTypes(r) + if err != nil { + return err + } + + if _, ok := tlvs[c.customBlob.TlvType()]; ok { + c.customBlob = tlv.SomeRecordT(blob) + } + + return nil +} + +// DecodeCommitTlvData decodes and applies auxiliary TLV data to a commitment. +func DecodeCommitTlvData(r io.Reader, c *ChannelCommitment) error { + var auxData commitTlvData + if err := auxData.decode(r); err != nil { + return err + } + + amendCommitTlvData(c, auxData) + + return nil +} + +// EncodeCommitTlvData extracts and encodes auxiliary TLV data from a +// commitment. +func EncodeCommitTlvData(w io.Writer, c *ChannelCommitment) error { + auxData := extractCommitTlvData(c) + return auxData.encode(w) +} + +// amendCommitTlvData updates the commitment with the given auxiliary TLV data. +func amendCommitTlvData(c *ChannelCommitment, auxData commitTlvData) { + auxData.customBlob.WhenSomeV(func(blob tlv.Blob) { + c.CustomBlob = fn.Some(blob) + }) +} + +// extractCommitTlvData creates a new commitTlvData from the given commitment. +func extractCommitTlvData(c *ChannelCommitment) commitTlvData { + var auxData commitTlvData + + c.CustomBlob.WhenSome(func(blob tlv.Blob) { + auxData.customBlob = tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType1](blob), + ) + }) + + return auxData +} + +// SerializeLogUpdates serializes provided list of updates to a stream. +func SerializeLogUpdates(w io.Writer, logUpdates []LogUpdate) error { + numUpdates := uint16(len(logUpdates)) + if err := binary.Write(w, byteOrder, numUpdates); err != nil { + return err + } + + for _, diff := range logUpdates { + err := WriteElements(w, diff.LogIndex, diff.UpdateMsg) + if err != nil { + return err + } + } + + return nil +} + +// DeserializeLogUpdates deserializes a list of updates from a stream. +func DeserializeLogUpdates(r io.Reader) ([]LogUpdate, error) { + var numUpdates uint16 + if err := binary.Read(r, byteOrder, &numUpdates); err != nil { + return nil, err + } + + logUpdates := make([]LogUpdate, numUpdates) + for i := 0; i < int(numUpdates); i++ { + err := ReadElements(r, + &logUpdates[i].LogIndex, &logUpdates[i].UpdateMsg, + ) + if err != nil { + return nil, err + } + } + + return logUpdates, nil +} + +// SerializeCommitDiff serializes the commit diff. +func SerializeCommitDiff(w io.Writer, diff *CommitDiff) error { + if err := SerializeChanCommit(w, &diff.Commitment); err != nil { + return err + } + + if err := WriteElements(w, diff.CommitSig); err != nil { + return err + } + + if err := SerializeLogUpdates(w, diff.LogUpdates); err != nil { + return err + } + + numOpenRefs := uint16(len(diff.OpenedCircuitKeys)) + if err := binary.Write(w, byteOrder, numOpenRefs); err != nil { + return err + } + + for _, openRef := range diff.OpenedCircuitKeys { + err := WriteElements(w, openRef.ChanID, openRef.HtlcID) + if err != nil { + return err + } + } + + numClosedRefs := uint16(len(diff.ClosedCircuitKeys)) + if err := binary.Write(w, byteOrder, numClosedRefs); err != nil { + return err + } + + for _, closedRef := range diff.ClosedCircuitKeys { + err := WriteElements(w, closedRef.ChanID, closedRef.HtlcID) + if err != nil { + return err + } + } + + // We'll also encode the commit aux data stream here. We do this here + // rather than above (at the call to serializeChanCommit), to ensure + // backwards compat for reads to existing non-custom channels. + if err := EncodeCommitTlvData(w, &diff.Commitment); err != nil { + return fmt.Errorf("unable to write aux data: %w", err) + } + + return nil +} + +// DeserializeCommitDiff deserializes the commit diff. +func DeserializeCommitDiff(r io.Reader) (*CommitDiff, error) { + var ( + d CommitDiff + err error + ) + + d.Commitment, err = DeserializeChanCommit(r) + if err != nil { + return nil, err + } + + var msg lnwire.Message + if err := ReadElements(r, &msg); err != nil { + return nil, err + } + commitSig, ok := msg.(*lnwire.CommitSig) + if !ok { + return nil, fmt.Errorf("expected lnwire.CommitSig, instead "+ + "read: %T", msg) + } + d.CommitSig = commitSig + + d.LogUpdates, err = DeserializeLogUpdates(r) + if err != nil { + return nil, err + } + + var numOpenRefs uint16 + if err := binary.Read(r, byteOrder, &numOpenRefs); err != nil { + return nil, err + } + + d.OpenedCircuitKeys = make([]models.CircuitKey, numOpenRefs) + for i := 0; i < int(numOpenRefs); i++ { + err := ReadElements(r, + &d.OpenedCircuitKeys[i].ChanID, + &d.OpenedCircuitKeys[i].HtlcID) + if err != nil { + return nil, err + } + } + + var numClosedRefs uint16 + if err := binary.Read(r, byteOrder, &numClosedRefs); err != nil { + return nil, err + } + + d.ClosedCircuitKeys = make([]models.CircuitKey, numClosedRefs) + for i := 0; i < int(numClosedRefs); i++ { + err := ReadElements(r, + &d.ClosedCircuitKeys[i].ChanID, + &d.ClosedCircuitKeys[i].HtlcID) + if err != nil { + return nil, err + } + } + + // As a final step, we'll read out any aux commit data that we have at + // the end of this byte stream. We do this here to ensure backward + // compatibility, as otherwise we risk erroneously reading into the + // wrong field. + if err := DecodeCommitTlvData(r, &d.Commitment); err != nil { + return nil, fmt.Errorf("unable to decode aux data: %w", err) + } + + return &d, nil +} diff --git a/chanstate/kv_final_htlc.go b/chanstate/kv_final_htlc.go new file mode 100644 index 00000000000..2e838875d13 --- /dev/null +++ b/chanstate/kv_final_htlc.go @@ -0,0 +1,256 @@ +package chanstate + +import ( + "errors" + "fmt" + + "github.com/lightningnetwork/lnd/kvdb" + "github.com/lightningnetwork/lnd/lnwire" +) + +var ( + // ErrFinalHtlcsBucketNotFound signals that the top-level final htlcs + // bucket does not exist. + ErrFinalHtlcsBucketNotFound = errors.New("final htlcs bucket not " + + "found") + + // ErrFinalChannelBucketNotFound signals that the channel bucket for + // final htlc outcomes does not exist. + ErrFinalChannelBucketNotFound = errors.New("final htlcs channel " + + "bucket not found") + + // ErrHtlcUnknown signals that an htlc has no final resolution yet. + ErrHtlcUnknown = errors.New("htlc unknown") +) + +var ( + // finalHtlcsBucket contains the htlcs that have been resolved + // definitively. Within this bucket, there is a sub-bucket for each + // channel. In each channel bucket, the htlc indices are stored along + // with final outcome. + // + // final-htlcs -> chanID -> htlcIndex -> outcome + // + // 'outcome' is a byte value that encodes: + // + // | true false + // ------+------------------ + // bit 0 | settled failed + // bit 1 | offchain onchain + // + // This bucket is positioned at the root level, because its contents + // will be kept independent of the channel lifecycle. This is to avoid + // the situation where a channel force-closes autonomously and the user + // not being able to query for htlc outcomes anymore. + finalHtlcsBucket = []byte("final-htlcs") +) + +// FinalHtlcsBucketKey returns the top-level bucket key that stores final htlc +// outcomes. +func FinalHtlcsBucketKey() []byte { + return finalHtlcsBucket +} + +// FinalHtlcByte defines a byte type that encodes information about the final +// htlc resolution. +type FinalHtlcByte byte + +const ( + // FinalHtlcSettledBit is the bit that encodes whether the htlc was + // settled or failed. + FinalHtlcSettledBit FinalHtlcByte = 1 << 0 + + // FinalHtlcOffchainBit is the bit that encodes whether the htlc was + // resolved offchain or onchain. + FinalHtlcOffchainBit FinalHtlcByte = 1 << 1 +) + +// FetchFinalHtlcsBucket returns the read-only final htlc bucket for a channel. +func FetchFinalHtlcsBucket(tx kvdb.RTx, + chanID lnwire.ShortChannelID) (kvdb.RBucket, error) { + + finalHtlcsBucket := tx.ReadBucket(finalHtlcsBucket) + if finalHtlcsBucket == nil { + return nil, ErrFinalHtlcsBucketNotFound + } + + var chanIDBytes [8]byte + byteOrder.PutUint64(chanIDBytes[:], chanID.ToUint64()) + + chanBucket := finalHtlcsBucket.NestedReadBucket(chanIDBytes[:]) + if chanBucket == nil { + return nil, ErrFinalChannelBucketNotFound + } + + return chanBucket, nil +} + +// FetchFinalHtlcsBucketRw returns the writable final htlc bucket for a channel. +func FetchFinalHtlcsBucketRw(tx kvdb.RwTx, + chanID lnwire.ShortChannelID) (kvdb.RwBucket, error) { + + finalHtlcsBucket, err := tx.CreateTopLevelBucket(finalHtlcsBucket) + if err != nil { + return nil, err + } + + var chanIDBytes [8]byte + byteOrder.PutUint64(chanIDBytes[:], chanID.ToUint64()) + chanBucket, err := finalHtlcsBucket.CreateBucketIfNotExists( + chanIDBytes[:], + ) + if err != nil { + return nil, err + } + + return chanBucket, nil +} + +// PutFinalHtlc writes the final htlc outcome to the database. Additionally it +// records whether the htlc was resolved off-chain or on-chain. +func PutFinalHtlc(finalHtlcsBucket kvdb.RwBucket, id uint64, + info FinalHtlcInfo) error { + + var key [8]byte + byteOrder.PutUint64(key[:], id) + + var finalHtlcByte FinalHtlcByte + if info.Settled { + finalHtlcByte |= FinalHtlcSettledBit + } + if info.Offchain { + finalHtlcByte |= FinalHtlcOffchainBit + } + + return finalHtlcsBucket.Put(key[:], []byte{byte(finalHtlcByte)}) +} + +// FetchFinalHtlc reads a final htlc outcome from the final htlc channel bucket. +func FetchFinalHtlc(finalHtlcsBucket kvdb.RBucket, + htlcIndex uint64) (*FinalHtlcInfo, error) { + + var idBytes [8]byte + byteOrder.PutUint64(idBytes[:], htlcIndex) + + value := finalHtlcsBucket.Get(idBytes[:]) + if value == nil { + return nil, ErrHtlcUnknown + } + + if len(value) != 1 { + return nil, errors.New("unexpected final htlc value length") + } + + info := FinalHtlcInfo{ + Settled: value[0]&byte(FinalHtlcSettledBit) != 0, + Offchain: value[0]&byte(FinalHtlcOffchainBit) != 0, + } + + return &info, nil +} + +// LookupFinalHtlc retrieves a final htlc resolution from the database. If the +// htlc has no final resolution yet, ErrHtlcUnknown is returned. +func (s *KVStore) LookupFinalHtlc(chanID lnwire.ShortChannelID, + htlcIndex uint64) (*FinalHtlcInfo, error) { + + var info *FinalHtlcInfo + + err := kvdb.View(s.backend, func(tx kvdb.RTx) error { + finalHtlcsBucket, err := FetchFinalHtlcsBucket(tx, chanID) + switch { + case errors.Is(err, ErrFinalHtlcsBucketNotFound): + fallthrough + + case errors.Is(err, ErrFinalChannelBucketNotFound): + return ErrHtlcUnknown + + case err != nil: + return fmt.Errorf("cannot fetch final htlcs bucket: %w", + err) + } + + info, err = FetchFinalHtlc(finalHtlcsBucket, htlcIndex) + + return err + }, func() { + info = nil + }) + if err != nil { + return nil, err + } + + return info, nil +} + +// PutOnchainFinalHtlcOutcome stores the final on-chain outcome of an htlc in +// the database. +func (s *KVStore) PutOnchainFinalHtlcOutcome(chanID lnwire.ShortChannelID, + htlcID uint64, settled bool) error { + + // Skip if the user did not opt in to storing final resolutions. + if !s.storeFinalHtlcResolutions { + return nil + } + + return kvdb.Update(s.backend, func(tx kvdb.RwTx) error { + finalHtlcsBucket, err := FetchFinalHtlcsBucketRw(tx, chanID) + if err != nil { + return err + } + + return PutFinalHtlc( + finalHtlcsBucket, htlcID, + FinalHtlcInfo{ + Settled: settled, + Offchain: false, + }, + ) + }, func() {}) +} + +// ProcessFinalHtlc stores a final htlc outcome in the database if signaled via +// the supplied log update. An in-memory htlcs map is updated too. +func ProcessFinalHtlc(finalHtlcsBucket kvdb.RwBucket, upd LogUpdate, + finalHtlcs map[uint64]bool) error { + + var ( + settled bool + id uint64 + ) + + switch msg := upd.UpdateMsg.(type) { + case *lnwire.UpdateFulfillHTLC: + settled = true + id = msg.ID + + case *lnwire.UpdateFailHTLC: + settled = false + id = msg.ID + + case *lnwire.UpdateFailMalformedHTLC: + settled = false + id = msg.ID + + default: + return nil + } + + // Store the final resolution in the database if a bucket is provided. + if finalHtlcsBucket != nil { + err := PutFinalHtlc( + finalHtlcsBucket, id, + FinalHtlcInfo{ + Settled: settled, + Offchain: true, + }, + ) + if err != nil { + return err + } + } + + finalHtlcs[id] = settled + + return nil +} diff --git a/chanstate/kv_forwarding_package.go b/chanstate/kv_forwarding_package.go new file mode 100644 index 00000000000..bd703204c27 --- /dev/null +++ b/chanstate/kv_forwarding_package.go @@ -0,0 +1,892 @@ +package chanstate + +import ( + "bytes" + "encoding/binary" + "errors" + "io" + + "github.com/lightningnetwork/lnd/kvdb" + "github.com/lightningnetwork/lnd/lnwire" +) + +//nolint:ll +var ( + // ErrCorruptedFwdPkg signals that the on-disk structure of the + // forwarding package has potentially been mangled. + ErrCorruptedFwdPkg = errors.New("fwding package db has been corrupted") + + // fwdPackagesKey is the root-level bucket that all forwarding packages + // are written. This bucket is further subdivided based on the short + // channel ID of each channel. + // + // Bucket hierarchy: + // + // fwdPackagesKey(root-bucket) + // | + // |-- + // | | + // | |-- + // | | |-- ackFilterKey: + // | | |-- settleFailFilterKey: + // | | |-- fwdFilterKey: + // | | | + // | | |-- addBucketKey + // | | | |-- : + // | | | |-- : + // | | | ... + // | | | + // | | |-- failSettleBucketKey + // | | |-- : + // | | |-- : + // | | ... + // | | + // | |-- + // | | | + // | ... ... + // | + // | + // |-- + // | | + // | ... + // ... + // + fwdPackagesKey = []byte("fwd-packages") + + // addBucketKey is the bucket to which all Add log updates are written. + addBucketKey = []byte("add-updates") + + // failSettleBucketKey is the bucket to which all Settle/Fail log + // updates are written. + failSettleBucketKey = []byte("fail-settle-updates") + + // fwdFilterKey is a key used to write the set of Adds that passed + // validation and are to be forwarded to the switch. + // NOTE: The presence of this key within a forwarding package indicates + // that the package has reached FwdStateProcessed. + fwdFilterKey = []byte("fwd-filter-key") + + // ackFilterKey is a key used to access the PkgFilter indicating which + // Adds have received a Settle/Fail. This response may come from a + // number of sources, including: exitHop settle/fails, switch failures, + // chain arbiter interjections, as well as settle/fails from the + // next hop in the route. + ackFilterKey = []byte("ack-filter-key") + + // settleFailFilterKey is a key used to access the PkgFilter indicating + // which Settles/Fails in have been received and processed by the link + // that originally received the Add. + settleFailFilterKey = []byte("settle-fail-filter-key") +) + +// FwdPackagesBucketKey returns the root-level bucket key that stores +// forwarding packages. +func FwdPackagesBucketKey() []byte { + return fwdPackagesKey +} + +// Encode serializes the AddRef to the given io.Writer. +func (a *AddRef) Encode(w io.Writer) error { + if err := binary.Write(w, binary.BigEndian, a.Height); err != nil { + return err + } + + return binary.Write(w, binary.BigEndian, a.Index) +} + +// Decode deserializes the AddRef from the given io.Reader. +func (a *AddRef) Decode(r io.Reader) error { + if err := binary.Read(r, binary.BigEndian, &a.Height); err != nil { + return err + } + + return binary.Read(r, binary.BigEndian, &a.Index) +} + +// Size returns number of bytes produced when the PkgFilter is serialized. +func (f *PkgFilter) Size() uint16 { + // 2 bytes for uint16 `count`, then round up number of bytes required to + // represent `count` bits. + return 2 + (f.count+7)/8 +} + +// Encode writes the filter to the provided io.Writer. +func (f *PkgFilter) Encode(w io.Writer) error { + if err := binary.Write(w, binary.BigEndian, f.count); err != nil { + return err + } + + _, err := w.Write(f.filter) + + return err +} + +// Decode reads the filter from the provided io.Reader. +func (f *PkgFilter) Decode(r io.Reader) error { + if err := binary.Read(r, binary.BigEndian, &f.count); err != nil { + return err + } + + f.filter = make([]byte, f.Size()-2) + _, err := io.ReadFull(r, f.filter) + + return err +} + +// SettleFailAcker is a generic interface providing the ability to acknowledge +// settle/fail HTLCs stored in forwarding packages. +type SettleFailAcker interface { + // AckSettleFails atomically updates the settle-fail filters in *other* + // channels' forwarding packages. + AckSettleFails(tx kvdb.RwTx, settleFailRefs ...SettleFailRef) error +} + +// GlobalFwdPkgReader is an interface used to retrieve the forwarding packages +// of any active channel. +type GlobalFwdPkgReader interface { + // LoadChannelFwdPkgs loads all known forwarding packages for the given + // channel. + LoadChannelFwdPkgs(tx kvdb.RTx, + source lnwire.ShortChannelID) ([]*FwdPkg, error) +} + +// FwdOperator defines the interfaces for managing forwarding packages that are +// external to a particular channel. This interface is used by the switch to +// read forwarding packages from arbitrary channels, and acknowledge settles and +// fails for locally-sourced payments. +type FwdOperator interface { + // GlobalFwdPkgReader provides read access to all known forwarding + // packages + GlobalFwdPkgReader + + // SettleFailAcker grants the ability to acknowledge settles or fails + // residing in arbitrary forwarding packages. + SettleFailAcker +} + +// SwitchPackager is a concrete implementation of the FwdOperator interface. +// A SwitchPackager offers the ability to read any forwarding package, and ack +// arbitrary settle and fail HTLCs. +type SwitchPackager struct{} + +// NewSwitchPackager instantiates a new SwitchPackager. +func NewSwitchPackager() *SwitchPackager { + return &SwitchPackager{} +} + +// AckSettleFails atomically updates the settle-fail filters in *other* +// channels' forwarding packages, to mark that the switch has received a settle +// or fail residing in the forwarding package of a link. +func (*SwitchPackager) AckSettleFails(tx kvdb.RwTx, + settleFailRefs ...SettleFailRef) error { + + return ackSettleFails(tx, settleFailRefs) +} + +// LoadChannelFwdPkgs loads all forwarding packages for a particular channel. +func (*SwitchPackager) LoadChannelFwdPkgs(tx kvdb.RTx, + source lnwire.ShortChannelID) ([]*FwdPkg, error) { + + return loadChannelFwdPkgs(tx, source) +} + +// FwdPackager supports all operations required to modify fwd packages, such as +// creation, updates, reading, and removal. The interfaces are broken down in +// this way to support future delegation of the subinterfaces. +type FwdPackager interface { + // AddFwdPkg serializes and writes a FwdPkg for this channel at the + // remote commitment height included in the forwarding package. + AddFwdPkg(tx kvdb.RwTx, fwdPkg *FwdPkg) error + + // SetFwdFilter looks up the forwarding package at the remote `height` + // and sets the `fwdFilter`, marking the Adds for which: + // 1) We are not the exit node + // 2) Passed all validation + // 3) Should be forwarded to the switch immediately after a failure + SetFwdFilter(tx kvdb.RwTx, height uint64, fwdFilter *PkgFilter) error + + // AckAddHtlcs atomically updates the add filters in this channel's + // forwarding packages to mark the resolution of an Add that was + // received from the remote party. + AckAddHtlcs(tx kvdb.RwTx, addRefs ...AddRef) error + + // SettleFailAcker allows a link to acknowledge settle/fail HTLCs + // belonging to other channels. + SettleFailAcker + + // LoadFwdPkgs loads all known forwarding packages owned by this + // channel. + LoadFwdPkgs(tx kvdb.RTx) ([]*FwdPkg, error) + + // RemovePkg deletes a forwarding package owned by this channel at + // the provided remote `height`. + RemovePkg(tx kvdb.RwTx, height uint64) error + + // Wipe deletes all the forwarding packages owned by this channel. + Wipe(tx kvdb.RwTx) error +} + +// ChannelPackager is used by a channel to manage the lifecycle of its +// forwarding packages. The packager is tied to a particular source channel ID, +// allowing it to create and edit its own packages. Each packager also has the +// ability to +// remove fail/settle htlcs that correspond to an add contained in one of +// source's packages. +type ChannelPackager struct { + source lnwire.ShortChannelID +} + +// NewChannelPackager creates a new packager for a single channel. +func NewChannelPackager(source lnwire.ShortChannelID) *ChannelPackager { + return &ChannelPackager{ + source: source, + } +} + +func newChannelPackager(channel *OpenChannel) *ChannelPackager { + return NewChannelPackager(channel.ShortChannelID) +} + +// LoadFwdPkgs scans the forwarding log for any packages that haven't been +// processed, and returns their deserialized log updates in map indexed by the +// remote commitment height at which the updates were locked in. +func (s *KVStore) LoadFwdPkgs(channel *OpenChannel) ([]*FwdPkg, + error) { + + var fwdPkgs []*FwdPkg + if err := kvdb.View(s.backend, func(tx kvdb.RTx) error { + var err error + fwdPkgs, err = newChannelPackager(channel).LoadFwdPkgs(tx) + return err + }, func() { + fwdPkgs = nil + }); err != nil { + return nil, err + } + + return fwdPkgs, nil +} + +// AckAddHtlcs updates the AckAddFilter containing any of the provided AddRefs +// indicating that a response to this Add has been committed to the remote +// party. Doing so will prevent these Add HTLCs from being reforwarded +// internally. +func (s *KVStore) AckAddHtlcs(channel *OpenChannel, + addRefs ...AddRef) error { + + return kvdb.Update(s.backend, func(tx kvdb.RwTx) error { + return newChannelPackager(channel).AckAddHtlcs(tx, addRefs...) + }, func() {}) +} + +// AckSettleFails updates the SettleFailFilter containing any of the provided +// SettleFailRefs, indicating that the response has been delivered to the +// incoming link, corresponding to a particular AddRef. Doing so will prevent +// the responses from being retransmitted internally. +func (s *KVStore) AckSettleFails(channel *OpenChannel, + settleFailRefs ...SettleFailRef) error { + + return kvdb.Update(s.backend, func(tx kvdb.RwTx) error { + return newChannelPackager(channel).AckSettleFails( + tx, settleFailRefs..., + ) + }, func() {}) +} + +// SetFwdFilter atomically sets the forwarding filter for the forwarding package +// identified by `height`. +func (s *KVStore) SetFwdFilter(channel *OpenChannel, height uint64, + fwdFilter *PkgFilter) error { + + return kvdb.Update(s.backend, func(tx kvdb.RwTx) error { + return newChannelPackager(channel).SetFwdFilter( + tx, height, fwdFilter, + ) + }, func() {}) +} + +// RemoveFwdPkgs atomically removes forwarding packages specified by the remote +// commitment heights. If one of the intermediate RemovePkg calls fails, then +// the later packages won't be removed. +// +// NOTE: This method should only be called on packages marked FwdStateCompleted. +func (s *KVStore) RemoveFwdPkgs(channel *OpenChannel, + heights ...uint64) error { + + return kvdb.Update(s.backend, func(tx kvdb.RwTx) error { + packager := newChannelPackager(channel) + + for _, height := range heights { + err := packager.RemovePkg(tx, height) + if err != nil { + return err + } + } + + return nil + }, func() {}) +} + +// AddFwdPkg writes a newly locked in forwarding package to disk. +func (*ChannelPackager) AddFwdPkg(tx kvdb.RwTx, fwdPkg *FwdPkg) error { + fwdPkgBkt, err := tx.CreateTopLevelBucket(fwdPackagesKey) + if err != nil { + return err + } + + source := forwardingLogKey(fwdPkg.Source.ToUint64()) + sourceBkt, err := fwdPkgBkt.CreateBucketIfNotExists(source[:]) + if err != nil { + return err + } + + heightKey := forwardingLogKey(fwdPkg.Height) + heightBkt, err := sourceBkt.CreateBucketIfNotExists(heightKey[:]) + if err != nil { + return err + } + + // Write ADD updates we received at this commit height. + addBkt, err := heightBkt.CreateBucketIfNotExists(addBucketKey) + if err != nil { + return err + } + + // Write SETTLE/FAIL updates we received at this commit height. + failSettleBkt, err := heightBkt.CreateBucketIfNotExists( + failSettleBucketKey, + ) + if err != nil { + return err + } + + for i := range fwdPkg.Adds { + err = putLogUpdate(addBkt, uint16(i), &fwdPkg.Adds[i]) + if err != nil { + return err + } + } + + // Persist the initialized pkg filter, which will be used to determine + // when we can remove this forwarding package from disk. + var ackFilterBuf bytes.Buffer + if err := fwdPkg.AckFilter.Encode(&ackFilterBuf); err != nil { + return err + } + + err = heightBkt.Put(ackFilterKey, ackFilterBuf.Bytes()) + if err != nil { + return err + } + + for i := range fwdPkg.SettleFails { + err = putLogUpdate( + failSettleBkt, uint16(i), &fwdPkg.SettleFails[i], + ) + if err != nil { + return err + } + } + + var settleFailFilterBuf bytes.Buffer + err = fwdPkg.SettleFailFilter.Encode(&settleFailFilterBuf) + if err != nil { + return err + } + + return heightBkt.Put(settleFailFilterKey, settleFailFilterBuf.Bytes()) +} + +// putLogUpdate writes an htlc to the provided `bkt`, using `index` as the key. +func putLogUpdate(bkt kvdb.RwBucket, idx uint16, htlc *LogUpdate) error { + var b bytes.Buffer + if err := serializeLogUpdate(&b, htlc); err != nil { + return err + } + + return bkt.Put(uint16Key(idx), b.Bytes()) +} + +// serializeLogUpdate writes a log update to the provided io.Writer. +func serializeLogUpdate(w io.Writer, l *LogUpdate) error { + return WriteElements(w, l.LogIndex, l.UpdateMsg) +} + +// deserializeLogUpdate reads a log update from the provided io.Reader. +func deserializeLogUpdate(r io.Reader) (*LogUpdate, error) { + l := &LogUpdate{} + if err := ReadElements(r, &l.LogIndex, &l.UpdateMsg); err != nil { + return nil, err + } + + return l, nil +} + +// LoadFwdPkgs scans the forwarding log for any packages that haven't been +// processed, and returns their deserialized log updates in a map indexed by the +// remote commitment height at which the updates were locked in. +func (p *ChannelPackager) LoadFwdPkgs(tx kvdb.RTx) ([]*FwdPkg, error) { + return loadChannelFwdPkgs(tx, p.source) +} + +// loadChannelFwdPkgs loads all forwarding packages owned by `source`. +func loadChannelFwdPkgs(tx kvdb.RTx, + source lnwire.ShortChannelID) ([]*FwdPkg, error) { + + fwdPkgBkt := tx.ReadBucket(fwdPackagesKey) + if fwdPkgBkt == nil { + return nil, nil + } + + sourceKey := forwardingLogKey(source.ToUint64()) + sourceBkt := fwdPkgBkt.NestedReadBucket(sourceKey[:]) + if sourceBkt == nil { + return nil, nil + } + + var heights []uint64 + if err := sourceBkt.ForEach(func(k, _ []byte) error { + if len(k) != 8 { + return ErrCorruptedFwdPkg + } + + heights = append(heights, byteOrder.Uint64(k)) + + return nil + }); err != nil { + return nil, err + } + + // Load the forwarding package for each retrieved height. + fwdPkgs := make([]*FwdPkg, 0, len(heights)) + for _, height := range heights { + fwdPkg, err := loadFwdPkg(fwdPkgBkt, source, height) + if err != nil { + return nil, err + } + + fwdPkgs = append(fwdPkgs, fwdPkg) + } + + return fwdPkgs, nil +} + +// loadFwdPkg reads the packager's fwd pkg at a given height, and determines the +// appropriate FwdState. +func loadFwdPkg(fwdPkgBkt kvdb.RBucket, source lnwire.ShortChannelID, + height uint64) (*FwdPkg, error) { + + sourceKey := forwardingLogKey(source.ToUint64()) + sourceBkt := fwdPkgBkt.NestedReadBucket(sourceKey[:]) + if sourceBkt == nil { + return nil, ErrCorruptedFwdPkg + } + + heightKey := forwardingLogKey(height) + heightBkt := sourceBkt.NestedReadBucket(heightKey[:]) + if heightBkt == nil { + return nil, ErrCorruptedFwdPkg + } + + // Load ADDs from disk. + addBkt := heightBkt.NestedReadBucket(addBucketKey) + if addBkt == nil { + return nil, ErrCorruptedFwdPkg + } + + adds, err := loadHtlcs(addBkt) + if err != nil { + return nil, err + } + + // Load ack filter from disk. + ackFilterBytes := heightBkt.Get(ackFilterKey) + if ackFilterBytes == nil { + return nil, ErrCorruptedFwdPkg + } + ackFilterReader := bytes.NewReader(ackFilterBytes) + + ackFilter := &PkgFilter{} + if err := ackFilter.Decode(ackFilterReader); err != nil { + return nil, err + } + + // Load SETTLE/FAILs from disk. + failSettleBkt := heightBkt.NestedReadBucket(failSettleBucketKey) + if failSettleBkt == nil { + return nil, ErrCorruptedFwdPkg + } + + failSettles, err := loadHtlcs(failSettleBkt) + if err != nil { + return nil, err + } + + // Load settle fail filter from disk. + settleFailFilterBytes := heightBkt.Get(settleFailFilterKey) + if settleFailFilterBytes == nil { + return nil, ErrCorruptedFwdPkg + } + settleFailFilterReader := bytes.NewReader(settleFailFilterBytes) + + settleFailFilter := &PkgFilter{} + if err := settleFailFilter.Decode(settleFailFilterReader); err != nil { + return nil, err + } + + // Initialize the fwding package, which always starts in the + // FwdStateLockedIn. We can determine what state the package was left in + // by examining constraints on the information loaded from disk. + fwdPkg := &FwdPkg{ + Source: source, + State: FwdStateLockedIn, + Height: height, + Adds: adds, + AckFilter: ackFilter, + SettleFails: failSettles, + SettleFailFilter: settleFailFilter, + } + + // Check if the forward filter has been persisted to disk. + // This indicates whether the Adds in this package have been processed. + // + // NOTE: We also expect packages with no Adds (settle/fail only packages + // or empty packages) to have the fwd filter set to signal that the + // packages have been processed. + fwdFilterBytes := heightBkt.Get(fwdFilterKey) + + // Handle packages with Adds that haven't been processed yet. + if fwdFilterBytes == nil { + // Create a new forward filter for the unprocessed Adds. + nAdds := uint16(len(adds)) + fwdPkg.FwdFilter = NewPkgFilter(nAdds) + + return fwdPkg, nil + } + + // Load the existing forward filter from disk. + fwdFilterReader := bytes.NewReader(fwdFilterBytes) + fwdPkg.FwdFilter = &PkgFilter{} + if err := fwdPkg.FwdFilter.Decode(fwdFilterReader); err != nil { + return nil, err + } + + // Mark the package as processed since the forward filter exists. + fwdPkg.State = FwdStateProcessed + + // If every add, settle, and fail has been fully acknowledged, we can + // safely set the package's state to FwdStateCompleted, signalling that + // it can be garbage collected. + if fwdPkg.AckFilter.IsFull() && fwdPkg.SettleFailFilter.IsFull() { + fwdPkg.State = FwdStateCompleted + } + + return fwdPkg, nil +} + +// loadHtlcs retrieves all serialized htlcs in a bucket, returning +// them in order of the indexes they were written under. +func loadHtlcs(bkt kvdb.RBucket) ([]LogUpdate, error) { + var htlcs []LogUpdate + if err := bkt.ForEach(func(_, v []byte) error { + htlc, err := deserializeLogUpdate(bytes.NewReader(v)) + if err != nil { + return err + } + + htlcs = append(htlcs, *htlc) + + return nil + }); err != nil { + return nil, err + } + + return htlcs, nil +} + +// SetFwdFilter writes the set of indexes corresponding to Adds at the +// `height` that are to be forwarded to the switch. Calling this method causes +// the forwarding package at `height` to be in FwdStateProcessed. We write this +// forwarding decision so that we always arrive at the same behavior for HTLCs +// leaving this channel. After a restart, we skip validation of these Adds, +// since they are assumed to have already been validated, and make the switch or +// outgoing link responsible for handling replays. +func (p *ChannelPackager) SetFwdFilter(tx kvdb.RwTx, height uint64, + fwdFilter *PkgFilter) error { + + fwdPkgBkt := tx.ReadWriteBucket(fwdPackagesKey) + if fwdPkgBkt == nil { + return ErrCorruptedFwdPkg + } + + source := forwardingLogKey(p.source.ToUint64()) + sourceBkt := fwdPkgBkt.NestedReadWriteBucket(source[:]) + if sourceBkt == nil { + return ErrCorruptedFwdPkg + } + + heightKey := forwardingLogKey(height) + heightBkt := sourceBkt.NestedReadWriteBucket(heightKey[:]) + if heightBkt == nil { + return ErrCorruptedFwdPkg + } + + // If the fwd filter has already been written, we return early to avoid + // modifying the persistent state. + forwardedAddsBytes := heightBkt.Get(fwdFilterKey) + if forwardedAddsBytes != nil { + return nil + } + + // Otherwise we serialize and write the provided fwd filter. + var b bytes.Buffer + if err := fwdFilter.Encode(&b); err != nil { + return err + } + + return heightBkt.Put(fwdFilterKey, b.Bytes()) +} + +// AckAddHtlcs accepts a list of references to add htlcs, and updates the +// AckAddFilter of those forwarding packages to indicate that a settle or fail +// has been received in response to the add. +func (p *ChannelPackager) AckAddHtlcs(tx kvdb.RwTx, addRefs ...AddRef) error { + if len(addRefs) == 0 { + return nil + } + + fwdPkgBkt := tx.ReadWriteBucket(fwdPackagesKey) + if fwdPkgBkt == nil { + return ErrCorruptedFwdPkg + } + + sourceKey := forwardingLogKey(p.source.ToUint64()) + sourceBkt := fwdPkgBkt.NestedReadWriteBucket(sourceKey[:]) + if sourceBkt == nil { + return ErrCorruptedFwdPkg + } + + // Organize the forward references such that we just get a single slice + // of indexes for each unique height. + heightDiffs := make(map[uint64][]uint16) + for _, addRef := range addRefs { + heightDiffs[addRef.Height] = append( + heightDiffs[addRef.Height], + addRef.Index, + ) + } + + // Load each height bucket once and remove all acked htlcs at that + // height. + for height, indexes := range heightDiffs { + err := ackAddHtlcsAtHeight(sourceBkt, height, indexes) + if err != nil { + return err + } + } + + return nil +} + +// ackAddHtlcsAtHeight updates the AddAckFilter of a single forwarding package +// with a list of indexes, writing the resulting filter back in its place. +func ackAddHtlcsAtHeight(sourceBkt kvdb.RwBucket, height uint64, + indexes []uint16) error { + + heightKey := forwardingLogKey(height) + heightBkt := sourceBkt.NestedReadWriteBucket(heightKey[:]) + if heightBkt == nil { + // If the height bucket isn't found, this could be because the + // forwarding package was already removed. We'll return nil to + // signal that the operation is successful, as there is nothing + // to ack. + return nil + } + + // Load ack filter from disk. + ackFilterBytes := heightBkt.Get(ackFilterKey) + if ackFilterBytes == nil { + return ErrCorruptedFwdPkg + } + + ackFilter := &PkgFilter{} + ackFilterReader := bytes.NewReader(ackFilterBytes) + if err := ackFilter.Decode(ackFilterReader); err != nil { + return err + } + + // Update the ack filter for this height. + for _, index := range indexes { + ackFilter.Set(index) + } + + // Write the resulting filter to disk. + var ackFilterBuf bytes.Buffer + if err := ackFilter.Encode(&ackFilterBuf); err != nil { + return err + } + + return heightBkt.Put(ackFilterKey, ackFilterBuf.Bytes()) +} + +// AckSettleFails persistently acknowledges settles or fails from a remote +// forwarding package. This should only be called after the source of the Add +// has locked in the settle/fail, or it becomes otherwise safe to forgo +// retransmitting the settle/fail after a restart. +func (p *ChannelPackager) AckSettleFails(tx kvdb.RwTx, + settleFailRefs ...SettleFailRef) error { + + return ackSettleFails(tx, settleFailRefs) +} + +// ackSettleFails persistently acknowledges a batch of settle fail references. +func ackSettleFails(tx kvdb.RwTx, settleFailRefs []SettleFailRef) error { + if len(settleFailRefs) == 0 { + return nil + } + + fwdPkgBkt := tx.ReadWriteBucket(fwdPackagesKey) + if fwdPkgBkt == nil { + return ErrCorruptedFwdPkg + } + + // Organize the forward references such that we just get a single slice + // of indexes for each unique destination-height pair. + destHeightDiffs := make(map[lnwire.ShortChannelID]map[uint64][]uint16) + for _, settleFailRef := range settleFailRefs { + destHeights, ok := destHeightDiffs[settleFailRef.Source] + if !ok { + destHeights = make(map[uint64][]uint16) + destHeightDiffs[settleFailRef.Source] = destHeights + } + + destHeights[settleFailRef.Height] = append( + destHeights[settleFailRef.Height], + settleFailRef.Index, + ) + } + + // With the references organized by destination and height, we now load + // each remote bucket, and update the settle fail filter for any + // settle/fail htlcs. + for dest, destHeights := range destHeightDiffs { + destKey := forwardingLogKey(dest.ToUint64()) + destBkt := fwdPkgBkt.NestedReadWriteBucket(destKey[:]) + if destBkt == nil { + // If the destination bucket is not found, this is + // likely the result of the destination channel being + // closed and having it's forwarding packages wiped. We + // won't treat this as an error, because the response + // will no longer be retransmitted internally. + continue + } + + for height, indexes := range destHeights { + err := ackSettleFailsAtHeight(destBkt, height, indexes) + if err != nil { + return err + } + } + } + + return nil +} + +// ackSettleFailsAtHeight given a destination bucket, acks the provided indexes +// at particular a height by updating the settle fail filter. +func ackSettleFailsAtHeight(destBkt kvdb.RwBucket, height uint64, + indexes []uint16) error { + + heightKey := forwardingLogKey(height) + heightBkt := destBkt.NestedReadWriteBucket(heightKey[:]) + if heightBkt == nil { + // If the height bucket isn't found, this could be because the + // forwarding package was already removed. We'll return nil to + // signal that the operation is as there is nothing to ack. + return nil + } + + // Load ack filter from disk. + settleFailFilterBytes := heightBkt.Get(settleFailFilterKey) + if settleFailFilterBytes == nil { + return ErrCorruptedFwdPkg + } + + settleFailFilter := &PkgFilter{} + settleFailFilterReader := bytes.NewReader(settleFailFilterBytes) + if err := settleFailFilter.Decode(settleFailFilterReader); err != nil { + return err + } + + // Update the ack filter for this height. + for _, index := range indexes { + settleFailFilter.Set(index) + } + + // Write the resulting filter to disk. + var settleFailFilterBuf bytes.Buffer + if err := settleFailFilter.Encode(&settleFailFilterBuf); err != nil { + return err + } + + return heightBkt.Put(settleFailFilterKey, settleFailFilterBuf.Bytes()) +} + +// RemovePkg deletes the forwarding package at the given height from the +// packager's source bucket. +func (p *ChannelPackager) RemovePkg(tx kvdb.RwTx, height uint64) error { + fwdPkgBkt := tx.ReadWriteBucket(fwdPackagesKey) + if fwdPkgBkt == nil { + return nil + } + + sourceBytes := forwardingLogKey(p.source.ToUint64()) + sourceBkt := fwdPkgBkt.NestedReadWriteBucket(sourceBytes[:]) + if sourceBkt == nil { + return ErrCorruptedFwdPkg + } + + heightKey := forwardingLogKey(height) + + return sourceBkt.DeleteNestedBucket(heightKey[:]) +} + +// Wipe deletes all the channel's forwarding packages, if any. +func (p *ChannelPackager) Wipe(tx kvdb.RwTx) error { + // If the root bucket doesn't exist, there's no need to delete. + fwdPkgBkt := tx.ReadWriteBucket(fwdPackagesKey) + if fwdPkgBkt == nil { + return nil + } + + sourceBytes := forwardingLogKey(p.source.ToUint64()) + + // If the nested bucket doesn't exist, there's no need to delete. + if fwdPkgBkt.NestedReadWriteBucket(sourceBytes[:]) == nil { + return nil + } + + return fwdPkgBkt.DeleteNestedBucket(sourceBytes[:]) +} + +// uint16Key writes the provided 16-bit unsigned integer to a 2-byte slice. +func uint16Key(i uint16) []byte { + key := make([]byte, 2) + byteOrder.PutUint16(key, i) + return key +} + +// forwardingLogKey converts a uint64 into an 8 byte forwarding package key. +func forwardingLogKey(updateNum uint64) [8]byte { + var key [8]byte + byteOrder.PutUint64(key[:], updateNum) + return key +} + +// Compile-time constraint to ensure that ChannelPackager implements the public +// FwdPackager interface. +var _ FwdPackager = (*ChannelPackager)(nil) + +// Compile-time constraint to ensure that SwitchPackager implements the public +// FwdOperator interface. +var _ FwdOperator = (*SwitchPackager)(nil) diff --git a/channeldb/forwarding_package_test.go b/chanstate/kv_forwarding_package_test.go similarity index 90% rename from channeldb/forwarding_package_test.go rename to chanstate/kv_forwarding_package_test.go index b11764bee93..9da6fc3196f 100644 --- a/channeldb/forwarding_package_test.go +++ b/chanstate/kv_forwarding_package_test.go @@ -1,4 +1,4 @@ -package channeldb_test +package chanstate_test import ( "bytes" @@ -7,7 +7,7 @@ import ( "testing" "github.com/btcsuite/btcd/wire" - "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/chanstate" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnwire" "github.com/stretchr/testify/require" @@ -27,7 +27,7 @@ func TestPkgFilterBruteForce(t *testing.T) { // properly for all relevant sizes of `high`. func checkPkgFilterRange(t *testing.T, high int) { for i := uint16(0); i < uint16(high); i++ { - f := channeldb.NewPkgFilter(i) + f := chanstate.NewPkgFilter(i) if f.Count() != i { t.Fatalf("pkg filter count=%d is actually %d", @@ -74,7 +74,7 @@ func TestPkgFilterRand(t *testing.T) { // is parameterized by a base `b` coprime to `p`, and using modular // exponentiation to generate all elements in [1,p). func checkPkgFilterRand(t *testing.T, b, p uint16) { - f := channeldb.NewPkgFilter(p) + f := chanstate.NewPkgFilter(p) var j = b for i := uint16(1); i < p; i++ { if f.Contains(j) { @@ -113,7 +113,9 @@ func checkPkgFilterRand(t *testing.T, b, p uint16) { // 2. verifying the number of bytes written matches the filter's Size() // 3. reconstructing the filter decoding the bytes // 4. checking that the two filters are the same according to Equal -func checkPkgFilterEncodeDecode(t *testing.T, i uint16, f *channeldb.PkgFilter) { +func checkPkgFilterEncodeDecode(t *testing.T, i uint16, + f *chanstate.PkgFilter) { + var b bytes.Buffer if err := f.Encode(&b); err != nil { t.Fatalf("unable to serialize pkg filter: %v", err) @@ -128,7 +130,7 @@ func checkPkgFilterEncodeDecode(t *testing.T, i uint16, f *channeldb.PkgFilter) reader := bytes.NewReader(b.Bytes()) - f2 := &channeldb.PkgFilter{} + f2 := &chanstate.PkgFilter{} if err := f2.Decode(reader); err != nil { t.Fatalf("unable to deserialize pkg filter: %v", err) } @@ -144,8 +146,8 @@ var ( chanID = lnwire.NewChanIDFromOutPoint(wire.OutPoint{}) ) -func testSettleFails() []channeldb.LogUpdate { - return []channeldb.LogUpdate{ +func testSettleFails() []chanstate.LogUpdate { + return []chanstate.LogUpdate{ { LogIndex: 2, UpdateMsg: &lnwire.UpdateFulfillHTLC{ @@ -165,8 +167,8 @@ func testSettleFails() []channeldb.LogUpdate { } } -func testAdds() []channeldb.LogUpdate { - return []channeldb.LogUpdate{ +func testAdds() []chanstate.LogUpdate { + return []chanstate.LogUpdate{ { LogIndex: 0, UpdateMsg: &lnwire.UpdateAddHTLC{ @@ -200,7 +202,7 @@ func TestPackagerEmptyFwdPkg(t *testing.T) { db := makeFwdPkgDB(t, "") shortChanID := lnwire.NewShortChanIDFromInt(1) - packager := channeldb.NewChannelPackager(shortChanID) + packager := chanstate.NewChannelPackager(shortChanID) // To begin, there should be no forwarding packages on disk. fwdPkgs := loadFwdPkgs(t, db, packager) @@ -209,7 +211,7 @@ func TestPackagerEmptyFwdPkg(t *testing.T) { } // Next, create and write a new forwarding package with no htlcs. - fwdPkg := channeldb.NewFwdPkg(shortChanID, 0, nil, nil) + fwdPkg := chanstate.NewFwdPkg(shortChanID, 0, nil, nil) if err := kvdb.Update(db, func(tx kvdb.RwTx) error { return packager.AddFwdPkg(tx, fwdPkg) @@ -224,7 +226,7 @@ func TestPackagerEmptyFwdPkg(t *testing.T) { if len(fwdPkgs) != 1 { t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs)) } - assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateLockedIn) + assertFwdPkgState(t, fwdPkgs[0], chanstate.FwdStateLockedIn) assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], 0, 0) assertAckFilterIsFull(t, fwdPkgs[0], true) @@ -243,7 +245,7 @@ func TestPackagerEmptyFwdPkg(t *testing.T) { if len(fwdPkgs) != 1 { t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs)) } - assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateCompleted) + assertFwdPkgState(t, fwdPkgs[0], chanstate.FwdStateCompleted) assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], 0, 0) assertAckFilterIsFull(t, fwdPkgs[0], true) @@ -269,7 +271,7 @@ func TestPackagerOnlyAdds(t *testing.T) { db := makeFwdPkgDB(t, "") shortChanID := lnwire.NewShortChanIDFromInt(1) - packager := channeldb.NewChannelPackager(shortChanID) + packager := chanstate.NewChannelPackager(shortChanID) // To begin, there should be no forwarding packages on disk. fwdPkgs := loadFwdPkgs(t, db, packager) @@ -281,7 +283,7 @@ func TestPackagerOnlyAdds(t *testing.T) { // Next, create and write a new forwarding package that only has add // htlcs. - fwdPkg := channeldb.NewFwdPkg(shortChanID, 0, adds, nil) + fwdPkg := chanstate.NewFwdPkg(shortChanID, 0, adds, nil) nAdds := len(adds) @@ -298,7 +300,7 @@ func TestPackagerOnlyAdds(t *testing.T) { if len(fwdPkgs) != 1 { t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs)) } - assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateLockedIn) + assertFwdPkgState(t, fwdPkgs[0], chanstate.FwdStateLockedIn) assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, 0) assertAckFilterIsFull(t, fwdPkgs[0], false) @@ -321,11 +323,11 @@ func TestPackagerOnlyAdds(t *testing.T) { if len(fwdPkgs) != 1 { t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs)) } - assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateProcessed) + assertFwdPkgState(t, fwdPkgs[0], chanstate.FwdStateProcessed) assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, 0) assertAckFilterIsFull(t, fwdPkgs[0], false) - addRef := channeldb.AddRef{ + addRef := chanstate.AddRef{ Height: fwdPkg.Height, Index: uint16(i), } @@ -344,7 +346,7 @@ func TestPackagerOnlyAdds(t *testing.T) { if len(fwdPkgs) != 1 { t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs)) } - assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateCompleted) + assertFwdPkgState(t, fwdPkgs[0], chanstate.FwdStateCompleted) assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, 0) assertAckFilterIsFull(t, fwdPkgs[0], true) @@ -373,7 +375,7 @@ func TestPackagerOnlySettleFails(t *testing.T) { db := makeFwdPkgDB(t, "") shortChanID := lnwire.NewShortChanIDFromInt(1) - packager := channeldb.NewChannelPackager(shortChanID) + packager := chanstate.NewChannelPackager(shortChanID) // To begin, there should be no forwarding packages on disk. fwdPkgs := loadFwdPkgs(t, db, packager) @@ -384,7 +386,7 @@ func TestPackagerOnlySettleFails(t *testing.T) { // Next, create and write a new forwarding package that only has add // htlcs. settleFails := testSettleFails() - fwdPkg := channeldb.NewFwdPkg(shortChanID, 0, nil, settleFails) + fwdPkg := chanstate.NewFwdPkg(shortChanID, 0, nil, settleFails) nSettleFails := len(settleFails) @@ -401,7 +403,7 @@ func TestPackagerOnlySettleFails(t *testing.T) { if len(fwdPkgs) != 1 { t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs)) } - assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateLockedIn) + assertFwdPkgState(t, fwdPkgs[0], chanstate.FwdStateLockedIn) assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], 0, nSettleFails) assertAckFilterIsFull(t, fwdPkgs[0], true) @@ -424,12 +426,12 @@ func TestPackagerOnlySettleFails(t *testing.T) { if len(fwdPkgs) != 1 { t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs)) } - assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateProcessed) + assertFwdPkgState(t, fwdPkgs[0], chanstate.FwdStateProcessed) assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], 0, nSettleFails) assertSettleFailFilterIsFull(t, fwdPkgs[0], false) assertAckFilterIsFull(t, fwdPkgs[0], true) - failSettleRef := channeldb.SettleFailRef{ + failSettleRef := chanstate.SettleFailRef{ Source: shortChanID, Height: fwdPkg.Height, Index: uint16(i), @@ -449,7 +451,7 @@ func TestPackagerOnlySettleFails(t *testing.T) { if len(fwdPkgs) != 1 { t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs)) } - assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateCompleted) + assertFwdPkgState(t, fwdPkgs[0], chanstate.FwdStateCompleted) assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], 0, nSettleFails) assertSettleFailFilterIsFull(t, fwdPkgs[0], true) assertAckFilterIsFull(t, fwdPkgs[0], true) @@ -478,7 +480,7 @@ func TestPackagerAddsThenSettleFails(t *testing.T) { db := makeFwdPkgDB(t, "") shortChanID := lnwire.NewShortChanIDFromInt(1) - packager := channeldb.NewChannelPackager(shortChanID) + packager := chanstate.NewChannelPackager(shortChanID) // To begin, there should be no forwarding packages on disk. fwdPkgs := loadFwdPkgs(t, db, packager) @@ -491,7 +493,7 @@ func TestPackagerAddsThenSettleFails(t *testing.T) { // Next, create and write a new forwarding package that only has add // htlcs. settleFails := testSettleFails() - fwdPkg := channeldb.NewFwdPkg(shortChanID, 0, adds, settleFails) + fwdPkg := chanstate.NewFwdPkg(shortChanID, 0, adds, settleFails) nAdds := len(adds) nSettleFails := len(settleFails) @@ -509,7 +511,7 @@ func TestPackagerAddsThenSettleFails(t *testing.T) { if len(fwdPkgs) != 1 { t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs)) } - assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateLockedIn) + assertFwdPkgState(t, fwdPkgs[0], chanstate.FwdStateLockedIn) assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, nSettleFails) assertAckFilterIsFull(t, fwdPkgs[0], false) @@ -532,12 +534,12 @@ func TestPackagerAddsThenSettleFails(t *testing.T) { if len(fwdPkgs) != 1 { t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs)) } - assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateProcessed) + assertFwdPkgState(t, fwdPkgs[0], chanstate.FwdStateProcessed) assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, nSettleFails) assertSettleFailFilterIsFull(t, fwdPkgs[0], false) assertAckFilterIsFull(t, fwdPkgs[0], false) - addRef := channeldb.AddRef{ + addRef := chanstate.AddRef{ Height: fwdPkg.Height, Index: uint16(i), } @@ -558,12 +560,12 @@ func TestPackagerAddsThenSettleFails(t *testing.T) { if len(fwdPkgs) != 1 { t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs)) } - assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateProcessed) + assertFwdPkgState(t, fwdPkgs[0], chanstate.FwdStateProcessed) assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, nSettleFails) assertSettleFailFilterIsFull(t, fwdPkgs[0], false) assertAckFilterIsFull(t, fwdPkgs[0], true) - failSettleRef := channeldb.SettleFailRef{ + failSettleRef := chanstate.SettleFailRef{ Source: shortChanID, Height: fwdPkg.Height, Index: uint16(i), @@ -583,7 +585,7 @@ func TestPackagerAddsThenSettleFails(t *testing.T) { if len(fwdPkgs) != 1 { t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs)) } - assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateCompleted) + assertFwdPkgState(t, fwdPkgs[0], chanstate.FwdStateCompleted) assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, nSettleFails) assertSettleFailFilterIsFull(t, fwdPkgs[0], true) assertAckFilterIsFull(t, fwdPkgs[0], true) @@ -614,7 +616,7 @@ func TestPackagerSettleFailsThenAdds(t *testing.T) { db := makeFwdPkgDB(t, "") shortChanID := lnwire.NewShortChanIDFromInt(1) - packager := channeldb.NewChannelPackager(shortChanID) + packager := chanstate.NewChannelPackager(shortChanID) // To begin, there should be no forwarding packages on disk. fwdPkgs := loadFwdPkgs(t, db, packager) @@ -627,7 +629,7 @@ func TestPackagerSettleFailsThenAdds(t *testing.T) { // Next, create and write a new forwarding package that has both add // and settle/fail htlcs. settleFails := testSettleFails() - fwdPkg := channeldb.NewFwdPkg(shortChanID, 0, adds, settleFails) + fwdPkg := chanstate.NewFwdPkg(shortChanID, 0, adds, settleFails) nAdds := len(adds) nSettleFails := len(settleFails) @@ -645,7 +647,7 @@ func TestPackagerSettleFailsThenAdds(t *testing.T) { if len(fwdPkgs) != 1 { t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs)) } - assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateLockedIn) + assertFwdPkgState(t, fwdPkgs[0], chanstate.FwdStateLockedIn) assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, nSettleFails) assertAckFilterIsFull(t, fwdPkgs[0], false) @@ -671,12 +673,12 @@ func TestPackagerSettleFailsThenAdds(t *testing.T) { if len(fwdPkgs) != 1 { t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs)) } - assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateProcessed) + assertFwdPkgState(t, fwdPkgs[0], chanstate.FwdStateProcessed) assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, nSettleFails) assertSettleFailFilterIsFull(t, fwdPkgs[0], false) assertAckFilterIsFull(t, fwdPkgs[0], false) - failSettleRef := channeldb.SettleFailRef{ + failSettleRef := chanstate.SettleFailRef{ Source: shortChanID, Height: fwdPkg.Height, Index: uint16(i), @@ -699,12 +701,12 @@ func TestPackagerSettleFailsThenAdds(t *testing.T) { if len(fwdPkgs) != 1 { t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs)) } - assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateProcessed) + assertFwdPkgState(t, fwdPkgs[0], chanstate.FwdStateProcessed) assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, nSettleFails) assertSettleFailFilterIsFull(t, fwdPkgs[0], true) assertAckFilterIsFull(t, fwdPkgs[0], false) - addRef := channeldb.AddRef{ + addRef := chanstate.AddRef{ Height: fwdPkg.Height, Index: uint16(i), } @@ -723,7 +725,7 @@ func TestPackagerSettleFailsThenAdds(t *testing.T) { if len(fwdPkgs) != 1 { t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs)) } - assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateCompleted) + assertFwdPkgState(t, fwdPkgs[0], chanstate.FwdStateCompleted) assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, nSettleFails) assertSettleFailFilterIsFull(t, fwdPkgs[0], true) assertAckFilterIsFull(t, fwdPkgs[0], true) @@ -750,7 +752,7 @@ func TestPackagerWipeAll(t *testing.T) { db := makeFwdPkgDB(t, "") shortChanID := lnwire.NewShortChanIDFromInt(1) - packager := channeldb.NewChannelPackager(shortChanID) + packager := chanstate.NewChannelPackager(shortChanID) // To begin, there should be no forwarding packages on disk. fwdPkgs := loadFwdPkgs(t, db, packager) @@ -761,8 +763,8 @@ func TestPackagerWipeAll(t *testing.T) { require.NoError(t, err, "unable to wipe fwdpkg") // Next, create and write two forwarding packages with no htlcs. - fwdPkg1 := channeldb.NewFwdPkg(shortChanID, 0, nil, nil) - fwdPkg2 := channeldb.NewFwdPkg(shortChanID, 1, nil, nil) + fwdPkg1 := chanstate.NewFwdPkg(shortChanID, 0, nil, nil) + fwdPkg2 := chanstate.NewFwdPkg(shortChanID, 1, nil, nil) err = kvdb.Update(db, func(tx kvdb.RwTx) error { if err := packager.AddFwdPkg(tx, fwdPkg2); err != nil { @@ -787,8 +789,9 @@ func TestPackagerWipeAll(t *testing.T) { // assertFwdPkgState checks the current state of a fwdpkg meets our // expectations. -func assertFwdPkgState(t *testing.T, fwdPkg *channeldb.FwdPkg, - state channeldb.FwdState) { +func assertFwdPkgState(t *testing.T, fwdPkg *chanstate.FwdPkg, + state chanstate.FwdState) { + _, _, line, _ := runtime.Caller(1) if fwdPkg.State != state { t.Fatalf("line %d: expected fwdpkg in state %v, found %v", @@ -798,7 +801,7 @@ func assertFwdPkgState(t *testing.T, fwdPkg *channeldb.FwdPkg, // assertFwdPkgNumAddsSettleFails checks that the number of adds and // settle/fail log updates are correct. -func assertFwdPkgNumAddsSettleFails(t *testing.T, fwdPkg *channeldb.FwdPkg, +func assertFwdPkgNumAddsSettleFails(t *testing.T, fwdPkg *chanstate.FwdPkg, expectedNumAdds, expectedNumSettleFails int) { _, _, line, _ := runtime.Caller(1) if len(fwdPkg.Adds) != expectedNumAdds { @@ -814,7 +817,9 @@ func assertFwdPkgNumAddsSettleFails(t *testing.T, fwdPkg *channeldb.FwdPkg, // assertAckFilterIsFull checks whether or not a fwdpkg's ack filter matches our // expected full-ness. -func assertAckFilterIsFull(t *testing.T, fwdPkg *channeldb.FwdPkg, expected bool) { +func assertAckFilterIsFull(t *testing.T, fwdPkg *chanstate.FwdPkg, + expected bool) { + _, _, line, _ := runtime.Caller(1) if fwdPkg.AckFilter.IsFull() != expected { t.Fatalf("line %d: expected fwdpkg ack filter IsFull to be %v, "+ @@ -824,7 +829,9 @@ func assertAckFilterIsFull(t *testing.T, fwdPkg *channeldb.FwdPkg, expected bool // assertSettleFailFilterIsFull checks whether or not a fwdpkg's settle fail // filter matches our expected full-ness. -func assertSettleFailFilterIsFull(t *testing.T, fwdPkg *channeldb.FwdPkg, expected bool) { +func assertSettleFailFilterIsFull(t *testing.T, fwdPkg *chanstate.FwdPkg, + expected bool) { + _, _, line, _ := runtime.Caller(1) if fwdPkg.SettleFailFilter.IsFull() != expected { t.Fatalf("line %d: expected fwdpkg settle/fail filter IsFull to be %v, "+ @@ -835,9 +842,9 @@ func assertSettleFailFilterIsFull(t *testing.T, fwdPkg *channeldb.FwdPkg, expect // loadFwdPkgs is a helper method that reads all forwarding packages for a // particular packager. func loadFwdPkgs(t *testing.T, db kvdb.Backend, - packager channeldb.FwdPackager) []*channeldb.FwdPkg { + packager chanstate.FwdPackager) []*chanstate.FwdPkg { - var fwdPkgs []*channeldb.FwdPkg + var fwdPkgs []*chanstate.FwdPkg if err := kvdb.View(db, func(tx kvdb.RTx) error { var err error fwdPkgs, err = packager.LoadFwdPkgs(tx) diff --git a/chanstate/kv_open_channel.go b/chanstate/kv_open_channel.go new file mode 100644 index 00000000000..7028a048b5b --- /dev/null +++ b/chanstate/kv_open_channel.go @@ -0,0 +1,2150 @@ +package chanstate + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "io" + + "github.com/btcsuite/btcd/btcec/v2" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/wire" + "github.com/btcsuite/btcwallet/walletdb" + "github.com/lightningnetwork/lnd/fn/v2" + graphdb "github.com/lightningnetwork/lnd/graph/db" + "github.com/lightningnetwork/lnd/keychain" + "github.com/lightningnetwork/lnd/kvdb" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/tlv" +) + +var ( + // openChannelBucket stores all the currently open channels. This bucket + // has a second, nested bucket which is keyed by a node's ID. Within + // that node ID bucket, all attributes required to track, update, and + // close a channel are stored. + // + // openChan -> nodeID -> chanPoint + // + // TODO(roasbeef): flesh out comment. + openChannelBucket = []byte("open-chan-bucket") + + // outpointBucket stores all of our channel outpoints and a tlv + // stream containing channel data. + // + // outpoint -> tlv stream. + // + outpointBucket = []byte("outpoint-bucket") + + // chanIDBucket stores all of the 32-byte channel ID's we know about. + // These could be derived from outpointBucket, but it is more + // convenient to have these in their own bucket. + // + // chanID -> tlv stream. + // + chanIDBucket = []byte("chan-id-bucket") + + // historicalChannelBucket stores all channels that have seen their + // commitment tx confirm. All information from their previous open state + // is retained. + historicalChannelBucket = []byte("historical-chan-bucket") + + // chanInfoKey can be accessed within the bucket for a channel + // (identified by its chanPoint). This key stores all the static + // information for a channel which is decided at the end of the + // funding flow. + chanInfoKey = []byte("chan-info-key") + + // localUpfrontShutdownKey can be accessed within the bucket for a + // channel (identified by its chanPoint). This key stores an optional + // upfront shutdown script for the local peer. + localUpfrontShutdownKey = []byte("local-upfront-shutdown-key") + + // remoteUpfrontShutdownKey can be accessed within the bucket for a + // channel (identified by its chanPoint). This key stores an optional + // upfront shutdown script for the remote peer. + remoteUpfrontShutdownKey = []byte("remote-upfront-shutdown-key") + + // frozenChanKey is the key where we store the information for any + // active "frozen" channels. This key is present only in the leaf + // bucket for a given channel. + frozenChanKey = []byte("frozen-chans") + + // dataLossCommitPointKey stores the commitment point received from the + // remote peer during a channel sync in case we have lost channel state. + dataLossCommitPointKey = []byte("data-loss-commit-point-key") +) + +// OpenChannelBucketKey returns the top-level open-channel bucket key. +func OpenChannelBucketKey() []byte { + return openChannelBucket +} + +// OutpointBucketKey returns the top-level outpoint index bucket key. +func OutpointBucketKey() []byte { + return outpointBucket +} + +// ChanIDBucketKey returns the top-level channel ID index bucket key. +func ChanIDBucketKey() []byte { + return chanIDBucket +} + +// HistoricalChannelBucketKey returns the top-level historical channel bucket +// key. +func HistoricalChannelBucketKey() []byte { + return historicalChannelBucket +} + +// ChanInfoKey returns the channel-bucket key for static channel information. +func ChanInfoKey() []byte { + return chanInfoKey +} + +// LocalUpfrontShutdownKey returns the channel-bucket key for the local upfront +// shutdown script. +func LocalUpfrontShutdownKey() []byte { + return localUpfrontShutdownKey +} + +// RemoteUpfrontShutdownKey returns the channel-bucket key for the remote +// upfront shutdown script. +func RemoteUpfrontShutdownKey() []byte { + return remoteUpfrontShutdownKey +} + +// FrozenChanKey returns the key used to store a channel's thaw height. +func FrozenChanKey() []byte { + return frozenChanKey +} + +// DataLossCommitPointKey returns the key used to store the data-loss commit +// point in a channel bucket. +func DataLossCommitPointKey() []byte { + return dataLossCommitPointKey +} + +const ( + // A tlv type definition used to serialize an outpoint's indexStatus + // for use in the outpoint index. + indexStatusType tlv.Type = 0 + + // IndexStatusType is the TLV type used to serialize an outpoint's + // indexStatus for use in the outpoint index. + IndexStatusType = indexStatusType +) + +// indexStatus is an enum-like type that describes what state the outpoint is +// in. Currently only two possible values. +type indexStatus uint8 + +// IndexStatus is an enum-like type that describes what state the outpoint is +// in. Currently only two possible values. +type IndexStatus = indexStatus + +const ( + // outpointOpen represents an outpoint that is open in the outpoint + // index. + outpointOpen indexStatus = 0 + + // OutpointOpen represents an outpoint that is open in the outpoint + // index. + OutpointOpen = outpointOpen + + // outpointClosed represents an outpoint that is closed in the outpoint + // index. + outpointClosed indexStatus = 1 + + // OutpointClosed represents an outpoint that is closed in the outpoint + // index. + OutpointClosed = outpointClosed +) + +func putOutpointIndexStatus(opBucket kvdb.RwBucket, chanKey []byte, + status indexStatus) error { + + statusByte := uint8(status) + statusRecord := tlv.MakePrimitiveRecord(indexStatusType, &statusByte) + opStream, err := tlv.NewStream(statusRecord) + if err != nil { + return err + } + + var b bytes.Buffer + if err := opStream.Encode(&b); err != nil { + return err + } + + return opBucket.Put(chanKey, b.Bytes()) +} + +// PutOpenOutpointIndex stores chanKey in the outpoint index as an open +// outpoint. +func PutOpenOutpointIndex(opBucket kvdb.RwBucket, chanKey []byte) error { + return putOutpointIndexStatus(opBucket, chanKey, outpointOpen) +} + +// UpdateClosedOutpointIndex flips the outpoint index entry for chanKey from +// open to closed. The index entry must already exist; it was placed there when +// the channel was opened. +func UpdateClosedOutpointIndex(tx kvdb.RwTx, chanKey []byte) error { + opBucket := tx.ReadWriteBucket(outpointBucket) + if opBucket == nil { + return ErrNoChanDBExists + } + if opBucket.Get(chanKey) == nil { + return ErrMissingIndexEntry + } + + return putOutpointIndexStatus(opBucket, chanKey, outpointClosed) +} + +// IsOutpointClosed reports whether the supplied chanKey has been flipped to +// outpointClosed in the supplied outpointBucket. The flip is performed in the +// same transaction as the rest of CloseChannel (sync and tombstone paths +// alike), so a true result is the authoritative "this channel went through +// CloseChannel" signal. On tombstone-enabled backends the chanBucket may still +// exist on disk; readers consult this helper to skip those entries. Callers +// fetch outpointBucket once and pass it in, which lets loop-style readers +// hoist the bucket lookup out of the inner loop. +func IsOutpointClosed(opBucket kvdb.RBucket, chanKey []byte) (bool, error) { + if opBucket == nil { + return false, nil + } + raw := opBucket.Get(chanKey) + if raw == nil { + return false, nil + } + + var status uint8 + statusRecord := tlv.MakePrimitiveRecord(indexStatusType, &status) + stream, err := tlv.NewStream(statusRecord) + if err != nil { + return false, err + } + if err := stream.Decode(bytes.NewReader(raw)); err != nil { + return false, fmt.Errorf("decode outpoint status for "+ + "chan_key=%x: %w", chanKey, err) + } + + return indexStatus(status) == outpointClosed, nil +} + +// FetchChanBucket is a helper function that returns the bucket where a +// channel's data resides in given: the public key for the node, the outpoint, +// and the chainhash that the channel resides on. +func FetchChanBucket(tx kvdb.RTx, nodeKey *btcec.PublicKey, + outPoint *wire.OutPoint, chainHash chainhash.Hash) ( + kvdb.RBucket, error) { + + // First fetch the top level bucket which stores all data related to + // current, active channels. + openChanBucket := tx.ReadBucket(openChannelBucket) + if openChanBucket == nil { + return nil, ErrNoChanDBExists + } + + // TODO(roasbeef): CreateTopLevelBucket on the interface isn't like + // CreateIfNotExists, will return error. + + // Within this top level bucket, fetch the bucket dedicated to storing + // open channel data specific to the remote node. + nodePub := nodeKey.SerializeCompressed() + nodeChanBucket := openChanBucket.NestedReadBucket(nodePub) + if nodeChanBucket == nil { + return nil, ErrNoActiveChannels + } + + // We'll then recurse down an additional layer in order to fetch the + // bucket for this particular chain. + chainBucket := nodeChanBucket.NestedReadBucket(chainHash[:]) + if chainBucket == nil { + return nil, ErrNoActiveChannels + } + + // With the bucket for the node and chain fetched, we can now go down + // another level, for this channel itself. + var chanPointBuf bytes.Buffer + if err := graphdb.WriteOutpoint(&chanPointBuf, outPoint); err != nil { + return nil, err + } + chanKey := chanPointBuf.Bytes() + + // Treat already-closed channels as gone. The chanBucket may still + // exist on tombstone-enabled backends; the outpoint flip is the + // source of truth. + closed, err := IsOutpointClosed(tx.ReadBucket(outpointBucket), chanKey) + if err != nil { + return nil, err + } + if closed { + return nil, ErrChannelNotFound + } + + chanBucket := chainBucket.NestedReadBucket(chanKey) + if chanBucket == nil { + return nil, ErrChannelNotFound + } + + return chanBucket, nil +} + +// FetchChanBucketRw is a helper function that returns the bucket where a +// channel's data resides in given: the public key for the node, the outpoint, +// and the chainhash that the channel resides on. This differs from +// FetchChanBucket in that it returns a writeable bucket. +func FetchChanBucketRw(tx kvdb.RwTx, nodeKey *btcec.PublicKey, + outPoint *wire.OutPoint, chainHash chainhash.Hash) (kvdb.RwBucket, + error) { + + // First fetch the top level bucket which stores all data related to + // current, active channels. + openChanBucket := tx.ReadWriteBucket(openChannelBucket) + if openChanBucket == nil { + return nil, ErrNoChanDBExists + } + + // TODO(roasbeef): CreateTopLevelBucket on the interface isn't like + // CreateIfNotExists, will return error. + + // Within this top level bucket, fetch the bucket dedicated to storing + // open channel data specific to the remote node. + nodePub := nodeKey.SerializeCompressed() + nodeChanBucket := openChanBucket.NestedReadWriteBucket(nodePub) + if nodeChanBucket == nil { + return nil, ErrNoActiveChannels + } + + // We'll then recurse down an additional layer in order to fetch the + // bucket for this particular chain. + chainBucket := nodeChanBucket.NestedReadWriteBucket(chainHash[:]) + if chainBucket == nil { + return nil, ErrNoActiveChannels + } + + // With the bucket for the node and chain fetched, we can now go down + // another level, for this channel itself. + var chanPointBuf bytes.Buffer + if err := graphdb.WriteOutpoint(&chanPointBuf, outPoint); err != nil { + return nil, err + } + chanKey := chanPointBuf.Bytes() + + // Treat already-closed channels as gone. The chanBucket may still + // exist on tombstone-enabled backends; the outpoint flip is the + // source of truth. + closed, err := IsOutpointClosed(tx.ReadBucket(outpointBucket), chanKey) + if err != nil { + return nil, err + } + if closed { + return nil, ErrChannelNotFound + } + + chanBucket := chainBucket.NestedReadWriteBucket(chanKey) + if chanBucket == nil { + return nil, ErrChannelNotFound + } + + return chanBucket, nil +} + +// FetchThawHeight fetches a channel's thaw height from the channel bucket. +func FetchThawHeight(chanBucket kvdb.RBucket) (uint32, error) { + var height uint32 + + heightBytes := chanBucket.Get(frozenChanKey) + heightReader := bytes.NewReader(heightBytes) + + if err := binary.Read(heightReader, byteOrder, &height); err != nil { + return 0, err + } + + return height, nil +} + +// StoreThawHeight stores a channel's thaw height in the channel bucket. +func StoreThawHeight(chanBucket kvdb.RwBucket, height uint32) error { + var heightBuf bytes.Buffer + if err := binary.Write(&heightBuf, byteOrder, height); err != nil { + return err + } + + return chanBucket.Put(frozenChanKey, heightBuf.Bytes()) +} + +// DeleteThawHeight deletes a channel's thaw height from the channel bucket. +func DeleteThawHeight(chanBucket kvdb.RwBucket) error { + return chanBucket.Delete(frozenChanKey) +} + +// FullSyncOpenChannel syncs the contents of an OpenChannel while re-using an +// existing database transaction. +func FullSyncOpenChannel(tx kvdb.RwTx, c *OpenChannel) error { + // Fetch the outpoint bucket and check if the outpoint already exists. + opBucket := tx.ReadWriteBucket(outpointBucket) + if opBucket == nil { + return ErrNoChanDBExists + } + cidBucket := tx.ReadWriteBucket(chanIDBucket) + if cidBucket == nil { + return ErrNoChanDBExists + } + + var chanPointBuf bytes.Buffer + err := graphdb.WriteOutpoint(&chanPointBuf, &c.FundingOutpoint) + if err != nil { + return err + } + + // Now, check if the outpoint exists in our index. + if opBucket.Get(chanPointBuf.Bytes()) != nil { + return ErrChanAlreadyExists + } + + cid := lnwire.NewChanIDFromOutPoint(c.FundingOutpoint) + if cidBucket.Get(cid[:]) != nil { + return ErrChanAlreadyExists + } + + // Add the outpoint to our outpoint index with the tlv stream. + if err := PutOpenOutpointIndex( + opBucket, chanPointBuf.Bytes(), + ); err != nil { + return err + } + + if err := cidBucket.Put(cid[:], []byte{}); err != nil { + return err + } + + // First fetch the top level bucket which stores all data related to + // current, active channels. + openChanBucket, err := tx.CreateTopLevelBucket(openChannelBucket) + if err != nil { + return err + } + + // Within this top level bucket, fetch the bucket dedicated to storing + // open channel data specific to the remote node. + nodePub := c.IdentityPub.SerializeCompressed() + nodeChanBucket, err := openChanBucket.CreateBucketIfNotExists(nodePub) + if err != nil { + return err + } + + // We'll then recurse down an additional layer in order to fetch the + // bucket for this particular chain. + chainBucket, err := nodeChanBucket.CreateBucketIfNotExists( + c.ChainHash[:], + ) + if err != nil { + return err + } + + // With the bucket for the node fetched, we can now go down another + // level, creating the bucket for this channel itself. + chanBucket, err := chainBucket.CreateBucket( + chanPointBuf.Bytes(), + ) + switch { + case errors.Is(err, kvdb.ErrBucketExists): + // If this channel already exists, then in order to avoid + // overriding it, we'll return an error back up to the caller. + return ErrChanAlreadyExists + case err != nil: + return err + } + + return PutOpenChannel(chanBucket, c) +} + +// keyLocRecord is a wrapper struct around keychain.KeyLocator to implement the +// tlv.RecordProducer interface. +type keyLocRecord struct { + keychain.KeyLocator +} + +// Record creates a Record out of a KeyLocator using the passed Type and the +// EKeyLocator and DKeyLocator functions. The size will always be 8 as +// KeyFamily is uint32 and the Index is uint32. +// +// NOTE: This is part of the tlv.RecordProducer interface. +func (k *keyLocRecord) Record() tlv.Record { + // Note that we set the type here as zero, as when used with a + // tlv.RecordT, the type param will be used as the type. + return tlv.MakeStaticRecord( + 0, &k.KeyLocator, 8, EKeyLocator, DKeyLocator, + ) +} + +// EKeyLocator is an encoder for keychain.KeyLocator. +func EKeyLocator(w io.Writer, val interface{}, buf *[8]byte) error { + if v, ok := val.(*keychain.KeyLocator); ok { + err := tlv.EUint32T(w, uint32(v.Family), buf) + if err != nil { + return err + } + + return tlv.EUint32T(w, v.Index, buf) + } + + return tlv.NewTypeForEncodingErr(val, "keychain.KeyLocator") +} + +// DKeyLocator is a decoder for keychain.KeyLocator. +func DKeyLocator(r io.Reader, val interface{}, buf *[8]byte, l uint64) error { + if v, ok := val.(*keychain.KeyLocator); ok { + var family uint32 + err := tlv.DUint32(r, &family, buf, 4) + if err != nil { + return err + } + v.Family = keychain.KeyFamily(family) + + return tlv.DUint32(r, &v.Index, buf, 4) + } + + return tlv.NewTypeForDecodingErr(val, "keychain.KeyLocator", l, 8) +} + +// WriteChanConfig serializes a channel config. +func WriteChanConfig(b io.Writer, c *ChannelConfig) error { + return WriteElements(b, + c.DustLimit, c.MaxPendingAmount, c.ChanReserve, c.MinHTLC, + c.MaxAcceptedHtlcs, c.CsvDelay, c.MultiSigKey, + c.RevocationBasePoint, c.PaymentBasePoint, c.DelayBasePoint, + c.HtlcBasePoint, + ) +} + +// ReadChanConfig deserializes a channel config. +func ReadChanConfig(b io.Reader, c *ChannelConfig) error { + return ReadElements(b, + &c.DustLimit, &c.MaxPendingAmount, &c.ChanReserve, + &c.MinHTLC, &c.MaxAcceptedHtlcs, &c.CsvDelay, + &c.MultiSigKey, &c.RevocationBasePoint, + &c.PaymentBasePoint, &c.DelayBasePoint, + &c.HtlcBasePoint, + ) +} + +// PutChanInfo serializes the static channel info into the channel bucket. +func PutChanInfo(chanBucket kvdb.RwBucket, channel *OpenChannel) error { + var w bytes.Buffer + if err := WriteElements(&w, + channel.ChanType, channel.ChainHash, channel.FundingOutpoint, + channel.ShortChannelID, channel.IsPending, channel.IsInitiator, + channel.ChannelStatusForStore(), channel.FundingBroadcastHeight, + channel.NumConfsRequired, channel.ChannelFlags, + channel.IdentityPub, channel.Capacity, channel.TotalMSatSent, + channel.TotalMSatReceived, + ); err != nil { + return err + } + + // For single funder channels that we initiated, and we have the + // funding transaction, then write the funding txn. + if channel.FundingTxPresent() { + if err := WriteElement(&w, channel.FundingTxn); err != nil { + return err + } + } + + if err := WriteChanConfig(&w, &channel.LocalChanCfg); err != nil { + return err + } + if err := WriteChanConfig(&w, &channel.RemoteChanCfg); err != nil { + return err + } + + if err := EncodeOpenChannelTlvData(&w, channel); err != nil { + return fmt.Errorf("unable to encode aux data: %w", err) + } + + if err := chanBucket.Put(chanInfoKey, w.Bytes()); err != nil { + return err + } + + // Finally, add optional shutdown scripts for the local and remote peer + // if they are present. + if err := putOptionalUpfrontShutdownScript( + chanBucket, localUpfrontShutdownKey, + channel.LocalShutdownScript, + ); err != nil { + return err + } + + return putOptionalUpfrontShutdownScript( + chanBucket, remoteUpfrontShutdownKey, + channel.RemoteShutdownScript, + ) +} + +// putOptionalUpfrontShutdownScript adds a shutdown script under the key +// provided if it has a non-zero length. +func putOptionalUpfrontShutdownScript(chanBucket kvdb.RwBucket, key []byte, + script []byte) error { + + // If the script is empty, we do not need to add anything. + if len(script) == 0 { + return nil + } + + var w bytes.Buffer + if err := WriteElement(&w, script); err != nil { + return err + } + + return chanBucket.Put(key, w.Bytes()) +} + +// FetchChanInfo deserializes the static channel info from the channel bucket. +func FetchChanInfo(chanBucket kvdb.RBucket, channel *OpenChannel) error { + infoBytes := chanBucket.Get(chanInfoKey) + if infoBytes == nil { + return ErrNoChanInfoFound + } + r := bytes.NewReader(infoBytes) + + var chanStatus ChannelStatus + if err := ReadElements(r, + &channel.ChanType, &channel.ChainHash, &channel.FundingOutpoint, + &channel.ShortChannelID, &channel.IsPending, + &channel.IsInitiator, + &chanStatus, &channel.FundingBroadcastHeight, + &channel.NumConfsRequired, &channel.ChannelFlags, + &channel.IdentityPub, &channel.Capacity, &channel.TotalMSatSent, + &channel.TotalMSatReceived, + ); err != nil { + return err + } + channel.SetChannelStatusForStore(chanStatus) + + // For single funder channels that we initiated and have the funding + // transaction to, read the funding txn. + if channel.FundingTxPresent() { + if err := ReadElement(r, &channel.FundingTxn); err != nil { + return err + } + } + + if err := ReadChanConfig(r, &channel.LocalChanCfg); err != nil { + return err + } + if err := ReadChanConfig(r, &channel.RemoteChanCfg); err != nil { + return err + } + + // Retrieve the boolean stored under lastWasRevokeKey. + lastWasRevokeBytes := chanBucket.Get(lastWasRevokeKey) + if lastWasRevokeBytes == nil { + // If nothing has been stored under this key, we store false + // in the OpenChannel struct. + channel.LastWasRevoke = false + } else { + // Otherwise, read the value into the LastWasRevoke field. + revokeReader := bytes.NewReader(lastWasRevokeBytes) + err := ReadElements(revokeReader, &channel.LastWasRevoke) + if err != nil { + return err + } + } + + if err := DecodeOpenChannelTlvData(r, channel); err != nil { + return fmt.Errorf("unable to decode aux data: %w", err) + } + + // Finally, read the optional shutdown scripts. + if err := getOptionalUpfrontShutdownScript( + chanBucket, localUpfrontShutdownKey, + &channel.LocalShutdownScript, + ); err != nil { + return err + } + + return getOptionalUpfrontShutdownScript( + chanBucket, remoteUpfrontShutdownKey, + &channel.RemoteShutdownScript, + ) +} + +// getOptionalUpfrontShutdownScript reads the shutdown script stored under the +// key provided if it is present. Upfront shutdown scripts are optional, so the +// function returns with no error if the key is not present. +func getOptionalUpfrontShutdownScript(chanBucket kvdb.RBucket, key []byte, + script *lnwire.DeliveryAddress) error { + + // Return early if the bucket does not exit, a shutdown script was not + // set. + bs := chanBucket.Get(key) + if bs == nil { + return nil + } + + var tempScript []byte + r := bytes.NewReader(bs) + if err := ReadElement(r, &tempScript); err != nil { + return err + } + *script = tempScript + + return nil +} + +// PutOpenChannel serializes, and stores the current state of the channel in +// its entirety. +func PutOpenChannel(chanBucket kvdb.RwBucket, channel *OpenChannel) error { + // First, we'll write out all the relatively static fields, that are + // decided upon initial channel creation. + if err := PutChanInfo(chanBucket, channel); err != nil { + return fmt.Errorf("unable to store chan info: %w", err) + } + + // With the static channel info written out, we'll now write out the + // current commitment state for both parties. + if err := PutChanCommitments(chanBucket, channel); err != nil { + return fmt.Errorf("unable to store chan commitments: %w", err) + } + + // Next, if this is a frozen channel, we'll add in the axillary + // information we need to store. + if channel.ChanType.IsFrozen() || + channel.ChanType.HasLeaseExpiration() { + + err := StoreThawHeight( + chanBucket, channel.ThawHeight, + ) + if err != nil { + return fmt.Errorf("unable to store thaw height: %w", + err) + } + } + + // Finally, we'll write out the revocation state for both parties + // within a distinct key space. + if err := PutChanRevocationState(chanBucket, channel); err != nil { + return fmt.Errorf("unable to store chan revocations: %w", err) + } + + return nil +} + +// FetchOpenChannel retrieves, and deserializes (including decrypting +// sensitive) the complete channel currently active with the passed nodeID. +func FetchOpenChannel(chanBucket kvdb.RBucket, + chanPoint *wire.OutPoint) (*OpenChannel, error) { + + channel := &OpenChannel{ + FundingOutpoint: *chanPoint, + } + + // First, we'll read all the static information that changes less + // frequently from disk. + if err := FetchChanInfo(chanBucket, channel); err != nil { + return nil, fmt.Errorf("unable to fetch chan info: %w", err) + } + + // With the static information read, we'll now read the current + // commitment state for both sides of the channel. + if err := FetchChanCommitments(chanBucket, channel); err != nil { + return nil, fmt.Errorf("unable to fetch chan commitments: %w", + err) + } + + // Next, if this is a frozen channel, we'll add in the axillary + // information we need to store. + if channel.ChanType.IsFrozen() || + channel.ChanType.HasLeaseExpiration() { + + thawHeight, err := FetchThawHeight(chanBucket) + if err != nil { + return nil, fmt.Errorf("unable to store thaw "+ + "height: %v", err) + } + + channel.ThawHeight = thawHeight + } + + // Finally, we'll retrieve the current revocation state so we can + // properly + if err := FetchChanRevocationState(chanBucket, channel); err != nil { + return nil, fmt.Errorf("unable to fetch chan revocations: %w", + err) + } + + return channel, nil +} + +// SyncPendingOpenChannel writes a pending channel to the store and records the +// funding broadcast height using an existing database transaction. +func SyncPendingOpenChannel(tx kvdb.RwTx, channel *OpenChannel, + pendingHeight uint32) error { + + channel.FundingBroadcastHeight = pendingHeight + + return FullSyncOpenChannel(tx, channel) +} + +// FetchOpenChannels starts a new database transaction and returns all stored +// currently active/open channels associated with the target nodeID. In the +// case that no active channels are known to have been created with this node, +// then a zero-length slice is returned. +func (s *KVStore) FetchOpenChannels(nodeID *btcec.PublicKey) ( + []*OpenChannel, error) { + + var channels []*OpenChannel + err := kvdb.View(s.backend, func(tx kvdb.RTx) error { + var err error + channels, err = FetchOpenChannelsTx(tx, nodeID) + + return err + }, func() { + channels = nil + }) + + return channels, err +} + +// FetchOpenChannelsTx uses an existing database transaction and returns all +// stored currently active/open channels associated with the target nodeID. In +// the case that no active channels are known to have been created with this +// node, then a zero-length slice is returned. +func FetchOpenChannelsTx(tx kvdb.RTx, + nodeID *btcec.PublicKey) ([]*OpenChannel, error) { + + // Get the bucket dedicated to storing the metadata for open channels. + openChanBucket := tx.ReadBucket(openChannelBucket) + if openChanBucket == nil { + return nil, nil + } + + // Within this top level bucket, fetch the bucket dedicated to storing + // open channel data specific to the remote node. + pub := nodeID.SerializeCompressed() + nodeChanBucket := openChanBucket.NestedReadBucket(pub) + if nodeChanBucket == nil { + return nil, nil + } + + // Next, we'll need to go down an additional layer in order to retrieve + // the channels for each chain the node knows of. + var channels []*OpenChannel + err := nodeChanBucket.ForEach(func(chainHash, v []byte) error { + // If there's a value, it's not a bucket so ignore it. + if v != nil { + return nil + } + + // If we've found a valid chainhash bucket, then we'll retrieve + // that so we can extract all the channels. + chainBucket := nodeChanBucket.NestedReadBucket(chainHash) + if chainBucket == nil { + return fmt.Errorf("unable to read bucket for chain=%x", + chainHash) + } + + // Finally, we both of the necessary buckets retrieved, fetch + // all the active channels related to this node. + nodeChannels, err := FetchNodeChannels(tx, chainBucket) + if err != nil { + return fmt.Errorf("unable to read channel for "+ + "chain_hash=%x, node_key=%x: %v", + chainHash, pub, err) + } + + channels = append(channels, nodeChannels...) + + return nil + }) + + return channels, err +} + +// FetchNodeChannels retrieves all active channels from the target chainBucket +// which is under a node's dedicated channel bucket. This function is typically +// used to fetch all the active channels related to a particular node. Channels +// already flipped to outpointClosed in the outpoint index are skipped silently +// — readers see only channels that are still considered open. +func FetchNodeChannels(tx kvdb.RTx, + chainBucket kvdb.RBucket) ([]*OpenChannel, error) { + + var channels []*OpenChannel + + // Hoist the outpoint-bucket lookup so the closed-channel check inside + // the loop is a per-iteration map probe rather than a tx-level bucket + // resolve. + opBucket := tx.ReadBucket(outpointBucket) + + // A node may have channels on several chains, so for each known chain, + // we'll extract all the channels. + err := chainBucket.ForEach(func(chanPoint, v []byte) error { + // If there's a value, it's not a bucket so ignore it. + if v != nil { + return nil + } + + // Skip already-closed channels. The chanBucket still exists + // on disk on tombstone-enabled backends; the outpoint flip is + // the sole signal that the channel should be treated as + // closed. + isClosed, err := IsOutpointClosed(opBucket, chanPoint) + if err != nil { + return err + } + if isClosed { + return nil + } + + // Once we've found a valid channel bucket, we'll extract it + // from the node's chain bucket. + chanBucket := chainBucket.NestedReadBucket(chanPoint) + + var outPoint wire.OutPoint + err = graphdb.ReadOutpoint( + bytes.NewReader(chanPoint), &outPoint, + ) + if err != nil { + return err + } + oChannel, err := FetchOpenChannel(chanBucket, &outPoint) + if err != nil { + return fmt.Errorf("unable to read channel data for "+ + "chan_point=%v: %w", outPoint, err) + } + + channels = append(channels, oChannel) + + return nil + }) + if err != nil { + return nil, err + } + + return channels, nil +} + +// FetchChannel attempts to locate a channel specified by the passed channel +// point. If the channel cannot be found, then an error will be returned. +func (s *KVStore) FetchChannel(chanPoint wire.OutPoint) (*OpenChannel, + error) { + + var targetChanPoint bytes.Buffer + err := graphdb.WriteOutpoint(&targetChanPoint, &chanPoint) + if err != nil { + return nil, err + } + + targetChanPointBytes := targetChanPoint.Bytes() + selector := func(chainBkt walletdb.ReadBucket) ([]byte, *wire.OutPoint, + error) { + + return targetChanPointBytes, &chanPoint, nil + } + + return s.channelScanner(nil, selector) +} + +// FetchChannelByID attempts to locate a channel specified by the passed channel +// ID. If the channel cannot be found, then an error will be returned. +func (s *KVStore) FetchChannelByID(id lnwire.ChannelID) (*OpenChannel, + error) { + + selector := func(chainBkt walletdb.ReadBucket) ([]byte, *wire.OutPoint, + error) { + + var ( + targetChanPointBytes []byte + targetChanPoint *wire.OutPoint + + // errChanFound is used to signal that the channel has + // been found so that iteration through the DB buckets + // can stop. + errChanFound = errors.New("channel found") + ) + err := chainBkt.ForEach(func(k, _ []byte) error { + var outPoint wire.OutPoint + err := graphdb.ReadOutpoint( + bytes.NewReader(k), &outPoint, + ) + if err != nil { + return err + } + + chanID := lnwire.NewChanIDFromOutPoint(outPoint) + if chanID != id { + return nil + } + + targetChanPoint = &outPoint + targetChanPointBytes = k + + return errChanFound + }) + if err != nil && !errors.Is(err, errChanFound) { + return nil, nil, err + } + if targetChanPoint == nil { + return nil, nil, ErrChannelNotFound + } + + return targetChanPointBytes, targetChanPoint, nil + } + + return s.channelScanner(nil, selector) +} + +// FetchPermAndTempPeers returns a map where the key is the remote node's +// public key and the value is a struct that has a tally of the pending-open +// channels and whether the peer has an open or closed channel with us. +func (s *KVStore) FetchPermAndTempPeers( + chainHash []byte) (map[string]ChanCount, error) { + + peerChanInfo := make(map[string]ChanCount) + + err := kvdb.View(s.backend, func(tx kvdb.RTx) error { + openChanBucket := tx.ReadBucket(openChannelBucket) + if openChanBucket == nil { + return ErrNoChanDBExists + } + + // Hoist the outpoint-bucket lookup so the closed-channel check + // inside the nested chainBucket.ForEach below is a per-channel + // map probe rather than a tx-level bucket resolve. + opBucket := tx.ReadBucket(outpointBucket) + + openChanErr := openChanBucket.ForEach(func(nodePub, + v []byte) error { + + // If there is a value, this is not a bucket. + if v != nil { + return nil + } + + nodeChanBucket := openChanBucket.NestedReadBucket( + nodePub, + ) + if nodeChanBucket == nil { + return nil + } + + chainBucket := nodeChanBucket.NestedReadBucket( + chainHash, + ) + if chainBucket == nil { + return fmt.Errorf("no chain bucket exists") + } + + var isPermPeer bool + var pendingOpenCount uint64 + + internalErr := chainBucket.ForEach(func(chanPoint, + val []byte) error { + + // If there is a value, this is not a bucket. + if val != nil { + return nil + } + + // Skip already-closed channels: they are + // logically closed even though their + // per-channel state still resides under + // chainBucket. The closed peer's protected + // status is established below via the + // historical-channel scan. + isClosed, err := IsOutpointClosed( + opBucket, chanPoint, + ) + if err != nil { + return err + } + if isClosed { + return nil + } + + chanBucket := chainBucket.NestedReadBucket( + chanPoint, + ) + if chanBucket == nil { + return nil + } + + var op wire.OutPoint + readErr := graphdb.ReadOutpoint( + bytes.NewReader(chanPoint), &op, + ) + if readErr != nil { + return readErr + } + + // We need to go through each channel and look + // at the IsPending status. + openChan, err := FetchOpenChannel( + chanBucket, &op, + ) + if err != nil { + return err + } + + if openChan.IsPending { + // Add to the pending-open count since + // this is a temp peer. + pendingOpenCount++ + return nil + } + + // Since IsPending is false, this is a perm + // peer. + isPermPeer = true + + return nil + }) + if internalErr != nil { + return internalErr + } + + peerCount := ChanCount{ + HasOpenOrClosedChan: isPermPeer, + PendingOpenCount: pendingOpenCount, + } + peerChanInfo[string(nodePub)] = peerCount + + return nil + }) + if openChanErr != nil { + return openChanErr + } + + // Now check the closed channel bucket. + historicalChanBucket := tx.ReadBucket(historicalChannelBucket) + if historicalChanBucket == nil { + return ErrNoHistoricalBucket + } + + historicalErr := historicalChanBucket.ForEach(func(chanPoint, + v []byte) error { + + // Parse each nested bucket and the chanInfoKey to get + // the IsPending bool. This determines whether the + // peer is protected or not. + if v != nil { + // This is not a bucket. This is currently not + // possible. + return nil + } + + chanBucket := historicalChanBucket.NestedReadBucket( + chanPoint, + ) + if chanBucket == nil { + // This is not possible. + return fmt.Errorf("no historical channel " + + "bucket exists") + } + + var op wire.OutPoint + readErr := graphdb.ReadOutpoint( + bytes.NewReader(chanPoint), &op, + ) + if readErr != nil { + return readErr + } + + // This channel is closed, but the structure of the + // historical bucket is the same. This is by design, + // which means we can call FetchOpenChannel. + channel, fetchErr := FetchOpenChannel(chanBucket, &op) + if fetchErr != nil { + return fetchErr + } + + // Only include this peer in the protected class if + // the closing transaction confirmed. Note that + // CloseChannel can be called in the funding manager + // while IsPending is true which is why we need this + // special-casing to not count premature funding + // manager calls to CloseChannel. + if !channel.IsPending { + // Fetch the public key of the remote node. We + // need to use the string-ified serialized, + // compressed bytes as the key. + remotePub := channel.IdentityPub + remoteSer := remotePub.SerializeCompressed() + remoteKey := string(remoteSer) + + count, exists := peerChanInfo[remoteKey] + if exists { + count.HasOpenOrClosedChan = true + peerChanInfo[remoteKey] = count + } else { + peerCount := ChanCount{ + HasOpenOrClosedChan: true, + } + peerChanInfo[remoteKey] = peerCount + } + } + + return nil + }) + if historicalErr != nil { + return historicalErr + } + + return nil + }, func() { + clear(peerChanInfo) + }) + + return peerChanInfo, err +} + +// channelSelector describes a function that takes a chain-hash bucket from +// within the open-channel DB and returns the wanted channel point bytes, and +// channel point. It must return the ErrChannelNotFound error if the wanted +// channel is not in the given bucket. +type channelSelector func(chainBkt walletdb.ReadBucket) ([]byte, + *wire.OutPoint, error) + +// channelScanner will traverse the DB to each chain-hash bucket of each node +// pub-key bucket in the open-channel-bucket. The chanSelector will then be used +// to fetch the wanted channel outpoint from the chain bucket. +func (s *KVStore) channelScanner(tx kvdb.RTx, + chanSelect channelSelector) (*OpenChannel, error) { + + var ( + targetChan *OpenChannel + + // errChanFound is used to signal that the channel has been + // found so that iteration through the DB buckets can stop. + errChanFound = errors.New("channel found") + ) + + // chanScan will traverse the following bucket structure: + // * nodePub => chainHash => chanPoint + // + // At each level we go one further, ensuring that we're traversing the + // proper key (that's actually a bucket). By only reading the bucket + // structure and skipping fully decoding each channel, we save a good + // bit of CPU as we don't need to do things like decompress public + // keys. + chanScan := func(tx kvdb.RTx) error { + // Get the bucket dedicated to storing the metadata for open + // channels. + openChanBucket := tx.ReadBucket(openChannelBucket) + if openChanBucket == nil { + return ErrNoActiveChannels + } + + // Hoist the outpoint-bucket lookup so the closed-channel + // check inside the per-chain ForEach below pays one tx-level + // bucket resolve total instead of one per visited chanKey. + opBucket := tx.ReadBucket(outpointBucket) + + // Within the node channel bucket, are the set of node pubkeys + // we have channels with, we don't know the entire set, so we'll + // check them all. + return openChanBucket.ForEach(func(nodePub, v []byte) error { + // Ensure that this is a key the same size as a pubkey, + // and also that it leads directly to a bucket. + if len(nodePub) != 33 || v != nil { + return nil + } + + nodeChanBucket := openChanBucket.NestedReadBucket( + nodePub, + ) + if nodeChanBucket == nil { + return nil + } + + // The next layer down is all the chains that this node + // has channels on with us. + return nodeChanBucket.ForEach(func(chainHash, + v []byte) error { + + // If there's a value, it's not a bucket so + // ignore it. + if v != nil { + return nil + } + + chainBucket := nodeChanBucket.NestedReadBucket( + chainHash, + ) + if chainBucket == nil { + return fmt.Errorf("unable to read "+ + "bucket for chain=%x", + chainHash) + } + + // Finally, we reach the leaf bucket that stores + // all the chanPoints for this node. + targetChanBytes, chanPoint, err := chanSelect( + chainBucket, + ) + if errors.Is(err, ErrChannelNotFound) { + return nil + } else if err != nil { + return err + } + + // An already-closed channel is logically gone + // and must not be surfaced by lookup-style + // scans. + isClosed, err := IsOutpointClosed( + opBucket, targetChanBytes, + ) + if err != nil { + return err + } + if isClosed { + return nil + } + + chanBucket := chainBucket.NestedReadBucket( + targetChanBytes, + ) + if chanBucket == nil { + return nil + } + + channel, err := FetchOpenChannel( + chanBucket, chanPoint, + ) + if err != nil { + return err + } + + targetChan = channel + + return errChanFound + }) + }) + } + + var err error + if tx == nil { + err = kvdb.View(s.backend, chanScan, func() {}) + } else { + err = chanScan(tx) + } + if err != nil && !errors.Is(err, errChanFound) { + return nil, err + } + + if targetChan != nil { + return targetChan, nil + } + + // If we can't find the channel, then we return with an error, as we + // have nothing to back up. + return nil, ErrChannelNotFound +} + +// FetchAllChannels attempts to retrieve all open channels currently stored +// within the database, including pending open, fully open and channels waiting +// for a closing transaction to confirm. +func (s *KVStore) FetchAllChannels() ([]*OpenChannel, error) { + return s.fetchChannels() +} + +// FetchAllOpenChannels will return all channels that have the funding +// transaction confirmed, and is not waiting for a closing transaction to be +// confirmed. +func (s *KVStore) FetchAllOpenChannels() ([]*OpenChannel, error) { + return s.fetchChannels( + pendingChannelFilter(false), + waitingCloseFilter(false), + ) +} + +// FetchPendingChannels will return channels that have completed the process of +// generating and broadcasting funding transactions, but whose funding +// transactions have yet to be confirmed on the blockchain. +func (s *KVStore) FetchPendingChannels() ([]*OpenChannel, error) { + return s.fetchChannels( + pendingChannelFilter(true), + waitingCloseFilter(false), + ) +} + +// FetchWaitingCloseChannels will return all channels that have been opened, +// but are now waiting for a closing transaction to be confirmed. +// +// NOTE: This includes channels that are also pending to be opened. +func (s *KVStore) FetchWaitingCloseChannels() ([]*OpenChannel, error) { + return s.fetchChannels(waitingCloseFilter(true)) +} + +// fetchChannelsFilter applies a filter to channels retrieved in fetchchannels. +// A set of filters can be combined to filter across multiple dimensions. +type fetchChannelsFilter func(channel *OpenChannel) bool + +// pendingChannelFilter returns a filter based on whether channels are pending +// (ie, their funding transaction still needs to confirm). If pending is false, +// channels with confirmed funding transactions are returned. +func pendingChannelFilter(pending bool) fetchChannelsFilter { + return func(channel *OpenChannel) bool { + return channel.IsPending == pending + } +} + +// waitingCloseFilter returns a filter which filters channels based on whether +// they are awaiting the confirmation of their closing transaction. If waiting +// close is true, channels that have had their closing tx broadcast are +// included. If it is false, channels that are not awaiting confirmation of +// their close transaction are returned. +func waitingCloseFilter(waitingClose bool) fetchChannelsFilter { + return func(channel *OpenChannel) bool { + // If the channel is in any other state than Default, + // then it means it is waiting to be closed. + channelWaitingClose := + channel.ChanStatus() != ChanStatusDefault + + // Include the channel if it matches the value for + // waiting close that we are filtering on. + return channelWaitingClose == waitingClose + } +} + +// fetchChannels attempts to retrieve channels currently stored in the +// database. It takes a set of filters which are applied to each channel to +// obtain a set of channels with the desired set of properties. Only channels +// which have a true value returned for *all* of the filters will be returned. +// If no filters are provided, every channel in the open channels bucket will +// be returned. +func (s *KVStore) fetchChannels(filters ...fetchChannelsFilter) ( + []*OpenChannel, error) { + + var channels []*OpenChannel + addChannel := func(channel *OpenChannel) { + channels = append(channels, channel) + } + + err := kvdb.View(s.backend, func(tx kvdb.RTx) error { + // Get the bucket dedicated to storing the metadata for open + // channels. + openChanBucket := tx.ReadBucket(openChannelBucket) + if openChanBucket == nil { + return ErrNoActiveChannels + } + + // Finally for each node public key in the open channel + // bucket, fetch all the channels related to this particular + // node. + return openChanBucket.ForEach(func(k, v []byte) error { + // Ensure that this is a key the same size as a pubkey, + // and also that it leads directly to a bucket. + if len(k) != 33 || v != nil { + return nil + } + + nodeChanBucket := openChanBucket.NestedReadBucket(k) + if nodeChanBucket == nil { + return nil + } + + return nodeChanBucket.ForEach(func(chainHash, + v []byte) error { + + // If there's a value, it's not a bucket so + // ignore it. + if v != nil { + return nil + } + + // If we've found a valid chainhash bucket, + // then we'll retrieve that so we can extract + // all the channels. + chainBucket := nodeChanBucket.NestedReadBucket( + chainHash, + ) + if chainBucket == nil { + return fmt.Errorf("unable to read "+ + "chain bucket %x", chainHash) + } + + nodeChans, err := FetchNodeChannels( + tx, chainBucket, + ) + if err != nil { + return fmt.Errorf("unable to read "+ + "channel chain=%x node=%x: %v", + chainHash, k, err) + } + for _, channel := range nodeChans { + // includeChannel indicates whether + // the channel meets our filters. + includeChannel := true + + // Check each filter. + for _, f := range filters { + // Stop once one filter fails. + if !f(channel) { + includeChannel = false + break + } + } + + // If the channel passed every filter, + // include it in our set of channels. + if includeChannel { + addChannel(channel) + } + } + + return nil + }) + }) + }, func() { + channels = nil + }) + if err != nil { + return nil, err + } + + return channels, nil +} + +// FetchHistoricalChanBucket returns a the channel bucket for a given outpoint +// from the historical channel bucket. If the bucket does not exist, +// ErrNoHistoricalBucket is returned. +func FetchHistoricalChanBucket(tx kvdb.RTx, + outPoint *wire.OutPoint) (kvdb.RBucket, error) { + + // First fetch the top level bucket which stores all data related to + // historically stored channels. + historicalChanBucket := tx.ReadBucket(historicalChannelBucket) + if historicalChanBucket == nil { + return nil, ErrNoHistoricalBucket + } + + // With the bucket for the node and chain fetched, we can now go down + // another level, for the channel itself. + var chanPointBuf bytes.Buffer + if err := graphdb.WriteOutpoint(&chanPointBuf, outPoint); err != nil { + return nil, err + } + chanBucket := historicalChanBucket.NestedReadBucket( + chanPointBuf.Bytes(), + ) + if chanBucket == nil { + return nil, ErrChannelNotFound + } + + return chanBucket, nil +} + +// FetchHistoricalChannel fetches open channel data from the historical channel +// bucket. +func (s *KVStore) FetchHistoricalChannel(outPoint *wire.OutPoint) ( + *OpenChannel, error) { + + var channel *OpenChannel + err := kvdb.View(s.backend, func(tx kvdb.RTx) error { + chanBucket, err := FetchHistoricalChanBucket(tx, outPoint) + if err != nil { + return err + } + + channel, err = FetchOpenChannel(chanBucket, outPoint) + + return err + }, func() { + channel = nil + }) + if err != nil { + return nil, err + } + + return channel, nil +} + +// RefreshChannel updates the in-memory channel state using the latest state +// observed on disk. +func (s *KVStore) RefreshChannel(channel *OpenChannel) error { + return kvdb.View(s.backend, func(tx kvdb.RTx) error { + chanBucket, err := FetchChanBucket( + tx, channel.IdentityPub, &channel.FundingOutpoint, + channel.ChainHash, + ) + if err != nil { + return err + } + + // We'll re-populating the in-memory channel with the info + // fetched from disk. + if err := FetchChanInfo(chanBucket, channel); err != nil { + return fmt.Errorf("unable to fetch chan info: %w", err) + } + + // Also populate the channel's commitment states for both sides + // of the channel. + err = FetchChanCommitments(chanBucket, channel) + if err != nil { + return fmt.Errorf("unable to fetch chan commitments: "+ + "%v", err) + } + + // Also retrieve the current revocation state. + err = FetchChanRevocationState(chanBucket, channel) + if err != nil { + return fmt.Errorf("unable to fetch chan revocations: "+ + "%v", err) + } + + return nil + }, func() {}) +} + +// MarkChannelConfirmationHeight updates the channel's confirmation height once +// the channel opening transaction receives one confirmation. +func (s *KVStore) MarkChannelConfirmationHeight(channel *OpenChannel, + height uint32) error { + + return kvdb.Update(s.backend, func(tx kvdb.RwTx) error { + chanBucket, err := FetchChanBucketRw( + tx, channel.IdentityPub, &channel.FundingOutpoint, + channel.ChainHash, + ) + if err != nil { + return err + } + + diskChannel, err := FetchOpenChannel( + chanBucket, &channel.FundingOutpoint, + ) + if err != nil { + return err + } + + diskChannel.ConfirmationHeight = height + + return PutOpenChannel(chanBucket, diskChannel) + }, func() {}) +} + +// MarkChannelCloseConfirmationHeight updates the channel's close confirmation +// height when the closing transaction is first detected in a block. +func (s *KVStore) MarkChannelCloseConfirmationHeight( + channel *OpenChannel, height fn.Option[uint32]) error { + + return kvdb.Update(s.backend, func(tx kvdb.RwTx) error { + chanBucket, err := FetchChanBucketRw( + tx, channel.IdentityPub, &channel.FundingOutpoint, + channel.ChainHash, + ) + if err != nil { + return err + } + + diskChannel, err := FetchOpenChannel( + chanBucket, &channel.FundingOutpoint, + ) + if err != nil { + return err + } + + diskChannel.CloseConfirmationHeight = height + + return PutOpenChannel(chanBucket, diskChannel) + }, func() {}) +} + +// MarkChannelOpen marks a channel as fully open given a locator that uniquely +// describes its location within the chain. +func (s *KVStore) MarkChannelOpen(channel *OpenChannel, + openLoc lnwire.ShortChannelID) error { + + return kvdb.Update(s.backend, func(tx kvdb.RwTx) error { + chanBucket, err := FetchChanBucketRw( + tx, channel.IdentityPub, &channel.FundingOutpoint, + channel.ChainHash, + ) + if err != nil { + return err + } + + diskChannel, err := FetchOpenChannel( + chanBucket, &channel.FundingOutpoint, + ) + if err != nil { + return err + } + + diskChannel.IsPending = false + diskChannel.ShortChannelID = openLoc + + return PutOpenChannel(chanBucket, diskChannel) + }, func() {}) +} + +// MarkChannelRealScid marks the zero-conf channel's confirmed ShortChannelID. +func (s *KVStore) MarkChannelRealScid(channel *OpenChannel, + realScid lnwire.ShortChannelID) error { + + return kvdb.Update(s.backend, func(tx kvdb.RwTx) error { + chanBucket, err := FetchChanBucketRw( + tx, channel.IdentityPub, &channel.FundingOutpoint, + channel.ChainHash, + ) + if err != nil { + return err + } + + diskChannel, err := FetchOpenChannel( + chanBucket, &channel.FundingOutpoint, + ) + if err != nil { + return err + } + + diskChannel.SetConfirmedScidForStore(realScid) + + return PutOpenChannel(chanBucket, diskChannel) + }, func() {}) +} + +// MarkChannelScidAliasNegotiated adds ScidAliasFeatureBit to ChanType in the +// database. +func (s *KVStore) MarkChannelScidAliasNegotiated( + channel *OpenChannel) error { + + return kvdb.Update(s.backend, func(tx kvdb.RwTx) error { + chanBucket, err := FetchChanBucketRw( + tx, channel.IdentityPub, &channel.FundingOutpoint, + channel.ChainHash, + ) + if err != nil { + return err + } + + diskChannel, err := FetchOpenChannel( + chanBucket, &channel.FundingOutpoint, + ) + if err != nil { + return err + } + + diskChannel.ChanType |= ScidAliasFeatureBit + + return PutOpenChannel(chanBucket, diskChannel) + }, func() {}) +} + +// ApplyChannelStatus adds the target status to the channel's persisted status +// bit field. +func (s *KVStore) ApplyChannelStatus(channel *OpenChannel, + status ChannelStatus) error { + + return s.putChanStatus(channel, status) +} + +// PutChannelDataLossCommitPoint stores the data-loss commit point in the +// target channel bucket. +func PutChannelDataLossCommitPoint(chanBucket kvdb.RwBucket, + commitPoint *btcec.PublicKey) error { + + return chanBucket.Put( + dataLossCommitPointKey, commitPoint.SerializeCompressed(), + ) +} + +// FetchChannelDataLossCommitPoint retrieves the data-loss commit point from the +// target channel bucket. +func FetchChannelDataLossCommitPoint( + chanBucket kvdb.RBucket) (*btcec.PublicKey, error) { + + bs := chanBucket.Get(dataLossCommitPointKey) + if bs == nil { + return nil, ErrNoCommitPoint + } + + var b [btcec.PubKeyBytesLenCompressed]byte + r := bytes.NewReader(bs) + if _, err := io.ReadFull(r, b[:]); err != nil { + return nil, err + } + + return btcec.ParsePubKey(b[:]) +} + +// MarkChannelDataLoss marks the channel as local-data-loss and stores the +// commit point needed if the remote force closes. +func (s *KVStore) MarkChannelDataLoss(channel *OpenChannel, + commitPoint *btcec.PublicKey) error { + + putCommitPoint := func(chanBucket kvdb.RwBucket) error { + return PutChannelDataLossCommitPoint(chanBucket, commitPoint) + } + + return s.putChanStatus( + channel, ChanStatusLocalDataLoss, putCommitPoint, + ) +} + +// FetchChannelDataLossCommitPoint retrieves the commit point stored when the +// channel was marked as local-data-loss. +func (s *KVStore) FetchChannelDataLossCommitPoint( + channel *OpenChannel) (*btcec.PublicKey, error) { + + var commitPoint *btcec.PublicKey + + err := kvdb.View(s.backend, func(tx kvdb.RTx) error { + chanBucket, err := FetchChanBucket( + tx, channel.IdentityPub, &channel.FundingOutpoint, + channel.ChainHash, + ) + switch { + case err == nil: + case errors.Is(err, ErrNoChanDBExists), + errors.Is(err, ErrNoActiveChannels), + errors.Is(err, ErrChannelNotFound): + + return ErrNoCommitPoint + default: + return err + } + + commitPoint, err = FetchChannelDataLossCommitPoint(chanBucket) + + return err + }, func() { + commitPoint = nil + }) + if err != nil { + return nil, err + } + + return commitPoint, nil +} + +// MarkChannelBorked marks the channel as irreconcilable. +func (s *KVStore) MarkChannelBorked(channel *OpenChannel) error { + return s.ApplyChannelStatus(channel, ChanStatusBorked) +} + +// putChanStatus appends the given status to the channel. fs is an optional +// list of closures that are given the chanBucket in order to atomically add +// extra information together with the new status. +func (s *KVStore) putChanStatus(channel *OpenChannel, + status ChannelStatus, fs ...func(kvdb.RwBucket) error) error { + + if err := kvdb.Update(s.backend, func(tx kvdb.RwTx) error { + chanBucket, err := FetchChanBucketRw( + tx, channel.IdentityPub, &channel.FundingOutpoint, + channel.ChainHash, + ) + if err != nil { + return err + } + + diskChannel, err := FetchOpenChannel( + chanBucket, &channel.FundingOutpoint, + ) + if err != nil { + return err + } + + // Add this status to the existing bitvector found in the DB. + status = diskChannel.ChannelStatusForStore() | status + diskChannel.SetChannelStatusForStore(status) + + if err := PutOpenChannel(chanBucket, diskChannel); err != nil { + return err + } + + for _, f := range fs { + // Skip execution of nil closures. + if f == nil { + continue + } + + if err := f(chanBucket); err != nil { + return err + } + } + + return nil + }, func() {}); err != nil { + return err + } + + // Update the in-memory representation to keep it in sync with the DB. + channel.SetChannelStatusForStore(status) + + return nil +} + +// ClearChannelStatus clears the target status from the channel's persisted +// status bit field. +func (s *KVStore) ClearChannelStatus(channel *OpenChannel, + status ChannelStatus) error { + + if err := kvdb.Update(s.backend, func(tx kvdb.RwTx) error { + chanBucket, err := FetchChanBucketRw( + tx, channel.IdentityPub, &channel.FundingOutpoint, + channel.ChainHash, + ) + if err != nil { + return err + } + + diskChannel, err := FetchOpenChannel( + chanBucket, &channel.FundingOutpoint, + ) + if err != nil { + return err + } + + // Unset this bit in the bitvector on disk. + status = diskChannel.ChannelStatusForStore() & ^status + diskChannel.SetChannelStatusForStore(status) + + return PutOpenChannel(chanBucket, diskChannel) + }, func() {}); err != nil { + return err + } + + // Update the in-memory representation to keep it in sync with the DB. + channel.SetChannelStatusForStore(status) + + return nil +} + +// IsChannelBorked returns true if the channel has been marked as borked in the +// database. This requires an existing database transaction to already be +// active. +// +// NOTE: The primary mutex should already be held before this method is called. +func IsChannelBorked(channel *OpenChannel, chanBucket kvdb.RBucket) ( + bool, error) { + + diskChannel, err := FetchOpenChannel( + chanBucket, &channel.FundingOutpoint, + ) + if err != nil { + return false, err + } + + return diskChannel.ChannelStatusForStore() != ChanStatusDefault, nil +} + +// openChannelTlvData houses the new data fields that are stored for each +// channel in a TLV stream within the root bucket. This is stored as a TLV +// stream appended to the existing hard-coded fields in the channel's root +// bucket. New fields being added to the channel state should be added here. +// +// NOTE: This struct is used for serialization purposes only and its fields +// should be accessed via the OpenChannel struct while in memory. +type openChannelTlvData struct { + // revokeKeyLoc is the key locator for the revocation key. + revokeKeyLoc tlv.RecordT[tlv.TlvType1, keyLocRecord] + + // initialLocalBalance is the initial local balance of the channel. + initialLocalBalance tlv.RecordT[tlv.TlvType2, uint64] + + // initialRemoteBalance is the initial remote balance of the channel. + initialRemoteBalance tlv.RecordT[tlv.TlvType3, uint64] + + // realScid is the real short channel ID of the channel corresponding to + // the on-chain outpoint. + realScid tlv.RecordT[tlv.TlvType4, lnwire.ShortChannelID] + + // memo is an optional text field that gives context to the user about + // the channel. + memo tlv.OptionalRecordT[tlv.TlvType5, []byte] + + // tapscriptRoot is the optional Tapscript root the channel funding + // output commits to. + tapscriptRoot tlv.OptionalRecordT[tlv.TlvType6, [32]byte] + + // customBlob is an optional TLV encoded blob of data representing + // custom channel funding information. + customBlob tlv.OptionalRecordT[tlv.TlvType7, tlv.Blob] + + // confirmationHeight records the block height at which the funding + // transaction was first confirmed. + confirmationHeight tlv.RecordT[tlv.TlvType8, uint32] + + // closeConfirmationHeight records the block height at which the closing + // transaction was first confirmed. This is used to calculate the + // remaining confirmations until the channel is considered fully closed. + // Note: if not set, it means either the channel has not been + // closed yet, or it was closed before this field was introduced. + closeConfirmationHeight tlv.OptionalRecordT[tlv.TlvType9, uint32] +} + +// encode serializes the openChannelTlvData to the given io.Writer. +func (c *openChannelTlvData) encode(w io.Writer) error { + tlvRecords := []tlv.Record{ + c.revokeKeyLoc.Record(), + c.initialLocalBalance.Record(), + c.initialRemoteBalance.Record(), + c.realScid.Record(), + c.confirmationHeight.Record(), + } + c.memo.WhenSome(func(memo tlv.RecordT[tlv.TlvType5, []byte]) { + tlvRecords = append(tlvRecords, memo.Record()) + }) + c.tapscriptRoot.WhenSome( + func(root tlv.RecordT[tlv.TlvType6, [32]byte]) { + tlvRecords = append(tlvRecords, root.Record()) + }, + ) + c.customBlob.WhenSome(func(blob tlv.RecordT[tlv.TlvType7, tlv.Blob]) { + tlvRecords = append(tlvRecords, blob.Record()) + }) + c.closeConfirmationHeight.WhenSome( + func(h tlv.RecordT[tlv.TlvType9, uint32]) { + tlvRecords = append(tlvRecords, h.Record()) + }, + ) + + tlv.SortRecords(tlvRecords) + + // Create the tlv stream. + tlvStream, err := tlv.NewStream(tlvRecords...) + if err != nil { + return err + } + + return tlvStream.Encode(w) +} + +// decode deserializes the openChannelTlvData from the given io.Reader. +func (c *openChannelTlvData) decode(r io.Reader) error { + memo := c.memo.Zero() + tapscriptRoot := c.tapscriptRoot.Zero() + blob := c.customBlob.Zero() + closeConfHeight := c.closeConfirmationHeight.Zero() + + // Create the tlv stream. + tlvStream, err := tlv.NewStream( + c.revokeKeyLoc.Record(), + c.initialLocalBalance.Record(), + c.initialRemoteBalance.Record(), + c.realScid.Record(), + memo.Record(), + tapscriptRoot.Record(), + blob.Record(), + c.confirmationHeight.Record(), + closeConfHeight.Record(), + ) + if err != nil { + return err + } + + tlvs, err := tlvStream.DecodeWithParsedTypes(r) + if err != nil { + return err + } + + if _, ok := tlvs[memo.TlvType()]; ok { + c.memo = tlv.SomeRecordT(memo) + } + if _, ok := tlvs[tapscriptRoot.TlvType()]; ok { + c.tapscriptRoot = tlv.SomeRecordT(tapscriptRoot) + } + if _, ok := tlvs[c.customBlob.TlvType()]; ok { + c.customBlob = tlv.SomeRecordT(blob) + } + if _, ok := tlvs[closeConfHeight.TlvType()]; ok { + c.closeConfirmationHeight = tlv.SomeRecordT(closeConfHeight) + } + + return nil +} + +// DecodeOpenChannelTlvData decodes and applies auxiliary TLV data to an open +// channel. +func DecodeOpenChannelTlvData(r io.Reader, channel *OpenChannel) error { + var auxData openChannelTlvData + if err := auxData.decode(r); err != nil { + return err + } + + amendOpenChannelTlvData(channel, auxData) + + return nil +} + +// EncodeOpenChannelTlvData extracts and encodes auxiliary TLV data from an open +// channel. +func EncodeOpenChannelTlvData(w io.Writer, channel *OpenChannel) error { + auxData := extractOpenChannelTlvData(channel) + return auxData.encode(w) +} + +// amendOpenChannelTlvData updates the channel with the given auxiliary TLV +// data. +func amendOpenChannelTlvData(channel *OpenChannel, auxData openChannelTlvData) { + channel.RevocationKeyLocator = auxData.revokeKeyLoc.Val.KeyLocator + channel.InitialLocalBalance = lnwire.MilliSatoshi( + auxData.initialLocalBalance.Val, + ) + channel.InitialRemoteBalance = lnwire.MilliSatoshi( + auxData.initialRemoteBalance.Val, + ) + channel.SetConfirmedScidForStore(auxData.realScid.Val) + channel.ConfirmationHeight = auxData.confirmationHeight.Val + + auxData.memo.WhenSomeV(func(memo []byte) { + channel.Memo = memo + }) + auxData.tapscriptRoot.WhenSomeV(func(h [32]byte) { + channel.TapscriptRoot = fn.Some[chainhash.Hash](h) + }) + auxData.customBlob.WhenSomeV(func(blob tlv.Blob) { + channel.CustomBlob = fn.Some(blob) + }) + auxData.closeConfirmationHeight.WhenSomeV(func(h uint32) { + channel.CloseConfirmationHeight = fn.Some(h) + }) +} + +// extractOpenChannelTlvData creates a new openChannelTlvData from the given +// channel. +func extractOpenChannelTlvData(channel *OpenChannel) openChannelTlvData { + auxData := openChannelTlvData{ + revokeKeyLoc: tlv.NewRecordT[tlv.TlvType1]( + keyLocRecord{channel.RevocationKeyLocator}, + ), + initialLocalBalance: tlv.NewPrimitiveRecord[tlv.TlvType2]( + uint64(channel.InitialLocalBalance), + ), + initialRemoteBalance: tlv.NewPrimitiveRecord[tlv.TlvType3]( + uint64(channel.InitialRemoteBalance), + ), + realScid: tlv.NewRecordT[tlv.TlvType4]( + channel.ConfirmedScidForStore(), + ), + confirmationHeight: tlv.NewPrimitiveRecord[tlv.TlvType8]( + channel.ConfirmationHeight, + ), + } + + if len(channel.Memo) != 0 { + auxData.memo = tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType5](channel.Memo), + ) + } + channel.TapscriptRoot.WhenSome(func(h chainhash.Hash) { + auxData.tapscriptRoot = tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType6, [32]byte](h), + ) + }) + channel.CustomBlob.WhenSome(func(blob tlv.Blob) { + auxData.customBlob = tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType7](blob), + ) + }) + channel.CloseConfirmationHeight.WhenSome(func(h uint32) { + auxData.closeConfirmationHeight = tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType9](h), + ) + }) + + return auxData +} diff --git a/chanstate/kv_revocation_log.go b/chanstate/kv_revocation_log.go new file mode 100644 index 00000000000..76ac41005f8 --- /dev/null +++ b/chanstate/kv_revocation_log.go @@ -0,0 +1,698 @@ +package chanstate + +import ( + "bytes" + "encoding/binary" + "errors" + "io" + "math" + + "github.com/lightningnetwork/lnd/kvdb" + "github.com/lightningnetwork/lnd/lntypes" + "github.com/lightningnetwork/lnd/tlv" +) + +var ( + // revocationLogBucketDeprecated is dedicated for storing the necessary + // delta state between channel updates required to re-construct a past + // state in order to punish a counterparty attempting a non-cooperative + // channel closure. This key should be accessed from within the + // sub-bucket of a target channel, identified by its channel point. + // + // Deprecated: This bucket is kept for read-only in case the user + // choose not to migrate the old data. + revocationLogBucketDeprecated = []byte("revocation-log-key") + + // revocationLogBucket is a sub-bucket under openChannelBucket. This + // sub-bucket is dedicated for storing the minimal info required to + // re-construct a past state in order to punish a counterparty + // attempting a non-cooperative channel closure. + revocationLogBucket = []byte("revocation-log") + + // ErrLogEntryNotFound is returned when we cannot find a log entry at + // the height requested in the revocation log. + ErrLogEntryNotFound = errors.New("log entry not found") + + // ErrOutputIndexTooBig is returned when the output index is greater + // than uint16. + ErrOutputIndexTooBig = errors.New("output index is over uint16") +) + +// RevocationLogBucketKey returns the sub-bucket key that stores the current +// revocation log format. +func RevocationLogBucketKey() []byte { + return revocationLogBucket +} + +// RevocationLogBucketDeprecatedKey returns the deprecated revocation-log bucket +// key. +func RevocationLogBucketDeprecatedKey() []byte { + return revocationLogBucketDeprecated +} + +// PutRevocationLog uses the fields `CommitTx` and `Htlcs` from a +// ChannelCommitment to construct a revocation log entry and saves them to +// disk. It also saves our output index and their output index, which are +// useful when creating breach retribution. +func PutRevocationLog(bucket kvdb.RwBucket, commit *ChannelCommitment, + ourOutputIndex, theirOutputIndex uint32, noAmtData bool) error { + + // Sanity check that the output indexes can be safely converted. + if ourOutputIndex > math.MaxUint16 { + return ErrOutputIndexTooBig + } + if theirOutputIndex > math.MaxUint16 { + return ErrOutputIndexTooBig + } + + rl := &RevocationLog{ + OurOutputIndex: tlv.NewPrimitiveRecord[tlv.TlvType0]( + uint16(ourOutputIndex), + ), + TheirOutputIndex: tlv.NewPrimitiveRecord[tlv.TlvType1]( + uint16(theirOutputIndex), + ), + CommitTxHash: tlv.NewPrimitiveRecord[tlv.TlvType2, [32]byte]( + commit.CommitTx.TxHash(), + ), + HTLCEntries: make([]*HTLCEntry, 0, len(commit.Htlcs)), + } + + commit.CustomBlob.WhenSome(func(blob tlv.Blob) { + rl.CustomBlob = tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType5, tlv.Blob](blob), + ) + }) + + if !noAmtData { + rl.OurBalance = tlv.SomeRecordT(tlv.NewRecordT[tlv.TlvType3]( + tlv.NewBigSizeT(commit.LocalBalance), + )) + + rl.TheirBalance = tlv.SomeRecordT(tlv.NewRecordT[tlv.TlvType4]( + tlv.NewBigSizeT(commit.RemoteBalance), + )) + } + + for _, htlc := range commit.Htlcs { + // Skip dust HTLCs. + if htlc.OutputIndex < 0 { + continue + } + + // Sanity check that the output indexes can be safely + // converted. + if htlc.OutputIndex > math.MaxUint16 { + return ErrOutputIndexTooBig + } + + entry, err := NewHTLCEntryFromHTLC(htlc) + if err != nil { + return err + } + rl.HTLCEntries = append(rl.HTLCEntries, entry) + } + + var b bytes.Buffer + err := SerializeRevocationLog(&b, rl) + if err != nil { + return err + } + + logEntrykey := revocationLogKey(commit.CommitHeight) + + return bucket.Put(logEntrykey[:], b.Bytes()) +} + +// FetchRevocationLog queries the revocation log bucket to find an log entry. +// Return an error if not found. +func FetchRevocationLog(log kvdb.RBucket, + updateNum uint64) (RevocationLog, error) { + + logEntrykey := revocationLogKey(updateNum) + commitBytes := log.Get(logEntrykey[:]) + if commitBytes == nil { + return RevocationLog{}, ErrLogEntryNotFound + } + + commitReader := bytes.NewReader(commitBytes) + + return DeserializeRevocationLog(commitReader) +} + +// FetchOldRevocationLog finds the revocation log from the deprecated +// sub-bucket. +func FetchOldRevocationLog(log kvdb.RBucket, + updateNum uint64) (ChannelCommitment, error) { + + logEntrykey := revocationLogKey(updateNum) + commitBytes := log.Get(logEntrykey[:]) + if commitBytes == nil { + return ChannelCommitment{}, ErrLogEntryNotFound + } + + commitReader := bytes.NewReader(commitBytes) + + return DeserializeChanCommit(commitReader) +} + +// FetchRevocationLogCompatible finds the revocation log from both the +// revocationLogBucket and revocationLogBucketDeprecated for compatibility +// concern. It returns three values, +// - RevocationLog, if this is non-nil, it means we've found the log in the +// new bucket. +// - ChannelCommitment, if this is non-nil, it means we've found the log +// in the old bucket. +// - error, this can happen if the log cannot be found in neither buckets. +func FetchRevocationLogCompatible(chanBucket kvdb.RBucket, + updateNum uint64) (*RevocationLog, *ChannelCommitment, error) { + + // Look into the new bucket first. + logBucket := chanBucket.NestedReadBucket(revocationLogBucket) + if logBucket != nil { + rl, err := FetchRevocationLog(logBucket, updateNum) + // We've found the record, no need to visit the old bucket. + if err == nil { + return &rl, nil, nil + } + + // Return the error if it doesn't say the log cannot be found. + if !errors.Is(err, ErrLogEntryNotFound) { + return nil, nil, err + } + } + + // Otherwise, look into the old bucket and try to find the log there. + oldBucket := chanBucket.NestedReadBucket(revocationLogBucketDeprecated) + if oldBucket != nil { + c, err := FetchOldRevocationLog(oldBucket, updateNum) + if err != nil { + return nil, nil, err + } + + // Found an old record and return it. + return nil, &c, nil + } + + // If both the buckets are nil, then the sub-buckets haven't been + // created yet. + if logBucket == nil && oldBucket == nil { + return nil, nil, ErrNoPastDeltas + } + + // Otherwise, we've tried to query the new bucket but the log cannot be + // found. + return nil, nil, ErrLogEntryNotFound +} + +// FetchLogBucket returns a read bucket by visiting both the new and the old +// bucket. +func FetchLogBucket(chanBucket kvdb.RBucket) (kvdb.RBucket, error) { + logBucket := chanBucket.NestedReadBucket(revocationLogBucket) + if logBucket == nil { + logBucket = chanBucket.NestedReadBucket( + revocationLogBucketDeprecated, + ) + if logBucket == nil { + return nil, ErrNoPastDeltas + } + } + + return logBucket, nil +} + +// DeleteLogBucket deletes the both the new and old revocation log buckets. +func DeleteLogBucket(chanBucket kvdb.RwBucket) error { + // Check if the bucket exists and delete it. + logBucket := chanBucket.NestedReadWriteBucket( + revocationLogBucket, + ) + if logBucket != nil { + err := chanBucket.DeleteNestedBucket(revocationLogBucket) + if err != nil { + return err + } + } + + // We also check whether the old revocation log bucket exists + // and delete it if so. + oldLogBucket := chanBucket.NestedReadWriteBucket( + revocationLogBucketDeprecated, + ) + if oldLogBucket != nil { + err := chanBucket.DeleteNestedBucket( + revocationLogBucketDeprecated, + ) + if err != nil { + return err + } + } + + return nil +} + +// RevocationLogTailCommitHeight returns the commit height at the end of the +// revocation log. +func (s *KVStore) RevocationLogTailCommitHeight( + channel *OpenChannel) (uint64, error) { + + var height uint64 + + // If we haven't created any state updates yet, then we'll exit early as + // there's nothing to be found on disk in the revocation bucket. + if channel.RemoteCommitment.CommitHeight == 0 { + return height, nil + } + + if err := kvdb.View(s.backend, func(tx kvdb.RTx) error { + chanBucket, err := FetchChanBucket( + tx, channel.IdentityPub, &channel.FundingOutpoint, + channel.ChainHash, + ) + if err != nil { + return err + } + + logBucket, err := FetchLogBucket(chanBucket) + if err != nil { + return err + } + + // Once we have the bucket that stores the revocation log from + // this channel, we'll jump to the _last_ key in bucket. Since + // the key is the commit height, we'll decode the bytes and + // return it. + cursor := logBucket.ReadCursor() + rawHeight, _ := cursor.Last() + height = byteOrder.Uint64(rawHeight) + + return nil + }, func() {}); err != nil { + return height, err + } + + return height, nil +} + +// FindPreviousState scans through the append-only log in an attempt to recover +// the previous channel state indicated by the update number. This method is +// intended to be used for obtaining the relevant data needed to claim all +// funds rightfully spendable in the case of an on-chain broadcast of the +// commitment transaction. +func (s *KVStore) FindPreviousState(channel *OpenChannel, + updateNum uint64) (*RevocationLog, *ChannelCommitment, error) { + + commit := &ChannelCommitment{} + rl := &RevocationLog{} + + err := kvdb.View(s.backend, func(tx kvdb.RTx) error { + chanBucket, err := FetchChanBucket( + tx, channel.IdentityPub, &channel.FundingOutpoint, + channel.ChainHash, + ) + if err != nil { + return err + } + + // Find the revocation log from both the new and the old + // bucket. + r, c, err := FetchRevocationLogCompatible( + chanBucket, updateNum, + ) + if err != nil { + return err + } + + rl = r + commit = c + + return nil + }, func() {}) + if err != nil { + return nil, nil, err + } + + // Either the `rl` or the `commit` is nil here. We return them as-is + // and leave it to the caller to decide its following action. + return rl, commit, nil +} + +// Record returns a tlv record for the SparsePayHash. +func (s *SparsePayHash) Record() tlv.Record { + // We use a zero for the type here, as this'll be used along with the + // RecordT type. + return tlv.MakeDynamicRecord( + 0, s, s.hashLen, + sparseHashEncoder, sparseHashDecoder, + ) +} + +// hashLen is used by MakeDynamicRecord to return the size of the RHash. +// +// NOTE: for zero hash, we return a length 0. +func (s *SparsePayHash) hashLen() uint64 { + if bytes.Equal(s[:], lntypes.ZeroHash[:]) { + return 0 + } + + return 32 +} + +// sparseHashEncoder is the customized encoder which skips encoding the empty +// hash. +func sparseHashEncoder(w io.Writer, val interface{}, buf *[8]byte) error { + v, ok := val.(*SparsePayHash) + if !ok { + return tlv.NewTypeForEncodingErr(val, "SparsePayHash") + } + + // If the value is an empty hash, we will skip encoding it. + if bytes.Equal(v[:], lntypes.ZeroHash[:]) { + return nil + } + + vArray := (*[32]byte)(v) + + return tlv.EBytes32(w, vArray, buf) +} + +// sparseHashDecoder is the customized decoder which skips decoding the empty +// hash. +func sparseHashDecoder(r io.Reader, val interface{}, buf *[8]byte, + l uint64) error { + + v, ok := val.(*SparsePayHash) + if !ok { + return tlv.NewTypeForEncodingErr(val, "SparsePayHash") + } + + // If the length is zero, we will skip encoding the empty hash. + if l == 0 { + return nil + } + + vArray := (*[32]byte)(v) + + return tlv.DBytes32(r, vArray, buf, 32) +} + +// toTlvStream converts an HTLCEntry record into a tlv representation. +func (h *HTLCEntry) toTlvStream() (*tlv.Stream, error) { + records := []tlv.Record{ + h.RHash.Record(), + h.RefundTimeout.Record(), + h.OutputIndex.Record(), + h.Incoming.Record(), + h.Amt.Record(), + } + + h.CustomBlob.WhenSome(func(r tlv.RecordT[tlv.TlvType5, tlv.Blob]) { + records = append(records, r.Record()) + }) + + h.HtlcIndex.WhenSome(func(r tlv.RecordT[tlv.TlvType6, + tlv.BigSizeT[uint64]]) { + + records = append(records, r.Record()) + }) + + tlv.SortRecords(records) + + return tlv.NewStream(records...) +} + +// SerializeRevocationLog serializes a RevocationLog record based on tlv +// format. +func SerializeRevocationLog(w io.Writer, rl *RevocationLog) error { + // Add the tlv records for all non-optional fields. + records := []tlv.Record{ + rl.OurOutputIndex.Record(), + rl.TheirOutputIndex.Record(), + rl.CommitTxHash.Record(), + } + + // Now we add any optional fields that are non-nil. + rl.OurBalance.WhenSome( + func(r tlv.RecordT[tlv.TlvType3, BigSizeMilliSatoshi]) { + records = append(records, r.Record()) + }, + ) + + rl.TheirBalance.WhenSome( + func(r tlv.RecordT[tlv.TlvType4, BigSizeMilliSatoshi]) { + records = append(records, r.Record()) + }, + ) + + rl.CustomBlob.WhenSome(func(r tlv.RecordT[tlv.TlvType5, tlv.Blob]) { + records = append(records, r.Record()) + }) + + // Create the tlv stream. + tlvStream, err := tlv.NewStream(records...) + if err != nil { + return err + } + + // Write the tlv stream. + if err := WriteTlvStream(w, tlvStream); err != nil { + return err + } + + // Write the HTLCs. + return SerializeHTLCEntries(w, rl.HTLCEntries) +} + +// SerializeHTLCEntries serializes a list of HTLCEntry records based on tlv +// format. +func SerializeHTLCEntries(w io.Writer, htlcs []*HTLCEntry) error { + for _, htlc := range htlcs { + // Create the tlv stream. + tlvStream, err := htlc.toTlvStream() + if err != nil { + return err + } + + // Write the tlv stream. + if err := WriteTlvStream(w, tlvStream); err != nil { + return err + } + } + + return nil +} + +// DeserializeRevocationLog deserializes a RevocationLog based on tlv format. +func DeserializeRevocationLog(r io.Reader) (RevocationLog, error) { + var rl RevocationLog + + ourBalance := rl.OurBalance.Zero() + theirBalance := rl.TheirBalance.Zero() + customBlob := rl.CustomBlob.Zero() + + // Create the tlv stream. + tlvStream, err := tlv.NewStream( + rl.OurOutputIndex.Record(), + rl.TheirOutputIndex.Record(), + rl.CommitTxHash.Record(), + ourBalance.Record(), + theirBalance.Record(), + customBlob.Record(), + ) + if err != nil { + return rl, err + } + + // Read the tlv stream. + parsedTypes, err := ReadTlvStream(r, tlvStream) + if err != nil { + return rl, err + } + + if t, ok := parsedTypes[ourBalance.TlvType()]; ok && t == nil { + rl.OurBalance = tlv.SomeRecordT(ourBalance) + } + + if t, ok := parsedTypes[theirBalance.TlvType()]; ok && t == nil { + rl.TheirBalance = tlv.SomeRecordT(theirBalance) + } + + if t, ok := parsedTypes[customBlob.TlvType()]; ok && t == nil { + rl.CustomBlob = tlv.SomeRecordT(customBlob) + } + + // Read the HTLC entries. + rl.HTLCEntries, err = DeserializeHTLCEntries(r) + + return rl, err +} + +// DeserializeHTLCEntries deserializes a list of HTLC entries based on tlv +// format. +func DeserializeHTLCEntries(r io.Reader) ([]*HTLCEntry, error) { + var ( + htlcs []*HTLCEntry + + // htlcIndexBlob defines the tlv record type to be used when + // decoding from the disk. We use it instead of the one defined + // in `HTLCEntry.HtlcIndex` as previously this field was encoded + // using `uint16`, thus we will read it as raw bytes and + // deserialize it further below. + htlcIndexBlob tlv.OptionalRecordT[tlv.TlvType6, tlv.Blob] + ) + + for { + var htlc HTLCEntry + + customBlob := htlc.CustomBlob.Zero() + htlcIndex := htlcIndexBlob.Zero() + + // Create the tlv stream. + records := []tlv.Record{ + htlc.RHash.Record(), + htlc.RefundTimeout.Record(), + htlc.OutputIndex.Record(), + htlc.Incoming.Record(), + htlc.Amt.Record(), + customBlob.Record(), + htlcIndex.Record(), + } + + tlvStream, err := tlv.NewStream(records...) + if err != nil { + return nil, err + } + + // Read the HTLC entry. + parsedTypes, err := ReadTlvStream(r, tlvStream) + if err != nil { + // We've reached the end when hitting an EOF. + if errors.Is(err, io.ErrUnexpectedEOF) { + break + } + + return nil, err + } + + if t, ok := parsedTypes[customBlob.TlvType()]; ok && t == nil { + htlc.CustomBlob = tlv.SomeRecordT(customBlob) + } + + if t, ok := parsedTypes[htlcIndex.TlvType()]; ok && t == nil { + record, err := deserializeHtlcIndexCompatible( + htlcIndex.Val, + ) + if err != nil { + return nil, err + } + + htlc.HtlcIndex = record + } + + // Append the entry. + htlcs = append(htlcs, &htlc) + } + + return htlcs, nil +} + +// deserializeHtlcIndexCompatible takes raw bytes and decodes it into an +// optional record that's assigned to the entry's HtlcIndex. +// +// NOTE: previously this `HtlcIndex` was a tlv record that used `uint16` to +// encode its value. Given now its value is encoded using BigSizeT, and for any +// BigSizeT, its possible length values are 1, 3, 5, and 8. This means if the +// tlv record has a length of 2, we know for sure it must be an old record +// whose value was encoded using uint16. +func deserializeHtlcIndexCompatible(rawBytes []byte) ( + tlv.OptionalRecordT[tlv.TlvType6, tlv.BigSizeT[uint64]], error) { + + var ( + // record defines the record that's used by the HtlcIndex in the + // entry. + record tlv.OptionalRecordT[ + tlv.TlvType6, tlv.BigSizeT[uint64], + ] + + // htlcIndexVal is the decoded uint64 value. + htlcIndexVal uint64 + ) + + // If the length of the tlv record is 2, it must be encoded using uint16 + // as the BigSizeT encoding cannot have this length. + if len(rawBytes) == 2 { + // Decode the raw bytes into uint16 and convert it into uint64. + htlcIndexVal = uint64(binary.BigEndian.Uint16(rawBytes)) + } else { + // This value is encoded using BigSizeT, we now use the decoder + // to deserialize the raw bytes. + r := bytes.NewBuffer(rawBytes) + + // Create a buffer to be used in the decoding process. + buf := [8]byte{} + + // Use the BigSizeT's decoder. + err := tlv.DBigSize(r, &htlcIndexVal, &buf, 8) + if err != nil { + return record, err + } + } + + record = tlv.SomeRecordT(tlv.NewRecordT[tlv.TlvType6]( + tlv.NewBigSizeT(htlcIndexVal), + )) + + return record, nil +} + +// WriteTlvStream is a helper function that encodes the tlv stream into the +// writer. +func WriteTlvStream(w io.Writer, s *tlv.Stream) error { + var b bytes.Buffer + if err := s.Encode(&b); err != nil { + return err + } + + // Write the stream's length as a varint. + err := tlv.WriteVarInt(w, uint64(b.Len()), &[8]byte{}) + if err != nil { + return err + } + + if _, err = w.Write(b.Bytes()); err != nil { + return err + } + + return nil +} + +// ReadTlvStream is a helper function that decodes the tlv stream from the +// reader. +func ReadTlvStream(r io.Reader, s *tlv.Stream) (tlv.TypeMap, error) { + var bodyLen uint64 + + // Read the stream's length. + bodyLen, err := tlv.ReadVarInt(r, &[8]byte{}) + switch { + // We'll convert any EOFs to ErrUnexpectedEOF, since this results in an + // invalid record. + case errors.Is(err, io.EOF): + return nil, io.ErrUnexpectedEOF + + // Other unexpected errors. + case err != nil: + return nil, err + } + + // TODO(yy): add overflow check. + lr := io.LimitReader(r, int64(bodyLen)) + + return s.DecodeWithParsedTypes(lr) +} + +// revocationLogKey converts a uint64 into an 8 byte revocation log key. +func revocationLogKey(updateNum uint64) [8]byte { + var key [8]byte + byteOrder.PutUint64(key[:], updateNum) + return key +} diff --git a/chanstate/kv_shutdown.go b/chanstate/kv_shutdown.go new file mode 100644 index 00000000000..89a3304dd17 --- /dev/null +++ b/chanstate/kv_shutdown.go @@ -0,0 +1,136 @@ +package chanstate + +import ( + "bytes" + "errors" + "io" + + "github.com/lightningnetwork/lnd/fn/v2" + "github.com/lightningnetwork/lnd/kvdb" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/tlv" +) + +var ( + // shutdownInfoKey points to the serialised shutdown info that has been + // persisted for a channel. The existence of this info means that we + // have sent the Shutdown message before and so should re-initiate the + // shutdown on re-establish. + shutdownInfoKey = []byte("shutdown-info-key") +) + +// ShutdownInfoKey returns the key for the serialised shutdown info stored in a +// channel bucket. +func ShutdownInfoKey() []byte { + return shutdownInfoKey +} + +// PutChannelShutdownInfo persists the ShutdownInfo in the target channel +// bucket. +func PutChannelShutdownInfo(chanBucket kvdb.RwBucket, + info *ShutdownInfo) error { + + var b bytes.Buffer + err := EncodeShutdownInfo(info, &b) + if err != nil { + return err + } + + return chanBucket.Put(shutdownInfoKey, b.Bytes()) +} + +// FetchChannelShutdownInfo fetches the persisted ShutdownInfo from the target +// channel bucket. +func FetchChannelShutdownInfo(chanBucket kvdb.RBucket) ( + *ShutdownInfo, error) { + + shutdownInfoBytes := chanBucket.Get(shutdownInfoKey) + if shutdownInfoBytes == nil { + return nil, ErrNoShutdownInfo + } + + return DecodeShutdownInfo(shutdownInfoBytes) +} + +// StoreChannelShutdownInfo persists the ShutdownInfo for the target channel. +func (s *KVStore) StoreChannelShutdownInfo(channel *OpenChannel, + info *ShutdownInfo) error { + + return kvdb.Update(s.backend, func(tx kvdb.RwTx) error { + chanBucket, err := FetchChanBucketRw( + tx, channel.IdentityPub, &channel.FundingOutpoint, + channel.ChainHash, + ) + if err != nil { + return err + } + + return PutChannelShutdownInfo(chanBucket, info) + }, func() {}) +} + +// FetchChannelShutdownInfo fetches the persisted ShutdownInfo for the target +// channel. +func (s *KVStore) FetchChannelShutdownInfo( + channel *OpenChannel) (fn.Option[ShutdownInfo], error) { + + var shutdownInfo *ShutdownInfo + err := kvdb.View(s.backend, func(tx kvdb.RTx) error { + chanBucket, err := FetchChanBucket( + tx, channel.IdentityPub, &channel.FundingOutpoint, + channel.ChainHash, + ) + switch { + case err == nil: + case errors.Is(err, ErrNoChanDBExists), + errors.Is(err, ErrNoActiveChannels), + errors.Is(err, ErrChannelNotFound): + + return ErrNoShutdownInfo + default: + return err + } + + shutdownInfo, err = FetchChannelShutdownInfo(chanBucket) + + return err + }, func() { + shutdownInfo = nil + }) + if err != nil { + return fn.None[ShutdownInfo](), err + } + + return fn.Some[ShutdownInfo](*shutdownInfo), nil +} + +// EncodeShutdownInfo serialises the ShutdownInfo to the given io.Writer. +func EncodeShutdownInfo(s *ShutdownInfo, w io.Writer) error { + records := []tlv.Record{ + s.DeliveryScript.Record(), + s.LocalInitiator.Record(), + } + + stream, err := tlv.NewStream(records...) + if err != nil { + return err + } + + return stream.Encode(w) +} + +// DecodeShutdownInfo constructs a ShutdownInfo struct by decoding the given +// byte slice. +func DecodeShutdownInfo(b []byte) (*ShutdownInfo, error) { + tlvStream := lnwire.ExtraOpaqueData(b) + + var info ShutdownInfo + records := []tlv.RecordProducer{ + &info.DeliveryScript, + &info.LocalInitiator, + } + + _, err := tlvStream.ExtractRecords(records...) + + return &info, err +} diff --git a/chanstate/kv_store.go b/chanstate/kv_store.go new file mode 100644 index 00000000000..2d1c187a6fd --- /dev/null +++ b/chanstate/kv_store.go @@ -0,0 +1,35 @@ +package chanstate + +import "github.com/lightningnetwork/lnd/kvdb" + +// KVStore is the KV-backed implementation of the channel-state store facets. +// Store facets are moved onto this type incrementally while channeldb keeps +// compatibility wrappers for callers that still depend on the old package. +type KVStore struct { + backend kvdb.Backend + noRevLogAmtData bool + storeFinalHtlcResolutions bool + tombstoneClosedChannels bool +} + +// NewKVStore creates a KV-backed channel-state store. +func NewKVStore(backend kvdb.Backend, + storeFinalHtlcResolutions, noRevLogAmtData, + tombstoneClosedChannels bool) *KVStore { + + return &KVStore{ + backend: backend, + noRevLogAmtData: noRevLogAmtData, + storeFinalHtlcResolutions: storeFinalHtlcResolutions, + tombstoneClosedChannels: tombstoneClosedChannels, + } +} + +var _ ChannelSetupStore = (*KVStore)(nil) +var _ FinalHTLCStore = (*KVStore)(nil) +var _ OpenChannelFwdPkgStore = (*KVStore)(nil) +var _ OpenChannelShutdownStore = (*KVStore)(nil) +var _ OpenChannelCloseTxStore = (*KVStore)(nil) +var _ OpenChannelStatusStore = (*KVStore)(nil) +var _ OpenChannelCommitmentStore = (*KVStore)(nil) +var _ HistoricalChannelStore = (*KVStore)(nil) diff --git a/chanstate/open_channel.go b/chanstate/open_channel.go new file mode 100644 index 00000000000..b8e4dd1c60e --- /dev/null +++ b/chanstate/open_channel.go @@ -0,0 +1,1256 @@ +package chanstate + +import ( + "crypto/sha256" + "errors" + "fmt" + "net" + "sync" + + "github.com/btcsuite/btcd/btcec/v2" + "github.com/btcsuite/btcd/btcutil" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/fn/v2" + "github.com/lightningnetwork/lnd/htlcswitch/hop" + "github.com/lightningnetwork/lnd/input" + "github.com/lightningnetwork/lnd/keychain" + "github.com/lightningnetwork/lnd/lntypes" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/shachain" + "github.com/lightningnetwork/lnd/tlv" +) + +// OpenChannel encapsulates the persistent and dynamic state of an open channel +// with a remote node. An open channel supports several options for on-disk +// serialization depending on the exact context. Full (upon channel creation) +// state commitments, and partial (due to a commitment update) writes are +// supported. Each partial write due to a state update appends the new update +// to an on-disk log, which can then subsequently be queried in order to +// "time-travel" to a prior state. +type OpenChannel struct { + // ChanType denotes which type of channel this is. + ChanType ChannelType + + // ChainHash is a hash which represents the blockchain that this + // channel will be opened within. This value is typically the genesis + // hash. In the case that the original chain went through a contentious + // hard-fork, then this value will be tweaked using the unique fork + // point on each branch. + ChainHash chainhash.Hash + + // FundingOutpoint is the outpoint of the final funding transaction. + // This value uniquely and globally identifies the channel within the + // target blockchain as specified by the chain hash parameter. + FundingOutpoint wire.OutPoint + + // ShortChannelID encodes the exact location in the chain in which the + // channel was initially confirmed. This includes: the block height, + // transaction index, and the output within the target transaction. + // + // If IsZeroConf(), then this will the "base" (very first) ALIAS scid + // and the confirmed SCID will be stored in ConfirmedScid. + ShortChannelID lnwire.ShortChannelID + + // IsPending indicates whether a channel's funding transaction has been + // confirmed. + IsPending bool + + // IsInitiator is a bool which indicates if we were the original + // initiator for the channel. This value may affect how higher levels + // negotiate fees, or close the channel. + IsInitiator bool + + // chanStatus is the current status of this channel. If it is not in + // the state Default, it should not be used for forwarding payments. + chanStatus ChannelStatus + + // FundingBroadcastHeight is the height in which the funding + // transaction was broadcast. This value can be used by higher level + // sub-systems to determine if a channel is stale and/or should have + // been confirmed before a certain height. + FundingBroadcastHeight uint32 + + // ConfirmationHeight records the block height at which the funding + // transaction was first confirmed. + ConfirmationHeight uint32 + + // CloseConfirmationHeight records the block height at which the closing + // transaction was first confirmed. This is used to track remaining + // confirmations until the channel is considered fully closed. It is + // None if the closing transaction has not yet been confirmed, or if + // this data was not available (e.g. channels closed before this + // field was introduced). + CloseConfirmationHeight fn.Option[uint32] + + // NumConfsRequired is the number of confirmations a channel's funding + // transaction must have received in order to be considered available + // for normal transactional use. + NumConfsRequired uint16 + + // ChannelFlags holds the flags that were sent as part of the + // open_channel message. + ChannelFlags lnwire.FundingFlag + + // IdentityPub is the identity public key of the remote node this + // channel has been established with. + IdentityPub *btcec.PublicKey + + // Capacity is the total capacity of this channel. + Capacity btcutil.Amount + + // TotalMSatSent is the total number of milli-satoshis we've sent + // within this channel. + TotalMSatSent lnwire.MilliSatoshi + + // TotalMSatReceived is the total number of milli-satoshis we've + // received within this channel. + TotalMSatReceived lnwire.MilliSatoshi + + // InitialLocalBalance is the balance we have during the channel + // opening. When we are not the initiator, this value represents the + // push amount. + InitialLocalBalance lnwire.MilliSatoshi + + // InitialRemoteBalance is the balance they have during the channel + // opening. + InitialRemoteBalance lnwire.MilliSatoshi + + // LocalChanCfg is the channel configuration for the local node. + LocalChanCfg ChannelConfig + + // RemoteChanCfg is the channel configuration for the remote node. + RemoteChanCfg ChannelConfig + + // LocalCommitment is the current local commitment state for the local + // party. This is stored distinct from the state of the remote party + // as there are certain asymmetric parameters which affect the + // structure of each commitment. + LocalCommitment ChannelCommitment + + // RemoteCommitment is the current remote commitment state for the + // remote party. This is stored distinct from the state of the local + // party as there are certain asymmetric parameters which affect the + // structure of each commitment. + RemoteCommitment ChannelCommitment + + // RemoteCurrentRevocation is the current revocation for their + // commitment transaction. However, since this the derived public key, + // we don't yet have the private key so we aren't yet able to verify + // that it's actually in the hash chain. + RemoteCurrentRevocation *btcec.PublicKey + + // RemoteNextRevocation is the revocation key to be used for the *next* + // commitment transaction we create for the local node. Within the + // specification, this value is referred to as the + // per-commitment-point. + RemoteNextRevocation *btcec.PublicKey + + // RevocationProducer is used to generate the revocation in such a way + // that remote side might store it efficiently and have the ability to + // restore the revocation by index if needed. Current implementation of + // secret producer is shachain producer. + RevocationProducer shachain.Producer + + // RevocationStore is used to efficiently store the revocations for + // previous channels states sent to us by remote side. Current + // implementation of secret store is shachain store. + RevocationStore shachain.Store + + // FundingTxn is the transaction containing this channel's funding + // outpoint. Upon restarts, this txn will be rebroadcast if the channel + // is found to be pending. + // + // NOTE: This value will only be populated for single-funder channels + // for which we are the initiator, and that we also have the funding + // transaction for. One can check this by using the HasFundingTx() + // method on the ChanType field. + FundingTxn *wire.MsgTx + + // LocalShutdownScript is set to a pre-set script if the channel was + // opened by the local node with option_upfront_shutdown_script set. If + // the option was not set, the field is empty. + LocalShutdownScript lnwire.DeliveryAddress + + // RemoteShutdownScript is set to a pre-set script if the channel was + // opened by the remote node with option_upfront_shutdown_script set. If + // the option was not set, the field is empty. + RemoteShutdownScript lnwire.DeliveryAddress + + // ThawHeight is the height when a frozen channel once again becomes a + // normal channel. If this is zero, then there're no restrictions on + // this channel. If the value is lower than 500,000, then it's + // interpreted as a relative height, or an absolute height otherwise. + ThawHeight uint32 + + // LastWasRevoke is a boolean that determines if the last update we sent + // was a revocation (true) or a commitment signature (false). + LastWasRevoke bool + + // RevocationKeyLocator stores the KeyLocator information that we will + // need to derive the shachain root for this channel. This allows us to + // have private key isolation from lnd. + RevocationKeyLocator keychain.KeyLocator + + // confirmedScid is the confirmed ShortChannelID for a zero-conf + // channel. If the channel is unconfirmed, then this will be the + // default ShortChannelID. This is only set for zero-conf channels. + confirmedScid lnwire.ShortChannelID + + // Memo is any arbitrary information we wish to store locally about the + // channel that will be useful to our future selves. + Memo []byte + + // TapscriptRoot is an optional tapscript root used to derive the MuSig2 + // funding output. + TapscriptRoot fn.Option[chainhash.Hash] + + // CustomBlob is an optional blob that can be used to store information + // specific to a custom channel type. This information is only created + // at channel funding time, and after wards is to be considered + // immutable. + CustomBlob fn.Option[tlv.Blob] + + // Db persists channel state through the Store contract. This field + // intentionally keeps the existing name while callers still construct + // channels through the channeldb compatibility alias. The store + // interface keeps receiver methods backend independent while the KV + // implementation remains in channeldb. + Db Store + + // TODO(roasbeef): just need to store local and remote HTLC's? + + sync.RWMutex +} + +// String returns a string representation of the channel. +func (c *OpenChannel) String() string { + indexStr := "height=%v, local_htlc_index=%v, local_log_index=%v, " + + "remote_htlc_index=%v, remote_log_index=%v" + + commit := c.LocalCommitment + local := fmt.Sprintf(indexStr, commit.CommitHeight, + commit.LocalHtlcIndex, commit.LocalLogIndex, + commit.RemoteHtlcIndex, commit.RemoteLogIndex, + ) + + commit = c.RemoteCommitment + remote := fmt.Sprintf(indexStr, commit.CommitHeight, + commit.LocalHtlcIndex, commit.LocalLogIndex, + commit.RemoteHtlcIndex, commit.RemoteLogIndex, + ) + + return fmt.Sprintf("SCID=%v, status=%v, initiator=%v, pending=%v, "+ + "local commitment has %s, remote commitment has %s", + c.ShortChannelID, c.chanStatus, c.IsInitiator, c.IsPending, + local, remote, + ) +} + +// Initiator returns the ChannelParty that originally opened this channel. +func (c *OpenChannel) Initiator() lntypes.ChannelParty { + c.RLock() + defer c.RUnlock() + + if c.IsInitiator { + return lntypes.Local + } + + return lntypes.Remote +} + +// ShortChanID returns the current ShortChannelID of this channel. +func (c *OpenChannel) ShortChanID() lnwire.ShortChannelID { + c.RLock() + defer c.RUnlock() + + return c.ShortChannelID +} + +// ZeroConfRealScid returns the zero-conf channel's confirmed scid. This should +// only be called if IsZeroConf returns true. +func (c *OpenChannel) ZeroConfRealScid() lnwire.ShortChannelID { + c.RLock() + defer c.RUnlock() + + return c.confirmedScid +} + +// ZeroConfConfirmed returns whether the zero-conf channel has confirmed. This +// should only be called if IsZeroConf returns true. +func (c *OpenChannel) ZeroConfConfirmed() bool { + c.RLock() + defer c.RUnlock() + + return c.confirmedScid != hop.Source +} + +// IsZeroConf returns whether the option_zeroconf channel type was negotiated. +func (c *OpenChannel) IsZeroConf() bool { + c.RLock() + defer c.RUnlock() + + return c.ChanType.HasZeroConf() +} + +// IsOptionScidAlias returns whether the option_scid_alias channel type was +// negotiated. +func (c *OpenChannel) IsOptionScidAlias() bool { + c.RLock() + defer c.RUnlock() + + return c.ChanType.HasScidAliasChan() +} + +// NegotiatedAliasFeature returns whether the option-scid-alias feature bit was +// negotiated. +func (c *OpenChannel) NegotiatedAliasFeature() bool { + c.RLock() + defer c.RUnlock() + + return c.ChanType.HasScidAliasFeature() +} + +// ChanStatus returns the current ChannelStatus of this channel. +func (c *OpenChannel) ChanStatus() ChannelStatus { + c.RLock() + defer c.RUnlock() + + return c.chanStatus +} + +// ChannelStatusForStore returns the in-memory channel status without taking +// the channel mutex. +// +// NOTE: This is a preliminary migration hook for KV-backed store code that +// still lives in channeldb during this refactor. Callers are responsible for +// synchronization. Normal callers should use ChanStatus. +func (c *OpenChannel) ChannelStatusForStore() ChannelStatus { + return c.chanStatus +} + +// SetChannelStatusForStore updates the in-memory channel status without taking +// the channel mutex. +// +// NOTE: This is a preliminary migration hook for KV-backed store code that +// still lives in channeldb during this refactor. Callers are responsible for +// synchronization. Normal callers should use ApplyChanStatus or +// ClearChanStatus when the status change must be persisted. +func (c *OpenChannel) SetChannelStatusForStore(status ChannelStatus) { + c.chanStatus = status +} + +// ApplyChanStatus allows the caller to modify the internal channel state in a +// thead-safe manner. +func (c *OpenChannel) ApplyChanStatus(status ChannelStatus) error { + c.Lock() + defer c.Unlock() + + return c.Db.ApplyChannelStatus(c, status) +} + +// ClearChanStatus allows the caller to clear a particular channel status from +// the primary channel status bit field. After this method returns, a call to +// HasChanStatus(status) should return false. +func (c *OpenChannel) ClearChanStatus(status ChannelStatus) error { + c.Lock() + defer c.Unlock() + + return c.Db.ClearChannelStatus(c, status) +} + +// HasChanStatus returns true if the internal bitfield channel status of the +// target channel has the specified status bit set. +func (c *OpenChannel) HasChanStatus(status ChannelStatus) bool { + c.RLock() + defer c.RUnlock() + + return c.hasChanStatus(status) +} + +func (c *OpenChannel) hasChanStatus(status ChannelStatus) bool { + // Special case ChanStatusDefualt since it isn't actually flag, but a + // particular combination (or lack-there-of) of flags. + if status == ChanStatusDefault { + return c.chanStatus == ChanStatusDefault + } + + return c.chanStatus&status == status +} + +// HasChanStatusForStore returns true if the internal bitfield channel status +// has the specified status bit set, without taking the channel mutex. +// +// NOTE: This is a preliminary migration hook for KV-backed store code that +// still lives in channeldb during this refactor. Callers are responsible for +// synchronization. Normal callers should use HasChanStatus. +func (c *OpenChannel) HasChanStatusForStore(status ChannelStatus) bool { + return c.hasChanStatus(status) +} + +// ConfirmedScidForStore returns the in-memory confirmed SCID without taking +// the channel mutex. +// +// NOTE: This is a preliminary migration hook for KV-backed store code that +// still lives in channeldb during this refactor. Callers are responsible for +// synchronization. Normal callers should use ZeroConfRealScid. +func (c *OpenChannel) ConfirmedScidForStore() lnwire.ShortChannelID { + return c.confirmedScid +} + +// SetConfirmedScidForStore updates the in-memory confirmed SCID without taking +// the channel mutex. +// +// NOTE: This is a preliminary migration hook for KV-backed store code that +// still lives in channeldb during this refactor. Callers are responsible for +// synchronization. +func (c *OpenChannel) SetConfirmedScidForStore(scid lnwire.ShortChannelID) { + c.confirmedScid = scid +} + +// BroadcastHeight returns the height at which the funding tx was broadcast. +func (c *OpenChannel) BroadcastHeight() uint32 { + c.RLock() + defer c.RUnlock() + + return c.FundingBroadcastHeight +} + +// FundingTxPresent returns true if expect the funding transcation to be found +// on disk or already populated within the passed open channel struct. +func (c *OpenChannel) FundingTxPresent() bool { + chanType := c.ChanType + + return chanType.IsSingleFunder() && chanType.HasFundingTx() && + c.IsInitiator && + !c.HasChanStatusForStore(ChanStatusRestored) +} + +// SetBroadcastHeight sets the FundingBroadcastHeight. +func (c *OpenChannel) SetBroadcastHeight(height uint32) { + c.Lock() + defer c.Unlock() + + c.FundingBroadcastHeight = height +} + +// Refresh updates the in-memory channel state using the latest state observed +// on disk. +func (c *OpenChannel) Refresh() error { + c.Lock() + defer c.Unlock() + + return c.Db.RefreshChannel(c) +} + +// MarkConfirmationHeight updates the channel's confirmation height once the +// channel opening transaction receives one confirmation. +func (c *OpenChannel) MarkConfirmationHeight(height uint32) error { + c.Lock() + defer c.Unlock() + + if err := c.Db.MarkChannelConfirmationHeight(c, height); err != nil { + return err + } + + c.ConfirmationHeight = height + + return nil +} + +// ResetCloseConfirmationHeight clears the channel's close confirmation height +// when the spending transaction is reorged out. +func (c *OpenChannel) ResetCloseConfirmationHeight() error { + return c.MarkCloseConfirmationHeight(fn.None[uint32]()) +} + +// MarkCloseConfirmationHeight updates the channel's close confirmation height +// when the closing transaction is first detected in a block (spend height). +func (c *OpenChannel) MarkCloseConfirmationHeight( + height fn.Option[uint32]) error { + + c.Lock() + defer c.Unlock() + + err := c.Db.MarkChannelCloseConfirmationHeight(c, height) + if err != nil { + return err + } + + c.CloseConfirmationHeight = height + + return nil +} + +// MarkAsOpen marks a channel as fully open given a locator that uniquely +// describes its location within the chain. +func (c *OpenChannel) MarkAsOpen(openLoc lnwire.ShortChannelID) error { + c.Lock() + defer c.Unlock() + + if err := c.Db.MarkChannelOpen(c, openLoc); err != nil { + return err + } + + c.IsPending = false + c.ShortChannelID = openLoc + + return nil +} + +// MarkRealScid marks the zero-conf channel's confirmed ShortChannelID. This +// should only be done if IsZeroConf returns true. +func (c *OpenChannel) MarkRealScid(realScid lnwire.ShortChannelID) error { + c.Lock() + defer c.Unlock() + + if err := c.Db.MarkChannelRealScid(c, realScid); err != nil { + return err + } + + c.confirmedScid = realScid + + return nil +} + +// MarkScidAliasNegotiated adds ScidAliasFeatureBit to ChanType in-memory and +// in the database. +func (c *OpenChannel) MarkScidAliasNegotiated() error { + c.Lock() + defer c.Unlock() + + if err := c.Db.MarkChannelScidAliasNegotiated(c); err != nil { + return err + } + + c.ChanType |= ScidAliasFeatureBit + + return nil +} + +// MarkDataLoss marks sets the channel status to LocalDataLoss and stores the +// passed commitPoint for use to retrieve funds in case the remote force closes +// the channel. +func (c *OpenChannel) MarkDataLoss(commitPoint *btcec.PublicKey) error { + c.Lock() + defer c.Unlock() + + return c.Db.MarkChannelDataLoss(c, commitPoint) +} + +// DataLossCommitPoint retrieves the stored commit point set during +// MarkDataLoss. If not found ErrNoCommitPoint is returned. +func (c *OpenChannel) DataLossCommitPoint() (*btcec.PublicKey, error) { + return c.Db.FetchChannelDataLossCommitPoint(c) +} + +// MarkBorked marks the event when the channel as reached an irreconcilable +// state, such as a channel breach or state desynchronization. Borked channels +// should never be added to the switch. +func (c *OpenChannel) MarkBorked() error { + c.Lock() + defer c.Unlock() + + return c.Db.MarkChannelBorked(c) +} + +// SecondCommitmentPoint returns the second per-commitment-point for use in the +// channel_ready message. +func (c *OpenChannel) SecondCommitmentPoint() (*btcec.PublicKey, error) { + c.RLock() + defer c.RUnlock() + + // Since we start at commitment height = 0, the second per commitment + // point is actually at the 1st index. + revocation, err := c.RevocationProducer.AtIndex(1) + if err != nil { + return nil, err + } + + return input.ComputeCommitmentPoint(revocation[:]), nil +} + +// ChanSyncMsg returns the ChannelReestablish message that should be sent upon +// reconnection with the remote peer that we're maintaining this channel with. +// The information contained within this message is necessary to re-sync our +// commitment chains in the case of a last or only partially processed message. +// When the remote party receives this message one of three things may happen: +// +// 1. We're fully synced and no messages need to be sent. +// 2. We didn't get the last CommitSig message they sent, so they'll re-send +// it. +// 3. We didn't get the last RevokeAndAck message they sent, so they'll +// re-send it. +// +// If this is a restored channel, having status ChanStatusRestored, then we'll +// modify our typical chan sync message to ensure they force close even if +// we're on the very first state. +func (c *OpenChannel) ChanSyncMsg() (*lnwire.ChannelReestablish, error) { + c.Lock() + defer c.Unlock() + + // The remote commitment height that we'll send in the + // ChannelReestablish message is our current commitment height plus + // one. If the receiver thinks that our commitment height is actually + // *equal* to this value, then they'll re-send the last commitment that + // they sent but we never fully processed. + localHeight := c.LocalCommitment.CommitHeight + nextLocalCommitHeight := localHeight + 1 + + // The second value we'll send is the height of the remote commitment + // from our PoV. If the receiver thinks that their height is actually + // *one plus* this value, then they'll re-send their last revocation. + remoteChainTipHeight := c.RemoteCommitment.CommitHeight + + // If this channel has undergone a commitment update, then in order to + // prove to the remote party our knowledge of their prior commitment + // state, we'll also send over the last commitment secret that the + // remote party sent. + var lastCommitSecret [32]byte + if remoteChainTipHeight != 0 { + remoteSecret, err := c.RevocationStore.LookUp( + remoteChainTipHeight - 1, + ) + if err != nil { + return nil, err + } + lastCommitSecret = [32]byte(*remoteSecret) + } + + // Additionally, we'll send over the current unrevoked commitment on + // our local commitment transaction. + currentCommitSecret, err := c.RevocationProducer.AtIndex( + localHeight, + ) + if err != nil { + return nil, err + } + + // If we've restored this channel, then we'll purposefully give them an + // invalid LocalUnrevokedCommitPoint so they'll force close the channel + // allowing us to sweep our funds. + if c.hasChanStatus(ChanStatusRestored) { + currentCommitSecret[0] ^= 1 + + // If this is a tweakless channel, then we'll purposefully send + // a next local height taht's invalid to trigger a force close + // on their end. We do this as tweakless channels don't require + // that the commitment point is valid, only that it's present. + if c.ChanType.IsTweakless() { + nextLocalCommitHeight = 0 + } + } + + // If this is a taproot channel, then we'll need to generate our next + // verification nonce to send to the remote party. They'll use this to + // sign the next update to our commitment transaction. + var ( + nextTaprootNonce lnwire.OptMusig2NonceTLV + nextLocalNonces lnwire.OptLocalNonces + ) + if c.ChanType.IsTaproot() { + taprootRevProducer, err := DeriveMusig2Shachain( + c.RevocationProducer, + ) + if err != nil { + return nil, err + } + + nextNonce, err := NewMusigVerificationNonce( + c.LocalChanCfg.MultiSigKey.PubKey, + nextLocalCommitHeight, taprootRevProducer, + ) + if err != nil { + return nil, fmt.Errorf("unable to gen next "+ + "nonce: %w", err) + } + + fundingTxid := c.FundingOutpoint.Hash + nonce := nextNonce.PubNonce + + // Final taproot channels use the map-based LocalNonces + // field keyed by funding TXID. Staging channels use the + // legacy single LocalNonce field. + if c.ChanType.IsTaprootFinal() { + noncesMap := make(map[chainhash.Hash]lnwire.Musig2Nonce) + noncesMap[fundingTxid] = nonce + nextLocalNonces = lnwire.SomeLocalNonces( + lnwire.LocalNoncesData{NoncesMap: noncesMap}, + ) + } else { + nextTaprootNonce = lnwire.SomeMusig2Nonce(nonce) + } + } + + return &lnwire.ChannelReestablish{ + ChanID: lnwire.NewChanIDFromOutPoint( + c.FundingOutpoint, + ), + NextLocalCommitHeight: nextLocalCommitHeight, + RemoteCommitTailHeight: remoteChainTipHeight, + LastRemoteCommitSecret: lastCommitSecret, + LocalUnrevokedCommitPoint: input.ComputeCommitmentPoint( + currentCommitSecret[:], + ), + LocalNonce: nextTaprootNonce, + LocalNonces: nextLocalNonces, + }, nil +} + +// MarkShutdownSent serialises and persist the given ShutdownInfo for this +// channel. Persisting this info represents the fact that we have sent the +// Shutdown message to the remote side and hence that we should re-transmit the +// same Shutdown message on re-establish. +func (c *OpenChannel) MarkShutdownSent(info *ShutdownInfo) error { + c.Lock() + defer c.Unlock() + + return c.Db.StoreChannelShutdownInfo(c, info) +} + +// ShutdownInfo decodes the shutdown info stored for this channel and returns +// the result. If no shutdown info has been persisted for this channel then the +// ErrNoShutdownInfo error is returned. +func (c *OpenChannel) ShutdownInfo() (fn.Option[ShutdownInfo], error) { + c.RLock() + defer c.RUnlock() + + return c.Db.FetchChannelShutdownInfo(c) +} + +// MarkCommitmentBroadcasted marks the channel as a commitment transaction has +// been broadcast, either our own or the remote, and we should watch the chain +// for it to confirm before taking any further action. It takes as argument the +// closing tx _we believe_ will appear in the chain. This is only used to +// republish this tx at startup to ensure propagation, and we should still +// handle the case where a different tx actually hits the chain. +func (c *OpenChannel) MarkCommitmentBroadcasted(closeTx *wire.MsgTx, + closer lntypes.ChannelParty) error { + + return c.Db.MarkChannelCommitmentBroadcasted(c, closeTx, closer) +} + +// MarkCoopBroadcasted marks the channel to indicate that a cooperative close +// transaction has been broadcast, either our own or the remote, and that we +// should watch the chain for it to confirm before taking further action. It +// takes as argument a cooperative close tx that could appear on chain, and +// should be rebroadcast upon startup. This is only used to republish and +// ensure propagation, and we should still handle the case where a different tx +// actually hits the chain. +func (c *OpenChannel) MarkCoopBroadcasted(closeTx *wire.MsgTx, + closer lntypes.ChannelParty) error { + + return c.Db.MarkChannelCoopBroadcasted(c, closeTx, closer) +} + +// BroadcastedCommitment retrieves the stored unilateral closing tx set during +// MarkCommitmentBroadcasted. If not found ErrNoCloseTx is returned. +func (c *OpenChannel) BroadcastedCommitment() (*wire.MsgTx, error) { + return c.Db.FetchChannelBroadcastedCommitment(c) +} + +// BroadcastedCooperative retrieves the stored cooperative closing tx set during +// MarkCoopBroadcasted. If not found ErrNoCloseTx is returned. +func (c *OpenChannel) BroadcastedCooperative() (*wire.MsgTx, error) { + return c.Db.FetchChannelBroadcastedCooperative(c) +} + +// SyncPending writes the contents of the channel to the database while it's in +// the pending (waiting for funding confirmation) state. The IsPending flag +// will be set to true. When the channel's funding transaction is confirmed, +// the channel should be marked as "open" and the IsPending flag set to false. +// Note that this function also creates a LinkNode relationship between this +// newly created channel and a new LinkNode instance. This allows listing all +// channels in the database globally, or according to the LinkNode they were +// created with. +// +// TODO(roasbeef): addr param should eventually be an lnwire.NetAddress type +// that includes service bits. +func (c *OpenChannel) SyncPending(addr net.Addr, pendingHeight uint32) error { + c.Lock() + defer c.Unlock() + + return c.Db.SyncPendingChannel(c, addr, pendingHeight) +} + +// UpdateCommitment updates the local commitment state. It locks in the pending +// local updates that were received by us from the remote party. The commitment +// state completely describes the balance state at this point in the commitment +// chain. In addition to that, it persists all the remote log updates that we +// have acked, but not signed a remote commitment for yet. These need to be +// persisted to be able to produce a valid commit signature if a restart would +// occur. This method its to be called when we revoke our prior commitment +// state. +// +// A map is returned of all the htlc resolutions that were locked in this +// commitment. Keys correspond to htlc indices and values indicate whether the +// htlc was settled or failed. +func (c *OpenChannel) UpdateCommitment(newCommitment *ChannelCommitment, + unsignedAckedUpdates []LogUpdate) (map[uint64]bool, error) { + + c.Lock() + defer c.Unlock() + + // If this is a restored channel, then we want to avoid mutating the + // state as all, as it's impossible to do so in a protocol compliant + // manner. + if c.hasChanStatus(ChanStatusRestored) { + return nil, ErrNoRestoredChannelMutation + } + + finalHtlcs, err := c.Db.UpdateChannelCommitment( + c, newCommitment, unsignedAckedUpdates, + ) + if err != nil { + return nil, err + } + + c.LocalCommitment = *newCommitment + + return finalHtlcs, nil +} + +// ActiveHtlcs returns a slice of HTLC's which are currently active on *both* +// commitment transactions. +func (c *OpenChannel) ActiveHtlcs() []HTLC { + c.RLock() + defer c.RUnlock() + + // We'll only return HTLC's that are locked into *both* commitment + // transactions. So we'll iterate through their set of HTLC's to note + // which ones are present on their commitment. + remoteHtlcs := make(map[[32]byte]struct{}) + for _, htlc := range c.RemoteCommitment.Htlcs { + log.Tracef("RemoteCommitment has htlc: id=%v, update=%v "+ + "incoming=%v", htlc.HtlcIndex, htlc.LogIndex, + htlc.Incoming) + + onionHash := sha256.Sum256(htlc.OnionBlob[:]) + remoteHtlcs[onionHash] = struct{}{} + } + + // Now that we know which HTLC's they have, we'll only mark the HTLC's + // as active if *we* know them as well. + activeHtlcs := make([]HTLC, 0, len(remoteHtlcs)) + for _, htlc := range c.LocalCommitment.Htlcs { + log.Tracef("LocalCommitment has htlc: id=%v, update=%v "+ + "incoming=%v", htlc.HtlcIndex, htlc.LogIndex, + htlc.Incoming) + + onionHash := sha256.Sum256(htlc.OnionBlob[:]) + if _, ok := remoteHtlcs[onionHash]; !ok { + log.Tracef("Skipped htlc due to onion mismatched: "+ + "id=%v, update=%v incoming=%v", + htlc.HtlcIndex, htlc.LogIndex, htlc.Incoming) + + continue + } + + activeHtlcs = append(activeHtlcs, htlc) + } + + return activeHtlcs +} + +// AppendRemoteCommitChain appends a new CommitDiff to the end of the +// commitment chain for the remote party. This method is to be used once we +// have prepared a new commitment state for the remote party, but before we +// transmit it to the remote party. The contents of the argument should be +// sufficient to retransmit the updates and signature needed to reconstruct the +// state in full, in the case that we need to retransmit. +func (c *OpenChannel) AppendRemoteCommitChain(diff *CommitDiff) error { + c.Lock() + defer c.Unlock() + + // If this is a restored channel, then we want to avoid mutating the + // state at all, as it's impossible to do so in a protocol compliant + // manner. + if c.hasChanStatus(ChanStatusRestored) { + return ErrNoRestoredChannelMutation + } + + return c.Db.AppendRemoteCommitChain(c, diff) +} + +// RemoteCommitChainTip returns the "tip" of the current remote commitment +// chain. This value will be non-nil iff, we've created a new commitment for +// the remote party that they haven't yet ACK'd. In this case, their commitment +// chain will have a length of two: their current unrevoked commitment, and +// this new pending commitment. Once they revoked their prior state, we'll swap +// these pointers, causing the tip and the tail to point to the same entry. +func (c *OpenChannel) RemoteCommitChainTip() (*CommitDiff, error) { + return c.Db.RemoteCommitChainTip(c) +} + +// UnsignedAckedUpdates retrieves the persisted unsigned acked remote log +// updates that still need to be signed for. +func (c *OpenChannel) UnsignedAckedUpdates() ([]LogUpdate, error) { + return c.Db.UnsignedAckedUpdates(c) +} + +// RemoteUnsignedLocalUpdates retrieves the persisted, unsigned local log +// updates that the remote still needs to sign for. +func (c *OpenChannel) RemoteUnsignedLocalUpdates() ([]LogUpdate, error) { + return c.Db.RemoteUnsignedLocalUpdates(c) +} + +// InsertNextRevocation inserts the _next_ commitment point (revocation) into +// the database, and also modifies the internal RemoteNextRevocation attribute +// to point to the passed key. This method is to be using during final channel +// set up, _after_ the channel has been fully confirmed. +// +// NOTE: If this method isn't called, then the target channel won't be able to +// propose new states for the commitment state of the remote party. +func (c *OpenChannel) InsertNextRevocation(revKey *btcec.PublicKey) error { + c.Lock() + defer c.Unlock() + + return c.Db.InsertNextRevocation(c, revKey) +} + +// AdvanceCommitChainTail records the new state transition within an on-disk +// append-only log which records all state transitions by the remote peer. In +// the case of an uncooperative broadcast of a prior state by the remote peer, +// this log can be consulted in order to reconstruct the state needed to +// rectify the situation. This method will add the current commitment for the +// remote party to the revocation log, and promote the current pending +// commitment to the current remote commitment. The updates parameter is the +// set of local updates that the peer still needs to send us a signature for. +// We store this set of updates in case we go down. +func (c *OpenChannel) AdvanceCommitChainTail(fwdPkg *FwdPkg, + updates []LogUpdate, ourOutputIndex, theirOutputIndex uint32) error { + + c.Lock() + defer c.Unlock() + + // If this is a restored channel, then we want to avoid mutating the + // state at all, as it's impossible to do so in a protocol compliant + // manner. + if c.hasChanStatus(ChanStatusRestored) { + return ErrNoRestoredChannelMutation + } + + return c.Db.AdvanceCommitChainTail( + c, fwdPkg, updates, ourOutputIndex, theirOutputIndex, + ) +} + +// NextLocalHtlcIndex returns the next unallocated local htlc index. To ensure +// this always returns the next index that has been not been allocated, this +// will first try to examine any pending commitments, before falling back to the +// last locked-in remote commitment. +func (c *OpenChannel) NextLocalHtlcIndex() (uint64, error) { + // First, load the most recent commit diff that we initiated for the + // remote party. If no pending commit is found, this is not treated as + // a critical error, since we can always fall back. + pendingRemoteCommit, err := c.RemoteCommitChainTip() + if err != nil && !errors.Is(err, ErrNoPendingCommit) { + return 0, err + } + + // If a pending commit was found, its local htlc index will be at least + // as large as the one on our local commitment. + if pendingRemoteCommit != nil { + return pendingRemoteCommit.Commitment.LocalHtlcIndex, nil + } + + // Otherwise, fallback to using the local htlc index of their + // commitment. + return c.RemoteCommitment.LocalHtlcIndex, nil +} + +// LoadFwdPkgs scans the forwarding log for any packages that haven't been +// processed, and returns their deserialized log updates in map indexed by the +// remote commitment height at which the updates were locked in. +func (c *OpenChannel) LoadFwdPkgs() ([]*FwdPkg, error) { + c.RLock() + defer c.RUnlock() + + return c.Db.LoadFwdPkgs(c) +} + +// AckAddHtlcs updates the AckAddFilter containing any of the provided AddRefs +// indicating that a response to this Add has been committed to the remote +// party. Doing so will prevent these Add HTLCs from being reforwarded +// internally. +func (c *OpenChannel) AckAddHtlcs(addRefs ...AddRef) error { + c.Lock() + defer c.Unlock() + + return c.Db.AckAddHtlcs(c, addRefs...) +} + +// AckSettleFails updates the SettleFailFilter containing any of the provided +// SettleFailRefs, indicating that the response has been delivered to the +// incoming link, corresponding to a particular AddRef. Doing so will prevent +// the responses from being retransmitted internally. +func (c *OpenChannel) AckSettleFails(settleFailRefs ...SettleFailRef) error { + c.Lock() + defer c.Unlock() + + return c.Db.AckSettleFails(c, settleFailRefs...) +} + +// SetFwdFilter atomically sets the forwarding filter for the forwarding package +// identified by `height`. +func (c *OpenChannel) SetFwdFilter(height uint64, fwdFilter *PkgFilter) error { + c.Lock() + defer c.Unlock() + + return c.Db.SetFwdFilter(c, height, fwdFilter) +} + +// RemoveFwdPkgs atomically removes forwarding packages specified by the +// remote commitment heights. If one of the intermediate RemovePkg calls fails, +// then the later packages won't be removed. +// +// NOTE: This method should only be called on packages marked FwdStateCompleted. +func (c *OpenChannel) RemoveFwdPkgs(heights ...uint64) error { + c.Lock() + defer c.Unlock() + + return c.Db.RemoveFwdPkgs(c, heights...) +} + +// CommitmentHeight returns the current commitment height. The commitment +// height represents the number of updates to the commitment state to date. +// This value is always monotonically increasing. This method is provided in +// order to allow multiple instances of a particular open channel to obtain a +// consistent view of the number of channel updates to date. +func (c *OpenChannel) CommitmentHeight() (uint64, error) { + c.RLock() + defer c.RUnlock() + + return c.Db.CommitmentHeight(c) +} + +// FindPreviousState scans through the append-only log in an attempt to recover +// the previous channel state indicated by the update number. This method is +// intended to be used for obtaining the relevant data needed to claim all +// funds rightfully spendable in the case of an on-chain broadcast of the +// commitment transaction. +func (c *OpenChannel) FindPreviousState( + updateNum uint64) (*RevocationLog, *ChannelCommitment, error) { + + c.RLock() + defer c.RUnlock() + + return c.Db.FindPreviousState(c, updateNum) +} + +// CloseChannel closes a previously active Lightning channel. Closing a +// channel entails persisting a record of the close while either purging the +// nested per-channel state inline (synchronous backends like bbolt and etcd) +// or skipping the cascading delete on tombstone-enabled backends, where the +// outpoint-index flip to outpointClosed is the authoritative marker. The +// compact summary written to closedChannelBucket and the historical record +// under historicalChannelBucket are populated identically across both paths, +// so historical reads remain uniform regardless of backend. The optional set +// of channel statuses is OR'd into the chanStatus written to the historical +// bucket and is used to record close initiators. +func (c *OpenChannel) CloseChannel(summary *ChannelCloseSummary, + statuses ...ChannelStatus) error { + + c.Lock() + defer c.Unlock() + + return c.Db.CloseChannel(c, summary, statuses...) +} + +// Snapshot returns a read-only snapshot of the current channel state. This +// snapshot includes information concerning the current settled balance within +// the channel, metadata detailing total flows, and any outstanding HTLCs. +func (c *OpenChannel) Snapshot() *ChannelSnapshot { + c.RLock() + defer c.RUnlock() + + localCommit := c.LocalCommitment + snapshot := &ChannelSnapshot{ + RemoteIdentity: *c.IdentityPub, + ChannelPoint: c.FundingOutpoint, + Capacity: c.Capacity, + TotalMSatSent: c.TotalMSatSent, + TotalMSatReceived: c.TotalMSatReceived, + ChainHash: c.ChainHash, + ChannelCommitment: ChannelCommitment{ + LocalBalance: localCommit.LocalBalance, + RemoteBalance: localCommit.RemoteBalance, + CommitHeight: localCommit.CommitHeight, + CommitFee: localCommit.CommitFee, + }, + } + + localCommit.CustomBlob.WhenSome(func(blob tlv.Blob) { + blobCopy := make([]byte, len(blob)) + copy(blobCopy, blob) + + snapshot.ChannelCommitment.CustomBlob = fn.Some(blobCopy) + }) + + // Copy over the current set of HTLCs to ensure the caller can't mutate + // our internal state. + snapshot.Htlcs = make([]HTLC, len(localCommit.Htlcs)) + for i, h := range localCommit.Htlcs { + snapshot.Htlcs[i] = h.Copy() + } + + return snapshot +} + +// Copy returns a deep copy of the channel state. +func (c *OpenChannel) Copy() *OpenChannel { + c.RLock() + defer c.RUnlock() + + clone := &OpenChannel{ + ChanType: c.ChanType, + ChainHash: c.ChainHash, + FundingOutpoint: c.FundingOutpoint, + ShortChannelID: c.ShortChannelID, + IsPending: c.IsPending, + IsInitiator: c.IsInitiator, + chanStatus: c.chanStatus, + FundingBroadcastHeight: c.FundingBroadcastHeight, + ConfirmationHeight: c.ConfirmationHeight, + NumConfsRequired: c.NumConfsRequired, + ChannelFlags: c.ChannelFlags, + IdentityPub: c.IdentityPub, + Capacity: c.Capacity, + TotalMSatSent: c.TotalMSatSent, + TotalMSatReceived: c.TotalMSatReceived, + InitialLocalBalance: c.InitialLocalBalance, + InitialRemoteBalance: c.InitialRemoteBalance, + LocalChanCfg: c.LocalChanCfg, + RemoteChanCfg: c.RemoteChanCfg, + LocalCommitment: c.LocalCommitment.Copy(), + RemoteCommitment: c.RemoteCommitment.Copy(), + RemoteCurrentRevocation: c.RemoteCurrentRevocation, + RemoteNextRevocation: c.RemoteNextRevocation, + RevocationProducer: c.RevocationProducer, + RevocationStore: c.RevocationStore, + ThawHeight: c.ThawHeight, + LastWasRevoke: c.LastWasRevoke, + RevocationKeyLocator: c.RevocationKeyLocator, + confirmedScid: c.confirmedScid, + TapscriptRoot: c.TapscriptRoot, + } + + if c.FundingTxn != nil { + clone.FundingTxn = c.FundingTxn.Copy() + } + + if len(c.LocalShutdownScript) > 0 { + clone.LocalShutdownScript = make( + lnwire.DeliveryAddress, + len(c.LocalShutdownScript), + ) + copy(clone.LocalShutdownScript, c.LocalShutdownScript) + } + if len(c.RemoteShutdownScript) > 0 { + clone.RemoteShutdownScript = make( + lnwire.DeliveryAddress, + len(c.RemoteShutdownScript), + ) + copy(clone.RemoteShutdownScript, c.RemoteShutdownScript) + } + + if len(c.Memo) > 0 { + clone.Memo = make([]byte, len(c.Memo)) + copy(clone.Memo, c.Memo) + } + + c.CustomBlob.WhenSome(func(blob tlv.Blob) { + blobCopy := make([]byte, len(blob)) + copy(blobCopy, blob) + clone.CustomBlob = fn.Some(blobCopy) + }) + + return clone +} + +// LatestCommitments returns the two latest commitments for both the local and +// remote party. These commitments are read from disk to ensure that only the +// latest fully committed state is returned. The first commitment returned is +// the local commitment, and the second returned is the remote commitment. +func (c *OpenChannel) LatestCommitments() (*ChannelCommitment, + *ChannelCommitment, error) { + + return c.Db.LatestCommitments(c) +} + +// RemoteRevocationStore returns the most up to date commitment version of the +// revocation storage tree for the remote party. This method can be used when +// acting on a possible contract breach to ensure, that the caller has the most +// up to date information required to deliver justice. +func (c *OpenChannel) RemoteRevocationStore() (shachain.Store, error) { + return c.Db.RemoteRevocationStore(c) +} + +// AbsoluteThawHeight determines a frozen channel's absolute thaw height. If the +// channel is not frozen, then 0 is returned. +func (c *OpenChannel) AbsoluteThawHeight() (uint32, error) { + // Only frozen channels have a thaw height. + if !c.ChanType.IsFrozen() && !c.ChanType.HasLeaseExpiration() { + return 0, nil + } + + // If the channel has the frozen bit set and it's thaw height is below + // the absolute threshold, then it's interpreted as a relative height to + // the chain's current height. + if c.ChanType.IsFrozen() && c.ThawHeight < AbsoluteThawHeightThreshold { + // We'll only known of the channel's short ID once it's + // confirmed. + if c.IsPending { + return 0, errors.New("cannot use relative thaw " + + "height for unconfirmed channel") + } + + // For non-zero-conf channels, this is the base height to use. + blockHeightBase := c.ShortChannelID.BlockHeight + + // If this is a zero-conf channel, the ShortChannelID will be + // an alias. + if c.IsZeroConf() { + if !c.ZeroConfConfirmed() { + return 0, errors.New("cannot use relative " + + "height for unconfirmed zero-conf " + + "channel") + } + + // Use the confirmed SCID's BlockHeight. + blockHeightBase = c.confirmedScid.BlockHeight + } + + return blockHeightBase + c.ThawHeight, nil + } + + return c.ThawHeight, nil +} + +// DeriveHeightHint derives the block height for the channel opening. +func (c *OpenChannel) DeriveHeightHint() uint32 { + // As a height hint, we'll try to use the opening height, but if the + // channel isn't yet open, then we'll use the height it was broadcast + // at. This may be an unconfirmed zero-conf channel. + heightHint := c.ShortChanID().BlockHeight + if heightHint == 0 { + heightHint = c.BroadcastHeight() + } + + // Since no zero-conf state is stored in a channel backup, the below + // logic will not be triggered for restored, zero-conf channels. Set + // the height hint for zero-conf channels. + if c.IsZeroConf() { + if c.ZeroConfConfirmed() { + // If the zero-conf channel is confirmed, we'll use the + // confirmed SCID's block height. + heightHint = c.ZeroConfRealScid().BlockHeight + } else { + // The zero-conf channel is unconfirmed. We'll need to + // use the FundingBroadcastHeight. + heightHint = c.BroadcastHeight() + } + } + + return heightHint +} diff --git a/chanstate/open_channel_types.go b/chanstate/open_channel_types.go new file mode 100644 index 00000000000..90ab06bc629 --- /dev/null +++ b/chanstate/open_channel_types.go @@ -0,0 +1,16 @@ +package chanstate + +import "net" + +// ChannelShell is a shell of a channel that is meant to be used for channel +// recovery purposes. It contains a minimal OpenChannel instance along with +// addresses for that target node. +type ChannelShell struct { + // NodeAddrs the set of addresses that this node has known to be + // reachable at in the past. + NodeAddrs []net.Addr + + // Chan is a shell of an OpenChannel, it contains only the items + // required to restore the channel on disk. + Chan *OpenChannel +} diff --git a/chanstate/revocation_log.go b/chanstate/revocation_log.go new file mode 100644 index 00000000000..1bbea81407f --- /dev/null +++ b/chanstate/revocation_log.go @@ -0,0 +1,200 @@ +package chanstate + +import ( + "math" + + "github.com/btcsuite/btcd/btcutil" + "github.com/lightningnetwork/lnd/fn/v2" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/tlv" +) + +const ( + // OutputIndexEmpty is used when the output index doesn't exist. + OutputIndexEmpty = math.MaxUint16 +) + +type ( + // BigSizeAmount is a type alias for a TLV record of a btcutil.Amount. + BigSizeAmount = tlv.BigSizeT[btcutil.Amount] + + // BigSizeMilliSatoshi is a type alias for a TLV record of a + // lnwire.MilliSatoshi. + BigSizeMilliSatoshi = tlv.BigSizeT[lnwire.MilliSatoshi] +) + +// SparsePayHash is a type alias for a 32 byte array, which when serialized is +// able to save some space by not including an empty payment hash on disk. +type SparsePayHash [32]byte + +// NewSparsePayHash creates a new SparsePayHash from a 32 byte array. +func NewSparsePayHash(rHash [32]byte) SparsePayHash { + return SparsePayHash(rHash) +} + +// HTLCEntry specifies the minimal info needed to be stored on disk for ALL the +// historical HTLCs, which is useful for constructing RevocationLog when a +// breach is detected. +// The actual size of each HTLCEntry varies based on its RHash and Amt(sat), +// summarized as follows, +// +// | RHash | Amt<=252 | Amt<=65,535 | Amt<=4,294,967,295 | otherwise | +// |:-----:|:--------:|:-----------:|:------------------:|:---------:| +// | true | 19 | 21 | 23 | 26 | +// | false | 51 | 53 | 55 | 58 | +// +// So the size varies from 19 bytes to 58 bytes, where most likely to be 23 or +// 55 bytes. +// +// NOTE: all the fields saved to disk use the primitive go types so they can be +// made into tlv records without further conversion. +type HTLCEntry struct { + // RHash is the payment hash of the HTLC. + RHash tlv.RecordT[tlv.TlvType0, SparsePayHash] + + // RefundTimeout is the absolute timeout on the HTLC that the sender + // must wait before reclaiming the funds in limbo. + RefundTimeout tlv.RecordT[tlv.TlvType1, uint32] + + // OutputIndex is the output index for this particular HTLC output + // within the commitment transaction. + // + // NOTE: we use uint16 instead of int32 here to save us 2 bytes, which + // gives us a max number of HTLCs of 65K. + OutputIndex tlv.RecordT[tlv.TlvType2, uint16] + + // Incoming denotes whether we're the receiver or the sender of this + // HTLC. + Incoming tlv.RecordT[tlv.TlvType3, bool] + + // Amt is the amount of satoshis this HTLC escrows. + Amt tlv.RecordT[tlv.TlvType4, tlv.BigSizeT[btcutil.Amount]] + + // CustomBlob is an optional blob that can be used to store information + // specific to revocation handling for a custom channel type. + CustomBlob tlv.OptionalRecordT[tlv.TlvType5, tlv.Blob] + + // HtlcIndex is the index of the HTLC in the channel. + HtlcIndex tlv.OptionalRecordT[tlv.TlvType6, tlv.BigSizeT[uint64]] +} + +// NewHTLCEntryFromHTLC creates a new HTLCEntry from an HTLC. +func NewHTLCEntryFromHTLC(htlc HTLC) (*HTLCEntry, error) { + h := &HTLCEntry{ + RHash: tlv.NewRecordT[tlv.TlvType0]( + NewSparsePayHash(htlc.RHash), + ), + RefundTimeout: tlv.NewPrimitiveRecord[tlv.TlvType1]( + htlc.RefundTimeout, + ), + OutputIndex: tlv.NewPrimitiveRecord[tlv.TlvType2]( + uint16(htlc.OutputIndex), + ), + Incoming: tlv.NewPrimitiveRecord[tlv.TlvType3](htlc.Incoming), + Amt: tlv.NewRecordT[tlv.TlvType4]( + tlv.NewBigSizeT(htlc.Amt.ToSatoshis()), + ), + HtlcIndex: tlv.SomeRecordT(tlv.NewRecordT[tlv.TlvType6]( + tlv.NewBigSizeT(htlc.HtlcIndex), + )), + } + + if len(htlc.CustomRecords) != 0 { + blob, err := htlc.CustomRecords.Serialize() + if err != nil { + return nil, err + } + + h.CustomBlob = tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType5, tlv.Blob](blob), + ) + } + + return h, nil +} + +// RevocationLog stores the info needed to construct a breach retribution. Its +// fields can be viewed as a subset of a ChannelCommitment's. In the database, +// all historical versions of the RevocationLog are saved using the +// CommitHeight as the key. +type RevocationLog struct { + // OurOutputIndex specifies our output index in this commitment. In a + // remote commitment transaction, this is the to remote output index. + OurOutputIndex tlv.RecordT[tlv.TlvType0, uint16] + + // TheirOutputIndex specifies their output index in this commitment. In + // a remote commitment transaction, this is the to local output index. + TheirOutputIndex tlv.RecordT[tlv.TlvType1, uint16] + + // CommitTxHash is the hash of the latest version of the commitment + // state, broadcast able by us. + CommitTxHash tlv.RecordT[tlv.TlvType2, [32]byte] + + // HTLCEntries is the set of HTLCEntry's that are pending at this + // particular commitment height. + HTLCEntries []*HTLCEntry + + // OurBalance is the current available balance within the channel + // directly spendable by us. In other words, it is the value of the + // to_remote output on the remote parties' commitment transaction. + // + // NOTE: this is an option so that it is clear if the value is zero or + // nil. Since migration 30 of the channeldb initially did not include + // this field, it could be the case that the field is not present for + // all revocation logs. + OurBalance tlv.OptionalRecordT[tlv.TlvType3, BigSizeMilliSatoshi] + + // TheirBalance is the current available balance within the channel + // directly spendable by the remote node. In other words, it is the + // value of the to_local output on the remote parties' commitment. + // + // NOTE: this is an option so that it is clear if the value is zero or + // nil. Since migration 30 of the channeldb initially did not include + // this field, it could be the case that the field is not present for + // all revocation logs. + TheirBalance tlv.OptionalRecordT[tlv.TlvType4, BigSizeMilliSatoshi] + + // CustomBlob is an optional blob that can be used to store information + // specific to a custom channel type. This information is only created + // at channel funding time, and after wards is to be considered + // immutable. + CustomBlob tlv.OptionalRecordT[tlv.TlvType5, tlv.Blob] +} + +// NewRevocationLog creates a new RevocationLog from the given parameters. +func NewRevocationLog(ourOutputIndex uint16, theirOutputIndex uint16, + commitHash [32]byte, ourBalance, + theirBalance fn.Option[lnwire.MilliSatoshi], htlcs []*HTLCEntry, + customBlob fn.Option[tlv.Blob]) RevocationLog { + + rl := RevocationLog{ + OurOutputIndex: tlv.NewPrimitiveRecord[tlv.TlvType0]( + ourOutputIndex, + ), + TheirOutputIndex: tlv.NewPrimitiveRecord[tlv.TlvType1]( + theirOutputIndex, + ), + CommitTxHash: tlv.NewPrimitiveRecord[tlv.TlvType2](commitHash), + HTLCEntries: htlcs, + } + + ourBalance.WhenSome(func(balance lnwire.MilliSatoshi) { + rl.OurBalance = tlv.SomeRecordT(tlv.NewRecordT[tlv.TlvType3]( + tlv.NewBigSizeT(balance), + )) + }) + + theirBalance.WhenSome(func(balance lnwire.MilliSatoshi) { + rl.TheirBalance = tlv.SomeRecordT(tlv.NewRecordT[tlv.TlvType4]( + tlv.NewBigSizeT(balance), + )) + }) + + customBlob.WhenSome(func(blob tlv.Blob) { + rl.CustomBlob = tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType5, tlv.Blob](blob), + ) + }) + + return rl +} diff --git a/chanstate/shutdown.go b/chanstate/shutdown.go new file mode 100644 index 00000000000..4c1dca31a51 --- /dev/null +++ b/chanstate/shutdown.go @@ -0,0 +1,41 @@ +package chanstate + +import ( + "github.com/lightningnetwork/lnd/lntypes" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/tlv" +) + +// ShutdownInfo contains various info about the shutdown initiation of a +// channel. +type ShutdownInfo struct { + // DeliveryScript is the address that we have included in any previous + // Shutdown message for a particular channel and so should include in + // any future re-sends of the Shutdown message. + DeliveryScript tlv.RecordT[tlv.TlvType0, lnwire.DeliveryAddress] + + // LocalInitiator is true if we sent a Shutdown message before ever + // receiving a Shutdown message from the remote peer. + LocalInitiator tlv.RecordT[tlv.TlvType1, bool] +} + +// NewShutdownInfo constructs a new ShutdownInfo object. +func NewShutdownInfo(deliveryScript lnwire.DeliveryAddress, + locallyInitiated bool) *ShutdownInfo { + + return &ShutdownInfo{ + DeliveryScript: tlv.NewRecordT[tlv.TlvType0](deliveryScript), + LocalInitiator: tlv.NewPrimitiveRecord[tlv.TlvType1]( + locallyInitiated, + ), + } +} + +// Closer identifies the ChannelParty that initiated the coop-closure process. +func (s ShutdownInfo) Closer() lntypes.ChannelParty { + if s.LocalInitiator.Val { + return lntypes.Local + } + + return lntypes.Remote +} diff --git a/chanstate/snapshot.go b/chanstate/snapshot.go new file mode 100644 index 00000000000..b6c8e8238d0 --- /dev/null +++ b/chanstate/snapshot.go @@ -0,0 +1,44 @@ +package chanstate + +import ( + "github.com/btcsuite/btcd/btcec/v2" + "github.com/btcsuite/btcd/btcutil" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/lnwire" +) + +// ChannelSnapshot is a frozen snapshot of the current channel state. A +// snapshot is detached from the original channel that generated it, providing +// read-only access to the current or prior state of an active channel. +// +// TODO(roasbeef): remove all together? pretty much just commitment. +type ChannelSnapshot struct { + // RemoteIdentity is the identity public key of the remote node that we + // are maintaining the open channel with. + RemoteIdentity btcec.PublicKey + + // ChanPoint is the outpoint that created the channel. This output is + // found within the funding transaction and uniquely identified the + // channel on the resident chain. + ChannelPoint wire.OutPoint + + // ChainHash is the genesis hash of the chain that the channel resides + // within. + ChainHash chainhash.Hash + + // Capacity is the total capacity of the channel. + Capacity btcutil.Amount + + // TotalMSatSent is the total number of milli-satoshis we've sent + // within this channel. + TotalMSatSent lnwire.MilliSatoshi + + // TotalMSatReceived is the total number of milli-satoshis we've + // received within this channel. + TotalMSatReceived lnwire.MilliSatoshi + + // ChannelCommitment is the current up-to-date commitment for the + // target channel. + ChannelCommitment +} diff --git a/chanstate/taproot.go b/chanstate/taproot.go new file mode 100644 index 00000000000..3bb575a0801 --- /dev/null +++ b/chanstate/taproot.go @@ -0,0 +1,79 @@ +package chanstate + +import ( + "bytes" + "crypto/hmac" + "crypto/sha256" + "fmt" + + "github.com/btcsuite/btcd/btcec/v2" + "github.com/btcsuite/btcd/btcec/v2/schnorr/musig2" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/lightningnetwork/lnd/shachain" +) + +const ( + // AbsoluteThawHeightThreshold is the threshold at which a thaw height + // begins to be interpreted as an absolute block height, rather than a + // relative one. + AbsoluteThawHeightThreshold uint32 = 500000 +) + +var ( + // taprootRevRootKey is the key used to derive the revocation root for + // the taproot nonces. This is done via HMAC of the existing revocation + // root. + taprootRevRootKey = []byte("taproot-rev-root") +) + +// DeriveMusig2Shachain derives a shachain producer for the taproot channel +// from normal shachain revocation root. +func DeriveMusig2Shachain(revRoot shachain.Producer) (shachain.Producer, error) { //nolint:ll + // In order to obtain the revocation root hash to create the taproot + // revocation, we'll encode the producer into a buffer, then use that + // to derive the shachain root needed. + var rootHashBuf bytes.Buffer + if err := revRoot.Encode(&rootHashBuf); err != nil { + return nil, fmt.Errorf("unable to encode producer: %w", err) + } + + revRootHash := chainhash.HashH(rootHashBuf.Bytes()) + + // For taproot channel types, we'll also generate a distinct shachain + // root using the same seed information. We'll use this to generate + // verification nonces for the channel. We'll bind with this a simple + // hmac. + taprootRevHmac := hmac.New(sha256.New, taprootRevRootKey) + if _, err := taprootRevHmac.Write(revRootHash[:]); err != nil { + return nil, err + } + + taprootRevRoot := taprootRevHmac.Sum(nil) + + // Once we have the root, we can then generate our shachain producer + // and from that generate the per-commitment point. + return shachain.NewRevocationProducerFromBytes( + taprootRevRoot, + ) +} + +// NewMusigVerificationNonce generates the local or verification nonce for +// another musig2 session. In order to permit our implementation to not have to +// write any secret nonce state to disk, we'll use the _next_ shachain +// pre-image as our primary randomness source. When used to generate the nonce +// again to broadcast our commitment hte current height will be used. +func NewMusigVerificationNonce(pubKey *btcec.PublicKey, targetHeight uint64, + shaGen shachain.Producer) (*musig2.Nonces, error) { + + // Now that we know what height we need, we'll grab the shachain + // pre-image at the target destination. + nextPreimage, err := shaGen.AtIndex(targetHeight) + if err != nil { + return nil, err + } + + shaChainRand := musig2.WithCustomRand(bytes.NewBuffer(nextPreimage[:])) + pubKeyOpt := musig2.WithPublicKey(pubKey) + + return musig2.GenNonces(pubKeyOpt, shaChainRand) +} diff --git a/contractcourt/anchor_resolver.go b/contractcourt/anchor_resolver.go index 7e267678240..2a825113166 100644 --- a/contractcourt/anchor_resolver.go +++ b/contractcourt/anchor_resolver.go @@ -10,6 +10,7 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/chanstate" "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/sweep" @@ -159,7 +160,7 @@ func (c *anchorResolver) Stop() { // state required for the proper resolution of a contract. // // NOTE: Part of the ContractResolver interface. -func (c *anchorResolver) SupplementState(state *channeldb.OpenChannel) { +func (c *anchorResolver) SupplementState(state *chanstate.OpenChannel) { c.chanType = state.ChanType } diff --git a/contractcourt/breach_arbitrator_test.go b/contractcourt/breach_arbitrator_test.go index 869a0093e01..e4d887348d5 100644 --- a/contractcourt/breach_arbitrator_test.go +++ b/contractcourt/breach_arbitrator_test.go @@ -22,6 +22,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/chanstate" "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" @@ -952,7 +953,8 @@ func initBreachedState(t *testing.T) (*BreachArbitrator, contractBreaches := make(chan *ContractBreachEvent) brar, err := createTestArbiter( - t, contractBreaches, alice.State().Db.GetParentDB(), + t, contractBreaches, + testChannelStateDB(t, alice.State()).GetParentDB(), ) require.NoError(t, err, "unable to initialize test breach arbiter") @@ -1118,7 +1120,8 @@ func TestBreachHandoffFail(t *testing.T) { assertNotPendingClosed(t, alice) brar, err := createTestArbiter( - t, contractBreaches, alice.State().Db.GetParentDB(), + t, contractBreaches, + testChannelStateDB(t, alice.State()).GetParentDB(), ) require.NoError(t, err, "unable to initialize test breach arbiter") @@ -1763,7 +1766,9 @@ func testBreachSpends(t *testing.T, test breachTest) { } // Assert that the channel is fully resolved. - assertBrarCleanup(t, brar, &chanPoint, alice.State().Db) + assertBrarCleanup( + t, brar, &chanPoint, testChannelStateDB(t, alice.State()), + ) } // TestBreachDelayedJusticeConfirmation tests that the breach arbiter will @@ -1968,7 +1973,9 @@ func TestBreachDelayedJusticeConfirmation(t *testing.T) { } // Assert that the channel is fully resolved. - assertBrarCleanup(t, brar, &chanPoint, alice.State().Db) + assertBrarCleanup( + t, brar, &chanPoint, testChannelStateDB(t, alice.State()), + ) } // findInputIndex returns the index of the input that spends from the given @@ -2080,7 +2087,9 @@ func assertBrarCleanup(t *testing.T, brar *BreachArbitrator, func assertPendingClosed(t *testing.T, c *lnwallet.LightningChannel) { t.Helper() - closedChans, err := c.State().Db.FetchClosedChannels(true) + closedChans, err := testChannelStateDB( + t, c.State(), + ).FetchClosedChannels(true) require.NoError(t, err, "unable to load pending closed channels") for _, chanSummary := range closedChans { @@ -2097,7 +2106,9 @@ func assertPendingClosed(t *testing.T, c *lnwallet.LightningChannel) { func assertNotPendingClosed(t *testing.T, c *lnwallet.LightningChannel) { t.Helper() - closedChans, err := c.State().Db.FetchClosedChannels(true) + closedChans, err := testChannelStateDB( + t, c.State(), + ).FetchClosedChannels(true) require.NoError(t, err, "unable to load pending closed channels") for _, chanSummary := range closedChans { @@ -2305,7 +2316,7 @@ func createInitChannels(t *testing.T) ( binary.BigEndian.Uint64(chanIDBytes[:]), ) - aliceChannelState := &channeldb.OpenChannel{ + aliceChannelState := &chanstate.OpenChannel{ LocalChanCfg: aliceCfg, RemoteChanCfg: bobCfg, IdentityPub: aliceKeyPub, @@ -2320,10 +2331,9 @@ func createInitChannels(t *testing.T) ( LocalCommitment: aliceCommit, RemoteCommitment: aliceCommit, Db: dbAlice.ChannelStateDB(), - Packager: channeldb.NewChannelPackager(shortChanID), FundingTxn: channels.TestFundingTx, } - bobChannelState := &channeldb.OpenChannel{ + bobChannelState := &chanstate.OpenChannel{ LocalChanCfg: bobCfg, RemoteChanCfg: aliceCfg, IdentityPub: bobKeyPub, @@ -2338,7 +2348,6 @@ func createInitChannels(t *testing.T) ( LocalCommitment: bobCommit, RemoteCommitment: bobCommit, Db: dbBob.ChannelStateDB(), - Packager: channeldb.NewChannelPackager(shortChanID), } aliceSigner := input.NewMockSigner( diff --git a/contractcourt/breach_resolver.go b/contractcourt/breach_resolver.go index f341128006c..29a7f6bac96 100644 --- a/contractcourt/breach_resolver.go +++ b/contractcourt/breach_resolver.go @@ -5,7 +5,7 @@ import ( "fmt" "io" - "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/chanstate" ) // breachResolver is a resolver that will handle breached closes. In the @@ -88,7 +88,7 @@ func (b *breachResolver) Stop() { } // SupplementState adds additional state to the breachResolver. -func (b *breachResolver) SupplementState(_ *channeldb.OpenChannel) { +func (b *breachResolver) SupplementState(_ *chanstate.OpenChannel) { } // Encode encodes the breachResolver to the passed writer. diff --git a/contractcourt/chain_arbitrator.go b/contractcourt/chain_arbitrator.go index eac63cb32d1..5aa0c269320 100644 --- a/contractcourt/chain_arbitrator.go +++ b/contractcourt/chain_arbitrator.go @@ -14,6 +14,7 @@ import ( "github.com/lightningnetwork/lnd/chainio" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/chanstate" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph/db/models" @@ -317,7 +318,7 @@ var _ chainio.Consumer = (*ChainArbitrator)(nil) // interact with. type arbChannel struct { // channel is the in-memory channel state. - channel *channeldb.OpenChannel + channel *chanstate.OpenChannel // c references the chain arbitrator and is used by arbChannel // internally. @@ -426,7 +427,7 @@ func (a *arbChannel) ForceCloseChan() (*wire.MsgTx, error) { // newActiveChannelArbitrator creates a new instance of an active channel // arbitrator given the state of the target channel. -func newActiveChannelArbitrator(channel *channeldb.OpenChannel, +func newActiveChannelArbitrator(channel *chanstate.OpenChannel, c *ChainArbitrator, chanEvents *ChainEventSubscription) (*ChannelArbitrator, error) { // TODO(roasbeef): fetch best height (or pass in) so can ensure block @@ -464,7 +465,7 @@ func newActiveChannelArbitrator(channel *channeldb.OpenChannel, tx, c.cfg.ChainHash, &chanPoint, report, ) }, - FetchHistoricalChannel: func() (*channeldb.OpenChannel, error) { + FetchHistoricalChannel: func() (*chanstate.OpenChannel, error) { chanStateDB := c.chanSource.ChannelStateDB() return chanStateDB.FetchHistoricalChannel(&chanPoint) }, @@ -516,7 +517,7 @@ func newActiveChannelArbitrator(channel *channeldb.OpenChannel, // getArbChannel returns an open channel wrapper for use by channel arbitrators. func (c *ChainArbitrator) getArbChannel( - channel *channeldb.OpenChannel) *arbChannel { + channel *chanstate.OpenChannel) *arbChannel { return &arbChannel{ channel: channel, @@ -824,7 +825,7 @@ func (c *ChainArbitrator) notifyChannelResolved(cp wire.OutPoint) { // transactions and republish them. This helps ensure propagation of the // transactions in the event that prior publications failed. func (c *ChainArbitrator) republishClosingTxs( - channel *channeldb.OpenChannel) error { + channel *chanstate.OpenChannel) error { // If the channel has had its unilateral close broadcasted already, // republish it in case it didn't propagate. @@ -856,7 +857,7 @@ func (c *ChainArbitrator) republishClosingTxs( // // NOTE: There is no risk to calling this method if the channel isn't in either // CommitmentBroadcasted or CoopBroadcasted, but the logs will be misleading. -func (c *ChainArbitrator) rebroadcast(channel *channeldb.OpenChannel, +func (c *ChainArbitrator) rebroadcast(channel *chanstate.OpenChannel, state channeldb.ChannelStatus) error { chanPoint := channel.FundingOutpoint @@ -1115,7 +1116,9 @@ func (c *ChainArbitrator) ForceCloseContract(chanPoint wire.OutPoint) (*wire.Msg // ChannelArbitrator tasked with watching over a new channel. Once a new // channel has finished its final funding flow, it should be registered with // the ChainArbitrator so we can properly react to any on-chain events. -func (c *ChainArbitrator) WatchNewChannel(newChan *channeldb.OpenChannel) error { +func (c *ChainArbitrator) WatchNewChannel( + newChan *chanstate.OpenChannel) error { + c.Lock() defer c.Unlock() @@ -1398,7 +1401,7 @@ func (c *ChainArbitrator) loadPendingCloseChannels() error { tx, c.cfg.ChainHash, &chanPoint, report, ) }, - FetchHistoricalChannel: func() (*channeldb.OpenChannel, error) { + FetchHistoricalChannel: func() (*chanstate.OpenChannel, error) { return chanStateDB.FetchHistoricalChannel(&chanPoint) }, FindOutgoingHTLCDeadline: func( diff --git a/contractcourt/chain_arbitrator_test.go b/contractcourt/chain_arbitrator_test.go index 622686f76c4..d19390c730f 100644 --- a/contractcourt/chain_arbitrator_test.go +++ b/contractcourt/chain_arbitrator_test.go @@ -8,6 +8,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/chanstate" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/lntest/mock" @@ -26,7 +27,7 @@ func TestChainArbitratorRepublishCloses(t *testing.T) { // Create 10 test channels and sync them to the database. const numChans = 10 - var channels []*channeldb.OpenChannel + var channels []*chanstate.OpenChannel for i := 0; i < numChans; i++ { lChannel, _, err := lnwallet.CreateTestChannels( t, channeldb.SingleFunderTweaklessBit, diff --git a/contractcourt/chain_watcher.go b/contractcourt/chain_watcher.go index e45bb3dc99b..7cae7dfce0a 100644 --- a/contractcourt/chain_watcher.go +++ b/contractcourt/chain_watcher.go @@ -19,6 +19,7 @@ import ( "github.com/lightningnetwork/lnd/chainio" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/chanstate" "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lntypes" @@ -239,7 +240,7 @@ type chainWatcherConfig struct { // chanState is a snapshot of the persistent state of the channel that // we're watching. In the event of an on-chain event, we'll query the // database to ensure that we act using the most up to date state. - chanState *channeldb.OpenChannel + chanState *chanstate.OpenChannel // notifier is a reference to the channel notifier that we'll use to be // notified of output spends and when transactions are confirmed. @@ -627,7 +628,7 @@ type chainSet struct { // newChainSet creates a new chainSet given the current up to date channel // state. -func newChainSet(chanState *channeldb.OpenChannel) (*chainSet, error) { +func newChainSet(chanState *chanstate.OpenChannel) (*chainSet, error) { // First, we'll grab the current unrevoked commitments for ourselves // and the remote party. localCommit, remoteCommit, err := chanState.LatestCommitments() @@ -1698,7 +1699,7 @@ func (c *chainWatcher) waitForCommitmentPoint() *btcec.PublicKey { } // deriveFundingPkScript derives the script used in the funding output. -func deriveFundingPkScript(chanState *channeldb.OpenChannel) ([]byte, error) { +func deriveFundingPkScript(chanState *chanstate.OpenChannel) ([]byte, error) { localKey := chanState.LocalChanCfg.MultiSigKey.PubKey remoteKey := chanState.RemoteChanCfg.MultiSigKey.PubKey diff --git a/contractcourt/chain_watcher_test.go b/contractcourt/chain_watcher_test.go index 8275886a140..a38763ff7f6 100644 --- a/contractcourt/chain_watcher_test.go +++ b/contractcourt/chain_watcher_test.go @@ -12,6 +12,7 @@ import ( "github.com/lightningnetwork/lnd/chainio" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/chanstate" "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" lnmock "github.com/lightningnetwork/lnd/lntest/mock" @@ -264,11 +265,11 @@ type dlpTestCase struct { // state) are returned. func executeStateTransitions(t *testing.T, htlcAmount lnwire.MilliSatoshi, aliceChannel, bobChannel *lnwallet.LightningChannel, - numUpdates uint8) ([]*channeldb.OpenChannel, error) { + numUpdates uint8) ([]*chanstate.OpenChannel, error) { // We'll make a copy of the channel state before each transition. var ( - chanStates []*channeldb.OpenChannel + chanStates []*chanstate.OpenChannel ) state, err := copyChannelState(t, aliceChannel.State()) diff --git a/contractcourt/channel_arbitrator.go b/contractcourt/channel_arbitrator.go index 458a8a0d254..eddb1219b8e 100644 --- a/contractcourt/channel_arbitrator.go +++ b/contractcourt/channel_arbitrator.go @@ -16,6 +16,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/chainio" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/chanstate" "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/htlcswitch/hop" @@ -166,7 +167,7 @@ type ChannelArbitratorConfig struct { // FetchHistoricalChannel retrieves the historical state of a channel. // This is mostly used to supplement the ContractResolvers with // additional information required for proper contract resolution. - FetchHistoricalChannel func() (*channeldb.OpenChannel, error) + FetchHistoricalChannel func() (*chanstate.OpenChannel, error) // FindOutgoingHTLCDeadline returns the deadline in absolute block // height for the specified outgoing HTLC. For an outgoing HTLC, its @@ -735,7 +736,7 @@ func (c *ChannelArbitrator) relaunchResolvers(commitSet *CommitSet, // We'll also fetch the historical state of this channel, as it should // have been marked as closed by now, and supplement it to each resolver // such that we can properly resolve our pending contracts. - var chanState *channeldb.OpenChannel + var chanState *chanstate.OpenChannel chanState, err = c.cfg.FetchHistoricalChannel() switch { // If we don't find this channel, then it may be the case that it @@ -2364,7 +2365,7 @@ func (c *ChannelArbitrator) prepContractResolutions( // We'll also fetch the historical state of this channel, as it should // have been marked as closed by now, and supplement it to each resolver // such that we can properly resolve our pending contracts. - var chanState *channeldb.OpenChannel + var chanState *chanstate.OpenChannel chanState, err := c.cfg.FetchHistoricalChannel() switch { // If we don't find this channel, then it may be the case that it diff --git a/contractcourt/channel_arbitrator_test.go b/contractcourt/channel_arbitrator_test.go index 37b9310399e..257d5655402 100644 --- a/contractcourt/channel_arbitrator_test.go +++ b/contractcourt/channel_arbitrator_test.go @@ -17,6 +17,7 @@ import ( "github.com/lightningnetwork/lnd/chainio" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/chanstate" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph/db/models" @@ -447,8 +448,8 @@ func createTestChannelArbitrator(t *testing.T, log ArbitratorLog, return nil }, - FetchHistoricalChannel: func() (*channeldb.OpenChannel, error) { - return &channeldb.OpenChannel{}, nil + FetchHistoricalChannel: func() (*chanstate.OpenChannel, error) { + return &chanstate.OpenChannel{}, nil }, FindOutgoingHTLCDeadline: func( htlc channeldb.HTLC) fn.Option[int32] { @@ -2163,7 +2164,9 @@ func TestChannelArbitratorPendingExpiredHTLC(t *testing.T) { func TestRemoteCloseInitiator(t *testing.T) { // getCloseSummary returns a unilateral close summary for the channel // provided. - getCloseSummary := func(channel *channeldb.OpenChannel) *RemoteUnilateralCloseInfo { + getCloseSummary := func( + channel *chanstate.OpenChannel) *RemoteUnilateralCloseInfo { + return &RemoteUnilateralCloseInfo{ UnilateralCloseSummary: &lnwallet.UnilateralCloseSummary{ SpendDetail: &chainntnfs.SpendDetail{ @@ -2193,7 +2196,7 @@ func TestRemoteCloseInitiator(t *testing.T) { // is expected to be buffered, as is the default for test // channel arbitrators. notifyClose func(sub *ChainEventSubscription, - channel *channeldb.OpenChannel) + channel *chanstate.OpenChannel) // expectedStates is the set of states we expect the arbitrator // to progress through. @@ -2202,7 +2205,7 @@ func TestRemoteCloseInitiator(t *testing.T) { { name: "force close", notifyClose: func(sub *ChainEventSubscription, - channel *channeldb.OpenChannel) { + channel *chanstate.OpenChannel) { s := getCloseSummary(channel) sub.RemoteUnilateralClosure <- s diff --git a/contractcourt/commit_sweep_resolver.go b/contractcourt/commit_sweep_resolver.go index dd02e848b4a..21722b3c232 100644 --- a/contractcourt/commit_sweep_resolver.go +++ b/contractcourt/commit_sweep_resolver.go @@ -12,6 +12,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/chanstate" "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lnwallet" @@ -210,7 +211,7 @@ func (c *commitSweepResolver) Stop() { // state required for the proper resolution of a contract. // // NOTE: Part of the ContractResolver interface. -func (c *commitSweepResolver) SupplementState(state *channeldb.OpenChannel) { +func (c *commitSweepResolver) SupplementState(state *chanstate.OpenChannel) { if state.ChanType.HasLeaseExpiration() { c.leaseExpiry = state.ThawHeight } diff --git a/contractcourt/contract_resolver.go b/contractcourt/contract_resolver.go index d11bd2f597a..dc05069dafe 100644 --- a/contractcourt/contract_resolver.go +++ b/contractcourt/contract_resolver.go @@ -10,6 +10,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btclog/v2" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/chanstate" "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/sweep" ) @@ -59,7 +60,7 @@ type ContractResolver interface { // SupplementState allows the user of a ContractResolver to supplement // it with state required for the proper resolution of a contract. - SupplementState(*channeldb.OpenChannel) + SupplementState(*chanstate.OpenChannel) // IsResolved returns true if the stored state in the resolve is fully // resolved. In this case the target output can be forgotten. diff --git a/contractcourt/htlc_lease_resolver.go b/contractcourt/htlc_lease_resolver.go index 3002cec0b75..944eeb85a40 100644 --- a/contractcourt/htlc_lease_resolver.go +++ b/contractcourt/htlc_lease_resolver.go @@ -3,7 +3,7 @@ package contractcourt import ( "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/chainntnfs" - "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/chanstate" "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/tlv" @@ -76,7 +76,7 @@ func (h *htlcLeaseResolver) makeSweepInput(op *wire.OutPoint, // state required for the proper resolution of a contract. // // NOTE: Part of the ContractResolver interface. -func (h *htlcLeaseResolver) SupplementState(state *channeldb.OpenChannel) { +func (h *htlcLeaseResolver) SupplementState(state *chanstate.OpenChannel) { if state.ChanType.HasLeaseExpiration() { h.leaseExpiry = state.ThawHeight } diff --git a/contractcourt/htlc_success_resolver.go b/contractcourt/htlc_success_resolver.go index 1770c214a45..82527d54773 100644 --- a/contractcourt/htlc_success_resolver.go +++ b/contractcourt/htlc_success_resolver.go @@ -11,6 +11,7 @@ import ( "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/chanstate" "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/input" @@ -384,7 +385,7 @@ func (h *htlcSuccessResolver) HtlcPoint() wire.OutPoint { // production taproot channels after restart. // // NOTE: Part of the ContractResolver interface. -func (h *htlcSuccessResolver) SupplementState(state *channeldb.OpenChannel) { +func (h *htlcSuccessResolver) SupplementState(state *chanstate.OpenChannel) { h.htlcLeaseResolver.SupplementState(state) h.chanType = state.ChanType } diff --git a/contractcourt/htlc_timeout_resolver.go b/contractcourt/htlc_timeout_resolver.go index c2cbb133beb..eed83510baf 100644 --- a/contractcourt/htlc_timeout_resolver.go +++ b/contractcourt/htlc_timeout_resolver.go @@ -12,6 +12,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/chanstate" "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lntypes" @@ -775,7 +776,7 @@ func (h *htlcTimeoutResolver) HtlcPoint() wire.OutPoint { // production taproot channels after restart. // // NOTE: Part of the ContractResolver interface. -func (h *htlcTimeoutResolver) SupplementState(state *channeldb.OpenChannel) { +func (h *htlcTimeoutResolver) SupplementState(state *chanstate.OpenChannel) { h.htlcLeaseResolver.SupplementState(state) h.chanType = state.ChanType } diff --git a/contractcourt/utils_test.go b/contractcourt/utils_test.go index 994bc57a88c..915e4c3ff70 100644 --- a/contractcourt/utils_test.go +++ b/contractcourt/utils_test.go @@ -10,8 +10,22 @@ import ( "time" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/chanstate" ) +func testChannelStateDB(t testing.TB, + state *chanstate.OpenChannel) *channeldb.ChannelStateDB { + + t.Helper() + + cdb, ok := state.Db.(*channeldb.ChannelStateDB) + if !ok { + t.Fatalf("expected ChannelStateDB, got %T", state.Db) + } + + return cdb +} + // timeout implements a test level timeout. func timeout() func() { done := make(chan struct{}) @@ -52,11 +66,13 @@ func copyFile(dest, src string) error { // copyChannelState copies the OpenChannel state by copying the database and // creating a new struct from it. The copied state is returned. -func copyChannelState(t *testing.T, state *channeldb.OpenChannel) ( - *channeldb.OpenChannel, error) { +func copyChannelState(t *testing.T, state *chanstate.OpenChannel) ( + *chanstate.OpenChannel, error) { // Make a copy of the DB. - dbFile := filepath.Join(state.Db.GetParentDB().Path(), "channel.db") + dbFile := filepath.Join( + testChannelStateDB(t, state).GetParentDB().Path(), "channel.db", + ) tempDbPath := t.TempDir() tempDbFile := filepath.Join(tempDbPath, "channel.db") diff --git a/discovery/ban.go b/discovery/ban.go index 0425948cbb1..80f5f8f7e14 100644 --- a/discovery/ban.go +++ b/discovery/ban.go @@ -10,7 +10,7 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/lightninglabs/neutrino/cache" "github.com/lightninglabs/neutrino/cache/lru" - "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/chanstate" "github.com/lightningnetwork/lnd/lnwire" ) @@ -67,7 +67,7 @@ type GraphCloser interface { type NodeInfoInquirer interface { // FetchOpenChannels returns the set of channels that we have with the // peer identified by the passed-in public key. - FetchOpenChannels(*btcec.PublicKey) ([]*channeldb.OpenChannel, error) + FetchOpenChannels(*btcec.PublicKey) ([]*chanstate.OpenChannel, error) } // ScidCloserMan helps the gossiper handle closed channels that are in the diff --git a/discovery/gossiper.go b/discovery/gossiper.go index 0c533d49213..7eccc5f70e2 100644 --- a/discovery/gossiper.go +++ b/discovery/gossiper.go @@ -25,6 +25,7 @@ import ( "github.com/lightningnetwork/lnd/batch" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/chanstate" "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph" graphdb "github.com/lightningnetwork/lnd/graph/db" @@ -383,7 +384,7 @@ type Config struct { // FindChannel allows the gossiper to find a channel that we're party // to without iterating over the entire set of open channels. FindChannel func(node *btcec.PublicKey, chanID lnwire.ChannelID) ( - *channeldb.OpenChannel, error) + *chanstate.OpenChannel, error) // IsStillZombieChannel returns true if the channel described by info // should still be considered a zombie. diff --git a/discovery/gossiper_test.go b/discovery/gossiper_test.go index 199532642d1..aba84c2ac04 100644 --- a/discovery/gossiper_test.go +++ b/discovery/gossiper_test.go @@ -27,6 +27,7 @@ import ( "github.com/lightningnetwork/lnd/batch" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/chanstate" "github.com/lightningnetwork/lnd/graph" graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/graph/db/models" @@ -891,7 +892,7 @@ func (ctx *testCtx) createChannelAnnouncement(blockHeight uint32, key1, } func mockFindChannel(node *btcec.PublicKey, chanID lnwire.ChannelID) ( - *channeldb.OpenChannel, error) { + *chanstate.OpenChannel, error) { return nil, nil } diff --git a/funding/manager.go b/funding/manager.go index 2dcd4f1b737..77655e0eb79 100644 --- a/funding/manager.go +++ b/funding/manager.go @@ -425,7 +425,7 @@ type Config struct { // channel ID. Providing the node's public key is an optimization that // prevents deserializing and scanning through all possible channels. FindChannel func(node *btcec.PublicKey, - chanID lnwire.ChannelID) (*channeldb.OpenChannel, error) + chanID lnwire.ChannelID) (*chanstate.OpenChannel, error) // TempChanIDSeed is a cryptographically random string of bytes that's // used as a seed to generate pending channel ID's. @@ -475,7 +475,7 @@ type Config struct { // the channel to the ChainArbitrator so it can watch for any on-chain // events related to the channel. We also provide the public key of the // node we're establishing a channel with for reconnection purposes. - WatchNewChannel func(*channeldb.OpenChannel, *btcec.PublicKey) error + WatchNewChannel func(*chanstate.OpenChannel, *btcec.PublicKey) error // ReportShortChanID allows the funding manager to report the confirmed // short channel ID of a formerly pending zero-conf channel to outside @@ -525,7 +525,7 @@ type Config struct { // NotifyPendingOpenChannelEvent informs the ChannelNotifier when // channels enter a pending state. NotifyPendingOpenChannelEvent func(wire.OutPoint, - *channeldb.OpenChannel, *btcec.PublicKey) + *chanstate.OpenChannel, *btcec.PublicKey) // NotifyFundingTimeout informs the ChannelNotifier when a pending-open // channel times out because the funding transaction hasn't confirmed. @@ -811,7 +811,7 @@ func (f *Manager) Stop() error { // rebroadcastFundingTx publishes the funding tx on startup for each // unconfirmed channel. -func (f *Manager) rebroadcastFundingTx(c *channeldb.OpenChannel) { +func (f *Manager) rebroadcastFundingTx(c *chanstate.OpenChannel) { var fundingTxBuf bytes.Buffer err := c.FundingTxn.Serialize(&fundingTxBuf) if err != nil { @@ -1090,7 +1090,7 @@ func (f *Manager) reservationCoordinator() { // OpenStatusUpdates. // // NOTE: This MUST be run as a goroutine. -func (f *Manager) advanceFundingState(channel *channeldb.OpenChannel, +func (f *Manager) advanceFundingState(channel *chanstate.OpenChannel, pendingChanID PendingChanID, updateChan chan<- *lnrpc.OpenStatusUpdate) { @@ -1171,7 +1171,7 @@ func (f *Manager) advanceFundingState(channel *channeldb.OpenChannel, // machine. This method is synchronous and the new channel opening state will // have been written to the database when it successfully returns. The // updateChan can be set non-nil to get OpenStatusUpdates. -func (f *Manager) stateStep(channel *channeldb.OpenChannel, +func (f *Manager) stateStep(channel *chanstate.OpenChannel, lnChannel *lnwallet.LightningChannel, shortChanID *lnwire.ShortChannelID, pendingChanID PendingChanID, channelState channelOpeningState, @@ -1296,7 +1296,7 @@ func (f *Manager) stateStep(channel *channeldb.OpenChannel, // advancePendingChannelState waits for a pending channel's funding tx to // confirm, and marks it open in the database when that happens. -func (f *Manager) advancePendingChannelState(channel *channeldb.OpenChannel, +func (f *Manager) advancePendingChannelState(channel *chanstate.OpenChannel, pendingChanID PendingChanID) error { if channel.IsZeroConf() { @@ -2962,7 +2962,7 @@ type confirmedChannel struct { // an ErrConfirmationTimeout. It is used to clean-up channel state and mark the // channel as closed. The error is only returned for the responder of the // channel flow. -func (f *Manager) fundingTimeout(c *channeldb.OpenChannel, +func (f *Manager) fundingTimeout(c *chanstate.OpenChannel, pendingID PendingChanID) error { // We'll get a timeout if the number of blocks mined since the channel @@ -3039,7 +3039,7 @@ func (f *Manager) fundingTimeout(c *channeldb.OpenChannel, // funding broadcast height. In case of confirmation, the short channel ID of // the channel and the funding transaction will be returned. func (f *Manager) waitForFundingWithTimeout( - ch *channeldb.OpenChannel) (*confirmedChannel, error) { + ch *chanstate.OpenChannel) (*confirmedChannel, error) { confChan := make(chan *confirmedChannel) timeoutChan := make(chan error, 1) @@ -3080,7 +3080,7 @@ func (f *Manager) waitForFundingWithTimeout( // MakeFundingScript re-creates the funding script for the funding transaction // of the target channel. -func MakeFundingScript(channel *channeldb.OpenChannel) ([]byte, error) { +func MakeFundingScript(channel *chanstate.OpenChannel) ([]byte, error) { localKey := channel.LocalChanCfg.MultiSigKey.PubKey remoteKey := channel.RemoteChanCfg.MultiSigKey.PubKey @@ -3118,7 +3118,7 @@ func MakeFundingScript(channel *channeldb.OpenChannel) ([]byte, error) { // // NOTE: This MUST be run as a goroutine. func (f *Manager) waitForFundingConfirmation( - completeChan *channeldb.OpenChannel, cancelChan <-chan struct{}, + completeChan *chanstate.OpenChannel, cancelChan <-chan struct{}, confChan chan<- *confirmedChannel) { defer f.wg.Done() @@ -3283,7 +3283,7 @@ func (f *Manager) waitForFundingConfirmation( // based on the confirmation details and sends this information, along with the // funding transaction, to the provided confirmation channel. func (f *Manager) handleConfirmation(confDetails *chainntnfs.TxConfirmation, - completeChan *channeldb.OpenChannel, + completeChan *chanstate.OpenChannel, confChan chan<- *confirmedChannel) error { fundingPoint := completeChan.FundingOutpoint @@ -3318,7 +3318,7 @@ func (f *Manager) handleConfirmation(confDetails *chainntnfs.TxConfirmation, // // NOTE: timeoutChan MUST be buffered. // NOTE: This MUST be run as a goroutine. -func (f *Manager) waitForTimeout(completeChan *channeldb.OpenChannel, +func (f *Manager) waitForTimeout(completeChan *chanstate.OpenChannel, cancelChan <-chan struct{}, timeoutChan chan<- error) { defer f.wg.Done() @@ -3390,7 +3390,7 @@ func (f *Manager) waitForTimeout(completeChan *channeldb.OpenChannel, // our short channel ID, which is known now that our funding transaction has // confirmed. We do not label transactions we did not publish, because our // wallet has no knowledge of them. -func (f *Manager) makeLabelForTx(c *channeldb.OpenChannel) { +func (f *Manager) makeLabelForTx(c *chanstate.OpenChannel) { if c.IsInitiator && c.ChanType.HasFundingTx() { shortChanID := c.ShortChanID() @@ -3416,7 +3416,7 @@ func (f *Manager) makeLabelForTx(c *channeldb.OpenChannel) { // decided short channel ID to the switch, and close the local discovery signal // for this channel. func (f *Manager) handleFundingConfirmation( - completeChan *channeldb.OpenChannel, + completeChan *chanstate.OpenChannel, confChannel *confirmedChannel) error { fundingPoint := completeChan.FundingOutpoint @@ -3495,7 +3495,7 @@ func (f *Manager) handleFundingConfirmation( // sendChannelReady creates and sends the channelReady message. // This should be called after the funding transaction has been confirmed, // and the channelState is 'markedOpen'. -func (f *Manager) sendChannelReady(completeChan *channeldb.OpenChannel, +func (f *Manager) sendChannelReady(completeChan *chanstate.OpenChannel, channel *lnwallet.LightningChannel) error { chanID := lnwire.NewChanIDFromOutPoint(completeChan.FundingOutpoint) @@ -3685,7 +3685,7 @@ func (f *Manager) receivedChannelReady(node *btcec.PublicKey, // extractAnnounceParams extracts the various channel announcement and update // parameters that will be needed to construct a ChannelAnnouncement and a // ChannelUpdate. -func (f *Manager) extractAnnounceParams(c *channeldb.OpenChannel) ( +func (f *Manager) extractAnnounceParams(c *chanstate.OpenChannel) ( lnwire.MilliSatoshi, lnwire.MilliSatoshi) { // We'll obtain the min HTLC value we can forward in our direction, as @@ -3744,7 +3744,7 @@ func mapGossipError(err error, msgType string) error { // The peerAlias is used for zero-conf channels to give the counter-party a // ChannelUpdate they understand. ourPolicy may be set for various // option-scid-alias channels to re-use the same policy. -func (f *Manager) addToGraph(completeChan *channeldb.OpenChannel, +func (f *Manager) addToGraph(completeChan *chanstate.OpenChannel, shortChanID *lnwire.ShortChannelID, peerAlias *lnwire.ShortChannelID, ourPolicy *models.ChannelEdgePolicy) error { @@ -3804,7 +3804,7 @@ func (f *Manager) addToGraph(completeChan *channeldb.OpenChannel, // 'addedToGraph') and the channel is ready to be used. This is the last // step in the channel opening process, and the opening state will be deleted // from the database if successful. -func (f *Manager) annAfterSixConfs(completeChan *channeldb.OpenChannel, +func (f *Manager) annAfterSixConfs(completeChan *chanstate.OpenChannel, shortChanID *lnwire.ShortChannelID) error { // If this channel is not meant to be announced to the greater network, @@ -3954,7 +3954,7 @@ func (f *Manager) annAfterSixConfs(completeChan *channeldb.OpenChannel, // waitForZeroConfChannel is called when the state is addedToGraph with // a zero-conf channel. This will wait for the real confirmation, add the // confirmed SCID to the router graph, and then announce after six confs. -func (f *Manager) waitForZeroConfChannel(c *channeldb.OpenChannel) error { +func (f *Manager) waitForZeroConfChannel(c *chanstate.OpenChannel) error { // First we'll check whether the channel is confirmed on-chain. If it // is already confirmed, the chainntnfs subsystem will return with the // confirmed tx. Otherwise, we'll wait here until confirmation occurs. @@ -4045,7 +4045,7 @@ func (f *Manager) waitForZeroConfChannel(c *channeldb.OpenChannel) error { // genFirstStateMusigNonce generates a nonces for the "first" local state. This // is the verification nonce for the state created for us after the initial // commitment transaction signed as part of the funding flow. -func genFirstStateMusigNonce(channel *channeldb.OpenChannel, +func genFirstStateMusigNonce(channel *chanstate.OpenChannel, ) (*musig2.Nonces, error) { musig2ShaChain, err := channeldb.DeriveMusig2Shachain( @@ -4420,7 +4420,7 @@ func (f *Manager) processChannelReady(peer lnpeer.Peer, // channelReady message, once the remote's channelReady is processed, the // channel is now active, thus we change its state to `addedToGraph` to // let the channel start handling routing. -func (f *Manager) handleChannelReadyReceived(channel *channeldb.OpenChannel, +func (f *Manager) handleChannelReadyReceived(channel *chanstate.OpenChannel, scid *lnwire.ShortChannelID, pendingChanID PendingChanID, updateChan chan<- *lnrpc.OpenStatusUpdate) error { @@ -4516,7 +4516,7 @@ func (f *Manager) handleChannelReadyReceived(channel *channeldb.OpenChannel, // policy set for the given channel. If we don't, we'll fall back to the default // values. func (f *Manager) ensureInitialForwardingPolicy(chanID lnwire.ChannelID, - channel *channeldb.OpenChannel) error { + channel *chanstate.OpenChannel) error { // Before we can add the channel to the peer, we'll need to ensure that // we have an initial forwarding policy set. This should always be the diff --git a/funding/manager_test.go b/funding/manager_test.go index 0dd9f472fc5..1d5fffca6d7 100644 --- a/funding/manager_test.go +++ b/funding/manager_test.go @@ -31,6 +31,7 @@ import ( acpt "github.com/lightningnetwork/lnd/chanacceptor" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channelnotifier" + "github.com/lightningnetwork/lnd/chanstate" "github.com/lightningnetwork/lnd/discovery" "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph" @@ -250,7 +251,7 @@ func (m *mockChanEvent) NotifyOpenChannelEvent(outpoint wire.OutPoint, } func (m *mockChanEvent) NotifyPendingOpenChannelEvent(outpoint wire.OutPoint, - pendingChannel *channeldb.OpenChannel, + pendingChannel *chanstate.OpenChannel, remotePub *btcec.PublicKey) { m.pendingOpenEvent <- channelnotifier.PendingOpenChannelEvent{ @@ -499,7 +500,7 @@ func createTestFundingManager(t *testing.T, privKey *btcec.PrivateKey, }, TempChanIDSeed: chanIDSeed, FindChannel: func(node *btcec.PublicKey, - chanID lnwire.ChannelID) (*channeldb.OpenChannel, + chanID lnwire.ChannelID) (*chanstate.OpenChannel, error) { nodeChans, err := cdb.FetchOpenChannels(node) @@ -549,7 +550,7 @@ func createTestFundingManager(t *testing.T, privKey *btcec.PrivateKey, RequiredRemoteMaxHTLCs: func(chanAmt btcutil.Amount) uint16 { return uint16(input.MaxHTLCNumber / 2) }, - WatchNewChannel: func(*channeldb.OpenChannel, + WatchNewChannel: func(*chanstate.OpenChannel, *btcec.PublicKey) error { return nil @@ -5289,7 +5290,7 @@ func TestChannelReadyUnknownChannelID(t *testing.T) { cfg.FindChannel = func( node *btcec.PublicKey, chanID lnwire.ChannelID, - ) (*channeldb.OpenChannel, error) { + ) (*chanstate.OpenChannel, error) { findChannelCalls.Add(1) diff --git a/htlcswitch/circuit_map.go b/htlcswitch/circuit_map.go index 15d4b5ffca4..299abffcd0a 100644 --- a/htlcswitch/circuit_map.go +++ b/htlcswitch/circuit_map.go @@ -6,7 +6,7 @@ import ( "fmt" "sync" - "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/chanstate" "github.com/lightningnetwork/lnd/htlcswitch/hop" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnutils" @@ -203,12 +203,12 @@ type CircuitMapConfig struct { // FetchAllOpenChannels is a function that fetches all currently open // channels from the channel database. - FetchAllOpenChannels func() ([]*channeldb.OpenChannel, error) + FetchAllOpenChannels func() ([]*chanstate.OpenChannel, error) // FetchClosedChannels is a function that fetches all closed channels // from the channel database. FetchClosedChannels func( - pendingOnly bool) ([]*channeldb.ChannelCloseSummary, error) + pendingOnly bool) ([]*chanstate.ChannelCloseSummary, error) // ExtractErrorEncrypter derives the shared secret used to encrypt // errors from the obfuscator's ephemeral public key. diff --git a/htlcswitch/circuit_map_test.go b/htlcswitch/circuit_map_test.go index 9bbb2c051a7..28d81ea2223 100644 --- a/htlcswitch/circuit_map_test.go +++ b/htlcswitch/circuit_map_test.go @@ -9,6 +9,7 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/chanstate" "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnwire" @@ -362,7 +363,7 @@ func createTestCloseChannelSummery(tx kvdb.RwTx, isPending bool, } outputPoint := wire.OutPoint{Hash: hash1, Index: 1} - ccs := &channeldb.ChannelCloseSummary{ + ccs := &chanstate.ChannelCloseSummary{ ChanPoint: outputPoint, ShortChanID: chanID, ChainHash: hash1, @@ -371,7 +372,7 @@ func createTestCloseChannelSummery(tx kvdb.RwTx, isPending bool, RemotePub: testEphemeralKey, Capacity: btcutil.Amount(10000), SettledBalance: btcutil.Amount(50000), - CloseType: channeldb.RemoteForceClose, + CloseType: chanstate.RemoteForceClose, IsPending: isPending, } var b bytes.Buffer @@ -389,7 +390,7 @@ func createTestCloseChannelSummery(tx kvdb.RwTx, isPending bool, func serializeChannelCloseSummary( w io.Writer, - cs *channeldb.ChannelCloseSummary) error { + cs *chanstate.ChannelCloseSummary) error { err := channeldb.WriteElements( w, diff --git a/htlcswitch/interfaces.go b/htlcswitch/interfaces.go index 4739afff6ec..6bf5ff7a575 100644 --- a/htlcswitch/interfaces.go +++ b/htlcswitch/interfaces.go @@ -6,6 +6,7 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/chanstate" "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/invoices" @@ -354,7 +355,7 @@ type TowerClient interface { // parameters within the client. This should be called during link // startup to ensure that the client is able to support the link during // operation. - RegisterChannel(lnwire.ChannelID, channeldb.ChannelType) error + RegisterChannel(lnwire.ChannelID, chanstate.ChannelType) error // BackupState initiates a request to back up a particular revoked // state. If the method returns nil, the backup is guaranteed to be diff --git a/htlcswitch/link.go b/htlcswitch/link.go index dd0f5d37ac2..7cde9459520 100644 --- a/htlcswitch/link.go +++ b/htlcswitch/link.go @@ -16,6 +16,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btclog/v2" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/chanstate" "github.com/lightningnetwork/lnd/contractcourt" "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph/db/models" @@ -258,7 +259,7 @@ type ChannelLinkConfig struct { // NotifyChannelUpdate allows the link to tell the ChannelNotifier when // a channel's state has been updated. - NotifyChannelUpdate func(*channeldb.OpenChannel) + NotifyChannelUpdate func(*chanstate.OpenChannel) // HtlcNotifier is an instance of a htlcNotifier which we will pipe htlc // events through. @@ -2372,7 +2373,7 @@ type dustClosure func(feerate chainfee.SatPerKWeight, incoming bool, whoseCommit lntypes.ChannelParty, amt btcutil.Amount) bool // dustHelper is used to construct the dustClosure. -func dustHelper(chantype channeldb.ChannelType, localDustLimit, +func dustHelper(chantype chanstate.ChannelType, localDustLimit, remoteDustLimit btcutil.Amount) dustClosure { isDust := func(feerate chainfee.SatPerKWeight, incoming bool, diff --git a/htlcswitch/link_test.go b/htlcswitch/link_test.go index 29b4f902d0b..9b65faa3c18 100644 --- a/htlcswitch/link_test.go +++ b/htlcswitch/link_test.go @@ -25,6 +25,7 @@ import ( sphinx "github.com/lightningnetwork/lightning-onion" "github.com/lightningnetwork/lnd/build" "github.com/lightningnetwork/lnd/channeldb" + cstate "github.com/lightningnetwork/lnd/chanstate" "github.com/lightningnetwork/lnd/contractcourt" "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph/db/models" @@ -32,7 +33,6 @@ import ( "github.com/lightningnetwork/lnd/htlcswitch/hop" "github.com/lightningnetwork/lnd/input" invpkg "github.com/lightningnetwork/lnd/invoices" - "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnpeer" "github.com/lightningnetwork/lnd/lntest/wait" "github.com/lightningnetwork/lnd/lntypes" @@ -2174,7 +2174,7 @@ func newSingleLinkTestHarness(t *testing.T, chanAmt, pCache := newMockPreimageCache() - aliceDb := aliceLc.channel.State().Db.GetParentDB() + aliceDb := testChannelStateDB(t, aliceLc.channel).GetParentDB() aliceSwitch, err := initSwitchWithDB(testStartingHeight, aliceDb) if err != nil { return singleLinkTestHarness{}, err @@ -2241,7 +2241,7 @@ func newSingleLinkTestHarness(t *testing.T, chanAmt, MaxFeeAllocation: DefaultMaxLinkFeeAllocation, NotifyActiveLink: func(wire.OutPoint) {}, NotifyActiveChannel: func(wire.OutPoint) {}, - NotifyChannelUpdate: func(*channeldb.OpenChannel) {}, + NotifyChannelUpdate: func(*cstate.OpenChannel) {}, NotifyInactiveChannel: func(wire.OutPoint) {}, NotifyInactiveLinkEvent: func(wire.OutPoint) {}, HtlcNotifier: aliceSwitch.cfg.HtlcNotifier, @@ -4854,7 +4854,7 @@ func (h *persistentLinkHarness) restartLink( pCache = newMockPreimageCache() ) - aliceDb := aliceChannel.State().Db.GetParentDB() + aliceDb := testChannelStateDB(t, aliceChannel).GetParentDB() if restartSwitch { var err error h.hSwitch, err = initSwitchWithDB(testStartingHeight, aliceDb) @@ -4932,7 +4932,7 @@ func (h *persistentLinkHarness) restartLink( NotifyActiveChannel: func(wire.OutPoint) {}, NotifyInactiveChannel: func(wire.OutPoint) {}, NotifyInactiveLinkEvent: func(wire.OutPoint) {}, - NotifyChannelUpdate: func(*channeldb.OpenChannel) {}, + NotifyChannelUpdate: func(*cstate.OpenChannel) {}, HtlcNotifier: h.hSwitch.cfg.HtlcNotifier, SyncStates: syncStates, GetAliases: getAliases, @@ -5769,42 +5769,14 @@ func TestChannelLinkCleanupSpuriousResponses(t *testing.T) { } } -type mockPackager struct { - failLoadFwdPkgs bool +type mockFailLoadFwdPkgStore struct { + cstate.Store } -func (*mockPackager) AddFwdPkg(tx kvdb.RwTx, fwdPkg *channeldb.FwdPkg) error { - return nil -} - -func (*mockPackager) SetFwdFilter(tx kvdb.RwTx, height uint64, - fwdFilter *channeldb.PkgFilter) error { - return nil -} - -func (*mockPackager) AckAddHtlcs(tx kvdb.RwTx, - addRefs ...channeldb.AddRef) error { - return nil -} - -func (m *mockPackager) LoadFwdPkgs(tx kvdb.RTx) ([]*channeldb.FwdPkg, error) { - if m.failLoadFwdPkgs { - return nil, fmt.Errorf("failing LoadFwdPkgs") - } - return nil, nil -} - -func (*mockPackager) RemovePkg(tx kvdb.RwTx, height uint64) error { - return nil -} +func (m *mockFailLoadFwdPkgStore) LoadFwdPkgs( + *cstate.OpenChannel) ([]*channeldb.FwdPkg, error) { -func (*mockPackager) Wipe(tx kvdb.RwTx) error { - return nil -} - -func (*mockPackager) AckSettleFails(tx kvdb.RwTx, - settleFailRefs ...channeldb.SettleFailRef) error { - return nil + return nil, fmt.Errorf("failing LoadFwdPkgs") } // TestChannelLinkFail tests that we will fail the channel, and force close the @@ -5880,10 +5852,10 @@ func TestChannelLinkFail(t *testing.T) { func(c *channelLink) { // We make the call to resolveFwdPkgs fail by // making the underlying forwarder fail. - pkg := &mockPackager{ - failLoadFwdPkgs: true, + state := c.channel.State() + state.Db = &mockFailLoadFwdPkgStore{ + Store: state.Db, } - c.channel.State().Packager = pkg }, func(*testing.T, *Switch, *channelLink, *lnwallet.LightningChannel) { diff --git a/htlcswitch/mailbox_test.go b/htlcswitch/mailbox_test.go index 57a581c4b14..1aa74bba2e4 100644 --- a/htlcswitch/mailbox_test.go +++ b/htlcswitch/mailbox_test.go @@ -9,6 +9,7 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/chanstate" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/lnmock" "github.com/lightningnetwork/lnd/lnwallet/chainfee" @@ -586,7 +587,7 @@ func TestMailBoxDustHandling(t *testing.T) { }) } -func testMailBoxDust(t *testing.T, chantype channeldb.ChannelType) { +func testMailBoxDust(t *testing.T, chantype chanstate.ChannelType) { t.Parallel() ctx := newMailboxContext(t, time.Now(), testExpiry) diff --git a/htlcswitch/mock.go b/htlcswitch/mock.go index 70bd73c37d2..9912ad648a8 100644 --- a/htlcswitch/mock.go +++ b/htlcswitch/mock.go @@ -22,6 +22,7 @@ import ( sphinx "github.com/lightningnetwork/lightning-onion" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/chanstate" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/contractcourt" "github.com/lightningnetwork/lnd/fn/v2" @@ -74,7 +75,7 @@ func (m *mockPreimageCache) AddPreimages(preimages ...lntypes.Preimage) error { } func (m *mockPreimageCache) SubscribeUpdates( - chanID lnwire.ShortChannelID, htlc *channeldb.HTLC, + chanID lnwire.ShortChannelID, htlc *chanstate.HTLC, payload *hop.Payload, nextHopOnionBlob []byte) (*contractcourt.WitnessSubscription, error) { diff --git a/htlcswitch/switch.go b/htlcswitch/switch.go index a3aae809b93..2bbf1cbc535 100644 --- a/htlcswitch/switch.go +++ b/htlcswitch/switch.go @@ -15,6 +15,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/chanstate" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/contractcourt" "github.com/lightningnetwork/lnd/fn/v2" @@ -150,16 +151,16 @@ type Config struct { // FetchAllOpenChannels is a function that fetches all currently open // channels from the channel database. - FetchAllOpenChannels func() ([]*channeldb.OpenChannel, error) + FetchAllOpenChannels func() ([]*chanstate.OpenChannel, error) // FetchAllChannels is a function that fetches all pending open, open, // and waiting close channels from the database. - FetchAllChannels func() ([]*channeldb.OpenChannel, error) + FetchAllChannels func() ([]*chanstate.OpenChannel, error) // FetchClosedChannels is a function that fetches all closed channels // from the channel database. FetchClosedChannels func( - pendingOnly bool) ([]*channeldb.ChannelCloseSummary, error) + pendingOnly bool) ([]*chanstate.ChannelCloseSummary, error) // SwitchPackager provides access to the forwarding packages of all // active channels. This gives the switch the ability to read arbitrary diff --git a/htlcswitch/test_utils.go b/htlcswitch/test_utils.go index 2e084250943..05524cc69bd 100644 --- a/htlcswitch/test_utils.go +++ b/htlcswitch/test_utils.go @@ -24,6 +24,7 @@ import ( "github.com/btcsuite/btcd/wire" sphinx "github.com/lightningnetwork/lightning-onion" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/chanstate" "github.com/lightningnetwork/lnd/contractcourt" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/htlcswitch/hop" @@ -43,6 +44,19 @@ import ( "github.com/stretchr/testify/require" ) +func testChannelStateDB(t testing.TB, + channel *lnwallet.LightningChannel) *channeldb.ChannelStateDB { + + t.Helper() + + cdb, ok := channel.State().Db.(*channeldb.ChannelStateDB) + if !ok { + t.Fatalf("expected ChannelStateDB, got %T", channel.State().Db) + } + + return cdb +} + // maxInflightHtlcs specifies the max number of inflight HTLCs. This number is // chosen to be smaller than the default 483 so the test can run faster. const maxInflightHtlcs = 50 @@ -291,7 +305,7 @@ func createTestChannel(t *testing.T, alicePrivKey, bobPrivKey []byte, CommitSig: bytes.Repeat([]byte{1}, 71), } - aliceChannelState := &channeldb.OpenChannel{ + aliceChannelState := &chanstate.OpenChannel{ LocalChanCfg: aliceCfg, RemoteChanCfg: bobCfg, IdentityPub: aliceKeyPub, @@ -306,11 +320,10 @@ func createTestChannel(t *testing.T, alicePrivKey, bobPrivKey []byte, RemoteCommitment: aliceCommit, ShortChannelID: chanID, Db: dbAlice.ChannelStateDB(), - Packager: channeldb.NewChannelPackager(chanID), FundingTxn: channels.TestFundingTx, } - bobChannelState := &channeldb.OpenChannel{ + bobChannelState := &chanstate.OpenChannel{ LocalChanCfg: bobCfg, RemoteChanCfg: aliceCfg, IdentityPub: bobKeyPub, @@ -325,7 +338,6 @@ func createTestChannel(t *testing.T, alicePrivKey, bobPrivKey []byte, RemoteCommitment: bobCommit, ShortChannelID: chanID, Db: dbBob.ChannelStateDB(), - Packager: channeldb.NewChannelPackager(chanID), } if err := aliceChannelState.SyncPending(bobAddr, broadcastHeight); err != nil { @@ -403,7 +415,7 @@ func createTestChannel(t *testing.T, alicePrivKey, bobPrivKey []byte, "channel: %w", err) } - var aliceStoredChannel *channeldb.OpenChannel + var aliceStoredChannel *chanstate.OpenChannel for _, channel := range aliceStoredChannels { if channel.FundingOutpoint.String() == prevOut.String() { aliceStoredChannel = channel @@ -451,7 +463,7 @@ func createTestChannel(t *testing.T, alicePrivKey, bobPrivKey []byte, "%w", err) } - var bobStoredChannel *channeldb.OpenChannel + var bobStoredChannel *chanstate.OpenChannel for _, channel := range bobStoredChannels { if channel.FundingOutpoint.String() == prevOut.String() { bobStoredChannel = channel @@ -954,9 +966,9 @@ func newThreeHopNetwork(t testing.TB, aliceChannel, firstBobChannel, secondBobChannel, carolChannel *lnwallet.LightningChannel, startingHeight uint32, opts ...serverOption) *threeHopNetwork { - aliceDb := aliceChannel.State().Db.GetParentDB() - bobDb := firstBobChannel.State().Db.GetParentDB() - carolDb := carolChannel.State().Db.GetParentDB() + aliceDb := testChannelStateDB(t, aliceChannel).GetParentDB() + bobDb := testChannelStateDB(t, firstBobChannel).GetParentDB() + carolDb := testChannelStateDB(t, carolChannel).GetParentDB() hopNetwork := newHopNetwork() @@ -1175,7 +1187,7 @@ func (h *hopNetwork) createChannelLink(server, peer *mockServer, NotifyActiveChannel: func(wire.OutPoint) {}, NotifyInactiveChannel: func(wire.OutPoint) {}, NotifyInactiveLinkEvent: func(wire.OutPoint) {}, - NotifyChannelUpdate: func(*channeldb.OpenChannel) {}, + NotifyChannelUpdate: func(*chanstate.OpenChannel) {}, HtlcNotifier: server.htlcSwitch.cfg.HtlcNotifier, GetAliases: getAliases, ShouldFwdExpAccountability: func() bool { return true }, @@ -1233,8 +1245,8 @@ func newTwoHopNetwork(t testing.TB, aliceChannel, bobChannel *lnwallet.LightningChannel, startingHeight uint32) *twoHopNetwork { - aliceDb := aliceChannel.State().Db.GetParentDB() - bobDb := bobChannel.State().Db.GetParentDB() + aliceDb := testChannelStateDB(t, aliceChannel).GetParentDB() + bobDb := testChannelStateDB(t, bobChannel).GetParentDB() hopNetwork := newHopNetwork() diff --git a/lnpeer/peer.go b/lnpeer/peer.go index cb6bc9867a1..4bd990a46f6 100644 --- a/lnpeer/peer.go +++ b/lnpeer/peer.go @@ -5,7 +5,7 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/wire" - "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/chanstate" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" ) @@ -14,7 +14,7 @@ import ( // with the set of channel options that may change how the channel is created. // This can be used to pass along the nonce state needed for taproot channels. type NewChannel struct { - *channeldb.OpenChannel + *chanstate.OpenChannel // ChanOpts can be used to change how the channel is created. ChanOpts []lnwallet.ChannelOpt diff --git a/lnrpc/invoicesrpc/addinvoice.go b/lnrpc/invoicesrpc/addinvoice.go index b4d39a99c58..e7224a4519a 100644 --- a/lnrpc/invoicesrpc/addinvoice.go +++ b/lnrpc/invoicesrpc/addinvoice.go @@ -17,7 +17,6 @@ import ( "github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" - "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/chanstate" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/invoices" @@ -730,7 +729,7 @@ type HopHintInfo struct { ScidAliasFeature bool } -func newHopHintInfo(c *channeldb.OpenChannel, isActive bool) *HopHintInfo { +func newHopHintInfo(c *chanstate.OpenChannel, isActive bool) *HopHintInfo { isPublic := c.ChannelFlags&lnwire.FFAnnounceChannel != 0 return &HopHintInfo{ @@ -781,7 +780,7 @@ type SelectHopHintsCfg struct { // FetchAllChannels retrieves all open channels currently stored // within the database. - FetchAllChannels func() ([]*channeldb.OpenChannel, error) + FetchAllChannels func() ([]*chanstate.OpenChannel, error) // IsChannelActive checks whether the channel identified by the provided // ChannelID is considered active. @@ -844,7 +843,7 @@ func sufficientHints(nHintsLeft int, currentAmount, // getPotentialHints returns a slice of open channels that should be considered // for the hopHint list in an invoice. The slice is sorted in descending order // based on the remote balance. -func getPotentialHints(cfg *SelectHopHintsCfg) ([]*channeldb.OpenChannel, +func getPotentialHints(cfg *SelectHopHintsCfg) ([]*chanstate.OpenChannel, error) { // TODO(positiveblue): get the channels slice already filtered by @@ -854,7 +853,7 @@ func getPotentialHints(cfg *SelectHopHintsCfg) ([]*channeldb.OpenChannel, return nil, err } - privateChannels := make([]*channeldb.OpenChannel, 0, len(openChannels)) + privateChannels := make([]*chanstate.OpenChannel, 0, len(openChannels)) for _, oc := range openChannels { isPublic := oc.ChannelFlags&lnwire.FFAnnounceChannel != 0 if !isPublic { @@ -876,7 +875,7 @@ func getPotentialHints(cfg *SelectHopHintsCfg) ([]*channeldb.OpenChannel, // shouldIncludeChannel returns true if the channel passes all the checks to // be a hopHint in a given invoice. func shouldIncludeChannel(cfg *SelectHopHintsCfg, - channel *channeldb.OpenChannel, + channel *chanstate.OpenChannel, alreadyIncluded map[uint64]bool) (zpay32.HopHint, lnwire.MilliSatoshi, bool) { @@ -922,7 +921,7 @@ func shouldIncludeChannel(cfg *SelectHopHintsCfg, // descending priority. func selectHopHints(cfg *SelectHopHintsCfg, nHintsLeft int, targetBandwidth lnwire.MilliSatoshi, - potentialHints []*channeldb.OpenChannel, + potentialHints []*chanstate.OpenChannel, alreadyIncluded map[uint64]bool) [][]zpay32.HopHint { currentBandwidth := lnwire.MilliSatoshi(0) diff --git a/lnrpc/invoicesrpc/addinvoice_test.go b/lnrpc/invoicesrpc/addinvoice_test.go index 104b2873dde..ec0e4d1def7 100644 --- a/lnrpc/invoicesrpc/addinvoice_test.go +++ b/lnrpc/invoicesrpc/addinvoice_test.go @@ -8,7 +8,7 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" - "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/chanstate" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" @@ -59,11 +59,15 @@ func (h *hopHintsConfigMock) GetAlias( // FetchAllChannels retrieves all open channels currently stored // within the database. -func (h *hopHintsConfigMock) FetchAllChannels() ([]*channeldb.OpenChannel, +func (h *hopHintsConfigMock) FetchAllChannels() ([]*chanstate.OpenChannel, error) { args := h.Mock.Called() - return args.Get(0).([]*channeldb.OpenChannel), args.Error(1) + + channels, ok := args.Get(0).([]*chanstate.OpenChannel) + require.True(h.t, ok) + + return channels, args.Error(1) } // FetchChannelEdgesByID attempts to lookup the two directed edges for @@ -102,7 +106,7 @@ func getTestPubKey() *btcec.PublicKey { var shouldIncludeChannelTestCases = []struct { name string setupMock func(*hopHintsConfigMock) - channel *channeldb.OpenChannel + channel *chanstate.OpenChannel alreadyIncluded map[uint64]bool cfg *SelectHopHintsCfg hopHint zpay32.HopHint @@ -112,7 +116,7 @@ var shouldIncludeChannelTestCases = []struct { name: "already included channels should not be included " + "again", alreadyIncluded: map[uint64]bool{1: true}, - channel: &channeldb.OpenChannel{ + channel: &chanstate.OpenChannel{ ShortChannelID: lnwire.NewShortChanIDFromInt(1), }, include: false, @@ -127,7 +131,7 @@ var shouldIncludeChannelTestCases = []struct { "IsChannelActive", chanID, ).Once().Return(true) }, - channel: &channeldb.OpenChannel{ + channel: &chanstate.OpenChannel{ FundingOutpoint: wire.OutPoint{ Index: 0, }, @@ -144,7 +148,7 @@ var shouldIncludeChannelTestCases = []struct { "IsChannelActive", chanID, ).Once().Return(false) }, - channel: &channeldb.OpenChannel{ + channel: &chanstate.OpenChannel{ FundingOutpoint: wire.OutPoint{ Index: 0, }, @@ -166,7 +170,7 @@ var shouldIncludeChannelTestCases = []struct { "IsPublicNode", mock.Anything, ).Once().Return(false, nil) }, - channel: &channeldb.OpenChannel{ + channel: &chanstate.OpenChannel{ FundingOutpoint: wire.OutPoint{ Index: 0, }, @@ -201,7 +205,7 @@ var shouldIncludeChannelTestCases = []struct { "FetchChannelEdgesByID", mock.Anything, ).Once().Return(nil, nil, nil, fmt.Errorf("no edge")) }, - channel: &channeldb.OpenChannel{ + channel: &chanstate.OpenChannel{ FundingOutpoint: wire.OutPoint{ Index: 0, }, @@ -237,12 +241,12 @@ var shouldIncludeChannelTestCases = []struct { "GetAlias", mock.Anything, ).Once().Return(lnwire.ShortChannelID{}, nil) }, - channel: &channeldb.OpenChannel{ + channel: &chanstate.OpenChannel{ FundingOutpoint: wire.OutPoint{ Index: 0, }, IdentityPub: getTestPubKey(), - ChanType: channeldb.ScidAliasFeatureBit, + ChanType: chanstate.ScidAliasFeatureBit, }, include: false, }, { @@ -275,12 +279,12 @@ var shouldIncludeChannelTestCases = []struct { "GetAlias", mock.Anything, ).Once().Return(alias, nil) }, - channel: &channeldb.OpenChannel{ + channel: &chanstate.OpenChannel{ FundingOutpoint: wire.OutPoint{ Index: 0, }, IdentityPub: getTestPubKey(), - ChanType: channeldb.ScidAliasFeatureBit, + ChanType: chanstate.ScidAliasFeatureBit, }, include: false, }, { @@ -328,7 +332,7 @@ var shouldIncludeChannelTestCases = []struct { nil, ) }, - channel: &channeldb.OpenChannel{ + channel: &chanstate.OpenChannel{ FundingOutpoint: wire.OutPoint{ Index: 1, }, @@ -375,7 +379,7 @@ var shouldIncludeChannelTestCases = []struct { }, nil, ) }, - channel: &channeldb.OpenChannel{ + channel: &chanstate.OpenChannel{ FundingOutpoint: wire.OutPoint{ Index: 1, }, @@ -428,13 +432,13 @@ var shouldIncludeChannelTestCases = []struct { "GetAlias", mock.Anything, ).Once().Return(aliasSCID, nil) }, - channel: &channeldb.OpenChannel{ + channel: &chanstate.OpenChannel{ FundingOutpoint: wire.OutPoint{ Index: 1, }, IdentityPub: getTestPubKey(), ShortChannelID: lnwire.NewShortChanIDFromInt(12), - ChanType: channeldb.ScidAliasFeatureBit, + ChanType: chanstate.ScidAliasFeatureBit, }, hopHint: zpay32.HopHint{ NodeID: getTestPubKey(), @@ -554,7 +558,7 @@ var populateHopHintsTestCases = []struct { setupMock: func(h *hopHintsConfigMock) { fundingOutpoint := wire.OutPoint{Index: 9} chanID := lnwire.NewChanIDFromOutPoint(fundingOutpoint) - allChannels := []*channeldb.OpenChannel{ + allChannels := []*chanstate.OpenChannel{ { FundingOutpoint: fundingOutpoint, ShortChannelID: lnwire.NewShortChanIDFromInt(9), @@ -601,9 +605,9 @@ var populateHopHintsTestCases = []struct { fundingOutpoint := wire.OutPoint{Index: 9} chanID := lnwire.NewChanIDFromOutPoint(fundingOutpoint) remoteBalance := lnwire.MilliSatoshi(10_000_000) - allChannels := []*channeldb.OpenChannel{ + allChannels := []*chanstate.OpenChannel{ { - LocalCommitment: channeldb.ChannelCommitment{ + LocalCommitment: chanstate.ChannelCommitment{ RemoteBalance: remoteBalance, }, FundingOutpoint: fundingOutpoint, @@ -652,12 +656,12 @@ var populateHopHintsTestCases = []struct { fundingOutpoint := wire.OutPoint{Index: 9} chanID := lnwire.NewChanIDFromOutPoint(fundingOutpoint) remoteBalance := lnwire.MilliSatoshi(10_000_000) - allChannels := []*channeldb.OpenChannel{ + allChannels := []*chanstate.OpenChannel{ // Because the channels with higher remote balance have // enough bandwidth we should never use this one. {}, { - LocalCommitment: channeldb.ChannelCommitment{ + LocalCommitment: chanstate.ChannelCommitment{ RemoteBalance: remoteBalance, }, FundingOutpoint: fundingOutpoint, @@ -851,11 +855,11 @@ func setupMockTwoChannels(h *hopHintsConfigMock) (lnwire.ChannelID, chanID2 := lnwire.NewChanIDFromOutPoint(fundingOutpoint2) remoteBalance2 := lnwire.MilliSatoshi(1_000_000) - allChannels := []*channeldb.OpenChannel{ + allChannels := []*chanstate.OpenChannel{ // After sorting we will first process chanID1 and then // chanID2. { - LocalCommitment: channeldb.ChannelCommitment{ + LocalCommitment: chanstate.ChannelCommitment{ RemoteBalance: remoteBalance2, }, FundingOutpoint: fundingOutpoint2, @@ -863,7 +867,7 @@ func setupMockTwoChannels(h *hopHintsConfigMock) (lnwire.ChannelID, IdentityPub: getTestPubKey(), }, { - LocalCommitment: channeldb.ChannelCommitment{ + LocalCommitment: chanstate.ChannelCommitment{ RemoteBalance: remoteBalance1, }, FundingOutpoint: fundingOutpoint1, diff --git a/lnrpc/walletrpc/walletkit_server.go b/lnrpc/walletrpc/walletkit_server.go index bece6001d6a..d7f8fb41bb8 100644 --- a/lnrpc/walletrpc/walletkit_server.go +++ b/lnrpc/walletrpc/walletkit_server.go @@ -32,6 +32,7 @@ import ( "github.com/btcsuite/btcwallet/wtxmgr" "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/chanstate" "github.com/lightningnetwork/lnd/contractcourt" "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" @@ -1184,7 +1185,7 @@ func (w *WalletKit) BumpFee(ctx context.Context, // getWaitingCloseChannel returns the waiting close channel in case it does // exist in the underlying channel state database. func (w *WalletKit) getWaitingCloseChannel( - chanPoint wire.OutPoint) (*channeldb.OpenChannel, error) { + chanPoint wire.OutPoint) (*chanstate.OpenChannel, error) { // Fetch all channels, which still have their commitment transaction not // confirmed (waiting close channels). @@ -1193,7 +1194,7 @@ func (w *WalletKit) getWaitingCloseChannel( return nil, err } - channel := fn.Find(chans, func(c *channeldb.OpenChannel) bool { + channel := fn.Find(chans, func(c *chanstate.OpenChannel) bool { return c.FundingOutpoint == chanPoint }) diff --git a/lnwallet/aux_leaf_store.go b/lnwallet/aux_leaf_store.go index 0a850503780..a682a32e540 100644 --- a/lnwallet/aux_leaf_store.go +++ b/lnwallet/aux_leaf_store.go @@ -5,6 +5,7 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/chanstate" "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lntypes" @@ -55,7 +56,7 @@ type CommitAuxLeaves struct { } // AuxChanState is a struct that holds certain fields of the -// channeldb.OpenChannel struct that are used by the aux components. The data +// chanstate.OpenChannel struct that are used by the aux components. The data // is copied over to prevent accidental mutation of the original channel state. type AuxChanState struct { // ChanType denotes which type of channel this is. @@ -110,7 +111,7 @@ type AuxChanState struct { } // NewAuxChanState creates a new AuxChanState from the given channel state. -func NewAuxChanState(chanState *channeldb.OpenChannel) AuxChanState { +func NewAuxChanState(chanState *chanstate.OpenChannel) AuxChanState { peerPub := chanState.IdentityPub.SerializeCompressed() return AuxChanState{ @@ -202,7 +203,7 @@ type AuxLeafStore interface { // auxLeavesFromView is used to derive the set of commit aux leaves (if any), // that are needed to create a new commitment transaction using the original // (unfiltered) htlc view. -func auxLeavesFromView(leafStore AuxLeafStore, chanState *channeldb.OpenChannel, +func auxLeavesFromView(leafStore AuxLeafStore, chanState *chanstate.OpenChannel, prevBlob fn.Option[tlv.Blob], originalView *HtlcView, whoseCommit lntypes.ChannelParty, ourBalance, theirBalance lnwire.MilliSatoshi, @@ -225,7 +226,7 @@ func auxLeavesFromView(leafStore AuxLeafStore, chanState *channeldb.OpenChannel, // updateAuxBlob is a helper function that attempts to update the aux blob // given the prior and current state information. -func updateAuxBlob(leafStore AuxLeafStore, chanState *channeldb.OpenChannel, +func updateAuxBlob(leafStore AuxLeafStore, chanState *chanstate.OpenChannel, prevBlob fn.Option[tlv.Blob], nextViewUnfiltered *HtlcView, whoseCommit lntypes.ChannelParty, ourBalance, theirBalance lnwire.MilliSatoshi, diff --git a/lnwallet/channel.go b/lnwallet/channel.go index 78ca895655a..d659bb83e0d 100644 --- a/lnwallet/channel.go +++ b/lnwallet/channel.go @@ -24,6 +24,7 @@ import ( "github.com/btcsuite/btclog/v2" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/chanstate" "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/input" @@ -795,7 +796,7 @@ type LightningChannel struct { // state, which we are able to broadcast safely. commitChains lntypes.Dual[*commitmentChain] - channelState *channeldb.OpenChannel + channelState *chanstate.OpenChannel commitBuilder *CommitmentBuilder @@ -953,7 +954,7 @@ func defaultChannelOpts() *channelOpts { // automatically persist pertinent state to the database in an efficient // manner. func NewLightningChannel(signer input.Signer, - state *channeldb.OpenChannel, + state *chanstate.OpenChannel, sigPool *SigPool, chanOpts ...ChannelOpt) (*LightningChannel, error) { opts := defaultChannelOpts() @@ -2098,7 +2099,9 @@ type BreachRetribution struct { // nil, then the revocation log will be checked to see if it contains the info // required to construct the BreachRetribution. If the revocation log is missing // the required fields then ErrRevLogDataMissing will be returned. -func NewBreachRetribution(chanState *channeldb.OpenChannel, stateNum uint64, +// +//nolint:funlen +func NewBreachRetribution(chanState *chanstate.OpenChannel, stateNum uint64, breachHeight uint32, spendTx *wire.MsgTx, leafStore fn.Option[AuxLeafStore], auxResolver fn.Option[AuxContractResolver]) (*BreachRetribution, @@ -2398,7 +2401,7 @@ func NewBreachRetribution(chanState *channeldb.OpenChannel, stateNum uint64, // createHtlcRetribution is a helper function to construct an HtlcRetribution // based on the passed params. -func createHtlcRetribution(chanState *channeldb.OpenChannel, +func createHtlcRetribution(chanState *chanstate.OpenChannel, keyRing *CommitmentKeyRing, commitHash chainhash.Hash, commitmentSecret *btcec.PrivateKey, leaseExpiry uint32, htlc *channeldb.HTLCEntry, @@ -2525,7 +2528,7 @@ func createHtlcRetribution(chanState *channeldb.OpenChannel, // see if these fields are present there. If they are not, then // ErrRevLogDataMissing is returned. func createBreachRetribution(revokedLog *channeldb.RevocationLog, - spendTx *wire.MsgTx, chanState *channeldb.OpenChannel, + spendTx *wire.MsgTx, chanState *chanstate.OpenChannel, keyRing *CommitmentKeyRing, commitmentSecret *btcec.PrivateKey, leaseExpiry uint32, auxLeaves fn.Option[CommitAuxLeaves]) (*BreachRetribution, int64, int64, @@ -2642,7 +2645,7 @@ func createBreachRetribution(revokedLog *channeldb.RevocationLog, // BreachRetribution using a ChannelCommitment. Returns the constructed // retribution, our amount, their amount, and a possible non-nil error. func createBreachRetributionLegacy(revokedLog *channeldb.ChannelCommitment, - chanState *channeldb.OpenChannel, keyRing *CommitmentKeyRing, + chanState *chanstate.OpenChannel, keyRing *CommitmentKeyRing, commitmentSecret *btcec.PrivateKey, ourScript, theirScript input.ScriptDescriptor, leaseExpiry uint32) (*BreachRetribution, int64, int64, error) { @@ -2995,7 +2998,7 @@ func (lc *LightningChannel) fetchCommitmentView( // fundingTxIn returns the funding output as a transaction input. The input // returned by this function uses a max sequence number, so it isn't able to be // used with RBF by default. -func fundingTxIn(chanState *channeldb.OpenChannel) wire.TxIn { +func fundingTxIn(chanState *chanstate.OpenChannel) wire.TxIn { return *wire.NewTxIn(&chanState.FundingOutpoint, nil, nil) } @@ -3251,7 +3254,7 @@ func (lc *LightningChannel) fetchParent(entry *paymentDescriptor, // configured reserve. It also uses the balance delta for the party, to account // for entry amounts that have been processed already. func balanceAboveReserve(party lntypes.ChannelParty, delta int64, - channel *channeldb.OpenChannel) bool { + channel *chanstate.OpenChannel) bool { // We're going to access the channel state, so let's make sure we're // holding the lock. @@ -3340,7 +3343,7 @@ func (lc *LightningChannel) evaluateNoOpHtlc(entry *paymentDescriptor, // signature can be submitted to the sigPool to generate all the signatures // asynchronously and in parallel. func genRemoteHtlcSigJobs(keyRing *CommitmentKeyRing, - chanState *channeldb.OpenChannel, leaseExpiry uint32, + chanState *chanstate.OpenChannel, leaseExpiry uint32, remoteCommitView *commitment, leafStore fn.Option[AuxLeafStore]) ([]SignJob, []AuxSigJob, chan struct{}, error) { @@ -4970,7 +4973,7 @@ func (lc *LightningChannel) recordSettlement( // directly into the pool of workers. // //nolint:funlen -func genHtlcSigValidationJobs(chanState *channeldb.OpenChannel, +func genHtlcSigValidationJobs(chanState *chanstate.OpenChannel, localCommitmentView *commitment, keyRing *CommitmentKeyRing, htlcSigs []lnwire.Sig, leaseExpiry uint32, leafStore fn.Option[AuxLeafStore], auxSigner fn.Option[AuxSigner], @@ -6761,10 +6764,10 @@ func (lc *LightningChannel) ChannelPoint() wire.OutPoint { return lc.channelState.FundingOutpoint } -// ChannelState returns a copy of the internal channeldb.OpenChannel state +// ChannelState returns a copy of the internal chanstate.OpenChannel state // struct. Modifications to the returned struct will not be reflected within // the LightningChannel. -func (lc *LightningChannel) ChannelState() *channeldb.OpenChannel { +func (lc *LightningChannel) ChannelState() *chanstate.OpenChannel { return lc.channelState.Copy() } @@ -7074,7 +7077,7 @@ type UnilateralCloseSummary struct { // happen in case we have lost state) it should be set to an empty struct, in // which case we will attempt to sweep the non-HTLC output using the passed // commitPoint. -func NewUnilateralCloseSummary(chanState *channeldb.OpenChannel, //nolint:funlen +func NewUnilateralCloseSummary(chanState *chanstate.OpenChannel, signer input.Signer, commitSpend *chainntnfs.SpendDetail, remoteCommit channeldb.ChannelCommitment, commitPoint *btcec.PublicKey, leafStore fn.Option[AuxLeafStore], @@ -7415,7 +7418,7 @@ func newOutgoingHtlcResolution(signer input.Signer, commitTxHeight uint32, htlc *channeldb.HTLC, keyRing *CommitmentKeyRing, feePerKw chainfee.SatPerKWeight, csvDelay, leaseExpiry uint32, whoseCommit lntypes.ChannelParty, isCommitFromInitiator bool, - chanType channeldb.ChannelType, chanState *channeldb.OpenChannel, + chanType channeldb.ChannelType, chanState *chanstate.OpenChannel, auxLeaves fn.Option[CommitAuxLeaves], auxResolver fn.Option[AuxContractResolver], ) (*OutgoingHtlcResolution, error) { @@ -7789,7 +7792,7 @@ func newIncomingHtlcResolution(signer input.Signer, commitTxHeight uint32, htlc *channeldb.HTLC, keyRing *CommitmentKeyRing, feePerKw chainfee.SatPerKWeight, csvDelay, leaseExpiry uint32, whoseCommit lntypes.ChannelParty, isCommitFromInitiator bool, - chanType channeldb.ChannelType, chanState *channeldb.OpenChannel, + chanType channeldb.ChannelType, chanState *chanstate.OpenChannel, auxLeaves fn.Option[CommitAuxLeaves], auxResolver fn.Option[AuxContractResolver], ) (*IncomingHtlcResolution, error) { @@ -8174,7 +8177,7 @@ func extractHtlcResolutions(feePerKw chainfee.SatPerKWeight, localChanCfg, remoteChanCfg *channeldb.ChannelConfig, commitTx *wire.MsgTx, commitTxHeight uint32, chanType channeldb.ChannelType, isCommitFromInitiator bool, - leaseExpiry uint32, chanState *channeldb.OpenChannel, + leaseExpiry uint32, chanState *chanstate.OpenChannel, auxLeaves fn.Option[CommitAuxLeaves], auxResolver fn.Option[AuxContractResolver]) (*HtlcResolutions, error) { @@ -8389,7 +8392,7 @@ func (lc *LightningChannel) ForceClose(opts ...ForceCloseOpt) ( // NewLocalForceCloseSummary generates a LocalForceCloseSummary from the given // channel state. The passed commitTx must be a fully signed commitment // transaction corresponding to localCommit. -func NewLocalForceCloseSummary(chanState *channeldb.OpenChannel, +func NewLocalForceCloseSummary(chanState *chanstate.OpenChannel, signer input.Signer, commitTx *wire.MsgTx, commitTxHeight uint32, stateNum uint64, leafStore fn.Option[AuxLeafStore], auxResolver fn.Option[AuxContractResolver]) (*LocalForceCloseSummary, @@ -9044,7 +9047,7 @@ func (lc *LightningChannel) NewAnchorResolutions() (*AnchorResolutions, // NewAnchorResolution returns the information that is required to sweep the // local anchor. -func NewAnchorResolution(chanState *channeldb.OpenChannel, +func NewAnchorResolution(chanState *chanstate.OpenChannel, commitTx *wire.MsgTx, keyRing *CommitmentKeyRing, whoseCommit lntypes.ChannelParty) (*AnchorResolution, error) { @@ -10088,7 +10091,7 @@ func (lc *LightningChannel) IsPending() bool { } // State provides access to the channel's internal state. -func (lc *LightningChannel) State() *channeldb.OpenChannel { +func (lc *LightningChannel) State() *chanstate.OpenChannel { return lc.channelState } diff --git a/lnwallet/channel_test.go b/lnwallet/channel_test.go index ab96d339749..b93b66b084d 100644 --- a/lnwallet/channel_test.go +++ b/lnwallet/channel_test.go @@ -26,6 +26,7 @@ import ( "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/chanstate" "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/input" @@ -9279,7 +9280,7 @@ func TestEvaluateView(t *testing.T) { t.Run(test.name, func(t *testing.T) { isInitiator := test.channelInitiator == lntypes.Local lc := LightningChannel{ - channelState: &channeldb.OpenChannel{ + channelState: &chanstate.OpenChannel{ IsInitiator: isInitiator, TotalMSatSent: 0, TotalMSatReceived: 0, @@ -10071,7 +10072,7 @@ func testGetDustSum(t *testing.T, chantype channeldb.ChannelType) { // deriveDummyRetributionParams is a helper function that derives a list of // dummy params to assist retribution creation related tests. -func deriveDummyRetributionParams(chanState *channeldb.OpenChannel) (uint32, +func deriveDummyRetributionParams(chanState *chanstate.OpenChannel) (uint32, *CommitmentKeyRing, chainhash.Hash) { config := chanState.RemoteChanCfg diff --git a/lnwallet/commitment.go b/lnwallet/commitment.go index f715b9336cf..500b5fd3a0d 100644 --- a/lnwallet/commitment.go +++ b/lnwallet/commitment.go @@ -11,6 +11,7 @@ import ( "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/chanstate" "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lntypes" @@ -635,7 +636,7 @@ type CommitmentBuilder struct { // chanState is the underlying channel's state struct, used to // determine the type of channel we are dealing with, and relevant // parameters. - chanState *channeldb.OpenChannel + chanState *chanstate.OpenChannel // obfuscator is a 48-bit state hint that's used to obfuscate the // current state number on the commitment transactions. @@ -647,7 +648,7 @@ type CommitmentBuilder struct { } // NewCommitmentBuilder creates a new CommitmentBuilder from chanState. -func NewCommitmentBuilder(chanState *channeldb.OpenChannel, +func NewCommitmentBuilder(chanState *chanstate.OpenChannel, leafStore fn.Option[AuxLeafStore]) *CommitmentBuilder { // The anchor channel type MUST be tweakless. @@ -665,7 +666,9 @@ func NewCommitmentBuilder(chanState *channeldb.OpenChannel, // createStateHintObfuscator derives and assigns the state hint obfuscator for // the channel, which is used to encode the commitment height in the sequence // number of commitment transaction inputs. -func createStateHintObfuscator(state *channeldb.OpenChannel) [StateHintSize]byte { +func createStateHintObfuscator( + state *chanstate.OpenChannel) [StateHintSize]byte { + if state.IsInitiator { return DeriveStateHintObfuscator( state.LocalChanCfg.PaymentBasePoint.PubKey, @@ -1320,7 +1323,7 @@ func addHTLC(commitTx *wire.MsgTx, whoseCommit lntypes.ChannelParty, // output scripts and compares them against the outputs inside the commitment // to find the match. func findOutputIndexesFromRemote(revocationPreimage *chainhash.Hash, - chanState *channeldb.OpenChannel, + chanState *chanstate.OpenChannel, leafStore fn.Option[AuxLeafStore]) (uint32, uint32, error) { // Init the output indexes as empty. diff --git a/lnwallet/reservation.go b/lnwallet/reservation.go index 83a0829afaf..a804888ad7b 100644 --- a/lnwallet/reservation.go +++ b/lnwallet/reservation.go @@ -12,6 +12,7 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/chanstate" "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" @@ -248,7 +249,7 @@ type ChannelReservation struct { ourContribution *ChannelContribution theirContribution *ChannelContribution - partialState *channeldb.OpenChannel + partialState *chanstate.OpenChannel nodeAddr net.Addr // The ID of this reservation, used to uniquely track the reservation @@ -494,7 +495,7 @@ func NewChannelReservation(capacity, localFundingAmt btcutil.Amount, FundingAmount: theirBalance.ToSatoshis(), ChannelConfig: &channeldb.ChannelConfig{}, }, - partialState: &channeldb.OpenChannel{ + partialState: &chanstate.OpenChannel{ ChanType: chanType, ChainHash: *chainHash, IsPending: true, @@ -777,11 +778,11 @@ func (r *ChannelReservation) OurSignatures() ([]*input.Script, // confirmations. Once the method unblocks, a LightningChannel instance is // returned, marking the channel available for updates. func (r *ChannelReservation) CompleteReservation(fundingInputScripts []*input.Script, - commitmentSig input.Signature) (*channeldb.OpenChannel, error) { + commitmentSig input.Signature) (*chanstate.OpenChannel, error) { // TODO(roasbeef): add flag for watch or not? errChan := make(chan error, 1) - completeChan := make(chan *channeldb.OpenChannel, 1) + completeChan := make(chan *chanstate.OpenChannel, 1) r.wallet.msgChan <- &addCounterPartySigsMsg{ pendingFundingID: r.reservationID, @@ -805,11 +806,11 @@ func (r *ChannelReservation) CompleteReservation(fundingInputScripts []*input.Sc // will be populated. func (r *ChannelReservation) CompleteReservationSingle( fundingPoint *wire.OutPoint, commitSig input.Signature, - auxFundingDesc fn.Option[AuxFundingDesc]) (*channeldb.OpenChannel, + auxFundingDesc fn.Option[AuxFundingDesc]) (*chanstate.OpenChannel, error) { errChan := make(chan error, 1) - completeChan := make(chan *channeldb.OpenChannel, 1) + completeChan := make(chan *chanstate.OpenChannel, 1) r.wallet.msgChan <- &addSingleFunderSigsMsg{ pendingFundingID: r.reservationID, @@ -903,7 +904,7 @@ func (r *ChannelReservation) Cancel() error { } // ChanState the current open channel state. -func (r *ChannelReservation) ChanState() *channeldb.OpenChannel { +func (r *ChannelReservation) ChanState() *chanstate.OpenChannel { r.RLock() defer r.RUnlock() diff --git a/lnwallet/taproot_test_vectors_test.go b/lnwallet/taproot_test_vectors_test.go index 7e3dcf8e6a0..34474939390 100644 --- a/lnwallet/taproot_test_vectors_test.go +++ b/lnwallet/taproot_test_vectors_test.go @@ -20,6 +20,7 @@ import ( "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/chanstate" "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" @@ -891,7 +892,7 @@ func createTaprootTestChannelsForVectors(tc *taprootTestContext, shortChanID := lnwire.NewShortChanIDFromInt(0xdeadbeef) - remoteChannelState := &channeldb.OpenChannel{ + remoteChannelState := &chanstate.OpenChannel{ LocalChanCfg: remoteCfg, RemoteChanCfg: localCfg, IdentityPub: tc.remoteFundingPrivkey.PubKey(), @@ -906,12 +907,9 @@ func createTaprootTestChannelsForVectors(tc *taprootTestContext, LocalCommitment: remoteCommit, RemoteCommitment: remoteCommit, Db: dbRemote.ChannelStateDB(), - Packager: channeldb.NewChannelPackager( - shortChanID, - ), - FundingTxn: fundingTx, + FundingTxn: fundingTx, } - localChannelState := &channeldb.OpenChannel{ + localChannelState := &chanstate.OpenChannel{ LocalChanCfg: localCfg, RemoteChanCfg: remoteCfg, IdentityPub: tc.localFundingPrivkey.PubKey(), @@ -926,10 +924,7 @@ func createTaprootTestChannelsForVectors(tc *taprootTestContext, LocalCommitment: localCommit, RemoteCommitment: localCommit, Db: dbLocal.ChannelStateDB(), - Packager: channeldb.NewChannelPackager( - shortChanID, - ), - FundingTxn: fundingTx, + FundingTxn: fundingTx, } // Create mock signers with all deterministic keys. The funding key must diff --git a/lnwallet/test_utils.go b/lnwallet/test_utils.go index 738558e224f..da0c69cba81 100644 --- a/lnwallet/test_utils.go +++ b/lnwallet/test_utils.go @@ -16,6 +16,7 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/chanstate" "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" @@ -308,7 +309,7 @@ func CreateTestChannels(t *testing.T, chanType channeldb.ChannelType, binary.BigEndian.Uint64(chanIDBytes[:]), ) - aliceChannelState := &channeldb.OpenChannel{ + aliceChannelState := &chanstate.OpenChannel{ LocalChanCfg: aliceCfg, RemoteChanCfg: bobCfg, IdentityPub: aliceKeys[0].PubKey(), @@ -323,10 +324,9 @@ func CreateTestChannels(t *testing.T, chanType channeldb.ChannelType, LocalCommitment: aliceLocalCommit, RemoteCommitment: aliceRemoteCommit, Db: dbAlice.ChannelStateDB(), - Packager: channeldb.NewChannelPackager(shortChanID), FundingTxn: testTx, } - bobChannelState := &channeldb.OpenChannel{ + bobChannelState := &chanstate.OpenChannel{ LocalChanCfg: bobCfg, RemoteChanCfg: aliceCfg, IdentityPub: bobKeys[0].PubKey(), @@ -341,7 +341,6 @@ func CreateTestChannels(t *testing.T, chanType channeldb.ChannelType, LocalCommitment: bobLocalCommit, RemoteCommitment: bobRemoteCommit, Db: dbBob.ChannelStateDB(), - Packager: channeldb.NewChannelPackager(shortChanID), } // If the channel type has a tapscript root, then we'll also specify diff --git a/lnwallet/transactions_test.go b/lnwallet/transactions_test.go index 38131eaa724..fb3142c4f36 100644 --- a/lnwallet/transactions_test.go +++ b/lnwallet/transactions_test.go @@ -21,6 +21,7 @@ import ( "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/chanstate" "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" @@ -970,7 +971,7 @@ func createTestChannelsForVectors(tc *testContext, chanType channeldb.ChannelTyp binary.BigEndian.Uint64(chanIDBytes[:]), ) - remoteChannelState := &channeldb.OpenChannel{ + remoteChannelState := &chanstate.OpenChannel{ LocalChanCfg: remoteCfg, RemoteChanCfg: localCfg, IdentityPub: remoteDummy2.PubKey(), @@ -985,10 +986,9 @@ func createTestChannelsForVectors(tc *testContext, chanType channeldb.ChannelTyp LocalCommitment: remoteCommit, RemoteCommitment: remoteCommit, Db: dbRemote.ChannelStateDB(), - Packager: channeldb.NewChannelPackager(shortChanID), FundingTxn: tc.fundingTx.MsgTx(), } - localChannelState := &channeldb.OpenChannel{ + localChannelState := &chanstate.OpenChannel{ LocalChanCfg: localCfg, RemoteChanCfg: remoteCfg, IdentityPub: localDummy2.PubKey(), @@ -1003,7 +1003,6 @@ func createTestChannelsForVectors(tc *testContext, chanType channeldb.ChannelTyp LocalCommitment: localCommit, RemoteCommitment: localCommit, Db: dbLocal.ChannelStateDB(), - Packager: channeldb.NewChannelPackager(shortChanID), FundingTxn: tc.fundingTx.MsgTx(), } diff --git a/lnwallet/wallet.go b/lnwallet/wallet.go index daba0992577..df9586d6d3f 100644 --- a/lnwallet/wallet.go +++ b/lnwallet/wallet.go @@ -22,6 +22,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcwallet/wallet" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/chanstate" "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" @@ -334,7 +335,7 @@ type addCounterPartySigsMsg struct { // This channel is used to return the completed channel after the wallet // has completed all of its stages in the funding process. - completeChan chan *channeldb.OpenChannel + completeChan chan *chanstate.OpenChannel // NOTE: In order to avoid deadlocks, this channel MUST be buffered. err chan error @@ -363,7 +364,7 @@ type addSingleFunderSigsMsg struct { // This channel is used to return the completed channel after the wallet // has completed all of its stages in the funding process. - completeChan chan *channeldb.OpenChannel + completeChan chan *chanstate.OpenChannel // NOTE: In order to avoid deadlocks, this channel MUST be buffered. err chan error @@ -1152,7 +1153,7 @@ func (l *LightningWallet) CurrentNumAnchorChans() (int, error) { } var numAnchors int - cntChannel := func(c *channeldb.OpenChannel) { + cntChannel := func(c *chanstate.OpenChannel) { // We skip private channels, as we assume they won't be used // for routing. if c.ChannelFlags&lnwire.FFAnnounceChannel == 0 { @@ -2601,7 +2602,7 @@ func initStateHints(commit1, commit2 *wire.MsgTx, // ValidateChannel will attempt to fully validate a newly mined channel, given // its funding transaction and existing channel state. If this method returns // an error, then the mined channel is invalid, and shouldn't be used. -func (l *LightningWallet) ValidateChannel(channelState *channeldb.OpenChannel, +func (l *LightningWallet) ValidateChannel(channelState *chanstate.OpenChannel, fundingTx *wire.MsgTx) error { var chanOpts []ChannelOpt diff --git a/netann/chan_status_manager.go b/netann/chan_status_manager.go index b21aeb18a58..5a2fd5a2107 100644 --- a/netann/chan_status_manager.go +++ b/netann/chan_status_manager.go @@ -8,7 +8,7 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/wire" - "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/chanstate" graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lnwallet" @@ -600,14 +600,14 @@ func (m *ChanStatusManager) disableInactiveChannels() { // fetchChannels returns the working set of channels managed by the // ChanStatusManager. The returned channels are filtered to only contain public // channels. -func (m *ChanStatusManager) fetchChannels() ([]*channeldb.OpenChannel, error) { +func (m *ChanStatusManager) fetchChannels() ([]*chanstate.OpenChannel, error) { allChannels, err := m.cfg.DB.FetchAllOpenChannels() if err != nil { return nil, err } // Filter out private channels. - var channels []*channeldb.OpenChannel + var channels []*chanstate.OpenChannel for _, c := range allChannels { // We'll skip any private channels, as they aren't used for // routing within the network by other nodes. diff --git a/netann/chan_status_manager_test.go b/netann/chan_status_manager_test.go index 02669ce4adc..f103051cd58 100644 --- a/netann/chan_status_manager_test.go +++ b/netann/chan_status_manager_test.go @@ -15,7 +15,7 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" - "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/chanstate" graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/keychain" @@ -51,14 +51,14 @@ func randOutpoint(t *testing.T) wire.OutPoint { var shortChanIDs uint64 -// createChannel generates a channeldb.OpenChannel with a random chanpoint and +// createChannel generates a chanstate.OpenChannel with a random chanpoint and // short channel id. -func createChannel(t *testing.T) *channeldb.OpenChannel { +func createChannel(t *testing.T) *chanstate.OpenChannel { t.Helper() sid := atomic.AddUint64(&shortChanIDs, 1) - return &channeldb.OpenChannel{ + return &chanstate.OpenChannel{ ShortChannelID: lnwire.NewShortChanIDFromInt(sid), ChannelFlags: lnwire.FFAnnounceChannel, FundingOutpoint: randOutpoint(t), @@ -69,7 +69,7 @@ func createChannel(t *testing.T) *channeldb.OpenChannel { // The remote party's public key is generated randomly, and then sorted against // our `pubkey` with the direction bit set appropriately in the policies. Our // update will be created with the disabled bit set if startEnabled is false. -func createEdgePolicies(t *testing.T, channel *channeldb.OpenChannel, +func createEdgePolicies(t *testing.T, channel *chanstate.OpenChannel, pubkey *btcec.PublicKey, startEnabled bool) (*models.ChannelEdgeInfo, *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) { @@ -134,7 +134,7 @@ func createEdgePolicies(t *testing.T, channel *channeldb.OpenChannel, type mockGraph struct { mu sync.Mutex - channels []*channeldb.OpenChannel + channels []*chanstate.OpenChannel chanInfos map[wire.OutPoint]*models.ChannelEdgeInfo chanPols1 map[wire.OutPoint]*models.ChannelEdgePolicy chanPols2 map[wire.OutPoint]*models.ChannelEdgePolicy @@ -147,7 +147,7 @@ func newMockGraph(t *testing.T, numChannels int, startEnabled bool, pubKey *btcec.PublicKey) *mockGraph { g := &mockGraph{ - channels: make([]*channeldb.OpenChannel, 0, numChannels), + channels: make([]*chanstate.OpenChannel, 0, numChannels), chanInfos: make(map[wire.OutPoint]*models.ChannelEdgeInfo), chanPols1: make(map[wire.OutPoint]*models.ChannelEdgePolicy), chanPols2: make(map[wire.OutPoint]*models.ChannelEdgePolicy), @@ -169,7 +169,7 @@ func newMockGraph(t *testing.T, numChannels int, startEnabled bool, return g } -func (g *mockGraph) FetchAllOpenChannels() ([]*channeldb.OpenChannel, error) { +func (g *mockGraph) FetchAllOpenChannels() ([]*chanstate.OpenChannel, error) { return g.chans(), nil } @@ -246,24 +246,24 @@ func (g *mockGraph) ApplyChannelUpdate(update *lnwire.ChannelUpdate1, return nil } -func (g *mockGraph) chans() []*channeldb.OpenChannel { +func (g *mockGraph) chans() []*chanstate.OpenChannel { g.mu.Lock() defer g.mu.Unlock() - channels := make([]*channeldb.OpenChannel, 0, len(g.channels)) + channels := make([]*chanstate.OpenChannel, 0, len(g.channels)) channels = append(channels, g.channels...) return channels } -func (g *mockGraph) addChannel(channel *channeldb.OpenChannel) { +func (g *mockGraph) addChannel(channel *chanstate.OpenChannel) { g.mu.Lock() defer g.mu.Unlock() g.channels = append(g.channels, channel) } -func (g *mockGraph) addEdgePolicy(c *channeldb.OpenChannel, +func (g *mockGraph) addEdgePolicy(c *chanstate.OpenChannel, info *models.ChannelEdgeInfo, pol1, pol2 *models.ChannelEdgePolicy) { @@ -276,7 +276,7 @@ func (g *mockGraph) addEdgePolicy(c *channeldb.OpenChannel, g.sidToCid[c.ShortChanID()] = c.FundingOutpoint } -func (g *mockGraph) removeChannel(channel *channeldb.OpenChannel) { +func (g *mockGraph) removeChannel(channel *chanstate.OpenChannel) { g.mu.Lock() defer g.mu.Unlock() @@ -401,7 +401,7 @@ func newHarness(t *testing.T, numChannels int, // markActive updates the active status of the passed channels within the mock // switch to active. -func (h *testHarness) markActive(channels []*channeldb.OpenChannel) { +func (h *testHarness) markActive(channels []*chanstate.OpenChannel) { h.t.Helper() for _, channel := range channels { @@ -412,7 +412,7 @@ func (h *testHarness) markActive(channels []*channeldb.OpenChannel) { // markInactive updates the active status of the passed channels within the mock // switch to inactive. -func (h *testHarness) markInactive(channels []*channeldb.OpenChannel) { +func (h *testHarness) markInactive(channels []*chanstate.OpenChannel) { h.t.Helper() for _, channel := range channels { @@ -423,8 +423,8 @@ func (h *testHarness) markInactive(channels []*channeldb.OpenChannel) { // assertEnables requests enables for all of the passed channels, and asserts // that the errors returned from RequestEnable matches expErr. -func (h *testHarness) assertEnables(channels []*channeldb.OpenChannel, expErr error, - manual bool) { +func (h *testHarness) assertEnables(channels []*chanstate.OpenChannel, + expErr error, manual bool) { h.t.Helper() @@ -435,8 +435,8 @@ func (h *testHarness) assertEnables(channels []*channeldb.OpenChannel, expErr er // assertDisables requests disables for all of the passed channels, and asserts // that the errors returned from RequestDisable matches expErr. -func (h *testHarness) assertDisables(channels []*channeldb.OpenChannel, expErr error, - manual bool) { +func (h *testHarness) assertDisables(channels []*chanstate.OpenChannel, + expErr error, manual bool) { h.t.Helper() @@ -447,7 +447,7 @@ func (h *testHarness) assertDisables(channels []*channeldb.OpenChannel, expErr e // assertAutos requests auto state management for all of the passed channels, and // asserts that the errors returned from RequestAuto matches expErr. -func (h *testHarness) assertAutos(channels []*channeldb.OpenChannel, +func (h *testHarness) assertAutos(channels []*chanstate.OpenChannel, expErr error) { h.t.Helper() @@ -506,7 +506,7 @@ func (h *testHarness) assertNoUpdates(duration time.Duration) { // are receive on the network for each of the passed OpenChannels, and that all // of their disable bits are set to match expEnabled. The expEnabled parameter // is ignored if channels is nil. -func (h *testHarness) assertUpdates(channels []*channeldb.OpenChannel, +func (h *testHarness) assertUpdates(channels []*chanstate.OpenChannel, expEnabled bool, duration time.Duration) { h.t.Helper() @@ -554,7 +554,7 @@ func (h *testHarness) assertUpdates(channels []*channeldb.OpenChannel, // sidsFromChans returns an index contain the short channel ids of each channel // provided in the list of OpenChannels. func sidsFromChans( - channels []*channeldb.OpenChannel) map[lnwire.ShortChannelID]struct{} { + channels []*chanstate.OpenChannel) map[lnwire.ShortChannelID]struct{} { sids := make(map[lnwire.ShortChannelID]struct{}) for _, channel := range channels { @@ -703,7 +703,7 @@ var stateMachineTests = []stateMachineTest{ startEnabled: false, fn: func(h testHarness) { // Create channels unknown to the graph. - unknownChans := []*channeldb.OpenChannel{ + unknownChans := []*chanstate.OpenChannel{ createChannel(h.t), createChannel(h.t), createChannel(h.t), @@ -723,7 +723,7 @@ var stateMachineTests = []stateMachineTest{ startEnabled: false, fn: func(h testHarness) { // Create channels unknown to the graph. - unknownChans := []*channeldb.OpenChannel{ + unknownChans := []*chanstate.OpenChannel{ createChannel(h.t), createChannel(h.t), createChannel(h.t), @@ -749,7 +749,7 @@ var stateMachineTests = []stateMachineTest{ // Add a new channels to the graph, but don't yet add // the edge policies. We should see no updates sent // since the manager can't access the policies. - newChans := []*channeldb.OpenChannel{ + newChans := []*chanstate.OpenChannel{ createChannel(h.t), createChannel(h.t), createChannel(h.t), diff --git a/netann/interface.go b/netann/interface.go index 78acc24cdd8..246edb5bf5a 100644 --- a/netann/interface.go +++ b/netann/interface.go @@ -4,7 +4,7 @@ import ( "context" "github.com/btcsuite/btcd/wire" - "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/chanstate" "github.com/lightningnetwork/lnd/graph/db/models" ) @@ -13,7 +13,7 @@ import ( type DB interface { // FetchAllOpenChannels returns a slice of all open channels known to // the daemon. This may include private or pending channels. - FetchAllOpenChannels() ([]*channeldb.OpenChannel, error) + FetchAllOpenChannels() ([]*chanstate.OpenChannel, error) } // ChannelGraph abstracts the required channel graph queries used by the diff --git a/peer/brontide.go b/peer/brontide.go index f7a01cd11f5..fb6ad8c1f9c 100644 --- a/peer/brontide.go +++ b/peer/brontide.go @@ -123,7 +123,7 @@ type outgoingMsg struct { errChan chan error // MUST be buffered. } -// newChannelMsg packages a channeldb.OpenChannel with a channel that allows +// newChannelMsg packages a chanstate.OpenChannel with a channel that allows // the receiver of the request to report when the channel creation process has // completed. type newChannelMsg struct { @@ -1142,7 +1142,9 @@ func (p *Brontide) addrWithInternalKey( // channels returned by the database. It returns a slice of channel reestablish // messages that should be sent to the peer immediately, in case we have borked // channels that haven't been closed yet. -func (p *Brontide) loadActiveChannels(chans []*channeldb.OpenChannel) ( +// +//nolint:funlen +func (p *Brontide) loadActiveChannels(chans []*chanstate.OpenChannel) ( []lnwire.Message, error) { // Return a slice of messages to send to the peers in case the channel @@ -1592,7 +1594,7 @@ func (p *Brontide) addLink(chanPoint *wire.OutPoint, // maybeSendNodeAnn sends our node announcement to the remote peer if at least // one confirmed public channel exists with them. -func (p *Brontide) maybeSendNodeAnn(channels []*channeldb.OpenChannel) { +func (p *Brontide) maybeSendNodeAnn(channels []*chanstate.OpenChannel) { defer p.cg.WgDone() hasConfirmedPublicChan := false @@ -5495,7 +5497,7 @@ func (p *Brontide) attachChannelEventSubscription() error { // updateNextRevocation updates the existing channel's next revocation if it's // nil. -func (p *Brontide) updateNextRevocation(c *channeldb.OpenChannel) error { +func (p *Brontide) updateNextRevocation(c *chanstate.OpenChannel) error { chanPoint := c.FundingOutpoint chanID := lnwire.NewChanIDFromOutPoint(chanPoint) @@ -5537,7 +5539,7 @@ func (p *Brontide) updateNextRevocation(c *channeldb.OpenChannel) error { } // addActiveChannel adds a new active channel to the `activeChannels` map. It -// takes a `channeldb.OpenChannel`, creates a `lnwallet.LightningChannel` from +// takes a `chanstate.OpenChannel`, creates a `lnwallet.LightningChannel` from // it and assembles it with a channel link. func (p *Brontide) addActiveChannel(c *lnpeer.NewChannel) error { chanPoint := c.FundingOutpoint @@ -5796,7 +5798,7 @@ func (p *Brontide) scaleTimeout(timeout time.Duration) time.Duration { // bandwidth against the traffic shaper. type auxHtlcValidator struct { peer *Brontide - dbChan *channeldb.OpenChannel + dbChan *chanstate.OpenChannel ts htlcswitch.AuxTrafficShaper } @@ -5872,7 +5874,7 @@ func (v *auxHtlcValidator) ValidateHtlc(amount, // createHtlcValidator creates an HTLC validator that performs final aux balance // validation before HTLCs are added to the channel state. -func (p *Brontide) createHtlcValidator(dbChan *channeldb.OpenChannel, +func (p *Brontide) createHtlcValidator(dbChan *chanstate.OpenChannel, ts htlcswitch.AuxTrafficShaper) lnwallet.AuxHtlcValidator { return &auxHtlcValidator{ diff --git a/peer/brontide_test.go b/peer/brontide_test.go index f4bb661ea38..0ce7fb7b33f 100644 --- a/peer/brontide_test.go +++ b/peer/brontide_test.go @@ -12,7 +12,7 @@ import ( "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/chainntnfs" - "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/chanstate" "github.com/lightningnetwork/lnd/contractcourt" "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/htlcswitch" @@ -765,7 +765,7 @@ func TestCustomShutdownScript(t *testing.T) { // setShutdown is a function which sets the upfront shutdown address for // the local channel. - setShutdown := func(a, b *channeldb.OpenChannel) { + setShutdown := func(a, b *chanstate.OpenChannel) { a.LocalShutdownScript = script b.RemoteShutdownScript = script } @@ -775,7 +775,7 @@ func TestCustomShutdownScript(t *testing.T) { // update is a function used to set values on the channel set up for the // test. It is used to set values for upfront shutdown addresses. - update func(a, b *channeldb.OpenChannel) + update func(a, b *chanstate.OpenChannel) // userCloseScript is the address specified by the user. userCloseScript lnwire.DeliveryAddress @@ -1225,8 +1225,8 @@ func assertMsgSent(t *testing.T, conn *mockMessageConn, func TestAlwaysSendChannelUpdate(t *testing.T) { require := require.New(t) - var channel *channeldb.OpenChannel - channelIntercept := func(a, b *channeldb.OpenChannel) { + var channel *chanstate.OpenChannel + channelIntercept := func(a, b *chanstate.OpenChannel) { channel = a } @@ -1437,8 +1437,8 @@ func TestStartupWriteMessageRace(t *testing.T) { // createTestPeerWithChannel, so we can mark it borked below. // We can't mark it borked within the callback, since the channel hasn't // been saved to the DB yet when the callback executes. - var channel *channeldb.OpenChannel - getChannels := func(a, b *channeldb.OpenChannel) { + var channel *chanstate.OpenChannel + getChannels := func(a, b *chanstate.OpenChannel) { channel = a } @@ -1638,7 +1638,7 @@ func TestCreateHtlcValidator(t *testing.T) { } // Create a mock channel with minimal required fields. - dbChan := &channeldb.OpenChannel{ + dbChan := &chanstate.OpenChannel{ ShortChannelID: lnwire.NewShortChanIDFromInt(123), } diff --git a/peer/test_utils.go b/peer/test_utils.go index 670af094605..4746120c437 100644 --- a/peer/test_utils.go +++ b/peer/test_utils.go @@ -18,6 +18,7 @@ import ( "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channelnotifier" + "github.com/lightningnetwork/lnd/chanstate" "github.com/lightningnetwork/lnd/fn/v2" graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/htlcswitch" @@ -55,7 +56,7 @@ var ( // noUpdate is a function which can be used as a parameter in // createTestPeerWithChannel to call the setup code with no custom values on // the channels set up. -var noUpdate = func(a, b *channeldb.OpenChannel) {} +var noUpdate = func(a, b *chanstate.OpenChannel) {} type peerTestCtx struct { peer *Brontide @@ -75,7 +76,7 @@ type peerTestCtx struct { // It takes an updateChan function which can be used to modify the default // values on the channel states for each peer. func createTestPeerWithChannel(t *testing.T, updateChan func(a, - b *channeldb.OpenChannel)) (*peerTestCtx, error) { + b *chanstate.OpenChannel)) (*peerTestCtx, error) { params := createTestPeer(t) @@ -238,7 +239,7 @@ func createTestPeerWithChannel(t *testing.T, updateChan func(a, binary.BigEndian.Uint64(chanIDBytes[:]), ) - aliceChannelState := &channeldb.OpenChannel{ + aliceChannelState := &chanstate.OpenChannel{ LocalChanCfg: aliceCfg, RemoteChanCfg: bobCfg, IdentityPub: aliceKeyPub, @@ -253,10 +254,9 @@ func createTestPeerWithChannel(t *testing.T, updateChan func(a, LocalCommitment: aliceCommit, RemoteCommitment: aliceCommit, Db: dbAlice.ChannelStateDB(), - Packager: channeldb.NewChannelPackager(shortChanID), FundingTxn: channels.TestFundingTx, } - bobChannelState := &channeldb.OpenChannel{ + bobChannelState := &chanstate.OpenChannel{ LocalChanCfg: bobCfg, RemoteChanCfg: aliceCfg, IdentityPub: bobKeyPub, @@ -270,7 +270,6 @@ func createTestPeerWithChannel(t *testing.T, updateChan func(a, LocalCommitment: bobCommit, RemoteCommitment: bobCommit, Db: dbBob.ChannelStateDB(), - Packager: channeldb.NewChannelPackager(shortChanID), } // Set custom values on the channel states. diff --git a/routing/blindedpath/blinded_path.go b/routing/blindedpath/blinded_path.go index ce5a1420c2e..9cacaa1ed2e 100644 --- a/routing/blindedpath/blinded_path.go +++ b/routing/blindedpath/blinded_path.go @@ -10,7 +10,7 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcutil" sphinx "github.com/lightningnetwork/lightning-onion" - "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/chanstate" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/record" @@ -46,7 +46,7 @@ type BuildBlindedPathCfg struct { *models.ChannelEdgePolicy, *models.ChannelEdgePolicy, error) // FetchOurOpenChannels fetches this node's set of open channels. - FetchOurOpenChannels func() ([]*channeldb.OpenChannel, error) + FetchOurOpenChannels func() ([]*chanstate.OpenChannel, error) // BestHeight can be used to fetch the best block height that this node // is aware of. @@ -529,7 +529,7 @@ func buildDummyRouteData(node route.Vertex, relayInfo *record.PaymentRelayInfo, // we use the provided default policy values, and we get the average capacity of // this node's channels to compute a MaxHTLC value. func computeDummyHopPolicy(defaultPolicy *BlindedHopPolicy, - fetchOurChannels func() ([]*channeldb.OpenChannel, error), + fetchOurChannels func() ([]*chanstate.OpenChannel, error), policies map[uint64]*BlindedHopPolicy) (*BlindedHopPolicy, error) { numPolicies := len(policies) diff --git a/routing/localchans/manager.go b/routing/localchans/manager.go index 6dd30b0ae3c..8e313ddd081 100644 --- a/routing/localchans/manager.go +++ b/routing/localchans/manager.go @@ -11,6 +11,7 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/chanstate" "github.com/lightningnetwork/lnd/discovery" "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/funding" @@ -48,7 +49,7 @@ type Manager struct { // FetchChannel is used to query local channel parameters. Optionally an // existing db tx can be supplied. - FetchChannel func(chanPoint wire.OutPoint) (*channeldb.OpenChannel, + FetchChannel func(chanPoint wire.OutPoint) (*chanstate.OpenChannel, error) // AddEdge is used to add edge/channel to the topology of the router. @@ -247,7 +248,7 @@ func (r *Manager) UpdatePolicy(ctx context.Context, } func (r *Manager) createMissingEdge(ctx context.Context, - channel *channeldb.OpenChannel, + channel *chanstate.OpenChannel, newSchema routing.ChannelPolicy) (*models.ChannelEdgeInfo, *models.ChannelEdgePolicy, *lnrpc.FailedUpdate) { @@ -294,7 +295,7 @@ func (r *Manager) createMissingEdge(ctx context.Context, } // createEdge recreates an edge and policy from an open channel in-memory. -func (r *Manager) createEdge(channel *channeldb.OpenChannel, +func (r *Manager) createEdge(channel *chanstate.OpenChannel, timestamp time.Time) (*models.ChannelEdgeInfo, *models.ChannelEdgePolicy, error) { @@ -475,7 +476,7 @@ func (r *Manager) updateEdge(chanPoint wire.OutPoint, // getHtlcAmtLimits retrieves the negotiated channel min and max htlc amount // constraints. -func (r *Manager) getHtlcAmtLimits(ch *channeldb.OpenChannel) ( +func (r *Manager) getHtlcAmtLimits(ch *chanstate.OpenChannel) ( lnwire.MilliSatoshi, lnwire.MilliSatoshi, error) { // The max htlc policy field must be less than or equal to the channel diff --git a/routing/localchans/manager_test.go b/routing/localchans/manager_test.go index c48e6164160..b68585dad5c 100644 --- a/routing/localchans/manager_test.go +++ b/routing/localchans/manager_test.go @@ -12,6 +12,7 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/chanstate" "github.com/lightningnetwork/lnd/discovery" "github.com/lightningnetwork/lnd/funding" "github.com/lightningnetwork/lnd/graph/db/models" @@ -138,28 +139,29 @@ func TestManager(t *testing.T) { return nil } - fetchChannel := func(chanPoint wire.OutPoint) (*channeldb.OpenChannel, + fetchChannel := func(chanPoint wire.OutPoint) (*chanstate.OpenChannel, error) { if chanPoint == chanPointMissing { - return &channeldb.OpenChannel{}, channeldb.ErrChannelNotFound + return &chanstate.OpenChannel{}, + channeldb.ErrChannelNotFound } - bounds := channeldb.ChannelStateBounds{ + bounds := chanstate.ChannelStateBounds{ MaxPendingAmount: maxPendingAmount, MinHTLC: minHTLC, } - return &channeldb.OpenChannel{ + return &chanstate.OpenChannel{ FundingOutpoint: chanPointValid, IdentityPub: remotepub, - LocalChanCfg: channeldb.ChannelConfig{ + LocalChanCfg: chanstate.ChannelConfig{ ChannelStateBounds: bounds, MultiSigKey: keychain.KeyDescriptor{ PubKey: localMultisigKey, }, }, - RemoteChanCfg: channeldb.ChannelConfig{ + RemoteChanCfg: chanstate.ChannelConfig{ ChannelStateBounds: bounds, MultiSigKey: keychain.KeyDescriptor{ PubKey: remoteMultisigKey, @@ -414,14 +416,14 @@ func TestCreateEdgeLower(t *testing.T) { TimeLockDelta: 7, } - channel := &channeldb.OpenChannel{ + channel := &chanstate.OpenChannel{ IdentityPub: remotepub, - LocalChanCfg: channeldb.ChannelConfig{ + LocalChanCfg: chanstate.ChannelConfig{ MultiSigKey: keychain.KeyDescriptor{ PubKey: localMultisigKey, }, }, - RemoteChanCfg: channeldb.ChannelConfig{ + RemoteChanCfg: chanstate.ChannelConfig{ MultiSigKey: keychain.KeyDescriptor{ PubKey: remoteMultisigKey, }, @@ -505,14 +507,14 @@ func TestCreateEdgeHigher(t *testing.T) { TimeLockDelta: 7, } - channel := &channeldb.OpenChannel{ + channel := &chanstate.OpenChannel{ IdentityPub: remotepub, - LocalChanCfg: channeldb.ChannelConfig{ + LocalChanCfg: chanstate.ChannelConfig{ MultiSigKey: keychain.KeyDescriptor{ PubKey: localMultisigKey, }, }, - RemoteChanCfg: channeldb.ChannelConfig{ + RemoteChanCfg: chanstate.ChannelConfig{ MultiSigKey: keychain.KeyDescriptor{ PubKey: remoteMultisigKey, }, diff --git a/rpcserver.go b/rpcserver.go index 491bd8a1426..f00c277211c 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -43,6 +43,7 @@ import ( "github.com/lightningnetwork/lnd/chanfitness" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channelnotifier" + "github.com/lightningnetwork/lnd/chanstate" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/contractcourt" "github.com/lightningnetwork/lnd/discovery" @@ -4027,7 +4028,7 @@ type ( // 1. The current blockchain height // 2. The block height at which the funding transaction was first confirmed // 3. The total number of confirmations required for the channel. -func calcRemainingConfs(pendingChan *channeldb.OpenChannel, +func calcRemainingConfs(pendingChan *chanstate.OpenChannel, currentHeight uint32) uint32 { // If the funding transaction hasn't been confirmed yet, @@ -4332,7 +4333,7 @@ func (r *rpcServer) fetchWaitingCloseChannels( // getClosingTx is a helper closure that tries to find the closing tx of // a given waiting close channel. Notice that if the remote closes the // channel, we may not have the closing tx. - getClosingTx := func(c *channeldb.OpenChannel) (*wire.MsgTx, error) { + getClosingTx := func(c *chanstate.OpenChannel) (*wire.MsgTx, error) { var ( tx *wire.MsgTx err error @@ -4972,7 +4973,7 @@ func createChannelConstraint( // isPrivate evaluates the ChannelFlags of the db channel to determine if the // channel is private or not. -func isPrivate(dbChannel *channeldb.OpenChannel) bool { +func isPrivate(dbChannel *chanstate.OpenChannel) bool { if dbChannel == nil { return false } @@ -4981,7 +4982,7 @@ func isPrivate(dbChannel *channeldb.OpenChannel) bool { // encodeCustomChanData encodes the custom channel data for the open channel. // It encodes that data as a pair of var bytes blobs. -func encodeCustomChanData(lnChan *channeldb.OpenChannel) ([]byte, error) { +func encodeCustomChanData(lnChan *chanstate.OpenChannel) ([]byte, error) { customOpenChanData := lnChan.CustomBlob.UnwrapOr(nil) customLocalCommitData := lnChan.LocalCommitment.CustomBlob.UnwrapOr(nil) @@ -5012,7 +5013,7 @@ func encodeCustomChanData(lnChan *channeldb.OpenChannel) ([]byte, error) { // //nolint:funlen func createRPCOpenChannel(ctx context.Context, r *rpcServer, - dbChannel *channeldb.OpenChannel, + dbChannel *chanstate.OpenChannel, isActive, peerAliasLookup bool) (*lnrpc.Channel, error) { nodePub := dbChannel.IdentityPub diff --git a/server.go b/server.go index 45992c464cb..c740a192e70 100644 --- a/server.go +++ b/server.go @@ -1605,7 +1605,7 @@ func newServer(ctx context.Context, cfg *Config, listenAddrs []net.Addr, } return delay }, - WatchNewChannel: func(channel *channeldb.OpenChannel, + WatchNewChannel: func(channel *chanstate.OpenChannel, peerKey *btcec.PublicKey) error { // First, we'll mark this new peer as a persistent peer @@ -3456,7 +3456,7 @@ func (s *server) createNewHiddenService(ctx context.Context) error { // optimization that is quicker than seeking for a channel given only the // ChannelID. func (s *server) findChannel(node *btcec.PublicKey, chanID lnwire.ChannelID) ( - *channeldb.OpenChannel, error) { + *chanstate.OpenChannel, error) { nodeChans, err := s.chanStateDB.FetchOpenChannels(node) if err != nil { @@ -4374,7 +4374,7 @@ func (s *server) notifyOpenChannelPeerEvent(op wire.OutPoint, // notifyPendingOpenChannelPeerEvent updates the access manager's maps and then // calls the channelNotifier's NotifyPendingOpenChannelEvent. func (s *server) notifyPendingOpenChannelPeerEvent(op wire.OutPoint, - pendingChan *channeldb.OpenChannel, remotePub *btcec.PublicKey) { + pendingChan *chanstate.OpenChannel, remotePub *btcec.PublicKey) { // Call newPendingOpenChan to update the access manager's maps for this // peer. diff --git a/watchtower/blob/type.go b/watchtower/blob/type.go index 00415afc917..df04769eaf1 100644 --- a/watchtower/blob/type.go +++ b/watchtower/blob/type.go @@ -4,7 +4,7 @@ import ( "fmt" "strings" - "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/chanstate" ) // Flag represents a specify option that can be present in a Type. @@ -97,7 +97,7 @@ const ( // TypeFromChannel returns the appropriate blob Type for the given channel // type. -func TypeFromChannel(chanType channeldb.ChannelType) Type { +func TypeFromChannel(chanType chanstate.ChannelType) Type { switch { case chanType.IsTaprootFinal(): return TypeAltruistTaprootFinalCommit @@ -130,7 +130,7 @@ func (t Type) Identifier() (string, error) { // CommitmentType returns the appropriate CommitmentType for the given blob Type // and channel type. -func (t Type) CommitmentType(chanType *channeldb.ChannelType) (CommitmentType, +func (t Type) CommitmentType(chanType *chanstate.ChannelType) (CommitmentType, error) { switch { diff --git a/watchtower/wtclient/backup_task_internal_test.go b/watchtower/wtclient/backup_task_internal_test.go index 62d7609469e..777ef8dfd6f 100644 --- a/watchtower/wtclient/backup_task_internal_test.go +++ b/watchtower/wtclient/backup_task_internal_test.go @@ -9,7 +9,7 @@ import ( "github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" - "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/chanstate" "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" @@ -64,7 +64,7 @@ type backupTaskTest struct { bindErr error expSweepScript []byte signer input.Signer - chanType channeldb.ChannelType + chanType chanstate.ChannelType commitType blob.CommitmentType } @@ -84,7 +84,7 @@ func genTaskTest( expSweepAmt int64, expRewardAmt int64, bindErr error, - chanType channeldb.ChannelType) backupTaskTest { + chanType chanstate.ChannelType) backupTaskTest { // Set the anchor or taproot flag in the blob type if the session needs // to support anchor or taproot channels. @@ -330,11 +330,11 @@ var ( func TestBackupTask(t *testing.T) { t.Parallel() - chanTypes := []channeldb.ChannelType{ - channeldb.SingleFunderBit, - channeldb.SingleFunderTweaklessBit, - channeldb.AnchorOutputsBit, - channeldb.SimpleTaprootFeatureBit, + chanTypes := []chanstate.ChannelType{ + chanstate.SingleFunderBit, + chanstate.SingleFunderTweaklessBit, + chanstate.AnchorOutputsBit, + chanstate.SimpleTaprootFeatureBit, } var backupTaskTests []backupTaskTest @@ -573,7 +573,7 @@ func testBackupTask(t *testing.T, test backupTaskTest) { // getBreachInfo is a helper closure that returns the breach retribution // info and channel type for the given channel and commit height. getBreachInfo := func(id lnwire.ChannelID, commitHeight uint64) ( - *lnwallet.BreachRetribution, channeldb.ChannelType, error) { + *lnwallet.BreachRetribution, chanstate.ChannelType, error) { return test.breachInfo, test.chanType, nil } diff --git a/watchtower/wtclient/client.go b/watchtower/wtclient/client.go index 8d74e9d1f3d..a0c4d6c8921 100644 --- a/watchtower/wtclient/client.go +++ b/watchtower/wtclient/client.go @@ -13,7 +13,7 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btclog/v2" - "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/chanstate" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" @@ -94,7 +94,7 @@ type RegisteredTower struct { // BreachRetribution from a channel ID and a commitment height. type BreachRetributionBuilder func(id lnwire.ChannelID, commitHeight uint64) (*lnwallet.BreachRetribution, - channeldb.ChannelType, error) + chanstate.ChannelType, error) // newTowerMsg is an internal message we'll use within the client to signal // that a new tower can be considered. diff --git a/watchtower/wtclient/client_test.go b/watchtower/wtclient/client_test.go index 1b50600ae89..cd8007c966a 100644 --- a/watchtower/wtclient/client_test.go +++ b/watchtower/wtclient/client_test.go @@ -19,6 +19,7 @@ import ( "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channelnotifier" + "github.com/lightningnetwork/lnd/chanstate" "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" @@ -514,7 +515,7 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness { }) fetchChannel := func(id lnwire.ChannelID) ( - *channeldb.ChannelCloseSummary, error) { + *chanstate.ChannelCloseSummary, error) { h.mu.Lock() defer h.mu.Unlock() @@ -524,7 +525,7 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness { return nil, channeldb.ErrClosedChannelNotFound } - return &channeldb.ChannelCloseSummary{CloseHeight: height}, nil + return &chanstate.ChannelCloseSummary{CloseHeight: height}, nil } h.clientPolicy = cfg.policy @@ -552,11 +553,11 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness { h.clientCfg.BuildBreachRetribution = func(id lnwire.ChannelID, commitHeight uint64) (*lnwallet.BreachRetribution, - channeldb.ChannelType, error) { + chanstate.ChannelType, error) { _, retribution := h.channelFromID(id).getState(commitHeight) - return retribution, channeldb.SimpleTaprootFeatureBit, nil + return retribution, chanstate.SimpleTaprootFeatureBit, nil } if !cfg.noServerStart { @@ -689,7 +690,7 @@ func (h *testHarness) closeChannel(id uint64, height uint32) { } h.channelEvents.sendUpdate(channelnotifier.ClosedChannelEvent{ - CloseSummary: &channeldb.ChannelCloseSummary{ + CloseSummary: &chanstate.ChannelCloseSummary{ ChanPoint: wire.OutPoint{ Hash: *chanPointHash, Index: 0, @@ -705,7 +706,7 @@ func (h *testHarness) registerChannel(id uint64) { chanID := chanIDFromInt(id) err := h.clientMgr.RegisterChannel( - chanID, channeldb.SimpleTaprootFeatureBit, + chanID, chanstate.SimpleTaprootFeatureBit, ) require.NoError(h.t, err) } diff --git a/watchtower/wtclient/manager.go b/watchtower/wtclient/manager.go index 7a39c8ff73e..0890a048605 100644 --- a/watchtower/wtclient/manager.go +++ b/watchtower/wtclient/manager.go @@ -12,6 +12,7 @@ import ( "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channelnotifier" + "github.com/lightningnetwork/lnd/chanstate" "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lnwire" @@ -67,7 +68,7 @@ type ClientManager interface { // parameters within the client. This should be called during link // startup to ensure that the client is able to support the link during // operation. - RegisterChannel(lnwire.ChannelID, channeldb.ChannelType) error + RegisterChannel(lnwire.ChannelID, chanstate.ChannelType) error // BackupState initiates a request to back up a particular revoked // state. If the method returns nil, the backup is guaranteed to be @@ -93,7 +94,7 @@ type Config struct { // channel. If the channel is not found or not yet closed then // channeldb.ErrClosedChannelNotFound will be returned. FetchClosedChannel func(cid lnwire.ChannelID) ( - *channeldb.ChannelCloseSummary, error) + *chanstate.ChannelCloseSummary, error) // ChainNotifier can be used to subscribe to block notifications. ChainNotifier chainntnfs.ChainNotifier @@ -597,7 +598,7 @@ func (m *Manager) Policy(blobType blob.Type) (wtpolicy.Policy, error) { // within the client. This should be called during link startup to ensure that // the client is able to support the link during operation. func (m *Manager) RegisterChannel(id lnwire.ChannelID, - chanType channeldb.ChannelType) error { + chanType chanstate.ChannelType) error { blobType := blob.TypeFromChannel(chanType) diff --git a/witness_beacon.go b/witness_beacon.go index 6c315d0c18f..4b4d08c49c9 100644 --- a/witness_beacon.go +++ b/witness_beacon.go @@ -5,6 +5,7 @@ import ( "sync" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/chanstate" "github.com/lightningnetwork/lnd/contractcourt" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/htlcswitch" @@ -59,7 +60,7 @@ func newPreimageBeacon(wCache witnessCache, // SubscribeUpdates returns a channel that will be sent upon *each* time a new // preimage is discovered. func (p *preimageBeacon) SubscribeUpdates( - chanID lnwire.ShortChannelID, htlc *channeldb.HTLC, + chanID lnwire.ShortChannelID, htlc *chanstate.HTLC, payload *hop.Payload, nextHopOnionBlob []byte) (*contractcourt.WitnessSubscription, error) { diff --git a/witness_beacon_test.go b/witness_beacon_test.go index d98c276f523..b65c5a5d0ae 100644 --- a/witness_beacon_test.go +++ b/witness_beacon_test.go @@ -3,7 +3,7 @@ package lnd import ( "testing" - "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/chanstate" "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/htlcswitch/hop" "github.com/lightningnetwork/lnd/lntypes" @@ -30,7 +30,7 @@ func TestWitnessBeaconIntercept(t *testing.T) { subscription, err := p.SubscribeUpdates( lnwire.NewShortChanIDFromInt(1), - &channeldb.HTLC{ + &chanstate.HTLC{ RHash: hash, }, &hop.Payload{},