prepare unembedding parameters in ActionEmbedder as well as the policy head, to allow for behavioral cloning before RL
This commit is contained in:
parent
9101a49cdd
commit
32aa355e37
@ -282,7 +282,9 @@ class ActionEmbedder(Module):
|
|||||||
*,
|
*,
|
||||||
num_discrete_actions: int | tuple[int, ...] = 0,
|
num_discrete_actions: int | tuple[int, ...] = 0,
|
||||||
num_continuous_actions = 0,
|
num_continuous_actions = 0,
|
||||||
continuous_norm_stats: tuple[tuple[float, float], ...] | None = None
|
continuous_norm_stats: tuple[tuple[float, float], ...] | None = None,
|
||||||
|
can_unembed = False,
|
||||||
|
unembed_dim = None
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -294,6 +296,8 @@ class ActionEmbedder(Module):
|
|||||||
self.num_discrete_action_types = len(num_discrete_actions)
|
self.num_discrete_action_types = len(num_discrete_actions)
|
||||||
self.discrete_action_embed = Embedding(total_discrete_actions, dim)
|
self.discrete_action_embed = Embedding(total_discrete_actions, dim)
|
||||||
|
|
||||||
|
self.register_buffer('num_discrete_actions', num_discrete_actions, persistent = False)
|
||||||
|
|
||||||
# continuous actions
|
# continuous actions
|
||||||
|
|
||||||
self.num_continuous_action_types = num_continuous_actions
|
self.num_continuous_action_types = num_continuous_actions
|
||||||
@ -314,6 +318,15 @@ class ActionEmbedder(Module):
|
|||||||
offsets = F.pad(num_discrete_actions.cumsum(dim = -1), (1, -1), value = 0)
|
offsets = F.pad(num_discrete_actions.cumsum(dim = -1), (1, -1), value = 0)
|
||||||
self.register_buffer('discrete_action_offsets', offsets, persistent = False)
|
self.register_buffer('discrete_action_offsets', offsets, persistent = False)
|
||||||
|
|
||||||
|
# unembedding
|
||||||
|
|
||||||
|
self.can_unembed = can_unembed
|
||||||
|
|
||||||
|
if can_unembed:
|
||||||
|
unembed_dim = default(unembed_dim, dim)
|
||||||
|
self.discrete_action_unembed = Parameter(torch.randn(total_discrete_actions, unembed_dim) * 1e-2)
|
||||||
|
self.continuous_action_unembed = Parameter(torch.randn(num_continuous_actions, unembed_dim, 2) * 1e-2)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def device(self):
|
def device(self):
|
||||||
return self.discrete_action_offsets.device
|
return self.discrete_action_offsets.device
|
||||||
@ -322,6 +335,40 @@ class ActionEmbedder(Module):
|
|||||||
def has_actions(self):
|
def has_actions(self):
|
||||||
return self.num_discrete_action_types > 0 or self.num_continuous_action_types > 0
|
return self.num_discrete_action_types > 0 or self.num_continuous_action_types > 0
|
||||||
|
|
||||||
|
def unembed(
|
||||||
|
self,
|
||||||
|
embeds, # (... d)
|
||||||
|
discrete_action_types = None, # (na)
|
||||||
|
continuous_action_types = None, # (na)
|
||||||
|
|
||||||
|
): # (... discrete_na), (... continuous_na 2)
|
||||||
|
|
||||||
|
assert self.can_unembed, 'can only unembed for predicted discrete and continuous actions if `can_unembed = True` is set on init'
|
||||||
|
|
||||||
|
assert not exists(discrete_action_types), 'selecting subset of discrete action types to unembed not completed yet'
|
||||||
|
|
||||||
|
# discrete actions
|
||||||
|
|
||||||
|
discrete_action_logits = None
|
||||||
|
|
||||||
|
if self.num_discrete_action_types > 0:
|
||||||
|
discrete_action_logits = einsum(embeds, self.discrete_action_unembed, '... d, na d -> ... na')
|
||||||
|
|
||||||
|
# continuous actions
|
||||||
|
|
||||||
|
continuous_action_mean_log_var = None
|
||||||
|
|
||||||
|
if self.num_continuous_action_types > 0:
|
||||||
|
|
||||||
|
continuous_action_unembed = self.continuous_action_unembed
|
||||||
|
|
||||||
|
if exists(continuous_action_types):
|
||||||
|
continuous_action_unembed = continuous_action_unembed[continuous_action_types]
|
||||||
|
|
||||||
|
continuous_action_mean_log_var = einsum(embeds, continuous_action_unembed, '... d, na d two -> ... na two')
|
||||||
|
|
||||||
|
return discrete_action_logits, continuous_action_mean_log_var
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
@ -1220,6 +1267,7 @@ class DynamicsWorldModel(Module):
|
|||||||
continuous_norm_stats = None,
|
continuous_norm_stats = None,
|
||||||
reward_loss_weight = 0.1,
|
reward_loss_weight = 0.1,
|
||||||
value_head_mlp_depth = 3,
|
value_head_mlp_depth = 3,
|
||||||
|
policy_head_mlp_depth = 3,
|
||||||
num_latent_genes = 0 # for carrying out evolution within the dreams https://web3.arxiv.org/abs/2503.19037
|
num_latent_genes = 0 # for carrying out evolution within the dreams https://web3.arxiv.org/abs/2503.19037
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -1338,6 +1386,15 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
self.reward_loss_weight = reward_loss_weight
|
self.reward_loss_weight = reward_loss_weight
|
||||||
|
|
||||||
|
# policy head
|
||||||
|
|
||||||
|
self.policy_head = create_mlp(
|
||||||
|
dim_in = dim,
|
||||||
|
dim = dim * 4,
|
||||||
|
dim_out = dim,
|
||||||
|
depth = policy_head_mlp_depth
|
||||||
|
)
|
||||||
|
|
||||||
# value head
|
# value head
|
||||||
|
|
||||||
self.value_head = create_mlp(
|
self.value_head = create_mlp(
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "dreamer4"
|
name = "dreamer4"
|
||||||
version = "0.0.8"
|
version = "0.0.9"
|
||||||
description = "Dreamer 4"
|
description = "Dreamer 4"
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||||
|
|||||||
@ -251,3 +251,22 @@ def test_action_embedder():
|
|||||||
action_embed = embedder(discrete_actions = discrete_actions, continuous_actions = continuous_actions, discrete_action_types = 1, continuous_action_types = 0)
|
action_embed = embedder(discrete_actions = discrete_actions, continuous_actions = continuous_actions, discrete_action_types = 1, continuous_action_types = 0)
|
||||||
|
|
||||||
assert action_embed.shape == (2, 3, 512)
|
assert action_embed.shape == (2, 3, 512)
|
||||||
|
|
||||||
|
# unembed
|
||||||
|
|
||||||
|
embedder = ActionEmbedder(
|
||||||
|
512,
|
||||||
|
num_discrete_actions = (4, 4),
|
||||||
|
num_continuous_actions = 2,
|
||||||
|
can_unembed = True
|
||||||
|
)
|
||||||
|
|
||||||
|
discrete_actions = torch.randint(0, 4, (2, 3, 2))
|
||||||
|
continuous_actions = torch.randn(2, 3, 2)
|
||||||
|
|
||||||
|
action_embed = embedder(discrete_actions = discrete_actions, continuous_actions = continuous_actions)
|
||||||
|
|
||||||
|
discrete_logits, continuous_mean_log_var = embedder.unembed(action_embed)
|
||||||
|
|
||||||
|
assert discrete_logits.shape == (2, 3, 8)
|
||||||
|
assert continuous_mean_log_var.shape == (2, 3, 2, 2)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user