function for combining experiences
This commit is contained in:
parent
d0ffc6bfed
commit
27ac05efb0
@ -6,7 +6,7 @@ from random import random
|
|||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, asdict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@ -14,6 +14,7 @@ from torch.nested import nested_tensor
|
|||||||
from torch.distributions import Normal
|
from torch.distributions import Normal
|
||||||
from torch.nn import Module, ModuleList, Embedding, Parameter, Sequential, Linear, RMSNorm, Identity
|
from torch.nn import Module, ModuleList, Embedding, Parameter, Sequential, Linear, RMSNorm, Identity
|
||||||
from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange
|
from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange
|
||||||
|
from torch.utils._pytree import tree_flatten, tree_unflatten
|
||||||
|
|
||||||
import torchvision
|
import torchvision
|
||||||
from torchvision.models import VGG16_Weights
|
from torchvision.models import VGG16_Weights
|
||||||
@ -82,6 +83,42 @@ class Experience:
|
|||||||
agent_index: int = 0
|
agent_index: int = 0
|
||||||
is_from_world_model: bool = True
|
is_from_world_model: bool = True
|
||||||
|
|
||||||
|
def combine_experiences(
|
||||||
|
exps: list[Experiences]
|
||||||
|
) -> Experience:
|
||||||
|
|
||||||
|
assert len(exps) > 0
|
||||||
|
exps_dict = [asdict(exp) for exp in exps]
|
||||||
|
|
||||||
|
values, tree_specs = zip(*[tree_flatten(exp_dict) for exp_dict in exps_dict])
|
||||||
|
|
||||||
|
tree_spec = first(tree_specs)
|
||||||
|
|
||||||
|
all_field_values = list(zip(*values))
|
||||||
|
|
||||||
|
# an assert to make sure all fields are either all tensors, or a single matching value (for step size, agent index etc) - can change this later
|
||||||
|
|
||||||
|
assert all([
|
||||||
|
all([is_tensor(v) for v in field_values]) or len(set(field_values)) == 1
|
||||||
|
for field_values in all_field_values
|
||||||
|
])
|
||||||
|
|
||||||
|
concatted = []
|
||||||
|
|
||||||
|
for field_values in all_field_values:
|
||||||
|
if is_tensor(first(field_values)):
|
||||||
|
new_field_value = cat(field_values)
|
||||||
|
else:
|
||||||
|
new_field_value = first(list(set(field_values)))
|
||||||
|
|
||||||
|
concatted.append(new_field_value)
|
||||||
|
|
||||||
|
# return experience
|
||||||
|
|
||||||
|
concat_exp_dict = tree_unflatten(concatted, tree_spec)
|
||||||
|
|
||||||
|
return Experience(**concat_exp_dict)
|
||||||
|
|
||||||
# helpers
|
# helpers
|
||||||
|
|
||||||
def exists(v):
|
def exists(v):
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "dreamer4"
|
name = "dreamer4"
|
||||||
version = "0.0.66"
|
version = "0.0.67"
|
||||||
description = "Dreamer 4"
|
description = "Dreamer 4"
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||||
|
|||||||
@ -637,11 +637,16 @@ def test_online_rl(
|
|||||||
)
|
)
|
||||||
|
|
||||||
from dreamer4.mocks import MockEnv
|
from dreamer4.mocks import MockEnv
|
||||||
|
from dreamer4.dreamer4 import combine_experiences
|
||||||
|
|
||||||
mock_env = MockEnv((256, 256), vectorized = vectorized, num_envs = 4)
|
mock_env = MockEnv((256, 256), vectorized = vectorized, num_envs = 4)
|
||||||
|
|
||||||
one_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 16, env_is_vectorized = vectorized)
|
one_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 16, env_is_vectorized = vectorized)
|
||||||
|
another_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 16, env_is_vectorized = vectorized)
|
||||||
|
|
||||||
actor_loss, critic_loss = world_model_and_policy.learn_from_experience(one_experience, use_signed_advantage = use_signed_advantage)
|
combined_experience = combine_experiences([one_experience, another_experience])
|
||||||
|
|
||||||
|
actor_loss, critic_loss = world_model_and_policy.learn_from_experience(combined_experience, use_signed_advantage = use_signed_advantage)
|
||||||
|
|
||||||
actor_loss.backward()
|
actor_loss.backward()
|
||||||
critic_loss.backward()
|
critic_loss.backward()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user