sum the kl div loss across number of actions by default for action embedder .kl_div
This commit is contained in:
parent
a358a44a53
commit
586379f2c8
@ -72,20 +72,22 @@ TokenizerLosses = namedtuple('TokenizerLosses', ('recon', 'lpips'))
|
|||||||
|
|
||||||
WorldModelLosses = namedtuple('WorldModelLosses', ('flow', 'rewards', 'discrete_actions', 'continuous_actions'))
|
WorldModelLosses = namedtuple('WorldModelLosses', ('flow', 'rewards', 'discrete_actions', 'continuous_actions'))
|
||||||
|
|
||||||
|
MaybeTensor = Tensor | None
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Experience:
|
class Experience:
|
||||||
latents: Tensor
|
latents: Tensor
|
||||||
video: Tensor | None = None
|
video: MaybeTensor = None
|
||||||
proprio: Tensor | None = None
|
proprio: MaybeTensor = None
|
||||||
agent_embed: Tensor | None = None
|
agent_embed: MaybeTensor = None
|
||||||
rewards: Tensor | None = None
|
rewards: Tensor | None = None
|
||||||
actions: tuple[Tensor, Tensor] | None = None
|
actions: tuple[MaybeTensor, MaybeTensor] | None = None
|
||||||
log_probs: tuple[Tensor, Tensor] | None = None
|
log_probs: tuple[MaybeTensor, MaybeTensor] | None = None
|
||||||
old_action_unembeds: tuple[Tensor, Tensor] | None = None
|
old_action_unembeds: tuple[MaybeTensor, MaybeTensor] | None = None
|
||||||
values: Tensor | None = None
|
values: MaybeTensor = None
|
||||||
step_size: int | None = None
|
step_size: int | None = None
|
||||||
lens: Tensor | None = None
|
lens: MaybeTensor = None
|
||||||
is_truncated: Tensor | None = None
|
is_truncated: MaybeTensor = None
|
||||||
agent_index: int = 0
|
agent_index: int = 0
|
||||||
is_from_world_model: bool = True
|
is_from_world_model: bool = True
|
||||||
|
|
||||||
@ -850,9 +852,10 @@ class ActionEmbedder(Module):
|
|||||||
|
|
||||||
def kl_div(
|
def kl_div(
|
||||||
self,
|
self,
|
||||||
src: tuple[Tensor | None, Tensor | None],
|
src: tuple[MaybeTensor, MaybeTensor],
|
||||||
tgt: tuple[Tensor | None, Tensor | None]
|
tgt: tuple[MaybeTensor, MaybeTensor],
|
||||||
) -> tuple[Tensor | None, Tensor | None]:
|
reduce_across_num_actions = True
|
||||||
|
) -> tuple[MaybeTensor, MaybeTensor]:
|
||||||
|
|
||||||
src_discrete, src_continuous = src
|
src_discrete, src_continuous = src
|
||||||
tgt_discrete, tgt_continuous = tgt
|
tgt_discrete, tgt_continuous = tgt
|
||||||
@ -894,6 +897,15 @@ class ActionEmbedder(Module):
|
|||||||
|
|
||||||
continuous_kl_div = kl.kl_divergence(src_normal, tgt_normal)
|
continuous_kl_div = kl.kl_divergence(src_normal, tgt_normal)
|
||||||
|
|
||||||
|
# maybe reduce
|
||||||
|
|
||||||
|
if reduce_across_num_actions:
|
||||||
|
if exists(discrete_kl_div):
|
||||||
|
discrete_kl_div = discrete_kl_div.sum(dim = -1)
|
||||||
|
|
||||||
|
if exists(continuous_kl_div):
|
||||||
|
continuous_kl_div = continuous_kl_div.sum(dim = -1)
|
||||||
|
|
||||||
return discrete_kl_div, continuous_kl_div
|
return discrete_kl_div, continuous_kl_div
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "dreamer4"
|
name = "dreamer4"
|
||||||
version = "0.0.94"
|
version = "0.0.95"
|
||||||
description = "Dreamer 4"
|
description = "Dreamer 4"
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||||
|
|||||||
@ -352,8 +352,8 @@ def test_action_embedder():
|
|||||||
|
|
||||||
discrete_kl_div, continuous_kl_div = embedder.kl_div((discrete_logits, continuous_mean_log_var), (discrete_logits_tgt, continuous_mean_log_var_tgt))
|
discrete_kl_div, continuous_kl_div = embedder.kl_div((discrete_logits, continuous_mean_log_var), (discrete_logits_tgt, continuous_mean_log_var_tgt))
|
||||||
|
|
||||||
assert discrete_kl_div.shape == (2, 3, 2)
|
assert discrete_kl_div.shape == (2, 3)
|
||||||
assert continuous_kl_div.shape == (2, 3, 2)
|
assert continuous_kl_div.shape == (2, 3)
|
||||||
|
|
||||||
# return discrete split by number of actions
|
# return discrete split by number of actions
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user