lnrouter: add PathEdge/LNPaymentPath for (node_id, scid)
This commit is contained in:
@@ -50,11 +50,15 @@ def fee_for_edge_msat(forwarded_amount_msat: int, fee_base_msat: int, fee_propor
|
|||||||
+ (forwarded_amount_msat * fee_proportional_millionths // 1_000_000)
|
+ (forwarded_amount_msat * fee_proportional_millionths // 1_000_000)
|
||||||
|
|
||||||
|
|
||||||
@attr.s
|
@attr.s(slots=True)
|
||||||
class RouteEdge:
|
class PathEdge:
|
||||||
"""if you travel through short_channel_id, you will reach node_id"""
|
"""if you travel through short_channel_id, you will reach node_id"""
|
||||||
node_id = attr.ib(type=bytes, kw_only=True)
|
node_id = attr.ib(type=bytes, kw_only=True)
|
||||||
short_channel_id = attr.ib(type=ShortChannelID, kw_only=True)
|
short_channel_id = attr.ib(type=ShortChannelID, kw_only=True)
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s
|
||||||
|
class RouteEdge(PathEdge):
|
||||||
fee_base_msat = attr.ib(type=int, kw_only=True)
|
fee_base_msat = attr.ib(type=int, kw_only=True)
|
||||||
fee_proportional_millionths = attr.ib(type=int, kw_only=True)
|
fee_proportional_millionths = attr.ib(type=int, kw_only=True)
|
||||||
cltv_expiry_delta = attr.ib(type=int, kw_only=True)
|
cltv_expiry_delta = attr.ib(type=int, kw_only=True)
|
||||||
@@ -93,6 +97,7 @@ class RouteEdge:
|
|||||||
return bool(features & LnFeatures.VAR_ONION_REQ or features & LnFeatures.VAR_ONION_OPT)
|
return bool(features & LnFeatures.VAR_ONION_REQ or features & LnFeatures.VAR_ONION_OPT)
|
||||||
|
|
||||||
|
|
||||||
|
LNPaymentPath = Sequence[PathEdge]
|
||||||
LNPaymentRoute = Sequence[RouteEdge]
|
LNPaymentRoute = Sequence[RouteEdge]
|
||||||
|
|
||||||
|
|
||||||
@@ -186,8 +191,8 @@ class LNPathFinder(Logger):
|
|||||||
|
|
||||||
def get_distances(self, nodeA: bytes, nodeB: bytes,
|
def get_distances(self, nodeA: bytes, nodeB: bytes,
|
||||||
invoice_amount_msat: int, *,
|
invoice_amount_msat: int, *,
|
||||||
my_channels: Dict[ShortChannelID, 'Channel'] = None) \
|
my_channels: Dict[ShortChannelID, 'Channel'] = None
|
||||||
-> Optional[Sequence[Tuple[bytes, bytes]]]:
|
) -> Dict[bytes, PathEdge]:
|
||||||
# note: we don't lock self.channel_db, so while the path finding runs,
|
# note: we don't lock self.channel_db, so while the path finding runs,
|
||||||
# the underlying graph could potentially change... (not good but maybe ~OK?)
|
# the underlying graph could potentially change... (not good but maybe ~OK?)
|
||||||
|
|
||||||
@@ -196,7 +201,7 @@ class LNPathFinder(Logger):
|
|||||||
# to properly calculate compound routing fees.
|
# to properly calculate compound routing fees.
|
||||||
distance_from_start = defaultdict(lambda: float('inf'))
|
distance_from_start = defaultdict(lambda: float('inf'))
|
||||||
distance_from_start[nodeB] = 0
|
distance_from_start[nodeB] = 0
|
||||||
prev_node = {}
|
prev_node = {} # type: Dict[bytes, PathEdge]
|
||||||
nodes_to_explore = queue.PriorityQueue()
|
nodes_to_explore = queue.PriorityQueue()
|
||||||
nodes_to_explore.put((0, invoice_amount_msat, nodeB)) # order of fields (in tuple) matters!
|
nodes_to_explore.put((0, invoice_amount_msat, nodeB)) # order of fields (in tuple) matters!
|
||||||
|
|
||||||
@@ -237,7 +242,8 @@ class LNPathFinder(Logger):
|
|||||||
alt_dist_to_neighbour = distance_from_start[edge_endnode] + edge_cost
|
alt_dist_to_neighbour = distance_from_start[edge_endnode] + edge_cost
|
||||||
if alt_dist_to_neighbour < distance_from_start[edge_startnode]:
|
if alt_dist_to_neighbour < distance_from_start[edge_startnode]:
|
||||||
distance_from_start[edge_startnode] = alt_dist_to_neighbour
|
distance_from_start[edge_startnode] = alt_dist_to_neighbour
|
||||||
prev_node[edge_startnode] = edge_endnode, edge_channel_id
|
prev_node[edge_startnode] = PathEdge(node_id=edge_endnode,
|
||||||
|
short_channel_id=ShortChannelID(edge_channel_id))
|
||||||
amount_to_forward_msat = amount_msat + fee_for_edge_msat
|
amount_to_forward_msat = amount_msat + fee_for_edge_msat
|
||||||
nodes_to_explore.put((alt_dist_to_neighbour, amount_to_forward_msat, edge_startnode))
|
nodes_to_explore.put((alt_dist_to_neighbour, amount_to_forward_msat, edge_startnode))
|
||||||
|
|
||||||
@@ -247,13 +253,8 @@ class LNPathFinder(Logger):
|
|||||||
def find_path_for_payment(self, nodeA: bytes, nodeB: bytes,
|
def find_path_for_payment(self, nodeA: bytes, nodeB: bytes,
|
||||||
invoice_amount_msat: int, *,
|
invoice_amount_msat: int, *,
|
||||||
my_channels: Dict[ShortChannelID, 'Channel'] = None) \
|
my_channels: Dict[ShortChannelID, 'Channel'] = None) \
|
||||||
-> Optional[Sequence[Tuple[bytes, bytes]]]:
|
-> Optional[LNPaymentPath]:
|
||||||
"""Return a path from nodeA to nodeB.
|
"""Return a path from nodeA to nodeB."""
|
||||||
|
|
||||||
Returns a list of (node_id, short_channel_id) representing a path.
|
|
||||||
To get from node ret[n][0] to ret[n+1][0], use channel ret[n+1][1];
|
|
||||||
i.e. an element reads as, "to get to node_id, travel through short_channel_id"
|
|
||||||
"""
|
|
||||||
assert type(nodeA) is bytes
|
assert type(nodeA) is bytes
|
||||||
assert type(nodeB) is bytes
|
assert type(nodeB) is bytes
|
||||||
assert type(invoice_amount_msat) is int
|
assert type(invoice_amount_msat) is int
|
||||||
@@ -270,19 +271,21 @@ class LNPathFinder(Logger):
|
|||||||
edge_startnode = nodeA
|
edge_startnode = nodeA
|
||||||
path = []
|
path = []
|
||||||
while edge_startnode != nodeB:
|
while edge_startnode != nodeB:
|
||||||
edge_endnode, edge_taken = prev_node[edge_startnode]
|
edge = prev_node[edge_startnode]
|
||||||
path += [(edge_endnode, edge_taken)]
|
path += [edge]
|
||||||
edge_startnode = edge_endnode
|
edge_startnode = edge.node_id
|
||||||
return path
|
return path
|
||||||
|
|
||||||
def create_route_from_path(self, path, from_node_id: bytes, *,
|
def create_route_from_path(self, path: Optional[LNPaymentPath], from_node_id: bytes, *,
|
||||||
my_channels: Dict[ShortChannelID, 'Channel'] = None) -> LNPaymentRoute:
|
my_channels: Dict[ShortChannelID, 'Channel'] = None) -> LNPaymentRoute:
|
||||||
assert isinstance(from_node_id, bytes)
|
assert isinstance(from_node_id, bytes)
|
||||||
if path is None:
|
if path is None:
|
||||||
raise Exception('cannot create route from None path')
|
raise Exception('cannot create route from None path')
|
||||||
route = []
|
route = []
|
||||||
prev_node_id = from_node_id
|
prev_node_id = from_node_id
|
||||||
for node_id, short_channel_id in path:
|
for edge in path:
|
||||||
|
node_id = edge.node_id
|
||||||
|
short_channel_id = edge.short_channel_id
|
||||||
channel_policy = self.channel_db.get_policy_for_node(short_channel_id=short_channel_id,
|
channel_policy = self.channel_db.get_policy_for_node(short_channel_id=short_channel_id,
|
||||||
node_id=prev_node_id,
|
node_id=prev_node_id,
|
||||||
my_channels=my_channels)
|
my_channels=my_channels)
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from electrum.lnonion import (OnionHopsDataSingle, new_onion_packet,
|
|||||||
from electrum import bitcoin, lnrouter
|
from electrum import bitcoin, lnrouter
|
||||||
from electrum.constants import BitcoinTestnet
|
from electrum.constants import BitcoinTestnet
|
||||||
from electrum.simple_config import SimpleConfig
|
from electrum.simple_config import SimpleConfig
|
||||||
|
from electrum.lnrouter import PathEdge
|
||||||
|
|
||||||
from . import TestCaseForTestnet
|
from . import TestCaseForTestnet
|
||||||
from .test_bitcoin import needs_test_with_all_chacha20_implementations
|
from .test_bitcoin import needs_test_with_all_chacha20_implementations
|
||||||
@@ -17,20 +18,6 @@ from .test_bitcoin import needs_test_with_all_chacha20_implementations
|
|||||||
|
|
||||||
class Test_LNRouter(TestCaseForTestnet):
|
class Test_LNRouter(TestCaseForTestnet):
|
||||||
|
|
||||||
#@staticmethod
|
|
||||||
#def parse_witness_list(witness_bytes):
|
|
||||||
# amount_witnesses = witness_bytes[0]
|
|
||||||
# witness_bytes = witness_bytes[1:]
|
|
||||||
# res = []
|
|
||||||
# for i in range(amount_witnesses):
|
|
||||||
# witness_length = witness_bytes[0]
|
|
||||||
# this_witness = witness_bytes[1:witness_length+1]
|
|
||||||
# assert len(this_witness) == witness_length
|
|
||||||
# witness_bytes = witness_bytes[witness_length+1:]
|
|
||||||
# res += [bytes(this_witness)]
|
|
||||||
# assert witness_bytes == b"", witness_bytes
|
|
||||||
# return res
|
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super().setUp()
|
super().setUp()
|
||||||
self.asyncio_loop, self._stop_loop, self._loop_thread = create_and_start_event_loop()
|
self.asyncio_loop, self._stop_loop, self._loop_thread = create_and_start_event_loop()
|
||||||
@@ -97,13 +84,13 @@ class Test_LNRouter(TestCaseForTestnet):
|
|||||||
cdb.add_channel_update({'short_channel_id': bfh('0000000000000006'), 'message_flags': b'\x00', 'channel_flags': b'\x00', 'cltv_expiry_delta': 10, 'htlc_minimum_msat': 250, 'fee_base_msat': 100, 'fee_proportional_millionths': 99999999, 'chain_hash': BitcoinTestnet.rev_genesis_bytes(), 'timestamp': 0})
|
cdb.add_channel_update({'short_channel_id': bfh('0000000000000006'), 'message_flags': b'\x00', 'channel_flags': b'\x00', 'cltv_expiry_delta': 10, 'htlc_minimum_msat': 250, 'fee_base_msat': 100, 'fee_proportional_millionths': 99999999, 'chain_hash': BitcoinTestnet.rev_genesis_bytes(), 'timestamp': 0})
|
||||||
cdb.add_channel_update({'short_channel_id': bfh('0000000000000006'), 'message_flags': b'\x00', 'channel_flags': b'\x01', 'cltv_expiry_delta': 10, 'htlc_minimum_msat': 250, 'fee_base_msat': 100, 'fee_proportional_millionths': 150, 'chain_hash': BitcoinTestnet.rev_genesis_bytes(), 'timestamp': 0})
|
cdb.add_channel_update({'short_channel_id': bfh('0000000000000006'), 'message_flags': b'\x00', 'channel_flags': b'\x01', 'cltv_expiry_delta': 10, 'htlc_minimum_msat': 250, 'fee_base_msat': 100, 'fee_proportional_millionths': 150, 'chain_hash': BitcoinTestnet.rev_genesis_bytes(), 'timestamp': 0})
|
||||||
path = path_finder.find_path_for_payment(b'\x02aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', b'\x02eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee', 100000)
|
path = path_finder.find_path_for_payment(b'\x02aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', b'\x02eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee', 100000)
|
||||||
self.assertEqual([(b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', b'\x00\x00\x00\x00\x00\x00\x00\x03'),
|
self.assertEqual([PathEdge(node_id=b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', short_channel_id=bfh('0000000000000003')),
|
||||||
(b'\x02eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee', b'\x00\x00\x00\x00\x00\x00\x00\x02'),
|
PathEdge(node_id=b'\x02eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee', short_channel_id=bfh('0000000000000002')),
|
||||||
], path)
|
], path)
|
||||||
start_node = b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb'
|
start_node = b'\x02aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa'
|
||||||
route = path_finder.create_route_from_path(path, start_node)
|
route = path_finder.create_route_from_path(path, start_node)
|
||||||
self.assertEqual(route[0].node_id, start_node)
|
self.assertEqual(b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', route[0].node_id)
|
||||||
self.assertEqual(route[0].short_channel_id, bfh('0000000000000003'))
|
self.assertEqual(bfh('0000000000000003'), route[0].short_channel_id)
|
||||||
|
|
||||||
# need to duplicate tear_down here, as we also need to wait for the sql thread to stop
|
# need to duplicate tear_down here, as we also need to wait for the sql thread to stop
|
||||||
self.asyncio_loop.call_soon_threadsafe(self._stop_loop.set_result, 1)
|
self.asyncio_loop.call_soon_threadsafe(self._stop_loop.set_result, 1)
|
||||||
|
|||||||
Reference in New Issue
Block a user