lnworker: add/fix some type hints, add some comments
follow-up recent refactor
This commit is contained in:
@@ -9,7 +9,7 @@ from collections import OrderedDict, defaultdict
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
from typing import Tuple, Dict, TYPE_CHECKING, Optional, Union, Set, Callable
|
||||
from typing import Tuple, Dict, TYPE_CHECKING, Optional, Union, Set, Callable, Awaitable
|
||||
from datetime import datetime
|
||||
import functools
|
||||
|
||||
@@ -1693,7 +1693,8 @@ class Peer(Logger):
|
||||
self, *,
|
||||
incoming_chan: Channel,
|
||||
htlc: UpdateAddHtlc,
|
||||
processed_onion: ProcessedOnionPacket) -> Tuple[bytes, int]:
|
||||
processed_onion: ProcessedOnionPacket,
|
||||
) -> str:
|
||||
|
||||
# Forward HTLC
|
||||
# FIXME: there are critical safety checks MISSING here
|
||||
@@ -1744,11 +1745,11 @@ class Peer(Logger):
|
||||
break
|
||||
else:
|
||||
return await self.lnworker.open_channel_just_in_time(
|
||||
next_peer,
|
||||
next_amount_msat_htlc,
|
||||
next_cltv_abs,
|
||||
htlc.payment_hash,
|
||||
processed_onion.next_packet)
|
||||
next_peer=next_peer,
|
||||
next_amount_msat_htlc=next_amount_msat_htlc,
|
||||
next_cltv_abs=next_cltv_abs,
|
||||
payment_hash=htlc.payment_hash,
|
||||
next_onion=processed_onion.next_packet)
|
||||
|
||||
local_height = chain.height()
|
||||
if next_chan is None:
|
||||
@@ -1815,7 +1816,8 @@ class Peer(Logger):
|
||||
inc_cltv_abs: int,
|
||||
outer_onion: ProcessedOnionPacket,
|
||||
trampoline_onion: ProcessedOnionPacket,
|
||||
fw_payment_key: str):
|
||||
fw_payment_key: str,
|
||||
) -> None:
|
||||
|
||||
forwarding_enabled = self.network.config.EXPERIMENTAL_LN_FORWARD_PAYMENTS
|
||||
forwarding_trampoline_enabled = self.network.config.EXPERIMENTAL_LN_FORWARD_TRAMPOLINE_PAYMENTS
|
||||
@@ -1905,11 +1907,11 @@ class Peer(Logger):
|
||||
trampoline_onion=next_trampoline_onion,
|
||||
)
|
||||
await self.lnworker.open_channel_just_in_time(
|
||||
next_peer,
|
||||
amt_to_forward,
|
||||
cltv_abs,
|
||||
payment_hash,
|
||||
next_onion)
|
||||
next_peer=next_peer,
|
||||
next_amount_msat_htlc=amt_to_forward,
|
||||
next_cltv_abs=cltv_abs,
|
||||
payment_hash=payment_hash,
|
||||
next_onion=next_onion)
|
||||
return
|
||||
|
||||
try:
|
||||
@@ -1957,8 +1959,8 @@ class Peer(Logger):
|
||||
htlc: UpdateAddHtlc,
|
||||
processed_onion: ProcessedOnionPacket,
|
||||
onion_packet_bytes: bytes,
|
||||
already_forwarded = False,
|
||||
) -> Tuple[Optional[bytes], Optional[Callable]]:
|
||||
already_forwarded: bool = False,
|
||||
) -> Tuple[Optional[str], Optional[bytes], Optional[Callable[[], Awaitable[Optional[str]]]]]:
|
||||
"""
|
||||
Decide what to do with an HTLC: return preimage if it can be fulfilled, forwarding callback if it can be forwarded.
|
||||
Return (payment_key, preimage, callback) with at most a single element of the last two not None
|
||||
@@ -2637,6 +2639,7 @@ class Peer(Logger):
|
||||
# HTLC we are supposed to forward, but haven't forwarded yet
|
||||
if not self.lnworker.enable_htlc_forwarding:
|
||||
return None, None, None
|
||||
assert payment_key
|
||||
if payment_key not in self.lnworker.active_forwardings:
|
||||
async def wrapped_callback():
|
||||
forwarding_coro = forwarding_callback()
|
||||
@@ -2649,6 +2652,7 @@ class Peer(Logger):
|
||||
assert len(self.lnworker.active_forwardings[payment_key]) == 0
|
||||
self.lnworker.save_forwarding_failure(payment_key, failure_message=e)
|
||||
# add to list
|
||||
assert len(self.lnworker.active_forwardings.get(payment_key, [])) == 0
|
||||
self.lnworker.active_forwardings[payment_key] = []
|
||||
fut = asyncio.ensure_future(wrapped_callback())
|
||||
# return payment_key so this branch will not be executed again
|
||||
|
||||
@@ -66,13 +66,15 @@ hex_to_bytes = lambda v: v if isinstance(v, bytes) else bytes.fromhex(v) if v is
|
||||
json_to_keypair = lambda v: v if isinstance(v, OnlyPubkeyKeypair) else Keypair(**v) if len(v)==2 else OnlyPubkeyKeypair(**v)
|
||||
|
||||
|
||||
def serialize_htlc_key(scid:bytes, htlc_id: int):
|
||||
def serialize_htlc_key(scid: bytes, htlc_id: int) -> str:
|
||||
return scid.hex() + ':%d'%htlc_id
|
||||
|
||||
def deserialize_htlc_key(htlc_key:str):
|
||||
|
||||
def deserialize_htlc_key(htlc_key: str) -> Tuple[bytes, int]:
|
||||
scid, htlc_id = htlc_key.split(':')
|
||||
return bytes.fromhex(scid), int(htlc_id)
|
||||
|
||||
|
||||
@attr.s
|
||||
class OnlyPubkeyKeypair(StoredObject):
|
||||
pubkey = attr.ib(type=bytes, converter=hex_to_bytes)
|
||||
|
||||
@@ -691,7 +691,7 @@ class PaySession(Logger):
|
||||
|
||||
self._amount_inflight = 0 # what we sent in htlcs (that receiver gets, without fees)
|
||||
self._nhtlcs_inflight = 0
|
||||
self.is_active = True
|
||||
self.is_active = True # is still trying to send new htlcs?
|
||||
|
||||
def diagnostic_name(self):
|
||||
pkey = sha256(self.payment_key)
|
||||
@@ -779,6 +779,7 @@ class PaySession(Logger):
|
||||
return self.amount_to_pay - self._amount_inflight
|
||||
|
||||
def can_be_deleted(self) -> bool:
|
||||
"""Returns True iff finished sending htlcs AND all pending htlcs have resolved."""
|
||||
if self.is_active:
|
||||
return False
|
||||
# note: no one is consuming from sent_htlcs_q anymore
|
||||
@@ -842,9 +843,9 @@ class LNWallet(LNWorker):
|
||||
self.set_invoice_status(payment_hash.hex(), PR_INFLIGHT)
|
||||
|
||||
# payment forwarding
|
||||
self.active_forwardings = self.db.get_dict('active_forwardings') # Dict: payment_key -> list of htlc_keys
|
||||
self.forwarding_failures = self.db.get_dict('forwarding_failures') # Dict: payment_key -> (error_bytes, error_message)
|
||||
self.downstream_to_upstream_htlc = {} # Dict: htlc_key -> htlc_key (not persisted)
|
||||
self.active_forwardings = self.db.get_dict('active_forwardings') # type: Dict[str, List[str]] # Dict: payment_key -> list of htlc_keys
|
||||
self.forwarding_failures = self.db.get_dict('forwarding_failures') # type: Dict[str, Tuple[str, str]] # Dict: payment_key -> (error_bytes, error_message)
|
||||
self.downstream_to_upstream_htlc = {} # type: Dict[str, str] # Dict: htlc_key -> htlc_key (not persisted)
|
||||
|
||||
# payment_hash -> callback:
|
||||
self.hold_invoice_callbacks = {} # type: Dict[bytes, Callable[[bytes], Awaitable[None]]]
|
||||
@@ -1222,20 +1223,28 @@ class LNWallet(LNWorker):
|
||||
self.logger.info('REBROADCASTING CLOSING TX')
|
||||
await self.network.try_broadcasting(force_close_tx, 'force-close')
|
||||
|
||||
def get_peer_by_scid_alias(self, scid_alias):
|
||||
def get_peer_by_scid_alias(self, scid_alias: bytes) -> Optional[Peer]:
|
||||
for nodeid, peer in self.peers.items():
|
||||
if scid_alias == self._scid_alias_of_node(nodeid):
|
||||
return peer
|
||||
|
||||
def _scid_alias_of_node(self, nodeid):
|
||||
def _scid_alias_of_node(self, nodeid: bytes) -> bytes:
|
||||
# scid alias for just-in-time channels
|
||||
return sha256(b'Electrum' + nodeid)[0:8]
|
||||
|
||||
def get_scid_alias(self):
|
||||
def get_scid_alias(self) -> bytes:
|
||||
return self._scid_alias_of_node(self.node_keypair.pubkey)
|
||||
|
||||
@log_exceptions
|
||||
async def open_channel_just_in_time(self, next_peer, next_amount_msat_htlc, next_cltv_abs, payment_hash, next_onion):
|
||||
async def open_channel_just_in_time( # FIXME xxxxx kwargs
|
||||
self,
|
||||
*,
|
||||
next_peer: Peer,
|
||||
next_amount_msat_htlc: int,
|
||||
next_cltv_abs: int,
|
||||
payment_hash: bytes,
|
||||
next_onion: OnionPacket,
|
||||
) -> str:
|
||||
# if an exception is raised during negotiation, we raise an OnionRoutingFailure.
|
||||
# this will cancel the incoming HTLC
|
||||
try:
|
||||
@@ -2396,7 +2405,7 @@ class LNWallet(LNWorker):
|
||||
if htlc_key in htlcs:
|
||||
return payment_key
|
||||
|
||||
def notify_upstream_peer(self, htlc_key):
|
||||
def notify_upstream_peer(self, htlc_key: str) -> None:
|
||||
"""Called when an HTLC we offered on chan gets irrevocably fulfilled or failed.
|
||||
If we find this was a forwarded HTLC, the upstream peer is notified.
|
||||
"""
|
||||
@@ -2510,7 +2519,7 @@ class LNWallet(LNWorker):
|
||||
if fw_key:
|
||||
paysession_active = False
|
||||
else:
|
||||
self.logger.info(f"received unknown htlc_failed, probably from previous session")
|
||||
self.logger.info(f"received unknown htlc_failed, probably from previous session (phash={payment_hash.hex()})")
|
||||
key = payment_hash.hex()
|
||||
self.set_invoice_status(key, PR_UNPAID)
|
||||
util.trigger_callback('payment_failed', self.wallet, key, '')
|
||||
@@ -2522,7 +2531,7 @@ class LNWallet(LNWorker):
|
||||
self.save_forwarding_failure(fw_key, error_bytes=error_bytes, failure_message=failure_message)
|
||||
self.notify_upstream_peer(htlc_key)
|
||||
else:
|
||||
self.logger.info(f"waiting for other htlcs to fail")
|
||||
self.logger.info(f"waiting for other htlcs to fail (phash={payment_hash.hex()})")
|
||||
|
||||
def calc_routing_hints_for_invoice(self, amount_msat: Optional[int], channels=None):
|
||||
"""calculate routing hints (BOLT-11 'r' field)"""
|
||||
@@ -3094,7 +3103,7 @@ class LNWallet(LNWorker):
|
||||
failure_hex = failure_message.to_bytes().hex() if failure_message else None
|
||||
self.forwarding_failures[payment_key] = (error_hex, failure_hex)
|
||||
|
||||
def get_forwarding_failure(self, payment_key: str):
|
||||
def get_forwarding_failure(self, payment_key: str) -> Tuple[Optional[bytes], Optional['OnionRoutingFailure']]:
|
||||
error_hex, failure_hex = self.forwarding_failures.get(payment_key, (None, None))
|
||||
error_bytes = bytes.fromhex(error_hex) if error_hex else None
|
||||
failure_message = OnionRoutingFailure.from_bytes(bytes.fromhex(failure_hex)) if failure_hex else None
|
||||
|
||||
Reference in New Issue
Block a user