diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py index a04eaee7a..1dcccfc36 100644 --- a/electrum/lnpeer.py +++ b/electrum/lnpeer.py @@ -45,8 +45,7 @@ from .lnutil import (Outpoint, LocalConfig, RECEIVED, UpdateAddHtlc, ChannelConf LOCAL, REMOTE, HTLCOwner, ln_compare_features, MIN_FINAL_CLTV_DELTA_ACCEPTED, RemoteMisbehaving, ShortChannelID, - IncompatibleLightningFeatures, derive_payment_secret_from_payment_preimage, - ChannelType, LNProtocolWarning, validate_features, + IncompatibleLightningFeatures, ChannelType, LNProtocolWarning, validate_features, IncompatibleOrInsaneFeatures, FeeBudgetExceeded, GossipForwardingMessage, GossipTimestampFilter, channel_id_from_funding_tx, PaymentFeeBudget, serialize_htlc_key, Keypair) @@ -2591,11 +2590,9 @@ class Peer(Logger, EventListener): raise exc_incorrect_or_unknown_pd preimage = self.lnworker.get_preimage(payment_hash) - expected_payment_secrets = [self.lnworker.get_payment_secret(htlc.payment_hash)] - if preimage: - expected_payment_secrets.append(derive_payment_secret_from_payment_preimage(preimage)) # legacy secret for old invoices - if payment_secret_from_onion not in expected_payment_secrets: - log_fail_reason(f'incorrect payment secret {payment_secret_from_onion.hex()} != {expected_payment_secrets[0].hex()}') + expected_payment_secret = self.lnworker.get_payment_secret(htlc.payment_hash) + if payment_secret_from_onion != expected_payment_secret: + log_fail_reason(f'incorrect payment secret {payment_secret_from_onion.hex()} != {expected_payment_secret.hex()}') raise exc_incorrect_or_unknown_pd invoice_msat = info.amount_msat if channel_opening_fee: diff --git a/electrum/lnutil.py b/electrum/lnutil.py index 11cee641c..80e74de30 100644 --- a/electrum/lnutil.py +++ b/electrum/lnutil.py @@ -1827,19 +1827,6 @@ def validate_features(features: int) -> LnFeatures: return features -def derive_payment_secret_from_payment_preimage(payment_preimage: bytes) -> bytes: - """Returns secret to be put into invoice. - Derivation is deterministic, based on the preimage. - Crucially the payment_hash must be derived in an independent way from this. - """ - # Note that this could be random data too, but then we would need to store it. - # We derive it identically to clightning, so that we cannot be distinguished: - # https://github.com/ElementsProject/lightning/blob/faac4b28adee5221e83787d64cd5d30b16b62097/lightningd/invoice.c#L115 - modified = bytearray(payment_preimage) - modified[0] ^= 1 - return sha256(bytes(modified)) - - def get_compressed_pubkey_from_bech32(bech32_pubkey: str) -> bytes: decoded_bech32 = segwit_addr.bech32_decode(bech32_pubkey) hrp = decoded_bech32.hrp diff --git a/tests/test_bolt11.py b/tests/test_bolt11.py index c4f756b1b..fd1884ff3 100644 --- a/tests/test_bolt11.py +++ b/tests/test_bolt11.py @@ -7,7 +7,7 @@ import unittest from electrum.lnaddr import shorten_amount, unshorten_amount, LnAddr, lnencode, lndecode from electrum.segwit_addr import bech32_encode, bech32_decode from electrum import segwit_addr -from electrum.lnutil import UnknownEvenFeatureBits, derive_payment_secret_from_payment_preimage, LnFeatures, IncompatibleLightningFeatures +from electrum.lnutil import UnknownEvenFeatureBits, LnFeatures, IncompatibleLightningFeatures from electrum import constants from . import ElectrumTestCase @@ -164,11 +164,6 @@ class TestBolt11(ElectrumTestCase): self.assertEqual((1 << 9) + (1 << 15) + (1 << 99), lnaddr.get_tag('9')) self.assertEqual(b"\x11" * 32, lnaddr.payment_secret) - def test_derive_payment_secret_from_payment_preimage(self): - preimage = bytes.fromhex("cc3fc000bdeff545acee53ada12ff96060834be263f77d645abbebc3a8d53b92") - self.assertEqual("bfd660b559b3f452c6bb05b8d2906f520c151c107b733863ed0cc53fc77021a8", - derive_payment_secret_from_payment_preimage(preimage).hex()) - def test_validate_and_compare_features(self): lnaddr = lndecode("lnbc25m1pvjluezpp5qqqsyqcyq5rqwzqfqqqsyqcyq5rqwzqfqqqsyqcyq5rqwzqfqypqsp5zyg3zyg3zyg3zyg3zyg3zyg3zyg3zyg3zyg3zyg3zyg3zyg3zygsdq5vdhkven9v5sxyetpdees9q5sqqqqqqqqqqqqqqqpqsqvvh7ut50r00p3pg34ea68k7zfw64f8yx9jcdk35lh5ft8qdr8g4r0xzsdcrmcy9hex8un8d8yraewvhqc9l0sh8l0e0yvmtxde2z0hgpzsje5l") lnaddr.validate_and_compare_features(LnFeatures((1 << 8) + (1 << 14) + (1 << 15)))