diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 344049b..a0fc1b8 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -5,6 +5,7 @@ from math import ceil, log2 from random import random from collections import namedtuple from functools import partial +from dataclasses import dataclass import torch import torch.nn.functional as F @@ -63,7 +64,14 @@ TokenizerLosses = namedtuple('TokenizerLosses', ('recon', 'lpips')) WorldModelLosses = namedtuple('WorldModelLosses', ('flow', 'reward', 'behavior_clone')) -WorldModelGenerations = namedtuple('WorldModelGenerations', ('video', 'latents', 'rewards', 'actions')) +@dataclass +class WorldModelGenerations: + latents: Tensor + video: Tensor | None = None + rewards: Tensor | None = None + actions: tuple[Tensor, Tensor] | None = None + log_probs: tuple[Tensor, Tensor] | None = None + values: Tensor | None = None # helpers @@ -1646,7 +1654,8 @@ class DynamicsWorldModel(Module): return_decoded_video = None, context_signal_noise = 0.1, # they do a noising of the past, this was from an old diffusion world modeling paper from EPFL iirc return_rewards_per_frame = False, - return_agent_actions = False + return_agent_actions = False, + return_log_probs_and_values = False ): # (b t n d) | (b c t h w) @@ -1671,11 +1680,16 @@ class DynamicsWorldModel(Module): # maybe return actions - if return_agent_actions: - assert self.action_embedder.has_actions + return_agent_actions |= return_log_probs_and_values - decoded_discrete_actions = None - decoded_continuous_actions = None + decoded_discrete_actions = None + decoded_continuous_actions = None + + # policy optimization related + + decoded_discrete_log_probs = None + decoded_continuous_log_probs = None + decoded_values = None # maybe return rewards @@ -1739,6 +1753,8 @@ class DynamicsWorldModel(Module): # decode the agent actions if needed if return_agent_actions: + assert self.action_embedder.has_actions + one_agent_embed = agent_embed[:, -1:, agent_index] policy_embed = self.policy_head(one_agent_embed) @@ -1748,6 +1764,21 @@ class DynamicsWorldModel(Module): decoded_discrete_actions = safe_cat(decoded_discrete_actions, sampled_discrete_actions, dim = 1) decoded_continuous_actions = safe_cat(decoded_continuous_actions, sampled_continuous_actions, dim = 1) + if return_log_probs_and_values: + discrete_log_probs, continuous_log_probs = self.action_embedder.log_probs( + policy_embed, + discrete_targets = sampled_discrete_actions, + continuous_targets = sampled_continuous_actions, + ) + + decoded_discrete_log_probs = safe_cat(decoded_discrete_log_probs, discrete_log_probs, dim = 1) + decoded_continuous_log_probs = safe_cat(decoded_continuous_log_probs, continuous_log_probs, dim = 1) + + value_bins = self.value_head(one_agent_embed) + values = self.reward_encoder.bins_to_scalar_value(value_bins) + + decoded_values = safe_cat(decoded_values, values, dim = 1) + # concat the denoised latent latents = cat((latents, denoised_latent), dim = 1) @@ -1765,27 +1796,36 @@ class DynamicsWorldModel(Module): has_tokenizer = exists(self.video_tokenizer) return_decoded_video = default(return_decoded_video, has_tokenizer) - if not return_decoded_video: - if not return_rewards_per_frame: - return denoised_latents + video = None - return denoised_latents, decoded_rewards + if return_decoded_video: + video = self.video_tokenizer.decode( + latents, + height = image_height, + width = image_width + ) - generated_video = self.video_tokenizer.decode( - latents, - height = image_height, - width = image_width - ) + # only return video or latent if not requesting anything else, for first stage training if not has_at_least_one(return_rewards_per_frame, return_agent_actions): - return generated_video + return video if return_decoded_video else latents - return WorldModelGenerations( - video = generated_video, - latents = latents, - rewards = decoded_rewards if return_rewards_per_frame else None, - actions = (decoded_discrete_actions, decoded_continuous_actions) if return_agent_actions else None - ) + # returning agent actions, rewards, and log probs + values for policy optimization + + gen = WorldModelGenerations(latents = latents, video = video) + + if return_rewards_per_frame: + gen.rewards = decoded_rewards + + if return_agent_actions: + gen.actions = (decoded_discrete_actions, decoded_continuous_actions) + + if return_log_probs_and_values: + gen.log_probs = (decoded_discrete_log_probs, decoded_continuous_log_probs) + + gen.values = decoded_values + + return gen def forward( self, diff --git a/pyproject.toml b/pyproject.toml index 3167f30..9c41677 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.0.25" +version = "0.0.26" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } diff --git a/tests/test_dreamer.py b/tests/test_dreamer.py index ae3cff7..11ed491 100644 --- a/tests/test_dreamer.py +++ b/tests/test_dreamer.py @@ -103,7 +103,7 @@ def test_e2e( # generating - generated_video, generated_rewards = dynamics.generate( + generations = dynamics.generate( time_steps = 10, image_height = 128, image_width = 128, @@ -111,8 +111,8 @@ def test_e2e( return_rewards_per_frame = True ) - assert generated_video.shape == (2, 3, 10, 128, 128) - assert generated_rewards.shape == (2, 10) + assert generations.video.shape == (2, 3, 10, 128, 128) + assert generations.rewards.shape == (2, 10) # rl @@ -215,18 +215,26 @@ def test_action_with_world_model(): rewards = torch.randn(1, 4) discrete_actions = torch.randint(0, 4, (1, 4, 1)) - generated_video, _, generated_rewards, (discrete_actions, continuous_actions) = dynamics.generate( + gen = dynamics.generate( 10, return_rewards_per_frame = True, - return_agent_actions = True + return_agent_actions = True, + return_log_probs_and_values = True ) - assert generated_video.shape == (1, 3, 10, 256, 256) - assert generated_rewards.shape == (1, 10) + assert gen.video.shape == (1, 3, 10, 256, 256) + assert gen.rewards.shape == (1, 10) + + discrete_actions, continuous_actions = gen.actions assert discrete_actions.shape == (1, 10, 1) assert continuous_actions is None + discrete_log_probs, _ = gen.log_probs + + assert discrete_log_probs.shape == (1, 10, 1) + assert gen.values.shape == (1, 10) + def test_action_embedder(): from dreamer4.dreamer4 import ActionEmbedder