lnwatcher: save sweepstore in sqlite database
This commit is contained in:
@@ -2,9 +2,11 @@
|
||||
# Distributed under the MIT software license, see the accompanying
|
||||
# file LICENCE or http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
import threading
|
||||
from typing import NamedTuple, Iterable, TYPE_CHECKING
|
||||
import os
|
||||
import queue
|
||||
import threading
|
||||
import concurrent
|
||||
from collections import defaultdict
|
||||
import asyncio
|
||||
from enum import IntEnum, auto
|
||||
@@ -35,27 +37,125 @@ class TxMinedDepth(IntEnum):
|
||||
FREE = auto()
|
||||
|
||||
|
||||
from sqlalchemy import create_engine, Column, ForeignKey, Integer, String, DateTime, Boolean
|
||||
from sqlalchemy.pool import StaticPool
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.orm.query import Query
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.sql import not_, or_
|
||||
from sqlalchemy.orm import scoped_session
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
class SweepTx(Base):
|
||||
__tablename__ = 'sweep_txs'
|
||||
funding_outpoint = Column(String(34))
|
||||
prev_txid = Column(String(32))
|
||||
tx = Column(String())
|
||||
txid = Column(String(32), primary_key=True) # txid of tx
|
||||
|
||||
class ChannelInfo(Base):
|
||||
__tablename__ = 'channel_info'
|
||||
address = Column(String(32), primary_key=True)
|
||||
outpoint = Column(String(34))
|
||||
|
||||
|
||||
class SweepStore(PrintError):
|
||||
|
||||
def __init__(self, path, network):
|
||||
PrintError.__init__(self)
|
||||
self.path = path
|
||||
self.network = network
|
||||
self.db_requests = queue.Queue()
|
||||
threading.Thread(target=self.sql_thread).start()
|
||||
|
||||
def sql_thread(self):
|
||||
engine = create_engine('sqlite:///' + self.path, pool_reset_on_return=None, poolclass=StaticPool)
|
||||
DBSession = sessionmaker(bind=engine, autoflush=False)
|
||||
self.DBSession = DBSession()
|
||||
if not os.path.exists(self.path):
|
||||
Base.metadata.create_all(engine)
|
||||
while self.network.asyncio_loop.is_running():
|
||||
try:
|
||||
future, func, args, kwargs = self.db_requests.get(timeout=0.1)
|
||||
except queue.Empty:
|
||||
continue
|
||||
try:
|
||||
result = func(self, *args, **kwargs)
|
||||
except BaseException as e:
|
||||
future.set_exception(e)
|
||||
continue
|
||||
future.set_result(result)
|
||||
# write
|
||||
self.DBSession.commit()
|
||||
self.print_error("SQL thread terminated")
|
||||
|
||||
def sql(func):
|
||||
def wrapper(self, *args, **kwargs):
|
||||
f = concurrent.futures.Future()
|
||||
self.db_requests.put((f, func, args, kwargs))
|
||||
return f.result(timeout=10)
|
||||
return wrapper
|
||||
|
||||
@sql
|
||||
def get_sweep_tx(self, funding_outpoint, prev_txid):
|
||||
return [Transaction(r.tx) for r in self.DBSession.query(SweepTx).filter(SweepTx.funding_outpoint==funding_outpoint, SweepTx.prev_txid==prev_txid).all()]
|
||||
|
||||
@sql
|
||||
def list_sweep_tx(self):
|
||||
return set(r.funding_outpoint for r in self.DBSession.query(SweepTx).all())
|
||||
|
||||
@sql
|
||||
def add_sweep_tx(self, funding_outpoint, prev_txid, tx):
|
||||
self.DBSession.add(SweepTx(funding_outpoint=funding_outpoint, prev_txid=prev_txid, tx=str(tx), txid=tx.txid()))
|
||||
self.DBSession.commit()
|
||||
|
||||
@sql
|
||||
def num_sweep_tx(self, funding_outpoint):
|
||||
return self.DBSession.query(SweepTx).filter(funding_outpoint==funding_outpoint).count()
|
||||
|
||||
@sql
|
||||
def remove_sweep_tx(self, funding_outpoint):
|
||||
v = self.DBSession.query(SweepTx).filter(SweepTx.funding_outpoint==funding_outpoint).all()
|
||||
self.DBSession.delete(v)
|
||||
self.DBSession.commit()
|
||||
|
||||
@sql
|
||||
def add_channel_info(self, address, outpoint):
|
||||
self.DBSession.add(ChannelInfo(address=address, outpoint=outpoint))
|
||||
self.DBSession.commit()
|
||||
|
||||
@sql
|
||||
def remove_channel_info(self, address):
|
||||
v = self.DBSession.query(ChannelInfo).filter(ChannelInfo.address==address).one_or_none()
|
||||
self.DBSession.delete(v)
|
||||
self.DBSession.commit()
|
||||
|
||||
@sql
|
||||
def has_channel_info(self, address):
|
||||
return bool(self.DBSession.query(ChannelInfo).filter(ChannelInfo.address==address).one_or_none())
|
||||
|
||||
@sql
|
||||
def get_channel_info(self, address):
|
||||
r = self.DBSession.query(ChannelInfo).filter(ChannelInfo.address==address).one_or_none()
|
||||
return r.outpoint if r else None
|
||||
|
||||
@sql
|
||||
def list_channel_info(self):
|
||||
return [(r.address, r.outpoint) for r in self.DBSession.query(ChannelInfo).all()]
|
||||
|
||||
|
||||
class LNWatcher(AddressSynchronizer):
|
||||
verbosity_filter = 'W'
|
||||
|
||||
def __init__(self, network: 'Network'):
|
||||
path = os.path.join(network.config.path, "watcher_db")
|
||||
path = os.path.join(network.config.path, "watchtower_wallet")
|
||||
storage = WalletStorage(path)
|
||||
AddressSynchronizer.__init__(self, storage)
|
||||
self.config = network.config
|
||||
self.start_network(network)
|
||||
self.lock = threading.RLock()
|
||||
self.channel_info = storage.get('channel_info', {}) # access with 'lock'
|
||||
# [funding_outpoint_str][prev_txid] -> set of Transaction
|
||||
# prev_txid is the txid of a tx that is watched for confirmations
|
||||
# access with 'lock'
|
||||
self.sweepstore = defaultdict(lambda: defaultdict(set))
|
||||
for funding_outpoint, ctxs in storage.get('sweepstore', {}).items():
|
||||
for txid, set_of_txns in ctxs.items():
|
||||
for tx in set_of_txns:
|
||||
tx2 = Transaction.from_dict(tx)
|
||||
self.sweepstore[funding_outpoint][txid].add(tx2)
|
||||
|
||||
self.sweepstore = SweepStore(os.path.join(network.config.path, "watchtower_db"), network)
|
||||
self.network.register_callback(self.on_network_update,
|
||||
['network_updated', 'blockchain_updated', 'verified', 'wallet_updated'])
|
||||
self.set_remote_watchtower()
|
||||
@@ -97,34 +197,18 @@ class LNWatcher(AddressSynchronizer):
|
||||
await asyncio.sleep(5)
|
||||
await self.watchtower_queue.put((name, args, kwargs))
|
||||
|
||||
def write_to_disk(self):
|
||||
# FIXME: json => every update takes linear instead of constant disk write
|
||||
with self.lock:
|
||||
storage = self.storage
|
||||
storage.put('channel_info', self.channel_info)
|
||||
# self.sweepstore
|
||||
sweepstore = {}
|
||||
for funding_outpoint, ctxs in self.sweepstore.items():
|
||||
sweepstore[funding_outpoint] = {}
|
||||
for prev_txid, set_of_txns in ctxs.items():
|
||||
sweepstore[funding_outpoint][prev_txid] = [tx.as_dict() for tx in set_of_txns]
|
||||
storage.put('sweepstore', sweepstore)
|
||||
storage.write()
|
||||
|
||||
@with_watchtower
|
||||
def watch_channel(self, address, outpoint):
|
||||
self.add_address(address)
|
||||
with self.lock:
|
||||
if address not in self.channel_info:
|
||||
self.channel_info[address] = outpoint
|
||||
self.write_to_disk()
|
||||
if not self.sweepstore.has_channel_info(address):
|
||||
self.sweepstore.add_channel_info(address, outpoint)
|
||||
|
||||
def unwatch_channel(self, address, funding_outpoint):
|
||||
self.print_error('unwatching', funding_outpoint)
|
||||
with self.lock:
|
||||
self.channel_info.pop(address)
|
||||
self.sweepstore.pop(funding_outpoint)
|
||||
self.write_to_disk()
|
||||
self.sweepstore.remove_sweep_tx(funding_outpoint)
|
||||
self.sweepstore.remove_channel_info(address)
|
||||
if funding_outpoint in self.tx_progress:
|
||||
self.tx_progress[funding_outpoint].all_done.set()
|
||||
|
||||
@@ -138,9 +222,7 @@ class LNWatcher(AddressSynchronizer):
|
||||
return
|
||||
if not self.synchronizer.is_up_to_date():
|
||||
return
|
||||
with self.lock:
|
||||
channel_info_items = list(self.channel_info.items())
|
||||
for address, outpoint in channel_info_items:
|
||||
for address, outpoint in self.sweepstore.list_channel_info():
|
||||
await self.check_onchain_situation(address, outpoint)
|
||||
|
||||
async def check_onchain_situation(self, address, funding_outpoint):
|
||||
@@ -192,8 +274,7 @@ class LNWatcher(AddressSynchronizer):
|
||||
if spender is not None:
|
||||
continue
|
||||
prev_txid, prev_n = prevout.split(':')
|
||||
with self.lock:
|
||||
sweep_txns = self.sweepstore[funding_outpoint][prev_txid]
|
||||
sweep_txns = self.sweepstore.get_sweep_tx(funding_outpoint, prev_txid)
|
||||
for tx in sweep_txns:
|
||||
if not await self.broadcast_or_log(funding_outpoint, tx):
|
||||
self.print_error(tx.name, f'could not publish tx: {str(tx)}, prev_txid: {prev_txid}')
|
||||
@@ -215,9 +296,7 @@ class LNWatcher(AddressSynchronizer):
|
||||
@with_watchtower
|
||||
def add_sweep_tx(self, funding_outpoint: str, prev_txid: str, tx_dict):
|
||||
tx = Transaction.from_dict(tx_dict)
|
||||
with self.lock:
|
||||
self.sweepstore[funding_outpoint][prev_txid].add(tx)
|
||||
self.write_to_disk()
|
||||
self.sweepstore.add_sweep_tx(funding_outpoint, prev_txid, tx)
|
||||
|
||||
def get_tx_mined_depth(self, txid: str):
|
||||
if not txid:
|
||||
|
||||
Reference in New Issue
Block a user