add a complete action embedder that can accept any number of discrete actions with variable bins as well as any number of continuous actions, pooled and added to the agent token as described in the paper (seems like they fixed that horrendous hack in dreamer v3 with sticky action)

This commit is contained in:
lucidrains 2025-10-09 07:53:42 -07:00
parent b62c08be65
commit e2d86a4543
3 changed files with 226 additions and 11 deletions

View File

@ -8,7 +8,7 @@ from functools import partial
import torch
import torch.nn.functional as F
from torch.nn import Module, ModuleList, Parameter, Sequential, Linear, RMSNorm, Identity
from torch.nn import Module, ModuleList, Embedding, Parameter, Sequential, Linear, RMSNorm, Identity
from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange
import torchvision
@ -29,6 +29,7 @@ from assoc_scan import AssocScan
# l - logit / predicted bins
# p - positions (3 for spacetime in this work)
# t - time
# na - action dimension (number of discrete and continuous actions)
# g - groups of query heads to key heads (gqa)
# vc - video channels
# vh, vw - video height and width
@ -66,6 +67,9 @@ def default(v, d):
def first(arr):
return arr[0]
def ensure_tuple(t):
return (t,) if not isinstance(t, tuple) else t
def divisible_by(num, den):
return (num % den) == 0
@ -267,6 +271,116 @@ class SymExpTwoHot(Module):
return inverse_pack(encoded, '* l')
# action related
ActionEmbeds = namedtuple('ActionEmbed', ('discrete', 'continuous'))
class ActionEmbedder(Module):
def __init__(
self,
dim,
*,
num_discrete_actions: int | tuple[int, ...] = 0,
num_continuous_actions = 0,
):
super().__init__()
# handle discrete actions
num_discrete_actions = tensor(ensure_tuple(num_discrete_actions))
total_discrete_actions = num_discrete_actions.sum().item()
self.num_discrete_action_types = len(num_discrete_actions)
self.discrete_action_embed = Embedding(total_discrete_actions, dim)
# continuous actions
self.num_continuous_action_types = num_continuous_actions
self.continuous_action_embed = Embedding(num_continuous_actions, dim)
# defaults
self.register_buffer('default_discrete_action_types', arange(self.num_discrete_action_types), persistent = False)
self.register_buffer('default_continuous_action_types', arange(self.num_continuous_action_types), persistent = False)
# calculate offsets
offsets = F.pad(num_discrete_actions.cumsum(dim = -1), (1, -1), value = 0)
self.register_buffer('discrete_action_offsets', offsets, persistent = False)
@property
def device(self):
return self.discrete_action_offsets.device
@property
def has_actions(self):
return self.num_discrete_action_types > 0 or self.num_continuous_action_types > 0
def forward(
self,
*,
discrete_actions = None, # (... na)
continuous_actions = None, # (... na)
discrete_action_types = None, # (na)
continuous_action_types = None, # (na)
return_sum_pooled_embeds = True
):
discrete_embeds = continuous_embeds = None
if exists(discrete_actions):
discrete_action_types = default(discrete_action_types, self.default_discrete_action_types)
if exists(discrete_action_types) and not is_tensor(discrete_action_types):
if isinstance(discrete_action_types, int):
discrete_action_types = (discrete_action_types,)
discrete_action_types = tensor(discrete_action_types, device = self.device)
offsets = self.discrete_action_offsets[discrete_action_types]
assert offsets.shape[-1] == discrete_actions.shape[-1], 'mismatched number of discrete actions'
# offset the discrete actions based on the action types passed in (by default all discrete actions) and the calculated offset
discrete_actions_offsetted = einx.add('... na, na', discrete_actions, offsets)
discrete_embeds = self.discrete_action_embed(discrete_actions_offsetted)
if exists(continuous_actions):
continuous_action_types = default(continuous_action_types, self.default_continuous_action_types)
if exists(continuous_action_types) and not is_tensor(continuous_action_types):
if isinstance(continuous_action_types, int):
continuous_action_types = (continuous_action_types,)
continuous_action_types = tensor(continuous_action_types, device = self.device)
assert continuous_action_types.shape[-1] == continuous_actions.shape[-1], 'mismatched number of continuous actions'
continuous_action_embed = self.continuous_action_embed(continuous_action_types)
# continuous embed is just the selected continuous action type with the scale
continuous_embeds = einx.multiply('na d, ... na -> ... na d', continuous_action_embed, continuous_actions)
# return not pooled
if not return_sum_pooled_embeds:
return ActionEmbeds(discrete_embeds, continuous_embeds)
# handle sum pooling, which is what they did in the paper for all the actions
pooled = 0.
if exists(discrete_embeds):
pooled = pooled + reduce(discrete_embeds, '... na d -> ... d', 'sum')
if exists(continuous_embeds):
pooled = pooled + reduce(continuous_embeds, '... na d -> ... d', 'sum')
return pooled
# generalized advantage estimate
@torch.no_grad()
@ -1089,6 +1203,8 @@ class DynamicsWorldModel(Module):
prob_no_shortcut_train = None, # probability of no shortcut training, defaults to 1 / num_step_sizes
add_reward_embed_to_agent_token = False,
add_reward_embed_dropout = 0.1,
num_discrete_actions: int | tuple[int, ...] = 0,
num_continuous_actions = 0,
reward_loss_weight = 0.1,
value_head_mlp_depth = 3,
num_latent_genes = 0 # for carrying out evolution within the dreams https://web3.arxiv.org/abs/2503.19037
@ -1182,6 +1298,14 @@ class DynamicsWorldModel(Module):
self.agent_has_genes = num_latent_genes > 0
self.latent_genes = Parameter(randn(num_latent_genes, dim) * 1e-2)
# action embedder
self.action_embedder = ActionEmbedder(
dim = dim,
num_discrete_actions = num_discrete_actions,
num_continuous_actions = num_continuous_actions
)
# each agent token will have the reward embedding of the previous time step - but could eventually just give reward its own token
self.add_reward_embed_to_agent_token = add_reward_embed_to_agent_token
@ -1397,14 +1521,18 @@ class DynamicsWorldModel(Module):
def forward(
self,
*,
video = None,
latents = None, # (b t n d) | (b t d)
signal_levels = None, # () | (b) | (b t)
step_sizes = None, # () | (b)
step_sizes_log2 = None, # () | (b)
tasks = None, # (b)
rewards = None, # (b t)
latent_gene_ids = None, # (b)
video = None, # (b c t vh vw)
latents = None, # (b t n d) | (b t d)
signal_levels = None, # () | (b) | (b t)
step_sizes = None, # () | (b)
step_sizes_log2 = None, # () | (b)
latent_gene_ids = None, # (b)
tasks = None, # (b)
rewards = None, # (b t)
discrete_actions = None, # (b t na)
continuous_actions = None, # (b t na)
discrete_action_types = None, # (na)
continuous_action_types = None, # (na)
return_pred_only = False,
latent_is_noised = False,
return_all_losses = False,
@ -1533,6 +1661,20 @@ class DynamicsWorldModel(Module):
agent_tokens = repeat(agent_tokens, 'b ... d -> b t ... d', t = time)
# maybe add the action embed to the agent tokens per time step
if exists(discrete_actions) or exists(continuous_actions):
assert self.action_embedder.has_actions
action_embed = self.action_embedder(
discrete_actions = discrete_actions,
discrete_action_types = discrete_action_types,
continuous_actions = continuous_actions,
continuous_action_types = continuous_action_types
)
agent_tokens = einx.add('b t ... d, b t d', agent_tokens, action_embed)
# maybe add a reward embedding to agent tokens
if exists(rewards):

View File

@ -1,6 +1,6 @@
[project]
name = "dreamer4"
version = "0.0.7"
version = "0.0.8"
description = "Dreamer 4"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

View File

@ -55,6 +55,7 @@ def test_e2e(
depth = 4,
num_spatial_tokens = num_spatial_tokens,
pred_orig_latent = pred_orig_latent,
num_discrete_actions = 4,
attn_dim_head = 16,
attn_kwargs = dict(
heads = heads,
@ -79,11 +80,14 @@ def test_e2e(
if add_task_embeds:
tasks = torch.randint(0, 4, (2,))
actions = torch.randint(0, 4, (2, 4, 1))
flow_loss = dynamics(
**dynamics_input,
tasks = tasks,
signal_levels = signal_levels,
step_sizes_log2 = step_sizes_log2
step_sizes_log2 = step_sizes_log2,
discrete_actions = actions
)
assert flow_loss.numel() == 1
@ -173,3 +177,72 @@ def test_attend_factory(
out = attend(q, k, v)
assert torch.allclose(flex_out, out, atol = 1e-6)
def test_action_embedder():
from dreamer4.dreamer4 import ActionEmbedder
# 1 discrete action with 4 choices
embedder = ActionEmbedder(
512,
num_discrete_actions = 4
)
actions = torch.randint(0, 4, (2, 3, 1))
action_embed = embedder(discrete_actions = actions)
assert action_embed.shape == (2, 3, 512)
# 2 discrete actions with 4 choices each
embedder = ActionEmbedder(
512,
num_discrete_actions = (4, 4)
)
actions = torch.randint(0, 4, (2, 3, 2))
action_embed = embedder(discrete_actions = actions)
assert action_embed.shape == (2, 3, 512)
# picking out only the second discrete action
actions = torch.randint(0, 4, (2, 3, 1))
action_embed = embedder(discrete_actions = actions, discrete_action_types = 1)
assert action_embed.shape == (2, 3, 512)
# 2 continuous actions
embedder = ActionEmbedder(
512,
num_continuous_actions = 2
)
actions = torch.randn((2, 3, 2))
action_embed = embedder(continuous_actions = actions)
assert action_embed.shape == (2, 3, 512)
# 2 discrete actions with 4 choices each and 2 continuous actions
embedder = ActionEmbedder(
512,
num_discrete_actions = (4, 4),
num_continuous_actions = 2
)
discrete_actions = torch.randint(0, 4, (2, 3, 2))
continuous_actions = torch.randn(2, 3, 2)
action_embed = embedder(discrete_actions = discrete_actions, continuous_actions = continuous_actions)
assert action_embed.shape == (2, 3, 512)
# picking out one discrete and one continuous
discrete_actions = torch.randint(0, 4, (2, 3, 1))
continuous_actions = torch.randn(2, 3, 1)
action_embed = embedder(discrete_actions = discrete_actions, continuous_actions = continuous_actions, discrete_action_types = 1, continuous_action_types = 0)
assert action_embed.shape == (2, 3, 512)