lnworker: add/fix some type hints, add some comments
follow-up recent refactor
This commit is contained in:
@@ -9,7 +9,7 @@ from collections import OrderedDict, defaultdict
|
|||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from typing import Tuple, Dict, TYPE_CHECKING, Optional, Union, Set, Callable
|
from typing import Tuple, Dict, TYPE_CHECKING, Optional, Union, Set, Callable, Awaitable
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import functools
|
import functools
|
||||||
|
|
||||||
@@ -1693,7 +1693,8 @@ class Peer(Logger):
|
|||||||
self, *,
|
self, *,
|
||||||
incoming_chan: Channel,
|
incoming_chan: Channel,
|
||||||
htlc: UpdateAddHtlc,
|
htlc: UpdateAddHtlc,
|
||||||
processed_onion: ProcessedOnionPacket) -> Tuple[bytes, int]:
|
processed_onion: ProcessedOnionPacket,
|
||||||
|
) -> str:
|
||||||
|
|
||||||
# Forward HTLC
|
# Forward HTLC
|
||||||
# FIXME: there are critical safety checks MISSING here
|
# FIXME: there are critical safety checks MISSING here
|
||||||
@@ -1744,11 +1745,11 @@ class Peer(Logger):
|
|||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
return await self.lnworker.open_channel_just_in_time(
|
return await self.lnworker.open_channel_just_in_time(
|
||||||
next_peer,
|
next_peer=next_peer,
|
||||||
next_amount_msat_htlc,
|
next_amount_msat_htlc=next_amount_msat_htlc,
|
||||||
next_cltv_abs,
|
next_cltv_abs=next_cltv_abs,
|
||||||
htlc.payment_hash,
|
payment_hash=htlc.payment_hash,
|
||||||
processed_onion.next_packet)
|
next_onion=processed_onion.next_packet)
|
||||||
|
|
||||||
local_height = chain.height()
|
local_height = chain.height()
|
||||||
if next_chan is None:
|
if next_chan is None:
|
||||||
@@ -1815,7 +1816,8 @@ class Peer(Logger):
|
|||||||
inc_cltv_abs: int,
|
inc_cltv_abs: int,
|
||||||
outer_onion: ProcessedOnionPacket,
|
outer_onion: ProcessedOnionPacket,
|
||||||
trampoline_onion: ProcessedOnionPacket,
|
trampoline_onion: ProcessedOnionPacket,
|
||||||
fw_payment_key: str):
|
fw_payment_key: str,
|
||||||
|
) -> None:
|
||||||
|
|
||||||
forwarding_enabled = self.network.config.EXPERIMENTAL_LN_FORWARD_PAYMENTS
|
forwarding_enabled = self.network.config.EXPERIMENTAL_LN_FORWARD_PAYMENTS
|
||||||
forwarding_trampoline_enabled = self.network.config.EXPERIMENTAL_LN_FORWARD_TRAMPOLINE_PAYMENTS
|
forwarding_trampoline_enabled = self.network.config.EXPERIMENTAL_LN_FORWARD_TRAMPOLINE_PAYMENTS
|
||||||
@@ -1905,11 +1907,11 @@ class Peer(Logger):
|
|||||||
trampoline_onion=next_trampoline_onion,
|
trampoline_onion=next_trampoline_onion,
|
||||||
)
|
)
|
||||||
await self.lnworker.open_channel_just_in_time(
|
await self.lnworker.open_channel_just_in_time(
|
||||||
next_peer,
|
next_peer=next_peer,
|
||||||
amt_to_forward,
|
next_amount_msat_htlc=amt_to_forward,
|
||||||
cltv_abs,
|
next_cltv_abs=cltv_abs,
|
||||||
payment_hash,
|
payment_hash=payment_hash,
|
||||||
next_onion)
|
next_onion=next_onion)
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -1957,8 +1959,8 @@ class Peer(Logger):
|
|||||||
htlc: UpdateAddHtlc,
|
htlc: UpdateAddHtlc,
|
||||||
processed_onion: ProcessedOnionPacket,
|
processed_onion: ProcessedOnionPacket,
|
||||||
onion_packet_bytes: bytes,
|
onion_packet_bytes: bytes,
|
||||||
already_forwarded = False,
|
already_forwarded: bool = False,
|
||||||
) -> Tuple[Optional[bytes], Optional[Callable]]:
|
) -> Tuple[Optional[str], Optional[bytes], Optional[Callable[[], Awaitable[Optional[str]]]]]:
|
||||||
"""
|
"""
|
||||||
Decide what to do with an HTLC: return preimage if it can be fulfilled, forwarding callback if it can be forwarded.
|
Decide what to do with an HTLC: return preimage if it can be fulfilled, forwarding callback if it can be forwarded.
|
||||||
Return (payment_key, preimage, callback) with at most a single element of the last two not None
|
Return (payment_key, preimage, callback) with at most a single element of the last two not None
|
||||||
@@ -2637,6 +2639,7 @@ class Peer(Logger):
|
|||||||
# HTLC we are supposed to forward, but haven't forwarded yet
|
# HTLC we are supposed to forward, but haven't forwarded yet
|
||||||
if not self.lnworker.enable_htlc_forwarding:
|
if not self.lnworker.enable_htlc_forwarding:
|
||||||
return None, None, None
|
return None, None, None
|
||||||
|
assert payment_key
|
||||||
if payment_key not in self.lnworker.active_forwardings:
|
if payment_key not in self.lnworker.active_forwardings:
|
||||||
async def wrapped_callback():
|
async def wrapped_callback():
|
||||||
forwarding_coro = forwarding_callback()
|
forwarding_coro = forwarding_callback()
|
||||||
@@ -2649,6 +2652,7 @@ class Peer(Logger):
|
|||||||
assert len(self.lnworker.active_forwardings[payment_key]) == 0
|
assert len(self.lnworker.active_forwardings[payment_key]) == 0
|
||||||
self.lnworker.save_forwarding_failure(payment_key, failure_message=e)
|
self.lnworker.save_forwarding_failure(payment_key, failure_message=e)
|
||||||
# add to list
|
# add to list
|
||||||
|
assert len(self.lnworker.active_forwardings.get(payment_key, [])) == 0
|
||||||
self.lnworker.active_forwardings[payment_key] = []
|
self.lnworker.active_forwardings[payment_key] = []
|
||||||
fut = asyncio.ensure_future(wrapped_callback())
|
fut = asyncio.ensure_future(wrapped_callback())
|
||||||
# return payment_key so this branch will not be executed again
|
# return payment_key so this branch will not be executed again
|
||||||
|
|||||||
@@ -66,13 +66,15 @@ hex_to_bytes = lambda v: v if isinstance(v, bytes) else bytes.fromhex(v) if v is
|
|||||||
json_to_keypair = lambda v: v if isinstance(v, OnlyPubkeyKeypair) else Keypair(**v) if len(v)==2 else OnlyPubkeyKeypair(**v)
|
json_to_keypair = lambda v: v if isinstance(v, OnlyPubkeyKeypair) else Keypair(**v) if len(v)==2 else OnlyPubkeyKeypair(**v)
|
||||||
|
|
||||||
|
|
||||||
def serialize_htlc_key(scid:bytes, htlc_id: int):
|
def serialize_htlc_key(scid: bytes, htlc_id: int) -> str:
|
||||||
return scid.hex() + ':%d'%htlc_id
|
return scid.hex() + ':%d'%htlc_id
|
||||||
|
|
||||||
def deserialize_htlc_key(htlc_key:str):
|
|
||||||
|
def deserialize_htlc_key(htlc_key: str) -> Tuple[bytes, int]:
|
||||||
scid, htlc_id = htlc_key.split(':')
|
scid, htlc_id = htlc_key.split(':')
|
||||||
return bytes.fromhex(scid), int(htlc_id)
|
return bytes.fromhex(scid), int(htlc_id)
|
||||||
|
|
||||||
|
|
||||||
@attr.s
|
@attr.s
|
||||||
class OnlyPubkeyKeypair(StoredObject):
|
class OnlyPubkeyKeypair(StoredObject):
|
||||||
pubkey = attr.ib(type=bytes, converter=hex_to_bytes)
|
pubkey = attr.ib(type=bytes, converter=hex_to_bytes)
|
||||||
|
|||||||
@@ -691,7 +691,7 @@ class PaySession(Logger):
|
|||||||
|
|
||||||
self._amount_inflight = 0 # what we sent in htlcs (that receiver gets, without fees)
|
self._amount_inflight = 0 # what we sent in htlcs (that receiver gets, without fees)
|
||||||
self._nhtlcs_inflight = 0
|
self._nhtlcs_inflight = 0
|
||||||
self.is_active = True
|
self.is_active = True # is still trying to send new htlcs?
|
||||||
|
|
||||||
def diagnostic_name(self):
|
def diagnostic_name(self):
|
||||||
pkey = sha256(self.payment_key)
|
pkey = sha256(self.payment_key)
|
||||||
@@ -779,6 +779,7 @@ class PaySession(Logger):
|
|||||||
return self.amount_to_pay - self._amount_inflight
|
return self.amount_to_pay - self._amount_inflight
|
||||||
|
|
||||||
def can_be_deleted(self) -> bool:
|
def can_be_deleted(self) -> bool:
|
||||||
|
"""Returns True iff finished sending htlcs AND all pending htlcs have resolved."""
|
||||||
if self.is_active:
|
if self.is_active:
|
||||||
return False
|
return False
|
||||||
# note: no one is consuming from sent_htlcs_q anymore
|
# note: no one is consuming from sent_htlcs_q anymore
|
||||||
@@ -842,9 +843,9 @@ class LNWallet(LNWorker):
|
|||||||
self.set_invoice_status(payment_hash.hex(), PR_INFLIGHT)
|
self.set_invoice_status(payment_hash.hex(), PR_INFLIGHT)
|
||||||
|
|
||||||
# payment forwarding
|
# payment forwarding
|
||||||
self.active_forwardings = self.db.get_dict('active_forwardings') # Dict: payment_key -> list of htlc_keys
|
self.active_forwardings = self.db.get_dict('active_forwardings') # type: Dict[str, List[str]] # Dict: payment_key -> list of htlc_keys
|
||||||
self.forwarding_failures = self.db.get_dict('forwarding_failures') # Dict: payment_key -> (error_bytes, error_message)
|
self.forwarding_failures = self.db.get_dict('forwarding_failures') # type: Dict[str, Tuple[str, str]] # Dict: payment_key -> (error_bytes, error_message)
|
||||||
self.downstream_to_upstream_htlc = {} # Dict: htlc_key -> htlc_key (not persisted)
|
self.downstream_to_upstream_htlc = {} # type: Dict[str, str] # Dict: htlc_key -> htlc_key (not persisted)
|
||||||
|
|
||||||
# payment_hash -> callback:
|
# payment_hash -> callback:
|
||||||
self.hold_invoice_callbacks = {} # type: Dict[bytes, Callable[[bytes], Awaitable[None]]]
|
self.hold_invoice_callbacks = {} # type: Dict[bytes, Callable[[bytes], Awaitable[None]]]
|
||||||
@@ -1222,20 +1223,28 @@ class LNWallet(LNWorker):
|
|||||||
self.logger.info('REBROADCASTING CLOSING TX')
|
self.logger.info('REBROADCASTING CLOSING TX')
|
||||||
await self.network.try_broadcasting(force_close_tx, 'force-close')
|
await self.network.try_broadcasting(force_close_tx, 'force-close')
|
||||||
|
|
||||||
def get_peer_by_scid_alias(self, scid_alias):
|
def get_peer_by_scid_alias(self, scid_alias: bytes) -> Optional[Peer]:
|
||||||
for nodeid, peer in self.peers.items():
|
for nodeid, peer in self.peers.items():
|
||||||
if scid_alias == self._scid_alias_of_node(nodeid):
|
if scid_alias == self._scid_alias_of_node(nodeid):
|
||||||
return peer
|
return peer
|
||||||
|
|
||||||
def _scid_alias_of_node(self, nodeid):
|
def _scid_alias_of_node(self, nodeid: bytes) -> bytes:
|
||||||
# scid alias for just-in-time channels
|
# scid alias for just-in-time channels
|
||||||
return sha256(b'Electrum' + nodeid)[0:8]
|
return sha256(b'Electrum' + nodeid)[0:8]
|
||||||
|
|
||||||
def get_scid_alias(self):
|
def get_scid_alias(self) -> bytes:
|
||||||
return self._scid_alias_of_node(self.node_keypair.pubkey)
|
return self._scid_alias_of_node(self.node_keypair.pubkey)
|
||||||
|
|
||||||
@log_exceptions
|
@log_exceptions
|
||||||
async def open_channel_just_in_time(self, next_peer, next_amount_msat_htlc, next_cltv_abs, payment_hash, next_onion):
|
async def open_channel_just_in_time( # FIXME xxxxx kwargs
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
next_peer: Peer,
|
||||||
|
next_amount_msat_htlc: int,
|
||||||
|
next_cltv_abs: int,
|
||||||
|
payment_hash: bytes,
|
||||||
|
next_onion: OnionPacket,
|
||||||
|
) -> str:
|
||||||
# if an exception is raised during negotiation, we raise an OnionRoutingFailure.
|
# if an exception is raised during negotiation, we raise an OnionRoutingFailure.
|
||||||
# this will cancel the incoming HTLC
|
# this will cancel the incoming HTLC
|
||||||
try:
|
try:
|
||||||
@@ -2396,7 +2405,7 @@ class LNWallet(LNWorker):
|
|||||||
if htlc_key in htlcs:
|
if htlc_key in htlcs:
|
||||||
return payment_key
|
return payment_key
|
||||||
|
|
||||||
def notify_upstream_peer(self, htlc_key):
|
def notify_upstream_peer(self, htlc_key: str) -> None:
|
||||||
"""Called when an HTLC we offered on chan gets irrevocably fulfilled or failed.
|
"""Called when an HTLC we offered on chan gets irrevocably fulfilled or failed.
|
||||||
If we find this was a forwarded HTLC, the upstream peer is notified.
|
If we find this was a forwarded HTLC, the upstream peer is notified.
|
||||||
"""
|
"""
|
||||||
@@ -2510,7 +2519,7 @@ class LNWallet(LNWorker):
|
|||||||
if fw_key:
|
if fw_key:
|
||||||
paysession_active = False
|
paysession_active = False
|
||||||
else:
|
else:
|
||||||
self.logger.info(f"received unknown htlc_failed, probably from previous session")
|
self.logger.info(f"received unknown htlc_failed, probably from previous session (phash={payment_hash.hex()})")
|
||||||
key = payment_hash.hex()
|
key = payment_hash.hex()
|
||||||
self.set_invoice_status(key, PR_UNPAID)
|
self.set_invoice_status(key, PR_UNPAID)
|
||||||
util.trigger_callback('payment_failed', self.wallet, key, '')
|
util.trigger_callback('payment_failed', self.wallet, key, '')
|
||||||
@@ -2522,7 +2531,7 @@ class LNWallet(LNWorker):
|
|||||||
self.save_forwarding_failure(fw_key, error_bytes=error_bytes, failure_message=failure_message)
|
self.save_forwarding_failure(fw_key, error_bytes=error_bytes, failure_message=failure_message)
|
||||||
self.notify_upstream_peer(htlc_key)
|
self.notify_upstream_peer(htlc_key)
|
||||||
else:
|
else:
|
||||||
self.logger.info(f"waiting for other htlcs to fail")
|
self.logger.info(f"waiting for other htlcs to fail (phash={payment_hash.hex()})")
|
||||||
|
|
||||||
def calc_routing_hints_for_invoice(self, amount_msat: Optional[int], channels=None):
|
def calc_routing_hints_for_invoice(self, amount_msat: Optional[int], channels=None):
|
||||||
"""calculate routing hints (BOLT-11 'r' field)"""
|
"""calculate routing hints (BOLT-11 'r' field)"""
|
||||||
@@ -3094,7 +3103,7 @@ class LNWallet(LNWorker):
|
|||||||
failure_hex = failure_message.to_bytes().hex() if failure_message else None
|
failure_hex = failure_message.to_bytes().hex() if failure_message else None
|
||||||
self.forwarding_failures[payment_key] = (error_hex, failure_hex)
|
self.forwarding_failures[payment_key] = (error_hex, failure_hex)
|
||||||
|
|
||||||
def get_forwarding_failure(self, payment_key: str):
|
def get_forwarding_failure(self, payment_key: str) -> Tuple[Optional[bytes], Optional['OnionRoutingFailure']]:
|
||||||
error_hex, failure_hex = self.forwarding_failures.get(payment_key, (None, None))
|
error_hex, failure_hex = self.forwarding_failures.get(payment_key, (None, None))
|
||||||
error_bytes = bytes.fromhex(error_hex) if error_hex else None
|
error_bytes = bytes.fromhex(error_hex) if error_hex else None
|
||||||
failure_message = OnionRoutingFailure.from_bytes(bytes.fromhex(failure_hex)) if failure_hex else None
|
failure_message = OnionRoutingFailure.from_bytes(bytes.fromhex(failure_hex)) if failure_hex else None
|
||||||
|
|||||||
Reference in New Issue
Block a user