1
0

integrate PSBT support natively. WIP

This commit is contained in:
SomberNight
2019-10-23 17:09:41 +02:00
parent 6d12ebabbb
commit bafe8a2fff
61 changed files with 3405 additions and 3310 deletions

View File

@@ -29,9 +29,9 @@ from collections import defaultdict
from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple, NamedTuple, Sequence
from . import bitcoin
from .bitcoin import COINBASE_MATURITY, TYPE_ADDRESS, TYPE_PUBKEY
from .bitcoin import COINBASE_MATURITY
from .util import profiler, bfh, TxMinedInfo
from .transaction import Transaction, TxOutput
from .transaction import Transaction, TxOutput, TxInput, PartialTxInput, TxOutpoint
from .synchronizer import Synchronizer
from .verifier import SPV
from .blockchain import hash_header
@@ -125,12 +125,12 @@ class AddressSynchronizer(Logger):
"""Return number of transactions where address is involved."""
return len(self._history_local.get(addr, ()))
def get_txin_address(self, txi) -> Optional[str]:
addr = txi.get('address')
if addr and addr != "(pubkey)":
return addr
prevout_hash = txi.get('prevout_hash')
prevout_n = txi.get('prevout_n')
def get_txin_address(self, txin: TxInput) -> Optional[str]:
if isinstance(txin, PartialTxInput):
if txin.address:
return txin.address
prevout_hash = txin.prevout.txid.hex()
prevout_n = txin.prevout.out_idx
for addr in self.db.get_txo_addresses(prevout_hash):
l = self.db.get_txo_addr(prevout_hash, addr)
for n, v, is_cb in l:
@@ -138,14 +138,8 @@ class AddressSynchronizer(Logger):
return addr
return None
def get_txout_address(self, txo: TxOutput):
if txo.type == TYPE_ADDRESS:
addr = txo.address
elif txo.type == TYPE_PUBKEY:
addr = bitcoin.public_key_to_p2pkh(bfh(txo.address))
else:
addr = None
return addr
def get_txout_address(self, txo: TxOutput) -> Optional[str]:
return txo.address
def load_unverified_transactions(self):
# review transactions that are in the history
@@ -183,7 +177,7 @@ class AddressSynchronizer(Logger):
if self.synchronizer:
self.synchronizer.add(address)
def get_conflicting_transactions(self, tx_hash, tx, include_self=False):
def get_conflicting_transactions(self, tx_hash, tx: Transaction, include_self=False):
"""Returns a set of transaction hashes from the wallet history that are
directly conflicting with tx, i.e. they have common outpoints being
spent with tx.
@@ -194,10 +188,10 @@ class AddressSynchronizer(Logger):
conflicting_txns = set()
with self.transaction_lock:
for txin in tx.inputs():
if txin['type'] == 'coinbase':
if txin.is_coinbase():
continue
prevout_hash = txin['prevout_hash']
prevout_n = txin['prevout_n']
prevout_hash = txin.prevout.txid.hex()
prevout_n = txin.prevout.out_idx
spending_tx_hash = self.db.get_spent_outpoint(prevout_hash, prevout_n)
if spending_tx_hash is None:
continue
@@ -213,7 +207,7 @@ class AddressSynchronizer(Logger):
conflicting_txns -= {tx_hash}
return conflicting_txns
def add_transaction(self, tx_hash, tx, allow_unrelated=False) -> bool:
def add_transaction(self, tx_hash, tx: Transaction, allow_unrelated=False) -> bool:
"""Returns whether the tx was successfully added to the wallet history."""
assert tx_hash, tx_hash
assert tx, tx
@@ -226,7 +220,7 @@ class AddressSynchronizer(Logger):
# BUT we track is_mine inputs in a txn, and during subsequent calls
# of add_transaction tx, we might learn of more-and-more inputs of
# being is_mine, as we roll the gap_limit forward
is_coinbase = tx.inputs()[0]['type'] == 'coinbase'
is_coinbase = tx.inputs()[0].is_coinbase()
tx_height = self.get_tx_height(tx_hash).height
if not allow_unrelated:
# note that during sync, if the transactions are not properly sorted,
@@ -277,11 +271,11 @@ class AddressSynchronizer(Logger):
self._get_addr_balance_cache.pop(addr, None) # invalidate cache
return
for txi in tx.inputs():
if txi['type'] == 'coinbase':
if txi.is_coinbase():
continue
prevout_hash = txi['prevout_hash']
prevout_n = txi['prevout_n']
ser = prevout_hash + ':%d' % prevout_n
prevout_hash = txi.prevout.txid.hex()
prevout_n = txi.prevout.out_idx
ser = txi.prevout.to_str()
self.db.set_spent_outpoint(prevout_hash, prevout_n, tx_hash)
add_value_from_prev_output()
# add outputs
@@ -310,10 +304,10 @@ class AddressSynchronizer(Logger):
if tx is not None:
# if we have the tx, this branch is faster
for txin in tx.inputs():
if txin['type'] == 'coinbase':
if txin.is_coinbase():
continue
prevout_hash = txin['prevout_hash']
prevout_n = txin['prevout_n']
prevout_hash = txin.prevout.txid.hex()
prevout_n = txin.prevout.out_idx
self.db.remove_spent_outpoint(prevout_hash, prevout_n)
else:
# expensive but always works
@@ -572,7 +566,7 @@ class AddressSynchronizer(Logger):
return cached_local_height
return self.network.get_local_height() if self.network else self.db.get('stored_height', 0)
def add_future_tx(self, tx, num_blocks):
def add_future_tx(self, tx: Transaction, num_blocks):
with self.lock:
self.add_transaction(tx.txid(), tx)
self.future_tx[tx.txid()] = num_blocks
@@ -649,9 +643,9 @@ class AddressSynchronizer(Logger):
if self.is_mine(addr):
is_mine = True
is_relevant = True
d = self.db.get_txo_addr(txin['prevout_hash'], addr)
d = self.db.get_txo_addr(txin.prevout.txid.hex(), addr)
for n, v, cb in d:
if n == txin['prevout_n']:
if n == txin.prevout.out_idx:
value = v
break
else:
@@ -736,23 +730,19 @@ class AddressSynchronizer(Logger):
sent[txi] = height
return received, sent
def get_addr_utxo(self, address):
def get_addr_utxo(self, address: str) -> Dict[TxOutpoint, PartialTxInput]:
coins, spent = self.get_addr_io(address)
for txi in spent:
coins.pop(txi)
out = {}
for txo, v in coins.items():
for prevout_str, v in coins.items():
tx_height, value, is_cb = v
prevout_hash, prevout_n = txo.split(':')
x = {
'address':address,
'value':value,
'prevout_n':int(prevout_n),
'prevout_hash':prevout_hash,
'height':tx_height,
'coinbase':is_cb
}
out[txo] = x
prevout = TxOutpoint.from_str(prevout_str)
utxo = PartialTxInput(prevout=prevout)
utxo._trusted_address = address
utxo._trusted_value_sats = value
utxo.block_height = tx_height
out[prevout] = utxo
return out
# return the total amount ever received by an address
@@ -799,7 +789,8 @@ class AddressSynchronizer(Logger):
@with_local_height_cached
def get_utxos(self, domain=None, *, excluded_addresses=None,
mature_only: bool = False, confirmed_only: bool = False, nonlocal_only: bool = False):
mature_only: bool = False, confirmed_only: bool = False,
nonlocal_only: bool = False) -> Sequence[PartialTxInput]:
coins = []
if domain is None:
domain = self.get_addresses()
@@ -809,14 +800,15 @@ class AddressSynchronizer(Logger):
mempool_height = self.get_local_height() + 1 # height of next block
for addr in domain:
utxos = self.get_addr_utxo(addr)
for x in utxos.values():
if confirmed_only and x['height'] <= 0:
for utxo in utxos.values():
if confirmed_only and utxo.block_height <= 0:
continue
if nonlocal_only and x['height'] == TX_HEIGHT_LOCAL:
if nonlocal_only and utxo.block_height == TX_HEIGHT_LOCAL:
continue
if mature_only and x['coinbase'] and x['height'] + COINBASE_MATURITY > mempool_height:
if (mature_only and utxo.prevout.is_coinbase()
and utxo.block_height + COINBASE_MATURITY > mempool_height):
continue
coins.append(x)
coins.append(utxo)
continue
return coins