diff --git a/README.md b/README.md index c892ba5..90f793c 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,10 @@ Train the agent on Alien in Atari 100K: ``` python3 dreamer.py --configs atari100k --task atari_alien --logdir ./logdir/atari_alien ``` +Train the agent on Crafter: +``` +python3 dreamer.py --configs crafter --logdir ./logdir/crafter +``` Monitor results: ``` tensorboard --logdir ./logdir diff --git a/configs.yaml b/configs.yaml index fce3313..6311333 100644 --- a/configs.yaml +++ b/configs.yaml @@ -138,6 +138,25 @@ dmc_vision: encoder: {mlp_keys: '$^', cnn_keys: 'image'} decoder: {mlp_keys: '$^', cnn_keys: 'image'} +crafter: + task: crafter_reward + step: 1e6 + action_repeat: 1 + envs: 1 + train_ratio: 512 + video_pred_log: true + dyn_hidden: 1024 + dyn_deter: 4096 + units: 1024 + reward_layers: 5 + cont_layers: 5 + value_layers: 5 + actor_layers: 5 + encoder: {mlp_keys: '$^', cnn_keys: 'image', cnn_depth: 96, mlp_layers: 5, mlp_units: 1024} + decoder: {mlp_keys: '$^', cnn_keys: 'image', cnn_depth: 96, mlp_layers: 5, mlp_units: 1024} + actor_dist: 'onehot' + imag_gradient: 'reinforce' + atari100k: steps: 4e5 envs: 1 diff --git a/dreamer.py b/dreamer.py index 0832818..c32d66f 100644 --- a/dreamer.py +++ b/dreamer.py @@ -211,6 +211,11 @@ def make_env(config, logger, mode, train_eps, eval_eps): task, mode if "train" in mode else "test", config.action_repeat ) env = wrappers.OneHotAction(env) + elif suite == "crafter": + import envs.crafter as crafter + + env = crafter.Crafter(task, config.size) + env = wrappers.OneHotAction(env) else: raise NotImplementedError(suite) env = wrappers.TimeLimit(env, config.time_limit) diff --git a/envs/crafter.py b/envs/crafter.py new file mode 100644 index 0000000..cbc476f --- /dev/null +++ b/envs/crafter.py @@ -0,0 +1,70 @@ +import gym +import numpy as np + + +class Crafter: + def __init__(self, task, size=(64, 64), seed=None): + assert task in ("reward", "noreward") + import crafter + + self._env = crafter.Env(size=size, reward=(task == "reward"), seed=seed) + self._achievements = crafter.constants.achievements.copy() + + @property + def observation_space(self): + spaces = { + "image": gym.spaces.Box( + 0, 255, self._env.observation_space.shape, dtype=np.uint8 + ), + "reward": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.float32), + "is_first": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.uint8), + "is_last": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.uint8), + "is_terminal": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.uint8), + "log_reward": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.float32), + } + spaces.update( + { + f"log_achievement_{k}": gym.spaces.Box( + -np.inf, np.inf, (1,), dtype=np.float32 + ) + for k in self._achievements + } + ) + return gym.spaces.Dict(spaces) + + @property + def action_space(self): + action_space = self._env.action_space + action_space.discrete = True + return action_space + + def step(self, action): + image, reward, done, info = self._env.step(action) + reward = np.float32(reward) + log_achievements = { + f"log_achievement_{k}": info["achievements"][k] if info else 0 + for k in self._achievements + } + obs = { + "image": image, + "reward": reward, + "is_first": False, + "is_last": done, + "is_terminal": info["discount"] == 0, + "log_reward": np.float32(info["reward"] if info else 0.0), + **log_achievements, + } + return obs, reward, done, info + + def render(self): + return self._env.render() + + def reset(self): + image = self._env.reset() + obs = { + "image": image, + "is_first": True, + "is_last": False, + "is_terminal": False, + } + return obs diff --git a/envs/wrappers.py b/envs/wrappers.py index 1a4a58b..c94b8e9 100644 --- a/envs/wrappers.py +++ b/envs/wrappers.py @@ -179,18 +179,22 @@ class RewardObs: @property def observation_space(self): spaces = self._env.observation_space.spaces - assert "reward" not in spaces - spaces["reward"] = gym.spaces.Box(-np.inf, np.inf, shape=(1,), dtype=np.float32) + if "reward" not in spaces: + spaces["reward"] = gym.spaces.Box( + -np.inf, np.inf, shape=(1,), dtype=np.float32 + ) return gym.spaces.Dict(spaces) def step(self, action): obs, reward, done, info = self._env.step(action) - obs["reward"] = reward + if "reward" not in obs: + obs["reward"] = reward return obs, reward, done, info def reset(self): obs = self._env.reset() - obs["reward"] = 0.0 + if "reward" not in obs: + obs["reward"] = 0.0 return obs diff --git a/networks.py b/networks.py index 1ba7673..b628b73 100644 --- a/networks.py +++ b/networks.py @@ -347,7 +347,11 @@ class MultiEncoder(nn.Module): ): super(MultiEncoder, self).__init__() excluded = ("is_first", "is_last", "is_terminal", "reward") - shapes = {k: v for k, v in shapes.items() if k not in excluded} + shapes = { + k: v + for k, v in shapes.items() + if k not in excluded and not k.startswith("log_") + } self.cnn_shapes = { k: v for k, v in shapes.items() if len(v) == 3 and re.match(cnn_keys, k) }