diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 02e9822..02db5b9 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -6,7 +6,7 @@ from random import random from contextlib import nullcontext from collections import namedtuple from functools import partial -from dataclasses import dataclass +from dataclasses import dataclass, asdict import torch import torch.nn.functional as F @@ -14,6 +14,7 @@ from torch.nested import nested_tensor from torch.distributions import Normal from torch.nn import Module, ModuleList, Embedding, Parameter, Sequential, Linear, RMSNorm, Identity from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange +from torch.utils._pytree import tree_flatten, tree_unflatten import torchvision from torchvision.models import VGG16_Weights @@ -82,6 +83,42 @@ class Experience: agent_index: int = 0 is_from_world_model: bool = True +def combine_experiences( + exps: list[Experiences] +) -> Experience: + + assert len(exps) > 0 + exps_dict = [asdict(exp) for exp in exps] + + values, tree_specs = zip(*[tree_flatten(exp_dict) for exp_dict in exps_dict]) + + tree_spec = first(tree_specs) + + all_field_values = list(zip(*values)) + + # an assert to make sure all fields are either all tensors, or a single matching value (for step size, agent index etc) - can change this later + + assert all([ + all([is_tensor(v) for v in field_values]) or len(set(field_values)) == 1 + for field_values in all_field_values + ]) + + concatted = [] + + for field_values in all_field_values: + if is_tensor(first(field_values)): + new_field_value = cat(field_values) + else: + new_field_value = first(list(set(field_values))) + + concatted.append(new_field_value) + + # return experience + + concat_exp_dict = tree_unflatten(concatted, tree_spec) + + return Experience(**concat_exp_dict) + # helpers def exists(v): diff --git a/pyproject.toml b/pyproject.toml index 8b02c2b..04b0cd1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.0.66" +version = "0.0.67" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } diff --git a/tests/test_dreamer.py b/tests/test_dreamer.py index 99355b9..39dbb9c 100644 --- a/tests/test_dreamer.py +++ b/tests/test_dreamer.py @@ -637,11 +637,16 @@ def test_online_rl( ) from dreamer4.mocks import MockEnv + from dreamer4.dreamer4 import combine_experiences + mock_env = MockEnv((256, 256), vectorized = vectorized, num_envs = 4) one_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 16, env_is_vectorized = vectorized) + another_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 16, env_is_vectorized = vectorized) - actor_loss, critic_loss = world_model_and_policy.learn_from_experience(one_experience, use_signed_advantage = use_signed_advantage) + combined_experience = combine_experiences([one_experience, another_experience]) + + actor_loss, critic_loss = world_model_and_policy.learn_from_experience(combined_experience, use_signed_advantage = use_signed_advantage) actor_loss.backward() critic_loss.backward()