env seed vary between envs of parallel

This commit is contained in:
NM512 2024-01-05 10:44:20 +09:00
parent 78e86703f4
commit 4fe9b29ebe

View File

@ -158,13 +158,13 @@ def make_dataset(episodes, config):
return dataset
def make_env(config, mode):
def make_env(config, mode, id):
suite, task = config.task.split("_", 1)
if suite == "dmc":
import envs.dmc as dmc
env = dmc.DeepMindControl(
task, config.action_repeat, config.size, seed=config.seed
task, config.action_repeat, config.size, seed=config.seed + id
)
env = wrappers.NormalizeActions(env)
elif suite == "atari":
@ -180,7 +180,7 @@ def make_env(config, mode):
sticky=config.stickey,
actions=config.actions,
resize=config.resize,
seed=config.seed,
seed=config.seed + id,
)
env = wrappers.OneHotAction(env)
elif suite == "dmlab":
@ -190,18 +190,18 @@ def make_env(config, mode):
task,
mode if "train" in mode else "test",
config.action_repeat,
seed=config.seed,
seed=config.seed + id,
)
env = wrappers.OneHotAction(env)
elif suite == "memorymaze":
from envs.memorymaze import MemoryMaze
env = MemoryMaze(task, seed=config.seed)
env = MemoryMaze(task, seed=config.seed + id)
env = wrappers.OneHotAction(env)
elif suite == "crafter":
import envs.crafter as crafter
env = crafter.Crafter(task, config.size, seed=config.seed)
env = crafter.Crafter(task, config.size, seed=config.seed + id)
env = wrappers.OneHotAction(env)
elif suite == "minecraft":
import envs.minecraft as minecraft
@ -249,9 +249,9 @@ def main(config):
else:
directory = config.evaldir
eval_eps = tools.load_episodes(directory, limit=1)
make = lambda mode: make_env(config, mode)
train_envs = [make("train") for _ in range(config.envs)]
eval_envs = [make("eval") for _ in range(config.envs)]
make = lambda mode, id: make_env(config, mode, id)
train_envs = [make("train", i) for i in range(config.envs)]
eval_envs = [make("eval", i) for i in range(config.envs)]
if config.parallel:
train_envs = [Parallel(env, "process") for env in train_envs]
eval_envs = [Parallel(env, "process") for env in eval_envs]
@ -259,6 +259,7 @@ def main(config):
train_envs = [Damy(env) for env in train_envs]
eval_envs = [Damy(env) for env in eval_envs]
acts = train_envs[0].action_space
print("Action Space", acts)
config.num_actions = acts.n if hasattr(acts, "n") else acts.shape[0]
state = None