test_lnpeer: some clean-up, make it easier to add "num_node>2" tests
This commit is contained in:
@@ -44,12 +44,12 @@ class PRNG:
|
|||||||
self.sha = sha256(seed)
|
self.sha = sha256(seed)
|
||||||
self.pool = bytearray()
|
self.pool = bytearray()
|
||||||
|
|
||||||
def get_bytes(self, n):
|
def get_bytes(self, n: int) -> bytes:
|
||||||
while len(self.pool) < n:
|
while len(self.pool) < n:
|
||||||
self.pool.extend(self.sha)
|
self.pool.extend(self.sha)
|
||||||
self.sha = sha256(self.sha)
|
self.sha = sha256(self.sha)
|
||||||
result, self.pool = self.pool[:n], self.pool[n:]
|
result, self.pool = self.pool[:n], self.pool[n:]
|
||||||
return result
|
return bytes(result)
|
||||||
|
|
||||||
def randint(self, start, end):
|
def randint(self, start, end):
|
||||||
# Returns random integer in [start, end)
|
# Returns random integer in [start, end)
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ from electrum.ecc import sig_string_from_der_sig
|
|||||||
from electrum.logging import console_stderr_handler
|
from electrum.logging import console_stderr_handler
|
||||||
from electrum.lnchannel import ChannelState
|
from electrum.lnchannel import ChannelState
|
||||||
from electrum.json_db import StoredDict
|
from electrum.json_db import StoredDict
|
||||||
|
from electrum.coinchooser import PRNG
|
||||||
|
|
||||||
from . import ElectrumTestCase
|
from . import ElectrumTestCase
|
||||||
|
|
||||||
@@ -110,8 +111,13 @@ def bip32(sequence):
|
|||||||
assert type(k) is bytes
|
assert type(k) is bytes
|
||||||
return k
|
return k
|
||||||
|
|
||||||
def create_test_channels(*, feerate=6000, local_msat=None, remote_msat=None):
|
def create_test_channels(*, feerate=6000, local_msat=None, remote_msat=None,
|
||||||
funding_txid = binascii.hexlify(b"\x01"*32).decode("ascii")
|
alice_name="alice", bob_name="bob",
|
||||||
|
alice_pubkey=b"\x01"*33, bob_pubkey=b"\x02"*33, random_seed=None):
|
||||||
|
if random_seed is None: # needed for deterministic randomness
|
||||||
|
random_seed = os.urandom(32)
|
||||||
|
random_gen = PRNG(random_seed)
|
||||||
|
funding_txid = binascii.hexlify(random_gen.get_bytes(32)).decode("ascii")
|
||||||
funding_index = 0
|
funding_index = 0
|
||||||
funding_sat = ((local_msat + remote_msat) // 1000) if local_msat is not None and remote_msat is not None else (bitcoin.COIN * 10)
|
funding_sat = ((local_msat + remote_msat) // 1000) if local_msat is not None and remote_msat is not None else (bitcoin.COIN * 10)
|
||||||
local_amount = local_msat if local_msat is not None else (funding_sat * 1000 // 2)
|
local_amount = local_msat if local_msat is not None else (funding_sat * 1000 // 2)
|
||||||
@@ -123,20 +129,20 @@ def create_test_channels(*, feerate=6000, local_msat=None, remote_msat=None):
|
|||||||
alice_pubkeys = [lnutil.OnlyPubkeyKeypair(x.pubkey) for x in alice_privkeys]
|
alice_pubkeys = [lnutil.OnlyPubkeyKeypair(x.pubkey) for x in alice_privkeys]
|
||||||
bob_pubkeys = [lnutil.OnlyPubkeyKeypair(x.pubkey) for x in bob_privkeys]
|
bob_pubkeys = [lnutil.OnlyPubkeyKeypair(x.pubkey) for x in bob_privkeys]
|
||||||
|
|
||||||
alice_seed = b"\x01" * 32
|
alice_seed = random_gen.get_bytes(32)
|
||||||
bob_seed = b"\x02" * 32
|
bob_seed = random_gen.get_bytes(32)
|
||||||
|
|
||||||
alice_first = lnutil.secret_to_pubkey(int.from_bytes(lnutil.get_per_commitment_secret_from_seed(alice_seed, lnutil.RevocationStore.START_INDEX), "big"))
|
alice_first = lnutil.secret_to_pubkey(int.from_bytes(lnutil.get_per_commitment_secret_from_seed(alice_seed, lnutil.RevocationStore.START_INDEX), "big"))
|
||||||
bob_first = lnutil.secret_to_pubkey(int.from_bytes(lnutil.get_per_commitment_secret_from_seed(bob_seed, lnutil.RevocationStore.START_INDEX), "big"))
|
bob_first = lnutil.secret_to_pubkey(int.from_bytes(lnutil.get_per_commitment_secret_from_seed(bob_seed, lnutil.RevocationStore.START_INDEX), "big"))
|
||||||
|
|
||||||
alice, bob = (
|
alice, bob = (
|
||||||
lnchannel.Channel(
|
lnchannel.Channel(
|
||||||
create_channel_state(funding_txid, funding_index, funding_sat, True, local_amount, remote_amount, alice_privkeys, bob_pubkeys, alice_seed, None, bob_first, b"\x02"*33, l_dust=200, r_dust=1300, l_csv=5, r_csv=4),
|
create_channel_state(funding_txid, funding_index, funding_sat, True, local_amount, remote_amount, alice_privkeys, bob_pubkeys, alice_seed, None, bob_first, other_node_id=bob_pubkey, l_dust=200, r_dust=1300, l_csv=5, r_csv=4),
|
||||||
name="alice",
|
name=bob_name,
|
||||||
initial_feerate=feerate),
|
initial_feerate=feerate),
|
||||||
lnchannel.Channel(
|
lnchannel.Channel(
|
||||||
create_channel_state(funding_txid, funding_index, funding_sat, False, remote_amount, local_amount, bob_privkeys, alice_pubkeys, bob_seed, None, alice_first, b"\x01"*33, l_dust=1300, r_dust=200, l_csv=4, r_csv=5),
|
create_channel_state(funding_txid, funding_index, funding_sat, False, remote_amount, local_amount, bob_privkeys, alice_pubkeys, bob_seed, None, alice_first, other_node_id=alice_pubkey, l_dust=1300, r_dust=200, l_csv=4, r_csv=5),
|
||||||
name="bob",
|
name=alice_name,
|
||||||
initial_feerate=feerate)
|
initial_feerate=feerate)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import logging
|
|||||||
import concurrent
|
import concurrent
|
||||||
from concurrent import futures
|
from concurrent import futures
|
||||||
import unittest
|
import unittest
|
||||||
|
from typing import Iterable
|
||||||
|
|
||||||
from aiorpcx import TaskGroup
|
from aiorpcx import TaskGroup
|
||||||
|
|
||||||
@@ -96,21 +97,23 @@ class MockWallet:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
|
class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
|
||||||
def __init__(self, remote_keypair, local_keypair, chan: 'Channel', tx_queue):
|
def __init__(self, *, local_keypair: Keypair, chans: Iterable['Channel'], tx_queue):
|
||||||
Logger.__init__(self)
|
Logger.__init__(self)
|
||||||
NetworkRetryManager.__init__(self, max_retry_delay_normal=1, init_retry_delay_normal=1)
|
NetworkRetryManager.__init__(self, max_retry_delay_normal=1, init_retry_delay_normal=1)
|
||||||
self.remote_keypair = remote_keypair
|
|
||||||
self.node_keypair = local_keypair
|
self.node_keypair = local_keypair
|
||||||
self.network = MockNetwork(tx_queue)
|
self.network = MockNetwork(tx_queue)
|
||||||
self._channels = {chan.channel_id: chan}
|
self.channel_db = self.network.channel_db
|
||||||
|
self._channels = {chan.channel_id: chan
|
||||||
|
for chan in chans}
|
||||||
self.payments = {}
|
self.payments = {}
|
||||||
self.logs = defaultdict(list)
|
self.logs = defaultdict(list)
|
||||||
self.wallet = MockWallet()
|
self.wallet = MockWallet()
|
||||||
self.features = LnFeatures(0)
|
self.features = LnFeatures(0)
|
||||||
self.features |= LnFeatures.OPTION_DATA_LOSS_PROTECT_OPT
|
self.features |= LnFeatures.OPTION_DATA_LOSS_PROTECT_OPT
|
||||||
self.pending_payments = defaultdict(asyncio.Future)
|
self.pending_payments = defaultdict(asyncio.Future)
|
||||||
chan.lnworker = self
|
for chan in chans:
|
||||||
chan.node_id = remote_keypair.pubkey
|
chan.lnworker = self
|
||||||
|
self._peers = {} # bytes -> Peer
|
||||||
# used in tests
|
# used in tests
|
||||||
self.enable_htlc_settle = asyncio.Event()
|
self.enable_htlc_settle = asyncio.Event()
|
||||||
self.enable_htlc_settle.set()
|
self.enable_htlc_settle.set()
|
||||||
@@ -130,13 +133,6 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
|
|||||||
def peers(self):
|
def peers(self):
|
||||||
return self._peers
|
return self._peers
|
||||||
|
|
||||||
@property
|
|
||||||
def _peers(self):
|
|
||||||
return {self.remote_keypair.pubkey: self.peer}
|
|
||||||
|
|
||||||
def channels_for_peer(self, pubkey):
|
|
||||||
return self._channels
|
|
||||||
|
|
||||||
def get_channel_by_short_id(self, short_channel_id):
|
def get_channel_by_short_id(self, short_channel_id):
|
||||||
with self.lock:
|
with self.lock:
|
||||||
for chan in self._channels.values():
|
for chan in self._channels.values():
|
||||||
@@ -171,6 +167,9 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
|
|||||||
get_first_timestamp = lambda self: 0
|
get_first_timestamp = lambda self: 0
|
||||||
on_peer_successfully_established = LNWallet.on_peer_successfully_established
|
on_peer_successfully_established = LNWallet.on_peer_successfully_established
|
||||||
get_channel_by_id = LNWallet.get_channel_by_id
|
get_channel_by_id = LNWallet.get_channel_by_id
|
||||||
|
channels_for_peer = LNWallet.channels_for_peer
|
||||||
|
_calc_routing_hints_for_invoice = LNWallet._calc_routing_hints_for_invoice
|
||||||
|
handle_error_code_from_failed_htlc = LNWallet.handle_error_code_from_failed_htlc
|
||||||
|
|
||||||
|
|
||||||
class MockTransport:
|
class MockTransport:
|
||||||
@@ -206,12 +205,16 @@ class PutIntoOthersQueueTransport(MockTransport):
|
|||||||
self.other_mock_transport.queue.put_nowait(data)
|
self.other_mock_transport.queue.put_nowait(data)
|
||||||
|
|
||||||
def transport_pair(k1, k2, name1, name2):
|
def transport_pair(k1, k2, name1, name2):
|
||||||
t1 = PutIntoOthersQueueTransport(k1, name1)
|
t1 = PutIntoOthersQueueTransport(k1, name2)
|
||||||
t2 = PutIntoOthersQueueTransport(k2, name2)
|
t2 = PutIntoOthersQueueTransport(k2, name1)
|
||||||
t1.other_mock_transport = t2
|
t1.other_mock_transport = t2
|
||||||
t2.other_mock_transport = t1
|
t2.other_mock_transport = t1
|
||||||
return t1, t2
|
return t1, t2
|
||||||
|
|
||||||
|
|
||||||
|
class PaymentDone(Exception): pass
|
||||||
|
|
||||||
|
|
||||||
class TestPeer(ElectrumTestCase):
|
class TestPeer(ElectrumTestCase):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -230,14 +233,16 @@ class TestPeer(ElectrumTestCase):
|
|||||||
|
|
||||||
def prepare_peers(self, alice_channel, bob_channel):
|
def prepare_peers(self, alice_channel, bob_channel):
|
||||||
k1, k2 = keypair(), keypair()
|
k1, k2 = keypair(), keypair()
|
||||||
t1, t2 = transport_pair(k2, k1, alice_channel.name, bob_channel.name)
|
alice_channel.node_id = k2.pubkey
|
||||||
|
bob_channel.node_id = k1.pubkey
|
||||||
|
t1, t2 = transport_pair(k1, k2, alice_channel.name, bob_channel.name)
|
||||||
q1, q2 = asyncio.Queue(), asyncio.Queue()
|
q1, q2 = asyncio.Queue(), asyncio.Queue()
|
||||||
w1 = MockLNWallet(k1, k2, alice_channel, tx_queue=q1)
|
w1 = MockLNWallet(local_keypair=k1, chans=[alice_channel], tx_queue=q1)
|
||||||
w2 = MockLNWallet(k2, k1, bob_channel, tx_queue=q2)
|
w2 = MockLNWallet(local_keypair=k2, chans=[bob_channel], tx_queue=q2)
|
||||||
p1 = Peer(w1, k1.pubkey, t1)
|
p1 = Peer(w1, k2.pubkey, t1)
|
||||||
p2 = Peer(w2, k2.pubkey, t2)
|
p2 = Peer(w2, k1.pubkey, t2)
|
||||||
w1.peer = p1
|
w1._peers[p1.pubkey] = p1
|
||||||
w2.peer = p2
|
w2._peers[p2.pubkey] = p2
|
||||||
# mark_open won't work if state is already OPEN.
|
# mark_open won't work if state is already OPEN.
|
||||||
# so set it to FUNDED
|
# so set it to FUNDED
|
||||||
alice_channel._state = ChannelState.FUNDED
|
alice_channel._state = ChannelState.FUNDED
|
||||||
@@ -248,10 +253,11 @@ class TestPeer(ElectrumTestCase):
|
|||||||
return p1, p2, w1, w2, q1, q2
|
return p1, p2, w1, w2, q1, q2
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def prepare_invoice(
|
async def prepare_invoice(
|
||||||
w2, # receiver
|
w2: MockLNWallet, # receiver
|
||||||
*,
|
*,
|
||||||
amount_sat=100_000,
|
amount_sat=100_000,
|
||||||
|
include_routing_hints=False,
|
||||||
):
|
):
|
||||||
amount_btc = amount_sat/Decimal(COIN)
|
amount_btc = amount_sat/Decimal(COIN)
|
||||||
payment_preimage = os.urandom(32)
|
payment_preimage = os.urandom(32)
|
||||||
@@ -259,12 +265,16 @@ class TestPeer(ElectrumTestCase):
|
|||||||
info = PaymentInfo(RHASH, amount_sat, RECEIVED, PR_UNPAID)
|
info = PaymentInfo(RHASH, amount_sat, RECEIVED, PR_UNPAID)
|
||||||
w2.save_preimage(RHASH, payment_preimage)
|
w2.save_preimage(RHASH, payment_preimage)
|
||||||
w2.save_payment_info(info)
|
w2.save_payment_info(info)
|
||||||
|
if include_routing_hints:
|
||||||
|
routing_hints = await w2._calc_routing_hints_for_invoice(amount_sat)
|
||||||
|
else:
|
||||||
|
routing_hints = []
|
||||||
lnaddr = LnAddr(
|
lnaddr = LnAddr(
|
||||||
paymenthash=RHASH,
|
paymenthash=RHASH,
|
||||||
amount=amount_btc,
|
amount=amount_btc,
|
||||||
tags=[('c', lnutil.MIN_FINAL_CLTV_EXPIRY_FOR_INVOICE),
|
tags=[('c', lnutil.MIN_FINAL_CLTV_EXPIRY_FOR_INVOICE),
|
||||||
('d', 'coffee')
|
('d', 'coffee')
|
||||||
])
|
] + routing_hints)
|
||||||
return lnencode(lnaddr, w2.node_keypair.privkey)
|
return lnencode(lnaddr, w2.node_keypair.privkey)
|
||||||
|
|
||||||
def test_reestablish(self):
|
def test_reestablish(self):
|
||||||
@@ -287,10 +297,11 @@ class TestPeer(ElectrumTestCase):
|
|||||||
|
|
||||||
@needs_test_with_all_chacha20_implementations
|
@needs_test_with_all_chacha20_implementations
|
||||||
def test_reestablish_with_old_state(self):
|
def test_reestablish_with_old_state(self):
|
||||||
alice_channel, bob_channel = create_test_channels()
|
random_seed = os.urandom(32)
|
||||||
alice_channel_0, bob_channel_0 = create_test_channels() # these are identical
|
alice_channel, bob_channel = create_test_channels(random_seed=random_seed)
|
||||||
|
alice_channel_0, bob_channel_0 = create_test_channels(random_seed=random_seed) # these are identical
|
||||||
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel)
|
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel)
|
||||||
pay_req = self.prepare_invoice(w2)
|
pay_req = run(self.prepare_invoice(w2))
|
||||||
async def pay():
|
async def pay():
|
||||||
result, log = await w1._pay(pay_req)
|
result, log = await w1._pay(pay_req)
|
||||||
self.assertEqual(result, True)
|
self.assertEqual(result, True)
|
||||||
@@ -323,15 +334,20 @@ class TestPeer(ElectrumTestCase):
|
|||||||
def test_payment(self):
|
def test_payment(self):
|
||||||
alice_channel, bob_channel = create_test_channels()
|
alice_channel, bob_channel = create_test_channels()
|
||||||
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel)
|
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel)
|
||||||
pay_req = self.prepare_invoice(w2)
|
async def pay(pay_req):
|
||||||
async def pay():
|
|
||||||
result, log = await w1._pay(pay_req)
|
result, log = await w1._pay(pay_req)
|
||||||
self.assertTrue(result)
|
self.assertTrue(result)
|
||||||
gath.cancel()
|
raise PaymentDone()
|
||||||
gath = asyncio.gather(pay(), p1._message_loop(), p2._message_loop(), p1.htlc_switch(), p2.htlc_switch())
|
|
||||||
async def f():
|
async def f():
|
||||||
await gath
|
async with TaskGroup() as group:
|
||||||
with self.assertRaises(concurrent.futures.CancelledError):
|
await group.spawn(p1._message_loop())
|
||||||
|
await group.spawn(p1.htlc_switch())
|
||||||
|
await group.spawn(p2._message_loop())
|
||||||
|
await group.spawn(p2.htlc_switch())
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
pay_req = await self.prepare_invoice(w2)
|
||||||
|
await group.spawn(pay(pay_req))
|
||||||
|
with self.assertRaises(PaymentDone):
|
||||||
run(f())
|
run(f())
|
||||||
|
|
||||||
#@unittest.skip("too expensive")
|
#@unittest.skip("too expensive")
|
||||||
@@ -343,15 +359,17 @@ class TestPeer(ElectrumTestCase):
|
|||||||
bob_init_balance_msat = bob_channel.balance(HTLCOwner.LOCAL)
|
bob_init_balance_msat = bob_channel.balance(HTLCOwner.LOCAL)
|
||||||
num_payments = 50
|
num_payments = 50
|
||||||
payment_value_sat = 10000 # make it large enough so that there are actually HTLCs on the ctx
|
payment_value_sat = 10000 # make it large enough so that there are actually HTLCs on the ctx
|
||||||
#pay_reqs1 = [self.prepare_invoice(w1, amount_sat=1) for i in range(num_payments)]
|
|
||||||
pay_reqs2 = [self.prepare_invoice(w2, amount_sat=payment_value_sat) for i in range(num_payments)]
|
|
||||||
max_htlcs_in_flight = asyncio.Semaphore(5)
|
max_htlcs_in_flight = asyncio.Semaphore(5)
|
||||||
async def single_payment(pay_req):
|
async def single_payment(pay_req):
|
||||||
async with max_htlcs_in_flight:
|
async with max_htlcs_in_flight:
|
||||||
await w1._pay(pay_req)
|
await w1._pay(pay_req)
|
||||||
async def many_payments():
|
async def many_payments():
|
||||||
async with TaskGroup() as group:
|
async with TaskGroup() as group:
|
||||||
for pay_req in pay_reqs2:
|
pay_reqs_tasks = [await group.spawn(self.prepare_invoice(w2, amount_sat=payment_value_sat))
|
||||||
|
for i in range(num_payments)]
|
||||||
|
async with TaskGroup() as group:
|
||||||
|
for pay_req_task in pay_reqs_tasks:
|
||||||
|
pay_req = pay_req_task.result()
|
||||||
await group.spawn(single_payment(pay_req))
|
await group.spawn(single_payment(pay_req))
|
||||||
gath.cancel()
|
gath.cancel()
|
||||||
gath = asyncio.gather(many_payments(), p1._message_loop(), p2._message_loop(), p1.htlc_switch(), p2.htlc_switch())
|
gath = asyncio.gather(many_payments(), p1._message_loop(), p2._message_loop(), p1.htlc_switch(), p2.htlc_switch())
|
||||||
@@ -373,7 +391,7 @@ class TestPeer(ElectrumTestCase):
|
|||||||
w1.network.config.set_key('fee_per_kb', 5000)
|
w1.network.config.set_key('fee_per_kb', 5000)
|
||||||
w2.network.config.set_key('fee_per_kb', 1000)
|
w2.network.config.set_key('fee_per_kb', 1000)
|
||||||
w2.enable_htlc_settle.clear()
|
w2.enable_htlc_settle.clear()
|
||||||
pay_req = self.prepare_invoice(w2)
|
pay_req = run(self.prepare_invoice(w2))
|
||||||
lnaddr = lndecode(pay_req, expected_hrp=constants.net.SEGWIT_HRP)
|
lnaddr = lndecode(pay_req, expected_hrp=constants.net.SEGWIT_HRP)
|
||||||
async def pay():
|
async def pay():
|
||||||
await asyncio.wait_for(p1.initialized, 1)
|
await asyncio.wait_for(p1.initialized, 1)
|
||||||
@@ -401,7 +419,7 @@ class TestPeer(ElectrumTestCase):
|
|||||||
def test_channel_usage_after_closing(self):
|
def test_channel_usage_after_closing(self):
|
||||||
alice_channel, bob_channel = create_test_channels()
|
alice_channel, bob_channel = create_test_channels()
|
||||||
p1, p2, w1, w2, q1, q2 = self.prepare_peers(alice_channel, bob_channel)
|
p1, p2, w1, w2, q1, q2 = self.prepare_peers(alice_channel, bob_channel)
|
||||||
pay_req = self.prepare_invoice(w2)
|
pay_req = run(self.prepare_invoice(w2))
|
||||||
|
|
||||||
addr = w1._check_invoice(pay_req)
|
addr = w1._check_invoice(pay_req)
|
||||||
route = w1._create_route_from_invoice(decoded_invoice=addr)
|
route = w1._create_route_from_invoice(decoded_invoice=addr)
|
||||||
|
|||||||
Reference in New Issue
Block a user