qt TaskThread: implement cancellation of tasks, for cleaner shutdown
fixes https://github.com/spesmilo/electrum/issues/7750 Each task we schedule on `TaskThread` can provide an optional `cancel` method. When stopping `TaskThread`, we call this `cancel` method on all tasks in the queue. If the currently running task does not implement `cancel`, `TaskThread.stop` will block until that task finishes. Note that there is a significant change in behaviour here: `ElectrumWindow.run_coroutine_from_thread` and `ElectrumWindow.pay_lightning_invoice` previously serialised the execution of their coroutines via wallet.thread. This is no longer the case: they will now schedule coroutines immediately. So for example, the GUI now allows trying to pay multiple LN invoices "concurrently".
This commit is contained in:
@@ -2,6 +2,9 @@
|
|||||||
# Distributed under the MIT software license, see the accompanying
|
# Distributed under the MIT software license, see the accompanying
|
||||||
# file LICENCE or http://www.opensource.org/licenses/mit-license.php
|
# file LICENCE or http://www.opensource.org/licenses/mit-license.php
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import concurrent.futures
|
||||||
|
|
||||||
from PyQt5.QtCore import Qt
|
from PyQt5.QtCore import Qt
|
||||||
from PyQt5.QtWidgets import QWidget, QVBoxLayout, QGridLayout, QLabel, QListWidget, QListWidgetItem
|
from PyQt5.QtWidgets import QWidget, QVBoxLayout, QGridLayout, QLabel, QListWidget, QListWidgetItem
|
||||||
|
|
||||||
@@ -29,15 +32,27 @@ class Bip39RecoveryDialog(WindowModalDialog):
|
|||||||
self.content = QVBoxLayout()
|
self.content = QVBoxLayout()
|
||||||
self.content.addWidget(QLabel(_('Scanning common paths for existing accounts...')))
|
self.content.addWidget(QLabel(_('Scanning common paths for existing accounts...')))
|
||||||
vbox.addLayout(self.content)
|
vbox.addLayout(self.content)
|
||||||
|
|
||||||
|
self.thread = TaskThread(self)
|
||||||
|
self.thread.finished.connect(self.deleteLater) # see #3956
|
||||||
|
network = Network.get_instance()
|
||||||
|
coro = account_discovery(network, self.get_account_xpub)
|
||||||
|
fut = asyncio.run_coroutine_threadsafe(coro, network.asyncio_loop)
|
||||||
|
self.thread.add(
|
||||||
|
fut.result,
|
||||||
|
on_success=self.on_recovery_success,
|
||||||
|
on_error=self.on_recovery_error,
|
||||||
|
cancel=fut.cancel,
|
||||||
|
)
|
||||||
|
|
||||||
self.ok_button = OkButton(self)
|
self.ok_button = OkButton(self)
|
||||||
self.ok_button.clicked.connect(self.on_ok_button_click)
|
self.ok_button.clicked.connect(self.on_ok_button_click)
|
||||||
self.ok_button.setEnabled(False)
|
self.ok_button.setEnabled(False)
|
||||||
vbox.addLayout(Buttons(CancelButton(self), self.ok_button))
|
cancel_button = CancelButton(self)
|
||||||
|
cancel_button.clicked.connect(fut.cancel)
|
||||||
|
vbox.addLayout(Buttons(cancel_button, self.ok_button))
|
||||||
self.finished.connect(self.on_finished)
|
self.finished.connect(self.on_finished)
|
||||||
self.show()
|
self.show()
|
||||||
self.thread = TaskThread(self)
|
|
||||||
self.thread.finished.connect(self.deleteLater) # see #3956
|
|
||||||
self.thread.add(self.recovery, self.on_recovery_success, None, self.on_recovery_error)
|
|
||||||
|
|
||||||
def on_finished(self):
|
def on_finished(self):
|
||||||
self.thread.stop()
|
self.thread.stop()
|
||||||
@@ -47,11 +62,6 @@ class Bip39RecoveryDialog(WindowModalDialog):
|
|||||||
account = item.data(self.ROLE_ACCOUNT)
|
account = item.data(self.ROLE_ACCOUNT)
|
||||||
self.on_account_select(account)
|
self.on_account_select(account)
|
||||||
|
|
||||||
def recovery(self):
|
|
||||||
network = Network.get_instance()
|
|
||||||
coroutine = account_discovery(network, self.get_account_xpub)
|
|
||||||
return network.run_from_another_thread(coroutine)
|
|
||||||
|
|
||||||
def on_recovery_success(self, accounts):
|
def on_recovery_success(self, accounts):
|
||||||
self.clear_content()
|
self.clear_content()
|
||||||
if len(accounts) == 0:
|
if len(accounts) == 0:
|
||||||
@@ -67,6 +77,9 @@ class Bip39RecoveryDialog(WindowModalDialog):
|
|||||||
self.content.addWidget(self.list)
|
self.content.addWidget(self.list)
|
||||||
|
|
||||||
def on_recovery_error(self, exc_info):
|
def on_recovery_error(self, exc_info):
|
||||||
|
e = exc_info[1]
|
||||||
|
if isinstance(e, concurrent.futures.CancelledError):
|
||||||
|
return
|
||||||
self.clear_content()
|
self.clear_content()
|
||||||
self.content.addWidget(QLabel(_('Error: Account discovery failed.')))
|
self.content.addWidget(QLabel(_('Error: Account discovery failed.')))
|
||||||
_logger.error(f"recovery error", exc_info=exc_info)
|
_logger.error(f"recovery error", exc_info=exc_info)
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ from functools import partial
|
|||||||
import queue
|
import queue
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import Optional, TYPE_CHECKING, Sequence, List, Union
|
from typing import Optional, TYPE_CHECKING, Sequence, List, Union
|
||||||
|
import concurrent.futures
|
||||||
|
|
||||||
from PyQt5.QtGui import QPixmap, QKeySequence, QIcon, QCursor, QFont
|
from PyQt5.QtGui import QPixmap, QKeySequence, QIcon, QCursor, QFont
|
||||||
from PyQt5.QtCore import Qt, QRect, QStringListModel, QSize, pyqtSignal, QPoint
|
from PyQt5.QtCore import Qt, QRect, QStringListModel, QSize, pyqtSignal, QPoint
|
||||||
@@ -318,16 +319,18 @@ class ElectrumWindow(QMainWindow, MessageBoxMixin, Logger):
|
|||||||
self._update_check_thread.start()
|
self._update_check_thread.start()
|
||||||
|
|
||||||
def run_coroutine_from_thread(self, coro, on_result=None):
|
def run_coroutine_from_thread(self, coro, on_result=None):
|
||||||
|
fut = asyncio.run_coroutine_threadsafe(coro, self.network.asyncio_loop)
|
||||||
def task():
|
def task():
|
||||||
try:
|
try:
|
||||||
f = asyncio.run_coroutine_threadsafe(coro, self.network.asyncio_loop)
|
r = fut.result()
|
||||||
r = f.result()
|
|
||||||
if on_result:
|
if on_result:
|
||||||
on_result(r)
|
on_result(r)
|
||||||
|
except concurrent.futures.CancelledError:
|
||||||
|
self.logger.info(f"wallet.thread coro got cancelled: {coro}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.exception("exception in coro scheduled via window.wallet")
|
self.logger.exception("exception in coro scheduled via window.wallet")
|
||||||
self.show_error_signal.emit(str(e))
|
self.show_error_signal.emit(repr(e))
|
||||||
self.wallet.thread.add(task)
|
self.wallet.thread.add(task, cancel=fut.cancel)
|
||||||
|
|
||||||
def on_fx_history(self):
|
def on_fx_history(self):
|
||||||
self.history_model.refresh('fx_history')
|
self.history_model.refresh('fx_history')
|
||||||
@@ -400,7 +403,7 @@ class ElectrumWindow(QMainWindow, MessageBoxMixin, Logger):
|
|||||||
|
|
||||||
def on_error(self, exc_info):
|
def on_error(self, exc_info):
|
||||||
e = exc_info[1]
|
e = exc_info[1]
|
||||||
if isinstance(e, UserCancelled):
|
if isinstance(e, (UserCancelled, concurrent.futures.CancelledError)):
|
||||||
pass
|
pass
|
||||||
elif isinstance(e, UserFacingException):
|
elif isinstance(e, UserFacingException):
|
||||||
self.show_error(str(e))
|
self.show_error(str(e))
|
||||||
@@ -1592,11 +1595,8 @@ class ElectrumWindow(QMainWindow, MessageBoxMixin, Logger):
|
|||||||
if not self.question(msg):
|
if not self.question(msg):
|
||||||
return
|
return
|
||||||
self.save_pending_invoice()
|
self.save_pending_invoice()
|
||||||
def task():
|
coro = self.wallet.lnworker.pay_invoice(invoice, amount_msat=amount_msat, attempts=LN_NUM_PAYMENT_ATTEMPTS)
|
||||||
coro = self.wallet.lnworker.pay_invoice(invoice, amount_msat=amount_msat, attempts=LN_NUM_PAYMENT_ATTEMPTS)
|
self.run_coroutine_from_thread(coro)
|
||||||
fut = asyncio.run_coroutine_threadsafe(coro, self.network.asyncio_loop)
|
|
||||||
return fut.result()
|
|
||||||
self.wallet.thread.add(task)
|
|
||||||
|
|
||||||
def on_request_status(self, wallet, key, status):
|
def on_request_status(self, wallet, key, status):
|
||||||
if wallet != self.wallet:
|
if wallet != self.wallet:
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ from PyQt5.QtWidgets import (QPushButton, QLabel, QMessageBox, QHBoxLayout,
|
|||||||
from electrum.i18n import _, languages
|
from electrum.i18n import _, languages
|
||||||
from electrum.util import FileImportFailed, FileExportFailed, make_aiohttp_session, resource_path
|
from electrum.util import FileImportFailed, FileExportFailed, make_aiohttp_session, resource_path
|
||||||
from electrum.invoices import PR_UNPAID, PR_PAID, PR_EXPIRED, PR_INFLIGHT, PR_UNKNOWN, PR_FAILED, PR_ROUTING, PR_UNCONFIRMED
|
from electrum.invoices import PR_UNPAID, PR_PAID, PR_EXPIRED, PR_INFLIGHT, PR_UNKNOWN, PR_FAILED, PR_ROUTING, PR_UNCONFIRMED
|
||||||
|
from electrum.logging import Logger
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .main_window import ElectrumWindow
|
from .main_window import ElectrumWindow
|
||||||
@@ -902,7 +903,7 @@ class PasswordLineEdit(QLineEdit):
|
|||||||
super().clear()
|
super().clear()
|
||||||
|
|
||||||
|
|
||||||
class TaskThread(QThread):
|
class TaskThread(QThread, Logger):
|
||||||
'''Thread that runs background tasks. Callbacks are guaranteed
|
'''Thread that runs background tasks. Callbacks are guaranteed
|
||||||
to happen in the context of its parent.'''
|
to happen in the context of its parent.'''
|
||||||
|
|
||||||
@@ -911,24 +912,35 @@ class TaskThread(QThread):
|
|||||||
cb_success: Optional[Callable]
|
cb_success: Optional[Callable]
|
||||||
cb_done: Optional[Callable]
|
cb_done: Optional[Callable]
|
||||||
cb_error: Optional[Callable]
|
cb_error: Optional[Callable]
|
||||||
|
cancel: Optional[Callable] = None
|
||||||
|
|
||||||
doneSig = pyqtSignal(object, object, object)
|
doneSig = pyqtSignal(object, object, object)
|
||||||
|
|
||||||
def __init__(self, parent, on_error=None):
|
def __init__(self, parent, on_error=None):
|
||||||
super(TaskThread, self).__init__(parent)
|
QThread.__init__(self, parent)
|
||||||
|
Logger.__init__(self)
|
||||||
self.on_error = on_error
|
self.on_error = on_error
|
||||||
self.tasks = queue.Queue()
|
self.tasks = queue.Queue()
|
||||||
|
self._cur_task = None # type: Optional[TaskThread.Task]
|
||||||
|
self._stopping = False
|
||||||
self.doneSig.connect(self.on_done)
|
self.doneSig.connect(self.on_done)
|
||||||
self.start()
|
self.start()
|
||||||
|
|
||||||
def add(self, task, on_success=None, on_done=None, on_error=None):
|
def add(self, task, on_success=None, on_done=None, on_error=None, *, cancel=None):
|
||||||
|
if self._stopping:
|
||||||
|
self.logger.warning(f"stopping or already stopped but tried to add new task.")
|
||||||
|
return
|
||||||
on_error = on_error or self.on_error
|
on_error = on_error or self.on_error
|
||||||
self.tasks.put(TaskThread.Task(task, on_success, on_done, on_error))
|
task_ = TaskThread.Task(task, on_success, on_done, on_error, cancel=cancel)
|
||||||
|
self.tasks.put(task_)
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
while True:
|
while True:
|
||||||
|
if self._stopping:
|
||||||
|
break
|
||||||
task = self.tasks.get() # type: TaskThread.Task
|
task = self.tasks.get() # type: TaskThread.Task
|
||||||
if not task:
|
self._cur_task = task
|
||||||
|
if not task or self._stopping:
|
||||||
break
|
break
|
||||||
try:
|
try:
|
||||||
result = task.task()
|
result = task.task()
|
||||||
@@ -944,7 +956,21 @@ class TaskThread(QThread):
|
|||||||
cb_result(result)
|
cb_result(result)
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
self.tasks.put(None)
|
self._stopping = True
|
||||||
|
# try to cancel currently running task now.
|
||||||
|
# if the task does not implement "cancel", we will have to wait until it finishes.
|
||||||
|
task = self._cur_task
|
||||||
|
if task and task.cancel:
|
||||||
|
task.cancel()
|
||||||
|
# cancel the remaining tasks in the queue
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
task = self.tasks.get_nowait()
|
||||||
|
except queue.Empty:
|
||||||
|
break
|
||||||
|
if task and task.cancel:
|
||||||
|
task.cancel()
|
||||||
|
self.tasks.put(None) # in case the thread is still waiting on the queue
|
||||||
self.exit()
|
self.exit()
|
||||||
self.wait()
|
self.wait()
|
||||||
|
|
||||||
|
|||||||
@@ -1106,7 +1106,9 @@ class LNWallet(LNWorker):
|
|||||||
self.save_payment_info(info)
|
self.save_payment_info(info)
|
||||||
self.wallet.set_label(key, lnaddr.get_description())
|
self.wallet.set_label(key, lnaddr.get_description())
|
||||||
|
|
||||||
|
self.logger.info(f"pay_invoice starting session for RHASH={payment_hash.hex()}")
|
||||||
self.set_invoice_status(key, PR_INFLIGHT)
|
self.set_invoice_status(key, PR_INFLIGHT)
|
||||||
|
success = False
|
||||||
try:
|
try:
|
||||||
await self.pay_to_node(
|
await self.pay_to_node(
|
||||||
node_pubkey=invoice_pubkey,
|
node_pubkey=invoice_pubkey,
|
||||||
@@ -1121,8 +1123,9 @@ class LNWallet(LNWorker):
|
|||||||
success = True
|
success = True
|
||||||
except PaymentFailure as e:
|
except PaymentFailure as e:
|
||||||
self.logger.info(f'payment failure: {e!r}')
|
self.logger.info(f'payment failure: {e!r}')
|
||||||
success = False
|
|
||||||
reason = str(e)
|
reason = str(e)
|
||||||
|
finally:
|
||||||
|
self.logger.info(f"pay_invoice ending session for RHASH={payment_hash.hex()}. {success=}")
|
||||||
if success:
|
if success:
|
||||||
self.set_invoice_status(key, PR_PAID)
|
self.set_invoice_status(key, PR_PAID)
|
||||||
util.trigger_callback('payment_succeeded', self.wallet, key)
|
util.trigger_callback('payment_succeeded', self.wallet, key)
|
||||||
|
|||||||
Reference in New Issue
Block a user