cleanup
This commit is contained in:
parent
2a902eaaf7
commit
ec18bc0fa4
@ -2100,9 +2100,13 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
return pred, agent_tokens
|
return pred, agent_tokens
|
||||||
|
|
||||||
|
# curry into get_prediction what does not change during first call as well as the shortcut ones
|
||||||
|
|
||||||
|
_get_prediction = partial(get_prediction, action_tokens = action_tokens, reward_tokens = reward_tokens, agent_tokens = agent_tokens)
|
||||||
|
|
||||||
# forward the network
|
# forward the network
|
||||||
|
|
||||||
pred, encoded_agent_tokens = get_prediction(noised_latents, signal_levels, step_sizes_log2, action_tokens, reward_tokens, agent_tokens, return_agent_tokens = True)
|
pred, encoded_agent_tokens = _get_prediction(noised_latents, signal_levels, step_sizes_log2, return_agent_tokens = True)
|
||||||
|
|
||||||
if return_pred_only:
|
if return_pred_only:
|
||||||
if not return_agent_tokens:
|
if not return_agent_tokens:
|
||||||
@ -2134,12 +2138,12 @@ class DynamicsWorldModel(Module):
|
|||||||
# basically a consistency loss where you ensure quantity of two half steps equals one step
|
# basically a consistency loss where you ensure quantity of two half steps equals one step
|
||||||
# dreamer then makes it works for x-space with some math
|
# dreamer then makes it works for x-space with some math
|
||||||
|
|
||||||
get_prediction_no_grad = torch.no_grad()(get_prediction)
|
get_prediction_no_grad = torch.no_grad()(_get_prediction)
|
||||||
|
|
||||||
step_sizes_log2_minus_one = step_sizes_log2 - 1 # which equals d / 2
|
step_sizes_log2_minus_one = step_sizes_log2 - 1 # which equals d / 2
|
||||||
half_step_size = 2 ** step_sizes_log2_minus_one
|
half_step_size = 2 ** step_sizes_log2_minus_one
|
||||||
|
|
||||||
first_step_pred = get_prediction_no_grad(noised_latents, signal_levels, step_sizes_log2_minus_one, action_tokens, reward_tokens, agent_tokens)
|
first_step_pred = get_prediction_no_grad(noised_latents, signal_levels, step_sizes_log2_minus_one)
|
||||||
|
|
||||||
# first derive b'
|
# first derive b'
|
||||||
|
|
||||||
@ -2158,7 +2162,7 @@ class DynamicsWorldModel(Module):
|
|||||||
# get second prediction for b''
|
# get second prediction for b''
|
||||||
|
|
||||||
signal_levels_plus_half_step = signal_levels + half_step_size[:, None]
|
signal_levels_plus_half_step = signal_levels + half_step_size[:, None]
|
||||||
second_step_pred = get_prediction_no_grad(denoised_latent, signal_levels_plus_half_step, step_sizes_log2_minus_one, action_tokens, reward_tokens, agent_tokens)
|
second_step_pred = get_prediction_no_grad(denoised_latent, signal_levels_plus_half_step, step_sizes_log2_minus_one)
|
||||||
|
|
||||||
if is_v_space_pred:
|
if is_v_space_pred:
|
||||||
second_step_pred_flow = second_step_pred
|
second_step_pred_flow = second_step_pred
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user