Merge pull request #4767 from SomberNight/auto_jump_forks
network: auto-switch servers to preferred fork (or longest chain)
This commit is contained in:
@@ -22,7 +22,7 @@
|
|||||||
# SOFTWARE.
|
# SOFTWARE.
|
||||||
import os
|
import os
|
||||||
import threading
|
import threading
|
||||||
from typing import Optional
|
from typing import Optional, Dict
|
||||||
|
|
||||||
from . import util
|
from . import util
|
||||||
from .bitcoin import Hash, hash_encode, int_to_hex, rev_hex
|
from .bitcoin import Hash, hash_encode, int_to_hex, rev_hex
|
||||||
@@ -73,7 +73,7 @@ def hash_header(header: dict) -> str:
|
|||||||
return hash_encode(Hash(bfh(serialize_header(header))))
|
return hash_encode(Hash(bfh(serialize_header(header))))
|
||||||
|
|
||||||
|
|
||||||
blockchains = {}
|
blockchains = {} # type: Dict[int, Blockchain]
|
||||||
blockchains_lock = threading.Lock()
|
blockchains_lock = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
@@ -100,7 +100,7 @@ class Blockchain(util.PrintError):
|
|||||||
Manages blockchain headers and their verification
|
Manages blockchain headers and their verification
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config, forkpoint: int, parent_id: int):
|
def __init__(self, config, forkpoint: int, parent_id: Optional[int]):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.forkpoint = forkpoint
|
self.forkpoint = forkpoint
|
||||||
self.checkpoints = constants.net.CHECKPOINTS
|
self.checkpoints = constants.net.CHECKPOINTS
|
||||||
@@ -124,22 +124,32 @@ class Blockchain(util.PrintError):
|
|||||||
children = list(filter(lambda y: y.parent_id==self.forkpoint, chains))
|
children = list(filter(lambda y: y.parent_id==self.forkpoint, chains))
|
||||||
return max([x.forkpoint for x in children]) if children else None
|
return max([x.forkpoint for x in children]) if children else None
|
||||||
|
|
||||||
def get_forkpoint(self) -> int:
|
def get_max_forkpoint(self) -> int:
|
||||||
|
"""Returns the max height where there is a fork
|
||||||
|
related to this chain.
|
||||||
|
"""
|
||||||
mc = self.get_max_child()
|
mc = self.get_max_child()
|
||||||
return mc if mc is not None else self.forkpoint
|
return mc if mc is not None else self.forkpoint
|
||||||
|
|
||||||
def get_branch_size(self) -> int:
|
def get_branch_size(self) -> int:
|
||||||
return self.height() - self.get_forkpoint() + 1
|
return self.height() - self.get_max_forkpoint() + 1
|
||||||
|
|
||||||
def get_name(self) -> str:
|
def get_name(self) -> str:
|
||||||
return self.get_hash(self.get_forkpoint()).lstrip('00')[0:10]
|
return self.get_hash(self.get_max_forkpoint()).lstrip('00')[0:10]
|
||||||
|
|
||||||
def check_header(self, header: dict) -> bool:
|
def check_header(self, header: dict) -> bool:
|
||||||
header_hash = hash_header(header)
|
header_hash = hash_header(header)
|
||||||
height = header.get('block_height')
|
height = header.get('block_height')
|
||||||
|
return self.check_hash(height, header_hash)
|
||||||
|
|
||||||
|
def check_hash(self, height: int, header_hash: str) -> bool:
|
||||||
|
"""Returns whether the hash of the block at given height
|
||||||
|
is the given hash.
|
||||||
|
"""
|
||||||
|
assert isinstance(header_hash, str) and len(header_hash) == 64, header_hash # hex
|
||||||
try:
|
try:
|
||||||
return header_hash == self.get_hash(height)
|
return header_hash == self.get_hash(height)
|
||||||
except MissingHeader:
|
except Exception:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def fork(parent, header: dict) -> 'Blockchain':
|
def fork(parent, header: dict) -> 'Blockchain':
|
||||||
|
|||||||
@@ -120,7 +120,7 @@ class ElectrumWindow(App):
|
|||||||
with blockchain.blockchains_lock: blockchain_items = list(blockchain.blockchains.items())
|
with blockchain.blockchains_lock: blockchain_items = list(blockchain.blockchains.items())
|
||||||
for index, b in blockchain_items:
|
for index, b in blockchain_items:
|
||||||
if name == b.get_name():
|
if name == b.get_name():
|
||||||
self.network.run_from_another_thread(self.network.follow_chain(index))
|
self.network.run_from_another_thread(self.network.follow_chain_given_id(index))
|
||||||
names = [blockchain.blockchains[b].get_name() for b in chains]
|
names = [blockchain.blockchains[b].get_name() for b in chains]
|
||||||
if len(names) > 1:
|
if len(names) > 1:
|
||||||
cur_chain = self.network.blockchain().get_name()
|
cur_chain = self.network.blockchain().get_name()
|
||||||
@@ -664,7 +664,7 @@ class ElectrumWindow(App):
|
|||||||
self.num_nodes = len(self.network.get_interfaces())
|
self.num_nodes = len(self.network.get_interfaces())
|
||||||
self.num_chains = len(self.network.get_blockchains())
|
self.num_chains = len(self.network.get_blockchains())
|
||||||
chain = self.network.blockchain()
|
chain = self.network.blockchain()
|
||||||
self.blockchain_forkpoint = chain.get_forkpoint()
|
self.blockchain_forkpoint = chain.get_max_forkpoint()
|
||||||
self.blockchain_name = chain.get_name()
|
self.blockchain_name = chain.get_name()
|
||||||
interface = self.network.interface
|
interface = self.network.interface
|
||||||
if interface:
|
if interface:
|
||||||
|
|||||||
@@ -107,7 +107,7 @@ class NodesListWidget(QTreeWidget):
|
|||||||
b = blockchain.blockchains[k]
|
b = blockchain.blockchains[k]
|
||||||
name = b.get_name()
|
name = b.get_name()
|
||||||
if n_chains >1:
|
if n_chains >1:
|
||||||
x = QTreeWidgetItem([name + '@%d'%b.get_forkpoint(), '%d'%b.height()])
|
x = QTreeWidgetItem([name + '@%d'%b.get_max_forkpoint(), '%d'%b.height()])
|
||||||
x.setData(0, Qt.UserRole, 1)
|
x.setData(0, Qt.UserRole, 1)
|
||||||
x.setData(1, Qt.UserRole, b.forkpoint)
|
x.setData(1, Qt.UserRole, b.forkpoint)
|
||||||
else:
|
else:
|
||||||
@@ -364,7 +364,7 @@ class NetworkChoiceLayout(object):
|
|||||||
chains = self.network.get_blockchains()
|
chains = self.network.get_blockchains()
|
||||||
if len(chains) > 1:
|
if len(chains) > 1:
|
||||||
chain = self.network.blockchain()
|
chain = self.network.blockchain()
|
||||||
forkpoint = chain.get_forkpoint()
|
forkpoint = chain.get_max_forkpoint()
|
||||||
name = chain.get_name()
|
name = chain.get_name()
|
||||||
msg = _('Chain split detected at block {0}').format(forkpoint) + '\n'
|
msg = _('Chain split detected at block {0}').format(forkpoint) + '\n'
|
||||||
msg += (_('You are following branch') if auto_connect else _('Your server is on branch'))+ ' ' + name
|
msg += (_('You are following branch') if auto_connect else _('Your server is on branch'))+ ' ' + name
|
||||||
@@ -411,14 +411,11 @@ class NetworkChoiceLayout(object):
|
|||||||
self.set_server()
|
self.set_server()
|
||||||
|
|
||||||
def follow_branch(self, index):
|
def follow_branch(self, index):
|
||||||
self.network.run_from_another_thread(self.network.follow_chain(index))
|
self.network.run_from_another_thread(self.network.follow_chain_given_id(index))
|
||||||
self.update()
|
self.update()
|
||||||
|
|
||||||
def follow_server(self, server):
|
def follow_server(self, server):
|
||||||
net_params = self.network.get_parameters()
|
self.network.run_from_another_thread(self.network.follow_chain_given_server(server))
|
||||||
host, port, protocol = deserialize_server(server)
|
|
||||||
net_params = net_params._replace(host=host, port=port, protocol=protocol)
|
|
||||||
self.network.run_from_another_thread(self.network.set_parameters(net_params))
|
|
||||||
self.update()
|
self.update()
|
||||||
|
|
||||||
def server_changed(self, x):
|
def server_changed(self, x):
|
||||||
|
|||||||
@@ -384,6 +384,7 @@ class Interface(PrintError):
|
|||||||
self.mark_ready()
|
self.mark_ready()
|
||||||
await self._process_header_at_tip()
|
await self._process_header_at_tip()
|
||||||
self.network.trigger_callback('network_updated')
|
self.network.trigger_callback('network_updated')
|
||||||
|
await self.network.switch_unwanted_fork_interface()
|
||||||
await self.network.switch_lagging_interface()
|
await self.network.switch_lagging_interface()
|
||||||
|
|
||||||
async def _process_header_at_tip(self):
|
async def _process_header_at_tip(self):
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ import json
|
|||||||
import sys
|
import sys
|
||||||
import ipaddress
|
import ipaddress
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import NamedTuple, Optional, Sequence, List
|
from typing import NamedTuple, Optional, Sequence, List, Dict
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
import dns
|
import dns
|
||||||
@@ -172,10 +172,9 @@ class Network(PrintError):
|
|||||||
self.config = SimpleConfig(config) if isinstance(config, dict) else config
|
self.config = SimpleConfig(config) if isinstance(config, dict) else config
|
||||||
self.num_server = 10 if not self.config.get('oneserver') else 0
|
self.num_server = 10 if not self.config.get('oneserver') else 0
|
||||||
blockchain.blockchains = blockchain.read_blockchains(self.config)
|
blockchain.blockchains = blockchain.read_blockchains(self.config)
|
||||||
self.print_error("blockchains", list(blockchain.blockchains.keys()))
|
self.print_error("blockchains", list(blockchain.blockchains))
|
||||||
self.blockchain_index = config.get('blockchain_index', 0)
|
self._blockchain_preferred_block = self.config.get('blockchain_preferred_block', None) # type: Optional[Dict]
|
||||||
if self.blockchain_index not in blockchain.blockchains.keys():
|
self._blockchain_index = 0
|
||||||
self.blockchain_index = 0
|
|
||||||
# Server for addresses and transactions
|
# Server for addresses and transactions
|
||||||
self.default_server = self.config.get('server', None)
|
self.default_server = self.config.get('server', None)
|
||||||
# Sanitize default server
|
# Sanitize default server
|
||||||
@@ -213,11 +212,10 @@ class Network(PrintError):
|
|||||||
# retry times
|
# retry times
|
||||||
self.server_retry_time = time.time()
|
self.server_retry_time = time.time()
|
||||||
self.nodes_retry_time = time.time()
|
self.nodes_retry_time = time.time()
|
||||||
# kick off the network. interface is the main server we are currently
|
# the main server we are currently communicating with
|
||||||
# communicating with. interfaces is the set of servers we are connecting
|
|
||||||
# to or have an ongoing connection with
|
|
||||||
self.interface = None # type: Interface
|
self.interface = None # type: Interface
|
||||||
self.interfaces = {}
|
# set of servers we have an ongoing connection with
|
||||||
|
self.interfaces = {} # type: Dict[str, Interface]
|
||||||
self.auto_connect = self.config.get('auto_connect', True)
|
self.auto_connect = self.config.get('auto_connect', True)
|
||||||
self.connecting = set()
|
self.connecting = set()
|
||||||
self.server_queue = None
|
self.server_queue = None
|
||||||
@@ -227,8 +225,8 @@ class Network(PrintError):
|
|||||||
#self.asyncio_loop.set_debug(1)
|
#self.asyncio_loop.set_debug(1)
|
||||||
self._run_forever = asyncio.Future()
|
self._run_forever = asyncio.Future()
|
||||||
self._thread = threading.Thread(target=self.asyncio_loop.run_until_complete,
|
self._thread = threading.Thread(target=self.asyncio_loop.run_until_complete,
|
||||||
args=(self._run_forever,),
|
args=(self._run_forever,),
|
||||||
name='Network')
|
name='Network')
|
||||||
self._thread.start()
|
self._thread.start()
|
||||||
|
|
||||||
def run_from_another_thread(self, coro):
|
def run_from_another_thread(self, coro):
|
||||||
@@ -523,20 +521,40 @@ class Network(PrintError):
|
|||||||
|
|
||||||
async def switch_lagging_interface(self):
|
async def switch_lagging_interface(self):
|
||||||
'''If auto_connect and lagging, switch interface'''
|
'''If auto_connect and lagging, switch interface'''
|
||||||
if await self._server_is_lagging() and self.auto_connect:
|
if self.auto_connect and await self._server_is_lagging():
|
||||||
# switch to one that has the correct header (not height)
|
# switch to one that has the correct header (not height)
|
||||||
header = self.blockchain().read_header(self.get_local_height())
|
best_header = self.blockchain().read_header(self.get_local_height())
|
||||||
def filt(x):
|
with self.interfaces_lock: interfaces = list(self.interfaces.values())
|
||||||
a = x[1].tip_header
|
filtered = list(filter(lambda iface: iface.tip_header == best_header, interfaces))
|
||||||
b = header
|
|
||||||
assert type(a) is type(b)
|
|
||||||
return a == b
|
|
||||||
|
|
||||||
with self.interfaces_lock: interfaces_items = list(self.interfaces.items())
|
|
||||||
filtered = list(map(lambda x: x[0], filter(filt, interfaces_items)))
|
|
||||||
if filtered:
|
if filtered:
|
||||||
choice = random.choice(filtered)
|
chosen_iface = random.choice(filtered)
|
||||||
await self.switch_to_interface(choice)
|
await self.switch_to_interface(chosen_iface.server)
|
||||||
|
|
||||||
|
async def switch_unwanted_fork_interface(self):
|
||||||
|
"""If auto_connect and main interface is not on preferred fork,
|
||||||
|
try to switch to preferred fork.
|
||||||
|
"""
|
||||||
|
if not self.auto_connect:
|
||||||
|
return
|
||||||
|
with self.interfaces_lock: interfaces = list(self.interfaces.values())
|
||||||
|
# try to switch to preferred fork
|
||||||
|
if self._blockchain_preferred_block:
|
||||||
|
pref_height = self._blockchain_preferred_block['height']
|
||||||
|
pref_hash = self._blockchain_preferred_block['hash']
|
||||||
|
filtered = list(filter(lambda iface: iface.blockchain.check_hash(pref_height, pref_hash),
|
||||||
|
interfaces))
|
||||||
|
if filtered:
|
||||||
|
chosen_iface = random.choice(filtered)
|
||||||
|
await self.switch_to_interface(chosen_iface.server)
|
||||||
|
return
|
||||||
|
# try to switch to longest chain
|
||||||
|
if self.blockchain().parent_id is None:
|
||||||
|
return # already on longest chain
|
||||||
|
filtered = list(filter(lambda iface: iface.blockchain.parent_id is None,
|
||||||
|
interfaces))
|
||||||
|
if filtered:
|
||||||
|
chosen_iface = random.choice(filtered)
|
||||||
|
await self.switch_to_interface(chosen_iface.server)
|
||||||
|
|
||||||
async def switch_to_interface(self, server: str):
|
async def switch_to_interface(self, server: str):
|
||||||
"""Switch to server as our main interface. If no connection exists,
|
"""Switch to server as our main interface. If no connection exists,
|
||||||
@@ -704,8 +722,8 @@ class Network(PrintError):
|
|||||||
def blockchain(self) -> Blockchain:
|
def blockchain(self) -> Blockchain:
|
||||||
interface = self.interface
|
interface = self.interface
|
||||||
if interface and interface.blockchain is not None:
|
if interface and interface.blockchain is not None:
|
||||||
self.blockchain_index = interface.blockchain.forkpoint
|
self._blockchain_index = interface.blockchain.forkpoint
|
||||||
return blockchain.blockchains[self.blockchain_index]
|
return blockchain.blockchains[self._blockchain_index]
|
||||||
|
|
||||||
def get_blockchains(self):
|
def get_blockchains(self):
|
||||||
out = {} # blockchain_id -> list(interfaces)
|
out = {} # blockchain_id -> list(interfaces)
|
||||||
@@ -724,24 +742,42 @@ class Network(PrintError):
|
|||||||
await self.connection_down(interface.server)
|
await self.connection_down(interface.server)
|
||||||
return ifaces
|
return ifaces
|
||||||
|
|
||||||
async def follow_chain(self, chain_id):
|
def _set_preferred_chain(self, chain: Blockchain):
|
||||||
bc = blockchain.blockchains.get(chain_id)
|
height = chain.get_max_forkpoint()
|
||||||
if bc:
|
header_hash = chain.get_hash(height)
|
||||||
self.blockchain_index = chain_id
|
self._blockchain_preferred_block = {
|
||||||
self.config.set_key('blockchain_index', chain_id)
|
'height': height,
|
||||||
with self.interfaces_lock: interfaces_values = list(self.interfaces.values())
|
'hash': header_hash,
|
||||||
for iface in interfaces_values:
|
}
|
||||||
if iface.blockchain == bc:
|
self.config.set_key('blockchain_preferred_block', self._blockchain_preferred_block)
|
||||||
await self.switch_to_interface(iface.server)
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
raise Exception('blockchain not found', chain_id)
|
|
||||||
|
|
||||||
if self.interface:
|
async def follow_chain_given_id(self, chain_id: int) -> None:
|
||||||
net_params = self.get_parameters()
|
bc = blockchain.blockchains.get(chain_id)
|
||||||
host, port, protocol = deserialize_server(self.interface.server)
|
if not bc:
|
||||||
net_params = net_params._replace(host=host, port=port, protocol=protocol)
|
raise Exception('blockchain {} not found'.format(chain_id))
|
||||||
await self.set_parameters(net_params)
|
self._set_preferred_chain(bc)
|
||||||
|
# select server on this chain
|
||||||
|
with self.interfaces_lock: interfaces = list(self.interfaces.values())
|
||||||
|
interfaces_on_selected_chain = list(filter(lambda iface: iface.blockchain == bc, interfaces))
|
||||||
|
if len(interfaces_on_selected_chain) == 0: return
|
||||||
|
chosen_iface = random.choice(interfaces_on_selected_chain)
|
||||||
|
# switch to server (and save to config)
|
||||||
|
net_params = self.get_parameters()
|
||||||
|
host, port, protocol = deserialize_server(chosen_iface.server)
|
||||||
|
net_params = net_params._replace(host=host, port=port, protocol=protocol)
|
||||||
|
await self.set_parameters(net_params)
|
||||||
|
|
||||||
|
async def follow_chain_given_server(self, server_str: str) -> None:
|
||||||
|
# note that server_str should correspond to a connected interface
|
||||||
|
iface = self.interfaces.get(server_str)
|
||||||
|
if iface is None:
|
||||||
|
return
|
||||||
|
self._set_preferred_chain(iface.blockchain)
|
||||||
|
# switch to server (and save to config)
|
||||||
|
net_params = self.get_parameters()
|
||||||
|
host, port, protocol = deserialize_server(server_str)
|
||||||
|
net_params = net_params._replace(host=host, port=port, protocol=protocol)
|
||||||
|
await self.set_parameters(net_params)
|
||||||
|
|
||||||
def get_local_height(self):
|
def get_local_height(self):
|
||||||
return self.blockchain().height()
|
return self.blockchain().height()
|
||||||
|
|||||||
@@ -156,7 +156,7 @@ class SPV(NetworkJobOnDefaultServer):
|
|||||||
|
|
||||||
async def _maybe_undo_verifications(self):
|
async def _maybe_undo_verifications(self):
|
||||||
def undo_verifications():
|
def undo_verifications():
|
||||||
height = self.blockchain.get_forkpoint()
|
height = self.blockchain.get_max_forkpoint()
|
||||||
self.print_error("undoing verifications back to height {}".format(height))
|
self.print_error("undoing verifications back to height {}".format(height))
|
||||||
tx_hashes = self.wallet.undo_verifications(self.blockchain, height)
|
tx_hashes = self.wallet.undo_verifications(self.blockchain, height)
|
||||||
for tx_hash in tx_hashes:
|
for tx_hash in tx_hashes:
|
||||||
|
|||||||
Reference in New Issue
Block a user