tutil: add NetworkRetryManager, a baseclass for LNWorker and Network - electrum - Electrum Bitcoin wallet
 (HTM) git clone https://git.parazyd.org/electrum
 (DIR) Log
 (DIR) Files
 (DIR) Refs
 (DIR) Submodules
       ---
 (DIR) commit 76f0ad3271a611457a88e4bfe213334afcf1a5ae
 (DIR) parent 90cb032721ffd8fad883762b3942de7a08b1949a
 (HTM) Author: SomberNight <somber.night@protonmail.com>
       Date:   Wed, 15 Apr 2020 17:17:11 +0200
       
       util: add NetworkRetryManager, a baseclass for LNWorker and Network
       
       Diffstat:
         M electrum/lnworker.py                |      48 +++++++++++--------------------
         M electrum/network.py                 |      46 +++++++++++--------------------
         M electrum/util.py                    |      53 ++++++++++++++++++++++++++++++-
       
       3 files changed, 85 insertions(+), 62 deletions(-)
       ---
 (DIR) diff --git a/electrum/lnworker.py b/electrum/lnworker.py
       t@@ -25,7 +25,7 @@ from . import constants, util
        from . import keystore
        from .util import profiler
        from .util import PR_UNPAID, PR_EXPIRED, PR_PAID, PR_INFLIGHT, PR_FAILED, PR_ROUTING
       -from .util import PR_TYPE_LN
       +from .util import PR_TYPE_LN, NetworkRetryManager
        from .lnutil import LN_MAX_FUNDING_SAT
        from .keystore import BIP32_KeyStore
        from .bitcoin import COIN
       t@@ -78,10 +78,6 @@ SAVED_PR_STATUS = [PR_PAID, PR_UNPAID, PR_INFLIGHT] # status that are persisted
        
        NUM_PEERS_TARGET = 4
        
       -MAX_RETRY_DELAY_FOR_PEERS = 3600  # sec
       -INIT_RETRY_DELAY_FOR_PEERS = 600  # sec
       -MAX_RETRY_DELAY_FOR_CHANNEL_PEERS = 300  # sec
       -INIT_RETRY_DELAY_FOR_CHANNEL_PEERS = 4  # sec
        
        FALLBACK_NODE_LIST_TESTNET = (
            LNPeerAddr(host='203.132.95.10', port=9735, pubkey=bfh('038863cf8ab91046230f561cd5b386cbff8309fa02e3f0c3ed161a3aeb64a643b9')),
       t@@ -143,10 +139,17 @@ class NoPathFound(PaymentFailure):
                return _('No path found')
        
        
       -class LNWorker(Logger):
       +class LNWorker(Logger, NetworkRetryManager[LNPeerAddr]):
        
            def __init__(self, xprv):
                Logger.__init__(self)
       +        NetworkRetryManager.__init__(
       +            self,
       +            max_retry_delay_normal=3600,
       +            init_retry_delay_normal=600,
       +            max_retry_delay_urgent=300,
       +            init_retry_delay_urgent=4,
       +        )
                self.node_keypair = generate_keypair(BIP32Node.from_xkey(xprv), LnKeyFamily.NODE_KEY)
                self.peers = {}  # type: Dict[bytes, Peer]  # pubkey -> Peer
                self.taskgroup = SilentTaskGroup()
       t@@ -158,8 +161,6 @@ class LNWorker(Logger):
                self.features |= LnFeatures.VAR_ONION_OPT
                self.features |= LnFeatures.PAYMENT_SECRET_OPT
        
       -        self._last_tried_peer = {}  # type: Dict[LNPeerAddr, Tuple[float, int]]  # LNPeerAddr -> (unix ts, num_attempts)
       -
            def channels_for_peer(self, node_id):
                return {}
        
       t@@ -208,17 +209,16 @@ class LNWorker(Logger):
                        continue
                    peers = await self._get_next_peers_to_try()
                    for peer in peers:
       -                if self._can_retry_peer(peer, now=now):
       +                if self._can_retry_addr(peer, now=now):
                            await self._add_peer(peer.host, peer.port, peer.pubkey)
        
       -    async def _add_peer(self, host, port, node_id) -> Peer:
       +    async def _add_peer(self, host: str, port: int, node_id: bytes) -> Peer:
                if node_id in self.peers:
                    return self.peers[node_id]
                port = int(port)
                peer_addr = LNPeerAddr(host, port, node_id)
                transport = LNTransport(self.node_keypair.privkey, peer_addr)
       -        last_time, num_attempts = self._last_tried_peer.get(peer_addr, (0, 0))
       -        self._last_tried_peer[peer_addr] = time.time(), num_attempts + 1
       +        self._trying_addr_now(peer_addr)
                self.logger.info(f"adding peer {peer_addr}")
                peer = Peer(self, node_id, transport)
                await self.taskgroup.spawn(peer.main_loop())
       t@@ -266,7 +266,7 @@ class LNWorker(Logger):
                if isinstance(peer.transport, LNTransport):
                    peer_addr = peer.transport.peer_addr
                    # reset connection attempt count
       -            self._last_tried_peer[peer_addr] = time.time(), 0
       +            self._on_connection_successfully_established(peer_addr)
                    # add into channel db
                    if self.channel_db:
                        self.channel_db.add_recent_peer(peer_addr)
       t@@ -274,20 +274,6 @@ class LNWorker(Logger):
                    for chan in peer.channels.values():
                        chan.add_or_update_peer_addr(peer_addr)
        
       -    def _can_retry_peer(self, peer: LNPeerAddr, *,
       -                        now: float = None, for_channel: bool = False) -> bool:
       -        if now is None:
       -            now = time.time()
       -        last_time, num_attempts = self._last_tried_peer.get(peer, (0, 0))
       -        if for_channel:
       -            delay = min(MAX_RETRY_DELAY_FOR_CHANNEL_PEERS,
       -                        INIT_RETRY_DELAY_FOR_CHANNEL_PEERS * 2 ** num_attempts)
       -        else:
       -            delay = min(MAX_RETRY_DELAY_FOR_PEERS,
       -                        INIT_RETRY_DELAY_FOR_PEERS * 2 ** num_attempts)
       -        next_time = last_time + delay
       -        return next_time < now
       -
            async def _get_next_peers_to_try(self) -> Sequence[LNPeerAddr]:
                now = time.time()
                await self.channel_db.data_loaded.wait()
       t@@ -298,7 +284,7 @@ class LNWorker(Logger):
                        continue
                    if peer.pubkey in self.peers:
                        continue
       -            if not self._can_retry_peer(peer, now=now):
       +            if not self._can_retry_addr(peer, now=now):
                        continue
                    if not self.is_good_peer(peer):
                        continue
       t@@ -315,7 +301,7 @@ class LNWorker(Logger):
                            peer = LNPeerAddr(host, port, node_id)
                        except ValueError:
                            continue
       -                if not self._can_retry_peer(peer, now=now):
       +                if not self._can_retry_addr(peer, now=now):
                            continue
                        if not self.is_good_peer(peer):
                            continue
       t@@ -330,7 +316,7 @@ class LNWorker(Logger):
                else:
                    return []  # regtest??
        
       -        fallback_list = [peer for peer in fallback_list if self._can_retry_peer(peer, now=now)]
       +        fallback_list = [peer for peer in fallback_list if self._can_retry_addr(peer, now=now)]
                if fallback_list:
                    return [random.choice(fallback_list)]
        
       t@@ -1298,7 +1284,7 @@ class LNWallet(LNWorker):
                # Done gathering addresses.
                # Now select first one that has not failed recently.
                for peer in peer_addresses:
       -            if self._can_retry_peer(peer, for_channel=True, now=now):
       +            if self._can_retry_addr(peer, urgent=True, now=now):
                        await self._add_peer(peer.host, peer.port, peer.pubkey)
                        return
        
 (DIR) diff --git a/electrum/network.py b/electrum/network.py
       t@@ -44,7 +44,7 @@ from aiohttp import ClientResponse
        from . import util
        from .util import (log_exceptions, ignore_exceptions,
                           bfh, SilentTaskGroup, make_aiohttp_session, send_exception_to_crash_reporter,
       -                   is_hash256_str, is_non_negative_integer, MyEncoder)
       +                   is_hash256_str, is_non_negative_integer, MyEncoder, NetworkRetryManager)
        
        from .bitcoin import COIN
        from . import constants
       t@@ -74,10 +74,6 @@ _logger = get_logger(__name__)
        NUM_TARGET_CONNECTED_SERVERS = 10
        NUM_STICKY_SERVERS = 4
        NUM_RECENT_SERVERS = 20
       -MAX_RETRY_DELAY_FOR_SERVERS = 600  # sec
       -INIT_RETRY_DELAY_FOR_SERVERS = 15  # sec
       -MAX_RETRY_DELAY_FOR_MAIN_SERVER = 10  # sec
       -INIT_RETRY_DELAY_FOR_MAIN_SERVER = 1  # sec
        
        
        def parse_servers(result: Sequence[Tuple[str, str, List[str]]]) -> Dict[str, dict]:
       t@@ -235,7 +231,7 @@ class UntrustedServerReturnedError(NetworkException):
        _INSTANCE = None
        
        
       -class Network(Logger):
       +class Network(Logger, NetworkRetryManager[ServerAddr]):
            """The Network class manages a set of connections to remote electrum
            servers, each connected socket is handled by an Interface() object.
            """
       t@@ -255,6 +251,13 @@ class Network(Logger):
                _INSTANCE = self
        
                Logger.__init__(self)
       +        NetworkRetryManager.__init__(
       +            self,
       +            max_retry_delay_normal=600,
       +            init_retry_delay_normal=15,
       +            max_retry_delay_urgent=10,
       +            init_retry_delay_urgent=1,
       +        )
        
                self.asyncio_loop = asyncio.get_event_loop()
                assert self.asyncio_loop.is_running(), "event loop not running"
       t@@ -301,8 +304,6 @@ class Network(Logger):
                dir_path = os.path.join(self.config.path, 'certs')
                util.make_dir(dir_path)
        
       -        # retry times
       -        self._last_tried_server = {}  # type: Dict[ServerAddr, Tuple[float, int]]  # unix ts, num_attempts
                # the main server we are currently communicating with
                self.interface = None
                self.default_server_changed_event = asyncio.Event()
       t@@ -536,19 +537,6 @@ class Network(Logger):
                    out = filter_noonion(out)
                return out
        
       -    def _can_retry_server(self, server: ServerAddr, *, now: float = None) -> bool:
       -        if now is None:
       -            now = time.time()
       -        last_time, num_attempts = self._last_tried_server.get(server, (0, 0))
       -        if server == self.default_server:
       -            delay = min(MAX_RETRY_DELAY_FOR_MAIN_SERVER,
       -                        INIT_RETRY_DELAY_FOR_MAIN_SERVER * 2 ** num_attempts)
       -        else:
       -            delay = min(MAX_RETRY_DELAY_FOR_SERVERS,
       -                        INIT_RETRY_DELAY_FOR_SERVERS * 2 ** num_attempts)
       -        next_time = last_time + delay
       -        return next_time < now
       -
            def _get_next_server_to_try(self) -> Optional[ServerAddr]:
                now = time.time()
                with self.interfaces_lock:
       t@@ -566,7 +554,7 @@ class Network(Logger):
                    for server in recent_servers:
                        if server in connected_servers:
                            continue
       -                if not self._can_retry_server(server, now=now):
       +                if not self._can_retry_addr(server, now=now):
                            continue
                        return server
                # try all servers we know about, pick one at random
       t@@ -574,7 +562,7 @@ class Network(Logger):
                servers = list(set(filter_protocol(hostmap, self.protocol)) - connected_servers)
                random.shuffle(servers)
                for server in servers:
       -            if not self._can_retry_server(server, now=now):
       +            if not self._can_retry_addr(server, now=now):
                        continue
                    return server
                return None
       t@@ -726,8 +714,8 @@ class Network(Logger):
                    await interface.close()
        
            @with_recent_servers_lock
       -    def _add_recent_server(self, server):
       -        self._last_tried_server[server] = time.time(), 0
       +    def _add_recent_server(self, server: ServerAddr) -> None:
       +        self._on_connection_successfully_established(server)
                # list is ordered
                if server in self._recent_servers:
                    self._recent_servers.remove(server)
       t@@ -761,9 +749,7 @@ class Network(Logger):
                if server == self.default_server:
                    self.logger.info(f"connecting to {server} as new interface")
                    self._set_status('connecting')
       -        # update _last_tried_server
       -        last_time, num_attempts = self._last_tried_server.get(server, (0, 0))
       -        self._last_tried_server[server] = time.time(), num_attempts + 1
       +        self._trying_addr_now(server)
        
                interface = Interface(network=self, server=server, proxy=self.proxy)
                # note: using longer timeouts here as DNS can sometimes be slow!
       t@@ -1151,7 +1137,7 @@ class Network(Logger):
                assert not self.interface and not self.interfaces
                assert not self._connecting
                self.logger.info('starting network')
       -        self._last_tried_server.clear()
       +        self._clear_addr_retry_times()
                self.protocol = self.default_server.protocol
                self._set_proxy(deserialize_proxy(self.config.get('proxy')))
                self._set_oneserver(self.config.get('oneserver', False))
       t@@ -1213,7 +1199,7 @@ class Network(Logger):
                    await self._switch_to_random_interface()
                # if auto_connect is not set, or still no main interface, retry current
                if not self.is_connected() and not self.is_connecting():
       -            if self._can_retry_server(self.default_server):
       +            if self._can_retry_addr(self.default_server, urgent=True):
                        await self.switch_to_interface(self.default_server)
        
            async def _maintain_sessions(self):
 (DIR) diff --git a/electrum/util.py b/electrum/util.py
       t@@ -23,7 +23,8 @@
        import binascii
        import os, sys, re, json
        from collections import defaultdict, OrderedDict
       -from typing import NamedTuple, Union, TYPE_CHECKING, Tuple, Optional, Callable, Any, Sequence
       +from typing import (NamedTuple, Union, TYPE_CHECKING, Tuple, Optional, Callable, Any,
       +                    Sequence, Dict, Generic, TypeVar)
        from datetime import datetime
        import decimal
        from decimal import Decimal
       t@@ -1342,3 +1343,53 @@ callback_mgr = CallbackManager()
        trigger_callback = callback_mgr.trigger_callback
        register_callback = callback_mgr.register_callback
        unregister_callback = callback_mgr.unregister_callback
       +
       +
       +_NetAddrType = TypeVar("_NetAddrType")
       +
       +
       +class NetworkRetryManager(Generic[_NetAddrType]):
       +    """Truncated Exponential Backoff for network connections."""
       +
       +    def __init__(
       +            self, *,
       +            max_retry_delay_normal: float,
       +            init_retry_delay_normal: float,
       +            max_retry_delay_urgent: float = None,
       +            init_retry_delay_urgent: float = None,
       +    ):
       +        self._last_tried_addr = {}  # type: Dict[_NetAddrType, Tuple[float, int]]  # (unix ts, num_attempts)
       +
       +        # note: these all use "seconds" as unit
       +        if max_retry_delay_urgent is None:
       +            max_retry_delay_urgent = max_retry_delay_normal
       +        if init_retry_delay_urgent is None:
       +            init_retry_delay_urgent = init_retry_delay_normal
       +        self._max_retry_delay_normal = max_retry_delay_normal
       +        self._init_retry_delay_normal = init_retry_delay_normal
       +        self._max_retry_delay_urgent = max_retry_delay_urgent
       +        self._init_retry_delay_urgent = init_retry_delay_urgent
       +
       +    def _trying_addr_now(self, addr: _NetAddrType) -> None:
       +        last_time, num_attempts = self._last_tried_addr.get(addr, (0, 0))
       +        self._last_tried_addr[addr] = time.time(), num_attempts + 1
       +
       +    def _on_connection_successfully_established(self, addr: _NetAddrType) -> None:
       +        self._last_tried_addr[addr] = time.time(), 0
       +
       +    def _can_retry_addr(self, peer: _NetAddrType, *,
       +                        now: float = None, urgent: bool = False) -> bool:
       +        if now is None:
       +            now = time.time()
       +        last_time, num_attempts = self._last_tried_addr.get(peer, (0, 0))
       +        if urgent:
       +            delay = min(self._max_retry_delay_urgent,
       +                        self._init_retry_delay_urgent * 2 ** num_attempts)
       +        else:
       +            delay = min(self._max_retry_delay_normal,
       +                        self._init_retry_delay_normal * 2 ** num_attempts)
       +        next_time = last_time + delay
       +        return next_time < now
       +
       +    def _clear_addr_retry_times(self) -> None:
       +        self._last_tried_addr.clear()