if payment fails with UPDATE onion error, also utilise channel_update for private channels
This commit is contained in:
@@ -129,7 +129,7 @@ class ChannelInfo(PrintError):
|
||||
else:
|
||||
self.policy_node2 = new_policy
|
||||
|
||||
def get_policy_for_node(self, node_id: bytes) -> 'ChannelInfoDirectedPolicy':
|
||||
def get_policy_for_node(self, node_id: bytes) -> Optional['ChannelInfoDirectedPolicy']:
|
||||
if node_id == self.node_id_1:
|
||||
return self.policy_node1
|
||||
elif node_id == self.node_id_2:
|
||||
@@ -285,6 +285,9 @@ class ChannelDB(JsonDB):
|
||||
self._recent_peers = []
|
||||
self._last_good_address = {} # node_id -> LNPeerAddr
|
||||
|
||||
# (intentionally not persisted)
|
||||
self._channel_updates_for_private_channels = {} # type: Dict[Tuple[bytes, bytes], ChannelInfoDirectedPolicy]
|
||||
|
||||
self.ca_verifier = LNChannelVerifier(network, self)
|
||||
|
||||
self.load_data()
|
||||
@@ -425,6 +428,21 @@ class ChannelDB(JsonDB):
|
||||
return # ignore
|
||||
self.nodes[pubkey] = new_node_info
|
||||
|
||||
def get_routing_policy_for_channel(self, start_node_id: bytes,
|
||||
short_channel_id: bytes) -> Optional[ChannelInfoDirectedPolicy]:
|
||||
if not start_node_id or not short_channel_id: return None
|
||||
channel_info = self.get_channel_info(short_channel_id)
|
||||
if channel_info is not None:
|
||||
return channel_info.get_policy_for_node(start_node_id)
|
||||
return self._channel_updates_for_private_channels.get((start_node_id, short_channel_id))
|
||||
|
||||
def add_channel_update_for_private_channel(self, msg_payload: dict, start_node_id: bytes):
|
||||
if not verify_sig_for_channel_update(msg_payload, start_node_id):
|
||||
return # ignore
|
||||
short_channel_id = msg_payload['short_channel_id']
|
||||
policy = ChannelInfoDirectedPolicy(msg_payload)
|
||||
self._channel_updates_for_private_channels[(start_node_id, short_channel_id)] = policy
|
||||
|
||||
def remove_channel(self, short_channel_id):
|
||||
try:
|
||||
channel_info = self._id_to_channel_info[short_channel_id]
|
||||
@@ -488,7 +506,7 @@ class RouteEdge(NamedTuple("RouteEdge", [('node_id', bytes),
|
||||
|
||||
class LNPathFinder(PrintError):
|
||||
|
||||
def __init__(self, channel_db):
|
||||
def __init__(self, channel_db: ChannelDB):
|
||||
self.channel_db = channel_db
|
||||
self.blacklist = set()
|
||||
|
||||
@@ -590,12 +608,9 @@ class LNPathFinder(PrintError):
|
||||
route = []
|
||||
prev_node_id = from_node_id
|
||||
for node_id, short_channel_id in path:
|
||||
channel_info = self.channel_db.get_channel_info(short_channel_id)
|
||||
if channel_info is None:
|
||||
raise Exception('cannot find channel info for short_channel_id: {}'.format(bh2u(short_channel_id)))
|
||||
channel_policy = channel_info.get_policy_for_node(prev_node_id)
|
||||
channel_policy = self.channel_db.get_routing_policy_for_channel(prev_node_id, short_channel_id)
|
||||
if channel_policy is None:
|
||||
raise Exception('cannot find channel policy for short_channel_id: {}'.format(bh2u(short_channel_id)))
|
||||
raise Exception(f'cannot find channel policy for short_channel_id: {bh2u(short_channel_id)}')
|
||||
route.append(RouteEdge(node_id,
|
||||
short_channel_id,
|
||||
channel_policy.fee_base_msat,
|
||||
|
||||
Reference in New Issue
Block a user