modified loss calculation

This commit is contained in:
NM512 2024-01-05 10:44:04 +09:00
parent e0487f8206
commit 78e86703f4
3 changed files with 21 additions and 25 deletions

View File

@ -144,10 +144,14 @@ class WorldModel(nn.Module):
preds[name] = pred
losses = {}
for name, pred in preds.items():
like = pred.log_prob(data[name])
losses[name] = -torch.mean(like) * self._scales.get(name, 1.0)
model_loss = sum(losses.values()) + kl_loss
metrics = self._model_opt(model_loss, self.parameters())
loss = -pred.log_prob(data[name])
assert loss.shape == embed.shape[:2], (name, loss.shape)
losses[name] = loss
scaled = {
key: value * self._scales[key] for key, value in losses.items()
}
model_loss = sum(scaled.values()) + kl_loss
metrics = self._model_opt(torch.mean(model_loss), self.parameters())
metrics.update({f"{name}_loss": to_np(loss) for name, loss in losses.items()})
metrics["kl_free"] = kl_free
@ -318,6 +322,8 @@ class ImagBehavior(nn.Module):
weights,
base,
)
actor_loss -= self._config.actor["entropy"] * actor_ent[:-1, ..., None]
actor_loss = torch.mean(actor_loss)
metrics.update(mets)
value_input = imag_feat
@ -382,10 +388,6 @@ class ImagBehavior(nn.Module):
discount = self._config.discount * self._world_model.heads["cont"](inp).mean
else:
discount = self._config.discount * torch.ones_like(reward)
if self._config.future_entropy and self._config.actor_entropy > 0:
reward += self._config.actor_entropy * actor_ent
if self._config.future_entropy and self._config.actor_state_entropy > 0:
reward += self._config.actor_state_entropy * state_ent
value = self.value(imag_feat).mode()
target = tools.lambda_return(
reward[1:],
@ -444,14 +446,7 @@ class ImagBehavior(nn.Module):
metrics["imag_gradient_mix"] = mix
else:
raise NotImplementedError(self._config.imag_gradient)
if not self._config.future_entropy and self._config.actor_entropy > 0:
actor_entropy = self._config.actor_entropy * actor_ent[:-1][:, :, None]
actor_target += actor_entropy
if not self._config.future_entropy and (self._config.actor_state_entropy > 0):
state_entropy = self._config.actor_state_entropy * state_ent[:-1]
actor_target += state_entropy
metrics["actor_state_entropy"] = to_np(torch.mean(state_entropy))
actor_loss = -torch.mean(weights[:-1] * actor_target)
actor_loss = -weights[:-1] * actor_target
return actor_loss, metrics
def _update_slow_target(self):

View File

@ -327,8 +327,9 @@ class RSSM(nn.Module):
dist(sg(post)) if self._discrete else dist(sg(post))._dist,
dist(prior) if self._discrete else dist(prior)._dist,
)
rep_loss = torch.mean(torch.clip(rep_loss, min=free))
dyn_loss = torch.mean(torch.clip(dyn_loss, min=free))
# this is implemented using maximum at the original repo as the gradients are not backpropagated for the out of limits.
rep_loss = torch.clip(rep_loss, min=free)
dyn_loss = torch.clip(dyn_loss, min=free)
loss = dyn_scale * dyn_loss + rep_scale * rep_loss
return loss, value, dyn_loss, rep_loss

View File

@ -338,7 +338,7 @@ def sample_episodes(episodes, length, seed=0):
if not ret:
index = int(np_random.randint(0, total - 1))
ret = {
k: v[index : min(index + length, total)]
k: v[index : min(index + length, total)].copy()
for k, v in episode.items()
if "log_" not in k
}
@ -350,7 +350,7 @@ def sample_episodes(episodes, length, seed=0):
possible = length - size
ret = {
k: np.append(
ret[k], v[index : min(index + possible, total)], axis=0
ret[k], v[index : min(index + possible, total)].copy(), axis=0
)
for k, v in episode.items()
if "log_" not in k
@ -482,6 +482,7 @@ class DiscDist:
above = len(self.buckets) - torch.sum(
(self.buckets > x[..., None]).to(torch.int32), dim=-1
)
# this is implemented using clip at the original repo as the gradients are not backpropagated for the out of limits.
below = torch.clip(below, 0, len(self.buckets) - 1)
above = torch.clip(above, 0, len(self.buckets) - 1)
equal = below == above
@ -606,7 +607,7 @@ class Bernoulli:
log_probs0 = -F.softplus(_logits)
log_probs1 = -F.softplus(-_logits)
return log_probs0 * (1 - x) + log_probs1 * x
return torch.sum(log_probs0 * (1 - x) + log_probs1 * x, -1)
class UnnormalizedHuber(torchd.normal.Normal):
@ -739,11 +740,12 @@ class Optimizer:
}[opt]()
self._scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
def __call__(self, loss, params, retain_graph=False):
def __call__(self, loss, params, retain_graph=True):
assert len(loss.shape) == 0, loss.shape
metrics = {}
metrics[f"{self._name}_loss"] = loss.detach().cpu().numpy()
self._scaler.scale(loss).backward()
self._opt.zero_grad()
self._scaler.scale(loss).backward(retain_graph=retain_graph)
self._scaler.unscale_(self._opt)
# loss.backward(retain_graph=retain_graph)
norm = torch.nn.utils.clip_grad_norm_(params, self._clip)
@ -1001,11 +1003,9 @@ def recursively_collect_optim_state_dict(
def recursively_load_optim_state_dict(obj, optimizers_state_dicts):
print(optimizers_state_dicts)
for path, state_dict in optimizers_state_dicts.items():
keys = path.split(".")
obj_now = obj
for key in keys:
obj_now = getattr(obj_now, key)
print(keys)
obj_now.load_state_dict(state_dict)