diff --git a/README.md b/README.md index a90ad62..8d21e00 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,75 @@ -## Dreamer 4 (wip) +## Dreamer 4 Implementation of Danijar's [latest iteration](https://arxiv.org/abs/2509.24527v1) for his [Dreamer](https://danijar.com/project/dreamer4/) line of work +## Install + +```bash +$ pip install dreamer4 +``` + +## Usage + +```python +import torch +from dreamer4 import VideoTokenizer, DynamicsWorldModel + +# video tokenizer, learned through MAE + lpips + +tokenizer = VideoTokenizer( + dim = 512, + dim_latent = 32, + patch_size = 32, + image_height = 256, + image_width = 256 +) + +# dynamics world model + +dynamics = DynamicsWorldModel( + dim = 512, + dim_latent = 32, + video_tokenizer = tokenizer, + num_discrete_actions = 4, + num_residual_streams = 1 +) + +# state, action, rewards + +video = torch.randn(2, 3, 10, 256, 256) +discrete_actions = torch.randint(0, 4, (2, 10, 1)) +rewards = torch.randn(2, 10) + +# learn dynamics / behavior cloned model + +loss = dynamics( + video = video, + rewards = rewards, + discrete_actions = discrete_actions +) + +loss.backward() + +# do the above with much data + +# then generate dreams + +dreams = dynamics.generate( + 10, + batch_size = 2, + return_decoded_video = True, + return_for_policy_optimization = True +) + +# learn from the dreams + +actor_loss, critic_loss = dynamics.learn_from_experience(dreams) + +(actor_loss + critic_loss).backward() +``` + ## Citation ```bibtex @@ -17,3 +83,5 @@ Implementation of Danijar's [latest iteration](https://arxiv.org/abs/2509.24527v url = {https://arxiv.org/abs/2509.24527}, } ``` + +*the conquest of nature is to be achieved through number and measure - angels to Descartes in a dream* diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 6a3a8ee..340ecc4 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -2429,6 +2429,7 @@ class DynamicsWorldModel(Module): normalize_advantages = None, eps = 1e-6 ): + assert isinstance(experience, Experience) latents = experience.latents actions = experience.actions @@ -2441,7 +2442,7 @@ class DynamicsWorldModel(Module): step_size = experience.step_size agent_index = experience.agent_index - assert all([*map(exists, (old_log_probs, actions, old_values, rewards, step_size))]), 'the generations need to contain the log probs, values, and rewards for policy optimization' + assert all([*map(exists, (old_log_probs, actions, old_values, rewards, step_size))]), 'the generations need to contain the log probs, values, and rewards for policy optimization - world_model.generate(..., return_log_probs_and_values = True)' batch, time = latents.shape[0], latents.shape[1] @@ -2694,12 +2695,22 @@ class DynamicsWorldModel(Module): return_rewards_per_frame = False, return_agent_actions = False, return_log_probs_and_values = False, + return_for_policy_optimization = False, return_time_kv_cache = False, store_agent_embed = True, store_old_action_unembeds = True ): # (b t n d) | (b c t h w) + # handy flag for returning generations for rl + + if return_for_policy_optimization: + return_agent_actions |= True + return_log_probs_and_values |= True + return_rewards_per_frame |= True + + # more variables + has_proprio = self.has_proprio was_training = self.training self.eval() @@ -2769,6 +2780,19 @@ class DynamicsWorldModel(Module): curr_time_steps = latents.shape[1] + # determine whether to take an extra step if + # (1) using time kv cache + # (2) decoding anything off agent embedding (rewards, actions, etc) + + take_extra_step = ( + use_time_kv_cache or + return_rewards_per_frame or + store_agent_embed or + return_agent_actions + ) + + # prepare noised latent / proprio inputs + noised_latent = randn((batch_size, 1, self.num_video_views, *latent_shape), device = self.device) noised_proprio = None @@ -2776,7 +2800,10 @@ class DynamicsWorldModel(Module): if has_proprio: noised_proprio = randn((batch_size, 1, self.dim_proprio), device = self.device) - for step in range(num_steps): + # denoising steps + + for step in range(num_steps + int(take_extra_step)): + is_last_step = (step + 1) == num_steps signal_levels = full((batch_size, 1), step * step_size, dtype = torch.long, device = self.device) @@ -2819,6 +2846,11 @@ class DynamicsWorldModel(Module): if use_time_kv_cache and is_last_step: time_kv_cache = next_time_kv_cache + # early break if taking an extra step for agent embedding off cleaned latents for decoding + + if take_extra_step and is_last_step: + break + # maybe proprio if has_proprio: @@ -3021,7 +3053,7 @@ class DynamicsWorldModel(Module): latent_is_noised = False, return_all_losses = False, return_intermediates = False, - add_autoregressive_action_loss = False, + add_autoregressive_action_loss = True, update_loss_ema = None, latent_has_view_dim = False ): diff --git a/pyproject.toml b/pyproject.toml index 3e7354b..ab8ddac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.0.102" +version = "0.1.0" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }