prepare unembedding parameters in ActionEmbedder as well as the policy head, to allow for behavioral cloning before RL

This commit is contained in:
lucidrains 2025-10-10 10:41:48 -07:00
parent 9101a49cdd
commit 32aa355e37
3 changed files with 78 additions and 2 deletions

View File

@ -282,7 +282,9 @@ class ActionEmbedder(Module):
*,
num_discrete_actions: int | tuple[int, ...] = 0,
num_continuous_actions = 0,
continuous_norm_stats: tuple[tuple[float, float], ...] | None = None
continuous_norm_stats: tuple[tuple[float, float], ...] | None = None,
can_unembed = False,
unembed_dim = None
):
super().__init__()
@ -294,6 +296,8 @@ class ActionEmbedder(Module):
self.num_discrete_action_types = len(num_discrete_actions)
self.discrete_action_embed = Embedding(total_discrete_actions, dim)
self.register_buffer('num_discrete_actions', num_discrete_actions, persistent = False)
# continuous actions
self.num_continuous_action_types = num_continuous_actions
@ -314,6 +318,15 @@ class ActionEmbedder(Module):
offsets = F.pad(num_discrete_actions.cumsum(dim = -1), (1, -1), value = 0)
self.register_buffer('discrete_action_offsets', offsets, persistent = False)
# unembedding
self.can_unembed = can_unembed
if can_unembed:
unembed_dim = default(unembed_dim, dim)
self.discrete_action_unembed = Parameter(torch.randn(total_discrete_actions, unembed_dim) * 1e-2)
self.continuous_action_unembed = Parameter(torch.randn(num_continuous_actions, unembed_dim, 2) * 1e-2)
@property
def device(self):
return self.discrete_action_offsets.device
@ -322,6 +335,40 @@ class ActionEmbedder(Module):
def has_actions(self):
return self.num_discrete_action_types > 0 or self.num_continuous_action_types > 0
def unembed(
self,
embeds, # (... d)
discrete_action_types = None, # (na)
continuous_action_types = None, # (na)
): # (... discrete_na), (... continuous_na 2)
assert self.can_unembed, 'can only unembed for predicted discrete and continuous actions if `can_unembed = True` is set on init'
assert not exists(discrete_action_types), 'selecting subset of discrete action types to unembed not completed yet'
# discrete actions
discrete_action_logits = None
if self.num_discrete_action_types > 0:
discrete_action_logits = einsum(embeds, self.discrete_action_unembed, '... d, na d -> ... na')
# continuous actions
continuous_action_mean_log_var = None
if self.num_continuous_action_types > 0:
continuous_action_unembed = self.continuous_action_unembed
if exists(continuous_action_types):
continuous_action_unembed = continuous_action_unembed[continuous_action_types]
continuous_action_mean_log_var = einsum(embeds, continuous_action_unembed, '... d, na d two -> ... na two')
return discrete_action_logits, continuous_action_mean_log_var
def forward(
self,
*,
@ -1220,6 +1267,7 @@ class DynamicsWorldModel(Module):
continuous_norm_stats = None,
reward_loss_weight = 0.1,
value_head_mlp_depth = 3,
policy_head_mlp_depth = 3,
num_latent_genes = 0 # for carrying out evolution within the dreams https://web3.arxiv.org/abs/2503.19037
):
super().__init__()
@ -1338,6 +1386,15 @@ class DynamicsWorldModel(Module):
self.reward_loss_weight = reward_loss_weight
# policy head
self.policy_head = create_mlp(
dim_in = dim,
dim = dim * 4,
dim_out = dim,
depth = policy_head_mlp_depth
)
# value head
self.value_head = create_mlp(

View File

@ -1,6 +1,6 @@
[project]
name = "dreamer4"
version = "0.0.8"
version = "0.0.9"
description = "Dreamer 4"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

View File

@ -251,3 +251,22 @@ def test_action_embedder():
action_embed = embedder(discrete_actions = discrete_actions, continuous_actions = continuous_actions, discrete_action_types = 1, continuous_action_types = 0)
assert action_embed.shape == (2, 3, 512)
# unembed
embedder = ActionEmbedder(
512,
num_discrete_actions = (4, 4),
num_continuous_actions = 2,
can_unembed = True
)
discrete_actions = torch.randint(0, 4, (2, 3, 2))
continuous_actions = torch.randn(2, 3, 2)
action_embed = embedder(discrete_actions = discrete_actions, continuous_actions = continuous_actions)
discrete_logits, continuous_mean_log_var = embedder.unembed(action_embed)
assert discrete_logits.shape == (2, 3, 8)
assert continuous_mean_log_var.shape == (2, 3, 2, 2)