fix numpy>=1.20 typing check (#323)
Change the behavior of to_numpy and to_torch: from now on, dict is automatically converted to Batch and list is automatically converted to np.ndarray (if an error occurs, raise the exception instead of converting each element in the list).
This commit is contained in:
parent
6426a39796
commit
09692c84fe
2
setup.py
2
setup.py
@ -47,7 +47,7 @@ setup(
|
||||
install_requires=[
|
||||
"gym>=0.15.4",
|
||||
"tqdm",
|
||||
"numpy!=1.16.0,<1.20.0", # https://github.com/numpy/numpy/issues/12793
|
||||
"numpy>1.16.0", # https://github.com/numpy/numpy/issues/12793
|
||||
"tensorboard",
|
||||
"torch>=1.4.0",
|
||||
"numba>=0.51.0",
|
||||
|
||||
@ -20,9 +20,9 @@ def test_batch():
|
||||
assert len(Batch(a=[1, 2, 3], b={'c': {}})) == 3
|
||||
assert not Batch(a=[1, 2, 3]).is_empty()
|
||||
b = Batch({'a': [4, 4], 'b': [5, 5]}, c=[None, None])
|
||||
assert b.c.dtype == np.object
|
||||
assert b.c.dtype == object
|
||||
b = Batch(d=[None], e=[starmap], f=Batch)
|
||||
assert b.d.dtype == b.e.dtype == np.object and b.f == Batch
|
||||
assert b.d.dtype == b.e.dtype == object and b.f == Batch
|
||||
b = Batch()
|
||||
b.update()
|
||||
assert b.is_empty()
|
||||
@ -153,10 +153,10 @@ def test_batch():
|
||||
batch3[0] = Batch(a={"c": 2, "e": 1})
|
||||
# auto convert
|
||||
batch4 = Batch(a=np.array(['a', 'b']))
|
||||
assert batch4.a.dtype == np.object # auto convert to np.object
|
||||
assert batch4.a.dtype == object # auto convert to object
|
||||
batch4.update(a=np.array(['c', 'd']))
|
||||
assert list(batch4.a) == ['c', 'd']
|
||||
assert batch4.a.dtype == np.object # auto convert to np.object
|
||||
assert batch4.a.dtype == object # auto convert to object
|
||||
batch5 = Batch(a=np.array([{'index': 0}]))
|
||||
assert isinstance(batch5.a, Batch)
|
||||
assert np.allclose(batch5.a.index, [0])
|
||||
@ -405,21 +405,23 @@ def test_utils_to_torch_numpy():
|
||||
assert data_list_2_torch.shape == (2, 3, 3)
|
||||
assert np.allclose(to_numpy(to_torch(data_list_2)), data_list_2)
|
||||
data_list_3 = [np.zeros((3, 2)), np.zeros((3, 3))]
|
||||
data_list_3_torch = to_torch(data_list_3)
|
||||
assert isinstance(data_list_3_torch, list)
|
||||
assert all(isinstance(e, torch.Tensor) for e in data_list_3_torch)
|
||||
assert all(starmap(np.allclose,
|
||||
zip(to_numpy(to_torch(data_list_3)), data_list_3)))
|
||||
data_list_3_torch = [torch.zeros((3, 2)), torch.zeros((3, 3))]
|
||||
with pytest.raises(TypeError):
|
||||
to_torch(data_list_3)
|
||||
with pytest.raises(TypeError):
|
||||
to_numpy(data_list_3_torch)
|
||||
data_list_4 = [np.zeros((2, 3)), np.zeros((3, 3))]
|
||||
data_list_4_torch = to_torch(data_list_4)
|
||||
assert isinstance(data_list_4_torch, list)
|
||||
assert all(isinstance(e, torch.Tensor) for e in data_list_4_torch)
|
||||
assert all(starmap(np.allclose,
|
||||
zip(to_numpy(to_torch(data_list_4)), data_list_4)))
|
||||
data_list_4_torch = [torch.zeros((2, 3)), torch.zeros((3, 3))]
|
||||
with pytest.raises(TypeError):
|
||||
to_torch(data_list_4)
|
||||
with pytest.raises(TypeError):
|
||||
to_numpy(data_list_4_torch)
|
||||
data_list_5 = [np.zeros(2), np.zeros((3, 3))]
|
||||
data_list_5_torch = to_torch(data_list_5)
|
||||
assert isinstance(data_list_5_torch, list)
|
||||
assert all(isinstance(e, torch.Tensor) for e in data_list_5_torch)
|
||||
data_list_5_torch = [torch.zeros(2), torch.zeros((3, 3))]
|
||||
with pytest.raises(TypeError):
|
||||
to_torch(data_list_5)
|
||||
with pytest.raises(TypeError):
|
||||
to_numpy(data_list_5_torch)
|
||||
data_array = np.random.rand(3, 2, 2)
|
||||
data_empty_tensor = to_torch(data_array[[]])
|
||||
assert isinstance(data_empty_tensor, torch.Tensor)
|
||||
@ -508,10 +510,10 @@ def test_batch_empty():
|
||||
assert np.allclose(b5.b.c, [2, 0])
|
||||
assert np.allclose(b5.b.d, [1, 0])
|
||||
data = Batch(a=[False, True],
|
||||
b={'c': np.array([2., 'st'], dtype=np.object),
|
||||
b={'c': np.array([2., 'st'], dtype=object),
|
||||
'd': [1, None],
|
||||
'e': [2., float('nan')]},
|
||||
c=np.array([1, 3, 4], dtype=np.int),
|
||||
c=np.array([1, 3, 4], dtype=int),
|
||||
t=torch.tensor([4, 5, 6, 7.]))
|
||||
data[-1] = Batch.empty(data[1])
|
||||
assert np.allclose(data.c, [1, 3, 0])
|
||||
|
||||
@ -33,7 +33,7 @@ def test_replaybuffer(size=10, bufsize=20):
|
||||
done=done, obs_next=obs_next, info=info))
|
||||
obs = obs_next
|
||||
assert len(buf) == min(bufsize, i + 1)
|
||||
assert buf.act.dtype == np.int
|
||||
assert buf.act.dtype == int
|
||||
assert buf.act.shape == (bufsize, 1)
|
||||
data, indice = buf.sample(bufsize * 2)
|
||||
assert (indice < len(buf)).all()
|
||||
@ -50,9 +50,9 @@ def test_replaybuffer(size=10, bufsize=20):
|
||||
assert b.obs_next[0] == 'str'
|
||||
assert np.all(b.obs[1:] == 0)
|
||||
assert np.all(b.obs_next[1:] == np.array(None))
|
||||
assert b.info.a[0] == 3 and b.info.a.dtype == np.integer
|
||||
assert b.info.a[0] == 3 and b.info.a.dtype == int
|
||||
assert np.all(b.info.a[1:] == 0)
|
||||
assert b.info.b.c[0] == 5.0 and b.info.b.c.dtype == np.inexact
|
||||
assert b.info.b.c[0] == 5.0 and b.info.b.c.dtype == float
|
||||
assert np.all(b.info.b.c[1:] == 0.0)
|
||||
assert ptr.shape == (1,) and ptr[0] == 0
|
||||
assert ep_rew.shape == (1,) and ep_rew[0] == 1
|
||||
@ -180,8 +180,8 @@ def test_priortized_replaybuffer(size=32, bufsize=15):
|
||||
assert len(buf2) == min(bufsize, 3 * (i + 1))
|
||||
# check single buffer's data
|
||||
assert buf.info.key.shape == (buf.maxsize,)
|
||||
assert buf.rew.dtype == np.float
|
||||
assert buf.done.dtype == np.bool_
|
||||
assert buf.rew.dtype == float
|
||||
assert buf.done.dtype == bool
|
||||
data, indice = buf.sample(len(buf) // 2)
|
||||
buf.update_weight(indice, -data.weight / 2)
|
||||
assert np.allclose(buf.weight[indice], np.abs(-data.weight / 2) ** buf._alpha)
|
||||
@ -273,7 +273,7 @@ def test_segtree():
|
||||
index = tree.get_prefix_sum_idx(scalar)
|
||||
assert naive[:index].sum() <= scalar <= naive[:index + 1].sum()
|
||||
# corner case here
|
||||
naive = np.ones(actual_len, np.int)
|
||||
naive = np.ones(actual_len, int)
|
||||
tree[np.arange(actual_len)] = naive
|
||||
for scalar in range(actual_len):
|
||||
index = tree.get_prefix_sum_idx(scalar * 1.)
|
||||
@ -485,7 +485,7 @@ def test_replaybuffermanager():
|
||||
buf.set_batch(batch)
|
||||
assert np.allclose(buf.buffers[-1].info, [1] * 5)
|
||||
assert buf.sample_index(-1).tolist() == []
|
||||
assert np.array([ReplayBuffer(0, ignore_obs_next=True)]).dtype == np.object
|
||||
assert np.array([ReplayBuffer(0, ignore_obs_next=True)]).dtype == object
|
||||
|
||||
|
||||
def test_cachedbuffer():
|
||||
|
||||
@ -7,16 +7,18 @@ from numbers import Number
|
||||
from collections.abc import Collection
|
||||
from typing import Any, List, Dict, Union, Iterator, Optional, Iterable, Sequence
|
||||
|
||||
IndexType = Union[slice, int, np.ndarray, List[int]]
|
||||
|
||||
|
||||
def _is_batch_set(data: Any) -> bool:
|
||||
# Batch set is a list/tuple of dict/Batch objects,
|
||||
# or 1-D np.ndarray with np.object type,
|
||||
# or 1-D np.ndarray with object type,
|
||||
# where each element is a dict/Batch object
|
||||
if isinstance(data, np.ndarray): # most often case
|
||||
# "for e in data" will just unpack the first dimension,
|
||||
# but data.tolist() will flatten ndarray of objects
|
||||
# so do not use data.tolist()
|
||||
return data.dtype == np.object and all(
|
||||
return data.dtype == object and all(
|
||||
isinstance(e, (dict, Batch)) for e in data)
|
||||
elif isinstance(data, (list, tuple)):
|
||||
if len(data) > 0 and all(isinstance(e, (dict, Batch)) for e in data):
|
||||
@ -50,13 +52,13 @@ def _to_array_with_correct_type(v: Any) -> np.ndarray:
|
||||
if isinstance(v, np.ndarray) and issubclass(v.dtype.type, (np.bool_, np.number)):
|
||||
return v # most often case
|
||||
# convert the value to np.ndarray
|
||||
# convert to np.object data type if neither bool nor number
|
||||
# convert to object data type if neither bool nor number
|
||||
# raises an exception if array's elements are tensors themself
|
||||
v = np.asanyarray(v)
|
||||
if not issubclass(v.dtype.type, (np.bool_, np.number)):
|
||||
v = v.astype(np.object)
|
||||
if v.dtype == np.object:
|
||||
# scalar ndarray with np.object data type is very annoying
|
||||
v = v.astype(object)
|
||||
if v.dtype == object:
|
||||
# scalar ndarray with object data type is very annoying
|
||||
# a=np.array([np.array({}, dtype=object), np.array({}, dtype=object)])
|
||||
# a is not array([{}, {}], dtype=object), and a[0]={} results in
|
||||
# something very strange:
|
||||
@ -87,13 +89,11 @@ def _create_value(
|
||||
if has_shape:
|
||||
shape = (size, *inst.shape) if stack else (size, *inst.shape[1:])
|
||||
if isinstance(inst, np.ndarray):
|
||||
if issubclass(inst.dtype.type, (np.bool_, np.number)):
|
||||
target_type = inst.dtype.type
|
||||
else:
|
||||
target_type = np.object
|
||||
target_type = inst.dtype.type if issubclass(
|
||||
inst.dtype.type, (np.bool_, np.number)) else object
|
||||
return np.full(
|
||||
shape,
|
||||
fill_value=None if target_type == np.object else 0,
|
||||
fill_value=None if target_type == object else 0,
|
||||
dtype=target_type
|
||||
)
|
||||
elif isinstance(inst, torch.Tensor):
|
||||
@ -105,8 +105,8 @@ def _create_value(
|
||||
return zero_batch
|
||||
elif is_scalar:
|
||||
return _create_value(np.asarray(inst), size, stack=stack)
|
||||
else: # fall back to np.object
|
||||
return np.array([None for _ in range(size)])
|
||||
else: # fall back to object
|
||||
return np.array([None for _ in range(size)], object)
|
||||
|
||||
|
||||
def _assert_type_keys(keys: Iterable[str]) -> None:
|
||||
@ -187,7 +187,7 @@ class Batch:
|
||||
for k, v in batch_dict.items():
|
||||
self.__dict__[k] = _parse_value(v)
|
||||
elif _is_batch_set(batch_dict):
|
||||
self.stack_(batch_dict)
|
||||
self.stack_(batch_dict) # type: ignore
|
||||
if len(kwargs) > 0:
|
||||
self.__init__(kwargs, copy=copy) # type: ignore
|
||||
|
||||
@ -223,9 +223,7 @@ class Batch:
|
||||
"""
|
||||
self.__init__(**state) # type: ignore
|
||||
|
||||
def __getitem__(
|
||||
self, index: Union[str, slice, int, np.integer, np.ndarray, List[int]]
|
||||
) -> Any:
|
||||
def __getitem__(self, index: Union[str, IndexType]) -> Any:
|
||||
"""Return self[index]."""
|
||||
if isinstance(index, str):
|
||||
return self.__dict__[index]
|
||||
@ -241,11 +239,7 @@ class Batch:
|
||||
else:
|
||||
raise IndexError("Cannot access item from empty Batch object.")
|
||||
|
||||
def __setitem__(
|
||||
self,
|
||||
index: Union[str, slice, int, np.integer, np.ndarray, List[int]],
|
||||
value: Any,
|
||||
) -> None:
|
||||
def __setitem__(self, index: Union[str, IndexType], value: Any) -> None:
|
||||
"""Assign value to self[index]."""
|
||||
value = _parse_value(value)
|
||||
if isinstance(index, str):
|
||||
@ -530,8 +524,7 @@ class Batch:
|
||||
elif all(isinstance(e, (Batch, dict)) for e in v): # third often
|
||||
self.__dict__[k] = Batch.stack(v, axis)
|
||||
else: # most often case is np.ndarray
|
||||
v = np.stack(v, axis)
|
||||
self.__dict__[k] = _to_array_with_correct_type(v)
|
||||
self.__dict__[k] = _to_array_with_correct_type(np.stack(v, axis))
|
||||
# all the keys
|
||||
keys_total = set.union(*[set(b.keys()) for b in batches])
|
||||
# keys that are reserved in all batches
|
||||
@ -587,9 +580,7 @@ class Batch:
|
||||
batch.stack_(batches, axis)
|
||||
return batch
|
||||
|
||||
def empty_(
|
||||
self, index: Union[str, slice, int, np.integer, np.ndarray, List[int]] = None
|
||||
) -> "Batch":
|
||||
def empty_(self, index: Optional[Union[slice, IndexType]] = None) -> "Batch":
|
||||
"""Return an empty Batch object with 0 or None filled.
|
||||
|
||||
If "index" is specified, it will only reset the specific indexed-data.
|
||||
@ -620,7 +611,7 @@ class Batch:
|
||||
elif v is None:
|
||||
continue
|
||||
elif isinstance(v, np.ndarray):
|
||||
if v.dtype == np.object:
|
||||
if v.dtype == object:
|
||||
self.__dict__[k][index] = None
|
||||
else:
|
||||
self.__dict__[k][index] = 0
|
||||
@ -636,10 +627,7 @@ class Batch:
|
||||
return self
|
||||
|
||||
@staticmethod
|
||||
def empty(
|
||||
batch: "Batch",
|
||||
index: Union[str, slice, int, np.integer, np.ndarray, List[int]] = None,
|
||||
) -> "Batch":
|
||||
def empty(batch: "Batch", index: Optional[IndexType] = None) -> "Batch":
|
||||
"""Return an empty Batch object with 0 or None filled.
|
||||
|
||||
The shape is the same as the given Batch.
|
||||
|
||||
@ -115,9 +115,9 @@ class ReplayBuffer:
|
||||
def unfinished_index(self) -> np.ndarray:
|
||||
"""Return the index of unfinished episode."""
|
||||
last = (self._index - 1) % self._size if self._size else 0
|
||||
return np.array([last] if not self.done[last] and self._size else [], np.int)
|
||||
return np.array([last] if not self.done[last] and self._size else [], int)
|
||||
|
||||
def prev(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray:
|
||||
def prev(self, index: Union[int, np.ndarray]) -> np.ndarray:
|
||||
"""Return the index of previous transition.
|
||||
|
||||
The index won't be modified if it is the beginning of an episode.
|
||||
@ -126,7 +126,7 @@ class ReplayBuffer:
|
||||
end_flag = self.done[index] | (index == self.last_index[0])
|
||||
return (index + end_flag) % self._size
|
||||
|
||||
def next(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray:
|
||||
def next(self, index: Union[int, np.ndarray]) -> np.ndarray:
|
||||
"""Return the index of next transition.
|
||||
|
||||
The index won't be modified if it is the end of an episode.
|
||||
@ -140,12 +140,12 @@ class ReplayBuffer:
|
||||
Return the updated indices. If update fails, return an empty array.
|
||||
"""
|
||||
if len(buffer) == 0 or self.maxsize == 0:
|
||||
return np.array([], np.int)
|
||||
return np.array([], int)
|
||||
stack_num, buffer.stack_num = buffer.stack_num, 1
|
||||
from_indices = buffer.sample_index(0) # get all available indices
|
||||
buffer.stack_num = stack_num
|
||||
if len(from_indices) == 0:
|
||||
return np.array([], np.int)
|
||||
return np.array([], int)
|
||||
to_indices = []
|
||||
for _ in range(len(from_indices)):
|
||||
to_indices.append(self._index)
|
||||
@ -224,8 +224,8 @@ class ReplayBuffer:
|
||||
self._meta[ptr] = batch
|
||||
except ValueError:
|
||||
stack = not stacked_batch
|
||||
batch.rew = batch.rew.astype(np.float)
|
||||
batch.done = batch.done.astype(np.bool_)
|
||||
batch.rew = batch.rew.astype(float)
|
||||
batch.done = batch.done.astype(bool)
|
||||
if self._meta.is_empty():
|
||||
self._meta = _create_value( # type: ignore
|
||||
batch, self.maxsize, stack)
|
||||
@ -248,10 +248,10 @@ class ReplayBuffer:
|
||||
[np.arange(self._index, self._size), np.arange(self._index)]
|
||||
)
|
||||
else:
|
||||
return np.array([], np.int)
|
||||
return np.array([], int)
|
||||
else:
|
||||
if batch_size < 0:
|
||||
return np.array([], np.int)
|
||||
return np.array([], int)
|
||||
all_indices = prev_indices = np.concatenate(
|
||||
[np.arange(self._index, self._size), np.arange(self._index)]
|
||||
)
|
||||
@ -275,9 +275,9 @@ class ReplayBuffer:
|
||||
|
||||
def get(
|
||||
self,
|
||||
index: Union[int, np.integer, np.ndarray],
|
||||
index: Union[int, List[int], np.ndarray],
|
||||
key: str,
|
||||
default_value: Optional[Any] = None,
|
||||
default_value: Any = None,
|
||||
stack_num: Optional[int] = None,
|
||||
) -> Union[Batch, np.ndarray]:
|
||||
"""Return the stacked result.
|
||||
@ -303,7 +303,7 @@ class ReplayBuffer:
|
||||
if isinstance(index, list):
|
||||
indice = np.array(index)
|
||||
else:
|
||||
indice = index
|
||||
indice = index # type: ignore
|
||||
for _ in range(stack_num):
|
||||
stack = [val[indice]] + stack
|
||||
indice = self.prev(indice)
|
||||
@ -316,30 +316,31 @@ class ReplayBuffer:
|
||||
raise e # val != Batch()
|
||||
return Batch()
|
||||
|
||||
def __getitem__(self, index: Union[slice, int, np.integer, np.ndarray]) -> Batch:
|
||||
def __getitem__(self, index: Union[slice, int, List[int], np.ndarray]) -> Batch:
|
||||
"""Return a data batch: self[index].
|
||||
|
||||
If stack_num is larger than 1, return the stacked obs and obs_next with shape
|
||||
(batch, len, ...).
|
||||
"""
|
||||
if isinstance(index, slice): # change slice to np array
|
||||
if index == slice(None): # buffer[:] will get all available data
|
||||
index = self.sample_index(0)
|
||||
else:
|
||||
index = self._indices[:len(self)][index]
|
||||
# buffer[:] will get all available data
|
||||
indice = self.sample_index(0) if index == slice(None) \
|
||||
else self._indices[:len(self)][index]
|
||||
else:
|
||||
indice = index
|
||||
# raise KeyError first instead of AttributeError,
|
||||
# to support np.array([ReplayBuffer()])
|
||||
obs = self.get(index, "obs")
|
||||
obs = self.get(indice, "obs")
|
||||
if self._save_obs_next:
|
||||
obs_next = self.get(index, "obs_next", Batch())
|
||||
obs_next = self.get(indice, "obs_next", Batch())
|
||||
else:
|
||||
obs_next = self.get(self.next(index), "obs", Batch())
|
||||
obs_next = self.get(self.next(indice), "obs", Batch())
|
||||
return Batch(
|
||||
obs=obs,
|
||||
act=self.act[index],
|
||||
rew=self.rew[index],
|
||||
done=self.done[index],
|
||||
act=self.act[indice],
|
||||
rew=self.rew[indice],
|
||||
done=self.done[indice],
|
||||
obs_next=obs_next,
|
||||
info=self.get(index, "info", Batch()),
|
||||
policy=self.get(index, "policy", Batch()),
|
||||
info=self.get(indice, "info", Batch()),
|
||||
policy=self.get(indice, "policy", Batch()),
|
||||
)
|
||||
|
||||
@ -58,14 +58,14 @@ class CachedReplayBuffer(ReplayBufferManager):
|
||||
cached_buffer_ids[i]th cached buffer's corresponding episode result.
|
||||
"""
|
||||
if buffer_ids is None:
|
||||
buffer_ids = np.arange(1, 1 + self.cached_buffer_num)
|
||||
buf_arr = np.arange(1, 1 + self.cached_buffer_num)
|
||||
else: # make sure it is np.ndarray
|
||||
buffer_ids = np.asarray(buffer_ids) + 1
|
||||
ptr, ep_rew, ep_len, ep_idx = super().add(batch, buffer_ids=buffer_ids)
|
||||
buf_arr = np.asarray(buffer_ids) + 1
|
||||
ptr, ep_rew, ep_len, ep_idx = super().add(batch, buffer_ids=buf_arr)
|
||||
# find the terminated episode, move data from cached buf to main buf
|
||||
updated_ptr, updated_ep_idx = [], []
|
||||
done = batch.done.astype(np.bool_)
|
||||
for buffer_idx in buffer_ids[done]:
|
||||
done = batch.done.astype(bool)
|
||||
for buffer_idx in buf_arr[done]:
|
||||
index = self.main_buffer.update(self.buffers[buffer_idx])
|
||||
if len(index) == 0: # unsuccessful move, replace with -1
|
||||
index = [-1]
|
||||
|
||||
@ -22,7 +22,7 @@ class ReplayBufferManager(ReplayBuffer):
|
||||
|
||||
def __init__(self, buffer_list: List[ReplayBuffer]) -> None:
|
||||
self.buffer_num = len(buffer_list)
|
||||
self.buffers = np.array(buffer_list, dtype=np.object)
|
||||
self.buffers = np.array(buffer_list, dtype=object)
|
||||
offset, size = [], 0
|
||||
buffer_type = type(self.buffers[0])
|
||||
kwargs = self.buffers[0].options
|
||||
@ -46,7 +46,7 @@ class ReplayBufferManager(ReplayBuffer):
|
||||
_next_index(index, offset, done, last, lens)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self._lengths.sum()
|
||||
return int(self._lengths.sum())
|
||||
|
||||
def reset(self, keep_statistics: bool = False) -> None:
|
||||
self.last_index = self._offset.copy()
|
||||
@ -68,7 +68,7 @@ class ReplayBufferManager(ReplayBuffer):
|
||||
for offset, buf in zip(self._offset, self.buffers)
|
||||
])
|
||||
|
||||
def prev(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray:
|
||||
def prev(self, index: Union[int, np.ndarray]) -> np.ndarray:
|
||||
if isinstance(index, (list, np.ndarray)):
|
||||
return _prev_index(np.asarray(index), self._extend_offset,
|
||||
self.done, self.last_index, self._lengths)
|
||||
@ -76,7 +76,7 @@ class ReplayBufferManager(ReplayBuffer):
|
||||
return _prev_index(np.array([index]), self._extend_offset,
|
||||
self.done, self.last_index, self._lengths)[0]
|
||||
|
||||
def next(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray:
|
||||
def next(self, index: Union[int, np.ndarray]) -> np.ndarray:
|
||||
if isinstance(index, (list, np.ndarray)):
|
||||
return _next_index(np.asarray(index), self._extend_offset,
|
||||
self.done, self.last_index, self._lengths)
|
||||
@ -130,8 +130,8 @@ class ReplayBufferManager(ReplayBuffer):
|
||||
try:
|
||||
self._meta[ptrs] = batch
|
||||
except ValueError:
|
||||
batch.rew = batch.rew.astype(np.float)
|
||||
batch.done = batch.done.astype(np.bool_)
|
||||
batch.rew = batch.rew.astype(float)
|
||||
batch.done = batch.done.astype(bool)
|
||||
if self._meta.is_empty():
|
||||
self._meta = _create_value( # type: ignore
|
||||
batch, self.maxsize, stack=False)
|
||||
@ -143,7 +143,7 @@ class ReplayBufferManager(ReplayBuffer):
|
||||
|
||||
def sample_index(self, batch_size: int) -> np.ndarray:
|
||||
if batch_size < 0:
|
||||
return np.array([], np.int)
|
||||
return np.array([], int)
|
||||
if self._sample_avail and self.stack_num > 1:
|
||||
all_indices = np.concatenate([
|
||||
buf.sample_index(0) + offset
|
||||
@ -154,7 +154,7 @@ class ReplayBufferManager(ReplayBuffer):
|
||||
else:
|
||||
return np.random.choice(all_indices, batch_size)
|
||||
if batch_size == 0: # get all available indices
|
||||
sample_num = np.zeros(self.buffer_num, np.int)
|
||||
sample_num = np.zeros(self.buffer_num, int)
|
||||
else:
|
||||
buffer_idx = np.random.choice(
|
||||
self.buffer_num, batch_size, p=self._lengths / self._lengths.sum()
|
||||
|
||||
@ -34,6 +34,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
||||
def update(self, buffer: ReplayBuffer) -> np.ndarray:
|
||||
indices = super().update(buffer)
|
||||
self.init_weight(indices)
|
||||
return indices
|
||||
|
||||
def add(
|
||||
self, batch: Batch, buffer_ids: Optional[Union[np.ndarray, List[int]]] = None
|
||||
@ -45,13 +46,11 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
||||
def sample_index(self, batch_size: int) -> np.ndarray:
|
||||
if batch_size > 0 and len(self) > 0:
|
||||
scalar = np.random.rand(batch_size) * self.weight.reduce()
|
||||
return self.weight.get_prefix_sum_idx(scalar)
|
||||
return self.weight.get_prefix_sum_idx(scalar) # type: ignore
|
||||
else:
|
||||
return super().sample_index(batch_size)
|
||||
|
||||
def get_weight(
|
||||
self, index: Union[slice, int, np.integer, np.ndarray]
|
||||
) -> np.ndarray:
|
||||
def get_weight(self, index: Union[int, np.ndarray]) -> Union[float, np.ndarray]:
|
||||
"""Get the importance sampling weight.
|
||||
|
||||
The "weight" in the returned Batch is the weight on loss function to de-bias
|
||||
@ -76,7 +75,13 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
||||
self._max_prio = max(self._max_prio, weight.max())
|
||||
self._min_prio = min(self._min_prio, weight.min())
|
||||
|
||||
def __getitem__(self, index: Union[slice, int, np.integer, np.ndarray]) -> Batch:
|
||||
batch = super().__getitem__(index)
|
||||
batch.weight = self.get_weight(index)
|
||||
def __getitem__(self, index: Union[slice, int, List[int], np.ndarray]) -> Batch:
|
||||
if isinstance(index, slice): # change slice to np array
|
||||
# buffer[:] will get all available data
|
||||
indice = self.sample_index(0) if index == slice(None) \
|
||||
else self._indices[:len(self)][index]
|
||||
else:
|
||||
indice = index
|
||||
batch = super().__getitem__(indice)
|
||||
batch.weight = self.get_weight(indice)
|
||||
return batch
|
||||
|
||||
@ -123,7 +123,7 @@ class Collector(object):
|
||||
if isinstance(state, torch.Tensor):
|
||||
state[id].zero_()
|
||||
elif isinstance(state, np.ndarray):
|
||||
state[id] = None if state.dtype == np.object else 0
|
||||
state[id] = None if state.dtype == object else 0
|
||||
elif isinstance(state, Batch):
|
||||
state.empty_(id)
|
||||
|
||||
@ -266,7 +266,7 @@ class Collector(object):
|
||||
if n_episode:
|
||||
surplus_env_num = len(ready_env_ids) - (n_episode - episode_count)
|
||||
if surplus_env_num > 0:
|
||||
mask = np.ones_like(ready_env_ids, np.bool)
|
||||
mask = np.ones_like(ready_env_ids, dtype=bool)
|
||||
mask[env_ind_local[:surplus_env_num]] = False
|
||||
ready_env_ids = ready_env_ids[mask]
|
||||
self.data = self.data[mask]
|
||||
@ -291,7 +291,7 @@ class Collector(object):
|
||||
rews, lens, idxs = list(map(
|
||||
np.concatenate, [episode_rews, episode_lens, episode_start_indices]))
|
||||
else:
|
||||
rews, lens, idxs = np.array([]), np.array([], np.int), np.array([], np.int)
|
||||
rews, lens, idxs = np.array([]), np.array([], int), np.array([], int)
|
||||
|
||||
return {
|
||||
"n/ep": episode_count,
|
||||
@ -493,7 +493,7 @@ class AsyncCollector(Collector):
|
||||
rews, lens, idxs = list(map(
|
||||
np.concatenate, [episode_rews, episode_lens, episode_start_indices]))
|
||||
else:
|
||||
rews, lens, idxs = np.array([]), np.array([], np.int), np.array([], np.int)
|
||||
rews, lens, idxs = np.array([]), np.array([], int), np.array([], int)
|
||||
|
||||
return {
|
||||
"n/ep": episode_count,
|
||||
|
||||
@ -4,15 +4,12 @@ import pickle
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
from numbers import Number
|
||||
from typing import Dict, Union, Optional
|
||||
from typing import Any, Dict, Union, Optional
|
||||
|
||||
from tianshou.data.batch import _parse_value, Batch
|
||||
|
||||
|
||||
def to_numpy(
|
||||
x: Optional[Union[Batch, dict, list, tuple, np.number, np.bool_, Number,
|
||||
np.ndarray, torch.Tensor]]
|
||||
) -> Union[Batch, dict, list, tuple, np.ndarray]:
|
||||
def to_numpy(x: Any) -> Union[Batch, np.ndarray]:
|
||||
"""Return an object without torch.Tensor."""
|
||||
if isinstance(x, torch.Tensor): # most often case
|
||||
return x.detach().cpu().numpy()
|
||||
@ -21,28 +18,22 @@ def to_numpy(
|
||||
elif isinstance(x, (np.number, np.bool_, Number)):
|
||||
return np.asanyarray(x)
|
||||
elif x is None:
|
||||
return np.array(None, dtype=np.object)
|
||||
elif isinstance(x, Batch):
|
||||
x = deepcopy(x)
|
||||
return np.array(None, dtype=object)
|
||||
elif isinstance(x, (dict, Batch)):
|
||||
x = Batch(x) if isinstance(x, dict) else deepcopy(x)
|
||||
x.to_numpy()
|
||||
return x
|
||||
elif isinstance(x, dict):
|
||||
return {k: to_numpy(v) for k, v in x.items()}
|
||||
elif isinstance(x, (list, tuple)):
|
||||
try:
|
||||
return to_numpy(_parse_value(x))
|
||||
except TypeError:
|
||||
return [to_numpy(e) for e in x]
|
||||
return to_numpy(_parse_value(x))
|
||||
else: # fallback
|
||||
return np.asanyarray(x)
|
||||
|
||||
|
||||
def to_torch(
|
||||
x: Union[Batch, dict, list, tuple, np.number, np.bool_, Number, np.ndarray,
|
||||
torch.Tensor],
|
||||
x: Any,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Union[str, int, torch.device] = "cpu",
|
||||
) -> Union[Batch, dict, list, tuple, torch.Tensor]:
|
||||
) -> Union[Batch, torch.Tensor]:
|
||||
"""Return an object without np.ndarray."""
|
||||
if isinstance(x, np.ndarray) and issubclass(
|
||||
x.dtype.type, (np.bool_, np.number)
|
||||
@ -57,25 +48,17 @@ def to_torch(
|
||||
return x.to(device) # type: ignore
|
||||
elif isinstance(x, (np.number, np.bool_, Number)):
|
||||
return to_torch(np.asanyarray(x), dtype, device)
|
||||
elif isinstance(x, dict):
|
||||
return {k: to_torch(v, dtype, device) for k, v in x.items()}
|
||||
elif isinstance(x, Batch):
|
||||
x = deepcopy(x)
|
||||
elif isinstance(x, (dict, Batch)):
|
||||
x = Batch(x, copy=True) if isinstance(x, dict) else deepcopy(x)
|
||||
x.to_torch(dtype, device)
|
||||
return x
|
||||
elif isinstance(x, (list, tuple)):
|
||||
try:
|
||||
return to_torch(_parse_value(x), dtype, device)
|
||||
except TypeError:
|
||||
return [to_torch(e, dtype, device) for e in x]
|
||||
return to_torch(_parse_value(x), dtype, device)
|
||||
else: # fallback
|
||||
raise TypeError(f"object {x} cannot be converted to torch.")
|
||||
|
||||
|
||||
def to_torch_as(
|
||||
x: Union[Batch, dict, list, tuple, np.ndarray, torch.Tensor],
|
||||
y: torch.Tensor,
|
||||
) -> Union[Batch, dict, list, tuple, torch.Tensor]:
|
||||
def to_torch_as(x: Any, y: torch.Tensor) -> Union[Batch, torch.Tensor]:
|
||||
"""Return an object without np.ndarray.
|
||||
|
||||
Same as ``to_torch(x, dtype=y.dtype, device=y.device)``.
|
||||
@ -147,25 +130,20 @@ def to_hdf5(x: Hdf5ConvertibleType, y: h5py.Group) -> None:
|
||||
y[k].attrs["__data_type__"] = v.__class__.__name__
|
||||
|
||||
|
||||
def from_hdf5(
|
||||
x: h5py.Group, device: Optional[str] = None
|
||||
) -> Hdf5ConvertibleType:
|
||||
def from_hdf5(x: h5py.Group, device: Optional[str] = None) -> Hdf5ConvertibleValues:
|
||||
"""Restore object from HDF5 group."""
|
||||
if isinstance(x, h5py.Dataset):
|
||||
# handle datasets
|
||||
if x.attrs["__data_type__"] == "ndarray":
|
||||
y = np.array(x)
|
||||
return np.array(x)
|
||||
elif x.attrs["__data_type__"] == "Tensor":
|
||||
y = torch.tensor(x, device=device)
|
||||
return torch.tensor(x, device=device)
|
||||
else:
|
||||
y = pickle.loads(x[()])
|
||||
return pickle.loads(x[()])
|
||||
else:
|
||||
# handle groups representing a dict or a Batch
|
||||
y = {k: v for k, v in x.attrs.items() if k != "__data_type__"}
|
||||
y = dict(x.attrs.items())
|
||||
data_type = y.pop("__data_type__", None)
|
||||
for k, v in x.items():
|
||||
y[k] = from_hdf5(v, device)
|
||||
if "__data_type__" in x.attrs:
|
||||
# if dictionary represents Batch, convert to Batch
|
||||
if x.attrs["__data_type__"] == "Batch":
|
||||
y = Batch(y)
|
||||
return y
|
||||
return Batch(y) if data_type == "Batch" else y
|
||||
|
||||
10
tianshou/env/venvs.py
vendored
10
tianshou/env/venvs.py
vendored
@ -140,12 +140,10 @@ class BaseVectorEnv(gym.Env):
|
||||
self, id: Optional[Union[int, List[int], np.ndarray]] = None
|
||||
) -> Union[List[int], np.ndarray]:
|
||||
if id is None:
|
||||
id = list(range(self.env_num))
|
||||
elif np.isscalar(id):
|
||||
id = [id]
|
||||
return id
|
||||
return list(range(self.env_num))
|
||||
return [id] if np.isscalar(id) else id # type: ignore
|
||||
|
||||
def _assert_id(self, id: List[int]) -> None:
|
||||
def _assert_id(self, id: Union[List[int], np.ndarray]) -> None:
|
||||
for i in id:
|
||||
assert i not in self.waiting_id, \
|
||||
f"Cannot interact with environment {i} which is stepping now."
|
||||
@ -291,7 +289,7 @@ class BaseVectorEnv(gym.Env):
|
||||
clip_max = 10.0 # this magic number is from openai baselines
|
||||
# see baselines/common/vec_env/vec_normalize.py#L10
|
||||
obs = (obs - self.obs_rms.mean) / np.sqrt(self.obs_rms.var + self.__eps)
|
||||
obs = np.clip(obs, -clip_max, clip_max)
|
||||
obs = np.clip(obs, -clip_max, clip_max) # type: ignore
|
||||
return obs
|
||||
|
||||
def __del__(self) -> None:
|
||||
|
||||
8
tianshou/env/worker/base.py
vendored
8
tianshou/env/worker/base.py
vendored
@ -25,9 +25,7 @@ class EnvWorker(ABC):
|
||||
def send_action(self, action: np.ndarray) -> None:
|
||||
pass
|
||||
|
||||
def get_result(
|
||||
self,
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
def get_result(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
return self.result
|
||||
|
||||
def step(
|
||||
@ -45,9 +43,7 @@ class EnvWorker(ABC):
|
||||
|
||||
@staticmethod
|
||||
def wait(
|
||||
workers: List["EnvWorker"],
|
||||
wait_num: int,
|
||||
timeout: Optional[float] = None,
|
||||
workers: List["EnvWorker"], wait_num: int, timeout: Optional[float] = None
|
||||
) -> List["EnvWorker"]:
|
||||
"""Given a list of workers, return those ready ones."""
|
||||
raise NotImplementedError
|
||||
|
||||
4
tianshou/env/worker/dummy.py
vendored
4
tianshou/env/worker/dummy.py
vendored
@ -20,9 +20,7 @@ class DummyEnvWorker(EnvWorker):
|
||||
|
||||
@staticmethod
|
||||
def wait( # type: ignore
|
||||
workers: List["DummyEnvWorker"],
|
||||
wait_num: int,
|
||||
timeout: Optional[float] = None,
|
||||
workers: List["DummyEnvWorker"], wait_num: int, timeout: Optional[float] = None
|
||||
) -> List["DummyEnvWorker"]:
|
||||
# Sequential EnvWorker objects are always ready
|
||||
return workers
|
||||
|
||||
4
tianshou/env/worker/ray.py
vendored
4
tianshou/env/worker/ray.py
vendored
@ -25,9 +25,7 @@ class RayEnvWorker(EnvWorker):
|
||||
|
||||
@staticmethod
|
||||
def wait( # type: ignore
|
||||
workers: List["RayEnvWorker"],
|
||||
wait_num: int,
|
||||
timeout: Optional[float] = None,
|
||||
workers: List["RayEnvWorker"], wait_num: int, timeout: Optional[float] = None
|
||||
) -> List["RayEnvWorker"]:
|
||||
results = [x.result for x in workers]
|
||||
ready_results, _ = ray.wait(results, num_returns=wait_num, timeout=timeout)
|
||||
|
||||
6
tianshou/env/worker/subproc.py
vendored
6
tianshou/env/worker/subproc.py
vendored
@ -12,7 +12,6 @@ from tianshou.env.utils import CloudpickleWrapper
|
||||
|
||||
|
||||
_NP_TO_CT = {
|
||||
np.bool: ctypes.c_bool,
|
||||
np.bool_: ctypes.c_bool,
|
||||
np.uint8: ctypes.c_uint8,
|
||||
np.uint16: ctypes.c_uint16,
|
||||
@ -31,7 +30,7 @@ class ShArray:
|
||||
"""Wrapper of multiprocessing Array."""
|
||||
|
||||
def __init__(self, dtype: np.generic, shape: Tuple[int]) -> None:
|
||||
self.arr = Array(_NP_TO_CT[dtype.type], int(np.prod(shape)))
|
||||
self.arr = Array(_NP_TO_CT[dtype.type], int(np.prod(shape))) # type: ignore
|
||||
self.dtype = dtype
|
||||
self.shape = shape
|
||||
|
||||
@ -64,8 +63,7 @@ def _worker(
|
||||
obs_bufs: Optional[Union[dict, tuple, ShArray]] = None,
|
||||
) -> None:
|
||||
def _encode_obs(
|
||||
obs: Union[dict, tuple, np.ndarray],
|
||||
buffer: Union[dict, tuple, ShArray],
|
||||
obs: Union[dict, tuple, np.ndarray], buffer: Union[dict, tuple, ShArray]
|
||||
) -> None:
|
||||
if isinstance(obs, np.ndarray) and isinstance(buffer, ShArray):
|
||||
buffer.save(obs)
|
||||
|
||||
@ -68,9 +68,7 @@ class OUNoise(BaseNoise):
|
||||
"""Reset to the initial state."""
|
||||
self._x = self._x0
|
||||
|
||||
def __call__(
|
||||
self, size: Sequence[int], mu: Optional[float] = None
|
||||
) -> np.ndarray:
|
||||
def __call__(self, size: Sequence[int], mu: Optional[float] = None) -> np.ndarray:
|
||||
"""Generate new noise.
|
||||
|
||||
Return an numpy array which size is equal to ``size``.
|
||||
@ -82,4 +80,4 @@ class OUNoise(BaseNoise):
|
||||
mu = self._mu
|
||||
r = self._beta * np.random.normal(size=size)
|
||||
self._x = self._x + self._alpha * (mu - self._x) + r
|
||||
return self._x
|
||||
return self._x # type: ignore
|
||||
|
||||
@ -142,14 +142,14 @@ class BasePolicy(ABC, nn.Module):
|
||||
isinstance(act, np.ndarray):
|
||||
# currently this action mapping only supports np.ndarray action
|
||||
if self.action_bound_method == "clip":
|
||||
act = np.clip(act, -1.0, 1.0)
|
||||
act = np.clip(act, -1.0, 1.0) # type: ignore
|
||||
elif self.action_bound_method == "tanh":
|
||||
act = np.tanh(act)
|
||||
if self.action_scaling:
|
||||
assert np.all(act >= -1.0) and np.all(act <= 1.0), \
|
||||
assert np.min(act) >= -1.0 and np.max(act) <= 1.0, \
|
||||
"action scaling only accepts raw action range = [-1, 1]"
|
||||
low, high = self.action_space.low, self.action_space.high
|
||||
act = low + (high - low) * (act + 1.0) / 2.0
|
||||
act = low + (high - low) * (act + 1.0) / 2.0 # type: ignore
|
||||
return act
|
||||
|
||||
def process_fn(
|
||||
@ -241,9 +241,9 @@ class BasePolicy(ABC, nn.Module):
|
||||
:return: A bool type numpy.ndarray in the same shape with indice. "True" means
|
||||
"obs_next" of that buffer[indice] is valid.
|
||||
"""
|
||||
mask = ~buffer.done[indice].astype(np.bool)
|
||||
# info['TimeLimit.truncated'] will be set to True if 'done' flag is generated
|
||||
# because of timelimit of environments. Checkout gym.wrappers.TimeLimit.
|
||||
mask = ~buffer.done[indice]
|
||||
# info["TimeLimit.truncated"] will be True if "done" flag is generated by
|
||||
# timelimit of environments. Checkout gym.wrappers.TimeLimit.
|
||||
if hasattr(buffer, 'info') and 'TimeLimit.truncated' in buffer.info:
|
||||
mask = mask | buffer.info['TimeLimit.truncated'][indice]
|
||||
return mask
|
||||
@ -281,7 +281,8 @@ class BasePolicy(ABC, nn.Module):
|
||||
assert np.isclose(gae_lambda, 1.0)
|
||||
v_s_ = np.zeros_like(rew)
|
||||
else:
|
||||
v_s_ = to_numpy(v_s_.flatten()) * BasePolicy.value_mask(buffer, indice)
|
||||
v_s_ = to_numpy(v_s_.flatten()) # type: ignore
|
||||
v_s_ = v_s_ * BasePolicy.value_mask(buffer, indice)
|
||||
v_s = np.roll(v_s_, 1) if v_s is None else to_numpy(v_s.flatten())
|
||||
|
||||
end_flag = batch.done.copy()
|
||||
|
||||
@ -58,7 +58,7 @@ class DiscreteBCQPolicy(DQNPolicy):
|
||||
else:
|
||||
self._log_tau = -np.inf
|
||||
assert 0.0 <= eval_eps < 1.0
|
||||
self._eps = eval_eps
|
||||
self.eps = eval_eps
|
||||
self._weight_reg = imitation_logits_penalty
|
||||
|
||||
def train(self, mode: bool = True) -> "DiscreteBCQPolicy":
|
||||
@ -96,15 +96,6 @@ class DiscreteBCQPolicy(DQNPolicy):
|
||||
return Batch(act=action, state=state, q_value=q_value,
|
||||
imitation_logits=imitation_logits)
|
||||
|
||||
def exploration_noise(self, act: np.ndarray, batch: Batch) -> np.ndarray:
|
||||
# add eps to act
|
||||
if not np.isclose(self._eps, 0.0):
|
||||
bsz = len(act)
|
||||
mask = np.random.rand(bsz) < self._eps
|
||||
act_rand = np.random.randint(self.max_action_num, size=[bsz])
|
||||
act[mask] = act_rand[mask]
|
||||
return act
|
||||
|
||||
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
|
||||
if self._iter % self._freq == 0:
|
||||
self.sync_weight()
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from typing import Any, Dict, Union, Optional
|
||||
from typing import Any, Dict, Tuple, Union, Optional
|
||||
|
||||
from tianshou.data import Batch
|
||||
from tianshou.policy import BasePolicy
|
||||
@ -100,7 +100,7 @@ class PSRLModel(object):
|
||||
discount_factor: float,
|
||||
eps: float,
|
||||
value: np.ndarray,
|
||||
) -> np.ndarray:
|
||||
) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""Value iteration solver for MDPs.
|
||||
|
||||
:param np.ndarray trans_prob: transition probabilities, with shape
|
||||
@ -126,7 +126,7 @@ class PSRLModel(object):
|
||||
def __call__(
|
||||
self,
|
||||
obs: np.ndarray,
|
||||
state: Optional[Any] = None,
|
||||
state: Any = None,
|
||||
info: Dict[str, Any] = {},
|
||||
) -> np.ndarray:
|
||||
if not self.updated:
|
||||
@ -215,6 +215,6 @@ class PSRLPolicy(BasePolicy):
|
||||
rew_count[obs_next, :] += 1
|
||||
self.model.observe(trans_count, rew_sum, rew_square_sum, rew_count)
|
||||
return {
|
||||
"psrl/rew_mean": self.model.rew_mean.mean(),
|
||||
"psrl/rew_std": self.model.rew_std.mean(),
|
||||
"psrl/rew_mean": float(self.model.rew_mean.mean()),
|
||||
"psrl/rew_std": float(self.model.rew_std.mean()),
|
||||
}
|
||||
|
||||
@ -5,7 +5,7 @@ import torch.nn.functional as F
|
||||
from typing import Any, Dict, List, Type, Optional
|
||||
|
||||
from tianshou.policy import PGPolicy
|
||||
from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch_as
|
||||
from tianshou.data import Batch, ReplayBuffer, to_torch_as
|
||||
|
||||
|
||||
class A2CPolicy(PGPolicy):
|
||||
@ -84,8 +84,8 @@ class A2CPolicy(PGPolicy):
|
||||
v_s.append(self.critic(b.obs))
|
||||
v_s_.append(self.critic(b.obs_next))
|
||||
batch.v_s = torch.cat(v_s, dim=0).flatten() # old value
|
||||
v_s = to_numpy(batch.v_s)
|
||||
v_s_ = to_numpy(torch.cat(v_s_, dim=0).flatten())
|
||||
v_s = batch.v_s.cpu().numpy()
|
||||
v_s_ = torch.cat(v_s_, dim=0).flatten().cpu().numpy()
|
||||
# when normalizing values, we do not minus self.ret_rms.mean to be numerically
|
||||
# consistent with OPENAI baselines' value normalization pipeline. Emperical
|
||||
# study also shows that "minus mean" will harm performances a tiny little bit
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import torch
|
||||
import warnings
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, Tuple, Union, Optional
|
||||
@ -167,7 +168,12 @@ class DDPGPolicy(BasePolicy):
|
||||
"loss/critic": critic_loss.item(),
|
||||
}
|
||||
|
||||
def exploration_noise(self, act: np.ndarray, batch: Batch) -> np.ndarray:
|
||||
if self._noise:
|
||||
act = act + self._noise(act.shape)
|
||||
def exploration_noise(
|
||||
self, act: Union[np.ndarray, Batch], batch: Batch
|
||||
) -> Union[np.ndarray, Batch]:
|
||||
if self._noise is None:
|
||||
return act
|
||||
if isinstance(act, np.ndarray):
|
||||
return act + self._noise(act.shape)
|
||||
warnings.warn("Cannot add exploration noise to non-numpy_array action.")
|
||||
return act
|
||||
|
||||
@ -168,8 +168,10 @@ class DQNPolicy(BasePolicy):
|
||||
self._iter += 1
|
||||
return {"loss": loss.item()}
|
||||
|
||||
def exploration_noise(self, act: np.ndarray, batch: Batch) -> np.ndarray:
|
||||
if not np.isclose(self.eps, 0.0):
|
||||
def exploration_noise(
|
||||
self, act: Union[np.ndarray, Batch], batch: Batch
|
||||
) -> Union[np.ndarray, Batch]:
|
||||
if isinstance(act, np.ndarray) and not np.isclose(self.eps, 0.0):
|
||||
bsz = len(act)
|
||||
rand_mask = np.random.rand(bsz) < self.eps
|
||||
q = np.random.rand(bsz, self.max_action_num) # [0, 1]
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import numpy as np
|
||||
from typing import Any, Dict, List, Union, Optional
|
||||
from typing import Any, Dict, List, Tuple, Union, Optional
|
||||
|
||||
from tianshou.policy import BasePolicy
|
||||
from tianshou.data import Batch, ReplayBuffer
|
||||
@ -71,7 +71,7 @@ class MultiAgentPolicyManager(BasePolicy):
|
||||
act[agent_index], batch[agent_index])
|
||||
return act
|
||||
|
||||
def forward(
|
||||
def forward( # type: ignore
|
||||
self,
|
||||
batch: Batch,
|
||||
state: Optional[Union[dict, Batch]] = None,
|
||||
@ -100,7 +100,8 @@ class MultiAgentPolicyManager(BasePolicy):
|
||||
"agent_n": xxx}
|
||||
}
|
||||
"""
|
||||
results = []
|
||||
results: List[Tuple[bool, np.ndarray, Batch,
|
||||
Union[np.ndarray, Batch], Batch]] = []
|
||||
for policy in self.policies:
|
||||
# This part of code is difficult to understand.
|
||||
# Let's follow an example with two agents
|
||||
@ -112,7 +113,7 @@ class MultiAgentPolicyManager(BasePolicy):
|
||||
agent_index = np.nonzero(batch.obs.agent_id == policy.agent_id)[0]
|
||||
if len(agent_index) == 0:
|
||||
# (has_data, agent_index, out, act, state)
|
||||
results.append((False, None, Batch(), None, Batch()))
|
||||
results.append((False, np.array([-1]), Batch(), Batch(), Batch()))
|
||||
continue
|
||||
tmp_batch = batch[agent_index]
|
||||
if isinstance(tmp_batch.rew, np.ndarray):
|
||||
|
||||
@ -14,16 +14,12 @@ class BaseLogger(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def write(
|
||||
self,
|
||||
key: str,
|
||||
x: Union[Number, np.number, np.ndarray],
|
||||
y: Union[Number, np.number, np.ndarray],
|
||||
**kwargs: Any,
|
||||
self, key: str, x: int, y: Union[Number, np.number, np.ndarray], **kwargs: Any
|
||||
) -> None:
|
||||
"""Specify how the writer is used to log data.
|
||||
|
||||
:param key: namespace which the input data tuple belongs to.
|
||||
:param x: stands for the ordinate of the input data tuple.
|
||||
:param str key: namespace which the input data tuple belongs to.
|
||||
:param int x: stands for the ordinate of the input data tuple.
|
||||
:param y: stands for the abscissa of the input data tuple.
|
||||
"""
|
||||
pass
|
||||
@ -84,11 +80,7 @@ class BasicLogger(BaseLogger):
|
||||
self.last_log_update_step = -1
|
||||
|
||||
def write(
|
||||
self,
|
||||
key: str,
|
||||
x: Union[Number, np.number, np.ndarray],
|
||||
y: Union[Number, np.number, np.ndarray],
|
||||
**kwargs: Any,
|
||||
self, key: str, x: int, y: Union[Number, np.number, np.ndarray], **kwargs: Any
|
||||
) -> None:
|
||||
self.writer.add_scalar(key, y, global_step=x)
|
||||
|
||||
@ -149,11 +141,7 @@ class LazyLogger(BasicLogger):
|
||||
super().__init__(None) # type: ignore
|
||||
|
||||
def write(
|
||||
self,
|
||||
key: str,
|
||||
x: Union[Number, np.number, np.ndarray],
|
||||
y: Union[Number, np.number, np.ndarray],
|
||||
**kwargs: Any,
|
||||
self, key: str, x: int, y: Union[Number, np.number, np.ndarray], **kwargs: Any
|
||||
) -> None:
|
||||
"""The LazyLogger writes nothing."""
|
||||
pass
|
||||
|
||||
@ -50,8 +50,7 @@ class MLP(nn.Module):
|
||||
output_dim: int = 0,
|
||||
hidden_sizes: Sequence[int] = (),
|
||||
norm_layer: Optional[Union[ModuleType, Sequence[ModuleType]]] = None,
|
||||
activation: Optional[Union[ModuleType, Sequence[ModuleType]]]
|
||||
= nn.ReLU,
|
||||
activation: Optional[Union[ModuleType, Sequence[ModuleType]]] = nn.ReLU,
|
||||
device: Optional[Union[str, int, torch.device]] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@ -139,7 +138,7 @@ class Net(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
state_shape: Union[int, Sequence[int]],
|
||||
action_shape: Optional[Union[int, Sequence[int]]] = 0,
|
||||
action_shape: Union[int, Sequence[int]] = 0,
|
||||
hidden_sizes: Sequence[int] = (),
|
||||
norm_layer: Optional[ModuleType] = None,
|
||||
activation: Optional[ModuleType] = nn.ReLU,
|
||||
@ -153,8 +152,8 @@ class Net(nn.Module):
|
||||
self.device = device
|
||||
self.softmax = softmax
|
||||
self.num_atoms = num_atoms
|
||||
input_dim = np.prod(state_shape)
|
||||
action_dim = np.prod(action_shape) * num_atoms
|
||||
input_dim = int(np.prod(state_shape))
|
||||
action_dim = int(np.prod(action_shape)) * num_atoms
|
||||
if concat:
|
||||
input_dim += action_dim
|
||||
self.use_dueling = dueling_param is not None
|
||||
@ -179,7 +178,7 @@ class Net(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
s: Union[np.ndarray, torch.Tensor],
|
||||
state: Optional[Any] = None,
|
||||
state: Any = None,
|
||||
info: Dict[str, Any] = {},
|
||||
) -> Tuple[torch.Tensor, Any]:
|
||||
"""Mapping: s -> flatten (inside MLP)-> logits."""
|
||||
@ -221,8 +220,8 @@ class Recurrent(nn.Module):
|
||||
num_layers=layer_num,
|
||||
batch_first=True,
|
||||
)
|
||||
self.fc1 = nn.Linear(np.prod(state_shape), hidden_layer_size)
|
||||
self.fc2 = nn.Linear(hidden_layer_size, np.prod(action_shape))
|
||||
self.fc1 = nn.Linear(int(np.prod(state_shape)), hidden_layer_size)
|
||||
self.fc2 = nn.Linear(hidden_layer_size, int(np.prod(action_shape)))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -46,7 +46,7 @@ class Actor(nn.Module):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.preprocess = preprocess_net
|
||||
self.output_dim = np.prod(action_shape)
|
||||
self.output_dim = int(np.prod(action_shape))
|
||||
input_dim = getattr(preprocess_net, "output_dim",
|
||||
preprocess_net_output_dim)
|
||||
self.last = MLP(input_dim, self.output_dim,
|
||||
@ -56,7 +56,7 @@ class Actor(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
s: Union[np.ndarray, torch.Tensor],
|
||||
state: Optional[Any] = None,
|
||||
state: Any = None,
|
||||
info: Dict[str, Any] = {},
|
||||
) -> Tuple[torch.Tensor, Any]:
|
||||
"""Mapping: s -> logits -> action."""
|
||||
@ -162,7 +162,7 @@ class ActorProb(nn.Module):
|
||||
super().__init__()
|
||||
self.preprocess = preprocess_net
|
||||
self.device = device
|
||||
self.output_dim = np.prod(action_shape)
|
||||
self.output_dim = int(np.prod(action_shape))
|
||||
input_dim = getattr(preprocess_net, "output_dim",
|
||||
preprocess_net_output_dim)
|
||||
self.mu = MLP(input_dim, self.output_dim,
|
||||
@ -179,7 +179,7 @@ class ActorProb(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
s: Union[np.ndarray, torch.Tensor],
|
||||
state: Optional[Any] = None,
|
||||
state: Any = None,
|
||||
info: Dict[str, Any] = {},
|
||||
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Any]:
|
||||
"""Mapping: s -> logits -> (mu, sigma)."""
|
||||
@ -219,12 +219,12 @@ class RecurrentActorProb(nn.Module):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.nn = nn.LSTM(
|
||||
input_size=np.prod(state_shape),
|
||||
input_size=int(np.prod(state_shape)),
|
||||
hidden_size=hidden_layer_size,
|
||||
num_layers=layer_num,
|
||||
batch_first=True,
|
||||
)
|
||||
output_dim = np.prod(action_shape)
|
||||
output_dim = int(np.prod(action_shape))
|
||||
self.mu = nn.Linear(hidden_layer_size, output_dim)
|
||||
self._c_sigma = conditioned_sigma
|
||||
if conditioned_sigma:
|
||||
@ -293,12 +293,12 @@ class RecurrentCritic(nn.Module):
|
||||
self.action_shape = action_shape
|
||||
self.device = device
|
||||
self.nn = nn.LSTM(
|
||||
input_size=np.prod(state_shape),
|
||||
input_size=int(np.prod(state_shape)),
|
||||
hidden_size=hidden_layer_size,
|
||||
num_layers=layer_num,
|
||||
batch_first=True,
|
||||
)
|
||||
self.fc2 = nn.Linear(hidden_layer_size + np.prod(action_shape), 1)
|
||||
self.fc2 = nn.Linear(hidden_layer_size + int(np.prod(action_shape)), 1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -45,7 +45,7 @@ class Actor(nn.Module):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.preprocess = preprocess_net
|
||||
self.output_dim = np.prod(action_shape)
|
||||
self.output_dim = int(np.prod(action_shape))
|
||||
input_dim = getattr(preprocess_net, "output_dim",
|
||||
preprocess_net_output_dim)
|
||||
self.last = MLP(input_dim, self.output_dim,
|
||||
@ -55,7 +55,7 @@ class Actor(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
s: Union[np.ndarray, torch.Tensor],
|
||||
state: Optional[Any] = None,
|
||||
state: Any = None,
|
||||
info: Dict[str, Any] = {},
|
||||
) -> Tuple[torch.Tensor, Any]:
|
||||
r"""Mapping: s -> Q(s, \*)."""
|
||||
|
||||
@ -3,8 +3,6 @@ import numpy as np
|
||||
from numbers import Number
|
||||
from typing import List, Union
|
||||
|
||||
from tianshou.data import to_numpy
|
||||
|
||||
|
||||
class MovAvg(object):
|
||||
"""Class for moving average.
|
||||
@ -28,44 +26,43 @@ class MovAvg(object):
|
||||
def __init__(self, size: int = 100) -> None:
|
||||
super().__init__()
|
||||
self.size = size
|
||||
self.cache: List[Union[Number, np.number]] = []
|
||||
self.cache: List[np.number] = []
|
||||
self.banned = [np.inf, np.nan, -np.inf]
|
||||
|
||||
def add(
|
||||
self, x: Union[Number, np.number, list, np.ndarray, torch.Tensor]
|
||||
) -> np.number:
|
||||
) -> float:
|
||||
"""Add a scalar into :class:`MovAvg`.
|
||||
|
||||
You can add ``torch.Tensor`` with only one element, a python scalar, or
|
||||
a list of python scalar.
|
||||
"""
|
||||
if isinstance(x, torch.Tensor):
|
||||
x = to_numpy(x.flatten())
|
||||
if isinstance(x, list) or isinstance(x, np.ndarray):
|
||||
for i in x:
|
||||
if i not in self.banned:
|
||||
self.cache.append(i)
|
||||
elif x not in self.banned:
|
||||
self.cache.append(x)
|
||||
x = x.flatten().cpu().numpy()
|
||||
if np.isscalar(x):
|
||||
x = [x]
|
||||
for i in x: # type: ignore
|
||||
if i not in self.banned:
|
||||
self.cache.append(i)
|
||||
if self.size > 0 and len(self.cache) > self.size:
|
||||
self.cache = self.cache[-self.size:]
|
||||
return self.get()
|
||||
|
||||
def get(self) -> np.number:
|
||||
def get(self) -> float:
|
||||
"""Get the average."""
|
||||
if len(self.cache) == 0:
|
||||
return 0
|
||||
return np.mean(self.cache)
|
||||
return 0.0
|
||||
return float(np.mean(self.cache))
|
||||
|
||||
def mean(self) -> np.number:
|
||||
def mean(self) -> float:
|
||||
"""Get the average. Same as :meth:`get`."""
|
||||
return self.get()
|
||||
|
||||
def std(self) -> np.number:
|
||||
def std(self) -> float:
|
||||
"""Get the standard deviation."""
|
||||
if len(self.cache) == 0:
|
||||
return 0
|
||||
return np.std(self.cache)
|
||||
return 0.0
|
||||
return float(np.std(self.cache))
|
||||
|
||||
|
||||
class RunningMeanStd(object):
|
||||
@ -74,8 +71,10 @@ class RunningMeanStd(object):
|
||||
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.mean, self.var = 0.0, 1.0
|
||||
def __init__(
|
||||
self, mean: Union[float, np.ndarray] = 0.0, std: Union[float, np.ndarray] = 1.0
|
||||
) -> None:
|
||||
self.mean, self.var = mean, std
|
||||
self.count = 0
|
||||
|
||||
def update(self, x: np.ndarray) -> None:
|
||||
@ -92,5 +91,5 @@ class RunningMeanStd(object):
|
||||
m_2 = m_a + m_b + delta ** 2 * self.count * batch_count / total_count
|
||||
new_var = m_2 / total_count
|
||||
|
||||
self.mean, self.var = new_mean, new_var
|
||||
self.mean, self.var = new_mean, new_var # type: ignore
|
||||
self.count = total_count
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user