From 32cf142b4d547f97257dfc26a7c5c15082ee494f Mon Sep 17 00:00:00 2001 From: lucidrains Date: Sat, 25 Oct 2025 11:31:41 -0700 Subject: [PATCH] take another step for variable len experiences --- dreamer4/dreamer4.py | 37 +++++++++++++++++++++++++++++++++++++ pyproject.toml | 2 +- tests/test_dreamer.py | 2 +- 3 files changed, 39 insertions(+), 2 deletions(-) diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 7d333d2..e0f8adc 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -91,6 +91,18 @@ def combine_experiences( ) -> Experience: assert len(exps) > 0 + + # set lens if not there + + for exp in exps: + latents = exp.latents + batch, time, device = *latents.shape[:2], latents.device + + if not exists(exp.lens): + exp.lens = torch.full((batch,), time, device = device) + + # convert to dictionary + exps_dict = [asdict(exp) for exp in exps] values, tree_specs = zip(*[tree_flatten(exp_dict) for exp_dict in exps_dict]) @@ -109,7 +121,11 @@ def combine_experiences( concatted = [] for field_values in all_field_values: + if is_tensor(first(field_values)): + + field_values = pad_tensors_at_dim_to_max_len(field_values, dims = (1, 2)) + new_field_value = cat(field_values) else: new_field_value = first(list(set(field_values))) @@ -223,6 +239,27 @@ def pad_at_dim( zeros = ((0, 0) * dims_from_right) return F.pad(t, (*zeros, *pad), value = value) +def pad_to_len(t, target_len, *, dim): + curr_len = t.shape[dim] + + if curr_len >= target_len: + return t + + return pad_at_dim(t, (0, target_len - curr_len), dim = dim) + +def pad_tensors_at_dim_to_max_len( + tensors: list[Tensor], + dims: tuple[int, ...] +): + for dim in dims: + if dim >= first(tensors).ndim: + continue + + max_time = max([t.shape[dim] for t in tensors]) + tensors = [pad_to_len(t, max_time, dim = dim) for t in tensors] + + return tensors + def align_dims_left(t, aligned_to): shape = t.shape num_right_dims = aligned_to.ndim - t.ndim diff --git a/pyproject.toml b/pyproject.toml index ab2e923..bca30ee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.0.73" +version = "0.0.74" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } diff --git a/tests/test_dreamer.py b/tests/test_dreamer.py index 87b0228..ae1f679 100644 --- a/tests/test_dreamer.py +++ b/tests/test_dreamer.py @@ -653,7 +653,7 @@ def test_online_rl( # manually - one_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 16, env_is_vectorized = vectorized) + one_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 8, env_is_vectorized = vectorized) another_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 16, env_is_vectorized = vectorized) combined_experience = combine_experiences([one_experience, another_experience])