1
0

plugin commands:

- make plugin commands start with plugin name + underscore
 - plugin_name must be passed to the plugin_command decorator
 - fixes:
    - remove plugin_commands (unneeded)
    - func_wrapper must await func()
    - setattr(Commands, name, func_wrapper)
 - add push/pull commands to labels plugin
This commit is contained in:
ThomasV
2025-03-16 11:12:04 +01:00
parent cb39737a39
commit a474b8674d
3 changed files with 43 additions and 34 deletions

View File

@@ -25,12 +25,10 @@
import io import io
import sys import sys
import datetime import datetime
import copy
import argparse import argparse
import json import json
import ast import ast
import base64 import base64
import operator
import asyncio import asyncio
import inspect import inspect
from collections import defaultdict from collections import defaultdict
@@ -43,7 +41,6 @@ import os
import electrum_ecc as ecc import electrum_ecc as ecc
from . import util from . import util
from . import keystore
from .lnmsg import OnionWireSerializer from .lnmsg import OnionWireSerializer
from .logging import Logger from .logging import Logger
from .onion_message import create_blinded_path, send_onion_message_to from .onion_message import create_blinded_path, send_onion_message_to
@@ -71,7 +68,6 @@ from .version import ELECTRUM_VERSION
from .simple_config import SimpleConfig from .simple_config import SimpleConfig
from .invoices import Invoice from .invoices import Invoice
from .fee_policy import FeePolicy from .fee_policy import FeePolicy
from . import submarine_swaps
from . import GuiImportError from . import GuiImportError
from . import crypto from . import crypto
from . import constants from . import constants
@@ -83,7 +79,6 @@ if TYPE_CHECKING:
known_commands = {} # type: Dict[str, Command] known_commands = {} # type: Dict[str, Command]
plugin_commands = defaultdict(set) # type: Dict[str, set[str]] # plugin_name -> set(command_name)
class NotSynchronizedException(UserFacingException): class NotSynchronizedException(UserFacingException):
@@ -104,8 +99,8 @@ def format_satoshis(x):
class Command: class Command:
def __init__(self, func, s): def __init__(self, func, name, s):
self.name = func.__name__ self.name = name
self.requires_network = 'n' in s self.requires_network = 'n' in s
self.requires_wallet = 'w' in s self.requires_wallet = 'w' in s
self.requires_password = 'p' in s self.requires_password = 'p' in s
@@ -137,17 +132,20 @@ class Command:
def command(s): def command(s):
def decorator(func): def decorator(func):
global known_commands global known_commands
name = func.__name__
if hasattr(func, '__wrapped__'): # plugin command function if hasattr(func, '__wrapped__'):
known_commands[name] = Command(func.__wrapped__, s) # plugin command function
else: # regular command function name = func.plugin_name + '_' + func.__name__
known_commands[name] = Command(func, s) known_commands[name] = Command(func.__wrapped__, name, s)
else:
# regular command function
name = func.__name__
known_commands[name] = Command(func, name, s)
@wraps(func) @wraps(func)
async def func_wrapper(*args, **kwargs): async def func_wrapper(*args, **kwargs):
cmd_runner = args[0] # type: Commands cmd_runner = args[0] # type: Commands
cmd = known_commands[func.__name__] # type: Command cmd = known_commands[name] # type: Command
password = kwargs.get('password') password = kwargs.get('password')
daemon = cmd_runner.daemon daemon = cmd_runner.daemon
if daemon: if daemon:
@@ -1563,34 +1561,21 @@ class Commands(Logger):
return encoded_blinded_path.hex() return encoded_blinded_path.hex()
def plugin_command(s, plugin_name = None): def plugin_command(s, plugin_name):
"""Decorator to register a cli command inside a plugin. To be used within a commands.py file """Decorator to register a cli command inside a plugin. To be used within a commands.py file
in the plugins root.""" in the plugins root."""
def decorator(func): def decorator(func):
global known_commands global known_commands
global plugin_commands func.plugin_name = plugin_name
name = func.__name__ name = plugin_name + '_' + func.__name__
if name in known_commands or hasattr(Commands, name): if name in known_commands or hasattr(Commands, name):
raise Exception(f"Plugins should not override other commands: {name}") raise Exception(f"Plugins should not override other commands: {name}")
assert name.startswith('plugin_'), f"Plugin command names should start with 'plugin_': {name}"
assert asyncio.iscoroutinefunction(func), f"Plugin commands must be a coroutine: {name}" assert asyncio.iscoroutinefunction(func), f"Plugin commands must be a coroutine: {name}"
if not plugin_name:
# this is way slower than providing the plugin name, so it should only be considered a fallback
caller_frame = sys._getframe(1)
module_name = caller_frame.f_globals.get('__name__')
plugin_name_from_frame = module_name.rsplit('.', 2)[-2]
# reassigning to plugin_name doesn't work here
plugin_commands[plugin_name_from_frame].add(name)
else:
plugin_commands[plugin_name].add(name)
setattr(Commands, name, func)
@command(s) @command(s)
@wraps(func) @wraps(func)
async def func_wrapper(*args, **kwargs): async def func_wrapper(*args, **kwargs):
return func(*args, **kwargs) return await func(*args, **kwargs)
setattr(Commands, name, func_wrapper)
return func_wrapper return func_wrapper
return decorator return decorator

View File

@@ -0,0 +1,22 @@
from electrum.commands import plugin_command
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from .labels import LabelsPlugin
from electrum.commands import Commands
plugin_name = "labels"
@plugin_command('w', plugin_name)
async def push(self: 'Commands', wallet=None) -> int:
""" push labels to server """
plugin: 'LabelsPlugin' = self.daemon._plugins.get_plugin(plugin_name)
return await plugin.push_thread(wallet)
@plugin_command('w', plugin_name)
async def pull(self: 'Commands', wallet=None) -> int:
""" pull labels from server """
assert wallet is not None
plugin: 'LabelsPlugin' = self.daemon._plugins.get_plugin(plugin_name)
return await plugin.pull_thread(wallet, force=False)

View File

@@ -108,7 +108,7 @@ class LabelsPlugin(BasePlugin):
except Exception as e: except Exception as e:
raise Exception('Could not decode: ' + await result.text()) from e raise Exception('Could not decode: ' + await result.text()) from e
async def push_thread(self, wallet: 'Abstract_Wallet'): async def push_thread(self, wallet: 'Abstract_Wallet') -> int:
wallet_data = self.wallets.get(wallet, None) wallet_data = self.wallets.get(wallet, None)
if not wallet_data: if not wallet_data:
raise Exception('Wallet {} not loaded'.format(wallet)) raise Exception('Wallet {} not loaded'.format(wallet))
@@ -126,8 +126,9 @@ class LabelsPlugin(BasePlugin):
bundle["labels"].append({'encryptedLabel': encoded_value, bundle["labels"].append({'encryptedLabel': encoded_value,
'externalId': encoded_key}) 'externalId': encoded_key})
await self.do_post("/labels", bundle) await self.do_post("/labels", bundle)
return len(bundle['labels'])
async def pull_thread(self, wallet: 'Abstract_Wallet', force: bool): async def pull_thread(self, wallet: 'Abstract_Wallet', force: bool) -> int:
wallet_data = self.wallets.get(wallet, None) wallet_data = self.wallets.get(wallet, None)
if not wallet_data: if not wallet_data:
raise Exception('Wallet {} not loaded'.format(wallet)) raise Exception('Wallet {} not loaded'.format(wallet))
@@ -140,7 +141,7 @@ class LabelsPlugin(BasePlugin):
raise ErrorConnectingServer(e) from e raise ErrorConnectingServer(e) from e
if response["labels"] is None or len(response["labels"]) == 0: if response["labels"] is None or len(response["labels"]) == 0:
self.logger.info('no new labels') self.logger.info('no new labels')
return return 0
self.logger.info(f'received {len(response["labels"])} labels') self.logger.info(f'received {len(response["labels"])} labels')
result = {} result = {}
@@ -165,6 +166,7 @@ class LabelsPlugin(BasePlugin):
self.set_nonce(wallet, response["nonce"] + 1) self.set_nonce(wallet, response["nonce"] + 1)
util.trigger_callback('labels_received', wallet, result) util.trigger_callback('labels_received', wallet, result)
self.on_pulled(wallet) self.on_pulled(wallet)
return len(result)
def on_pulled(self, wallet: 'Abstract_Wallet') -> None: def on_pulled(self, wallet: 'Abstract_Wallet') -> None:
pass pass