added the option for a deterministic run
This commit is contained in:
		
							parent
							
								
									68096d1f62
								
							
						
					
					
						commit
						606ec8af8c
					
				@ -6,6 +6,7 @@ defaults:
 | 
				
			|||||||
  offline_traindir: ''
 | 
					  offline_traindir: ''
 | 
				
			||||||
  offline_evaldir: ''
 | 
					  offline_evaldir: ''
 | 
				
			||||||
  seed: 0
 | 
					  seed: 0
 | 
				
			||||||
 | 
					  deterministic_run: False
 | 
				
			||||||
  steps: 1e6
 | 
					  steps: 1e6
 | 
				
			||||||
  parallel: False
 | 
					  parallel: False
 | 
				
			||||||
  eval_every: 1e4
 | 
					  eval_every: 1e4
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										17
									
								
								dreamer.py
									
									
									
									
									
								
							
							
						
						
									
										17
									
								
								dreamer.py
									
									
									
									
									
								
							@ -186,7 +186,9 @@ def make_env(config, mode):
 | 
				
			|||||||
    if suite == "dmc":
 | 
					    if suite == "dmc":
 | 
				
			||||||
        import envs.dmc as dmc
 | 
					        import envs.dmc as dmc
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        env = dmc.DeepMindControl(task, config.action_repeat, config.size)
 | 
					        env = dmc.DeepMindControl(
 | 
				
			||||||
 | 
					            task, config.action_repeat, config.size, seed=config.seed
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
        env = wrappers.NormalizeActions(env)
 | 
					        env = wrappers.NormalizeActions(env)
 | 
				
			||||||
    elif suite == "atari":
 | 
					    elif suite == "atari":
 | 
				
			||||||
        import envs.atari as atari
 | 
					        import envs.atari as atari
 | 
				
			||||||
@ -201,24 +203,28 @@ def make_env(config, mode):
 | 
				
			|||||||
            sticky=config.stickey,
 | 
					            sticky=config.stickey,
 | 
				
			||||||
            actions=config.actions,
 | 
					            actions=config.actions,
 | 
				
			||||||
            resize=config.resize,
 | 
					            resize=config.resize,
 | 
				
			||||||
 | 
					            seed=config.seed,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        env = wrappers.OneHotAction(env)
 | 
					        env = wrappers.OneHotAction(env)
 | 
				
			||||||
    elif suite == "dmlab":
 | 
					    elif suite == "dmlab":
 | 
				
			||||||
        import envs.dmlab as dmlab
 | 
					        import envs.dmlab as dmlab
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        env = dmlab.DeepMindLabyrinth(
 | 
					        env = dmlab.DeepMindLabyrinth(
 | 
				
			||||||
            task, mode if "train" in mode else "test", config.action_repeat
 | 
					            task,
 | 
				
			||||||
 | 
					            mode if "train" in mode else "test",
 | 
				
			||||||
 | 
					            config.action_repeat,
 | 
				
			||||||
 | 
					            seed=config.seed,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        env = wrappers.OneHotAction(env)
 | 
					        env = wrappers.OneHotAction(env)
 | 
				
			||||||
    elif suite == "MemoryMaze":
 | 
					    elif suite == "MemoryMaze":
 | 
				
			||||||
        from envs.memorymaze import MemoryMaze
 | 
					        from envs.memorymaze import MemoryMaze
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        env = MemoryMaze(task)
 | 
					        env = MemoryMaze(task, seed=config.seed)
 | 
				
			||||||
        env = wrappers.OneHotAction(env)
 | 
					        env = wrappers.OneHotAction(env)
 | 
				
			||||||
    elif suite == "crafter":
 | 
					    elif suite == "crafter":
 | 
				
			||||||
        import envs.crafter as crafter
 | 
					        import envs.crafter as crafter
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        env = crafter.Crafter(task, config.size)
 | 
					        env = crafter.Crafter(task, config.size, seed=config.seed)
 | 
				
			||||||
        env = wrappers.OneHotAction(env)
 | 
					        env = wrappers.OneHotAction(env)
 | 
				
			||||||
    elif suite == "minecraft":
 | 
					    elif suite == "minecraft":
 | 
				
			||||||
        import envs.minecraft as minecraft
 | 
					        import envs.minecraft as minecraft
 | 
				
			||||||
@ -236,6 +242,9 @@ def make_env(config, mode):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def main(config):
 | 
					def main(config):
 | 
				
			||||||
 | 
					    tools.set_seed_everywhere(config.seed)
 | 
				
			||||||
 | 
					    if config.deterministic_run:
 | 
				
			||||||
 | 
					        tools.enable_deterministic_run()
 | 
				
			||||||
    logdir = pathlib.Path(config.logdir).expanduser()
 | 
					    logdir = pathlib.Path(config.logdir).expanduser()
 | 
				
			||||||
    config.traindir = config.traindir or logdir / "train_eps"
 | 
					    config.traindir = config.traindir or logdir / "train_eps"
 | 
				
			||||||
    config.evaldir = config.evaldir or logdir / "eval_eps"
 | 
					    config.evaldir = config.evaldir or logdir / "eval_eps"
 | 
				
			||||||
 | 
				
			|||||||
@ -5,7 +5,7 @@ import numpy as np
 | 
				
			|||||||
class Crafter:
 | 
					class Crafter:
 | 
				
			||||||
    metadata = {}
 | 
					    metadata = {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __init__(self, task, size=(64, 64), seed=None):
 | 
					    def __init__(self, task, size=(64, 64), seed=0):
 | 
				
			||||||
        assert task in ("reward", "noreward")
 | 
					        assert task in ("reward", "noreward")
 | 
				
			||||||
        import crafter
 | 
					        import crafter
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -5,14 +5,18 @@ import numpy as np
 | 
				
			|||||||
class DeepMindControl:
 | 
					class DeepMindControl:
 | 
				
			||||||
    metadata = {}
 | 
					    metadata = {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __init__(self, name, action_repeat=1, size=(64, 64), camera=None):
 | 
					    def __init__(self, name, action_repeat=1, size=(64, 64), camera=None, seed=0):
 | 
				
			||||||
        domain, task = name.split("_", 1)
 | 
					        domain, task = name.split("_", 1)
 | 
				
			||||||
        if domain == "cup":  # Only domain with multiple words.
 | 
					        if domain == "cup":  # Only domain with multiple words.
 | 
				
			||||||
            domain = "ball_in_cup"
 | 
					            domain = "ball_in_cup"
 | 
				
			||||||
        if isinstance(domain, str):
 | 
					        if isinstance(domain, str):
 | 
				
			||||||
            from dm_control import suite
 | 
					            from dm_control import suite
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            self._env = suite.load(domain, task)
 | 
					            self._env = suite.load(
 | 
				
			||||||
 | 
					                domain,
 | 
				
			||||||
 | 
					                task,
 | 
				
			||||||
 | 
					                task_kwargs={"random": seed},
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            assert task is None
 | 
					            assert task is None
 | 
				
			||||||
            self._env = domain()
 | 
					            self._env = domain()
 | 
				
			||||||
 | 
				
			|||||||
@ -5,11 +5,11 @@ import numpy as np
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class MemoryMaze:
 | 
					class MemoryMaze:
 | 
				
			||||||
    def __init__(self, task, obs_key="image", act_key="action", size=(64, 64)):
 | 
					    def __init__(self, task, obs_key="image", act_key="action", size=(64, 64), seed=0):
 | 
				
			||||||
        if task == "9x9":
 | 
					        if task == "9x9":
 | 
				
			||||||
            self._env = gym.make("memory_maze:MemoryMaze-9x9-v0")
 | 
					            self._env = gym.make("memory_maze:MemoryMaze-9x9-v0", seed=seed)
 | 
				
			||||||
        elif task == "15x15":
 | 
					        elif task == "15x15":
 | 
				
			||||||
            self._env = gym.make("memory_maze:MemoryMaze-15x15-v0")
 | 
					            self._env = gym.make("memory_maze:MemoryMaze-15x15-v0", seed=seed)
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            raise NotImplementedError(task)
 | 
					            raise NotImplementedError(task)
 | 
				
			||||||
        self._obs_is_dict = hasattr(self._env.observation_space, "spaces")
 | 
					        self._obs_is_dict = hasattr(self._env.observation_space, "spaces")
 | 
				
			||||||
 | 
				
			|||||||
@ -11,7 +11,7 @@ protobuf==3.20.0
 | 
				
			|||||||
gym==0.19.0
 | 
					gym==0.19.0
 | 
				
			||||||
dm_control==1.0.9
 | 
					dm_control==1.0.9
 | 
				
			||||||
scipy==1.8.0
 | 
					scipy==1.8.0
 | 
				
			||||||
memory_maze==1.0.2
 | 
					memory_maze==1.0.3
 | 
				
			||||||
atari-py==0.2.9
 | 
					atari-py==0.2.9
 | 
				
			||||||
crafter==1.8.0
 | 
					crafter==1.8.0
 | 
				
			||||||
opencv-python==4.7.0.72
 | 
					opencv-python==4.7.0.72
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										27
									
								
								tools.py
									
									
									
									
									
								
							
							
						
						
									
										27
									
								
								tools.py
									
									
									
									
									
								
							@ -6,7 +6,7 @@ import json
 | 
				
			|||||||
import pathlib
 | 
					import pathlib
 | 
				
			||||||
import re
 | 
					import re
 | 
				
			||||||
import time
 | 
					import time
 | 
				
			||||||
import uuid
 | 
					import random
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import numpy as np
 | 
					import numpy as np
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -321,7 +321,7 @@ def from_generator(generator, batch_size):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def sample_episodes(episodes, length, seed=0):
 | 
					def sample_episodes(episodes, length, seed=0):
 | 
				
			||||||
    random = np.random.RandomState(seed)
 | 
					    np_random = np.random.RandomState(seed)
 | 
				
			||||||
    while True:
 | 
					    while True:
 | 
				
			||||||
        size = 0
 | 
					        size = 0
 | 
				
			||||||
        ret = None
 | 
					        ret = None
 | 
				
			||||||
@ -330,15 +330,17 @@ def sample_episodes(episodes, length, seed=0):
 | 
				
			|||||||
        )
 | 
					        )
 | 
				
			||||||
        p = p / np.sum(p)
 | 
					        p = p / np.sum(p)
 | 
				
			||||||
        while size < length:
 | 
					        while size < length:
 | 
				
			||||||
            episode = random.choice(list(episodes.values()), p=p)
 | 
					            episode = np_random.choice(list(episodes.values()), p=p)
 | 
				
			||||||
            total = len(next(iter(episode.values())))
 | 
					            total = len(next(iter(episode.values())))
 | 
				
			||||||
            # make sure at least one transition included
 | 
					            # make sure at least one transition included
 | 
				
			||||||
            if total < 2:
 | 
					            if total < 2:
 | 
				
			||||||
                continue
 | 
					                continue
 | 
				
			||||||
            if not ret:
 | 
					            if not ret:
 | 
				
			||||||
                index = int(random.randint(0, total - 1))
 | 
					                index = int(np_random.randint(0, total - 1))
 | 
				
			||||||
                ret = {
 | 
					                ret = {
 | 
				
			||||||
                    k: v[index : min(index + length, total)] for k, v in episode.items()
 | 
					                    k: v[index : min(index + length, total)]
 | 
				
			||||||
 | 
					                    for k, v in episode.items()
 | 
				
			||||||
 | 
					                    if "log_" not in k
 | 
				
			||||||
                }
 | 
					                }
 | 
				
			||||||
                if "is_first" in ret:
 | 
					                if "is_first" in ret:
 | 
				
			||||||
                    ret["is_first"][0] = True
 | 
					                    ret["is_first"][0] = True
 | 
				
			||||||
@ -351,6 +353,7 @@ def sample_episodes(episodes, length, seed=0):
 | 
				
			|||||||
                        ret[k], v[index : min(index + possible, total)], axis=0
 | 
					                        ret[k], v[index : min(index + possible, total)], axis=0
 | 
				
			||||||
                    )
 | 
					                    )
 | 
				
			||||||
                    for k, v in episode.items()
 | 
					                    for k, v in episode.items()
 | 
				
			||||||
 | 
					                    if "log_" not in k
 | 
				
			||||||
                }
 | 
					                }
 | 
				
			||||||
                if "is_first" in ret:
 | 
					                if "is_first" in ret:
 | 
				
			||||||
                    ret["is_first"][size] = True
 | 
					                    ret["is_first"][size] = True
 | 
				
			||||||
@ -980,3 +983,17 @@ def tensorstats(tensor, prefix=None):
 | 
				
			|||||||
    if prefix:
 | 
					    if prefix:
 | 
				
			||||||
        metrics = {f"{prefix}_{k}": v for k, v in metrics.items()}
 | 
					        metrics = {f"{prefix}_{k}": v for k, v in metrics.items()}
 | 
				
			||||||
    return metrics
 | 
					    return metrics
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def set_seed_everywhere(seed):
 | 
				
			||||||
 | 
					    torch.manual_seed(seed)
 | 
				
			||||||
 | 
					    if torch.cuda.is_available():
 | 
				
			||||||
 | 
					        torch.cuda.manual_seed_all(seed)
 | 
				
			||||||
 | 
					    np.random.seed(seed)
 | 
				
			||||||
 | 
					    random.seed(seed)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def enable_deterministic_run():
 | 
				
			||||||
 | 
					    os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
 | 
				
			||||||
 | 
					    torch.backends.cudnn.benchmark = False
 | 
				
			||||||
 | 
					    torch.use_deterministic_algorithms(True)
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user