introduce an is_truncated field on Experience, and mask out rewards and values before calculating gae appropriately
This commit is contained in:
parent
fbfd59e42f
commit
59c458aea3
@ -13,7 +13,7 @@ import torch.nn.functional as F
|
||||
from torch.nested import nested_tensor
|
||||
from torch.distributions import Normal
|
||||
from torch.nn import Module, ModuleList, Embedding, Parameter, Sequential, Linear, RMSNorm, Identity
|
||||
from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange
|
||||
from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, full, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange
|
||||
from torch.utils._pytree import tree_flatten, tree_unflatten
|
||||
|
||||
import torchvision
|
||||
@ -83,6 +83,7 @@ class Experience:
|
||||
values: Tensor | None = None
|
||||
step_size: int | None = None
|
||||
lens: Tensor | None = None
|
||||
is_truncated: Tensor | None = None
|
||||
agent_index: int = 0
|
||||
is_from_world_model: bool = True
|
||||
|
||||
@ -99,7 +100,10 @@ def combine_experiences(
|
||||
batch, time, device = *latents.shape[:2], latents.device
|
||||
|
||||
if not exists(exp.lens):
|
||||
exp.lens = torch.full((batch,), time, device = device)
|
||||
exp.lens = full((batch,), time, device = device)
|
||||
|
||||
if not exists(exp.is_truncated):
|
||||
exp.is_truncated = full((batch,), True, device = device)
|
||||
|
||||
# convert to dictionary
|
||||
|
||||
@ -2115,7 +2119,7 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
time_kv_cache = None
|
||||
|
||||
for _ in range(max_timesteps):
|
||||
for i in range(max_timesteps + 1):
|
||||
|
||||
latents = self.video_tokenizer(video, return_latents = True)
|
||||
|
||||
@ -2141,6 +2145,15 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
one_agent_embed = agent_embed[..., -1:, agent_index, :]
|
||||
|
||||
# values
|
||||
|
||||
value_bins = self.value_head(one_agent_embed)
|
||||
value = self.reward_encoder.bins_to_scalar_value(value_bins)
|
||||
|
||||
values = safe_cat((values, value), dim = 1)
|
||||
|
||||
# policy embed
|
||||
|
||||
policy_embed = self.policy_head(one_agent_embed)
|
||||
|
||||
# sample actions
|
||||
@ -2162,11 +2175,6 @@ class DynamicsWorldModel(Module):
|
||||
discrete_log_probs = safe_cat((discrete_log_probs, one_discrete_log_probs), dim = 1)
|
||||
continuous_log_probs = safe_cat((continuous_log_probs, one_continuous_log_probs), dim = 1)
|
||||
|
||||
value_bins = self.value_head(one_agent_embed)
|
||||
value = self.reward_encoder.bins_to_scalar_value(value_bins)
|
||||
|
||||
values = safe_cat((values, value), dim = 1)
|
||||
|
||||
# pass the sampled action to the environment and get back next state and reward
|
||||
|
||||
next_frame, reward = env.step((sampled_discrete_actions, sampled_continuous_actions))
|
||||
@ -2187,6 +2195,8 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
# package up one experience for learning
|
||||
|
||||
batch, device = latents.shape[0], latents.device
|
||||
|
||||
one_experience = Experience(
|
||||
latents = latents,
|
||||
video = video[:, :, :-1],
|
||||
@ -2196,6 +2206,7 @@ class DynamicsWorldModel(Module):
|
||||
values = values,
|
||||
step_size = step_size,
|
||||
agent_index = agent_index,
|
||||
lens = full((batch,), max_timesteps, device = device),
|
||||
is_from_world_model = False
|
||||
)
|
||||
|
||||
@ -2224,6 +2235,26 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
assert all([*map(exists, (old_log_probs, actions, old_values, rewards, step_size))]), 'the generations need to contain the log probs, values, and rewards for policy optimization'
|
||||
|
||||
batch, time = latents.shape[0], latents.shape[1]
|
||||
|
||||
# calculate returns
|
||||
|
||||
# for truncated (true by default), we will calculate experience lens + 1 and mask out anything after that
|
||||
# for terminated, will just mask out any after lens
|
||||
|
||||
# if not supplied, assume truncated (which is the case for games like minecraft or capped timesteps)
|
||||
|
||||
if not exists(experience.is_truncated):
|
||||
experience.is_truncated = full((batch,), True, device = latents.device)
|
||||
|
||||
lens_for_gae_calc = torch.where(experience.is_truncated, experience.lens, experience.lens + 1)
|
||||
mask_for_gae = lens_to_mask(lens_for_gae_calc, time)
|
||||
|
||||
rewards = rewards.masked_fill(mask_for_gae, 0.)
|
||||
old_values = old_values.masked_fill(mask_for_gae, 0.)
|
||||
|
||||
# calculate returns
|
||||
|
||||
returns = calc_gae(rewards, old_values, gamma = self.gae_discount_factor, lam = self.gae_lambda, use_accelerated = self.gae_use_accelerated)
|
||||
|
||||
# handle variable lengths
|
||||
@ -2387,7 +2418,7 @@ class DynamicsWorldModel(Module):
|
||||
assert 0 < num_steps <= self.max_steps, f'number of steps {num_steps} must be between 0 and {self.max_steps}'
|
||||
|
||||
if isinstance(tasks, int):
|
||||
tasks = torch.full((batch_size,), tasks, device = self.device)
|
||||
tasks = full((batch_size,), tasks, device = self.device)
|
||||
|
||||
assert not exists(tasks) or tasks.shape[0] == batch_size
|
||||
|
||||
@ -2624,7 +2655,7 @@ class DynamicsWorldModel(Module):
|
||||
# returning agent actions, rewards, and log probs + values for policy optimization
|
||||
|
||||
batch, device = latents.shape[0], latents.device
|
||||
experience_lens = torch.full((batch,), time_steps, device = device)
|
||||
experience_lens = full((batch,), time_steps, device = device)
|
||||
|
||||
gen = Experience(
|
||||
latents = latents,
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "dreamer4"
|
||||
version = "0.0.77"
|
||||
version = "0.0.78"
|
||||
description = "Dreamer 4"
|
||||
authors = [
|
||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user