prepare unembedding parameters in ActionEmbedder as well as the policy head, to allow for behavioral cloning before RL
This commit is contained in:
parent
9101a49cdd
commit
32aa355e37
@ -282,7 +282,9 @@ class ActionEmbedder(Module):
|
||||
*,
|
||||
num_discrete_actions: int | tuple[int, ...] = 0,
|
||||
num_continuous_actions = 0,
|
||||
continuous_norm_stats: tuple[tuple[float, float], ...] | None = None
|
||||
continuous_norm_stats: tuple[tuple[float, float], ...] | None = None,
|
||||
can_unembed = False,
|
||||
unembed_dim = None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -294,6 +296,8 @@ class ActionEmbedder(Module):
|
||||
self.num_discrete_action_types = len(num_discrete_actions)
|
||||
self.discrete_action_embed = Embedding(total_discrete_actions, dim)
|
||||
|
||||
self.register_buffer('num_discrete_actions', num_discrete_actions, persistent = False)
|
||||
|
||||
# continuous actions
|
||||
|
||||
self.num_continuous_action_types = num_continuous_actions
|
||||
@ -314,6 +318,15 @@ class ActionEmbedder(Module):
|
||||
offsets = F.pad(num_discrete_actions.cumsum(dim = -1), (1, -1), value = 0)
|
||||
self.register_buffer('discrete_action_offsets', offsets, persistent = False)
|
||||
|
||||
# unembedding
|
||||
|
||||
self.can_unembed = can_unembed
|
||||
|
||||
if can_unembed:
|
||||
unembed_dim = default(unembed_dim, dim)
|
||||
self.discrete_action_unembed = Parameter(torch.randn(total_discrete_actions, unembed_dim) * 1e-2)
|
||||
self.continuous_action_unembed = Parameter(torch.randn(num_continuous_actions, unembed_dim, 2) * 1e-2)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self.discrete_action_offsets.device
|
||||
@ -322,6 +335,40 @@ class ActionEmbedder(Module):
|
||||
def has_actions(self):
|
||||
return self.num_discrete_action_types > 0 or self.num_continuous_action_types > 0
|
||||
|
||||
def unembed(
|
||||
self,
|
||||
embeds, # (... d)
|
||||
discrete_action_types = None, # (na)
|
||||
continuous_action_types = None, # (na)
|
||||
|
||||
): # (... discrete_na), (... continuous_na 2)
|
||||
|
||||
assert self.can_unembed, 'can only unembed for predicted discrete and continuous actions if `can_unembed = True` is set on init'
|
||||
|
||||
assert not exists(discrete_action_types), 'selecting subset of discrete action types to unembed not completed yet'
|
||||
|
||||
# discrete actions
|
||||
|
||||
discrete_action_logits = None
|
||||
|
||||
if self.num_discrete_action_types > 0:
|
||||
discrete_action_logits = einsum(embeds, self.discrete_action_unembed, '... d, na d -> ... na')
|
||||
|
||||
# continuous actions
|
||||
|
||||
continuous_action_mean_log_var = None
|
||||
|
||||
if self.num_continuous_action_types > 0:
|
||||
|
||||
continuous_action_unembed = self.continuous_action_unembed
|
||||
|
||||
if exists(continuous_action_types):
|
||||
continuous_action_unembed = continuous_action_unembed[continuous_action_types]
|
||||
|
||||
continuous_action_mean_log_var = einsum(embeds, continuous_action_unembed, '... d, na d two -> ... na two')
|
||||
|
||||
return discrete_action_logits, continuous_action_mean_log_var
|
||||
|
||||
def forward(
|
||||
self,
|
||||
*,
|
||||
@ -1220,6 +1267,7 @@ class DynamicsWorldModel(Module):
|
||||
continuous_norm_stats = None,
|
||||
reward_loss_weight = 0.1,
|
||||
value_head_mlp_depth = 3,
|
||||
policy_head_mlp_depth = 3,
|
||||
num_latent_genes = 0 # for carrying out evolution within the dreams https://web3.arxiv.org/abs/2503.19037
|
||||
):
|
||||
super().__init__()
|
||||
@ -1338,6 +1386,15 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
self.reward_loss_weight = reward_loss_weight
|
||||
|
||||
# policy head
|
||||
|
||||
self.policy_head = create_mlp(
|
||||
dim_in = dim,
|
||||
dim = dim * 4,
|
||||
dim_out = dim,
|
||||
depth = policy_head_mlp_depth
|
||||
)
|
||||
|
||||
# value head
|
||||
|
||||
self.value_head = create_mlp(
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "dreamer4"
|
||||
version = "0.0.8"
|
||||
version = "0.0.9"
|
||||
description = "Dreamer 4"
|
||||
authors = [
|
||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||
|
||||
@ -251,3 +251,22 @@ def test_action_embedder():
|
||||
action_embed = embedder(discrete_actions = discrete_actions, continuous_actions = continuous_actions, discrete_action_types = 1, continuous_action_types = 0)
|
||||
|
||||
assert action_embed.shape == (2, 3, 512)
|
||||
|
||||
# unembed
|
||||
|
||||
embedder = ActionEmbedder(
|
||||
512,
|
||||
num_discrete_actions = (4, 4),
|
||||
num_continuous_actions = 2,
|
||||
can_unembed = True
|
||||
)
|
||||
|
||||
discrete_actions = torch.randint(0, 4, (2, 3, 2))
|
||||
continuous_actions = torch.randn(2, 3, 2)
|
||||
|
||||
action_embed = embedder(discrete_actions = discrete_actions, continuous_actions = continuous_actions)
|
||||
|
||||
discrete_logits, continuous_mean_log_var = embedder.unembed(action_embed)
|
||||
|
||||
assert discrete_logits.shape == (2, 3, 8)
|
||||
assert continuous_mean_log_var.shape == (2, 3, 2, 2)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user