1
0

channel_db: don't wait for load_data to finish if stopping

ChannelDB.load_data() takes ~15 seconds. Previously if the user tried
to close the program while load_data is running, we would block until
load_data() finished. (e.g. consider starting and immediately stopping
Electrum)
Now instead we can abort load_data early.
This commit is contained in:
SomberNight
2023-08-30 11:49:42 +00:00
parent 78f0f788d6
commit 6557a21c45

View File

@@ -33,6 +33,7 @@ import base64
import asyncio
import threading
from enum import IntEnum
import functools
from aiorpcx import NetAddress
@@ -273,6 +274,9 @@ def get_mychannel_policy(short_channel_id: bytes, node_id: bytes,
return Policy.from_msg(local_update_decoded)
class _LoadDataAborted(Exception): pass
create_channel_info = """
CREATE TABLE IF NOT EXISTS channel_info (
short_channel_id BLOB(8),
@@ -733,15 +737,30 @@ class ChannelDB(SqlDB):
return [(str(net_addr.host), net_addr.port, ts)
for net_addr, ts in addr_to_ts.items()]
def handle_abort(func):
@functools.wraps(func)
def wrapper(self: 'ChannelDB', *args, **kwargs):
try:
return func(self, *args, **kwargs)
except _LoadDataAborted:
return
return wrapper
@sql
@profiler
@handle_abort
def load_data(self):
if self.data_loaded.is_set():
return
# Note: this method takes several seconds... mostly due to lnmsg.decode_msg being slow.
def maybe_abort():
if self.stopping:
self.logger.info("load_data() was asked to stop. exiting early.")
raise _LoadDataAborted()
c = self.conn.cursor()
c.execute("""SELECT * FROM address""")
for x in c:
maybe_abort()
node_id, host, port, timestamp = x
try:
net_addr = NetAddress(host, port)
@@ -757,6 +776,7 @@ class ChannelDB(SqlDB):
self._recent_peers = sorted_node_ids[:self.NUM_MAX_RECENT_PEERS]
c.execute("""SELECT * FROM channel_info""")
for short_channel_id, msg in c:
maybe_abort()
try:
ci = ChannelInfo.from_raw_msg(msg)
except IncompatibleOrInsaneFeatures:
@@ -766,6 +786,7 @@ class ChannelDB(SqlDB):
self._channels[ShortChannelID.normalize(short_channel_id)] = ci
c.execute("""SELECT * FROM node_info""")
for node_id, msg in c:
maybe_abort()
try:
node_info, node_addresses = NodeInfo.from_raw_msg(msg)
except IncompatibleOrInsaneFeatures:
@@ -776,6 +797,7 @@ class ChannelDB(SqlDB):
self._nodes[node_id] = node_info
c.execute("""SELECT * FROM policy""")
for key, msg in c:
maybe_abort()
try:
p = Policy.from_raw_msg(key, msg)
except FailedToParseMsg: