1
0

lnworker/lnpeer: don't use lnworker.channels.get(chan_id)

- lnworker.channels takes a copy of the whole dict, to make it thread-safe
- in LNWallet class, can just use self._channels.get(chan_id)
- otherwise there is lnworker.get_channel_by_id
- same for lnpeer.channels.get and lnpeer.get_channel_by_id
This commit is contained in:
SomberNight
2025-12-10 16:14:31 +00:00
parent 6ceb4ad71f
commit c465f7c3e0
3 changed files with 18 additions and 18 deletions

View File

@@ -207,8 +207,8 @@ class ChannelsList(MyTreeView):
idx2 = selected[1] idx2 = selected[1]
channel_id1 = idx1.sibling(idx1.row(), self.Columns.NODE_ALIAS).data(ROLE_CHANNEL_ID) channel_id1 = idx1.sibling(idx1.row(), self.Columns.NODE_ALIAS).data(ROLE_CHANNEL_ID)
channel_id2 = idx2.sibling(idx2.row(), self.Columns.NODE_ALIAS).data(ROLE_CHANNEL_ID) channel_id2 = idx2.sibling(idx2.row(), self.Columns.NODE_ALIAS).data(ROLE_CHANNEL_ID)
chan1 = self.lnworker.channels.get(channel_id1) chan1 = self.lnworker.get_channel_by_id(channel_id1)
chan2 = self.lnworker.channels.get(channel_id2) chan2 = self.lnworker.get_channel_by_id(channel_id2)
if chan1 and chan2 and (not self.lnworker.uses_trampoline() or chan1.node_id != chan2.node_id): if chan1 and chan2 and (not self.lnworker.uses_trampoline() or chan1.node_id != chan2.node_id):
return chan1, chan2 return chan1, chan2
return None, None return None, None

View File

@@ -289,7 +289,7 @@ class Peer(Logger, EventListener):
self.logger.info(f"remote peer sent error [DO NOT TRUST THIS MESSAGE]: " self.logger.info(f"remote peer sent error [DO NOT TRUST THIS MESSAGE]: "
f"{error_text_bytes_to_safe_str(err_bytes, max_len=None)}. chan_id={chan_id.hex()}. " f"{error_text_bytes_to_safe_str(err_bytes, max_len=None)}. chan_id={chan_id.hex()}. "
f"{is_known_chan_id=}") f"{is_known_chan_id=}")
if chan := self.channels.get(chan_id): if chan := self.get_channel_by_id(chan_id):
self.schedule_force_closing(chan_id) self.schedule_force_closing(chan_id)
self.ordered_message_queues[chan_id].put_nowait((None, {'error': err_bytes})) self.ordered_message_queues[chan_id].put_nowait((None, {'error': err_bytes}))
chan.save_remote_peer_sent_error(err_bytes) chan.save_remote_peer_sent_error(err_bytes)
@@ -1499,7 +1499,7 @@ class Peer(Logger, EventListener):
channels_with_peer.extend(self.temp_id_to_id.values()) channels_with_peer.extend(self.temp_id_to_id.values())
if channel_id not in channels_with_peer: if channel_id not in channels_with_peer:
raise ValueError(f"channel {channel_id.hex()} does not belong to this peer") raise ValueError(f"channel {channel_id.hex()} does not belong to this peer")
chan = self.channels.get(channel_id) chan = self.get_channel_by_id(channel_id)
if not chan: if not chan:
self.logger.warning(f"tried to force-close channel {channel_id.hex()} but it is not in self.channels yet") self.logger.warning(f"tried to force-close channel {channel_id.hex()} but it is not in self.channels yet")
if ChanCloseOption.LOCAL_FCLOSE in chan.get_close_options(): if ChanCloseOption.LOCAL_FCLOSE in chan.get_close_options():
@@ -2275,8 +2275,8 @@ class Peer(Logger, EventListener):
self.lnworker.dont_expire_htlcs.pop(payment_hash.hex(), None) # htlcs wont get expired anymore self.lnworker.dont_expire_htlcs.pop(payment_hash.hex(), None) # htlcs wont get expired anymore
for mpp_htlc in list(htlc_set.htlcs): for mpp_htlc in list(htlc_set.htlcs):
htlc_id = mpp_htlc.htlc.htlc_id htlc_id = mpp_htlc.htlc.htlc_id
chan = self.lnworker.channels[mpp_htlc.channel_id] chan = self.get_channel_by_id(mpp_htlc.channel_id)
if chan.channel_id not in self.channels: if chan is None:
# this htlc belongs to another peer and has to be settled in their htlc_switch # this htlc belongs to another peer and has to be settled in their htlc_switch
continue continue
if not chan.can_update_ctx(proposer=LOCAL): if not chan.can_update_ctx(proposer=LOCAL):
@@ -2317,9 +2317,9 @@ class Peer(Logger, EventListener):
self.lnworker.dont_expire_htlcs.pop(payment_hash.hex(), None) self.lnworker.dont_expire_htlcs.pop(payment_hash.hex(), None)
self.lnworker.dont_settle_htlcs.pop(payment_hash.hex(), None) # already failed self.lnworker.dont_settle_htlcs.pop(payment_hash.hex(), None) # already failed
for mpp_htlc in list(htlc_set.htlcs): for mpp_htlc in list(htlc_set.htlcs):
chan = self.lnworker.channels[mpp_htlc.channel_id] chan = self.get_channel_by_id(mpp_htlc.channel_id)
htlc_id = mpp_htlc.htlc.htlc_id htlc_id = mpp_htlc.htlc.htlc_id
if chan.channel_id not in self.channels: if chan is None:
# this htlc belongs to another peer and has to be settled in their htlc_switch # this htlc belongs to another peer and has to be settled in their htlc_switch
continue continue
if not chan.can_update_ctx(proposer=LOCAL): if not chan.can_update_ctx(proposer=LOCAL):
@@ -2489,7 +2489,8 @@ class Peer(Logger, EventListener):
@log_exceptions @log_exceptions
async def close_channel(self, chan_id: bytes): async def close_channel(self, chan_id: bytes):
chan = self.channels[chan_id] chan = self.get_channel_by_id(chan_id)
assert chan
self.shutdown_received[chan_id] = self.asyncio_loop.create_future() self.shutdown_received[chan_id] = self.asyncio_loop.create_future()
await self.send_shutdown(chan) await self.send_shutdown(chan)
payload = await self.shutdown_received[chan_id] payload = await self.shutdown_received[chan_id]
@@ -2949,7 +2950,7 @@ class Peer(Logger, EventListener):
onion_payload: dict onion_payload: dict
) -> Callable[[str], None]: ) -> Callable[[str], None]:
def _log_fail_reason(reason: str) -> None: def _log_fail_reason(reason: str) -> None:
scid = self.lnworker.channels[channel_id].short_channel_id scid = self.lnworker.get_channel_by_id(channel_id).short_channel_id
self.logger.info(f"will FAIL HTLC: {str(scid)=}. {reason=}. {str(htlc)=}. {onion_payload=}") self.logger.info(f"will FAIL HTLC: {str(scid)=}. {reason=}. {str(htlc)=}. {onion_payload=}")
return _log_fail_reason return _log_fail_reason
@@ -3082,7 +3083,7 @@ class Peer(Logger, EventListener):
if mpp_set.resolution == RecvMPPResolution.WAITING: if mpp_set.resolution == RecvMPPResolution.WAITING:
# calculate the sum of just in time channel opening fees, note jit only supports # calculate the sum of just in time channel opening fees, note jit only supports
# single part payments for now, this is enforced by checking against the invoice features # single part payments for now, this is enforced by checking against the invoice features
htlc_channels = [self.lnworker.channels[channel_id] for channel_id in set(h.channel_id for h in mpp_set.htlcs)] htlc_channels = [self.lnworker.get_channel_by_id(channel_id) for channel_id in set(h.channel_id for h in mpp_set.htlcs)]
jit_opening_fees_msat = sum((c.jit_opening_fee or 0) for c in htlc_channels) jit_opening_fees_msat = sum((c.jit_opening_fee or 0) for c in htlc_channels)
# check if set is first stage multi-trampoline payment to us # check if set is first stage multi-trampoline payment to us

View File

@@ -2149,7 +2149,7 @@ class LNWallet(LNWorker):
per_trampoline_channel_amounts = defaultdict(list) per_trampoline_channel_amounts = defaultdict(list)
# categorize by trampoline nodes for trampoline mpp construction # categorize by trampoline nodes for trampoline mpp construction
for (chan_id, _), part_amounts_msat in sc.config.items(): for (chan_id, _), part_amounts_msat in sc.config.items():
chan = self.channels[chan_id] chan = self._channels[chan_id]
for part_amount_msat in part_amounts_msat: for part_amount_msat in part_amounts_msat:
per_trampoline_channel_amounts[chan.node_id].append((chan_id, part_amount_msat)) per_trampoline_channel_amounts[chan.node_id].append((chan_id, part_amount_msat))
# for each trampoline forwarder, construct mpp trampoline # for each trampoline forwarder, construct mpp trampoline
@@ -2179,7 +2179,7 @@ class LNWallet(LNWorker):
self.logger.info(f'trampoline hops: {[hop.end_node.hex() for hop in trampoline_route]}') self.logger.info(f'trampoline hops: {[hop.end_node.hex() for hop in trampoline_route]}')
self.logger.info(f'per trampoline fees: {per_trampoline_fees}') self.logger.info(f'per trampoline fees: {per_trampoline_fees}')
for chan_id, part_amount_msat in trampoline_parts: for chan_id, part_amount_msat in trampoline_parts:
chan = self.channels[chan_id] chan = self._channels[chan_id]
margin = chan.available_to_spend(LOCAL) - part_amount_msat margin = chan.available_to_spend(LOCAL) - part_amount_msat
delta_fee = min(per_trampoline_fees, margin) delta_fee = min(per_trampoline_fees, margin)
# TODO: distribute trampoline fee over several channels? # TODO: distribute trampoline fee over several channels?
@@ -2216,7 +2216,7 @@ class LNWallet(LNWorker):
# a failure to find a path for a single part, we try the next configuration # a failure to find a path for a single part, we try the next configuration
for (chan_id, _), part_amounts_msat in sc.config.items(): for (chan_id, _), part_amounts_msat in sc.config.items():
for part_amount_msat in part_amounts_msat: for part_amount_msat in part_amounts_msat:
channel = self.channels[chan_id] channel = self._channels[chan_id]
route = await run_in_thread( route = await run_in_thread(
partial( partial(
self.create_route_for_single_htlc, self.create_route_for_single_htlc,
@@ -3277,7 +3277,7 @@ class LNWallet(LNWorker):
return asyncio.create_task(self.network.try_broadcasting(tx, 'force-close')) return asyncio.create_task(self.network.try_broadcasting(tx, 'force-close'))
def remove_channel(self, chan_id): def remove_channel(self, chan_id):
chan = self.channels[chan_id] chan = self._channels[chan_id]
assert chan.can_be_deleted() assert chan.can_be_deleted()
with self.lock: with self.lock:
self._channels.pop(chan_id) self._channels.pop(chan_id)
@@ -3394,8 +3394,7 @@ class LNWallet(LNWorker):
return 'channel_backup:' + encrypted return 'channel_backup:' + encrypted
async def request_force_close(self, channel_id: bytes, *, connect_str=None) -> None: async def request_force_close(self, channel_id: bytes, *, connect_str=None) -> None:
if channel_id in self.channels: if chan := self.get_channel_by_id(channel_id):
chan = self.channels[channel_id]
peer = self._peers.get(chan.node_id) peer = self._peers.get(chan.node_id)
chan.should_request_force_close = True chan.should_request_force_close = True
if peer: if peer:
@@ -3558,7 +3557,7 @@ class LNWallet(LNWorker):
assert not any_outer_onion.are_we_final assert not any_outer_onion.are_we_final
assert len(processed_htlc_set) == 1, processed_htlc_set assert len(processed_htlc_set) == 1, processed_htlc_set
forward_htlc = any_mpp_htlc.htlc forward_htlc = any_mpp_htlc.htlc
incoming_chan = self.channels[any_mpp_htlc.channel_id] incoming_chan = self._channels[any_mpp_htlc.channel_id]
next_htlc = await self._maybe_forward_htlc( next_htlc = await self._maybe_forward_htlc(
incoming_chan=incoming_chan, incoming_chan=incoming_chan,
htlc=forward_htlc, htlc=forward_htlc,