hardware devices: run all device communication on dedicated thread (#6561)
hidapi/libusb etc are not thread-safe. related: #6554
This commit is contained in:
@@ -29,9 +29,10 @@ import time
|
||||
import threading
|
||||
import sys
|
||||
from typing import (NamedTuple, Any, Union, TYPE_CHECKING, Optional, Tuple,
|
||||
Dict, Iterable, List, Sequence)
|
||||
Dict, Iterable, List, Sequence, Callable, TypeVar)
|
||||
import concurrent
|
||||
from concurrent import futures
|
||||
from functools import wraps, partial
|
||||
|
||||
from .i18n import _
|
||||
from .util import (profiler, DaemonThread, UserCancelled, ThreadJob, UserFacingException)
|
||||
@@ -334,11 +335,37 @@ PLACEHOLDER_HW_CLIENT_LABELS = {None, "", " "}
|
||||
# https://github.com/signal11/hidapi/pull/414#issuecomment-445164238
|
||||
# It is not entirely clear to me, exactly what is safe and what isn't, when
|
||||
# using multiple threads...
|
||||
# For now, we use a dedicated thread to enumerate devices (_hid_executor),
|
||||
# and we synchronize all device opens/closes/enumeration (_hid_lock).
|
||||
# FIXME there are still probably threading issues with how we use hidapi...
|
||||
_hid_executor = None # type: Optional[concurrent.futures.Executor]
|
||||
_hid_lock = threading.Lock()
|
||||
# Hence, we use a single thread for all device communications, including
|
||||
# enumeration. Everything that uses hidapi, libusb, etc, MUST run on
|
||||
# the following thread:
|
||||
_hwd_comms_executor = concurrent.futures.ThreadPoolExecutor(
|
||||
max_workers=1,
|
||||
thread_name_prefix='hwd_comms_thread'
|
||||
)
|
||||
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
def run_in_hwd_thread(func: Callable[[], T]) -> T:
|
||||
if threading.current_thread().name.startswith("hwd_comms_thread"):
|
||||
return func()
|
||||
else:
|
||||
fut = _hwd_comms_executor.submit(func)
|
||||
return fut.result()
|
||||
#except (concurrent.futures.CancelledError, concurrent.futures.TimeoutError) as e:
|
||||
|
||||
|
||||
def runs_in_hwd_thread(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
return run_in_hwd_thread(partial(func, *args, **kwargs))
|
||||
return wrapper
|
||||
|
||||
|
||||
def assert_runs_in_hwd_thread():
|
||||
if not threading.current_thread().name.startswith("hwd_comms_thread"):
|
||||
raise Exception("must only be called from HWD communication thread")
|
||||
|
||||
|
||||
class DeviceMgr(ThreadJob):
|
||||
@@ -384,24 +411,11 @@ class DeviceMgr(ThreadJob):
|
||||
self._recognised_hardware = {} # type: Dict[Tuple[int, int], HW_PluginBase]
|
||||
# Custom enumerate functions for devices we don't know about.
|
||||
self._enumerate_func = set() # Needs self.lock.
|
||||
# locks: if you need to take multiple ones, acquire them in the order they are defined here!
|
||||
self._scan_lock = threading.RLock()
|
||||
|
||||
self.lock = threading.RLock()
|
||||
self.hid_lock = _hid_lock
|
||||
|
||||
self.config = config
|
||||
|
||||
global _hid_executor
|
||||
if _hid_executor is None:
|
||||
_hid_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1,
|
||||
thread_name_prefix='hid_enumerate_thread')
|
||||
|
||||
def with_scan_lock(func):
|
||||
def func_wrapper(self: 'DeviceMgr', *args, **kwargs):
|
||||
with self._scan_lock:
|
||||
return func(self, *args, **kwargs)
|
||||
return func_wrapper
|
||||
|
||||
def thread_jobs(self):
|
||||
# Thread job to handle device timeouts
|
||||
return [self]
|
||||
@@ -423,6 +437,7 @@ class DeviceMgr(ThreadJob):
|
||||
with self.lock:
|
||||
self._enumerate_func.add(func)
|
||||
|
||||
@runs_in_hwd_thread
|
||||
def create_client(self, device: 'Device', handler: Optional['HardwareHandlerBase'],
|
||||
plugin: 'HW_PluginBase') -> Optional['HardwareClientBase']:
|
||||
# Get from cache first
|
||||
@@ -452,7 +467,7 @@ class DeviceMgr(ThreadJob):
|
||||
if xpub not in self.xpub_ids:
|
||||
return
|
||||
_id = self.xpub_ids.pop(xpub)
|
||||
self._close_client(_id)
|
||||
self._close_client(_id)
|
||||
|
||||
def unpair_id(self, id_):
|
||||
xpub = self.xpub_by_id(id_)
|
||||
@@ -462,8 +477,9 @@ class DeviceMgr(ThreadJob):
|
||||
self._close_client(id_)
|
||||
|
||||
def _close_client(self, id_):
|
||||
client = self._client_by_id(id_)
|
||||
self.clients.pop(client, None)
|
||||
with self.lock:
|
||||
client = self._client_by_id(id_)
|
||||
self.clients.pop(client, None)
|
||||
if client:
|
||||
client.close()
|
||||
|
||||
@@ -486,7 +502,7 @@ class DeviceMgr(ThreadJob):
|
||||
self.scan_devices()
|
||||
return self._client_by_id(id_)
|
||||
|
||||
@with_scan_lock
|
||||
@runs_in_hwd_thread
|
||||
def client_for_keystore(self, plugin: 'HW_PluginBase', handler: Optional['HardwareHandlerBase'],
|
||||
keystore: 'Hardware_KeyStore',
|
||||
force_pair: bool, *,
|
||||
@@ -655,25 +671,15 @@ class DeviceMgr(ThreadJob):
|
||||
# note: updated label/soft_device_id will be saved after pairing succeeds
|
||||
return info
|
||||
|
||||
@with_scan_lock
|
||||
@runs_in_hwd_thread
|
||||
def _scan_devices_with_hid(self) -> List['Device']:
|
||||
try:
|
||||
import hid
|
||||
except ImportError:
|
||||
return []
|
||||
|
||||
def hid_enumerate():
|
||||
with self.hid_lock:
|
||||
return hid.enumerate(0, 0)
|
||||
|
||||
hid_list_fut = _hid_executor.submit(hid_enumerate)
|
||||
try:
|
||||
hid_list = hid_list_fut.result()
|
||||
except (concurrent.futures.CancelledError, concurrent.futures.TimeoutError) as e:
|
||||
return []
|
||||
|
||||
devices = []
|
||||
for d in hid_list:
|
||||
for d in hid.enumerate(0, 0):
|
||||
product_key = (d['vendor_id'], d['product_id'])
|
||||
if product_key in self._recognised_hardware:
|
||||
plugin = self._recognised_hardware[product_key]
|
||||
@@ -681,7 +687,7 @@ class DeviceMgr(ThreadJob):
|
||||
devices.append(device)
|
||||
return devices
|
||||
|
||||
@with_scan_lock
|
||||
@runs_in_hwd_thread
|
||||
@profiler
|
||||
def scan_devices(self) -> Sequence['Device']:
|
||||
self.logger.info("scanning devices...")
|
||||
@@ -693,10 +699,8 @@ class DeviceMgr(ThreadJob):
|
||||
with self.lock:
|
||||
enumerate_funcs = list(self._enumerate_func)
|
||||
for f in enumerate_funcs:
|
||||
# custom enumerate functions might use hidapi, so use hid thread to be safe
|
||||
new_devices_fut = _hid_executor.submit(f)
|
||||
try:
|
||||
new_devices = new_devices_fut.result()
|
||||
new_devices = f()
|
||||
except BaseException as e:
|
||||
self.logger.error('custom device enum failed. func {}, error {}'
|
||||
.format(str(f), repr(e)))
|
||||
|
||||
Reference in New Issue
Block a user