1
0

plugins: structure plugin storage in wallet

store all plugin data by plugin name in a root dictionary `plugin_data`
inside the wallet db so that plugin data can get deleted again.
Prunes the data of plugins from the wallet db on wallet stop if the
plugin is not installed anymore.
This commit is contained in:
f321x
2025-05-05 18:16:29 +02:00
parent f25ddbc8f9
commit e80551192b
9 changed files with 47 additions and 23 deletions

View File

@@ -189,6 +189,11 @@ class StoredDict(dict):
self.db.add_patch({'op': 'remove', 'path': key_path(self.path, key)})
return r
def setdefault(self, key, default = None, /):
if key not in self:
self.__setitem__(key, default)
return self[key]
class StoredList(list):

View File

@@ -654,6 +654,10 @@ class BasePlugin(Logger):
def read_file(self, filename: str) -> bytes:
return self.parent.read_file(self.name, filename)
def get_storage(self, wallet: 'Abstract_Wallet') -> dict:
"""Returns a dict which is persisted in the per-wallet database."""
plugin_storage = wallet.db.get_plugin_storage()
return plugin_storage.setdefault(self.name, {})
class DeviceUnpairableError(UserFacingException): pass
class HardwarePluginLibraryUnavailable(Exception): pass

View File

@@ -25,8 +25,6 @@ if TYPE_CHECKING:
from aiohttp_socks import ProxyConnector
STORAGE_NAME = 'nwc_plugin'
class NWCServerPlugin(BasePlugin):
URI_SCHEME = 'nostr+walletconnect://'
@@ -47,10 +45,10 @@ class NWCServerPlugin(BasePlugin):
if self.initialized:
# this might be called for several wallets. only use one.
return
storage = self.get_plugin_storage(wallet)
self.connections = storage['connections']
storage = self.get_storage(wallet)
self.connections = storage.setdefault('connections', {})
self.delete_expired_connections()
self.nwc_server = NWCServer(self.config, wallet, self.taskgroup)
self.nwc_server = NWCServer(self.config, wallet, self.taskgroup, self.connections)
asyncio.run_coroutine_threadsafe(self.taskgroup.spawn(self.nwc_server.run()), get_asyncio_loop())
self.initialized = True
@@ -67,13 +65,6 @@ class NWCServerPlugin(BasePlugin):
)
self.logger.debug(f"NWCServerPlugin closed, stopping taskgroup")
@staticmethod
def get_plugin_storage(wallet: 'Abstract_Wallet') -> dict:
storage = wallet.db.get_dict(STORAGE_NAME)
if 'connections' not in storage:
storage['connections'] = {}
return storage
def delete_expired_connections(self):
if self.connections is None:
return
@@ -167,12 +158,17 @@ class NWCServer(Logger, EventListener):
'notifications']
SUPPORTED_NOTIFICATIONS: list[str] = ["payment_sent", "payment_received"]
def __init__(self, config: 'SimpleConfig', wallet: 'Abstract_Wallet', taskgroup: 'OldTaskGroup'):
def __init__(
self,
config: 'SimpleConfig',
wallet: 'Abstract_Wallet',
taskgroup: 'OldTaskGroup',
connection_storage: dict,
):
Logger.__init__(self)
self.config = config # type: 'SimpleConfig'
self.wallet = wallet # type: 'Abstract_Wallet'
storage = wallet.db.get_dict(STORAGE_NAME) # type: dict
self.connections = storage['connections'] # type: dict[str, dict] # client hex pubkey -> connection data
self.connections = connection_storage # type: dict[str, dict] # client hex pubkey -> connection data
self.relays = config.NOSTR_RELAYS.split(",") or [] # type: List[str]
self.do_stop = False
self.taskgroup = taskgroup # type: 'OldTaskGroup'

View File

@@ -78,7 +78,7 @@ class CosignerWallet(Logger):
KEEP_DELAY = 24*60*60
def __init__(self, wallet: 'Multisig_Wallet'):
def __init__(self, wallet: 'Multisig_Wallet', db_storage: dict):
assert isinstance(wallet, Multisig_Wallet)
self.wallet = wallet
@@ -90,7 +90,7 @@ class CosignerWallet(Logger):
self.pending = asyncio.Event()
self.wallet_uptodate = asyncio.Event()
self.known_events = wallet.db.get_dict('cosigner_events')
self.known_events = db_storage.setdefault('cosigner_events', {})
for k, v in list(self.known_events.items()):
if v < now() - self.KEEP_DELAY:

View File

@@ -117,7 +117,8 @@ class Plugin(PsbtNostrPlugin):
class QmlCosignerWallet(EventListener, CosignerWallet):
def __init__(self, wallet: 'Multisig_Wallet', plugin: 'Plugin'):
CosignerWallet.__init__(self, wallet)
db_storage = plugin.get_storage(wallet)
CosignerWallet.__init__(self, wallet, db_storage)
self.plugin = plugin
self.register_callbacks()

View File

@@ -57,7 +57,7 @@ class Plugin(PsbtNostrPlugin):
return
if wallet.wallet_type == '2fa':
return
self.add_cosigner_wallet(wallet, QtCosignerWallet(wallet, window))
self.add_cosigner_wallet(wallet, QtCosignerWallet(wallet, window, self))
@hook
def on_close_window(self, window):
@@ -83,8 +83,9 @@ class Plugin(PsbtNostrPlugin):
class QtCosignerWallet(EventListener, CosignerWallet):
def __init__(self, wallet: 'Multisig_Wallet', window: 'ElectrumWindow'):
CosignerWallet.__init__(self, wallet)
def __init__(self, wallet: 'Multisig_Wallet', window: 'ElectrumWindow', plugin: 'Plugin'):
db_storage = plugin.get_storage(wallet)
CosignerWallet.__init__(self, wallet, db_storage)
self.window = window
self.obj = QReceiveSignalObject()
self.obj.cosignerReceivedPsbt.connect(self.on_receive)

View File

@@ -4,7 +4,7 @@ import time
import os
import stat
from decimal import Decimal
from typing import Union, Optional, Dict, Sequence, Tuple, Any, Set, Callable
from typing import Union, Optional, Dict, Sequence, Tuple, Any, Set, Callable, AbstractSet
from numbers import Real
from functools import cached_property
@@ -349,6 +349,10 @@ class SimpleConfig(Logger):
def is_plugin_enabled(self, name: str) -> bool:
return bool(self.get(f'plugins.{name}.enabled'))
def get_installed_plugins(self) -> AbstractSet[str]:
"""Returns all plugin names registered in the config."""
return self.get('plugins', {}).keys()
def enable_plugin(self, name: str):
self.set_key(f'plugins.{name}.enabled', True, save=True)

View File

@@ -556,6 +556,7 @@ class Abstract_Wallet(ABC, Logger, EventListener):
finally: # even if we get cancelled
if any([ks.is_requesting_to_be_rewritten_to_wallet_file for ks in self.get_keystores()]):
self.save_keystore()
self.db.prune_uninstalled_plugin_data(self.config.get_installed_plugins())
self.save_db()
def is_up_to_date(self) -> bool:

View File

@@ -29,7 +29,8 @@ import json
import copy
import threading
from collections import defaultdict
from typing import Dict, Optional, List, Tuple, Set, Iterable, NamedTuple, Sequence, TYPE_CHECKING, Union
from typing import (Dict, Optional, List, Tuple, Set, Iterable, NamedTuple, Sequence, TYPE_CHECKING,
Union, AbstractSet)
import binascii
import time
from functools import partial
@@ -1725,5 +1726,16 @@ class WalletDB(JsonDB):
if wallet_type in plugin_loaders:
plugin_loaders[wallet_type]()
def get_plugin_storage(self) -> dict:
return self.get_dict('plugin_data')
def prune_uninstalled_plugin_data(self, installed_plugins: AbstractSet[str]) -> None:
"""Remove plugin data for plugins that are not installed anymore."""
plugin_storage = self.get_plugin_storage()
for name in list(plugin_storage.keys()):
if name not in installed_plugins:
plugin_storage.pop(name)
self.logger.info(f"deleting plugin data: {name=}")
def set_keystore_encryption(self, enable):
self.put('use_encryption', enable)