diff --git a/electrum/lnworker.py b/electrum/lnworker.py index 7be4673a1..17a1282dd 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -1915,20 +1915,24 @@ class LNWallet(LNWorker): else: return random.choice(list(hardcoded_trampoline_nodes().values())).pubkey - def suggest_splits( + def suggest_payment_splits( self, *, amount_msat: int, final_total_msat: int, my_active_channels: Sequence[Channel], invoice_features: LnFeatures, - r_tags, + r_tags: Sequence[Sequence[Sequence[Any]]], + receiver_pubkey: bytes, ) -> List['SplitConfigRating']: channels_with_funds = { (chan.channel_id, chan.node_id): int(chan.available_to_spend(HTLCOwner.LOCAL)) for chan in my_active_channels } - self.logger.info(f"channels_with_funds: {channels_with_funds}") + # if we have a direct channel it's preferrable to send a single part directly through this + # channel, so this bool will disable excluding single part payments + have_direct_channel = any(chan.node_id == receiver_pubkey for chan in my_active_channels) + self.logger.info(f"channels_with_funds: {channels_with_funds}, {have_direct_channel=}") exclude_single_part_payments = False if self.uses_trampoline(): # in the case of a legacy payment, we don't allow splitting via different @@ -1944,22 +1948,18 @@ class LNWallet(LNWorker): if invoice_features.supports(LnFeatures.BASIC_MPP_OPT) and not self.config.TEST_FORCE_DISABLE_MPP: # if amt is still large compared to total_msat, split it: if (amount_msat / final_total_msat > self.MPP_SPLIT_PART_FRACTION - and amount_msat > self.MPP_SPLIT_PART_MINAMT_MSAT): + and amount_msat > self.MPP_SPLIT_PART_MINAMT_MSAT + and not have_direct_channel): exclude_single_part_payments = True - def get_splits(): - return suggest_splits( - amount_msat, - channels_with_funds, - exclude_single_part_payments=exclude_single_part_payments, - exclude_multinode_payments=exclude_multinode_payments, - exclude_single_channel_splits=exclude_single_channel_splits - ) + split_configurations = suggest_splits( + amount_msat, + channels_with_funds, + exclude_single_part_payments=exclude_single_part_payments, + exclude_multinode_payments=exclude_multinode_payments, + exclude_single_channel_splits=exclude_single_channel_splits + ) - split_configurations = get_splits() - if not split_configurations and exclude_single_part_payments: - exclude_single_part_payments = False - split_configurations = get_splits() self.logger.info(f'suggest_split {amount_msat} returned {len(split_configurations)} configurations') return split_configurations @@ -1989,12 +1989,13 @@ class LNWallet(LNWorker): chan.is_active() and not chan.is_frozen_for_sending()] # try random order random.shuffle(my_active_channels) - split_configurations = self.suggest_splits( + split_configurations = self.suggest_payment_splits( amount_msat=amount_msat, final_total_msat=paysession.amount_to_pay, my_active_channels=my_active_channels, invoice_features=paysession.invoice_features, r_tags=paysession.r_tags, + receiver_pubkey=paysession.invoice_pubkey, ) for sc in split_configurations: is_multichan_mpp = len(sc.config.items()) > 1 diff --git a/electrum/mpp_split.py b/electrum/mpp_split.py index 084ad612a..c7ea3663a 100644 --- a/electrum/mpp_split.py +++ b/electrum/mpp_split.py @@ -167,6 +167,11 @@ def suggest_splits( if config.total_config_amount() != amount_msat: raise NoPathFound('Cannot distribute payment over channels.') if target_parts > 1 and config.is_any_amount_smaller_than_min_part_size(): + if target_parts == 2: + # if there are already too small parts at the first split excluding single + # part payments may return only few configurations, this will allow single part + # payments for more payments, if they are too small to split + exclude_single_part_payments = False continue assert config.total_config_amount() == amount_msat configs.append(config) diff --git a/tests/test_lnpeer.py b/tests/test_lnpeer.py index b5624167f..d934f1611 100644 --- a/tests/test_lnpeer.py +++ b/tests/test_lnpeer.py @@ -9,7 +9,7 @@ from collections import defaultdict import logging import concurrent from concurrent import futures -import unittest +from unittest import mock from typing import Iterable, NamedTuple, Tuple, List, Dict from aiorpcx import timeout_after, TaskTimeout @@ -45,6 +45,7 @@ from electrum.invoices import PR_PAID, PR_UNPAID, Invoice from electrum.interface import GracefulDisconnect from electrum.simple_config import SimpleConfig from electrum.fee_policy import FeeTimeEstimates, FEE_ETA_TARGETS +from electrum.mpp_split import split_amount_normal from .test_lnchannel import create_test_channels from .test_bitcoin import needs_test_with_all_chacha20_implementations @@ -323,7 +324,7 @@ class MockLNWallet(Logger, EventListener, NetworkRetryManager[LNPeerAddr]): is_forwarded_htlc = LNWallet.is_forwarded_htlc notify_upstream_peer = LNWallet.notify_upstream_peer _force_close_channel = LNWallet._force_close_channel - suggest_splits = LNWallet.suggest_splits + suggest_payment_splits = LNWallet.suggest_payment_splits register_hold_invoice = LNWallet.register_hold_invoice unregister_hold_invoice = LNWallet.unregister_hold_invoice add_payment_info_for_hold_invoice = LNWallet.add_payment_info_for_hold_invoice @@ -1556,6 +1557,55 @@ class TestPeerForwarding(TestPeer): print(f" {keys[a].pubkey.hex()}") return graph + async def test_payment_in_graph_with_direct_channel(self): + """Test payment over a direct channel where sender has multiple available channels.""" + graph = self.prepare_chans_and_peers_in_graph(self.GRAPH_DEFINITIONS['line_graph']) + peers = graph.peers.values() + # use same MPP_SPLIT_PART_FRACTION as in regular LNWallet + graph.workers['bob'].MPP_SPLIT_PART_FRACTION = LNWallet.MPP_SPLIT_PART_FRACTION + + # mock split_amount_normal so it's possible to test both cases, the amount getting sorted + # out because one part is below the min size and the other case of both parts being just + # above the min size, so no part is getting sorted out + def mocked_split_amount_normal(total_amount: int, num_parts: int) -> List[int]: + if num_parts == 2 and total_amount == 21_000_000: # test amount 21k sat + # this will not get sorted out by suggest_splits + return [10_500_000, 10_500_000] + elif num_parts == 2 and total_amount == 21_000_001: # 2nd test case + # this will get sorted out by suggest_splits + return [11_000_002, 9_999_999] + else: + return split_amount_normal(total_amount, num_parts) + + async def pay(lnaddr, pay_req): + self.assertEqual(PR_UNPAID, graph.workers['alice'].get_payment_status(lnaddr.paymenthash)) + with mock.patch('electrum.mpp_split.split_amount_normal', + side_effect=mocked_split_amount_normal): + result, log = await graph.workers['bob'].pay_invoice(pay_req) + self.assertTrue(result) + self.assertEqual(PR_PAID, graph.workers['alice'].get_payment_status(lnaddr.paymenthash)) + + async def f(): + async with OldTaskGroup() as group: + for peer in peers: + await group.spawn(peer._message_loop()) + await group.spawn(peer.htlc_switch()) + for peer in peers: + await peer.initialized + for test in [21_000_000, 21_000_001]: + lnaddr, pay_req = self.prepare_invoice( + graph.workers['alice'], + amount_msat=test, + include_routing_hints=True, + invoice_features=LnFeatures.BASIC_MPP_OPT + | LnFeatures.PAYMENT_SECRET_REQ + | LnFeatures.VAR_ONION_REQ + ) + await pay(lnaddr, pay_req) + raise PaymentDone() + with self.assertRaises(PaymentDone): + await f() + async def test_payment_multihop(self): graph = self.prepare_chans_and_peers_in_graph(self.GRAPH_DEFINITIONS['square_graph']) peers = graph.peers.values()