From ff81dd761bdb995f28bc155cdafec252c2f1de65 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Mon, 13 Oct 2025 11:36:21 -0700 Subject: [PATCH] separate action and agent embeds --- dreamer4/dreamer4.py | 21 ++++++++++++++------- pyproject.toml | 2 +- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 18bf488..3db4e7d 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -39,6 +39,7 @@ from assoc_scan import AssocScan # vh, vw - video height and width import einx +from einx import add, multiply from einops import einsum, rearrange, repeat, reduce, pack, unpack from einops.layers.torch import Rearrange, Reduce @@ -579,7 +580,7 @@ class ActionEmbedder(Module): # offset the discrete actions based on the action types passed in (by default all discrete actions) and the calculated offset - discrete_actions_offsetted = einx.add('... na, na', discrete_actions, offsets) + discrete_actions_offsetted = add('... na, na', discrete_actions, offsets) discrete_embeds = self.discrete_action_embed(discrete_actions_offsetted) if exists(continuous_actions): @@ -599,7 +600,7 @@ class ActionEmbedder(Module): # continuous embed is just the selected continuous action type with the scale - continuous_embeds = einx.multiply('na d, ... na -> ... na d', continuous_action_embed, continuous_actions) + continuous_embeds = multiply('na d, ... na -> ... na d', continuous_action_embed, continuous_actions) # return not pooled @@ -769,7 +770,7 @@ class MultiHeadRMSNorm(Module): ): normed = l2norm(x) scale = (self.gamma + 1.) * self.scale - return einx.multiply('... h n d, h d', normed, scale) + return multiply('... h n d, h d', normed, scale) # naive attend @@ -1559,6 +1560,8 @@ class DynamicsWorldModel(Module): # they sum all the actions into a single token self.num_agents = num_agents + + self.agent_learned_embed = Parameter(randn(self.num_agents, dim) * 1e-2) self.action_learned_embed = Parameter(randn(self.num_agents, dim) * 1e-2) self.num_tasks = num_tasks @@ -1927,13 +1930,13 @@ class DynamicsWorldModel(Module): # reinforcement learning related - agent_tokens = repeat(self.action_learned_embed, '... d -> b ... d', b = batch) + agent_tokens = repeat(self.agent_learned_embed, '... d -> b ... d', b = batch) if exists(tasks): assert self.num_tasks > 0 task_embeds = self.task_embed(tasks) - agent_tokens = einx.add('b ... d, b d', agent_tokens, task_embeds) + agent_tokens = add('b ... d, b d', agent_tokens, task_embeds) # maybe evolution @@ -1941,7 +1944,7 @@ class DynamicsWorldModel(Module): assert exists(self.latent_genes) latent_genes = self.latent_genes[latent_gene_ids] - agent_tokens = einx.add('b ... d, b d', agent_tokens, latent_genes) + agent_tokens = add('b ... d, b d', agent_tokens, latent_genes) # handle agent tokens w/ actions and task embeds @@ -1962,12 +1965,13 @@ class DynamicsWorldModel(Module): reward_embeds = pad_at_dim(reward_embeds, (1, -pop_last_reward), dim = -2, value = 0.) # shift as each agent token predicts the next reward - agent_tokens = einx.add('b t ... d, b t d', agent_tokens, reward_embeds) + agent_tokens = add('b t ... d, b t d', agent_tokens, reward_embeds) # maybe create the action tokens if exists(discrete_actions) or exists(continuous_actions): assert self.action_embedder.has_actions + assert self.num_agents == 1, 'only one agent allowed for now' action_tokens = self.action_embedder( discrete_actions = discrete_actions, @@ -1975,6 +1979,9 @@ class DynamicsWorldModel(Module): continuous_actions = continuous_actions, continuous_action_types = continuous_action_types ) + + action_tokens = add('1 d, b t d', self.action_learned_embed, action_tokens) + else: action_tokens = agent_tokens[:, :, 0:0] # else empty off agent tokens diff --git a/pyproject.toml b/pyproject.toml index 0c004a6..9f40547 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.0.18" +version = "0.0.19" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }