From 432a359bcfc0f912344288df750f308bbb0edd24 Mon Sep 17 00:00:00 2001 From: NM512 Date: Mon, 24 Apr 2023 06:25:17 +0900 Subject: [PATCH] put running episode into replay buffer --- dreamer.py | 2 +- envs/wrappers.py | 26 +++++++++++++++++++++++++- tools.py | 2 +- 3 files changed, 27 insertions(+), 3 deletions(-) diff --git a/dreamer.py b/dreamer.py index 5681db1..a84105a 100644 --- a/dreamer.py +++ b/dreamer.py @@ -225,7 +225,7 @@ def make_env(config, logger, mode, train_eps, eval_eps): eval_eps, ) ] - env = wrappers.CollectDataset(env, callbacks) + env = wrappers.CollectDataset(env, mode, train_eps, callbacks=callbacks) env = wrappers.RewardObs(env) return env diff --git a/envs/wrappers.py b/envs/wrappers.py index 177f2d9..a70611d 100644 --- a/envs/wrappers.py +++ b/envs/wrappers.py @@ -1,13 +1,18 @@ import gym import numpy as np +import uuid class CollectDataset: - def __init__(self, env, callbacks=None, precision=32): + def __init__( + self, env, mode, train_eps, eval_eps=dict(), callbacks=None, precision=32 + ): self._env = env self._callbacks = callbacks or () self._precision = precision self._episode = None + self._cache = dict(train=train_eps, eval=eval_eps)[mode] + self._temp_name = str(uuid.uuid4()) def __getattr__(self, name): return getattr(self._env, name) @@ -23,7 +28,11 @@ class CollectDataset: transition["reward"] = reward transition["discount"] = info.get("discount", np.array(1 - float(done))) self._episode.append(transition) + self.add_to_cache(transition) if done: + # detele transitions before whole episode is stored + del self._cache[self._temp_name] + self._temp_name = str(uuid.uuid4()) for key, value in self._episode[1].items(): if key not in self._episode[0]: self._episode[0][key] = 0 * value @@ -43,8 +52,23 @@ class CollectDataset: transition["reward"] = 0.0 transition["discount"] = 1.0 self._episode = [transition] + self.add_to_cache(transition) return obs + def add_to_cache(self, transition): + if self._temp_name not in self._cache: + self._cache[self._temp_name] = dict() + for key, val in transition.items(): + self._cache[self._temp_name][key] = [self._convert(val)] + else: + for key, val in transition.items(): + if key not in self._cache[self._temp_name]: + # fill missing data(action) + self._cache[self._temp_name][key] = [self._convert(0 * val)] + self._cache[self._temp_name][key].append(self._convert(val)) + else: + self._cache[self._temp_name][key].append(self._convert(val)) + def _convert(self, value): value = np.array(value) if np.issubdtype(value.dtype, np.floating): diff --git a/tools.py b/tools.py index bd73beb..de7ccc4 100644 --- a/tools.py +++ b/tools.py @@ -207,7 +207,7 @@ def sample_episodes(episodes, length=None, balance=False, seed=0): total = len(next(iter(episode.values()))) available = total - length if available < 1: - print(f"Skipped short episode of length {available}.") + # print(f"Skipped short episode of length {available}.") continue if balance: index = min(random.randint(0, total), available)