From 4fe9b29ebe9fdc6150977fe04bb356b36b5ea0c0 Mon Sep 17 00:00:00 2001 From: NM512 Date: Fri, 5 Jan 2024 10:44:20 +0900 Subject: [PATCH] env seed vary between envs of parallel --- dreamer.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/dreamer.py b/dreamer.py index 3c70d88..4584e36 100644 --- a/dreamer.py +++ b/dreamer.py @@ -158,13 +158,13 @@ def make_dataset(episodes, config): return dataset -def make_env(config, mode): +def make_env(config, mode, id): suite, task = config.task.split("_", 1) if suite == "dmc": import envs.dmc as dmc env = dmc.DeepMindControl( - task, config.action_repeat, config.size, seed=config.seed + task, config.action_repeat, config.size, seed=config.seed + id ) env = wrappers.NormalizeActions(env) elif suite == "atari": @@ -180,7 +180,7 @@ def make_env(config, mode): sticky=config.stickey, actions=config.actions, resize=config.resize, - seed=config.seed, + seed=config.seed + id, ) env = wrappers.OneHotAction(env) elif suite == "dmlab": @@ -190,18 +190,18 @@ def make_env(config, mode): task, mode if "train" in mode else "test", config.action_repeat, - seed=config.seed, + seed=config.seed + id, ) env = wrappers.OneHotAction(env) elif suite == "memorymaze": from envs.memorymaze import MemoryMaze - env = MemoryMaze(task, seed=config.seed) + env = MemoryMaze(task, seed=config.seed + id) env = wrappers.OneHotAction(env) elif suite == "crafter": import envs.crafter as crafter - env = crafter.Crafter(task, config.size, seed=config.seed) + env = crafter.Crafter(task, config.size, seed=config.seed + id) env = wrappers.OneHotAction(env) elif suite == "minecraft": import envs.minecraft as minecraft @@ -249,9 +249,9 @@ def main(config): else: directory = config.evaldir eval_eps = tools.load_episodes(directory, limit=1) - make = lambda mode: make_env(config, mode) - train_envs = [make("train") for _ in range(config.envs)] - eval_envs = [make("eval") for _ in range(config.envs)] + make = lambda mode, id: make_env(config, mode, id) + train_envs = [make("train", i) for i in range(config.envs)] + eval_envs = [make("eval", i) for i in range(config.envs)] if config.parallel: train_envs = [Parallel(env, "process") for env in train_envs] eval_envs = [Parallel(env, "process") for env in eval_envs] @@ -259,6 +259,7 @@ def main(config): train_envs = [Damy(env) for env in train_envs] eval_envs = [Damy(env) for env in eval_envs] acts = train_envs[0].action_space + print("Action Space", acts) config.num_actions = acts.n if hasattr(acts, "n") else acts.shape[0] state = None