add kl div on action embedder, working way towards the kl div loss in pmpo
This commit is contained in:
parent
91d697f8ca
commit
691d9ca007
@ -11,7 +11,7 @@ from dataclasses import dataclass, asdict
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nested import nested_tensor
|
||||
from torch.distributions import Normal
|
||||
from torch.distributions import Normal, kl
|
||||
from torch.nn import Module, ModuleList, Embedding, Parameter, Sequential, Linear, RMSNorm, Identity
|
||||
from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, full, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange
|
||||
from torch.utils._pytree import tree_flatten, tree_unflatten
|
||||
@ -198,6 +198,14 @@ def masked_mean(t, mask = None):
|
||||
def log(t, eps = 1e-20):
|
||||
return t.clamp(min = eps).log()
|
||||
|
||||
def mean_log_var_to_distr(
|
||||
mean_log_var: Tensor
|
||||
) -> Normal:
|
||||
|
||||
mean, log_var = mean_log_var.unbind(dim = -1)
|
||||
std = (0.5 * log_var).exp()
|
||||
return Normal(mean, std)
|
||||
|
||||
def safe_cat(tensors, dim):
|
||||
tensors = [*filter(exists, tensors)]
|
||||
|
||||
@ -824,10 +832,7 @@ class ActionEmbedder(Module):
|
||||
continuous_entropies = None
|
||||
|
||||
if exists(continuous_targets):
|
||||
mean, log_var = continuous_action_mean_log_var.unbind(dim = -1)
|
||||
std = (0.5 * log_var).exp()
|
||||
|
||||
distr = Normal(mean, std)
|
||||
distr = mean_log_var_to_distr(continuous_action_mean_log_var)
|
||||
continuous_log_probs = distr.log_prob(continuous_targets)
|
||||
|
||||
if return_entropies:
|
||||
@ -842,6 +847,54 @@ class ActionEmbedder(Module):
|
||||
|
||||
return log_probs, entropies
|
||||
|
||||
def kl_div(
|
||||
self,
|
||||
src: tuple[Tensor | None, Tensor | None],
|
||||
tgt: tuple[Tensor | None, Tensor | None]
|
||||
) -> tuple[Tensor | None, Tensor | None]:
|
||||
|
||||
src_discrete, src_continuous = src
|
||||
tgt_discrete, tgt_continuous = tgt
|
||||
|
||||
discrete_kl_div = None
|
||||
|
||||
# split discrete if it is not already (multiple discrete actions)
|
||||
|
||||
if exists(src_discrete):
|
||||
|
||||
discrete_split = self.num_discrete_actions.tolist()
|
||||
|
||||
if is_tensor(src_discrete):
|
||||
src_discrete = src_discrete.split(discrete_split, dim = -1)
|
||||
|
||||
if is_tensor(tgt_discrete):
|
||||
tgt_discrete = tgt_discrete.split(discrete_split, dim = -1)
|
||||
|
||||
discrete_kl_divs = []
|
||||
|
||||
for src_logit, tgt_logit in zip(src_discrete, tgt_discrete):
|
||||
|
||||
src_log_probs = src_logit.log_softmax(dim = -1)
|
||||
tgt_prob = tgt_logit.softmax(dim = -1)
|
||||
|
||||
one_discrete_kl_div = F.kl_div(src_log_probs, tgt_prob, reduction = 'none')
|
||||
|
||||
discrete_kl_divs.append(one_discrete_kl_div.sum(dim = -1))
|
||||
|
||||
discrete_kl_div = stack(discrete_kl_divs, dim = -1)
|
||||
|
||||
# calculate kl divergence for continuous
|
||||
|
||||
continuous_kl_div = None
|
||||
|
||||
if exists(src_continuous):
|
||||
src_normal = mean_log_var_to_distr(src_continuous)
|
||||
tgt_normal = mean_log_var_to_distr(tgt_continuous)
|
||||
|
||||
continuous_kl_div = kl.kl_divergence(src_normal, tgt_normal)
|
||||
|
||||
return discrete_kl_div, continuous_kl_div
|
||||
|
||||
def forward(
|
||||
self,
|
||||
*,
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "dreamer4"
|
||||
version = "0.0.91"
|
||||
version = "0.0.92"
|
||||
description = "Dreamer 4"
|
||||
authors = [
|
||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||
|
||||
@ -346,6 +346,15 @@ def test_action_embedder():
|
||||
assert discrete_logits.shape == (2, 3, 8)
|
||||
assert continuous_mean_log_var.shape == (2, 3, 2, 2)
|
||||
|
||||
# test kl div
|
||||
|
||||
discrete_logits_tgt, continuous_mean_log_var_tgt = embedder.unembed(action_embed)
|
||||
|
||||
discrete_kl_div, continuous_kl_div = embedder.kl_div((discrete_logits, continuous_mean_log_var), (discrete_logits_tgt, continuous_mean_log_var_tgt))
|
||||
|
||||
assert discrete_kl_div.shape == (2, 3, 2)
|
||||
assert continuous_kl_div.shape == (2, 3, 2)
|
||||
|
||||
# return discrete split by number of actions
|
||||
|
||||
discrete_logits, continuous_mean_log_var = embedder.unembed(action_embed, return_split_discrete = True)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user