handle variable lengthed experiences when doing policy optimization
This commit is contained in:
parent
46432aee9b
commit
fbfd59e42f
@ -82,7 +82,7 @@ class Experience:
|
||||
log_probs: tuple[Tensor, Tensor] | None = None
|
||||
values: Tensor | None = None
|
||||
step_size: int | None = None
|
||||
lens: Tensor | None = None,
|
||||
lens: Tensor | None = None
|
||||
agent_index: int = 0
|
||||
is_from_world_model: bool = True
|
||||
|
||||
@ -2226,6 +2226,15 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
returns = calc_gae(rewards, old_values, gamma = self.gae_discount_factor, lam = self.gae_lambda, use_accelerated = self.gae_use_accelerated)
|
||||
|
||||
# handle variable lengths
|
||||
|
||||
max_time = latents.shape[1]
|
||||
is_var_len = exists(experience.lens)
|
||||
|
||||
if is_var_len:
|
||||
lens = experience.lens
|
||||
mask = lens_to_mask(lens, max_time)
|
||||
|
||||
# determine whether to finetune entire transformer or just learn the heads
|
||||
|
||||
world_model_forward_context = torch.no_grad if only_learn_policy_value_heads else nullcontext
|
||||
@ -2291,13 +2300,20 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
# handle entropy loss for naive exploration bonus
|
||||
|
||||
entropy_loss = - reduce(entropies, 'b t na -> b t', 'sum').mean()
|
||||
entropy_loss = - reduce(entropies, 'b t na -> b t', 'sum')
|
||||
|
||||
total_policy_loss = (
|
||||
policy_loss +
|
||||
entropy_loss * self.policy_entropy_weight
|
||||
)
|
||||
|
||||
# maybe handle variable lengths
|
||||
|
||||
if is_var_len:
|
||||
total_policy_loss = total_policy_loss[mask].mean()
|
||||
else:
|
||||
total_policy_loss = total_policy_loss.mean()
|
||||
|
||||
# maybe take policy optimizer step
|
||||
|
||||
if exists(policy_optim):
|
||||
@ -2316,10 +2332,19 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
return_bins = self.reward_encoder(returns)
|
||||
|
||||
value_bins, return_bins, clipped_value_bins = tuple(rearrange(t, 'b t l -> b l t') for t in (value_bins, return_bins, clipped_value_bins))
|
||||
|
||||
value_loss_1 = F.cross_entropy(value_bins, return_bins, reduction = 'none')
|
||||
value_loss_2 = F.cross_entropy(clipped_value_bins, return_bins, reduction = 'none')
|
||||
|
||||
value_loss = torch.maximum(value_loss_1, value_loss_2).mean()
|
||||
value_loss = torch.maximum(value_loss_1, value_loss_2)
|
||||
|
||||
# maybe variable length
|
||||
|
||||
if is_var_len:
|
||||
value_loss = value_loss[mask].mean()
|
||||
else:
|
||||
value_loss = value_loss.mean()
|
||||
|
||||
# maybe take value optimizer step
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "dreamer4"
|
||||
version = "0.0.76"
|
||||
version = "0.0.77"
|
||||
description = "Dreamer 4"
|
||||
authors = [
|
||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user