This commit is contained in:
lucidrains 2025-10-16 06:44:28 -07:00
parent 2a902eaaf7
commit ec18bc0fa4

View File

@ -2100,9 +2100,13 @@ class DynamicsWorldModel(Module):
return pred, agent_tokens
# curry into get_prediction what does not change during first call as well as the shortcut ones
_get_prediction = partial(get_prediction, action_tokens = action_tokens, reward_tokens = reward_tokens, agent_tokens = agent_tokens)
# forward the network
pred, encoded_agent_tokens = get_prediction(noised_latents, signal_levels, step_sizes_log2, action_tokens, reward_tokens, agent_tokens, return_agent_tokens = True)
pred, encoded_agent_tokens = _get_prediction(noised_latents, signal_levels, step_sizes_log2, return_agent_tokens = True)
if return_pred_only:
if not return_agent_tokens:
@ -2134,12 +2138,12 @@ class DynamicsWorldModel(Module):
# basically a consistency loss where you ensure quantity of two half steps equals one step
# dreamer then makes it works for x-space with some math
get_prediction_no_grad = torch.no_grad()(get_prediction)
get_prediction_no_grad = torch.no_grad()(_get_prediction)
step_sizes_log2_minus_one = step_sizes_log2 - 1 # which equals d / 2
half_step_size = 2 ** step_sizes_log2_minus_one
first_step_pred = get_prediction_no_grad(noised_latents, signal_levels, step_sizes_log2_minus_one, action_tokens, reward_tokens, agent_tokens)
first_step_pred = get_prediction_no_grad(noised_latents, signal_levels, step_sizes_log2_minus_one)
# first derive b'
@ -2158,7 +2162,7 @@ class DynamicsWorldModel(Module):
# get second prediction for b''
signal_levels_plus_half_step = signal_levels + half_step_size[:, None]
second_step_pred = get_prediction_no_grad(denoised_latent, signal_levels_plus_half_step, step_sizes_log2_minus_one, action_tokens, reward_tokens, agent_tokens)
second_step_pred = get_prediction_no_grad(denoised_latent, signal_levels_plus_half_step, step_sizes_log2_minus_one)
if is_v_space_pred:
second_step_pred_flow = second_step_pred