lnhtlc: speed-up methods for recent ctns
we maintain a set of interesting htlc_ids
This commit is contained in:
@@ -7,6 +7,9 @@ from collections import defaultdict
|
||||
import logging
|
||||
import concurrent
|
||||
from concurrent import futures
|
||||
import unittest
|
||||
|
||||
from aiorpcx import TaskGroup
|
||||
|
||||
from electrum import constants
|
||||
from electrum.network import Network
|
||||
@@ -18,13 +21,13 @@ from electrum.util import bh2u, create_and_start_event_loop
|
||||
from electrum.lnpeer import Peer
|
||||
from electrum.lnutil import LNPeerAddr, Keypair, privkey_to_pubkey
|
||||
from electrum.lnutil import LightningPeerConnectionClosed, RemoteMisbehaving
|
||||
from electrum.lnutil import PaymentFailure, LnLocalFeatures
|
||||
from electrum.lnutil import PaymentFailure, LnLocalFeatures, HTLCOwner
|
||||
from electrum.lnchannel import channel_states, peer_states, Channel
|
||||
from electrum.lnrouter import LNPathFinder
|
||||
from electrum.channel_db import ChannelDB
|
||||
from electrum.lnworker import LNWallet, NoPathFound
|
||||
from electrum.lnmsg import encode_msg, decode_msg
|
||||
from electrum.logging import console_stderr_handler
|
||||
from electrum.logging import console_stderr_handler, Logger
|
||||
from electrum.lnworker import PaymentInfo, RECEIVED, PR_UNPAID
|
||||
|
||||
from .test_lnchannel import create_test_channels
|
||||
@@ -81,8 +84,9 @@ class MockWallet:
|
||||
def is_lightning_backup(self):
|
||||
return False
|
||||
|
||||
class MockLNWallet:
|
||||
class MockLNWallet(Logger):
|
||||
def __init__(self, remote_keypair, local_keypair, chan: 'Channel', tx_queue):
|
||||
Logger.__init__(self)
|
||||
self.remote_keypair = remote_keypair
|
||||
self.node_keypair = local_keypair
|
||||
self.network = MockNetwork(tx_queue)
|
||||
@@ -216,9 +220,11 @@ class TestPeer(ElectrumTestCase):
|
||||
return p1, p2, w1, w2, q1, q2
|
||||
|
||||
@staticmethod
|
||||
def prepare_invoice(w2 # receiver
|
||||
):
|
||||
amount_sat = 100000
|
||||
def prepare_invoice(
|
||||
w2, # receiver
|
||||
*,
|
||||
amount_sat=100_000,
|
||||
):
|
||||
amount_btc = amount_sat/Decimal(COIN)
|
||||
payment_preimage = os.urandom(32)
|
||||
RHASH = sha256(payment_preimage)
|
||||
@@ -300,6 +306,35 @@ class TestPeer(ElectrumTestCase):
|
||||
with self.assertRaises(concurrent.futures.CancelledError):
|
||||
run(f())
|
||||
|
||||
@unittest.skip("too expensive")
|
||||
#@needs_test_with_all_chacha20_implementations
|
||||
def test_payments_stresstest(self):
|
||||
alice_channel, bob_channel = create_test_channels()
|
||||
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel)
|
||||
alice_init_balance_msat = alice_channel.balance(HTLCOwner.LOCAL)
|
||||
bob_init_balance_msat = bob_channel.balance(HTLCOwner.LOCAL)
|
||||
num_payments = 1000
|
||||
#pay_reqs1 = [self.prepare_invoice(w1, amount_sat=1) for i in range(num_payments)]
|
||||
pay_reqs2 = [self.prepare_invoice(w2, amount_sat=1) for i in range(num_payments)]
|
||||
max_htlcs_in_flight = asyncio.Semaphore(5)
|
||||
async def single_payment(pay_req):
|
||||
async with max_htlcs_in_flight:
|
||||
await w1._pay(pay_req)
|
||||
async def many_payments():
|
||||
async with TaskGroup() as group:
|
||||
for pay_req in pay_reqs2:
|
||||
await group.spawn(single_payment(pay_req))
|
||||
gath.cancel()
|
||||
gath = asyncio.gather(many_payments(), p1._message_loop(), p2._message_loop(), p1.htlc_switch(), p2.htlc_switch())
|
||||
async def f():
|
||||
await gath
|
||||
with self.assertRaises(concurrent.futures.CancelledError):
|
||||
run(f())
|
||||
self.assertEqual(alice_init_balance_msat - num_payments * 1000, alice_channel.balance(HTLCOwner.LOCAL))
|
||||
self.assertEqual(alice_init_balance_msat - num_payments * 1000, bob_channel.balance(HTLCOwner.REMOTE))
|
||||
self.assertEqual(bob_init_balance_msat + num_payments * 1000, bob_channel.balance(HTLCOwner.LOCAL))
|
||||
self.assertEqual(bob_init_balance_msat + num_payments * 1000, alice_channel.balance(HTLCOwner.REMOTE))
|
||||
|
||||
@needs_test_with_all_chacha20_implementations
|
||||
def test_close(self):
|
||||
alice_channel, bob_channel = create_test_channels()
|
||||
@@ -352,5 +387,6 @@ class TestPeer(ElectrumTestCase):
|
||||
with self.assertRaises(PaymentFailure):
|
||||
run(f())
|
||||
|
||||
|
||||
def run(coro):
|
||||
return asyncio.run_coroutine_threadsafe(coro, loop=asyncio.get_event_loop()).result()
|
||||
|
||||
Reference in New Issue
Block a user