From b2725d9b6ee81841c17bd9f524e9d5ccb1d73577 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Sat, 11 Oct 2025 09:24:49 -0700 Subject: [PATCH] complete behavior cloning for one agent --- dreamer4/dreamer4.py | 89 +++++++++++++++++++++++++++++++++++++++++-- tests/test_dreamer.py | 17 ++++++++- 2 files changed, 101 insertions(+), 5 deletions(-) diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 0df98e5..fb598bf 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -8,6 +8,7 @@ from functools import partial import torch import torch.nn.functional as F +from torch.distributions import Normal from torch.nn import Module, ModuleList, Embedding, Parameter, Sequential, Linear, RMSNorm, Identity from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange @@ -58,6 +59,8 @@ LinearNoBias = partial(Linear, bias = False) TokenizerLosses = namedtuple('TokenizerLosses', ('recon', 'lpips')) +WorldModelLosses = namedtuple('WorldModelLosses', ('flow', 'reward', 'behavior_clone')) + # helpers def exists(v): @@ -83,6 +86,9 @@ def is_power_two(num): # tensor helpers +def log(t, eps = 1e-20): + return t.clamp(min = eps).log() + def pack_one(t, pattern): packed, packed_shape = pack([t], pattern) @@ -372,7 +378,7 @@ class ActionEmbedder(Module): def unembed( self, - embeds, # (... d) + embeds, # (... d) discrete_action_types = None, # (na) continuous_action_types = None, # (na) return_split_discrete = False @@ -421,6 +427,46 @@ class ActionEmbedder(Module): return discrete_action_logits, continuous_action_mean_log_var + def log_probs( + self, + embeds, # (... d) + discrete_targets = None, # (... na) + continuous_targets = None, # (... na) + discrete_action_types = None, # (na) + continuous_action_types = None, # (na) + ): + discrete_action_logits, continuous_action_mean_log_var = self.unembed(embeds, discrete_action_types = discrete_action_types, continuous_action_types = continuous_action_types, return_split_discrete = True) + + # discrete + + discrete_log_probs = None + + if exists(discrete_targets): + discrete_log_probs = [] + + for one_discrete_action_logit, one_discrete_target in zip(discrete_action_logits, discrete_targets.unbind(dim = -1)): + + one_discrete_log_probs = one_discrete_action_logit.log_softmax(dim = -1) + one_discrete_target = rearrange(one_discrete_target, '... -> ... 1') + + log_prob = one_discrete_log_probs.gather(-1, one_discrete_target) + discrete_log_probs.append(log_prob) + + discrete_log_probs = cat(discrete_log_probs, dim = -1) + + # continuous + + continuous_log_probs = None + + if exists(continuous_targets): + mean, log_var = continuous_action_mean_log_var.unbind(dim = -1) + std = (0.5 * log_var).exp() + + distr = Normal(mean, std) + continuous_log_probs = distr.log_prob(continuous_targets) + + return discrete_log_probs, continuous_log_probs + def forward( self, *, @@ -1325,6 +1371,7 @@ class DynamicsWorldModel(Module): reward_loss_weight = 0.1, value_head_mlp_depth = 3, policy_head_mlp_depth = 3, + behavior_clone_weight = 0.1, num_latent_genes = 0, # for carrying out evolution within the dreams https://web3.arxiv.org/abs/2503.19037 num_residual_streams = 1 ): @@ -1441,6 +1488,8 @@ class DynamicsWorldModel(Module): unembed_dim = dim * 4 ) + self.behavior_clone_weight = behavior_clone_weight + # each agent token will have the reward embedding of the previous time step - but could eventually just give reward its own token self.add_reward_embed_to_agent_token = add_reward_embed_to_agent_token @@ -1671,7 +1720,8 @@ class DynamicsWorldModel(Module): return_pred_only = False, latent_is_noised = False, return_all_losses = False, - return_agent_tokens = False + return_agent_tokens = False, + add_autoregressive_action_loss = False ): # handle video or latents @@ -2023,17 +2073,48 @@ class DynamicsWorldModel(Module): reward_pred = self.to_reward_pred(encoded_agent_tokens) reward_loss = F.cross_entropy(reward_pred, two_hot_encoding) + # maybe autoregressive action loss + + behavior_clone_loss = self.zero + + if ( + self.num_agents == 1 and + add_autoregressive_action_loss and + (exists(discrete_actions) or exists(continuous_actions)) + ): + assert self.action_embedder.has_actions + + # only for 1 agent + + agent_tokens = rearrange(agent_tokens, 'b t 1 d -> b t d') + policy_embed = self.policy_head(agent_tokens[:, :-1]) + + discrete_log_probs, continuous_log_probs = self.action_embedder.log_probs( + policy_embed, + discrete_targets = discrete_actions[:, 1:] if exists(discrete_actions) else None, + continuous_targets = continuous_actions[:, 1:] if exists(continuous_actions) else None + ) + + if exists(discrete_log_probs): + behavior_clone_loss = behavior_clone_loss + discrete_log_probs.sum(dim = -1).mean() + + if exists(continuous_log_probs): + behavior_clone_loss = behavior_clone_loss + continuous_log_probs.sum(dim = -1).mean() + # gather losses total_loss = ( flow_loss + - reward_loss * self.reward_loss_weight + reward_loss * self.reward_loss_weight + + behavior_clone_loss * self.behavior_clone_weight ) if not return_all_losses: return total_loss - return total_loss, (flow_loss, reward_loss) + losses = WorldModelLosses(flow_loss, reward_loss, behavior_clone_loss) + + return total_loss, losses # dreamer diff --git a/tests/test_dreamer.py b/tests/test_dreamer.py index 0008797..d8f40f8 100644 --- a/tests/test_dreamer.py +++ b/tests/test_dreamer.py @@ -95,7 +95,8 @@ def test_e2e( tasks = tasks, signal_levels = signal_levels, step_sizes_log2 = step_sizes_log2, - discrete_actions = actions + discrete_actions = actions, + add_autoregressive_action_loss = True ) assert flow_loss.numel() == 1 @@ -286,3 +287,17 @@ def test_action_embedder(): assert discrete_logits.shape == (2, 3, 4) assert continuous_mean_log_var.shape == (2, 3, 1, 2) + + # log probs + + assert discrete_logits.shape == (2, 3, 4) + assert continuous_mean_log_var.shape == (2, 3, 1, 2) + + discrete_log_probs, continuous_log_probs = embedder.log_probs( + action_embed, + discrete_targets = discrete_actions, + continuous_targets = continuous_actions + ) + + assert discrete_log_probs.shape == (2, 3, 2) + assert continuous_log_probs.shape == (2, 3, 2)