tlnpeer: add unit test for upfront shutdown script - electrum - Electrum Bitcoin wallet
 (HTM) git clone https://git.parazyd.org/electrum
 (DIR) Log
 (DIR) Files
 (DIR) Refs
 (DIR) Submodules
       ---
 (DIR) commit e555ea650ef49ddd1bd9f9144cac7237261c1c8c
 (DIR) parent 673f89f0d2d7b0c5563361a41195d29d68e477d8
 (HTM) Author: bitromortac <bitromortac@protonmail.com>
       Date:   Tue, 29 Dec 2020 17:40:01 +0100
       
       lnpeer: add unit test for upfront shutdown script
       
       Diffstat:
         M electrum/tests/test_lnchannel.py    |      30 +++++++++++++++++++++++-------
         M electrum/tests/test_lnpeer.py       |      76 ++++++++++++++++++++++++++++++-
       
       2 files changed, 97 insertions(+), 9 deletions(-)
       ---
 (DIR) diff --git a/electrum/tests/test_lnchannel.py b/electrum/tests/test_lnchannel.py
       t@@ -34,8 +34,6 @@ from electrum import lnchannel
        from electrum import lnutil
        from electrum import bip32 as bip32_utils
        from electrum.lnutil import SENT, LOCAL, REMOTE, RECEIVED
       -from electrum.lnutil import FeeUpdate
       -from electrum.ecc import sig_string_from_der_sig
        from electrum.logging import console_stderr_handler
        from electrum.lnchannel import ChannelState
        from electrum.json_db import StoredDict
       t@@ -46,7 +44,11 @@ from . import ElectrumTestCase
        
        one_bitcoin_in_msat = bitcoin.COIN * 1000
        
       -def create_channel_state(funding_txid, funding_index, funding_sat, is_initiator, local_amount, remote_amount, privkeys, other_pubkeys, seed, cur, nex, other_node_id, l_dust, r_dust, l_csv, r_csv):
       +
       +def create_channel_state(funding_txid, funding_index, funding_sat, is_initiator,
       +                         local_amount, remote_amount, privkeys, other_pubkeys,
       +                         seed, cur, nex, other_node_id, l_dust, r_dust, l_csv,
       +                         r_csv):
            assert local_amount > 0
            assert remote_amount > 0
            channel_id, _ = lnpeer.channel_id_from_funding_tx(funding_txid, funding_index)
       t@@ -134,16 +136,30 @@ def create_test_channels(*, feerate=6000, local_msat=None, remote_msat=None,
            alice_seed = random_gen.get_bytes(32)
            bob_seed = random_gen.get_bytes(32)
        
       -    alice_first = lnutil.secret_to_pubkey(int.from_bytes(lnutil.get_per_commitment_secret_from_seed(alice_seed, lnutil.RevocationStore.START_INDEX), "big"))
       -    bob_first = lnutil.secret_to_pubkey(int.from_bytes(lnutil.get_per_commitment_secret_from_seed(bob_seed, lnutil.RevocationStore.START_INDEX), "big"))
       +    alice_first = lnutil.secret_to_pubkey(
       +        int.from_bytes(lnutil.get_per_commitment_secret_from_seed(
       +            alice_seed, lnutil.RevocationStore.START_INDEX), "big"))
       +    bob_first = lnutil.secret_to_pubkey(
       +        int.from_bytes(lnutil.get_per_commitment_secret_from_seed(
       +            bob_seed, lnutil.RevocationStore.START_INDEX), "big"))
        
            alice, bob = (
                lnchannel.Channel(
       -            create_channel_state(funding_txid, funding_index, funding_sat, True, local_amount, remote_amount, alice_privkeys, bob_pubkeys, alice_seed, None, bob_first, other_node_id=bob_pubkey, l_dust=200, r_dust=1300, l_csv=5, r_csv=4),
       +            create_channel_state(
       +                funding_txid, funding_index, funding_sat, True, local_amount,
       +                remote_amount, alice_privkeys, bob_pubkeys, alice_seed, None,
       +                bob_first, other_node_id=bob_pubkey, l_dust=200, r_dust=1300,
       +                l_csv=5, r_csv=4
       +            ),
                    name=bob_name,
                    initial_feerate=feerate),
                lnchannel.Channel(
       -            create_channel_state(funding_txid, funding_index, funding_sat, False, remote_amount, local_amount, bob_privkeys, alice_pubkeys, bob_seed, None, alice_first, other_node_id=alice_pubkey, l_dust=1300, r_dust=200, l_csv=4, r_csv=5),
       +            create_channel_state(
       +                funding_txid, funding_index, funding_sat, False, remote_amount,
       +                local_amount, bob_privkeys, alice_pubkeys, bob_seed, None,
       +                alice_first, other_node_id=alice_pubkey, l_dust=1300, r_dust=200,
       +                l_csv=4, r_csv=5
       +            ),
                    name=alice_name,
                    initial_feerate=feerate)
            )
 (DIR) diff --git a/electrum/tests/test_lnpeer.py b/electrum/tests/test_lnpeer.py
       t@@ -12,14 +12,15 @@ from typing import Iterable, NamedTuple
        
        from aiorpcx import TaskGroup
        
       +from electrum import bitcoin
        from electrum import constants
        from electrum.network import Network
        from electrum.ecc import ECPrivkey
        from electrum import simple_config, lnutil
        from electrum.lnaddr import lnencode, LnAddr, lndecode
        from electrum.bitcoin import COIN, sha256
       -from electrum.util import bh2u, create_and_start_event_loop, NetworkRetryManager
       -from electrum.lnpeer import Peer
       +from electrum.util import bh2u, create_and_start_event_loop, NetworkRetryManager, bfh
       +from electrum.lnpeer import Peer, UpfrontShutdownScriptViolation
        from electrum.lnutil import LNPeerAddr, Keypair, privkey_to_pubkey
        from electrum.lnutil import LightningPeerConnectionClosed, RemoteMisbehaving
        from electrum.lnutil import PaymentFailure, LnFeatures, HTLCOwner
       t@@ -121,6 +122,7 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
                self.wallet = MockWallet()
                self.features = LnFeatures(0)
                self.features |= LnFeatures.OPTION_DATA_LOSS_PROTECT_OPT
       +        self.features |= LnFeatures.OPTION_UPFRONT_SHUTDOWN_SCRIPT_OPT
                self.pending_payments = defaultdict(asyncio.Future)
                for chan in chans:
                    chan.lnworker = self
       t@@ -605,6 +607,76 @@ class TestPeer(ElectrumTestCase):
                with self.assertRaises(concurrent.futures.CancelledError):
                    run(f())
        
       +    @needs_test_with_all_chacha20_implementations
       +    def test_close_upfront_shutdown_script(self):
       +        alice_channel, bob_channel = create_test_channels()
       +
       +        # create upfront shutdown script for bob, alice doesn't use upfront
       +        # shutdown script
       +        bob_uss_pub = lnutil.privkey_to_pubkey(os.urandom(32))
       +        bob_uss_addr = bitcoin.pubkey_to_address('p2wpkh', bh2u(bob_uss_pub))
       +        bob_uss = bfh(bitcoin.address_to_script(bob_uss_addr))
       +
       +        # bob commits to close to bob_uss
       +        alice_channel.config[HTLCOwner.REMOTE].upfront_shutdown_script = bob_uss
       +        # but bob closes to some receiving address, which we achieve by not
       +        # setting the upfront shutdown script in the channel config
       +        bob_channel.config[HTLCOwner.LOCAL].upfront_shutdown_script = b''
       +
       +        p1, p2, w1, w2, q1, q2 = self.prepare_peers(alice_channel, bob_channel)
       +        w1.network.config.set_key('dynamic_fees', False)
       +        w2.network.config.set_key('dynamic_fees', False)
       +        w1.network.config.set_key('fee_per_kb', 5000)
       +        w2.network.config.set_key('fee_per_kb', 1000)
       +
       +        async def test():
       +            async def close():
       +                await asyncio.wait_for(p1.initialized, 1)
       +                await asyncio.wait_for(p2.initialized, 1)
       +                # bob closes channel with different shutdown script
       +                await p1.close_channel(alice_channel.channel_id)
       +                gath.cancel()
       +
       +            async def main_loop(peer):
       +                    async with peer.taskgroup as group:
       +                        await group.spawn(peer._message_loop())
       +                        await group.spawn(peer.htlc_switch())
       +
       +            coros = [close(), main_loop(p1), main_loop(p2)]
       +            gath = asyncio.gather(*coros)
       +            await gath
       +
       +        with self.assertRaises(UpfrontShutdownScriptViolation):
       +            run(test())
       +
       +        # bob sends the same upfront_shutdown_script has he announced
       +        alice_channel.config[HTLCOwner.REMOTE].upfront_shutdown_script = bob_uss
       +        bob_channel.config[HTLCOwner.LOCAL].upfront_shutdown_script = bob_uss
       +
       +        p1, p2, w1, w2, q1, q2 = self.prepare_peers(alice_channel, bob_channel)
       +        w1.network.config.set_key('dynamic_fees', False)
       +        w2.network.config.set_key('dynamic_fees', False)
       +        w1.network.config.set_key('fee_per_kb', 5000)
       +        w2.network.config.set_key('fee_per_kb', 1000)
       +
       +        async def test():
       +            async def close():
       +                await asyncio.wait_for(p1.initialized, 1)
       +                await asyncio.wait_for(p2.initialized, 1)
       +                await p1.close_channel(alice_channel.channel_id)
       +                gath.cancel()
       +
       +            async def main_loop(peer):
       +                async with peer.taskgroup as group:
       +                    await group.spawn(peer._message_loop())
       +                    await group.spawn(peer.htlc_switch())
       +
       +            coros = [close(), main_loop(p1), main_loop(p2)]
       +            gath = asyncio.gather(*coros)
       +            await gath
       +        with self.assertRaises(asyncio.CancelledError):
       +            run(test())
       +
            def test_channel_usage_after_closing(self):
                alice_channel, bob_channel = create_test_channels()
                p1, p2, w1, w2, q1, q2 = self.prepare_peers(alice_channel, bob_channel)