1
0

Merge the network and network_proxy

This commit is contained in:
Neil Booth
2015-08-30 21:18:10 +09:00
parent 4d6a0f29ee
commit 2d05e7d891
14 changed files with 158 additions and 319 deletions

View File

@@ -5,7 +5,8 @@ import sys
import random
import select
import traceback
from collections import deque
from collections import defaultdict, deque
from threading import Lock
import socks
import socket
@@ -129,20 +130,19 @@ class Network(util.DaemonThread):
Our external API:
- Member functions get_header(), get_parameters(), get_status_value(),
new_blockchain_height(), set_parameters(), start(),
stop()
- Member functions get_header(), get_interfaces(), get_local_height(),
get_parameters(), get_server_height(), get_status_value(),
is_connected(), new_blockchain_height(), set_parameters(), start(),
stop()
"""
def __init__(self, pipe, config=None):
def __init__(self, config=None):
if config is None:
config = {} # Do not use mutables as default values!
util.DaemonThread.__init__(self)
self.config = SimpleConfig(config) if type(config) == type({}) else config
self.num_server = 8 if not self.config.get('oneserver') else 0
self.blockchain = Blockchain(self.config, self)
self.requests_queue = pipe.send_queue
self.response_queue = pipe.get_queue
# A deque of interface header requests, processed left-to-right
self.bc_requests = deque()
# Server for addresses and transactions
@@ -155,6 +155,10 @@ class Network(util.DaemonThread):
if not self.default_server:
self.default_server = pick_random_server()
self.lock = Lock()
self.pending_sends = []
self.message_id = 0
self.debug = False
self.irc_servers = {} # returned by interface (list from irc)
self.recent_servers = self.read_recent_servers()
@@ -163,6 +167,8 @@ class Network(util.DaemonThread):
self.heights = {}
self.merkle_roots = {}
self.utxo_roots = {}
self.subscriptions = defaultdict(list)
self.callbacks = defaultdict(list)
dir_path = os.path.join( self.config.path, 'certs')
if not os.path.exists(dir_path):
@@ -188,6 +194,15 @@ class Network(util.DaemonThread):
self.start_network(deserialize_server(self.default_server)[2],
deserialize_proxy(self.config.get('proxy')))
def register_callback(self, event, callback):
with self.lock:
self.callbacks[event].append(callback)
def trigger_callback(self, event, params=()):
with self.lock:
callbacks = self.callbacks[event][:]
[callback(*params) for callback in callbacks]
def read_recent_servers(self):
if not self.config.path:
return []
@@ -231,6 +246,12 @@ class Network(util.DaemonThread):
def is_connected(self):
return self.interface is not None
def is_connecting(self):
return self.connection_status == 'connecting'
def is_up_to_date(self):
return self.unanswered_requests == {}
def queue_request(self, method, params):
self.interface.queue_request({'method': method, 'params': params})
@@ -263,7 +284,10 @@ class Network(util.DaemonThread):
def notify(self, key):
value = self.get_status_value(key)
self.response_queue.put({'method':'network.status', 'params':[key, value]})
if key in ['status', 'updated']:
self.trigger_callback(key)
else:
self.trigger_callback(key, (value,))
def get_parameters(self):
host, port, protocol = deserialize_server(self.default_server)
@@ -337,8 +361,16 @@ class Network(util.DaemonThread):
self.socket_queue = Queue.Queue()
def set_parameters(self, host, port, protocol, proxy, auto_connect):
self.auto_connect = auto_connect
proxy_str = serialize_proxy(proxy)
server = serialize_server(host, port, protocol)
self.config.set_key('auto_connect', auto_connect, False)
self.config.set_key("proxy", proxy_str, False)
self.config.set_key("server", server, True)
# abort if changes were not allowed by config
if self.config.get('server') != server_str or self.config.get('proxy') != proxy_str:
return
self.auto_connect = auto_connect
if self.proxy != proxy or self.protocol != protocol:
# Restart the network defaulting to the given server
self.stop_network()
@@ -405,7 +437,9 @@ class Network(util.DaemonThread):
self.switch_lagging_interface(i.server)
self.notify('updated')
def process_response(self, interface, response):
def process_response(self, interface, response, callback):
if self.debug:
self.print_error("<--", response)
error = response.get('error')
result = response.get('result')
method = response.get('method')
@@ -437,8 +471,19 @@ class Network(util.DaemonThread):
# Cache address subscription results
if method == 'blockchain.address.subscribe' and error is None:
addr = response['params'][0]
self.addr_responses[addr] = result
self.response_queue.put(response)
self.addr_responses[addr] = response
if callback is None:
params = response['params']
with self.lock:
for k,v in self.subscriptions.items():
if (method, params) in v:
callback = k
break
if callback is None:
self.print_error("received unexpected notification",
method, params)
else:
callback(response)
def process_responses(self, interface):
notifications, responses = interface.get_responses()
@@ -449,12 +494,14 @@ class Network(util.DaemonThread):
if client_id is not None:
if interface != self.interface:
continue
self.unanswered_requests.pop(client_id)
_req, callback = self.unanswered_requests.pop(client_id)
else:
callback = None
# Copy the request method and params to the response
response['method'] = request.get('method')
response['params'] = request.get('params')
response['id'] = client_id
self.process_response(interface, response)
self.process_response(interface, response, callback)
for response in notifications:
if not response: # Closed remotely
@@ -466,16 +513,42 @@ class Network(util.DaemonThread):
response['result'] = response['params'][0]
response['params'] = []
elif method == 'blockchain.address.subscribe':
params = response['params']
response['params'] = [params[0]] # addr
response['result'] = params[1]
self.process_response(interface, response)
self.process_response(interface, response, None)
def handle_incoming_requests(self):
while not self.requests_queue.empty():
self.process_request(self.requests_queue.get())
def send(self, messages, callback):
'''Messages is a list of (method, value) tuples'''
with self.lock:
self.pending_sends.append((messages, callback))
def process_request(self, request):
def process_pending_sends(self):
sends = self.pending_sends
self.pending_sends = []
for messages, callback in sends:
subs = filter(lambda (m,v): m.endswith('.subscribe'), messages)
with self.lock:
for sub in subs:
if sub not in self.subscriptions[callback]:
self.subscriptions[callback].append(sub)
_id = self.message_id
self.message_id += len(messages)
unsent = []
for message in messages:
method, params = message
request = {'id': _id, 'method': method, 'params': params}
if not self.process_request(request, callback):
unsent.append(message)
_id += 1
if unsent:
with self.lock:
self.pending_sends.append((unsent, callback))
# FIXME: inline this function
def process_request(self, request, callback):
'''Returns true if the request was processed.'''
method = request['method']
params = request['params']
@@ -492,14 +565,14 @@ class Network(util.DaemonThread):
out['error'] = str(e)
traceback.print_exc(file=sys.stdout)
self.print_error("network error", str(e))
self.response_queue.put(out)
callback(out)
return True
if method == 'blockchain.address.subscribe':
addr = params[0]
self.subscribed_addresses.add(addr)
if addr in self.addr_responses:
self.response_queue.put({'id':_id, 'result':self.addr_responses[addr]})
callback(self.addr_responses[addr])
return True
# This request needs connectivity. If we don't have an
@@ -507,7 +580,9 @@ class Network(util.DaemonThread):
if not self.interface:
return False
self.unanswered_requests[_id] = request
if self.debug:
self.print_error("-->", request)
self.unanswered_requests[_id] = request, callback
self.interface.queue_request(request)
return True
@@ -679,10 +754,12 @@ class Network(util.DaemonThread):
while self.is_running():
self.maintain_sockets()
self.wait_on_sockets()
self.handle_incoming_requests()
self.handle_bc_requests()
self.run_jobs() # Synchronizer and Verifier
self.process_pending_sends()
self.stop_network()
self.trigger_callback('stop')
self.print_error("stopped")
def on_header(self, i, header):
@@ -706,3 +783,11 @@ class Network(util.DaemonThread):
def get_local_height(self):
return self.blockchain.height()
def synchronous_get(self, request, timeout=100000000):
queue = Queue.Queue()
self.send([request], queue.put)
r = queue.get(True, timeout)
if r.get('error'):
raise BaseException(r.get('error'))
return r.get('result')