extract the log probs and predicted values (symexp two hot encoded) for the phase 3 RL training

This commit is contained in:
lucidrains 2025-10-16 10:40:59 -07:00
parent 2d20d0a6c1
commit a0161760a0
3 changed files with 78 additions and 30 deletions

View File

@ -5,6 +5,7 @@ from math import ceil, log2
from random import random from random import random
from collections import namedtuple from collections import namedtuple
from functools import partial from functools import partial
from dataclasses import dataclass
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@ -63,7 +64,14 @@ TokenizerLosses = namedtuple('TokenizerLosses', ('recon', 'lpips'))
WorldModelLosses = namedtuple('WorldModelLosses', ('flow', 'reward', 'behavior_clone')) WorldModelLosses = namedtuple('WorldModelLosses', ('flow', 'reward', 'behavior_clone'))
WorldModelGenerations = namedtuple('WorldModelGenerations', ('video', 'latents', 'rewards', 'actions')) @dataclass
class WorldModelGenerations:
latents: Tensor
video: Tensor | None = None
rewards: Tensor | None = None
actions: tuple[Tensor, Tensor] | None = None
log_probs: tuple[Tensor, Tensor] | None = None
values: Tensor | None = None
# helpers # helpers
@ -1646,7 +1654,8 @@ class DynamicsWorldModel(Module):
return_decoded_video = None, return_decoded_video = None,
context_signal_noise = 0.1, # they do a noising of the past, this was from an old diffusion world modeling paper from EPFL iirc context_signal_noise = 0.1, # they do a noising of the past, this was from an old diffusion world modeling paper from EPFL iirc
return_rewards_per_frame = False, return_rewards_per_frame = False,
return_agent_actions = False return_agent_actions = False,
return_log_probs_and_values = False
): # (b t n d) | (b c t h w) ): # (b t n d) | (b c t h w)
@ -1671,11 +1680,16 @@ class DynamicsWorldModel(Module):
# maybe return actions # maybe return actions
if return_agent_actions: return_agent_actions |= return_log_probs_and_values
assert self.action_embedder.has_actions
decoded_discrete_actions = None decoded_discrete_actions = None
decoded_continuous_actions = None decoded_continuous_actions = None
# policy optimization related
decoded_discrete_log_probs = None
decoded_continuous_log_probs = None
decoded_values = None
# maybe return rewards # maybe return rewards
@ -1739,6 +1753,8 @@ class DynamicsWorldModel(Module):
# decode the agent actions if needed # decode the agent actions if needed
if return_agent_actions: if return_agent_actions:
assert self.action_embedder.has_actions
one_agent_embed = agent_embed[:, -1:, agent_index] one_agent_embed = agent_embed[:, -1:, agent_index]
policy_embed = self.policy_head(one_agent_embed) policy_embed = self.policy_head(one_agent_embed)
@ -1748,6 +1764,21 @@ class DynamicsWorldModel(Module):
decoded_discrete_actions = safe_cat(decoded_discrete_actions, sampled_discrete_actions, dim = 1) decoded_discrete_actions = safe_cat(decoded_discrete_actions, sampled_discrete_actions, dim = 1)
decoded_continuous_actions = safe_cat(decoded_continuous_actions, sampled_continuous_actions, dim = 1) decoded_continuous_actions = safe_cat(decoded_continuous_actions, sampled_continuous_actions, dim = 1)
if return_log_probs_and_values:
discrete_log_probs, continuous_log_probs = self.action_embedder.log_probs(
policy_embed,
discrete_targets = sampled_discrete_actions,
continuous_targets = sampled_continuous_actions,
)
decoded_discrete_log_probs = safe_cat(decoded_discrete_log_probs, discrete_log_probs, dim = 1)
decoded_continuous_log_probs = safe_cat(decoded_continuous_log_probs, continuous_log_probs, dim = 1)
value_bins = self.value_head(one_agent_embed)
values = self.reward_encoder.bins_to_scalar_value(value_bins)
decoded_values = safe_cat(decoded_values, values, dim = 1)
# concat the denoised latent # concat the denoised latent
latents = cat((latents, denoised_latent), dim = 1) latents = cat((latents, denoised_latent), dim = 1)
@ -1765,27 +1796,36 @@ class DynamicsWorldModel(Module):
has_tokenizer = exists(self.video_tokenizer) has_tokenizer = exists(self.video_tokenizer)
return_decoded_video = default(return_decoded_video, has_tokenizer) return_decoded_video = default(return_decoded_video, has_tokenizer)
if not return_decoded_video: video = None
if not return_rewards_per_frame:
return denoised_latents
return denoised_latents, decoded_rewards if return_decoded_video:
video = self.video_tokenizer.decode(
latents,
height = image_height,
width = image_width
)
generated_video = self.video_tokenizer.decode( # only return video or latent if not requesting anything else, for first stage training
latents,
height = image_height,
width = image_width
)
if not has_at_least_one(return_rewards_per_frame, return_agent_actions): if not has_at_least_one(return_rewards_per_frame, return_agent_actions):
return generated_video return video if return_decoded_video else latents
return WorldModelGenerations( # returning agent actions, rewards, and log probs + values for policy optimization
video = generated_video,
latents = latents, gen = WorldModelGenerations(latents = latents, video = video)
rewards = decoded_rewards if return_rewards_per_frame else None,
actions = (decoded_discrete_actions, decoded_continuous_actions) if return_agent_actions else None if return_rewards_per_frame:
) gen.rewards = decoded_rewards
if return_agent_actions:
gen.actions = (decoded_discrete_actions, decoded_continuous_actions)
if return_log_probs_and_values:
gen.log_probs = (decoded_discrete_log_probs, decoded_continuous_log_probs)
gen.values = decoded_values
return gen
def forward( def forward(
self, self,

View File

@ -1,6 +1,6 @@
[project] [project]
name = "dreamer4" name = "dreamer4"
version = "0.0.25" version = "0.0.26"
description = "Dreamer 4" description = "Dreamer 4"
authors = [ authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" } { name = "Phil Wang", email = "lucidrains@gmail.com" }

View File

@ -103,7 +103,7 @@ def test_e2e(
# generating # generating
generated_video, generated_rewards = dynamics.generate( generations = dynamics.generate(
time_steps = 10, time_steps = 10,
image_height = 128, image_height = 128,
image_width = 128, image_width = 128,
@ -111,8 +111,8 @@ def test_e2e(
return_rewards_per_frame = True return_rewards_per_frame = True
) )
assert generated_video.shape == (2, 3, 10, 128, 128) assert generations.video.shape == (2, 3, 10, 128, 128)
assert generated_rewards.shape == (2, 10) assert generations.rewards.shape == (2, 10)
# rl # rl
@ -215,18 +215,26 @@ def test_action_with_world_model():
rewards = torch.randn(1, 4) rewards = torch.randn(1, 4)
discrete_actions = torch.randint(0, 4, (1, 4, 1)) discrete_actions = torch.randint(0, 4, (1, 4, 1))
generated_video, _, generated_rewards, (discrete_actions, continuous_actions) = dynamics.generate( gen = dynamics.generate(
10, 10,
return_rewards_per_frame = True, return_rewards_per_frame = True,
return_agent_actions = True return_agent_actions = True,
return_log_probs_and_values = True
) )
assert generated_video.shape == (1, 3, 10, 256, 256) assert gen.video.shape == (1, 3, 10, 256, 256)
assert generated_rewards.shape == (1, 10) assert gen.rewards.shape == (1, 10)
discrete_actions, continuous_actions = gen.actions
assert discrete_actions.shape == (1, 10, 1) assert discrete_actions.shape == (1, 10, 1)
assert continuous_actions is None assert continuous_actions is None
discrete_log_probs, _ = gen.log_probs
assert discrete_log_probs.shape == (1, 10, 1)
assert gen.values.shape == (1, 10)
def test_action_embedder(): def test_action_embedder():
from dreamer4.dreamer4 import ActionEmbedder from dreamer4.dreamer4 import ActionEmbedder