1
0

dependencies: remove bitstring

- `bitstring` started depending on `bitarray` in version 4.1 [0]
  - that would mean one additional dependency for us (from yet another maintainer), which is not even pure python
- we only use bitstring for bolt11-parsing
- hence this PR rewrites the bolt11-parsing and removes `bitstring` as dependency
- note: I benchmarked lndecode using [1], and the new code performs better,
  taking around 80% time needed for old code (when using bitstring 3.1.9, pure python).
  Though the variance is quite large in both cases.

[0]: 95ee533ee4/release_notes.txt (L108)
[1]: d7597d96d0
This commit is contained in:
SomberNight
2024-04-24 14:10:01 +00:00
parent 20d7543b53
commit cf2ed509b4
6 changed files with 184 additions and 170 deletions

View File

@@ -1,18 +1,17 @@
#! /usr/bin/env python3
# This was forked from https://github.com/rustyrussell/lightning-payencode/tree/acc16ec13a3fa1dc16c07af6ec67c261bd8aff23
import io
import re
import time
from hashlib import sha256
from binascii import hexlify
from decimal import Decimal
from typing import Optional, TYPE_CHECKING, Type, Dict, Any
from typing import Optional, TYPE_CHECKING, Type, Dict, Any, Union, Sequence, List, Tuple
import random
import bitstring
from .bitcoin import hash160_to_b58_address, b58_address_to_hash160, TOTAL_COIN_SUPPLY_LIMIT_IN_BTC
from .segwit_addr import bech32_encode, bech32_decode, CHARSET
from .segwit_addr import bech32_encode, bech32_decode, CHARSET, CHARSET_INVERSE, convertbits
from . import segwit_addr
from . import constants
from .constants import AbstractNet
@@ -75,25 +74,9 @@ def unshorten_amount(amount) -> Decimal:
else:
return Decimal(amount)
_INT_TO_BINSTR = {a: '0' * (5-len(bin(a)[2:])) + bin(a)[2:] for a in range(32)}
# Bech32 spits out array of 5-bit values. Shim here.
def u5_to_bitarray(arr):
b = ''.join(_INT_TO_BINSTR[a] for a in arr)
return bitstring.BitArray(bin=b)
def bitarray_to_u5(barr):
assert barr.len % 5 == 0
ret = []
s = bitstring.ConstBitStream(barr)
while s.pos != s.len:
ret.append(s.read(5).uint)
return ret
def encode_fallback(fallback: str, net: Type[AbstractNet]):
""" Encode all supported fallback addresses.
"""
def encode_fallback_addr(fallback: str, net: Type[AbstractNet]) -> Sequence[int]:
"""Encode all supported fallback addresses."""
wver, wprog_ints = segwit_addr.decode_segwit_address(net.SEGWIT_HRP, fallback)
if wver is not None:
wprog = bytes(wprog_ints)
@@ -106,20 +89,20 @@ def encode_fallback(fallback: str, net: Type[AbstractNet]):
else:
raise LnEncodeException(f"Unknown address type {addrtype} for {net}")
wprog = addr
return tagged('f', bitstring.pack("uint:5", wver) + wprog)
data5 = convertbits(wprog, 8, 5)
assert data5 is not None
return tagged5('f', [wver] + list(data5))
def parse_fallback(fallback, net: Type[AbstractNet]):
wver = fallback[0:5].uint
def parse_fallback_addr(data5: Sequence[int], net: Type[AbstractNet]) -> Optional[str]:
wver = data5[0]
data8 = bytes(convertbits(data5[1:], 5, 8, False))
if wver == 17:
addr = hash160_to_b58_address(fallback[5:].tobytes(), net.ADDRTYPE_P2PKH)
addr = hash160_to_b58_address(data8, net.ADDRTYPE_P2PKH)
elif wver == 18:
addr = hash160_to_b58_address(fallback[5:].tobytes(), net.ADDRTYPE_P2SH)
addr = hash160_to_b58_address(data8, net.ADDRTYPE_P2SH)
elif wver <= 16:
witprog = fallback[5:] # cut witver
witprog = witprog[:len(witprog) // 8 * 8] # can only be full bytes
witprog = witprog.tobytes()
addr = segwit_addr.encode_segwit_address(net.SEGWIT_HRP, wver, witprog)
addr = segwit_addr.encode_segwit_address(net.SEGWIT_HRP, wver, data8)
else:
return None
return addr
@@ -128,47 +111,52 @@ def parse_fallback(fallback, net: Type[AbstractNet]):
BOLT11_HRP_INV_DICT = {net.BOLT11_HRP: net for net in constants.NETS_LIST}
# Tagged field containing BitArray
def tagged(char, l):
# Tagged fields need to be zero-padded to 5 bits.
while l.len % 5 != 0:
l.append('0b0')
return bitstring.pack("uint:5, uint:5, uint:5",
CHARSET.find(char),
(l.len / 5) / 32, (l.len / 5) % 32) + l
def tagged5(char: str, data5: Sequence[int]) -> Sequence[int]:
assert len(data5) < (1 << 10)
return [CHARSET_INVERSE[char], len(data5) >> 5, len(data5) & 31] + data5
# Tagged field containing bytes
def tagged_bytes(char, l):
return tagged(char, bitstring.BitArray(l))
def trim_to_min_length(bits):
"""Ensures 'bits' have min number of leading zeroes.
Assumes 'bits' is big-endian, and that it needs to be encoded in 5 bit blocks.
def tagged8(char: str, data8: Sequence[int]) -> Sequence[int]:
return tagged5(char, convertbits(data8, 8, 5))
def int_to_data5(val: int, *, bit_len: int = None) -> Sequence[int]:
"""Represent big-endian number with as many 0-31 values as it takes.
If `bit_len` is set, use exactly bit_len//5 values (left-padded with zeroes).
"""
bits = bits[:] # copy
# make sure we can be split into 5 bit blocks
while bits.len % 5 != 0:
bits.prepend('0b0')
# Get minimal length by trimming leading 5 bits at a time.
while bits.startswith('0b00000'):
if len(bits) == 5:
break # v == 0
bits = bits[5:]
return bits
if bit_len is not None:
assert bit_len % 5 == 0, bit_len
if val.bit_length() > bit_len:
raise ValueError(f"{val=} too big for {bit_len=!r}")
ret = []
while val != 0:
ret.append(val % 32)
val //= 32
if bit_len is not None:
ret.extend([0] * (len(ret) - bit_len // 5))
ret.reverse()
return ret
# Discard trailing bits, convert to bytes.
def trim_to_bytes(barr):
# Adds a byte if necessary.
b = barr.tobytes()
if barr.len % 8 != 0:
return b[:-1]
return b
# Try to pull out tagged data: returns tag, tagged data and remainder.
def pull_tagged(stream):
tag = stream.read(5).uint
length = stream.read(5).uint * 32 + stream.read(5).uint
return (CHARSET[tag], stream.read(length * 5), stream)
def int_from_data5(data5: Sequence[int]) -> int:
total = 0
for v in data5:
total = 32 * total + v
return total
def pull_tagged(data5: bytearray) -> Tuple[str, Sequence[int]]:
"""Try to pull out tagged data: returns tag, tagged data. Mutates data in-place."""
if len(data5) < 3:
raise ValueError("Truncated field")
length = data5[1] * 32 + data5[2]
if length > len(data5) - 3:
raise ValueError(
"Truncated {} field: expected {} values".format(CHARSET[data5[0]], length))
ret = (CHARSET[data5[0]], data5[3:3+length])
del data5[:3 + length] # much faster than: data5=data5[offset:]
return ret
def lnencode(addr: 'LnAddr', privkey) -> str:
if addr.amount:
@@ -179,17 +167,17 @@ def lnencode(addr: 'LnAddr', privkey) -> str:
hrp = 'ln' + amount
# Start with the timestamp
data = bitstring.pack('uint:35', addr.date)
data5 = int_to_data5(addr.date, bit_len=35)
tags_set = set()
# Payment hash
assert addr.paymenthash is not None
data += tagged_bytes('p', addr.paymenthash)
data5 += tagged8('p', addr.paymenthash)
tags_set.add('p')
if addr.payment_secret is not None:
data += tagged_bytes('s', addr.payment_secret)
data5 += tagged8('s', addr.payment_secret)
tags_set.add('s')
for k, v in addr.tags:
@@ -202,39 +190,44 @@ def lnencode(addr: 'LnAddr', privkey) -> str:
raise LnEncodeException("Duplicate '{}' tag".format(k))
if k == 'r':
route = bitstring.BitArray()
route = bytearray()
for step in v:
pubkey, channel, feebase, feerate, cltv = step
route.append(bitstring.BitArray(pubkey) + bitstring.BitArray(channel) + bitstring.pack('intbe:32', feebase) + bitstring.pack('intbe:32', feerate) + bitstring.pack('intbe:16', cltv))
data += tagged('r', route)
pubkey, scid, feebase, feerate, cltv = step
route += pubkey
route += scid
route += int.to_bytes(feebase, length=4, byteorder="big", signed=False)
route += int.to_bytes(feerate, length=4, byteorder="big", signed=False)
route += int.to_bytes(cltv, length=2, byteorder="big", signed=False)
data5 += tagged8('r', route)
elif k == 't':
pubkey, feebase, feerate, cltv = v
route = bitstring.BitArray(pubkey) + bitstring.pack('intbe:32', feebase) + bitstring.pack('intbe:32', feerate) + bitstring.pack('intbe:16', cltv)
data += tagged('t', route)
route = bytearray()
route += pubkey
route += int.to_bytes(feebase, length=4, byteorder="big", signed=False)
route += int.to_bytes(feerate, length=4, byteorder="big", signed=False)
route += int.to_bytes(cltv, length=2, byteorder="big", signed=False)
data5 += tagged8('t', route)
elif k == 'f':
if v is not None:
data += encode_fallback(v, addr.net)
data5 += encode_fallback_addr(v, addr.net)
elif k == 'd':
# truncate to max length: 1024*5 bits = 639 bytes
data += tagged_bytes('d', v.encode()[0:639])
data5 += tagged8('d', v.encode()[0:639])
elif k == 'x':
expirybits = bitstring.pack('intbe:64', v)
expirybits = trim_to_min_length(expirybits)
data += tagged('x', expirybits)
expirybits = int_to_data5(v)
data5 += tagged5('x', expirybits)
elif k == 'h':
data += tagged_bytes('h', sha256(v.encode('utf-8')).digest())
data5 += tagged8('h', sha256(v.encode('utf-8')).digest())
elif k == 'n':
data += tagged_bytes('n', v)
data5 += tagged8('n', v)
elif k == 'c':
finalcltvbits = bitstring.pack('intbe:64', v)
finalcltvbits = trim_to_min_length(finalcltvbits)
data += tagged('c', finalcltvbits)
finalcltvbits = int_to_data5(v)
data5 += tagged5('c', finalcltvbits)
elif k == '9':
if v == 0:
continue
feature_bits = bitstring.BitArray(uint=v, length=v.bit_length())
feature_bits = trim_to_min_length(feature_bits)
data += tagged('9', feature_bits)
feature_bits = int_to_data5(v)
data5 += tagged5('9', feature_bits)
else:
# FIXME: Support unknown tags?
raise LnEncodeException("Unknown tag {}".format(k))
@@ -251,15 +244,16 @@ def lnencode(addr: 'LnAddr', privkey) -> str:
raise ValueError("Must include either 'd' or 'h'")
# We actually sign the hrp, then data (padded to 8 bits with zeroes).
msg = hrp.encode("ascii") + data.tobytes()
msg = hrp.encode("ascii") + bytes(convertbits(data5, 5, 8))
msg32 = sha256(msg).digest()
privkey = ecc.ECPrivkey(privkey)
sig = privkey.ecdsa_sign_recoverable(msg32, is_compressed=False)
recovery_flag = bytes([sig[0] - 27])
sig = bytes(sig[1:]) + recovery_flag
data += sig
sig = bytes(convertbits(sig, 8, 5, False))
data5 += sig
return bech32_encode(segwit_addr.Encoding.BECH32, hrp, bitarray_to_u5(data))
return bech32_encode(segwit_addr.Encoding.BECH32, hrp, data5)
class LnAddr(object):
@@ -393,6 +387,7 @@ class SerializableKey:
def serialize(self):
return self.pubkey.get_public_key_bytes(True)
def lndecode(invoice: str, *, verbose=False, net=None) -> LnAddr:
"""Parses a string into an LnAddr object.
Can raise LnDecodeException or IncompatibleOrInsaneFeatures.
@@ -401,7 +396,7 @@ def lndecode(invoice: str, *, verbose=False, net=None) -> LnAddr:
net = constants.net
decoded_bech32 = bech32_decode(invoice, ignore_long_length=True)
hrp = decoded_bech32.hrp
data = decoded_bech32.data
data5 = decoded_bech32.data # "5" as in list of 5-bit integers
if decoded_bech32.encoding is None:
raise LnDecodeException("Bad bech32 checksum")
if decoded_bech32.encoding != segwit_addr.Encoding.BECH32:
@@ -416,13 +411,12 @@ def lndecode(invoice: str, *, verbose=False, net=None) -> LnAddr:
if not hrp[2:].startswith(net.BOLT11_HRP):
raise LnDecodeException(f"Wrong Lightning invoice HRP {hrp[2:]}, should be {net.BOLT11_HRP}")
data = u5_to_bitarray(data)
# Final signature 65 bytes, split it off.
if len(data) < 65*8:
if len(data5) < 65*8//5:
raise LnDecodeException("Too short to contain signature")
sigdecoded = data[-65*8:].tobytes()
data = bitstring.ConstBitStream(data[:-65*8])
sigdecoded = bytes(convertbits(data5[-65*8//5:], 5, 8, False))
data5 = data5[:-65*8//5]
data5_remaining = bytearray(data5) # note: bytearray is faster than list of ints
addr = LnAddr()
addr.pubkey = None
@@ -439,17 +433,18 @@ def lndecode(invoice: str, *, verbose=False, net=None) -> LnAddr:
if amountstr != '':
addr.amount = unshorten_amount(amountstr)
addr.date = data.read(35).uint
addr.date = int_from_data5(data5_remaining[:7])
data5_remaining = data5_remaining[7:]
while data.pos != data.len:
tag, tagdata, data = pull_tagged(data)
while data5_remaining:
tag, tagdata = pull_tagged(data5_remaining) # mutates arg
# BOLT #11:
#
# A reader MUST skip over unknown fields, an `f` field with unknown
# `version`, or a `p`, `h`, or `n` field which does not have
# `data_length` 52, 52, or 53 respectively.
data_length = len(tagdata) / 5
data_length = len(tagdata)
if tag == 'r':
# BOLT #11:
@@ -462,24 +457,43 @@ def lndecode(invoice: str, *, verbose=False, net=None) -> LnAddr:
# * `feebase` (32 bits, big-endian)
# * `feerate` (32 bits, big-endian)
# * `cltv_expiry_delta` (16 bits, big-endian)
route=[]
s = bitstring.ConstBitStream(tagdata)
while s.pos + 264 + 64 + 32 + 32 + 16 < s.len:
route.append((s.read(264).tobytes(),
s.read(64).tobytes(),
s.read(32).uintbe,
s.read(32).uintbe,
s.read(16).uintbe))
addr.tags.append(('r',route))
tagdata = convertbits(tagdata, 5, 8, False)
if not tagdata:
continue
route = []
with io.BytesIO(bytes(tagdata)) as s:
while True:
pubkey = s.read(33)
scid = s.read(8)
feebase = s.read(4)
feerate = s.read(4)
cltv = s.read(2)
if len(cltv) != 2:
break # EOF
feebase = int.from_bytes(feebase, byteorder="big")
feerate = int.from_bytes(feerate, byteorder="big")
cltv = int.from_bytes(cltv, byteorder="big")
route.append((pubkey, scid, feebase, feerate, cltv))
if route:
addr.tags.append(('r',route))
elif tag == 't':
s = bitstring.ConstBitStream(tagdata)
e = (s.read(264).tobytes(),
s.read(32).uintbe,
s.read(32).uintbe,
s.read(16).uintbe)
addr.tags.append(('t', e))
tagdata = convertbits(tagdata, 5, 8, False)
if not tagdata:
continue
route = []
with io.BytesIO(bytes(tagdata)) as s:
pubkey = s.read(33)
feebase = s.read(4)
feerate = s.read(4)
cltv = s.read(2)
if len(cltv) == 2: # no EOF
feebase = int.from_bytes(feebase, byteorder="big")
feerate = int.from_bytes(feerate, byteorder="big")
cltv = int.from_bytes(cltv, byteorder="big")
route.append((pubkey, feebase, feerate, cltv))
addr.tags.append(('t', route))
elif tag == 'f':
fallback = parse_fallback(tagdata, addr.net)
fallback = parse_fallback_addr(tagdata, addr.net)
if fallback:
addr.tags.append(('f', fallback))
else:
@@ -488,41 +502,41 @@ def lndecode(invoice: str, *, verbose=False, net=None) -> LnAddr:
continue
elif tag == 'd':
addr.tags.append(('d', trim_to_bytes(tagdata).decode('utf-8')))
addr.tags.append(('d', bytes(convertbits(tagdata, 5, 8, False)).decode('utf-8')))
elif tag == 'h':
if data_length != 52:
addr.unknown_tags.append((tag, tagdata))
continue
addr.tags.append(('h', trim_to_bytes(tagdata)))
addr.tags.append(('h', bytes(convertbits(tagdata, 5, 8, False))))
elif tag == 'x':
addr.tags.append(('x', tagdata.uint))
addr.tags.append(('x', int_from_data5(tagdata)))
elif tag == 'p':
if data_length != 52:
addr.unknown_tags.append((tag, tagdata))
continue
addr.paymenthash = trim_to_bytes(tagdata)
addr.paymenthash = bytes(convertbits(tagdata, 5, 8, False))
elif tag == 's':
if data_length != 52:
addr.unknown_tags.append((tag, tagdata))
continue
addr.payment_secret = trim_to_bytes(tagdata)
addr.payment_secret = bytes(convertbits(tagdata, 5, 8, False))
elif tag == 'n':
if data_length != 53:
addr.unknown_tags.append((tag, tagdata))
continue
pubkeybytes = trim_to_bytes(tagdata)
pubkeybytes = bytes(convertbits(tagdata, 5, 8, False))
addr.pubkey = pubkeybytes
elif tag == 'c':
addr.tags.append(('c', tagdata.uint))
addr.tags.append(('c', int_from_data5(tagdata)))
elif tag == '9':
features = tagdata.uint
features = int_from_data5(tagdata)
addr.tags.append(('9', features))
# note: The features are not validated here in the parser,
# instead, validation is done just before we try paying the invoice (in lnworker._check_invoice).
@@ -536,16 +550,17 @@ def lndecode(invoice: str, *, verbose=False, net=None) -> LnAddr:
print('hex of signature data (32 byte r, 32 byte s): {}'
.format(hexlify(sigdecoded[0:64])))
print('recovery flag: {}'.format(sigdecoded[64]))
data8 = bytes(convertbits(data5, 5, 8, True))
print('hex of data for signing: {}'
.format(hexlify(hrp.encode("ascii") + data.tobytes())))
print('SHA256 of above: {}'.format(sha256(hrp.encode("ascii") + data.tobytes()).hexdigest()))
.format(hexlify(hrp.encode("ascii") + data8)))
print('SHA256 of above: {}'.format(sha256(hrp.encode("ascii") + data8).hexdigest()))
# BOLT #11:
#
# A reader MUST check that the `signature` is valid (see the `n` tagged
# field specified below).
addr.signature = sigdecoded[:65]
hrp_hash = sha256(hrp.encode("ascii") + data.tobytes()).digest()
hrp_hash = sha256(hrp.encode("ascii") + bytes(convertbits(data5, 5, 8, True))).digest()
if addr.pubkey: # Specified by `n`
# BOLT #11:
#