diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 6461a81..2d11d75 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -2100,9 +2100,13 @@ class DynamicsWorldModel(Module): return pred, agent_tokens + # curry into get_prediction what does not change during first call as well as the shortcut ones + + _get_prediction = partial(get_prediction, action_tokens = action_tokens, reward_tokens = reward_tokens, agent_tokens = agent_tokens) + # forward the network - pred, encoded_agent_tokens = get_prediction(noised_latents, signal_levels, step_sizes_log2, action_tokens, reward_tokens, agent_tokens, return_agent_tokens = True) + pred, encoded_agent_tokens = _get_prediction(noised_latents, signal_levels, step_sizes_log2, return_agent_tokens = True) if return_pred_only: if not return_agent_tokens: @@ -2134,12 +2138,12 @@ class DynamicsWorldModel(Module): # basically a consistency loss where you ensure quantity of two half steps equals one step # dreamer then makes it works for x-space with some math - get_prediction_no_grad = torch.no_grad()(get_prediction) + get_prediction_no_grad = torch.no_grad()(_get_prediction) step_sizes_log2_minus_one = step_sizes_log2 - 1 # which equals d / 2 half_step_size = 2 ** step_sizes_log2_minus_one - first_step_pred = get_prediction_no_grad(noised_latents, signal_levels, step_sizes_log2_minus_one, action_tokens, reward_tokens, agent_tokens) + first_step_pred = get_prediction_no_grad(noised_latents, signal_levels, step_sizes_log2_minus_one) # first derive b' @@ -2158,7 +2162,7 @@ class DynamicsWorldModel(Module): # get second prediction for b'' signal_levels_plus_half_step = signal_levels + half_step_size[:, None] - second_step_pred = get_prediction_no_grad(denoised_latent, signal_levels_plus_half_step, step_sizes_log2_minus_one, action_tokens, reward_tokens, agent_tokens) + second_step_pred = get_prediction_no_grad(denoised_latent, signal_levels_plus_half_step, step_sizes_log2_minus_one) if is_v_space_pred: second_step_pred_flow = second_step_pred