function for combining experiences

This commit is contained in:
lucidrains 2025-10-24 08:00:10 -07:00
parent d0ffc6bfed
commit 27ac05efb0
3 changed files with 45 additions and 3 deletions

View File

@ -6,7 +6,7 @@ from random import random
from contextlib import nullcontext from contextlib import nullcontext
from collections import namedtuple from collections import namedtuple
from functools import partial from functools import partial
from dataclasses import dataclass from dataclasses import dataclass, asdict
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@ -14,6 +14,7 @@ from torch.nested import nested_tensor
from torch.distributions import Normal from torch.distributions import Normal
from torch.nn import Module, ModuleList, Embedding, Parameter, Sequential, Linear, RMSNorm, Identity from torch.nn import Module, ModuleList, Embedding, Parameter, Sequential, Linear, RMSNorm, Identity
from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange
from torch.utils._pytree import tree_flatten, tree_unflatten
import torchvision import torchvision
from torchvision.models import VGG16_Weights from torchvision.models import VGG16_Weights
@ -82,6 +83,42 @@ class Experience:
agent_index: int = 0 agent_index: int = 0
is_from_world_model: bool = True is_from_world_model: bool = True
def combine_experiences(
exps: list[Experiences]
) -> Experience:
assert len(exps) > 0
exps_dict = [asdict(exp) for exp in exps]
values, tree_specs = zip(*[tree_flatten(exp_dict) for exp_dict in exps_dict])
tree_spec = first(tree_specs)
all_field_values = list(zip(*values))
# an assert to make sure all fields are either all tensors, or a single matching value (for step size, agent index etc) - can change this later
assert all([
all([is_tensor(v) for v in field_values]) or len(set(field_values)) == 1
for field_values in all_field_values
])
concatted = []
for field_values in all_field_values:
if is_tensor(first(field_values)):
new_field_value = cat(field_values)
else:
new_field_value = first(list(set(field_values)))
concatted.append(new_field_value)
# return experience
concat_exp_dict = tree_unflatten(concatted, tree_spec)
return Experience(**concat_exp_dict)
# helpers # helpers
def exists(v): def exists(v):

View File

@ -1,6 +1,6 @@
[project] [project]
name = "dreamer4" name = "dreamer4"
version = "0.0.66" version = "0.0.67"
description = "Dreamer 4" description = "Dreamer 4"
authors = [ authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" } { name = "Phil Wang", email = "lucidrains@gmail.com" }

View File

@ -637,11 +637,16 @@ def test_online_rl(
) )
from dreamer4.mocks import MockEnv from dreamer4.mocks import MockEnv
from dreamer4.dreamer4 import combine_experiences
mock_env = MockEnv((256, 256), vectorized = vectorized, num_envs = 4) mock_env = MockEnv((256, 256), vectorized = vectorized, num_envs = 4)
one_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 16, env_is_vectorized = vectorized) one_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 16, env_is_vectorized = vectorized)
another_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 16, env_is_vectorized = vectorized)
actor_loss, critic_loss = world_model_and_policy.learn_from_experience(one_experience, use_signed_advantage = use_signed_advantage) combined_experience = combine_experiences([one_experience, another_experience])
actor_loss, critic_loss = world_model_and_policy.learn_from_experience(combined_experience, use_signed_advantage = use_signed_advantage)
actor_loss.backward() actor_loss.backward()
critic_loss.backward() critic_loss.backward()