diff --git a/models.py b/models.py index 0ab27f1..a97fbc8 100644 --- a/models.py +++ b/models.py @@ -144,10 +144,14 @@ class WorldModel(nn.Module): preds[name] = pred losses = {} for name, pred in preds.items(): - like = pred.log_prob(data[name]) - losses[name] = -torch.mean(like) * self._scales.get(name, 1.0) - model_loss = sum(losses.values()) + kl_loss - metrics = self._model_opt(model_loss, self.parameters()) + loss = -pred.log_prob(data[name]) + assert loss.shape == embed.shape[:2], (name, loss.shape) + losses[name] = loss + scaled = { + key: value * self._scales[key] for key, value in losses.items() + } + model_loss = sum(scaled.values()) + kl_loss + metrics = self._model_opt(torch.mean(model_loss), self.parameters()) metrics.update({f"{name}_loss": to_np(loss) for name, loss in losses.items()}) metrics["kl_free"] = kl_free @@ -318,6 +322,8 @@ class ImagBehavior(nn.Module): weights, base, ) + actor_loss -= self._config.actor["entropy"] * actor_ent[:-1, ..., None] + actor_loss = torch.mean(actor_loss) metrics.update(mets) value_input = imag_feat @@ -382,10 +388,6 @@ class ImagBehavior(nn.Module): discount = self._config.discount * self._world_model.heads["cont"](inp).mean else: discount = self._config.discount * torch.ones_like(reward) - if self._config.future_entropy and self._config.actor_entropy > 0: - reward += self._config.actor_entropy * actor_ent - if self._config.future_entropy and self._config.actor_state_entropy > 0: - reward += self._config.actor_state_entropy * state_ent value = self.value(imag_feat).mode() target = tools.lambda_return( reward[1:], @@ -444,14 +446,7 @@ class ImagBehavior(nn.Module): metrics["imag_gradient_mix"] = mix else: raise NotImplementedError(self._config.imag_gradient) - if not self._config.future_entropy and self._config.actor_entropy > 0: - actor_entropy = self._config.actor_entropy * actor_ent[:-1][:, :, None] - actor_target += actor_entropy - if not self._config.future_entropy and (self._config.actor_state_entropy > 0): - state_entropy = self._config.actor_state_entropy * state_ent[:-1] - actor_target += state_entropy - metrics["actor_state_entropy"] = to_np(torch.mean(state_entropy)) - actor_loss = -torch.mean(weights[:-1] * actor_target) + actor_loss = -weights[:-1] * actor_target return actor_loss, metrics def _update_slow_target(self): diff --git a/networks.py b/networks.py index 5118ada..f43630c 100644 --- a/networks.py +++ b/networks.py @@ -327,8 +327,9 @@ class RSSM(nn.Module): dist(sg(post)) if self._discrete else dist(sg(post))._dist, dist(prior) if self._discrete else dist(prior)._dist, ) - rep_loss = torch.mean(torch.clip(rep_loss, min=free)) - dyn_loss = torch.mean(torch.clip(dyn_loss, min=free)) + # this is implemented using maximum at the original repo as the gradients are not backpropagated for the out of limits. + rep_loss = torch.clip(rep_loss, min=free) + dyn_loss = torch.clip(dyn_loss, min=free) loss = dyn_scale * dyn_loss + rep_scale * rep_loss return loss, value, dyn_loss, rep_loss diff --git a/tools.py b/tools.py index 8265a57..b12c52e 100644 --- a/tools.py +++ b/tools.py @@ -338,7 +338,7 @@ def sample_episodes(episodes, length, seed=0): if not ret: index = int(np_random.randint(0, total - 1)) ret = { - k: v[index : min(index + length, total)] + k: v[index : min(index + length, total)].copy() for k, v in episode.items() if "log_" not in k } @@ -350,7 +350,7 @@ def sample_episodes(episodes, length, seed=0): possible = length - size ret = { k: np.append( - ret[k], v[index : min(index + possible, total)], axis=0 + ret[k], v[index : min(index + possible, total)].copy(), axis=0 ) for k, v in episode.items() if "log_" not in k @@ -482,6 +482,7 @@ class DiscDist: above = len(self.buckets) - torch.sum( (self.buckets > x[..., None]).to(torch.int32), dim=-1 ) + # this is implemented using clip at the original repo as the gradients are not backpropagated for the out of limits. below = torch.clip(below, 0, len(self.buckets) - 1) above = torch.clip(above, 0, len(self.buckets) - 1) equal = below == above @@ -606,7 +607,7 @@ class Bernoulli: log_probs0 = -F.softplus(_logits) log_probs1 = -F.softplus(-_logits) - return log_probs0 * (1 - x) + log_probs1 * x + return torch.sum(log_probs0 * (1 - x) + log_probs1 * x, -1) class UnnormalizedHuber(torchd.normal.Normal): @@ -739,11 +740,12 @@ class Optimizer: }[opt]() self._scaler = torch.cuda.amp.GradScaler(enabled=use_amp) - def __call__(self, loss, params, retain_graph=False): + def __call__(self, loss, params, retain_graph=True): assert len(loss.shape) == 0, loss.shape metrics = {} metrics[f"{self._name}_loss"] = loss.detach().cpu().numpy() - self._scaler.scale(loss).backward() + self._opt.zero_grad() + self._scaler.scale(loss).backward(retain_graph=retain_graph) self._scaler.unscale_(self._opt) # loss.backward(retain_graph=retain_graph) norm = torch.nn.utils.clip_grad_norm_(params, self._clip) @@ -1001,11 +1003,9 @@ def recursively_collect_optim_state_dict( def recursively_load_optim_state_dict(obj, optimizers_state_dicts): - print(optimizers_state_dicts) for path, state_dict in optimizers_state_dicts.items(): keys = path.split(".") obj_now = obj for key in keys: obj_now = getattr(obj_now, key) - print(keys) obj_now.load_state_dict(state_dict)