diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py index fe0516b8d..55a591c46 100644 --- a/electrum/lnpeer.py +++ b/electrum/lnpeer.py @@ -4,7 +4,6 @@ # Distributed under the MIT software license, see the accompanying # file LICENCE or http://www.opensource.org/licenses/mit-license.php -import zlib from collections import OrderedDict, defaultdict import asyncio import os @@ -786,13 +785,13 @@ class Peer(Logger, EventListener): first_blocknum=first_block, number_of_blocks=num_blocks) - def decode_short_ids(self, encoded): - if encoded[0] == 0: - decoded = encoded[1:] - elif encoded[0] == 1: - decoded = zlib.decompress(encoded[1:]) - else: + @staticmethod + def decode_short_ids(encoded): + if len(encoded) < 1 or (len(encoded) - 1) % 8 != 0: + raise Exception(f'decode_short_ids: invalid size: {len(encoded)=}') + elif encoded[0] != 0: raise Exception(f'decode_short_ids: unexpected first byte: {encoded[0]}') + decoded = encoded[1:] ids = [decoded[i:i+8] for i in range(0, len(decoded), 8)] return ids diff --git a/tests/test_lnpeer.py b/tests/test_lnpeer.py index 941a4a12f..1baf3b156 100644 --- a/tests/test_lnpeer.py +++ b/tests/test_lnpeer.py @@ -634,6 +634,45 @@ class TestPeer(ElectrumTestCase): w2.register_hold_invoice(payment_hash, cb) +class TestPeerUtils(TestPeer): + + def test_decode_short_ids(self): + """ + Test Peer.decode_short_ids() against some data from + https://github.com/lightning/bolts/commit/313c0f290eb87e96dc8195cad0c891418a826c2c + """ + # Test uncompressed encoding with three scids + encoded_uncompressed = bytes.fromhex("00" + "0000000000003043" + "00000000000778d6" + "000000000046e1c1") + result = Peer.decode_short_ids(encoded_uncompressed) + self.assertEqual(len(result), 3) + self.assertEqual(result[0], bytes.fromhex("0000000000003043")) # 0x0x12355 + self.assertEqual(result[1], bytes.fromhex("00000000000778d6")) # 0x7x30934 + self.assertEqual(result[2], bytes.fromhex("000000000046e1c1")) # 0x70x57793 + + # Test empty list + encoded_empty = bytes.fromhex("00") + result = Peer.decode_short_ids(encoded_empty) + self.assertEqual(result, []) + + # Test single scid + encoded_single = bytes.fromhex("00" + "000000000000008e") # 0x0x142 + result = Peer.decode_short_ids(encoded_single) + self.assertEqual(len(result), 1) + self.assertEqual(result[0], bytes.fromhex("000000000000008e")) + + # test invalid size raises exception + encoded_invalid = bytes.fromhex("00" + "00" * 9) + with self.assertRaises(Exception) as ctx: + Peer.decode_short_ids(encoded_invalid) + self.assertIn("invalid size", str(ctx.exception)) + + # Test unsupported encoding raises exception (considering it even passes the length check) + encoded_unsupported = bytes.fromhex("01" + "00" * 8) # 01 was zlib before removed + with self.assertRaises(Exception) as ctx: + Peer.decode_short_ids(encoded_unsupported) + self.assertIn("unexpected first byte", str(ctx.exception)) + + class TestPeerDirect(TestPeer): def prepare_peers(