start the learning in dreams portion
This commit is contained in:
parent
a0161760a0
commit
0dba734280
@ -72,6 +72,8 @@ class WorldModelGenerations:
|
||||
actions: tuple[Tensor, Tensor] | None = None
|
||||
log_probs: tuple[Tensor, Tensor] | None = None
|
||||
values: Tensor | None = None
|
||||
step_size: int | None = None
|
||||
agent_index: int = 0
|
||||
|
||||
# helpers
|
||||
|
||||
@ -107,7 +109,7 @@ def is_empty(t):
|
||||
def log(t, eps = 1e-20):
|
||||
return t.clamp(min = eps).log()
|
||||
|
||||
def safe_cat(*tensors, dim):
|
||||
def safe_cat(tensors, dim):
|
||||
tensors = [*filter(exists, tensors)]
|
||||
|
||||
if len(tensors) == 0:
|
||||
@ -648,7 +650,7 @@ class ActionEmbedder(Module):
|
||||
def calc_gae(
|
||||
rewards,
|
||||
values,
|
||||
masks,
|
||||
masks = None,
|
||||
gamma = 0.99,
|
||||
lam = 0.95,
|
||||
use_accelerated = None
|
||||
@ -656,6 +658,9 @@ def calc_gae(
|
||||
assert values.shape[-1] == rewards.shape[-1]
|
||||
use_accelerated = default(use_accelerated, rewards.is_cuda)
|
||||
|
||||
if not exists(masks):
|
||||
masks = torch.ones_like(values)
|
||||
|
||||
values = F.pad(values, (0, 1), value = 0.)
|
||||
values, values_next = values[..., :-1], values[..., 1:]
|
||||
|
||||
@ -1456,7 +1461,13 @@ class DynamicsWorldModel(Module):
|
||||
policy_head_mlp_depth = 3,
|
||||
behavior_clone_weight = 0.1,
|
||||
num_latent_genes = 0, # for carrying out evolution within the dreams https://web3.arxiv.org/abs/2503.19037
|
||||
num_residual_streams = 1
|
||||
num_residual_streams = 1,
|
||||
gae_discount_factor = 0.997,
|
||||
gae_lambda = 0.95,
|
||||
ppo_eps_clip = 0.2,
|
||||
value_clip = 0.4,
|
||||
policy_entropy_weight = .01,
|
||||
gae_use_accelerated = False
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -1614,6 +1625,16 @@ class DynamicsWorldModel(Module):
|
||||
final_norm = False
|
||||
)
|
||||
|
||||
# ppo related
|
||||
|
||||
self.gae_use_accelerated = gae_use_accelerated
|
||||
self.gae_discount_factor = gae_discount_factor
|
||||
self.gae_lambda = gae_lambda
|
||||
|
||||
self.ppo_eps_clip = ppo_eps_clip
|
||||
self.value_clip = value_clip
|
||||
self.policy_entropy_weight = value_clip
|
||||
|
||||
# zero
|
||||
|
||||
self.register_buffer('zero', tensor(0.), persistent = False)
|
||||
@ -1642,6 +1663,87 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
return list(set(params) - set(self.video_tokenizer.parameters()))
|
||||
|
||||
def learn_policy_from_generations(
|
||||
self,
|
||||
generation: WorldModelGenerations
|
||||
):
|
||||
latents = generation.latents
|
||||
actions = generation.actions
|
||||
old_log_probs = generation.log_probs
|
||||
old_values = generation.values
|
||||
rewards = generation.rewards
|
||||
|
||||
step_size = generation.step_size
|
||||
agent_index = generation.agent_index
|
||||
|
||||
assert all([*map(exists, (old_log_probs, actions, old_values, rewards, step_size))]), 'the generations need to contain the log probs, values, and rewards for policy optimization'
|
||||
|
||||
returns = calc_gae(rewards, old_values, gamma = self.gae_discount_factor, lam = self.gae_lambda, use_accelerated = self.gae_use_accelerated)
|
||||
|
||||
# apparently they just use the sign of the advantage
|
||||
# https://arxiv.org/abs/2410.04166v1
|
||||
|
||||
advantage = (returns - old_values).sign()
|
||||
|
||||
# replay for the action logits and values
|
||||
|
||||
discrete_actions, continuous_actions = actions
|
||||
|
||||
_, agent_embed = self.forward(
|
||||
latents = latents,
|
||||
signal_levels = self.max_steps - 1,
|
||||
step_sizes = step_size,
|
||||
rewards = rewards,
|
||||
discrete_actions = discrete_actions,
|
||||
continuous_actions = continuous_actions,
|
||||
latent_is_noised = True,
|
||||
return_pred_only = True,
|
||||
return_agent_tokens = True
|
||||
)
|
||||
|
||||
agent_embed = agent_embed[..., agent_index, :]
|
||||
|
||||
# ppo
|
||||
|
||||
policy_embed = self.policy_head(agent_embed)
|
||||
|
||||
log_probs = self.action_embedder.log_probs(policy_embed, discrete_targets = discrete_actions, continuous_targets = continuous_actions)
|
||||
|
||||
# concat discrete and continuous actions into one for optimizing
|
||||
|
||||
old_log_probs = safe_cat(old_log_probs, dim = -1)
|
||||
log_probs = safe_cat(log_probs, dim = -1)
|
||||
|
||||
ratio = (log_probs - old_log_probs).exp()
|
||||
|
||||
clipped_ratio = ratio.clamp(1. - self.ppo_eps_clip, 1. + self.ppo_eps_clip)
|
||||
|
||||
advantage = rearrange(advantage, '... -> ... 1') # broadcast across all actions
|
||||
|
||||
# clipped surrogate loss
|
||||
|
||||
policy_loss = -torch.min(ratio * advantage, clipped_ratio * advantage)
|
||||
|
||||
policy_loss = reduce(policy_loss, 'b t na -> b t', 'sum')
|
||||
policy_loss = policy_loss.mean()
|
||||
|
||||
# value loss
|
||||
|
||||
value_bins = self.value_head(agent_embed)
|
||||
values = self.reward_encoder.bins_to_scalar_value(value_bins)
|
||||
|
||||
clipped_values = old_values + (values - old_values).clamp(-self.value_clip, self.value_clip)
|
||||
clipped_value_bins = self.reward_encoder(clipped_values)
|
||||
|
||||
return_bins = self.reward_encoder(returns)
|
||||
|
||||
value_loss_1 = F.cross_entropy(value_bins, return_bins, reduction = 'none')
|
||||
value_loss_2 = F.cross_entropy(clipped_value_bins, return_bins, reduction = 'none')
|
||||
|
||||
value_loss = torch.maximum(value_loss_1, value_loss_2).mean()
|
||||
|
||||
return policy_loss, value_loss
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
self,
|
||||
@ -1761,8 +1863,8 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
sampled_discrete_actions, sampled_continuous_actions = self.action_embedder.sample(policy_embed)
|
||||
|
||||
decoded_discrete_actions = safe_cat(decoded_discrete_actions, sampled_discrete_actions, dim = 1)
|
||||
decoded_continuous_actions = safe_cat(decoded_continuous_actions, sampled_continuous_actions, dim = 1)
|
||||
decoded_discrete_actions = safe_cat((decoded_discrete_actions, sampled_discrete_actions), dim = 1)
|
||||
decoded_continuous_actions = safe_cat((decoded_continuous_actions, sampled_continuous_actions), dim = 1)
|
||||
|
||||
if return_log_probs_and_values:
|
||||
discrete_log_probs, continuous_log_probs = self.action_embedder.log_probs(
|
||||
@ -1771,13 +1873,13 @@ class DynamicsWorldModel(Module):
|
||||
continuous_targets = sampled_continuous_actions,
|
||||
)
|
||||
|
||||
decoded_discrete_log_probs = safe_cat(decoded_discrete_log_probs, discrete_log_probs, dim = 1)
|
||||
decoded_continuous_log_probs = safe_cat(decoded_continuous_log_probs, continuous_log_probs, dim = 1)
|
||||
decoded_discrete_log_probs = safe_cat((decoded_discrete_log_probs, discrete_log_probs), dim = 1)
|
||||
decoded_continuous_log_probs = safe_cat((decoded_continuous_log_probs, continuous_log_probs), dim = 1)
|
||||
|
||||
value_bins = self.value_head(one_agent_embed)
|
||||
values = self.reward_encoder.bins_to_scalar_value(value_bins)
|
||||
|
||||
decoded_values = safe_cat(decoded_values, values, dim = 1)
|
||||
decoded_values = safe_cat((decoded_values, values), dim = 1)
|
||||
|
||||
# concat the denoised latent
|
||||
|
||||
@ -1812,7 +1914,12 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
# returning agent actions, rewards, and log probs + values for policy optimization
|
||||
|
||||
gen = WorldModelGenerations(latents = latents, video = video)
|
||||
gen = WorldModelGenerations(
|
||||
latents = latents,
|
||||
video = video,
|
||||
step_size = step_size,
|
||||
agent_index = agent_index
|
||||
)
|
||||
|
||||
if return_rewards_per_frame:
|
||||
gen.rewards = decoded_rewards
|
||||
@ -2228,6 +2335,5 @@ class Dreamer(Module):
|
||||
self,
|
||||
video_tokenizer: VideoTokenizer,
|
||||
dynamics_model: DynamicsWorldModel,
|
||||
discount_factor = 0.997
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "dreamer4"
|
||||
version = "0.0.26"
|
||||
version = "0.0.27"
|
||||
description = "Dreamer 4"
|
||||
authors = [
|
||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||
|
||||
@ -216,24 +216,32 @@ def test_action_with_world_model():
|
||||
discrete_actions = torch.randint(0, 4, (1, 4, 1))
|
||||
|
||||
gen = dynamics.generate(
|
||||
10,
|
||||
16,
|
||||
batch_size = 4,
|
||||
return_rewards_per_frame = True,
|
||||
return_agent_actions = True,
|
||||
return_log_probs_and_values = True
|
||||
)
|
||||
|
||||
assert gen.video.shape == (1, 3, 10, 256, 256)
|
||||
assert gen.rewards.shape == (1, 10)
|
||||
assert gen.video.shape == (4, 3, 16, 256, 256)
|
||||
assert gen.rewards.shape == (4, 16)
|
||||
|
||||
discrete_actions, continuous_actions = gen.actions
|
||||
|
||||
assert discrete_actions.shape == (1, 10, 1)
|
||||
assert discrete_actions.shape == (4, 16, 1)
|
||||
assert continuous_actions is None
|
||||
|
||||
discrete_log_probs, _ = gen.log_probs
|
||||
|
||||
assert discrete_log_probs.shape == (1, 10, 1)
|
||||
assert gen.values.shape == (1, 10)
|
||||
assert discrete_log_probs.shape == (4, 16, 1)
|
||||
assert gen.values.shape == (4, 16)
|
||||
|
||||
# take a reinforcement learning step
|
||||
|
||||
actor_loss, critic_loss = dynamics.learn_policy_from_generations(gen)
|
||||
|
||||
actor_loss.backward(retain_graph = True)
|
||||
critic_loss.backward()
|
||||
|
||||
def test_action_embedder():
|
||||
from dreamer4.dreamer4 import ActionEmbedder
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user