1
0

create parent class for sql databases

This commit is contained in:
ThomasV
2019-03-06 09:56:22 +01:00
parent b861e2e955
commit d8e9a9a49e
3 changed files with 71 additions and 81 deletions

View File

@@ -35,13 +35,11 @@ from typing import Sequence, List, Tuple, Optional, Dict, NamedTuple, TYPE_CHECK
import binascii
import base64
from sqlalchemy import create_engine, Column, ForeignKey, Integer, String, DateTime, Boolean
from sqlalchemy.pool import StaticPool
from sqlalchemy.orm import sessionmaker
from sqlalchemy import Column, ForeignKey, Integer, String, DateTime, Boolean
from sqlalchemy.orm.query import Query
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.sql import not_, or_
from sqlalchemy.orm import scoped_session
from .sql_db import SqlDB, sql
from . import constants
from .util import PrintError, bh2u, profiler, get_headers_dir, bfh, is_ip_address, list_enabled_bits
@@ -212,50 +210,25 @@ class Address(Base):
last_connected_date = Column(DateTime(), nullable=False)
class ChannelDB(PrintError):
class ChannelDB(SqlDB):
NUM_MAX_RECENT_PEERS = 20
def __init__(self, network: 'Network'):
self.network = network
path = os.path.join(get_headers_dir(network.config), 'channel_db')
super().__init__(network, path, Base)
print(Base)
self.num_nodes = 0
self.num_channels = 0
self.path = os.path.join(get_headers_dir(network.config), 'channel_db.sqlite3')
self._channel_updates_for_private_channels = {} # type: Dict[Tuple[bytes, bytes], dict]
self.ca_verifier = LNChannelVerifier(network, self)
self.db_requests = queue.Queue()
threading.Thread(target=self.sql_thread).start()
self.update_counts()
def sql_thread(self):
self.sql_thread = threading.currentThread()
engine = create_engine('sqlite:///' + self.path, pool_reset_on_return=None, poolclass=StaticPool)#, echo=True)
DBSession = sessionmaker(bind=engine, autoflush=False)
self.DBSession = DBSession()
if not os.path.exists(self.path):
Base.metadata.create_all(engine)
@sql
def update_counts(self):
self._update_counts()
while self.network.asyncio_loop.is_running():
try:
future, func, args, kwargs = self.db_requests.get(timeout=0.1)
except queue.Empty:
continue
try:
result = func(self, *args, **kwargs)
except BaseException as e:
future.set_exception(e)
continue
future.set_result(result)
# write
self.DBSession.commit()
self.print_error("SQL thread terminated")
def sql(func):
def wrapper(self, *args, **kwargs):
assert threading.currentThread() != self.sql_thread
f = concurrent.futures.Future()
self.db_requests.put((f, func, args, kwargs))
return f.result(timeout=10)
return wrapper
def _update_counts(self):
self.num_channels = self.DBSession.query(ChannelInfo).count()