1
0

Better support for USB devices

Benefits of this rewrite include:

- support of disconnecting / reconnecting a device without having
  to close the wallet, even in a different USB socket
- support of multiple keepkey / trezor devices, both during wallet
  creation and general use
- wallet is watching-only dynamically according to whether the
  associated device is currently plugged in or not
This commit is contained in:
Neil Booth
2016-01-02 09:43:56 +09:00
parent 187b4dc9c1
commit 21bf5a8a84
11 changed files with 345 additions and 225 deletions

View File

@@ -1,4 +1,6 @@
import re
import time
from binascii import unhexlify
from struct import pack
from unicodedata import normalize
@@ -12,6 +14,9 @@ from electrum.transaction import (deserialize, is_extended_pubkey,
Transaction, x_to_xpub)
from electrum.wallet import BIP32_HD_Wallet, BIP44_Wallet
class DeviceDisconnectedError(Exception):
pass
class TrezorCompatibleWallet(BIP44_Wallet):
# Extend BIP44 Wallet as required by hardware implementation.
# Derived classes must set:
@@ -22,11 +27,21 @@ class TrezorCompatibleWallet(BIP44_Wallet):
def __init__(self, storage):
BIP44_Wallet.__init__(self, storage)
self.proper_device = False
# This is set when paired with a device, and used to re-pair
# a device that is disconnected and re-connected
self.device_id = None
# Errors and other user interaction is done through the wallet's
# handler. The handler is per-window and preserved across
# device reconnects
self.handler = None
def give_error(self, message):
self.print_error(message)
raise Exception(message)
def disconnected(self):
self.print_error("disconnected")
self.handler.watching_only_changed()
def connected(self):
self.print_error("connected")
self.handler.watching_only_changed()
def get_action(self):
pass
@@ -35,29 +50,29 @@ class TrezorCompatibleWallet(BIP44_Wallet):
return False
def is_watching_only(self):
'''The wallet is watching-only if its trezor device is not
connected. This result is dynamic and changes over time.'''
assert not self.has_seed()
return not self.proper_device
return self.plugin.lookup_client(self) is None
def can_change_password(self):
return False
def get_client(self):
return self.plugin.get_client(self)
def check_proper_device(self):
return self.get_client().check_proper_device(self)
def client(self):
return self.plugin.client(self)
def derive_xkeys(self, root, derivation, password):
if self.master_public_keys.get(root):
return BIP44_wallet.derive_xkeys(self, root, derivation, password)
# Happens when creating a wallet
# When creating a wallet we need to ask the device for the
# master public key
derivation = derivation.replace(self.root_name, self.prefix() + "/")
xpub = self.get_public_key(derivation)
return xpub, None
def get_public_key(self, bip32_path):
client = self.get_client()
client = self.client()
address_n = client.expand_path(bip32_path)
node = client.get_public_node(address_n).node
xpub = ("0488B21E".decode('hex') + chr(node.depth)
@@ -72,25 +87,15 @@ class TrezorCompatibleWallet(BIP44_Wallet):
raise RuntimeError(_('Decrypt method is not implemented'))
def sign_message(self, address, message, password):
client = self.get_client()
self.check_proper_device()
try:
address_path = self.address_id(address)
address_n = client.expand_path(address_path)
except Exception as e:
self.give_error(e)
try:
msg_sig = client.sign_message('Bitcoin', address_n, message)
except Exception as e:
self.give_error(e)
finally:
self.plugin.get_handler(self).stop()
client = self.client()
address_path = self.address_id(address)
address_n = client.expand_path(address_path)
msg_sig = client.sign_message('Bitcoin', address_n, message)
return msg_sig.signature
def sign_transaction(self, tx, password):
if tx.is_complete() or self.is_watching_only():
return
self.check_proper_device()
# previous transactions used as inputs
prev_tx = {}
# path of the xpubs that are involved
@@ -123,51 +128,172 @@ class TrezorCompatiblePlugin(BasePlugin):
# libraries_available, libraries_URL, minimum_firmware,
# wallet_class, ckd_public, types, HidTransport
# This plugin automatically keeps track of attached devices, and
# connects to anything attached creating a new Client instance.
# When disconnected, the client is informed via a callback.
# As a device can be disconnected and/or reconnected in a different
# USB port (giving it a new path), the wallet must be dynamic in
# asking for its client.
# If a wallet is successfully paired with a given device, the plugin
# stores its serial number in the wallet so it can be automatically
# re-paired if the same device is connected elsewhere.
# Approaching things this way permits several devices to be connected
# simultaneously and handled smoothly.
def __init__(self, parent, config, name):
BasePlugin.__init__(self, parent, config, name)
self.device = self.wallet_class.device
self.client = None
self.wallet_class.plugin = self
# A set of client instances to USB paths
self.clients = set()
# The device wallets we have seen to inform on reconnection
self.paired_wallets = set()
# Do an initial scan
self.last_scan = 0
self.timer_actions()
def give_error(self, message):
self.print_error(message)
raise Exception(message)
@hook
def timer_actions(self):
if self.libraries_available:
# Scan connected devices every second
now = time.time()
if now > self.last_scan + 1:
self.last_scan = now
self.scan_devices()
def scan_devices(self):
paths = self.HidTransport.enumerate()
connected = set([c for c in self.clients if c.path in paths])
disconnected = self.clients - connected
# Inform clients and wallets they were disconnected
for client in disconnected:
self.print_error("device disconnected:", client)
if client.wallet:
client.wallet.disconnected()
for path in paths:
# Look for new paths
if any(c.path == path for c in connected):
continue
try:
transport = self.HidTransport(path)
except BaseException as e:
# We were probably just disconnected; never mind
self.print_error("cannot connect at", path, str(e))
continue
self.print_error("connected to device at", path[0])
try:
client = self.client_class(transport, path, self)
except BaseException as e:
self.print_error("cannot create client for", path, str(e))
else:
connected.add(client)
self.print_error("new device:", client)
# Inform reconnected wallets
for wallet in self.paired_wallets:
if wallet.device_id == client.features.device_id:
client.wallet = wallet
wallet.connected()
self.clients = connected
def clear_session(self, client):
# Clearing the session forces pin re-entry
self.print_error("clear session:", client)
client.clear_session()
def select_device(self, wallet, wizard):
'''Called when creating a new wallet. Select the device
to use.'''
clients = list(self.clients)
if not len(clients):
return
if len(clients) > 1:
labels = [client.label() for client in clients]
msg = _("Please select which %s device to use:") % self.device
client = clients[wizard.query_choice(msg, labels)]
else:
client = clients[0]
self.pair_wallet(wallet, client)
def pair_wallet(self, wallet, client):
self.print_error("pairing wallet %s to device %s" % (wallet, client))
self.paired_wallets.add(wallet)
wallet.device_id = client.features.device_id
client.wallet = wallet
wallet.connected()
def try_to_pair_wallet(self, wallet):
'''Call this when loading an existing wallet to find if the
associated device is connected.'''
account = '0'
if not account in wallet.accounts:
self.print_error("try pair_wallet: wallet has no accounts")
return None
first_address = wallet.accounts[account].first_address()[0]
derivation = wallet.address_derivation(account, 0, 0)
for client in self.clients:
if client.wallet:
continue
if not client.atleast_version(*self.minimum_firmware):
wallet.handler.show_error(
_('Outdated %s firmware for device labelled %s. Please '
'download the updated firmware from %s') %
(self.device, client.label(), self.firmware_URL))
continue
# This gives us a handler
client.wallet = wallet
device_address = None
try:
device_address = client.address_from_derivation(derivation)
finally:
client.wallet = None
if first_address == device_address:
self.pair_wallet(wallet, client)
return client
return None
def lookup_client(self, wallet):
for client in self.clients:
if client.features.device_id == wallet.device_id:
return client
return None
def client(self, wallet):
'''Returns a wrapped client which handles cleanup in case of
thrown exceptions, etc.'''
assert isinstance(wallet, self.wallet_class)
assert wallet.handler != None
if wallet.device_id is None:
client = self.try_to_pair_wallet(wallet)
else:
client = self.lookup_client(wallet)
if not client:
msg = (_('Could not connect to your %s. Verify the '
'cable is connected and that no other app is '
'using it.\nContinuing in watching-only mode '
'until the device is re-connected.') % self.device)
if not self.clients:
wallet.handler.show_error(msg)
raise DeviceDisconnectedError(msg)
return client
def is_enabled(self):
return self.libraries_available
def create_client(self):
if not self.libraries_available:
self.give_error(_('please install the %s libraries from %s')
% (self.device, self.libraries_URL))
devices = self.HidTransport.enumerate()
if not devices:
self.give_error(_('Could not connect to your %s. Verify the '
'cable is connected and that no other app is '
'using it.\nContinuing in watching-only mode.'
% self.device))
transport = self.HidTransport(devices[0])
client = self.client_class(transport, self)
if not client.atleast_version(*self.minimum_firmware):
self.give_error(_('Outdated %s firmware. Please update the '
'firmware from %s')
% (self.device, self.firmware_URL))
return client
def get_handler(self, wallet):
return self.get_client(wallet).handler
def get_client(self, wallet=None):
if not self.client or self.client.bad:
self.client = self.create_client()
return self.client
def atleast_version(self, major, minor=0, patch=0):
return self.get_client().atleast_version(major, minor, patch)
@staticmethod
def normalize_passphrase(self, passphrase):
return normalize('NFKD', unicode(passphrase or ''))
@@ -192,41 +318,33 @@ class TrezorCompatiblePlugin(BasePlugin):
@hook
def close_wallet(self, wallet):
if self.client:
self.print_error("clear session")
self.client.clear_session()
self.client.transport.close()
self.client = None
# Don't retain references to a closed wallet
self.paired_wallets.discard(wallet)
client = self.lookup_client(wallet)
if client:
self.clear_session(client)
# Release the device
self.clients.discard(client)
client.transport.close()
def sign_transaction(self, wallet, tx, prev_tx, xpub_path):
self.prev_tx = prev_tx
self.xpub_path = xpub_path
client = self.get_client()
client = self.client(wallet)
inputs = self.tx_inputs(tx, True)
outputs = self.tx_outputs(wallet, tx)
try:
signed_tx = client.sign_tx('Bitcoin', inputs, outputs)[1]
except Exception as e:
self.give_error(e)
finally:
self.get_handler(wallet).stop()
signed_tx = client.sign_tx('Bitcoin', inputs, outputs)[1]
raw = signed_tx.encode('hex')
tx.update_signatures(raw)
def show_address(self, wallet, address):
client = self.get_client()
wallet.check_proper_device()
try:
address_path = wallet.address_id(address)
address_n = self.client_class.expand_path(address_path)
except Exception as e:
self.give_error(e)
try:
client.get_address('Bitcoin', address_n, True)
except Exception as e:
self.give_error(e)
finally:
self.get_handler(wallet).stop()
client = self.client(wallet)
if not client.atleast_version(1, 3):
wallet.handler.show_error(_("Your device firmware is too old"))
return
address_path = wallet.address_id(address)
address_n = client.expand_path(address_path)
client.get_address('Bitcoin', address_n, True)
def tx_inputs(self, tx, for_sig=False):
inputs = []