diff --git a/README.md b/README.md index c892ba5..90f793c 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,10 @@ Train the agent on Alien in Atari 100K: ``` python3 dreamer.py --configs atari100k --task atari_alien --logdir ./logdir/atari_alien ``` +Train the agent on Crafter: +``` +python3 dreamer.py --configs crafter --logdir ./logdir/crafter +``` Monitor results: ``` tensorboard --logdir ./logdir diff --git a/configs.yaml b/configs.yaml index 18e96a1..33b6efd 100644 --- a/configs.yaml +++ b/configs.yaml @@ -120,22 +120,46 @@ defaults: disag_units: 400 disag_action_cond: False -dmc_vision: - steps: 1e6 - train_ratio: 512 - video_pred_log: true - encoder: {mlp_keys: '$^', cnn_keys: 'image'} - decoder: {mlp_keys: '$^', cnn_keys: 'image'} - dmc_proprio: steps: 5e5 + action_repeat: 2 + envs: 4 train_ratio: 512 video_pred_log: false encoder: {mlp_keys: '.*', cnn_keys: '$^'} decoder: {mlp_keys: '.*', cnn_keys: '$^'} +dmc_vision: + steps: 1e6 + action_repeat: 2 + envs: 4 + train_ratio: 512 + video_pred_log: true + encoder: {mlp_keys: '$^', cnn_keys: 'image'} + decoder: {mlp_keys: '$^', cnn_keys: 'image'} + +crafter: + task: crafter_reward + step: 1e6 + action_repeat: 1 + envs: 1 + train_ratio: 512 + video_pred_log: true + dyn_hidden: 1024 + dyn_deter: 4096 + units: 1024 + reward_layers: 5 + cont_layers: 5 + value_layers: 5 + actor_layers: 5 + encoder: {mlp_keys: '$^', cnn_keys: 'image', cnn_depth: 96, mlp_layers: 5, mlp_units: 1024} + decoder: {mlp_keys: '$^', cnn_keys: 'image', cnn_depth: 96, mlp_layers: 5, mlp_units: 1024} + actor_dist: 'onehot' + imag_gradient: 'reinforce' + atari100k: steps: 4e5 + envs: 1 action_repeat: 4 eval_episode_num: 100 stickey: False diff --git a/dreamer.py b/dreamer.py index 24750d1..fec590f 100644 --- a/dreamer.py +++ b/dreamer.py @@ -55,7 +55,9 @@ class Dreamer(nn.Module): self._task_behavior = models.ImagBehavior( config, self._wm, config.behavior_stop_grad ) - if config.compile and os.name != 'nt': # compilation is not supported on windows + if ( + config.compile and os.name != "nt" + ): # compilation is not supported on windows self._wm = torch.compile(self._wm) self._task_behavior = torch.compile(self._task_behavior) reward = lambda f, s, a: self._wm.heads["reward"](f).mean() @@ -156,7 +158,6 @@ class Dreamer(nn.Module): post, context, mets = self._wm._train(data) metrics.update(mets) start = post - # start['deter'] (16, 64, 512) reward = lambda f, s, a: self._wm.heads["reward"]( self._wm.dynamics.get_feat(s) ).mode() @@ -221,6 +222,11 @@ def make_env(config, logger, mode, train_eps, eval_eps): from envs.memorymaze import MemoryMaze env = MemoryMaze(env) env = wrappers.OneHotAction(env) + elif suite == "crafter": + import envs.crafter as crafter + + env = crafter.Crafter(task, config.size) + env = wrappers.OneHotAction(env) else: raise NotImplementedError(suite) env = wrappers.TimeLimit(env, config.time_limit) diff --git a/envs/crafter.py b/envs/crafter.py new file mode 100644 index 0000000..cbc476f --- /dev/null +++ b/envs/crafter.py @@ -0,0 +1,70 @@ +import gym +import numpy as np + + +class Crafter: + def __init__(self, task, size=(64, 64), seed=None): + assert task in ("reward", "noreward") + import crafter + + self._env = crafter.Env(size=size, reward=(task == "reward"), seed=seed) + self._achievements = crafter.constants.achievements.copy() + + @property + def observation_space(self): + spaces = { + "image": gym.spaces.Box( + 0, 255, self._env.observation_space.shape, dtype=np.uint8 + ), + "reward": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.float32), + "is_first": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.uint8), + "is_last": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.uint8), + "is_terminal": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.uint8), + "log_reward": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.float32), + } + spaces.update( + { + f"log_achievement_{k}": gym.spaces.Box( + -np.inf, np.inf, (1,), dtype=np.float32 + ) + for k in self._achievements + } + ) + return gym.spaces.Dict(spaces) + + @property + def action_space(self): + action_space = self._env.action_space + action_space.discrete = True + return action_space + + def step(self, action): + image, reward, done, info = self._env.step(action) + reward = np.float32(reward) + log_achievements = { + f"log_achievement_{k}": info["achievements"][k] if info else 0 + for k in self._achievements + } + obs = { + "image": image, + "reward": reward, + "is_first": False, + "is_last": done, + "is_terminal": info["discount"] == 0, + "log_reward": np.float32(info["reward"] if info else 0.0), + **log_achievements, + } + return obs, reward, done, info + + def render(self): + return self._env.render() + + def reset(self): + image = self._env.reset() + obs = { + "image": image, + "is_first": True, + "is_last": False, + "is_terminal": False, + } + return obs diff --git a/envs/wrappers.py b/envs/wrappers.py index 1a4a58b..c94b8e9 100644 --- a/envs/wrappers.py +++ b/envs/wrappers.py @@ -179,18 +179,22 @@ class RewardObs: @property def observation_space(self): spaces = self._env.observation_space.spaces - assert "reward" not in spaces - spaces["reward"] = gym.spaces.Box(-np.inf, np.inf, shape=(1,), dtype=np.float32) + if "reward" not in spaces: + spaces["reward"] = gym.spaces.Box( + -np.inf, np.inf, shape=(1,), dtype=np.float32 + ) return gym.spaces.Dict(spaces) def step(self, action): obs, reward, done, info = self._env.step(action) - obs["reward"] = reward + if "reward" not in obs: + obs["reward"] = reward return obs, reward, done, info def reset(self): obs = self._env.reset() - obs["reward"] = 0.0 + if "reward" not in obs: + obs["reward"] = 0.0 return obs diff --git a/exploration.py b/exploration.py index 5eefbf9..bb2e60b 100644 --- a/exploration.py +++ b/exploration.py @@ -58,7 +58,9 @@ class Plan2Explore(nn.Module): "feat": config.dyn_stoch + config.dyn_deter, }[self._config.disag_target] kw = dict( - inp_dim=feat_size + config.num_actions if config.disag_action_cond else 0, # pytorch version + inp_dim=feat_size + config.num_actions + if config.disag_action_cond + else 0, # pytorch version shape=size, layers=config.disag_layers, units=config.disag_units, @@ -93,7 +95,9 @@ class Plan2Explore(nn.Module): }[self._config.disag_target] inputs = context["feat"] if self._config.disag_action_cond: - inputs = torch.concat([inputs, torch.Tensor(data["action"]).to(self._config.device)], -1) + inputs = torch.concat( + [inputs, torch.Tensor(data["action"]).to(self._config.device)], -1 + ) metrics.update(self._train_ensemble(inputs, target)) metrics.update(self._behavior._train(start, self._intrinsic_reward)[-1]) return None, metrics diff --git a/models.py b/models.py index 8402d10..2c00597 100644 --- a/models.py +++ b/models.py @@ -399,13 +399,10 @@ class ImagBehavior(nn.Module): if self._config.future_entropy and self._config.actor_state_entropy() > 0: reward += self._config.actor_state_entropy() * state_ent value = self.value(imag_feat).mode() - # value(15, 960, ch) - # action(15, 960, ch) - # discount(15, 960, ch) target = tools.lambda_return( - reward[:-1], + reward[1:], value[:-1], - discount[:-1], + discount[1:], bootstrap=value[-1], lambda_=self._config.discount_lambda, axis=0, diff --git a/networks.py b/networks.py index 3a767fa..b628b73 100644 --- a/networks.py +++ b/networks.py @@ -215,7 +215,9 @@ class RSSM(nn.Module): is_first, is_first.shape + (1,) * (len(val.shape) - len(is_first.shape)), ) - prev_state[key] = val * (1.0 - is_first_r) + init_state[key] * is_first_r + prev_state[key] = ( + val * (1.0 - is_first_r) + init_state[key] * is_first_r + ) prior = self.img_step(prev_state, prev_action, None, sample) if self._shared: @@ -345,7 +347,11 @@ class MultiEncoder(nn.Module): ): super(MultiEncoder, self).__init__() excluded = ("is_first", "is_last", "is_terminal", "reward") - shapes = {k: v for k, v in shapes.items() if k not in excluded} + shapes = { + k: v + for k, v in shapes.items() + if k not in excluded and not k.startswith("log_") + } self.cnn_shapes = { k: v for k, v in shapes.items() if len(v) == 3 and re.match(cnn_keys, k) } @@ -547,6 +553,7 @@ class ConvDecoder(nn.Module): self._embed_size = minres**2 * depth * 2 ** (layer_num - 1) self._linear_layer = nn.Linear(feat_size, self._embed_size) + self._linear_layer.apply(tools.weight_init) in_dim = self._embed_size // (minres**2) layers = [] diff --git a/tools.py b/tools.py index bc46903..8b52379 100644 --- a/tools.py +++ b/tools.py @@ -804,7 +804,9 @@ def weight_init(m): denoms = (in_num + out_num) / 2.0 scale = 1.0 / denoms std = np.sqrt(scale) / 0.87962566103423978 - nn.init.trunc_normal_(m.weight.data, mean=0.0, std=std, a=-2.0, b=2.0) + nn.init.trunc_normal_( + m.weight.data, mean=0.0, std=std, a=-2.0 * std, b=2.0 * std + ) if hasattr(m.bias, "data"): m.bias.data.fill_(0.0) elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):