Standardized behavior of Batch.cat and misc code refactor (#137)
* code refactor; remove unused kwargs; add reward_normalization for dqn * bugfix for __setitem__ with torch.Tensor; add Batch.condense * minor fix * support cat with empty Batch * remove the dependency of is_empty on len; specify the semantic of empty Batch by test cases * support stack with empty Batch * remove condense * refactor code to reflect the shared / partial / reserved categories of keys * add is_empty(recursive=False) * doc fix * docfix and bugfix for _is_batch_set * add doc for key reservation * bugfix for algebra operators * fix cat with lens hint * code refactor * bugfix for storing None * use ValueError instead of exception * hide lens away from users * add comment for __cat * move the computation of the initial value of lens in cat_ itself. * change the place of doc string * doc fix for Batch doc string * change recursive to recurse * doc string fix * minor fix for batch doc
This commit is contained in:
parent
09e10e384f
commit
3a08e27ed4
@ -10,7 +10,12 @@ from tianshou.data import Batch, to_torch
|
||||
def test_batch():
|
||||
assert list(Batch()) == []
|
||||
assert Batch().is_empty()
|
||||
assert Batch(b={'c': {}}).is_empty()
|
||||
assert not Batch(b={'c': {}}).is_empty()
|
||||
assert Batch(b={'c': {}}).is_empty(recurse=True)
|
||||
assert not Batch(a=Batch(), b=Batch(c=Batch())).is_empty()
|
||||
assert Batch(a=Batch(), b=Batch(c=Batch())).is_empty(recurse=True)
|
||||
assert not Batch(d=1).is_empty()
|
||||
assert not Batch(a=np.float64(1.0)).is_empty()
|
||||
assert len(Batch(a=[1, 2, 3], b={'c': {}})) == 3
|
||||
assert not Batch(a=[1, 2, 3]).is_empty()
|
||||
b = Batch()
|
||||
@ -109,6 +114,11 @@ def test_batch():
|
||||
assert isinstance(batch5.b, Batch)
|
||||
assert np.allclose(batch5.b.index, [1])
|
||||
|
||||
# None is a valid object and can be stored in Batch
|
||||
a = Batch.stack([Batch(a=None), Batch(b=None)])
|
||||
assert a.a[0] is None and a.a[1] is None
|
||||
assert a.b[0] is None and a.b[1] is None
|
||||
|
||||
|
||||
def test_batch_over_batch():
|
||||
batch = Batch(a=[3, 4, 5], b=[4, 5, 6])
|
||||
@ -162,6 +172,20 @@ def test_batch_cat_and_stack():
|
||||
assert isinstance(b12_cat_in.a.d.e, np.ndarray)
|
||||
assert b12_cat_in.a.d.e.ndim == 1
|
||||
|
||||
a = Batch(a=Batch(a=np.random.randn(3, 4)))
|
||||
assert np.allclose(
|
||||
np.concatenate([a.a.a, a.a.a]),
|
||||
Batch.cat([a, Batch(a=Batch(a=Batch())), a]).a.a)
|
||||
|
||||
# test cat with lens infer
|
||||
a = Batch(a=Batch(a=np.random.randn(3, 4)), b=np.random.randn(3, 4))
|
||||
b = Batch(a=Batch(a=Batch(), t=Batch()), b=np.random.randn(3, 4))
|
||||
ans = Batch.cat([a, b, a])
|
||||
assert np.allclose(ans.a.a,
|
||||
np.concatenate([a.a.a, np.zeros((3, 4)), a.a.a]))
|
||||
assert np.allclose(ans.b, np.concatenate([a.b, b.b, a.b]))
|
||||
assert ans.a.t.is_empty()
|
||||
|
||||
b12_stack = Batch.stack((b1, b2))
|
||||
assert isinstance(b12_stack.a.d.e, np.ndarray)
|
||||
assert b12_stack.a.d.e.ndim == 2
|
||||
@ -177,6 +201,32 @@ def test_batch_cat_and_stack():
|
||||
assert torch.allclose(test.b, ans.b)
|
||||
assert np.allclose(test.common.c, ans.common.c)
|
||||
|
||||
# test cat with reserved keys (values are Batch())
|
||||
b1 = Batch(a=np.random.rand(3, 4), common=Batch(c=np.random.rand(3, 5)))
|
||||
b2 = Batch(a=Batch(),
|
||||
b=torch.rand(4, 3),
|
||||
common=Batch(c=np.random.rand(4, 5)))
|
||||
test = Batch.cat([b1, b2])
|
||||
ans = Batch(a=np.concatenate([b1.a, np.zeros((4, 4))]),
|
||||
b=torch.cat([torch.zeros(3, 3), b2.b]),
|
||||
common=Batch(c=np.concatenate([b1.common.c, b2.common.c])))
|
||||
assert np.allclose(test.a, ans.a)
|
||||
assert torch.allclose(test.b, ans.b)
|
||||
assert np.allclose(test.common.c, ans.common.c)
|
||||
|
||||
# test cat with all reserved keys (values are Batch())
|
||||
b1 = Batch(a=Batch(), common=Batch(c=np.random.rand(3, 5)))
|
||||
b2 = Batch(a=Batch(),
|
||||
b=torch.rand(4, 3),
|
||||
common=Batch(c=np.random.rand(4, 5)))
|
||||
test = Batch.cat([b1, b2])
|
||||
ans = Batch(a=Batch(),
|
||||
b=torch.cat([torch.zeros(3, 3), b2.b]),
|
||||
common=Batch(c=np.concatenate([b1.common.c, b2.common.c])))
|
||||
assert ans.a.is_empty()
|
||||
assert torch.allclose(test.b, ans.b)
|
||||
assert np.allclose(test.common.c, ans.common.c)
|
||||
|
||||
# test stack with compatible keys
|
||||
b3 = Batch(a=np.zeros((3, 4)),
|
||||
b=torch.ones((2, 5)),
|
||||
@ -205,6 +255,25 @@ def test_batch_cat_and_stack():
|
||||
assert np.allclose(d.c, [3, 0, 7])
|
||||
assert np.allclose(d.d, [0, 6, 9])
|
||||
|
||||
# test stack with empty Batch()
|
||||
assert Batch.stack([Batch(), Batch(), Batch()]).is_empty()
|
||||
a = Batch(a=1, b=2, c=3, d=Batch(), e=Batch())
|
||||
b = Batch(a=4, b=5, d=6, e=Batch())
|
||||
c = Batch(c=7, b=6, d=9, e=Batch())
|
||||
d = Batch.stack([a, b, c])
|
||||
assert np.allclose(d.a, [1, 4, 0])
|
||||
assert np.allclose(d.b, [2, 5, 6])
|
||||
assert np.allclose(d.c, [3, 0, 7])
|
||||
assert np.allclose(d.d, [0, 6, 9])
|
||||
assert d.e.is_empty()
|
||||
b1 = Batch(a=Batch(), common=Batch(c=np.random.rand(4, 5)))
|
||||
b2 = Batch(b=Batch(), common=Batch(c=np.random.rand(4, 5)))
|
||||
test = Batch.stack([b1, b2], axis=-1)
|
||||
assert test.a.is_empty()
|
||||
assert test.b.is_empty()
|
||||
assert np.allclose(test.common.c,
|
||||
np.stack([b1.common.c, b2.common.c], axis=-1))
|
||||
|
||||
b1 = Batch(a=np.random.rand(4, 4), common=Batch(c=np.random.rand(4, 5)))
|
||||
b2 = Batch(b=torch.rand(4, 6), common=Batch(c=np.random.rand(4, 5)))
|
||||
test = Batch.stack([b1, b2])
|
||||
|
@ -14,11 +14,17 @@ warnings.filterwarnings(
|
||||
|
||||
|
||||
def _is_batch_set(data: Any) -> bool:
|
||||
# Batch set is a list/tuple of dict/Batch objects,
|
||||
# or 1-D np.ndarray with np.object type,
|
||||
# where each element is a dict/Batch object
|
||||
if isinstance(data, (list, tuple)):
|
||||
if len(data) > 0 and all(isinstance(e, (dict, Batch)) for e in data):
|
||||
return True
|
||||
elif isinstance(data, np.ndarray) and data.dtype == np.object:
|
||||
if all(isinstance(e, (dict, Batch)) for e in data.tolist()):
|
||||
# ``for e in data`` will just unpack the first dimension,
|
||||
# but data.tolist() will flatten ndarray of objects
|
||||
# so do not use data.tolist()
|
||||
if all(isinstance(e, (dict, Batch)) for e in data):
|
||||
return True
|
||||
return False
|
||||
|
||||
@ -39,7 +45,7 @@ def _create_value(inst: Any, size: int, stack=True) -> Union[
|
||||
# here we do not consider scalar types, following the
|
||||
# behavior of numpy which does not support concatenation
|
||||
# of zero-dimensional arrays (scalars)
|
||||
raise TypeError(f"cannot cat {inst} with which is scalar")
|
||||
raise TypeError(f"cannot concatenate with {inst} which is scalar")
|
||||
if has_shape:
|
||||
shape = (size, *inst.shape) if stack else (size, *inst.shape[1:])
|
||||
if isinstance(inst, np.ndarray):
|
||||
@ -95,9 +101,9 @@ class Batch:
|
||||
|
||||
In short, you can define a :class:`Batch` with any key-value pair.
|
||||
|
||||
For Numpy arrays, only data types with ``np.object``, bool, and number
|
||||
are supported. For strings or other data types, however, they can be
|
||||
held in ``np.object`` arrays.
|
||||
For Numpy arrays, only data types with ``np.object``, bool, and number are
|
||||
supported. For strings or other data types, however, they can be held in
|
||||
``np.object`` arrays.
|
||||
|
||||
The current implementation of Tianshou typically use 7 reserved keys in
|
||||
:class:`~tianshou.data.Batch`:
|
||||
@ -108,9 +114,39 @@ class Batch:
|
||||
* ``done`` the done flag of step :math:`t` ;
|
||||
* ``obs_next`` the observation of step :math:`t+1` ;
|
||||
* ``info`` the info of step :math:`t` (in ``gym.Env``, the ``env.step()``\
|
||||
function returns 4 arguments, and the last one is ``info``);
|
||||
function returns 4 arguments, and the last one is ``info``);
|
||||
* ``policy`` the data computed by policy in step :math:`t`;
|
||||
|
||||
For convenience, :class:`~tianshou.data.Batch` supports the mechanism of
|
||||
key reservation: one can specify a key without any value, which serves as
|
||||
a placeholder for the Batch object. For example, you know there will be a
|
||||
key named ``obs``, but do not know the value until the simulator runs. Then
|
||||
you can reserve the key ``obs``. This is done by setting the value to
|
||||
``Batch()``.
|
||||
|
||||
For a Batch object, we call it "incomplete" if: (i) it is ``Batch()``; (ii)
|
||||
it has reserved keys; (iii) any of its sub-Batch is incomplete. Otherwise,
|
||||
the Batch object is finalized.
|
||||
|
||||
Key reservation mechanism is convenient, but also causes some problem in
|
||||
aggregation operators like ``stack`` or ``cat`` of Batch objects. We say
|
||||
that Batch objects are compatible for aggregation with three cases:
|
||||
|
||||
1. finalized Batch objects are compatible if and only if their exists a \
|
||||
way to extend keys so that their structures are exactly the same.
|
||||
|
||||
2. incomplete Batch objects and other finalized objects are compatible if \
|
||||
their exists a way to extend keys so that incomplete Batch objects can \
|
||||
have the same structure as finalized objects.
|
||||
|
||||
3. incomplete Batch objects themselevs are compatible if their exists a \
|
||||
way to extend keys so that their structure can be the same.
|
||||
|
||||
In a word, incomplete Batch objects have a set of possible structures
|
||||
in the future, but finalized Batch object only have a finalized structure.
|
||||
Batch objects are compatible if and only if they share at least one
|
||||
commonly possible structure by extending keys.
|
||||
|
||||
:class:`~tianshou.data.Batch` object can be initialized by a wide variety
|
||||
of arguments, ranging from the key/value pairs or dictionary, to list and
|
||||
Numpy arrays of :class:`dict` or Batch instances where each element is
|
||||
@ -126,8 +162,8 @@ class Batch:
|
||||
)
|
||||
|
||||
:class:`~tianshou.data.Batch` has the same API as a native Python
|
||||
:class:`dict`. In this regard, one can access stored data using string
|
||||
key, or iterate over stored data:
|
||||
:class:`dict`. In this regard, one can access stored data using string key,
|
||||
or iterate over stored data:
|
||||
::
|
||||
|
||||
>>> data = Batch(a=4, b=[5, 5])
|
||||
@ -153,7 +189,7 @@ class Batch:
|
||||
)
|
||||
>>> for sample in data:
|
||||
>>> print(sample.a)
|
||||
[0., 2.]
|
||||
[0. 2.]
|
||||
|
||||
>>> print(data.shape)
|
||||
[1, 2]
|
||||
@ -341,7 +377,7 @@ class Batch:
|
||||
if len(batch_items) > 0:
|
||||
b = Batch()
|
||||
for k, v in batch_items:
|
||||
if isinstance(v, Batch) and len(v.__dict__) == 0:
|
||||
if isinstance(v, Batch) and v.is_empty():
|
||||
b.__dict__[k] = Batch()
|
||||
else:
|
||||
b.__dict__[k] = v[index]
|
||||
@ -376,8 +412,9 @@ class Batch:
|
||||
except KeyError:
|
||||
if isinstance(val, Batch):
|
||||
self.__dict__[key][index] = Batch()
|
||||
elif isinstance(val, np.ndarray) and \
|
||||
issubclass(val.dtype.type, (np.bool_, np.number)):
|
||||
elif isinstance(val, torch.Tensor) or \
|
||||
(isinstance(val, np.ndarray) and
|
||||
issubclass(val.dtype.type, (np.bool_, np.number))):
|
||||
self.__dict__[key][index] = 0
|
||||
else:
|
||||
self.__dict__[key][index] = None
|
||||
@ -389,14 +426,14 @@ class Batch:
|
||||
for (k, r), v in zip(self.__dict__.items(),
|
||||
other.__dict__.values()):
|
||||
# TODO are keys consistent?
|
||||
if r is None:
|
||||
if isinstance(r, Batch) and r.is_empty():
|
||||
continue
|
||||
else:
|
||||
self.__dict__[k] += v
|
||||
return self
|
||||
elif isinstance(other, (Number, np.number)):
|
||||
for k, r in self.items():
|
||||
if r is None:
|
||||
if isinstance(r, Batch) and r.is_empty():
|
||||
continue
|
||||
else:
|
||||
self.__dict__[k] += other
|
||||
@ -413,7 +450,9 @@ class Batch:
|
||||
"""Algebraic multiplication with a scalar value in-place."""
|
||||
assert isinstance(val, (Number, np.number)), \
|
||||
"Only multiplication by a number is supported."
|
||||
for k in self.__dict__.keys():
|
||||
for k, r in self.__dict__.items():
|
||||
if isinstance(r, Batch) and r.is_empty():
|
||||
continue
|
||||
self.__dict__[k] *= val
|
||||
return self
|
||||
|
||||
@ -425,7 +464,9 @@ class Batch:
|
||||
"""Algebraic division with a scalar value in-place."""
|
||||
assert isinstance(val, (Number, np.number)), \
|
||||
"Only division by a number is supported."
|
||||
for k in self.__dict__.keys():
|
||||
for k, r in self.__dict__.items():
|
||||
if isinstance(r, Batch) and r.is_empty():
|
||||
continue
|
||||
self.__dict__[k] /= val
|
||||
return self
|
||||
|
||||
@ -501,6 +542,77 @@ class Batch:
|
||||
elif isinstance(v, Batch):
|
||||
v.to_torch(dtype, device)
|
||||
|
||||
def __cat(self,
|
||||
batches: Union['Batch', List[Union[dict, 'Batch']]],
|
||||
lens: List[int]) -> None:
|
||||
"""::
|
||||
|
||||
>>> a = Batch(a=np.random.randn(3, 4))
|
||||
>>> x = Batch(a=a, b=np.random.randn(4, 4))
|
||||
>>> y = Batch(a=Batch(a=Batch()), b=np.random.randn(4, 4))
|
||||
|
||||
If we want to concatenate x and y, we want to pad y.a.a with zeros.
|
||||
Without ``lens`` as a hint, when we concatenate x.a and y.a, we would
|
||||
not be able to know how to pad y.a. So ``Batch.cat_`` should compute
|
||||
the ``lens`` to give ``Batch.__cat`` a hint.
|
||||
::
|
||||
|
||||
>>> ans = Batch.cat([x, y])
|
||||
>>> # this is equivalent to the following line
|
||||
>>> ans = Batch(); ans.__cat([x, y], lens=[3, 4])
|
||||
>>> # this lens is equal to [len(a), len(b)]
|
||||
"""
|
||||
# partial keys will be padded by zeros
|
||||
# with the shape of [len, rest_shape]
|
||||
sum_lens = [0]
|
||||
for x in lens:
|
||||
sum_lens.append(sum_lens[-1] + x)
|
||||
# collect non-empty keys
|
||||
keys_map = [
|
||||
set(k for k, v in batch.items()
|
||||
if not (isinstance(v, Batch) and v.is_empty()))
|
||||
for batch in batches]
|
||||
keys_shared = set.intersection(*keys_map)
|
||||
values_shared = [[e[k] for e in batches] for k in keys_shared]
|
||||
_assert_type_keys(keys_shared)
|
||||
for k, v in zip(keys_shared, values_shared):
|
||||
if all(isinstance(e, (dict, Batch)) for e in v):
|
||||
batch_holder = Batch()
|
||||
batch_holder.__cat(v, lens=lens)
|
||||
self.__dict__[k] = batch_holder
|
||||
elif all(isinstance(e, torch.Tensor) for e in v):
|
||||
self.__dict__[k] = torch.cat(v)
|
||||
else:
|
||||
# cat Batch(a=np.zeros((3, 4))) and Batch(a=Batch(b=Batch()))
|
||||
# will fail here
|
||||
v = np.concatenate(v)
|
||||
if not issubclass(v.dtype.type, (np.bool_, np.number)):
|
||||
v = v.astype(np.object)
|
||||
self.__dict__[k] = v
|
||||
keys_total = set.union(*[set(b.keys()) for b in batches])
|
||||
keys_reserve_or_partial = set.difference(keys_total, keys_shared)
|
||||
_assert_type_keys(keys_reserve_or_partial)
|
||||
# keys that are reserved in all batches
|
||||
keys_reserve = set.difference(keys_total, set.union(*keys_map))
|
||||
# keys that occur only in some batches, but not all
|
||||
keys_partial = keys_reserve_or_partial.difference(keys_reserve)
|
||||
for k in keys_reserve:
|
||||
# reserved keys
|
||||
self.__dict__[k] = Batch()
|
||||
for k in keys_partial:
|
||||
for i, e in enumerate(batches):
|
||||
if k not in e.__dict__:
|
||||
continue
|
||||
val = e.get(k)
|
||||
if isinstance(val, Batch) and val.is_empty():
|
||||
continue
|
||||
try:
|
||||
self.__dict__[k][sum_lens[i]:sum_lens[i + 1]] = val
|
||||
except KeyError:
|
||||
self.__dict__[k] = \
|
||||
_create_value(val, sum_lens[-1], stack=False)
|
||||
self.__dict__[k][sum_lens[i]:sum_lens[i + 1]] = val
|
||||
|
||||
def cat_(self,
|
||||
batches: Union['Batch', List[Union[dict, 'Batch']]]) -> None:
|
||||
"""Concatenate a list of (or one) :class:`~tianshou.data.Batch` objects
|
||||
@ -511,40 +623,25 @@ class Batch:
|
||||
if len(batches) == 0:
|
||||
return
|
||||
batches = [x if isinstance(x, Batch) else Batch(x) for x in batches]
|
||||
if len(self.__dict__) > 0:
|
||||
|
||||
# x.is_empty() means that x is Batch() and should be ignored
|
||||
batches = [x for x in batches if not x.is_empty()]
|
||||
try:
|
||||
# x.is_empty(recurse=True) here means x is a nested empty batch
|
||||
# like Batch(a=Batch), and we have to treat it as length zero and
|
||||
# keep it.
|
||||
lens = [0 if x.is_empty(recurse=True) else len(x)
|
||||
for x in batches]
|
||||
except TypeError as e:
|
||||
e2 = ValueError(
|
||||
f'Batch.cat_ meets an exception. Maybe because there is '
|
||||
f'any scalar in {batches} but Batch.cat_ does not support'
|
||||
f'the concatenation of scalar.')
|
||||
raise Exception([e, e2])
|
||||
if not self.is_empty():
|
||||
batches = [self] + list(batches)
|
||||
# partial keys will be padded by zeros
|
||||
# with the shape of [len, rest_shape]
|
||||
lens = [len(x) for x in batches]
|
||||
sum_lens = [0]
|
||||
for x in lens:
|
||||
sum_lens.append(sum_lens[-1] + x)
|
||||
keys_map = list(map(lambda e: set(e.keys()), batches))
|
||||
keys_shared = set.intersection(*keys_map)
|
||||
values_shared = [[e[k] for e in batches] for k in keys_shared]
|
||||
_assert_type_keys(keys_shared)
|
||||
for k, v in zip(keys_shared, values_shared):
|
||||
if all(isinstance(e, (dict, Batch)) for e in v):
|
||||
self.__dict__[k] = Batch.cat(v)
|
||||
elif all(isinstance(e, torch.Tensor) for e in v):
|
||||
self.__dict__[k] = torch.cat(v)
|
||||
else:
|
||||
v = np.concatenate(v)
|
||||
if not issubclass(v.dtype.type, (np.bool_, np.number)):
|
||||
v = v.astype(np.object)
|
||||
self.__dict__[k] = v
|
||||
keys_partial = set.union(*keys_map) - keys_shared
|
||||
_assert_type_keys(keys_partial)
|
||||
for k in keys_partial:
|
||||
for i, e in enumerate(batches):
|
||||
val = e.get(k, None)
|
||||
if val is not None:
|
||||
try:
|
||||
self.__dict__[k][sum_lens[i]:sum_lens[i + 1]] = val
|
||||
except KeyError:
|
||||
self.__dict__[k] = \
|
||||
_create_value(val, sum_lens[-1], stack=False)
|
||||
self.__dict__[k][sum_lens[i]:sum_lens[i + 1]] = val
|
||||
lens = [0 if self.is_empty(recurse=True) else len(self)] + lens
|
||||
return self.__cat(batches, lens)
|
||||
|
||||
@staticmethod
|
||||
def cat(batches: List[Union[dict, 'Batch']]) -> 'Batch':
|
||||
@ -577,9 +674,13 @@ class Batch:
|
||||
if len(batches) == 0:
|
||||
return
|
||||
batches = [x if isinstance(x, Batch) else Batch(x) for x in batches]
|
||||
if len(self.__dict__) > 0:
|
||||
if not self.is_empty():
|
||||
batches = [self] + list(batches)
|
||||
keys_map = list(map(lambda e: set(e.keys()), batches))
|
||||
# collect non-empty keys
|
||||
keys_map = [
|
||||
set(k for k, v in batch.items()
|
||||
if not (isinstance(v, Batch) and v.is_empty()))
|
||||
for batch in batches]
|
||||
keys_shared = set.intersection(*keys_map)
|
||||
values_shared = [[e[k] for e in batches] for k in keys_shared]
|
||||
_assert_type_keys(keys_shared)
|
||||
@ -593,22 +694,35 @@ class Batch:
|
||||
if not issubclass(v.dtype.type, (np.bool_, np.number)):
|
||||
v = v.astype(np.object)
|
||||
self.__dict__[k] = v
|
||||
keys_partial = set.difference(set.union(*keys_map), keys_shared)
|
||||
# all the keys
|
||||
keys_total = set.union(*[set(b.keys()) for b in batches])
|
||||
# keys that are reserved in all batches
|
||||
keys_reserve = set.difference(keys_total, set.union(*keys_map))
|
||||
# keys that are either partial or reserved
|
||||
keys_reserve_or_partial = set.difference(keys_total, keys_shared)
|
||||
# keys that occur only in some batches, but not all
|
||||
keys_partial = keys_reserve_or_partial.difference(keys_reserve)
|
||||
if keys_partial and axis != 0:
|
||||
raise ValueError(
|
||||
f"Stack of Batch with non-shared keys {keys_partial} "
|
||||
f"is only supported with axis=0, but got axis={axis}!")
|
||||
_assert_type_keys(keys_partial)
|
||||
_assert_type_keys(keys_reserve_or_partial)
|
||||
for k in keys_reserve:
|
||||
# reserved keys
|
||||
self.__dict__[k] = Batch()
|
||||
for k in keys_partial:
|
||||
for i, e in enumerate(batches):
|
||||
val = e.get(k, None)
|
||||
if val is not None:
|
||||
try:
|
||||
self.__dict__[k][i] = val
|
||||
except KeyError:
|
||||
self.__dict__[k] = \
|
||||
_create_value(val, len(batches))
|
||||
self.__dict__[k][i] = val
|
||||
if k not in e.__dict__:
|
||||
continue
|
||||
val = e.get(k)
|
||||
if isinstance(val, Batch) and val.is_empty():
|
||||
continue
|
||||
try:
|
||||
self.__dict__[k][i] = val
|
||||
except KeyError:
|
||||
self.__dict__[k] = \
|
||||
_create_value(val, len(batches))
|
||||
self.__dict__[k][i] = val
|
||||
|
||||
@staticmethod
|
||||
def stack(batches: List[Union[dict, 'Batch']], axis: int = 0) -> 'Batch':
|
||||
@ -691,26 +805,53 @@ class Batch:
|
||||
"""Return len(self)."""
|
||||
r = []
|
||||
for v in self.__dict__.values():
|
||||
if isinstance(v, Batch) and v.is_empty():
|
||||
if isinstance(v, Batch) and v.is_empty(recurse=True):
|
||||
continue
|
||||
elif hasattr(v, '__len__') and (not isinstance(
|
||||
v, (np.ndarray, torch.Tensor)) or v.ndim > 0):
|
||||
r.append(len(v))
|
||||
else:
|
||||
raise TypeError("Object of type 'Batch' has no len()")
|
||||
raise TypeError(f"Object {v} in {self} has no len()")
|
||||
if len(r) == 0:
|
||||
raise TypeError("Object of type 'Batch' has no len()")
|
||||
raise TypeError(f"Object {self} has no len()")
|
||||
return min(r)
|
||||
|
||||
def is_empty(self):
|
||||
return not any(
|
||||
not x.is_empty() if isinstance(x, Batch)
|
||||
else hasattr(x, '__len__') and len(x) > 0 for x in self.values())
|
||||
def is_empty(self, recurse: bool = False):
|
||||
"""
|
||||
Test if a Batch is empty. If ``recurse=True``, it further tests the
|
||||
values of the object; else it only tests the existence of any key.
|
||||
|
||||
``b.is_empty(recurse=True)`` is mainly used to distinguish
|
||||
``Batch(a=Batch(a=Batch()))`` and ``Batch(a=1)``. They both raise
|
||||
exceptions when applied to ``len()``, but the former can be used in
|
||||
``cat``, while the latter is a scalar and cannot be used in ``cat``.
|
||||
|
||||
Another usage is in ``__len__``, where we have to skip checking the
|
||||
length of recursely empty Batch.
|
||||
::
|
||||
|
||||
>>> Batch().is_empty()
|
||||
True
|
||||
>>> Batch(a=Batch(), b=Batch(c=Batch())).is_empty()
|
||||
False
|
||||
>>> Batch(a=Batch(), b=Batch(c=Batch())).is_empty(recurse=True)
|
||||
True
|
||||
>>> Batch(d=1).is_empty()
|
||||
False
|
||||
>>> Batch(a=np.float64(1.0)).is_empty()
|
||||
False
|
||||
"""
|
||||
if len(self.__dict__) == 0:
|
||||
return True
|
||||
if not recurse:
|
||||
return False
|
||||
return all(False if not isinstance(x, Batch)
|
||||
else x.is_empty(recurse=True) for x in self.values())
|
||||
|
||||
@property
|
||||
def shape(self) -> List[int]:
|
||||
"""Return self.shape."""
|
||||
if len(self.__dict__.keys()) == 0:
|
||||
if self.is_empty():
|
||||
return []
|
||||
else:
|
||||
data_shape = []
|
||||
|
@ -98,7 +98,7 @@ class Collector(object):
|
||||
stat_size: Optional[int] = 100,
|
||||
action_noise: Optional[BaseNoise] = None,
|
||||
reward_metric: Optional[Callable[[np.ndarray], float]] = None,
|
||||
**kwargs) -> None:
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.env = env
|
||||
self.env_num = 1
|
||||
|
@ -108,7 +108,8 @@ class BasePolicy(ABC, nn.Module):
|
||||
batch: Batch,
|
||||
v_s_: Optional[Union[np.ndarray, torch.Tensor]] = None,
|
||||
gamma: float = 0.99,
|
||||
gae_lambda: float = 0.95) -> Batch:
|
||||
gae_lambda: float = 0.95,
|
||||
) -> Batch:
|
||||
"""Compute returns over given full-length episodes, including the
|
||||
implementation of Generalized Advantage Estimator (arXiv:1506.02438).
|
||||
|
||||
@ -124,18 +125,19 @@ class BasePolicy(ABC, nn.Module):
|
||||
|
||||
:return: a Batch. The result will be stored in batch.returns.
|
||||
"""
|
||||
rew = batch.rew
|
||||
if v_s_ is None:
|
||||
v_s_ = batch.rew * 0.
|
||||
v_s_ = rew * 0.
|
||||
else:
|
||||
if not isinstance(v_s_, np.ndarray):
|
||||
v_s_ = np.array(v_s_, np.float)
|
||||
v_s_ = v_s_.reshape(batch.rew.shape)
|
||||
v_s_ = v_s_.reshape(rew.shape)
|
||||
returns = np.roll(v_s_, 1, axis=0)
|
||||
m = (1. - batch.done) * gamma
|
||||
delta = batch.rew + v_s_ * m - returns
|
||||
delta = rew + v_s_ * m - returns
|
||||
m *= gae_lambda
|
||||
gae = 0.
|
||||
for i in range(len(batch.rew) - 1, -1, -1):
|
||||
for i in range(len(rew) - 1, -1, -1):
|
||||
gae = delta[i] + m[i] * gae
|
||||
returns[i] += gae
|
||||
batch.returns = returns
|
||||
@ -149,7 +151,7 @@ class BasePolicy(ABC, nn.Module):
|
||||
target_q_fn: Callable[[ReplayBuffer, np.ndarray], torch.Tensor],
|
||||
gamma: float = 0.99,
|
||||
n_step: int = 1,
|
||||
rew_norm: bool = False
|
||||
rew_norm: bool = False,
|
||||
) -> np.ndarray:
|
||||
r"""Compute n-step return for Q-learning targets:
|
||||
|
||||
@ -180,8 +182,9 @@ class BasePolicy(ABC, nn.Module):
|
||||
:return: a Batch. The result will be stored in batch.returns as a
|
||||
torch.Tensor with shape (bsz, ).
|
||||
"""
|
||||
rew = buffer.rew
|
||||
if rew_norm:
|
||||
bfr = buffer.rew[:min(len(buffer), 1000)] # avoid large buffer
|
||||
bfr = rew[:min(len(buffer), 1000)] # avoid large buffer
|
||||
mean, std = bfr.mean(), bfr.std()
|
||||
if np.isclose(std, 0):
|
||||
mean, std = 0, 1
|
||||
@ -189,7 +192,7 @@ class BasePolicy(ABC, nn.Module):
|
||||
mean, std = 0, 1
|
||||
returns = np.zeros_like(indice)
|
||||
gammas = np.zeros_like(indice) + n_step
|
||||
done, rew, buf_len = buffer.done, buffer.rew, len(buffer)
|
||||
done, buf_len = buffer.done, len(buffer)
|
||||
for n in range(n_step - 1, -1, -1):
|
||||
now = (indice + n) % buf_len
|
||||
gammas[done[now] > 0] = n
|
||||
|
@ -23,7 +23,7 @@ class ImitationPolicy(BasePolicy):
|
||||
"""
|
||||
|
||||
def __init__(self, model: torch.nn.Module, optim: torch.optim.Optimizer,
|
||||
mode: str = 'continuous', **kwargs) -> None:
|
||||
mode: str = 'continuous') -> None:
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.optim = optim
|
||||
|
@ -21,6 +21,8 @@ class DQNPolicy(BasePolicy):
|
||||
ahead.
|
||||
:param int target_update_freq: the target network update frequency (``0``
|
||||
if you do not use the target network).
|
||||
:param bool reward_normalization: normalize the reward to Normal(0, 1),
|
||||
defaults to ``False``.
|
||||
|
||||
.. seealso::
|
||||
|
||||
@ -34,6 +36,7 @@ class DQNPolicy(BasePolicy):
|
||||
discount_factor: float = 0.99,
|
||||
estimation_step: int = 1,
|
||||
target_update_freq: Optional[int] = 0,
|
||||
reward_normalization: bool = False,
|
||||
**kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.model = model
|
||||
@ -49,6 +52,7 @@ class DQNPolicy(BasePolicy):
|
||||
if self._target:
|
||||
self.model_old = deepcopy(self.model)
|
||||
self.model_old.eval()
|
||||
self._rew_norm = reward_normalization
|
||||
|
||||
def set_eps(self, eps: float) -> None:
|
||||
"""Set the eps for epsilon-greedy exploration."""
|
||||
@ -94,7 +98,8 @@ class DQNPolicy(BasePolicy):
|
||||
to :math:`Q_{new}`.
|
||||
"""
|
||||
batch = self.compute_nstep_return(
|
||||
batch, buffer, indice, self._target_q, self._gamma, self._n_step)
|
||||
batch, buffer, indice, self._target_q,
|
||||
self._gamma, self._n_step, self._rew_norm)
|
||||
if isinstance(buffer, PrioritizedReplayBuffer):
|
||||
batch.update_weight = buffer.update_weight
|
||||
batch.indice = indice
|
||||
|
@ -27,7 +27,6 @@ def offpolicy_trainer(
|
||||
writer: Optional[SummaryWriter] = None,
|
||||
log_interval: int = 1,
|
||||
verbose: bool = True,
|
||||
**kwargs
|
||||
) -> Dict[str, Union[float, str]]:
|
||||
"""A wrapper for off-policy trainer procedure.
|
||||
|
||||
|
@ -27,7 +27,6 @@ def onpolicy_trainer(
|
||||
writer: Optional[SummaryWriter] = None,
|
||||
log_interval: int = 1,
|
||||
verbose: bool = True,
|
||||
**kwargs
|
||||
) -> Dict[str, Union[float, str]]:
|
||||
"""A wrapper for on-policy trainer procedure.
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user