1
0

lnworker: run create_route_for_payment end-to-end, incl private edges

We pass the private edges to lnrouter, and let it find routes end-to-end.
Previously the edge_cost heuristics didn't apply to the private edges
and we were just randomly picking one of the route hints and use that.
So e.g. cheaper private edges were not preferred, but they are now.

PathEdge now stores both start_node and end_node; not just end_node.
This commit is contained in:
SomberNight
2021-03-02 18:00:31 +01:00
parent 4445cef033
commit 750d8cfab5
6 changed files with 321 additions and 157 deletions

View File

@@ -55,10 +55,14 @@ def fee_for_edge_msat(forwarded_amount_msat: int, fee_base_msat: int, fee_propor
@attr.s(slots=True)
class PathEdge:
"""if you travel through short_channel_id, you will reach node_id"""
node_id = attr.ib(type=bytes, kw_only=True, repr=lambda val: val.hex())
start_node = attr.ib(type=bytes, kw_only=True, repr=lambda val: val.hex())
end_node = attr.ib(type=bytes, kw_only=True, repr=lambda val: val.hex())
short_channel_id = attr.ib(type=ShortChannelID, kw_only=True, repr=lambda val: str(val))
@property
def node_id(self) -> bytes:
# legacy compat # TODO rm
return self.end_node
@attr.s
class RouteEdge(PathEdge):
@@ -73,17 +77,26 @@ class RouteEdge(PathEdge):
fee_proportional_millionths=self.fee_proportional_millionths)
@classmethod
def from_channel_policy(cls, channel_policy: 'Policy',
short_channel_id: bytes, end_node: bytes, *,
node_info: Optional[NodeInfo]) -> 'RouteEdge':
def from_channel_policy(
cls,
*,
channel_policy: 'Policy',
short_channel_id: bytes,
start_node: bytes,
end_node: bytes,
node_info: Optional[NodeInfo], # for end_node
) -> 'RouteEdge':
assert isinstance(short_channel_id, bytes)
assert type(start_node) is bytes
assert type(end_node) is bytes
return RouteEdge(node_id=end_node,
short_channel_id=ShortChannelID.normalize(short_channel_id),
fee_base_msat=channel_policy.fee_base_msat,
fee_proportional_millionths=channel_policy.fee_proportional_millionths,
cltv_expiry_delta=channel_policy.cltv_expiry_delta,
node_features=node_info.features if node_info else 0)
return RouteEdge(
start_node=start_node,
end_node=end_node,
short_channel_id=ShortChannelID.normalize(short_channel_id),
fee_base_msat=channel_policy.fee_base_msat,
fee_proportional_millionths=channel_policy.fee_proportional_millionths,
cltv_expiry_delta=channel_policy.cltv_expiry_delta,
node_features=node_info.features if node_info else 0)
def is_sane_to_use(self, amount_msat: int) -> bool:
# TODO revise ad-hoc heuristics
@@ -155,21 +168,37 @@ class LNPathFinder(Logger):
Logger.__init__(self)
self.channel_db = channel_db
def _edge_cost(self, short_channel_id: bytes, start_node: bytes, end_node: bytes,
payment_amt_msat: int, ignore_costs=False, is_mine=False, *,
my_channels: Dict[ShortChannelID, 'Channel'] = None) -> Tuple[float, int]:
def _edge_cost(
self,
*,
short_channel_id: bytes,
start_node: bytes,
end_node: bytes,
payment_amt_msat: int,
ignore_costs=False,
is_mine=False,
my_channels: Dict[ShortChannelID, 'Channel'] = None,
private_route_edges: Dict[ShortChannelID, RouteEdge] = None,
) -> Tuple[float, int]:
"""Heuristic cost (distance metric) of going through a channel.
Returns (heuristic_cost, fee_for_edge_msat).
"""
channel_info = self.channel_db.get_channel_info(short_channel_id, my_channels=my_channels)
if private_route_edges is None:
private_route_edges = {}
channel_info = self.channel_db.get_channel_info(
short_channel_id, my_channels=my_channels, private_route_edges=private_route_edges)
if channel_info is None:
return float('inf'), 0
channel_policy = self.channel_db.get_policy_for_node(short_channel_id, start_node, my_channels=my_channels)
channel_policy = self.channel_db.get_policy_for_node(
short_channel_id, start_node, my_channels=my_channels, private_route_edges=private_route_edges)
if channel_policy is None:
return float('inf'), 0
# channels that did not publish both policies often return temporary channel failure
if self.channel_db.get_policy_for_node(short_channel_id, end_node, my_channels=my_channels) is None \
and not is_mine:
channel_policy_backwards = self.channel_db.get_policy_for_node(
short_channel_id, end_node, my_channels=my_channels, private_route_edges=private_route_edges)
if (channel_policy_backwards is None
and not is_mine
and short_channel_id not in private_route_edges):
return float('inf'), 0
if channel_policy.is_disabled():
return float('inf'), 0
@@ -181,9 +210,15 @@ class LNPathFinder(Logger):
if channel_policy.htlc_maximum_msat is not None and \
payment_amt_msat > channel_policy.htlc_maximum_msat:
return float('inf'), 0 # payment amount too large
node_info = self.channel_db.get_node_info_for_node_id(node_id=end_node)
route_edge = RouteEdge.from_channel_policy(channel_policy, short_channel_id, end_node,
node_info=node_info)
route_edge = private_route_edges.get(short_channel_id, None)
if route_edge is None:
node_info = self.channel_db.get_node_info_for_node_id(node_id=end_node)
route_edge = RouteEdge.from_channel_policy(
channel_policy=channel_policy,
short_channel_id=short_channel_id,
start_node=start_node,
end_node=end_node,
node_info=node_info)
if not route_edge.is_sane_to_use(payment_amt_msat):
return float('inf'), 0 # thanks but no thanks
@@ -201,9 +236,16 @@ class LNPathFinder(Logger):
overall_cost = base_cost + fee_msat + cltv_cost
return overall_cost, fee_msat
def get_distances(self, nodeA: bytes, nodeB: bytes, invoice_amount_msat: int, *,
my_channels: Dict[ShortChannelID, 'Channel'] = None,
blacklist: Set[ShortChannelID] = None) -> Dict[bytes, PathEdge]:
def get_distances(
self,
*,
nodeA: bytes,
nodeB: bytes,
invoice_amount_msat: int,
my_channels: Dict[ShortChannelID, 'Channel'] = None,
blacklist: Set[ShortChannelID] = None,
private_route_edges: Dict[ShortChannelID, RouteEdge] = None,
) -> Dict[bytes, PathEdge]:
# note: we don't lock self.channel_db, so while the path finding runs,
# the underlying graph could potentially change... (not good but maybe ~OK?)
@@ -226,11 +268,13 @@ class LNPathFinder(Logger):
# so instead of decreasing priorities, we add items again into the queue.
# so there are duplicates in the queue, that we discard now:
continue
for edge_channel_id in self.channel_db.get_channels_for_node(edge_endnode, my_channels=my_channels):
for edge_channel_id in self.channel_db.get_channels_for_node(
edge_endnode, my_channels=my_channels, private_route_edges=private_route_edges):
assert isinstance(edge_channel_id, bytes)
if blacklist and edge_channel_id in blacklist:
continue
channel_info = self.channel_db.get_channel_info(edge_channel_id, my_channels=my_channels)
channel_info = self.channel_db.get_channel_info(
edge_channel_id, my_channels=my_channels, private_route_edges=private_route_edges)
edge_startnode = channel_info.node2_id if channel_info.node1_id == edge_endnode else channel_info.node1_id
is_mine = edge_channel_id in my_channels
if is_mine:
@@ -242,29 +286,37 @@ class LNPathFinder(Logger):
if not my_channels[edge_channel_id].can_receive(amount_msat, check_frozen=True):
continue
edge_cost, fee_for_edge_msat = self._edge_cost(
edge_channel_id,
short_channel_id=edge_channel_id,
start_node=edge_startnode,
end_node=edge_endnode,
payment_amt_msat=amount_msat,
ignore_costs=(edge_startnode == nodeA),
is_mine=is_mine,
my_channels=my_channels)
my_channels=my_channels,
private_route_edges=private_route_edges)
alt_dist_to_neighbour = distance_from_start[edge_endnode] + edge_cost
if alt_dist_to_neighbour < distance_from_start[edge_startnode]:
distance_from_start[edge_startnode] = alt_dist_to_neighbour
prev_node[edge_startnode] = PathEdge(node_id=edge_endnode,
short_channel_id=ShortChannelID(edge_channel_id))
prev_node[edge_startnode] = PathEdge(
start_node=edge_startnode,
end_node=edge_endnode,
short_channel_id=ShortChannelID(edge_channel_id))
amount_to_forward_msat = amount_msat + fee_for_edge_msat
nodes_to_explore.put((alt_dist_to_neighbour, amount_to_forward_msat, edge_startnode))
return prev_node
@profiler
def find_path_for_payment(self, nodeA: bytes, nodeB: bytes,
invoice_amount_msat: int, *,
my_channels: Dict[ShortChannelID, 'Channel'] = None,
blacklist: Set[ShortChannelID] = None) \
-> Optional[LNPaymentPath]:
def find_path_for_payment(
self,
*,
nodeA: bytes,
nodeB: bytes,
invoice_amount_msat: int,
my_channels: Dict[ShortChannelID, 'Channel'] = None,
blacklist: Set[ShortChannelID] = None,
private_route_edges: Dict[ShortChannelID, RouteEdge] = None,
) -> Optional[LNPaymentPath]:
"""Return a path from nodeA to nodeB."""
assert type(nodeA) is bytes
assert type(nodeB) is bytes
@@ -272,7 +324,13 @@ class LNPathFinder(Logger):
if my_channels is None:
my_channels = {}
prev_node = self.get_distances(nodeA, nodeB, invoice_amount_msat, my_channels=my_channels, blacklist=blacklist)
prev_node = self.get_distances(
nodeA=nodeA,
nodeB=nodeB,
invoice_amount_msat=invoice_amount_msat,
my_channels=my_channels,
blacklist=blacklist,
private_route_edges=private_route_edges)
if nodeA not in prev_node:
return None # no path found
@@ -287,34 +345,66 @@ class LNPathFinder(Logger):
edge_startnode = edge.node_id
return path
def create_route_from_path(self, path: Optional[LNPaymentPath], from_node_id: bytes, *,
my_channels: Dict[ShortChannelID, 'Channel'] = None) -> LNPaymentRoute:
assert isinstance(from_node_id, bytes)
def create_route_from_path(
self,
path: Optional[LNPaymentPath],
*,
my_channels: Dict[ShortChannelID, 'Channel'] = None,
private_route_edges: Dict[ShortChannelID, RouteEdge] = None,
) -> LNPaymentRoute:
if path is None:
raise Exception('cannot create route from None path')
if private_route_edges is None:
private_route_edges = {}
route = []
prev_node_id = from_node_id
for edge in path:
node_id = edge.node_id
short_channel_id = edge.short_channel_id
prev_end_node = path[0].start_node
for path_edge in path:
short_channel_id = path_edge.short_channel_id
_endnodes = self.channel_db.get_endnodes_for_chan(short_channel_id, my_channels=my_channels)
if _endnodes and sorted(_endnodes) != sorted([prev_node_id, node_id]):
if _endnodes and sorted(_endnodes) != sorted([path_edge.start_node, path_edge.end_node]):
raise LNPathInconsistent("endpoints of edge inconsistent with short_channel_id")
if path_edge.start_node != prev_end_node:
raise LNPathInconsistent("edges do not chain together")
channel_policy = self.channel_db.get_policy_for_node(short_channel_id=short_channel_id,
node_id=prev_node_id,
my_channels=my_channels)
if channel_policy is None:
raise NoChannelPolicy(short_channel_id)
node_info = self.channel_db.get_node_info_for_node_id(node_id=node_id)
route.append(RouteEdge.from_channel_policy(channel_policy, short_channel_id, node_id,
node_info=node_info))
prev_node_id = node_id
route_edge = private_route_edges.get(short_channel_id, None)
if route_edge is None:
channel_policy = self.channel_db.get_policy_for_node(
short_channel_id=short_channel_id,
node_id=path_edge.start_node,
my_channels=my_channels)
if channel_policy is None:
raise NoChannelPolicy(short_channel_id)
node_info = self.channel_db.get_node_info_for_node_id(node_id=path_edge.end_node)
route_edge = RouteEdge.from_channel_policy(
channel_policy=channel_policy,
short_channel_id=short_channel_id,
start_node=path_edge.start_node,
end_node=path_edge.end_node,
node_info=node_info)
route.append(route_edge)
prev_end_node = path_edge.end_node
return route
def find_route(self, nodeA: bytes, nodeB: bytes, invoice_amount_msat: int, *,
path = None, my_channels: Dict[ShortChannelID, 'Channel'] = None,
blacklist: Set[ShortChannelID] = None) -> Optional[LNPaymentRoute]:
def find_route(
self,
*,
nodeA: bytes,
nodeB: bytes,
invoice_amount_msat: int,
path = None,
my_channels: Dict[ShortChannelID, 'Channel'] = None,
blacklist: Set[ShortChannelID] = None,
private_route_edges: Dict[ShortChannelID, RouteEdge] = None,
) -> Optional[LNPaymentRoute]:
route = None
if not path:
path = self.find_path_for_payment(nodeA, nodeB, invoice_amount_msat, my_channels=my_channels, blacklist=blacklist)
path = self.find_path_for_payment(
nodeA=nodeA,
nodeB=nodeB,
invoice_amount_msat=invoice_amount_msat,
my_channels=my_channels,
blacklist=blacklist,
private_route_edges=private_route_edges)
if path:
return self.create_route_from_path(path, nodeA, my_channels=my_channels)
route = self.create_route_from_path(
path, my_channels=my_channels, private_route_edges=private_route_edges)
return route