1
0

swaps: code style clean-up, add type hints, force kwargs

no intended functional changes
This commit is contained in:
SomberNight
2023-11-22 17:50:29 +00:00
parent 9d5d582752
commit 9f1b8613d0
8 changed files with 120 additions and 66 deletions

View File

@@ -1,7 +1,7 @@
import asyncio
import json
import os
from typing import TYPE_CHECKING, Optional, Dict, Union
from typing import TYPE_CHECKING, Optional, Dict, Union, Sequence, Tuple
from decimal import Decimal
import math
import time
@@ -9,6 +9,7 @@ import time
import attr
import aiohttp
from . import lnutil
from .crypto import sha256, hash_160
from .ecc import ECPrivkey
from .bitcoin import (script_to_p2wsh, opcodes, p2wsh_nested_script, push_script,
@@ -38,6 +39,7 @@ if TYPE_CHECKING:
from .wallet import Abstract_Wallet
from .lnwatcher import LNWalletWatcher
from .lnworker import LNWallet
from .lnchannel import Channel
from .simple_config import SimpleConfig
@@ -76,7 +78,16 @@ WITNESS_TEMPLATE_REVERSE_SWAP = [
opcodes.OP_CHECKSIG
]
def check_reverse_redeem_script(redeem_script, lockup_address, payment_hash, locktime, *, refund_pubkey=None, claim_pubkey=None):
def check_reverse_redeem_script(
*,
redeem_script: str,
lockup_address: str,
payment_hash: bytes,
locktime: int,
refund_pubkey: bytes = None,
claim_pubkey: bytes = None,
) -> None:
redeem_script = bytes.fromhex(redeem_script)
parsed_script = [x for x in script_GetOp(redeem_script)]
if not match_script_against_template(redeem_script, WITNESS_TEMPLATE_REVERSE_SWAP):
@@ -91,7 +102,6 @@ def check_reverse_redeem_script(redeem_script, lockup_address, payment_hash, loc
raise Exception("rswap check failed: our pubkey not in script")
if locktime != int.from_bytes(parsed_script[10][1], byteorder='little'):
raise Exception("rswap check failed: inconsistent locktime and script")
return parsed_script[7][1], parsed_script[13][1]
class SwapServerError(Exception):
@@ -109,7 +119,7 @@ class SwapData(StoredObject):
onchain_amount = attr.ib(type=int) # in sats
lightning_amount = attr.ib(type=int) # in sats
redeem_script = attr.ib(type=bytes, converter=hex_to_bytes)
preimage = attr.ib(type=bytes, converter=hex_to_bytes)
preimage = attr.ib(type=Optional[bytes], converter=hex_to_bytes)
prepay_hash = attr.ib(type=Optional[bytes], converter=hex_to_bytes)
privkey = attr.ib(type=bytes, converter=hex_to_bytes)
lockup_address = attr.ib(type=str)
@@ -349,6 +359,7 @@ class SwapManager(Logger):
return self.get_fee(CLAIM_FEE_SIZE)
def get_fee(self, size):
# note: 'size' is in vbytes
return self._get_fee(size=size, config=self.wallet.config)
@classmethod
@@ -376,15 +387,16 @@ class SwapManager(Logger):
if swap.funding_txid is None:
password = self.wallet.get_unlocked_password()
for batch_rbf in [True, False]:
tx = self.create_funding_tx(swap, None, password, batch_rbf=batch_rbf)
tx = self.create_funding_tx(swap, None, password=password, batch_rbf=batch_rbf)
try:
await self.broadcast_funding_tx(swap, tx)
except TxBroadcastServerReturnedError:
continue
break
def create_normal_swap(self, *, lightning_amount_sat=None, payment_hash: bytes=None, their_pubkey=None):
def create_normal_swap(self, *, lightning_amount_sat: int, payment_hash: bytes, their_pubkey: bytes = None):
""" server method """
assert lightning_amount_sat
locktime = self.network.get_local_height() + LOCKTIME_DELTA_REFUND
our_privkey = os.urandom(32)
our_pubkey = ECPrivkey(our_privkey).get_public_key_bytes(compressed=True)
@@ -400,8 +412,6 @@ class SwapManager(Logger):
lightning_amount_sat=lightning_amount_sat,
payment_hash=payment_hash,
our_privkey=our_privkey,
their_pubkey=their_pubkey,
invoice=None,
prepay=True,
)
self.lnworker.register_hold_invoice(payment_hash, self.hold_invoice_callback)
@@ -409,35 +419,32 @@ class SwapManager(Logger):
def add_normal_swap(
self, *,
redeem_script=None,
locktime=None,
onchain_amount_sat=None,
lightning_amount_sat=None,
payment_hash=None,
our_privkey=None,
their_pubkey=None,
invoice=None,
prepay=None,
channels=None,
):
""" if invoice is None, create a hold invoice """
redeem_script: str,
locktime: int, # onchain
onchain_amount_sat: int,
lightning_amount_sat: int,
payment_hash: bytes,
our_privkey: bytes,
prepay: bool,
channels: Optional[Sequence['Channel']] = None,
) -> Tuple[SwapData, str, str]:
"""creates a hold invoice"""
if prepay:
prepay_amount_sat = self.get_claim_fee() * 2
invoice_amount_sat = lightning_amount_sat - prepay_amount_sat
else:
invoice_amount_sat = lightning_amount_sat
if not invoice:
_, invoice = self.lnworker.get_bolt11_invoice(
payment_hash=payment_hash,
amount_msat=invoice_amount_sat * 1000,
message='Submarine swap',
expiry=300,
fallback_address=None,
channels=channels,
)
# add payment info to lnworker
self.lnworker.add_payment_info_for_hold_invoice(payment_hash, invoice_amount_sat)
_, invoice = self.lnworker.get_bolt11_invoice(
payment_hash=payment_hash,
amount_msat=invoice_amount_sat * 1000,
message='Submarine swap',
expiry=300,
fallback_address=None,
channels=channels,
)
# add payment info to lnworker
self.lnworker.add_payment_info_for_hold_invoice(payment_hash, invoice_amount_sat)
if prepay:
prepay_hash = self.lnworker.create_payment_info(amount_msat=prepay_amount_sat*1000)
@@ -477,14 +484,14 @@ class SwapManager(Logger):
self.add_lnwatcher_callback(swap)
return swap, invoice, prepay_invoice
def create_reverse_swap(self, *, lightning_amount_sat=None, their_pubkey=None):
def create_reverse_swap(self, *, lightning_amount_sat: int, their_pubkey: bytes) -> SwapData:
""" server method. """
assert lightning_amount_sat is not None
locktime = self.network.get_local_height() + LOCKTIME_DELTA_REFUND
privkey = os.urandom(32)
our_pubkey = ECPrivkey(privkey).get_public_key_bytes(compressed=True)
onchain_amount_sat = self._get_send_amount(lightning_amount_sat, is_reverse=False)
preimage = os.urandom(32)
assert lightning_amount_sat is not None
payment_hash = sha256(preimage)
redeem_script = construct_script(
WITNESS_TEMPLATE_REVERSE_SWAP,
@@ -501,7 +508,18 @@ class SwapManager(Logger):
lightning_amount_sat=lightning_amount_sat)
return swap
def add_reverse_swap(self, *, redeem_script=None, locktime=None, privkey=None, lightning_amount_sat=None, onchain_amount_sat=None, preimage=None, payment_hash=None, prepay_hash=None):
def add_reverse_swap(
self,
*,
redeem_script: str,
locktime: int, # onchain
privkey: bytes,
lightning_amount_sat: int,
onchain_amount_sat: int,
preimage: bytes,
payment_hash: bytes,
prepay_hash: Optional[bytes] = None,
) -> SwapData:
lockup_address = script_to_p2wsh(redeem_script)
receive_address = self.wallet.get_receiving_address()
swap = SwapData(
@@ -526,7 +544,7 @@ class SwapManager(Logger):
self.add_lnwatcher_callback(swap)
return swap
def add_invoice(self, invoice, pay_now=False):
def add_invoice(self, invoice: str, pay_now: bool = False) -> None:
invoice = Invoice.from_bech32(invoice)
key = invoice.rhash
payment_hash = bytes.fromhex(key)
@@ -548,28 +566,41 @@ class SwapManager(Logger):
password,
tx: PartialTransaction = None,
channels = None,
) -> str:
) -> Optional[str]:
"""send on-chain BTC, receive on Lightning
Old (removed) flow:
- User generates an LN invoice with RHASH, and knows preimage.
- User creates on-chain output locked to RHASH.
- Server pays LN invoice. User reveals preimage.
- Server spends the on-chain output using preimage.
New flow:
- user requests swap
- server creates preimage, sends RHASH to user
- user creates hold invoice, sends it to server
- User requests swap
- Server creates preimage, sends RHASH to user
- User creates hold invoice, sends it to server
- Server sends HTLC, user holds it
- User creates on-chain output locked to RHASH
- Server spends the on-chain output using preimage (revealing the preimage)
- User fulfills HTLC using preimage
"""
assert self.network
assert self.lnwatcher
swap, invoice = await self.request_normal_swap(lightning_amount_sat, expected_onchain_amount_sat, channels=channels)
tx = self.create_funding_tx(swap, tx, password)
return await self.wait_for_htlcs_and_broadcast(swap, invoice, tx)
swap, invoice = await self.request_normal_swap(
lightning_amount_sat=lightning_amount_sat,
expected_onchain_amount_sat=expected_onchain_amount_sat,
channels=channels,
)
tx = self.create_funding_tx(swap, tx, password=password)
return await self.wait_for_htlcs_and_broadcast(swap=swap, invoice=invoice, tx=tx)
async def request_normal_swap(self, lightning_amount_sat, expected_onchain_amount_sat, channels=None):
amount_msat = lightning_amount_sat * 1000
async def request_normal_swap(
self,
*,
lightning_amount_sat: int,
expected_onchain_amount_sat: int,
channels: Optional[Sequence['Channel']] = None,
) -> Tuple[SwapData, str]:
refund_privkey = os.urandom(32)
refund_pubkey = ECPrivkey(refund_privkey).get_public_key_bytes(compressed=True)
@@ -585,8 +616,6 @@ class SwapManager(Logger):
timeout=30)
data = json.loads(response)
payment_hash = bytes.fromhex(data["preimageHash"])
preimage = None
invoice = None
zeroconf = data["acceptZeroConf"]
onchain_amount = data["expectedAmount"]
@@ -594,7 +623,13 @@ class SwapManager(Logger):
lockup_address = data["address"]
redeem_script = data["redeemScript"]
# verify redeem_script is built with our pubkey and preimage
claim_pubkey, _ = check_reverse_redeem_script(redeem_script, lockup_address, payment_hash, locktime, refund_pubkey=refund_pubkey)
check_reverse_redeem_script(
redeem_script=redeem_script,
lockup_address=lockup_address,
payment_hash=payment_hash,
locktime=locktime,
refund_pubkey=refund_pubkey,
)
# check that onchain_amount is not more than what we estimated
if onchain_amount > expected_onchain_amount_sat:
@@ -611,14 +646,18 @@ class SwapManager(Logger):
onchain_amount_sat=onchain_amount,
payment_hash=payment_hash,
our_privkey=refund_privkey,
their_pubkey=claim_pubkey,
invoice=invoice,
prepay=False,
channels=channels,
)
return swap, invoice
async def wait_for_htlcs_and_broadcast(self, swap, invoice, tx):
async def wait_for_htlcs_and_broadcast(
self,
*,
swap: SwapData,
invoice: str,
tx: Transaction,
) -> Optional[str]:
payment_hash = swap.payment_hash
refund_pubkey = ECPrivkey(swap.privkey).get_public_key_bytes(compressed=True)
async def callback(payment_hash):
@@ -644,7 +683,14 @@ class SwapManager(Logger):
await asyncio.sleep(0.1)
return swap.funding_txid
def create_funding_tx(self, swap, tx, password, *, batch_rbf: Optional[bool] = None):
def create_funding_tx(
self,
swap: SwapData,
tx: Optional[PartialTransaction],
*,
password,
batch_rbf: Optional[bool] = None,
) -> PartialTransaction:
# create funding tx
# note: rbf must not decrease payment
# this is taken care of in wallet._is_rbf_allowed_to_touch_tx_output
@@ -663,7 +709,7 @@ class SwapManager(Logger):
return tx
@log_exceptions
async def request_swap_for_tx(self, tx: 'PartialTransaction'):
async def request_swap_for_tx(self, tx: 'PartialTransaction') -> Optional[Tuple[SwapData, str, PartialTransaction]]:
for o in tx.outputs():
if o.address == self.dummy_address:
change_amount = o.value
@@ -679,7 +725,7 @@ class SwapManager(Logger):
return swap, invoice, tx
@log_exceptions
async def broadcast_funding_tx(self, swap, tx):
async def broadcast_funding_tx(self, swap: SwapData, tx: Transaction) -> None:
swap.funding_txid = tx.txid()
await self.network.broadcast_transaction(tx)
@@ -688,7 +734,7 @@ class SwapManager(Logger):
*,
lightning_amount_sat: int,
expected_onchain_amount_sat: int,
channels = None,
channels: Optional[Sequence['Channel']] = None,
) -> Optional[str]:
"""send on Lightning, receive on-chain
@@ -729,7 +775,14 @@ class SwapManager(Logger):
onchain_amount = data["onchainAmount"]
response_id = data['id']
# verify redeem_script is built with our pubkey and preimage
check_reverse_redeem_script(redeem_script, lockup_address, payment_hash, locktime, refund_pubkey=None, claim_pubkey=our_pubkey)
check_reverse_redeem_script(
redeem_script=redeem_script,
lockup_address=lockup_address,
payment_hash=payment_hash,
locktime=locktime,
refund_pubkey=None,
claim_pubkey=our_pubkey,
)
# check that the onchain amount is what we expected
if onchain_amount < expected_onchain_amount_sat:
raise Exception(f"rswap check failed: onchain_amount is less than what we expected: "