1
0

lnworker: fix silent TypeError in _calc_routing_hints_for_invoice

This commit is contained in:
SomberNight
2019-08-16 22:03:20 +02:00
committed by ThomasV
parent 02681c6664
commit ba431495db
3 changed files with 8 additions and 5 deletions

View File

@@ -545,7 +545,7 @@ class ChannelDB(SqlDB):
def get_policy_for_node(self, short_channel_id: bytes, node_id: bytes) -> Optional['Policy']: def get_policy_for_node(self, short_channel_id: bytes, node_id: bytes) -> Optional['Policy']:
return self._policies.get((node_id, short_channel_id)) return self._policies.get((node_id, short_channel_id))
def get_channel_info(self, channel_id: bytes): def get_channel_info(self, channel_id: bytes) -> ChannelInfo:
return self._channels.get(channel_id) return self._channels.get(channel_id)
def get_channels_for_node(self, node_id) -> Set[bytes]: def get_channels_for_node(self, node_id) -> Set[bytes]:

View File

@@ -814,7 +814,7 @@ class LNWallet(LNWorker):
amount_msat = int(decoded_invoice.amount * COIN * 1000) amount_msat = int(decoded_invoice.amount * COIN * 1000)
invoice_pubkey = decoded_invoice.pubkey.serialize() invoice_pubkey = decoded_invoice.pubkey.serialize()
# use 'r' field from invoice # use 'r' field from invoice
route = None # type: List[RouteEdge] route = None # type: Optional[List[RouteEdge]]
# only want 'r' tags # only want 'r' tags
r_tags = list(filter(lambda x: x[0] == 'r', decoded_invoice.tags)) r_tags = list(filter(lambda x: x[0] == 'r', decoded_invoice.tags))
# strip the tag type, it's implicitly 'r' now # strip the tag type, it's implicitly 'r' now
@@ -979,7 +979,7 @@ class LNWallet(LNWorker):
cltv_expiry_delta = 1 # lnd won't even try with zero cltv_expiry_delta = 1 # lnd won't even try with zero
missing_info = True missing_info = True
if channel_info: if channel_info:
policy = self.channel_db.get_policy_for_node(channel_info, chan.node_id) policy = self.channel_db.get_policy_for_node(channel_info.short_channel_id, chan.node_id)
if policy: if policy:
fee_base_msat = policy.fee_base_msat fee_base_msat = policy.fee_base_msat
fee_proportional_millionths = policy.fee_proportional_millionths fee_proportional_millionths = policy.fee_proportional_millionths

View File

@@ -33,7 +33,7 @@ import json
import sys import sys
import ipaddress import ipaddress
import asyncio import asyncio
from typing import NamedTuple, Optional, Sequence, List, Dict, Tuple from typing import NamedTuple, Optional, Sequence, List, Dict, Tuple, TYPE_CHECKING
import traceback import traceback
import dns import dns
@@ -60,6 +60,9 @@ from .simple_config import SimpleConfig
from .i18n import _ from .i18n import _
from .logging import get_logger, Logger from .logging import get_logger, Logger
if TYPE_CHECKING:
from .channel_db import ChannelDB
_logger = get_logger(__name__) _logger = get_logger(__name__)
@@ -307,7 +310,7 @@ class Network(Logger):
self.lngossip = lnworker.LNGossip(self) self.lngossip = lnworker.LNGossip(self)
self.local_watchtower = lnwatcher.WatchTower(self) if self.config.get('local_watchtower', True) else None self.local_watchtower = lnwatcher.WatchTower(self) if self.config.get('local_watchtower', True) else None
else: else:
self.channel_db = None self.channel_db = None # type: Optional[ChannelDB]
self.lngossip = None self.lngossip = None
self.local_watchtower = None self.local_watchtower = None