added the option for a deterministic run
This commit is contained in:
parent
68096d1f62
commit
606ec8af8c
@ -6,6 +6,7 @@ defaults:
|
||||
offline_traindir: ''
|
||||
offline_evaldir: ''
|
||||
seed: 0
|
||||
deterministic_run: False
|
||||
steps: 1e6
|
||||
parallel: False
|
||||
eval_every: 1e4
|
||||
|
17
dreamer.py
17
dreamer.py
@ -186,7 +186,9 @@ def make_env(config, mode):
|
||||
if suite == "dmc":
|
||||
import envs.dmc as dmc
|
||||
|
||||
env = dmc.DeepMindControl(task, config.action_repeat, config.size)
|
||||
env = dmc.DeepMindControl(
|
||||
task, config.action_repeat, config.size, seed=config.seed
|
||||
)
|
||||
env = wrappers.NormalizeActions(env)
|
||||
elif suite == "atari":
|
||||
import envs.atari as atari
|
||||
@ -201,24 +203,28 @@ def make_env(config, mode):
|
||||
sticky=config.stickey,
|
||||
actions=config.actions,
|
||||
resize=config.resize,
|
||||
seed=config.seed,
|
||||
)
|
||||
env = wrappers.OneHotAction(env)
|
||||
elif suite == "dmlab":
|
||||
import envs.dmlab as dmlab
|
||||
|
||||
env = dmlab.DeepMindLabyrinth(
|
||||
task, mode if "train" in mode else "test", config.action_repeat
|
||||
task,
|
||||
mode if "train" in mode else "test",
|
||||
config.action_repeat,
|
||||
seed=config.seed,
|
||||
)
|
||||
env = wrappers.OneHotAction(env)
|
||||
elif suite == "MemoryMaze":
|
||||
from envs.memorymaze import MemoryMaze
|
||||
|
||||
env = MemoryMaze(task)
|
||||
env = MemoryMaze(task, seed=config.seed)
|
||||
env = wrappers.OneHotAction(env)
|
||||
elif suite == "crafter":
|
||||
import envs.crafter as crafter
|
||||
|
||||
env = crafter.Crafter(task, config.size)
|
||||
env = crafter.Crafter(task, config.size, seed=config.seed)
|
||||
env = wrappers.OneHotAction(env)
|
||||
elif suite == "minecraft":
|
||||
import envs.minecraft as minecraft
|
||||
@ -236,6 +242,9 @@ def make_env(config, mode):
|
||||
|
||||
|
||||
def main(config):
|
||||
tools.set_seed_everywhere(config.seed)
|
||||
if config.deterministic_run:
|
||||
tools.enable_deterministic_run()
|
||||
logdir = pathlib.Path(config.logdir).expanduser()
|
||||
config.traindir = config.traindir or logdir / "train_eps"
|
||||
config.evaldir = config.evaldir or logdir / "eval_eps"
|
||||
|
@ -5,7 +5,7 @@ import numpy as np
|
||||
class Crafter:
|
||||
metadata = {}
|
||||
|
||||
def __init__(self, task, size=(64, 64), seed=None):
|
||||
def __init__(self, task, size=(64, 64), seed=0):
|
||||
assert task in ("reward", "noreward")
|
||||
import crafter
|
||||
|
||||
|
@ -5,14 +5,18 @@ import numpy as np
|
||||
class DeepMindControl:
|
||||
metadata = {}
|
||||
|
||||
def __init__(self, name, action_repeat=1, size=(64, 64), camera=None):
|
||||
def __init__(self, name, action_repeat=1, size=(64, 64), camera=None, seed=0):
|
||||
domain, task = name.split("_", 1)
|
||||
if domain == "cup": # Only domain with multiple words.
|
||||
domain = "ball_in_cup"
|
||||
if isinstance(domain, str):
|
||||
from dm_control import suite
|
||||
|
||||
self._env = suite.load(domain, task)
|
||||
self._env = suite.load(
|
||||
domain,
|
||||
task,
|
||||
task_kwargs={"random": seed},
|
||||
)
|
||||
else:
|
||||
assert task is None
|
||||
self._env = domain()
|
||||
|
@ -5,11 +5,11 @@ import numpy as np
|
||||
|
||||
|
||||
class MemoryMaze:
|
||||
def __init__(self, task, obs_key="image", act_key="action", size=(64, 64)):
|
||||
def __init__(self, task, obs_key="image", act_key="action", size=(64, 64), seed=0):
|
||||
if task == "9x9":
|
||||
self._env = gym.make("memory_maze:MemoryMaze-9x9-v0")
|
||||
self._env = gym.make("memory_maze:MemoryMaze-9x9-v0", seed=seed)
|
||||
elif task == "15x15":
|
||||
self._env = gym.make("memory_maze:MemoryMaze-15x15-v0")
|
||||
self._env = gym.make("memory_maze:MemoryMaze-15x15-v0", seed=seed)
|
||||
else:
|
||||
raise NotImplementedError(task)
|
||||
self._obs_is_dict = hasattr(self._env.observation_space, "spaces")
|
||||
|
@ -11,7 +11,7 @@ protobuf==3.20.0
|
||||
gym==0.19.0
|
||||
dm_control==1.0.9
|
||||
scipy==1.8.0
|
||||
memory_maze==1.0.2
|
||||
memory_maze==1.0.3
|
||||
atari-py==0.2.9
|
||||
crafter==1.8.0
|
||||
opencv-python==4.7.0.72
|
||||
|
27
tools.py
27
tools.py
@ -6,7 +6,7 @@ import json
|
||||
import pathlib
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -321,7 +321,7 @@ def from_generator(generator, batch_size):
|
||||
|
||||
|
||||
def sample_episodes(episodes, length, seed=0):
|
||||
random = np.random.RandomState(seed)
|
||||
np_random = np.random.RandomState(seed)
|
||||
while True:
|
||||
size = 0
|
||||
ret = None
|
||||
@ -330,15 +330,17 @@ def sample_episodes(episodes, length, seed=0):
|
||||
)
|
||||
p = p / np.sum(p)
|
||||
while size < length:
|
||||
episode = random.choice(list(episodes.values()), p=p)
|
||||
episode = np_random.choice(list(episodes.values()), p=p)
|
||||
total = len(next(iter(episode.values())))
|
||||
# make sure at least one transition included
|
||||
if total < 2:
|
||||
continue
|
||||
if not ret:
|
||||
index = int(random.randint(0, total - 1))
|
||||
index = int(np_random.randint(0, total - 1))
|
||||
ret = {
|
||||
k: v[index : min(index + length, total)] for k, v in episode.items()
|
||||
k: v[index : min(index + length, total)]
|
||||
for k, v in episode.items()
|
||||
if "log_" not in k
|
||||
}
|
||||
if "is_first" in ret:
|
||||
ret["is_first"][0] = True
|
||||
@ -351,6 +353,7 @@ def sample_episodes(episodes, length, seed=0):
|
||||
ret[k], v[index : min(index + possible, total)], axis=0
|
||||
)
|
||||
for k, v in episode.items()
|
||||
if "log_" not in k
|
||||
}
|
||||
if "is_first" in ret:
|
||||
ret["is_first"][size] = True
|
||||
@ -980,3 +983,17 @@ def tensorstats(tensor, prefix=None):
|
||||
if prefix:
|
||||
metrics = {f"{prefix}_{k}": v for k, v in metrics.items()}
|
||||
return metrics
|
||||
|
||||
|
||||
def set_seed_everywhere(seed):
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
|
||||
|
||||
def enable_deterministic_run():
|
||||
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
|
||||
torch.backends.cudnn.benchmark = False
|
||||
torch.use_deterministic_algorithms(True)
|
||||
|
Loading…
x
Reference in New Issue
Block a user