allow for the combining of experiences from environment and dream
This commit is contained in:
parent
690ecf07dc
commit
2e7f406d49
@ -96,7 +96,7 @@ class Experience:
|
||||
lens: MaybeTensor = None
|
||||
is_truncated: MaybeTensor = None
|
||||
agent_index: int = 0
|
||||
is_from_world_model: bool = True
|
||||
is_from_world_model: bool | Tensor = True
|
||||
|
||||
def cpu(self):
|
||||
return self.to(torch.device('cpu'))
|
||||
@ -124,6 +124,9 @@ def combine_experiences(
|
||||
if not exists(exp.is_truncated):
|
||||
exp.is_truncated = full((batch,), True, device = device)
|
||||
|
||||
if isinstance(exp.is_from_world_model, bool):
|
||||
exp.is_from_world_model = tensor(exp.is_from_world_model)
|
||||
|
||||
# convert to dictionary
|
||||
|
||||
exps_dict = [asdict(exp) for exp in exps]
|
||||
@ -145,11 +148,15 @@ def combine_experiences(
|
||||
|
||||
for field_values in all_field_values:
|
||||
|
||||
if is_tensor(first(field_values)):
|
||||
first_value = first(field_values)
|
||||
|
||||
if is_tensor(first_value):
|
||||
|
||||
field_values = pad_tensors_at_dim_to_max_len(field_values, dims = (1, 2))
|
||||
|
||||
new_field_value = cat(field_values)
|
||||
cat_or_stack = cat if first_value.ndim > 0 else stack
|
||||
|
||||
new_field_value = cat_or_stack(field_values)
|
||||
else:
|
||||
new_field_value = first(list(set(field_values)))
|
||||
|
||||
@ -2408,7 +2415,7 @@ class DynamicsWorldModel(Module):
|
||||
env,
|
||||
seed = None,
|
||||
agent_index = 0,
|
||||
step_size = 4,
|
||||
num_steps = 4,
|
||||
max_timesteps = 16,
|
||||
env_is_vectorized = False,
|
||||
use_time_cache = True,
|
||||
@ -2448,6 +2455,11 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
episode_lens = full((batch,), 0, device = device)
|
||||
|
||||
# derive step size
|
||||
|
||||
assert divisible_by(self.max_steps, num_steps)
|
||||
step_size = self.max_steps // num_steps
|
||||
|
||||
# maybe time kv cache
|
||||
|
||||
time_cache = None
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "dreamer4"
|
||||
version = "0.1.19"
|
||||
version = "0.1.20"
|
||||
description = "Dreamer 4"
|
||||
authors = [
|
||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||
|
||||
@ -643,7 +643,9 @@ def test_online_rl(
|
||||
dim_latent = 16,
|
||||
patch_size = 32,
|
||||
attn_dim_head = 16,
|
||||
num_latent_tokens = 1
|
||||
num_latent_tokens = 1,
|
||||
image_height = 256,
|
||||
image_width = 256,
|
||||
)
|
||||
|
||||
world_model_and_policy = DynamicsWorldModel(
|
||||
@ -677,10 +679,12 @@ def test_online_rl(
|
||||
|
||||
# manually
|
||||
|
||||
dream_experience = world_model_and_policy.generate(10, batch_size = 1, store_agent_embed = store_agent_embed, return_for_policy_optimization = True)
|
||||
|
||||
one_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 8, env_is_vectorized = vectorized, store_agent_embed = store_agent_embed)
|
||||
another_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 16, env_is_vectorized = vectorized, store_agent_embed = store_agent_embed)
|
||||
|
||||
combined_experience = combine_experiences([one_experience, another_experience])
|
||||
combined_experience = combine_experiences([dream_experience, one_experience, another_experience])
|
||||
|
||||
# quick test moving the experience to different devices
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user