tdaemon/wallet/network: make stop() methods async - electrum - Electrum Bitcoin wallet
 (HTM) git clone https://git.parazyd.org/electrum
 (DIR) Log
 (DIR) Files
 (DIR) Refs
 (DIR) Submodules
       ---
 (DIR) commit 3c019c2f9c4d2fdefe52d84444632b54d421e140
 (DIR) parent ce88b36e81a533810b08ccf8796120951265da8a
 (HTM) Author: SomberNight <somber.night@protonmail.com>
       Date:   Tue,  9 Mar 2021 17:52:36 +0100
       
       daemon/wallet/network: make stop() methods async
       
       Diffstat:
         M electrum/address_synchronizer.py    |      19 ++++++++++++-------
         M electrum/daemon.py                  |      43 ++++++++++++++++++++-----------
         M electrum/gui/__init__.py            |       6 ++++++
         M electrum/gui/kivy/main_window.py    |       3 ++-
         M electrum/gui/qt/settings_dialog.py  |       3 ++-
         M electrum/interface.py               |       2 +-
         M electrum/lnwatcher.py               |       4 ++--
         M electrum/lnworker.py                |      12 ++++++------
         M electrum/network.py                 |      31 +++++++++++++++----------------
         M electrum/sql_db.py                  |       3 +++
         M electrum/tests/test_storage_upgradā€¦ |      15 +++++++++++++--
         M electrum/tests/test_wallet.py       |      15 +++++++++++++--
         M electrum/util.py                    |      10 ++++------
         M electrum/wallet.py                  |      26 ++++++++++++++++----------
         M run_electrum                        |       1 -
       
       15 files changed, 123 insertions(+), 70 deletions(-)
       ---
 (DIR) diff --git a/electrum/address_synchronizer.py b/electrum/address_synchronizer.py
       t@@ -28,6 +28,8 @@ import itertools
        from collections import defaultdict
        from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple, NamedTuple, Sequence, List
        
       +from aiorpcx import TaskGroup
       +
        from . import bitcoin, util
        from .bitcoin import COINBASE_MATURITY
        from .util import profiler, bfh, TxMinedInfo, UnrelatedTransactionException
       t@@ -197,16 +199,19 @@ class AddressSynchronizer(Logger):
            def on_blockchain_updated(self, event, *args):
                self._get_addr_balance_cache = {}  # invalidate cache
        
       -    def stop(self):
       +    async def stop(self):
                if self.network:
       -            if self.synchronizer:
       -                asyncio.run_coroutine_threadsafe(self.synchronizer.stop(), self.network.asyncio_loop)
       +            try:
       +                async with TaskGroup() as group:
       +                    if self.synchronizer:
       +                        await group.spawn(self.synchronizer.stop())
       +                    if self.verifier:
       +                        await group.spawn(self.verifier.stop())
       +            finally:  # even if we get cancelled
                        self.synchronizer = None
       -            if self.verifier:
       -                asyncio.run_coroutine_threadsafe(self.verifier.stop(), self.network.asyncio_loop)
                        self.verifier = None
       -            util.unregister_callback(self.on_blockchain_updated)
       -            self.db.put('stored_height', self.get_local_height())
       +                util.unregister_callback(self.on_blockchain_updated)
       +                self.db.put('stored_height', self.get_local_height())
        
            def add_address(self, address):
                if not self.db.get_addr_history(address):
 (DIR) diff --git a/electrum/daemon.py b/electrum/daemon.py
       t@@ -29,7 +29,7 @@ import time
        import traceback
        import sys
        import threading
       -from typing import Dict, Optional, Tuple, Iterable, Callable, Union, Sequence, Mapping
       +from typing import Dict, Optional, Tuple, Iterable, Callable, Union, Sequence, Mapping, TYPE_CHECKING
        from base64 import b64decode, b64encode
        from collections import defaultdict
        import concurrent
       t@@ -38,7 +38,7 @@ import json
        
        import aiohttp
        from aiohttp import web, client_exceptions
       -from aiorpcx import TaskGroup
       +from aiorpcx import TaskGroup, timeout_after, TaskTimeout
        
        from . import util
        from .network import Network
       t@@ -53,6 +53,9 @@ from .simple_config import SimpleConfig
        from .exchange_rate import FxThread
        from .logging import get_logger, Logger
        
       +if TYPE_CHECKING:
       +    from electrum import gui
       +
        
        _logger = get_logger(__name__)
        
       t@@ -407,6 +410,7 @@ class PayServer(Logger):
        class Daemon(Logger):
        
            network: Optional[Network]
       +    gui_object: Optional[Union['gui.qt.ElectrumGui', 'gui.kivy.ElectrumGui']]
        
            @profiler
            def __init__(self, config: SimpleConfig, fd=None, *, listen_jsonrpc=True):
       t@@ -523,7 +527,8 @@ class Daemon(Logger):
                wallet = self._wallets.pop(path, None)
                if not wallet:
                    return False
       -        wallet.stop()
       +        fut = asyncio.run_coroutine_threadsafe(wallet.stop(), self.asyncio_loop)
       +        fut.result()
                return True
        
            def run_daemon(self):
       t@@ -544,20 +549,28 @@ class Daemon(Logger):
                    self.running = False
        
            def on_stop(self):
       +        self.logger.info("on_stop() entered. initiating shutdown")
                if self.gui_object:
                    self.gui_object.stop()
       -        # stop network/wallets
       -        for k, wallet in self._wallets.items():
       -            wallet.stop()
       -        if self.network:
       -            self.logger.info("shutting down network")
       -            self.network.stop()
       -        self.logger.info("stopping taskgroup")
       -        fut = asyncio.run_coroutine_threadsafe(self.taskgroup.cancel_remaining(), self.asyncio_loop)
       -        try:
       -            fut.result(timeout=2)
       -        except (concurrent.futures.TimeoutError, concurrent.futures.CancelledError, asyncio.CancelledError):
       -            pass
       +
       +        @log_exceptions
       +        async def stop_async():
       +            self.logger.info("stopping all wallets")
       +            async with TaskGroup() as group:
       +                for k, wallet in self._wallets.items():
       +                    await group.spawn(wallet.stop())
       +            self.logger.info("stopping network and taskgroup")
       +            try:
       +                async with timeout_after(2):
       +                    async with TaskGroup() as group:
       +                        if self.network:
       +                            await group.spawn(self.network.stop(full_shutdown=True))
       +                        await group.spawn(self.taskgroup.cancel_remaining())
       +            except TaskTimeout:
       +                pass
       +
       +        fut = asyncio.run_coroutine_threadsafe(stop_async(), self.asyncio_loop)
       +        fut.result()
                self.logger.info("removing lockfile")
                remove_lockfile(get_lockfile(self.config))
                self.logger.info("stopped")
 (DIR) diff --git a/electrum/gui/__init__.py b/electrum/gui/__init__.py
       t@@ -3,3 +3,9 @@
        # The Wallet object is instantiated by the GUI
        
        # Notifications about network events are sent to the GUI by using network.register_callback()
       +
       +from typing import TYPE_CHECKING
       +
       +if TYPE_CHECKING:
       +    from . import qt
       +    from . import kivy
 (DIR) diff --git a/electrum/gui/kivy/main_window.py b/electrum/gui/kivy/main_window.py
       t@@ -190,7 +190,8 @@ class ElectrumWindow(App, Logger):
                if self.use_gossip:
                    self.network.start_gossip()
                else:
       -            self.network.stop_gossip()
       +            self.network.run_from_another_thread(
       +                self.network.stop_gossip())
        
            android_backups = BooleanProperty(False)
            def on_android_backups(self, instance, x):
 (DIR) diff --git a/electrum/gui/qt/settings_dialog.py b/electrum/gui/qt/settings_dialog.py
       t@@ -141,7 +141,8 @@ channels graph and compute payment path locally, instead of using trampoline pay
                    if use_gossip:
                        self.window.network.start_gossip()
                    else:
       -                self.window.network.stop_gossip()
       +                self.window.network.run_from_another_thread(
       +                    self.window.network.stop_gossip())
                    util.trigger_callback('ln_gossip_sync_progress')
                    # FIXME: update all wallet windows
                    util.trigger_callback('channels_updated', self.wallet)
 (DIR) diff --git a/electrum/interface.py b/electrum/interface.py
       t@@ -695,7 +695,7 @@ class Interface(Logger):
                    # We give up after a while and just abort the connection.
                    # Note: specifically if the server is running Fulcrum, waiting seems hopeless,
                    #       the connection must be aborted (see https://github.com/cculianu/Fulcrum/issues/76)
       -            force_after = 2  # seconds
       +            force_after = 1  # seconds
                if self.session:
                    await self.session.close(force_after=force_after)
                # monitor_connection will cancel tasks
 (DIR) diff --git a/electrum/lnwatcher.py b/electrum/lnwatcher.py
       t@@ -147,8 +147,8 @@ class LNWatcher(AddressSynchronizer):
                # status gets populated when we run
                self.channel_status = {}
        
       -    def stop(self):
       -        super().stop()
       +    async def stop(self):
       +        await super().stop()
                util.unregister_callback(self.on_network_update)
        
            def get_channel_status(self, outpoint):
 (DIR) diff --git a/electrum/lnworker.py b/electrum/lnworker.py
       t@@ -311,11 +311,11 @@ class LNWorker(Logger, NetworkRetryManager[LNPeerAddr]):
                self._add_peers_from_config()
                asyncio.run_coroutine_threadsafe(self.main_loop(), self.network.asyncio_loop)
        
       -    def stop(self):
       +    async def stop(self):
                if self.listen_server:
       -            self.network.asyncio_loop.call_soon_threadsafe(self.listen_server.close)
       -        asyncio.run_coroutine_threadsafe(self.taskgroup.cancel_remaining(), self.network.asyncio_loop)
       +            self.listen_server.close()
                util.unregister_callback(self.on_proxy_changed)
       +        await self.taskgroup.cancel_remaining()
        
            def _add_peers_from_config(self):
                peer_list = self.config.get('lightning_peers', [])
       t@@ -704,9 +704,9 @@ class LNWallet(LNWorker):
                    tg_coro = self.taskgroup.spawn(coro)
                    asyncio.run_coroutine_threadsafe(tg_coro, self.network.asyncio_loop)
        
       -    def stop(self):
       -        super().stop()
       -        self.lnwatcher.stop()
       +    async def stop(self):
       +        await super().stop()
       +        await self.lnwatcher.stop()
                self.lnwatcher = None
        
            def peer_closed(self, peer):
 (DIR) diff --git a/electrum/network.py b/electrum/network.py
       t@@ -252,6 +252,11 @@ class Network(Logger, NetworkRetryManager[ServerAddr]):
            default_server: ServerAddr
            _recent_servers: List[ServerAddr]
        
       +    channel_blacklist: 'ChannelBlackList'
       +    channel_db: Optional['ChannelDB'] = None
       +    lngossip: Optional['LNGossip'] = None
       +    local_watchtower: Optional['WatchTower'] = None
       +
            def __init__(self, config: SimpleConfig, *, daemon: 'Daemon' = None):
                global _INSTANCE
                assert _INSTANCE is None, "Network is a singleton!"
       t@@ -344,9 +349,6 @@ class Network(Logger, NetworkRetryManager[ServerAddr]):
        
                # lightning network
                self.channel_blacklist = ChannelBlackList()
       -        self.channel_db = None  # type: Optional[ChannelDB]
       -        self.lngossip = None  # type: Optional[LNGossip]
       -        self.local_watchtower = None  # type: Optional[WatchTower]
                if self.config.get('run_local_watchtower', False):
                    from . import lnwatcher
                    self.local_watchtower = lnwatcher.WatchTower(self)
       t@@ -373,11 +375,13 @@ class Network(Logger, NetworkRetryManager[ServerAddr]):
                    self.lngossip = lnworker.LNGossip()
                    self.lngossip.start_network(self)
        
       -    def stop_gossip(self):
       +    async def stop_gossip(self, *, full_shutdown: bool = False):
                if self.lngossip:
       -            self.lngossip.stop()
       +            await self.lngossip.stop()
                    self.lngossip = None
                    self.channel_db.stop()
       +            if full_shutdown:
       +                await self.channel_db.stopped_event.wait()
                    self.channel_db = None
        
            def run_from_another_thread(self, coro, *, timeout=None):
       t@@ -623,7 +627,7 @@ class Network(Logger, NetworkRetryManager[ServerAddr]):
                    self.auto_connect = net_params.auto_connect
                    if self.proxy != proxy or self.oneserver != net_params.oneserver:
                        # Restart the network defaulting to the given server
       -                await self._stop()
       +                await self.stop(full_shutdown=False)
                        self.default_server = server
                        await self._start()
                    elif self.default_server != server:
       t@@ -1217,13 +1221,13 @@ class Network(Logger, NetworkRetryManager[ServerAddr]):
                asyncio.run_coroutine_threadsafe(self._start(), self.asyncio_loop)
        
            @log_exceptions
       -    async def _stop(self, full_shutdown=False):
       +    async def stop(self, *, full_shutdown: bool = True):
                self.logger.info("stopping network")
                try:
                    # note: cancel_remaining ~cannot be cancelled, it suppresses CancelledError
       -            await asyncio.wait_for(self.taskgroup.cancel_remaining(), timeout=2)
       +            await asyncio.wait_for(self.taskgroup.cancel_remaining(), timeout=1)
                except (asyncio.TimeoutError, asyncio.CancelledError) as e:
       -            self.logger.info(f"exc during main_taskgroup cancellation: {repr(e)}")
       +            self.logger.info(f"exc during taskgroup cancellation: {repr(e)}")
                self.taskgroup = None
                self.interface = None
                self.interfaces = {}
       t@@ -1231,13 +1235,8 @@ class Network(Logger, NetworkRetryManager[ServerAddr]):
                self._closing_ifaces.clear()
                if not full_shutdown:
                    util.trigger_callback('network_updated')
       -
       -    def stop(self):
       -        assert self._loop_thread != threading.current_thread(), 'must not be called from network thread'
       -        fut = asyncio.run_coroutine_threadsafe(self._stop(full_shutdown=True), self.asyncio_loop)
       -        try:
       -            fut.result(timeout=2)
       -        except (concurrent.futures.TimeoutError, concurrent.futures.CancelledError): pass
       +        if full_shutdown:
       +            await self.stop_gossip(full_shutdown=full_shutdown)
        
            async def _ensure_there_is_a_main_interface(self):
                if self.is_connected():
 (DIR) diff --git a/electrum/sql_db.py b/electrum/sql_db.py
       t@@ -25,6 +25,7 @@ class SqlDB(Logger):
                Logger.__init__(self)
                self.asyncio_loop = asyncio_loop
                self.stopping = False
       +        self.stopped_event = asyncio.Event()
                self.path = path
                test_read_write_permissions(path)
                self.commit_interval = commit_interval
       t@@ -65,6 +66,8 @@ class SqlDB(Logger):
                # write
                self.conn.commit()
                self.conn.close()
       +
       +        self.asyncio_loop.call_soon_threadsafe(self.stopped_event.set)
                self.logger.info("SQL thread terminated")
        
            def create_database(self):
 (DIR) diff --git a/electrum/tests/test_storage_upgrade.py b/electrum/tests/test_storage_upgrade.py
       t@@ -3,10 +3,12 @@ import tempfile
        import os
        import json
        from typing import Optional
       +import asyncio
        
        from electrum.wallet_db import WalletDB
        from electrum.wallet import Wallet
        from electrum import constants
       +from electrum import util
        
        from .test_wallet import WalletTestCase
        
       t@@ -15,6 +17,15 @@ from .test_wallet import WalletTestCase
        # TODO hw wallet with client version 2.6.x (single-, and multiacc)
        class TestStorageUpgrade(WalletTestCase):
        
       +    def setUp(self):
       +        super().setUp()
       +        self.asyncio_loop, self._stop_loop, self._loop_thread = util.create_and_start_event_loop()
       +
       +    def tearDown(self):
       +        super().tearDown()
       +        self.asyncio_loop.call_soon_threadsafe(self._stop_loop.set_result, 1)
       +        self._loop_thread.join(timeout=1)
       +
            def testnet_wallet(func):
                # note: it's ok to modify global network constants in subclasses of SequentialTestCase
                def wrapper(self, *args, **kwargs):
       t@@ -281,7 +292,7 @@ class TestStorageUpgrade(WalletTestCase):
                # to simulate ks.opportunistically_fill_in_missing_info_from_device():
                ks._root_fingerprint = "deadbeef"
                ks.is_requesting_to_be_rewritten_to_wallet_file = True
       -        wallet.stop()
       +        asyncio.run_coroutine_threadsafe(wallet.stop(), self.asyncio_loop).result()
        
            def test_upgrade_from_client_2_9_3_importedkeys_keystore_changes(self):
                # see #6401
       t@@ -292,7 +303,7 @@ class TestStorageUpgrade(WalletTestCase):
                    ["p2wpkh:L1cgMEnShp73r9iCukoPE3MogLeueNYRD9JVsfT1zVHyPBR3KqBY"],
                    password=None
                )
       -        wallet.stop()
       +        asyncio.run_coroutine_threadsafe(wallet.stop(), self.asyncio_loop).result()
        
            @testnet_wallet
            def test_upgrade_from_client_3_3_8_xpub_with_realistic_history(self):
 (DIR) diff --git a/electrum/tests/test_wallet.py b/electrum/tests/test_wallet.py
       t@@ -5,8 +5,9 @@ import os
        import json
        from decimal import Decimal
        import time
       -
        from io import StringIO
       +import asyncio
       +
        from electrum.storage import WalletStorage
        from electrum.wallet_db import FINAL_SEED_VERSION
        from electrum.wallet import (Abstract_Wallet, Standard_Wallet, create_new_wallet,
       t@@ -16,6 +17,7 @@ from electrum.util import TxMinedInfo, InvalidPassword
        from electrum.bitcoin import COIN
        from electrum.wallet_db import WalletDB
        from electrum.simple_config import SimpleConfig
       +from electrum import util
        
        from . import ElectrumTestCase
        
       t@@ -237,6 +239,15 @@ class TestCreateRestoreWallet(WalletTestCase):
        
        class TestWalletPassword(WalletTestCase):
        
       +    def setUp(self):
       +        super().setUp()
       +        self.asyncio_loop, self._stop_loop, self._loop_thread = util.create_and_start_event_loop()
       +
       +    def tearDown(self):
       +        super().tearDown()
       +        self.asyncio_loop.call_soon_threadsafe(self._stop_loop.set_result, 1)
       +        self._loop_thread.join(timeout=1)
       +
            def test_update_password_of_imported_wallet(self):
                wallet_str = '{"addr_history":{"1364Js2VG66BwRdkaoxAaFtdPb1eQgn8Dr":[],"15CyDgLffJsJgQrhcyooFH4gnVDG82pUrA":[],"1Exet2BhHsFxKTwhnfdsBMkPYLGvobxuW6":[]},"addresses":{"change":[],"receiving":["1364Js2VG66BwRdkaoxAaFtdPb1eQgn8Dr","1Exet2BhHsFxKTwhnfdsBMkPYLGvobxuW6","15CyDgLffJsJgQrhcyooFH4gnVDG82pUrA"]},"keystore":{"keypairs":{"0344b1588589958b0bcab03435061539e9bcf54677c104904044e4f8901f4ebdf5":"L2sED74axVXC4H8szBJ4rQJrkfem7UMc6usLCPUoEWxDCFGUaGUM","0389508c13999d08ffae0f434a085f4185922d64765c0bff2f66e36ad7f745cc5f":"L3Gi6EQLvYw8gEEUckmqawkevfj9s8hxoQDFveQJGZHTfyWnbk1U","04575f52b82f159fa649d2a4c353eb7435f30206f0a6cb9674fbd659f45082c37d559ffd19bea9c0d3b7dcc07a7b79f4cffb76026d5d4dff35341efe99056e22d2":"5JyVyXU1LiRXATvRTQvR9Kp8Rx1X84j2x49iGkjSsXipydtByUq"},"type":"imported"},"pruned_txo":{},"seed_version":13,"stored_height":-1,"transactions":{},"tx_fees":{},"txi":{},"txo":{},"use_encryption":false,"verified_tx3":{},"wallet_type":"standard","winpos-qt":[100,100,840,405]}'
                db = WalletDB(wallet_str, manual_upgrades=False)
       t@@ -273,7 +284,7 @@ class TestWalletPassword(WalletTestCase):
                db = WalletDB(wallet_str, manual_upgrades=False)
                storage = WalletStorage(self.wallet_path)
                wallet = Wallet(db, storage, config=self.config)
       -        wallet.stop()
       +        asyncio.run_coroutine_threadsafe(wallet.stop(), self.asyncio_loop).result()
        
                storage = WalletStorage(self.wallet_path)
                # if storage.is_encrypted():
 (DIR) diff --git a/electrum/util.py b/electrum/util.py
       t@@ -1205,11 +1205,9 @@ class NetworkJobOnDefaultServer(Logger, ABC):
                if taskgroup != self.taskgroup:
                    raise asyncio.CancelledError()
        
       -    async def stop(self):
       -        unregister_callback(self._restart)
       -        await self._stop()
       -
       -    async def _stop(self):
       +    async def stop(self, *, full_shutdown: bool = True):
       +        if full_shutdown:
       +            unregister_callback(self._restart)
                await self.taskgroup.cancel_remaining()
        
            @log_exceptions
       t@@ -1219,7 +1217,7 @@ class NetworkJobOnDefaultServer(Logger, ABC):
                    return  # we should get called again soon
        
                async with self._restart_lock:
       -            await self._stop()
       +            await self.stop(full_shutdown=False)
                    self._reset()
                    await self._start(interface)
        
 (DIR) diff --git a/electrum/wallet.py b/electrum/wallet.py
       t@@ -46,7 +46,7 @@ import itertools
        import threading
        import enum
        
       -from aiorpcx import TaskGroup
       +from aiorpcx import TaskGroup, timeout_after, TaskTimeout
        
        from .i18n import _
        from .bip32 import BIP32Node, convert_bip32_intpath_to_strpath, convert_bip32_path_to_list_of_uint32
       t@@ -353,15 +353,21 @@ class Abstract_Wallet(AddressSynchronizer, ABC):
                ln_xprv = node.to_xprv()
                self.db.put('lightning_privkey2', ln_xprv)
        
       -    def stop(self):
       -        super().stop()
       -        if any([ks.is_requesting_to_be_rewritten_to_wallet_file for ks in self.get_keystores()]):
       -            self.save_keystore()
       -        if self.network:
       -            if self.lnworker:
       -                self.lnworker.stop()
       -                self.lnworker = None
       -        self.save_db()
       +    async def stop(self):
       +        """Stop all networking and save DB to disk."""
       +        try:
       +            async with timeout_after(5):
       +                await super().stop()
       +                if self.network:
       +                    if self.lnworker:
       +                        await self.lnworker.stop()
       +                        self.lnworker = None
       +        except TaskTimeout:
       +            pass
       +        finally:  # even if we get cancelled
       +            if any([ks.is_requesting_to_be_rewritten_to_wallet_file for ks in self.get_keystores()]):
       +                self.save_keystore()
       +            self.save_db()
        
            def set_up_to_date(self, b):
                super().set_up_to_date(b)
 (DIR) diff --git a/run_electrum b/run_electrum
       t@@ -345,7 +345,6 @@ def main():
                    print_stderr('unknown command:', uri)
                    sys.exit(1)
        
       -    # singleton
            config = SimpleConfig(config_options)
        
            if config.get('testnet'):