From 628b856c63c176aec69e2d0b1de030f19e02a42a Mon Sep 17 00:00:00 2001 From: NM512 Date: Sat, 22 Apr 2023 09:34:23 +0900 Subject: [PATCH] changed the discount head to predict terminal --- configs.yaml | 14 ++++++------- dreamer.py | 5 ----- models.py | 55 +++++++++++++++++++++++++++++++--------------------- networks.py | 26 +++++++++++-------------- 4 files changed, 50 insertions(+), 50 deletions(-) diff --git a/configs.yaml b/configs.yaml index 7d3106d..30943cb 100644 --- a/configs.yaml +++ b/configs.yaml @@ -42,10 +42,10 @@ defaults: dyn_std_act: 'sigmoid2' dyn_min_std: 0.1 dyn_temp_post: True - grad_heads: ['image', 'reward', 'discount'] + grad_heads: ['image', 'reward', 'cont'] units: 512 reward_layers: 2 - discount_layers: 2 + cont_layers: 2 value_layers: 2 actor_layers: 2 act: 'SiLU' @@ -55,12 +55,10 @@ defaults: decoder_kernels: [4, 4, 4, 4] value_head: 'twohot_symlog' reward_head: 'twohot_symlog' - kl_lscale: '0.1' - kl_rscale: '0.5' + dyn_scale: '0.5' + rep_scale: '0.1' kl_free: '1.0' - kl_forward: False - pred_discount: True - discount_scale: 1.0 + cont_scale: 1.0 reward_scale: 1.0 weight_decay: 0.0 unimix_ratio: 0.01 @@ -80,7 +78,7 @@ defaults: value_grad_clip: 100 actor_grad_clip: 100 dataset_size: 1000000 - oversample_ends: False + oversample_ends: True slow_value_target: True slow_target_update: 1 slow_target_fraction: 0.02 diff --git a/dreamer.py b/dreamer.py index 796051b..5681db1 100644 --- a/dreamer.py +++ b/dreamer.py @@ -155,16 +155,11 @@ class Dreamer(nn.Module): metrics.update(mets) start = post # start['deter'] (16, 64, 512) - if self._config.pred_discount: # Last step could be terminal. - start = {k: v[:, :-1] for k, v in post.items()} - context = {k: v[:, :-1] for k, v in context.items()} reward = lambda f, s, a: self._wm.heads["reward"]( self._wm.dynamics.get_feat(s) ).mode() metrics.update(self._task_behavior._train(start, reward)[-1]) if self._config.expl_behavior != "greedy": - if self._config.pred_discount: - data = {k: v[:, :-1] for k, v in data.items()} mets = self._expl_behavior.train(start, context, data)[-1] metrics.update({"expl_" + key: value for key, value in mets.items()}) for name, value in metrics.items(): diff --git a/models.py b/models.py index 3455f5a..52a640e 100644 --- a/models.py +++ b/models.py @@ -107,16 +107,15 @@ class WorldModel(nn.Module): dist=config.reward_head, outscale=0.0, ) - if config.pred_discount: - self.heads["discount"] = networks.DenseHead( - feat_size, # pytorch version - [], - config.discount_layers, - config.units, - config.act, - config.norm, - dist="binary", - ) + self.heads["cont"] = networks.DenseHead( + feat_size, # pytorch version + [], + config.cont_layers, + config.units, + config.act, + config.norm, + dist="binary", + ) for name in config.grad_heads: assert name in self.heads, name self._model_opt = tools.Optimizer( @@ -129,7 +128,7 @@ class WorldModel(nn.Module): opt=config.opt, use_amp=self._use_amp, ) - self._scales = dict(reward=config.reward_scale, discount=config.discount_scale) + self._scales = dict(reward=config.reward_scale, cont=config.cont_scale) def _train(self, data): # action (batch_size, batch_length, act_dim) @@ -143,10 +142,10 @@ class WorldModel(nn.Module): embed = self.encoder(data) post, prior = self.dynamics.observe(embed, data["action"]) kl_free = tools.schedule(self._config.kl_free, self._step) - kl_lscale = tools.schedule(self._config.kl_lscale, self._step) - kl_rscale = tools.schedule(self._config.kl_rscale, self._step) - kl_loss, kl_value, loss_lhs, loss_rhs = self.dynamics.kl_loss( - post, prior, self._config.kl_forward, kl_free, kl_lscale, kl_rscale + dyn_scale = tools.schedule(self._config.dyn_scale, self._step) + rep_scale = tools.schedule(self._config.rep_scale, self._step) + kl_loss, kl_value, dyn_loss, rep_loss = self.dynamics.kl_loss( + post, prior, kl_free, dyn_scale, rep_scale ) losses = {} likes = {} @@ -163,10 +162,10 @@ class WorldModel(nn.Module): metrics.update({f"{name}_loss": to_np(loss) for name, loss in losses.items()}) metrics["kl_free"] = kl_free - metrics["kl_lscale"] = kl_lscale - metrics["kl_rscale"] = kl_rscale - metrics["loss_lhs"] = to_np(loss_lhs) - metrics["loss_rhs"] = to_np(loss_rhs) + metrics["dyn_scale"] = dyn_scale + metrics["rep_scale"] = rep_scale + metrics["dyn_loss"] = to_np(dyn_loss) + metrics["rep_loss"] = to_np(rep_loss) metrics["kl"] = to_np(torch.mean(kl_value)) with torch.cuda.amp.autocast(self._use_amp): metrics["prior_ent"] = to_np( @@ -193,6 +192,11 @@ class WorldModel(nn.Module): obs["discount"] *= self._config.discount # (batch_size, batch_length) -> (batch_size, batch_length, 1) obs["discount"] = torch.Tensor(obs["discount"]).unsqueeze(-1) + if "is_terminal" in obs: + # this label is necessary to train cont_head + obs["cont"] = torch.Tensor(1.0 - obs["is_terminal"]).unsqueeze(-1) + else: + raise ValueError('"is_terminal" was not found in observation.') obs = {k: torch.Tensor(v).to(self._config.device) for k, v in obs.items()} return obs @@ -347,7 +351,14 @@ class ImagBehavior(nn.Module): metrics.update(tools.tensorstats(value.mode(), "value")) metrics.update(tools.tensorstats(target, "target")) metrics.update(tools.tensorstats(reward, "imag_reward")) - metrics.update(tools.tensorstats(imag_action, "imag_action")) + if self._config.actor_dist in ["onehot"]: + metrics.update( + tools.tensorstats( + torch.argmax(imag_action, dim=-1).float(), "imag_action" + ) + ) + else: + metrics.update(tools.tensorstats(imag_action, "imag_action")) metrics["actor_ent"] = to_np(torch.mean(actor_ent)) with tools.RequiresGrad(self): metrics.update(self._actor_opt(actor_loss, self.actor.parameters())) @@ -390,9 +401,9 @@ class ImagBehavior(nn.Module): def _compute_target( self, imag_feat, imag_state, imag_action, reward, actor_ent, state_ent ): - if "discount" in self._world_model.heads: + if "cont" in self._world_model.heads: inp = self._world_model.dynamics.get_feat(imag_state) - discount = self._world_model.heads["discount"](inp).mean + discount = self._config.discount * self._world_model.heads["cont"](inp).mean else: discount = self._config.discount * torch.ones_like(reward) if self._config.future_entropy and self._config.actor_entropy() > 0: diff --git a/networks.py b/networks.py index 8f360d1..3171afa 100644 --- a/networks.py +++ b/networks.py @@ -273,28 +273,24 @@ class RSSM(nn.Module): std = std + self._min_std return {"mean": mean, "std": std} - def kl_loss(self, post, prior, forward, free, lscale, rscale): + def kl_loss(self, post, prior, free, dyn_scale, rep_scale): kld = torchd.kl.kl_divergence dist = lambda x: self.get_dist(x) sg = lambda x: {k: v.detach() for k, v in x.items()} - # forward == false -> (post, prior) - lhs, rhs = (prior, post) if forward else (post, prior) - # forward == false -> Lrep - value_lhs = value = kld( - dist(lhs) if self._discrete else dist(lhs)._dist, - dist(sg(rhs)) if self._discrete else dist(sg(rhs))._dist, + rep_loss = value = kld( + dist(post) if self._discrete else dist(post)._dist, + dist(sg(prior)) if self._discrete else dist(sg(prior))._dist, ) - # forward == false -> Ldyn - value_rhs = kld( - dist(sg(lhs)) if self._discrete else dist(sg(lhs))._dist, - dist(rhs) if self._discrete else dist(rhs)._dist, + dyn_loss = kld( + dist(sg(post)) if self._discrete else dist(sg(post))._dist, + dist(prior) if self._discrete else dist(prior)._dist, ) - loss_lhs = torch.clip(torch.mean(value_lhs), min=free) - loss_rhs = torch.clip(torch.mean(value_rhs), min=free) - loss = lscale * loss_lhs + rscale * loss_rhs + rep_loss = torch.mean(torch.clip(rep_loss, min=free)) + dyn_loss = torch.mean(torch.clip(dyn_loss, min=free)) + loss = dyn_scale * dyn_loss + rep_scale * rep_loss - return loss, value, loss_lhs, loss_rhs + return loss, value, dyn_loss, rep_loss class ConvEncoder(nn.Module):