able to store the agent embeddings during rollouts with imagination or environment, for efficient policy optimization (but will also allow for finetuning world model for the heads)
This commit is contained in:
parent
789f091c63
commit
d476fa7b14
@ -77,6 +77,7 @@ class Experience:
|
||||
latents: Tensor
|
||||
video: Tensor | None = None
|
||||
proprio: Tensor | None = None
|
||||
agent_embed: Tensor | None = None,
|
||||
rewards: Tensor | None = None
|
||||
actions: tuple[Tensor, Tensor] | None = None
|
||||
log_probs: tuple[Tensor, Tensor] | None = None
|
||||
@ -2105,7 +2106,8 @@ class DynamicsWorldModel(Module):
|
||||
step_size = 4,
|
||||
max_timesteps = 16,
|
||||
env_is_vectorized = False,
|
||||
use_time_kv_cache = True
|
||||
use_time_kv_cache = True,
|
||||
store_agent_embed = False
|
||||
):
|
||||
assert exists(self.video_tokenizer)
|
||||
|
||||
@ -2130,6 +2132,8 @@ class DynamicsWorldModel(Module):
|
||||
values = None
|
||||
latents = None
|
||||
|
||||
acc_agent_embed = None
|
||||
|
||||
# keep track of termination, for setting the `is_truncated` field on Experience and for early stopping interaction with env
|
||||
|
||||
is_terminated = full((batch,), False, device = device)
|
||||
@ -2251,6 +2255,8 @@ class DynamicsWorldModel(Module):
|
||||
video = cat((video, next_frame), dim = 2)
|
||||
rewards = safe_cat((rewards, reward), dim = 1)
|
||||
|
||||
acc_agent_embed = safe_cat((acc_agent_embed, agent_embed), dim = 1)
|
||||
|
||||
# package up one experience for learning
|
||||
|
||||
batch, device = latents.shape[0], latents.device
|
||||
@ -2262,6 +2268,7 @@ class DynamicsWorldModel(Module):
|
||||
actions = (discrete_actions, continuous_actions),
|
||||
log_probs = (discrete_log_probs, continuous_log_probs),
|
||||
values = values,
|
||||
agent_embed = acc_agent_embed if store_agent_embed else None,
|
||||
step_size = step_size,
|
||||
agent_index = agent_index,
|
||||
is_truncated = is_truncated,
|
||||
@ -2491,6 +2498,7 @@ class DynamicsWorldModel(Module):
|
||||
return_agent_actions = False,
|
||||
return_log_probs_and_values = False,
|
||||
return_time_kv_cache = False,
|
||||
store_agent_embed = False
|
||||
|
||||
): # (b t n d) | (b c t h w)
|
||||
|
||||
@ -2543,6 +2551,10 @@ class DynamicsWorldModel(Module):
|
||||
decoded_continuous_log_probs = None
|
||||
decoded_values = None
|
||||
|
||||
# maybe store agent embed
|
||||
|
||||
acc_agent_embed = None
|
||||
|
||||
# maybe return rewards
|
||||
|
||||
decoded_rewards = None
|
||||
@ -2651,6 +2663,10 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
decoded_rewards = cat((decoded_rewards, pred_reward), dim = 1)
|
||||
|
||||
# maybe store agent embed
|
||||
|
||||
acc_agent_embed = safe_cat((acc_agent_embed, agent_embed), dim = 1)
|
||||
|
||||
# decode the agent actions if needed
|
||||
|
||||
if return_agent_actions:
|
||||
@ -2747,6 +2763,7 @@ class DynamicsWorldModel(Module):
|
||||
latents = latents,
|
||||
video = video,
|
||||
proprio = proprio if has_proprio else None,
|
||||
agent_embed = acc_agent_embed if store_agent_embed else None,
|
||||
step_size = step_size,
|
||||
agent_index = agent_index,
|
||||
lens = experience_lens,
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "dreamer4"
|
||||
version = "0.0.83"
|
||||
version = "0.0.85"
|
||||
description = "Dreamer 4"
|
||||
authors = [
|
||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||
|
||||
@ -614,11 +614,13 @@ def test_cache_generate():
|
||||
@param('use_signed_advantage', (False, True))
|
||||
@param('env_can_terminate', (False, True))
|
||||
@param('env_can_truncate', (False, True))
|
||||
@param('store_agent_embed', (False, True))
|
||||
def test_online_rl(
|
||||
vectorized,
|
||||
use_signed_advantage,
|
||||
env_can_terminate,
|
||||
env_can_truncate
|
||||
env_can_truncate,
|
||||
store_agent_embed
|
||||
):
|
||||
from dreamer4.dreamer4 import DynamicsWorldModel, VideoTokenizer
|
||||
|
||||
@ -664,11 +666,14 @@ def test_online_rl(
|
||||
|
||||
# manually
|
||||
|
||||
one_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 8, env_is_vectorized = vectorized)
|
||||
another_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 16, env_is_vectorized = vectorized)
|
||||
one_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 8, env_is_vectorized = vectorized, store_agent_embed = store_agent_embed)
|
||||
another_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 16, env_is_vectorized = vectorized, store_agent_embed = store_agent_embed)
|
||||
|
||||
combined_experience = combine_experiences([one_experience, another_experience])
|
||||
|
||||
if store_agent_embed:
|
||||
assert exists(combined_experience.agent_embed)
|
||||
|
||||
actor_loss, critic_loss = world_model_and_policy.learn_from_experience(combined_experience, use_signed_advantage = use_signed_advantage)
|
||||
|
||||
actor_loss.backward()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user