diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 228edc1..0da5594 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -11,7 +11,7 @@ from dataclasses import dataclass, asdict import torch import torch.nn.functional as F from torch.nested import nested_tensor -from torch.distributions import Normal +from torch.distributions import Normal, kl from torch.nn import Module, ModuleList, Embedding, Parameter, Sequential, Linear, RMSNorm, Identity from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, full, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange from torch.utils._pytree import tree_flatten, tree_unflatten @@ -198,6 +198,14 @@ def masked_mean(t, mask = None): def log(t, eps = 1e-20): return t.clamp(min = eps).log() +def mean_log_var_to_distr( + mean_log_var: Tensor +) -> Normal: + + mean, log_var = mean_log_var.unbind(dim = -1) + std = (0.5 * log_var).exp() + return Normal(mean, std) + def safe_cat(tensors, dim): tensors = [*filter(exists, tensors)] @@ -824,10 +832,7 @@ class ActionEmbedder(Module): continuous_entropies = None if exists(continuous_targets): - mean, log_var = continuous_action_mean_log_var.unbind(dim = -1) - std = (0.5 * log_var).exp() - - distr = Normal(mean, std) + distr = mean_log_var_to_distr(continuous_action_mean_log_var) continuous_log_probs = distr.log_prob(continuous_targets) if return_entropies: @@ -842,6 +847,54 @@ class ActionEmbedder(Module): return log_probs, entropies + def kl_div( + self, + src: tuple[Tensor | None, Tensor | None], + tgt: tuple[Tensor | None, Tensor | None] + ) -> tuple[Tensor | None, Tensor | None]: + + src_discrete, src_continuous = src + tgt_discrete, tgt_continuous = tgt + + discrete_kl_div = None + + # split discrete if it is not already (multiple discrete actions) + + if exists(src_discrete): + + discrete_split = self.num_discrete_actions.tolist() + + if is_tensor(src_discrete): + src_discrete = src_discrete.split(discrete_split, dim = -1) + + if is_tensor(tgt_discrete): + tgt_discrete = tgt_discrete.split(discrete_split, dim = -1) + + discrete_kl_divs = [] + + for src_logit, tgt_logit in zip(src_discrete, tgt_discrete): + + src_log_probs = src_logit.log_softmax(dim = -1) + tgt_prob = tgt_logit.softmax(dim = -1) + + one_discrete_kl_div = F.kl_div(src_log_probs, tgt_prob, reduction = 'none') + + discrete_kl_divs.append(one_discrete_kl_div.sum(dim = -1)) + + discrete_kl_div = stack(discrete_kl_divs, dim = -1) + + # calculate kl divergence for continuous + + continuous_kl_div = None + + if exists(src_continuous): + src_normal = mean_log_var_to_distr(src_continuous) + tgt_normal = mean_log_var_to_distr(tgt_continuous) + + continuous_kl_div = kl.kl_divergence(src_normal, tgt_normal) + + return discrete_kl_div, continuous_kl_div + def forward( self, *, diff --git a/pyproject.toml b/pyproject.toml index 48dfec2..84aa1f6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.0.91" +version = "0.0.92" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } diff --git a/tests/test_dreamer.py b/tests/test_dreamer.py index 85bd78b..be8ce14 100644 --- a/tests/test_dreamer.py +++ b/tests/test_dreamer.py @@ -346,6 +346,15 @@ def test_action_embedder(): assert discrete_logits.shape == (2, 3, 8) assert continuous_mean_log_var.shape == (2, 3, 2, 2) + # test kl div + + discrete_logits_tgt, continuous_mean_log_var_tgt = embedder.unembed(action_embed) + + discrete_kl_div, continuous_kl_div = embedder.kl_div((discrete_logits, continuous_mean_log_var), (discrete_logits_tgt, continuous_mean_log_var_tgt)) + + assert discrete_kl_div.shape == (2, 3, 2) + assert continuous_kl_div.shape == (2, 3, 2) + # return discrete split by number of actions discrete_logits, continuous_mean_log_var = embedder.unembed(action_embed, return_split_discrete = True)