lnworker: fix silent TypeError in _calc_routing_hints_for_invoice
This commit is contained in:
@@ -545,7 +545,7 @@ class ChannelDB(SqlDB):
|
||||
def get_policy_for_node(self, short_channel_id: bytes, node_id: bytes) -> Optional['Policy']:
|
||||
return self._policies.get((node_id, short_channel_id))
|
||||
|
||||
def get_channel_info(self, channel_id: bytes):
|
||||
def get_channel_info(self, channel_id: bytes) -> ChannelInfo:
|
||||
return self._channels.get(channel_id)
|
||||
|
||||
def get_channels_for_node(self, node_id) -> Set[bytes]:
|
||||
|
||||
@@ -814,7 +814,7 @@ class LNWallet(LNWorker):
|
||||
amount_msat = int(decoded_invoice.amount * COIN * 1000)
|
||||
invoice_pubkey = decoded_invoice.pubkey.serialize()
|
||||
# use 'r' field from invoice
|
||||
route = None # type: List[RouteEdge]
|
||||
route = None # type: Optional[List[RouteEdge]]
|
||||
# only want 'r' tags
|
||||
r_tags = list(filter(lambda x: x[0] == 'r', decoded_invoice.tags))
|
||||
# strip the tag type, it's implicitly 'r' now
|
||||
@@ -979,7 +979,7 @@ class LNWallet(LNWorker):
|
||||
cltv_expiry_delta = 1 # lnd won't even try with zero
|
||||
missing_info = True
|
||||
if channel_info:
|
||||
policy = self.channel_db.get_policy_for_node(channel_info, chan.node_id)
|
||||
policy = self.channel_db.get_policy_for_node(channel_info.short_channel_id, chan.node_id)
|
||||
if policy:
|
||||
fee_base_msat = policy.fee_base_msat
|
||||
fee_proportional_millionths = policy.fee_proportional_millionths
|
||||
|
||||
@@ -33,7 +33,7 @@ import json
|
||||
import sys
|
||||
import ipaddress
|
||||
import asyncio
|
||||
from typing import NamedTuple, Optional, Sequence, List, Dict, Tuple
|
||||
from typing import NamedTuple, Optional, Sequence, List, Dict, Tuple, TYPE_CHECKING
|
||||
import traceback
|
||||
|
||||
import dns
|
||||
@@ -60,6 +60,9 @@ from .simple_config import SimpleConfig
|
||||
from .i18n import _
|
||||
from .logging import get_logger, Logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .channel_db import ChannelDB
|
||||
|
||||
|
||||
_logger = get_logger(__name__)
|
||||
|
||||
@@ -307,7 +310,7 @@ class Network(Logger):
|
||||
self.lngossip = lnworker.LNGossip(self)
|
||||
self.local_watchtower = lnwatcher.WatchTower(self) if self.config.get('local_watchtower', True) else None
|
||||
else:
|
||||
self.channel_db = None
|
||||
self.channel_db = None # type: Optional[ChannelDB]
|
||||
self.lngossip = None
|
||||
self.local_watchtower = None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user