add kl div on action embedder, working way towards the kl div loss in pmpo

This commit is contained in:
lucidrains 2025-10-29 10:02:25 -07:00
parent 91d697f8ca
commit 691d9ca007
3 changed files with 68 additions and 6 deletions

View File

@ -11,7 +11,7 @@ from dataclasses import dataclass, asdict
import torch
import torch.nn.functional as F
from torch.nested import nested_tensor
from torch.distributions import Normal
from torch.distributions import Normal, kl
from torch.nn import Module, ModuleList, Embedding, Parameter, Sequential, Linear, RMSNorm, Identity
from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, full, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange
from torch.utils._pytree import tree_flatten, tree_unflatten
@ -198,6 +198,14 @@ def masked_mean(t, mask = None):
def log(t, eps = 1e-20):
return t.clamp(min = eps).log()
def mean_log_var_to_distr(
mean_log_var: Tensor
) -> Normal:
mean, log_var = mean_log_var.unbind(dim = -1)
std = (0.5 * log_var).exp()
return Normal(mean, std)
def safe_cat(tensors, dim):
tensors = [*filter(exists, tensors)]
@ -824,10 +832,7 @@ class ActionEmbedder(Module):
continuous_entropies = None
if exists(continuous_targets):
mean, log_var = continuous_action_mean_log_var.unbind(dim = -1)
std = (0.5 * log_var).exp()
distr = Normal(mean, std)
distr = mean_log_var_to_distr(continuous_action_mean_log_var)
continuous_log_probs = distr.log_prob(continuous_targets)
if return_entropies:
@ -842,6 +847,54 @@ class ActionEmbedder(Module):
return log_probs, entropies
def kl_div(
self,
src: tuple[Tensor | None, Tensor | None],
tgt: tuple[Tensor | None, Tensor | None]
) -> tuple[Tensor | None, Tensor | None]:
src_discrete, src_continuous = src
tgt_discrete, tgt_continuous = tgt
discrete_kl_div = None
# split discrete if it is not already (multiple discrete actions)
if exists(src_discrete):
discrete_split = self.num_discrete_actions.tolist()
if is_tensor(src_discrete):
src_discrete = src_discrete.split(discrete_split, dim = -1)
if is_tensor(tgt_discrete):
tgt_discrete = tgt_discrete.split(discrete_split, dim = -1)
discrete_kl_divs = []
for src_logit, tgt_logit in zip(src_discrete, tgt_discrete):
src_log_probs = src_logit.log_softmax(dim = -1)
tgt_prob = tgt_logit.softmax(dim = -1)
one_discrete_kl_div = F.kl_div(src_log_probs, tgt_prob, reduction = 'none')
discrete_kl_divs.append(one_discrete_kl_div.sum(dim = -1))
discrete_kl_div = stack(discrete_kl_divs, dim = -1)
# calculate kl divergence for continuous
continuous_kl_div = None
if exists(src_continuous):
src_normal = mean_log_var_to_distr(src_continuous)
tgt_normal = mean_log_var_to_distr(tgt_continuous)
continuous_kl_div = kl.kl_divergence(src_normal, tgt_normal)
return discrete_kl_div, continuous_kl_div
def forward(
self,
*,

View File

@ -1,6 +1,6 @@
[project]
name = "dreamer4"
version = "0.0.91"
version = "0.0.92"
description = "Dreamer 4"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

View File

@ -346,6 +346,15 @@ def test_action_embedder():
assert discrete_logits.shape == (2, 3, 8)
assert continuous_mean_log_var.shape == (2, 3, 2, 2)
# test kl div
discrete_logits_tgt, continuous_mean_log_var_tgt = embedder.unembed(action_embed)
discrete_kl_div, continuous_kl_div = embedder.kl_div((discrete_logits, continuous_mean_log_var), (discrete_logits_tgt, continuous_mean_log_var_tgt))
assert discrete_kl_div.shape == (2, 3, 2)
assert continuous_kl_div.shape == (2, 3, 2)
# return discrete split by number of actions
discrete_logits, continuous_mean_log_var = embedder.unembed(action_embed, return_split_discrete = True)