complete the latent autoregressive prediction, to use the log variance as a state entropy bonus
This commit is contained in:
parent
125693ce1c
commit
fb6d69f43a
@ -73,12 +73,14 @@ LinearNoBias = partial(Linear, bias = False)
|
|||||||
|
|
||||||
TokenizerLosses = namedtuple('TokenizerLosses', ('recon', 'lpips', 'time_decorr', 'space_decorr'))
|
TokenizerLosses = namedtuple('TokenizerLosses', ('recon', 'lpips', 'time_decorr', 'space_decorr'))
|
||||||
|
|
||||||
WorldModelLosses = namedtuple('WorldModelLosses', ('flow', 'rewards', 'discrete_actions', 'continuous_actions'))
|
WorldModelLosses = namedtuple('WorldModelLosses', ('flow', 'rewards', 'discrete_actions', 'continuous_actions', 'state_pred'))
|
||||||
|
|
||||||
AttentionIntermediates = namedtuple('AttentionIntermediates', ('next_kv_cache', 'normed_inputs'))
|
AttentionIntermediates = namedtuple('AttentionIntermediates', ('next_kv_cache', 'normed_inputs'))
|
||||||
|
|
||||||
TransformerIntermediates = namedtuple('TransformerIntermediates', ('next_kv_cache', 'normed_time_inputs', 'normed_space_inputs', 'next_rnn_hiddens'))
|
TransformerIntermediates = namedtuple('TransformerIntermediates', ('next_kv_cache', 'normed_time_inputs', 'normed_space_inputs', 'next_rnn_hiddens'))
|
||||||
|
|
||||||
|
Predictions = namedtuple('Predictions', ('flow', 'proprioception', 'state'))
|
||||||
|
|
||||||
MaybeTensor = Tensor | None
|
MaybeTensor = Tensor | None
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -3060,8 +3062,10 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
# maybe proprio
|
# maybe proprio
|
||||||
|
|
||||||
if has_proprio:
|
# maybe proprio
|
||||||
pred, pred_proprio = pred
|
|
||||||
|
pred_proprio = pred.proprioception
|
||||||
|
pred = pred.flow
|
||||||
|
|
||||||
# unpack pred
|
# unpack pred
|
||||||
|
|
||||||
@ -3555,23 +3559,29 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
if self.has_proprio:
|
if self.has_proprio:
|
||||||
pred_proprio = self.to_proprio_pred(proprio_token)
|
pred_proprio = self.to_proprio_pred(proprio_token)
|
||||||
|
else:
|
||||||
pred = (pred, pred_proprio)
|
pred_proprio = None
|
||||||
|
|
||||||
# maybe state pred
|
# maybe state pred
|
||||||
|
|
||||||
if self.should_pred_state:
|
if self.should_pred_state:
|
||||||
pred_state = self.to_state_pred(state_pred_token)
|
pred_state = self.to_state_pred(state_pred_token)
|
||||||
|
else:
|
||||||
|
pred_state = None
|
||||||
|
|
||||||
# returning
|
# returning
|
||||||
|
|
||||||
|
# returning
|
||||||
|
|
||||||
|
predictions = Predictions(pred, pred_proprio, pred_state)
|
||||||
|
|
||||||
if not return_agent_tokens:
|
if not return_agent_tokens:
|
||||||
return pred
|
return predictions
|
||||||
|
|
||||||
if not return_time_cache:
|
if not return_time_cache:
|
||||||
return pred, agent_tokens
|
return predictions, agent_tokens
|
||||||
|
|
||||||
return pred, (agent_tokens, intermediates)
|
return predictions, (agent_tokens, intermediates)
|
||||||
|
|
||||||
# curry into get_prediction what does not change during first call as well as the shortcut ones
|
# curry into get_prediction what does not change during first call as well as the shortcut ones
|
||||||
|
|
||||||
@ -3590,12 +3600,13 @@ class DynamicsWorldModel(Module):
|
|||||||
# pack the predictions to calculate flow for different modalities all at once
|
# pack the predictions to calculate flow for different modalities all at once
|
||||||
|
|
||||||
if self.has_proprio:
|
if self.has_proprio:
|
||||||
pred, for_flow_loss_packed_shape = pack(pred, 'b t *')
|
packed_pred, for_flow_loss_packed_shape = pack((pred.flow, pred.proprioception), 'b t *')
|
||||||
|
|
||||||
noised, _ = pack((noised_latents, noised_proprio), 'b t *')
|
noised, _ = pack((noised_latents, noised_proprio), 'b t *')
|
||||||
data, _ = pack((latents, proprio), 'b t *')
|
data, _ = pack((latents, proprio), 'b t *')
|
||||||
noise, _ = pack((noise, proprio_noise), 'b t *')
|
noise, _ = pack((noise, proprio_noise), 'b t *')
|
||||||
else:
|
else:
|
||||||
|
packed_pred = pred.flow
|
||||||
noised = noised_latents
|
noised = noised_latents
|
||||||
data = latents
|
data = latents
|
||||||
|
|
||||||
@ -3614,9 +3625,10 @@ class DynamicsWorldModel(Module):
|
|||||||
pred = fn(noised, noised_proprio, *args, **kwargs)
|
pred = fn(noised, noised_proprio, *args, **kwargs)
|
||||||
|
|
||||||
if self.has_proprio:
|
if self.has_proprio:
|
||||||
pred, _ = pack(pred, 'b t *')
|
packed_flow, _ = pack((pred.flow, pred.proprioception), 'b t *')
|
||||||
|
return packed_flow
|
||||||
|
|
||||||
return pred
|
return pred.flow
|
||||||
return inner
|
return inner
|
||||||
|
|
||||||
wrapped_get_prediction = maybe_pack_unpack(_get_prediction)
|
wrapped_get_prediction = maybe_pack_unpack(_get_prediction)
|
||||||
@ -3683,12 +3695,12 @@ class DynamicsWorldModel(Module):
|
|||||||
# need to convert x-space to v-space
|
# need to convert x-space to v-space
|
||||||
|
|
||||||
if is_x_space:
|
if is_x_space:
|
||||||
pred = (pred - noised) / (1. - first_times)
|
packed_pred = (packed_pred - noised) / (1. - first_times)
|
||||||
maybe_shortcut_loss_weight = (1. - first_times) ** 2
|
maybe_shortcut_loss_weight = (1. - first_times) ** 2
|
||||||
|
|
||||||
# mse loss
|
# mse loss
|
||||||
|
|
||||||
flow_losses = F.mse_loss(pred, pred_target, reduction = 'none')
|
flow_losses = F.mse_loss(packed_pred, pred_target, reduction = 'none')
|
||||||
|
|
||||||
flow_losses = flow_losses * maybe_shortcut_loss_weight # handle the (1-t)^2 in eq(7)
|
flow_losses = flow_losses * maybe_shortcut_loss_weight # handle the (1-t)^2 in eq(7)
|
||||||
|
|
||||||
@ -3740,6 +3752,18 @@ class DynamicsWorldModel(Module):
|
|||||||
else:
|
else:
|
||||||
reward_loss = reduce(reward_losses, '... mtp -> mtp', 'mean') # they sum across the prediction steps (mtp dimension) - eq(9)
|
reward_loss = reduce(reward_losses, '... mtp -> mtp', 'mean') # they sum across the prediction steps (mtp dimension) - eq(9)
|
||||||
|
|
||||||
|
# maybe autoregressive state prediction loss
|
||||||
|
|
||||||
|
state_pred_loss = self.zero
|
||||||
|
|
||||||
|
if self.should_pred_state:
|
||||||
|
pred_latent, latent_to_pred = pred.state[:, :-1], latents[:, 1:]
|
||||||
|
|
||||||
|
pred_latent_mean, pred_latent_log_var = pred_latent.unbind(dim = -1)
|
||||||
|
pred_latent_var = pred_latent_log_var.exp()
|
||||||
|
|
||||||
|
state_pred_loss = F.gaussian_nll_loss(pred_latent_mean, latent_to_pred, var = pred_latent_var)
|
||||||
|
|
||||||
# maybe autoregressive action loss
|
# maybe autoregressive action loss
|
||||||
|
|
||||||
discrete_action_loss = self.zero
|
discrete_action_loss = self.zero
|
||||||
@ -3807,7 +3831,7 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
# handle loss normalization
|
# handle loss normalization
|
||||||
|
|
||||||
losses = WorldModelLosses(flow_loss, reward_loss, discrete_action_loss, continuous_action_loss)
|
losses = WorldModelLosses(flow_loss, reward_loss, discrete_action_loss, continuous_action_loss, state_pred_loss)
|
||||||
|
|
||||||
if exists(self.flow_loss_normalizer):
|
if exists(self.flow_loss_normalizer):
|
||||||
flow_loss = self.flow_loss_normalizer(flow_loss, update_ema = update_loss_ema)
|
flow_loss = self.flow_loss_normalizer(flow_loss, update_ema = update_loss_ema)
|
||||||
@ -3827,7 +3851,8 @@ class DynamicsWorldModel(Module):
|
|||||||
flow_loss * self.latent_flow_loss_weight +
|
flow_loss * self.latent_flow_loss_weight +
|
||||||
(reward_loss * self.reward_loss_weight).sum() +
|
(reward_loss * self.reward_loss_weight).sum() +
|
||||||
(discrete_action_loss * self.discrete_action_loss_weight).sum() +
|
(discrete_action_loss * self.discrete_action_loss_weight).sum() +
|
||||||
(continuous_action_loss * self.continuous_action_loss_weight).sum()
|
(continuous_action_loss * self.continuous_action_loss_weight).sum() +
|
||||||
|
(state_pred_loss * self.state_pred_loss_weight)
|
||||||
)
|
)
|
||||||
|
|
||||||
if not return_all_losses:
|
if not return_all_losses:
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "dreamer4"
|
name = "dreamer4"
|
||||||
version = "0.1.20"
|
version = "0.1.21"
|
||||||
description = "Dreamer 4"
|
description = "Dreamer 4"
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user