take another step for variable len experiences
This commit is contained in:
parent
1ed6a15cb0
commit
32cf142b4d
@ -91,6 +91,18 @@ def combine_experiences(
|
|||||||
) -> Experience:
|
) -> Experience:
|
||||||
|
|
||||||
assert len(exps) > 0
|
assert len(exps) > 0
|
||||||
|
|
||||||
|
# set lens if not there
|
||||||
|
|
||||||
|
for exp in exps:
|
||||||
|
latents = exp.latents
|
||||||
|
batch, time, device = *latents.shape[:2], latents.device
|
||||||
|
|
||||||
|
if not exists(exp.lens):
|
||||||
|
exp.lens = torch.full((batch,), time, device = device)
|
||||||
|
|
||||||
|
# convert to dictionary
|
||||||
|
|
||||||
exps_dict = [asdict(exp) for exp in exps]
|
exps_dict = [asdict(exp) for exp in exps]
|
||||||
|
|
||||||
values, tree_specs = zip(*[tree_flatten(exp_dict) for exp_dict in exps_dict])
|
values, tree_specs = zip(*[tree_flatten(exp_dict) for exp_dict in exps_dict])
|
||||||
@ -109,7 +121,11 @@ def combine_experiences(
|
|||||||
concatted = []
|
concatted = []
|
||||||
|
|
||||||
for field_values in all_field_values:
|
for field_values in all_field_values:
|
||||||
|
|
||||||
if is_tensor(first(field_values)):
|
if is_tensor(first(field_values)):
|
||||||
|
|
||||||
|
field_values = pad_tensors_at_dim_to_max_len(field_values, dims = (1, 2))
|
||||||
|
|
||||||
new_field_value = cat(field_values)
|
new_field_value = cat(field_values)
|
||||||
else:
|
else:
|
||||||
new_field_value = first(list(set(field_values)))
|
new_field_value = first(list(set(field_values)))
|
||||||
@ -223,6 +239,27 @@ def pad_at_dim(
|
|||||||
zeros = ((0, 0) * dims_from_right)
|
zeros = ((0, 0) * dims_from_right)
|
||||||
return F.pad(t, (*zeros, *pad), value = value)
|
return F.pad(t, (*zeros, *pad), value = value)
|
||||||
|
|
||||||
|
def pad_to_len(t, target_len, *, dim):
|
||||||
|
curr_len = t.shape[dim]
|
||||||
|
|
||||||
|
if curr_len >= target_len:
|
||||||
|
return t
|
||||||
|
|
||||||
|
return pad_at_dim(t, (0, target_len - curr_len), dim = dim)
|
||||||
|
|
||||||
|
def pad_tensors_at_dim_to_max_len(
|
||||||
|
tensors: list[Tensor],
|
||||||
|
dims: tuple[int, ...]
|
||||||
|
):
|
||||||
|
for dim in dims:
|
||||||
|
if dim >= first(tensors).ndim:
|
||||||
|
continue
|
||||||
|
|
||||||
|
max_time = max([t.shape[dim] for t in tensors])
|
||||||
|
tensors = [pad_to_len(t, max_time, dim = dim) for t in tensors]
|
||||||
|
|
||||||
|
return tensors
|
||||||
|
|
||||||
def align_dims_left(t, aligned_to):
|
def align_dims_left(t, aligned_to):
|
||||||
shape = t.shape
|
shape = t.shape
|
||||||
num_right_dims = aligned_to.ndim - t.ndim
|
num_right_dims = aligned_to.ndim - t.ndim
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "dreamer4"
|
name = "dreamer4"
|
||||||
version = "0.0.73"
|
version = "0.0.74"
|
||||||
description = "Dreamer 4"
|
description = "Dreamer 4"
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||||
|
|||||||
@ -653,7 +653,7 @@ def test_online_rl(
|
|||||||
|
|
||||||
# manually
|
# manually
|
||||||
|
|
||||||
one_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 16, env_is_vectorized = vectorized)
|
one_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 8, env_is_vectorized = vectorized)
|
||||||
another_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 16, env_is_vectorized = vectorized)
|
another_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 16, env_is_vectorized = vectorized)
|
||||||
|
|
||||||
combined_experience = combine_experiences([one_experience, another_experience])
|
combined_experience = combine_experiences([one_experience, another_experience])
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user