added the option for a deterministic run

This commit is contained in:
NM512 2023-08-16 21:46:06 +09:00
parent 68096d1f62
commit 606ec8af8c
7 changed files with 47 additions and 16 deletions

View File

@ -6,6 +6,7 @@ defaults:
offline_traindir: ''
offline_evaldir: ''
seed: 0
deterministic_run: False
steps: 1e6
parallel: False
eval_every: 1e4

View File

@ -186,7 +186,9 @@ def make_env(config, mode):
if suite == "dmc":
import envs.dmc as dmc
env = dmc.DeepMindControl(task, config.action_repeat, config.size)
env = dmc.DeepMindControl(
task, config.action_repeat, config.size, seed=config.seed
)
env = wrappers.NormalizeActions(env)
elif suite == "atari":
import envs.atari as atari
@ -201,24 +203,28 @@ def make_env(config, mode):
sticky=config.stickey,
actions=config.actions,
resize=config.resize,
seed=config.seed,
)
env = wrappers.OneHotAction(env)
elif suite == "dmlab":
import envs.dmlab as dmlab
env = dmlab.DeepMindLabyrinth(
task, mode if "train" in mode else "test", config.action_repeat
task,
mode if "train" in mode else "test",
config.action_repeat,
seed=config.seed,
)
env = wrappers.OneHotAction(env)
elif suite == "MemoryMaze":
from envs.memorymaze import MemoryMaze
env = MemoryMaze(task)
env = MemoryMaze(task, seed=config.seed)
env = wrappers.OneHotAction(env)
elif suite == "crafter":
import envs.crafter as crafter
env = crafter.Crafter(task, config.size)
env = crafter.Crafter(task, config.size, seed=config.seed)
env = wrappers.OneHotAction(env)
elif suite == "minecraft":
import envs.minecraft as minecraft
@ -236,6 +242,9 @@ def make_env(config, mode):
def main(config):
tools.set_seed_everywhere(config.seed)
if config.deterministic_run:
tools.enable_deterministic_run()
logdir = pathlib.Path(config.logdir).expanduser()
config.traindir = config.traindir or logdir / "train_eps"
config.evaldir = config.evaldir or logdir / "eval_eps"

View File

@ -5,7 +5,7 @@ import numpy as np
class Crafter:
metadata = {}
def __init__(self, task, size=(64, 64), seed=None):
def __init__(self, task, size=(64, 64), seed=0):
assert task in ("reward", "noreward")
import crafter

View File

@ -5,14 +5,18 @@ import numpy as np
class DeepMindControl:
metadata = {}
def __init__(self, name, action_repeat=1, size=(64, 64), camera=None):
def __init__(self, name, action_repeat=1, size=(64, 64), camera=None, seed=0):
domain, task = name.split("_", 1)
if domain == "cup": # Only domain with multiple words.
domain = "ball_in_cup"
if isinstance(domain, str):
from dm_control import suite
self._env = suite.load(domain, task)
self._env = suite.load(
domain,
task,
task_kwargs={"random": seed},
)
else:
assert task is None
self._env = domain()

View File

@ -5,11 +5,11 @@ import numpy as np
class MemoryMaze:
def __init__(self, task, obs_key="image", act_key="action", size=(64, 64)):
def __init__(self, task, obs_key="image", act_key="action", size=(64, 64), seed=0):
if task == "9x9":
self._env = gym.make("memory_maze:MemoryMaze-9x9-v0")
self._env = gym.make("memory_maze:MemoryMaze-9x9-v0", seed=seed)
elif task == "15x15":
self._env = gym.make("memory_maze:MemoryMaze-15x15-v0")
self._env = gym.make("memory_maze:MemoryMaze-15x15-v0", seed=seed)
else:
raise NotImplementedError(task)
self._obs_is_dict = hasattr(self._env.observation_space, "spaces")

View File

@ -11,7 +11,7 @@ protobuf==3.20.0
gym==0.19.0
dm_control==1.0.9
scipy==1.8.0
memory_maze==1.0.2
memory_maze==1.0.3
atari-py==0.2.9
crafter==1.8.0
opencv-python==4.7.0.72

View File

@ -6,7 +6,7 @@ import json
import pathlib
import re
import time
import uuid
import random
import numpy as np
@ -321,7 +321,7 @@ def from_generator(generator, batch_size):
def sample_episodes(episodes, length, seed=0):
random = np.random.RandomState(seed)
np_random = np.random.RandomState(seed)
while True:
size = 0
ret = None
@ -330,15 +330,17 @@ def sample_episodes(episodes, length, seed=0):
)
p = p / np.sum(p)
while size < length:
episode = random.choice(list(episodes.values()), p=p)
episode = np_random.choice(list(episodes.values()), p=p)
total = len(next(iter(episode.values())))
# make sure at least one transition included
if total < 2:
continue
if not ret:
index = int(random.randint(0, total - 1))
index = int(np_random.randint(0, total - 1))
ret = {
k: v[index : min(index + length, total)] for k, v in episode.items()
k: v[index : min(index + length, total)]
for k, v in episode.items()
if "log_" not in k
}
if "is_first" in ret:
ret["is_first"][0] = True
@ -351,6 +353,7 @@ def sample_episodes(episodes, length, seed=0):
ret[k], v[index : min(index + possible, total)], axis=0
)
for k, v in episode.items()
if "log_" not in k
}
if "is_first" in ret:
ret["is_first"][size] = True
@ -980,3 +983,17 @@ def tensorstats(tensor, prefix=None):
if prefix:
metrics = {f"{prefix}_{k}": v for k, v in metrics.items()}
return metrics
def set_seed_everywhere(seed):
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
def enable_deterministic_run():
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)