able to move the experience to cpu easily, and auto matically move it to the device of the dynamics world model when learning from it

This commit is contained in:
lucidrains 2025-11-09 16:16:13 +00:00
parent 24ef72d528
commit cfd34f1eba
4 changed files with 19 additions and 3 deletions

View File

@ -14,7 +14,7 @@ from torch.nested import nested_tensor
from torch.distributions import Normal, kl from torch.distributions import Normal, kl
from torch.nn import Module, ModuleList, Embedding, Parameter, Sequential, Linear, RMSNorm, Identity from torch.nn import Module, ModuleList, Embedding, Parameter, Sequential, Linear, RMSNorm, Identity
from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, full, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, full, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange
from torch.utils._pytree import tree_flatten, tree_unflatten from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
import torchvision import torchvision
from torchvision.models import VGG16_Weights from torchvision.models import VGG16_Weights
@ -91,6 +91,14 @@ class Experience:
agent_index: int = 0 agent_index: int = 0
is_from_world_model: bool = True is_from_world_model: bool = True
def cpu(self):
return self.to(torch.device('cpu'))
def to(self, device):
experience_dict = asdict(self)
experience_dict = tree_map(lambda t: t.to(device) if is_tensor(t) else t, experience_dict)
return Experience(**experience_dict)
def combine_experiences( def combine_experiences(
exps: list[Experiences] exps: list[Experiences]
) -> Experience: ) -> Experience:
@ -2435,6 +2443,8 @@ class DynamicsWorldModel(Module):
): ):
assert isinstance(experience, Experience) assert isinstance(experience, Experience)
experience = experience.to(self.device)
latents = experience.latents latents = experience.latents
actions = experience.actions actions = experience.actions
old_log_probs = experience.log_probs old_log_probs = experience.log_probs

View File

@ -528,7 +528,7 @@ class SimTrainer(Module):
total_experience += num_experience total_experience += num_experience
experiences.append(experience) experiences.append(experience.cpu())
combined_experiences = combine_experiences(experiences) combined_experiences = combine_experiences(experiences)

View File

@ -1,6 +1,6 @@
[project] [project]
name = "dreamer4" name = "dreamer4"
version = "0.1.4" version = "0.1.5"
description = "Dreamer 4" description = "Dreamer 4"
authors = [ authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" } { name = "Phil Wang", email = "lucidrains@gmail.com" }

View File

@ -680,6 +680,12 @@ def test_online_rl(
combined_experience = combine_experiences([one_experience, another_experience]) combined_experience = combine_experiences([one_experience, another_experience])
# quick test moving the experience to different devices
if torch.cuda.is_available():
combined_experience = combined_experience.to(torch.device('cuda'))
combined_experience = combined_experience.to(world_model_and_policy.device)
if store_agent_embed: if store_agent_embed:
assert exists(combined_experience.agent_embed) assert exists(combined_experience.agent_embed)