tMerge pull request #6003 from spesmilo/htlc_switch - electrum - Electrum Bitcoin wallet
 (HTM) git clone https://git.parazyd.org/electrum
 (DIR) Log
 (DIR) Files
 (DIR) Refs
 (DIR) Submodules
       ---
 (DIR) commit 367d30d6c0585e9e270ad312c5f79105d3d77dc5
 (DIR) parent c81335fb44e78ae4a2c7ef963e65384b18ca18d6
 (HTM) Author: ThomasV <thomasv@electrum.org>
       Date:   Mon,  2 Mar 2020 22:14:09 +0100
       
       Merge pull request #6003 from spesmilo/htlc_switch
       
       Htlc switch
       Diffstat:
         M electrum/lnchannel.py               |       7 ++++++-
         M electrum/lnhtlc.py                  |       1 +
         M electrum/lnpeer.py                  |     232 +++++++++++++++----------------
         M electrum/lnworker.py                |       2 +-
         M electrum/tests/test_lnpeer.py       |      12 ++++++------
       
       5 files changed, 126 insertions(+), 128 deletions(-)
       ---
 (DIR) diff --git a/electrum/lnchannel.py b/electrum/lnchannel.py
       t@@ -408,7 +408,7 @@ class Channel(Logger):
                self.logger.info("add_htlc")
                return htlc
        
       -    def receive_htlc(self, htlc: UpdateAddHtlc) -> UpdateAddHtlc:
       +    def receive_htlc(self, htlc: UpdateAddHtlc, onion_packet:bytes = None) -> UpdateAddHtlc:
                """
                ReceiveHTLC adds an HTLC to the state machine's remote update log. This
                method should be called in response to receiving a new HTLC from the remote
       t@@ -427,6 +427,11 @@ class Channel(Logger):
                            f' HTLC amount: {htlc.amount_msat}')
                with self.db_lock:
                    self.hm.recv_htlc(htlc)
       +            local_ctn = self.get_latest_ctn(LOCAL)
       +            remote_ctn = self.get_latest_ctn(REMOTE)
       +            if onion_packet:
       +                self.hm.log['unfulfilled_htlcs'][htlc.htlc_id] = local_ctn, remote_ctn, onion_packet.hex(), False
       +
                self.logger.info("receive_htlc")
                return htlc
        
 (DIR) diff --git a/electrum/lnhtlc.py b/electrum/lnhtlc.py
       t@@ -25,6 +25,7 @@ class HTLCManager:
                    log[LOCAL] = deepcopy(initial)
                    log[REMOTE] = deepcopy(initial)
                    log['unacked_local_updates2'] = {}
       +            log['unfulfilled_htlcs'] = {}  # htlc_id -> onion_packet
        
                # maybe bootstrap fee_updates if initial_feerate was provided
                if initial_feerate is not None:
 (DIR) diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py
       t@@ -249,6 +249,7 @@ class Peer(Logger):
            async def main_loop(self):
                async with self.taskgroup as group:
                    await group.spawn(self._message_loop())
       +            await group.spawn(self.htlc_switch())
                    await group.spawn(self.query_gossip())
                    await group.spawn(self.process_gossip())
        
       t@@ -1131,195 +1132,137 @@ class Peer(Logger):
                self.logger.info(f"on_update_add_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}")
                cltv_expiry = int.from_bytes(payload["cltv_expiry"], 'big')
                amount_msat_htlc = int.from_bytes(payload["amount_msat"], 'big')
       -        onion_packet = OnionPacket.from_bytes(payload["onion_routing_packet"])
       -        processed_onion = process_onion_packet(onion_packet, associated_data=payment_hash, our_onion_private_key=self.privkey)
       +        onion_packet = payload["onion_routing_packet"]
                if chan.get_state() != channel_states.OPEN:
                    raise RemoteMisbehaving(f"received update_add_htlc while chan.get_state() != OPEN. state was {chan.get_state()}")
                if cltv_expiry > bitcoin.NLOCKTIME_BLOCKHEIGHT_MAX:
                    asyncio.ensure_future(self.lnworker.force_close_channel(channel_id))
                    raise RemoteMisbehaving(f"received update_add_htlc with cltv_expiry > BLOCKHEIGHT_MAX. value was {cltv_expiry}")
                # add htlc
       -        htlc = UpdateAddHtlc(amount_msat=amount_msat_htlc,
       -                             payment_hash=payment_hash,
       -                             cltv_expiry=cltv_expiry,
       -                             timestamp=int(time.time()),
       -                             htlc_id=htlc_id)
       -        htlc = chan.receive_htlc(htlc)
       -        # TODO: fulfilling/failing/forwarding of htlcs should be robust to going offline.
       -        #       instead of storing state implicitly in coroutines, we could decouple it from receiving the htlc.
       -        #       maybe persist the required details, and have a long-running task that makes these decisions.
       -        local_ctn = chan.get_latest_ctn(LOCAL)
       -        remote_ctn = chan.get_latest_ctn(REMOTE)
       -        if processed_onion.are_we_final:
       -            asyncio.ensure_future(self._maybe_fulfill_htlc(chan=chan,
       -                                                           htlc=htlc,
       -                                                           local_ctn=local_ctn,
       -                                                           remote_ctn=remote_ctn,
       -                                                           onion_packet=onion_packet,
       -                                                           processed_onion=processed_onion))
       -        else:
       -            asyncio.ensure_future(self._maybe_forward_htlc(chan=chan,
       -                                                           htlc=htlc,
       -                                                           local_ctn=local_ctn,
       -                                                           remote_ctn=remote_ctn,
       -                                                           onion_packet=onion_packet,
       -                                                           processed_onion=processed_onion))
       -
       -    @log_exceptions
       -    async def _maybe_forward_htlc(self, chan: Channel, htlc: UpdateAddHtlc, *, local_ctn: int, remote_ctn: int,
       -                                  onion_packet: OnionPacket, processed_onion: ProcessedOnionPacket):
       -        await self.await_local(chan, local_ctn)
       -        await self.await_remote(chan, remote_ctn)
       +        htlc = UpdateAddHtlc(
       +            amount_msat=amount_msat_htlc,
       +            payment_hash=payment_hash,
       +            cltv_expiry=cltv_expiry,
       +            timestamp=int(time.time()),
       +            htlc_id=htlc_id)
       +        chan.receive_htlc(htlc, onion_packet)
       +
       +    def maybe_forward_htlc(self, chan: Channel, htlc: UpdateAddHtlc, *,
       +                           onion_packet: OnionPacket, processed_onion: ProcessedOnionPacket):
                # Forward HTLC
       -        # FIXME: this is not robust to us going offline before payment is fulfilled
                # FIXME: there are critical safety checks MISSING here
                forwarding_enabled = self.network.config.get('lightning_forward_payments', False)
                if not forwarding_enabled:
                    self.logger.info(f"forwarding is disabled. failing htlc.")
       -            reason = OnionRoutingFailureMessage(code=OnionFailureCode.PERMANENT_CHANNEL_FAILURE, data=b'')
       -            await self.fail_htlc(chan, htlc.htlc_id, onion_packet, reason)
       -            return
       +            return OnionRoutingFailureMessage(code=OnionFailureCode.PERMANENT_CHANNEL_FAILURE, data=b'')
                dph = processed_onion.hop_data.per_hop
                next_chan = self.lnworker.get_channel_by_short_id(dph.short_channel_id)
                next_chan_scid = dph.short_channel_id
       -        next_peer = self.lnworker.peers[next_chan.node_id]
                local_height = self.network.get_local_height()
                if next_chan is None:
                    self.logger.info(f"cannot forward htlc. cannot find next_chan {next_chan_scid}")
       -            reason = OnionRoutingFailureMessage(code=OnionFailureCode.UNKNOWN_NEXT_PEER, data=b'')
       -            await self.fail_htlc(chan, htlc.htlc_id, onion_packet, reason)
       -            return
       +            return OnionRoutingFailureMessage(code=OnionFailureCode.UNKNOWN_NEXT_PEER, data=b'')
                outgoing_chan_upd = next_chan.get_outgoing_gossip_channel_update()[2:]
                outgoing_chan_upd_len = len(outgoing_chan_upd).to_bytes(2, byteorder="big")
                if not next_chan.can_send_update_add_htlc():
                    self.logger.info(f"cannot forward htlc. next_chan {next_chan_scid} cannot send ctx updates. "
                                     f"chan state {next_chan.get_state()}, peer state: {next_chan.peer_state}")
       -            reason = OnionRoutingFailureMessage(code=OnionFailureCode.TEMPORARY_CHANNEL_FAILURE,
       -                                                data=outgoing_chan_upd_len+outgoing_chan_upd)
       -            await self.fail_htlc(chan, htlc.htlc_id, onion_packet, reason)
       -            return
       +            data = outgoing_chan_upd_len + outgoing_chan_upd
       +            return OnionRoutingFailureMessage(code=OnionFailureCode.TEMPORARY_CHANNEL_FAILURE, data=data)
                next_cltv_expiry = int.from_bytes(dph.outgoing_cltv_value, 'big')
                if htlc.cltv_expiry - next_cltv_expiry < NBLOCK_OUR_CLTV_EXPIRY_DELTA:
       -            reason = OnionRoutingFailureMessage(code=OnionFailureCode.INCORRECT_CLTV_EXPIRY,
       -                                                data=(htlc.cltv_expiry.to_bytes(4, byteorder="big")
       -                                                      + outgoing_chan_upd_len + outgoing_chan_upd))
       -            await self.fail_htlc(chan, htlc.htlc_id, onion_packet, reason)
       -            return
       +            data = htlc.cltv_expiry.to_bytes(4, byteorder="big") + outgoing_chan_upd_len + outgoing_chan_upd
       +            return OnionRoutingFailureMessage(code=OnionFailureCode.INCORRECT_CLTV_EXPIRY, data=data)
                if htlc.cltv_expiry - lnutil.NBLOCK_DEADLINE_BEFORE_EXPIRY_FOR_RECEIVED_HTLCS <= local_height \
                        or next_cltv_expiry <= local_height:
       -            reason = OnionRoutingFailureMessage(code=OnionFailureCode.EXPIRY_TOO_SOON,
       -                                                data=outgoing_chan_upd_len+outgoing_chan_upd)
       -            await self.fail_htlc(chan, htlc.htlc_id, onion_packet, reason)
       -            return
       +            data = outgoing_chan_upd_len + outgoing_chan_upd
       +            return OnionRoutingFailureMessage(code=OnionFailureCode.EXPIRY_TOO_SOON, data=data)
                if max(htlc.cltv_expiry, next_cltv_expiry) > local_height + lnutil.NBLOCK_CLTV_EXPIRY_TOO_FAR_INTO_FUTURE:
       -            reason = OnionRoutingFailureMessage(code=OnionFailureCode.EXPIRY_TOO_FAR, data=b'')
       -            await self.fail_htlc(chan, htlc.htlc_id, onion_packet, reason)
       -            return
       +            return OnionRoutingFailureMessage(code=OnionFailureCode.EXPIRY_TOO_FAR, data=b'')
                next_amount_msat_htlc = int.from_bytes(dph.amt_to_forward, 'big')
       -        forwarding_fees = fee_for_edge_msat(forwarded_amount_msat=next_amount_msat_htlc,
       -                                            fee_base_msat=lnutil.OUR_FEE_BASE_MSAT,
       -                                            fee_proportional_millionths=lnutil.OUR_FEE_PROPORTIONAL_MILLIONTHS)
       +        forwarding_fees = fee_for_edge_msat(
       +            forwarded_amount_msat=next_amount_msat_htlc,
       +            fee_base_msat=lnutil.OUR_FEE_BASE_MSAT,
       +            fee_proportional_millionths=lnutil.OUR_FEE_PROPORTIONAL_MILLIONTHS)
                if htlc.amount_msat - next_amount_msat_htlc < forwarding_fees:
       -            reason = OnionRoutingFailureMessage(code=OnionFailureCode.FEE_INSUFFICIENT,
       -                                                data=(next_amount_msat_htlc.to_bytes(8, byteorder="big")
       -                                                      + outgoing_chan_upd_len + outgoing_chan_upd))
       -            await self.fail_htlc(chan, htlc.htlc_id, onion_packet, reason)
       -            return
       -
       +            data = next_amount_msat_htlc.to_bytes(8, byteorder="big") + outgoing_chan_upd_len + outgoing_chan_upd
       +            return OnionRoutingFailureMessage(code=OnionFailureCode.FEE_INSUFFICIENT, data=data)
                self.logger.info(f'forwarding htlc to {next_chan.node_id}')
       -        next_htlc = UpdateAddHtlc(amount_msat=next_amount_msat_htlc, payment_hash=htlc.payment_hash, cltv_expiry=next_cltv_expiry, timestamp=int(time.time()))
       +        next_htlc = UpdateAddHtlc(
       +            amount_msat=next_amount_msat_htlc,
       +            payment_hash=htlc.payment_hash,
       +            cltv_expiry=next_cltv_expiry,
       +            timestamp=int(time.time()))
                next_htlc = next_chan.add_htlc(next_htlc)
       -        next_remote_ctn = next_chan.get_latest_ctn(REMOTE)
       -        next_peer.send_message(
       -            "update_add_htlc",
       -            channel_id=next_chan.channel_id,
       -            id=next_htlc.htlc_id,
       -            cltv_expiry=dph.outgoing_cltv_value,
       -            amount_msat=dph.amt_to_forward,
       -            payment_hash=next_htlc.payment_hash,
       -            onion_routing_packet=processed_onion.next_packet.to_bytes()
       -        )
       -        await next_peer.await_remote(next_chan, next_remote_ctn)
       -        success, preimage, reason = await self.lnworker.await_payment(next_htlc.payment_hash)
       -        if success:
       -            await self._fulfill_htlc(chan, htlc.htlc_id, preimage)
       -            self.logger.info("htlc forwarded successfully")
       -        else:
       -            # TODO: test this
       -            self.logger.info(f"forwarded htlc has failed, {reason}")
       -            await self.fail_htlc(chan, htlc.htlc_id, onion_packet, reason)
       -
       -    @log_exceptions
       -    async def _maybe_fulfill_htlc(self, chan: Channel, htlc: UpdateAddHtlc, *, local_ctn: int, remote_ctn: int,
       -                                  onion_packet: OnionPacket, processed_onion: ProcessedOnionPacket):
       -        await self.await_local(chan, local_ctn)
       -        await self.await_remote(chan, remote_ctn)
       +        next_peer = self.lnworker.peers[next_chan.node_id]
       +        try:
       +            next_peer.send_message(
       +                "update_add_htlc",
       +                channel_id=next_chan.channel_id,
       +                id=next_htlc.htlc_id,
       +                cltv_expiry=dph.outgoing_cltv_value,
       +                amount_msat=dph.amt_to_forward,
       +                payment_hash=next_htlc.payment_hash,
       +                onion_routing_packet=processed_onion.next_packet.to_bytes()
       +            )
       +        except BaseException as e:
       +            self.logger.info(f"failed to forward htlc: error sending message. {e}")
       +            data = outgoing_chan_upd_len + outgoing_chan_upd
       +            return OnionRoutingFailureMessage(code=OnionFailureCode.TEMPORARY_CHANNEL_FAILURE, data=data)
       +        return None
       +
       +    def maybe_fulfill_htlc(self, chan: Channel, htlc: UpdateAddHtlc, *,
       +                          onion_packet: OnionPacket, processed_onion: ProcessedOnionPacket):
                try:
                    info = self.lnworker.get_payment_info(htlc.payment_hash)
                    preimage = self.lnworker.get_preimage(htlc.payment_hash)
                except UnknownPaymentHash:
                    reason = OnionRoutingFailureMessage(code=OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, data=b'')
       -            await self.fail_htlc(chan, htlc.htlc_id, onion_packet, reason)
       -            return
       +            return False, reason
                expected_received_msat = int(info.amount * 1000) if info.amount is not None else None
                if expected_received_msat is not None and \
                        not (expected_received_msat <= htlc.amount_msat <= 2 * expected_received_msat):
                    reason = OnionRoutingFailureMessage(code=OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, data=b'')
       -            await self.fail_htlc(chan, htlc.htlc_id, onion_packet, reason)
       -            return
       +            return False, reason
                local_height = self.network.get_local_height()
                if local_height + MIN_FINAL_CLTV_EXPIRY_ACCEPTED > htlc.cltv_expiry:
                    reason = OnionRoutingFailureMessage(code=OnionFailureCode.FINAL_EXPIRY_TOO_SOON, data=b'')
       -            await self.fail_htlc(chan, htlc.htlc_id, onion_packet, reason)
       -            return
       +            return False, reason
                cltv_from_onion = int.from_bytes(processed_onion.hop_data.per_hop.outgoing_cltv_value, byteorder="big")
                if cltv_from_onion != htlc.cltv_expiry:
                    reason = OnionRoutingFailureMessage(code=OnionFailureCode.FINAL_INCORRECT_CLTV_EXPIRY,
                                                        data=htlc.cltv_expiry.to_bytes(4, byteorder="big"))
       -            await self.fail_htlc(chan, htlc.htlc_id, onion_packet, reason)
       -            return
       +            return False, reason
                amount_from_onion = int.from_bytes(processed_onion.hop_data.per_hop.amt_to_forward, byteorder="big")
                if amount_from_onion > htlc.amount_msat:
                    reason = OnionRoutingFailureMessage(code=OnionFailureCode.FINAL_INCORRECT_HTLC_AMOUNT,
                                                        data=htlc.amount_msat.to_bytes(8, byteorder="big"))
       -            await self.fail_htlc(chan, htlc.htlc_id, onion_packet, reason)
       -            return
       -        #self.network.trigger_callback('htlc_added', htlc, invoice, RECEIVED)
       -        await self.lnworker.enable_htlc_settle.wait()
       -        await self._fulfill_htlc(chan, htlc.htlc_id, preimage)
       +            return False, reason
       +        # all good
       +        return preimage, None
        
       -    async def _fulfill_htlc(self, chan: Channel, htlc_id: int, preimage: bytes):
       +    def fulfill_htlc(self, chan: Channel, htlc_id: int, preimage: bytes):
                self.logger.info(f"_fulfill_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}")
       -        if not chan.can_send_ctx_updates():
       -            self.logger.info(f"dropping chan update (fulfill htlc {htlc_id}) for {chan.short_channel_id}. "
       -                             f"cannot send updates")
       -            return
       +        assert chan.can_send_ctx_updates(), f"cannot send updates: {chan.short_channel_id}"
                chan.settle_htlc(preimage, htlc_id)
                payment_hash = sha256(preimage)
                self.lnworker.payment_received(payment_hash)
       -        remote_ctn = chan.get_latest_ctn(REMOTE)
                self.send_message("update_fulfill_htlc",
                                  channel_id=chan.channel_id,
                                  id=htlc_id,
                                  payment_preimage=preimage)
       -        await self.await_remote(chan, remote_ctn)
        
       -    async def fail_htlc(self, chan: Channel, htlc_id: int, onion_packet: OnionPacket,
       -                        reason: OnionRoutingFailureMessage):
       +    def fail_htlc(self, chan: Channel, htlc_id: int, onion_packet: OnionPacket,
       +                  reason: OnionRoutingFailureMessage):
                self.logger.info(f"fail_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}. reason: {reason}")
       -        if not chan.can_send_ctx_updates():
       -            self.logger.info(f"dropping chan update (fail htlc {htlc_id}) for {chan.short_channel_id}. "
       -                             f"cannot send updates")
       -            return
       +        assert chan.can_send_ctx_updates(), f"cannot send updates: {chan.short_channel_id}"
                chan.fail_htlc(htlc_id)
       -        remote_ctn = chan.get_latest_ctn(REMOTE)
                error_packet = construct_onion_error(reason, onion_packet, our_onion_private_key=self.privkey)
                self.send_message("update_fail_htlc",
                                  channel_id=chan.channel_id,
                                  id=htlc_id,
                                  len=len(error_packet),
                                  reason=error_packet)
       -        await self.await_remote(chan, remote_ctn)
        
            def on_revoke_and_ack(self, payload):
                channel_id = payload["channel_id"]
       t@@ -1484,3 +1427,52 @@ class Peer(Logger):
                # broadcast
                await self.network.try_broadcasting(closing_tx, 'closing')
                return closing_tx.txid()
       +
       +    async def htlc_switch(self):
       +        while True:
       +            await asyncio.sleep(0.1)
       +            for chan_id, chan in self.channels.items():
       +                if not chan.can_send_ctx_updates():
       +                    continue
       +                self.maybe_send_commitment(chan)
       +                done = set()
       +                unfulfilled = chan.hm.log.get('unfulfilled_htlcs', {})
       +                for htlc_id, (local_ctn, remote_ctn, onion_packet_hex, forwarded) in unfulfilled.items():
       +                    if chan.get_oldest_unrevoked_ctn(LOCAL) <= local_ctn:
       +                        continue
       +                    if chan.get_oldest_unrevoked_ctn(REMOTE) <= remote_ctn:
       +                        continue
       +                    chan.logger.info(f'found unfulfilled htlc: {htlc_id}')
       +                    onion_packet = OnionPacket.from_bytes(bytes.fromhex(onion_packet_hex))
       +                    htlc = chan.hm.log[REMOTE]['adds'][htlc_id]
       +                    payment_hash = htlc.payment_hash
       +                    processed_onion = process_onion_packet(onion_packet, associated_data=payment_hash, our_onion_private_key=self.privkey)
       +                    preimage, error = None, None
       +                    if processed_onion.are_we_final:
       +                        preimage, error = self.maybe_fulfill_htlc(
       +                            chan=chan,
       +                            htlc=htlc,
       +                            onion_packet=onion_packet,
       +                            processed_onion=processed_onion)
       +                    elif not forwarded:
       +                        error = self.maybe_forward_htlc(
       +                            chan=chan,
       +                            htlc=htlc,
       +                            onion_packet=onion_packet,
       +                            processed_onion=processed_onion)
       +                        if not error:
       +                            unfulfilled[htlc_id] = local_ctn, remote_ctn, onion_packet_hex, True
       +                    else:
       +                        f = self.lnworker.pending_payments[payment_hash]
       +                        if f.done():
       +                            success, preimage, error = f.result()
       +                    if preimage:
       +                        await self.lnworker.enable_htlc_settle.wait()
       +                        self.fulfill_htlc(chan, htlc.htlc_id, preimage)
       +                        done.add(htlc_id)
       +                    if error:
       +                        self.fail_htlc(chan, htlc.htlc_id, onion_packet, error)
       +                        done.add(htlc_id)
       +                # cleanup
       +                for htlc_id in done:
       +                    unfulfilled.pop(htlc_id)
 (DIR) diff --git a/electrum/lnworker.py b/electrum/lnworker.py
       t@@ -55,7 +55,7 @@ from .lnutil import (Outpoint, LNPeerAddr,
                             ShortChannelID, PaymentAttemptLog, PaymentAttemptFailureDetails)
        from .lnutil import ln_dummy_address, ln_compare_features
        from .transaction import PartialTxOutput, PartialTransaction, PartialTxInput
       -from .lnonion import OnionFailureCode
       +from .lnonion import OnionFailureCode, process_onion_packet, OnionPacket
        from .lnmsg import decode_msg
        from .i18n import _
        from .lnrouter import RouteEdge, LNPaymentRoute, is_route_sane_to_use
 (DIR) diff --git a/electrum/tests/test_lnpeer.py b/electrum/tests/test_lnpeer.py
       t@@ -238,7 +238,7 @@ class TestPeer(ElectrumTestCase):
                    self.assertEqual(alice_channel.peer_state, peer_states.GOOD)
                    self.assertEqual(bob_channel.peer_state, peer_states.GOOD)
                    gath.cancel()
       -        gath = asyncio.gather(reestablish(), p1._message_loop(), p2._message_loop())
       +        gath = asyncio.gather(reestablish(), p1._message_loop(), p2._message_loop(), p1.htlc_switch(), p1.htlc_switch())
                async def f():
                    await gath
                with self.assertRaises(concurrent.futures.CancelledError):
       t@@ -253,7 +253,7 @@ class TestPeer(ElectrumTestCase):
                    result = await LNWallet._pay(w1, pay_req)
                    self.assertEqual(result, True)
                    gath.cancel()
       -        gath = asyncio.gather(pay(), p1._message_loop(), p2._message_loop())
       +        gath = asyncio.gather(pay(), p1._message_loop(), p2._message_loop(), p1.htlc_switch(), p2.htlc_switch())
                async def f():
                    await gath
                with self.assertRaises(concurrent.futures.CancelledError):
       t@@ -271,7 +271,7 @@ class TestPeer(ElectrumTestCase):
                    # wait so that pending messages are processed
                    #await asyncio.sleep(1)
                    gath.cancel()
       -        gath = asyncio.gather(reestablish(), p1._message_loop(), p2._message_loop())
       +        gath = asyncio.gather(reestablish(), p1._message_loop(), p2._message_loop(), p1.htlc_switch(), p2.htlc_switch())
                async def f():
                    await gath
                with self.assertRaises(concurrent.futures.CancelledError):
       t@@ -285,7 +285,7 @@ class TestPeer(ElectrumTestCase):
                    result = await LNWallet._pay(w1, pay_req)
                    self.assertTrue(result)
                    gath.cancel()
       -        gath = asyncio.gather(pay(), p1._message_loop(), p2._message_loop())
       +        gath = asyncio.gather(pay(), p1._message_loop(), p2._message_loop(), p1.htlc_switch(), p2.htlc_switch())
                async def f():
                    await gath
                with self.assertRaises(concurrent.futures.CancelledError):
       t@@ -313,7 +313,7 @@ class TestPeer(ElectrumTestCase):
                async def set_settle():
                    await asyncio.sleep(0.1)
                    w2.enable_htlc_settle.set()
       -        gath = asyncio.gather(pay(), set_settle(), p1._message_loop(), p2._message_loop())
       +        gath = asyncio.gather(pay(), set_settle(), p1._message_loop(), p2._message_loop(), p1.htlc_switch(), p2.htlc_switch())
                async def f():
                    await gath
                with self.assertRaises(concurrent.futures.CancelledError):
       t@@ -338,7 +338,7 @@ class TestPeer(ElectrumTestCase):
                # AssertionError is ok since we shouldn't use old routes, and the
                # route finding should fail when channel is closed
                async def f():
       -            await asyncio.gather(w1._pay_to_route(route, addr), p1._message_loop(), p2._message_loop())
       +            await asyncio.gather(w1._pay_to_route(route, addr), p1._message_loop(), p2._message_loop(), p1.htlc_switch(), p2.htlc_switch())
                with self.assertRaises(PaymentFailure):
                    run(f())