extract the log probs and predicted values (symexp two hot encoded) for the phase 3 RL training
This commit is contained in:
parent
2d20d0a6c1
commit
a0161760a0
@ -5,6 +5,7 @@ from math import ceil, log2
|
||||
from random import random
|
||||
from collections import namedtuple
|
||||
from functools import partial
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@ -63,7 +64,14 @@ TokenizerLosses = namedtuple('TokenizerLosses', ('recon', 'lpips'))
|
||||
|
||||
WorldModelLosses = namedtuple('WorldModelLosses', ('flow', 'reward', 'behavior_clone'))
|
||||
|
||||
WorldModelGenerations = namedtuple('WorldModelGenerations', ('video', 'latents', 'rewards', 'actions'))
|
||||
@dataclass
|
||||
class WorldModelGenerations:
|
||||
latents: Tensor
|
||||
video: Tensor | None = None
|
||||
rewards: Tensor | None = None
|
||||
actions: tuple[Tensor, Tensor] | None = None
|
||||
log_probs: tuple[Tensor, Tensor] | None = None
|
||||
values: Tensor | None = None
|
||||
|
||||
# helpers
|
||||
|
||||
@ -1646,7 +1654,8 @@ class DynamicsWorldModel(Module):
|
||||
return_decoded_video = None,
|
||||
context_signal_noise = 0.1, # they do a noising of the past, this was from an old diffusion world modeling paper from EPFL iirc
|
||||
return_rewards_per_frame = False,
|
||||
return_agent_actions = False
|
||||
return_agent_actions = False,
|
||||
return_log_probs_and_values = False
|
||||
|
||||
): # (b t n d) | (b c t h w)
|
||||
|
||||
@ -1671,11 +1680,16 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
# maybe return actions
|
||||
|
||||
if return_agent_actions:
|
||||
assert self.action_embedder.has_actions
|
||||
return_agent_actions |= return_log_probs_and_values
|
||||
|
||||
decoded_discrete_actions = None
|
||||
decoded_continuous_actions = None
|
||||
decoded_discrete_actions = None
|
||||
decoded_continuous_actions = None
|
||||
|
||||
# policy optimization related
|
||||
|
||||
decoded_discrete_log_probs = None
|
||||
decoded_continuous_log_probs = None
|
||||
decoded_values = None
|
||||
|
||||
# maybe return rewards
|
||||
|
||||
@ -1739,6 +1753,8 @@ class DynamicsWorldModel(Module):
|
||||
# decode the agent actions if needed
|
||||
|
||||
if return_agent_actions:
|
||||
assert self.action_embedder.has_actions
|
||||
|
||||
one_agent_embed = agent_embed[:, -1:, agent_index]
|
||||
|
||||
policy_embed = self.policy_head(one_agent_embed)
|
||||
@ -1748,6 +1764,21 @@ class DynamicsWorldModel(Module):
|
||||
decoded_discrete_actions = safe_cat(decoded_discrete_actions, sampled_discrete_actions, dim = 1)
|
||||
decoded_continuous_actions = safe_cat(decoded_continuous_actions, sampled_continuous_actions, dim = 1)
|
||||
|
||||
if return_log_probs_and_values:
|
||||
discrete_log_probs, continuous_log_probs = self.action_embedder.log_probs(
|
||||
policy_embed,
|
||||
discrete_targets = sampled_discrete_actions,
|
||||
continuous_targets = sampled_continuous_actions,
|
||||
)
|
||||
|
||||
decoded_discrete_log_probs = safe_cat(decoded_discrete_log_probs, discrete_log_probs, dim = 1)
|
||||
decoded_continuous_log_probs = safe_cat(decoded_continuous_log_probs, continuous_log_probs, dim = 1)
|
||||
|
||||
value_bins = self.value_head(one_agent_embed)
|
||||
values = self.reward_encoder.bins_to_scalar_value(value_bins)
|
||||
|
||||
decoded_values = safe_cat(decoded_values, values, dim = 1)
|
||||
|
||||
# concat the denoised latent
|
||||
|
||||
latents = cat((latents, denoised_latent), dim = 1)
|
||||
@ -1765,27 +1796,36 @@ class DynamicsWorldModel(Module):
|
||||
has_tokenizer = exists(self.video_tokenizer)
|
||||
return_decoded_video = default(return_decoded_video, has_tokenizer)
|
||||
|
||||
if not return_decoded_video:
|
||||
if not return_rewards_per_frame:
|
||||
return denoised_latents
|
||||
video = None
|
||||
|
||||
return denoised_latents, decoded_rewards
|
||||
if return_decoded_video:
|
||||
video = self.video_tokenizer.decode(
|
||||
latents,
|
||||
height = image_height,
|
||||
width = image_width
|
||||
)
|
||||
|
||||
generated_video = self.video_tokenizer.decode(
|
||||
latents,
|
||||
height = image_height,
|
||||
width = image_width
|
||||
)
|
||||
# only return video or latent if not requesting anything else, for first stage training
|
||||
|
||||
if not has_at_least_one(return_rewards_per_frame, return_agent_actions):
|
||||
return generated_video
|
||||
return video if return_decoded_video else latents
|
||||
|
||||
return WorldModelGenerations(
|
||||
video = generated_video,
|
||||
latents = latents,
|
||||
rewards = decoded_rewards if return_rewards_per_frame else None,
|
||||
actions = (decoded_discrete_actions, decoded_continuous_actions) if return_agent_actions else None
|
||||
)
|
||||
# returning agent actions, rewards, and log probs + values for policy optimization
|
||||
|
||||
gen = WorldModelGenerations(latents = latents, video = video)
|
||||
|
||||
if return_rewards_per_frame:
|
||||
gen.rewards = decoded_rewards
|
||||
|
||||
if return_agent_actions:
|
||||
gen.actions = (decoded_discrete_actions, decoded_continuous_actions)
|
||||
|
||||
if return_log_probs_and_values:
|
||||
gen.log_probs = (decoded_discrete_log_probs, decoded_continuous_log_probs)
|
||||
|
||||
gen.values = decoded_values
|
||||
|
||||
return gen
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "dreamer4"
|
||||
version = "0.0.25"
|
||||
version = "0.0.26"
|
||||
description = "Dreamer 4"
|
||||
authors = [
|
||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||
|
||||
@ -103,7 +103,7 @@ def test_e2e(
|
||||
|
||||
# generating
|
||||
|
||||
generated_video, generated_rewards = dynamics.generate(
|
||||
generations = dynamics.generate(
|
||||
time_steps = 10,
|
||||
image_height = 128,
|
||||
image_width = 128,
|
||||
@ -111,8 +111,8 @@ def test_e2e(
|
||||
return_rewards_per_frame = True
|
||||
)
|
||||
|
||||
assert generated_video.shape == (2, 3, 10, 128, 128)
|
||||
assert generated_rewards.shape == (2, 10)
|
||||
assert generations.video.shape == (2, 3, 10, 128, 128)
|
||||
assert generations.rewards.shape == (2, 10)
|
||||
|
||||
# rl
|
||||
|
||||
@ -215,18 +215,26 @@ def test_action_with_world_model():
|
||||
rewards = torch.randn(1, 4)
|
||||
discrete_actions = torch.randint(0, 4, (1, 4, 1))
|
||||
|
||||
generated_video, _, generated_rewards, (discrete_actions, continuous_actions) = dynamics.generate(
|
||||
gen = dynamics.generate(
|
||||
10,
|
||||
return_rewards_per_frame = True,
|
||||
return_agent_actions = True
|
||||
return_agent_actions = True,
|
||||
return_log_probs_and_values = True
|
||||
)
|
||||
|
||||
assert generated_video.shape == (1, 3, 10, 256, 256)
|
||||
assert generated_rewards.shape == (1, 10)
|
||||
assert gen.video.shape == (1, 3, 10, 256, 256)
|
||||
assert gen.rewards.shape == (1, 10)
|
||||
|
||||
discrete_actions, continuous_actions = gen.actions
|
||||
|
||||
assert discrete_actions.shape == (1, 10, 1)
|
||||
assert continuous_actions is None
|
||||
|
||||
discrete_log_probs, _ = gen.log_probs
|
||||
|
||||
assert discrete_log_probs.shape == (1, 10, 1)
|
||||
assert gen.values.shape == (1, 10)
|
||||
|
||||
def test_action_embedder():
|
||||
from dreamer4.dreamer4 import ActionEmbedder
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user