added the option for a deterministic run
This commit is contained in:
parent
68096d1f62
commit
606ec8af8c
@ -6,6 +6,7 @@ defaults:
|
|||||||
offline_traindir: ''
|
offline_traindir: ''
|
||||||
offline_evaldir: ''
|
offline_evaldir: ''
|
||||||
seed: 0
|
seed: 0
|
||||||
|
deterministic_run: False
|
||||||
steps: 1e6
|
steps: 1e6
|
||||||
parallel: False
|
parallel: False
|
||||||
eval_every: 1e4
|
eval_every: 1e4
|
||||||
|
17
dreamer.py
17
dreamer.py
@ -186,7 +186,9 @@ def make_env(config, mode):
|
|||||||
if suite == "dmc":
|
if suite == "dmc":
|
||||||
import envs.dmc as dmc
|
import envs.dmc as dmc
|
||||||
|
|
||||||
env = dmc.DeepMindControl(task, config.action_repeat, config.size)
|
env = dmc.DeepMindControl(
|
||||||
|
task, config.action_repeat, config.size, seed=config.seed
|
||||||
|
)
|
||||||
env = wrappers.NormalizeActions(env)
|
env = wrappers.NormalizeActions(env)
|
||||||
elif suite == "atari":
|
elif suite == "atari":
|
||||||
import envs.atari as atari
|
import envs.atari as atari
|
||||||
@ -201,24 +203,28 @@ def make_env(config, mode):
|
|||||||
sticky=config.stickey,
|
sticky=config.stickey,
|
||||||
actions=config.actions,
|
actions=config.actions,
|
||||||
resize=config.resize,
|
resize=config.resize,
|
||||||
|
seed=config.seed,
|
||||||
)
|
)
|
||||||
env = wrappers.OneHotAction(env)
|
env = wrappers.OneHotAction(env)
|
||||||
elif suite == "dmlab":
|
elif suite == "dmlab":
|
||||||
import envs.dmlab as dmlab
|
import envs.dmlab as dmlab
|
||||||
|
|
||||||
env = dmlab.DeepMindLabyrinth(
|
env = dmlab.DeepMindLabyrinth(
|
||||||
task, mode if "train" in mode else "test", config.action_repeat
|
task,
|
||||||
|
mode if "train" in mode else "test",
|
||||||
|
config.action_repeat,
|
||||||
|
seed=config.seed,
|
||||||
)
|
)
|
||||||
env = wrappers.OneHotAction(env)
|
env = wrappers.OneHotAction(env)
|
||||||
elif suite == "MemoryMaze":
|
elif suite == "MemoryMaze":
|
||||||
from envs.memorymaze import MemoryMaze
|
from envs.memorymaze import MemoryMaze
|
||||||
|
|
||||||
env = MemoryMaze(task)
|
env = MemoryMaze(task, seed=config.seed)
|
||||||
env = wrappers.OneHotAction(env)
|
env = wrappers.OneHotAction(env)
|
||||||
elif suite == "crafter":
|
elif suite == "crafter":
|
||||||
import envs.crafter as crafter
|
import envs.crafter as crafter
|
||||||
|
|
||||||
env = crafter.Crafter(task, config.size)
|
env = crafter.Crafter(task, config.size, seed=config.seed)
|
||||||
env = wrappers.OneHotAction(env)
|
env = wrappers.OneHotAction(env)
|
||||||
elif suite == "minecraft":
|
elif suite == "minecraft":
|
||||||
import envs.minecraft as minecraft
|
import envs.minecraft as minecraft
|
||||||
@ -236,6 +242,9 @@ def make_env(config, mode):
|
|||||||
|
|
||||||
|
|
||||||
def main(config):
|
def main(config):
|
||||||
|
tools.set_seed_everywhere(config.seed)
|
||||||
|
if config.deterministic_run:
|
||||||
|
tools.enable_deterministic_run()
|
||||||
logdir = pathlib.Path(config.logdir).expanduser()
|
logdir = pathlib.Path(config.logdir).expanduser()
|
||||||
config.traindir = config.traindir or logdir / "train_eps"
|
config.traindir = config.traindir or logdir / "train_eps"
|
||||||
config.evaldir = config.evaldir or logdir / "eval_eps"
|
config.evaldir = config.evaldir or logdir / "eval_eps"
|
||||||
|
@ -5,7 +5,7 @@ import numpy as np
|
|||||||
class Crafter:
|
class Crafter:
|
||||||
metadata = {}
|
metadata = {}
|
||||||
|
|
||||||
def __init__(self, task, size=(64, 64), seed=None):
|
def __init__(self, task, size=(64, 64), seed=0):
|
||||||
assert task in ("reward", "noreward")
|
assert task in ("reward", "noreward")
|
||||||
import crafter
|
import crafter
|
||||||
|
|
||||||
|
@ -5,14 +5,18 @@ import numpy as np
|
|||||||
class DeepMindControl:
|
class DeepMindControl:
|
||||||
metadata = {}
|
metadata = {}
|
||||||
|
|
||||||
def __init__(self, name, action_repeat=1, size=(64, 64), camera=None):
|
def __init__(self, name, action_repeat=1, size=(64, 64), camera=None, seed=0):
|
||||||
domain, task = name.split("_", 1)
|
domain, task = name.split("_", 1)
|
||||||
if domain == "cup": # Only domain with multiple words.
|
if domain == "cup": # Only domain with multiple words.
|
||||||
domain = "ball_in_cup"
|
domain = "ball_in_cup"
|
||||||
if isinstance(domain, str):
|
if isinstance(domain, str):
|
||||||
from dm_control import suite
|
from dm_control import suite
|
||||||
|
|
||||||
self._env = suite.load(domain, task)
|
self._env = suite.load(
|
||||||
|
domain,
|
||||||
|
task,
|
||||||
|
task_kwargs={"random": seed},
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
assert task is None
|
assert task is None
|
||||||
self._env = domain()
|
self._env = domain()
|
||||||
|
@ -5,11 +5,11 @@ import numpy as np
|
|||||||
|
|
||||||
|
|
||||||
class MemoryMaze:
|
class MemoryMaze:
|
||||||
def __init__(self, task, obs_key="image", act_key="action", size=(64, 64)):
|
def __init__(self, task, obs_key="image", act_key="action", size=(64, 64), seed=0):
|
||||||
if task == "9x9":
|
if task == "9x9":
|
||||||
self._env = gym.make("memory_maze:MemoryMaze-9x9-v0")
|
self._env = gym.make("memory_maze:MemoryMaze-9x9-v0", seed=seed)
|
||||||
elif task == "15x15":
|
elif task == "15x15":
|
||||||
self._env = gym.make("memory_maze:MemoryMaze-15x15-v0")
|
self._env = gym.make("memory_maze:MemoryMaze-15x15-v0", seed=seed)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(task)
|
raise NotImplementedError(task)
|
||||||
self._obs_is_dict = hasattr(self._env.observation_space, "spaces")
|
self._obs_is_dict = hasattr(self._env.observation_space, "spaces")
|
||||||
|
@ -11,7 +11,7 @@ protobuf==3.20.0
|
|||||||
gym==0.19.0
|
gym==0.19.0
|
||||||
dm_control==1.0.9
|
dm_control==1.0.9
|
||||||
scipy==1.8.0
|
scipy==1.8.0
|
||||||
memory_maze==1.0.2
|
memory_maze==1.0.3
|
||||||
atari-py==0.2.9
|
atari-py==0.2.9
|
||||||
crafter==1.8.0
|
crafter==1.8.0
|
||||||
opencv-python==4.7.0.72
|
opencv-python==4.7.0.72
|
||||||
|
27
tools.py
27
tools.py
@ -6,7 +6,7 @@ import json
|
|||||||
import pathlib
|
import pathlib
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
import uuid
|
import random
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@ -321,7 +321,7 @@ def from_generator(generator, batch_size):
|
|||||||
|
|
||||||
|
|
||||||
def sample_episodes(episodes, length, seed=0):
|
def sample_episodes(episodes, length, seed=0):
|
||||||
random = np.random.RandomState(seed)
|
np_random = np.random.RandomState(seed)
|
||||||
while True:
|
while True:
|
||||||
size = 0
|
size = 0
|
||||||
ret = None
|
ret = None
|
||||||
@ -330,15 +330,17 @@ def sample_episodes(episodes, length, seed=0):
|
|||||||
)
|
)
|
||||||
p = p / np.sum(p)
|
p = p / np.sum(p)
|
||||||
while size < length:
|
while size < length:
|
||||||
episode = random.choice(list(episodes.values()), p=p)
|
episode = np_random.choice(list(episodes.values()), p=p)
|
||||||
total = len(next(iter(episode.values())))
|
total = len(next(iter(episode.values())))
|
||||||
# make sure at least one transition included
|
# make sure at least one transition included
|
||||||
if total < 2:
|
if total < 2:
|
||||||
continue
|
continue
|
||||||
if not ret:
|
if not ret:
|
||||||
index = int(random.randint(0, total - 1))
|
index = int(np_random.randint(0, total - 1))
|
||||||
ret = {
|
ret = {
|
||||||
k: v[index : min(index + length, total)] for k, v in episode.items()
|
k: v[index : min(index + length, total)]
|
||||||
|
for k, v in episode.items()
|
||||||
|
if "log_" not in k
|
||||||
}
|
}
|
||||||
if "is_first" in ret:
|
if "is_first" in ret:
|
||||||
ret["is_first"][0] = True
|
ret["is_first"][0] = True
|
||||||
@ -351,6 +353,7 @@ def sample_episodes(episodes, length, seed=0):
|
|||||||
ret[k], v[index : min(index + possible, total)], axis=0
|
ret[k], v[index : min(index + possible, total)], axis=0
|
||||||
)
|
)
|
||||||
for k, v in episode.items()
|
for k, v in episode.items()
|
||||||
|
if "log_" not in k
|
||||||
}
|
}
|
||||||
if "is_first" in ret:
|
if "is_first" in ret:
|
||||||
ret["is_first"][size] = True
|
ret["is_first"][size] = True
|
||||||
@ -980,3 +983,17 @@ def tensorstats(tensor, prefix=None):
|
|||||||
if prefix:
|
if prefix:
|
||||||
metrics = {f"{prefix}_{k}": v for k, v in metrics.items()}
|
metrics = {f"{prefix}_{k}": v for k, v in metrics.items()}
|
||||||
return metrics
|
return metrics
|
||||||
|
|
||||||
|
|
||||||
|
def set_seed_everywhere(seed):
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
random.seed(seed)
|
||||||
|
|
||||||
|
|
||||||
|
def enable_deterministic_run():
|
||||||
|
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
|
||||||
|
torch.backends.cudnn.benchmark = False
|
||||||
|
torch.use_deterministic_algorithms(True)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user