allow for the combining of experiences from environment and dream

This commit is contained in:
lucidrains 2025-11-13 16:37:35 -08:00
parent 690ecf07dc
commit 2e7f406d49
3 changed files with 23 additions and 7 deletions

View File

@ -96,7 +96,7 @@ class Experience:
lens: MaybeTensor = None
is_truncated: MaybeTensor = None
agent_index: int = 0
is_from_world_model: bool = True
is_from_world_model: bool | Tensor = True
def cpu(self):
return self.to(torch.device('cpu'))
@ -124,6 +124,9 @@ def combine_experiences(
if not exists(exp.is_truncated):
exp.is_truncated = full((batch,), True, device = device)
if isinstance(exp.is_from_world_model, bool):
exp.is_from_world_model = tensor(exp.is_from_world_model)
# convert to dictionary
exps_dict = [asdict(exp) for exp in exps]
@ -145,11 +148,15 @@ def combine_experiences(
for field_values in all_field_values:
if is_tensor(first(field_values)):
first_value = first(field_values)
if is_tensor(first_value):
field_values = pad_tensors_at_dim_to_max_len(field_values, dims = (1, 2))
new_field_value = cat(field_values)
cat_or_stack = cat if first_value.ndim > 0 else stack
new_field_value = cat_or_stack(field_values)
else:
new_field_value = first(list(set(field_values)))
@ -2408,7 +2415,7 @@ class DynamicsWorldModel(Module):
env,
seed = None,
agent_index = 0,
step_size = 4,
num_steps = 4,
max_timesteps = 16,
env_is_vectorized = False,
use_time_cache = True,
@ -2448,6 +2455,11 @@ class DynamicsWorldModel(Module):
episode_lens = full((batch,), 0, device = device)
# derive step size
assert divisible_by(self.max_steps, num_steps)
step_size = self.max_steps // num_steps
# maybe time kv cache
time_cache = None

View File

@ -1,6 +1,6 @@
[project]
name = "dreamer4"
version = "0.1.19"
version = "0.1.20"
description = "Dreamer 4"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

View File

@ -643,7 +643,9 @@ def test_online_rl(
dim_latent = 16,
patch_size = 32,
attn_dim_head = 16,
num_latent_tokens = 1
num_latent_tokens = 1,
image_height = 256,
image_width = 256,
)
world_model_and_policy = DynamicsWorldModel(
@ -677,10 +679,12 @@ def test_online_rl(
# manually
dream_experience = world_model_and_policy.generate(10, batch_size = 1, store_agent_embed = store_agent_embed, return_for_policy_optimization = True)
one_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 8, env_is_vectorized = vectorized, store_agent_embed = store_agent_embed)
another_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 16, env_is_vectorized = vectorized, store_agent_embed = store_agent_embed)
combined_experience = combine_experiences([one_experience, another_experience])
combined_experience = combine_experiences([dream_experience, one_experience, another_experience])
# quick test moving the experience to different devices