“Shengjiewang-Jason” 1367bca203 first commit
2024-06-07 16:02:01 +08:00

146 lines
4.9 KiB
Python

# Copyright (c) EVAR Lab, IIIS, Tsinghua University.
#
# This source code is licensed under the GNU License, Version 3.0
# found in the LICENSE file in the root directory of this source tree.
import torch
import torchrl
import torch.nn as nn
import torch.nn.functional as F
from ez.utils.distribution import SquashedNormal, TruncatedNormal, ContDist
from ez.utils.format import atanh
from ez.utils.format import symlog, symexp, DiscreteSupport
from torch.cuda.amp import autocast as autocast
def cosine_similarity_loss(f1, f2):
"""Cosine Consistency loss function: similarity loss
Parameters
"""
f1 = F.normalize(f1, p=2., dim=-1, eps=1e-5)
f2 = F.normalize(f2, p=2., dim=-1, eps=1e-5)
return -(f1 * f2).sum(dim=1)
def kl_loss(prediction, target):
return -(torch.log_softmax(prediction, dim=-1) * target).sum(-1)
def symlog_loss(prediction, target):
return 0.5 * (prediction.squeeze() - symlog(target)) ** 2
def Value_loss(preds, targets, config):
v_num = config.train.v_num
targets = targets.repeat(v_num, 1)
iql_weight = config.train.IQL_weight
if not config.train.use_IQL:
iql_weight = 0.5
if config.model.value_support.type == 'symlog':
loss_func = symlog_loss
reformed_values = symexp(preds).squeeze()
target_supports = targets
elif config.model.value_support.type == 'support':
loss_func = kl_loss
reformed_values = DiscreteSupport.vector_to_scalar(preds, **config.model.value_support).squeeze()
target_supports = DiscreteSupport.scalar_to_vector(targets, **config.model.value_support)
else:
raise NotImplementedError
value_error = reformed_values - targets
value_sign = (value_error > 0).float().detach()
value_weight = (1 - value_sign) * iql_weight + value_sign * (1 - iql_weight)
value_loss = (value_weight * loss_func(preds, target_supports)).mean(0)
return value_loss
def set_requires_grad(net, value):
"""Enable/disable gradients for a given (sub)network."""
for param in net.parameters():
param.requires_grad_(value)
def continuous_loss(policy, target_action, target_policy, target_best_action, mask=None, distribution_type='squashed_gaussian'):
action_dim = policy.size(1) // 2
n_branches = target_policy.size(1)
if distribution_type == 'squashed_gaussian':
mean, std = policy[:, :action_dim], policy[:, action_dim:]
distr = SquashedNormal(mean, std)
elif distribution_type == 'truncated_gaussian':
mean, std = policy[:, :action_dim], policy[:, action_dim:]
distr = torchrl.modules.TruncatedNormal(mean, std)
else:
raise NotImplementedError
# full pi loss in Eq. 6 of the paper
target_action = torch.moveaxis(target_action, 0, 1)
policy_log_prob = distr.log_prob(target_action).sum(-1)
policy_log_prob = torch.moveaxis(policy_log_prob, 0, 1)
fullpi_loss = (-target_policy * policy_log_prob).sum(1)
# simple pi loss in Eq. 7 of the paper
target_best_action = target_best_action.clip(-0.999, 0.999)
if distribution_type != 'truncated_gaussian':
simplepi_loss = -distr.log_prob(target_best_action).sum(-1) # simple policy loss of Gumbel MuZero
else:
simplepi_loss = -distr.log_prob(target_best_action)
# choose action loss according to action dim
if action_dim == 1:
loss = fullpi_loss
else:
loss = simplepi_loss
if distribution_type in ['squashed_gaussian', 'truncated_gaussian']:
ent_action = distr.rsample((1024,))
ent_action = ent_action.clip(-0.999, 0.999)
if distribution_type != 'truncated_gaussian':
ent_log_prob = distr.log_prob(ent_action).sum(-1)
else:
ent_log_prob = distr.log_prob(ent_action)
entropy = -ent_log_prob.mean(0)
else:
entropy = distr.entropy().sum(-1)
if mask is not None:
loss = loss * mask
entropy = entropy * mask
return loss, entropy
class BarlowLoss(nn.Module):
def __init__(self, lmbda, reduction='mean'):
super().__init__()
self.lmbda = lmbda
self.reduction = reduction
def _off_diagonal(self, x):
# return a flattened view of the off-diagonal elements of a square matrix
n, m = x.shape
assert n == m
return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()
def forward(self, z1, z2):
n, d = z1.shape
# z1 = F.normalize(z1, dim=-1, p=2)
# z2 = F.normalize(z2, dim=-1, p=2)
# normalize along batch dim
z1 = (z1 - z1.mean(0)) / z1.std(0) # NxD
z2 = (z2 - z2.mean(0)) / z2.std(0) # NxD
# cross correltation matrix
cor = torch.mm(z1.T, z2)
cor.div_(n)
# loss
on_diag = torch.diagonal(cor).add_(-1).pow_(2).sum()
off_diag = self._off_diagonal(cor).pow_(2).sum()
loss = on_diag + self.lmbda * off_diag
if self.reduction == 'mean':
return loss
else:
raise ValueError