sum the kl div loss across number of actions by default for action embedder .kl_div
This commit is contained in:
parent
a358a44a53
commit
586379f2c8
@ -72,20 +72,22 @@ TokenizerLosses = namedtuple('TokenizerLosses', ('recon', 'lpips'))
|
||||
|
||||
WorldModelLosses = namedtuple('WorldModelLosses', ('flow', 'rewards', 'discrete_actions', 'continuous_actions'))
|
||||
|
||||
MaybeTensor = Tensor | None
|
||||
|
||||
@dataclass
|
||||
class Experience:
|
||||
latents: Tensor
|
||||
video: Tensor | None = None
|
||||
proprio: Tensor | None = None
|
||||
agent_embed: Tensor | None = None
|
||||
video: MaybeTensor = None
|
||||
proprio: MaybeTensor = None
|
||||
agent_embed: MaybeTensor = None
|
||||
rewards: Tensor | None = None
|
||||
actions: tuple[Tensor, Tensor] | None = None
|
||||
log_probs: tuple[Tensor, Tensor] | None = None
|
||||
old_action_unembeds: tuple[Tensor, Tensor] | None = None
|
||||
values: Tensor | None = None
|
||||
actions: tuple[MaybeTensor, MaybeTensor] | None = None
|
||||
log_probs: tuple[MaybeTensor, MaybeTensor] | None = None
|
||||
old_action_unembeds: tuple[MaybeTensor, MaybeTensor] | None = None
|
||||
values: MaybeTensor = None
|
||||
step_size: int | None = None
|
||||
lens: Tensor | None = None
|
||||
is_truncated: Tensor | None = None
|
||||
lens: MaybeTensor = None
|
||||
is_truncated: MaybeTensor = None
|
||||
agent_index: int = 0
|
||||
is_from_world_model: bool = True
|
||||
|
||||
@ -850,9 +852,10 @@ class ActionEmbedder(Module):
|
||||
|
||||
def kl_div(
|
||||
self,
|
||||
src: tuple[Tensor | None, Tensor | None],
|
||||
tgt: tuple[Tensor | None, Tensor | None]
|
||||
) -> tuple[Tensor | None, Tensor | None]:
|
||||
src: tuple[MaybeTensor, MaybeTensor],
|
||||
tgt: tuple[MaybeTensor, MaybeTensor],
|
||||
reduce_across_num_actions = True
|
||||
) -> tuple[MaybeTensor, MaybeTensor]:
|
||||
|
||||
src_discrete, src_continuous = src
|
||||
tgt_discrete, tgt_continuous = tgt
|
||||
@ -894,6 +897,15 @@ class ActionEmbedder(Module):
|
||||
|
||||
continuous_kl_div = kl.kl_divergence(src_normal, tgt_normal)
|
||||
|
||||
# maybe reduce
|
||||
|
||||
if reduce_across_num_actions:
|
||||
if exists(discrete_kl_div):
|
||||
discrete_kl_div = discrete_kl_div.sum(dim = -1)
|
||||
|
||||
if exists(continuous_kl_div):
|
||||
continuous_kl_div = continuous_kl_div.sum(dim = -1)
|
||||
|
||||
return discrete_kl_div, continuous_kl_div
|
||||
|
||||
def forward(
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "dreamer4"
|
||||
version = "0.0.94"
|
||||
version = "0.0.95"
|
||||
description = "Dreamer 4"
|
||||
authors = [
|
||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||
|
||||
@ -352,8 +352,8 @@ def test_action_embedder():
|
||||
|
||||
discrete_kl_div, continuous_kl_div = embedder.kl_div((discrete_logits, continuous_mean_log_var), (discrete_logits_tgt, continuous_mean_log_var_tgt))
|
||||
|
||||
assert discrete_kl_div.shape == (2, 3, 2)
|
||||
assert continuous_kl_div.shape == (2, 3, 2)
|
||||
assert discrete_kl_div.shape == (2, 3)
|
||||
assert continuous_kl_div.shape == (2, 3)
|
||||
|
||||
# return discrete split by number of actions
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user