introduce an is_truncated field on Experience, and mask out rewards and values before calculating gae appropriately
This commit is contained in:
parent
fbfd59e42f
commit
59c458aea3
@ -13,7 +13,7 @@ import torch.nn.functional as F
|
|||||||
from torch.nested import nested_tensor
|
from torch.nested import nested_tensor
|
||||||
from torch.distributions import Normal
|
from torch.distributions import Normal
|
||||||
from torch.nn import Module, ModuleList, Embedding, Parameter, Sequential, Linear, RMSNorm, Identity
|
from torch.nn import Module, ModuleList, Embedding, Parameter, Sequential, Linear, RMSNorm, Identity
|
||||||
from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange
|
from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, full, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange
|
||||||
from torch.utils._pytree import tree_flatten, tree_unflatten
|
from torch.utils._pytree import tree_flatten, tree_unflatten
|
||||||
|
|
||||||
import torchvision
|
import torchvision
|
||||||
@ -83,6 +83,7 @@ class Experience:
|
|||||||
values: Tensor | None = None
|
values: Tensor | None = None
|
||||||
step_size: int | None = None
|
step_size: int | None = None
|
||||||
lens: Tensor | None = None
|
lens: Tensor | None = None
|
||||||
|
is_truncated: Tensor | None = None
|
||||||
agent_index: int = 0
|
agent_index: int = 0
|
||||||
is_from_world_model: bool = True
|
is_from_world_model: bool = True
|
||||||
|
|
||||||
@ -99,7 +100,10 @@ def combine_experiences(
|
|||||||
batch, time, device = *latents.shape[:2], latents.device
|
batch, time, device = *latents.shape[:2], latents.device
|
||||||
|
|
||||||
if not exists(exp.lens):
|
if not exists(exp.lens):
|
||||||
exp.lens = torch.full((batch,), time, device = device)
|
exp.lens = full((batch,), time, device = device)
|
||||||
|
|
||||||
|
if not exists(exp.is_truncated):
|
||||||
|
exp.is_truncated = full((batch,), True, device = device)
|
||||||
|
|
||||||
# convert to dictionary
|
# convert to dictionary
|
||||||
|
|
||||||
@ -2115,7 +2119,7 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
time_kv_cache = None
|
time_kv_cache = None
|
||||||
|
|
||||||
for _ in range(max_timesteps):
|
for i in range(max_timesteps + 1):
|
||||||
|
|
||||||
latents = self.video_tokenizer(video, return_latents = True)
|
latents = self.video_tokenizer(video, return_latents = True)
|
||||||
|
|
||||||
@ -2141,6 +2145,15 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
one_agent_embed = agent_embed[..., -1:, agent_index, :]
|
one_agent_embed = agent_embed[..., -1:, agent_index, :]
|
||||||
|
|
||||||
|
# values
|
||||||
|
|
||||||
|
value_bins = self.value_head(one_agent_embed)
|
||||||
|
value = self.reward_encoder.bins_to_scalar_value(value_bins)
|
||||||
|
|
||||||
|
values = safe_cat((values, value), dim = 1)
|
||||||
|
|
||||||
|
# policy embed
|
||||||
|
|
||||||
policy_embed = self.policy_head(one_agent_embed)
|
policy_embed = self.policy_head(one_agent_embed)
|
||||||
|
|
||||||
# sample actions
|
# sample actions
|
||||||
@ -2162,11 +2175,6 @@ class DynamicsWorldModel(Module):
|
|||||||
discrete_log_probs = safe_cat((discrete_log_probs, one_discrete_log_probs), dim = 1)
|
discrete_log_probs = safe_cat((discrete_log_probs, one_discrete_log_probs), dim = 1)
|
||||||
continuous_log_probs = safe_cat((continuous_log_probs, one_continuous_log_probs), dim = 1)
|
continuous_log_probs = safe_cat((continuous_log_probs, one_continuous_log_probs), dim = 1)
|
||||||
|
|
||||||
value_bins = self.value_head(one_agent_embed)
|
|
||||||
value = self.reward_encoder.bins_to_scalar_value(value_bins)
|
|
||||||
|
|
||||||
values = safe_cat((values, value), dim = 1)
|
|
||||||
|
|
||||||
# pass the sampled action to the environment and get back next state and reward
|
# pass the sampled action to the environment and get back next state and reward
|
||||||
|
|
||||||
next_frame, reward = env.step((sampled_discrete_actions, sampled_continuous_actions))
|
next_frame, reward = env.step((sampled_discrete_actions, sampled_continuous_actions))
|
||||||
@ -2187,6 +2195,8 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
# package up one experience for learning
|
# package up one experience for learning
|
||||||
|
|
||||||
|
batch, device = latents.shape[0], latents.device
|
||||||
|
|
||||||
one_experience = Experience(
|
one_experience = Experience(
|
||||||
latents = latents,
|
latents = latents,
|
||||||
video = video[:, :, :-1],
|
video = video[:, :, :-1],
|
||||||
@ -2196,6 +2206,7 @@ class DynamicsWorldModel(Module):
|
|||||||
values = values,
|
values = values,
|
||||||
step_size = step_size,
|
step_size = step_size,
|
||||||
agent_index = agent_index,
|
agent_index = agent_index,
|
||||||
|
lens = full((batch,), max_timesteps, device = device),
|
||||||
is_from_world_model = False
|
is_from_world_model = False
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -2224,6 +2235,26 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
assert all([*map(exists, (old_log_probs, actions, old_values, rewards, step_size))]), 'the generations need to contain the log probs, values, and rewards for policy optimization'
|
assert all([*map(exists, (old_log_probs, actions, old_values, rewards, step_size))]), 'the generations need to contain the log probs, values, and rewards for policy optimization'
|
||||||
|
|
||||||
|
batch, time = latents.shape[0], latents.shape[1]
|
||||||
|
|
||||||
|
# calculate returns
|
||||||
|
|
||||||
|
# for truncated (true by default), we will calculate experience lens + 1 and mask out anything after that
|
||||||
|
# for terminated, will just mask out any after lens
|
||||||
|
|
||||||
|
# if not supplied, assume truncated (which is the case for games like minecraft or capped timesteps)
|
||||||
|
|
||||||
|
if not exists(experience.is_truncated):
|
||||||
|
experience.is_truncated = full((batch,), True, device = latents.device)
|
||||||
|
|
||||||
|
lens_for_gae_calc = torch.where(experience.is_truncated, experience.lens, experience.lens + 1)
|
||||||
|
mask_for_gae = lens_to_mask(lens_for_gae_calc, time)
|
||||||
|
|
||||||
|
rewards = rewards.masked_fill(mask_for_gae, 0.)
|
||||||
|
old_values = old_values.masked_fill(mask_for_gae, 0.)
|
||||||
|
|
||||||
|
# calculate returns
|
||||||
|
|
||||||
returns = calc_gae(rewards, old_values, gamma = self.gae_discount_factor, lam = self.gae_lambda, use_accelerated = self.gae_use_accelerated)
|
returns = calc_gae(rewards, old_values, gamma = self.gae_discount_factor, lam = self.gae_lambda, use_accelerated = self.gae_use_accelerated)
|
||||||
|
|
||||||
# handle variable lengths
|
# handle variable lengths
|
||||||
@ -2387,7 +2418,7 @@ class DynamicsWorldModel(Module):
|
|||||||
assert 0 < num_steps <= self.max_steps, f'number of steps {num_steps} must be between 0 and {self.max_steps}'
|
assert 0 < num_steps <= self.max_steps, f'number of steps {num_steps} must be between 0 and {self.max_steps}'
|
||||||
|
|
||||||
if isinstance(tasks, int):
|
if isinstance(tasks, int):
|
||||||
tasks = torch.full((batch_size,), tasks, device = self.device)
|
tasks = full((batch_size,), tasks, device = self.device)
|
||||||
|
|
||||||
assert not exists(tasks) or tasks.shape[0] == batch_size
|
assert not exists(tasks) or tasks.shape[0] == batch_size
|
||||||
|
|
||||||
@ -2624,7 +2655,7 @@ class DynamicsWorldModel(Module):
|
|||||||
# returning agent actions, rewards, and log probs + values for policy optimization
|
# returning agent actions, rewards, and log probs + values for policy optimization
|
||||||
|
|
||||||
batch, device = latents.shape[0], latents.device
|
batch, device = latents.shape[0], latents.device
|
||||||
experience_lens = torch.full((batch,), time_steps, device = device)
|
experience_lens = full((batch,), time_steps, device = device)
|
||||||
|
|
||||||
gen = Experience(
|
gen = Experience(
|
||||||
latents = latents,
|
latents = latents,
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "dreamer4"
|
name = "dreamer4"
|
||||||
version = "0.0.77"
|
version = "0.0.78"
|
||||||
description = "Dreamer 4"
|
description = "Dreamer 4"
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user