add a complete action embedder that can accept any number of discrete actions with variable bins as well as any number of continuous actions, pooled and added to the agent token as described in the paper (seems like they fixed that horrendous hack in dreamer v3 with sticky action)
This commit is contained in:
parent
b62c08be65
commit
e2d86a4543
@ -8,7 +8,7 @@ from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Module, ModuleList, Parameter, Sequential, Linear, RMSNorm, Identity
|
||||
from torch.nn import Module, ModuleList, Embedding, Parameter, Sequential, Linear, RMSNorm, Identity
|
||||
from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange
|
||||
|
||||
import torchvision
|
||||
@ -29,6 +29,7 @@ from assoc_scan import AssocScan
|
||||
# l - logit / predicted bins
|
||||
# p - positions (3 for spacetime in this work)
|
||||
# t - time
|
||||
# na - action dimension (number of discrete and continuous actions)
|
||||
# g - groups of query heads to key heads (gqa)
|
||||
# vc - video channels
|
||||
# vh, vw - video height and width
|
||||
@ -66,6 +67,9 @@ def default(v, d):
|
||||
def first(arr):
|
||||
return arr[0]
|
||||
|
||||
def ensure_tuple(t):
|
||||
return (t,) if not isinstance(t, tuple) else t
|
||||
|
||||
def divisible_by(num, den):
|
||||
return (num % den) == 0
|
||||
|
||||
@ -267,6 +271,116 @@ class SymExpTwoHot(Module):
|
||||
|
||||
return inverse_pack(encoded, '* l')
|
||||
|
||||
# action related
|
||||
|
||||
ActionEmbeds = namedtuple('ActionEmbed', ('discrete', 'continuous'))
|
||||
|
||||
class ActionEmbedder(Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
*,
|
||||
num_discrete_actions: int | tuple[int, ...] = 0,
|
||||
num_continuous_actions = 0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# handle discrete actions
|
||||
|
||||
num_discrete_actions = tensor(ensure_tuple(num_discrete_actions))
|
||||
total_discrete_actions = num_discrete_actions.sum().item()
|
||||
|
||||
self.num_discrete_action_types = len(num_discrete_actions)
|
||||
self.discrete_action_embed = Embedding(total_discrete_actions, dim)
|
||||
|
||||
# continuous actions
|
||||
|
||||
self.num_continuous_action_types = num_continuous_actions
|
||||
self.continuous_action_embed = Embedding(num_continuous_actions, dim)
|
||||
|
||||
# defaults
|
||||
|
||||
self.register_buffer('default_discrete_action_types', arange(self.num_discrete_action_types), persistent = False)
|
||||
self.register_buffer('default_continuous_action_types', arange(self.num_continuous_action_types), persistent = False)
|
||||
|
||||
# calculate offsets
|
||||
|
||||
offsets = F.pad(num_discrete_actions.cumsum(dim = -1), (1, -1), value = 0)
|
||||
self.register_buffer('discrete_action_offsets', offsets, persistent = False)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self.discrete_action_offsets.device
|
||||
|
||||
@property
|
||||
def has_actions(self):
|
||||
return self.num_discrete_action_types > 0 or self.num_continuous_action_types > 0
|
||||
|
||||
def forward(
|
||||
self,
|
||||
*,
|
||||
discrete_actions = None, # (... na)
|
||||
continuous_actions = None, # (... na)
|
||||
discrete_action_types = None, # (na)
|
||||
continuous_action_types = None, # (na)
|
||||
return_sum_pooled_embeds = True
|
||||
):
|
||||
|
||||
discrete_embeds = continuous_embeds = None
|
||||
|
||||
if exists(discrete_actions):
|
||||
|
||||
discrete_action_types = default(discrete_action_types, self.default_discrete_action_types)
|
||||
|
||||
if exists(discrete_action_types) and not is_tensor(discrete_action_types):
|
||||
if isinstance(discrete_action_types, int):
|
||||
discrete_action_types = (discrete_action_types,)
|
||||
|
||||
discrete_action_types = tensor(discrete_action_types, device = self.device)
|
||||
|
||||
offsets = self.discrete_action_offsets[discrete_action_types]
|
||||
|
||||
assert offsets.shape[-1] == discrete_actions.shape[-1], 'mismatched number of discrete actions'
|
||||
|
||||
# offset the discrete actions based on the action types passed in (by default all discrete actions) and the calculated offset
|
||||
|
||||
discrete_actions_offsetted = einx.add('... na, na', discrete_actions, offsets)
|
||||
discrete_embeds = self.discrete_action_embed(discrete_actions_offsetted)
|
||||
|
||||
if exists(continuous_actions):
|
||||
continuous_action_types = default(continuous_action_types, self.default_continuous_action_types)
|
||||
|
||||
if exists(continuous_action_types) and not is_tensor(continuous_action_types):
|
||||
if isinstance(continuous_action_types, int):
|
||||
continuous_action_types = (continuous_action_types,)
|
||||
|
||||
continuous_action_types = tensor(continuous_action_types, device = self.device)
|
||||
|
||||
assert continuous_action_types.shape[-1] == continuous_actions.shape[-1], 'mismatched number of continuous actions'
|
||||
|
||||
continuous_action_embed = self.continuous_action_embed(continuous_action_types)
|
||||
|
||||
# continuous embed is just the selected continuous action type with the scale
|
||||
|
||||
continuous_embeds = einx.multiply('na d, ... na -> ... na d', continuous_action_embed, continuous_actions)
|
||||
|
||||
# return not pooled
|
||||
|
||||
if not return_sum_pooled_embeds:
|
||||
return ActionEmbeds(discrete_embeds, continuous_embeds)
|
||||
|
||||
# handle sum pooling, which is what they did in the paper for all the actions
|
||||
|
||||
pooled = 0.
|
||||
|
||||
if exists(discrete_embeds):
|
||||
pooled = pooled + reduce(discrete_embeds, '... na d -> ... d', 'sum')
|
||||
|
||||
if exists(continuous_embeds):
|
||||
pooled = pooled + reduce(continuous_embeds, '... na d -> ... d', 'sum')
|
||||
|
||||
return pooled
|
||||
|
||||
# generalized advantage estimate
|
||||
|
||||
@torch.no_grad()
|
||||
@ -1089,6 +1203,8 @@ class DynamicsWorldModel(Module):
|
||||
prob_no_shortcut_train = None, # probability of no shortcut training, defaults to 1 / num_step_sizes
|
||||
add_reward_embed_to_agent_token = False,
|
||||
add_reward_embed_dropout = 0.1,
|
||||
num_discrete_actions: int | tuple[int, ...] = 0,
|
||||
num_continuous_actions = 0,
|
||||
reward_loss_weight = 0.1,
|
||||
value_head_mlp_depth = 3,
|
||||
num_latent_genes = 0 # for carrying out evolution within the dreams https://web3.arxiv.org/abs/2503.19037
|
||||
@ -1182,6 +1298,14 @@ class DynamicsWorldModel(Module):
|
||||
self.agent_has_genes = num_latent_genes > 0
|
||||
self.latent_genes = Parameter(randn(num_latent_genes, dim) * 1e-2)
|
||||
|
||||
# action embedder
|
||||
|
||||
self.action_embedder = ActionEmbedder(
|
||||
dim = dim,
|
||||
num_discrete_actions = num_discrete_actions,
|
||||
num_continuous_actions = num_continuous_actions
|
||||
)
|
||||
|
||||
# each agent token will have the reward embedding of the previous time step - but could eventually just give reward its own token
|
||||
|
||||
self.add_reward_embed_to_agent_token = add_reward_embed_to_agent_token
|
||||
@ -1397,14 +1521,18 @@ class DynamicsWorldModel(Module):
|
||||
def forward(
|
||||
self,
|
||||
*,
|
||||
video = None,
|
||||
latents = None, # (b t n d) | (b t d)
|
||||
signal_levels = None, # () | (b) | (b t)
|
||||
step_sizes = None, # () | (b)
|
||||
step_sizes_log2 = None, # () | (b)
|
||||
tasks = None, # (b)
|
||||
rewards = None, # (b t)
|
||||
latent_gene_ids = None, # (b)
|
||||
video = None, # (b c t vh vw)
|
||||
latents = None, # (b t n d) | (b t d)
|
||||
signal_levels = None, # () | (b) | (b t)
|
||||
step_sizes = None, # () | (b)
|
||||
step_sizes_log2 = None, # () | (b)
|
||||
latent_gene_ids = None, # (b)
|
||||
tasks = None, # (b)
|
||||
rewards = None, # (b t)
|
||||
discrete_actions = None, # (b t na)
|
||||
continuous_actions = None, # (b t na)
|
||||
discrete_action_types = None, # (na)
|
||||
continuous_action_types = None, # (na)
|
||||
return_pred_only = False,
|
||||
latent_is_noised = False,
|
||||
return_all_losses = False,
|
||||
@ -1533,6 +1661,20 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
agent_tokens = repeat(agent_tokens, 'b ... d -> b t ... d', t = time)
|
||||
|
||||
# maybe add the action embed to the agent tokens per time step
|
||||
|
||||
if exists(discrete_actions) or exists(continuous_actions):
|
||||
assert self.action_embedder.has_actions
|
||||
|
||||
action_embed = self.action_embedder(
|
||||
discrete_actions = discrete_actions,
|
||||
discrete_action_types = discrete_action_types,
|
||||
continuous_actions = continuous_actions,
|
||||
continuous_action_types = continuous_action_types
|
||||
)
|
||||
|
||||
agent_tokens = einx.add('b t ... d, b t d', agent_tokens, action_embed)
|
||||
|
||||
# maybe add a reward embedding to agent tokens
|
||||
|
||||
if exists(rewards):
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "dreamer4"
|
||||
version = "0.0.7"
|
||||
version = "0.0.8"
|
||||
description = "Dreamer 4"
|
||||
authors = [
|
||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||
|
||||
@ -55,6 +55,7 @@ def test_e2e(
|
||||
depth = 4,
|
||||
num_spatial_tokens = num_spatial_tokens,
|
||||
pred_orig_latent = pred_orig_latent,
|
||||
num_discrete_actions = 4,
|
||||
attn_dim_head = 16,
|
||||
attn_kwargs = dict(
|
||||
heads = heads,
|
||||
@ -79,11 +80,14 @@ def test_e2e(
|
||||
if add_task_embeds:
|
||||
tasks = torch.randint(0, 4, (2,))
|
||||
|
||||
actions = torch.randint(0, 4, (2, 4, 1))
|
||||
|
||||
flow_loss = dynamics(
|
||||
**dynamics_input,
|
||||
tasks = tasks,
|
||||
signal_levels = signal_levels,
|
||||
step_sizes_log2 = step_sizes_log2
|
||||
step_sizes_log2 = step_sizes_log2,
|
||||
discrete_actions = actions
|
||||
)
|
||||
|
||||
assert flow_loss.numel() == 1
|
||||
@ -173,3 +177,72 @@ def test_attend_factory(
|
||||
out = attend(q, k, v)
|
||||
|
||||
assert torch.allclose(flex_out, out, atol = 1e-6)
|
||||
|
||||
def test_action_embedder():
|
||||
from dreamer4.dreamer4 import ActionEmbedder
|
||||
|
||||
# 1 discrete action with 4 choices
|
||||
|
||||
embedder = ActionEmbedder(
|
||||
512,
|
||||
num_discrete_actions = 4
|
||||
)
|
||||
|
||||
actions = torch.randint(0, 4, (2, 3, 1))
|
||||
action_embed = embedder(discrete_actions = actions)
|
||||
|
||||
assert action_embed.shape == (2, 3, 512)
|
||||
|
||||
# 2 discrete actions with 4 choices each
|
||||
|
||||
embedder = ActionEmbedder(
|
||||
512,
|
||||
num_discrete_actions = (4, 4)
|
||||
)
|
||||
|
||||
actions = torch.randint(0, 4, (2, 3, 2))
|
||||
action_embed = embedder(discrete_actions = actions)
|
||||
|
||||
assert action_embed.shape == (2, 3, 512)
|
||||
|
||||
# picking out only the second discrete action
|
||||
|
||||
actions = torch.randint(0, 4, (2, 3, 1))
|
||||
action_embed = embedder(discrete_actions = actions, discrete_action_types = 1)
|
||||
|
||||
assert action_embed.shape == (2, 3, 512)
|
||||
|
||||
# 2 continuous actions
|
||||
|
||||
embedder = ActionEmbedder(
|
||||
512,
|
||||
num_continuous_actions = 2
|
||||
)
|
||||
|
||||
actions = torch.randn((2, 3, 2))
|
||||
action_embed = embedder(continuous_actions = actions)
|
||||
|
||||
assert action_embed.shape == (2, 3, 512)
|
||||
|
||||
# 2 discrete actions with 4 choices each and 2 continuous actions
|
||||
|
||||
embedder = ActionEmbedder(
|
||||
512,
|
||||
num_discrete_actions = (4, 4),
|
||||
num_continuous_actions = 2
|
||||
)
|
||||
|
||||
discrete_actions = torch.randint(0, 4, (2, 3, 2))
|
||||
continuous_actions = torch.randn(2, 3, 2)
|
||||
|
||||
action_embed = embedder(discrete_actions = discrete_actions, continuous_actions = continuous_actions)
|
||||
assert action_embed.shape == (2, 3, 512)
|
||||
|
||||
# picking out one discrete and one continuous
|
||||
|
||||
discrete_actions = torch.randint(0, 4, (2, 3, 1))
|
||||
continuous_actions = torch.randn(2, 3, 1)
|
||||
|
||||
action_embed = embedder(discrete_actions = discrete_actions, continuous_actions = continuous_actions, discrete_action_types = 1, continuous_action_types = 0)
|
||||
|
||||
assert action_embed.shape == (2, 3, 512)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user