1
0

tests: lnpeer: mostly unify prepare_peers and prepare_graph

This commit is contained in:
SomberNight
2025-12-18 18:57:55 +00:00
parent 024f9b988d
commit ec65c53de3

View File

@@ -12,7 +12,7 @@ import concurrent
from concurrent import futures
from functools import lru_cache
from unittest import mock
from typing import Iterable, NamedTuple, Tuple, List, Dict, Sequence
from typing import Iterable, NamedTuple, Tuple, List, Dict, Sequence, Mapping
from types import MappingProxyType
import time
import statistics
@@ -323,6 +323,24 @@ depleted_channel = {
}
_GRAPH_DEFINITIONS = {
# A -- B
'single_chan' : {
'alice': {
'channels': {
'bob': {
'local_balance_msat': 10 * bitcoin.COIN * 1000 // 2,
'remote_balance_msat': 10 * bitcoin.COIN * 1000 // 2,
},
},
},
'bob': {
},
},
# A
# high fee / \ low fee
# B C
# high fee \ / low fee
# D
'square_graph': {
'alice': {
'channels': {
@@ -353,6 +371,7 @@ _GRAPH_DEFINITIONS = {
'dave': {
},
},
# A -- B -- C -- D -- E
'line_graph': {
'alice': {
'channels': {
@@ -515,7 +534,12 @@ class TestPeer(ElectrumTestCase):
raise OnionRoutingFailure(code=OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, data=b'')
w2.register_hold_invoice(payment_hash, cb)
def prepare_chans_and_peers_in_graph(self, graph_definition) -> Graph:
def prepare_chans_and_peers_in_graph(
self,
graph_definition,
*,
channels: Mapping[Tuple[str, str], Channel] = None,
) -> Graph:
workers = {} # type: Dict[str, MockLNWallet]
# create workers
@@ -524,32 +548,42 @@ class TestPeer(ElectrumTestCase):
self._lnworkers_created.extend(list(workers.values()))
keys = {name: w.node_keypair for name, w in workers.items()}
channels = {} # type: Dict[Tuple[str, str], Channel]
if channels is None:
channels = {} # type: Dict[Tuple[str, str], Channel]
transports = {}
peers = {}
# create channels
for a, definition in graph_definition.items():
for b, channel_def in definition.get('channels', {}).items():
channel_ab, channel_ba = create_test_channels(
alice_name=a,
bob_name=b,
alice_pubkey=keys[a].pubkey,
bob_pubkey=keys[b].pubkey,
local_msat=channel_def['local_balance_msat'],
remote_msat=channel_def['remote_balance_msat'],
anchor_outputs=self.TEST_ANCHOR_CHANNELS
)
channels[(a, b)], channels[(b, a)] = channel_ab, channel_ba
if ((a, b) in channels) or ((b, a) in channels):
# if either chan direction is present, both must be present
channel_ab = channels[(a, b)]
channel_ba = channels[(b, a)]
else: # create new chans now
channel_ab, channel_ba = create_test_channels(
alice_name=a,
bob_name=b,
alice_pubkey=keys[a].pubkey,
bob_pubkey=keys[b].pubkey,
local_msat=channel_def['local_balance_msat'],
remote_msat=channel_def['remote_balance_msat'],
anchor_outputs=self.TEST_ANCHOR_CHANNELS
)
channels[(a, b)], channels[(b, a)] = channel_ab, channel_ba
workers[a]._add_channel(channel_ab)
workers[b]._add_channel(channel_ba)
transport_ab, transport_ba = transport_pair(keys[a], keys[b], channel_ab.name, channel_ba.name)
transports[(a, b)], transports[(b, a)] = transport_ab, transport_ba
# set fees
channel_ab.forwarding_fee_proportional_millionths = channel_def['local_fee_rate_millionths']
channel_ab.forwarding_fee_base_msat = channel_def['local_base_fee_msat']
channel_ba.forwarding_fee_proportional_millionths = channel_def['remote_fee_rate_millionths']
channel_ba.forwarding_fee_base_msat = channel_def['remote_base_fee_msat']
if 'local_fee_rate_millionths' in channel_def:
channel_ab.forwarding_fee_proportional_millionths = channel_def['local_fee_rate_millionths']
if 'local_base_fee_msat' in channel_def:
channel_ab.forwarding_fee_base_msat = channel_def['local_base_fee_msat']
if 'remote_fee_rate_millionths' in channel_def:
channel_ba.forwarding_fee_proportional_millionths = channel_def['remote_fee_rate_millionths']
if 'remote_base_fee_msat' in channel_def:
channel_ba.forwarding_fee_base_msat = channel_def['remote_base_fee_msat']
# create peers
for ab in channels.keys():
@@ -630,34 +664,27 @@ class TestPeerDirect(TestPeer):
def prepare_peers(
self, alice_channel: Channel, bob_channel: Channel,
):
w1 = MockLNWallet(name=bob_channel.name, has_anchors=self.TEST_ANCHOR_CHANNELS)
w2 = MockLNWallet(name=alice_channel.name, has_anchors=self.TEST_ANCHOR_CHANNELS)
k1 = w1.node_keypair
k2 = w2.node_keypair
alice_channel.node_id = k2.pubkey
bob_channel.node_id = k1.pubkey
alice_channel.storage['node_id'] = alice_channel.node_id
bob_channel.storage['node_id'] = bob_channel.node_id
t1, t2 = transport_pair(k1, k2, alice_channel.name, bob_channel.name)
w1._add_channel(alice_channel)
w2._add_channel(bob_channel)
self._lnworkers_created.extend([w1, w2])
p1 = PeerInTests(w1, k2.pubkey, t1)
p2 = PeerInTests(w2, k1.pubkey, t2)
w1.lnpeermgr._peers[p1.pubkey] = p1
w2.lnpeermgr._peers[p2.pubkey] = p2
# mark_open won't work if state is already OPEN.
# so set it to FUNDED
alice_channel._state = ChannelState.FUNDED
bob_channel._state = ChannelState.FUNDED
# this populates the channel graph:
p1.mark_open(alice_channel)
p2.mark_open(bob_channel)
graph = self.prepare_chans_and_peers_in_graph(
self.GRAPH_DEFINITIONS['single_chan'],
channels={('alice', 'bob'): alice_channel, ('bob', 'alice'): bob_channel},
)
c1, c2 = graph.channels.values()
p1, p2 = graph.peers.values()
w1, w2 = graph.workers.values()
# FIXME xxxxx horrible ugly hack:
c1.node_id = w2.node_keypair.pubkey
c2.node_id = w1.node_keypair.pubkey
c1.storage['node_id'] = c1.node_id
c2.storage['node_id'] = c2.node_id
return p1, p2, w1, w2
async def test_reestablish(self):
alice_channel, bob_channel = create_test_channels()
p1, p2, w1, w2 = self.prepare_peers(alice_channel, bob_channel)
graph = self.prepare_chans_and_peers_in_graph(self.GRAPH_DEFINITIONS['single_chan'])
p1, p2 = graph.peers.values()
alice_channel, bob_channel = graph.channels.values()
for chan in (alice_channel, bob_channel):
chan.peer_state = PeerState.DISCONNECTED
async def reestablish():
@@ -852,8 +879,9 @@ class TestPeerDirect(TestPeer):
test_bundle_timeout=False
):
"""Alice pays Bob a single HTLC via direct channel."""
alice_channel, bob_channel = create_test_channels()
p1, p2, w1, w2 = self.prepare_peers(alice_channel, bob_channel)
graph = self.prepare_chans_and_peers_in_graph(self.GRAPH_DEFINITIONS['single_chan'])
p1, p2 = graph.peers.values()
w1, w2 = graph.workers.values()
results = {}
async def pay(lnaddr, pay_req):
self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr.paymenthash, direction=RECEIVED))
@@ -938,8 +966,9 @@ class TestPeerDirect(TestPeer):
await self._test_simple_payment(test_trampoline=test_trampoline, test_hold_invoice=True, test_failure=True)
async def test_check_invoice_before_payment(self):
alice_channel, bob_channel = create_test_channels()
p1, p2, w1, w2 = self.prepare_peers(alice_channel, bob_channel)
graph = self.prepare_chans_and_peers_in_graph(self.GRAPH_DEFINITIONS['single_chan'])
p1, p2 = graph.peers.values()
w1, w2 = graph.workers.values()
async def try_paying_some_invoices():
# feature bits: unknown even fbit
invoice_features = w2.features.for_invoice() | (1 << 990) # add undefined even fbit
@@ -975,8 +1004,9 @@ class TestPeerDirect(TestPeer):
rejected immediately upon receiving them.
"""
async def run_test(test_trampoline):
alice_channel, bob_channel = create_test_channels()
p1, p2, w1, w2 = self.prepare_peers(alice_channel, bob_channel)
graph = self.prepare_chans_and_peers_in_graph(self.GRAPH_DEFINITIONS['single_chan'])
p1, p2 = graph.peers.values()
w1, w2 = graph.workers.values()
async def try_pay_with_too_low_final_cltv_delta(lnaddr, w1=w1, w2=w2):
self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr.paymenthash, direction=RECEIVED))
@@ -1018,8 +1048,9 @@ class TestPeerDirect(TestPeer):
async def test_reject_payment_for_expired_invoice(self):
"""Tests that new htlcs paying an invoice that has already been expired will get rejected."""
async def run_test(test_trampoline):
alice_channel, bob_channel = create_test_channels()
p1, p2, w1, w2 = self.prepare_peers(alice_channel, bob_channel)
graph = self.prepare_chans_and_peers_in_graph(self.GRAPH_DEFINITIONS['single_chan'])
p1, p2 = graph.peers.values()
w1, w2 = graph.workers.values()
# create lightning invoice in the past, so it is expired
with mock.patch('time.time', return_value=int(time.time()) - 10000):
@@ -1061,8 +1092,9 @@ class TestPeerDirect(TestPeer):
async def test_reject_mpp_for_non_mpp_invoice(self):
"""Test that we reject a payment if it is mpp and we didn't signal support for mpp in the invoice"""
async def run_test(test_trampoline):
alice_channel, bob_channel = create_test_channels()
p1, p2, w1, w2 = self.prepare_peers(alice_channel, bob_channel)
graph = self.prepare_chans_and_peers_in_graph(self.GRAPH_DEFINITIONS['single_chan'])
p1, p2 = graph.peers.values()
w1, w2 = graph.workers.values()
w1.config.TEST_FORCE_MPP = True # force alice to send mpp
if test_trampoline:
@@ -1101,8 +1133,9 @@ class TestPeerDirect(TestPeer):
async def test_reject_multiple_payments_of_same_invoice(self):
"""Tests that new htlcs paying an invoice that has already been paid will get rejected."""
async def run_test(test_trampoline):
alice_channel, bob_channel = create_test_channels()
p1, p2, w1, w2 = self.prepare_peers(alice_channel, bob_channel)
graph = self.prepare_chans_and_peers_in_graph(self.GRAPH_DEFINITIONS['single_chan'])
p1, p2 = graph.peers.values()
w1, w2 = graph.workers.values()
lnaddr, _pay_req = self.prepare_invoice(w2)
@@ -1145,8 +1178,10 @@ class TestPeerDirect(TestPeer):
before sending 'commitment_signed'. Neither party should fulfill
the respective HTLCs until those are irrevocably committed to.
"""
alice_channel, bob_channel = create_test_channels()
p1, p2, w1, w2 = self.prepare_peers(alice_channel, bob_channel)
graph = self.prepare_chans_and_peers_in_graph(self.GRAPH_DEFINITIONS['single_chan'])
p1, p2 = graph.peers.values()
w1, w2 = graph.workers.values()
alice_channel, bob_channel = graph.channels.values()
async def pay():
await util.wait_for2(p1.initialized, 1)
await util.wait_for2(p2.initialized, 1)
@@ -1220,8 +1255,10 @@ class TestPeerDirect(TestPeer):
#@unittest.skip("too expensive")
async def test_payments_stresstest(self):
alice_channel, bob_channel = create_test_channels()
p1, p2, w1, w2 = self.prepare_peers(alice_channel, bob_channel)
graph = self.prepare_chans_and_peers_in_graph(self.GRAPH_DEFINITIONS['single_chan'])
p1, p2 = graph.peers.values()
w1, w2 = graph.workers.values()
alice_channel, bob_channel = graph.channels.values()
alice_init_balance_msat = alice_channel.balance(HTLCOwner.LOCAL)
bob_init_balance_msat = bob_channel.balance(HTLCOwner.LOCAL)
num_payments = 50
@@ -1252,8 +1289,10 @@ class TestPeerDirect(TestPeer):
# - Alice sends htlc1: 0.1 BTC, H1, S1 (total_msat=1 BTC)
# - Alice sends htlc2: 0.9 BTC, H2, S1 (total_msat=1 BTC)
# - Bob(victim) reveals preimage for H1 and fulfills htlc1 (fails other)
alice_channel, bob_channel = create_test_channels()
p1, p2, w1, w2 = self.prepare_peers(alice_channel, bob_channel)
graph = self.prepare_chans_and_peers_in_graph(self.GRAPH_DEFINITIONS['single_chan'])
p1, p2 = graph.peers.values()
w1, w2 = graph.workers.values()
alice_channel, bob_channel = graph.channels.values()
async def pay():
self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr1.paymenthash, direction=RECEIVED))
self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr2.paymenthash, direction=RECEIVED))
@@ -1325,8 +1364,10 @@ class TestPeerDirect(TestPeer):
# - Alice sends htlc1: 0.1 BTC (total_msat=0.2 BTC)
# - Alice sends htlc2: 0.1 BTC (total_msat=1 BTC)
# - Bob(victim) reveals preimage and fulfills htlc2 (fails other)
alice_channel, bob_channel = create_test_channels()
p1, p2, w1, w2 = self.prepare_peers(alice_channel, bob_channel)
graph = self.prepare_chans_and_peers_in_graph(self.GRAPH_DEFINITIONS['single_chan'])
p1, p2 = graph.peers.values()
w1, w2 = graph.workers.values()
alice_channel, bob_channel = graph.channels.values()
async def pay():
self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr1.paymenthash, direction=RECEIVED))
@@ -1392,8 +1433,10 @@ class TestPeerDirect(TestPeer):
"""Alice gets two htlcs as part of a mpp, one has a cltv too close to expiry and will get failed.
Test that the other htlc won't get settled if the mpp isn't complete anymore after failing the other htlc.
"""
alice_channel, bob_channel = create_test_channels()
p1, p2, w1, w2 = self.prepare_peers(alice_channel, bob_channel)
graph = self.prepare_chans_and_peers_in_graph(self.GRAPH_DEFINITIONS['single_chan'])
p1, p2 = graph.peers.values()
w1, w2 = graph.workers.values()
alice_channel, bob_channel = graph.channels.values()
async def pay():
await util.wait_for2(p1.initialized, 1)
await util.wait_for2(p2.initialized, 1)
@@ -1477,8 +1520,10 @@ class TestPeerDirect(TestPeer):
and the sender gets a second chance to pay the same invoice.
"""
async def run_test(test_trampoline: bool):
alice_channel, bob_channel = create_test_channels()
alice_peer, bob_peer, alice_wallet, bob_wallet = self.prepare_peers(alice_channel, bob_channel)
graph = self.prepare_chans_and_peers_in_graph(self.GRAPH_DEFINITIONS['single_chan'])
alice_peer, bob_peer = graph.peers.values()
alice_wallet, bob_wallet = graph.workers.values()
alice_channel, bob_channel = graph.channels.values()
bob_wallet.features |= LnFeatures.BASIC_MPP_OPT
lnaddr1, pay_req1 = self.prepare_invoice(bob_wallet, amount_msat=10_000)
@@ -1591,8 +1636,10 @@ class TestPeerDirect(TestPeer):
# ))
async def _test_shutdown(self, alice_fee, bob_fee, alice_fee_range=None, bob_fee_range=None):
alice_channel, bob_channel = create_test_channels()
p1, p2, w1, w2 = self.prepare_peers(alice_channel, bob_channel)
graph = self.prepare_chans_and_peers_in_graph(self.GRAPH_DEFINITIONS['single_chan'])
p1, p2 = graph.peers.values()
w1, w2 = graph.workers.values()
alice_channel, bob_channel = graph.channels.values()
w1.network.config.TEST_SHUTDOWN_FEE = alice_fee
w2.network.config.TEST_SHUTDOWN_FEE = bob_fee
if alice_fee_range is not None:
@@ -1628,8 +1675,9 @@ class TestPeerDirect(TestPeer):
await gath
async def test_warning(self):
alice_channel, bob_channel = create_test_channels()
p1, p2, w1, w2 = self.prepare_peers(alice_channel, bob_channel)
graph = self.prepare_chans_and_peers_in_graph(self.GRAPH_DEFINITIONS['single_chan'])
p1, p2 = graph.peers.values()
alice_channel, bob_channel = graph.channels.values()
async def action():
await util.wait_for2(p1.initialized, 1)
@@ -1640,8 +1688,9 @@ class TestPeerDirect(TestPeer):
await gath
async def test_error(self):
alice_channel, bob_channel = create_test_channels()
p1, p2, w1, w2 = self.prepare_peers(alice_channel, bob_channel)
graph = self.prepare_chans_and_peers_in_graph(self.GRAPH_DEFINITIONS['single_chan'])
p1, p2 = graph.peers.values()
alice_channel, bob_channel = graph.channels.values()
async def action():
await util.wait_for2(p1.initialized, 1)
@@ -1730,8 +1779,10 @@ class TestPeerDirect(TestPeer):
self.assertEqual(1, len(closing_tx.get_output_idxs_from_address(bob_uss_addr)))
async def test_channel_usage_after_closing(self):
alice_channel, bob_channel = create_test_channels()
p1, p2, w1, w2 = self.prepare_peers(alice_channel, bob_channel)
graph = self.prepare_chans_and_peers_in_graph(self.GRAPH_DEFINITIONS['single_chan'])
p1, p2 = graph.peers.values()
w1, w2 = graph.workers.values()
alice_channel, bob_channel = graph.channels.values()
lnaddr, pay_req = self.prepare_invoice(w2)
lnaddr = w1._check_bolt11_invoice(pay_req.lightning_invoice)
@@ -1771,8 +1822,8 @@ class TestPeerDirect(TestPeer):
await f()
async def test_sending_weird_messages_that_should_be_ignored(self):
alice_channel, bob_channel = create_test_channels()
p1, p2, w1, w2 = self.prepare_peers(alice_channel, bob_channel)
graph = self.prepare_chans_and_peers_in_graph(self.GRAPH_DEFINITIONS['single_chan'])
p1, p2 = graph.peers.values()
async def send_weird_messages():
await util.wait_for2(p1.initialized, 1)
@@ -1802,8 +1853,8 @@ class TestPeerDirect(TestPeer):
await f()
async def test_sending_weird_messages__unknown_even_type(self):
alice_channel, bob_channel = create_test_channels()
p1, p2, w1, w2 = self.prepare_peers(alice_channel, bob_channel)
graph = self.prepare_chans_and_peers_in_graph(self.GRAPH_DEFINITIONS['single_chan'])
p1, p2 = graph.peers.values()
async def send_weird_messages():
await util.wait_for2(p1.initialized, 1)
@@ -1831,8 +1882,8 @@ class TestPeerDirect(TestPeer):
self.assertTrue(isinstance(failing_task.exception().__cause__, lnmsg.UnknownMandatoryMsgType))
async def test_sending_weird_messages__known_msg_with_insufficient_length(self):
alice_channel, bob_channel = create_test_channels()
p1, p2, w1, w2 = self.prepare_peers(alice_channel, bob_channel)
graph = self.prepare_chans_and_peers_in_graph(self.GRAPH_DEFINITIONS['single_chan'])
p1, p2 = graph.peers.values()
async def send_weird_messages():
await util.wait_for2(p1.initialized, 1)
@@ -1871,8 +1922,9 @@ class TestPeerDirect(TestPeer):
which behave differently and use the persisted `LNWallet.dont_expire_htlcs` dict.
"""
async def run_test(test_trampoline):
alice_channel, bob_channel = create_test_channels()
alice_p, bob_p, alice_w, bob_w = self.prepare_peers(alice_channel, bob_channel)
graph = self.prepare_chans_and_peers_in_graph(self.GRAPH_DEFINITIONS['single_chan'])
alice_p, bob_p = graph.peers.values()
alice_w, bob_w = graph.workers.values()
lnaddr, pay_req = self.prepare_invoice(bob_w, min_final_cltv_delta=150)
del bob_w._preimages[pay_req.rhash] # del preimage so bob doesn't settle
@@ -2011,8 +2063,9 @@ class TestPeerDirect(TestPeer):
preimage is available.
"""
async def run_test(test_trampoline, test_expiry):
alice_channel, bob_channel = create_test_channels()
p1, p2, w1, w2 = self.prepare_peers(alice_channel, bob_channel)
graph = self.prepare_chans_and_peers_in_graph(self.GRAPH_DEFINITIONS['single_chan'])
p1, p2 = graph.peers.values()
w1, w2 = graph.workers.values()
if test_trampoline:
await self._activate_trampoline(w1)
# declare bob as trampoline node