modified training step display
This commit is contained in:
parent
f3fe3a872e
commit
34a44916f7
19
dreamer.py
19
dreamer.py
@ -213,10 +213,12 @@ def make_env(config, logger, mode, train_eps, eval_eps):
|
|||||||
env = wrappers.OneHotAction(env)
|
env = wrappers.OneHotAction(env)
|
||||||
elif suite == "MemoryMaze":
|
elif suite == "MemoryMaze":
|
||||||
from envs.memorymaze import MemoryMaze
|
from envs.memorymaze import MemoryMaze
|
||||||
|
|
||||||
env = MemoryMaze(task)
|
env = MemoryMaze(task)
|
||||||
env = wrappers.OneHotAction(env)
|
env = wrappers.OneHotAction(env)
|
||||||
elif suite == "crafter":
|
elif suite == "crafter":
|
||||||
import envs.crafter as crafter
|
import envs.crafter as crafter
|
||||||
|
|
||||||
env = crafter.Crafter(task, config.size)
|
env = crafter.Crafter(task, config.size)
|
||||||
env = wrappers.OneHotAction(env)
|
env = wrappers.OneHotAction(env)
|
||||||
else:
|
else:
|
||||||
@ -254,17 +256,19 @@ class ProcessEpisodeWrap:
|
|||||||
length = len(episode["reward"]) - 1
|
length = len(episode["reward"]) - 1
|
||||||
score = float(episode["reward"].astype(np.float64).sum())
|
score = float(episode["reward"].astype(np.float64).sum())
|
||||||
video = episode["image"]
|
video = episode["image"]
|
||||||
|
# add new episode
|
||||||
cache[str(filename)] = episode
|
cache[str(filename)] = episode
|
||||||
if mode == "train":
|
if mode == "train":
|
||||||
total = 0
|
step_in_dataset = 0
|
||||||
for key, ep in reversed(sorted(cache.items(), key=lambda x: x[0])):
|
for key, ep in reversed(sorted(cache.items(), key=lambda x: x[0])):
|
||||||
if not config.dataset_size or total <= config.dataset_size - length:
|
if (
|
||||||
total += len(ep["reward"]) - 1
|
not config.dataset_size
|
||||||
|
or step_in_dataset + (len(ep["reward"]) - 1) <= config.dataset_size
|
||||||
|
):
|
||||||
|
step_in_dataset += len(ep["reward"]) - 1
|
||||||
else:
|
else:
|
||||||
del cache[key]
|
del cache[key]
|
||||||
logger.scalar("dataset_size", total)
|
logger.scalar("dataset_size", step_in_dataset)
|
||||||
# use dataset_size as log step for a condition of envs > 1
|
|
||||||
log_step = total * config.action_repeat
|
|
||||||
elif mode == "eval":
|
elif mode == "eval":
|
||||||
# keep only last item for saving memory
|
# keep only last item for saving memory
|
||||||
while len(cache) > 1:
|
while len(cache) > 1:
|
||||||
@ -285,7 +289,6 @@ class ProcessEpisodeWrap:
|
|||||||
score = sum(cls.eval_scores) / len(cls.eval_scores)
|
score = sum(cls.eval_scores) / len(cls.eval_scores)
|
||||||
length = sum(cls.eval_lengths) / len(cls.eval_lengths)
|
length = sum(cls.eval_lengths) / len(cls.eval_lengths)
|
||||||
episode_num = len(cls.eval_scores)
|
episode_num = len(cls.eval_scores)
|
||||||
log_step = logger.step
|
|
||||||
logger.video(f"{mode}_policy", video[None])
|
logger.video(f"{mode}_policy", video[None])
|
||||||
cls.eval_done = True
|
cls.eval_done = True
|
||||||
|
|
||||||
@ -295,7 +298,7 @@ class ProcessEpisodeWrap:
|
|||||||
logger.scalar(
|
logger.scalar(
|
||||||
f"{mode}_episodes", len(cache) if mode == "train" else episode_num
|
f"{mode}_episodes", len(cache) if mode == "train" else episode_num
|
||||||
)
|
)
|
||||||
logger.write(step=log_step)
|
logger.write(step=logger.step)
|
||||||
|
|
||||||
|
|
||||||
def main(config):
|
def main(config):
|
||||||
|
@ -32,13 +32,16 @@ class MemoryMaze:
|
|||||||
spaces = self._env.observation_space.spaces.copy()
|
spaces = self._env.observation_space.spaces.copy()
|
||||||
else:
|
else:
|
||||||
spaces = {self._obs_key: self._env.observation_space}
|
spaces = {self._obs_key: self._env.observation_space}
|
||||||
return gym.spaces.Dict({
|
return gym.spaces.Dict(
|
||||||
|
{
|
||||||
**spaces,
|
**spaces,
|
||||||
"reward": gym.spaces.Box(-np.inf, np.inf, (), dtype=np.float32),
|
"reward": gym.spaces.Box(-np.inf, np.inf, (), dtype=np.float32),
|
||||||
"is_first": gym.spaces.Box(0, 1, (), dtype=bool),
|
"is_first": gym.spaces.Box(0, 1, (), dtype=bool),
|
||||||
"is_last": gym.spaces.Box(0, 1, (), dtype=bool),
|
"is_last": gym.spaces.Box(0, 1, (), dtype=bool),
|
||||||
"is_terminal": gym.spaces.Box(0, 1, (), dtype=bool),
|
"is_terminal": gym.spaces.Box(0, 1, (), dtype=bool),
|
||||||
})
|
}
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def action_space(self):
|
def action_space(self):
|
||||||
space = self._env.action_space
|
space = self._env.action_space
|
||||||
@ -49,7 +52,7 @@ class MemoryMaze:
|
|||||||
obs, reward, done, info = self._env.step(action)
|
obs, reward, done, info = self._env.step(action)
|
||||||
if not self._obs_is_dict:
|
if not self._obs_is_dict:
|
||||||
obs = {self._obs_key: obs}
|
obs = {self._obs_key: obs}
|
||||||
obs['reward'] = reward
|
obs["reward"] = reward
|
||||||
obs["is_first"] = False
|
obs["is_first"] = False
|
||||||
obs["is_last"] = done
|
obs["is_last"] = done
|
||||||
obs["is_terminal"] = info.get("is_terminal", False)
|
obs["is_terminal"] = info.get("is_terminal", False)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user