diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 5c5cc2e..a945674 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -322,10 +322,25 @@ class ActionEmbedder(Module): self.can_unembed = can_unembed - if can_unembed: - unembed_dim = default(unembed_dim, dim) - self.discrete_action_unembed = Parameter(torch.randn(total_discrete_actions, unembed_dim) * 1e-2) - self.continuous_action_unembed = Parameter(torch.randn(num_continuous_actions, unembed_dim, 2) * 1e-2) + if not can_unembed: + return + + unembed_dim = default(unembed_dim, dim) + self.discrete_action_unembed = Parameter(torch.randn(total_discrete_actions, unembed_dim) * 1e-2) + + discrete_action_index = arange(total_discrete_actions) + + padded_num_discrete_actions = F.pad(num_discrete_actions, (1, 0), value = 0) + exclusive_cumsum = padded_num_discrete_actions.cumsum(dim = -1) + + discrete_action_mask = ( + einx.greater_equal('j, i -> i j', discrete_action_index, exclusive_cumsum[:-1]) & + einx.less('j, i -> i j', discrete_action_index, exclusive_cumsum[1:]) + ) + + self.register_buffer('discrete_action_mask', discrete_action_mask, persistent = False) + + self.continuous_action_unembed = Parameter(torch.randn(num_continuous_actions, unembed_dim, 2) * 1e-2) @property def device(self): @@ -335,6 +350,18 @@ class ActionEmbedder(Module): def has_actions(self): return self.num_discrete_action_types > 0 or self.num_continuous_action_types > 0 + def cast_action_types( + self, + action_types = None + ): + if exists(action_types) and not is_tensor(action_types): + if isinstance(action_types, int): + action_types = (action_types,) + + action_types = tensor(action_types, device = self.device) + + return action_types + def unembed( self, embeds, # (... d) @@ -345,14 +372,22 @@ class ActionEmbedder(Module): assert self.can_unembed, 'can only unembed for predicted discrete and continuous actions if `can_unembed = True` is set on init' - assert not exists(discrete_action_types), 'selecting subset of discrete action types to unembed not completed yet' + discrete_action_types, continuous_action_types = tuple(self.cast_action_types(t) for t in (discrete_action_types, continuous_action_types)) # discrete actions discrete_action_logits = None if self.num_discrete_action_types > 0: - discrete_action_logits = einsum(embeds, self.discrete_action_unembed, '... d, na d -> ... na') + + discrete_action_unembed = self.discrete_action_unembed + + if exists(discrete_action_types): + discrete_action_mask = self.discrete_action_mask[discrete_action_types].any(dim = 0) + + discrete_action_unembed = discrete_action_unembed[discrete_action_mask] + + discrete_action_logits = einsum(embeds, discrete_action_unembed, '... d, na d -> ... na') # continuous actions @@ -385,11 +420,7 @@ class ActionEmbedder(Module): discrete_action_types = default(discrete_action_types, self.default_discrete_action_types) - if exists(discrete_action_types) and not is_tensor(discrete_action_types): - if isinstance(discrete_action_types, int): - discrete_action_types = (discrete_action_types,) - - discrete_action_types = tensor(discrete_action_types, device = self.device) + discrete_action_types = self.cast_action_types(discrete_action_types) offsets = self.discrete_action_offsets[discrete_action_types] @@ -403,11 +434,7 @@ class ActionEmbedder(Module): if exists(continuous_actions): continuous_action_types = default(continuous_action_types, self.default_continuous_action_types) - if exists(continuous_action_types) and not is_tensor(continuous_action_types): - if isinstance(continuous_action_types, int): - continuous_action_types = (continuous_action_types,) - - continuous_action_types = tensor(continuous_action_types, device = self.device) + continuous_action_types = self.cast_action_types(continuous_action_types) assert continuous_action_types.shape[-1] == continuous_actions.shape[-1], 'mismatched number of continuous actions' diff --git a/pyproject.toml b/pyproject.toml index 1ebca49..f8bb675 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.0.9" +version = "0.0.10" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } diff --git a/tests/test_dreamer.py b/tests/test_dreamer.py index bfd79eb..65f3d7e 100644 --- a/tests/test_dreamer.py +++ b/tests/test_dreamer.py @@ -270,3 +270,10 @@ def test_action_embedder(): assert discrete_logits.shape == (2, 3, 8) assert continuous_mean_log_var.shape == (2, 3, 2, 2) + + # unembed subset of actions + + discrete_logits, continuous_mean_log_var = embedder.unembed(action_embed, discrete_action_types = 1, continuous_action_types = 0) + + assert discrete_logits.shape == (2, 3, 4) + assert continuous_mean_log_var.shape == (2, 3, 1, 2)