env seed vary between envs of parallel
This commit is contained in:
parent
78e86703f4
commit
4fe9b29ebe
19
dreamer.py
19
dreamer.py
@ -158,13 +158,13 @@ def make_dataset(episodes, config):
|
||||
return dataset
|
||||
|
||||
|
||||
def make_env(config, mode):
|
||||
def make_env(config, mode, id):
|
||||
suite, task = config.task.split("_", 1)
|
||||
if suite == "dmc":
|
||||
import envs.dmc as dmc
|
||||
|
||||
env = dmc.DeepMindControl(
|
||||
task, config.action_repeat, config.size, seed=config.seed
|
||||
task, config.action_repeat, config.size, seed=config.seed + id
|
||||
)
|
||||
env = wrappers.NormalizeActions(env)
|
||||
elif suite == "atari":
|
||||
@ -180,7 +180,7 @@ def make_env(config, mode):
|
||||
sticky=config.stickey,
|
||||
actions=config.actions,
|
||||
resize=config.resize,
|
||||
seed=config.seed,
|
||||
seed=config.seed + id,
|
||||
)
|
||||
env = wrappers.OneHotAction(env)
|
||||
elif suite == "dmlab":
|
||||
@ -190,18 +190,18 @@ def make_env(config, mode):
|
||||
task,
|
||||
mode if "train" in mode else "test",
|
||||
config.action_repeat,
|
||||
seed=config.seed,
|
||||
seed=config.seed + id,
|
||||
)
|
||||
env = wrappers.OneHotAction(env)
|
||||
elif suite == "memorymaze":
|
||||
from envs.memorymaze import MemoryMaze
|
||||
|
||||
env = MemoryMaze(task, seed=config.seed)
|
||||
env = MemoryMaze(task, seed=config.seed + id)
|
||||
env = wrappers.OneHotAction(env)
|
||||
elif suite == "crafter":
|
||||
import envs.crafter as crafter
|
||||
|
||||
env = crafter.Crafter(task, config.size, seed=config.seed)
|
||||
env = crafter.Crafter(task, config.size, seed=config.seed + id)
|
||||
env = wrappers.OneHotAction(env)
|
||||
elif suite == "minecraft":
|
||||
import envs.minecraft as minecraft
|
||||
@ -249,9 +249,9 @@ def main(config):
|
||||
else:
|
||||
directory = config.evaldir
|
||||
eval_eps = tools.load_episodes(directory, limit=1)
|
||||
make = lambda mode: make_env(config, mode)
|
||||
train_envs = [make("train") for _ in range(config.envs)]
|
||||
eval_envs = [make("eval") for _ in range(config.envs)]
|
||||
make = lambda mode, id: make_env(config, mode, id)
|
||||
train_envs = [make("train", i) for i in range(config.envs)]
|
||||
eval_envs = [make("eval", i) for i in range(config.envs)]
|
||||
if config.parallel:
|
||||
train_envs = [Parallel(env, "process") for env in train_envs]
|
||||
eval_envs = [Parallel(env, "process") for env in eval_envs]
|
||||
@ -259,6 +259,7 @@ def main(config):
|
||||
train_envs = [Damy(env) for env in train_envs]
|
||||
eval_envs = [Damy(env) for env in eval_envs]
|
||||
acts = train_envs[0].action_space
|
||||
print("Action Space", acts)
|
||||
config.num_actions = acts.n if hasattr(acts, "n") else acts.shape[0]
|
||||
|
||||
state = None
|
||||
|
Loading…
x
Reference in New Issue
Block a user