From e2d86a4543ee75de59a3d4c99a8b90a833e761b3 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Thu, 9 Oct 2025 07:53:42 -0700 Subject: [PATCH] add a complete action embedder that can accept any number of discrete actions with variable bins as well as any number of continuous actions, pooled and added to the agent token as described in the paper (seems like they fixed that horrendous hack in dreamer v3 with sticky action) --- dreamer4/dreamer4.py | 160 +++++++++++++++++++++++++++++++++++++++--- pyproject.toml | 2 +- tests/test_dreamer.py | 75 +++++++++++++++++++- 3 files changed, 226 insertions(+), 11 deletions(-) diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 091be87..9b2fac5 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -8,7 +8,7 @@ from functools import partial import torch import torch.nn.functional as F -from torch.nn import Module, ModuleList, Parameter, Sequential, Linear, RMSNorm, Identity +from torch.nn import Module, ModuleList, Embedding, Parameter, Sequential, Linear, RMSNorm, Identity from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange import torchvision @@ -29,6 +29,7 @@ from assoc_scan import AssocScan # l - logit / predicted bins # p - positions (3 for spacetime in this work) # t - time +# na - action dimension (number of discrete and continuous actions) # g - groups of query heads to key heads (gqa) # vc - video channels # vh, vw - video height and width @@ -66,6 +67,9 @@ def default(v, d): def first(arr): return arr[0] +def ensure_tuple(t): + return (t,) if not isinstance(t, tuple) else t + def divisible_by(num, den): return (num % den) == 0 @@ -267,6 +271,116 @@ class SymExpTwoHot(Module): return inverse_pack(encoded, '* l') +# action related + +ActionEmbeds = namedtuple('ActionEmbed', ('discrete', 'continuous')) + +class ActionEmbedder(Module): + def __init__( + self, + dim, + *, + num_discrete_actions: int | tuple[int, ...] = 0, + num_continuous_actions = 0, + ): + super().__init__() + + # handle discrete actions + + num_discrete_actions = tensor(ensure_tuple(num_discrete_actions)) + total_discrete_actions = num_discrete_actions.sum().item() + + self.num_discrete_action_types = len(num_discrete_actions) + self.discrete_action_embed = Embedding(total_discrete_actions, dim) + + # continuous actions + + self.num_continuous_action_types = num_continuous_actions + self.continuous_action_embed = Embedding(num_continuous_actions, dim) + + # defaults + + self.register_buffer('default_discrete_action_types', arange(self.num_discrete_action_types), persistent = False) + self.register_buffer('default_continuous_action_types', arange(self.num_continuous_action_types), persistent = False) + + # calculate offsets + + offsets = F.pad(num_discrete_actions.cumsum(dim = -1), (1, -1), value = 0) + self.register_buffer('discrete_action_offsets', offsets, persistent = False) + + @property + def device(self): + return self.discrete_action_offsets.device + + @property + def has_actions(self): + return self.num_discrete_action_types > 0 or self.num_continuous_action_types > 0 + + def forward( + self, + *, + discrete_actions = None, # (... na) + continuous_actions = None, # (... na) + discrete_action_types = None, # (na) + continuous_action_types = None, # (na) + return_sum_pooled_embeds = True + ): + + discrete_embeds = continuous_embeds = None + + if exists(discrete_actions): + + discrete_action_types = default(discrete_action_types, self.default_discrete_action_types) + + if exists(discrete_action_types) and not is_tensor(discrete_action_types): + if isinstance(discrete_action_types, int): + discrete_action_types = (discrete_action_types,) + + discrete_action_types = tensor(discrete_action_types, device = self.device) + + offsets = self.discrete_action_offsets[discrete_action_types] + + assert offsets.shape[-1] == discrete_actions.shape[-1], 'mismatched number of discrete actions' + + # offset the discrete actions based on the action types passed in (by default all discrete actions) and the calculated offset + + discrete_actions_offsetted = einx.add('... na, na', discrete_actions, offsets) + discrete_embeds = self.discrete_action_embed(discrete_actions_offsetted) + + if exists(continuous_actions): + continuous_action_types = default(continuous_action_types, self.default_continuous_action_types) + + if exists(continuous_action_types) and not is_tensor(continuous_action_types): + if isinstance(continuous_action_types, int): + continuous_action_types = (continuous_action_types,) + + continuous_action_types = tensor(continuous_action_types, device = self.device) + + assert continuous_action_types.shape[-1] == continuous_actions.shape[-1], 'mismatched number of continuous actions' + + continuous_action_embed = self.continuous_action_embed(continuous_action_types) + + # continuous embed is just the selected continuous action type with the scale + + continuous_embeds = einx.multiply('na d, ... na -> ... na d', continuous_action_embed, continuous_actions) + + # return not pooled + + if not return_sum_pooled_embeds: + return ActionEmbeds(discrete_embeds, continuous_embeds) + + # handle sum pooling, which is what they did in the paper for all the actions + + pooled = 0. + + if exists(discrete_embeds): + pooled = pooled + reduce(discrete_embeds, '... na d -> ... d', 'sum') + + if exists(continuous_embeds): + pooled = pooled + reduce(continuous_embeds, '... na d -> ... d', 'sum') + + return pooled + # generalized advantage estimate @torch.no_grad() @@ -1089,6 +1203,8 @@ class DynamicsWorldModel(Module): prob_no_shortcut_train = None, # probability of no shortcut training, defaults to 1 / num_step_sizes add_reward_embed_to_agent_token = False, add_reward_embed_dropout = 0.1, + num_discrete_actions: int | tuple[int, ...] = 0, + num_continuous_actions = 0, reward_loss_weight = 0.1, value_head_mlp_depth = 3, num_latent_genes = 0 # for carrying out evolution within the dreams https://web3.arxiv.org/abs/2503.19037 @@ -1182,6 +1298,14 @@ class DynamicsWorldModel(Module): self.agent_has_genes = num_latent_genes > 0 self.latent_genes = Parameter(randn(num_latent_genes, dim) * 1e-2) + # action embedder + + self.action_embedder = ActionEmbedder( + dim = dim, + num_discrete_actions = num_discrete_actions, + num_continuous_actions = num_continuous_actions + ) + # each agent token will have the reward embedding of the previous time step - but could eventually just give reward its own token self.add_reward_embed_to_agent_token = add_reward_embed_to_agent_token @@ -1397,14 +1521,18 @@ class DynamicsWorldModel(Module): def forward( self, *, - video = None, - latents = None, # (b t n d) | (b t d) - signal_levels = None, # () | (b) | (b t) - step_sizes = None, # () | (b) - step_sizes_log2 = None, # () | (b) - tasks = None, # (b) - rewards = None, # (b t) - latent_gene_ids = None, # (b) + video = None, # (b c t vh vw) + latents = None, # (b t n d) | (b t d) + signal_levels = None, # () | (b) | (b t) + step_sizes = None, # () | (b) + step_sizes_log2 = None, # () | (b) + latent_gene_ids = None, # (b) + tasks = None, # (b) + rewards = None, # (b t) + discrete_actions = None, # (b t na) + continuous_actions = None, # (b t na) + discrete_action_types = None, # (na) + continuous_action_types = None, # (na) return_pred_only = False, latent_is_noised = False, return_all_losses = False, @@ -1533,6 +1661,20 @@ class DynamicsWorldModel(Module): agent_tokens = repeat(agent_tokens, 'b ... d -> b t ... d', t = time) + # maybe add the action embed to the agent tokens per time step + + if exists(discrete_actions) or exists(continuous_actions): + assert self.action_embedder.has_actions + + action_embed = self.action_embedder( + discrete_actions = discrete_actions, + discrete_action_types = discrete_action_types, + continuous_actions = continuous_actions, + continuous_action_types = continuous_action_types + ) + + agent_tokens = einx.add('b t ... d, b t d', agent_tokens, action_embed) + # maybe add a reward embedding to agent tokens if exists(rewards): diff --git a/pyproject.toml b/pyproject.toml index 82c9f58..855620f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.0.7" +version = "0.0.8" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } diff --git a/tests/test_dreamer.py b/tests/test_dreamer.py index 9e9e01e..09f758b 100644 --- a/tests/test_dreamer.py +++ b/tests/test_dreamer.py @@ -55,6 +55,7 @@ def test_e2e( depth = 4, num_spatial_tokens = num_spatial_tokens, pred_orig_latent = pred_orig_latent, + num_discrete_actions = 4, attn_dim_head = 16, attn_kwargs = dict( heads = heads, @@ -79,11 +80,14 @@ def test_e2e( if add_task_embeds: tasks = torch.randint(0, 4, (2,)) + actions = torch.randint(0, 4, (2, 4, 1)) + flow_loss = dynamics( **dynamics_input, tasks = tasks, signal_levels = signal_levels, - step_sizes_log2 = step_sizes_log2 + step_sizes_log2 = step_sizes_log2, + discrete_actions = actions ) assert flow_loss.numel() == 1 @@ -173,3 +177,72 @@ def test_attend_factory( out = attend(q, k, v) assert torch.allclose(flex_out, out, atol = 1e-6) + +def test_action_embedder(): + from dreamer4.dreamer4 import ActionEmbedder + + # 1 discrete action with 4 choices + + embedder = ActionEmbedder( + 512, + num_discrete_actions = 4 + ) + + actions = torch.randint(0, 4, (2, 3, 1)) + action_embed = embedder(discrete_actions = actions) + + assert action_embed.shape == (2, 3, 512) + + # 2 discrete actions with 4 choices each + + embedder = ActionEmbedder( + 512, + num_discrete_actions = (4, 4) + ) + + actions = torch.randint(0, 4, (2, 3, 2)) + action_embed = embedder(discrete_actions = actions) + + assert action_embed.shape == (2, 3, 512) + + # picking out only the second discrete action + + actions = torch.randint(0, 4, (2, 3, 1)) + action_embed = embedder(discrete_actions = actions, discrete_action_types = 1) + + assert action_embed.shape == (2, 3, 512) + + # 2 continuous actions + + embedder = ActionEmbedder( + 512, + num_continuous_actions = 2 + ) + + actions = torch.randn((2, 3, 2)) + action_embed = embedder(continuous_actions = actions) + + assert action_embed.shape == (2, 3, 512) + + # 2 discrete actions with 4 choices each and 2 continuous actions + + embedder = ActionEmbedder( + 512, + num_discrete_actions = (4, 4), + num_continuous_actions = 2 + ) + + discrete_actions = torch.randint(0, 4, (2, 3, 2)) + continuous_actions = torch.randn(2, 3, 2) + + action_embed = embedder(discrete_actions = discrete_actions, continuous_actions = continuous_actions) + assert action_embed.shape == (2, 3, 512) + + # picking out one discrete and one continuous + + discrete_actions = torch.randint(0, 4, (2, 3, 1)) + continuous_actions = torch.randn(2, 3, 1) + + action_embed = embedder(discrete_actions = discrete_actions, continuous_actions = continuous_actions, discrete_action_types = 1, continuous_action_types = 0) + + assert action_embed.shape == (2, 3, 512)