From 83cfd2cd1b92282dcfd33fd109d8bcac2793c6c3 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Sat, 18 Oct 2025 07:47:13 -0700 Subject: [PATCH] task conditioning when dreaming --- dreamer4/dreamer4.py | 13 ++++++++++++- pyproject.toml | 2 +- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 2d82883..c263d26 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -21,7 +21,6 @@ from torch.optim import Optimizer from adam_atan2_pytorch import MuonAdamAtan2 from x_mlps_pytorch.normed_mlp import create_mlp -from x_mlps_pytorch.ensemble import Ensemble from hyper_connections import get_init_and_expand_reduce_stream_functions @@ -1816,6 +1815,7 @@ class DynamicsWorldModel(Module): num_steps = 4, batch_size = 1, agent_index = 0, + tasks: int | Tensor | None = None, image_height = None, image_width = None, return_decoded_video = None, @@ -1829,9 +1829,18 @@ class DynamicsWorldModel(Module): was_training = self.training self.eval() + # validation + assert log2(num_steps).is_integer(), f'number of steps {num_steps} must be a power of 2' assert 0 < num_steps <= self.max_steps, f'number of steps {num_steps} must be between 0 and {self.max_steps}' + if isinstance(tasks, int): + tasks = torch.full((batch_size,), tasks, device = self.device) + + assert tasks.shape[0] == batch_size + + # get state latent shape + latent_shape = self.latent_shape # derive step size @@ -1884,6 +1893,7 @@ class DynamicsWorldModel(Module): signal_levels = signal_levels_with_context, step_sizes = step_size, rewards = decoded_rewards, + tasks = tasks, discrete_actions = decoded_discrete_actions, continuous_actions = decoded_continuous_actions, latent_is_noised = True, @@ -2349,6 +2359,7 @@ class DynamicsWorldModel(Module): encoded_agent_tokens = reduce(encoded_agent_tokens, 'b t g d -> b t d', 'mean') reward_pred = self.to_reward_pred(encoded_agent_tokens) + reward_loss = F.cross_entropy(reward_pred, two_hot_encoding) # maybe autoregressive action loss diff --git a/pyproject.toml b/pyproject.toml index 5dd5629..8e7ad54 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.0.32" +version = "0.0.33" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }