tlnhtlc: cleanup and save settled htlcs - electrum - Electrum Bitcoin wallet
 (HTM) git clone https://git.parazyd.org/electrum
 (DIR) Log
 (DIR) Files
 (DIR) Refs
 (DIR) Submodules
       ---
 (DIR) commit 02eca034866caa85d3a2f9f8d23c5f1044349f79
 (DIR) parent 6f5209ef8506fd889064cee9753ad8aaee02ad42
 (HTM) Author: Janus <ysangkok@gmail.com>
       Date:   Wed, 12 Sep 2018 23:37:45 +0200
       
       lnhtlc: cleanup and save settled htlcs
       
       Diffstat:
         M electrum/lnhtlc.py                  |     177 ++++++++++++++++---------------
         M electrum/tests/test_lnhtlc.py       |      28 +++++++++++++++-------------
       
       2 files changed, 104 insertions(+), 101 deletions(-)
       ---
 (DIR) diff --git a/electrum/lnhtlc.py b/electrum/lnhtlc.py
       t@@ -2,6 +2,7 @@
        from collections import namedtuple
        import binascii
        import json
       +from enum import IntFlag
        
        from .util import bfh, PrintError, bh2u
        from .bitcoin import Hash
       t@@ -21,9 +22,22 @@ from .transaction import Transaction
        SettleHtlc = namedtuple("SettleHtlc", ["htlc_id"])
        RevokeAndAck = namedtuple("RevokeAndAck", ["per_commitment_secret", "next_per_commitment_point"])
        
       -FUNDEE_SIGNED = 1
       -FUNDEE_ACKED =  2
       -FUNDER_SIGNED = 4
       +class FeeUpdateProgress(IntFlag):
       +    FUNDEE_SIGNED = 1
       +    FUNDEE_ACKED =  2
       +    FUNDER_SIGNED = 4
       +
       +class HTLCOwner(IntFlag):
       +    LOCAL = 1
       +    REMOTE = -LOCAL
       +
       +    SENT = LOCAL
       +    RECEIVED = REMOTE
       +
       +SENT = HTLCOwner.SENT
       +RECEIVED = HTLCOwner.RECEIVED
       +LOCAL = HTLCOwner.LOCAL
       +REMOTE = HTLCOwner.REMOTE
        
        class FeeUpdate:
            def __init__(self, rate):
       t@@ -37,13 +51,14 @@ class UpdateAddHtlc:
                self.cltv_expiry = cltv_expiry
        
                # the height the htlc was locked in at, or None
       -        self.r_locked_in = None
       -        self.l_locked_in = None
       +        self.locked_in = {LOCAL: None, REMOTE: None}
       +
       +        self.settled = {LOCAL: None, REMOTE: None}
        
                self.htlc_id = None
        
            def as_tuple(self):
       -        return (self.htlc_id, self.amount_msat, self.payment_hash, self.cltv_expiry, self.r_locked_in, self.l_locked_in)
       +        return (self.htlc_id, self.amount_msat, self.payment_hash, self.cltv_expiry, self.locked_in[REMOTE], self.locked_in[LOCAL], self.settled)
        
            def __hash__(self):
                return hash(self.as_tuple())
       t@@ -77,7 +92,7 @@ class HTLCStateMachine(PrintError):
            @property
            def pending_remote_feerate(self):
                if self.pending_fee is not None:
       -            if self.constraints.is_initiator or (self.pending_fee.progress & FUNDEE_ACKED):
       +            if self.constraints.is_initiator or (self.pending_fee.progress & FeeUpdateProgress.FUNDEE_ACKED):
                        return self.pending_fee.rate
                return self.remote_state.feerate
        
       t@@ -86,7 +101,7 @@ class HTLCStateMachine(PrintError):
                if self.pending_fee is not None:
                    if not self.constraints.is_initiator:
                        return self.pending_fee.rate
       -            if self.constraints.is_initiator and (self.pending_fee.progress & FUNDEE_ACKED):
       +            if self.constraints.is_initiator and (self.pending_fee.progress & FeeUpdateProgress.FUNDEE_ACKED):
                        return self.pending_fee.rate
                return self.local_state.feerate
        
       t@@ -134,13 +149,10 @@ class HTLCStateMachine(PrintError):
                # any past commitment transaction and use that instead; until then...
                self.remote_commitment_to_be_revoked = Transaction(state["remote_commitment_to_be_revoked"])
        
       -        self.local_update_log = []
       -        self.remote_update_log = []
       +        self.log = {LOCAL: [], REMOTE: []}
        
                self.name = name
        
       -        self.total_msat_sent = 0
       -        self.total_msat_received = 0
                self.pending_fee = None
        
                self.local_commitment = self.pending_local_commitment
       t@@ -174,7 +186,7 @@ class HTLCStateMachine(PrintError):
                should be called when preparing to send an outgoing HTLC.
                """
                assert type(htlc) is UpdateAddHtlc
       -        self.local_update_log.append(htlc)
       +        self.log[LOCAL].append(htlc)
                self.print_error("add_htlc")
                htlc_id = self.local_state.next_htlc_id
                self.local_state=self.local_state._replace(next_htlc_id=htlc_id + 1)
       t@@ -189,7 +201,7 @@ class HTLCStateMachine(PrintError):
                """
                self.print_error("receive_htlc")
                assert type(htlc) is UpdateAddHtlc
       -        self.remote_update_log.append(htlc)
       +        self.log[REMOTE].append(htlc)
                htlc_id = self.remote_state.next_htlc_id
                self.remote_state=self.remote_state._replace(next_htlc_id=htlc_id + 1)
                htlc.htlc_id = htlc_id
       t@@ -208,9 +220,9 @@ class HTLCStateMachine(PrintError):
                any). The HTLC signatures are sorted according to the BIP 69 order of the
                HTLC's on the commitment transaction.
                """
       -        for htlc in self.local_update_log:
       +        for htlc in self.log[LOCAL]:
                    if not type(htlc) is UpdateAddHtlc: continue
       -            if htlc.l_locked_in is None: htlc.l_locked_in = self.local_state.ctn
       +            if htlc.locked_in[LOCAL] is None: htlc.locked_in[LOCAL] = self.local_state.ctn
                self.print_error("sign_next_commitment")
        
                pending_remote_commitment = self.pending_remote_commitment
       t@@ -243,9 +255,9 @@ class HTLCStateMachine(PrintError):
        
                if self.pending_fee:
                    if not self.constraints.is_initiator:
       -                self.pending_fee.progress |= FUNDEE_SIGNED
       -            if self.constraints.is_initiator and (self.pending_fee.progress & FUNDEE_ACKED):
       -                self.pending_fee.progress |= FUNDER_SIGNED
       +                self.pending_fee.progress |= FeeUpdateProgress.FUNDEE_SIGNED
       +            if self.constraints.is_initiator and (self.pending_fee.progress & FeeUpdateProgress.FUNDEE_ACKED):
       +                self.pending_fee.progress |= FeeUpdateProgress.FUNDER_SIGNED
        
                if self.lnwatcher:
                    self.lnwatcher.process_new_offchain_ctx(self, pending_remote_commitment, ours=False)
       t@@ -265,9 +277,9 @@ class HTLCStateMachine(PrintError):
                """
        
                self.print_error("receive_new_commitment")
       -        for htlc in self.remote_update_log:
       +        for htlc in self.log[REMOTE]:
                    if not type(htlc) is UpdateAddHtlc: continue
       -            if htlc.r_locked_in is None: htlc.r_locked_in = self.remote_state.ctn
       +            if htlc.locked_in[REMOTE] is None: htlc.locked_in[REMOTE] = self.remote_state.ctn
                assert len(htlc_sigs) == 0 or type(htlc_sigs[0]) is bytes
        
                pending_local_commitment = self.pending_local_commitment
       t@@ -294,9 +306,9 @@ class HTLCStateMachine(PrintError):
        
                if self.pending_fee:
                    if not self.constraints.is_initiator:
       -                self.pending_fee.progress |= FUNDEE_SIGNED
       -            if self.constraints.is_initiator and (self.pending_fee.progress & FUNDEE_ACKED):
       -                self.pending_fee.progress |= FUNDER_SIGNED
       +                self.pending_fee.progress |= FeeUpdateProgress.FUNDEE_SIGNED
       +            if self.constraints.is_initiator and (self.pending_fee.progress & FeeUpdateProgress.FUNDEE_ACKED):
       +                self.pending_fee.progress |= FeeUpdateProgress.FUNDER_SIGNED
        
                if self.lnwatcher:
                    self.lnwatcher.process_new_offchain_ctx(self, pending_local_commitment, ours=True)
       t@@ -321,11 +333,11 @@ class HTLCStateMachine(PrintError):
                new_remote_feerate = self.remote_state.feerate
        
                if self.pending_fee is not None:
       -            if not self.constraints.is_initiator and (self.pending_fee.progress & FUNDEE_SIGNED):
       +            if not self.constraints.is_initiator and (self.pending_fee.progress & FeeUpdateProgress.FUNDEE_SIGNED):
                        new_local_feerate = new_remote_feerate = self.pending_fee.rate
                        self.pending_fee = None
                        print("FEERATE CHANGE COMPLETE (non-initiator)")
       -            if self.constraints.is_initiator and (self.pending_fee.progress & FUNDER_SIGNED):
       +            if self.constraints.is_initiator and (self.pending_fee.progress & FeeUpdateProgress.FUNDER_SIGNED):
                        new_local_feerate = new_remote_feerate = self.pending_fee.rate
                        self.pending_fee = None
                        print("FEERATE CHANGE COMPLETE (initiator)")
       t@@ -382,41 +394,21 @@ class HTLCStateMachine(PrintError):
                if self.lnwatcher:
                    self.lnwatcher.process_new_revocation_secret(self, revocation.per_commitment_secret)
        
       -        settle_fails2 = []
       -        for x in self.remote_update_log:
       -            if type(x) is not SettleHtlc:
       -                continue
       -            settle_fails2.append(x)
       -
       -        sent_this_batch = 0
       +        def mark_settled(subject):
       +            """
       +            find settled htlcs for subject (LOCAL or REMOTE) and mark them settled, return value of settled htlcs
       +            """
       +            old_amount = self.htlcsum(self.gen_htlc_indices(subject, False))
        
       -        for x in settle_fails2:
       -            htlc = self.lookup_htlc(self.local_update_log, x.htlc_id)
       -            sent_this_batch += htlc.amount_msat
       +            for x in self.log[-subject]:
       +                if type(x) is not SettleHtlc: continue
       +                htlc = self.lookup_htlc(self.log[subject], x.htlc_id)
       +                htlc.settled[subject] = self.current_height[subject]
        
       -        self.total_msat_sent += sent_this_batch
       +            return old_amount - self.htlcsum(self.gen_htlc_indices(subject, False))
        
       -        # log compaction (remove entries relating to htlc's that have been settled)
       -
       -        to_remove = []
       -        for x in filter(lambda x: type(x) is SettleHtlc, self.remote_update_log):
       -            to_remove += [y for y in self.local_update_log if y.htlc_id == x.htlc_id]
       -
       -        # assert that we should have compacted the log earlier
       -        assert len(to_remove) <= 1, to_remove
       -        if len(to_remove) == 1:
       -            self.remote_update_log = [x for x in self.remote_update_log if x.htlc_id != to_remove[0].htlc_id]
       -            self.local_update_log = [x for x in self.local_update_log if x.htlc_id != to_remove[0].htlc_id]
       -
       -        to_remove = []
       -        for x in filter(lambda x: type(x) is SettleHtlc, self.local_update_log):
       -            to_remove += [y for y in self.remote_update_log if y.htlc_id == x.htlc_id]
       -        if len(to_remove) == 1:
       -            self.remote_update_log = [x for x in self.remote_update_log if x.htlc_id != to_remove[0].htlc_id]
       -            self.local_update_log = [x for x in self.local_update_log if x.htlc_id != to_remove[0].htlc_id]
       -        received_this_batch = sum(x.amount_msat for x in to_remove)
       -
       -        self.total_msat_received += received_this_batch
       +        sent_this_batch = mark_settled(LOCAL)
       +        received_this_batch = mark_settled(REMOTE)
        
                next_point = self.remote_state.next_per_commitment_point
        
       t@@ -434,7 +426,7 @@ class HTLCStateMachine(PrintError):
        
                if self.pending_fee:
                    if self.constraints.is_initiator:
       -                self.pending_fee.progress |= FUNDEE_ACKED
       +                self.pending_fee.progress |= FeeUpdateProgress.FUNDEE_ACKED
        
                self.local_commitment = self.pending_local_commitment
                self.remote_commitment = self.pending_remote_commitment
       t@@ -449,14 +441,14 @@ class HTLCStateMachine(PrintError):
                return amount_unsettled
        
            def amounts(self):
       -        remote_settled_value = self.htlcsum(self.gen_htlc_indices("remote", False))
       -        local_settled_value = self.htlcsum(self.gen_htlc_indices("local", False))
       -        htlc_value_local = self.htlcsum(self.htlcs_in_local)
       -        htlc_value_remote = self.htlcsum(self.htlcs_in_remote)
       -        local_msat = self.local_state.amount_msat -\
       -          htlc_value_local + remote_settled_value - local_settled_value
       +        remote_settled= self.htlcsum(self.gen_htlc_indices(REMOTE, False))
       +        local_settled= self.htlcsum(self.gen_htlc_indices(LOCAL, False))
       +        unsettled_local = self.htlcsum(self.gen_htlc_indices(LOCAL, True))
       +        unsettled_remote = self.htlcsum(self.gen_htlc_indices(REMOTE, True))
                remote_msat = self.remote_state.amount_msat -\
       -          htlc_value_remote + local_settled_value - remote_settled_value
       +          unsettled_remote + local_settled - remote_settled
       +        local_msat = self.local_state.amount_msat -\
       +          unsettled_local + remote_settled - local_settled
                return remote_msat, local_msat
        
            @property
       t@@ -525,61 +517,70 @@ class HTLCStateMachine(PrintError):
                    local_msat, remote_msat, htlcs_in_local + htlcs_in_remote)
                return commit
        
       -    def gen_htlc_indices(self, subject, just_unsettled=True):
       -        assert subject in ["local", "remote"]
       -        update_log = (self.remote_update_log if subject == "remote" else self.local_update_log)
       -        other_log = (self.remote_update_log if subject != "remote" else self.local_update_log)
       +    @property
       +    def total_msat(self):
       +        return {LOCAL: self.htlcsum(self.gen_htlc_indices(LOCAL, False, True)), REMOTE: self.htlcsum(self.gen_htlc_indices(REMOTE, False, True))}
       +
       +    def gen_htlc_indices(self, subject, only_pending, include_settled=False):
       +        """
       +        only_pending: require the htlc's settlement to be pending (needs additional signatures/acks)
       +        include_settled: include settled (totally done with) htlcs
       +        """
       +        update_log = self.log[subject]
       +        other_log = self.log[-subject]
                res = []
                for htlc in update_log:
                    if type(htlc) is not UpdateAddHtlc:
                        continue
       -            height = (self.local_state.ctn if subject == "remote" else self.remote_state.ctn)
       -            locked_in = (htlc.r_locked_in if subject == "remote" else htlc.l_locked_in)
       +            height = self.current_height[-subject]
       +            locked_in = htlc.locked_in[subject]
        
       -            if locked_in is None or just_unsettled == (SettleHtlc(htlc.htlc_id) in other_log):
       +            if locked_in is None or only_pending == (SettleHtlc(htlc.htlc_id) in other_log):
                        continue
       +
       +            settled_cutoff = self.local_state.ctn if subject == LOCAL else self.remote_state.ctn
       +
       +            if not include_settled and htlc.settled[subject] is not None and settled_cutoff >= htlc.settled[subject]:
       +                continue
       +
                    res.append(htlc)
                return res
        
            @property
            def htlcs_in_local(self):
                """in the local log. 'offered by us'"""
       -        return self.gen_htlc_indices("local")
       +        return self.gen_htlc_indices(LOCAL, True)
        
            @property
            def htlcs_in_remote(self):
                """in the remote log. 'offered by them'"""
       -        return self.gen_htlc_indices("remote")
       +        return self.gen_htlc_indices(REMOTE, True)
        
            def settle_htlc(self, preimage, htlc_id):
                """
                SettleHTLC attempts to settle an existing outstanding received HTLC.
                """
                self.print_error("settle_htlc")
       -        htlc = self.lookup_htlc(self.remote_update_log, htlc_id)
       +        htlc = self.lookup_htlc(self.log[REMOTE], htlc_id)
                assert htlc.payment_hash == sha256(preimage)
       -        self.local_update_log.append(SettleHtlc(htlc_id))
       +        self.log[LOCAL].append(SettleHtlc(htlc_id))
        
            def receive_htlc_settle(self, preimage, htlc_index):
                self.print_error("receive_htlc_settle")
       -        htlc = self.lookup_htlc(self.local_update_log, htlc_index)
       +        htlc = self.lookup_htlc(self.log[LOCAL], htlc_index)
                assert htlc.payment_hash == sha256(preimage)
       -        assert len([x.htlc_id == htlc_index for x in self.local_update_log]) == 1
       -        self.remote_update_log.append(SettleHtlc(htlc_index))
       +        assert len([x.htlc_id == htlc_index for x in self.log[LOCAL]]) == 1
       +        self.log[REMOTE].append(SettleHtlc(htlc_index))
        
            def fail_htlc(self, htlc):
                # TODO
       -        self.local_update_log = []
       -        self.remote_update_log = []
       +        self.log[LOCAL] = []
       +        self.log[REMOTE] = []
                self.print_error("fail_htlc (EMPTIED LOGS)")
        
            @property
       -    def l_current_height(self):
       -        return self.local_state.ctn
       -
       -    @property
       -    def r_current_height(self):
       -        return self.remote_state.ctn
       +    def current_height(self):
       +        return {LOCAL: self.local_state.ctn, REMOTE: self.remote_state.ctn}
        
            @property
            def pending_local_fee(self):
 (DIR) diff --git a/electrum/tests/test_lnhtlc.py b/electrum/tests/test_lnhtlc.py
       t@@ -9,6 +9,8 @@ import electrum.util as util
        import os
        import binascii
        
       +from electrum.lnhtlc import SENT, LOCAL, REMOTE, RECEIVED
       +
        def create_channel_state(funding_txid, funding_index, funding_sat, local_feerate, is_initiator, local_amount, remote_amount, privkeys, other_pubkeys, seed, cur, nex, other_node_id, l_dust, r_dust, l_csv, r_csv):
            assert local_amount > 0
            assert remote_amount > 0
       t@@ -195,10 +197,10 @@ class TestLNBaseHTLCStateMachine(unittest.TestCase):
                aliceSent = 0
                bobSent = 0
        
       -        self.assertEqual(alice_channel.total_msat_sent, aliceSent, "alice has incorrect milli-satoshis sent")
       -        self.assertEqual(alice_channel.total_msat_received, bobSent, "alice has incorrect milli-satoshis received")
       -        self.assertEqual(bob_channel.total_msat_sent, bobSent, "bob has incorrect milli-satoshis sent")
       -        self.assertEqual(bob_channel.total_msat_received, aliceSent, "bob has incorrect milli-satoshis received")
       +        self.assertEqual(alice_channel.total_msat[SENT], aliceSent, "alice has incorrect milli-satoshis sent")
       +        self.assertEqual(alice_channel.total_msat[RECEIVED], bobSent, "alice has incorrect milli-satoshis received")
       +        self.assertEqual(bob_channel.total_msat[SENT], bobSent, "bob has incorrect milli-satoshis sent")
       +        self.assertEqual(bob_channel.total_msat[RECEIVED], aliceSent, "bob has incorrect milli-satoshis received")
                self.assertEqual(bob_channel.local_state.ctn, 1, "bob has incorrect commitment height")
                self.assertEqual(alice_channel.local_state.ctn, 1, "alice has incorrect commitment height")
        
       t@@ -236,18 +238,18 @@ class TestLNBaseHTLCStateMachine(unittest.TestCase):
                # should show 1 BTC received. They should also be at commitment height
                # two, with the revocation window extended by 1 (5).
                mSatTransferred = one_bitcoin_in_msat
       -        self.assertEqual(alice_channel.total_msat_sent, mSatTransferred, "alice satoshis sent incorrect %s vs %s expected"% (alice_channel.total_msat_sent, mSatTransferred))
       -        self.assertEqual(alice_channel.total_msat_received, 0, "alice satoshis received incorrect %s vs %s expected"% (alice_channel.total_msat_received, 0))
       -        self.assertEqual(bob_channel.total_msat_received, mSatTransferred, "bob satoshis received incorrect %s vs %s expected"% (bob_channel.total_msat_received, mSatTransferred))
       -        self.assertEqual(bob_channel.total_msat_sent, 0, "bob satoshis sent incorrect %s vs %s expected"% (bob_channel.total_msat_sent, 0))
       -        self.assertEqual(bob_channel.l_current_height, 2, "bob has incorrect commitment height, %s vs %s"% (bob_channel.l_current_height, 2))
       -        self.assertEqual(alice_channel.l_current_height, 2, "alice has incorrect commitment height, %s vs %s"% (alice_channel.l_current_height, 2))
       +        self.assertEqual(alice_channel.total_msat[SENT], mSatTransferred, "alice satoshis sent incorrect")
       +        self.assertEqual(alice_channel.total_msat[RECEIVED], 0, "alice satoshis received incorrect")
       +        self.assertEqual(bob_channel.total_msat[RECEIVED], mSatTransferred, "bob satoshis received incorrect")
       +        self.assertEqual(bob_channel.total_msat[SENT], 0, "bob satoshis sent incorrect")
       +        self.assertEqual(bob_channel.current_height[LOCAL], 2, "bob has incorrect commitment height")
       +        self.assertEqual(alice_channel.current_height[LOCAL], 2, "alice has incorrect commitment height")
        
                # The logs of both sides should now be cleared since the entry adding
                # the HTLC should have been removed once both sides receive the
                # revocation.
       -        self.assertEqual(alice_channel.local_update_log, [], "alice's local not updated, should be empty, has %s entries instead"% len(alice_channel.local_update_log))
       -        self.assertEqual(alice_channel.remote_update_log, [], "alice's remote not updated, should be empty, has %s entries instead"% len(alice_channel.remote_update_log))
       +        #self.assertEqual(alice_channel.local_update_log, [], "alice's local not updated, should be empty, has %s entries instead"% len(alice_channel.local_update_log))
       +        #self.assertEqual(alice_channel.remote_update_log, [], "alice's remote not updated, should be empty, has %s entries instead"% len(alice_channel.remote_update_log))
        
            def alice_to_bob_fee_update(self):
                fee = 111
       t@@ -340,7 +342,7 @@ class TestLNHTLCDust(unittest.TestCase):
                alice_channel.receive_htlc_settle(paymentPreimage, aliceHtlcIndex)
                force_state_transition(bob_channel, alice_channel)
                self.assertEqual(len(alice_channel.local_commitment.outputs()), 2)
       -        self.assertEqual(alice_channel.total_msat_sent // 1000, htlcAmt)
       +        self.assertEqual(alice_channel.total_msat[SENT] // 1000, htlcAmt)
        
        def force_state_transition(chanA, chanB):
            chanB.receive_new_commitment(*chanA.sign_next_commitment())