trefactor storage of channels, path finding - electrum - Electrum Bitcoin wallet
 (HTM) git clone https://git.parazyd.org/electrum
 (DIR) Log
 (DIR) Files
 (DIR) Refs
 (DIR) Submodules
       ---
 (DIR) commit 815079efe025d33bbecc17ba0959f233f0e54376
 (DIR) parent 5b1a5e8786664eba0476453de166243073b62fd9
 (HTM) Author: SomberNight <somber.night@protonmail.com>
       Date:   Tue, 17 Apr 2018 20:01:51 +0200
       
       refactor storage of channels, path finding
       
       Diffstat:
         M lib/lnbase.py                       |     261 ++++++++++++++++++++-----------
         M lib/tests/test_lnbase.py            |       4 ++--
       
       2 files changed, 170 insertions(+), 95 deletions(-)
       ---
 (DIR) diff --git a/lib/lnbase.py b/lib/lnbase.py
       t@@ -419,10 +419,8 @@ class Peer(PrintError):
                self.localfeatures = (0x08 if request_initial_sync else 0)
                # view of the network
                self.nodes = {} # received node announcements
       -        self.channels = {} # received channel announcements
       -        self.channel_u_origin = {}
       -        self.channel_u_final = {}
       -        self.graph_of_payment_channels = defaultdict(set)  # node -> short_channel_id
       +        self.channel_db = ChannelDB()
       +        self.path_finder = LNPathFinder(self.channel_db)
        
            def diagnostic_name(self):
                return self.host
       t@@ -541,8 +539,8 @@ class Peer(PrintError):
            def on_funding_signed(self, payload):
                sig = payload['signature']
                channel_id = payload['channel_id']
       -        tx = self.channels[channel_id]
       -        self.network.broadcast(tx)
       +        #tx = self.channels[channel_id]  # FIXME
       +        #self.network.broadcast(tx)
        
            def on_funding_signed(self, payload):
                self.funding_signed[payload["temporary_channel_id"]].set_result(payload)
       t@@ -588,99 +586,14 @@ class Peer(PrintError):
                pass
        
            def on_channel_update(self, payload):
       -        flags = int.from_bytes(payload['flags'], byteorder="big")
       -        direction = bool(flags & 1)
       -        short_channel_id = payload['short_channel_id']
       -        if direction == 0:
       -            self.channel_u_origin[short_channel_id] = payload
       -        else:
       -            self.channel_u_final[short_channel_id] = payload
       -        self.print_error('channel update', binascii.hexlify(short_channel_id), flags)
       +        self.channel_db.on_channel_update(payload)
        
            def on_channel_announcement(self, payload):
       -        short_channel_id = payload['short_channel_id']
       -        self.print_error('channel announcement', binascii.hexlify(short_channel_id))
       -        self.channels[short_channel_id] = payload
       -        self.add_channel_to_graph(payload)
       -
       -    def add_channel_to_graph(self, payload):
       -        node1 = payload['node_id_1']
       -        node2 = payload['node_id_2']
       -        channel_id = payload['short_channel_id']
       -        self.graph_of_payment_channels[node1].add(channel_id)
       -        self.graph_of_payment_channels[node2].add(channel_id)
       +        self.channel_db.on_channel_announcement(payload)
        
            #def open_channel(self, funding_sat, push_msat):
            #    self.send_message(gen_msg('open_channel', funding_satoshis=funding_sat, push_msat=push_msat))
        
       -    @profiler
       -    def find_route_for_payment(self, from_node_id, to_node_id, amount_msat=None):
       -        """Return a route between from_node_id and to_node_id.
       -
       -        Returns a list of (node_id, short_channel_id) representing a path.
       -        To get from node ret[n][0] to ret[n+1][0], use channel ret[n][1]
       -        """
       -        # TODO find multiple paths??
       -
       -        def edge_cost(short_channel_id, direction):
       -            """Heuristic cost of going through a channel.
       -            direction: 0 or 1. --- 0 means node_id_1 -> node_id_2
       -            """
       -            channel_updates = self.channel_u_origin if direction == 0 else self.channel_u_final
       -            try:
       -                cltv_expiry_delta           = channel_updates[short_channel_id]['cltv_expiry_delta']
       -                htlc_minimum_msat           = channel_updates[short_channel_id]['htlc_minimum_msat']
       -                fee_base_msat               = channel_updates[short_channel_id]['fee_base_msat']
       -                fee_proportional_millionths = channel_updates[short_channel_id]['fee_proportional_millionths']
       -            except KeyError:
       -                return float('inf')  # can't use this channel
       -            if amount_msat is not None and amount_msat < htlc_minimum_msat:
       -                return float('inf')  # can't use this channel
       -            amt = amount_msat or 50000 * 1000  # guess for typical payment amount
       -            fee_msat = fee_base_msat + amt * fee_proportional_millionths / 1000000
       -            # TODO revise
       -            # paying 10 more satoshis ~ waiting one more block
       -            fee_cost = fee_msat / 1000 / 10
       -            cltv_cost = cltv_expiry_delta
       -            return cltv_cost + fee_cost + 1
       -
       -        # run Dijkstra
       -        distance_from_start = defaultdict(lambda: float('inf'))
       -        distance_from_start[from_node_id] = 0
       -        prev_node = {}
       -        nodes_to_explore = queue.PriorityQueue()
       -        nodes_to_explore.put((0, from_node_id))
       -
       -        while nodes_to_explore.qsize() > 0:
       -            dist_to_cur_node, cur_node = nodes_to_explore.get()
       -            if cur_node == to_node_id:
       -                break
       -            if dist_to_cur_node != distance_from_start[cur_node]:
       -                # queue.PriorityQueue does not implement decrease_priority,
       -                # so instead of decreasing priorities, we add items again into the queue.
       -                # so there are duplicates in the queue, that we discard now:
       -                continue
       -            for edge in self.graph_of_payment_channels[cur_node]:
       -                node1 = self.channels[edge]['node_id_1']
       -                node2 = self.channels[edge]['node_id_2']
       -                neighbour, direction = (node1, 1) if node1 != cur_node else (node2, 0)
       -                alt_dist_to_neighbour = distance_from_start[cur_node] + edge_cost(edge, direction)
       -                if alt_dist_to_neighbour < distance_from_start[neighbour]:
       -                    distance_from_start[neighbour] = alt_dist_to_neighbour
       -                    prev_node[neighbour] = cur_node, edge
       -                    nodes_to_explore.put((alt_dist_to_neighbour, neighbour))
       -        else:
       -            return None  # no path found
       -
       -        # backtrack from end to start
       -        cur_node = to_node_id
       -        path = [(cur_node, None)]
       -        while cur_node != from_node_id:
       -            cur_node, edge_taken = prev_node[cur_node]
       -            path += [(cur_node, edge_taken)]
       -        path.reverse()
       -        return path
       -
            @aiosafe
            async def main_loop(self):
                self.reader, self.writer = await asyncio.open_connection(self.host, self.port)
       t@@ -792,3 +705,165 @@ class LNWorker:
                # todo: get utxo from wallet
                # submit coro to asyncio main loop
                self.peer.open_channel()
       +
       +
       +class ChannelInfo(PrintError):
       +
       +    def __init__(self, channel_announcement_payload):
       +        self.channel_id = channel_announcement_payload['short_channel_id']
       +        self.node_id_1 = channel_announcement_payload['node_id_1']
       +        self.node_id_2 = channel_announcement_payload['node_id_2']
       +
       +        self.capacity_sat = None
       +        self.policy_node1 = None
       +        self.policy_node2 = None
       +
       +    def set_capacity(self, capacity):
       +        # TODO call this after looking up UTXO for funding txn on chain
       +        self.capacity_sat = capacity
       +
       +    def on_channel_update(self, msg_payload):
       +        assert self.channel_id == msg_payload['short_channel_id']
       +        flags = int.from_bytes(msg_payload['flags'], byteorder="big")
       +        direction = bool(flags & 1)
       +        if direction == 0:
       +            self.policy_node1 = ChannelInfoDirectedPolicy(msg_payload)
       +        else:
       +            self.policy_node2 = ChannelInfoDirectedPolicy(msg_payload)
       +        self.print_error('channel update', binascii.hexlify(self.channel_id), flags)
       +
       +    def get_policy_for_node(self, node_id):
       +        if node_id == self.node_id_1:
       +            return self.policy_node1
       +        elif node_id == self.node_id_2:
       +            return self.policy_node2
       +        else:
       +            raise Exception('node_id {} not in channel {}'.format(node_id, self.channel_id))
       +
       +
       +class ChannelInfoDirectedPolicy:
       +
       +    def __init__(self, channel_update_payload):
       +        self.cltv_expiry_delta           = channel_update_payload['cltv_expiry_delta']
       +        self.htlc_minimum_msat           = channel_update_payload['htlc_minimum_msat']
       +        self.fee_base_msat               = channel_update_payload['fee_base_msat']
       +        self.fee_proportional_millionths = channel_update_payload['fee_proportional_millionths']
       +
       +
       +class ChannelDB(PrintError):
       +
       +    def __init__(self):
       +        self._id_to_channel_info = {}
       +        self._channels_for_node = defaultdict(set)  # node -> set(short_channel_id)
       +
       +    def get_channel_info(self, channel_id):
       +        return self._id_to_channel_info.get(channel_id, None)
       +
       +    def get_channels_for_node(self, node_id):
       +        """Returns the set of channels that have node_id as one of the endpoints."""
       +        return self._channels_for_node[node_id]
       +
       +    def on_channel_announcement(self, msg_payload):
       +        short_channel_id = msg_payload['short_channel_id']
       +        self.print_error('channel announcement', binascii.hexlify(short_channel_id))
       +        channel_info = ChannelInfo(msg_payload)
       +        self._id_to_channel_info[short_channel_id] = channel_info
       +        self._channels_for_node[channel_info.node_id_1].add(short_channel_id)
       +        self._channels_for_node[channel_info.node_id_2].add(short_channel_id)
       +
       +    def on_channel_update(self, msg_payload):
       +        short_channel_id = msg_payload['short_channel_id']
       +        self._id_to_channel_info[short_channel_id].on_channel_update(msg_payload)
       +
       +    def remove_channel(self, short_channel_id):
       +        try:
       +            channel_info = self._id_to_channel_info[short_channel_id]
       +        except KeyError:
       +            self.print_error('cannot find channel {}'.format(short_channel_id))
       +            return
       +        self._id_to_channel_info.pop(short_channel_id, None)
       +        for node in (channel_info.node_id_1, channel_info.node_id_2):
       +            try:
       +                self._channels_for_node[node].remove(short_channel_id)
       +            except KeyError:
       +                pass
       +
       +
       +class LNPathFinder(PrintError):
       +
       +    def __init__(self, channel_db):
       +        self.channel_db = channel_db
       +
       +    def _edge_cost(self, short_channel_id, start_node, payment_amt_msat):
       +        """Heuristic cost of going through a channel.
       +        direction: 0 or 1. --- 0 means node_id_1 -> node_id_2
       +        """
       +        channel_info = self.channel_db.get_channel_info(short_channel_id)
       +        if channel_info is None:
       +            return float('inf')
       +
       +        channel_policy = channel_info.get_policy_for_node(start_node)
       +        cltv_expiry_delta           = channel_policy.cltv_expiry_delta
       +        htlc_minimum_msat           = channel_policy.htlc_minimum_msat
       +        fee_base_msat               = channel_policy.fee_base_msat
       +        fee_proportional_millionths = channel_policy.fee_proportional_millionths
       +        if payment_amt_msat is not None:
       +            if payment_amt_msat < htlc_minimum_msat:
       +                return float('inf')  # payment amount too little
       +            if channel_info.capacity_sat is not None and \
       +                    payment_amt_msat // 1000 > channel_info.capacity_sat:
       +                return float('inf')  # payment amount too large
       +        amt = payment_amt_msat or 50000 * 1000  # guess for typical payment amount
       +        fee_msat = fee_base_msat + amt * fee_proportional_millionths / 1000000
       +        # TODO revise
       +        # paying 10 more satoshis ~ waiting one more block
       +        fee_cost = fee_msat / 1000 / 10
       +        cltv_cost = cltv_expiry_delta
       +        return cltv_cost + fee_cost + 1
       +
       +    @profiler
       +    def find_path_for_payment(self, from_node_id, to_node_id, amount_msat=None):
       +        """Return a path between from_node_id and to_node_id.
       +
       +        Returns a list of (node_id, short_channel_id) representing a path.
       +        To get from node ret[n][0] to ret[n+1][0], use channel ret[n][1]
       +        """
       +        # TODO find multiple paths??
       +
       +        # run Dijkstra
       +        distance_from_start = defaultdict(lambda: float('inf'))
       +        distance_from_start[from_node_id] = 0
       +        prev_node = {}
       +        nodes_to_explore = queue.PriorityQueue()
       +        nodes_to_explore.put((0, from_node_id))
       +
       +        while nodes_to_explore.qsize() > 0:
       +            dist_to_cur_node, cur_node = nodes_to_explore.get()
       +            if cur_node == to_node_id:
       +                break
       +            if dist_to_cur_node != distance_from_start[cur_node]:
       +                # queue.PriorityQueue does not implement decrease_priority,
       +                # so instead of decreasing priorities, we add items again into the queue.
       +                # so there are duplicates in the queue, that we discard now:
       +                continue
       +            for edge_channel_id in self.channel_db.get_channels_for_node(cur_node):
       +                channel_info = self.channel_db.get_channel_info(edge_channel_id)
       +                node1, node2 = channel_info.node_id_1, channel_info.node_id_2
       +                neighbour = node2 if node1 == cur_node else node1
       +                alt_dist_to_neighbour = distance_from_start[cur_node] \
       +                                        + self._edge_cost(edge_channel_id, cur_node, amount_msat)
       +                if alt_dist_to_neighbour < distance_from_start[neighbour]:
       +                    distance_from_start[neighbour] = alt_dist_to_neighbour
       +                    prev_node[neighbour] = cur_node, edge_channel_id
       +                    nodes_to_explore.put((alt_dist_to_neighbour, neighbour))
       +        else:
       +            return None  # no path found
       +
       +        # backtrack from end to start
       +        cur_node = to_node_id
       +        path = [(cur_node, None)]
       +        while cur_node != from_node_id:
       +            cur_node, edge_taken = prev_node[cur_node]
       +            path += [(cur_node, edge_taken)]
       +        path.reverse()
       +        return path
 (DIR) diff --git a/lib/tests/test_lnbase.py b/lib/tests/test_lnbase.py
       t@@ -181,7 +181,7 @@ class Test_LNBase(unittest.TestCase):
                # local_signature = 30440220549e80b4496803cbc4a1d09d46df50109f546d43fbbf86cd90b174b1484acd5402205f12a4f995cb9bded597eabfee195a285986aa6d93ae5bb72507ebc6a4e2349e
                output_htlc_success_tx_4 = "020000000001018154ecccf11a5fb56c39654c4deb4d2296f83c69268280b94d021370c94e219704000000000000000001a00f0000000000002200204adb4e2f00643db396dd120d4e7dc17625f5f2c11a40d857accc862d6b7dd80e050047304402207e0410e45454b0978a623f36a10626ef17b27d9ad44e2760f98cfa3efb37924f0220220bd8acd43ecaa916a80bd4f919c495a2c58982ce7c8625153f8596692a801d014730440220549e80b4496803cbc4a1d09d46df50109f546d43fbbf86cd90b174b1484acd5402205f12a4f995cb9bded597eabfee195a285986aa6d93ae5bb72507ebc6a4e2349e012004040404040404040404040404040404040404040404040404040404040404048a76a91414011f7254d96b819c76986c277d115efce6f7b58763ac67210394854aa6eab5b2a8122cc726e9dded053a2184d88256816826d6231c068d4a5b7c8201208763a91418bc1a114ccf9c052d3d23e28d3b0a9d1227434288527c21030d417a46946384f88d5f3337267c5e579765875dc4daca813e21734b140639e752ae677502f801b175ac686800000000"
        
       -    def test_find_route_for_payment(self):
       +    def test_find_path_for_payment(self):
                p = Peer('', 0, 'a')
                p.on_channel_announcement({'node_id_1': 'b', 'node_id_2': 'c', 'short_channel_id': bfh('0000000000000001')})
                p.on_channel_announcement({'node_id_1': 'b', 'node_id_2': 'e', 'short_channel_id': bfh('0000000000000002')})
       t@@ -201,7 +201,7 @@ class Test_LNBase(unittest.TestCase):
                p.on_channel_update({'short_channel_id': bfh('0000000000000005'), 'flags': b'1', 'cltv_expiry_delta': 10, 'htlc_minimum_msat': 250, 'fee_base_msat': 100, 'fee_proportional_millionths': 999})
                p.on_channel_update({'short_channel_id': bfh('0000000000000006'), 'flags': b'0', 'cltv_expiry_delta': 10, 'htlc_minimum_msat': 250, 'fee_base_msat': 100, 'fee_proportional_millionths': 99999999})
                p.on_channel_update({'short_channel_id': bfh('0000000000000006'), 'flags': b'1', 'cltv_expiry_delta': 10, 'htlc_minimum_msat': 250, 'fee_base_msat': 100, 'fee_proportional_millionths': 150})
       -        print(p.find_route_for_payment('a', 'e', 100000))
       +        print(p.path_finder.find_path_for_payment('a', 'e', 100000))
        
            def test_key_derivation(self):
                # BOLT3, Appendix E