1
0

if payment fails with UPDATE onion error, also utilise channel_update for private channels

This commit is contained in:
SomberNight
2018-10-16 21:35:30 +02:00
committed by ThomasV
parent 962f70c7da
commit 2e5552816c
4 changed files with 58 additions and 22 deletions

View File

@@ -129,7 +129,7 @@ class ChannelInfo(PrintError):
else:
self.policy_node2 = new_policy
def get_policy_for_node(self, node_id: bytes) -> 'ChannelInfoDirectedPolicy':
def get_policy_for_node(self, node_id: bytes) -> Optional['ChannelInfoDirectedPolicy']:
if node_id == self.node_id_1:
return self.policy_node1
elif node_id == self.node_id_2:
@@ -285,6 +285,9 @@ class ChannelDB(JsonDB):
self._recent_peers = []
self._last_good_address = {} # node_id -> LNPeerAddr
# (intentionally not persisted)
self._channel_updates_for_private_channels = {} # type: Dict[Tuple[bytes, bytes], ChannelInfoDirectedPolicy]
self.ca_verifier = LNChannelVerifier(network, self)
self.load_data()
@@ -425,6 +428,21 @@ class ChannelDB(JsonDB):
return # ignore
self.nodes[pubkey] = new_node_info
def get_routing_policy_for_channel(self, start_node_id: bytes,
short_channel_id: bytes) -> Optional[ChannelInfoDirectedPolicy]:
if not start_node_id or not short_channel_id: return None
channel_info = self.get_channel_info(short_channel_id)
if channel_info is not None:
return channel_info.get_policy_for_node(start_node_id)
return self._channel_updates_for_private_channels.get((start_node_id, short_channel_id))
def add_channel_update_for_private_channel(self, msg_payload: dict, start_node_id: bytes):
if not verify_sig_for_channel_update(msg_payload, start_node_id):
return # ignore
short_channel_id = msg_payload['short_channel_id']
policy = ChannelInfoDirectedPolicy(msg_payload)
self._channel_updates_for_private_channels[(start_node_id, short_channel_id)] = policy
def remove_channel(self, short_channel_id):
try:
channel_info = self._id_to_channel_info[short_channel_id]
@@ -488,7 +506,7 @@ class RouteEdge(NamedTuple("RouteEdge", [('node_id', bytes),
class LNPathFinder(PrintError):
def __init__(self, channel_db):
def __init__(self, channel_db: ChannelDB):
self.channel_db = channel_db
self.blacklist = set()
@@ -590,12 +608,9 @@ class LNPathFinder(PrintError):
route = []
prev_node_id = from_node_id
for node_id, short_channel_id in path:
channel_info = self.channel_db.get_channel_info(short_channel_id)
if channel_info is None:
raise Exception('cannot find channel info for short_channel_id: {}'.format(bh2u(short_channel_id)))
channel_policy = channel_info.get_policy_for_node(prev_node_id)
channel_policy = self.channel_db.get_routing_policy_for_channel(prev_node_id, short_channel_id)
if channel_policy is None:
raise Exception('cannot find channel policy for short_channel_id: {}'.format(bh2u(short_channel_id)))
raise Exception(f'cannot find channel policy for short_channel_id: {bh2u(short_channel_id)}')
route.append(RouteEdge(node_id,
short_channel_id,
channel_policy.fee_base_msat,