diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 346e9ce..b017ea7 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -14,7 +14,7 @@ from torch.nested import nested_tensor from torch.distributions import Normal, kl from torch.nn import Module, ModuleList, Embedding, Parameter, Sequential, Linear, RMSNorm, Identity from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, full, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange -from torch.utils._pytree import tree_flatten, tree_unflatten +from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten import torchvision from torchvision.models import VGG16_Weights @@ -91,6 +91,14 @@ class Experience: agent_index: int = 0 is_from_world_model: bool = True + def cpu(self): + return self.to(torch.device('cpu')) + + def to(self, device): + experience_dict = asdict(self) + experience_dict = tree_map(lambda t: t.to(device) if is_tensor(t) else t, experience_dict) + return Experience(**experience_dict) + def combine_experiences( exps: list[Experiences] ) -> Experience: @@ -2435,6 +2443,8 @@ class DynamicsWorldModel(Module): ): assert isinstance(experience, Experience) + experience = experience.to(self.device) + latents = experience.latents actions = experience.actions old_log_probs = experience.log_probs diff --git a/dreamer4/trainers.py b/dreamer4/trainers.py index bb4b198..09d0971 100644 --- a/dreamer4/trainers.py +++ b/dreamer4/trainers.py @@ -528,7 +528,7 @@ class SimTrainer(Module): total_experience += num_experience - experiences.append(experience) + experiences.append(experience.cpu()) combined_experiences = combine_experiences(experiences) diff --git a/pyproject.toml b/pyproject.toml index 57574ca..2746a1f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.1.4" +version = "0.1.5" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } diff --git a/tests/test_dreamer.py b/tests/test_dreamer.py index b9f4de1..15156fe 100644 --- a/tests/test_dreamer.py +++ b/tests/test_dreamer.py @@ -680,6 +680,12 @@ def test_online_rl( combined_experience = combine_experiences([one_experience, another_experience]) + # quick test moving the experience to different devices + + if torch.cuda.is_available(): + combined_experience = combined_experience.to(torch.device('cuda')) + combined_experience = combined_experience.to(world_model_and_policy.device) + if store_agent_embed: assert exists(combined_experience.agent_embed)