env seed vary between envs of parallel
This commit is contained in:
parent
78e86703f4
commit
4fe9b29ebe
19
dreamer.py
19
dreamer.py
@ -158,13 +158,13 @@ def make_dataset(episodes, config):
|
|||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
def make_env(config, mode):
|
def make_env(config, mode, id):
|
||||||
suite, task = config.task.split("_", 1)
|
suite, task = config.task.split("_", 1)
|
||||||
if suite == "dmc":
|
if suite == "dmc":
|
||||||
import envs.dmc as dmc
|
import envs.dmc as dmc
|
||||||
|
|
||||||
env = dmc.DeepMindControl(
|
env = dmc.DeepMindControl(
|
||||||
task, config.action_repeat, config.size, seed=config.seed
|
task, config.action_repeat, config.size, seed=config.seed + id
|
||||||
)
|
)
|
||||||
env = wrappers.NormalizeActions(env)
|
env = wrappers.NormalizeActions(env)
|
||||||
elif suite == "atari":
|
elif suite == "atari":
|
||||||
@ -180,7 +180,7 @@ def make_env(config, mode):
|
|||||||
sticky=config.stickey,
|
sticky=config.stickey,
|
||||||
actions=config.actions,
|
actions=config.actions,
|
||||||
resize=config.resize,
|
resize=config.resize,
|
||||||
seed=config.seed,
|
seed=config.seed + id,
|
||||||
)
|
)
|
||||||
env = wrappers.OneHotAction(env)
|
env = wrappers.OneHotAction(env)
|
||||||
elif suite == "dmlab":
|
elif suite == "dmlab":
|
||||||
@ -190,18 +190,18 @@ def make_env(config, mode):
|
|||||||
task,
|
task,
|
||||||
mode if "train" in mode else "test",
|
mode if "train" in mode else "test",
|
||||||
config.action_repeat,
|
config.action_repeat,
|
||||||
seed=config.seed,
|
seed=config.seed + id,
|
||||||
)
|
)
|
||||||
env = wrappers.OneHotAction(env)
|
env = wrappers.OneHotAction(env)
|
||||||
elif suite == "memorymaze":
|
elif suite == "memorymaze":
|
||||||
from envs.memorymaze import MemoryMaze
|
from envs.memorymaze import MemoryMaze
|
||||||
|
|
||||||
env = MemoryMaze(task, seed=config.seed)
|
env = MemoryMaze(task, seed=config.seed + id)
|
||||||
env = wrappers.OneHotAction(env)
|
env = wrappers.OneHotAction(env)
|
||||||
elif suite == "crafter":
|
elif suite == "crafter":
|
||||||
import envs.crafter as crafter
|
import envs.crafter as crafter
|
||||||
|
|
||||||
env = crafter.Crafter(task, config.size, seed=config.seed)
|
env = crafter.Crafter(task, config.size, seed=config.seed + id)
|
||||||
env = wrappers.OneHotAction(env)
|
env = wrappers.OneHotAction(env)
|
||||||
elif suite == "minecraft":
|
elif suite == "minecraft":
|
||||||
import envs.minecraft as minecraft
|
import envs.minecraft as minecraft
|
||||||
@ -249,9 +249,9 @@ def main(config):
|
|||||||
else:
|
else:
|
||||||
directory = config.evaldir
|
directory = config.evaldir
|
||||||
eval_eps = tools.load_episodes(directory, limit=1)
|
eval_eps = tools.load_episodes(directory, limit=1)
|
||||||
make = lambda mode: make_env(config, mode)
|
make = lambda mode, id: make_env(config, mode, id)
|
||||||
train_envs = [make("train") for _ in range(config.envs)]
|
train_envs = [make("train", i) for i in range(config.envs)]
|
||||||
eval_envs = [make("eval") for _ in range(config.envs)]
|
eval_envs = [make("eval", i) for i in range(config.envs)]
|
||||||
if config.parallel:
|
if config.parallel:
|
||||||
train_envs = [Parallel(env, "process") for env in train_envs]
|
train_envs = [Parallel(env, "process") for env in train_envs]
|
||||||
eval_envs = [Parallel(env, "process") for env in eval_envs]
|
eval_envs = [Parallel(env, "process") for env in eval_envs]
|
||||||
@ -259,6 +259,7 @@ def main(config):
|
|||||||
train_envs = [Damy(env) for env in train_envs]
|
train_envs = [Damy(env) for env in train_envs]
|
||||||
eval_envs = [Damy(env) for env in eval_envs]
|
eval_envs = [Damy(env) for env in eval_envs]
|
||||||
acts = train_envs[0].action_space
|
acts = train_envs[0].action_space
|
||||||
|
print("Action Space", acts)
|
||||||
config.num_actions = acts.n if hasattr(acts, "n") else acts.shape[0]
|
config.num_actions = acts.n if hasattr(acts, "n") else acts.shape[0]
|
||||||
|
|
||||||
state = None
|
state = None
|
||||||
|
Loading…
x
Reference in New Issue
Block a user