1
0

create parent class for sql databases

This commit is contained in:
ThomasV
2019-03-06 09:56:22 +01:00
parent b861e2e955
commit d8e9a9a49e
3 changed files with 71 additions and 81 deletions

View File

@@ -11,9 +11,14 @@ from collections import defaultdict
import asyncio
from enum import IntEnum, auto
from typing import NamedTuple, Dict
import jsonrpclib
from sqlalchemy import Column, ForeignKey, Integer, String, DateTime, Boolean
from sqlalchemy.orm.query import Query
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.sql import not_, or_
from .sql_db import SqlDB, sql
from .util import PrintError, bh2u, bfh, log_exceptions, ignore_exceptions
from . import wallet
from .storage import WalletStorage
@@ -37,14 +42,6 @@ class TxMinedDepth(IntEnum):
FREE = auto()
from sqlalchemy import create_engine, Column, ForeignKey, Integer, String, DateTime, Boolean
from sqlalchemy.pool import StaticPool
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm.query import Query
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.sql import not_, or_
from sqlalchemy.orm import scoped_session
Base = declarative_base()
class SweepTx(Base):
@@ -60,42 +57,11 @@ class ChannelInfo(Base):
outpoint = Column(String(34))
class SweepStore(PrintError):
class SweepStore(SqlDB):
def __init__(self, path, network):
PrintError.__init__(self)
self.path = path
self.network = network
self.db_requests = queue.Queue()
threading.Thread(target=self.sql_thread).start()
def sql_thread(self):
engine = create_engine('sqlite:///' + self.path, pool_reset_on_return=None, poolclass=StaticPool)
DBSession = sessionmaker(bind=engine, autoflush=False)
self.DBSession = DBSession()
if not os.path.exists(self.path):
Base.metadata.create_all(engine)
while self.network.asyncio_loop.is_running():
try:
future, func, args, kwargs = self.db_requests.get(timeout=0.1)
except queue.Empty:
continue
try:
result = func(self, *args, **kwargs)
except BaseException as e:
future.set_exception(e)
continue
future.set_result(result)
# write
self.DBSession.commit()
self.print_error("SQL thread terminated")
def sql(func):
def wrapper(self, *args, **kwargs):
f = concurrent.futures.Future()
self.db_requests.put((f, func, args, kwargs))
return f.result(timeout=10)
return wrapper
super().__init__(network, path, Base)
@sql
def get_sweep_tx(self, funding_outpoint, prev_txid):