diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 5504951..91bc687 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -73,12 +73,14 @@ LinearNoBias = partial(Linear, bias = False) TokenizerLosses = namedtuple('TokenizerLosses', ('recon', 'lpips', 'time_decorr', 'space_decorr')) -WorldModelLosses = namedtuple('WorldModelLosses', ('flow', 'rewards', 'discrete_actions', 'continuous_actions')) +WorldModelLosses = namedtuple('WorldModelLosses', ('flow', 'rewards', 'discrete_actions', 'continuous_actions', 'state_pred')) AttentionIntermediates = namedtuple('AttentionIntermediates', ('next_kv_cache', 'normed_inputs')) TransformerIntermediates = namedtuple('TransformerIntermediates', ('next_kv_cache', 'normed_time_inputs', 'normed_space_inputs', 'next_rnn_hiddens')) +Predictions = namedtuple('Predictions', ('flow', 'proprioception', 'state')) + MaybeTensor = Tensor | None @dataclass @@ -3060,8 +3062,10 @@ class DynamicsWorldModel(Module): # maybe proprio - if has_proprio: - pred, pred_proprio = pred + # maybe proprio + + pred_proprio = pred.proprioception + pred = pred.flow # unpack pred @@ -3555,23 +3559,29 @@ class DynamicsWorldModel(Module): if self.has_proprio: pred_proprio = self.to_proprio_pred(proprio_token) - - pred = (pred, pred_proprio) + else: + pred_proprio = None # maybe state pred if self.should_pred_state: pred_state = self.to_state_pred(state_pred_token) + else: + pred_state = None # returning + # returning + + predictions = Predictions(pred, pred_proprio, pred_state) + if not return_agent_tokens: - return pred + return predictions if not return_time_cache: - return pred, agent_tokens + return predictions, agent_tokens - return pred, (agent_tokens, intermediates) + return predictions, (agent_tokens, intermediates) # curry into get_prediction what does not change during first call as well as the shortcut ones @@ -3590,12 +3600,13 @@ class DynamicsWorldModel(Module): # pack the predictions to calculate flow for different modalities all at once if self.has_proprio: - pred, for_flow_loss_packed_shape = pack(pred, 'b t *') + packed_pred, for_flow_loss_packed_shape = pack((pred.flow, pred.proprioception), 'b t *') noised, _ = pack((noised_latents, noised_proprio), 'b t *') data, _ = pack((latents, proprio), 'b t *') noise, _ = pack((noise, proprio_noise), 'b t *') else: + packed_pred = pred.flow noised = noised_latents data = latents @@ -3614,9 +3625,10 @@ class DynamicsWorldModel(Module): pred = fn(noised, noised_proprio, *args, **kwargs) if self.has_proprio: - pred, _ = pack(pred, 'b t *') + packed_flow, _ = pack((pred.flow, pred.proprioception), 'b t *') + return packed_flow - return pred + return pred.flow return inner wrapped_get_prediction = maybe_pack_unpack(_get_prediction) @@ -3683,12 +3695,12 @@ class DynamicsWorldModel(Module): # need to convert x-space to v-space if is_x_space: - pred = (pred - noised) / (1. - first_times) + packed_pred = (packed_pred - noised) / (1. - first_times) maybe_shortcut_loss_weight = (1. - first_times) ** 2 # mse loss - flow_losses = F.mse_loss(pred, pred_target, reduction = 'none') + flow_losses = F.mse_loss(packed_pred, pred_target, reduction = 'none') flow_losses = flow_losses * maybe_shortcut_loss_weight # handle the (1-t)^2 in eq(7) @@ -3740,6 +3752,18 @@ class DynamicsWorldModel(Module): else: reward_loss = reduce(reward_losses, '... mtp -> mtp', 'mean') # they sum across the prediction steps (mtp dimension) - eq(9) + # maybe autoregressive state prediction loss + + state_pred_loss = self.zero + + if self.should_pred_state: + pred_latent, latent_to_pred = pred.state[:, :-1], latents[:, 1:] + + pred_latent_mean, pred_latent_log_var = pred_latent.unbind(dim = -1) + pred_latent_var = pred_latent_log_var.exp() + + state_pred_loss = F.gaussian_nll_loss(pred_latent_mean, latent_to_pred, var = pred_latent_var) + # maybe autoregressive action loss discrete_action_loss = self.zero @@ -3807,7 +3831,7 @@ class DynamicsWorldModel(Module): # handle loss normalization - losses = WorldModelLosses(flow_loss, reward_loss, discrete_action_loss, continuous_action_loss) + losses = WorldModelLosses(flow_loss, reward_loss, discrete_action_loss, continuous_action_loss, state_pred_loss) if exists(self.flow_loss_normalizer): flow_loss = self.flow_loss_normalizer(flow_loss, update_ema = update_loss_ema) @@ -3827,7 +3851,8 @@ class DynamicsWorldModel(Module): flow_loss * self.latent_flow_loss_weight + (reward_loss * self.reward_loss_weight).sum() + (discrete_action_loss * self.discrete_action_loss_weight).sum() + - (continuous_action_loss * self.continuous_action_loss_weight).sum() + (continuous_action_loss * self.continuous_action_loss_weight).sum() + + (state_pred_loss * self.state_pred_loss_weight) ) if not return_all_losses: diff --git a/pyproject.toml b/pyproject.toml index 12ac647..95cd36a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.1.20" +version = "0.1.21" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }