1
0

lnaddr: make min_cltv logic less error-prone

round-tripping the value behaved unexpectedly before...
This commit is contained in:
SomberNight
2023-06-16 16:59:13 +00:00
parent f7a8e55a6a
commit a3b0e97c88
2 changed files with 6 additions and 3 deletions

View File

@@ -272,7 +272,6 @@ class LnAddr(object):
self.pubkey = None self.pubkey = None
self.net = constants.net if net is None else net # type: Type[AbstractNet] self.net = constants.net if net is None else net # type: Type[AbstractNet]
self._amount = amount # type: Optional[Decimal] # in bitcoins self._amount = amount # type: Optional[Decimal] # in bitcoins
self._min_final_cltv_expiry = 18
@property @property
def amount(self) -> Optional[Decimal]: def amount(self) -> Optional[Decimal]:
@@ -326,7 +325,10 @@ class LnAddr(object):
) )
def get_min_final_cltv_expiry(self) -> int: def get_min_final_cltv_expiry(self) -> int:
return self._min_final_cltv_expiry cltv = self.get_tag('c')
if cltv is None:
return 18
return int(cltv)
def get_tag(self, tag): def get_tag(self, tag):
for k, v in self.tags: for k, v in self.tags:
@@ -482,7 +484,7 @@ def lndecode(invoice: str, *, verbose=False, net=None) -> LnAddr:
addr.pubkey = pubkeybytes addr.pubkey = pubkeybytes
elif tag == 'c': elif tag == 'c':
addr._min_final_cltv_expiry = tagdata.uint addr.tags.append(('c', tagdata.uint))
elif tag == '9': elif tag == '9':
features = tagdata.uint features = tagdata.uint

View File

@@ -146,6 +146,7 @@ class TestBolt11(ElectrumTestCase):
def test_min_final_cltv_expiry_roundtrip(self): def test_min_final_cltv_expiry_roundtrip(self):
for cltv in (1, 15, 16, 31, 32, 33, 150, 511, 512, 513, 1023, 1024, 1025): for cltv in (1, 15, 16, 31, 32, 33, 150, 511, 512, 513, 1023, 1024, 1025):
lnaddr = LnAddr(paymenthash=RHASH, amount=Decimal('0.001'), tags=[('d', '1 cup coffee'), ('x', 60), ('c', cltv)]) lnaddr = LnAddr(paymenthash=RHASH, amount=Decimal('0.001'), tags=[('d', '1 cup coffee'), ('x', 60), ('c', cltv)])
self.assertEqual(cltv, lnaddr.get_min_final_cltv_expiry())
invoice = lnencode(lnaddr, PRIVKEY) invoice = lnencode(lnaddr, PRIVKEY)
self.assertEqual(cltv, lndecode(invoice).get_min_final_cltv_expiry()) self.assertEqual(cltv, lndecode(invoice).get_min_final_cltv_expiry())