mpp_split: make SplitConfig a subclass of dict, not just a type-hint
This commit is contained in:
@@ -1841,7 +1841,7 @@ class LNWallet(LNWorker):
|
||||
)
|
||||
for sc in split_configurations:
|
||||
is_multichan_mpp = len(sc.config.items()) > 1
|
||||
is_mpp = sum(len(x) for x in list(sc.config.values())) > 1
|
||||
is_mpp = sc.config.number_parts() > 1
|
||||
if is_mpp and not paysession.invoice_features.supports(LnFeatures.BASIC_MPP_OPT):
|
||||
continue
|
||||
if not is_mpp and self.config.TEST_FORCE_MPP:
|
||||
|
||||
@@ -15,12 +15,33 @@ CANDIDATES_PER_LEVEL = 20
|
||||
MAX_PARTS = 5 # maximum number of parts for splitting
|
||||
|
||||
|
||||
# maps a channel (channel_id, node_id) to a list of amounts
|
||||
SplitConfig = Dict[Tuple[bytes, bytes], List[int]]
|
||||
# maps a channel (channel_id, node_id) to the funds it has available
|
||||
ChannelsFundsInfo = Dict[Tuple[bytes, bytes], int]
|
||||
|
||||
|
||||
class SplitConfig(dict, Dict[Tuple[bytes, bytes], List[int]]):
|
||||
"""maps a channel (channel_id, node_id) to a list of amounts"""
|
||||
def number_parts(self) -> int:
|
||||
return sum([len(v) for v in self.values() if sum(v)])
|
||||
|
||||
def number_nonzero_channels(self) -> int:
|
||||
return len([v for v in self.values() if sum(v)])
|
||||
|
||||
def number_nonzero_nodes(self) -> int:
|
||||
# using a set comprehension
|
||||
return len({nodeid for (_, nodeid), amounts in self.items() if sum(amounts)})
|
||||
|
||||
def total_config_amount(self) -> int:
|
||||
return sum([sum(c) for c in self.values()])
|
||||
|
||||
def is_any_amount_smaller_than_min_part_size(self) -> bool:
|
||||
smaller = False
|
||||
for amounts in self.values():
|
||||
if any([amount < MIN_PART_SIZE_MSAT for amount in amounts]):
|
||||
smaller |= True
|
||||
return smaller
|
||||
|
||||
|
||||
class SplitConfigRating(NamedTuple):
|
||||
config: SplitConfig
|
||||
rating: float
|
||||
@@ -41,31 +62,6 @@ def split_amount_normal(total_amount: int, num_parts: int) -> List[int]:
|
||||
return parts
|
||||
|
||||
|
||||
def number_parts(config: SplitConfig) -> int:
|
||||
return sum([len(v) for v in config.values() if sum(v)])
|
||||
|
||||
|
||||
def number_nonzero_channels(config: SplitConfig) -> int:
|
||||
return len([v for v in config.values() if sum(v)])
|
||||
|
||||
|
||||
def number_nonzero_nodes(config: SplitConfig) -> int:
|
||||
# using a set comprehension
|
||||
return len({nodeid for (_, nodeid), amounts in config.items() if sum(amounts)})
|
||||
|
||||
|
||||
def total_config_amount(config: SplitConfig) -> int:
|
||||
return sum([sum(c) for c in config.values()])
|
||||
|
||||
|
||||
def is_any_amount_smaller_than_min_part_size(config: SplitConfig) -> bool:
|
||||
smaller = False
|
||||
for amounts in config.values():
|
||||
if any([amount < MIN_PART_SIZE_MSAT for amount in amounts]):
|
||||
smaller |= True
|
||||
return smaller
|
||||
|
||||
|
||||
def remove_duplicates(configs: List[SplitConfig]) -> List[SplitConfig]:
|
||||
unique_configs = set()
|
||||
for config in configs:
|
||||
@@ -74,16 +70,16 @@ def remove_duplicates(configs: List[SplitConfig]) -> List[SplitConfig]:
|
||||
config_sorted_keys = {k: config_sorted_values[k] for k in sorted(config_sorted_values.keys())}
|
||||
hashable_config = tuple((c, tuple(sorted(config[c]))) for c in config_sorted_keys)
|
||||
unique_configs.add(hashable_config)
|
||||
unique_configs = [{c[0]: list(c[1]) for c in config} for config in unique_configs]
|
||||
unique_configs = [SplitConfig({c[0]: list(c[1]) for c in config}) for config in unique_configs]
|
||||
return unique_configs
|
||||
|
||||
|
||||
def remove_multiple_nodes(configs: List[SplitConfig]) -> List[SplitConfig]:
|
||||
return [config for config in configs if number_nonzero_nodes(config) == 1]
|
||||
return [config for config in configs if config.number_nonzero_nodes() == 1]
|
||||
|
||||
|
||||
def remove_single_part_configs(configs: List[SplitConfig]) -> List[SplitConfig]:
|
||||
return [config for config in configs if number_parts(config) != 1]
|
||||
return [config for config in configs if config.number_parts() != 1]
|
||||
|
||||
|
||||
def remove_single_channel_splits(configs: List[SplitConfig]) -> List[SplitConfig]:
|
||||
@@ -107,7 +103,7 @@ def rate_config(
|
||||
lowest (best). A penalty depending on the total amount sent over a channel
|
||||
counteracts channel exhaustion."""
|
||||
rating = 0
|
||||
total_amount = total_config_amount(config)
|
||||
total_amount = config.total_config_amount()
|
||||
|
||||
for channel, amounts in config.items():
|
||||
funds = channels_with_funds[channel]
|
||||
@@ -143,7 +139,7 @@ def suggest_splits(
|
||||
for _ in range(CANDIDATES_PER_LEVEL):
|
||||
# we want to have configurations with no splitting to many splittings
|
||||
for target_parts in range(1, MAX_PARTS):
|
||||
config = defaultdict(list) # type: SplitConfig
|
||||
config = SplitConfig()
|
||||
|
||||
# randomly split amount into target_parts chunks
|
||||
split_amounts = split_amount_normal(amount_msat, target_parts)
|
||||
@@ -152,6 +148,8 @@ def suggest_splits(
|
||||
random.shuffle(channels_order)
|
||||
# we check each channel and try to put the funds inside, break if we succeed
|
||||
for c in channels_order:
|
||||
if c not in config:
|
||||
config[c] = []
|
||||
if sum(config[c]) + amount <= channels_with_funds[c]:
|
||||
config[c].append(amount)
|
||||
break
|
||||
@@ -167,11 +165,11 @@ def suggest_splits(
|
||||
distribute_amount -= add_amount
|
||||
if distribute_amount == 0:
|
||||
break
|
||||
if total_config_amount(config) != amount_msat:
|
||||
if config.total_config_amount() != amount_msat:
|
||||
raise NoPathFound('Cannot distribute payment over channels.')
|
||||
if target_parts > 1 and is_any_amount_smaller_than_min_part_size(config):
|
||||
if target_parts > 1 and config.is_any_amount_smaller_than_min_part_size():
|
||||
continue
|
||||
assert total_config_amount(config) == amount_msat
|
||||
assert config.total_config_amount() == amount_msat
|
||||
configs.append(config)
|
||||
|
||||
configs = remove_duplicates(configs)
|
||||
|
||||
@@ -15,10 +15,10 @@ class TestMppSplit(ElectrumTestCase):
|
||||
random.seed(0)
|
||||
# key tuple denotes (channel_id, node_id)
|
||||
self.channels_with_funds = {
|
||||
(0, 0): 1_000_000_000,
|
||||
(1, 1): 500_000_000,
|
||||
(2, 0): 302_000_000,
|
||||
(3, 2): 101_000_000,
|
||||
(b"0", b"0"): 1_000_000_000,
|
||||
(b"1", b"1"): 500_000_000,
|
||||
(b"2", b"0"): 302_000_000,
|
||||
(b"3", b"2"): 101_000_000,
|
||||
}
|
||||
|
||||
def tearDown(self):
|
||||
@@ -30,52 +30,52 @@ class TestMppSplit(ElectrumTestCase):
|
||||
with self.subTest(msg="do a payment with the maximal amount spendable over a single channel"):
|
||||
splits = mpp_split.suggest_splits(1_000_000_000, self.channels_with_funds, exclude_single_part_payments=True)
|
||||
self.assertEqual({
|
||||
(0, 0): [671_020_676],
|
||||
(1, 1): [328_979_324],
|
||||
(2, 0): [],
|
||||
(3, 2): []},
|
||||
(b"0", b"0"): [671_020_676],
|
||||
(b"1", b"1"): [328_979_324],
|
||||
(b"2", b"0"): [],
|
||||
(b"3", b"2"): []},
|
||||
splits[0].config
|
||||
)
|
||||
|
||||
with self.subTest(msg="payment amount that does not require to be split"):
|
||||
splits = mpp_split.suggest_splits(50_000_000, self.channels_with_funds, exclude_single_part_payments=False)
|
||||
self.assertEqual({(0, 0): [50_000_000]}, splits[0].config)
|
||||
self.assertEqual({(1, 1): [50_000_000]}, splits[1].config)
|
||||
self.assertEqual({(2, 0): [50_000_000]}, splits[2].config)
|
||||
self.assertEqual({(3, 2): [50_000_000]}, splits[3].config)
|
||||
self.assertEqual(2, mpp_split.number_parts(splits[4].config))
|
||||
self.assertEqual({(b"0", b"0"): [50_000_000]}, splits[0].config)
|
||||
self.assertEqual({(b"1", b"1"): [50_000_000]}, splits[1].config)
|
||||
self.assertEqual({(b"2", b"0"): [50_000_000]}, splits[2].config)
|
||||
self.assertEqual({(b"3", b"2"): [50_000_000]}, splits[3].config)
|
||||
self.assertEqual(2, splits[4].config.number_parts())
|
||||
|
||||
with self.subTest(msg="do a payment with a larger amount than what is supported by a single channel"):
|
||||
splits = mpp_split.suggest_splits(1_100_000_000, self.channels_with_funds, exclude_single_part_payments=False)
|
||||
self.assertEqual(2, mpp_split.number_parts(splits[0].config))
|
||||
self.assertEqual(2, splits[0].config.number_parts())
|
||||
|
||||
with self.subTest(msg="do a payment with the maximal amount spendable over all channels"):
|
||||
splits = mpp_split.suggest_splits(
|
||||
sum(self.channels_with_funds.values()), self.channels_with_funds, exclude_single_part_payments=True)
|
||||
self.assertEqual({
|
||||
(0, 0): [1_000_000_000],
|
||||
(1, 1): [500_000_000],
|
||||
(2, 0): [302_000_000],
|
||||
(3, 2): [101_000_000]},
|
||||
(b"0", b"0"): [1_000_000_000],
|
||||
(b"1", b"1"): [500_000_000],
|
||||
(b"2", b"0"): [302_000_000],
|
||||
(b"3", b"2"): [101_000_000]},
|
||||
splits[0].config
|
||||
)
|
||||
|
||||
with self.subTest(msg="do a payment with the amount supported by all channels"):
|
||||
splits = mpp_split.suggest_splits(101_000_000, self.channels_with_funds, exclude_single_part_payments=False)
|
||||
for split in splits[:3]:
|
||||
self.assertEqual(1, mpp_split.number_nonzero_channels(split.config))
|
||||
self.assertEqual(1, split.config.number_nonzero_channels())
|
||||
# due to exhaustion of the smallest channel, the algorithm favors
|
||||
# a splitting of the parts into two
|
||||
self.assertEqual(2, mpp_split.number_parts(splits[4].config))
|
||||
self.assertEqual(2, splits[4].config.number_parts())
|
||||
|
||||
def test_send_to_single_node(self):
|
||||
splits = mpp_split.suggest_splits(1_000_000_000, self.channels_with_funds, exclude_single_part_payments=False, exclude_multinode_payments=True)
|
||||
for split in splits:
|
||||
assert mpp_split.number_nonzero_nodes(split.config) == 1
|
||||
assert split.config.number_nonzero_nodes() == 1
|
||||
|
||||
def test_saturation(self):
|
||||
"""Split configurations which spend the full amount in a channel should be avoided."""
|
||||
channels_with_funds = {(0, 0): 159_799_733_076, (1, 1): 499_986_152_000}
|
||||
channels_with_funds = {(b"0", b"0"): 159_799_733_076, (b"1", b"1"): 499_986_152_000}
|
||||
splits = mpp_split.suggest_splits(600_000_000_000, channels_with_funds, exclude_single_part_payments=True)
|
||||
|
||||
uses_full_amount = False
|
||||
@@ -100,37 +100,37 @@ class TestMppSplit(ElectrumTestCase):
|
||||
with self.subTest(msg="split payments with intermediate part penalty"):
|
||||
mpp_split.PART_PENALTY = 1.0
|
||||
splits = mpp_split.suggest_splits(1_100_000_000, self.channels_with_funds)
|
||||
self.assertEqual(2, mpp_split.number_parts(splits[0].config))
|
||||
self.assertEqual(2, splits[0].config.number_parts())
|
||||
|
||||
with self.subTest(msg="split payments with intermediate part penalty"):
|
||||
mpp_split.PART_PENALTY = 0.3
|
||||
splits = mpp_split.suggest_splits(1_100_000_000, self.channels_with_funds)
|
||||
self.assertEqual(4, mpp_split.number_parts(splits[0].config))
|
||||
self.assertEqual(4, splits[0].config.number_parts())
|
||||
|
||||
with self.subTest(msg="split payments with no part penalty"):
|
||||
mpp_split.PART_PENALTY = 0.0
|
||||
splits = mpp_split.suggest_splits(1_100_000_000, self.channels_with_funds)
|
||||
self.assertEqual(5, mpp_split.number_parts(splits[0].config))
|
||||
self.assertEqual(5, splits[0].config.number_parts())
|
||||
|
||||
def test_suggest_splits_single_channel(self):
|
||||
channels_with_funds = {
|
||||
(0, 0): 1_000_000_000,
|
||||
(b"0", b"0"): 1_000_000_000,
|
||||
}
|
||||
|
||||
with self.subTest(msg="do a payment with the maximal amount spendable on a single channel"):
|
||||
splits = mpp_split.suggest_splits(1_000_000_000, channels_with_funds, exclude_single_part_payments=False)
|
||||
self.assertEqual({(0, 0): [1_000_000_000]}, splits[0].config)
|
||||
self.assertEqual({(b"0", b"0"): [1_000_000_000]}, splits[0].config)
|
||||
with self.subTest(msg="test sending an amount greater than what we have available"):
|
||||
self.assertRaises(NoPathFound, mpp_split.suggest_splits, *(1_100_000_000, channels_with_funds))
|
||||
with self.subTest(msg="test sending a large amount over a single channel in chunks"):
|
||||
mpp_split.PART_PENALTY = 0.5
|
||||
splits = mpp_split.suggest_splits(1_000_000_000, channels_with_funds, exclude_single_part_payments=False)
|
||||
self.assertEqual(2, len(splits[0].config[(0, 0)]))
|
||||
self.assertEqual(2, len(splits[0].config[(b"0", b"0")]))
|
||||
with self.subTest(msg="test sending a large amount over a single channel in chunks"):
|
||||
mpp_split.PART_PENALTY = 0.3
|
||||
splits = mpp_split.suggest_splits(1_000_000_000, channels_with_funds, exclude_single_part_payments=False)
|
||||
self.assertEqual(3, len(splits[0].config[(0, 0)]))
|
||||
self.assertEqual(3, len(splits[0].config[(b"0", b"0")]))
|
||||
with self.subTest(msg="exclude all single channel splits"):
|
||||
mpp_split.PART_PENALTY = 0.3
|
||||
splits = mpp_split.suggest_splits(1_000_000_000, channels_with_funds, exclude_single_channel_splits=True)
|
||||
self.assertEqual(1, len(splits[0].config[(0, 0)]))
|
||||
self.assertEqual(1, len(splits[0].config[(b"0", b"0")]))
|
||||
|
||||
Reference in New Issue
Block a user