fix numpy>=1.20 typing check (#323)

Change the behavior of to_numpy and to_torch: from now on, dict is automatically converted to Batch and list is automatically converted to np.ndarray (if an error occurs, raise the exception instead of converting each element in the list).
This commit is contained in:
n+e 2021-03-30 16:06:03 +08:00 committed by GitHub
parent 6426a39796
commit 09692c84fe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
28 changed files with 212 additions and 265 deletions

View File

@ -47,7 +47,7 @@ setup(
install_requires=[
"gym>=0.15.4",
"tqdm",
"numpy!=1.16.0,<1.20.0", # https://github.com/numpy/numpy/issues/12793
"numpy>1.16.0", # https://github.com/numpy/numpy/issues/12793
"tensorboard",
"torch>=1.4.0",
"numba>=0.51.0",

View File

@ -20,9 +20,9 @@ def test_batch():
assert len(Batch(a=[1, 2, 3], b={'c': {}})) == 3
assert not Batch(a=[1, 2, 3]).is_empty()
b = Batch({'a': [4, 4], 'b': [5, 5]}, c=[None, None])
assert b.c.dtype == np.object
assert b.c.dtype == object
b = Batch(d=[None], e=[starmap], f=Batch)
assert b.d.dtype == b.e.dtype == np.object and b.f == Batch
assert b.d.dtype == b.e.dtype == object and b.f == Batch
b = Batch()
b.update()
assert b.is_empty()
@ -153,10 +153,10 @@ def test_batch():
batch3[0] = Batch(a={"c": 2, "e": 1})
# auto convert
batch4 = Batch(a=np.array(['a', 'b']))
assert batch4.a.dtype == np.object # auto convert to np.object
assert batch4.a.dtype == object # auto convert to object
batch4.update(a=np.array(['c', 'd']))
assert list(batch4.a) == ['c', 'd']
assert batch4.a.dtype == np.object # auto convert to np.object
assert batch4.a.dtype == object # auto convert to object
batch5 = Batch(a=np.array([{'index': 0}]))
assert isinstance(batch5.a, Batch)
assert np.allclose(batch5.a.index, [0])
@ -405,21 +405,23 @@ def test_utils_to_torch_numpy():
assert data_list_2_torch.shape == (2, 3, 3)
assert np.allclose(to_numpy(to_torch(data_list_2)), data_list_2)
data_list_3 = [np.zeros((3, 2)), np.zeros((3, 3))]
data_list_3_torch = to_torch(data_list_3)
assert isinstance(data_list_3_torch, list)
assert all(isinstance(e, torch.Tensor) for e in data_list_3_torch)
assert all(starmap(np.allclose,
zip(to_numpy(to_torch(data_list_3)), data_list_3)))
data_list_3_torch = [torch.zeros((3, 2)), torch.zeros((3, 3))]
with pytest.raises(TypeError):
to_torch(data_list_3)
with pytest.raises(TypeError):
to_numpy(data_list_3_torch)
data_list_4 = [np.zeros((2, 3)), np.zeros((3, 3))]
data_list_4_torch = to_torch(data_list_4)
assert isinstance(data_list_4_torch, list)
assert all(isinstance(e, torch.Tensor) for e in data_list_4_torch)
assert all(starmap(np.allclose,
zip(to_numpy(to_torch(data_list_4)), data_list_4)))
data_list_4_torch = [torch.zeros((2, 3)), torch.zeros((3, 3))]
with pytest.raises(TypeError):
to_torch(data_list_4)
with pytest.raises(TypeError):
to_numpy(data_list_4_torch)
data_list_5 = [np.zeros(2), np.zeros((3, 3))]
data_list_5_torch = to_torch(data_list_5)
assert isinstance(data_list_5_torch, list)
assert all(isinstance(e, torch.Tensor) for e in data_list_5_torch)
data_list_5_torch = [torch.zeros(2), torch.zeros((3, 3))]
with pytest.raises(TypeError):
to_torch(data_list_5)
with pytest.raises(TypeError):
to_numpy(data_list_5_torch)
data_array = np.random.rand(3, 2, 2)
data_empty_tensor = to_torch(data_array[[]])
assert isinstance(data_empty_tensor, torch.Tensor)
@ -508,10 +510,10 @@ def test_batch_empty():
assert np.allclose(b5.b.c, [2, 0])
assert np.allclose(b5.b.d, [1, 0])
data = Batch(a=[False, True],
b={'c': np.array([2., 'st'], dtype=np.object),
b={'c': np.array([2., 'st'], dtype=object),
'd': [1, None],
'e': [2., float('nan')]},
c=np.array([1, 3, 4], dtype=np.int),
c=np.array([1, 3, 4], dtype=int),
t=torch.tensor([4, 5, 6, 7.]))
data[-1] = Batch.empty(data[1])
assert np.allclose(data.c, [1, 3, 0])

View File

@ -33,7 +33,7 @@ def test_replaybuffer(size=10, bufsize=20):
done=done, obs_next=obs_next, info=info))
obs = obs_next
assert len(buf) == min(bufsize, i + 1)
assert buf.act.dtype == np.int
assert buf.act.dtype == int
assert buf.act.shape == (bufsize, 1)
data, indice = buf.sample(bufsize * 2)
assert (indice < len(buf)).all()
@ -50,9 +50,9 @@ def test_replaybuffer(size=10, bufsize=20):
assert b.obs_next[0] == 'str'
assert np.all(b.obs[1:] == 0)
assert np.all(b.obs_next[1:] == np.array(None))
assert b.info.a[0] == 3 and b.info.a.dtype == np.integer
assert b.info.a[0] == 3 and b.info.a.dtype == int
assert np.all(b.info.a[1:] == 0)
assert b.info.b.c[0] == 5.0 and b.info.b.c.dtype == np.inexact
assert b.info.b.c[0] == 5.0 and b.info.b.c.dtype == float
assert np.all(b.info.b.c[1:] == 0.0)
assert ptr.shape == (1,) and ptr[0] == 0
assert ep_rew.shape == (1,) and ep_rew[0] == 1
@ -180,8 +180,8 @@ def test_priortized_replaybuffer(size=32, bufsize=15):
assert len(buf2) == min(bufsize, 3 * (i + 1))
# check single buffer's data
assert buf.info.key.shape == (buf.maxsize,)
assert buf.rew.dtype == np.float
assert buf.done.dtype == np.bool_
assert buf.rew.dtype == float
assert buf.done.dtype == bool
data, indice = buf.sample(len(buf) // 2)
buf.update_weight(indice, -data.weight / 2)
assert np.allclose(buf.weight[indice], np.abs(-data.weight / 2) ** buf._alpha)
@ -273,7 +273,7 @@ def test_segtree():
index = tree.get_prefix_sum_idx(scalar)
assert naive[:index].sum() <= scalar <= naive[:index + 1].sum()
# corner case here
naive = np.ones(actual_len, np.int)
naive = np.ones(actual_len, int)
tree[np.arange(actual_len)] = naive
for scalar in range(actual_len):
index = tree.get_prefix_sum_idx(scalar * 1.)
@ -485,7 +485,7 @@ def test_replaybuffermanager():
buf.set_batch(batch)
assert np.allclose(buf.buffers[-1].info, [1] * 5)
assert buf.sample_index(-1).tolist() == []
assert np.array([ReplayBuffer(0, ignore_obs_next=True)]).dtype == np.object
assert np.array([ReplayBuffer(0, ignore_obs_next=True)]).dtype == object
def test_cachedbuffer():

View File

@ -7,16 +7,18 @@ from numbers import Number
from collections.abc import Collection
from typing import Any, List, Dict, Union, Iterator, Optional, Iterable, Sequence
IndexType = Union[slice, int, np.ndarray, List[int]]
def _is_batch_set(data: Any) -> bool:
# Batch set is a list/tuple of dict/Batch objects,
# or 1-D np.ndarray with np.object type,
# or 1-D np.ndarray with object type,
# where each element is a dict/Batch object
if isinstance(data, np.ndarray): # most often case
# "for e in data" will just unpack the first dimension,
# but data.tolist() will flatten ndarray of objects
# so do not use data.tolist()
return data.dtype == np.object and all(
return data.dtype == object and all(
isinstance(e, (dict, Batch)) for e in data)
elif isinstance(data, (list, tuple)):
if len(data) > 0 and all(isinstance(e, (dict, Batch)) for e in data):
@ -50,13 +52,13 @@ def _to_array_with_correct_type(v: Any) -> np.ndarray:
if isinstance(v, np.ndarray) and issubclass(v.dtype.type, (np.bool_, np.number)):
return v # most often case
# convert the value to np.ndarray
# convert to np.object data type if neither bool nor number
# convert to object data type if neither bool nor number
# raises an exception if array's elements are tensors themself
v = np.asanyarray(v)
if not issubclass(v.dtype.type, (np.bool_, np.number)):
v = v.astype(np.object)
if v.dtype == np.object:
# scalar ndarray with np.object data type is very annoying
v = v.astype(object)
if v.dtype == object:
# scalar ndarray with object data type is very annoying
# a=np.array([np.array({}, dtype=object), np.array({}, dtype=object)])
# a is not array([{}, {}], dtype=object), and a[0]={} results in
# something very strange:
@ -87,13 +89,11 @@ def _create_value(
if has_shape:
shape = (size, *inst.shape) if stack else (size, *inst.shape[1:])
if isinstance(inst, np.ndarray):
if issubclass(inst.dtype.type, (np.bool_, np.number)):
target_type = inst.dtype.type
else:
target_type = np.object
target_type = inst.dtype.type if issubclass(
inst.dtype.type, (np.bool_, np.number)) else object
return np.full(
shape,
fill_value=None if target_type == np.object else 0,
fill_value=None if target_type == object else 0,
dtype=target_type
)
elif isinstance(inst, torch.Tensor):
@ -105,8 +105,8 @@ def _create_value(
return zero_batch
elif is_scalar:
return _create_value(np.asarray(inst), size, stack=stack)
else: # fall back to np.object
return np.array([None for _ in range(size)])
else: # fall back to object
return np.array([None for _ in range(size)], object)
def _assert_type_keys(keys: Iterable[str]) -> None:
@ -187,7 +187,7 @@ class Batch:
for k, v in batch_dict.items():
self.__dict__[k] = _parse_value(v)
elif _is_batch_set(batch_dict):
self.stack_(batch_dict)
self.stack_(batch_dict) # type: ignore
if len(kwargs) > 0:
self.__init__(kwargs, copy=copy) # type: ignore
@ -223,9 +223,7 @@ class Batch:
"""
self.__init__(**state) # type: ignore
def __getitem__(
self, index: Union[str, slice, int, np.integer, np.ndarray, List[int]]
) -> Any:
def __getitem__(self, index: Union[str, IndexType]) -> Any:
"""Return self[index]."""
if isinstance(index, str):
return self.__dict__[index]
@ -241,11 +239,7 @@ class Batch:
else:
raise IndexError("Cannot access item from empty Batch object.")
def __setitem__(
self,
index: Union[str, slice, int, np.integer, np.ndarray, List[int]],
value: Any,
) -> None:
def __setitem__(self, index: Union[str, IndexType], value: Any) -> None:
"""Assign value to self[index]."""
value = _parse_value(value)
if isinstance(index, str):
@ -530,8 +524,7 @@ class Batch:
elif all(isinstance(e, (Batch, dict)) for e in v): # third often
self.__dict__[k] = Batch.stack(v, axis)
else: # most often case is np.ndarray
v = np.stack(v, axis)
self.__dict__[k] = _to_array_with_correct_type(v)
self.__dict__[k] = _to_array_with_correct_type(np.stack(v, axis))
# all the keys
keys_total = set.union(*[set(b.keys()) for b in batches])
# keys that are reserved in all batches
@ -587,9 +580,7 @@ class Batch:
batch.stack_(batches, axis)
return batch
def empty_(
self, index: Union[str, slice, int, np.integer, np.ndarray, List[int]] = None
) -> "Batch":
def empty_(self, index: Optional[Union[slice, IndexType]] = None) -> "Batch":
"""Return an empty Batch object with 0 or None filled.
If "index" is specified, it will only reset the specific indexed-data.
@ -620,7 +611,7 @@ class Batch:
elif v is None:
continue
elif isinstance(v, np.ndarray):
if v.dtype == np.object:
if v.dtype == object:
self.__dict__[k][index] = None
else:
self.__dict__[k][index] = 0
@ -636,10 +627,7 @@ class Batch:
return self
@staticmethod
def empty(
batch: "Batch",
index: Union[str, slice, int, np.integer, np.ndarray, List[int]] = None,
) -> "Batch":
def empty(batch: "Batch", index: Optional[IndexType] = None) -> "Batch":
"""Return an empty Batch object with 0 or None filled.
The shape is the same as the given Batch.

View File

@ -115,9 +115,9 @@ class ReplayBuffer:
def unfinished_index(self) -> np.ndarray:
"""Return the index of unfinished episode."""
last = (self._index - 1) % self._size if self._size else 0
return np.array([last] if not self.done[last] and self._size else [], np.int)
return np.array([last] if not self.done[last] and self._size else [], int)
def prev(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray:
def prev(self, index: Union[int, np.ndarray]) -> np.ndarray:
"""Return the index of previous transition.
The index won't be modified if it is the beginning of an episode.
@ -126,7 +126,7 @@ class ReplayBuffer:
end_flag = self.done[index] | (index == self.last_index[0])
return (index + end_flag) % self._size
def next(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray:
def next(self, index: Union[int, np.ndarray]) -> np.ndarray:
"""Return the index of next transition.
The index won't be modified if it is the end of an episode.
@ -140,12 +140,12 @@ class ReplayBuffer:
Return the updated indices. If update fails, return an empty array.
"""
if len(buffer) == 0 or self.maxsize == 0:
return np.array([], np.int)
return np.array([], int)
stack_num, buffer.stack_num = buffer.stack_num, 1
from_indices = buffer.sample_index(0) # get all available indices
buffer.stack_num = stack_num
if len(from_indices) == 0:
return np.array([], np.int)
return np.array([], int)
to_indices = []
for _ in range(len(from_indices)):
to_indices.append(self._index)
@ -224,8 +224,8 @@ class ReplayBuffer:
self._meta[ptr] = batch
except ValueError:
stack = not stacked_batch
batch.rew = batch.rew.astype(np.float)
batch.done = batch.done.astype(np.bool_)
batch.rew = batch.rew.astype(float)
batch.done = batch.done.astype(bool)
if self._meta.is_empty():
self._meta = _create_value( # type: ignore
batch, self.maxsize, stack)
@ -248,10 +248,10 @@ class ReplayBuffer:
[np.arange(self._index, self._size), np.arange(self._index)]
)
else:
return np.array([], np.int)
return np.array([], int)
else:
if batch_size < 0:
return np.array([], np.int)
return np.array([], int)
all_indices = prev_indices = np.concatenate(
[np.arange(self._index, self._size), np.arange(self._index)]
)
@ -275,9 +275,9 @@ class ReplayBuffer:
def get(
self,
index: Union[int, np.integer, np.ndarray],
index: Union[int, List[int], np.ndarray],
key: str,
default_value: Optional[Any] = None,
default_value: Any = None,
stack_num: Optional[int] = None,
) -> Union[Batch, np.ndarray]:
"""Return the stacked result.
@ -303,7 +303,7 @@ class ReplayBuffer:
if isinstance(index, list):
indice = np.array(index)
else:
indice = index
indice = index # type: ignore
for _ in range(stack_num):
stack = [val[indice]] + stack
indice = self.prev(indice)
@ -316,30 +316,31 @@ class ReplayBuffer:
raise e # val != Batch()
return Batch()
def __getitem__(self, index: Union[slice, int, np.integer, np.ndarray]) -> Batch:
def __getitem__(self, index: Union[slice, int, List[int], np.ndarray]) -> Batch:
"""Return a data batch: self[index].
If stack_num is larger than 1, return the stacked obs and obs_next with shape
(batch, len, ...).
"""
if isinstance(index, slice): # change slice to np array
if index == slice(None): # buffer[:] will get all available data
index = self.sample_index(0)
else:
index = self._indices[:len(self)][index]
# buffer[:] will get all available data
indice = self.sample_index(0) if index == slice(None) \
else self._indices[:len(self)][index]
else:
indice = index
# raise KeyError first instead of AttributeError,
# to support np.array([ReplayBuffer()])
obs = self.get(index, "obs")
obs = self.get(indice, "obs")
if self._save_obs_next:
obs_next = self.get(index, "obs_next", Batch())
obs_next = self.get(indice, "obs_next", Batch())
else:
obs_next = self.get(self.next(index), "obs", Batch())
obs_next = self.get(self.next(indice), "obs", Batch())
return Batch(
obs=obs,
act=self.act[index],
rew=self.rew[index],
done=self.done[index],
act=self.act[indice],
rew=self.rew[indice],
done=self.done[indice],
obs_next=obs_next,
info=self.get(index, "info", Batch()),
policy=self.get(index, "policy", Batch()),
info=self.get(indice, "info", Batch()),
policy=self.get(indice, "policy", Batch()),
)

View File

@ -58,14 +58,14 @@ class CachedReplayBuffer(ReplayBufferManager):
cached_buffer_ids[i]th cached buffer's corresponding episode result.
"""
if buffer_ids is None:
buffer_ids = np.arange(1, 1 + self.cached_buffer_num)
buf_arr = np.arange(1, 1 + self.cached_buffer_num)
else: # make sure it is np.ndarray
buffer_ids = np.asarray(buffer_ids) + 1
ptr, ep_rew, ep_len, ep_idx = super().add(batch, buffer_ids=buffer_ids)
buf_arr = np.asarray(buffer_ids) + 1
ptr, ep_rew, ep_len, ep_idx = super().add(batch, buffer_ids=buf_arr)
# find the terminated episode, move data from cached buf to main buf
updated_ptr, updated_ep_idx = [], []
done = batch.done.astype(np.bool_)
for buffer_idx in buffer_ids[done]:
done = batch.done.astype(bool)
for buffer_idx in buf_arr[done]:
index = self.main_buffer.update(self.buffers[buffer_idx])
if len(index) == 0: # unsuccessful move, replace with -1
index = [-1]

View File

@ -22,7 +22,7 @@ class ReplayBufferManager(ReplayBuffer):
def __init__(self, buffer_list: List[ReplayBuffer]) -> None:
self.buffer_num = len(buffer_list)
self.buffers = np.array(buffer_list, dtype=np.object)
self.buffers = np.array(buffer_list, dtype=object)
offset, size = [], 0
buffer_type = type(self.buffers[0])
kwargs = self.buffers[0].options
@ -46,7 +46,7 @@ class ReplayBufferManager(ReplayBuffer):
_next_index(index, offset, done, last, lens)
def __len__(self) -> int:
return self._lengths.sum()
return int(self._lengths.sum())
def reset(self, keep_statistics: bool = False) -> None:
self.last_index = self._offset.copy()
@ -68,7 +68,7 @@ class ReplayBufferManager(ReplayBuffer):
for offset, buf in zip(self._offset, self.buffers)
])
def prev(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray:
def prev(self, index: Union[int, np.ndarray]) -> np.ndarray:
if isinstance(index, (list, np.ndarray)):
return _prev_index(np.asarray(index), self._extend_offset,
self.done, self.last_index, self._lengths)
@ -76,7 +76,7 @@ class ReplayBufferManager(ReplayBuffer):
return _prev_index(np.array([index]), self._extend_offset,
self.done, self.last_index, self._lengths)[0]
def next(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray:
def next(self, index: Union[int, np.ndarray]) -> np.ndarray:
if isinstance(index, (list, np.ndarray)):
return _next_index(np.asarray(index), self._extend_offset,
self.done, self.last_index, self._lengths)
@ -130,8 +130,8 @@ class ReplayBufferManager(ReplayBuffer):
try:
self._meta[ptrs] = batch
except ValueError:
batch.rew = batch.rew.astype(np.float)
batch.done = batch.done.astype(np.bool_)
batch.rew = batch.rew.astype(float)
batch.done = batch.done.astype(bool)
if self._meta.is_empty():
self._meta = _create_value( # type: ignore
batch, self.maxsize, stack=False)
@ -143,7 +143,7 @@ class ReplayBufferManager(ReplayBuffer):
def sample_index(self, batch_size: int) -> np.ndarray:
if batch_size < 0:
return np.array([], np.int)
return np.array([], int)
if self._sample_avail and self.stack_num > 1:
all_indices = np.concatenate([
buf.sample_index(0) + offset
@ -154,7 +154,7 @@ class ReplayBufferManager(ReplayBuffer):
else:
return np.random.choice(all_indices, batch_size)
if batch_size == 0: # get all available indices
sample_num = np.zeros(self.buffer_num, np.int)
sample_num = np.zeros(self.buffer_num, int)
else:
buffer_idx = np.random.choice(
self.buffer_num, batch_size, p=self._lengths / self._lengths.sum()

View File

@ -34,6 +34,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
def update(self, buffer: ReplayBuffer) -> np.ndarray:
indices = super().update(buffer)
self.init_weight(indices)
return indices
def add(
self, batch: Batch, buffer_ids: Optional[Union[np.ndarray, List[int]]] = None
@ -45,13 +46,11 @@ class PrioritizedReplayBuffer(ReplayBuffer):
def sample_index(self, batch_size: int) -> np.ndarray:
if batch_size > 0 and len(self) > 0:
scalar = np.random.rand(batch_size) * self.weight.reduce()
return self.weight.get_prefix_sum_idx(scalar)
return self.weight.get_prefix_sum_idx(scalar) # type: ignore
else:
return super().sample_index(batch_size)
def get_weight(
self, index: Union[slice, int, np.integer, np.ndarray]
) -> np.ndarray:
def get_weight(self, index: Union[int, np.ndarray]) -> Union[float, np.ndarray]:
"""Get the importance sampling weight.
The "weight" in the returned Batch is the weight on loss function to de-bias
@ -76,7 +75,13 @@ class PrioritizedReplayBuffer(ReplayBuffer):
self._max_prio = max(self._max_prio, weight.max())
self._min_prio = min(self._min_prio, weight.min())
def __getitem__(self, index: Union[slice, int, np.integer, np.ndarray]) -> Batch:
batch = super().__getitem__(index)
batch.weight = self.get_weight(index)
def __getitem__(self, index: Union[slice, int, List[int], np.ndarray]) -> Batch:
if isinstance(index, slice): # change slice to np array
# buffer[:] will get all available data
indice = self.sample_index(0) if index == slice(None) \
else self._indices[:len(self)][index]
else:
indice = index
batch = super().__getitem__(indice)
batch.weight = self.get_weight(indice)
return batch

View File

@ -123,7 +123,7 @@ class Collector(object):
if isinstance(state, torch.Tensor):
state[id].zero_()
elif isinstance(state, np.ndarray):
state[id] = None if state.dtype == np.object else 0
state[id] = None if state.dtype == object else 0
elif isinstance(state, Batch):
state.empty_(id)
@ -266,7 +266,7 @@ class Collector(object):
if n_episode:
surplus_env_num = len(ready_env_ids) - (n_episode - episode_count)
if surplus_env_num > 0:
mask = np.ones_like(ready_env_ids, np.bool)
mask = np.ones_like(ready_env_ids, dtype=bool)
mask[env_ind_local[:surplus_env_num]] = False
ready_env_ids = ready_env_ids[mask]
self.data = self.data[mask]
@ -291,7 +291,7 @@ class Collector(object):
rews, lens, idxs = list(map(
np.concatenate, [episode_rews, episode_lens, episode_start_indices]))
else:
rews, lens, idxs = np.array([]), np.array([], np.int), np.array([], np.int)
rews, lens, idxs = np.array([]), np.array([], int), np.array([], int)
return {
"n/ep": episode_count,
@ -493,7 +493,7 @@ class AsyncCollector(Collector):
rews, lens, idxs = list(map(
np.concatenate, [episode_rews, episode_lens, episode_start_indices]))
else:
rews, lens, idxs = np.array([]), np.array([], np.int), np.array([], np.int)
rews, lens, idxs = np.array([]), np.array([], int), np.array([], int)
return {
"n/ep": episode_count,

View File

@ -4,15 +4,12 @@ import pickle
import numpy as np
from copy import deepcopy
from numbers import Number
from typing import Dict, Union, Optional
from typing import Any, Dict, Union, Optional
from tianshou.data.batch import _parse_value, Batch
def to_numpy(
x: Optional[Union[Batch, dict, list, tuple, np.number, np.bool_, Number,
np.ndarray, torch.Tensor]]
) -> Union[Batch, dict, list, tuple, np.ndarray]:
def to_numpy(x: Any) -> Union[Batch, np.ndarray]:
"""Return an object without torch.Tensor."""
if isinstance(x, torch.Tensor): # most often case
return x.detach().cpu().numpy()
@ -21,28 +18,22 @@ def to_numpy(
elif isinstance(x, (np.number, np.bool_, Number)):
return np.asanyarray(x)
elif x is None:
return np.array(None, dtype=np.object)
elif isinstance(x, Batch):
x = deepcopy(x)
return np.array(None, dtype=object)
elif isinstance(x, (dict, Batch)):
x = Batch(x) if isinstance(x, dict) else deepcopy(x)
x.to_numpy()
return x
elif isinstance(x, dict):
return {k: to_numpy(v) for k, v in x.items()}
elif isinstance(x, (list, tuple)):
try:
return to_numpy(_parse_value(x))
except TypeError:
return [to_numpy(e) for e in x]
return to_numpy(_parse_value(x))
else: # fallback
return np.asanyarray(x)
def to_torch(
x: Union[Batch, dict, list, tuple, np.number, np.bool_, Number, np.ndarray,
torch.Tensor],
x: Any,
dtype: Optional[torch.dtype] = None,
device: Union[str, int, torch.device] = "cpu",
) -> Union[Batch, dict, list, tuple, torch.Tensor]:
) -> Union[Batch, torch.Tensor]:
"""Return an object without np.ndarray."""
if isinstance(x, np.ndarray) and issubclass(
x.dtype.type, (np.bool_, np.number)
@ -57,25 +48,17 @@ def to_torch(
return x.to(device) # type: ignore
elif isinstance(x, (np.number, np.bool_, Number)):
return to_torch(np.asanyarray(x), dtype, device)
elif isinstance(x, dict):
return {k: to_torch(v, dtype, device) for k, v in x.items()}
elif isinstance(x, Batch):
x = deepcopy(x)
elif isinstance(x, (dict, Batch)):
x = Batch(x, copy=True) if isinstance(x, dict) else deepcopy(x)
x.to_torch(dtype, device)
return x
elif isinstance(x, (list, tuple)):
try:
return to_torch(_parse_value(x), dtype, device)
except TypeError:
return [to_torch(e, dtype, device) for e in x]
return to_torch(_parse_value(x), dtype, device)
else: # fallback
raise TypeError(f"object {x} cannot be converted to torch.")
def to_torch_as(
x: Union[Batch, dict, list, tuple, np.ndarray, torch.Tensor],
y: torch.Tensor,
) -> Union[Batch, dict, list, tuple, torch.Tensor]:
def to_torch_as(x: Any, y: torch.Tensor) -> Union[Batch, torch.Tensor]:
"""Return an object without np.ndarray.
Same as ``to_torch(x, dtype=y.dtype, device=y.device)``.
@ -147,25 +130,20 @@ def to_hdf5(x: Hdf5ConvertibleType, y: h5py.Group) -> None:
y[k].attrs["__data_type__"] = v.__class__.__name__
def from_hdf5(
x: h5py.Group, device: Optional[str] = None
) -> Hdf5ConvertibleType:
def from_hdf5(x: h5py.Group, device: Optional[str] = None) -> Hdf5ConvertibleValues:
"""Restore object from HDF5 group."""
if isinstance(x, h5py.Dataset):
# handle datasets
if x.attrs["__data_type__"] == "ndarray":
y = np.array(x)
return np.array(x)
elif x.attrs["__data_type__"] == "Tensor":
y = torch.tensor(x, device=device)
return torch.tensor(x, device=device)
else:
y = pickle.loads(x[()])
return pickle.loads(x[()])
else:
# handle groups representing a dict or a Batch
y = {k: v for k, v in x.attrs.items() if k != "__data_type__"}
y = dict(x.attrs.items())
data_type = y.pop("__data_type__", None)
for k, v in x.items():
y[k] = from_hdf5(v, device)
if "__data_type__" in x.attrs:
# if dictionary represents Batch, convert to Batch
if x.attrs["__data_type__"] == "Batch":
y = Batch(y)
return y
return Batch(y) if data_type == "Batch" else y

10
tianshou/env/venvs.py vendored
View File

@ -140,12 +140,10 @@ class BaseVectorEnv(gym.Env):
self, id: Optional[Union[int, List[int], np.ndarray]] = None
) -> Union[List[int], np.ndarray]:
if id is None:
id = list(range(self.env_num))
elif np.isscalar(id):
id = [id]
return id
return list(range(self.env_num))
return [id] if np.isscalar(id) else id # type: ignore
def _assert_id(self, id: List[int]) -> None:
def _assert_id(self, id: Union[List[int], np.ndarray]) -> None:
for i in id:
assert i not in self.waiting_id, \
f"Cannot interact with environment {i} which is stepping now."
@ -291,7 +289,7 @@ class BaseVectorEnv(gym.Env):
clip_max = 10.0 # this magic number is from openai baselines
# see baselines/common/vec_env/vec_normalize.py#L10
obs = (obs - self.obs_rms.mean) / np.sqrt(self.obs_rms.var + self.__eps)
obs = np.clip(obs, -clip_max, clip_max)
obs = np.clip(obs, -clip_max, clip_max) # type: ignore
return obs
def __del__(self) -> None:

View File

@ -25,9 +25,7 @@ class EnvWorker(ABC):
def send_action(self, action: np.ndarray) -> None:
pass
def get_result(
self,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
def get_result(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
return self.result
def step(
@ -45,9 +43,7 @@ class EnvWorker(ABC):
@staticmethod
def wait(
workers: List["EnvWorker"],
wait_num: int,
timeout: Optional[float] = None,
workers: List["EnvWorker"], wait_num: int, timeout: Optional[float] = None
) -> List["EnvWorker"]:
"""Given a list of workers, return those ready ones."""
raise NotImplementedError

View File

@ -20,9 +20,7 @@ class DummyEnvWorker(EnvWorker):
@staticmethod
def wait( # type: ignore
workers: List["DummyEnvWorker"],
wait_num: int,
timeout: Optional[float] = None,
workers: List["DummyEnvWorker"], wait_num: int, timeout: Optional[float] = None
) -> List["DummyEnvWorker"]:
# Sequential EnvWorker objects are always ready
return workers

View File

@ -25,9 +25,7 @@ class RayEnvWorker(EnvWorker):
@staticmethod
def wait( # type: ignore
workers: List["RayEnvWorker"],
wait_num: int,
timeout: Optional[float] = None,
workers: List["RayEnvWorker"], wait_num: int, timeout: Optional[float] = None
) -> List["RayEnvWorker"]:
results = [x.result for x in workers]
ready_results, _ = ray.wait(results, num_returns=wait_num, timeout=timeout)

View File

@ -12,7 +12,6 @@ from tianshou.env.utils import CloudpickleWrapper
_NP_TO_CT = {
np.bool: ctypes.c_bool,
np.bool_: ctypes.c_bool,
np.uint8: ctypes.c_uint8,
np.uint16: ctypes.c_uint16,
@ -31,7 +30,7 @@ class ShArray:
"""Wrapper of multiprocessing Array."""
def __init__(self, dtype: np.generic, shape: Tuple[int]) -> None:
self.arr = Array(_NP_TO_CT[dtype.type], int(np.prod(shape)))
self.arr = Array(_NP_TO_CT[dtype.type], int(np.prod(shape))) # type: ignore
self.dtype = dtype
self.shape = shape
@ -64,8 +63,7 @@ def _worker(
obs_bufs: Optional[Union[dict, tuple, ShArray]] = None,
) -> None:
def _encode_obs(
obs: Union[dict, tuple, np.ndarray],
buffer: Union[dict, tuple, ShArray],
obs: Union[dict, tuple, np.ndarray], buffer: Union[dict, tuple, ShArray]
) -> None:
if isinstance(obs, np.ndarray) and isinstance(buffer, ShArray):
buffer.save(obs)

View File

@ -68,9 +68,7 @@ class OUNoise(BaseNoise):
"""Reset to the initial state."""
self._x = self._x0
def __call__(
self, size: Sequence[int], mu: Optional[float] = None
) -> np.ndarray:
def __call__(self, size: Sequence[int], mu: Optional[float] = None) -> np.ndarray:
"""Generate new noise.
Return an numpy array which size is equal to ``size``.
@ -82,4 +80,4 @@ class OUNoise(BaseNoise):
mu = self._mu
r = self._beta * np.random.normal(size=size)
self._x = self._x + self._alpha * (mu - self._x) + r
return self._x
return self._x # type: ignore

View File

@ -142,14 +142,14 @@ class BasePolicy(ABC, nn.Module):
isinstance(act, np.ndarray):
# currently this action mapping only supports np.ndarray action
if self.action_bound_method == "clip":
act = np.clip(act, -1.0, 1.0)
act = np.clip(act, -1.0, 1.0) # type: ignore
elif self.action_bound_method == "tanh":
act = np.tanh(act)
if self.action_scaling:
assert np.all(act >= -1.0) and np.all(act <= 1.0), \
assert np.min(act) >= -1.0 and np.max(act) <= 1.0, \
"action scaling only accepts raw action range = [-1, 1]"
low, high = self.action_space.low, self.action_space.high
act = low + (high - low) * (act + 1.0) / 2.0
act = low + (high - low) * (act + 1.0) / 2.0 # type: ignore
return act
def process_fn(
@ -241,9 +241,9 @@ class BasePolicy(ABC, nn.Module):
:return: A bool type numpy.ndarray in the same shape with indice. "True" means
"obs_next" of that buffer[indice] is valid.
"""
mask = ~buffer.done[indice].astype(np.bool)
# info['TimeLimit.truncated'] will be set to True if 'done' flag is generated
# because of timelimit of environments. Checkout gym.wrappers.TimeLimit.
mask = ~buffer.done[indice]
# info["TimeLimit.truncated"] will be True if "done" flag is generated by
# timelimit of environments. Checkout gym.wrappers.TimeLimit.
if hasattr(buffer, 'info') and 'TimeLimit.truncated' in buffer.info:
mask = mask | buffer.info['TimeLimit.truncated'][indice]
return mask
@ -281,7 +281,8 @@ class BasePolicy(ABC, nn.Module):
assert np.isclose(gae_lambda, 1.0)
v_s_ = np.zeros_like(rew)
else:
v_s_ = to_numpy(v_s_.flatten()) * BasePolicy.value_mask(buffer, indice)
v_s_ = to_numpy(v_s_.flatten()) # type: ignore
v_s_ = v_s_ * BasePolicy.value_mask(buffer, indice)
v_s = np.roll(v_s_, 1) if v_s is None else to_numpy(v_s.flatten())
end_flag = batch.done.copy()

View File

@ -58,7 +58,7 @@ class DiscreteBCQPolicy(DQNPolicy):
else:
self._log_tau = -np.inf
assert 0.0 <= eval_eps < 1.0
self._eps = eval_eps
self.eps = eval_eps
self._weight_reg = imitation_logits_penalty
def train(self, mode: bool = True) -> "DiscreteBCQPolicy":
@ -96,15 +96,6 @@ class DiscreteBCQPolicy(DQNPolicy):
return Batch(act=action, state=state, q_value=q_value,
imitation_logits=imitation_logits)
def exploration_noise(self, act: np.ndarray, batch: Batch) -> np.ndarray:
# add eps to act
if not np.isclose(self._eps, 0.0):
bsz = len(act)
mask = np.random.rand(bsz) < self._eps
act_rand = np.random.randint(self.max_action_num, size=[bsz])
act[mask] = act_rand[mask]
return act
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
if self._iter % self._freq == 0:
self.sync_weight()

View File

@ -1,6 +1,6 @@
import torch
import numpy as np
from typing import Any, Dict, Union, Optional
from typing import Any, Dict, Tuple, Union, Optional
from tianshou.data import Batch
from tianshou.policy import BasePolicy
@ -100,7 +100,7 @@ class PSRLModel(object):
discount_factor: float,
eps: float,
value: np.ndarray,
) -> np.ndarray:
) -> Tuple[np.ndarray, np.ndarray]:
"""Value iteration solver for MDPs.
:param np.ndarray trans_prob: transition probabilities, with shape
@ -126,7 +126,7 @@ class PSRLModel(object):
def __call__(
self,
obs: np.ndarray,
state: Optional[Any] = None,
state: Any = None,
info: Dict[str, Any] = {},
) -> np.ndarray:
if not self.updated:
@ -215,6 +215,6 @@ class PSRLPolicy(BasePolicy):
rew_count[obs_next, :] += 1
self.model.observe(trans_count, rew_sum, rew_square_sum, rew_count)
return {
"psrl/rew_mean": self.model.rew_mean.mean(),
"psrl/rew_std": self.model.rew_std.mean(),
"psrl/rew_mean": float(self.model.rew_mean.mean()),
"psrl/rew_std": float(self.model.rew_std.mean()),
}

View File

@ -5,7 +5,7 @@ import torch.nn.functional as F
from typing import Any, Dict, List, Type, Optional
from tianshou.policy import PGPolicy
from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch_as
from tianshou.data import Batch, ReplayBuffer, to_torch_as
class A2CPolicy(PGPolicy):
@ -84,8 +84,8 @@ class A2CPolicy(PGPolicy):
v_s.append(self.critic(b.obs))
v_s_.append(self.critic(b.obs_next))
batch.v_s = torch.cat(v_s, dim=0).flatten() # old value
v_s = to_numpy(batch.v_s)
v_s_ = to_numpy(torch.cat(v_s_, dim=0).flatten())
v_s = batch.v_s.cpu().numpy()
v_s_ = torch.cat(v_s_, dim=0).flatten().cpu().numpy()
# when normalizing values, we do not minus self.ret_rms.mean to be numerically
# consistent with OPENAI baselines' value normalization pipeline. Emperical
# study also shows that "minus mean" will harm performances a tiny little bit

View File

@ -1,4 +1,5 @@
import torch
import warnings
import numpy as np
from copy import deepcopy
from typing import Any, Dict, Tuple, Union, Optional
@ -167,7 +168,12 @@ class DDPGPolicy(BasePolicy):
"loss/critic": critic_loss.item(),
}
def exploration_noise(self, act: np.ndarray, batch: Batch) -> np.ndarray:
if self._noise:
act = act + self._noise(act.shape)
def exploration_noise(
self, act: Union[np.ndarray, Batch], batch: Batch
) -> Union[np.ndarray, Batch]:
if self._noise is None:
return act
if isinstance(act, np.ndarray):
return act + self._noise(act.shape)
warnings.warn("Cannot add exploration noise to non-numpy_array action.")
return act

View File

@ -168,8 +168,10 @@ class DQNPolicy(BasePolicy):
self._iter += 1
return {"loss": loss.item()}
def exploration_noise(self, act: np.ndarray, batch: Batch) -> np.ndarray:
if not np.isclose(self.eps, 0.0):
def exploration_noise(
self, act: Union[np.ndarray, Batch], batch: Batch
) -> Union[np.ndarray, Batch]:
if isinstance(act, np.ndarray) and not np.isclose(self.eps, 0.0):
bsz = len(act)
rand_mask = np.random.rand(bsz) < self.eps
q = np.random.rand(bsz, self.max_action_num) # [0, 1]

View File

@ -1,5 +1,5 @@
import numpy as np
from typing import Any, Dict, List, Union, Optional
from typing import Any, Dict, List, Tuple, Union, Optional
from tianshou.policy import BasePolicy
from tianshou.data import Batch, ReplayBuffer
@ -71,7 +71,7 @@ class MultiAgentPolicyManager(BasePolicy):
act[agent_index], batch[agent_index])
return act
def forward(
def forward( # type: ignore
self,
batch: Batch,
state: Optional[Union[dict, Batch]] = None,
@ -100,7 +100,8 @@ class MultiAgentPolicyManager(BasePolicy):
"agent_n": xxx}
}
"""
results = []
results: List[Tuple[bool, np.ndarray, Batch,
Union[np.ndarray, Batch], Batch]] = []
for policy in self.policies:
# This part of code is difficult to understand.
# Let's follow an example with two agents
@ -112,7 +113,7 @@ class MultiAgentPolicyManager(BasePolicy):
agent_index = np.nonzero(batch.obs.agent_id == policy.agent_id)[0]
if len(agent_index) == 0:
# (has_data, agent_index, out, act, state)
results.append((False, None, Batch(), None, Batch()))
results.append((False, np.array([-1]), Batch(), Batch(), Batch()))
continue
tmp_batch = batch[agent_index]
if isinstance(tmp_batch.rew, np.ndarray):

View File

@ -14,16 +14,12 @@ class BaseLogger(ABC):
@abstractmethod
def write(
self,
key: str,
x: Union[Number, np.number, np.ndarray],
y: Union[Number, np.number, np.ndarray],
**kwargs: Any,
self, key: str, x: int, y: Union[Number, np.number, np.ndarray], **kwargs: Any
) -> None:
"""Specify how the writer is used to log data.
:param key: namespace which the input data tuple belongs to.
:param x: stands for the ordinate of the input data tuple.
:param str key: namespace which the input data tuple belongs to.
:param int x: stands for the ordinate of the input data tuple.
:param y: stands for the abscissa of the input data tuple.
"""
pass
@ -84,11 +80,7 @@ class BasicLogger(BaseLogger):
self.last_log_update_step = -1
def write(
self,
key: str,
x: Union[Number, np.number, np.ndarray],
y: Union[Number, np.number, np.ndarray],
**kwargs: Any,
self, key: str, x: int, y: Union[Number, np.number, np.ndarray], **kwargs: Any
) -> None:
self.writer.add_scalar(key, y, global_step=x)
@ -149,11 +141,7 @@ class LazyLogger(BasicLogger):
super().__init__(None) # type: ignore
def write(
self,
key: str,
x: Union[Number, np.number, np.ndarray],
y: Union[Number, np.number, np.ndarray],
**kwargs: Any,
self, key: str, x: int, y: Union[Number, np.number, np.ndarray], **kwargs: Any
) -> None:
"""The LazyLogger writes nothing."""
pass

View File

@ -50,8 +50,7 @@ class MLP(nn.Module):
output_dim: int = 0,
hidden_sizes: Sequence[int] = (),
norm_layer: Optional[Union[ModuleType, Sequence[ModuleType]]] = None,
activation: Optional[Union[ModuleType, Sequence[ModuleType]]]
= nn.ReLU,
activation: Optional[Union[ModuleType, Sequence[ModuleType]]] = nn.ReLU,
device: Optional[Union[str, int, torch.device]] = None,
) -> None:
super().__init__()
@ -139,7 +138,7 @@ class Net(nn.Module):
def __init__(
self,
state_shape: Union[int, Sequence[int]],
action_shape: Optional[Union[int, Sequence[int]]] = 0,
action_shape: Union[int, Sequence[int]] = 0,
hidden_sizes: Sequence[int] = (),
norm_layer: Optional[ModuleType] = None,
activation: Optional[ModuleType] = nn.ReLU,
@ -153,8 +152,8 @@ class Net(nn.Module):
self.device = device
self.softmax = softmax
self.num_atoms = num_atoms
input_dim = np.prod(state_shape)
action_dim = np.prod(action_shape) * num_atoms
input_dim = int(np.prod(state_shape))
action_dim = int(np.prod(action_shape)) * num_atoms
if concat:
input_dim += action_dim
self.use_dueling = dueling_param is not None
@ -179,7 +178,7 @@ class Net(nn.Module):
def forward(
self,
s: Union[np.ndarray, torch.Tensor],
state: Optional[Any] = None,
state: Any = None,
info: Dict[str, Any] = {},
) -> Tuple[torch.Tensor, Any]:
"""Mapping: s -> flatten (inside MLP)-> logits."""
@ -221,8 +220,8 @@ class Recurrent(nn.Module):
num_layers=layer_num,
batch_first=True,
)
self.fc1 = nn.Linear(np.prod(state_shape), hidden_layer_size)
self.fc2 = nn.Linear(hidden_layer_size, np.prod(action_shape))
self.fc1 = nn.Linear(int(np.prod(state_shape)), hidden_layer_size)
self.fc2 = nn.Linear(hidden_layer_size, int(np.prod(action_shape)))
def forward(
self,

View File

@ -46,7 +46,7 @@ class Actor(nn.Module):
super().__init__()
self.device = device
self.preprocess = preprocess_net
self.output_dim = np.prod(action_shape)
self.output_dim = int(np.prod(action_shape))
input_dim = getattr(preprocess_net, "output_dim",
preprocess_net_output_dim)
self.last = MLP(input_dim, self.output_dim,
@ -56,7 +56,7 @@ class Actor(nn.Module):
def forward(
self,
s: Union[np.ndarray, torch.Tensor],
state: Optional[Any] = None,
state: Any = None,
info: Dict[str, Any] = {},
) -> Tuple[torch.Tensor, Any]:
"""Mapping: s -> logits -> action."""
@ -162,7 +162,7 @@ class ActorProb(nn.Module):
super().__init__()
self.preprocess = preprocess_net
self.device = device
self.output_dim = np.prod(action_shape)
self.output_dim = int(np.prod(action_shape))
input_dim = getattr(preprocess_net, "output_dim",
preprocess_net_output_dim)
self.mu = MLP(input_dim, self.output_dim,
@ -179,7 +179,7 @@ class ActorProb(nn.Module):
def forward(
self,
s: Union[np.ndarray, torch.Tensor],
state: Optional[Any] = None,
state: Any = None,
info: Dict[str, Any] = {},
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Any]:
"""Mapping: s -> logits -> (mu, sigma)."""
@ -219,12 +219,12 @@ class RecurrentActorProb(nn.Module):
super().__init__()
self.device = device
self.nn = nn.LSTM(
input_size=np.prod(state_shape),
input_size=int(np.prod(state_shape)),
hidden_size=hidden_layer_size,
num_layers=layer_num,
batch_first=True,
)
output_dim = np.prod(action_shape)
output_dim = int(np.prod(action_shape))
self.mu = nn.Linear(hidden_layer_size, output_dim)
self._c_sigma = conditioned_sigma
if conditioned_sigma:
@ -293,12 +293,12 @@ class RecurrentCritic(nn.Module):
self.action_shape = action_shape
self.device = device
self.nn = nn.LSTM(
input_size=np.prod(state_shape),
input_size=int(np.prod(state_shape)),
hidden_size=hidden_layer_size,
num_layers=layer_num,
batch_first=True,
)
self.fc2 = nn.Linear(hidden_layer_size + np.prod(action_shape), 1)
self.fc2 = nn.Linear(hidden_layer_size + int(np.prod(action_shape)), 1)
def forward(
self,

View File

@ -45,7 +45,7 @@ class Actor(nn.Module):
super().__init__()
self.device = device
self.preprocess = preprocess_net
self.output_dim = np.prod(action_shape)
self.output_dim = int(np.prod(action_shape))
input_dim = getattr(preprocess_net, "output_dim",
preprocess_net_output_dim)
self.last = MLP(input_dim, self.output_dim,
@ -55,7 +55,7 @@ class Actor(nn.Module):
def forward(
self,
s: Union[np.ndarray, torch.Tensor],
state: Optional[Any] = None,
state: Any = None,
info: Dict[str, Any] = {},
) -> Tuple[torch.Tensor, Any]:
r"""Mapping: s -> Q(s, \*)."""

View File

@ -3,8 +3,6 @@ import numpy as np
from numbers import Number
from typing import List, Union
from tianshou.data import to_numpy
class MovAvg(object):
"""Class for moving average.
@ -28,44 +26,43 @@ class MovAvg(object):
def __init__(self, size: int = 100) -> None:
super().__init__()
self.size = size
self.cache: List[Union[Number, np.number]] = []
self.cache: List[np.number] = []
self.banned = [np.inf, np.nan, -np.inf]
def add(
self, x: Union[Number, np.number, list, np.ndarray, torch.Tensor]
) -> np.number:
) -> float:
"""Add a scalar into :class:`MovAvg`.
You can add ``torch.Tensor`` with only one element, a python scalar, or
a list of python scalar.
"""
if isinstance(x, torch.Tensor):
x = to_numpy(x.flatten())
if isinstance(x, list) or isinstance(x, np.ndarray):
for i in x:
if i not in self.banned:
self.cache.append(i)
elif x not in self.banned:
self.cache.append(x)
x = x.flatten().cpu().numpy()
if np.isscalar(x):
x = [x]
for i in x: # type: ignore
if i not in self.banned:
self.cache.append(i)
if self.size > 0 and len(self.cache) > self.size:
self.cache = self.cache[-self.size:]
return self.get()
def get(self) -> np.number:
def get(self) -> float:
"""Get the average."""
if len(self.cache) == 0:
return 0
return np.mean(self.cache)
return 0.0
return float(np.mean(self.cache))
def mean(self) -> np.number:
def mean(self) -> float:
"""Get the average. Same as :meth:`get`."""
return self.get()
def std(self) -> np.number:
def std(self) -> float:
"""Get the standard deviation."""
if len(self.cache) == 0:
return 0
return np.std(self.cache)
return 0.0
return float(np.std(self.cache))
class RunningMeanStd(object):
@ -74,8 +71,10 @@ class RunningMeanStd(object):
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
"""
def __init__(self) -> None:
self.mean, self.var = 0.0, 1.0
def __init__(
self, mean: Union[float, np.ndarray] = 0.0, std: Union[float, np.ndarray] = 1.0
) -> None:
self.mean, self.var = mean, std
self.count = 0
def update(self, x: np.ndarray) -> None:
@ -92,5 +91,5 @@ class RunningMeanStd(object):
m_2 = m_a + m_b + delta ** 2 * self.count * batch_count / total_count
new_var = m_2 / total_count
self.mean, self.var = new_mean, new_var
self.mean, self.var = new_mean, new_var # type: ignore
self.count = total_count