ttest_lnpeer: some clean-up, make it easier to add "num_node>2" tests - electrum - Electrum Bitcoin wallet
 (HTM) git clone https://git.parazyd.org/electrum
 (DIR) Log
 (DIR) Files
 (DIR) Refs
 (DIR) Submodules
       ---
 (DIR) commit 04d018cd0f5fafb8abd4b2771fcc1c6206f80279
 (DIR) parent 7951f2ed3b652e12dfaf7f9685d1428a2a9e4deb
 (HTM) Author: SomberNight <somber.night@protonmail.com>
       Date:   Wed,  6 May 2020 10:44:38 +0200
       
       ttest_lnpeer: some clean-up, make it easier to add "num_node>2" tests
       
       Diffstat:
         M electrum/coinchooser.py             |       4 ++--
         M electrum/tests/test_lnchannel.py    |      22 ++++++++++++++--------
         M electrum/tests/test_lnpeer.py       |      94 ++++++++++++++++++-------------
       
       3 files changed, 72 insertions(+), 48 deletions(-)
       ---
 (DIR) diff --git a/electrum/coinchooser.py b/electrum/coinchooser.py
       t@@ -44,12 +44,12 @@ class PRNG:
                self.sha = sha256(seed)
                self.pool = bytearray()
        
       -    def get_bytes(self, n):
       +    def get_bytes(self, n: int) -> bytes:
                while len(self.pool) < n:
                    self.pool.extend(self.sha)
                    self.sha = sha256(self.sha)
                result, self.pool = self.pool[:n], self.pool[n:]
       -        return result
       +        return bytes(result)
        
            def randint(self, start, end):
                # Returns random integer in [start, end)
 (DIR) diff --git a/electrum/tests/test_lnchannel.py b/electrum/tests/test_lnchannel.py
       t@@ -39,6 +39,7 @@ from electrum.ecc import sig_string_from_der_sig
        from electrum.logging import console_stderr_handler
        from electrum.lnchannel import ChannelState
        from electrum.json_db import StoredDict
       +from electrum.coinchooser import PRNG
        
        from . import ElectrumTestCase
        
       t@@ -110,8 +111,13 @@ def bip32(sequence):
            assert type(k) is bytes
            return k
        
       -def create_test_channels(*, feerate=6000, local_msat=None, remote_msat=None):
       -    funding_txid = binascii.hexlify(b"\x01"*32).decode("ascii")
       +def create_test_channels(*, feerate=6000, local_msat=None, remote_msat=None,
       +                         alice_name="alice", bob_name="bob",
       +                         alice_pubkey=b"\x01"*33, bob_pubkey=b"\x02"*33, random_seed=None):
       +    if random_seed is None:  # needed for deterministic randomness
       +        random_seed = os.urandom(32)
       +    random_gen = PRNG(random_seed)
       +    funding_txid = binascii.hexlify(random_gen.get_bytes(32)).decode("ascii")
            funding_index = 0
            funding_sat = ((local_msat + remote_msat) // 1000) if local_msat is not None and remote_msat is not None else (bitcoin.COIN * 10)
            local_amount = local_msat if local_msat is not None else (funding_sat * 1000 // 2)
       t@@ -123,20 +129,20 @@ def create_test_channels(*, feerate=6000, local_msat=None, remote_msat=None):
            alice_pubkeys = [lnutil.OnlyPubkeyKeypair(x.pubkey) for x in alice_privkeys]
            bob_pubkeys = [lnutil.OnlyPubkeyKeypair(x.pubkey) for x in bob_privkeys]
        
       -    alice_seed = b"\x01" * 32
       -    bob_seed = b"\x02" * 32
       +    alice_seed = random_gen.get_bytes(32)
       +    bob_seed = random_gen.get_bytes(32)
        
            alice_first = lnutil.secret_to_pubkey(int.from_bytes(lnutil.get_per_commitment_secret_from_seed(alice_seed, lnutil.RevocationStore.START_INDEX), "big"))
            bob_first = lnutil.secret_to_pubkey(int.from_bytes(lnutil.get_per_commitment_secret_from_seed(bob_seed, lnutil.RevocationStore.START_INDEX), "big"))
        
            alice, bob = (
                lnchannel.Channel(
       -            create_channel_state(funding_txid, funding_index, funding_sat, True, local_amount, remote_amount, alice_privkeys, bob_pubkeys, alice_seed, None, bob_first, b"\x02"*33, l_dust=200, r_dust=1300, l_csv=5, r_csv=4),
       -            name="alice",
       +            create_channel_state(funding_txid, funding_index, funding_sat, True, local_amount, remote_amount, alice_privkeys, bob_pubkeys, alice_seed, None, bob_first, other_node_id=bob_pubkey, l_dust=200, r_dust=1300, l_csv=5, r_csv=4),
       +            name=bob_name,
                    initial_feerate=feerate),
                lnchannel.Channel(
       -            create_channel_state(funding_txid, funding_index, funding_sat, False, remote_amount, local_amount, bob_privkeys, alice_pubkeys, bob_seed, None, alice_first, b"\x01"*33, l_dust=1300, r_dust=200, l_csv=4, r_csv=5),
       -            name="bob",
       +            create_channel_state(funding_txid, funding_index, funding_sat, False, remote_amount, local_amount, bob_privkeys, alice_pubkeys, bob_seed, None, alice_first, other_node_id=alice_pubkey, l_dust=1300, r_dust=200, l_csv=4, r_csv=5),
       +            name=alice_name,
                    initial_feerate=feerate)
            )
        
 (DIR) diff --git a/electrum/tests/test_lnpeer.py b/electrum/tests/test_lnpeer.py
       t@@ -8,6 +8,7 @@ import logging
        import concurrent
        from concurrent import futures
        import unittest
       +from typing import Iterable
        
        from aiorpcx import TaskGroup
        
       t@@ -96,21 +97,23 @@ class MockWallet:
                return False
        
        class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
       -    def __init__(self, remote_keypair, local_keypair, chan: 'Channel', tx_queue):
       +    def __init__(self, *, local_keypair: Keypair, chans: Iterable['Channel'], tx_queue):
                Logger.__init__(self)
                NetworkRetryManager.__init__(self, max_retry_delay_normal=1, init_retry_delay_normal=1)
       -        self.remote_keypair = remote_keypair
                self.node_keypair = local_keypair
                self.network = MockNetwork(tx_queue)
       -        self._channels = {chan.channel_id: chan}
       +        self.channel_db = self.network.channel_db
       +        self._channels = {chan.channel_id: chan
       +                          for chan in chans}
                self.payments = {}
                self.logs = defaultdict(list)
                self.wallet = MockWallet()
                self.features = LnFeatures(0)
                self.features |= LnFeatures.OPTION_DATA_LOSS_PROTECT_OPT
                self.pending_payments = defaultdict(asyncio.Future)
       -        chan.lnworker = self
       -        chan.node_id = remote_keypair.pubkey
       +        for chan in chans:
       +            chan.lnworker = self
       +        self._peers = {}  # bytes -> Peer
                # used in tests
                self.enable_htlc_settle = asyncio.Event()
                self.enable_htlc_settle.set()
       t@@ -130,13 +133,6 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
            def peers(self):
                return self._peers
        
       -    @property
       -    def _peers(self):
       -        return {self.remote_keypair.pubkey: self.peer}
       -
       -    def channels_for_peer(self, pubkey):
       -        return self._channels
       -
            def get_channel_by_short_id(self, short_channel_id):
                with self.lock:
                    for chan in self._channels.values():
       t@@ -171,6 +167,9 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
            get_first_timestamp = lambda self: 0
            on_peer_successfully_established = LNWallet.on_peer_successfully_established
            get_channel_by_id = LNWallet.get_channel_by_id
       +    channels_for_peer = LNWallet.channels_for_peer
       +    _calc_routing_hints_for_invoice = LNWallet._calc_routing_hints_for_invoice
       +    handle_error_code_from_failed_htlc = LNWallet.handle_error_code_from_failed_htlc
        
        
        class MockTransport:
       t@@ -206,12 +205,16 @@ class PutIntoOthersQueueTransport(MockTransport):
                self.other_mock_transport.queue.put_nowait(data)
        
        def transport_pair(k1, k2, name1, name2):
       -    t1 = PutIntoOthersQueueTransport(k1, name1)
       -    t2 = PutIntoOthersQueueTransport(k2, name2)
       +    t1 = PutIntoOthersQueueTransport(k1, name2)
       +    t2 = PutIntoOthersQueueTransport(k2, name1)
            t1.other_mock_transport = t2
            t2.other_mock_transport = t1
            return t1, t2
        
       +
       +class PaymentDone(Exception): pass
       +
       +
        class TestPeer(ElectrumTestCase):
        
            @classmethod
       t@@ -230,14 +233,16 @@ class TestPeer(ElectrumTestCase):
        
            def prepare_peers(self, alice_channel, bob_channel):
                k1, k2 = keypair(), keypair()
       -        t1, t2 = transport_pair(k2, k1, alice_channel.name, bob_channel.name)
       +        alice_channel.node_id = k2.pubkey
       +        bob_channel.node_id = k1.pubkey
       +        t1, t2 = transport_pair(k1, k2, alice_channel.name, bob_channel.name)
                q1, q2 = asyncio.Queue(), asyncio.Queue()
       -        w1 = MockLNWallet(k1, k2, alice_channel, tx_queue=q1)
       -        w2 = MockLNWallet(k2, k1, bob_channel, tx_queue=q2)
       -        p1 = Peer(w1, k1.pubkey, t1)
       -        p2 = Peer(w2, k2.pubkey, t2)
       -        w1.peer = p1
       -        w2.peer = p2
       +        w1 = MockLNWallet(local_keypair=k1, chans=[alice_channel], tx_queue=q1)
       +        w2 = MockLNWallet(local_keypair=k2, chans=[bob_channel], tx_queue=q2)
       +        p1 = Peer(w1, k2.pubkey, t1)
       +        p2 = Peer(w2, k1.pubkey, t2)
       +        w1._peers[p1.pubkey] = p1
       +        w2._peers[p2.pubkey] = p2
                # mark_open won't work if state is already OPEN.
                # so set it to FUNDED
                alice_channel._state = ChannelState.FUNDED
       t@@ -248,10 +253,11 @@ class TestPeer(ElectrumTestCase):
                return p1, p2, w1, w2, q1, q2
        
            @staticmethod
       -    def prepare_invoice(
       -            w2,  # receiver
       +    async def prepare_invoice(
       +            w2: MockLNWallet,  # receiver
                    *,
                    amount_sat=100_000,
       +            include_routing_hints=False,
            ):
                amount_btc = amount_sat/Decimal(COIN)
                payment_preimage = os.urandom(32)
       t@@ -259,12 +265,16 @@ class TestPeer(ElectrumTestCase):
                info = PaymentInfo(RHASH, amount_sat, RECEIVED, PR_UNPAID)
                w2.save_preimage(RHASH, payment_preimage)
                w2.save_payment_info(info)
       +        if include_routing_hints:
       +            routing_hints = await w2._calc_routing_hints_for_invoice(amount_sat)
       +        else:
       +            routing_hints = []
                lnaddr = LnAddr(
                            paymenthash=RHASH,
                            amount=amount_btc,
                            tags=[('c', lnutil.MIN_FINAL_CLTV_EXPIRY_FOR_INVOICE),
                                  ('d', 'coffee')
       -                         ])
       +                         ] + routing_hints)
                return lnencode(lnaddr, w2.node_keypair.privkey)
        
            def test_reestablish(self):
       t@@ -287,10 +297,11 @@ class TestPeer(ElectrumTestCase):
        
            @needs_test_with_all_chacha20_implementations
            def test_reestablish_with_old_state(self):
       -        alice_channel, bob_channel = create_test_channels()
       -        alice_channel_0, bob_channel_0 = create_test_channels() # these are identical
       +        random_seed = os.urandom(32)
       +        alice_channel, bob_channel = create_test_channels(random_seed=random_seed)
       +        alice_channel_0, bob_channel_0 = create_test_channels(random_seed=random_seed)  # these are identical
                p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel)
       -        pay_req = self.prepare_invoice(w2)
       +        pay_req = run(self.prepare_invoice(w2))
                async def pay():
                    result, log = await w1._pay(pay_req)
                    self.assertEqual(result, True)
       t@@ -323,15 +334,20 @@ class TestPeer(ElectrumTestCase):
            def test_payment(self):
                alice_channel, bob_channel = create_test_channels()
                p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel)
       -        pay_req = self.prepare_invoice(w2)
       -        async def pay():
       +        async def pay(pay_req):
                    result, log = await w1._pay(pay_req)
                    self.assertTrue(result)
       -            gath.cancel()
       -        gath = asyncio.gather(pay(), p1._message_loop(), p2._message_loop(), p1.htlc_switch(), p2.htlc_switch())
       +            raise PaymentDone()
                async def f():
       -            await gath
       -        with self.assertRaises(concurrent.futures.CancelledError):
       +            async with TaskGroup() as group:
       +                await group.spawn(p1._message_loop())
       +                await group.spawn(p1.htlc_switch())
       +                await group.spawn(p2._message_loop())
       +                await group.spawn(p2.htlc_switch())
       +                await asyncio.sleep(0.01)
       +                pay_req = await self.prepare_invoice(w2)
       +                await group.spawn(pay(pay_req))
       +        with self.assertRaises(PaymentDone):
                    run(f())
        
            #@unittest.skip("too expensive")
       t@@ -343,15 +359,17 @@ class TestPeer(ElectrumTestCase):
                bob_init_balance_msat = bob_channel.balance(HTLCOwner.LOCAL)
                num_payments = 50
                payment_value_sat = 10000  # make it large enough so that there are actually HTLCs on the ctx
       -        #pay_reqs1 = [self.prepare_invoice(w1, amount_sat=1) for i in range(num_payments)]
       -        pay_reqs2 = [self.prepare_invoice(w2, amount_sat=payment_value_sat) for i in range(num_payments)]
                max_htlcs_in_flight = asyncio.Semaphore(5)
                async def single_payment(pay_req):
                    async with max_htlcs_in_flight:
                        await w1._pay(pay_req)
                async def many_payments():
                    async with TaskGroup() as group:
       -                for pay_req in pay_reqs2:
       +                pay_reqs_tasks = [await group.spawn(self.prepare_invoice(w2, amount_sat=payment_value_sat))
       +                                  for i in range(num_payments)]
       +            async with TaskGroup() as group:
       +                for pay_req_task in pay_reqs_tasks:
       +                    pay_req = pay_req_task.result()
                            await group.spawn(single_payment(pay_req))
                    gath.cancel()
                gath = asyncio.gather(many_payments(), p1._message_loop(), p2._message_loop(), p1.htlc_switch(), p2.htlc_switch())
       t@@ -373,7 +391,7 @@ class TestPeer(ElectrumTestCase):
                w1.network.config.set_key('fee_per_kb', 5000)
                w2.network.config.set_key('fee_per_kb', 1000)
                w2.enable_htlc_settle.clear()
       -        pay_req = self.prepare_invoice(w2)
       +        pay_req = run(self.prepare_invoice(w2))
                lnaddr = lndecode(pay_req, expected_hrp=constants.net.SEGWIT_HRP)
                async def pay():
                    await asyncio.wait_for(p1.initialized, 1)
       t@@ -401,7 +419,7 @@ class TestPeer(ElectrumTestCase):
            def test_channel_usage_after_closing(self):
                alice_channel, bob_channel = create_test_channels()
                p1, p2, w1, w2, q1, q2 = self.prepare_peers(alice_channel, bob_channel)
       -        pay_req = self.prepare_invoice(w2)
       +        pay_req = run(self.prepare_invoice(w2))
        
                addr = w1._check_invoice(pay_req)
                route = w1._create_route_from_invoice(decoded_invoice=addr)