lnworker/lnpeer: add some type hints, force some kwargs
This commit is contained in:
@@ -437,9 +437,12 @@ class OnionRoutingFailure(Exception):
|
||||
return str(self.code.name)
|
||||
return f"Unknown error ({self.code!r})"
|
||||
|
||||
def construct_onion_error(reason: OnionRoutingFailure,
|
||||
onion_packet: OnionPacket,
|
||||
our_onion_private_key: bytes) -> bytes:
|
||||
|
||||
def construct_onion_error(
|
||||
reason: OnionRoutingFailure,
|
||||
onion_packet: OnionPacket,
|
||||
our_onion_private_key: bytes,
|
||||
) -> bytes:
|
||||
# create payload
|
||||
failure_msg = reason.to_bytes()
|
||||
failure_len = len(failure_msg)
|
||||
|
||||
@@ -1373,9 +1373,12 @@ class Peer(Logger):
|
||||
chan.receive_htlc(htlc, onion_packet)
|
||||
util.trigger_callback('htlc_added', chan, htlc, RECEIVED)
|
||||
|
||||
def maybe_forward_htlc(self, chan: Channel, htlc: UpdateAddHtlc, *,
|
||||
onion_packet: OnionPacket, processed_onion: ProcessedOnionPacket
|
||||
) -> Tuple[Optional[bytes], Optional[int], Optional[OnionRoutingFailure]]:
|
||||
def maybe_forward_htlc(
|
||||
self,
|
||||
*,
|
||||
htlc: UpdateAddHtlc,
|
||||
processed_onion: ProcessedOnionPacket,
|
||||
) -> Tuple[bytes, int]:
|
||||
# Forward HTLC
|
||||
# FIXME: there are critical safety checks MISSING here
|
||||
forwarding_enabled = self.network.config.get('lightning_forward_payments', False)
|
||||
@@ -1662,7 +1665,7 @@ class Peer(Logger):
|
||||
self.shutdown_received[chan_id] = asyncio.Future()
|
||||
await self.send_shutdown(chan)
|
||||
payload = await self.shutdown_received[chan_id]
|
||||
txid = await self._shutdown(chan, payload, True)
|
||||
txid = await self._shutdown(chan, payload, is_local=True)
|
||||
self.logger.info(f'({chan.get_id_for_log()}) Channel closed {txid}')
|
||||
return txid
|
||||
|
||||
@@ -1686,10 +1689,10 @@ class Peer(Logger):
|
||||
else:
|
||||
chan = self.channels[chan_id]
|
||||
await self.send_shutdown(chan)
|
||||
txid = await self._shutdown(chan, payload, False)
|
||||
txid = await self._shutdown(chan, payload, is_local=False)
|
||||
self.logger.info(f'({chan.get_id_for_log()}) Channel closed by remote peer {txid}')
|
||||
|
||||
def can_send_shutdown(self, chan):
|
||||
def can_send_shutdown(self, chan: Channel):
|
||||
if chan.get_state() >= ChannelState.OPENING:
|
||||
return True
|
||||
if chan.constraints.is_initiator and chan.channel_id in self.funding_created_sent:
|
||||
@@ -1718,7 +1721,7 @@ class Peer(Logger):
|
||||
chan.set_can_send_ctx_updates(True)
|
||||
|
||||
@log_exceptions
|
||||
async def _shutdown(self, chan: Channel, payload, is_local):
|
||||
async def _shutdown(self, chan: Channel, payload, *, is_local: bool):
|
||||
# wait until no HTLCs remain in either commitment transaction
|
||||
while len(chan.hm.htlcs(LOCAL)) + len(chan.hm.htlcs(REMOTE)) > 0:
|
||||
self.logger.info(f'(chan: {chan.short_channel_id}) waiting for htlcs to settle...')
|
||||
@@ -1826,7 +1829,12 @@ class Peer(Logger):
|
||||
error_reason = e
|
||||
else:
|
||||
try:
|
||||
preimage, fw_info, error_bytes = self.process_unfulfilled_htlc(chan, htlc_id, htlc, forwarding_info, onion_packet_bytes, onion_packet)
|
||||
preimage, fw_info, error_bytes = self.process_unfulfilled_htlc(
|
||||
chan=chan,
|
||||
htlc=htlc,
|
||||
forwarding_info=forwarding_info,
|
||||
onion_packet_bytes=onion_packet_bytes,
|
||||
onion_packet=onion_packet)
|
||||
except OnionRoutingFailure as e:
|
||||
error_bytes = construct_onion_error(e, onion_packet, our_onion_private_key=self.privkey)
|
||||
if fw_info:
|
||||
@@ -1850,13 +1858,24 @@ class Peer(Logger):
|
||||
for htlc_id in done:
|
||||
unfulfilled.pop(htlc_id)
|
||||
|
||||
def process_unfulfilled_htlc(self, chan, htlc_id, htlc, forwarding_info, onion_packet_bytes, onion_packet):
|
||||
def process_unfulfilled_htlc(
|
||||
self,
|
||||
*,
|
||||
chan: Channel,
|
||||
htlc: UpdateAddHtlc,
|
||||
forwarding_info: Tuple[str, int],
|
||||
onion_packet_bytes: bytes,
|
||||
onion_packet: OnionPacket,
|
||||
) -> Tuple[Optional[bytes], Union[bool, None, Tuple[str, int]], Optional[bytes]]:
|
||||
"""
|
||||
returns either preimage or fw_info or error_bytes or (None, None, None)
|
||||
raise an OnionRoutingFailure if we need to fail the htlc
|
||||
"""
|
||||
payment_hash = htlc.payment_hash
|
||||
processed_onion = self.process_onion_packet(onion_packet, payment_hash, onion_packet_bytes)
|
||||
processed_onion = self.process_onion_packet(
|
||||
onion_packet,
|
||||
payment_hash=payment_hash,
|
||||
onion_packet_bytes=onion_packet_bytes)
|
||||
if processed_onion.are_we_final:
|
||||
preimage = self.maybe_fulfill_htlc(
|
||||
chan=chan,
|
||||
@@ -1867,8 +1886,8 @@ class Peer(Logger):
|
||||
if not forwarding_info:
|
||||
trampoline_onion = self.process_onion_packet(
|
||||
processed_onion.trampoline_onion_packet,
|
||||
htlc.payment_hash,
|
||||
onion_packet_bytes,
|
||||
payment_hash=htlc.payment_hash,
|
||||
onion_packet_bytes=onion_packet_bytes,
|
||||
is_trampoline=True)
|
||||
if trampoline_onion.are_we_final:
|
||||
preimage = self.maybe_fulfill_htlc(
|
||||
@@ -1892,13 +1911,10 @@ class Peer(Logger):
|
||||
|
||||
elif not forwarding_info:
|
||||
next_chan_id, next_htlc_id = self.maybe_forward_htlc(
|
||||
chan=chan,
|
||||
htlc=htlc,
|
||||
onion_packet=onion_packet,
|
||||
processed_onion=processed_onion)
|
||||
if next_chan_id:
|
||||
fw_info = (next_chan_id.hex(), next_htlc_id)
|
||||
return None, fw_info, None
|
||||
fw_info = (next_chan_id.hex(), next_htlc_id)
|
||||
return None, fw_info, None
|
||||
else:
|
||||
preimage = self.lnworker.get_preimage(payment_hash)
|
||||
next_chan_id_hex, htlc_id = forwarding_info
|
||||
@@ -1913,7 +1929,14 @@ class Peer(Logger):
|
||||
return preimage, None, None
|
||||
return None, None, None
|
||||
|
||||
def process_onion_packet(self, onion_packet, payment_hash, onion_packet_bytes, is_trampoline=False):
|
||||
def process_onion_packet(
|
||||
self,
|
||||
onion_packet: OnionPacket,
|
||||
*,
|
||||
payment_hash: bytes,
|
||||
onion_packet_bytes: bytes,
|
||||
is_trampoline: bool = False,
|
||||
) -> ProcessedOnionPacket:
|
||||
failure_data = sha256(onion_packet_bytes)
|
||||
try:
|
||||
processed_onion = process_onion_packet(
|
||||
|
||||
@@ -268,7 +268,10 @@ class LNRater(Logger):
|
||||
|
||||
return pk, self._node_stats[pk]
|
||||
|
||||
def suggest_peer(self):
|
||||
def suggest_peer(self) -> Optional[bytes]:
|
||||
"""Suggests a LN node to open a channel with.
|
||||
Returns a node ID (pubkey).
|
||||
"""
|
||||
self.maybe_analyze_graph()
|
||||
if self._node_ratings:
|
||||
return self.suggest_node_channel_open()[0]
|
||||
|
||||
@@ -7,7 +7,8 @@ import os
|
||||
from decimal import Decimal
|
||||
import random
|
||||
import time
|
||||
from typing import Optional, Sequence, Tuple, List, Set, Dict, TYPE_CHECKING, NamedTuple, Union, Mapping, Any
|
||||
from typing import (Optional, Sequence, Tuple, List, Set, Dict, TYPE_CHECKING,
|
||||
NamedTuple, Union, Mapping, Any, Iterable)
|
||||
import threading
|
||||
import socket
|
||||
import aiohttp
|
||||
@@ -266,10 +267,10 @@ class LNWorker(Logger, NetworkRetryManager[LNPeerAddr]):
|
||||
with self.lock:
|
||||
return self._peers.copy()
|
||||
|
||||
def channels_for_peer(self, node_id):
|
||||
def channels_for_peer(self, node_id: bytes) -> Dict[bytes, Channel]:
|
||||
return {}
|
||||
|
||||
def get_node_alias(self, node_id):
|
||||
def get_node_alias(self, node_id: bytes) -> str:
|
||||
if self.channel_db:
|
||||
node_info = self.channel_db.get_node_info_for_node_id(node_id)
|
||||
node_alias = (node_info.alias if node_info else '') or node_id.hex()
|
||||
@@ -380,7 +381,7 @@ class LNWorker(Logger, NetworkRetryManager[LNPeerAddr]):
|
||||
self._add_peer(host, int(port), bfh(pubkey)),
|
||||
self.network.asyncio_loop)
|
||||
|
||||
def is_good_peer(self, peer):
|
||||
def is_good_peer(self, peer: LNPeerAddr) -> bool:
|
||||
# the purpose of this method is to filter peers that advertise the desired feature bits
|
||||
# it is disabled for now, because feature bits published in node announcements seem to be unreliable
|
||||
return True
|
||||
@@ -566,7 +567,7 @@ class LNGossip(LNWorker):
|
||||
self.channel_db.prune_orphaned_channels()
|
||||
await asyncio.sleep(120)
|
||||
|
||||
async def add_new_ids(self, ids):
|
||||
async def add_new_ids(self, ids: Iterable[bytes]):
|
||||
known = self.channel_db.get_channel_ids()
|
||||
new = set(ids) - set(known)
|
||||
self.unknown_ids.update(new)
|
||||
@@ -574,7 +575,7 @@ class LNGossip(LNWorker):
|
||||
util.trigger_callback('gossip_peers', self.num_peers())
|
||||
util.trigger_callback('ln_gossip_sync_progress')
|
||||
|
||||
def get_ids_to_query(self):
|
||||
def get_ids_to_query(self) -> Sequence[bytes]:
|
||||
N = 500
|
||||
l = list(self.unknown_ids)
|
||||
self.unknown_ids = set(l[N:])
|
||||
@@ -910,7 +911,7 @@ class LNWallet(LNWorker):
|
||||
if chan.funding_outpoint.to_str() == txo:
|
||||
return chan
|
||||
|
||||
async def on_channel_update(self, chan):
|
||||
async def on_channel_update(self, chan: Channel):
|
||||
|
||||
if chan.get_state() == ChannelState.OPEN and chan.should_be_closed_due_to_expiring_htlcs(self.network.get_local_height()):
|
||||
self.logger.info(f"force-closing due to expiring htlcs")
|
||||
@@ -938,10 +939,14 @@ class LNWallet(LNWorker):
|
||||
|
||||
@log_exceptions
|
||||
async def _open_channel_coroutine(
|
||||
self, *, connect_str: str,
|
||||
self,
|
||||
*,
|
||||
connect_str: str,
|
||||
funding_tx: PartialTransaction,
|
||||
funding_sat: int, push_sat: int,
|
||||
password: Optional[str]) -> Tuple[Channel, PartialTransaction]:
|
||||
funding_sat: int,
|
||||
push_sat: int,
|
||||
password: Optional[str],
|
||||
) -> Tuple[Channel, PartialTransaction]:
|
||||
peer = await self.add_peer(connect_str)
|
||||
coro = peer.channel_establishment_flow(
|
||||
funding_tx=funding_tx,
|
||||
@@ -1006,7 +1011,7 @@ class LNWallet(LNWorker):
|
||||
if chan.short_channel_id == short_channel_id:
|
||||
return chan
|
||||
|
||||
def create_routes_from_invoice(self, amount_msat, decoded_invoice, *, full_path=None):
|
||||
def create_routes_from_invoice(self, amount_msat: int, decoded_invoice: LnAddr, *, full_path=None):
|
||||
return self.create_routes_for_payment(
|
||||
amount_msat=amount_msat,
|
||||
invoice_pubkey=decoded_invoice.pubkey.serialize(),
|
||||
@@ -1051,9 +1056,16 @@ class LNWallet(LNWorker):
|
||||
util.trigger_callback('invoice_status', self.wallet, key)
|
||||
try:
|
||||
await self.pay_to_node(
|
||||
invoice_pubkey, payment_hash, payment_secret, amount_to_pay,
|
||||
min_cltv_expiry, r_tags, t_tags, invoice_features,
|
||||
attempts=attempts, full_path=full_path)
|
||||
node_pubkey=invoice_pubkey,
|
||||
payment_hash=payment_hash,
|
||||
payment_secret=payment_secret,
|
||||
amount_to_pay=amount_to_pay,
|
||||
min_cltv_expiry=min_cltv_expiry,
|
||||
r_tags=r_tags,
|
||||
t_tags=t_tags,
|
||||
invoice_features=invoice_features,
|
||||
attempts=attempts,
|
||||
full_path=full_path)
|
||||
success = True
|
||||
except PaymentFailure as e:
|
||||
self.logger.exception('')
|
||||
@@ -1068,12 +1080,23 @@ class LNWallet(LNWorker):
|
||||
log = self.logs[key]
|
||||
return success, log
|
||||
|
||||
|
||||
async def pay_to_node(
|
||||
self, node_pubkey, payment_hash, payment_secret, amount_to_pay,
|
||||
min_cltv_expiry, r_tags, t_tags, invoice_features, *,
|
||||
attempts: int = 1, full_path: LNPaymentPath=None,
|
||||
trampoline_onion=None, trampoline_fee=None, trampoline_cltv_delta=None):
|
||||
self,
|
||||
*,
|
||||
node_pubkey: bytes,
|
||||
payment_hash: bytes,
|
||||
payment_secret: Optional[bytes],
|
||||
amount_to_pay: int, # in msat
|
||||
min_cltv_expiry: int,
|
||||
r_tags,
|
||||
t_tags,
|
||||
invoice_features: int,
|
||||
attempts: int = 1,
|
||||
full_path: LNPaymentPath = None,
|
||||
trampoline_onion=None,
|
||||
trampoline_fee=None,
|
||||
trampoline_cltv_delta=None,
|
||||
) -> None:
|
||||
|
||||
if trampoline_onion:
|
||||
# todo: compare to the fee of the actual route we found
|
||||
@@ -1095,7 +1118,14 @@ class LNWallet(LNWorker):
|
||||
min_cltv_expiry, r_tags, t_tags, invoice_features, full_path=full_path))
|
||||
# 2. send htlcs
|
||||
for route, amount_msat in routes:
|
||||
await self.pay_to_route(route, amount_msat, amount_to_pay, payment_hash, payment_secret, min_cltv_expiry, trampoline_onion)
|
||||
await self.pay_to_route(
|
||||
route,
|
||||
amount_msat=amount_msat,
|
||||
total_msat=amount_to_pay,
|
||||
payment_hash=payment_hash,
|
||||
payment_secret=payment_secret,
|
||||
min_cltv_expiry=min_cltv_expiry,
|
||||
trampoline_onion=trampoline_onion)
|
||||
amount_inflight += amount_msat
|
||||
util.trigger_callback('invoice_status', self.wallet, payment_hash.hex())
|
||||
# 3. await a queue
|
||||
@@ -1111,9 +1141,17 @@ class LNWallet(LNWorker):
|
||||
# if we get a channel update, we might retry the same route and amount
|
||||
self.handle_error_code_from_failed_htlc(htlc_log)
|
||||
|
||||
async def pay_to_route(self, route: LNPaymentRoute, amount_msat: int,
|
||||
total_msat: int, payment_hash: bytes, payment_secret: bytes,
|
||||
min_cltv_expiry: int, trampoline_onion: bytes=None):
|
||||
async def pay_to_route(
|
||||
self,
|
||||
route: LNPaymentRoute,
|
||||
*,
|
||||
amount_msat: int,
|
||||
total_msat: int,
|
||||
payment_hash: bytes,
|
||||
payment_secret: Optional[bytes],
|
||||
min_cltv_expiry: int,
|
||||
trampoline_onion: bytes = None,
|
||||
) -> None:
|
||||
# send a single htlc
|
||||
short_channel_id = route[0].short_channel_id
|
||||
chan = self.get_channel_by_short_id(short_channel_id)
|
||||
@@ -1267,7 +1305,7 @@ class LNWallet(LNWorker):
|
||||
result.append(bitstring.BitArray(pubkey) + bitstring.BitArray(channel) + bitstring.pack('intbe:32', feebase) + bitstring.pack('intbe:32', feerate) + bitstring.pack('intbe:16', cltv))
|
||||
return result.tobytes()
|
||||
|
||||
def is_trampoline_peer(self, node_id):
|
||||
def is_trampoline_peer(self, node_id: bytes) -> bool:
|
||||
# until trampoline is advertised in lnfeatures, check against hardcoded list
|
||||
if is_hardcoded_trampoline(node_id):
|
||||
return True
|
||||
@@ -1276,8 +1314,11 @@ class LNWallet(LNWorker):
|
||||
return True
|
||||
return False
|
||||
|
||||
def suggest_peer(self):
|
||||
return self.lnrater.suggest_peer() if self.channel_db else random.choice(list(hardcoded_trampoline_nodes().values())).pubkey
|
||||
def suggest_peer(self) -> Optional[bytes]:
|
||||
if self.channel_db:
|
||||
return self.lnrater.suggest_peer()
|
||||
else:
|
||||
return random.choice(list(hardcoded_trampoline_nodes().values())).pubkey
|
||||
|
||||
def create_trampoline_route(
|
||||
self, amount_msat:int,
|
||||
@@ -1400,8 +1441,10 @@ class LNWallet(LNWorker):
|
||||
invoice_pubkey,
|
||||
min_cltv_expiry,
|
||||
r_tags, t_tags,
|
||||
invoice_features,
|
||||
*, full_path: LNPaymentPath = None) -> Sequence[Tuple[LNPaymentRoute, int]]:
|
||||
invoice_features: int,
|
||||
*,
|
||||
full_path: LNPaymentPath = None,
|
||||
) -> Sequence[Tuple[LNPaymentRoute, int]]:
|
||||
"""Creates multiple routes for splitting a payment over the available
|
||||
private channels.
|
||||
|
||||
@@ -1411,13 +1454,14 @@ class LNWallet(LNWorker):
|
||||
# try to send over a single channel
|
||||
try:
|
||||
routes = [self.create_route_for_payment(
|
||||
amount_msat,
|
||||
invoice_pubkey,
|
||||
min_cltv_expiry,
|
||||
r_tags, t_tags,
|
||||
invoice_features,
|
||||
None,
|
||||
full_path=full_path
|
||||
amount_msat=amount_msat,
|
||||
invoice_pubkey=invoice_pubkey,
|
||||
min_cltv_expiry=min_cltv_expiry,
|
||||
r_tags=r_tags,
|
||||
t_tags=t_tags,
|
||||
invoice_features=invoice_features,
|
||||
outgoing_channel=None,
|
||||
full_path=full_path,
|
||||
)]
|
||||
except NoPathFound:
|
||||
if not invoice_features.supports(LnFeatures.BASIC_MPP_OPT):
|
||||
@@ -1439,12 +1483,13 @@ class LNWallet(LNWorker):
|
||||
# its capacity. This could be dealt with by temporarily
|
||||
# iteratively blacklisting channels for this mpp attempt.
|
||||
route, amt = self.create_route_for_payment(
|
||||
part_amount_msat,
|
||||
invoice_pubkey,
|
||||
min_cltv_expiry,
|
||||
r_tags, t_tags,
|
||||
invoice_features,
|
||||
channel,
|
||||
amount_msat=part_amount_msat,
|
||||
invoice_pubkey=invoice_pubkey,
|
||||
min_cltv_expiry=min_cltv_expiry,
|
||||
r_tags=r_tags,
|
||||
t_tags=t_tags,
|
||||
invoice_features=invoice_features,
|
||||
outgoing_channel=channel,
|
||||
full_path=None)
|
||||
routes.append((route, amt))
|
||||
self.logger.info(f"found acceptable split configuration: {list(s[0].values())} rating: {s[1]}")
|
||||
@@ -1457,13 +1502,16 @@ class LNWallet(LNWorker):
|
||||
|
||||
def create_route_for_payment(
|
||||
self,
|
||||
*,
|
||||
amount_msat: int,
|
||||
invoice_pubkey,
|
||||
min_cltv_expiry,
|
||||
r_tags, t_tags,
|
||||
invoice_features,
|
||||
invoice_pubkey: bytes,
|
||||
min_cltv_expiry: int,
|
||||
r_tags,
|
||||
t_tags,
|
||||
invoice_features: int,
|
||||
outgoing_channel: Channel = None,
|
||||
*, full_path: Optional[LNPaymentPath]) -> Tuple[LNPaymentRoute, int]:
|
||||
full_path: Optional[LNPaymentPath],
|
||||
) -> Tuple[LNPaymentRoute, int]:
|
||||
|
||||
channels = [outgoing_channel] if outgoing_channel else list(self.channels.values())
|
||||
if not self.channel_db:
|
||||
@@ -1554,7 +1602,13 @@ class LNWallet(LNWorker):
|
||||
raise Exception(_("add invoice timed out"))
|
||||
|
||||
@log_exceptions
|
||||
async def create_invoice(self, *, amount_msat: Optional[int], message, expiry: int):
|
||||
async def create_invoice(
|
||||
self,
|
||||
*,
|
||||
amount_msat: Optional[int],
|
||||
message,
|
||||
expiry: int,
|
||||
) -> Tuple[LnAddr, str]:
|
||||
timestamp = int(time.time())
|
||||
routing_hints = await self._calc_routing_hints_for_invoice(amount_msat)
|
||||
if not routing_hints:
|
||||
@@ -1628,7 +1682,7 @@ class LNWallet(LNWorker):
|
||||
self.payments[key] = info.amount_msat, info.direction, info.status
|
||||
self.wallet.save_db()
|
||||
|
||||
def htlc_received(self, short_channel_id, htlc, expected_msat):
|
||||
def htlc_received(self, short_channel_id, htlc: UpdateAddHtlc, expected_msat: int):
|
||||
status = self.get_payment_status(htlc.payment_hash)
|
||||
if status == PR_PAID:
|
||||
return True, None
|
||||
|
||||
@@ -775,7 +775,13 @@ class TestPeer(ElectrumTestCase):
|
||||
min_cltv_expiry = lnaddr.get_min_final_cltv_expiry()
|
||||
payment_hash = lnaddr.paymenthash
|
||||
payment_secret = lnaddr.payment_secret
|
||||
pay = w1.pay_to_route(route, amount_msat, amount_msat, payment_hash, payment_secret, min_cltv_expiry)
|
||||
pay = w1.pay_to_route(
|
||||
route,
|
||||
amount_msat=amount_msat,
|
||||
total_msat=amount_msat,
|
||||
payment_hash=payment_hash,
|
||||
payment_secret=payment_secret,
|
||||
min_cltv_expiry=min_cltv_expiry)
|
||||
await asyncio.gather(pay, p1._message_loop(), p2._message_loop(), p1.htlc_switch(), p2.htlc_switch())
|
||||
with self.assertRaises(PaymentFailure):
|
||||
run(f())
|
||||
|
||||
Reference in New Issue
Block a user