lnworker: run create_route_for_payment end-to-end, incl private edges
We pass the private edges to lnrouter, and let it find routes end-to-end. Previously the edge_cost heuristics didn't apply to the private edges and we were just randomly picking one of the route hints and use that. So e.g. cheaper private edges were not preferred, but they are now. PathEdge now stores both start_node and end_node; not just end_node.
This commit is contained in:
@@ -55,10 +55,14 @@ def fee_for_edge_msat(forwarded_amount_msat: int, fee_base_msat: int, fee_propor
|
||||
|
||||
@attr.s(slots=True)
|
||||
class PathEdge:
|
||||
"""if you travel through short_channel_id, you will reach node_id"""
|
||||
node_id = attr.ib(type=bytes, kw_only=True, repr=lambda val: val.hex())
|
||||
start_node = attr.ib(type=bytes, kw_only=True, repr=lambda val: val.hex())
|
||||
end_node = attr.ib(type=bytes, kw_only=True, repr=lambda val: val.hex())
|
||||
short_channel_id = attr.ib(type=ShortChannelID, kw_only=True, repr=lambda val: str(val))
|
||||
|
||||
@property
|
||||
def node_id(self) -> bytes:
|
||||
# legacy compat # TODO rm
|
||||
return self.end_node
|
||||
|
||||
@attr.s
|
||||
class RouteEdge(PathEdge):
|
||||
@@ -73,17 +77,26 @@ class RouteEdge(PathEdge):
|
||||
fee_proportional_millionths=self.fee_proportional_millionths)
|
||||
|
||||
@classmethod
|
||||
def from_channel_policy(cls, channel_policy: 'Policy',
|
||||
short_channel_id: bytes, end_node: bytes, *,
|
||||
node_info: Optional[NodeInfo]) -> 'RouteEdge':
|
||||
def from_channel_policy(
|
||||
cls,
|
||||
*,
|
||||
channel_policy: 'Policy',
|
||||
short_channel_id: bytes,
|
||||
start_node: bytes,
|
||||
end_node: bytes,
|
||||
node_info: Optional[NodeInfo], # for end_node
|
||||
) -> 'RouteEdge':
|
||||
assert isinstance(short_channel_id, bytes)
|
||||
assert type(start_node) is bytes
|
||||
assert type(end_node) is bytes
|
||||
return RouteEdge(node_id=end_node,
|
||||
short_channel_id=ShortChannelID.normalize(short_channel_id),
|
||||
fee_base_msat=channel_policy.fee_base_msat,
|
||||
fee_proportional_millionths=channel_policy.fee_proportional_millionths,
|
||||
cltv_expiry_delta=channel_policy.cltv_expiry_delta,
|
||||
node_features=node_info.features if node_info else 0)
|
||||
return RouteEdge(
|
||||
start_node=start_node,
|
||||
end_node=end_node,
|
||||
short_channel_id=ShortChannelID.normalize(short_channel_id),
|
||||
fee_base_msat=channel_policy.fee_base_msat,
|
||||
fee_proportional_millionths=channel_policy.fee_proportional_millionths,
|
||||
cltv_expiry_delta=channel_policy.cltv_expiry_delta,
|
||||
node_features=node_info.features if node_info else 0)
|
||||
|
||||
def is_sane_to_use(self, amount_msat: int) -> bool:
|
||||
# TODO revise ad-hoc heuristics
|
||||
@@ -155,21 +168,37 @@ class LNPathFinder(Logger):
|
||||
Logger.__init__(self)
|
||||
self.channel_db = channel_db
|
||||
|
||||
def _edge_cost(self, short_channel_id: bytes, start_node: bytes, end_node: bytes,
|
||||
payment_amt_msat: int, ignore_costs=False, is_mine=False, *,
|
||||
my_channels: Dict[ShortChannelID, 'Channel'] = None) -> Tuple[float, int]:
|
||||
def _edge_cost(
|
||||
self,
|
||||
*,
|
||||
short_channel_id: bytes,
|
||||
start_node: bytes,
|
||||
end_node: bytes,
|
||||
payment_amt_msat: int,
|
||||
ignore_costs=False,
|
||||
is_mine=False,
|
||||
my_channels: Dict[ShortChannelID, 'Channel'] = None,
|
||||
private_route_edges: Dict[ShortChannelID, RouteEdge] = None,
|
||||
) -> Tuple[float, int]:
|
||||
"""Heuristic cost (distance metric) of going through a channel.
|
||||
Returns (heuristic_cost, fee_for_edge_msat).
|
||||
"""
|
||||
channel_info = self.channel_db.get_channel_info(short_channel_id, my_channels=my_channels)
|
||||
if private_route_edges is None:
|
||||
private_route_edges = {}
|
||||
channel_info = self.channel_db.get_channel_info(
|
||||
short_channel_id, my_channels=my_channels, private_route_edges=private_route_edges)
|
||||
if channel_info is None:
|
||||
return float('inf'), 0
|
||||
channel_policy = self.channel_db.get_policy_for_node(short_channel_id, start_node, my_channels=my_channels)
|
||||
channel_policy = self.channel_db.get_policy_for_node(
|
||||
short_channel_id, start_node, my_channels=my_channels, private_route_edges=private_route_edges)
|
||||
if channel_policy is None:
|
||||
return float('inf'), 0
|
||||
# channels that did not publish both policies often return temporary channel failure
|
||||
if self.channel_db.get_policy_for_node(short_channel_id, end_node, my_channels=my_channels) is None \
|
||||
and not is_mine:
|
||||
channel_policy_backwards = self.channel_db.get_policy_for_node(
|
||||
short_channel_id, end_node, my_channels=my_channels, private_route_edges=private_route_edges)
|
||||
if (channel_policy_backwards is None
|
||||
and not is_mine
|
||||
and short_channel_id not in private_route_edges):
|
||||
return float('inf'), 0
|
||||
if channel_policy.is_disabled():
|
||||
return float('inf'), 0
|
||||
@@ -181,9 +210,15 @@ class LNPathFinder(Logger):
|
||||
if channel_policy.htlc_maximum_msat is not None and \
|
||||
payment_amt_msat > channel_policy.htlc_maximum_msat:
|
||||
return float('inf'), 0 # payment amount too large
|
||||
node_info = self.channel_db.get_node_info_for_node_id(node_id=end_node)
|
||||
route_edge = RouteEdge.from_channel_policy(channel_policy, short_channel_id, end_node,
|
||||
node_info=node_info)
|
||||
route_edge = private_route_edges.get(short_channel_id, None)
|
||||
if route_edge is None:
|
||||
node_info = self.channel_db.get_node_info_for_node_id(node_id=end_node)
|
||||
route_edge = RouteEdge.from_channel_policy(
|
||||
channel_policy=channel_policy,
|
||||
short_channel_id=short_channel_id,
|
||||
start_node=start_node,
|
||||
end_node=end_node,
|
||||
node_info=node_info)
|
||||
if not route_edge.is_sane_to_use(payment_amt_msat):
|
||||
return float('inf'), 0 # thanks but no thanks
|
||||
|
||||
@@ -201,9 +236,16 @@ class LNPathFinder(Logger):
|
||||
overall_cost = base_cost + fee_msat + cltv_cost
|
||||
return overall_cost, fee_msat
|
||||
|
||||
def get_distances(self, nodeA: bytes, nodeB: bytes, invoice_amount_msat: int, *,
|
||||
my_channels: Dict[ShortChannelID, 'Channel'] = None,
|
||||
blacklist: Set[ShortChannelID] = None) -> Dict[bytes, PathEdge]:
|
||||
def get_distances(
|
||||
self,
|
||||
*,
|
||||
nodeA: bytes,
|
||||
nodeB: bytes,
|
||||
invoice_amount_msat: int,
|
||||
my_channels: Dict[ShortChannelID, 'Channel'] = None,
|
||||
blacklist: Set[ShortChannelID] = None,
|
||||
private_route_edges: Dict[ShortChannelID, RouteEdge] = None,
|
||||
) -> Dict[bytes, PathEdge]:
|
||||
# note: we don't lock self.channel_db, so while the path finding runs,
|
||||
# the underlying graph could potentially change... (not good but maybe ~OK?)
|
||||
|
||||
@@ -226,11 +268,13 @@ class LNPathFinder(Logger):
|
||||
# so instead of decreasing priorities, we add items again into the queue.
|
||||
# so there are duplicates in the queue, that we discard now:
|
||||
continue
|
||||
for edge_channel_id in self.channel_db.get_channels_for_node(edge_endnode, my_channels=my_channels):
|
||||
for edge_channel_id in self.channel_db.get_channels_for_node(
|
||||
edge_endnode, my_channels=my_channels, private_route_edges=private_route_edges):
|
||||
assert isinstance(edge_channel_id, bytes)
|
||||
if blacklist and edge_channel_id in blacklist:
|
||||
continue
|
||||
channel_info = self.channel_db.get_channel_info(edge_channel_id, my_channels=my_channels)
|
||||
channel_info = self.channel_db.get_channel_info(
|
||||
edge_channel_id, my_channels=my_channels, private_route_edges=private_route_edges)
|
||||
edge_startnode = channel_info.node2_id if channel_info.node1_id == edge_endnode else channel_info.node1_id
|
||||
is_mine = edge_channel_id in my_channels
|
||||
if is_mine:
|
||||
@@ -242,29 +286,37 @@ class LNPathFinder(Logger):
|
||||
if not my_channels[edge_channel_id].can_receive(amount_msat, check_frozen=True):
|
||||
continue
|
||||
edge_cost, fee_for_edge_msat = self._edge_cost(
|
||||
edge_channel_id,
|
||||
short_channel_id=edge_channel_id,
|
||||
start_node=edge_startnode,
|
||||
end_node=edge_endnode,
|
||||
payment_amt_msat=amount_msat,
|
||||
ignore_costs=(edge_startnode == nodeA),
|
||||
is_mine=is_mine,
|
||||
my_channels=my_channels)
|
||||
my_channels=my_channels,
|
||||
private_route_edges=private_route_edges)
|
||||
alt_dist_to_neighbour = distance_from_start[edge_endnode] + edge_cost
|
||||
if alt_dist_to_neighbour < distance_from_start[edge_startnode]:
|
||||
distance_from_start[edge_startnode] = alt_dist_to_neighbour
|
||||
prev_node[edge_startnode] = PathEdge(node_id=edge_endnode,
|
||||
short_channel_id=ShortChannelID(edge_channel_id))
|
||||
prev_node[edge_startnode] = PathEdge(
|
||||
start_node=edge_startnode,
|
||||
end_node=edge_endnode,
|
||||
short_channel_id=ShortChannelID(edge_channel_id))
|
||||
amount_to_forward_msat = amount_msat + fee_for_edge_msat
|
||||
nodes_to_explore.put((alt_dist_to_neighbour, amount_to_forward_msat, edge_startnode))
|
||||
|
||||
return prev_node
|
||||
|
||||
@profiler
|
||||
def find_path_for_payment(self, nodeA: bytes, nodeB: bytes,
|
||||
invoice_amount_msat: int, *,
|
||||
my_channels: Dict[ShortChannelID, 'Channel'] = None,
|
||||
blacklist: Set[ShortChannelID] = None) \
|
||||
-> Optional[LNPaymentPath]:
|
||||
def find_path_for_payment(
|
||||
self,
|
||||
*,
|
||||
nodeA: bytes,
|
||||
nodeB: bytes,
|
||||
invoice_amount_msat: int,
|
||||
my_channels: Dict[ShortChannelID, 'Channel'] = None,
|
||||
blacklist: Set[ShortChannelID] = None,
|
||||
private_route_edges: Dict[ShortChannelID, RouteEdge] = None,
|
||||
) -> Optional[LNPaymentPath]:
|
||||
"""Return a path from nodeA to nodeB."""
|
||||
assert type(nodeA) is bytes
|
||||
assert type(nodeB) is bytes
|
||||
@@ -272,7 +324,13 @@ class LNPathFinder(Logger):
|
||||
if my_channels is None:
|
||||
my_channels = {}
|
||||
|
||||
prev_node = self.get_distances(nodeA, nodeB, invoice_amount_msat, my_channels=my_channels, blacklist=blacklist)
|
||||
prev_node = self.get_distances(
|
||||
nodeA=nodeA,
|
||||
nodeB=nodeB,
|
||||
invoice_amount_msat=invoice_amount_msat,
|
||||
my_channels=my_channels,
|
||||
blacklist=blacklist,
|
||||
private_route_edges=private_route_edges)
|
||||
|
||||
if nodeA not in prev_node:
|
||||
return None # no path found
|
||||
@@ -287,34 +345,66 @@ class LNPathFinder(Logger):
|
||||
edge_startnode = edge.node_id
|
||||
return path
|
||||
|
||||
def create_route_from_path(self, path: Optional[LNPaymentPath], from_node_id: bytes, *,
|
||||
my_channels: Dict[ShortChannelID, 'Channel'] = None) -> LNPaymentRoute:
|
||||
assert isinstance(from_node_id, bytes)
|
||||
def create_route_from_path(
|
||||
self,
|
||||
path: Optional[LNPaymentPath],
|
||||
*,
|
||||
my_channels: Dict[ShortChannelID, 'Channel'] = None,
|
||||
private_route_edges: Dict[ShortChannelID, RouteEdge] = None,
|
||||
) -> LNPaymentRoute:
|
||||
if path is None:
|
||||
raise Exception('cannot create route from None path')
|
||||
if private_route_edges is None:
|
||||
private_route_edges = {}
|
||||
route = []
|
||||
prev_node_id = from_node_id
|
||||
for edge in path:
|
||||
node_id = edge.node_id
|
||||
short_channel_id = edge.short_channel_id
|
||||
prev_end_node = path[0].start_node
|
||||
for path_edge in path:
|
||||
short_channel_id = path_edge.short_channel_id
|
||||
_endnodes = self.channel_db.get_endnodes_for_chan(short_channel_id, my_channels=my_channels)
|
||||
if _endnodes and sorted(_endnodes) != sorted([prev_node_id, node_id]):
|
||||
if _endnodes and sorted(_endnodes) != sorted([path_edge.start_node, path_edge.end_node]):
|
||||
raise LNPathInconsistent("endpoints of edge inconsistent with short_channel_id")
|
||||
if path_edge.start_node != prev_end_node:
|
||||
raise LNPathInconsistent("edges do not chain together")
|
||||
channel_policy = self.channel_db.get_policy_for_node(short_channel_id=short_channel_id,
|
||||
node_id=prev_node_id,
|
||||
my_channels=my_channels)
|
||||
if channel_policy is None:
|
||||
raise NoChannelPolicy(short_channel_id)
|
||||
node_info = self.channel_db.get_node_info_for_node_id(node_id=node_id)
|
||||
route.append(RouteEdge.from_channel_policy(channel_policy, short_channel_id, node_id,
|
||||
node_info=node_info))
|
||||
prev_node_id = node_id
|
||||
route_edge = private_route_edges.get(short_channel_id, None)
|
||||
if route_edge is None:
|
||||
channel_policy = self.channel_db.get_policy_for_node(
|
||||
short_channel_id=short_channel_id,
|
||||
node_id=path_edge.start_node,
|
||||
my_channels=my_channels)
|
||||
if channel_policy is None:
|
||||
raise NoChannelPolicy(short_channel_id)
|
||||
node_info = self.channel_db.get_node_info_for_node_id(node_id=path_edge.end_node)
|
||||
route_edge = RouteEdge.from_channel_policy(
|
||||
channel_policy=channel_policy,
|
||||
short_channel_id=short_channel_id,
|
||||
start_node=path_edge.start_node,
|
||||
end_node=path_edge.end_node,
|
||||
node_info=node_info)
|
||||
route.append(route_edge)
|
||||
prev_end_node = path_edge.end_node
|
||||
return route
|
||||
|
||||
def find_route(self, nodeA: bytes, nodeB: bytes, invoice_amount_msat: int, *,
|
||||
path = None, my_channels: Dict[ShortChannelID, 'Channel'] = None,
|
||||
blacklist: Set[ShortChannelID] = None) -> Optional[LNPaymentRoute]:
|
||||
def find_route(
|
||||
self,
|
||||
*,
|
||||
nodeA: bytes,
|
||||
nodeB: bytes,
|
||||
invoice_amount_msat: int,
|
||||
path = None,
|
||||
my_channels: Dict[ShortChannelID, 'Channel'] = None,
|
||||
blacklist: Set[ShortChannelID] = None,
|
||||
private_route_edges: Dict[ShortChannelID, RouteEdge] = None,
|
||||
) -> Optional[LNPaymentRoute]:
|
||||
route = None
|
||||
if not path:
|
||||
path = self.find_path_for_payment(nodeA, nodeB, invoice_amount_msat, my_channels=my_channels, blacklist=blacklist)
|
||||
path = self.find_path_for_payment(
|
||||
nodeA=nodeA,
|
||||
nodeB=nodeB,
|
||||
invoice_amount_msat=invoice_amount_msat,
|
||||
my_channels=my_channels,
|
||||
blacklist=blacklist,
|
||||
private_route_edges=private_route_edges)
|
||||
if path:
|
||||
return self.create_route_from_path(path, nodeA, my_channels=my_channels)
|
||||
route = self.create_route_from_path(
|
||||
path, my_channels=my_channels, private_route_edges=private_route_edges)
|
||||
return route
|
||||
|
||||
Reference in New Issue
Block a user