able to roll out actions from one agent within the dreams of a world model

This commit is contained in:
lucidrains 2025-10-16 10:15:43 -07:00
parent d74f09f0b3
commit 2d20d0a6c1
3 changed files with 87 additions and 4 deletions

View File

@ -63,6 +63,8 @@ TokenizerLosses = namedtuple('TokenizerLosses', ('recon', 'lpips'))
WorldModelLosses = namedtuple('WorldModelLosses', ('flow', 'reward', 'behavior_clone'))
WorldModelGenerations = namedtuple('WorldModelGenerations', ('video', 'latents', 'rewards', 'actions'))
# helpers
def exists(v):
@ -74,6 +76,9 @@ def default(v, d):
def first(arr):
return arr[0]
def has_at_least_one(*bools):
return sum([*map(int, bools)]) > 0
def ensure_tuple(t):
return (t,) if not isinstance(t, tuple) else t
@ -94,6 +99,16 @@ def is_empty(t):
def log(t, eps = 1e-20):
return t.clamp(min = eps).log()
def safe_cat(*tensors, dim):
tensors = [*filter(exists, tensors)]
if len(tensors) == 0:
return None
elif len(tensors) == 1:
return tensors[0]
return cat(tensors, dim = dim)
def gumbel_noise(t):
noise = torch.rand_like(t)
return -log(-log(noise))
@ -1630,7 +1645,8 @@ class DynamicsWorldModel(Module):
image_width = None,
return_decoded_video = None,
context_signal_noise = 0.1, # they do a noising of the past, this was from an old diffusion world modeling paper from EPFL iirc
return_rewards_per_frame = False
return_rewards_per_frame = False,
return_agent_actions = False
): # (b t n d) | (b c t h w)
@ -1653,6 +1669,14 @@ class DynamicsWorldModel(Module):
past_context_noise = latents.clone()
# maybe return actions
if return_agent_actions:
assert self.action_embedder.has_actions
decoded_discrete_actions = None
decoded_continuous_actions = None
# maybe return rewards
if return_rewards_per_frame:
@ -1679,6 +1703,8 @@ class DynamicsWorldModel(Module):
signal_levels = signal_levels_with_context,
step_sizes = step_size,
rewards = decoded_rewards,
discrete_actions = decoded_discrete_actions,
continuous_actions = decoded_continuous_actions,
latent_is_noised = True,
return_pred_only = True,
return_agent_tokens = True
@ -1710,6 +1736,18 @@ class DynamicsWorldModel(Module):
decoded_rewards = cat((decoded_rewards, pred_reward), dim = 1)
# decode the agent actions if needed
if return_agent_actions:
one_agent_embed = agent_embed[:, -1:, agent_index]
policy_embed = self.policy_head(one_agent_embed)
sampled_discrete_actions, sampled_continuous_actions = self.action_embedder.sample(policy_embed)
decoded_discrete_actions = safe_cat(decoded_discrete_actions, sampled_discrete_actions, dim = 1)
decoded_continuous_actions = safe_cat(decoded_continuous_actions, sampled_continuous_actions, dim = 1)
# concat the denoised latent
latents = cat((latents, denoised_latent), dim = 1)
@ -1739,10 +1777,15 @@ class DynamicsWorldModel(Module):
width = image_width
)
if not return_rewards_per_frame:
if not has_at_least_one(return_rewards_per_frame, return_agent_actions):
return generated_video
return generated_video, decoded_rewards
return WorldModelGenerations(
video = generated_video,
latents = latents,
rewards = decoded_rewards if return_rewards_per_frame else None,
actions = (decoded_discrete_actions, decoded_continuous_actions) if return_agent_actions else None
)
def forward(
self,

View File

@ -1,6 +1,6 @@
[project]
name = "dreamer4"
version = "0.0.24"
version = "0.0.25"
description = "Dreamer 4"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

View File

@ -187,6 +187,46 @@ def test_attend_factory(
assert torch.allclose(flex_out, out, atol = 1e-6)
def test_action_with_world_model():
from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel
tokenizer = VideoTokenizer(
512,
dim_latent = 32,
patch_size = 32,
encoder_depth = 1,
decoder_depth = 1,
attn_heads = 8,
image_height = 256,
image_width = 256,
attn_kwargs = dict(
query_heads = 16
)
)
dynamics = DynamicsWorldModel(
512,
num_agents = 1,
video_tokenizer = tokenizer,
dim_latent = 32,
num_discrete_actions = 4
)
rewards = torch.randn(1, 4)
discrete_actions = torch.randint(0, 4, (1, 4, 1))
generated_video, _, generated_rewards, (discrete_actions, continuous_actions) = dynamics.generate(
10,
return_rewards_per_frame = True,
return_agent_actions = True
)
assert generated_video.shape == (1, 3, 10, 256, 256)
assert generated_rewards.shape == (1, 10)
assert discrete_actions.shape == (1, 10, 1)
assert continuous_actions is None
def test_action_embedder():
from dreamer4.dreamer4 import ActionEmbedder