diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index a945674..34dacc9 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -367,6 +367,7 @@ class ActionEmbedder(Module): embeds, # (... d) discrete_action_types = None, # (na) continuous_action_types = None, # (na) + return_split_discrete = False ): # (... discrete_na), (... continuous_na 2) @@ -389,6 +390,14 @@ class ActionEmbedder(Module): discrete_action_logits = einsum(embeds, discrete_action_unembed, '... d, na d -> ... na') + # whether to split the discrete action logits by the number of actions per action type + + if exists(discrete_action_logits) and return_split_discrete: + + split_sizes = self.num_discrete_actions[discrete_action_types] if exists(discrete_action_types) else self.num_discrete_actions + + discrete_action_logits = discrete_action_logits.split(split_sizes.tolist(), dim = -1) + # continuous actions continuous_action_mean_log_var = None diff --git a/pyproject.toml b/pyproject.toml index f8bb675..f6129e5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.0.10" +version = "0.0.11" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } diff --git a/tests/test_dreamer.py b/tests/test_dreamer.py index 65f3d7e..1a45a79 100644 --- a/tests/test_dreamer.py +++ b/tests/test_dreamer.py @@ -271,6 +271,11 @@ def test_action_embedder(): assert discrete_logits.shape == (2, 3, 8) assert continuous_mean_log_var.shape == (2, 3, 2, 2) + # return discrete split by number of actions + + discrete_logits, continuous_mean_log_var = embedder.unembed(action_embed, return_split_discrete = True) + assert discrete_logits[0].shape == discrete_logits[1].shape == (2, 3, 4) + # unembed subset of actions discrete_logits, continuous_mean_log_var = embedder.unembed(action_embed, discrete_action_types = 1, continuous_action_types = 0)