extract the log probs and predicted values (symexp two hot encoded) for the phase 3 RL training

This commit is contained in:
lucidrains 2025-10-16 10:40:59 -07:00
parent 2d20d0a6c1
commit a0161760a0
3 changed files with 78 additions and 30 deletions

View File

@ -5,6 +5,7 @@ from math import ceil, log2
from random import random
from collections import namedtuple
from functools import partial
from dataclasses import dataclass
import torch
import torch.nn.functional as F
@ -63,7 +64,14 @@ TokenizerLosses = namedtuple('TokenizerLosses', ('recon', 'lpips'))
WorldModelLosses = namedtuple('WorldModelLosses', ('flow', 'reward', 'behavior_clone'))
WorldModelGenerations = namedtuple('WorldModelGenerations', ('video', 'latents', 'rewards', 'actions'))
@dataclass
class WorldModelGenerations:
latents: Tensor
video: Tensor | None = None
rewards: Tensor | None = None
actions: tuple[Tensor, Tensor] | None = None
log_probs: tuple[Tensor, Tensor] | None = None
values: Tensor | None = None
# helpers
@ -1646,7 +1654,8 @@ class DynamicsWorldModel(Module):
return_decoded_video = None,
context_signal_noise = 0.1, # they do a noising of the past, this was from an old diffusion world modeling paper from EPFL iirc
return_rewards_per_frame = False,
return_agent_actions = False
return_agent_actions = False,
return_log_probs_and_values = False
): # (b t n d) | (b c t h w)
@ -1671,11 +1680,16 @@ class DynamicsWorldModel(Module):
# maybe return actions
if return_agent_actions:
assert self.action_embedder.has_actions
return_agent_actions |= return_log_probs_and_values
decoded_discrete_actions = None
decoded_continuous_actions = None
decoded_discrete_actions = None
decoded_continuous_actions = None
# policy optimization related
decoded_discrete_log_probs = None
decoded_continuous_log_probs = None
decoded_values = None
# maybe return rewards
@ -1739,6 +1753,8 @@ class DynamicsWorldModel(Module):
# decode the agent actions if needed
if return_agent_actions:
assert self.action_embedder.has_actions
one_agent_embed = agent_embed[:, -1:, agent_index]
policy_embed = self.policy_head(one_agent_embed)
@ -1748,6 +1764,21 @@ class DynamicsWorldModel(Module):
decoded_discrete_actions = safe_cat(decoded_discrete_actions, sampled_discrete_actions, dim = 1)
decoded_continuous_actions = safe_cat(decoded_continuous_actions, sampled_continuous_actions, dim = 1)
if return_log_probs_and_values:
discrete_log_probs, continuous_log_probs = self.action_embedder.log_probs(
policy_embed,
discrete_targets = sampled_discrete_actions,
continuous_targets = sampled_continuous_actions,
)
decoded_discrete_log_probs = safe_cat(decoded_discrete_log_probs, discrete_log_probs, dim = 1)
decoded_continuous_log_probs = safe_cat(decoded_continuous_log_probs, continuous_log_probs, dim = 1)
value_bins = self.value_head(one_agent_embed)
values = self.reward_encoder.bins_to_scalar_value(value_bins)
decoded_values = safe_cat(decoded_values, values, dim = 1)
# concat the denoised latent
latents = cat((latents, denoised_latent), dim = 1)
@ -1765,27 +1796,36 @@ class DynamicsWorldModel(Module):
has_tokenizer = exists(self.video_tokenizer)
return_decoded_video = default(return_decoded_video, has_tokenizer)
if not return_decoded_video:
if not return_rewards_per_frame:
return denoised_latents
video = None
return denoised_latents, decoded_rewards
if return_decoded_video:
video = self.video_tokenizer.decode(
latents,
height = image_height,
width = image_width
)
generated_video = self.video_tokenizer.decode(
latents,
height = image_height,
width = image_width
)
# only return video or latent if not requesting anything else, for first stage training
if not has_at_least_one(return_rewards_per_frame, return_agent_actions):
return generated_video
return video if return_decoded_video else latents
return WorldModelGenerations(
video = generated_video,
latents = latents,
rewards = decoded_rewards if return_rewards_per_frame else None,
actions = (decoded_discrete_actions, decoded_continuous_actions) if return_agent_actions else None
)
# returning agent actions, rewards, and log probs + values for policy optimization
gen = WorldModelGenerations(latents = latents, video = video)
if return_rewards_per_frame:
gen.rewards = decoded_rewards
if return_agent_actions:
gen.actions = (decoded_discrete_actions, decoded_continuous_actions)
if return_log_probs_and_values:
gen.log_probs = (decoded_discrete_log_probs, decoded_continuous_log_probs)
gen.values = decoded_values
return gen
def forward(
self,

View File

@ -1,6 +1,6 @@
[project]
name = "dreamer4"
version = "0.0.25"
version = "0.0.26"
description = "Dreamer 4"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

View File

@ -103,7 +103,7 @@ def test_e2e(
# generating
generated_video, generated_rewards = dynamics.generate(
generations = dynamics.generate(
time_steps = 10,
image_height = 128,
image_width = 128,
@ -111,8 +111,8 @@ def test_e2e(
return_rewards_per_frame = True
)
assert generated_video.shape == (2, 3, 10, 128, 128)
assert generated_rewards.shape == (2, 10)
assert generations.video.shape == (2, 3, 10, 128, 128)
assert generations.rewards.shape == (2, 10)
# rl
@ -215,18 +215,26 @@ def test_action_with_world_model():
rewards = torch.randn(1, 4)
discrete_actions = torch.randint(0, 4, (1, 4, 1))
generated_video, _, generated_rewards, (discrete_actions, continuous_actions) = dynamics.generate(
gen = dynamics.generate(
10,
return_rewards_per_frame = True,
return_agent_actions = True
return_agent_actions = True,
return_log_probs_and_values = True
)
assert generated_video.shape == (1, 3, 10, 256, 256)
assert generated_rewards.shape == (1, 10)
assert gen.video.shape == (1, 3, 10, 256, 256)
assert gen.rewards.shape == (1, 10)
discrete_actions, continuous_actions = gen.actions
assert discrete_actions.shape == (1, 10, 1)
assert continuous_actions is None
discrete_log_probs, _ = gen.log_probs
assert discrete_log_probs.shape == (1, 10, 1)
assert gen.values.shape == (1, 10)
def test_action_embedder():
from dreamer4.dreamer4 import ActionEmbedder