diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 091be87..9b2fac5 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -8,7 +8,7 @@ from functools import partial import torch import torch.nn.functional as F -from torch.nn import Module, ModuleList, Parameter, Sequential, Linear, RMSNorm, Identity +from torch.nn import Module, ModuleList, Embedding, Parameter, Sequential, Linear, RMSNorm, Identity from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange import torchvision @@ -29,6 +29,7 @@ from assoc_scan import AssocScan # l - logit / predicted bins # p - positions (3 for spacetime in this work) # t - time +# na - action dimension (number of discrete and continuous actions) # g - groups of query heads to key heads (gqa) # vc - video channels # vh, vw - video height and width @@ -66,6 +67,9 @@ def default(v, d): def first(arr): return arr[0] +def ensure_tuple(t): + return (t,) if not isinstance(t, tuple) else t + def divisible_by(num, den): return (num % den) == 0 @@ -267,6 +271,116 @@ class SymExpTwoHot(Module): return inverse_pack(encoded, '* l') +# action related + +ActionEmbeds = namedtuple('ActionEmbed', ('discrete', 'continuous')) + +class ActionEmbedder(Module): + def __init__( + self, + dim, + *, + num_discrete_actions: int | tuple[int, ...] = 0, + num_continuous_actions = 0, + ): + super().__init__() + + # handle discrete actions + + num_discrete_actions = tensor(ensure_tuple(num_discrete_actions)) + total_discrete_actions = num_discrete_actions.sum().item() + + self.num_discrete_action_types = len(num_discrete_actions) + self.discrete_action_embed = Embedding(total_discrete_actions, dim) + + # continuous actions + + self.num_continuous_action_types = num_continuous_actions + self.continuous_action_embed = Embedding(num_continuous_actions, dim) + + # defaults + + self.register_buffer('default_discrete_action_types', arange(self.num_discrete_action_types), persistent = False) + self.register_buffer('default_continuous_action_types', arange(self.num_continuous_action_types), persistent = False) + + # calculate offsets + + offsets = F.pad(num_discrete_actions.cumsum(dim = -1), (1, -1), value = 0) + self.register_buffer('discrete_action_offsets', offsets, persistent = False) + + @property + def device(self): + return self.discrete_action_offsets.device + + @property + def has_actions(self): + return self.num_discrete_action_types > 0 or self.num_continuous_action_types > 0 + + def forward( + self, + *, + discrete_actions = None, # (... na) + continuous_actions = None, # (... na) + discrete_action_types = None, # (na) + continuous_action_types = None, # (na) + return_sum_pooled_embeds = True + ): + + discrete_embeds = continuous_embeds = None + + if exists(discrete_actions): + + discrete_action_types = default(discrete_action_types, self.default_discrete_action_types) + + if exists(discrete_action_types) and not is_tensor(discrete_action_types): + if isinstance(discrete_action_types, int): + discrete_action_types = (discrete_action_types,) + + discrete_action_types = tensor(discrete_action_types, device = self.device) + + offsets = self.discrete_action_offsets[discrete_action_types] + + assert offsets.shape[-1] == discrete_actions.shape[-1], 'mismatched number of discrete actions' + + # offset the discrete actions based on the action types passed in (by default all discrete actions) and the calculated offset + + discrete_actions_offsetted = einx.add('... na, na', discrete_actions, offsets) + discrete_embeds = self.discrete_action_embed(discrete_actions_offsetted) + + if exists(continuous_actions): + continuous_action_types = default(continuous_action_types, self.default_continuous_action_types) + + if exists(continuous_action_types) and not is_tensor(continuous_action_types): + if isinstance(continuous_action_types, int): + continuous_action_types = (continuous_action_types,) + + continuous_action_types = tensor(continuous_action_types, device = self.device) + + assert continuous_action_types.shape[-1] == continuous_actions.shape[-1], 'mismatched number of continuous actions' + + continuous_action_embed = self.continuous_action_embed(continuous_action_types) + + # continuous embed is just the selected continuous action type with the scale + + continuous_embeds = einx.multiply('na d, ... na -> ... na d', continuous_action_embed, continuous_actions) + + # return not pooled + + if not return_sum_pooled_embeds: + return ActionEmbeds(discrete_embeds, continuous_embeds) + + # handle sum pooling, which is what they did in the paper for all the actions + + pooled = 0. + + if exists(discrete_embeds): + pooled = pooled + reduce(discrete_embeds, '... na d -> ... d', 'sum') + + if exists(continuous_embeds): + pooled = pooled + reduce(continuous_embeds, '... na d -> ... d', 'sum') + + return pooled + # generalized advantage estimate @torch.no_grad() @@ -1089,6 +1203,8 @@ class DynamicsWorldModel(Module): prob_no_shortcut_train = None, # probability of no shortcut training, defaults to 1 / num_step_sizes add_reward_embed_to_agent_token = False, add_reward_embed_dropout = 0.1, + num_discrete_actions: int | tuple[int, ...] = 0, + num_continuous_actions = 0, reward_loss_weight = 0.1, value_head_mlp_depth = 3, num_latent_genes = 0 # for carrying out evolution within the dreams https://web3.arxiv.org/abs/2503.19037 @@ -1182,6 +1298,14 @@ class DynamicsWorldModel(Module): self.agent_has_genes = num_latent_genes > 0 self.latent_genes = Parameter(randn(num_latent_genes, dim) * 1e-2) + # action embedder + + self.action_embedder = ActionEmbedder( + dim = dim, + num_discrete_actions = num_discrete_actions, + num_continuous_actions = num_continuous_actions + ) + # each agent token will have the reward embedding of the previous time step - but could eventually just give reward its own token self.add_reward_embed_to_agent_token = add_reward_embed_to_agent_token @@ -1397,14 +1521,18 @@ class DynamicsWorldModel(Module): def forward( self, *, - video = None, - latents = None, # (b t n d) | (b t d) - signal_levels = None, # () | (b) | (b t) - step_sizes = None, # () | (b) - step_sizes_log2 = None, # () | (b) - tasks = None, # (b) - rewards = None, # (b t) - latent_gene_ids = None, # (b) + video = None, # (b c t vh vw) + latents = None, # (b t n d) | (b t d) + signal_levels = None, # () | (b) | (b t) + step_sizes = None, # () | (b) + step_sizes_log2 = None, # () | (b) + latent_gene_ids = None, # (b) + tasks = None, # (b) + rewards = None, # (b t) + discrete_actions = None, # (b t na) + continuous_actions = None, # (b t na) + discrete_action_types = None, # (na) + continuous_action_types = None, # (na) return_pred_only = False, latent_is_noised = False, return_all_losses = False, @@ -1533,6 +1661,20 @@ class DynamicsWorldModel(Module): agent_tokens = repeat(agent_tokens, 'b ... d -> b t ... d', t = time) + # maybe add the action embed to the agent tokens per time step + + if exists(discrete_actions) or exists(continuous_actions): + assert self.action_embedder.has_actions + + action_embed = self.action_embedder( + discrete_actions = discrete_actions, + discrete_action_types = discrete_action_types, + continuous_actions = continuous_actions, + continuous_action_types = continuous_action_types + ) + + agent_tokens = einx.add('b t ... d, b t d', agent_tokens, action_embed) + # maybe add a reward embedding to agent tokens if exists(rewards): diff --git a/pyproject.toml b/pyproject.toml index 82c9f58..855620f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.0.7" +version = "0.0.8" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } diff --git a/tests/test_dreamer.py b/tests/test_dreamer.py index 9e9e01e..09f758b 100644 --- a/tests/test_dreamer.py +++ b/tests/test_dreamer.py @@ -55,6 +55,7 @@ def test_e2e( depth = 4, num_spatial_tokens = num_spatial_tokens, pred_orig_latent = pred_orig_latent, + num_discrete_actions = 4, attn_dim_head = 16, attn_kwargs = dict( heads = heads, @@ -79,11 +80,14 @@ def test_e2e( if add_task_embeds: tasks = torch.randint(0, 4, (2,)) + actions = torch.randint(0, 4, (2, 4, 1)) + flow_loss = dynamics( **dynamics_input, tasks = tasks, signal_levels = signal_levels, - step_sizes_log2 = step_sizes_log2 + step_sizes_log2 = step_sizes_log2, + discrete_actions = actions ) assert flow_loss.numel() == 1 @@ -173,3 +177,72 @@ def test_attend_factory( out = attend(q, k, v) assert torch.allclose(flex_out, out, atol = 1e-6) + +def test_action_embedder(): + from dreamer4.dreamer4 import ActionEmbedder + + # 1 discrete action with 4 choices + + embedder = ActionEmbedder( + 512, + num_discrete_actions = 4 + ) + + actions = torch.randint(0, 4, (2, 3, 1)) + action_embed = embedder(discrete_actions = actions) + + assert action_embed.shape == (2, 3, 512) + + # 2 discrete actions with 4 choices each + + embedder = ActionEmbedder( + 512, + num_discrete_actions = (4, 4) + ) + + actions = torch.randint(0, 4, (2, 3, 2)) + action_embed = embedder(discrete_actions = actions) + + assert action_embed.shape == (2, 3, 512) + + # picking out only the second discrete action + + actions = torch.randint(0, 4, (2, 3, 1)) + action_embed = embedder(discrete_actions = actions, discrete_action_types = 1) + + assert action_embed.shape == (2, 3, 512) + + # 2 continuous actions + + embedder = ActionEmbedder( + 512, + num_continuous_actions = 2 + ) + + actions = torch.randn((2, 3, 2)) + action_embed = embedder(continuous_actions = actions) + + assert action_embed.shape == (2, 3, 512) + + # 2 discrete actions with 4 choices each and 2 continuous actions + + embedder = ActionEmbedder( + 512, + num_discrete_actions = (4, 4), + num_continuous_actions = 2 + ) + + discrete_actions = torch.randint(0, 4, (2, 3, 2)) + continuous_actions = torch.randn(2, 3, 2) + + action_embed = embedder(discrete_actions = discrete_actions, continuous_actions = continuous_actions) + assert action_embed.shape == (2, 3, 512) + + # picking out one discrete and one continuous + + discrete_actions = torch.randint(0, 4, (2, 3, 1)) + continuous_actions = torch.randn(2, 3, 1) + + action_embed = embedder(discrete_actions = discrete_actions, continuous_actions = continuous_actions, discrete_action_types = 1, continuous_action_types = 0) + + assert action_embed.shape == (2, 3, 512)