lnwatcher: save sweepstore in sqlite database
This commit is contained in:
@@ -52,9 +52,11 @@ class WatcherList(MyTreeView):
|
|||||||
def update(self):
|
def update(self):
|
||||||
self.model().clear()
|
self.model().clear()
|
||||||
self.update_headers({0:_('Outpoint'), 1:_('Tx'), 2:_('Status')})
|
self.update_headers({0:_('Outpoint'), 1:_('Tx'), 2:_('Status')})
|
||||||
for outpoint, sweep_dict in self.parent.lnwatcher.sweepstore.items():
|
sweepstore = self.parent.lnwatcher.sweepstore
|
||||||
|
for outpoint in sweepstore.list_sweep_tx():
|
||||||
|
n = sweepstore.num_sweep_tx(outpoint)
|
||||||
status = self.parent.lnwatcher.get_channel_status(outpoint)
|
status = self.parent.lnwatcher.get_channel_status(outpoint)
|
||||||
items = [QStandardItem(e) for e in [outpoint, "%d"%len(sweep_dict), status]]
|
items = [QStandardItem(e) for e in [outpoint, "%d"%n, status]]
|
||||||
self.model().insertRow(self.model().rowCount(), items)
|
self.model().insertRow(self.model().rowCount(), items)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -2,9 +2,11 @@
|
|||||||
# Distributed under the MIT software license, see the accompanying
|
# Distributed under the MIT software license, see the accompanying
|
||||||
# file LICENCE or http://www.opensource.org/licenses/mit-license.php
|
# file LICENCE or http://www.opensource.org/licenses/mit-license.php
|
||||||
|
|
||||||
import threading
|
|
||||||
from typing import NamedTuple, Iterable, TYPE_CHECKING
|
from typing import NamedTuple, Iterable, TYPE_CHECKING
|
||||||
import os
|
import os
|
||||||
|
import queue
|
||||||
|
import threading
|
||||||
|
import concurrent
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
import asyncio
|
import asyncio
|
||||||
from enum import IntEnum, auto
|
from enum import IntEnum, auto
|
||||||
@@ -35,27 +37,125 @@ class TxMinedDepth(IntEnum):
|
|||||||
FREE = auto()
|
FREE = auto()
|
||||||
|
|
||||||
|
|
||||||
|
from sqlalchemy import create_engine, Column, ForeignKey, Integer, String, DateTime, Boolean
|
||||||
|
from sqlalchemy.pool import StaticPool
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
from sqlalchemy.orm.query import Query
|
||||||
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
|
from sqlalchemy.sql import not_, or_
|
||||||
|
from sqlalchemy.orm import scoped_session
|
||||||
|
|
||||||
|
Base = declarative_base()
|
||||||
|
|
||||||
|
class SweepTx(Base):
|
||||||
|
__tablename__ = 'sweep_txs'
|
||||||
|
funding_outpoint = Column(String(34))
|
||||||
|
prev_txid = Column(String(32))
|
||||||
|
tx = Column(String())
|
||||||
|
txid = Column(String(32), primary_key=True) # txid of tx
|
||||||
|
|
||||||
|
class ChannelInfo(Base):
|
||||||
|
__tablename__ = 'channel_info'
|
||||||
|
address = Column(String(32), primary_key=True)
|
||||||
|
outpoint = Column(String(34))
|
||||||
|
|
||||||
|
|
||||||
|
class SweepStore(PrintError):
|
||||||
|
|
||||||
|
def __init__(self, path, network):
|
||||||
|
PrintError.__init__(self)
|
||||||
|
self.path = path
|
||||||
|
self.network = network
|
||||||
|
self.db_requests = queue.Queue()
|
||||||
|
threading.Thread(target=self.sql_thread).start()
|
||||||
|
|
||||||
|
def sql_thread(self):
|
||||||
|
engine = create_engine('sqlite:///' + self.path, pool_reset_on_return=None, poolclass=StaticPool)
|
||||||
|
DBSession = sessionmaker(bind=engine, autoflush=False)
|
||||||
|
self.DBSession = DBSession()
|
||||||
|
if not os.path.exists(self.path):
|
||||||
|
Base.metadata.create_all(engine)
|
||||||
|
while self.network.asyncio_loop.is_running():
|
||||||
|
try:
|
||||||
|
future, func, args, kwargs = self.db_requests.get(timeout=0.1)
|
||||||
|
except queue.Empty:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
result = func(self, *args, **kwargs)
|
||||||
|
except BaseException as e:
|
||||||
|
future.set_exception(e)
|
||||||
|
continue
|
||||||
|
future.set_result(result)
|
||||||
|
# write
|
||||||
|
self.DBSession.commit()
|
||||||
|
self.print_error("SQL thread terminated")
|
||||||
|
|
||||||
|
def sql(func):
|
||||||
|
def wrapper(self, *args, **kwargs):
|
||||||
|
f = concurrent.futures.Future()
|
||||||
|
self.db_requests.put((f, func, args, kwargs))
|
||||||
|
return f.result(timeout=10)
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
@sql
|
||||||
|
def get_sweep_tx(self, funding_outpoint, prev_txid):
|
||||||
|
return [Transaction(r.tx) for r in self.DBSession.query(SweepTx).filter(SweepTx.funding_outpoint==funding_outpoint, SweepTx.prev_txid==prev_txid).all()]
|
||||||
|
|
||||||
|
@sql
|
||||||
|
def list_sweep_tx(self):
|
||||||
|
return set(r.funding_outpoint for r in self.DBSession.query(SweepTx).all())
|
||||||
|
|
||||||
|
@sql
|
||||||
|
def add_sweep_tx(self, funding_outpoint, prev_txid, tx):
|
||||||
|
self.DBSession.add(SweepTx(funding_outpoint=funding_outpoint, prev_txid=prev_txid, tx=str(tx), txid=tx.txid()))
|
||||||
|
self.DBSession.commit()
|
||||||
|
|
||||||
|
@sql
|
||||||
|
def num_sweep_tx(self, funding_outpoint):
|
||||||
|
return self.DBSession.query(SweepTx).filter(funding_outpoint==funding_outpoint).count()
|
||||||
|
|
||||||
|
@sql
|
||||||
|
def remove_sweep_tx(self, funding_outpoint):
|
||||||
|
v = self.DBSession.query(SweepTx).filter(SweepTx.funding_outpoint==funding_outpoint).all()
|
||||||
|
self.DBSession.delete(v)
|
||||||
|
self.DBSession.commit()
|
||||||
|
|
||||||
|
@sql
|
||||||
|
def add_channel_info(self, address, outpoint):
|
||||||
|
self.DBSession.add(ChannelInfo(address=address, outpoint=outpoint))
|
||||||
|
self.DBSession.commit()
|
||||||
|
|
||||||
|
@sql
|
||||||
|
def remove_channel_info(self, address):
|
||||||
|
v = self.DBSession.query(ChannelInfo).filter(ChannelInfo.address==address).one_or_none()
|
||||||
|
self.DBSession.delete(v)
|
||||||
|
self.DBSession.commit()
|
||||||
|
|
||||||
|
@sql
|
||||||
|
def has_channel_info(self, address):
|
||||||
|
return bool(self.DBSession.query(ChannelInfo).filter(ChannelInfo.address==address).one_or_none())
|
||||||
|
|
||||||
|
@sql
|
||||||
|
def get_channel_info(self, address):
|
||||||
|
r = self.DBSession.query(ChannelInfo).filter(ChannelInfo.address==address).one_or_none()
|
||||||
|
return r.outpoint if r else None
|
||||||
|
|
||||||
|
@sql
|
||||||
|
def list_channel_info(self):
|
||||||
|
return [(r.address, r.outpoint) for r in self.DBSession.query(ChannelInfo).all()]
|
||||||
|
|
||||||
|
|
||||||
class LNWatcher(AddressSynchronizer):
|
class LNWatcher(AddressSynchronizer):
|
||||||
verbosity_filter = 'W'
|
verbosity_filter = 'W'
|
||||||
|
|
||||||
def __init__(self, network: 'Network'):
|
def __init__(self, network: 'Network'):
|
||||||
path = os.path.join(network.config.path, "watcher_db")
|
path = os.path.join(network.config.path, "watchtower_wallet")
|
||||||
storage = WalletStorage(path)
|
storage = WalletStorage(path)
|
||||||
AddressSynchronizer.__init__(self, storage)
|
AddressSynchronizer.__init__(self, storage)
|
||||||
self.config = network.config
|
self.config = network.config
|
||||||
self.start_network(network)
|
self.start_network(network)
|
||||||
self.lock = threading.RLock()
|
self.lock = threading.RLock()
|
||||||
self.channel_info = storage.get('channel_info', {}) # access with 'lock'
|
self.sweepstore = SweepStore(os.path.join(network.config.path, "watchtower_db"), network)
|
||||||
# [funding_outpoint_str][prev_txid] -> set of Transaction
|
|
||||||
# prev_txid is the txid of a tx that is watched for confirmations
|
|
||||||
# access with 'lock'
|
|
||||||
self.sweepstore = defaultdict(lambda: defaultdict(set))
|
|
||||||
for funding_outpoint, ctxs in storage.get('sweepstore', {}).items():
|
|
||||||
for txid, set_of_txns in ctxs.items():
|
|
||||||
for tx in set_of_txns:
|
|
||||||
tx2 = Transaction.from_dict(tx)
|
|
||||||
self.sweepstore[funding_outpoint][txid].add(tx2)
|
|
||||||
|
|
||||||
self.network.register_callback(self.on_network_update,
|
self.network.register_callback(self.on_network_update,
|
||||||
['network_updated', 'blockchain_updated', 'verified', 'wallet_updated'])
|
['network_updated', 'blockchain_updated', 'verified', 'wallet_updated'])
|
||||||
self.set_remote_watchtower()
|
self.set_remote_watchtower()
|
||||||
@@ -97,34 +197,18 @@ class LNWatcher(AddressSynchronizer):
|
|||||||
await asyncio.sleep(5)
|
await asyncio.sleep(5)
|
||||||
await self.watchtower_queue.put((name, args, kwargs))
|
await self.watchtower_queue.put((name, args, kwargs))
|
||||||
|
|
||||||
def write_to_disk(self):
|
|
||||||
# FIXME: json => every update takes linear instead of constant disk write
|
|
||||||
with self.lock:
|
|
||||||
storage = self.storage
|
|
||||||
storage.put('channel_info', self.channel_info)
|
|
||||||
# self.sweepstore
|
|
||||||
sweepstore = {}
|
|
||||||
for funding_outpoint, ctxs in self.sweepstore.items():
|
|
||||||
sweepstore[funding_outpoint] = {}
|
|
||||||
for prev_txid, set_of_txns in ctxs.items():
|
|
||||||
sweepstore[funding_outpoint][prev_txid] = [tx.as_dict() for tx in set_of_txns]
|
|
||||||
storage.put('sweepstore', sweepstore)
|
|
||||||
storage.write()
|
|
||||||
|
|
||||||
@with_watchtower
|
@with_watchtower
|
||||||
def watch_channel(self, address, outpoint):
|
def watch_channel(self, address, outpoint):
|
||||||
self.add_address(address)
|
self.add_address(address)
|
||||||
with self.lock:
|
with self.lock:
|
||||||
if address not in self.channel_info:
|
if not self.sweepstore.has_channel_info(address):
|
||||||
self.channel_info[address] = outpoint
|
self.sweepstore.add_channel_info(address, outpoint)
|
||||||
self.write_to_disk()
|
|
||||||
|
|
||||||
def unwatch_channel(self, address, funding_outpoint):
|
def unwatch_channel(self, address, funding_outpoint):
|
||||||
self.print_error('unwatching', funding_outpoint)
|
self.print_error('unwatching', funding_outpoint)
|
||||||
with self.lock:
|
self.sweepstore.remove_sweep_tx(funding_outpoint)
|
||||||
self.channel_info.pop(address)
|
self.sweepstore.remove_channel_info(address)
|
||||||
self.sweepstore.pop(funding_outpoint)
|
|
||||||
self.write_to_disk()
|
|
||||||
if funding_outpoint in self.tx_progress:
|
if funding_outpoint in self.tx_progress:
|
||||||
self.tx_progress[funding_outpoint].all_done.set()
|
self.tx_progress[funding_outpoint].all_done.set()
|
||||||
|
|
||||||
@@ -138,9 +222,7 @@ class LNWatcher(AddressSynchronizer):
|
|||||||
return
|
return
|
||||||
if not self.synchronizer.is_up_to_date():
|
if not self.synchronizer.is_up_to_date():
|
||||||
return
|
return
|
||||||
with self.lock:
|
for address, outpoint in self.sweepstore.list_channel_info():
|
||||||
channel_info_items = list(self.channel_info.items())
|
|
||||||
for address, outpoint in channel_info_items:
|
|
||||||
await self.check_onchain_situation(address, outpoint)
|
await self.check_onchain_situation(address, outpoint)
|
||||||
|
|
||||||
async def check_onchain_situation(self, address, funding_outpoint):
|
async def check_onchain_situation(self, address, funding_outpoint):
|
||||||
@@ -192,8 +274,7 @@ class LNWatcher(AddressSynchronizer):
|
|||||||
if spender is not None:
|
if spender is not None:
|
||||||
continue
|
continue
|
||||||
prev_txid, prev_n = prevout.split(':')
|
prev_txid, prev_n = prevout.split(':')
|
||||||
with self.lock:
|
sweep_txns = self.sweepstore.get_sweep_tx(funding_outpoint, prev_txid)
|
||||||
sweep_txns = self.sweepstore[funding_outpoint][prev_txid]
|
|
||||||
for tx in sweep_txns:
|
for tx in sweep_txns:
|
||||||
if not await self.broadcast_or_log(funding_outpoint, tx):
|
if not await self.broadcast_or_log(funding_outpoint, tx):
|
||||||
self.print_error(tx.name, f'could not publish tx: {str(tx)}, prev_txid: {prev_txid}')
|
self.print_error(tx.name, f'could not publish tx: {str(tx)}, prev_txid: {prev_txid}')
|
||||||
@@ -215,9 +296,7 @@ class LNWatcher(AddressSynchronizer):
|
|||||||
@with_watchtower
|
@with_watchtower
|
||||||
def add_sweep_tx(self, funding_outpoint: str, prev_txid: str, tx_dict):
|
def add_sweep_tx(self, funding_outpoint: str, prev_txid: str, tx_dict):
|
||||||
tx = Transaction.from_dict(tx_dict)
|
tx = Transaction.from_dict(tx_dict)
|
||||||
with self.lock:
|
self.sweepstore.add_sweep_tx(funding_outpoint, prev_txid, tx)
|
||||||
self.sweepstore[funding_outpoint][prev_txid].add(tx)
|
|
||||||
self.write_to_disk()
|
|
||||||
|
|
||||||
def get_tx_mined_depth(self, txid: str):
|
def get_tx_mined_depth(self, txid: str):
|
||||||
if not txid:
|
if not txid:
|
||||||
|
|||||||
Reference in New Issue
Block a user