1
0

pass blacklist to lnrouter.find_route, so that lnrouter is stateless (see #6778)

This commit is contained in:
ThomasV
2021-01-11 15:19:50 +01:00
parent 9d7a317404
commit ad91257729
5 changed files with 36 additions and 28 deletions

View File

@@ -135,24 +135,12 @@ def is_fee_sane(fee_msat: int, *, payment_amount_msat: int) -> bool:
return False
BLACKLIST_DURATION = 3600
class LNPathFinder(Logger):
def __init__(self, channel_db: ChannelDB):
Logger.__init__(self)
self.channel_db = channel_db
self.blacklist = dict() # short_chan_id -> timestamp
def add_to_blacklist(self, short_channel_id: ShortChannelID):
self.logger.info(f'blacklisting channel {short_channel_id}')
now = int(time.time())
self.blacklist[short_channel_id] = now
def is_blacklisted(self, short_channel_id: ShortChannelID) -> bool:
now = int(time.time())
t = self.blacklist.get(short_channel_id, 0)
return now - t < BLACKLIST_DURATION
def _edge_cost(self, short_channel_id: bytes, start_node: bytes, end_node: bytes,
payment_amt_msat: int, ignore_costs=False, is_mine=False, *,
@@ -200,10 +188,9 @@ class LNPathFinder(Logger):
overall_cost = base_cost + fee_msat + cltv_cost
return overall_cost, fee_msat
def get_distances(self, nodeA: bytes, nodeB: bytes,
invoice_amount_msat: int, *,
my_channels: Dict[ShortChannelID, 'Channel'] = None
) -> Dict[bytes, PathEdge]:
def get_distances(self, nodeA: bytes, nodeB: bytes, invoice_amount_msat: int, *,
my_channels: Dict[ShortChannelID, 'Channel'] = None,
blacklist: Set[ShortChannelID] = None) -> Dict[bytes, PathEdge]:
# note: we don't lock self.channel_db, so while the path finding runs,
# the underlying graph could potentially change... (not good but maybe ~OK?)
@@ -216,7 +203,6 @@ class LNPathFinder(Logger):
nodes_to_explore = queue.PriorityQueue()
nodes_to_explore.put((0, invoice_amount_msat, nodeB)) # order of fields (in tuple) matters!
# main loop of search
while nodes_to_explore.qsize() > 0:
dist_to_edge_endnode, amount_msat, edge_endnode = nodes_to_explore.get()
@@ -229,7 +215,7 @@ class LNPathFinder(Logger):
continue
for edge_channel_id in self.channel_db.get_channels_for_node(edge_endnode, my_channels=my_channels):
assert isinstance(edge_channel_id, bytes)
if self.is_blacklisted(edge_channel_id):
if blacklist and edge_channel_id in blacklist:
continue
channel_info = self.channel_db.get_channel_info(edge_channel_id, my_channels=my_channels)
edge_startnode = channel_info.node2_id if channel_info.node1_id == edge_endnode else channel_info.node1_id
@@ -263,7 +249,8 @@ class LNPathFinder(Logger):
@profiler
def find_path_for_payment(self, nodeA: bytes, nodeB: bytes,
invoice_amount_msat: int, *,
my_channels: Dict[ShortChannelID, 'Channel'] = None) \
my_channels: Dict[ShortChannelID, 'Channel'] = None,
blacklist: Set[ShortChannelID] = None) \
-> Optional[LNPaymentPath]:
"""Return a path from nodeA to nodeB."""
assert type(nodeA) is bytes
@@ -272,7 +259,7 @@ class LNPathFinder(Logger):
if my_channels is None:
my_channels = {}
prev_node = self.get_distances(nodeA, nodeB, invoice_amount_msat, my_channels=my_channels)
prev_node = self.get_distances(nodeA, nodeB, invoice_amount_msat, my_channels=my_channels, blacklist=blacklist)
if nodeA not in prev_node:
return None # no path found
@@ -312,8 +299,9 @@ class LNPathFinder(Logger):
return route
def find_route(self, nodeA: bytes, nodeB: bytes, invoice_amount_msat: int, *,
path = None, my_channels: Dict[ShortChannelID, 'Channel'] = None) -> Optional[LNPaymentRoute]:
path = None, my_channels: Dict[ShortChannelID, 'Channel'] = None,
blacklist: Set[ShortChannelID] = None) -> Optional[LNPaymentRoute]:
if not path:
path = self.find_path_for_payment(nodeA, nodeB, invoice_amount_msat, my_channels=my_channels)
path = self.find_path_for_payment(nodeA, nodeB, invoice_amount_msat, my_channels=my_channels, blacklist=blacklist)
if path:
return self.create_route_from_path(path, nodeA, my_channels=my_channels)