exchange_rate: add some type hints
This commit is contained in:
@@ -8,7 +8,7 @@ import time
|
|||||||
import csv
|
import csv
|
||||||
import decimal
|
import decimal
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
from typing import Sequence, Optional
|
from typing import Sequence, Optional, Mapping, Dict, Union
|
||||||
|
|
||||||
from aiorpcx.curio import timeout_after, TaskTimeout
|
from aiorpcx.curio import timeout_after, TaskTimeout
|
||||||
import aiohttp
|
import aiohttp
|
||||||
@@ -41,8 +41,8 @@ class ExchangeBase(Logger):
|
|||||||
|
|
||||||
def __init__(self, on_quotes, on_history):
|
def __init__(self, on_quotes, on_history):
|
||||||
Logger.__init__(self)
|
Logger.__init__(self)
|
||||||
self.history = {}
|
self.history = {} # type: Dict[str, Dict[str, Union[str, float, Decimal]]]
|
||||||
self.quotes = {}
|
self.quotes = {} # type: Dict[str, Union[str, float, Decimal, None]]
|
||||||
self.on_quotes = on_quotes
|
self.on_quotes = on_quotes
|
||||||
self.on_history = on_history
|
self.on_history = on_history
|
||||||
|
|
||||||
@@ -75,7 +75,7 @@ class ExchangeBase(Logger):
|
|||||||
def name(self):
|
def name(self):
|
||||||
return self.__class__.__name__
|
return self.__class__.__name__
|
||||||
|
|
||||||
async def update_safe(self, ccy):
|
async def update_safe(self, ccy: str) -> None:
|
||||||
try:
|
try:
|
||||||
self.logger.info(f"getting fx quotes for {ccy}")
|
self.logger.info(f"getting fx quotes for {ccy}")
|
||||||
self.quotes = await self.get_rates(ccy)
|
self.quotes = await self.get_rates(ccy)
|
||||||
@@ -88,7 +88,7 @@ class ExchangeBase(Logger):
|
|||||||
self.quotes = {}
|
self.quotes = {}
|
||||||
self.on_quotes()
|
self.on_quotes()
|
||||||
|
|
||||||
def read_historical_rates(self, ccy, cache_dir) -> Optional[dict]:
|
def read_historical_rates(self, ccy: str, cache_dir: str) -> Optional[dict]:
|
||||||
filename = os.path.join(cache_dir, self.name() + '_'+ ccy)
|
filename = os.path.join(cache_dir, self.name() + '_'+ ccy)
|
||||||
if not os.path.exists(filename):
|
if not os.path.exists(filename):
|
||||||
return None
|
return None
|
||||||
@@ -106,7 +106,7 @@ class ExchangeBase(Logger):
|
|||||||
return h
|
return h
|
||||||
|
|
||||||
@log_exceptions
|
@log_exceptions
|
||||||
async def get_historical_rates_safe(self, ccy, cache_dir):
|
async def get_historical_rates_safe(self, ccy: str, cache_dir: str) -> None:
|
||||||
try:
|
try:
|
||||||
self.logger.info(f"requesting fx history for {ccy}")
|
self.logger.info(f"requesting fx history for {ccy}")
|
||||||
h = await self.request_history(ccy)
|
h = await self.request_history(ccy)
|
||||||
@@ -124,7 +124,7 @@ class ExchangeBase(Logger):
|
|||||||
self.history[ccy] = h
|
self.history[ccy] = h
|
||||||
self.on_history()
|
self.on_history()
|
||||||
|
|
||||||
def get_historical_rates(self, ccy, cache_dir):
|
def get_historical_rates(self, ccy: str, cache_dir: str) -> None:
|
||||||
if ccy not in self.history_ccys():
|
if ccy not in self.history_ccys():
|
||||||
return
|
return
|
||||||
h = self.history.get(ccy)
|
h = self.history.get(ccy)
|
||||||
@@ -133,19 +133,19 @@ class ExchangeBase(Logger):
|
|||||||
if h is None or h['timestamp'] < time.time() - 24*3600:
|
if h is None or h['timestamp'] < time.time() - 24*3600:
|
||||||
asyncio.get_event_loop().create_task(self.get_historical_rates_safe(ccy, cache_dir))
|
asyncio.get_event_loop().create_task(self.get_historical_rates_safe(ccy, cache_dir))
|
||||||
|
|
||||||
def history_ccys(self):
|
def history_ccys(self) -> Sequence[str]:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def historical_rate(self, ccy, d_t):
|
def historical_rate(self, ccy: str, d_t: datetime) -> Union[str, float, Decimal]:
|
||||||
return self.history.get(ccy, {}).get(d_t.strftime('%Y-%m-%d'), 'NaN')
|
return self.history.get(ccy, {}).get(d_t.strftime('%Y-%m-%d'), 'NaN')
|
||||||
|
|
||||||
async def request_history(self, ccy):
|
async def request_history(self, ccy: str) -> Dict[str, Union[str, float, Decimal]]:
|
||||||
raise NotImplementedError() # implemented by subclasses
|
raise NotImplementedError() # implemented by subclasses
|
||||||
|
|
||||||
async def get_rates(self, ccy):
|
async def get_rates(self, ccy: str) -> Mapping[str, Union[str, float, Decimal, None]]:
|
||||||
raise NotImplementedError() # implemented by subclasses
|
raise NotImplementedError() # implemented by subclasses
|
||||||
|
|
||||||
async def get_currencies(self):
|
async def get_currencies(self) -> Sequence[str]:
|
||||||
rates = await self.get_rates('')
|
rates = await self.get_rates('')
|
||||||
return sorted([str(a) for (a, b) in rates.items() if b is not None and len(a)==3])
|
return sorted([str(a) for (a, b) in rates.items() if b is not None and len(a)==3])
|
||||||
|
|
||||||
@@ -489,7 +489,7 @@ class FxThread(ThreadJob):
|
|||||||
self.history_used_spot = False
|
self.history_used_spot = False
|
||||||
self.ccy_combo = None
|
self.ccy_combo = None
|
||||||
self.hist_checkbox = None
|
self.hist_checkbox = None
|
||||||
self.cache_dir = os.path.join(config.path, 'cache')
|
self.cache_dir = os.path.join(config.path, 'cache') # type: str
|
||||||
self._trigger = asyncio.Event()
|
self._trigger = asyncio.Event()
|
||||||
self._trigger.set()
|
self._trigger.set()
|
||||||
self.set_exchange(self.config_exchange())
|
self.set_exchange(self.config_exchange())
|
||||||
|
|||||||
Reference in New Issue
Block a user