added the option for a deterministic run
This commit is contained in:
		
							parent
							
								
									68096d1f62
								
							
						
					
					
						commit
						606ec8af8c
					
				@ -6,6 +6,7 @@ defaults:
 | 
			
		||||
  offline_traindir: ''
 | 
			
		||||
  offline_evaldir: ''
 | 
			
		||||
  seed: 0
 | 
			
		||||
  deterministic_run: False
 | 
			
		||||
  steps: 1e6
 | 
			
		||||
  parallel: False
 | 
			
		||||
  eval_every: 1e4
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										17
									
								
								dreamer.py
									
									
									
									
									
								
							
							
						
						
									
										17
									
								
								dreamer.py
									
									
									
									
									
								
							@ -186,7 +186,9 @@ def make_env(config, mode):
 | 
			
		||||
    if suite == "dmc":
 | 
			
		||||
        import envs.dmc as dmc
 | 
			
		||||
 | 
			
		||||
        env = dmc.DeepMindControl(task, config.action_repeat, config.size)
 | 
			
		||||
        env = dmc.DeepMindControl(
 | 
			
		||||
            task, config.action_repeat, config.size, seed=config.seed
 | 
			
		||||
        )
 | 
			
		||||
        env = wrappers.NormalizeActions(env)
 | 
			
		||||
    elif suite == "atari":
 | 
			
		||||
        import envs.atari as atari
 | 
			
		||||
@ -201,24 +203,28 @@ def make_env(config, mode):
 | 
			
		||||
            sticky=config.stickey,
 | 
			
		||||
            actions=config.actions,
 | 
			
		||||
            resize=config.resize,
 | 
			
		||||
            seed=config.seed,
 | 
			
		||||
        )
 | 
			
		||||
        env = wrappers.OneHotAction(env)
 | 
			
		||||
    elif suite == "dmlab":
 | 
			
		||||
        import envs.dmlab as dmlab
 | 
			
		||||
 | 
			
		||||
        env = dmlab.DeepMindLabyrinth(
 | 
			
		||||
            task, mode if "train" in mode else "test", config.action_repeat
 | 
			
		||||
            task,
 | 
			
		||||
            mode if "train" in mode else "test",
 | 
			
		||||
            config.action_repeat,
 | 
			
		||||
            seed=config.seed,
 | 
			
		||||
        )
 | 
			
		||||
        env = wrappers.OneHotAction(env)
 | 
			
		||||
    elif suite == "MemoryMaze":
 | 
			
		||||
        from envs.memorymaze import MemoryMaze
 | 
			
		||||
 | 
			
		||||
        env = MemoryMaze(task)
 | 
			
		||||
        env = MemoryMaze(task, seed=config.seed)
 | 
			
		||||
        env = wrappers.OneHotAction(env)
 | 
			
		||||
    elif suite == "crafter":
 | 
			
		||||
        import envs.crafter as crafter
 | 
			
		||||
 | 
			
		||||
        env = crafter.Crafter(task, config.size)
 | 
			
		||||
        env = crafter.Crafter(task, config.size, seed=config.seed)
 | 
			
		||||
        env = wrappers.OneHotAction(env)
 | 
			
		||||
    elif suite == "minecraft":
 | 
			
		||||
        import envs.minecraft as minecraft
 | 
			
		||||
@ -236,6 +242,9 @@ def make_env(config, mode):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def main(config):
 | 
			
		||||
    tools.set_seed_everywhere(config.seed)
 | 
			
		||||
    if config.deterministic_run:
 | 
			
		||||
        tools.enable_deterministic_run()
 | 
			
		||||
    logdir = pathlib.Path(config.logdir).expanduser()
 | 
			
		||||
    config.traindir = config.traindir or logdir / "train_eps"
 | 
			
		||||
    config.evaldir = config.evaldir or logdir / "eval_eps"
 | 
			
		||||
 | 
			
		||||
@ -5,7 +5,7 @@ import numpy as np
 | 
			
		||||
class Crafter:
 | 
			
		||||
    metadata = {}
 | 
			
		||||
 | 
			
		||||
    def __init__(self, task, size=(64, 64), seed=None):
 | 
			
		||||
    def __init__(self, task, size=(64, 64), seed=0):
 | 
			
		||||
        assert task in ("reward", "noreward")
 | 
			
		||||
        import crafter
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -5,14 +5,18 @@ import numpy as np
 | 
			
		||||
class DeepMindControl:
 | 
			
		||||
    metadata = {}
 | 
			
		||||
 | 
			
		||||
    def __init__(self, name, action_repeat=1, size=(64, 64), camera=None):
 | 
			
		||||
    def __init__(self, name, action_repeat=1, size=(64, 64), camera=None, seed=0):
 | 
			
		||||
        domain, task = name.split("_", 1)
 | 
			
		||||
        if domain == "cup":  # Only domain with multiple words.
 | 
			
		||||
            domain = "ball_in_cup"
 | 
			
		||||
        if isinstance(domain, str):
 | 
			
		||||
            from dm_control import suite
 | 
			
		||||
 | 
			
		||||
            self._env = suite.load(domain, task)
 | 
			
		||||
            self._env = suite.load(
 | 
			
		||||
                domain,
 | 
			
		||||
                task,
 | 
			
		||||
                task_kwargs={"random": seed},
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            assert task is None
 | 
			
		||||
            self._env = domain()
 | 
			
		||||
 | 
			
		||||
@ -5,11 +5,11 @@ import numpy as np
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class MemoryMaze:
 | 
			
		||||
    def __init__(self, task, obs_key="image", act_key="action", size=(64, 64)):
 | 
			
		||||
    def __init__(self, task, obs_key="image", act_key="action", size=(64, 64), seed=0):
 | 
			
		||||
        if task == "9x9":
 | 
			
		||||
            self._env = gym.make("memory_maze:MemoryMaze-9x9-v0")
 | 
			
		||||
            self._env = gym.make("memory_maze:MemoryMaze-9x9-v0", seed=seed)
 | 
			
		||||
        elif task == "15x15":
 | 
			
		||||
            self._env = gym.make("memory_maze:MemoryMaze-15x15-v0")
 | 
			
		||||
            self._env = gym.make("memory_maze:MemoryMaze-15x15-v0", seed=seed)
 | 
			
		||||
        else:
 | 
			
		||||
            raise NotImplementedError(task)
 | 
			
		||||
        self._obs_is_dict = hasattr(self._env.observation_space, "spaces")
 | 
			
		||||
 | 
			
		||||
@ -11,7 +11,7 @@ protobuf==3.20.0
 | 
			
		||||
gym==0.19.0
 | 
			
		||||
dm_control==1.0.9
 | 
			
		||||
scipy==1.8.0
 | 
			
		||||
memory_maze==1.0.2
 | 
			
		||||
memory_maze==1.0.3
 | 
			
		||||
atari-py==0.2.9
 | 
			
		||||
crafter==1.8.0
 | 
			
		||||
opencv-python==4.7.0.72
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										27
									
								
								tools.py
									
									
									
									
									
								
							
							
						
						
									
										27
									
								
								tools.py
									
									
									
									
									
								
							@ -6,7 +6,7 @@ import json
 | 
			
		||||
import pathlib
 | 
			
		||||
import re
 | 
			
		||||
import time
 | 
			
		||||
import uuid
 | 
			
		||||
import random
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
 | 
			
		||||
@ -321,7 +321,7 @@ def from_generator(generator, batch_size):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def sample_episodes(episodes, length, seed=0):
 | 
			
		||||
    random = np.random.RandomState(seed)
 | 
			
		||||
    np_random = np.random.RandomState(seed)
 | 
			
		||||
    while True:
 | 
			
		||||
        size = 0
 | 
			
		||||
        ret = None
 | 
			
		||||
@ -330,15 +330,17 @@ def sample_episodes(episodes, length, seed=0):
 | 
			
		||||
        )
 | 
			
		||||
        p = p / np.sum(p)
 | 
			
		||||
        while size < length:
 | 
			
		||||
            episode = random.choice(list(episodes.values()), p=p)
 | 
			
		||||
            episode = np_random.choice(list(episodes.values()), p=p)
 | 
			
		||||
            total = len(next(iter(episode.values())))
 | 
			
		||||
            # make sure at least one transition included
 | 
			
		||||
            if total < 2:
 | 
			
		||||
                continue
 | 
			
		||||
            if not ret:
 | 
			
		||||
                index = int(random.randint(0, total - 1))
 | 
			
		||||
                index = int(np_random.randint(0, total - 1))
 | 
			
		||||
                ret = {
 | 
			
		||||
                    k: v[index : min(index + length, total)] for k, v in episode.items()
 | 
			
		||||
                    k: v[index : min(index + length, total)]
 | 
			
		||||
                    for k, v in episode.items()
 | 
			
		||||
                    if "log_" not in k
 | 
			
		||||
                }
 | 
			
		||||
                if "is_first" in ret:
 | 
			
		||||
                    ret["is_first"][0] = True
 | 
			
		||||
@ -351,6 +353,7 @@ def sample_episodes(episodes, length, seed=0):
 | 
			
		||||
                        ret[k], v[index : min(index + possible, total)], axis=0
 | 
			
		||||
                    )
 | 
			
		||||
                    for k, v in episode.items()
 | 
			
		||||
                    if "log_" not in k
 | 
			
		||||
                }
 | 
			
		||||
                if "is_first" in ret:
 | 
			
		||||
                    ret["is_first"][size] = True
 | 
			
		||||
@ -980,3 +983,17 @@ def tensorstats(tensor, prefix=None):
 | 
			
		||||
    if prefix:
 | 
			
		||||
        metrics = {f"{prefix}_{k}": v for k, v in metrics.items()}
 | 
			
		||||
    return metrics
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def set_seed_everywhere(seed):
 | 
			
		||||
    torch.manual_seed(seed)
 | 
			
		||||
    if torch.cuda.is_available():
 | 
			
		||||
        torch.cuda.manual_seed_all(seed)
 | 
			
		||||
    np.random.seed(seed)
 | 
			
		||||
    random.seed(seed)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def enable_deterministic_run():
 | 
			
		||||
    os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
 | 
			
		||||
    torch.backends.cudnn.benchmark = False
 | 
			
		||||
    torch.use_deterministic_algorithms(True)
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user