handle subset of discrete action unembedding
This commit is contained in:
parent
c68942b026
commit
9230267d34
@ -322,10 +322,25 @@ class ActionEmbedder(Module):
|
|||||||
|
|
||||||
self.can_unembed = can_unembed
|
self.can_unembed = can_unembed
|
||||||
|
|
||||||
if can_unembed:
|
if not can_unembed:
|
||||||
unembed_dim = default(unembed_dim, dim)
|
return
|
||||||
self.discrete_action_unembed = Parameter(torch.randn(total_discrete_actions, unembed_dim) * 1e-2)
|
|
||||||
self.continuous_action_unembed = Parameter(torch.randn(num_continuous_actions, unembed_dim, 2) * 1e-2)
|
unembed_dim = default(unembed_dim, dim)
|
||||||
|
self.discrete_action_unembed = Parameter(torch.randn(total_discrete_actions, unembed_dim) * 1e-2)
|
||||||
|
|
||||||
|
discrete_action_index = arange(total_discrete_actions)
|
||||||
|
|
||||||
|
padded_num_discrete_actions = F.pad(num_discrete_actions, (1, 0), value = 0)
|
||||||
|
exclusive_cumsum = padded_num_discrete_actions.cumsum(dim = -1)
|
||||||
|
|
||||||
|
discrete_action_mask = (
|
||||||
|
einx.greater_equal('j, i -> i j', discrete_action_index, exclusive_cumsum[:-1]) &
|
||||||
|
einx.less('j, i -> i j', discrete_action_index, exclusive_cumsum[1:])
|
||||||
|
)
|
||||||
|
|
||||||
|
self.register_buffer('discrete_action_mask', discrete_action_mask, persistent = False)
|
||||||
|
|
||||||
|
self.continuous_action_unembed = Parameter(torch.randn(num_continuous_actions, unembed_dim, 2) * 1e-2)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def device(self):
|
def device(self):
|
||||||
@ -335,6 +350,18 @@ class ActionEmbedder(Module):
|
|||||||
def has_actions(self):
|
def has_actions(self):
|
||||||
return self.num_discrete_action_types > 0 or self.num_continuous_action_types > 0
|
return self.num_discrete_action_types > 0 or self.num_continuous_action_types > 0
|
||||||
|
|
||||||
|
def cast_action_types(
|
||||||
|
self,
|
||||||
|
action_types = None
|
||||||
|
):
|
||||||
|
if exists(action_types) and not is_tensor(action_types):
|
||||||
|
if isinstance(action_types, int):
|
||||||
|
action_types = (action_types,)
|
||||||
|
|
||||||
|
action_types = tensor(action_types, device = self.device)
|
||||||
|
|
||||||
|
return action_types
|
||||||
|
|
||||||
def unembed(
|
def unembed(
|
||||||
self,
|
self,
|
||||||
embeds, # (... d)
|
embeds, # (... d)
|
||||||
@ -345,14 +372,22 @@ class ActionEmbedder(Module):
|
|||||||
|
|
||||||
assert self.can_unembed, 'can only unembed for predicted discrete and continuous actions if `can_unembed = True` is set on init'
|
assert self.can_unembed, 'can only unembed for predicted discrete and continuous actions if `can_unembed = True` is set on init'
|
||||||
|
|
||||||
assert not exists(discrete_action_types), 'selecting subset of discrete action types to unembed not completed yet'
|
discrete_action_types, continuous_action_types = tuple(self.cast_action_types(t) for t in (discrete_action_types, continuous_action_types))
|
||||||
|
|
||||||
# discrete actions
|
# discrete actions
|
||||||
|
|
||||||
discrete_action_logits = None
|
discrete_action_logits = None
|
||||||
|
|
||||||
if self.num_discrete_action_types > 0:
|
if self.num_discrete_action_types > 0:
|
||||||
discrete_action_logits = einsum(embeds, self.discrete_action_unembed, '... d, na d -> ... na')
|
|
||||||
|
discrete_action_unembed = self.discrete_action_unembed
|
||||||
|
|
||||||
|
if exists(discrete_action_types):
|
||||||
|
discrete_action_mask = self.discrete_action_mask[discrete_action_types].any(dim = 0)
|
||||||
|
|
||||||
|
discrete_action_unembed = discrete_action_unembed[discrete_action_mask]
|
||||||
|
|
||||||
|
discrete_action_logits = einsum(embeds, discrete_action_unembed, '... d, na d -> ... na')
|
||||||
|
|
||||||
# continuous actions
|
# continuous actions
|
||||||
|
|
||||||
@ -385,11 +420,7 @@ class ActionEmbedder(Module):
|
|||||||
|
|
||||||
discrete_action_types = default(discrete_action_types, self.default_discrete_action_types)
|
discrete_action_types = default(discrete_action_types, self.default_discrete_action_types)
|
||||||
|
|
||||||
if exists(discrete_action_types) and not is_tensor(discrete_action_types):
|
discrete_action_types = self.cast_action_types(discrete_action_types)
|
||||||
if isinstance(discrete_action_types, int):
|
|
||||||
discrete_action_types = (discrete_action_types,)
|
|
||||||
|
|
||||||
discrete_action_types = tensor(discrete_action_types, device = self.device)
|
|
||||||
|
|
||||||
offsets = self.discrete_action_offsets[discrete_action_types]
|
offsets = self.discrete_action_offsets[discrete_action_types]
|
||||||
|
|
||||||
@ -403,11 +434,7 @@ class ActionEmbedder(Module):
|
|||||||
if exists(continuous_actions):
|
if exists(continuous_actions):
|
||||||
continuous_action_types = default(continuous_action_types, self.default_continuous_action_types)
|
continuous_action_types = default(continuous_action_types, self.default_continuous_action_types)
|
||||||
|
|
||||||
if exists(continuous_action_types) and not is_tensor(continuous_action_types):
|
continuous_action_types = self.cast_action_types(continuous_action_types)
|
||||||
if isinstance(continuous_action_types, int):
|
|
||||||
continuous_action_types = (continuous_action_types,)
|
|
||||||
|
|
||||||
continuous_action_types = tensor(continuous_action_types, device = self.device)
|
|
||||||
|
|
||||||
assert continuous_action_types.shape[-1] == continuous_actions.shape[-1], 'mismatched number of continuous actions'
|
assert continuous_action_types.shape[-1] == continuous_actions.shape[-1], 'mismatched number of continuous actions'
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "dreamer4"
|
name = "dreamer4"
|
||||||
version = "0.0.9"
|
version = "0.0.10"
|
||||||
description = "Dreamer 4"
|
description = "Dreamer 4"
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||||
|
|||||||
@ -270,3 +270,10 @@ def test_action_embedder():
|
|||||||
|
|
||||||
assert discrete_logits.shape == (2, 3, 8)
|
assert discrete_logits.shape == (2, 3, 8)
|
||||||
assert continuous_mean_log_var.shape == (2, 3, 2, 2)
|
assert continuous_mean_log_var.shape == (2, 3, 2, 2)
|
||||||
|
|
||||||
|
# unembed subset of actions
|
||||||
|
|
||||||
|
discrete_logits, continuous_mean_log_var = embedder.unembed(action_embed, discrete_action_types = 1, continuous_action_types = 0)
|
||||||
|
|
||||||
|
assert discrete_logits.shape == (2, 3, 4)
|
||||||
|
assert continuous_mean_log_var.shape == (2, 3, 1, 2)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user