hotfix:keep statisics of buffer when reset buffer in on policy trainer (#328)

This commit is contained in:
ChenDRAG 2021-03-27 16:58:48 +08:00 committed by GitHub
parent 8963a14327
commit 105b277b87
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 8 additions and 7 deletions

View File

@ -98,11 +98,12 @@ class ReplayBuffer:
buf.__setstate__(from_hdf5(f, device=device))
return buf
def reset(self) -> None:
def reset(self, keep_statistics: bool = False) -> None:
"""Clear all the data in replay buffer and episode statistics."""
self.last_index = np.array([0])
self._index = self._size = 0
self._ep_rew, self._ep_len, self._ep_idx = 0.0, 0, 0
if not keep_statistics:
self._ep_rew, self._ep_len, self._ep_idx = 0.0, 0, 0
def set_batch(self, batch: Batch) -> None:
"""Manually choose the batch you want the ReplayBuffer to manage."""

View File

@ -48,11 +48,11 @@ class ReplayBufferManager(ReplayBuffer):
def __len__(self) -> int:
return self._lengths.sum()
def reset(self) -> None:
def reset(self, keep_statistics: bool = False) -> None:
self.last_index = self._offset.copy()
self._lengths = np.zeros_like(self._offset)
for buf in self.buffers:
buf.reset()
buf.reset(keep_statistics=keep_statistics)
def _set_batch_for_children(self) -> None:
for offset, buf in zip(self._offset, self.buffers):

View File

@ -105,9 +105,9 @@ class Collector(object):
"""Reset the statistic variables."""
self.collect_step, self.collect_episode, self.collect_time = 0, 0, 0.0
def reset_buffer(self) -> None:
def reset_buffer(self, keep_statistics: bool = False) -> None:
"""Reset the data buffer."""
self.buffer.reset()
self.buffer.reset(keep_statistics=keep_statistics)
def reset_env(self) -> None:
"""Reset all of the environments."""

View File

@ -134,7 +134,7 @@ def onpolicy_trainer(
losses = policy.update(
0, train_collector.buffer,
batch_size=batch_size, repeat=repeat_per_collect)
train_collector.reset_buffer()
train_collector.reset_buffer(keep_statistics=True)
step = max([1] + [
len(v) for v in losses.values() if isinstance(v, list)])
gradient_step += step