able to move the experience to cpu easily, and auto matically move it to the device of the dynamics world model when learning from it

This commit is contained in:
lucidrains 2025-11-09 16:16:13 +00:00
parent 24ef72d528
commit cfd34f1eba
4 changed files with 19 additions and 3 deletions

View File

@ -14,7 +14,7 @@ from torch.nested import nested_tensor
from torch.distributions import Normal, kl
from torch.nn import Module, ModuleList, Embedding, Parameter, Sequential, Linear, RMSNorm, Identity
from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, full, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange
from torch.utils._pytree import tree_flatten, tree_unflatten
from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
import torchvision
from torchvision.models import VGG16_Weights
@ -91,6 +91,14 @@ class Experience:
agent_index: int = 0
is_from_world_model: bool = True
def cpu(self):
return self.to(torch.device('cpu'))
def to(self, device):
experience_dict = asdict(self)
experience_dict = tree_map(lambda t: t.to(device) if is_tensor(t) else t, experience_dict)
return Experience(**experience_dict)
def combine_experiences(
exps: list[Experiences]
) -> Experience:
@ -2435,6 +2443,8 @@ class DynamicsWorldModel(Module):
):
assert isinstance(experience, Experience)
experience = experience.to(self.device)
latents = experience.latents
actions = experience.actions
old_log_probs = experience.log_probs

View File

@ -528,7 +528,7 @@ class SimTrainer(Module):
total_experience += num_experience
experiences.append(experience)
experiences.append(experience.cpu())
combined_experiences = combine_experiences(experiences)

View File

@ -1,6 +1,6 @@
[project]
name = "dreamer4"
version = "0.1.4"
version = "0.1.5"
description = "Dreamer 4"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

View File

@ -680,6 +680,12 @@ def test_online_rl(
combined_experience = combine_experiences([one_experience, another_experience])
# quick test moving the experience to different devices
if torch.cuda.is_available():
combined_experience = combined_experience.to(torch.device('cuda'))
combined_experience = combined_experience.to(world_model_and_policy.device)
if store_agent_embed:
assert exists(combined_experience.agent_embed)