From 606ec8af8c096cb49658df70a85a0ca66ea6c3c8 Mon Sep 17 00:00:00 2001 From: NM512 Date: Wed, 16 Aug 2023 21:46:06 +0900 Subject: [PATCH] added the option for a deterministic run --- configs.yaml | 1 + dreamer.py | 17 +++++++++++++---- envs/crafter.py | 2 +- envs/dmc.py | 8 ++++++-- envs/memorymaze.py | 6 +++--- requirements.txt | 2 +- tools.py | 27 ++++++++++++++++++++++----- 7 files changed, 47 insertions(+), 16 deletions(-) diff --git a/configs.yaml b/configs.yaml index d023001..a9bd5bd 100644 --- a/configs.yaml +++ b/configs.yaml @@ -6,6 +6,7 @@ defaults: offline_traindir: '' offline_evaldir: '' seed: 0 + deterministic_run: False steps: 1e6 parallel: False eval_every: 1e4 diff --git a/dreamer.py b/dreamer.py index a02dec3..1b82f60 100644 --- a/dreamer.py +++ b/dreamer.py @@ -186,7 +186,9 @@ def make_env(config, mode): if suite == "dmc": import envs.dmc as dmc - env = dmc.DeepMindControl(task, config.action_repeat, config.size) + env = dmc.DeepMindControl( + task, config.action_repeat, config.size, seed=config.seed + ) env = wrappers.NormalizeActions(env) elif suite == "atari": import envs.atari as atari @@ -201,24 +203,28 @@ def make_env(config, mode): sticky=config.stickey, actions=config.actions, resize=config.resize, + seed=config.seed, ) env = wrappers.OneHotAction(env) elif suite == "dmlab": import envs.dmlab as dmlab env = dmlab.DeepMindLabyrinth( - task, mode if "train" in mode else "test", config.action_repeat + task, + mode if "train" in mode else "test", + config.action_repeat, + seed=config.seed, ) env = wrappers.OneHotAction(env) elif suite == "MemoryMaze": from envs.memorymaze import MemoryMaze - env = MemoryMaze(task) + env = MemoryMaze(task, seed=config.seed) env = wrappers.OneHotAction(env) elif suite == "crafter": import envs.crafter as crafter - env = crafter.Crafter(task, config.size) + env = crafter.Crafter(task, config.size, seed=config.seed) env = wrappers.OneHotAction(env) elif suite == "minecraft": import envs.minecraft as minecraft @@ -236,6 +242,9 @@ def make_env(config, mode): def main(config): + tools.set_seed_everywhere(config.seed) + if config.deterministic_run: + tools.enable_deterministic_run() logdir = pathlib.Path(config.logdir).expanduser() config.traindir = config.traindir or logdir / "train_eps" config.evaldir = config.evaldir or logdir / "eval_eps" diff --git a/envs/crafter.py b/envs/crafter.py index 5a67494..5d72483 100644 --- a/envs/crafter.py +++ b/envs/crafter.py @@ -5,7 +5,7 @@ import numpy as np class Crafter: metadata = {} - def __init__(self, task, size=(64, 64), seed=None): + def __init__(self, task, size=(64, 64), seed=0): assert task in ("reward", "noreward") import crafter diff --git a/envs/dmc.py b/envs/dmc.py index 1907410..874d1ad 100644 --- a/envs/dmc.py +++ b/envs/dmc.py @@ -5,14 +5,18 @@ import numpy as np class DeepMindControl: metadata = {} - def __init__(self, name, action_repeat=1, size=(64, 64), camera=None): + def __init__(self, name, action_repeat=1, size=(64, 64), camera=None, seed=0): domain, task = name.split("_", 1) if domain == "cup": # Only domain with multiple words. domain = "ball_in_cup" if isinstance(domain, str): from dm_control import suite - self._env = suite.load(domain, task) + self._env = suite.load( + domain, + task, + task_kwargs={"random": seed}, + ) else: assert task is None self._env = domain() diff --git a/envs/memorymaze.py b/envs/memorymaze.py index 93603a6..19ca980 100644 --- a/envs/memorymaze.py +++ b/envs/memorymaze.py @@ -5,11 +5,11 @@ import numpy as np class MemoryMaze: - def __init__(self, task, obs_key="image", act_key="action", size=(64, 64)): + def __init__(self, task, obs_key="image", act_key="action", size=(64, 64), seed=0): if task == "9x9": - self._env = gym.make("memory_maze:MemoryMaze-9x9-v0") + self._env = gym.make("memory_maze:MemoryMaze-9x9-v0", seed=seed) elif task == "15x15": - self._env = gym.make("memory_maze:MemoryMaze-15x15-v0") + self._env = gym.make("memory_maze:MemoryMaze-15x15-v0", seed=seed) else: raise NotImplementedError(task) self._obs_is_dict = hasattr(self._env.observation_space, "spaces") diff --git a/requirements.txt b/requirements.txt index fc7a2a3..930f0f5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,7 +11,7 @@ protobuf==3.20.0 gym==0.19.0 dm_control==1.0.9 scipy==1.8.0 -memory_maze==1.0.2 +memory_maze==1.0.3 atari-py==0.2.9 crafter==1.8.0 opencv-python==4.7.0.72 diff --git a/tools.py b/tools.py index d5da141..69b9edd 100644 --- a/tools.py +++ b/tools.py @@ -6,7 +6,7 @@ import json import pathlib import re import time -import uuid +import random import numpy as np @@ -321,7 +321,7 @@ def from_generator(generator, batch_size): def sample_episodes(episodes, length, seed=0): - random = np.random.RandomState(seed) + np_random = np.random.RandomState(seed) while True: size = 0 ret = None @@ -330,15 +330,17 @@ def sample_episodes(episodes, length, seed=0): ) p = p / np.sum(p) while size < length: - episode = random.choice(list(episodes.values()), p=p) + episode = np_random.choice(list(episodes.values()), p=p) total = len(next(iter(episode.values()))) # make sure at least one transition included if total < 2: continue if not ret: - index = int(random.randint(0, total - 1)) + index = int(np_random.randint(0, total - 1)) ret = { - k: v[index : min(index + length, total)] for k, v in episode.items() + k: v[index : min(index + length, total)] + for k, v in episode.items() + if "log_" not in k } if "is_first" in ret: ret["is_first"][0] = True @@ -351,6 +353,7 @@ def sample_episodes(episodes, length, seed=0): ret[k], v[index : min(index + possible, total)], axis=0 ) for k, v in episode.items() + if "log_" not in k } if "is_first" in ret: ret["is_first"][size] = True @@ -980,3 +983,17 @@ def tensorstats(tensor, prefix=None): if prefix: metrics = {f"{prefix}_{k}": v for k, v in metrics.items()} return metrics + + +def set_seed_everywhere(seed): + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + + +def enable_deterministic_run(): + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" + torch.backends.cudnn.benchmark = False + torch.use_deterministic_algorithms(True)