add kl div on action embedder, working way towards the kl div loss in pmpo
This commit is contained in:
parent
91d697f8ca
commit
691d9ca007
@ -11,7 +11,7 @@ from dataclasses import dataclass, asdict
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.nested import nested_tensor
|
from torch.nested import nested_tensor
|
||||||
from torch.distributions import Normal
|
from torch.distributions import Normal, kl
|
||||||
from torch.nn import Module, ModuleList, Embedding, Parameter, Sequential, Linear, RMSNorm, Identity
|
from torch.nn import Module, ModuleList, Embedding, Parameter, Sequential, Linear, RMSNorm, Identity
|
||||||
from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, full, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange
|
from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, full, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange
|
||||||
from torch.utils._pytree import tree_flatten, tree_unflatten
|
from torch.utils._pytree import tree_flatten, tree_unflatten
|
||||||
@ -198,6 +198,14 @@ def masked_mean(t, mask = None):
|
|||||||
def log(t, eps = 1e-20):
|
def log(t, eps = 1e-20):
|
||||||
return t.clamp(min = eps).log()
|
return t.clamp(min = eps).log()
|
||||||
|
|
||||||
|
def mean_log_var_to_distr(
|
||||||
|
mean_log_var: Tensor
|
||||||
|
) -> Normal:
|
||||||
|
|
||||||
|
mean, log_var = mean_log_var.unbind(dim = -1)
|
||||||
|
std = (0.5 * log_var).exp()
|
||||||
|
return Normal(mean, std)
|
||||||
|
|
||||||
def safe_cat(tensors, dim):
|
def safe_cat(tensors, dim):
|
||||||
tensors = [*filter(exists, tensors)]
|
tensors = [*filter(exists, tensors)]
|
||||||
|
|
||||||
@ -824,10 +832,7 @@ class ActionEmbedder(Module):
|
|||||||
continuous_entropies = None
|
continuous_entropies = None
|
||||||
|
|
||||||
if exists(continuous_targets):
|
if exists(continuous_targets):
|
||||||
mean, log_var = continuous_action_mean_log_var.unbind(dim = -1)
|
distr = mean_log_var_to_distr(continuous_action_mean_log_var)
|
||||||
std = (0.5 * log_var).exp()
|
|
||||||
|
|
||||||
distr = Normal(mean, std)
|
|
||||||
continuous_log_probs = distr.log_prob(continuous_targets)
|
continuous_log_probs = distr.log_prob(continuous_targets)
|
||||||
|
|
||||||
if return_entropies:
|
if return_entropies:
|
||||||
@ -842,6 +847,54 @@ class ActionEmbedder(Module):
|
|||||||
|
|
||||||
return log_probs, entropies
|
return log_probs, entropies
|
||||||
|
|
||||||
|
def kl_div(
|
||||||
|
self,
|
||||||
|
src: tuple[Tensor | None, Tensor | None],
|
||||||
|
tgt: tuple[Tensor | None, Tensor | None]
|
||||||
|
) -> tuple[Tensor | None, Tensor | None]:
|
||||||
|
|
||||||
|
src_discrete, src_continuous = src
|
||||||
|
tgt_discrete, tgt_continuous = tgt
|
||||||
|
|
||||||
|
discrete_kl_div = None
|
||||||
|
|
||||||
|
# split discrete if it is not already (multiple discrete actions)
|
||||||
|
|
||||||
|
if exists(src_discrete):
|
||||||
|
|
||||||
|
discrete_split = self.num_discrete_actions.tolist()
|
||||||
|
|
||||||
|
if is_tensor(src_discrete):
|
||||||
|
src_discrete = src_discrete.split(discrete_split, dim = -1)
|
||||||
|
|
||||||
|
if is_tensor(tgt_discrete):
|
||||||
|
tgt_discrete = tgt_discrete.split(discrete_split, dim = -1)
|
||||||
|
|
||||||
|
discrete_kl_divs = []
|
||||||
|
|
||||||
|
for src_logit, tgt_logit in zip(src_discrete, tgt_discrete):
|
||||||
|
|
||||||
|
src_log_probs = src_logit.log_softmax(dim = -1)
|
||||||
|
tgt_prob = tgt_logit.softmax(dim = -1)
|
||||||
|
|
||||||
|
one_discrete_kl_div = F.kl_div(src_log_probs, tgt_prob, reduction = 'none')
|
||||||
|
|
||||||
|
discrete_kl_divs.append(one_discrete_kl_div.sum(dim = -1))
|
||||||
|
|
||||||
|
discrete_kl_div = stack(discrete_kl_divs, dim = -1)
|
||||||
|
|
||||||
|
# calculate kl divergence for continuous
|
||||||
|
|
||||||
|
continuous_kl_div = None
|
||||||
|
|
||||||
|
if exists(src_continuous):
|
||||||
|
src_normal = mean_log_var_to_distr(src_continuous)
|
||||||
|
tgt_normal = mean_log_var_to_distr(tgt_continuous)
|
||||||
|
|
||||||
|
continuous_kl_div = kl.kl_divergence(src_normal, tgt_normal)
|
||||||
|
|
||||||
|
return discrete_kl_div, continuous_kl_div
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "dreamer4"
|
name = "dreamer4"
|
||||||
version = "0.0.91"
|
version = "0.0.92"
|
||||||
description = "Dreamer 4"
|
description = "Dreamer 4"
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||||
|
|||||||
@ -346,6 +346,15 @@ def test_action_embedder():
|
|||||||
assert discrete_logits.shape == (2, 3, 8)
|
assert discrete_logits.shape == (2, 3, 8)
|
||||||
assert continuous_mean_log_var.shape == (2, 3, 2, 2)
|
assert continuous_mean_log_var.shape == (2, 3, 2, 2)
|
||||||
|
|
||||||
|
# test kl div
|
||||||
|
|
||||||
|
discrete_logits_tgt, continuous_mean_log_var_tgt = embedder.unembed(action_embed)
|
||||||
|
|
||||||
|
discrete_kl_div, continuous_kl_div = embedder.kl_div((discrete_logits, continuous_mean_log_var), (discrete_logits_tgt, continuous_mean_log_var_tgt))
|
||||||
|
|
||||||
|
assert discrete_kl_div.shape == (2, 3, 2)
|
||||||
|
assert continuous_kl_div.shape == (2, 3, 2)
|
||||||
|
|
||||||
# return discrete split by number of actions
|
# return discrete split by number of actions
|
||||||
|
|
||||||
discrete_logits, continuous_mean_log_var = embedder.unembed(action_embed, return_split_discrete = True)
|
discrete_logits, continuous_mean_log_var = embedder.unembed(action_embed, return_split_discrete = True)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user