lnutil: change ReceivedMPPStatus.htlcs to frozenset, i.e. immutable
As ThomasV says: > ReceivedMPPStatus is a Namedtuple, which is immutable, but it contains > a mutable field. Since ReceivedMPPStatus is not a StoredObject, > no patch will be created when the htlcs list is modified, and we may > end up not saving the change to disk if partial writes are enabled. patch taken from https://github.com/spesmilo/electrum/pull/10395#pullrequestreview-3634244541 closes https://github.com/spesmilo/electrum/pull/10395 Co-authored-by: f321x <f@f321x.com>
This commit is contained in:
@@ -2200,7 +2200,7 @@ class Peer(Logger, EventListener):
|
|||||||
# get payment hash of any htlc in the set (they are all the same)
|
# get payment hash of any htlc in the set (they are all the same)
|
||||||
payment_hash = htlc_set.get_payment_hash()
|
payment_hash = htlc_set.get_payment_hash()
|
||||||
assert payment_hash is not None, htlc_set
|
assert payment_hash is not None, htlc_set
|
||||||
assert payment_hash not in self.lnworker.dont_settle_htlcs
|
assert payment_hash.hex() not in self.lnworker.dont_settle_htlcs
|
||||||
self.lnworker.dont_expire_htlcs.pop(payment_hash.hex(), None) # htlcs wont get expired anymore
|
self.lnworker.dont_expire_htlcs.pop(payment_hash.hex(), None) # htlcs wont get expired anymore
|
||||||
for mpp_htlc in list(htlc_set.htlcs):
|
for mpp_htlc in list(htlc_set.htlcs):
|
||||||
htlc_id = mpp_htlc.htlc.htlc_id
|
htlc_id = mpp_htlc.htlc.htlc_id
|
||||||
@@ -2214,10 +2214,12 @@ class Peer(Logger, EventListener):
|
|||||||
if chan.hm.was_htlc_preimage_released(htlc_id=htlc_id, htlc_proposer=REMOTE):
|
if chan.hm.was_htlc_preimage_released(htlc_id=htlc_id, htlc_proposer=REMOTE):
|
||||||
# this check is intended to gracefully handle stale htlcs in the set, e.g. after a crash
|
# this check is intended to gracefully handle stale htlcs in the set, e.g. after a crash
|
||||||
self.logger.debug(f"{mpp_htlc=} was already settled before, dropping it.")
|
self.logger.debug(f"{mpp_htlc=} was already settled before, dropping it.")
|
||||||
htlc_set.htlcs.remove(mpp_htlc)
|
htlc_set = htlc_set._replace(htlcs=htlc_set.htlcs - {mpp_htlc})
|
||||||
|
self.lnworker.received_mpp_htlcs[payment_key] = htlc_set
|
||||||
continue
|
continue
|
||||||
self._fulfill_htlc(chan, htlc_id, preimage)
|
self._fulfill_htlc(chan, htlc_id, preimage)
|
||||||
htlc_set.htlcs.remove(mpp_htlc)
|
htlc_set = htlc_set._replace(htlcs=htlc_set.htlcs - {mpp_htlc})
|
||||||
|
self.lnworker.received_mpp_htlcs[payment_key] = htlc_set
|
||||||
# reset just-in-time opening fee of channel
|
# reset just-in-time opening fee of channel
|
||||||
chan.jit_opening_fee = None
|
chan.jit_opening_fee = None
|
||||||
|
|
||||||
@@ -2255,7 +2257,8 @@ class Peer(Logger, EventListener):
|
|||||||
if chan.hm.was_htlc_failed(htlc_id=htlc_id, htlc_proposer=REMOTE):
|
if chan.hm.was_htlc_failed(htlc_id=htlc_id, htlc_proposer=REMOTE):
|
||||||
# this check is intended to gracefully handle stale htlcs in the set, e.g. after a crash
|
# this check is intended to gracefully handle stale htlcs in the set, e.g. after a crash
|
||||||
self.logger.debug(f"{mpp_htlc=} was already failed before, dropping it.")
|
self.logger.debug(f"{mpp_htlc=} was already failed before, dropping it.")
|
||||||
htlc_set.htlcs.remove(mpp_htlc)
|
htlc_set = htlc_set._replace(htlcs=htlc_set.htlcs - {mpp_htlc})
|
||||||
|
self.lnworker.received_mpp_htlcs[payment_key] = htlc_set
|
||||||
continue
|
continue
|
||||||
onion_packet = self._parse_onion_packet(mpp_htlc.unprocessed_onion)
|
onion_packet = self._parse_onion_packet(mpp_htlc.unprocessed_onion)
|
||||||
processed_onion_packet = self._process_incoming_onion_packet(
|
processed_onion_packet = self._process_incoming_onion_packet(
|
||||||
@@ -2286,7 +2289,8 @@ class Peer(Logger, EventListener):
|
|||||||
htlc_id=htlc_id,
|
htlc_id=htlc_id,
|
||||||
error_bytes=error_bytes,
|
error_bytes=error_bytes,
|
||||||
)
|
)
|
||||||
htlc_set.htlcs.remove(mpp_htlc)
|
htlc_set = htlc_set._replace(htlcs=htlc_set.htlcs - {mpp_htlc})
|
||||||
|
self.lnworker.received_mpp_htlcs[payment_key] = htlc_set
|
||||||
|
|
||||||
def fail_htlc(self, *, chan: Channel, htlc_id: int, error_bytes: bytes):
|
def fail_htlc(self, *, chan: Channel, htlc_id: int, error_bytes: bytes):
|
||||||
self.logger.info(f"fail_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}.")
|
self.logger.info(f"fail_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}.")
|
||||||
@@ -3148,11 +3152,12 @@ class Peer(Logger, EventListener):
|
|||||||
if not parent:
|
if not parent:
|
||||||
parent = ReceivedMPPStatus(
|
parent = ReceivedMPPStatus(
|
||||||
resolution=RecvMPPResolution.WAITING,
|
resolution=RecvMPPResolution.WAITING,
|
||||||
htlcs=set(),
|
htlcs=frozenset(),
|
||||||
)
|
)
|
||||||
self.lnworker.received_mpp_htlcs[mpp_set.parent_set_key] = parent
|
self.lnworker.received_mpp_htlcs[mpp_set.parent_set_key] = parent._replace(
|
||||||
parent.htlcs.update(mpp_set.htlcs)
|
htlcs=parent.htlcs | mpp_set.htlcs
|
||||||
mpp_set.htlcs.clear()
|
)
|
||||||
|
self.lnworker.received_mpp_htlcs[payment_key] = mpp_set._replace(htlcs=frozenset())
|
||||||
return None, None, None # this set will get deleted as there are no htlcs in it anymore
|
return None, None, None # this set will get deleted as there are no htlcs in it anymore
|
||||||
|
|
||||||
assert not mpp_set.parent_set_key
|
assert not mpp_set.parent_set_key
|
||||||
|
|||||||
@@ -1980,7 +1980,7 @@ class ReceivedMPPHtlc(NamedTuple):
|
|||||||
|
|
||||||
class ReceivedMPPStatus(NamedTuple):
|
class ReceivedMPPStatus(NamedTuple):
|
||||||
resolution: RecvMPPResolution
|
resolution: RecvMPPResolution
|
||||||
htlcs: set[ReceivedMPPHtlc]
|
htlcs: frozenset[ReceivedMPPHtlc]
|
||||||
# parent_set_key is needed as trampoline allows MPP to be nested, the parent_set_key is the
|
# parent_set_key is needed as trampoline allows MPP to be nested, the parent_set_key is the
|
||||||
# payment key of the final mpp set (derived from inner trampoline onion payment secret)
|
# payment key of the final mpp set (derived from inner trampoline onion payment secret)
|
||||||
# to which the separate trampoline sets htlcs get added once they are complete.
|
# to which the separate trampoline sets htlcs get added once they are complete.
|
||||||
@@ -2005,7 +2005,7 @@ class ReceivedMPPStatus(NamedTuple):
|
|||||||
@stored_in('received_mpp_htlcs', tuple)
|
@stored_in('received_mpp_htlcs', tuple)
|
||||||
def from_tuple(resolution, htlc_list, parent_set_key=None) -> 'ReceivedMPPStatus':
|
def from_tuple(resolution, htlc_list, parent_set_key=None) -> 'ReceivedMPPStatus':
|
||||||
assert isinstance(resolution, int)
|
assert isinstance(resolution, int)
|
||||||
htlc_set = set(ReceivedMPPHtlc.from_tuple(*htlc_data) for htlc_data in htlc_list)
|
htlc_set = frozenset(ReceivedMPPHtlc.from_tuple(*htlc_data) for htlc_data in htlc_list)
|
||||||
return ReceivedMPPStatus(
|
return ReceivedMPPStatus(
|
||||||
resolution=RecvMPPResolution(resolution),
|
resolution=RecvMPPResolution(resolution),
|
||||||
htlcs=htlc_set,
|
htlcs=htlc_set,
|
||||||
|
|||||||
@@ -2807,7 +2807,7 @@ class LNWallet(Logger):
|
|||||||
self.logger.debug(f"creating new mpp set for {payment_key=}")
|
self.logger.debug(f"creating new mpp set for {payment_key=}")
|
||||||
mpp_status = ReceivedMPPStatus(
|
mpp_status = ReceivedMPPStatus(
|
||||||
resolution=RecvMPPResolution.WAITING,
|
resolution=RecvMPPResolution.WAITING,
|
||||||
htlcs=set(),
|
htlcs=frozenset(),
|
||||||
)
|
)
|
||||||
|
|
||||||
if mpp_status.resolution > RecvMPPResolution.WAITING:
|
if mpp_status.resolution > RecvMPPResolution.WAITING:
|
||||||
@@ -2827,8 +2827,9 @@ class LNWallet(Logger):
|
|||||||
)
|
)
|
||||||
assert new_htlc not in mpp_status.htlcs, "each htlc should make it here only once?"
|
assert new_htlc not in mpp_status.htlcs, "each htlc should make it here only once?"
|
||||||
assert isinstance(unprocessed_onion_packet, str)
|
assert isinstance(unprocessed_onion_packet, str)
|
||||||
mpp_status.htlcs.add(new_htlc) # side-effecting htlc_set
|
new_htlcs = set(mpp_status.htlcs)
|
||||||
self.received_mpp_htlcs[payment_key] = mpp_status
|
new_htlcs.add(new_htlc)
|
||||||
|
self.received_mpp_htlcs[payment_key] = mpp_status._replace(htlcs=frozenset(new_htlcs))
|
||||||
|
|
||||||
def set_mpp_resolution(self, payment_key: str, new_resolution: RecvMPPResolution) -> ReceivedMPPStatus:
|
def set_mpp_resolution(self, payment_key: str, new_resolution: RecvMPPResolution) -> ReceivedMPPStatus:
|
||||||
mpp_status = self.received_mpp_htlcs[payment_key]
|
mpp_status = self.received_mpp_htlcs[payment_key]
|
||||||
@@ -2910,10 +2911,14 @@ class LNWallet(Logger):
|
|||||||
assert chan._state == ChannelState.REDEEMED
|
assert chan._state == ChannelState.REDEEMED
|
||||||
for payment_key_hex, mpp_status in list(self.received_mpp_htlcs.items()):
|
for payment_key_hex, mpp_status in list(self.received_mpp_htlcs.items()):
|
||||||
htlcs_to_remove = [htlc for htlc in mpp_status.htlcs if htlc.channel_id == chan.channel_id]
|
htlcs_to_remove = [htlc for htlc in mpp_status.htlcs if htlc.channel_id == chan.channel_id]
|
||||||
|
new_htlcs = set(mpp_status.htlcs)
|
||||||
for stale_mpp_htlc in htlcs_to_remove:
|
for stale_mpp_htlc in htlcs_to_remove:
|
||||||
assert mpp_status.resolution != RecvMPPResolution.WAITING
|
assert mpp_status.resolution != RecvMPPResolution.WAITING
|
||||||
self.logger.info(f'maybe_cleanup_mpp: removing htlc of MPP {payment_key_hex}')
|
self.logger.info(f'maybe_cleanup_mpp: removing htlc of MPP {payment_key_hex}')
|
||||||
mpp_status.htlcs.remove(stale_mpp_htlc) # side-effecting htlc_set
|
new_htlcs.remove(stale_mpp_htlc)
|
||||||
|
if htlcs_to_remove:
|
||||||
|
mpp_status = mpp_status._replace(htlcs=frozenset(new_htlcs))
|
||||||
|
self.received_mpp_htlcs[payment_key_hex] = mpp_status # save changes to db
|
||||||
if len(mpp_status.htlcs) == 0:
|
if len(mpp_status.htlcs) == 0:
|
||||||
self.logger.info(f'maybe_cleanup_mpp: removing mpp {payment_key_hex}')
|
self.logger.info(f'maybe_cleanup_mpp: removing mpp {payment_key_hex}')
|
||||||
del self.received_mpp_htlcs[payment_key_hex]
|
del self.received_mpp_htlcs[payment_key_hex]
|
||||||
|
|||||||
@@ -331,7 +331,7 @@ class MyEncoder(json.JSONEncoder):
|
|||||||
if isinstance(obj, datetime):
|
if isinstance(obj, datetime):
|
||||||
# note: if there is a timezone specified, this will include the offset
|
# note: if there is a timezone specified, this will include the offset
|
||||||
return obj.isoformat(' ', timespec="minutes")
|
return obj.isoformat(' ', timespec="minutes")
|
||||||
if isinstance(obj, set):
|
if isinstance(obj, (set, frozenset)):
|
||||||
return list(obj)
|
return list(obj)
|
||||||
if isinstance(obj, bytes): # for nametuples in lnchannel
|
if isinstance(obj, bytes): # for nametuples in lnchannel
|
||||||
return obj.hex()
|
return obj.hex()
|
||||||
|
|||||||
Reference in New Issue
Block a user