pass blacklist to lnrouter.find_route, so that lnrouter is stateless (see #6778)
This commit is contained in:
@@ -135,24 +135,12 @@ def is_fee_sane(fee_msat: int, *, payment_amount_msat: int) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
BLACKLIST_DURATION = 3600
|
|
||||||
|
|
||||||
class LNPathFinder(Logger):
|
class LNPathFinder(Logger):
|
||||||
|
|
||||||
def __init__(self, channel_db: ChannelDB):
|
def __init__(self, channel_db: ChannelDB):
|
||||||
Logger.__init__(self)
|
Logger.__init__(self)
|
||||||
self.channel_db = channel_db
|
self.channel_db = channel_db
|
||||||
self.blacklist = dict() # short_chan_id -> timestamp
|
|
||||||
|
|
||||||
def add_to_blacklist(self, short_channel_id: ShortChannelID):
|
|
||||||
self.logger.info(f'blacklisting channel {short_channel_id}')
|
|
||||||
now = int(time.time())
|
|
||||||
self.blacklist[short_channel_id] = now
|
|
||||||
|
|
||||||
def is_blacklisted(self, short_channel_id: ShortChannelID) -> bool:
|
|
||||||
now = int(time.time())
|
|
||||||
t = self.blacklist.get(short_channel_id, 0)
|
|
||||||
return now - t < BLACKLIST_DURATION
|
|
||||||
|
|
||||||
def _edge_cost(self, short_channel_id: bytes, start_node: bytes, end_node: bytes,
|
def _edge_cost(self, short_channel_id: bytes, start_node: bytes, end_node: bytes,
|
||||||
payment_amt_msat: int, ignore_costs=False, is_mine=False, *,
|
payment_amt_msat: int, ignore_costs=False, is_mine=False, *,
|
||||||
@@ -200,10 +188,9 @@ class LNPathFinder(Logger):
|
|||||||
overall_cost = base_cost + fee_msat + cltv_cost
|
overall_cost = base_cost + fee_msat + cltv_cost
|
||||||
return overall_cost, fee_msat
|
return overall_cost, fee_msat
|
||||||
|
|
||||||
def get_distances(self, nodeA: bytes, nodeB: bytes,
|
def get_distances(self, nodeA: bytes, nodeB: bytes, invoice_amount_msat: int, *,
|
||||||
invoice_amount_msat: int, *,
|
my_channels: Dict[ShortChannelID, 'Channel'] = None,
|
||||||
my_channels: Dict[ShortChannelID, 'Channel'] = None
|
blacklist: Set[ShortChannelID] = None) -> Dict[bytes, PathEdge]:
|
||||||
) -> Dict[bytes, PathEdge]:
|
|
||||||
# note: we don't lock self.channel_db, so while the path finding runs,
|
# note: we don't lock self.channel_db, so while the path finding runs,
|
||||||
# the underlying graph could potentially change... (not good but maybe ~OK?)
|
# the underlying graph could potentially change... (not good but maybe ~OK?)
|
||||||
|
|
||||||
@@ -216,7 +203,6 @@ class LNPathFinder(Logger):
|
|||||||
nodes_to_explore = queue.PriorityQueue()
|
nodes_to_explore = queue.PriorityQueue()
|
||||||
nodes_to_explore.put((0, invoice_amount_msat, nodeB)) # order of fields (in tuple) matters!
|
nodes_to_explore.put((0, invoice_amount_msat, nodeB)) # order of fields (in tuple) matters!
|
||||||
|
|
||||||
|
|
||||||
# main loop of search
|
# main loop of search
|
||||||
while nodes_to_explore.qsize() > 0:
|
while nodes_to_explore.qsize() > 0:
|
||||||
dist_to_edge_endnode, amount_msat, edge_endnode = nodes_to_explore.get()
|
dist_to_edge_endnode, amount_msat, edge_endnode = nodes_to_explore.get()
|
||||||
@@ -229,7 +215,7 @@ class LNPathFinder(Logger):
|
|||||||
continue
|
continue
|
||||||
for edge_channel_id in self.channel_db.get_channels_for_node(edge_endnode, my_channels=my_channels):
|
for edge_channel_id in self.channel_db.get_channels_for_node(edge_endnode, my_channels=my_channels):
|
||||||
assert isinstance(edge_channel_id, bytes)
|
assert isinstance(edge_channel_id, bytes)
|
||||||
if self.is_blacklisted(edge_channel_id):
|
if blacklist and edge_channel_id in blacklist:
|
||||||
continue
|
continue
|
||||||
channel_info = self.channel_db.get_channel_info(edge_channel_id, my_channels=my_channels)
|
channel_info = self.channel_db.get_channel_info(edge_channel_id, my_channels=my_channels)
|
||||||
edge_startnode = channel_info.node2_id if channel_info.node1_id == edge_endnode else channel_info.node1_id
|
edge_startnode = channel_info.node2_id if channel_info.node1_id == edge_endnode else channel_info.node1_id
|
||||||
@@ -263,7 +249,8 @@ class LNPathFinder(Logger):
|
|||||||
@profiler
|
@profiler
|
||||||
def find_path_for_payment(self, nodeA: bytes, nodeB: bytes,
|
def find_path_for_payment(self, nodeA: bytes, nodeB: bytes,
|
||||||
invoice_amount_msat: int, *,
|
invoice_amount_msat: int, *,
|
||||||
my_channels: Dict[ShortChannelID, 'Channel'] = None) \
|
my_channels: Dict[ShortChannelID, 'Channel'] = None,
|
||||||
|
blacklist: Set[ShortChannelID] = None) \
|
||||||
-> Optional[LNPaymentPath]:
|
-> Optional[LNPaymentPath]:
|
||||||
"""Return a path from nodeA to nodeB."""
|
"""Return a path from nodeA to nodeB."""
|
||||||
assert type(nodeA) is bytes
|
assert type(nodeA) is bytes
|
||||||
@@ -272,7 +259,7 @@ class LNPathFinder(Logger):
|
|||||||
if my_channels is None:
|
if my_channels is None:
|
||||||
my_channels = {}
|
my_channels = {}
|
||||||
|
|
||||||
prev_node = self.get_distances(nodeA, nodeB, invoice_amount_msat, my_channels=my_channels)
|
prev_node = self.get_distances(nodeA, nodeB, invoice_amount_msat, my_channels=my_channels, blacklist=blacklist)
|
||||||
|
|
||||||
if nodeA not in prev_node:
|
if nodeA not in prev_node:
|
||||||
return None # no path found
|
return None # no path found
|
||||||
@@ -312,8 +299,9 @@ class LNPathFinder(Logger):
|
|||||||
return route
|
return route
|
||||||
|
|
||||||
def find_route(self, nodeA: bytes, nodeB: bytes, invoice_amount_msat: int, *,
|
def find_route(self, nodeA: bytes, nodeB: bytes, invoice_amount_msat: int, *,
|
||||||
path = None, my_channels: Dict[ShortChannelID, 'Channel'] = None) -> Optional[LNPaymentRoute]:
|
path = None, my_channels: Dict[ShortChannelID, 'Channel'] = None,
|
||||||
|
blacklist: Set[ShortChannelID] = None) -> Optional[LNPaymentRoute]:
|
||||||
if not path:
|
if not path:
|
||||||
path = self.find_path_for_payment(nodeA, nodeB, invoice_amount_msat, my_channels=my_channels)
|
path = self.find_path_for_payment(nodeA, nodeB, invoice_amount_msat, my_channels=my_channels, blacklist=blacklist)
|
||||||
if path:
|
if path:
|
||||||
return self.create_route_from_path(path, nodeA, my_channels=my_channels)
|
return self.create_route_from_path(path, nodeA, my_channels=my_channels)
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import json
|
|||||||
from collections import namedtuple, defaultdict
|
from collections import namedtuple, defaultdict
|
||||||
from typing import NamedTuple, List, Tuple, Mapping, Optional, TYPE_CHECKING, Union, Dict, Set, Sequence
|
from typing import NamedTuple, List, Tuple, Mapping, Optional, TYPE_CHECKING, Union, Dict, Set, Sequence
|
||||||
import re
|
import re
|
||||||
|
import time
|
||||||
import attr
|
import attr
|
||||||
from aiorpcx import NetAddress
|
from aiorpcx import NetAddress
|
||||||
|
|
||||||
@@ -1313,3 +1313,17 @@ class OnionFailureCodeMetaFlag(IntFlag):
|
|||||||
NODE = 0x2000
|
NODE = 0x2000
|
||||||
UPDATE = 0x1000
|
UPDATE = 0x1000
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelBlackList:
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.blacklist = dict() # short_chan_id -> timestamp
|
||||||
|
|
||||||
|
def add(self, short_channel_id: ShortChannelID):
|
||||||
|
now = int(time.time())
|
||||||
|
self.blacklist[short_channel_id] = now
|
||||||
|
|
||||||
|
def get_current_list(self) -> Set[ShortChannelID]:
|
||||||
|
BLACKLIST_DURATION = 3600
|
||||||
|
now = int(time.time())
|
||||||
|
return set(k for k, t in self.blacklist.items() if now - t < BLACKLIST_DURATION)
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import os
|
|||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
from typing import Optional, Sequence, Tuple, List, Dict, TYPE_CHECKING, NamedTuple, Union, Mapping, Any
|
from typing import Optional, Sequence, Tuple, List, Set, Dict, TYPE_CHECKING, NamedTuple, Union, Mapping, Any
|
||||||
import threading
|
import threading
|
||||||
import socket
|
import socket
|
||||||
import aiohttp
|
import aiohttp
|
||||||
@@ -540,6 +540,7 @@ class LNGossip(LNWorker):
|
|||||||
if categorized_chan_upds.good:
|
if categorized_chan_upds.good:
|
||||||
self.logger.debug(f'on_channel_update: {len(categorized_chan_upds.good)}/{len(chan_upds_chunk)}')
|
self.logger.debug(f'on_channel_update: {len(categorized_chan_upds.good)}/{len(chan_upds_chunk)}')
|
||||||
|
|
||||||
|
|
||||||
class LNWallet(LNWorker):
|
class LNWallet(LNWorker):
|
||||||
|
|
||||||
lnwatcher: Optional['LNWalletWatcher']
|
lnwatcher: Optional['LNWalletWatcher']
|
||||||
@@ -1014,7 +1015,8 @@ class LNWallet(LNWorker):
|
|||||||
except IndexError:
|
except IndexError:
|
||||||
self.logger.info("payment destination reported error")
|
self.logger.info("payment destination reported error")
|
||||||
else:
|
else:
|
||||||
self.network.path_finder.add_to_blacklist(short_chan_id)
|
self.logger.info(f'blacklisting channel {short_channel_id}')
|
||||||
|
self.network.channel_blacklist.add(short_chan_id)
|
||||||
else:
|
else:
|
||||||
# probably got "update_fail_malformed_htlc". well... who to penalise now?
|
# probably got "update_fail_malformed_htlc". well... who to penalise now?
|
||||||
assert payment_attempt.failure_message is not None
|
assert payment_attempt.failure_message is not None
|
||||||
@@ -1127,6 +1129,7 @@ class LNWallet(LNWorker):
|
|||||||
channels = list(self.channels.values())
|
channels = list(self.channels.values())
|
||||||
scid_to_my_channels = {chan.short_channel_id: chan for chan in channels
|
scid_to_my_channels = {chan.short_channel_id: chan for chan in channels
|
||||||
if chan.short_channel_id is not None}
|
if chan.short_channel_id is not None}
|
||||||
|
blacklist = self.network.channel_blacklist.get_current_list()
|
||||||
for private_route in r_tags:
|
for private_route in r_tags:
|
||||||
if len(private_route) == 0:
|
if len(private_route) == 0:
|
||||||
continue
|
continue
|
||||||
@@ -1144,7 +1147,7 @@ class LNWallet(LNWorker):
|
|||||||
try:
|
try:
|
||||||
route = self.network.path_finder.find_route(
|
route = self.network.path_finder.find_route(
|
||||||
self.node_keypair.pubkey, border_node_pubkey, amount_msat,
|
self.node_keypair.pubkey, border_node_pubkey, amount_msat,
|
||||||
path=path, my_channels=scid_to_my_channels)
|
path=path, my_channels=scid_to_my_channels, blacklist=blacklist)
|
||||||
except NoChannelPolicy:
|
except NoChannelPolicy:
|
||||||
continue
|
continue
|
||||||
if not route:
|
if not route:
|
||||||
@@ -1186,7 +1189,7 @@ class LNWallet(LNWorker):
|
|||||||
if route is None:
|
if route is None:
|
||||||
route = self.network.path_finder.find_route(
|
route = self.network.path_finder.find_route(
|
||||||
self.node_keypair.pubkey, invoice_pubkey, amount_msat,
|
self.node_keypair.pubkey, invoice_pubkey, amount_msat,
|
||||||
path=full_path, my_channels=scid_to_my_channels)
|
path=full_path, my_channels=scid_to_my_channels, blacklist=blacklist)
|
||||||
if not route:
|
if not route:
|
||||||
raise NoPathFound()
|
raise NoPathFound()
|
||||||
if not is_route_sane_to_use(route, amount_msat, decoded_invoice.get_min_final_cltv_expiry()):
|
if not is_route_sane_to_use(route, amount_msat, decoded_invoice.get_min_final_cltv_expiry()):
|
||||||
|
|||||||
@@ -45,7 +45,6 @@ from . import util
|
|||||||
from .util import (log_exceptions, ignore_exceptions,
|
from .util import (log_exceptions, ignore_exceptions,
|
||||||
bfh, SilentTaskGroup, make_aiohttp_session, send_exception_to_crash_reporter,
|
bfh, SilentTaskGroup, make_aiohttp_session, send_exception_to_crash_reporter,
|
||||||
is_hash256_str, is_non_negative_integer, MyEncoder, NetworkRetryManager)
|
is_hash256_str, is_non_negative_integer, MyEncoder, NetworkRetryManager)
|
||||||
|
|
||||||
from .bitcoin import COIN
|
from .bitcoin import COIN
|
||||||
from . import constants
|
from . import constants
|
||||||
from . import blockchain
|
from . import blockchain
|
||||||
@@ -60,6 +59,7 @@ from .version import PROTOCOL_VERSION
|
|||||||
from .simple_config import SimpleConfig
|
from .simple_config import SimpleConfig
|
||||||
from .i18n import _
|
from .i18n import _
|
||||||
from .logging import get_logger, Logger
|
from .logging import get_logger, Logger
|
||||||
|
from .lnutil import ChannelBlackList
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .channel_db import ChannelDB
|
from .channel_db import ChannelDB
|
||||||
@@ -335,6 +335,7 @@ class Network(Logger, NetworkRetryManager[ServerAddr]):
|
|||||||
self._has_ever_managed_to_connect_to_server = False
|
self._has_ever_managed_to_connect_to_server = False
|
||||||
|
|
||||||
# lightning network
|
# lightning network
|
||||||
|
self.channel_blacklist = ChannelBlackList()
|
||||||
self.channel_db = None # type: Optional[ChannelDB]
|
self.channel_db = None # type: Optional[ChannelDB]
|
||||||
self.lngossip = None # type: Optional[LNGossip]
|
self.lngossip = None # type: Optional[LNGossip]
|
||||||
self.local_watchtower = None # type: Optional[WatchTower]
|
self.local_watchtower = None # type: Optional[WatchTower]
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ from electrum.lnmsg import encode_msg, decode_msg
|
|||||||
from electrum.logging import console_stderr_handler, Logger
|
from electrum.logging import console_stderr_handler, Logger
|
||||||
from electrum.lnworker import PaymentInfo, RECEIVED, PR_UNPAID
|
from electrum.lnworker import PaymentInfo, RECEIVED, PR_UNPAID
|
||||||
from electrum.lnonion import OnionFailureCode
|
from electrum.lnonion import OnionFailureCode
|
||||||
|
from electrum.lnutil import ChannelBlackList
|
||||||
|
|
||||||
from .test_lnchannel import create_test_channels
|
from .test_lnchannel import create_test_channels
|
||||||
from .test_bitcoin import needs_test_with_all_chacha20_implementations
|
from .test_bitcoin import needs_test_with_all_chacha20_implementations
|
||||||
@@ -62,6 +63,7 @@ class MockNetwork:
|
|||||||
self.path_finder = LNPathFinder(self.channel_db)
|
self.path_finder = LNPathFinder(self.channel_db)
|
||||||
self.tx_queue = tx_queue
|
self.tx_queue = tx_queue
|
||||||
self._blockchain = MockBlockchain()
|
self._blockchain = MockBlockchain()
|
||||||
|
self.channel_blacklist = ChannelBlackList()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def callback_lock(self):
|
def callback_lock(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user