cleanup
This commit is contained in:
parent
2a902eaaf7
commit
ec18bc0fa4
@ -2100,9 +2100,13 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
return pred, agent_tokens
|
||||
|
||||
# curry into get_prediction what does not change during first call as well as the shortcut ones
|
||||
|
||||
_get_prediction = partial(get_prediction, action_tokens = action_tokens, reward_tokens = reward_tokens, agent_tokens = agent_tokens)
|
||||
|
||||
# forward the network
|
||||
|
||||
pred, encoded_agent_tokens = get_prediction(noised_latents, signal_levels, step_sizes_log2, action_tokens, reward_tokens, agent_tokens, return_agent_tokens = True)
|
||||
pred, encoded_agent_tokens = _get_prediction(noised_latents, signal_levels, step_sizes_log2, return_agent_tokens = True)
|
||||
|
||||
if return_pred_only:
|
||||
if not return_agent_tokens:
|
||||
@ -2134,12 +2138,12 @@ class DynamicsWorldModel(Module):
|
||||
# basically a consistency loss where you ensure quantity of two half steps equals one step
|
||||
# dreamer then makes it works for x-space with some math
|
||||
|
||||
get_prediction_no_grad = torch.no_grad()(get_prediction)
|
||||
get_prediction_no_grad = torch.no_grad()(_get_prediction)
|
||||
|
||||
step_sizes_log2_minus_one = step_sizes_log2 - 1 # which equals d / 2
|
||||
half_step_size = 2 ** step_sizes_log2_minus_one
|
||||
|
||||
first_step_pred = get_prediction_no_grad(noised_latents, signal_levels, step_sizes_log2_minus_one, action_tokens, reward_tokens, agent_tokens)
|
||||
first_step_pred = get_prediction_no_grad(noised_latents, signal_levels, step_sizes_log2_minus_one)
|
||||
|
||||
# first derive b'
|
||||
|
||||
@ -2158,7 +2162,7 @@ class DynamicsWorldModel(Module):
|
||||
# get second prediction for b''
|
||||
|
||||
signal_levels_plus_half_step = signal_levels + half_step_size[:, None]
|
||||
second_step_pred = get_prediction_no_grad(denoised_latent, signal_levels_plus_half_step, step_sizes_log2_minus_one, action_tokens, reward_tokens, agent_tokens)
|
||||
second_step_pred = get_prediction_no_grad(denoised_latent, signal_levels_plus_half_step, step_sizes_log2_minus_one)
|
||||
|
||||
if is_v_space_pred:
|
||||
second_step_pred_flow = second_step_pred
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user