Standardized behavior of Batch.cat and misc code refactor (#137)

* code refactor; remove unused kwargs; add reward_normalization for dqn

* bugfix for __setitem__ with torch.Tensor; add Batch.condense

* minor fix

* support cat with empty Batch

* remove the dependency of is_empty on len; specify the semantic of empty Batch by test cases

* support stack with empty Batch

* remove condense

* refactor code to reflect the shared / partial / reserved categories of keys

* add is_empty(recursive=False)

* doc fix

* docfix and bugfix for _is_batch_set

* add doc for key reservation

* bugfix for algebra operators

* fix cat with lens hint

* code refactor

* bugfix for storing None

* use ValueError instead of exception

* hide lens away from users

* add comment for __cat

* move the computation of the initial value of lens in cat_ itself.

* change the place of doc string

* doc fix for Batch doc string

* change recursive to recurse

* doc string fix

* minor fix for batch doc
This commit is contained in:
youkaichao 2020-07-16 19:36:32 +08:00 committed by Trinkle23897
parent 09e10e384f
commit 3a08e27ed4
8 changed files with 299 additions and 83 deletions

View File

@ -10,7 +10,12 @@ from tianshou.data import Batch, to_torch
def test_batch():
assert list(Batch()) == []
assert Batch().is_empty()
assert Batch(b={'c': {}}).is_empty()
assert not Batch(b={'c': {}}).is_empty()
assert Batch(b={'c': {}}).is_empty(recurse=True)
assert not Batch(a=Batch(), b=Batch(c=Batch())).is_empty()
assert Batch(a=Batch(), b=Batch(c=Batch())).is_empty(recurse=True)
assert not Batch(d=1).is_empty()
assert not Batch(a=np.float64(1.0)).is_empty()
assert len(Batch(a=[1, 2, 3], b={'c': {}})) == 3
assert not Batch(a=[1, 2, 3]).is_empty()
b = Batch()
@ -109,6 +114,11 @@ def test_batch():
assert isinstance(batch5.b, Batch)
assert np.allclose(batch5.b.index, [1])
# None is a valid object and can be stored in Batch
a = Batch.stack([Batch(a=None), Batch(b=None)])
assert a.a[0] is None and a.a[1] is None
assert a.b[0] is None and a.b[1] is None
def test_batch_over_batch():
batch = Batch(a=[3, 4, 5], b=[4, 5, 6])
@ -162,6 +172,20 @@ def test_batch_cat_and_stack():
assert isinstance(b12_cat_in.a.d.e, np.ndarray)
assert b12_cat_in.a.d.e.ndim == 1
a = Batch(a=Batch(a=np.random.randn(3, 4)))
assert np.allclose(
np.concatenate([a.a.a, a.a.a]),
Batch.cat([a, Batch(a=Batch(a=Batch())), a]).a.a)
# test cat with lens infer
a = Batch(a=Batch(a=np.random.randn(3, 4)), b=np.random.randn(3, 4))
b = Batch(a=Batch(a=Batch(), t=Batch()), b=np.random.randn(3, 4))
ans = Batch.cat([a, b, a])
assert np.allclose(ans.a.a,
np.concatenate([a.a.a, np.zeros((3, 4)), a.a.a]))
assert np.allclose(ans.b, np.concatenate([a.b, b.b, a.b]))
assert ans.a.t.is_empty()
b12_stack = Batch.stack((b1, b2))
assert isinstance(b12_stack.a.d.e, np.ndarray)
assert b12_stack.a.d.e.ndim == 2
@ -177,6 +201,32 @@ def test_batch_cat_and_stack():
assert torch.allclose(test.b, ans.b)
assert np.allclose(test.common.c, ans.common.c)
# test cat with reserved keys (values are Batch())
b1 = Batch(a=np.random.rand(3, 4), common=Batch(c=np.random.rand(3, 5)))
b2 = Batch(a=Batch(),
b=torch.rand(4, 3),
common=Batch(c=np.random.rand(4, 5)))
test = Batch.cat([b1, b2])
ans = Batch(a=np.concatenate([b1.a, np.zeros((4, 4))]),
b=torch.cat([torch.zeros(3, 3), b2.b]),
common=Batch(c=np.concatenate([b1.common.c, b2.common.c])))
assert np.allclose(test.a, ans.a)
assert torch.allclose(test.b, ans.b)
assert np.allclose(test.common.c, ans.common.c)
# test cat with all reserved keys (values are Batch())
b1 = Batch(a=Batch(), common=Batch(c=np.random.rand(3, 5)))
b2 = Batch(a=Batch(),
b=torch.rand(4, 3),
common=Batch(c=np.random.rand(4, 5)))
test = Batch.cat([b1, b2])
ans = Batch(a=Batch(),
b=torch.cat([torch.zeros(3, 3), b2.b]),
common=Batch(c=np.concatenate([b1.common.c, b2.common.c])))
assert ans.a.is_empty()
assert torch.allclose(test.b, ans.b)
assert np.allclose(test.common.c, ans.common.c)
# test stack with compatible keys
b3 = Batch(a=np.zeros((3, 4)),
b=torch.ones((2, 5)),
@ -205,6 +255,25 @@ def test_batch_cat_and_stack():
assert np.allclose(d.c, [3, 0, 7])
assert np.allclose(d.d, [0, 6, 9])
# test stack with empty Batch()
assert Batch.stack([Batch(), Batch(), Batch()]).is_empty()
a = Batch(a=1, b=2, c=3, d=Batch(), e=Batch())
b = Batch(a=4, b=5, d=6, e=Batch())
c = Batch(c=7, b=6, d=9, e=Batch())
d = Batch.stack([a, b, c])
assert np.allclose(d.a, [1, 4, 0])
assert np.allclose(d.b, [2, 5, 6])
assert np.allclose(d.c, [3, 0, 7])
assert np.allclose(d.d, [0, 6, 9])
assert d.e.is_empty()
b1 = Batch(a=Batch(), common=Batch(c=np.random.rand(4, 5)))
b2 = Batch(b=Batch(), common=Batch(c=np.random.rand(4, 5)))
test = Batch.stack([b1, b2], axis=-1)
assert test.a.is_empty()
assert test.b.is_empty()
assert np.allclose(test.common.c,
np.stack([b1.common.c, b2.common.c], axis=-1))
b1 = Batch(a=np.random.rand(4, 4), common=Batch(c=np.random.rand(4, 5)))
b2 = Batch(b=torch.rand(4, 6), common=Batch(c=np.random.rand(4, 5)))
test = Batch.stack([b1, b2])

View File

@ -14,11 +14,17 @@ warnings.filterwarnings(
def _is_batch_set(data: Any) -> bool:
# Batch set is a list/tuple of dict/Batch objects,
# or 1-D np.ndarray with np.object type,
# where each element is a dict/Batch object
if isinstance(data, (list, tuple)):
if len(data) > 0 and all(isinstance(e, (dict, Batch)) for e in data):
return True
elif isinstance(data, np.ndarray) and data.dtype == np.object:
if all(isinstance(e, (dict, Batch)) for e in data.tolist()):
# ``for e in data`` will just unpack the first dimension,
# but data.tolist() will flatten ndarray of objects
# so do not use data.tolist()
if all(isinstance(e, (dict, Batch)) for e in data):
return True
return False
@ -39,7 +45,7 @@ def _create_value(inst: Any, size: int, stack=True) -> Union[
# here we do not consider scalar types, following the
# behavior of numpy which does not support concatenation
# of zero-dimensional arrays (scalars)
raise TypeError(f"cannot cat {inst} with which is scalar")
raise TypeError(f"cannot concatenate with {inst} which is scalar")
if has_shape:
shape = (size, *inst.shape) if stack else (size, *inst.shape[1:])
if isinstance(inst, np.ndarray):
@ -95,9 +101,9 @@ class Batch:
In short, you can define a :class:`Batch` with any key-value pair.
For Numpy arrays, only data types with ``np.object``, bool, and number
are supported. For strings or other data types, however, they can be
held in ``np.object`` arrays.
For Numpy arrays, only data types with ``np.object``, bool, and number are
supported. For strings or other data types, however, they can be held in
``np.object`` arrays.
The current implementation of Tianshou typically use 7 reserved keys in
:class:`~tianshou.data.Batch`:
@ -108,9 +114,39 @@ class Batch:
* ``done`` the done flag of step :math:`t` ;
* ``obs_next`` the observation of step :math:`t+1` ;
* ``info`` the info of step :math:`t` (in ``gym.Env``, the ``env.step()``\
function returns 4 arguments, and the last one is ``info``);
function returns 4 arguments, and the last one is ``info``);
* ``policy`` the data computed by policy in step :math:`t`;
For convenience, :class:`~tianshou.data.Batch` supports the mechanism of
key reservation: one can specify a key without any value, which serves as
a placeholder for the Batch object. For example, you know there will be a
key named ``obs``, but do not know the value until the simulator runs. Then
you can reserve the key ``obs``. This is done by setting the value to
``Batch()``.
For a Batch object, we call it "incomplete" if: (i) it is ``Batch()``; (ii)
it has reserved keys; (iii) any of its sub-Batch is incomplete. Otherwise,
the Batch object is finalized.
Key reservation mechanism is convenient, but also causes some problem in
aggregation operators like ``stack`` or ``cat`` of Batch objects. We say
that Batch objects are compatible for aggregation with three cases:
1. finalized Batch objects are compatible if and only if their exists a \
way to extend keys so that their structures are exactly the same.
2. incomplete Batch objects and other finalized objects are compatible if \
their exists a way to extend keys so that incomplete Batch objects can \
have the same structure as finalized objects.
3. incomplete Batch objects themselevs are compatible if their exists a \
way to extend keys so that their structure can be the same.
In a word, incomplete Batch objects have a set of possible structures
in the future, but finalized Batch object only have a finalized structure.
Batch objects are compatible if and only if they share at least one
commonly possible structure by extending keys.
:class:`~tianshou.data.Batch` object can be initialized by a wide variety
of arguments, ranging from the key/value pairs or dictionary, to list and
Numpy arrays of :class:`dict` or Batch instances where each element is
@ -126,8 +162,8 @@ class Batch:
)
:class:`~tianshou.data.Batch` has the same API as a native Python
:class:`dict`. In this regard, one can access stored data using string
key, or iterate over stored data:
:class:`dict`. In this regard, one can access stored data using string key,
or iterate over stored data:
::
>>> data = Batch(a=4, b=[5, 5])
@ -153,7 +189,7 @@ class Batch:
)
>>> for sample in data:
>>> print(sample.a)
[0., 2.]
[0. 2.]
>>> print(data.shape)
[1, 2]
@ -341,7 +377,7 @@ class Batch:
if len(batch_items) > 0:
b = Batch()
for k, v in batch_items:
if isinstance(v, Batch) and len(v.__dict__) == 0:
if isinstance(v, Batch) and v.is_empty():
b.__dict__[k] = Batch()
else:
b.__dict__[k] = v[index]
@ -376,8 +412,9 @@ class Batch:
except KeyError:
if isinstance(val, Batch):
self.__dict__[key][index] = Batch()
elif isinstance(val, np.ndarray) and \
issubclass(val.dtype.type, (np.bool_, np.number)):
elif isinstance(val, torch.Tensor) or \
(isinstance(val, np.ndarray) and
issubclass(val.dtype.type, (np.bool_, np.number))):
self.__dict__[key][index] = 0
else:
self.__dict__[key][index] = None
@ -389,14 +426,14 @@ class Batch:
for (k, r), v in zip(self.__dict__.items(),
other.__dict__.values()):
# TODO are keys consistent?
if r is None:
if isinstance(r, Batch) and r.is_empty():
continue
else:
self.__dict__[k] += v
return self
elif isinstance(other, (Number, np.number)):
for k, r in self.items():
if r is None:
if isinstance(r, Batch) and r.is_empty():
continue
else:
self.__dict__[k] += other
@ -413,7 +450,9 @@ class Batch:
"""Algebraic multiplication with a scalar value in-place."""
assert isinstance(val, (Number, np.number)), \
"Only multiplication by a number is supported."
for k in self.__dict__.keys():
for k, r in self.__dict__.items():
if isinstance(r, Batch) and r.is_empty():
continue
self.__dict__[k] *= val
return self
@ -425,7 +464,9 @@ class Batch:
"""Algebraic division with a scalar value in-place."""
assert isinstance(val, (Number, np.number)), \
"Only division by a number is supported."
for k in self.__dict__.keys():
for k, r in self.__dict__.items():
if isinstance(r, Batch) and r.is_empty():
continue
self.__dict__[k] /= val
return self
@ -501,6 +542,77 @@ class Batch:
elif isinstance(v, Batch):
v.to_torch(dtype, device)
def __cat(self,
batches: Union['Batch', List[Union[dict, 'Batch']]],
lens: List[int]) -> None:
"""::
>>> a = Batch(a=np.random.randn(3, 4))
>>> x = Batch(a=a, b=np.random.randn(4, 4))
>>> y = Batch(a=Batch(a=Batch()), b=np.random.randn(4, 4))
If we want to concatenate x and y, we want to pad y.a.a with zeros.
Without ``lens`` as a hint, when we concatenate x.a and y.a, we would
not be able to know how to pad y.a. So ``Batch.cat_`` should compute
the ``lens`` to give ``Batch.__cat`` a hint.
::
>>> ans = Batch.cat([x, y])
>>> # this is equivalent to the following line
>>> ans = Batch(); ans.__cat([x, y], lens=[3, 4])
>>> # this lens is equal to [len(a), len(b)]
"""
# partial keys will be padded by zeros
# with the shape of [len, rest_shape]
sum_lens = [0]
for x in lens:
sum_lens.append(sum_lens[-1] + x)
# collect non-empty keys
keys_map = [
set(k for k, v in batch.items()
if not (isinstance(v, Batch) and v.is_empty()))
for batch in batches]
keys_shared = set.intersection(*keys_map)
values_shared = [[e[k] for e in batches] for k in keys_shared]
_assert_type_keys(keys_shared)
for k, v in zip(keys_shared, values_shared):
if all(isinstance(e, (dict, Batch)) for e in v):
batch_holder = Batch()
batch_holder.__cat(v, lens=lens)
self.__dict__[k] = batch_holder
elif all(isinstance(e, torch.Tensor) for e in v):
self.__dict__[k] = torch.cat(v)
else:
# cat Batch(a=np.zeros((3, 4))) and Batch(a=Batch(b=Batch()))
# will fail here
v = np.concatenate(v)
if not issubclass(v.dtype.type, (np.bool_, np.number)):
v = v.astype(np.object)
self.__dict__[k] = v
keys_total = set.union(*[set(b.keys()) for b in batches])
keys_reserve_or_partial = set.difference(keys_total, keys_shared)
_assert_type_keys(keys_reserve_or_partial)
# keys that are reserved in all batches
keys_reserve = set.difference(keys_total, set.union(*keys_map))
# keys that occur only in some batches, but not all
keys_partial = keys_reserve_or_partial.difference(keys_reserve)
for k in keys_reserve:
# reserved keys
self.__dict__[k] = Batch()
for k in keys_partial:
for i, e in enumerate(batches):
if k not in e.__dict__:
continue
val = e.get(k)
if isinstance(val, Batch) and val.is_empty():
continue
try:
self.__dict__[k][sum_lens[i]:sum_lens[i + 1]] = val
except KeyError:
self.__dict__[k] = \
_create_value(val, sum_lens[-1], stack=False)
self.__dict__[k][sum_lens[i]:sum_lens[i + 1]] = val
def cat_(self,
batches: Union['Batch', List[Union[dict, 'Batch']]]) -> None:
"""Concatenate a list of (or one) :class:`~tianshou.data.Batch` objects
@ -511,40 +623,25 @@ class Batch:
if len(batches) == 0:
return
batches = [x if isinstance(x, Batch) else Batch(x) for x in batches]
if len(self.__dict__) > 0:
# x.is_empty() means that x is Batch() and should be ignored
batches = [x for x in batches if not x.is_empty()]
try:
# x.is_empty(recurse=True) here means x is a nested empty batch
# like Batch(a=Batch), and we have to treat it as length zero and
# keep it.
lens = [0 if x.is_empty(recurse=True) else len(x)
for x in batches]
except TypeError as e:
e2 = ValueError(
f'Batch.cat_ meets an exception. Maybe because there is '
f'any scalar in {batches} but Batch.cat_ does not support'
f'the concatenation of scalar.')
raise Exception([e, e2])
if not self.is_empty():
batches = [self] + list(batches)
# partial keys will be padded by zeros
# with the shape of [len, rest_shape]
lens = [len(x) for x in batches]
sum_lens = [0]
for x in lens:
sum_lens.append(sum_lens[-1] + x)
keys_map = list(map(lambda e: set(e.keys()), batches))
keys_shared = set.intersection(*keys_map)
values_shared = [[e[k] for e in batches] for k in keys_shared]
_assert_type_keys(keys_shared)
for k, v in zip(keys_shared, values_shared):
if all(isinstance(e, (dict, Batch)) for e in v):
self.__dict__[k] = Batch.cat(v)
elif all(isinstance(e, torch.Tensor) for e in v):
self.__dict__[k] = torch.cat(v)
else:
v = np.concatenate(v)
if not issubclass(v.dtype.type, (np.bool_, np.number)):
v = v.astype(np.object)
self.__dict__[k] = v
keys_partial = set.union(*keys_map) - keys_shared
_assert_type_keys(keys_partial)
for k in keys_partial:
for i, e in enumerate(batches):
val = e.get(k, None)
if val is not None:
try:
self.__dict__[k][sum_lens[i]:sum_lens[i + 1]] = val
except KeyError:
self.__dict__[k] = \
_create_value(val, sum_lens[-1], stack=False)
self.__dict__[k][sum_lens[i]:sum_lens[i + 1]] = val
lens = [0 if self.is_empty(recurse=True) else len(self)] + lens
return self.__cat(batches, lens)
@staticmethod
def cat(batches: List[Union[dict, 'Batch']]) -> 'Batch':
@ -577,9 +674,13 @@ class Batch:
if len(batches) == 0:
return
batches = [x if isinstance(x, Batch) else Batch(x) for x in batches]
if len(self.__dict__) > 0:
if not self.is_empty():
batches = [self] + list(batches)
keys_map = list(map(lambda e: set(e.keys()), batches))
# collect non-empty keys
keys_map = [
set(k for k, v in batch.items()
if not (isinstance(v, Batch) and v.is_empty()))
for batch in batches]
keys_shared = set.intersection(*keys_map)
values_shared = [[e[k] for e in batches] for k in keys_shared]
_assert_type_keys(keys_shared)
@ -593,22 +694,35 @@ class Batch:
if not issubclass(v.dtype.type, (np.bool_, np.number)):
v = v.astype(np.object)
self.__dict__[k] = v
keys_partial = set.difference(set.union(*keys_map), keys_shared)
# all the keys
keys_total = set.union(*[set(b.keys()) for b in batches])
# keys that are reserved in all batches
keys_reserve = set.difference(keys_total, set.union(*keys_map))
# keys that are either partial or reserved
keys_reserve_or_partial = set.difference(keys_total, keys_shared)
# keys that occur only in some batches, but not all
keys_partial = keys_reserve_or_partial.difference(keys_reserve)
if keys_partial and axis != 0:
raise ValueError(
f"Stack of Batch with non-shared keys {keys_partial} "
f"is only supported with axis=0, but got axis={axis}!")
_assert_type_keys(keys_partial)
_assert_type_keys(keys_reserve_or_partial)
for k in keys_reserve:
# reserved keys
self.__dict__[k] = Batch()
for k in keys_partial:
for i, e in enumerate(batches):
val = e.get(k, None)
if val is not None:
try:
self.__dict__[k][i] = val
except KeyError:
self.__dict__[k] = \
_create_value(val, len(batches))
self.__dict__[k][i] = val
if k not in e.__dict__:
continue
val = e.get(k)
if isinstance(val, Batch) and val.is_empty():
continue
try:
self.__dict__[k][i] = val
except KeyError:
self.__dict__[k] = \
_create_value(val, len(batches))
self.__dict__[k][i] = val
@staticmethod
def stack(batches: List[Union[dict, 'Batch']], axis: int = 0) -> 'Batch':
@ -691,26 +805,53 @@ class Batch:
"""Return len(self)."""
r = []
for v in self.__dict__.values():
if isinstance(v, Batch) and v.is_empty():
if isinstance(v, Batch) and v.is_empty(recurse=True):
continue
elif hasattr(v, '__len__') and (not isinstance(
v, (np.ndarray, torch.Tensor)) or v.ndim > 0):
r.append(len(v))
else:
raise TypeError("Object of type 'Batch' has no len()")
raise TypeError(f"Object {v} in {self} has no len()")
if len(r) == 0:
raise TypeError("Object of type 'Batch' has no len()")
raise TypeError(f"Object {self} has no len()")
return min(r)
def is_empty(self):
return not any(
not x.is_empty() if isinstance(x, Batch)
else hasattr(x, '__len__') and len(x) > 0 for x in self.values())
def is_empty(self, recurse: bool = False):
"""
Test if a Batch is empty. If ``recurse=True``, it further tests the
values of the object; else it only tests the existence of any key.
``b.is_empty(recurse=True)`` is mainly used to distinguish
``Batch(a=Batch(a=Batch()))`` and ``Batch(a=1)``. They both raise
exceptions when applied to ``len()``, but the former can be used in
``cat``, while the latter is a scalar and cannot be used in ``cat``.
Another usage is in ``__len__``, where we have to skip checking the
length of recursely empty Batch.
::
>>> Batch().is_empty()
True
>>> Batch(a=Batch(), b=Batch(c=Batch())).is_empty()
False
>>> Batch(a=Batch(), b=Batch(c=Batch())).is_empty(recurse=True)
True
>>> Batch(d=1).is_empty()
False
>>> Batch(a=np.float64(1.0)).is_empty()
False
"""
if len(self.__dict__) == 0:
return True
if not recurse:
return False
return all(False if not isinstance(x, Batch)
else x.is_empty(recurse=True) for x in self.values())
@property
def shape(self) -> List[int]:
"""Return self.shape."""
if len(self.__dict__.keys()) == 0:
if self.is_empty():
return []
else:
data_shape = []

View File

@ -98,7 +98,7 @@ class Collector(object):
stat_size: Optional[int] = 100,
action_noise: Optional[BaseNoise] = None,
reward_metric: Optional[Callable[[np.ndarray], float]] = None,
**kwargs) -> None:
) -> None:
super().__init__()
self.env = env
self.env_num = 1

View File

@ -108,7 +108,8 @@ class BasePolicy(ABC, nn.Module):
batch: Batch,
v_s_: Optional[Union[np.ndarray, torch.Tensor]] = None,
gamma: float = 0.99,
gae_lambda: float = 0.95) -> Batch:
gae_lambda: float = 0.95,
) -> Batch:
"""Compute returns over given full-length episodes, including the
implementation of Generalized Advantage Estimator (arXiv:1506.02438).
@ -124,18 +125,19 @@ class BasePolicy(ABC, nn.Module):
:return: a Batch. The result will be stored in batch.returns.
"""
rew = batch.rew
if v_s_ is None:
v_s_ = batch.rew * 0.
v_s_ = rew * 0.
else:
if not isinstance(v_s_, np.ndarray):
v_s_ = np.array(v_s_, np.float)
v_s_ = v_s_.reshape(batch.rew.shape)
v_s_ = v_s_.reshape(rew.shape)
returns = np.roll(v_s_, 1, axis=0)
m = (1. - batch.done) * gamma
delta = batch.rew + v_s_ * m - returns
delta = rew + v_s_ * m - returns
m *= gae_lambda
gae = 0.
for i in range(len(batch.rew) - 1, -1, -1):
for i in range(len(rew) - 1, -1, -1):
gae = delta[i] + m[i] * gae
returns[i] += gae
batch.returns = returns
@ -149,7 +151,7 @@ class BasePolicy(ABC, nn.Module):
target_q_fn: Callable[[ReplayBuffer, np.ndarray], torch.Tensor],
gamma: float = 0.99,
n_step: int = 1,
rew_norm: bool = False
rew_norm: bool = False,
) -> np.ndarray:
r"""Compute n-step return for Q-learning targets:
@ -180,8 +182,9 @@ class BasePolicy(ABC, nn.Module):
:return: a Batch. The result will be stored in batch.returns as a
torch.Tensor with shape (bsz, ).
"""
rew = buffer.rew
if rew_norm:
bfr = buffer.rew[:min(len(buffer), 1000)] # avoid large buffer
bfr = rew[:min(len(buffer), 1000)] # avoid large buffer
mean, std = bfr.mean(), bfr.std()
if np.isclose(std, 0):
mean, std = 0, 1
@ -189,7 +192,7 @@ class BasePolicy(ABC, nn.Module):
mean, std = 0, 1
returns = np.zeros_like(indice)
gammas = np.zeros_like(indice) + n_step
done, rew, buf_len = buffer.done, buffer.rew, len(buffer)
done, buf_len = buffer.done, len(buffer)
for n in range(n_step - 1, -1, -1):
now = (indice + n) % buf_len
gammas[done[now] > 0] = n

View File

@ -23,7 +23,7 @@ class ImitationPolicy(BasePolicy):
"""
def __init__(self, model: torch.nn.Module, optim: torch.optim.Optimizer,
mode: str = 'continuous', **kwargs) -> None:
mode: str = 'continuous') -> None:
super().__init__()
self.model = model
self.optim = optim

View File

@ -21,6 +21,8 @@ class DQNPolicy(BasePolicy):
ahead.
:param int target_update_freq: the target network update frequency (``0``
if you do not use the target network).
:param bool reward_normalization: normalize the reward to Normal(0, 1),
defaults to ``False``.
.. seealso::
@ -34,6 +36,7 @@ class DQNPolicy(BasePolicy):
discount_factor: float = 0.99,
estimation_step: int = 1,
target_update_freq: Optional[int] = 0,
reward_normalization: bool = False,
**kwargs) -> None:
super().__init__(**kwargs)
self.model = model
@ -49,6 +52,7 @@ class DQNPolicy(BasePolicy):
if self._target:
self.model_old = deepcopy(self.model)
self.model_old.eval()
self._rew_norm = reward_normalization
def set_eps(self, eps: float) -> None:
"""Set the eps for epsilon-greedy exploration."""
@ -94,7 +98,8 @@ class DQNPolicy(BasePolicy):
to :math:`Q_{new}`.
"""
batch = self.compute_nstep_return(
batch, buffer, indice, self._target_q, self._gamma, self._n_step)
batch, buffer, indice, self._target_q,
self._gamma, self._n_step, self._rew_norm)
if isinstance(buffer, PrioritizedReplayBuffer):
batch.update_weight = buffer.update_weight
batch.indice = indice

View File

@ -27,7 +27,6 @@ def offpolicy_trainer(
writer: Optional[SummaryWriter] = None,
log_interval: int = 1,
verbose: bool = True,
**kwargs
) -> Dict[str, Union[float, str]]:
"""A wrapper for off-policy trainer procedure.

View File

@ -27,7 +27,6 @@ def onpolicy_trainer(
writer: Optional[SummaryWriter] = None,
log_interval: int = 1,
verbose: bool = True,
**kwargs
) -> Dict[str, Union[float, str]]:
"""A wrapper for on-policy trainer procedure.