added the option for a deterministic run

This commit is contained in:
NM512 2023-08-16 21:46:06 +09:00
parent 68096d1f62
commit 606ec8af8c
7 changed files with 47 additions and 16 deletions

View File

@ -6,6 +6,7 @@ defaults:
offline_traindir: '' offline_traindir: ''
offline_evaldir: '' offline_evaldir: ''
seed: 0 seed: 0
deterministic_run: False
steps: 1e6 steps: 1e6
parallel: False parallel: False
eval_every: 1e4 eval_every: 1e4

View File

@ -186,7 +186,9 @@ def make_env(config, mode):
if suite == "dmc": if suite == "dmc":
import envs.dmc as dmc import envs.dmc as dmc
env = dmc.DeepMindControl(task, config.action_repeat, config.size) env = dmc.DeepMindControl(
task, config.action_repeat, config.size, seed=config.seed
)
env = wrappers.NormalizeActions(env) env = wrappers.NormalizeActions(env)
elif suite == "atari": elif suite == "atari":
import envs.atari as atari import envs.atari as atari
@ -201,24 +203,28 @@ def make_env(config, mode):
sticky=config.stickey, sticky=config.stickey,
actions=config.actions, actions=config.actions,
resize=config.resize, resize=config.resize,
seed=config.seed,
) )
env = wrappers.OneHotAction(env) env = wrappers.OneHotAction(env)
elif suite == "dmlab": elif suite == "dmlab":
import envs.dmlab as dmlab import envs.dmlab as dmlab
env = dmlab.DeepMindLabyrinth( env = dmlab.DeepMindLabyrinth(
task, mode if "train" in mode else "test", config.action_repeat task,
mode if "train" in mode else "test",
config.action_repeat,
seed=config.seed,
) )
env = wrappers.OneHotAction(env) env = wrappers.OneHotAction(env)
elif suite == "MemoryMaze": elif suite == "MemoryMaze":
from envs.memorymaze import MemoryMaze from envs.memorymaze import MemoryMaze
env = MemoryMaze(task) env = MemoryMaze(task, seed=config.seed)
env = wrappers.OneHotAction(env) env = wrappers.OneHotAction(env)
elif suite == "crafter": elif suite == "crafter":
import envs.crafter as crafter import envs.crafter as crafter
env = crafter.Crafter(task, config.size) env = crafter.Crafter(task, config.size, seed=config.seed)
env = wrappers.OneHotAction(env) env = wrappers.OneHotAction(env)
elif suite == "minecraft": elif suite == "minecraft":
import envs.minecraft as minecraft import envs.minecraft as minecraft
@ -236,6 +242,9 @@ def make_env(config, mode):
def main(config): def main(config):
tools.set_seed_everywhere(config.seed)
if config.deterministic_run:
tools.enable_deterministic_run()
logdir = pathlib.Path(config.logdir).expanduser() logdir = pathlib.Path(config.logdir).expanduser()
config.traindir = config.traindir or logdir / "train_eps" config.traindir = config.traindir or logdir / "train_eps"
config.evaldir = config.evaldir or logdir / "eval_eps" config.evaldir = config.evaldir or logdir / "eval_eps"

View File

@ -5,7 +5,7 @@ import numpy as np
class Crafter: class Crafter:
metadata = {} metadata = {}
def __init__(self, task, size=(64, 64), seed=None): def __init__(self, task, size=(64, 64), seed=0):
assert task in ("reward", "noreward") assert task in ("reward", "noreward")
import crafter import crafter

View File

@ -5,14 +5,18 @@ import numpy as np
class DeepMindControl: class DeepMindControl:
metadata = {} metadata = {}
def __init__(self, name, action_repeat=1, size=(64, 64), camera=None): def __init__(self, name, action_repeat=1, size=(64, 64), camera=None, seed=0):
domain, task = name.split("_", 1) domain, task = name.split("_", 1)
if domain == "cup": # Only domain with multiple words. if domain == "cup": # Only domain with multiple words.
domain = "ball_in_cup" domain = "ball_in_cup"
if isinstance(domain, str): if isinstance(domain, str):
from dm_control import suite from dm_control import suite
self._env = suite.load(domain, task) self._env = suite.load(
domain,
task,
task_kwargs={"random": seed},
)
else: else:
assert task is None assert task is None
self._env = domain() self._env = domain()

View File

@ -5,11 +5,11 @@ import numpy as np
class MemoryMaze: class MemoryMaze:
def __init__(self, task, obs_key="image", act_key="action", size=(64, 64)): def __init__(self, task, obs_key="image", act_key="action", size=(64, 64), seed=0):
if task == "9x9": if task == "9x9":
self._env = gym.make("memory_maze:MemoryMaze-9x9-v0") self._env = gym.make("memory_maze:MemoryMaze-9x9-v0", seed=seed)
elif task == "15x15": elif task == "15x15":
self._env = gym.make("memory_maze:MemoryMaze-15x15-v0") self._env = gym.make("memory_maze:MemoryMaze-15x15-v0", seed=seed)
else: else:
raise NotImplementedError(task) raise NotImplementedError(task)
self._obs_is_dict = hasattr(self._env.observation_space, "spaces") self._obs_is_dict = hasattr(self._env.observation_space, "spaces")

View File

@ -11,7 +11,7 @@ protobuf==3.20.0
gym==0.19.0 gym==0.19.0
dm_control==1.0.9 dm_control==1.0.9
scipy==1.8.0 scipy==1.8.0
memory_maze==1.0.2 memory_maze==1.0.3
atari-py==0.2.9 atari-py==0.2.9
crafter==1.8.0 crafter==1.8.0
opencv-python==4.7.0.72 opencv-python==4.7.0.72

View File

@ -6,7 +6,7 @@ import json
import pathlib import pathlib
import re import re
import time import time
import uuid import random
import numpy as np import numpy as np
@ -321,7 +321,7 @@ def from_generator(generator, batch_size):
def sample_episodes(episodes, length, seed=0): def sample_episodes(episodes, length, seed=0):
random = np.random.RandomState(seed) np_random = np.random.RandomState(seed)
while True: while True:
size = 0 size = 0
ret = None ret = None
@ -330,15 +330,17 @@ def sample_episodes(episodes, length, seed=0):
) )
p = p / np.sum(p) p = p / np.sum(p)
while size < length: while size < length:
episode = random.choice(list(episodes.values()), p=p) episode = np_random.choice(list(episodes.values()), p=p)
total = len(next(iter(episode.values()))) total = len(next(iter(episode.values())))
# make sure at least one transition included # make sure at least one transition included
if total < 2: if total < 2:
continue continue
if not ret: if not ret:
index = int(random.randint(0, total - 1)) index = int(np_random.randint(0, total - 1))
ret = { ret = {
k: v[index : min(index + length, total)] for k, v in episode.items() k: v[index : min(index + length, total)]
for k, v in episode.items()
if "log_" not in k
} }
if "is_first" in ret: if "is_first" in ret:
ret["is_first"][0] = True ret["is_first"][0] = True
@ -351,6 +353,7 @@ def sample_episodes(episodes, length, seed=0):
ret[k], v[index : min(index + possible, total)], axis=0 ret[k], v[index : min(index + possible, total)], axis=0
) )
for k, v in episode.items() for k, v in episode.items()
if "log_" not in k
} }
if "is_first" in ret: if "is_first" in ret:
ret["is_first"][size] = True ret["is_first"][size] = True
@ -980,3 +983,17 @@ def tensorstats(tensor, prefix=None):
if prefix: if prefix:
metrics = {f"{prefix}_{k}": v for k, v in metrics.items()} metrics = {f"{prefix}_{k}": v for k, v in metrics.items()}
return metrics return metrics
def set_seed_everywhere(seed):
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
def enable_deterministic_run():
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)