modified loss calculation
This commit is contained in:
parent
e0487f8206
commit
78e86703f4
27
models.py
27
models.py
@ -144,10 +144,14 @@ class WorldModel(nn.Module):
|
|||||||
preds[name] = pred
|
preds[name] = pred
|
||||||
losses = {}
|
losses = {}
|
||||||
for name, pred in preds.items():
|
for name, pred in preds.items():
|
||||||
like = pred.log_prob(data[name])
|
loss = -pred.log_prob(data[name])
|
||||||
losses[name] = -torch.mean(like) * self._scales.get(name, 1.0)
|
assert loss.shape == embed.shape[:2], (name, loss.shape)
|
||||||
model_loss = sum(losses.values()) + kl_loss
|
losses[name] = loss
|
||||||
metrics = self._model_opt(model_loss, self.parameters())
|
scaled = {
|
||||||
|
key: value * self._scales[key] for key, value in losses.items()
|
||||||
|
}
|
||||||
|
model_loss = sum(scaled.values()) + kl_loss
|
||||||
|
metrics = self._model_opt(torch.mean(model_loss), self.parameters())
|
||||||
|
|
||||||
metrics.update({f"{name}_loss": to_np(loss) for name, loss in losses.items()})
|
metrics.update({f"{name}_loss": to_np(loss) for name, loss in losses.items()})
|
||||||
metrics["kl_free"] = kl_free
|
metrics["kl_free"] = kl_free
|
||||||
@ -318,6 +322,8 @@ class ImagBehavior(nn.Module):
|
|||||||
weights,
|
weights,
|
||||||
base,
|
base,
|
||||||
)
|
)
|
||||||
|
actor_loss -= self._config.actor["entropy"] * actor_ent[:-1, ..., None]
|
||||||
|
actor_loss = torch.mean(actor_loss)
|
||||||
metrics.update(mets)
|
metrics.update(mets)
|
||||||
value_input = imag_feat
|
value_input = imag_feat
|
||||||
|
|
||||||
@ -382,10 +388,6 @@ class ImagBehavior(nn.Module):
|
|||||||
discount = self._config.discount * self._world_model.heads["cont"](inp).mean
|
discount = self._config.discount * self._world_model.heads["cont"](inp).mean
|
||||||
else:
|
else:
|
||||||
discount = self._config.discount * torch.ones_like(reward)
|
discount = self._config.discount * torch.ones_like(reward)
|
||||||
if self._config.future_entropy and self._config.actor_entropy > 0:
|
|
||||||
reward += self._config.actor_entropy * actor_ent
|
|
||||||
if self._config.future_entropy and self._config.actor_state_entropy > 0:
|
|
||||||
reward += self._config.actor_state_entropy * state_ent
|
|
||||||
value = self.value(imag_feat).mode()
|
value = self.value(imag_feat).mode()
|
||||||
target = tools.lambda_return(
|
target = tools.lambda_return(
|
||||||
reward[1:],
|
reward[1:],
|
||||||
@ -444,14 +446,7 @@ class ImagBehavior(nn.Module):
|
|||||||
metrics["imag_gradient_mix"] = mix
|
metrics["imag_gradient_mix"] = mix
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(self._config.imag_gradient)
|
raise NotImplementedError(self._config.imag_gradient)
|
||||||
if not self._config.future_entropy and self._config.actor_entropy > 0:
|
actor_loss = -weights[:-1] * actor_target
|
||||||
actor_entropy = self._config.actor_entropy * actor_ent[:-1][:, :, None]
|
|
||||||
actor_target += actor_entropy
|
|
||||||
if not self._config.future_entropy and (self._config.actor_state_entropy > 0):
|
|
||||||
state_entropy = self._config.actor_state_entropy * state_ent[:-1]
|
|
||||||
actor_target += state_entropy
|
|
||||||
metrics["actor_state_entropy"] = to_np(torch.mean(state_entropy))
|
|
||||||
actor_loss = -torch.mean(weights[:-1] * actor_target)
|
|
||||||
return actor_loss, metrics
|
return actor_loss, metrics
|
||||||
|
|
||||||
def _update_slow_target(self):
|
def _update_slow_target(self):
|
||||||
|
@ -327,8 +327,9 @@ class RSSM(nn.Module):
|
|||||||
dist(sg(post)) if self._discrete else dist(sg(post))._dist,
|
dist(sg(post)) if self._discrete else dist(sg(post))._dist,
|
||||||
dist(prior) if self._discrete else dist(prior)._dist,
|
dist(prior) if self._discrete else dist(prior)._dist,
|
||||||
)
|
)
|
||||||
rep_loss = torch.mean(torch.clip(rep_loss, min=free))
|
# this is implemented using maximum at the original repo as the gradients are not backpropagated for the out of limits.
|
||||||
dyn_loss = torch.mean(torch.clip(dyn_loss, min=free))
|
rep_loss = torch.clip(rep_loss, min=free)
|
||||||
|
dyn_loss = torch.clip(dyn_loss, min=free)
|
||||||
loss = dyn_scale * dyn_loss + rep_scale * rep_loss
|
loss = dyn_scale * dyn_loss + rep_scale * rep_loss
|
||||||
|
|
||||||
return loss, value, dyn_loss, rep_loss
|
return loss, value, dyn_loss, rep_loss
|
||||||
|
14
tools.py
14
tools.py
@ -338,7 +338,7 @@ def sample_episodes(episodes, length, seed=0):
|
|||||||
if not ret:
|
if not ret:
|
||||||
index = int(np_random.randint(0, total - 1))
|
index = int(np_random.randint(0, total - 1))
|
||||||
ret = {
|
ret = {
|
||||||
k: v[index : min(index + length, total)]
|
k: v[index : min(index + length, total)].copy()
|
||||||
for k, v in episode.items()
|
for k, v in episode.items()
|
||||||
if "log_" not in k
|
if "log_" not in k
|
||||||
}
|
}
|
||||||
@ -350,7 +350,7 @@ def sample_episodes(episodes, length, seed=0):
|
|||||||
possible = length - size
|
possible = length - size
|
||||||
ret = {
|
ret = {
|
||||||
k: np.append(
|
k: np.append(
|
||||||
ret[k], v[index : min(index + possible, total)], axis=0
|
ret[k], v[index : min(index + possible, total)].copy(), axis=0
|
||||||
)
|
)
|
||||||
for k, v in episode.items()
|
for k, v in episode.items()
|
||||||
if "log_" not in k
|
if "log_" not in k
|
||||||
@ -482,6 +482,7 @@ class DiscDist:
|
|||||||
above = len(self.buckets) - torch.sum(
|
above = len(self.buckets) - torch.sum(
|
||||||
(self.buckets > x[..., None]).to(torch.int32), dim=-1
|
(self.buckets > x[..., None]).to(torch.int32), dim=-1
|
||||||
)
|
)
|
||||||
|
# this is implemented using clip at the original repo as the gradients are not backpropagated for the out of limits.
|
||||||
below = torch.clip(below, 0, len(self.buckets) - 1)
|
below = torch.clip(below, 0, len(self.buckets) - 1)
|
||||||
above = torch.clip(above, 0, len(self.buckets) - 1)
|
above = torch.clip(above, 0, len(self.buckets) - 1)
|
||||||
equal = below == above
|
equal = below == above
|
||||||
@ -606,7 +607,7 @@ class Bernoulli:
|
|||||||
log_probs0 = -F.softplus(_logits)
|
log_probs0 = -F.softplus(_logits)
|
||||||
log_probs1 = -F.softplus(-_logits)
|
log_probs1 = -F.softplus(-_logits)
|
||||||
|
|
||||||
return log_probs0 * (1 - x) + log_probs1 * x
|
return torch.sum(log_probs0 * (1 - x) + log_probs1 * x, -1)
|
||||||
|
|
||||||
|
|
||||||
class UnnormalizedHuber(torchd.normal.Normal):
|
class UnnormalizedHuber(torchd.normal.Normal):
|
||||||
@ -739,11 +740,12 @@ class Optimizer:
|
|||||||
}[opt]()
|
}[opt]()
|
||||||
self._scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
|
self._scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
|
||||||
|
|
||||||
def __call__(self, loss, params, retain_graph=False):
|
def __call__(self, loss, params, retain_graph=True):
|
||||||
assert len(loss.shape) == 0, loss.shape
|
assert len(loss.shape) == 0, loss.shape
|
||||||
metrics = {}
|
metrics = {}
|
||||||
metrics[f"{self._name}_loss"] = loss.detach().cpu().numpy()
|
metrics[f"{self._name}_loss"] = loss.detach().cpu().numpy()
|
||||||
self._scaler.scale(loss).backward()
|
self._opt.zero_grad()
|
||||||
|
self._scaler.scale(loss).backward(retain_graph=retain_graph)
|
||||||
self._scaler.unscale_(self._opt)
|
self._scaler.unscale_(self._opt)
|
||||||
# loss.backward(retain_graph=retain_graph)
|
# loss.backward(retain_graph=retain_graph)
|
||||||
norm = torch.nn.utils.clip_grad_norm_(params, self._clip)
|
norm = torch.nn.utils.clip_grad_norm_(params, self._clip)
|
||||||
@ -1001,11 +1003,9 @@ def recursively_collect_optim_state_dict(
|
|||||||
|
|
||||||
|
|
||||||
def recursively_load_optim_state_dict(obj, optimizers_state_dicts):
|
def recursively_load_optim_state_dict(obj, optimizers_state_dicts):
|
||||||
print(optimizers_state_dicts)
|
|
||||||
for path, state_dict in optimizers_state_dicts.items():
|
for path, state_dict in optimizers_state_dicts.items():
|
||||||
keys = path.split(".")
|
keys = path.split(".")
|
||||||
obj_now = obj
|
obj_now = obj
|
||||||
for key in keys:
|
for key in keys:
|
||||||
obj_now = getattr(obj_now, key)
|
obj_now = getattr(obj_now, key)
|
||||||
print(keys)
|
|
||||||
obj_now.load_state_dict(state_dict)
|
obj_now.load_state_dict(state_dict)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user