last commit for the day
This commit is contained in:
parent
9230267d34
commit
5df3e69583
@ -367,6 +367,7 @@ class ActionEmbedder(Module):
|
|||||||
embeds, # (... d)
|
embeds, # (... d)
|
||||||
discrete_action_types = None, # (na)
|
discrete_action_types = None, # (na)
|
||||||
continuous_action_types = None, # (na)
|
continuous_action_types = None, # (na)
|
||||||
|
return_split_discrete = False
|
||||||
|
|
||||||
): # (... discrete_na), (... continuous_na 2)
|
): # (... discrete_na), (... continuous_na 2)
|
||||||
|
|
||||||
@ -389,6 +390,14 @@ class ActionEmbedder(Module):
|
|||||||
|
|
||||||
discrete_action_logits = einsum(embeds, discrete_action_unembed, '... d, na d -> ... na')
|
discrete_action_logits = einsum(embeds, discrete_action_unembed, '... d, na d -> ... na')
|
||||||
|
|
||||||
|
# whether to split the discrete action logits by the number of actions per action type
|
||||||
|
|
||||||
|
if exists(discrete_action_logits) and return_split_discrete:
|
||||||
|
|
||||||
|
split_sizes = self.num_discrete_actions[discrete_action_types] if exists(discrete_action_types) else self.num_discrete_actions
|
||||||
|
|
||||||
|
discrete_action_logits = discrete_action_logits.split(split_sizes.tolist(), dim = -1)
|
||||||
|
|
||||||
# continuous actions
|
# continuous actions
|
||||||
|
|
||||||
continuous_action_mean_log_var = None
|
continuous_action_mean_log_var = None
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "dreamer4"
|
name = "dreamer4"
|
||||||
version = "0.0.10"
|
version = "0.0.11"
|
||||||
description = "Dreamer 4"
|
description = "Dreamer 4"
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||||
|
|||||||
@ -271,6 +271,11 @@ def test_action_embedder():
|
|||||||
assert discrete_logits.shape == (2, 3, 8)
|
assert discrete_logits.shape == (2, 3, 8)
|
||||||
assert continuous_mean_log_var.shape == (2, 3, 2, 2)
|
assert continuous_mean_log_var.shape == (2, 3, 2, 2)
|
||||||
|
|
||||||
|
# return discrete split by number of actions
|
||||||
|
|
||||||
|
discrete_logits, continuous_mean_log_var = embedder.unembed(action_embed, return_split_discrete = True)
|
||||||
|
assert discrete_logits[0].shape == discrete_logits[1].shape == (2, 3, 4)
|
||||||
|
|
||||||
# unembed subset of actions
|
# unembed subset of actions
|
||||||
|
|
||||||
discrete_logits, continuous_mean_log_var = embedder.unembed(action_embed, discrete_action_types = 1, continuous_action_types = 0)
|
discrete_logits, continuous_mean_log_var = embedder.unembed(action_embed, discrete_action_types = 1, continuous_action_types = 0)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user