From cfd34f1ebaddd5482e2e0513bbe12e548dbc28f7 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Sun, 9 Nov 2025 16:16:13 +0000 Subject: [PATCH] able to move the experience to cpu easily, and auto matically move it to the device of the dynamics world model when learning from it --- dreamer4/dreamer4.py | 12 +++++++++++- dreamer4/trainers.py | 2 +- pyproject.toml | 2 +- tests/test_dreamer.py | 6 ++++++ 4 files changed, 19 insertions(+), 3 deletions(-) diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 346e9ce..b017ea7 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -14,7 +14,7 @@ from torch.nested import nested_tensor from torch.distributions import Normal, kl from torch.nn import Module, ModuleList, Embedding, Parameter, Sequential, Linear, RMSNorm, Identity from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, full, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange -from torch.utils._pytree import tree_flatten, tree_unflatten +from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten import torchvision from torchvision.models import VGG16_Weights @@ -91,6 +91,14 @@ class Experience: agent_index: int = 0 is_from_world_model: bool = True + def cpu(self): + return self.to(torch.device('cpu')) + + def to(self, device): + experience_dict = asdict(self) + experience_dict = tree_map(lambda t: t.to(device) if is_tensor(t) else t, experience_dict) + return Experience(**experience_dict) + def combine_experiences( exps: list[Experiences] ) -> Experience: @@ -2435,6 +2443,8 @@ class DynamicsWorldModel(Module): ): assert isinstance(experience, Experience) + experience = experience.to(self.device) + latents = experience.latents actions = experience.actions old_log_probs = experience.log_probs diff --git a/dreamer4/trainers.py b/dreamer4/trainers.py index bb4b198..09d0971 100644 --- a/dreamer4/trainers.py +++ b/dreamer4/trainers.py @@ -528,7 +528,7 @@ class SimTrainer(Module): total_experience += num_experience - experiences.append(experience) + experiences.append(experience.cpu()) combined_experiences = combine_experiences(experiences) diff --git a/pyproject.toml b/pyproject.toml index 57574ca..2746a1f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.1.4" +version = "0.1.5" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } diff --git a/tests/test_dreamer.py b/tests/test_dreamer.py index b9f4de1..15156fe 100644 --- a/tests/test_dreamer.py +++ b/tests/test_dreamer.py @@ -680,6 +680,12 @@ def test_online_rl( combined_experience = combine_experiences([one_experience, another_experience]) + # quick test moving the experience to different devices + + if torch.cuda.is_available(): + combined_experience = combined_experience.to(torch.device('cuda')) + combined_experience = combined_experience.to(world_model_and_policy.device) + if store_agent_embed: assert exists(combined_experience.agent_embed)