tChannelDB: add self.lock and make it thread-safe - electrum - Electrum Bitcoin wallet
 (HTM) git clone https://git.parazyd.org/electrum
 (DIR) Log
 (DIR) Files
 (DIR) Refs
 (DIR) Submodules
       ---
 (DIR) commit fd56fb918982c7afbf96267e03d7cfb2ca7fc858
 (DIR) parent 1ca6f6f306a1f9f376c7519b9b80662aa39768ee
 (HTM) Author: SomberNight <somber.night@protonmail.com>
       Date:   Sat, 29 Feb 2020 18:32:47 +0100
       
       ChannelDB: add self.lock and make it thread-safe
       
       Diffstat:
         M electrum/channel_db.py              |      71 ++++++++++++++++++++-----------
       
       1 file changed, 46 insertions(+), 25 deletions(-)
       ---
 (DIR) diff --git a/electrum/channel_db.py b/electrum/channel_db.py
       t@@ -31,6 +31,7 @@ from typing import Sequence, List, Tuple, Optional, Dict, NamedTuple, TYPE_CHECK
        import binascii
        import base64
        import asyncio
       +import threading
        
        
        from .sql_db import SqlDB, sql
       t@@ -247,17 +248,21 @@ class ChannelDB(SqlDB):
            def __init__(self, network: 'Network'):
                path = os.path.join(get_headers_dir(network.config), 'gossip_db')
                super().__init__(network, path, commit_interval=100)
       +        self.lock = threading.RLock()
                self.num_nodes = 0
                self.num_channels = 0
                self._channel_updates_for_private_channels = {}  # type: Dict[Tuple[bytes, bytes], dict]
                self.ca_verifier = LNChannelVerifier(network, self)
       +
                # initialized in load_data
       +        # note: modify/iterate needs self.lock
                self._channels = {}  # type: Dict[bytes, ChannelInfo]
                self._policies = {}  # type: Dict[Tuple[bytes, bytes], Policy]  # (node_id, scid) -> Policy
                self._nodes = {}  # type: Dict[bytes, NodeInfo]  # node_id -> NodeInfo
                # node_id -> (host, port, ts)
                self._addresses = defaultdict(set)  # type: Dict[bytes, Set[Tuple[str, int, int]]]
                self._channels_for_node = defaultdict(set)  # type: Dict[bytes, Set[ShortChannelID]]
       +
                self.data_loaded = asyncio.Event()
                self.network = network # only for callback
        
       t@@ -268,16 +273,19 @@ class ChannelDB(SqlDB):
                self.network.trigger_callback('channel_db', self.num_nodes, self.num_channels, self.num_policies)
        
            def get_channel_ids(self):
       -        return set(self._channels.keys())
       +        with self.lock:
       +            return set(self._channels.keys())
        
            def add_recent_peer(self, peer: LNPeerAddr):
                now = int(time.time())
                node_id = peer.pubkey
       -        self._addresses[node_id].add((peer.host, peer.port, now))
       +        with self.lock:
       +            self._addresses[node_id].add((peer.host, peer.port, now))
                self.save_node_address(node_id, peer, now)
        
            def get_200_randomly_sorted_nodes_not_in(self, node_ids):
       -        unshuffled = set(self._nodes.keys()) - node_ids
       +        with self.lock:
       +            unshuffled = set(self._nodes.keys()) - node_ids
                return random.sample(unshuffled, min(200, len(unshuffled)))
        
            def get_last_good_address(self, node_id) -> Optional[LNPeerAddr]:
       t@@ -296,8 +304,10 @@ class ChannelDB(SqlDB):
                # FIXME this does not reliably return "recent" peers...
                #       Also, the list() cast over the whole dict (thousands of elements),
                #       is really inefficient.
       +        with self.lock:
       +            _addresses_keys = list(self._addresses.keys())
                r = [self.get_last_good_address(node_id)
       -             for node_id in list(self._addresses.keys())[-self.NUM_MAX_RECENT_PEERS:]]
       +             for node_id in _addresses_keys[-self.NUM_MAX_RECENT_PEERS:]]
                return list(reversed(r))
        
            # note: currently channel announcements are trusted by default (trusted=True);
       t@@ -336,9 +346,10 @@ class ChannelDB(SqlDB):
                except UnknownEvenFeatureBits:
                    return
                channel_info = channel_info._replace(capacity_sat=capacity_sat)
       -        self._channels[channel_info.short_channel_id] = channel_info
       -        self._channels_for_node[channel_info.node1_id].add(channel_info.short_channel_id)
       -        self._channels_for_node[channel_info.node2_id].add(channel_info.short_channel_id)
       +        with self.lock:
       +            self._channels[channel_info.short_channel_id] = channel_info
       +            self._channels_for_node[channel_info.node1_id].add(channel_info.short_channel_id)
       +            self._channels_for_node[channel_info.node2_id].add(channel_info.short_channel_id)
                if 'raw' in msg:
                    self.save_channel(channel_info.short_channel_id, msg['raw'])
        
       t@@ -397,7 +408,8 @@ class ChannelDB(SqlDB):
                    if verify:
                        self.verify_channel_update(payload)
                    policy = Policy.from_msg(payload)
       -            self._policies[key] = policy
       +            with self.lock:
       +                self._policies[key] = policy
                    if 'raw' in payload:
                        self.save_policy(policy.key, payload['raw'])
                #
       t@@ -492,32 +504,38 @@ class ChannelDB(SqlDB):
                    if node and node.timestamp >= node_info.timestamp:
                        continue
                    # save
       -            self._nodes[node_id] = node_info
       +            with self.lock:
       +                self._nodes[node_id] = node_info
                    if 'raw' in msg_payload:
                        self.save_node_info(node_id, msg_payload['raw'])
       -            for addr in node_addresses:
       -                self._addresses[node_id].add((addr.host, addr.port, 0))
       +            with self.lock:
       +                for addr in node_addresses:
       +                    self._addresses[node_id].add((addr.host, addr.port, 0))
                    self.save_node_addresses(node_id, node_addresses)
        
                self.logger.debug("on_node_announcement: %d/%d"%(len(new_nodes), len(msg_payloads)))
                self.update_counts()
        
            def get_old_policies(self, delta):
       +        with self.lock:
       +            _policies = self._policies.copy()
                now = int(time.time())
       -        return list(k for k, v in list(self._policies.items()) if v.timestamp <= now - delta)
       +        return list(k for k, v in _policies.items() if v.timestamp <= now - delta)
        
            def prune_old_policies(self, delta):
                l = self.get_old_policies(delta)
                if l:
                    for k in l:
       -                self._policies.pop(k)
       +                with self.lock:
       +                    self._policies.pop(k)
                        self.delete_policy(*k)
                    self.update_counts()
                    self.logger.info(f'Deleting {len(l)} old policies')
        
            def get_orphaned_channels(self):
       -        ids = set(x[1] for x in self._policies.keys())
       -        return list(x for x in self._channels.keys() if x not in ids)
       +        with self.lock:
       +            ids = set(x[1] for x in self._policies.keys())
       +            return list(x for x in self._channels.keys() if x not in ids)
        
            def prune_orphaned_channels(self):
                l = self.get_orphaned_channels()
       t@@ -535,10 +553,11 @@ class ChannelDB(SqlDB):
                self._channel_updates_for_private_channels[(start_node_id, short_channel_id)] = msg_payload
        
            def remove_channel(self, short_channel_id: ShortChannelID):
       -        channel_info = self._channels.pop(short_channel_id, None)
       -        if channel_info:
       -            self._channels_for_node[channel_info.node1_id].remove(channel_info.short_channel_id)
       -            self._channels_for_node[channel_info.node2_id].remove(channel_info.short_channel_id)
       +        with self.lock:
       +            channel_info = self._channels.pop(short_channel_id, None)
       +            if channel_info:
       +                self._channels_for_node[channel_info.node1_id].remove(channel_info.short_channel_id)
       +                self._channels_for_node[channel_info.node2_id].remove(channel_info.short_channel_id)
                # delete from database
                self.delete_channel(short_channel_id)
        
       t@@ -571,17 +590,19 @@ class ChannelDB(SqlDB):
                    self._channels_for_node[channel_info.node2_id].add(channel_info.short_channel_id)
                self.logger.info(f'load data {len(self._channels)} {len(self._policies)} {len(self._channels_for_node)}')
                self.update_counts()
       -        self.count_incomplete_channels()
       +        self.logger.info(f'semi-orphaned channels: {self.get_num_incomplete_channels()}')
                self.data_loaded.set()
        
       -    def count_incomplete_channels(self):
       -        out = set()
       -        for short_channel_id, ci in self._channels.items():
       +    def get_num_incomplete_channels(self) -> int:
       +        found = set()
       +        with self.lock:
       +            _channels = self._channels.copy()
       +        for short_channel_id, ci in _channels.items():
                    p1 = self.get_policy_for_node(short_channel_id, ci.node1_id)
                    p2 = self.get_policy_for_node(short_channel_id, ci.node2_id)
                    if p1 is None or p2 is not None:
       -                out.add(short_channel_id)
       -        self.logger.info(f'semi-orphaned: {len(out)}')
       +                found.add(short_channel_id)
       +        return len(found)
        
            def get_policy_for_node(self, short_channel_id: bytes, node_id: bytes, *,
                                    my_channels: Dict[ShortChannelID, 'Channel'] = None) -> Optional['Policy']: