diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 3347187..2cd34d2 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -2069,6 +2069,7 @@ class DynamicsWorldModel(Module): image_width = None, return_decoded_video = None, context_signal_noise = 0.1, # they do a noising of the past, this was from an old diffusion world modeling paper from EPFL iirc + time_kv_cache: Tensor | None = None, use_time_kv_cache = True, return_rewards_per_frame = False, return_agent_actions = False, @@ -2120,13 +2121,10 @@ class DynamicsWorldModel(Module): # maybe return rewards + decoded_rewards = None if return_rewards_per_frame: decoded_rewards = empty((batch_size, 0), device = self.device, dtype = torch.float32) - # handle maybe time kv cache - - time_kv_cache = None - # while all the frames of the video (per latent) is not generated while latents.shape[1] < time_steps: diff --git a/pyproject.toml b/pyproject.toml index 63c5949..0c68faf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.0.54" +version = "0.0.55" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } diff --git a/tests/test_dreamer.py b/tests/test_dreamer.py index 0498123..d949f24 100644 --- a/tests/test_dreamer.py +++ b/tests/test_dreamer.py @@ -473,3 +473,33 @@ def test_tokenizer_trainer(): ) trainer() + +def test_cache_generate(): + from dreamer4.dreamer4 import DynamicsWorldModel + + dynamics = DynamicsWorldModel( + dim = 16, + dim_latent = 16, + max_steps = 64, + num_tasks = 4, + num_latent_tokens = 4, + depth = 4, + num_spatial_tokens = 1, + pred_orig_latent = True, + num_discrete_actions = 4, + attn_dim_head = 16, + prob_no_shortcut_train = 0.1, + num_residual_streams = 1 + ) + + generated, time_kv_cache = dynamics.generate(1, return_time_kv_cache = True) + generated, time_kv_cache = dynamics.generate(1, time_kv_cache = time_kv_cache, return_time_kv_cache = True) + generated, time_kv_cache = dynamics.generate(1, time_kv_cache = time_kv_cache, return_time_kv_cache = True) + +@param('vectorized', (False, True)) +def test_online_rl( + vectorized +): + from dreamer4.mocks import MockEnv + + mock_env = MockEnv((256, 256), vectorized = vectorized, batch_size = 4) \ No newline at end of file