hotfix:keep statisics of buffer when reset buffer in on policy trainer (#328)
This commit is contained in:
parent
8963a14327
commit
105b277b87
@ -98,11 +98,12 @@ class ReplayBuffer:
|
||||
buf.__setstate__(from_hdf5(f, device=device))
|
||||
return buf
|
||||
|
||||
def reset(self) -> None:
|
||||
def reset(self, keep_statistics: bool = False) -> None:
|
||||
"""Clear all the data in replay buffer and episode statistics."""
|
||||
self.last_index = np.array([0])
|
||||
self._index = self._size = 0
|
||||
self._ep_rew, self._ep_len, self._ep_idx = 0.0, 0, 0
|
||||
if not keep_statistics:
|
||||
self._ep_rew, self._ep_len, self._ep_idx = 0.0, 0, 0
|
||||
|
||||
def set_batch(self, batch: Batch) -> None:
|
||||
"""Manually choose the batch you want the ReplayBuffer to manage."""
|
||||
|
@ -48,11 +48,11 @@ class ReplayBufferManager(ReplayBuffer):
|
||||
def __len__(self) -> int:
|
||||
return self._lengths.sum()
|
||||
|
||||
def reset(self) -> None:
|
||||
def reset(self, keep_statistics: bool = False) -> None:
|
||||
self.last_index = self._offset.copy()
|
||||
self._lengths = np.zeros_like(self._offset)
|
||||
for buf in self.buffers:
|
||||
buf.reset()
|
||||
buf.reset(keep_statistics=keep_statistics)
|
||||
|
||||
def _set_batch_for_children(self) -> None:
|
||||
for offset, buf in zip(self._offset, self.buffers):
|
||||
|
@ -105,9 +105,9 @@ class Collector(object):
|
||||
"""Reset the statistic variables."""
|
||||
self.collect_step, self.collect_episode, self.collect_time = 0, 0, 0.0
|
||||
|
||||
def reset_buffer(self) -> None:
|
||||
def reset_buffer(self, keep_statistics: bool = False) -> None:
|
||||
"""Reset the data buffer."""
|
||||
self.buffer.reset()
|
||||
self.buffer.reset(keep_statistics=keep_statistics)
|
||||
|
||||
def reset_env(self) -> None:
|
||||
"""Reset all of the environments."""
|
||||
|
@ -134,7 +134,7 @@ def onpolicy_trainer(
|
||||
losses = policy.update(
|
||||
0, train_collector.buffer,
|
||||
batch_size=batch_size, repeat=repeat_per_collect)
|
||||
train_collector.reset_buffer()
|
||||
train_collector.reset_buffer(keep_statistics=True)
|
||||
step = max([1] + [
|
||||
len(v) for v in losses.values() if isinstance(v, list)])
|
||||
gradient_step += step
|
||||
|
Loading…
x
Reference in New Issue
Block a user