start the learning in dreams portion

This commit is contained in:
lucidrains 2025-10-17 08:00:47 -07:00
parent a0161760a0
commit 0dba734280
3 changed files with 131 additions and 17 deletions

View File

@ -72,6 +72,8 @@ class WorldModelGenerations:
actions: tuple[Tensor, Tensor] | None = None
log_probs: tuple[Tensor, Tensor] | None = None
values: Tensor | None = None
step_size: int | None = None
agent_index: int = 0
# helpers
@ -107,7 +109,7 @@ def is_empty(t):
def log(t, eps = 1e-20):
return t.clamp(min = eps).log()
def safe_cat(*tensors, dim):
def safe_cat(tensors, dim):
tensors = [*filter(exists, tensors)]
if len(tensors) == 0:
@ -648,7 +650,7 @@ class ActionEmbedder(Module):
def calc_gae(
rewards,
values,
masks,
masks = None,
gamma = 0.99,
lam = 0.95,
use_accelerated = None
@ -656,6 +658,9 @@ def calc_gae(
assert values.shape[-1] == rewards.shape[-1]
use_accelerated = default(use_accelerated, rewards.is_cuda)
if not exists(masks):
masks = torch.ones_like(values)
values = F.pad(values, (0, 1), value = 0.)
values, values_next = values[..., :-1], values[..., 1:]
@ -1456,7 +1461,13 @@ class DynamicsWorldModel(Module):
policy_head_mlp_depth = 3,
behavior_clone_weight = 0.1,
num_latent_genes = 0, # for carrying out evolution within the dreams https://web3.arxiv.org/abs/2503.19037
num_residual_streams = 1
num_residual_streams = 1,
gae_discount_factor = 0.997,
gae_lambda = 0.95,
ppo_eps_clip = 0.2,
value_clip = 0.4,
policy_entropy_weight = .01,
gae_use_accelerated = False
):
super().__init__()
@ -1614,6 +1625,16 @@ class DynamicsWorldModel(Module):
final_norm = False
)
# ppo related
self.gae_use_accelerated = gae_use_accelerated
self.gae_discount_factor = gae_discount_factor
self.gae_lambda = gae_lambda
self.ppo_eps_clip = ppo_eps_clip
self.value_clip = value_clip
self.policy_entropy_weight = value_clip
# zero
self.register_buffer('zero', tensor(0.), persistent = False)
@ -1642,6 +1663,87 @@ class DynamicsWorldModel(Module):
return list(set(params) - set(self.video_tokenizer.parameters()))
def learn_policy_from_generations(
self,
generation: WorldModelGenerations
):
latents = generation.latents
actions = generation.actions
old_log_probs = generation.log_probs
old_values = generation.values
rewards = generation.rewards
step_size = generation.step_size
agent_index = generation.agent_index
assert all([*map(exists, (old_log_probs, actions, old_values, rewards, step_size))]), 'the generations need to contain the log probs, values, and rewards for policy optimization'
returns = calc_gae(rewards, old_values, gamma = self.gae_discount_factor, lam = self.gae_lambda, use_accelerated = self.gae_use_accelerated)
# apparently they just use the sign of the advantage
# https://arxiv.org/abs/2410.04166v1
advantage = (returns - old_values).sign()
# replay for the action logits and values
discrete_actions, continuous_actions = actions
_, agent_embed = self.forward(
latents = latents,
signal_levels = self.max_steps - 1,
step_sizes = step_size,
rewards = rewards,
discrete_actions = discrete_actions,
continuous_actions = continuous_actions,
latent_is_noised = True,
return_pred_only = True,
return_agent_tokens = True
)
agent_embed = agent_embed[..., agent_index, :]
# ppo
policy_embed = self.policy_head(agent_embed)
log_probs = self.action_embedder.log_probs(policy_embed, discrete_targets = discrete_actions, continuous_targets = continuous_actions)
# concat discrete and continuous actions into one for optimizing
old_log_probs = safe_cat(old_log_probs, dim = -1)
log_probs = safe_cat(log_probs, dim = -1)
ratio = (log_probs - old_log_probs).exp()
clipped_ratio = ratio.clamp(1. - self.ppo_eps_clip, 1. + self.ppo_eps_clip)
advantage = rearrange(advantage, '... -> ... 1') # broadcast across all actions
# clipped surrogate loss
policy_loss = -torch.min(ratio * advantage, clipped_ratio * advantage)
policy_loss = reduce(policy_loss, 'b t na -> b t', 'sum')
policy_loss = policy_loss.mean()
# value loss
value_bins = self.value_head(agent_embed)
values = self.reward_encoder.bins_to_scalar_value(value_bins)
clipped_values = old_values + (values - old_values).clamp(-self.value_clip, self.value_clip)
clipped_value_bins = self.reward_encoder(clipped_values)
return_bins = self.reward_encoder(returns)
value_loss_1 = F.cross_entropy(value_bins, return_bins, reduction = 'none')
value_loss_2 = F.cross_entropy(clipped_value_bins, return_bins, reduction = 'none')
value_loss = torch.maximum(value_loss_1, value_loss_2).mean()
return policy_loss, value_loss
@torch.no_grad()
def generate(
self,
@ -1761,8 +1863,8 @@ class DynamicsWorldModel(Module):
sampled_discrete_actions, sampled_continuous_actions = self.action_embedder.sample(policy_embed)
decoded_discrete_actions = safe_cat(decoded_discrete_actions, sampled_discrete_actions, dim = 1)
decoded_continuous_actions = safe_cat(decoded_continuous_actions, sampled_continuous_actions, dim = 1)
decoded_discrete_actions = safe_cat((decoded_discrete_actions, sampled_discrete_actions), dim = 1)
decoded_continuous_actions = safe_cat((decoded_continuous_actions, sampled_continuous_actions), dim = 1)
if return_log_probs_and_values:
discrete_log_probs, continuous_log_probs = self.action_embedder.log_probs(
@ -1771,13 +1873,13 @@ class DynamicsWorldModel(Module):
continuous_targets = sampled_continuous_actions,
)
decoded_discrete_log_probs = safe_cat(decoded_discrete_log_probs, discrete_log_probs, dim = 1)
decoded_continuous_log_probs = safe_cat(decoded_continuous_log_probs, continuous_log_probs, dim = 1)
decoded_discrete_log_probs = safe_cat((decoded_discrete_log_probs, discrete_log_probs), dim = 1)
decoded_continuous_log_probs = safe_cat((decoded_continuous_log_probs, continuous_log_probs), dim = 1)
value_bins = self.value_head(one_agent_embed)
values = self.reward_encoder.bins_to_scalar_value(value_bins)
decoded_values = safe_cat(decoded_values, values, dim = 1)
decoded_values = safe_cat((decoded_values, values), dim = 1)
# concat the denoised latent
@ -1812,7 +1914,12 @@ class DynamicsWorldModel(Module):
# returning agent actions, rewards, and log probs + values for policy optimization
gen = WorldModelGenerations(latents = latents, video = video)
gen = WorldModelGenerations(
latents = latents,
video = video,
step_size = step_size,
agent_index = agent_index
)
if return_rewards_per_frame:
gen.rewards = decoded_rewards
@ -2228,6 +2335,5 @@ class Dreamer(Module):
self,
video_tokenizer: VideoTokenizer,
dynamics_model: DynamicsWorldModel,
discount_factor = 0.997
):
super().__init__()

View File

@ -1,6 +1,6 @@
[project]
name = "dreamer4"
version = "0.0.26"
version = "0.0.27"
description = "Dreamer 4"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

View File

@ -216,24 +216,32 @@ def test_action_with_world_model():
discrete_actions = torch.randint(0, 4, (1, 4, 1))
gen = dynamics.generate(
10,
16,
batch_size = 4,
return_rewards_per_frame = True,
return_agent_actions = True,
return_log_probs_and_values = True
)
assert gen.video.shape == (1, 3, 10, 256, 256)
assert gen.rewards.shape == (1, 10)
assert gen.video.shape == (4, 3, 16, 256, 256)
assert gen.rewards.shape == (4, 16)
discrete_actions, continuous_actions = gen.actions
assert discrete_actions.shape == (1, 10, 1)
assert discrete_actions.shape == (4, 16, 1)
assert continuous_actions is None
discrete_log_probs, _ = gen.log_probs
assert discrete_log_probs.shape == (1, 10, 1)
assert gen.values.shape == (1, 10)
assert discrete_log_probs.shape == (4, 16, 1)
assert gen.values.shape == (4, 16)
# take a reinforcement learning step
actor_loss, critic_loss = dynamics.learn_policy_from_generations(gen)
actor_loss.backward(retain_graph = True)
critic_loss.backward()
def test_action_embedder():
from dreamer4.dreamer4 import ActionEmbedder