1
0

hardware devices: run all device communication on dedicated thread (#6561)

hidapi/libusb etc are not thread-safe.

related: #6554
This commit is contained in:
ghost43
2020-09-08 15:52:53 +00:00
committed by GitHub
parent 53a5a21ee8
commit 21c3572600
12 changed files with 195 additions and 97 deletions

View File

@@ -29,9 +29,10 @@ import time
import threading
import sys
from typing import (NamedTuple, Any, Union, TYPE_CHECKING, Optional, Tuple,
Dict, Iterable, List, Sequence)
Dict, Iterable, List, Sequence, Callable, TypeVar)
import concurrent
from concurrent import futures
from functools import wraps, partial
from .i18n import _
from .util import (profiler, DaemonThread, UserCancelled, ThreadJob, UserFacingException)
@@ -334,11 +335,37 @@ PLACEHOLDER_HW_CLIENT_LABELS = {None, "", " "}
# https://github.com/signal11/hidapi/pull/414#issuecomment-445164238
# It is not entirely clear to me, exactly what is safe and what isn't, when
# using multiple threads...
# For now, we use a dedicated thread to enumerate devices (_hid_executor),
# and we synchronize all device opens/closes/enumeration (_hid_lock).
# FIXME there are still probably threading issues with how we use hidapi...
_hid_executor = None # type: Optional[concurrent.futures.Executor]
_hid_lock = threading.Lock()
# Hence, we use a single thread for all device communications, including
# enumeration. Everything that uses hidapi, libusb, etc, MUST run on
# the following thread:
_hwd_comms_executor = concurrent.futures.ThreadPoolExecutor(
max_workers=1,
thread_name_prefix='hwd_comms_thread'
)
T = TypeVar('T')
def run_in_hwd_thread(func: Callable[[], T]) -> T:
if threading.current_thread().name.startswith("hwd_comms_thread"):
return func()
else:
fut = _hwd_comms_executor.submit(func)
return fut.result()
#except (concurrent.futures.CancelledError, concurrent.futures.TimeoutError) as e:
def runs_in_hwd_thread(func):
@wraps(func)
def wrapper(*args, **kwargs):
return run_in_hwd_thread(partial(func, *args, **kwargs))
return wrapper
def assert_runs_in_hwd_thread():
if not threading.current_thread().name.startswith("hwd_comms_thread"):
raise Exception("must only be called from HWD communication thread")
class DeviceMgr(ThreadJob):
@@ -384,24 +411,11 @@ class DeviceMgr(ThreadJob):
self._recognised_hardware = {} # type: Dict[Tuple[int, int], HW_PluginBase]
# Custom enumerate functions for devices we don't know about.
self._enumerate_func = set() # Needs self.lock.
# locks: if you need to take multiple ones, acquire them in the order they are defined here!
self._scan_lock = threading.RLock()
self.lock = threading.RLock()
self.hid_lock = _hid_lock
self.config = config
global _hid_executor
if _hid_executor is None:
_hid_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1,
thread_name_prefix='hid_enumerate_thread')
def with_scan_lock(func):
def func_wrapper(self: 'DeviceMgr', *args, **kwargs):
with self._scan_lock:
return func(self, *args, **kwargs)
return func_wrapper
def thread_jobs(self):
# Thread job to handle device timeouts
return [self]
@@ -423,6 +437,7 @@ class DeviceMgr(ThreadJob):
with self.lock:
self._enumerate_func.add(func)
@runs_in_hwd_thread
def create_client(self, device: 'Device', handler: Optional['HardwareHandlerBase'],
plugin: 'HW_PluginBase') -> Optional['HardwareClientBase']:
# Get from cache first
@@ -452,7 +467,7 @@ class DeviceMgr(ThreadJob):
if xpub not in self.xpub_ids:
return
_id = self.xpub_ids.pop(xpub)
self._close_client(_id)
self._close_client(_id)
def unpair_id(self, id_):
xpub = self.xpub_by_id(id_)
@@ -462,8 +477,9 @@ class DeviceMgr(ThreadJob):
self._close_client(id_)
def _close_client(self, id_):
client = self._client_by_id(id_)
self.clients.pop(client, None)
with self.lock:
client = self._client_by_id(id_)
self.clients.pop(client, None)
if client:
client.close()
@@ -486,7 +502,7 @@ class DeviceMgr(ThreadJob):
self.scan_devices()
return self._client_by_id(id_)
@with_scan_lock
@runs_in_hwd_thread
def client_for_keystore(self, plugin: 'HW_PluginBase', handler: Optional['HardwareHandlerBase'],
keystore: 'Hardware_KeyStore',
force_pair: bool, *,
@@ -655,25 +671,15 @@ class DeviceMgr(ThreadJob):
# note: updated label/soft_device_id will be saved after pairing succeeds
return info
@with_scan_lock
@runs_in_hwd_thread
def _scan_devices_with_hid(self) -> List['Device']:
try:
import hid
except ImportError:
return []
def hid_enumerate():
with self.hid_lock:
return hid.enumerate(0, 0)
hid_list_fut = _hid_executor.submit(hid_enumerate)
try:
hid_list = hid_list_fut.result()
except (concurrent.futures.CancelledError, concurrent.futures.TimeoutError) as e:
return []
devices = []
for d in hid_list:
for d in hid.enumerate(0, 0):
product_key = (d['vendor_id'], d['product_id'])
if product_key in self._recognised_hardware:
plugin = self._recognised_hardware[product_key]
@@ -681,7 +687,7 @@ class DeviceMgr(ThreadJob):
devices.append(device)
return devices
@with_scan_lock
@runs_in_hwd_thread
@profiler
def scan_devices(self) -> Sequence['Device']:
self.logger.info("scanning devices...")
@@ -693,10 +699,8 @@ class DeviceMgr(ThreadJob):
with self.lock:
enumerate_funcs = list(self._enumerate_func)
for f in enumerate_funcs:
# custom enumerate functions might use hidapi, so use hid thread to be safe
new_devices_fut = _hid_executor.submit(f)
try:
new_devices = new_devices_fut.result()
new_devices = f()
except BaseException as e:
self.logger.error('custom device enum failed. func {}, error {}'
.format(str(f), repr(e)))