From 32aa355e37b93986d99be66f792a9c0e4d2cbd4f Mon Sep 17 00:00:00 2001 From: lucidrains Date: Fri, 10 Oct 2025 10:41:48 -0700 Subject: [PATCH] prepare unembedding parameters in ActionEmbedder as well as the policy head, to allow for behavioral cloning before RL --- dreamer4/dreamer4.py | 59 ++++++++++++++++++++++++++++++++++++++++++- pyproject.toml | 2 +- tests/test_dreamer.py | 19 ++++++++++++++ 3 files changed, 78 insertions(+), 2 deletions(-) diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 2998f28..13eb4b2 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -282,7 +282,9 @@ class ActionEmbedder(Module): *, num_discrete_actions: int | tuple[int, ...] = 0, num_continuous_actions = 0, - continuous_norm_stats: tuple[tuple[float, float], ...] | None = None + continuous_norm_stats: tuple[tuple[float, float], ...] | None = None, + can_unembed = False, + unembed_dim = None ): super().__init__() @@ -294,6 +296,8 @@ class ActionEmbedder(Module): self.num_discrete_action_types = len(num_discrete_actions) self.discrete_action_embed = Embedding(total_discrete_actions, dim) + self.register_buffer('num_discrete_actions', num_discrete_actions, persistent = False) + # continuous actions self.num_continuous_action_types = num_continuous_actions @@ -314,6 +318,15 @@ class ActionEmbedder(Module): offsets = F.pad(num_discrete_actions.cumsum(dim = -1), (1, -1), value = 0) self.register_buffer('discrete_action_offsets', offsets, persistent = False) + # unembedding + + self.can_unembed = can_unembed + + if can_unembed: + unembed_dim = default(unembed_dim, dim) + self.discrete_action_unembed = Parameter(torch.randn(total_discrete_actions, unembed_dim) * 1e-2) + self.continuous_action_unembed = Parameter(torch.randn(num_continuous_actions, unembed_dim, 2) * 1e-2) + @property def device(self): return self.discrete_action_offsets.device @@ -322,6 +335,40 @@ class ActionEmbedder(Module): def has_actions(self): return self.num_discrete_action_types > 0 or self.num_continuous_action_types > 0 + def unembed( + self, + embeds, # (... d) + discrete_action_types = None, # (na) + continuous_action_types = None, # (na) + + ): # (... discrete_na), (... continuous_na 2) + + assert self.can_unembed, 'can only unembed for predicted discrete and continuous actions if `can_unembed = True` is set on init' + + assert not exists(discrete_action_types), 'selecting subset of discrete action types to unembed not completed yet' + + # discrete actions + + discrete_action_logits = None + + if self.num_discrete_action_types > 0: + discrete_action_logits = einsum(embeds, self.discrete_action_unembed, '... d, na d -> ... na') + + # continuous actions + + continuous_action_mean_log_var = None + + if self.num_continuous_action_types > 0: + + continuous_action_unembed = self.continuous_action_unembed + + if exists(continuous_action_types): + continuous_action_unembed = continuous_action_unembed[continuous_action_types] + + continuous_action_mean_log_var = einsum(embeds, continuous_action_unembed, '... d, na d two -> ... na two') + + return discrete_action_logits, continuous_action_mean_log_var + def forward( self, *, @@ -1220,6 +1267,7 @@ class DynamicsWorldModel(Module): continuous_norm_stats = None, reward_loss_weight = 0.1, value_head_mlp_depth = 3, + policy_head_mlp_depth = 3, num_latent_genes = 0 # for carrying out evolution within the dreams https://web3.arxiv.org/abs/2503.19037 ): super().__init__() @@ -1338,6 +1386,15 @@ class DynamicsWorldModel(Module): self.reward_loss_weight = reward_loss_weight + # policy head + + self.policy_head = create_mlp( + dim_in = dim, + dim = dim * 4, + dim_out = dim, + depth = policy_head_mlp_depth + ) + # value head self.value_head = create_mlp( diff --git a/pyproject.toml b/pyproject.toml index 855620f..1ebca49 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.0.8" +version = "0.0.9" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } diff --git a/tests/test_dreamer.py b/tests/test_dreamer.py index 6b89f8d..bfd79eb 100644 --- a/tests/test_dreamer.py +++ b/tests/test_dreamer.py @@ -251,3 +251,22 @@ def test_action_embedder(): action_embed = embedder(discrete_actions = discrete_actions, continuous_actions = continuous_actions, discrete_action_types = 1, continuous_action_types = 0) assert action_embed.shape == (2, 3, 512) + + # unembed + + embedder = ActionEmbedder( + 512, + num_discrete_actions = (4, 4), + num_continuous_actions = 2, + can_unembed = True + ) + + discrete_actions = torch.randint(0, 4, (2, 3, 2)) + continuous_actions = torch.randn(2, 3, 2) + + action_embed = embedder(discrete_actions = discrete_actions, continuous_actions = continuous_actions) + + discrete_logits, continuous_mean_log_var = embedder.unembed(action_embed) + + assert discrete_logits.shape == (2, 3, 8) + assert continuous_mean_log_var.shape == (2, 3, 2, 2)