tlnbase: add RevocationStore test, remove unnecessary lnd helper functions - electrum - Electrum Bitcoin wallet
 (HTM) git clone https://git.parazyd.org/electrum
 (DIR) Log
 (DIR) Files
 (DIR) Refs
 (DIR) Submodules
       ---
 (DIR) commit 3a20c8ce000aa8f270d407da9d2ad03f7d980d4b
 (DIR) parent cf82150aab0b288a2c24f27812ba88a0d7a37793
 (HTM) Author: Janus <ysangkok@gmail.com>
       Date:   Fri, 11 May 2018 15:43:47 +0200
       
       lnbase: add RevocationStore test, remove unnecessary lnd helper functions
       
       Diffstat:
         M lib/lnbase.py                       |      47 ++++---------------------------
         M lib/tests/test_lnbase.py            |      49 ++++++++++---------------------
       
       2 files changed, 22 insertions(+), 74 deletions(-)
       ---
 (DIR) diff --git a/lib/lnbase.py b/lib/lnbase.py
       t@@ -320,30 +320,18 @@ def derive_blinded_pubkey(basepoint, per_commitment_point):
            return point_to_ser(k1 + k2)
        
        def shachain_derive(element, toIndex):
       -    """ compact per-commitment secret storage, taken from lnd """
       -    fromIndex = element.index
       -    positions = derive_bit_transformations(fromIndex, toIndex)
       -    buf = bytearray(element.secret)
       -    for position in positions:
       -        byteNumber = position // 8
       -        bitNumber = position % 8
       -        buf[byteNumber] ^= 1 << bitNumber
       -        h = bitcoin.sha256(buf)
       -        buf = bytearray(h)
       +    return ShachainElement(get_per_commitment_secret_from_seed(element.secret, toIndex, count_trailing_zeros(element.index)), toIndex)
        
       -    return ShachainElement(index=toIndex, secret=bytes(buf))
        
       -
       -def get_per_commitment_secret_from_seed(seed: bytes, i: int, bits: int = 47) -> bytes:
       +def get_per_commitment_secret_from_seed(seed: bytes, i: int, bits: int = 48) -> bytes:
            """Generate per commitment secret."""
            per_commitment_secret = bytearray(seed)
       -    for bitindex in range(bits, -1, -1):
       +    for bitindex in range(bits - 1, -1, -1):
                mask = 1 << bitindex
                if i & mask:
                    per_commitment_secret[bitindex // 8] ^= 1 << (bitindex % 8)
                    per_commitment_secret = bytearray(bitcoin.sha256(per_commitment_secret))
            bajts = bytes(per_commitment_secret)
       -    assert shachain_derive(ShachainElement(index=0, secret=seed), i).secret == bajts
            return bajts
        
        
       t@@ -1569,44 +1557,21 @@ def count_trailing_zeros(index):
                return 48
        
        ShachainElement = namedtuple("ShachainElement", ["secret", "index"])
       +ShachainElement.__str__ = lambda self: "ShachainElement(" + bh2u(self.secret) + "," + str(self.index) + ")"
        
        class RevocationStore:
            """ taken from lnd """
            def __init__(self):
       -        self.buckets = {}
       +        self.buckets = [None] * 48
                self.index = 2**48 - 1
       -    def set_index(self, index):
       -        self.index = index
            def add_next_entry(self, hsh):
                new_element = ShachainElement(index=self.index, secret=hsh)
                bucket = count_trailing_zeros(self.index)
                for i in range(0, bucket):
       -            if i not in self.buckets: return
                    this_bucket = self.buckets[i]
                    e = shachain_derive(new_element, this_bucket.index)
        
                    if e != this_bucket:
       -                return "hash is not derivable: {} {} {}".format(bh2u(e.secret), bh2u(this_bucket.secret), this_bucket.index)
       +                raise Exception("hash is not derivable: {} {} {}".format(bh2u(e.secret), bh2u(this_bucket.secret), this_bucket.index))
                self.buckets[bucket] = new_element
                self.index -= 1
       -        return
       -
       -def get_prefix(index, position):
       -    """ taken from lnd """
       -    mask = (1<<64)-1 - ((1<<position)-1)
       -    return index & mask
       -
       -def derive_bit_transformations(fromm, to):
       -    """ taken from lnd """
       -    positions = []
       -    if fromm == to: return positions
       -
       -    zeros = count_trailing_zeros(fromm)
       -    if fromm > (1<<64)-1: raise Exception("fromm too big")
       -    if fromm != get_prefix(to, zeros):
       -        raise Exception("prefixes are different, indexes are not derivable")
       -
       -    for position in range(zeros, -1, -1):
       -        if to >> position & 1 == 1:
       -            positions.append(position)
       -    return positions
 (DIR) diff --git a/lib/tests/test_lnbase.py b/lib/tests/test_lnbase.py
       t@@ -7,7 +7,7 @@ from lib.lnbase import make_commitment, get_obscured_ctn, Peer, make_offered_htl
        from lib.lnbase import secret_to_pubkey, derive_pubkey, derive_privkey, derive_blinded_pubkey, overall_weight
        from lib.lnbase import make_htlc_tx_output, make_htlc_tx_inputs, get_per_commitment_secret_from_seed
        from lib.lnbase import make_htlc_tx_witness, OnionHopsDataSingle, new_onion_packet, OnionPerHop
       -from lib.lnbase import RevocationStore, derive_bit_transformations
       +from lib.lnbase import RevocationStore, ShachainElement, shachain_derive
        from lib.transaction import Transaction
        from lib import bitcoin
        import ecdsa.ellipticcurve
       t@@ -385,27 +385,6 @@ class Test_LNBase(unittest.TestCase):
                    self.assertEqual(hops_data[i].per_hop.to_bytes(), processed_packet.hop_data.per_hop.to_bytes())
                    packet = processed_packet.next_packet
        
       -    def test_shachain_producer(self):
       -        from collections import namedtuple
       -        tests = []
       -        DeriveTest = namedtuple("DeriveTest", ["name", "fromm", "to", "position", "should_fail"])
       -        tests.append(DeriveTest("zero 'from' 'to'", 0, 0, [], False))
       -        tests.append(DeriveTest("same indexes #1", 0b100, 0b100, [], False))
       -        tests.append(DeriveTest("same indexes #2", 0b1, 0b0, None, True))
       -        tests.append(DeriveTest("test seed 'from'", 0b0, 0b10, [1], False))
       -        tests.append(DeriveTest("not the same indexes", 0b1100, 0b0100, None, True))
       -        tests.append(DeriveTest("'from' index greater than 'to' index", 0b1010, 0b1000, None, True))
       -        tests.append(DeriveTest("zero number trailing zeros", 0b1, 0b1, [], False))
       -        for test in tests:
       -            try:
       -                pos = derive_bit_transformations(test.fromm, test.to)
       -                if test.should_fail:
       -                    raise Exception("test did not fail")
       -                self.assertEqual(test.position, pos)
       -            except:
       -                if not test.should_fail:
       -                    raise Exception(test.name)
       -
            def test_shachain_store(self):
                tests = [
                    {
       t@@ -792,23 +771,27 @@ class Test_LNBase(unittest.TestCase):
                ]
        
                for test in tests:
       -            old_receiver = None
       -            receiver = None
       +            receiver = RevocationStore()
                    for insert in test["inserts"]:
       -                old_receiver = receiver
       -                receiver = RevocationStore()
                        secret = bytes.fromhex(insert["secret"])
       -                if not insert["successful"]:
       -                    receiver.set_index(old_receiver.index)
       -                    receiver.buckets = old_receiver.buckets
       -                secret = secret[::-1]
        
       -                err = receiver.add_next_entry(secret)
       -                if isinstance(err, str):
       +                try:
       +                    receiver.add_next_entry(secret)
       +                except Exception as e:
                            if insert["successful"]:
       -                        raise Exception("Failed ({}): error was received but it shouldn't: {}".format(test["name"], err))
       +                        raise Exception("Failed ({}): error was received but it shouldn't: {}".format(test["name"], e))
                        else:
                            if not insert["successful"]:
                                raise Exception("Failed ({}): error wasn't received".format(test["name"]))
        
                    print("Passed ({})".format(test["name"]))
       +
       +    def test_shachain_produce_consume(self):
       +        seed = bitcoin.sha256(b"shachaintest")
       +        consumer = RevocationStore()
       +        for i in range(10000):
       +            secret = shachain_derive(ShachainElement(seed, 0), 2**48 - i - 1).secret
       +            try:
       +                consumer.add_next_entry(secret)
       +            except Exception as e:
       +                raise Exception("iteration " + str(i) + ": " + str(e))