plugin commands:
- make plugin commands start with plugin name + underscore
- plugin_name must be passed to the plugin_command decorator
- fixes:
- remove plugin_commands (unneeded)
- func_wrapper must await func()
- setattr(Commands, name, func_wrapper)
- add push/pull commands to labels plugin
This commit is contained in:
@@ -25,12 +25,10 @@
|
||||
import io
|
||||
import sys
|
||||
import datetime
|
||||
import copy
|
||||
import argparse
|
||||
import json
|
||||
import ast
|
||||
import base64
|
||||
import operator
|
||||
import asyncio
|
||||
import inspect
|
||||
from collections import defaultdict
|
||||
@@ -43,7 +41,6 @@ import os
|
||||
import electrum_ecc as ecc
|
||||
|
||||
from . import util
|
||||
from . import keystore
|
||||
from .lnmsg import OnionWireSerializer
|
||||
from .logging import Logger
|
||||
from .onion_message import create_blinded_path, send_onion_message_to
|
||||
@@ -71,7 +68,6 @@ from .version import ELECTRUM_VERSION
|
||||
from .simple_config import SimpleConfig
|
||||
from .invoices import Invoice
|
||||
from .fee_policy import FeePolicy
|
||||
from . import submarine_swaps
|
||||
from . import GuiImportError
|
||||
from . import crypto
|
||||
from . import constants
|
||||
@@ -83,7 +79,6 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
known_commands = {} # type: Dict[str, Command]
|
||||
plugin_commands = defaultdict(set) # type: Dict[str, set[str]] # plugin_name -> set(command_name)
|
||||
|
||||
|
||||
class NotSynchronizedException(UserFacingException):
|
||||
@@ -104,8 +99,8 @@ def format_satoshis(x):
|
||||
|
||||
|
||||
class Command:
|
||||
def __init__(self, func, s):
|
||||
self.name = func.__name__
|
||||
def __init__(self, func, name, s):
|
||||
self.name = name
|
||||
self.requires_network = 'n' in s
|
||||
self.requires_wallet = 'w' in s
|
||||
self.requires_password = 'p' in s
|
||||
@@ -137,17 +132,20 @@ class Command:
|
||||
def command(s):
|
||||
def decorator(func):
|
||||
global known_commands
|
||||
name = func.__name__
|
||||
|
||||
if hasattr(func, '__wrapped__'): # plugin command function
|
||||
known_commands[name] = Command(func.__wrapped__, s)
|
||||
else: # regular command function
|
||||
known_commands[name] = Command(func, s)
|
||||
if hasattr(func, '__wrapped__'):
|
||||
# plugin command function
|
||||
name = func.plugin_name + '_' + func.__name__
|
||||
known_commands[name] = Command(func.__wrapped__, name, s)
|
||||
else:
|
||||
# regular command function
|
||||
name = func.__name__
|
||||
known_commands[name] = Command(func, name, s)
|
||||
|
||||
@wraps(func)
|
||||
async def func_wrapper(*args, **kwargs):
|
||||
cmd_runner = args[0] # type: Commands
|
||||
cmd = known_commands[func.__name__] # type: Command
|
||||
cmd = known_commands[name] # type: Command
|
||||
password = kwargs.get('password')
|
||||
daemon = cmd_runner.daemon
|
||||
if daemon:
|
||||
@@ -1563,34 +1561,21 @@ class Commands(Logger):
|
||||
|
||||
return encoded_blinded_path.hex()
|
||||
|
||||
def plugin_command(s, plugin_name = None):
|
||||
def plugin_command(s, plugin_name):
|
||||
"""Decorator to register a cli command inside a plugin. To be used within a commands.py file
|
||||
in the plugins root."""
|
||||
def decorator(func):
|
||||
global known_commands
|
||||
global plugin_commands
|
||||
name = func.__name__
|
||||
func.plugin_name = plugin_name
|
||||
name = plugin_name + '_' + func.__name__
|
||||
if name in known_commands or hasattr(Commands, name):
|
||||
raise Exception(f"Plugins should not override other commands: {name}")
|
||||
assert name.startswith('plugin_'), f"Plugin command names should start with 'plugin_': {name}"
|
||||
assert asyncio.iscoroutinefunction(func), f"Plugin commands must be a coroutine: {name}"
|
||||
|
||||
if not plugin_name:
|
||||
# this is way slower than providing the plugin name, so it should only be considered a fallback
|
||||
caller_frame = sys._getframe(1)
|
||||
module_name = caller_frame.f_globals.get('__name__')
|
||||
plugin_name_from_frame = module_name.rsplit('.', 2)[-2]
|
||||
# reassigning to plugin_name doesn't work here
|
||||
plugin_commands[plugin_name_from_frame].add(name)
|
||||
else:
|
||||
plugin_commands[plugin_name].add(name)
|
||||
|
||||
setattr(Commands, name, func)
|
||||
|
||||
@command(s)
|
||||
@wraps(func)
|
||||
async def func_wrapper(*args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
return await func(*args, **kwargs)
|
||||
setattr(Commands, name, func_wrapper)
|
||||
return func_wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
22
electrum/plugins/labels/commands.py
Normal file
22
electrum/plugins/labels/commands.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from electrum.commands import plugin_command
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .labels import LabelsPlugin
|
||||
from electrum.commands import Commands
|
||||
|
||||
plugin_name = "labels"
|
||||
|
||||
@plugin_command('w', plugin_name)
|
||||
async def push(self: 'Commands', wallet=None) -> int:
|
||||
""" push labels to server """
|
||||
plugin: 'LabelsPlugin' = self.daemon._plugins.get_plugin(plugin_name)
|
||||
return await plugin.push_thread(wallet)
|
||||
|
||||
|
||||
@plugin_command('w', plugin_name)
|
||||
async def pull(self: 'Commands', wallet=None) -> int:
|
||||
""" pull labels from server """
|
||||
assert wallet is not None
|
||||
plugin: 'LabelsPlugin' = self.daemon._plugins.get_plugin(plugin_name)
|
||||
return await plugin.pull_thread(wallet, force=False)
|
||||
@@ -108,7 +108,7 @@ class LabelsPlugin(BasePlugin):
|
||||
except Exception as e:
|
||||
raise Exception('Could not decode: ' + await result.text()) from e
|
||||
|
||||
async def push_thread(self, wallet: 'Abstract_Wallet'):
|
||||
async def push_thread(self, wallet: 'Abstract_Wallet') -> int:
|
||||
wallet_data = self.wallets.get(wallet, None)
|
||||
if not wallet_data:
|
||||
raise Exception('Wallet {} not loaded'.format(wallet))
|
||||
@@ -126,8 +126,9 @@ class LabelsPlugin(BasePlugin):
|
||||
bundle["labels"].append({'encryptedLabel': encoded_value,
|
||||
'externalId': encoded_key})
|
||||
await self.do_post("/labels", bundle)
|
||||
return len(bundle['labels'])
|
||||
|
||||
async def pull_thread(self, wallet: 'Abstract_Wallet', force: bool):
|
||||
async def pull_thread(self, wallet: 'Abstract_Wallet', force: bool) -> int:
|
||||
wallet_data = self.wallets.get(wallet, None)
|
||||
if not wallet_data:
|
||||
raise Exception('Wallet {} not loaded'.format(wallet))
|
||||
@@ -140,7 +141,7 @@ class LabelsPlugin(BasePlugin):
|
||||
raise ErrorConnectingServer(e) from e
|
||||
if response["labels"] is None or len(response["labels"]) == 0:
|
||||
self.logger.info('no new labels')
|
||||
return
|
||||
return 0
|
||||
|
||||
self.logger.info(f'received {len(response["labels"])} labels')
|
||||
result = {}
|
||||
@@ -165,6 +166,7 @@ class LabelsPlugin(BasePlugin):
|
||||
self.set_nonce(wallet, response["nonce"] + 1)
|
||||
util.trigger_callback('labels_received', wallet, result)
|
||||
self.on_pulled(wallet)
|
||||
return len(result)
|
||||
|
||||
def on_pulled(self, wallet: 'Abstract_Wallet') -> None:
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user