1
0

lnworker: run create_route_for_payment end-to-end, incl private edges

We pass the private edges to lnrouter, and let it find routes end-to-end.
Previously the edge_cost heuristics didn't apply to the private edges
and we were just randomly picking one of the route hints and use that.
So e.g. cheaper private edges were not preferred, but they are now.

PathEdge now stores both start_node and end_node; not just end_node.
This commit is contained in:
SomberNight
2021-03-02 18:00:31 +01:00
parent 4445cef033
commit 750d8cfab5
6 changed files with 321 additions and 157 deletions

View File

@@ -48,6 +48,7 @@ from .lnmsg import decode_msg
if TYPE_CHECKING:
from .network import Network
from .lnchannel import Channel
from .lnrouter import RouteEdge
FLAG_DISABLE = 1 << 1
@@ -81,6 +82,16 @@ class ChannelInfo(NamedTuple):
payload_dict = decode_msg(raw)[1]
return ChannelInfo.from_msg(payload_dict)
@staticmethod
def from_route_edge(route_edge: 'RouteEdge') -> 'ChannelInfo':
node1_id, node2_id = sorted([route_edge.start_node, route_edge.end_node])
return ChannelInfo(
short_channel_id=route_edge.short_channel_id,
node1_id=node1_id,
node2_id=node2_id,
capacity_sat=None,
)
class Policy(NamedTuple):
key: bytes
@@ -113,6 +124,20 @@ class Policy(NamedTuple):
payload['start_node'] = key[8:]
return Policy.from_msg(payload)
@staticmethod
def from_route_edge(route_edge: 'RouteEdge') -> 'Policy':
return Policy(
key=route_edge.short_channel_id + route_edge.start_node,
cltv_expiry_delta=route_edge.cltv_expiry_delta,
htlc_minimum_msat=0,
htlc_maximum_msat=None,
fee_base_msat=route_edge.fee_base_msat,
fee_proportional_millionths=route_edge.fee_proportional_millionths,
channel_flags=0,
message_flags=0,
timestamp=0,
)
def is_disabled(self):
return self.channel_flags & FLAG_DISABLE
@@ -216,6 +241,8 @@ class CategorizedChannelUpdates(NamedTuple):
def get_mychannel_info(short_channel_id: ShortChannelID,
my_channels: Dict[ShortChannelID, 'Channel']) -> Optional[ChannelInfo]:
chan = my_channels.get(short_channel_id)
if not chan:
return
ci = ChannelInfo.from_raw_msg(chan.construct_channel_announcement_without_sigs())
return ci._replace(capacity_sat=chan.constraints.capacity)
@@ -724,8 +751,14 @@ class ChannelDB(SqlDB):
nchans_with_2p = len(self._chans_with_2_policies)
return nchans_with_0p, nchans_with_1p, nchans_with_2p
def get_policy_for_node(self, short_channel_id: bytes, node_id: bytes, *,
my_channels: Dict[ShortChannelID, 'Channel'] = None) -> Optional['Policy']:
def get_policy_for_node(
self,
short_channel_id: bytes,
node_id: bytes,
*,
my_channels: Dict[ShortChannelID, 'Channel'] = None,
private_route_edges: Dict[ShortChannelID, 'RouteEdge'] = None,
) -> Optional['Policy']:
channel_info = self.get_channel_info(short_channel_id)
if channel_info is not None: # publicly announced channel
policy = self._policies.get((node_id, short_channel_id))
@@ -737,28 +770,56 @@ class ChannelDB(SqlDB):
return Policy.from_msg(chan_upd_dict)
# check if it's one of our own channels
if my_channels:
return get_mychannel_policy(short_channel_id, node_id, my_channels)
policy = get_mychannel_policy(short_channel_id, node_id, my_channels)
if policy:
return policy
if private_route_edges:
route_edge = private_route_edges.get(short_channel_id, None)
if route_edge:
return Policy.from_route_edge(route_edge)
def get_channel_info(self, short_channel_id: ShortChannelID, *,
my_channels: Dict[ShortChannelID, 'Channel'] = None) -> Optional[ChannelInfo]:
def get_channel_info(
self,
short_channel_id: ShortChannelID,
*,
my_channels: Dict[ShortChannelID, 'Channel'] = None,
private_route_edges: Dict[ShortChannelID, 'RouteEdge'] = None,
) -> Optional[ChannelInfo]:
ret = self._channels.get(short_channel_id)
if ret:
return ret
# check if it's one of our own channels
if my_channels:
return get_mychannel_info(short_channel_id, my_channels)
channel_info = get_mychannel_info(short_channel_id, my_channels)
if channel_info:
return channel_info
if private_route_edges:
route_edge = private_route_edges.get(short_channel_id)
if route_edge:
return ChannelInfo.from_route_edge(route_edge)
def get_channels_for_node(self, node_id: bytes, *,
my_channels: Dict[ShortChannelID, 'Channel'] = None) -> Set[bytes]:
def get_channels_for_node(
self,
node_id: bytes,
*,
my_channels: Dict[ShortChannelID, 'Channel'] = None,
private_route_edges: Dict[ShortChannelID, 'RouteEdge'] = None,
) -> Set[bytes]:
"""Returns the set of short channel IDs where node_id is one of the channel participants."""
if not self.data_loaded.is_set():
raise Exception("channelDB data not loaded yet!")
relevant_channels = self._channels_for_node.get(node_id) or set()
relevant_channels = set(relevant_channels) # copy
# add our own channels # TODO maybe slow?
for chan in (my_channels.values() or []):
if node_id in (chan.node_id, chan.get_local_pubkey()):
relevant_channels.add(chan.short_channel_id)
if my_channels:
for chan in my_channels.values():
if node_id in (chan.node_id, chan.get_local_pubkey()):
relevant_channels.add(chan.short_channel_id)
# add private channels # TODO maybe slow?
if private_route_edges:
for route_edge in private_route_edges.values():
if node_id in (route_edge.start_node, route_edge.end_node):
relevant_channels.add(route_edge.short_channel_id)
return relevant_channels
def get_endnodes_for_chan(self, short_channel_id: ShortChannelID, *,