mpp_split: make SplitConfig a subclass of dict, not just a type-hint
This commit is contained in:
@@ -15,12 +15,33 @@ CANDIDATES_PER_LEVEL = 20
|
||||
MAX_PARTS = 5 # maximum number of parts for splitting
|
||||
|
||||
|
||||
# maps a channel (channel_id, node_id) to a list of amounts
|
||||
SplitConfig = Dict[Tuple[bytes, bytes], List[int]]
|
||||
# maps a channel (channel_id, node_id) to the funds it has available
|
||||
ChannelsFundsInfo = Dict[Tuple[bytes, bytes], int]
|
||||
|
||||
|
||||
class SplitConfig(dict, Dict[Tuple[bytes, bytes], List[int]]):
|
||||
"""maps a channel (channel_id, node_id) to a list of amounts"""
|
||||
def number_parts(self) -> int:
|
||||
return sum([len(v) for v in self.values() if sum(v)])
|
||||
|
||||
def number_nonzero_channels(self) -> int:
|
||||
return len([v for v in self.values() if sum(v)])
|
||||
|
||||
def number_nonzero_nodes(self) -> int:
|
||||
# using a set comprehension
|
||||
return len({nodeid for (_, nodeid), amounts in self.items() if sum(amounts)})
|
||||
|
||||
def total_config_amount(self) -> int:
|
||||
return sum([sum(c) for c in self.values()])
|
||||
|
||||
def is_any_amount_smaller_than_min_part_size(self) -> bool:
|
||||
smaller = False
|
||||
for amounts in self.values():
|
||||
if any([amount < MIN_PART_SIZE_MSAT for amount in amounts]):
|
||||
smaller |= True
|
||||
return smaller
|
||||
|
||||
|
||||
class SplitConfigRating(NamedTuple):
|
||||
config: SplitConfig
|
||||
rating: float
|
||||
@@ -41,31 +62,6 @@ def split_amount_normal(total_amount: int, num_parts: int) -> List[int]:
|
||||
return parts
|
||||
|
||||
|
||||
def number_parts(config: SplitConfig) -> int:
|
||||
return sum([len(v) for v in config.values() if sum(v)])
|
||||
|
||||
|
||||
def number_nonzero_channels(config: SplitConfig) -> int:
|
||||
return len([v for v in config.values() if sum(v)])
|
||||
|
||||
|
||||
def number_nonzero_nodes(config: SplitConfig) -> int:
|
||||
# using a set comprehension
|
||||
return len({nodeid for (_, nodeid), amounts in config.items() if sum(amounts)})
|
||||
|
||||
|
||||
def total_config_amount(config: SplitConfig) -> int:
|
||||
return sum([sum(c) for c in config.values()])
|
||||
|
||||
|
||||
def is_any_amount_smaller_than_min_part_size(config: SplitConfig) -> bool:
|
||||
smaller = False
|
||||
for amounts in config.values():
|
||||
if any([amount < MIN_PART_SIZE_MSAT for amount in amounts]):
|
||||
smaller |= True
|
||||
return smaller
|
||||
|
||||
|
||||
def remove_duplicates(configs: List[SplitConfig]) -> List[SplitConfig]:
|
||||
unique_configs = set()
|
||||
for config in configs:
|
||||
@@ -74,16 +70,16 @@ def remove_duplicates(configs: List[SplitConfig]) -> List[SplitConfig]:
|
||||
config_sorted_keys = {k: config_sorted_values[k] for k in sorted(config_sorted_values.keys())}
|
||||
hashable_config = tuple((c, tuple(sorted(config[c]))) for c in config_sorted_keys)
|
||||
unique_configs.add(hashable_config)
|
||||
unique_configs = [{c[0]: list(c[1]) for c in config} for config in unique_configs]
|
||||
unique_configs = [SplitConfig({c[0]: list(c[1]) for c in config}) for config in unique_configs]
|
||||
return unique_configs
|
||||
|
||||
|
||||
def remove_multiple_nodes(configs: List[SplitConfig]) -> List[SplitConfig]:
|
||||
return [config for config in configs if number_nonzero_nodes(config) == 1]
|
||||
return [config for config in configs if config.number_nonzero_nodes() == 1]
|
||||
|
||||
|
||||
def remove_single_part_configs(configs: List[SplitConfig]) -> List[SplitConfig]:
|
||||
return [config for config in configs if number_parts(config) != 1]
|
||||
return [config for config in configs if config.number_parts() != 1]
|
||||
|
||||
|
||||
def remove_single_channel_splits(configs: List[SplitConfig]) -> List[SplitConfig]:
|
||||
@@ -107,7 +103,7 @@ def rate_config(
|
||||
lowest (best). A penalty depending on the total amount sent over a channel
|
||||
counteracts channel exhaustion."""
|
||||
rating = 0
|
||||
total_amount = total_config_amount(config)
|
||||
total_amount = config.total_config_amount()
|
||||
|
||||
for channel, amounts in config.items():
|
||||
funds = channels_with_funds[channel]
|
||||
@@ -143,7 +139,7 @@ def suggest_splits(
|
||||
for _ in range(CANDIDATES_PER_LEVEL):
|
||||
# we want to have configurations with no splitting to many splittings
|
||||
for target_parts in range(1, MAX_PARTS):
|
||||
config = defaultdict(list) # type: SplitConfig
|
||||
config = SplitConfig()
|
||||
|
||||
# randomly split amount into target_parts chunks
|
||||
split_amounts = split_amount_normal(amount_msat, target_parts)
|
||||
@@ -152,6 +148,8 @@ def suggest_splits(
|
||||
random.shuffle(channels_order)
|
||||
# we check each channel and try to put the funds inside, break if we succeed
|
||||
for c in channels_order:
|
||||
if c not in config:
|
||||
config[c] = []
|
||||
if sum(config[c]) + amount <= channels_with_funds[c]:
|
||||
config[c].append(amount)
|
||||
break
|
||||
@@ -167,11 +165,11 @@ def suggest_splits(
|
||||
distribute_amount -= add_amount
|
||||
if distribute_amount == 0:
|
||||
break
|
||||
if total_config_amount(config) != amount_msat:
|
||||
if config.total_config_amount() != amount_msat:
|
||||
raise NoPathFound('Cannot distribute payment over channels.')
|
||||
if target_parts > 1 and is_any_amount_smaller_than_min_part_size(config):
|
||||
if target_parts > 1 and config.is_any_amount_smaller_than_min_part_size():
|
||||
continue
|
||||
assert total_config_amount(config) == amount_msat
|
||||
assert config.total_config_amount() == amount_msat
|
||||
configs.append(config)
|
||||
|
||||
configs = remove_duplicates(configs)
|
||||
|
||||
Reference in New Issue
Block a user