146 lines
4.9 KiB
Python
146 lines
4.9 KiB
Python
# Copyright (c) EVAR Lab, IIIS, Tsinghua University.
|
|
#
|
|
# This source code is licensed under the GNU License, Version 3.0
|
|
# found in the LICENSE file in the root directory of this source tree.
|
|
|
|
import torch
|
|
import torchrl
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from ez.utils.distribution import SquashedNormal, TruncatedNormal, ContDist
|
|
from ez.utils.format import atanh
|
|
from ez.utils.format import symlog, symexp, DiscreteSupport
|
|
from torch.cuda.amp import autocast as autocast
|
|
|
|
|
|
def cosine_similarity_loss(f1, f2):
|
|
"""Cosine Consistency loss function: similarity loss
|
|
Parameters
|
|
"""
|
|
f1 = F.normalize(f1, p=2., dim=-1, eps=1e-5)
|
|
f2 = F.normalize(f2, p=2., dim=-1, eps=1e-5)
|
|
return -(f1 * f2).sum(dim=1)
|
|
|
|
|
|
def kl_loss(prediction, target):
|
|
return -(torch.log_softmax(prediction, dim=-1) * target).sum(-1)
|
|
|
|
|
|
def symlog_loss(prediction, target):
|
|
return 0.5 * (prediction.squeeze() - symlog(target)) ** 2
|
|
|
|
|
|
def Value_loss(preds, targets, config):
|
|
v_num = config.train.v_num
|
|
targets = targets.repeat(v_num, 1)
|
|
iql_weight = config.train.IQL_weight
|
|
if not config.train.use_IQL:
|
|
iql_weight = 0.5
|
|
if config.model.value_support.type == 'symlog':
|
|
loss_func = symlog_loss
|
|
reformed_values = symexp(preds).squeeze()
|
|
target_supports = targets
|
|
elif config.model.value_support.type == 'support':
|
|
loss_func = kl_loss
|
|
reformed_values = DiscreteSupport.vector_to_scalar(preds, **config.model.value_support).squeeze()
|
|
target_supports = DiscreteSupport.scalar_to_vector(targets, **config.model.value_support)
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
value_error = reformed_values - targets
|
|
value_sign = (value_error > 0).float().detach()
|
|
value_weight = (1 - value_sign) * iql_weight + value_sign * (1 - iql_weight)
|
|
value_loss = (value_weight * loss_func(preds, target_supports)).mean(0)
|
|
return value_loss
|
|
|
|
def set_requires_grad(net, value):
|
|
"""Enable/disable gradients for a given (sub)network."""
|
|
for param in net.parameters():
|
|
param.requires_grad_(value)
|
|
|
|
|
|
|
|
def continuous_loss(policy, target_action, target_policy, target_best_action, mask=None, distribution_type='squashed_gaussian'):
|
|
action_dim = policy.size(1) // 2
|
|
n_branches = target_policy.size(1)
|
|
if distribution_type == 'squashed_gaussian':
|
|
mean, std = policy[:, :action_dim], policy[:, action_dim:]
|
|
distr = SquashedNormal(mean, std)
|
|
elif distribution_type == 'truncated_gaussian':
|
|
mean, std = policy[:, :action_dim], policy[:, action_dim:]
|
|
distr = torchrl.modules.TruncatedNormal(mean, std)
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
# full pi loss in Eq. 6 of the paper
|
|
target_action = torch.moveaxis(target_action, 0, 1)
|
|
policy_log_prob = distr.log_prob(target_action).sum(-1)
|
|
policy_log_prob = torch.moveaxis(policy_log_prob, 0, 1)
|
|
fullpi_loss = (-target_policy * policy_log_prob).sum(1)
|
|
|
|
# simple pi loss in Eq. 7 of the paper
|
|
target_best_action = target_best_action.clip(-0.999, 0.999)
|
|
if distribution_type != 'truncated_gaussian':
|
|
simplepi_loss = -distr.log_prob(target_best_action).sum(-1) # simple policy loss of Gumbel MuZero
|
|
else:
|
|
simplepi_loss = -distr.log_prob(target_best_action)
|
|
|
|
# choose action loss according to action dim
|
|
if action_dim == 1:
|
|
loss = fullpi_loss
|
|
else:
|
|
loss = simplepi_loss
|
|
|
|
if distribution_type in ['squashed_gaussian', 'truncated_gaussian']:
|
|
ent_action = distr.rsample((1024,))
|
|
ent_action = ent_action.clip(-0.999, 0.999)
|
|
if distribution_type != 'truncated_gaussian':
|
|
ent_log_prob = distr.log_prob(ent_action).sum(-1)
|
|
else:
|
|
ent_log_prob = distr.log_prob(ent_action)
|
|
entropy = -ent_log_prob.mean(0)
|
|
else:
|
|
entropy = distr.entropy().sum(-1)
|
|
|
|
if mask is not None:
|
|
loss = loss * mask
|
|
entropy = entropy * mask
|
|
|
|
return loss, entropy
|
|
|
|
|
|
class BarlowLoss(nn.Module):
|
|
def __init__(self, lmbda, reduction='mean'):
|
|
super().__init__()
|
|
self.lmbda = lmbda
|
|
self.reduction = reduction
|
|
|
|
def _off_diagonal(self, x):
|
|
# return a flattened view of the off-diagonal elements of a square matrix
|
|
n, m = x.shape
|
|
assert n == m
|
|
return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()
|
|
|
|
def forward(self, z1, z2):
|
|
n, d = z1.shape
|
|
# z1 = F.normalize(z1, dim=-1, p=2)
|
|
# z2 = F.normalize(z2, dim=-1, p=2)
|
|
|
|
# normalize along batch dim
|
|
z1 = (z1 - z1.mean(0)) / z1.std(0) # NxD
|
|
z2 = (z2 - z2.mean(0)) / z2.std(0) # NxD
|
|
|
|
# cross correltation matrix
|
|
cor = torch.mm(z1.T, z2)
|
|
cor.div_(n)
|
|
|
|
# loss
|
|
on_diag = torch.diagonal(cor).add_(-1).pow_(2).sum()
|
|
off_diag = self._off_diagonal(cor).pow_(2).sum()
|
|
|
|
loss = on_diag + self.lmbda * off_diag
|
|
|
|
if self.reduction == 'mean':
|
|
return loss
|
|
else:
|
|
raise ValueError |