diff --git a/README.md b/README.md index 46a4fda..0478402 100644 --- a/README.md +++ b/README.md @@ -17,8 +17,17 @@ Monitor results: tensorboard --logdir $ABSOLUTEPATH_TO_SAVE_LOG ``` +## ToDo +- [ ] Prototyping +- [ ] Modify implementation details based on the author's implementation +- [ ] Evaluate on visual DMC suite(~10 tasks) +- [ ] Add other tasks and corresponding model sizes implementation +- [ ] Continuous implementation improvement + + ## Acknowledgments This code is heavily inspired by the following works: +- danijar's Dreamer-v3 jax implementation: https://github.com/danijar/dreamerv3 - danijar's Dreamer-v2 tensorflow implementation: https://github.com/danijar/dreamerv2 - jsikyoon's Dreamer-v2 pytorch implementation: https://github.com/jsikyoon/dreamer-torch - RajGhugare19's Dreamer-v2 pytorch implementation: https://github.com/RajGhugare19/dreamerv2 diff --git a/configs.yaml b/configs.yaml index 796e42a..40b58f2 100644 --- a/configs.yaml +++ b/configs.yaml @@ -8,9 +8,9 @@ defaults: seed: 0 steps: 1e6 eval_every: 1e4 + eval_episode_num: 10 log_every: 1e4 reset_every: 0 - #gpu_growth: True device: 'cuda:0' precision: 16 debug: False @@ -25,9 +25,6 @@ defaults: grayscale: False prefill: 2500 eval_noise: 0.0 - reward_trans: 'symlog' - obs_trans: 'normalize' - critic_trans: 'symlog' reward_EMA: True # Model @@ -36,8 +33,8 @@ defaults: dyn_deter: 512 dyn_stoch: 32 dyn_discrete: 32 - dyn_input_layers: 2 - dyn_output_layers: 2 + dyn_input_layers: 1 + dyn_output_layers: 1 dyn_rec_depth: 1 dyn_shared: False dyn_mean_act: 'none' @@ -53,11 +50,10 @@ defaults: act: 'SiLU' norm: 'LayerNorm' cnn_depth: 32 - encoder_kernels: [3, 3, 3, 3] - decoder_kernels: [3, 3, 3, 3] - # changed here - value_head: 'twohot' - reward_head: 'twohot' + encoder_kernels: [4, 4, 4, 4] + decoder_kernels: [4, 4, 4, 4] + value_head: 'twohot_symlog' + reward_head: 'twohot_symlog' kl_lscale: '0.1' kl_rscale: '0.5' kl_free: '1.0' @@ -71,7 +67,7 @@ defaults: # Training batch_size: 16 batch_length: 64 - train_every: 5 + train_ratio: 512 train_steps: 1 pretrain: 100 model_lr: 1e-4 @@ -85,9 +81,8 @@ defaults: dataset_size: 0 oversample_ends: False slow_value_target: True - slow_actor_target: True - slow_target_update: 100 - slow_target_fraction: 0.01 + slow_target_update: 1 + slow_target_fraction: 0.02 opt: 'adam' # Behavior. @@ -95,16 +90,15 @@ defaults: discount_lambda: 0.95 imag_horizon: 15 imag_gradient: 'dynamics' - imag_gradient_mix: '0.1' + imag_gradient_mix: '0.0' imag_sample: True - actor_dist: 'trunc_normal' + actor_dist: 'normal' actor_entropy: '3e-4' actor_state_entropy: 0.0 actor_init_std: 1.0 actor_min_std: 0.1 - actor_disc: 5 + actor_max_std: 1.0 actor_temp: 0.1 - actor_outscale: 0.0 expl_amount: 0.0 eval_state_mean: False collect_dyn_sample: True @@ -134,3 +128,14 @@ debug: batch_size: 10 batch_length: 20 +cheetah: + task: 'dmc_cheetah_run' + +pendulum: + task: 'dmc_pendulum_swingup' + +cup: + task: 'dmc_cup_catch' + +acrobot: + task: 'dmc_acrobot_swingup' diff --git a/dreamer.py b/dreamer.py index b855e64..eff53ce 100644 --- a/dreamer.py +++ b/dreamer.py @@ -22,6 +22,7 @@ import torch from torch import nn from torch import distributions as torchd + to_np = lambda x: x.detach().cpu().numpy() @@ -31,7 +32,8 @@ class Dreamer(nn.Module): self._config = config self._logger = logger self._should_log = tools.Every(config.log_every) - self._should_train = tools.Every(config.train_every) + batch_steps = config.batch_size * config.batch_length + self._should_train = tools.Every(batch_steps / config.train_ratio) self._should_pretrain = tools.Once() self._should_reset = tools.Every(config.reset_every) self._should_expl = tools.Until(int(config.expl_until / config.action_repeat)) @@ -146,16 +148,17 @@ class Dreamer(nn.Module): post, context, mets = self._wm._train(data) metrics.update(mets) start = post + # start['deter'] (16, 64, 512) if self._config.pred_discount: # Last step could be terminal. - start = {k: v[:, :-1] for k, v in post.items()} - context = {k: v[:, :-1] for k, v in context.items()} + start = {k: v[:-1] for k, v in post.items()} + context = {k: v[:-1] for k, v in context.items()} reward = lambda f, s, a: self._wm.heads["reward"]( self._wm.dynamics.get_feat(s) ).mode() metrics.update(self._task_behavior._train(start, reward)[-1]) if self._config.expl_behavior != "greedy": if self._config.pred_discount: - data = {k: v[:, :-1] for k, v in data.items()} + data = {k: v[:-1] for k, v in data.items()} mets = self._expl_behavior.train(start, context, data)[-1] metrics.update({"expl_" + key: value for key, value in mets.items()}) for name, value in metrics.items(): @@ -205,7 +208,12 @@ def make_env(config, logger, mode, train_eps, eval_eps): if (mode == "train") or (mode == "eval"): callbacks = [ functools.partial( - process_episode, config, logger, mode, train_eps, eval_eps + ProcessEpisodeWrap.process_episode, + config, + logger, + mode, + train_eps, + eval_eps, ) ] env = wrappers.CollectDataset(env, callbacks) @@ -213,31 +221,51 @@ def make_env(config, logger, mode, train_eps, eval_eps): return env -def process_episode(config, logger, mode, train_eps, eval_eps, episode): - directory = dict(train=config.traindir, eval=config.evaldir)[mode] - cache = dict(train=train_eps, eval=eval_eps)[mode] - filename = tools.save_episodes(directory, [episode])[0] - length = len(episode["reward"]) - 1 - score = float(episode["reward"].astype(np.float64).sum()) - video = episode["image"] - if mode == "eval": - cache.clear() - if mode == "train" and config.dataset_size: - total = 0 - for key, ep in reversed(sorted(cache.items(), key=lambda x: x[0])): - if total <= config.dataset_size - length: - total += len(ep["reward"]) - 1 +class ProcessEpisodeWrap: + eval_scores = [] + eval_lengths = [] + + @classmethod + def process_episode(cls, config, logger, mode, train_eps, eval_eps, episode): + directory = dict(train=config.traindir, eval=config.evaldir)[mode] + cache = dict(train=train_eps, eval=eval_eps)[mode] + # this saved episodes is given as train_eps or eval_eps from next call + filename = tools.save_episodes(directory, [episode])[0] + length = len(episode["reward"]) - 1 + score = float(episode["reward"].astype(np.float64).sum()) + video = episode["image"] + cache[str(filename)] = episode + if mode == "eval": + cls.eval_scores.append(score) + cls.eval_lengths.append(length) + # save when enought number of episodes are stored + if len(cls.eval_scores) < config.eval_episode_num: + return else: - del cache[key] - logger.scalar("dataset_size", total + length) - cache[str(filename)] = episode - print(f"{mode.title()} episode has {length} steps and return {score:.1f}.") - logger.scalar(f"{mode}_return", score) - logger.scalar(f"{mode}_length", length) - logger.scalar(f"{mode}_episodes", len(cache)) - if mode == "eval" or config.expl_gifs: - logger.video(f"{mode}_policy", video[None]) - logger.write() + score = sum(cls.eval_scores) / len(cls.eval_scores) + length = sum(cls.eval_lengths) / len(cls.eval_lengths) + episode_num = len(cls.eval_scores) + cls.eval_scores = [] + cls.eval_lengths = [] + + if mode == "train" and config.dataset_size: + total = 0 + for key, ep in reversed(sorted(cache.items(), key=lambda x: x[0])): + if total <= config.dataset_size - length: + total += len(ep["reward"]) - 1 + else: + del cache[key] + logger.scalar("dataset_size", total + length) + print(f"{mode.title()} episode has {length} steps and return {score:.1f}.") + logger.scalar(f"{mode}_return", score) + logger.scalar(f"{mode}_length", length) + logger.scalar( + f"{mode}_episodes", len(cache) if mode == "train" else episode_num + ) + if mode == "eval" or config.expl_gifs: + # only last video in eval videos is preservad + logger.video(f"{mode}_policy", video[None]) + logger.write() def main(config): @@ -315,7 +343,7 @@ def main(config): video_pred = agent._wm.video_pred(next(eval_dataset)) logger.video("eval_openl", to_np(video_pred)) eval_policy = functools.partial(agent, training=False) - tools.simulate(eval_policy, eval_envs, episodes=1) + tools.simulate(eval_policy, eval_envs, episodes=config.eval_episode_num) print("Start training.") state = tools.simulate(agent, train_envs, config.eval_every, state=state) torch.save(agent.state_dict(), logdir / "latest_model.pt") diff --git a/models.py b/models.py index 7489a17..7460130 100644 --- a/models.py +++ b/models.py @@ -10,30 +10,22 @@ import tools to_np = lambda x: x.detach().cpu().numpy() -def symlog(x): - return torch.sign(x) * torch.log(torch.abs(x) + 1.0) - - -def symexp(x): - return torch.sign(x) * (torch.exp(torch.abs(x)) - 1.0) - - class RewardEMA(object): """running mean and std""" def __init__(self, device, alpha=1e-2): self.device = device - self.scale = torch.zeros((1,)).to(device) + self.values = torch.zeros((2,)).to(device) self.alpha = alpha self.range = torch.tensor([0.05, 0.95]).to(device) def __call__(self, x): flat_x = torch.flatten(x.detach()) x_quantile = torch.quantile(input=flat_x, q=self.range) - scale = x_quantile[1] - x_quantile[0] - new_scale = self.alpha * scale + (1 - self.alpha) * self.scale - self.scale = new_scale - return x / torch.clip(self.scale, min=1.0) + self.values = self.alpha * x_quantile + (1 - self.alpha) * self.values + scale = torch.clip(self.values[1] - self.values[0], min=1.0) + offset = self.values[0] + return offset.detach(), scale.detach() class WorldModel(nn.Module): @@ -93,7 +85,7 @@ class WorldModel(nn.Module): shape, config.decoder_kernels, ) - if config.reward_head == "twohot": + if config.reward_head == "twohot_symlog": self.heads["reward"] = networks.DenseHead( feat_size, # pytorch version (255,), @@ -102,6 +94,7 @@ class WorldModel(nn.Module): config.act, config.norm, dist=config.reward_head, + outscale=0.0, ) else: self.heads["reward"] = networks.DenseHead( @@ -112,9 +105,8 @@ class WorldModel(nn.Module): config.act, config.norm, dist=config.reward_head, + outscale=0.0, ) - # added this - self.heads["reward"].apply(tools.weight_init) if config.pred_discount: self.heads["discount"] = networks.DenseHead( feat_size, # pytorch version @@ -163,8 +155,6 @@ class WorldModel(nn.Module): feat = self.dynamics.get_feat(post) feat = feat if grad_head else feat.detach() pred = head(feat) - # if name == 'image': - # losses[name] = torch.nn.functional.mse_loss(pred.mode(), data[name], 'sum') like = pred.log_prob(data[name]) likes[name] = like losses[name] = -torch.mean(like) * self._scales.get(name, 1.0) @@ -196,24 +186,9 @@ class WorldModel(nn.Module): def preprocess(self, obs): obs = obs.copy() - if self._config.obs_trans == "normalize": - obs["image"] = torch.Tensor(obs["image"]) / 255.0 - 0.5 - elif self._config.obs_trans == "identity": - obs["image"] = torch.Tensor(obs["image"]) - elif self._config.obs_trans == "symlog": - obs["image"] = symlog(torch.Tensor(obs["image"])) - else: - raise NotImplemented(f"{self._config.reward_trans} is not implemented") - if self._config.reward_trans == "tanh": - # (batch_size, batch_length) -> (batch_size, batch_length, 1) - obs["reward"] = torch.tanh(torch.Tensor(obs["reward"])).unsqueeze(-1) - elif self._config.reward_trans == "identity": - # (batch_size, batch_length) -> (batch_size, batch_length, 1) - obs["reward"] = torch.Tensor(obs["reward"]).unsqueeze(-1) - elif self._config.reward_trans == "symlog": - obs["reward"] = symlog(torch.Tensor(obs["reward"])).unsqueeze(-1) - else: - raise NotImplemented(f"{self._config.reward_trans} is not implemented") + obs["image"] = torch.Tensor(obs["image"]) / 255.0 - 0.5 + # (batch_size, batch_length) -> (batch_size, batch_length, 1) + obs["reward"] = torch.Tensor(obs["reward"]).unsqueeze(-1) if "discount" in obs: obs["discount"] *= self._config.discount # (batch_size, batch_length) -> (batch_size, batch_length, 1) @@ -234,13 +209,9 @@ class WorldModel(nn.Module): reward_prior = self.heads["reward"](self.dynamics.get_feat(prior)).mode() # observed image is given until 5 steps model = torch.cat([recon[:, :5], openl], 1) - if self._config.obs_trans == "normalize": - truth = data["image"][:6] + 0.5 - model += 0.5 - elif self._config.obs_trans == "symlog": - truth = symexp(data["image"][:6]) / 255.0 - model = symexp(model) / 255.0 - error = (model - truth + 1) / 2 + truth = data["image"][:6] + 0.5 + model = model + 0.5 + error = (model - truth + 1.0) / 2.0 return torch.cat([truth, model, error], 2) @@ -267,11 +238,11 @@ class ImagBehavior(nn.Module): config.actor_dist, config.actor_init_std, config.actor_min_std, - config.actor_dist, + config.actor_max_std, config.actor_temp, - config.actor_outscale, + outscale=1.0, ) # action_dist -> action_disc? - if config.value_head == "twohot": + if config.value_head == "twohot_symlog": self.value = networks.DenseHead( feat_size, # pytorch version (255,), @@ -280,6 +251,7 @@ class ImagBehavior(nn.Module): config.act, config.norm, config.value_head, + outscale=0.0, ) else: self.value = networks.DenseHead( @@ -290,9 +262,9 @@ class ImagBehavior(nn.Module): config.act, config.norm, config.value_head, + outscale=0.0, ) - self.value.apply(tools.weight_init) - if config.slow_value_target or config.slow_actor_target: + if config.slow_value_target: self._slow_value = copy.deepcopy(self.value) self._updates = 0 kw = dict(wd=config.weight_decay, opt=config.opt, use_amp=self._use_amp) @@ -335,21 +307,12 @@ class ImagBehavior(nn.Module): start, self.actor, self._config.imag_horizon, repeats ) reward = objective(imag_feat, imag_state, imag_action) - if self._config.reward_trans == "symlog": - # rescale predicted reward by head['reward'] - reward = symexp(reward) actor_ent = self.actor(imag_feat).entropy() state_ent = self._world_model.dynamics.get_dist(imag_state).entropy() # this target is not scaled # slow is flag to indicate whether slow_target is used for lambda-return - target, weights = self._compute_target( - imag_feat, - imag_state, - imag_action, - reward, - actor_ent, - state_ent, - self._config.slow_actor_target, + target, weights, base = self._compute_target( + imag_feat, imag_state, imag_action, reward, actor_ent, state_ent ) actor_loss, mets = self._compute_actor_loss( imag_feat, @@ -359,42 +322,31 @@ class ImagBehavior(nn.Module): actor_ent, state_ent, weights, + base, ) metrics.update(mets) - if self._config.slow_value_target != self._config.slow_actor_target: - target, weights = self._compute_target( - imag_feat, - imag_state, - imag_action, - reward, - actor_ent, - state_ent, - self._config.slow_value_target, - ) value_input = imag_feat with tools.RequiresGrad(self.value): with torch.cuda.amp.autocast(self._use_amp): value = self.value(value_input[:-1].detach()) target = torch.stack(target, dim=1) - # only critic target is processed using symlog(not actor) - if self._config.critic_trans == "symlog": - metrics["unscaled_target_mean"] = to_np(torch.mean(target)) - target = symlog(target) # (time, batch, 1), (time, batch, 1) -> (time, batch) value_loss = -value.log_prob(target.detach()) + slow_target = self._slow_value(value_input[:-1].detach()) + if self._config.slow_value_target: + value_loss = value_loss - value.log_prob( + slow_target.mode().detach() + ) if self._config.value_decay: value_loss += self._config.value_decay * value.mode() # (time, batch, 1), (time, batch, 1) -> (1,) value_loss = torch.mean(weights[:-1] * value_loss[:, :, None]) - metrics["value_mean"] = to_np(torch.mean(value.mode())) - metrics["value_max"] = to_np(torch.max(value.mode())) - metrics["value_min"] = to_np(torch.min(value.mode())) - metrics["value_std"] = to_np(torch.std(value.mode())) - metrics["target_mean"] = to_np(torch.mean(target)) - metrics["reward_mean"] = to_np(torch.mean(reward)) - metrics["reward_std"] = to_np(torch.std(reward)) + metrics.update(tools.tensorstats(value.mode(), "value")) + metrics.update(tools.tensorstats(target, "target")) + metrics.update(tools.tensorstats(reward, "imag_reward")) + metrics.update(tools.tensorstats(imag_action, "imag_action")) metrics["actor_ent"] = to_np(torch.mean(actor_ent)) with tools.RequiresGrad(self): metrics.update(self._actor_opt(actor_loss, self.actor.parameters())) @@ -402,6 +354,11 @@ class ImagBehavior(nn.Module): return imag_feat, imag_state, imag_action, weights, metrics def _imagine(self, start, policy, horizon, repeats=None): + # horizon: 15 + # start = dict(stoch, deter, logit) + # start["stoch"] (16, 63, 32, 32) + # start["deter"] (16, 63, 512) + # start["logit"] (16, 63, 32, 32) dynamics = self._world_model.dynamics if repeats: raise NotImplemented("repeats is not implemented in this version") @@ -418,6 +375,8 @@ class ImagBehavior(nn.Module): feat = 0 * dynamics.get_feat(start) action = policy(feat).mode() + # Is this action deterministic or stochastic? + # action = policy(feat).sample() succ, feats, actions = tools.static_scan( step, [torch.arange(horizon)], (start, feat, action) ) @@ -428,7 +387,7 @@ class ImagBehavior(nn.Module): return feats, states, actions def _compute_target( - self, imag_feat, imag_state, imag_action, reward, actor_ent, state_ent, slow + self, imag_feat, imag_state, imag_action, reward, actor_ent, state_ent ): if "discount" in self._world_model.heads: inp = self._world_model.dynamics.get_feat(imag_state) @@ -439,13 +398,10 @@ class ImagBehavior(nn.Module): reward += self._config.actor_entropy() * actor_ent if self._config.future_entropy and self._config.actor_state_entropy() > 0: reward += self._config.actor_state_entropy() * state_ent - if slow: - value = self._slow_value(imag_feat).mode() - else: - value = self.value(imag_feat).mode() - if self._config.critic_trans == "symlog": - # After adding this line there is issue - value = symexp(value) + value = self.value(imag_feat).mode() + # value(15, 960, ch) + # action(15, 960, ch) + # discount(15, 960, ch) target = tools.lambda_return( reward[:-1], value[:-1], @@ -457,10 +413,18 @@ class ImagBehavior(nn.Module): weights = torch.cumprod( torch.cat([torch.ones_like(discount[:1]), discount[:-1]], 0), 0 ).detach() - return target, weights + return target, weights, value[:-1] def _compute_actor_loss( - self, imag_feat, imag_state, imag_action, target, actor_ent, state_ent, weights + self, + imag_feat, + imag_state, + imag_action, + target, + actor_ent, + state_ent, + weights, + base, ): metrics = {} inp = imag_feat.detach() if self._stop_grad_actor else imag_feat @@ -469,11 +433,17 @@ class ImagBehavior(nn.Module): # Q-val for actor is not transformed using symlog target = torch.stack(target, dim=1) if self._config.reward_EMA: - target = self.reward_ema(target) - metrics["EMA_scale"] = to_np(self.reward_ema.scale) + offset, scale = self.reward_ema(target) + normed_target = (target - offset) / scale + normed_base = (base - offset) / scale + adv = normed_target - normed_base + metrics.update(tools.tensorstats(normed_target, "normed_target")) + values = self.reward_ema.values + metrics["EMA_005"] = to_np(values[0]) + metrics["EMA_095"] = to_np(values[1]) if self._config.imag_gradient == "dynamics": - actor_target = target + actor_target = adv elif self._config.imag_gradient == "reinforce": actor_target = ( policy.log_prob(imag_action)[:-1][:, :, None] @@ -501,7 +471,7 @@ class ImagBehavior(nn.Module): return actor_loss, metrics def _update_slow_target(self): - if self._config.slow_value_target or self._config.slow_actor_target: + if self._config.slow_value_target: if self._updates % self._config.slow_target_update == 0: mix = self._config.slow_target_fraction for s, d in zip(self.value.parameters(), self._slow_value.parameters()): diff --git a/networks.py b/networks.py index a9cb0d7..9e2ae7b 100644 --- a/networks.py +++ b/networks.py @@ -59,29 +59,33 @@ class RSSM(nn.Module): if self._shared: inp_dim += self._embed for i in range(self._layers_input): - inp_layers.append(nn.Linear(inp_dim, self._hidden)) - inp_layers.append(self._norm(self._hidden)) + inp_layers.append(nn.Linear(inp_dim, self._hidden, bias=False)) + inp_layers.append(self._norm(self._hidden, eps=1e-03)) inp_layers.append(self._act()) if i == 0: inp_dim = self._hidden self._inp_layers = nn.Sequential(*inp_layers) + self._inp_layers.apply(tools.weight_init) if cell == "gru": self._cell = GRUCell(self._hidden, self._deter) + self._cell.apply(tools.weight_init) elif cell == "gru_layer_norm": self._cell = GRUCell(self._hidden, self._deter, norm=True) + self._cell.apply(tools.weight_init) else: raise NotImplementedError(cell) img_out_layers = [] inp_dim = self._deter for i in range(self._layers_output): - img_out_layers.append(nn.Linear(inp_dim, self._hidden)) - img_out_layers.append(self._norm(self._hidden)) + img_out_layers.append(nn.Linear(inp_dim, self._hidden, bias=False)) + img_out_layers.append(self._norm(self._hidden, eps=1e-03)) img_out_layers.append(self._act()) if i == 0: inp_dim = self._hidden self._img_out_layers = nn.Sequential(*img_out_layers) + self._img_out_layers.apply(tools.weight_init) obs_out_layers = [] if self._temp_post: @@ -89,19 +93,24 @@ class RSSM(nn.Module): else: inp_dim = self._embed for i in range(self._layers_output): - obs_out_layers.append(nn.Linear(inp_dim, self._hidden)) - obs_out_layers.append(self._norm(self._hidden)) + obs_out_layers.append(nn.Linear(inp_dim, self._hidden, bias=False)) + obs_out_layers.append(self._norm(self._hidden, eps=1e-03)) obs_out_layers.append(self._act()) if i == 0: inp_dim = self._hidden self._obs_out_layers = nn.Sequential(*obs_out_layers) + self._obs_out_layers.apply(tools.weight_init) if self._discrete: self._ims_stat_layer = nn.Linear(self._hidden, self._stoch * self._discrete) + self._ims_stat_layer.apply(tools.weight_init) self._obs_stat_layer = nn.Linear(self._hidden, self._stoch * self._discrete) + self._obs_stat_layer.apply(tools.weight_init) else: self._ims_stat_layer = nn.Linear(self._hidden, 2 * self._stoch) + self._ims_stat_layer.apply(tools.weight_init) self._obs_stat_layer = nn.Linear(self._hidden, 2 * self._stoch) + self._obs_stat_layer.apply(tools.weight_init) def initial(self, batch_size): deter = torch.zeros(batch_size, self._deter).to(self._device) @@ -178,6 +187,7 @@ class RSSM(nn.Module): def obs_step(self, prev_state, prev_action, embed, sample=True): # if shared is True, prior and post both use same networks(inp_layers, _img_out_layers, _ims_stat_layer) # otherwise, post use different network(_obs_out_layers) with prior[deter] and embed as inputs + prev_action *= (1.0 / torch.clip(torch.abs(prev_action), min=1.0)).detach() prior = self.img_step(prev_state, prev_action, None, sample) if self._shared: post = self.img_step(prev_state, prev_action, embed, sample) @@ -200,6 +210,7 @@ class RSSM(nn.Module): # this is used for making future image def img_step(self, prev_state, prev_action, embed=None, sample=True): # (batch, stoch, discrete_num) + prev_action *= (1.0 / torch.clip(torch.abs(prev_action), min=1.0)).detach() prev_stoch = prev_state["stoch"] if self._discrete: shape = list(prev_stoch.shape[:-2]) + [self._stoch * self._discrete] @@ -317,12 +328,15 @@ class ConvEncoder(nn.Module): out_channels=depth, kernel_size=(kernel, kernel), stride=(2, 2), + bias=False, ) ) - h, w = h // 2, w // 2 - # layers.append(norm([depth, h, w])) + layers.append(ChLayerNorm(depth)) layers.append(act()) + h, w = h // 2, w // 2 + self.layers = nn.Sequential(*layers) + self.layers.apply(tools.weight_init) def __call__(self, obs): x = obs["image"].reshape((-1,) + tuple(obs["image"].shape[-3:])) @@ -343,6 +357,7 @@ class ConvDecoder(nn.Module): norm=nn.LayerNorm, shape=(3, 64, 64), kernels=(3, 3, 3, 3), + outscale=1.0, ): super(ConvDecoder, self).__init__() self._inp_depth = inp_depth @@ -358,19 +373,25 @@ class ConvDecoder(nn.Module): self._linear_layer = nn.Linear(inp_depth, self._embed_size) inp_dim = self._embed_size // 16 - cnnt_layers = [] + layers = [] h, w = 4, 4 for i, kernel in enumerate(self._kernels): depth = self._embed_size // 16 // (2 ** (i + 1)) act = self._act + bias = False + initializer = tools.weight_init if i == len(self._kernels) - 1: depth = self._shape[0] - act = None + act = False + bias = True + norm = False + initializer = tools.uniform_weight_init(outscale) + if i != 0: inp_dim = 2 ** (len(self._kernels) - (i - 1) - 2) * self._depth - pad_h, outpad_h = calc_same_pad(k=kernel, s=2, d=1) - pad_w, outpad_w = calc_same_pad(k=kernel, s=2, d=1) - cnnt_layers.append( + pad_h, outpad_h = self.calc_same_pad(k=kernel, s=2, d=1) + pad_w, outpad_w = self.calc_same_pad(k=kernel, s=2, d=1) + layers.append( nn.ConvTranspose2d( inp_dim, depth, @@ -378,26 +399,32 @@ class ConvDecoder(nn.Module): 2, padding=(pad_h, pad_w), output_padding=(outpad_h, outpad_w), + bias=bias, ) ) + if norm: + layers.append(ChLayerNorm(depth)) + if act: + layers.append(act()) + [m.apply(initializer) for m in layers[-3:]] h, w = h * 2, w * 2 - # cnnt_layers.append(norm([depth, h, w])) - if act is not None: - cnnt_layers.append(act()) - self._cnnt_layers = nn.Sequential(*cnnt_layers) + + self.layers = nn.Sequential(*layers) + + def calc_same_pad(self, k, s, d): + val = d * (k - 1) - s + 1 + pad = math.ceil(val / 2) + outpad = pad * 2 - val + return pad, outpad def __call__(self, features, dtype=None): x = self._linear_layer(features) x = x.reshape([-1, 4, 4, self._embed_size // 16]) x = x.permute(0, 3, 1, 2) - x = self._cnnt_layers(x) + x = self.layers(x) mean = x.reshape(features.shape[:-1] + self._shape) mean = mean.permute(0, 1, 3, 4, 2) - return tools.ContDist( - torchd.independent.Independent( - torchd.normal.Normal(mean, 1), len(self._shape) - ) - ) + return tools.SymlogDist(mean) class DenseHead(nn.Module): @@ -411,7 +438,7 @@ class DenseHead(nn.Module): norm=nn.LayerNorm, dist="normal", std=1.0, - unimix_ratio=0.0, + outscale=1.0, ): super(DenseHead, self).__init__() self._shape = (shape,) if isinstance(shape, int) else shape @@ -423,27 +450,30 @@ class DenseHead(nn.Module): self._norm = norm self._dist = dist self._std = std - self._unimix_ratio = unimix_ratio - mean_layers = [] + layers = [] for index in range(self._layers): - mean_layers.append(nn.Linear(inp_dim, self._units)) - mean_layers.append(norm(self._units)) - mean_layers.append(act()) + layers.append(nn.Linear(inp_dim, self._units, bias=False)) + layers.append(norm(self._units, eps=1e-03)) + layers.append(act()) if index == 0: inp_dim = self._units - mean_layers.append(nn.Linear(inp_dim, np.prod(self._shape))) - self._mean_layers = nn.Sequential(*mean_layers) + self.layers = nn.Sequential(*layers) + self.layers.apply(tools.weight_init) + + self.mean_layer = nn.Linear(inp_dim, np.prod(self._shape)) + self.mean_layer.apply(tools.uniform_weight_init(outscale)) if self._std == "learned": - self._std_layer = nn.Linear(self._units, np.prod(self._shape)) + self.std_layer = nn.Linear(self._units, np.prod(self._shape)) + self.std_layer.apply(tools.uniform_weight_init(outscale)) def __call__(self, features, dtype=None): x = features - mean = self._mean_layers(x) + out = self.layers(x) + mean = self.mean_layer(out) if self._std == "learned": - std = self._std_layer(x) - std = torch.softplus(std) + 0.01 + std = self.std_layer(out) else: std = self._std if self._dist == "normal": @@ -464,8 +494,8 @@ class DenseHead(nn.Module): torchd.bernoulli.Bernoulli(logits=mean), len(self._shape) ) ) - if self._dist == "twohot": - return tools.TwoHotDist(logits=mean, unimix_ratio=self._unimix_ratio) + if self._dist == "twohot_symlog": + return tools.TwoHotDistSymlog(logits=mean) raise NotImplementedError(self._dist) @@ -481,9 +511,9 @@ class ActionHead(nn.Module): dist="trunc_normal", init_std=0.0, min_std=0.1, - action_disc=5, + max_std=1.0, temp=0.1, - outscale=0, + outscale=1.0, ): super(ActionHead, self).__init__() self._size = size @@ -493,24 +523,27 @@ class ActionHead(nn.Module): self._act = act self._norm = norm self._min_std = min_std + self._max_std = max_std self._init_std = init_std - self._action_disc = action_disc self._temp = temp() if callable(temp) else temp - self._outscale = outscale pre_layers = [] for index in range(self._layers): - pre_layers.append(nn.Linear(inp_dim, self._units)) - pre_layers.append(norm(self._units)) + pre_layers.append(nn.Linear(inp_dim, self._units, bias=False)) + pre_layers.append(norm(self._units, eps=1e-03)) pre_layers.append(act()) if index == 0: inp_dim = self._units self._pre_layers = nn.Sequential(*pre_layers) + self._pre_layers.apply(tools.weight_init) if self._dist in ["tanh_normal", "tanh_normal_5", "normal", "trunc_normal"]: self._dist_layer = nn.Linear(self._units, 2 * self._size) + self._dist_layer.apply(tools.uniform_weight_init(outscale)) + elif self._dist in ["normal_1", "onehot", "onehot_gumbel"]: self._dist_layer = nn.Linear(self._units, self._size) + self._dist_layer.apply(tools.uniform_weight_init(outscale)) def __call__(self, features, dtype=None): x = features @@ -539,9 +572,11 @@ class ActionHead(nn.Module): dist = tools.SampleDist(dist) elif self._dist == "normal": x = self._dist_layer(x) - mean, std = torch.split(x, 2, -1) - std = F.softplus(std + self._init_std) + self._min_std - dist = torchd.normal.Normal(mean, std) + mean, std = torch.split(x, [self._size] * 2, -1) + std = (self._max_std - self._min_std) * torch.sigmoid( + std + 2.0 + ) + self._min_std + dist = torchd.normal.Normal(torch.tanh(mean), std) dist = tools.ContDist(torchd.independent.Independent(dist, 1)) elif self._dist == "normal_1": x = self._dist_layer(x) @@ -574,9 +609,9 @@ class GRUCell(nn.Module): self._act = act self._norm = norm self._update_bias = update_bias - self._layer = nn.Linear(inp_size + size, 3 * size, bias=norm is not None) + self._layer = nn.Linear(inp_size + size, 3 * size, bias=False) if norm: - self._norm = nn.LayerNorm(3 * size) + self._norm = nn.LayerNorm(3 * size, eps=1e-03) @property def state_size(self): @@ -625,8 +660,13 @@ class Conv2dSame(torch.nn.Conv2d): return ret -def calc_same_pad(k, s, d): - val = d * (k - 1) - s + 1 - pad = math.ceil(val / 2) - outpad = pad * 2 - val - return pad, outpad +class ChLayerNorm(nn.Module): + def __init__(self, ch, eps=1e-03): + super(ChLayerNorm, self).__init__() + self.norm = torch.nn.LayerNorm(ch, eps=eps) + + def forward(self, x): + x = x.permute(0, 2, 3, 1) + x = self.norm(x) + x = x.permute(0, 3, 1, 2) + return x diff --git a/tools.py b/tools.py index 13b351e..419af2c 100644 --- a/tools.py +++ b/tools.py @@ -17,6 +17,14 @@ from torch.utils.data import Dataset from torch.utils.tensorboard import SummaryWriter +to_np = lambda x: x.detach().cpu().numpy() + +def symlog(x): + return torch.sign(x) * torch.log(torch.abs(x) + 1.0) + +def symexp(x): + return torch.sign(x) * (torch.exp(torch.abs(x)) - 1.0) + class RequiresGrad: def __init__(self, model): @@ -269,11 +277,13 @@ class SampleDist: class OneHotDist(torchd.one_hot_categorical.OneHotCategorical): def __init__(self, logits=None, probs=None, unimix_ratio=0.0): - if logits is not None and probs is None and unimix_ratio > 0.0: + if logits is not None and unimix_ratio > 0.0: probs = F.softmax(logits, dim=-1) probs = probs * (1.0-unimix_ratio) + unimix_ratio / probs.shape[-1] - logits = None - super().__init__(logits=logits, probs=probs) + logits = torch.log(probs) + super().__init__(logits=logits, probs=None) + else: + super().__init__(logits=logits, probs=probs) def mode(self): _mode = F.one_hot(torch.argmax(super().logits, axis=-1), super().logits.shape[-1]) @@ -290,42 +300,81 @@ class OneHotDist(torchd.one_hot_categorical.OneHotCategorical): return sample -class TwoHotDist(torchd.one_hot_categorical.OneHotCategorical): +class TwoHotDistSymlog(): - def __init__(self, logits=None, probs=None, unimix_ratio=0.0, device='cuda'): - if logits is not None and probs is None and unimix_ratio > 0.0: - probs = F.softmax(logits, dim=-1) - probs = probs * (1.0-unimix_ratio) + unimix_ratio / probs.shape[-1] - logits = None - super().__init__(logits=logits, probs=probs) - - self.buckets = torch.linspace(-20.0, 20.0, steps=255).to(device) + def __init__(self, logits=None, low=-20.0, high=20.0, device='cuda'): + self.logits = logits + self.probs = torch.softmax(logits, -1) + self.buckets = torch.linspace(low, high, steps=255).to(device) self.width = (self.buckets[-1] - self.buckets[0]) / 255 + def mean(self): + print("mean called") + _mode = self.probs * self.buckets + return symexp(torch.sum(_mode, dim=-1, keepdim=True)) + def mode(self): - _mode = super().probs * self.buckets - return torch.sum(_mode, dim=-1, keepdim=True) + _mode = self.probs * self.buckets + return symexp(torch.sum(_mode, dim=-1, keepdim=True)) # Inside OneHotCategorical, log_prob is calculated using only max element in targets def log_prob(self, x): + x = symlog(x) # x(time, batch, 1) - x = (x - self.buckets[0]) / self.width - lower_indices = (x).to(torch.int64) - # lower_indices is idnside 0 ~ len(buckets)-2 - lower_indices = torch.clip(lower_indices, max=len(self.buckets)-2) - # upper_indices is inside 1 ~ len(buckets)-1 - upper_indices = lower_indices + 1 - lower_weight = torch.abs(x - upper_indices).squeeze(-1) - upper_weight = torch.abs(x - lower_indices).squeeze(-1) - # (time, batch, 1) -> (time, batch, bucket_class) - lower_log_prob = super().log_prob(F.one_hot(lower_indices.squeeze(-1), num_classes=len(self.buckets))) - upper_log_prob = super().log_prob(F.one_hot(upper_indices.squeeze(-1), num_classes=len(self.buckets))) + below = torch.sum((self.buckets <= x[..., None]).to(torch.int32), dim=-1) -1 + above = len(self.buckets) - torch.sum((self.buckets > x[..., None]).to(torch.int32), dim=-1) + below = torch.clip(below, 0, len(self.buckets)-1) + above = torch.clip(above, 0, len(self.buckets)-1) + equal = (below == above) - # label = lower_log_prob * lower_weight + upper_log_prob * upper_weight - # # (time, batch, bucket_class) -> (time, batch) - # cross_entropy = torch.sum(torch.log(super().probs) * label, axis=-1) + dist_to_below = torch.where(equal, 1, torch.abs(self.buckets[below] - x)) + dist_to_above = torch.where(equal, 1, torch.abs(self.buckets[above] - x)) + total = dist_to_below + dist_to_above + weight_below = dist_to_above / total + weight_above = dist_to_below / total + target = ( + F.one_hot(below, num_classes=len(self.buckets)) * weight_below[..., None] + + F.one_hot(above, num_classes=len(self.buckets)) * weight_above[..., None]) + log_pred = self.logits - torch.logsumexp(self.logits, -1, keepdim=True) + target = target.squeeze(-2) - return lower_weight * lower_log_prob + upper_weight * upper_log_prob + return (target * log_pred).sum(-1) + + def log_prob_target(self, target): + log_pred = super().logits - torch.logsumexp(super().logits, -1, keepdim=True) + return (target * log_pred).sum(-1) + +class SymlogDist(): + def __init__(self, mode, dist='mse', agg='sum', tol=1e-8, dim_to_reduce=[-1, -2, -3]): + self._mode = mode + self._dist = dist + self._agg = agg + self._tol = tol + self._dim_to_reduce = dim_to_reduce + + def mode(self): + return symexp(self._mode) + + def mean(self): + return symexp(self._mode) + + def log_prob(self, value): + assert self._mode.shape == value.shape + if self._dist == 'mse': + distance = (self._mode - symlog(value)) ** 2.0 + distance = torch.where(distance < self._tol, 0, distance) + elif self._dist == 'abs': + distance = torch.abs(self._mode - symlog(value)) + distance = torch.where(distance < self._tol, 0, distance) + else: + raise NotImplementedError(self._dist) + if self._agg == 'mean': + loss = distance.mean(self._dim_to_reduce) + elif self._agg == 'sum': + loss = distance.sum(self._dim_to_reduce) + else: + raise NotImplementedError(self._agg) + return -loss class ContDist: @@ -438,6 +487,7 @@ def static_scan_for_lambda_return(fn, inputs, start): indices = reversed(indices) flag = True for index in indices: + # (inputs, pcont) -> (inputs[index], pcont[index]) inp = lambda x: (_input[x] for _input in inputs) last = fn(last, *inp(index)) if flag: @@ -446,6 +496,7 @@ def static_scan_for_lambda_return(fn, inputs, start): else: outputs = torch.cat([outputs, last], dim=-1) outputs = torch.reshape(outputs, [outputs.shape[0], outputs.shape[1], 1]) + outputs = torch.flip(outputs, [1]) outputs = torch.unbind(outputs, dim=0) return outputs @@ -687,14 +738,53 @@ def schedule(string, step): def weight_init(m): if isinstance(m, nn.Linear): - nn.init.orthogonal_(m.weight.data) + in_num = m.in_features + out_num = m.out_features + denoms = (in_num + out_num) / 2.0 + scale = 1.0 / denoms + std = np.sqrt(scale) / 0.87962566103423978 + nn.init.trunc_normal_(m.weight.data, mean=0.0, std=std, a=- 2.0, b=2.0) if hasattr(m.bias, 'data'): m.bias.data.fill_(0.0) elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): - gain = nn.init.calculate_gain('relu') - nn.init.orthogonal_(m.weight.data, gain) + space = m.kernel_size[0] * m.kernel_size[1] + in_num = space * m.in_channels + out_num = space * m.out_channels + denoms = (in_num + out_num) / 2.0 + scale = 1.0 / denoms + std = np.sqrt(scale) / 0.87962566103423978 + nn.init.trunc_normal_(m.weight.data, mean=0.0, std=std, a=- 2.0, b=2.0) if hasattr(m.bias, 'data'): m.bias.data.fill_(0.0) elif isinstance(m, nn.LayerNorm): + m.weight.data.fill_(1.0) if hasattr(m.bias, 'data'): - m.bias.data.fill_(0.0) \ No newline at end of file + m.bias.data.fill_(0.0) + +def uniform_weight_init(given_scale): + def f(m): + if isinstance(m, nn.Linear): + in_num = m.in_features + out_num = m.out_features + denoms = (in_num + out_num) / 2.0 + scale = given_scale / denoms + limit = np.sqrt(3 * scale) + nn.init.uniform_(m.weight.data, a=-limit, b=limit) + if hasattr(m.bias, 'data'): + m.bias.data.fill_(0.0) + elif isinstance(m, nn.LayerNorm): + m.weight.data.fill_(1.0) + if hasattr(m.bias, 'data'): + m.bias.data.fill_(0.0) + return f + +def tensorstats(tensor, prefix=None): + metrics = { + 'mean': to_np(torch.mean(tensor)), + 'std': to_np(torch.std(tensor)), + 'min': to_np(torch.min(tensor)), + 'max': to_np(torch.max(tensor)), + } + if prefix: + metrics = {f'{prefix}_{k}': v for k, v in metrics.items()} + return metrics