tpersist nodes in channel_db on disk - electrum - Electrum Bitcoin wallet
 (HTM) git clone https://git.parazyd.org/electrum
 (DIR) Log
 (DIR) Files
 (DIR) Refs
 (DIR) Submodules
       ---
 (DIR) commit bc06ded4b929641214299bbae7eb404b307ccdca
 (DIR) parent 5a05a92b3d631277318cb015b56f98ee36a084ee
 (HTM) Author: SomberNight <somber.night@protonmail.com>
       Date:   Thu, 26 Jul 2018 21:08:25 +0200
       
       persist nodes in channel_db on disk
       
       Diffstat:
         M electrum/gui/qt/channels_list.py    |      19 +++++++++++--------
         M electrum/lnbase.py                  |      35 ++-----------------------------
         M electrum/lnrouter.py                |     135 +++++++++++++++++++++++++++++--
         M electrum/lnutil.py                  |      18 +++++++++++++++++-
         M electrum/lnworker.py                |       9 +++++----
         M electrum/network.py                 |       1 -
         M electrum/tests/test_util.py         |      21 ++++++++++++++++++++-
         M electrum/util.py                    |      20 +++++++++++++++++++-
       
       8 files changed, 204 insertions(+), 54 deletions(-)
       ---
 (DIR) diff --git a/electrum/gui/qt/channels_list.py b/electrum/gui/qt/channels_list.py
       t@@ -64,10 +64,12 @@ class ChannelsList(MyTreeWidget):
                return h
        
            def update_status(self):
       -        n = len(self.parent.network.lightning_nodes)
       -        nc = len(self.parent.network.channel_db)
       -        np = len(self.parent.wallet.lnworker.peers)
       -        self.status.setText(_('{} peers, {} nodes, {} channels').format(np, n, nc))
       +        channel_db = self.parent.network.channel_db
       +        num_nodes = len(channel_db.nodes)
       +        num_channels = len(channel_db)
       +        num_peers = len(self.parent.wallet.lnworker.peers)
       +        self.status.setText(_('{} peers, {} nodes, {} channels')
       +                            .format(num_peers, num_nodes, num_channels))
        
            def new_channel_dialog(self):
                lnworker = self.parent.wallet.lnworker
       t@@ -116,15 +118,16 @@ class ChannelsList(MyTreeWidget):
        
                peer = lnworker.peers.get(node_id)
                if not peer:
       -            known = node_id in self.parent.network.lightning_nodes
       +            all_nodes = self.parent.network.channel_db.nodes
       +            node_info = all_nodes.get(node_id, None)
                    if rest is not None:
                        try:
                            host, port = rest.split(":")
                        except ValueError:
                            self.parent.show_error(_('Connection strings must be in <node_pubkey>@<host>:<port> format'))
       -            elif known:
       -                node = self.network.lightning_nodes.get(node_id)
       -                host, port = node['addresses'][0]
       +                    return
       +            elif node_info:
       +                host, port = node_info.addresses[0]
                    else:
                        self.parent.show_error(_('Unknown node:') + ' ' + nodeid_hex)
                        return
 (DIR) diff --git a/electrum/lnbase.py b/electrum/lnbase.py
       t@@ -29,7 +29,7 @@ from . import crypto
        from .crypto import sha256
        from . import constants
        from . import transaction
       -from .util import PrintError, bh2u, print_error, bfh, profiler, xor_bytes
       +from .util import PrintError, bh2u, print_error, bfh
        from .transaction import opcodes, Transaction
        from .lnonion import new_onion_packet, OnionHopsDataSingle, OnionPerHop, decode_onion_error
        from .lnaddr import lndecode
       t@@ -428,38 +428,7 @@ class Peer(PrintError):
                self.funding_signed[channel_id].put_nowait(payload)
        
            def on_node_announcement(self, payload):
       -        pubkey = payload['node_id']
       -        signature = payload['signature']
       -        h = bitcoin.Hash(payload['raw'][66:])
       -        if not ecc.verify_signature(pubkey, signature, h):
       -            return False
       -        self.s = payload['addresses']
       -        def read(n):
       -            data, self.s = self.s[0:n], self.s[n:]
       -            return data
       -        addresses = []
       -        while self.s:
       -            atype = ord(read(1))
       -            if atype == 0:
       -                pass
       -            elif atype == 1:
       -                ipv4_addr = '.'.join(map(lambda x: '%d' % x, read(4)))
       -                port = int.from_bytes(read(2), 'big')
       -                x = ipv4_addr, port, binascii.hexlify(pubkey)
       -                addresses.append((ipv4_addr, port))
       -            elif atype == 2:
       -                ipv6_addr = b':'.join([binascii.hexlify(read(2)) for i in range(4)])
       -                port = int.from_bytes(read(2), 'big')
       -                addresses.append((ipv6_addr, port))
       -            else:
       -                pass
       -            continue
       -        alias = payload['alias'].rstrip(b'\x00')
       -        self.network.lightning_nodes[pubkey] = {
       -            'alias': alias,
       -            'addresses': addresses
       -        }
       -        #self.print_error('node announcement', binascii.hexlify(pubkey), alias, addresses)
       +        self.channel_db.on_node_announcement(payload)
                self.network.trigger_callback('ln_status')
        
            def on_init(self, payload):
 (DIR) diff --git a/electrum/lnrouter.py b/electrum/lnrouter.py
       t@@ -29,17 +29,31 @@ import json
        import threading
        from collections import namedtuple, defaultdict
        from typing import Sequence, Union, Tuple, Optional
       -
       +import binascii
       +import base64
        
        from . import constants
       -from .util import PrintError, bh2u, profiler, get_headers_dir, bfh
       +from .util import PrintError, bh2u, profiler, get_headers_dir, bfh, is_ip_address, list_enabled_bits
        from .storage import JsonDB
        from .lnchanannverifier import LNChanAnnVerifier, verify_sig_for_channel_update
       +from .crypto import Hash
       +from . import ecc
       +from .lnutil import LN_GLOBAL_FEATURE_BITS
       +
       +
       +class UnknownEvenFeatureBits(Exception): pass
        
        
        class ChannelInfo(PrintError):
        
            def __init__(self, channel_announcement_payload):
       +        self.features_len = channel_announcement_payload['len']
       +        self.features = channel_announcement_payload['features']
       +        enabled_features = list_enabled_bits(int.from_bytes(self.features, "big"))
       +        for fbit in enabled_features:
       +            if fbit not in LN_GLOBAL_FEATURE_BITS and fbit % 2 == 0:
       +                raise UnknownEvenFeatureBits()
       +
                self.channel_id = channel_announcement_payload['short_channel_id']
                self.node_id_1 = channel_announcement_payload['node_id_1']
                self.node_id_2 = channel_announcement_payload['node_id_2']
       t@@ -47,8 +61,6 @@ class ChannelInfo(PrintError):
                assert type(self.node_id_2) is bytes
                assert list(sorted([self.node_id_1, self.node_id_2])) == [self.node_id_1, self.node_id_2]
        
       -        self.features_len = channel_announcement_payload['len']
       -        self.features = channel_announcement_payload['features']
                self.bitcoin_key_1 = channel_announcement_payload['bitcoin_key_1']
                self.bitcoin_key_2 = channel_announcement_payload['bitcoin_key_2']
        
       t@@ -162,6 +174,86 @@ class ChannelInfoDirectedPolicy:
                return ChannelInfoDirectedPolicy(d2)
        
        
       +class NodeInfo(PrintError):
       +
       +    def __init__(self, node_announcement_payload, addresses_already_parsed=False):
       +        self.pubkey = node_announcement_payload['node_id']
       +        self.features_len = node_announcement_payload['flen']
       +        self.features = node_announcement_payload['features']
       +        enabled_features = list_enabled_bits(int.from_bytes(self.features, "big"))
       +        for fbit in enabled_features:
       +            if fbit not in LN_GLOBAL_FEATURE_BITS and fbit % 2 == 0:
       +                raise UnknownEvenFeatureBits()
       +        if not addresses_already_parsed:
       +            self.addresses = self.parse_addresses_field(node_announcement_payload['addresses'])
       +        else:
       +            self.addresses = node_announcement_payload['addresses']
       +        self.alias = node_announcement_payload['alias'].rstrip(b'\x00')
       +        self.timestamp = int.from_bytes(node_announcement_payload['timestamp'], "big")
       +
       +    @classmethod
       +    def parse_addresses_field(cls, addresses_field):
       +        buf = addresses_field
       +        def read(n):
       +            nonlocal buf
       +            data, buf = buf[0:n], buf[n:]
       +            return data
       +        addresses = []
       +        while buf:
       +            atype = ord(read(1))
       +            if atype == 0:
       +                pass
       +            elif atype == 1:  # IPv4
       +                ipv4_addr = '.'.join(map(lambda x: '%d' % x, read(4)))
       +                port = int.from_bytes(read(2), 'big')
       +                if is_ip_address(ipv4_addr) and port != 0:
       +                    addresses.append((ipv4_addr, port))
       +            elif atype == 2:  # IPv6
       +                ipv6_addr = b':'.join([binascii.hexlify(read(2)) for i in range(8)])
       +                ipv6_addr = ipv6_addr.decode('ascii')
       +                port = int.from_bytes(read(2), 'big')
       +                if is_ip_address(ipv6_addr) and port != 0:
       +                    addresses.append((ipv6_addr, port))
       +            elif atype == 3:  # onion v2
       +                host = base64.b32encode(read(10)) + b'.onion'
       +                host = host.decode('ascii').lower()
       +                port = int.from_bytes(read(2), 'big')
       +                addresses.append((host, port))
       +            elif atype == 4:  # onion v3
       +                host = base64.b32encode(read(35)) + b'.onion'
       +                host = host.decode('ascii').lower()
       +                port = int.from_bytes(read(2), 'big')
       +                addresses.append((host, port))
       +            else:
       +                # unknown address type
       +                # we don't know how long it is -> have to escape
       +                # if there are other addresses we could have parsed later, they are lost.
       +                break
       +        return addresses
       +
       +    def to_json(self) -> dict:
       +        d = {}
       +        d['node_id'] = bh2u(self.pubkey)
       +        d['flen'] = bh2u(self.features_len)
       +        d['features'] = bh2u(self.features)
       +        d['addresses'] = self.addresses
       +        d['alias'] = bh2u(self.alias)
       +        d['timestamp'] = self.timestamp
       +        return d
       +
       +    @classmethod
       +    def from_json(cls, d: dict):
       +        if d is None: return None
       +        d2 = {}
       +        d2['node_id'] = bfh(d['node_id'])
       +        d2['flen'] = bfh(d['flen'])
       +        d2['features'] = bfh(d['features'])
       +        d2['addresses'] = d['addresses']
       +        d2['alias'] = bfh(d['alias'])
       +        d2['timestamp'] = d['timestamp'].to_bytes(4, "big")
       +        return NodeInfo(d2, addresses_already_parsed=True)
       +
       +
        class ChannelDB(JsonDB):
        
            def __init__(self, network):
       t@@ -173,6 +265,7 @@ class ChannelDB(JsonDB):
                self.lock = threading.Lock()
                self._id_to_channel_info = {}
                self._channels_for_node = defaultdict(set)  # node -> set(short_channel_id)
       +        self.nodes = {}  # node_id -> NodeInfo
        
                self.ca_verifier = LNChanAnnVerifier(network, self)
                self.network.add_jobs([self.ca_verifier])
       t@@ -184,21 +277,35 @@ class ChannelDB(JsonDB):
                    with open(self.path, "r", encoding='utf-8') as f:
                        raw = f.read()
                        self.data = json.loads(raw)
       +        # channels
                channel_infos = self.get('channel_infos', {})
                for short_channel_id, channel_info_d in channel_infos.items():
                    channel_info = ChannelInfo.from_json(channel_info_d)
                    short_channel_id = bfh(short_channel_id)
                    self.add_verified_channel_info(short_channel_id, channel_info)
       +        # nodes
       +        node_infos = self.get('node_infos', {})
       +        for node_id, node_info_d in node_infos.items():
       +            node_info = NodeInfo.from_json(node_info_d)
       +            node_id = bfh(node_id)
       +            self.nodes[node_id] = node_info
        
            def save_data(self):
                with self.lock:
       +            # channels
                    channel_infos = {}
                    for short_channel_id, channel_info in self._id_to_channel_info.items():
                        channel_infos[bh2u(short_channel_id)] = channel_info
                    self.put('channel_infos', channel_infos)
       +            # nodes
       +            node_infos = {}
       +            for node_id, node_info in self.nodes.items():
       +                node_infos[bh2u(node_id)] = node_info
       +            self.put('node_infos', node_infos)
                self.write()
        
            def __len__(self):
       +        # number of channels
                return len(self._id_to_channel_info)
        
            def get_channel_info(self, channel_id) -> Optional[ChannelInfo]:
       t@@ -220,7 +327,10 @@ class ChannelDB(JsonDB):
                    return
                if constants.net.rev_genesis_bytes() != msg_payload['chain_hash']:
                    return
       -        channel_info = ChannelInfo(msg_payload)
       +        try:
       +            channel_info = ChannelInfo(msg_payload)
       +        except UnknownEvenFeatureBits:
       +            return
                if trusted:
                    self.add_verified_channel_info(short_channel_id, channel_info)
                else:
       t@@ -244,6 +354,21 @@ class ChannelDB(JsonDB):
                    return
                channel_info.on_channel_update(msg_payload, trusted=trusted)
        
       +    def on_node_announcement(self, msg_payload):
       +        pubkey = msg_payload['node_id']
       +        signature = msg_payload['signature']
       +        h = Hash(msg_payload['raw'][66:])
       +        if not ecc.verify_signature(pubkey, signature, h):
       +            return
       +        old_node_info = self.nodes.get(pubkey, None)
       +        try:
       +            new_node_info = NodeInfo(msg_payload)
       +        except UnknownEvenFeatureBits:
       +            return
       +        if old_node_info and old_node_info.timestamp >= new_node_info.timestamp:
       +            return  # ignore
       +        self.nodes[pubkey] = new_node_info
       +
            def remove_channel(self, short_channel_id):
                try:
                    channel_info = self._id_to_channel_info[short_channel_id]
 (DIR) diff --git a/electrum/lnutil.py b/electrum/lnutil.py
       t@@ -1,4 +1,4 @@
       -from .util import bfh, bh2u
       +from .util import bfh, bh2u, inv_dict
        from .crypto import sha256
        import json
        from collections import namedtuple
       t@@ -380,3 +380,19 @@ def overall_weight(num_htlc):
        def get_ecdh(priv: bytes, pub: bytes) -> bytes:
            pt = ECPubkey(pub) * string_to_number(priv)
            return sha256(pt.get_public_key_bytes())
       +
       +
       +LN_LOCAL_FEATURE_BITS = {
       +    0: 'option_data_loss_protect_req',
       +    1: 'option_data_loss_protect_opt',
       +    3: 'initial_routing_sync',
       +    4: 'option_upfront_shutdown_script_req',
       +    5: 'option_upfront_shutdown_script_opt',
       +    6: 'gossip_queries_req',
       +    7: 'gossip_queries_opt',
       +}
       +LN_LOCAL_FEATURE_BITS_INV = inv_dict(LN_LOCAL_FEATURE_BITS)
       +
       +LN_GLOBAL_FEATURE_BITS = {}
       +LN_GLOBAL_FEATURE_BITS_INV = inv_dict(LN_GLOBAL_FEATURE_BITS)
       +
 (DIR) diff --git a/electrum/lnworker.py b/electrum/lnworker.py
       t@@ -228,11 +228,12 @@ class LNWorker(PrintError):
                            self.peers.pop(k)
                    if len(self.peers) > 3:
                        continue
       -            if not self.network.lightning_nodes:
       +            if not self.network.channel_db.nodes:
                        continue
       -            node_id = random.choice(list(self.network.lightning_nodes.keys()))
       -            node = self.network.lightning_nodes.get(node_id)
       -            addresses = node.get('addresses')
       +            all_nodes = self.network.channel_db.nodes
       +            node_id = random.choice(list(all_nodes))
       +            node = all_nodes.get(node_id)
       +            addresses = node.addresses
                    if addresses:
                        host, port = addresses[0]
                        self.print_error("trying node", bh2u(node_id))
 (DIR) diff --git a/electrum/network.py b/electrum/network.py
       t@@ -300,7 +300,6 @@ class Network(Logger):
                self._set_status('disconnected')
        
                # lightning network
       -        self.lightning_nodes = {}
                self.channel_db = lnrouter.ChannelDB(self)
                self.path_finder = lnrouter.LNPathFinder(self.channel_db)
                self.lnwatcher = lnwatcher.LNWatcher(self)
 (DIR) diff --git a/electrum/tests/test_util.py b/electrum/tests/test_util.py
       t@@ -1,7 +1,7 @@
        from decimal import Decimal
        
        from electrum.util import (format_satoshis, format_fee_satoshis, parse_URI,
       -                           is_hash256_str, chunks)
       +                           is_hash256_str, chunks, is_ip_address, list_enabled_bits)
        
        from . import SequentialTestCase
        
       t@@ -110,3 +110,22 @@ class TestUtil(SequentialTestCase):
                                 list(chunks([1, 2, 3, 4, 5], 2)))
                with self.assertRaises(ValueError):
                    list(chunks([1, 2, 3], 0))
       +
       +    def test_list_enabled_bits(self):
       +        self.assertEqual((0, 2, 3, 6), list_enabled_bits(77))
       +        self.assertEqual((), list_enabled_bits(0))
       +
       +    def test_is_ip_address(self):
       +        self.assertTrue(is_ip_address("127.0.0.1"))
       +        self.assertTrue(is_ip_address("127.000.000.1"))
       +        self.assertTrue(is_ip_address("255.255.255.255"))
       +        self.assertFalse(is_ip_address("255.255.256.255"))
       +        self.assertFalse(is_ip_address("123.456.789.000"))
       +        self.assertTrue(is_ip_address("2001:0db8:0000:0000:0000:ff00:0042:8329"))
       +        self.assertTrue(is_ip_address("2001:db8:0:0:0:ff00:42:8329"))
       +        self.assertTrue(is_ip_address("2001:db8::ff00:42:8329"))
       +        self.assertFalse(is_ip_address("2001:::db8::ff00:42:8329"))
       +        self.assertTrue(is_ip_address("::1"))
       +        self.assertFalse(is_ip_address("2001:db8:0:0:g:ff00:42:8329"))
       +        self.assertFalse(is_ip_address("lol"))
       +        self.assertFalse(is_ip_address(":@ASD:@AS\x77\x22\xff¬!"))
 (DIR) diff --git a/electrum/util.py b/electrum/util.py
       t@@ -23,7 +23,7 @@
        import binascii
        import os, sys, re, json
        from collections import defaultdict, OrderedDict
       -from typing import NamedTuple, Union, TYPE_CHECKING, Tuple, Optional, Callable, Any
       +from typing import NamedTuple, Union, TYPE_CHECKING, Tuple, Optional, Callable, Any, Sequence
        from datetime import datetime
        import decimal
        from decimal import Decimal
       t@@ -40,6 +40,7 @@ import json
        import time
        from typing import NamedTuple, Optional
        import ssl
       +import ipaddress
        
        import aiohttp
        from aiohttp_socks import SocksConnector, SocksVer
       t@@ -1156,3 +1157,20 @@ def multisig_type(wallet_type):
            if match:
                match = [int(x) for x in match.group(1, 2)]
            return match
       +
       +
       +def is_ip_address(x: Union[str, bytes]) -> bool:
       +    if isinstance(x, bytes):
       +        x = x.decode("utf-8")
       +    try:
       +        ipaddress.ip_address(x)
       +        return True
       +    except ValueError:
       +        return False
       +
       +
       +def list_enabled_bits(x: int) -> Sequence[int]:
       +    """e.g. 77 (0b1001101) --> (0, 2, 3, 6)"""
       +    binary = bin(x)[2:]
       +    rev_bin = reversed(binary)
       +    return tuple(i for i, b in enumerate(rev_bin) if b == '1')