able to move the experience to cpu easily, and auto matically move it to the device of the dynamics world model when learning from it
This commit is contained in:
parent
24ef72d528
commit
cfd34f1eba
@ -14,7 +14,7 @@ from torch.nested import nested_tensor
|
|||||||
from torch.distributions import Normal, kl
|
from torch.distributions import Normal, kl
|
||||||
from torch.nn import Module, ModuleList, Embedding, Parameter, Sequential, Linear, RMSNorm, Identity
|
from torch.nn import Module, ModuleList, Embedding, Parameter, Sequential, Linear, RMSNorm, Identity
|
||||||
from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, full, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange
|
from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, full, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange
|
||||||
from torch.utils._pytree import tree_flatten, tree_unflatten
|
from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
|
||||||
|
|
||||||
import torchvision
|
import torchvision
|
||||||
from torchvision.models import VGG16_Weights
|
from torchvision.models import VGG16_Weights
|
||||||
@ -91,6 +91,14 @@ class Experience:
|
|||||||
agent_index: int = 0
|
agent_index: int = 0
|
||||||
is_from_world_model: bool = True
|
is_from_world_model: bool = True
|
||||||
|
|
||||||
|
def cpu(self):
|
||||||
|
return self.to(torch.device('cpu'))
|
||||||
|
|
||||||
|
def to(self, device):
|
||||||
|
experience_dict = asdict(self)
|
||||||
|
experience_dict = tree_map(lambda t: t.to(device) if is_tensor(t) else t, experience_dict)
|
||||||
|
return Experience(**experience_dict)
|
||||||
|
|
||||||
def combine_experiences(
|
def combine_experiences(
|
||||||
exps: list[Experiences]
|
exps: list[Experiences]
|
||||||
) -> Experience:
|
) -> Experience:
|
||||||
@ -2435,6 +2443,8 @@ class DynamicsWorldModel(Module):
|
|||||||
):
|
):
|
||||||
assert isinstance(experience, Experience)
|
assert isinstance(experience, Experience)
|
||||||
|
|
||||||
|
experience = experience.to(self.device)
|
||||||
|
|
||||||
latents = experience.latents
|
latents = experience.latents
|
||||||
actions = experience.actions
|
actions = experience.actions
|
||||||
old_log_probs = experience.log_probs
|
old_log_probs = experience.log_probs
|
||||||
|
|||||||
@ -528,7 +528,7 @@ class SimTrainer(Module):
|
|||||||
|
|
||||||
total_experience += num_experience
|
total_experience += num_experience
|
||||||
|
|
||||||
experiences.append(experience)
|
experiences.append(experience.cpu())
|
||||||
|
|
||||||
combined_experiences = combine_experiences(experiences)
|
combined_experiences = combine_experiences(experiences)
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "dreamer4"
|
name = "dreamer4"
|
||||||
version = "0.1.4"
|
version = "0.1.5"
|
||||||
description = "Dreamer 4"
|
description = "Dreamer 4"
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||||
|
|||||||
@ -680,6 +680,12 @@ def test_online_rl(
|
|||||||
|
|
||||||
combined_experience = combine_experiences([one_experience, another_experience])
|
combined_experience = combine_experiences([one_experience, another_experience])
|
||||||
|
|
||||||
|
# quick test moving the experience to different devices
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
combined_experience = combined_experience.to(torch.device('cuda'))
|
||||||
|
combined_experience = combined_experience.to(world_model_and_policy.device)
|
||||||
|
|
||||||
if store_agent_embed:
|
if store_agent_embed:
|
||||||
assert exists(combined_experience.agent_embed)
|
assert exists(combined_experience.agent_embed)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user