able to move the experience to cpu easily, and auto matically move it to the device of the dynamics world model when learning from it
This commit is contained in:
parent
24ef72d528
commit
cfd34f1eba
@ -14,7 +14,7 @@ from torch.nested import nested_tensor
|
||||
from torch.distributions import Normal, kl
|
||||
from torch.nn import Module, ModuleList, Embedding, Parameter, Sequential, Linear, RMSNorm, Identity
|
||||
from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, full, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange
|
||||
from torch.utils._pytree import tree_flatten, tree_unflatten
|
||||
from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
|
||||
|
||||
import torchvision
|
||||
from torchvision.models import VGG16_Weights
|
||||
@ -91,6 +91,14 @@ class Experience:
|
||||
agent_index: int = 0
|
||||
is_from_world_model: bool = True
|
||||
|
||||
def cpu(self):
|
||||
return self.to(torch.device('cpu'))
|
||||
|
||||
def to(self, device):
|
||||
experience_dict = asdict(self)
|
||||
experience_dict = tree_map(lambda t: t.to(device) if is_tensor(t) else t, experience_dict)
|
||||
return Experience(**experience_dict)
|
||||
|
||||
def combine_experiences(
|
||||
exps: list[Experiences]
|
||||
) -> Experience:
|
||||
@ -2435,6 +2443,8 @@ class DynamicsWorldModel(Module):
|
||||
):
|
||||
assert isinstance(experience, Experience)
|
||||
|
||||
experience = experience.to(self.device)
|
||||
|
||||
latents = experience.latents
|
||||
actions = experience.actions
|
||||
old_log_probs = experience.log_probs
|
||||
|
||||
@ -528,7 +528,7 @@ class SimTrainer(Module):
|
||||
|
||||
total_experience += num_experience
|
||||
|
||||
experiences.append(experience)
|
||||
experiences.append(experience.cpu())
|
||||
|
||||
combined_experiences = combine_experiences(experiences)
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "dreamer4"
|
||||
version = "0.1.4"
|
||||
version = "0.1.5"
|
||||
description = "Dreamer 4"
|
||||
authors = [
|
||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||
|
||||
@ -680,6 +680,12 @@ def test_online_rl(
|
||||
|
||||
combined_experience = combine_experiences([one_experience, another_experience])
|
||||
|
||||
# quick test moving the experience to different devices
|
||||
|
||||
if torch.cuda.is_available():
|
||||
combined_experience = combined_experience.to(torch.device('cuda'))
|
||||
combined_experience = combined_experience.to(world_model_and_policy.device)
|
||||
|
||||
if store_agent_embed:
|
||||
assert exists(combined_experience.agent_embed)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user