able to roll out actions from one agent within the dreams of a world model
This commit is contained in:
parent
d74f09f0b3
commit
2d20d0a6c1
@ -63,6 +63,8 @@ TokenizerLosses = namedtuple('TokenizerLosses', ('recon', 'lpips'))
|
||||
|
||||
WorldModelLosses = namedtuple('WorldModelLosses', ('flow', 'reward', 'behavior_clone'))
|
||||
|
||||
WorldModelGenerations = namedtuple('WorldModelGenerations', ('video', 'latents', 'rewards', 'actions'))
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(v):
|
||||
@ -74,6 +76,9 @@ def default(v, d):
|
||||
def first(arr):
|
||||
return arr[0]
|
||||
|
||||
def has_at_least_one(*bools):
|
||||
return sum([*map(int, bools)]) > 0
|
||||
|
||||
def ensure_tuple(t):
|
||||
return (t,) if not isinstance(t, tuple) else t
|
||||
|
||||
@ -94,6 +99,16 @@ def is_empty(t):
|
||||
def log(t, eps = 1e-20):
|
||||
return t.clamp(min = eps).log()
|
||||
|
||||
def safe_cat(*tensors, dim):
|
||||
tensors = [*filter(exists, tensors)]
|
||||
|
||||
if len(tensors) == 0:
|
||||
return None
|
||||
elif len(tensors) == 1:
|
||||
return tensors[0]
|
||||
|
||||
return cat(tensors, dim = dim)
|
||||
|
||||
def gumbel_noise(t):
|
||||
noise = torch.rand_like(t)
|
||||
return -log(-log(noise))
|
||||
@ -1630,7 +1645,8 @@ class DynamicsWorldModel(Module):
|
||||
image_width = None,
|
||||
return_decoded_video = None,
|
||||
context_signal_noise = 0.1, # they do a noising of the past, this was from an old diffusion world modeling paper from EPFL iirc
|
||||
return_rewards_per_frame = False
|
||||
return_rewards_per_frame = False,
|
||||
return_agent_actions = False
|
||||
|
||||
): # (b t n d) | (b c t h w)
|
||||
|
||||
@ -1653,6 +1669,14 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
past_context_noise = latents.clone()
|
||||
|
||||
# maybe return actions
|
||||
|
||||
if return_agent_actions:
|
||||
assert self.action_embedder.has_actions
|
||||
|
||||
decoded_discrete_actions = None
|
||||
decoded_continuous_actions = None
|
||||
|
||||
# maybe return rewards
|
||||
|
||||
if return_rewards_per_frame:
|
||||
@ -1679,6 +1703,8 @@ class DynamicsWorldModel(Module):
|
||||
signal_levels = signal_levels_with_context,
|
||||
step_sizes = step_size,
|
||||
rewards = decoded_rewards,
|
||||
discrete_actions = decoded_discrete_actions,
|
||||
continuous_actions = decoded_continuous_actions,
|
||||
latent_is_noised = True,
|
||||
return_pred_only = True,
|
||||
return_agent_tokens = True
|
||||
@ -1710,6 +1736,18 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
decoded_rewards = cat((decoded_rewards, pred_reward), dim = 1)
|
||||
|
||||
# decode the agent actions if needed
|
||||
|
||||
if return_agent_actions:
|
||||
one_agent_embed = agent_embed[:, -1:, agent_index]
|
||||
|
||||
policy_embed = self.policy_head(one_agent_embed)
|
||||
|
||||
sampled_discrete_actions, sampled_continuous_actions = self.action_embedder.sample(policy_embed)
|
||||
|
||||
decoded_discrete_actions = safe_cat(decoded_discrete_actions, sampled_discrete_actions, dim = 1)
|
||||
decoded_continuous_actions = safe_cat(decoded_continuous_actions, sampled_continuous_actions, dim = 1)
|
||||
|
||||
# concat the denoised latent
|
||||
|
||||
latents = cat((latents, denoised_latent), dim = 1)
|
||||
@ -1739,10 +1777,15 @@ class DynamicsWorldModel(Module):
|
||||
width = image_width
|
||||
)
|
||||
|
||||
if not return_rewards_per_frame:
|
||||
if not has_at_least_one(return_rewards_per_frame, return_agent_actions):
|
||||
return generated_video
|
||||
|
||||
return generated_video, decoded_rewards
|
||||
return WorldModelGenerations(
|
||||
video = generated_video,
|
||||
latents = latents,
|
||||
rewards = decoded_rewards if return_rewards_per_frame else None,
|
||||
actions = (decoded_discrete_actions, decoded_continuous_actions) if return_agent_actions else None
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "dreamer4"
|
||||
version = "0.0.24"
|
||||
version = "0.0.25"
|
||||
description = "Dreamer 4"
|
||||
authors = [
|
||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||
|
||||
@ -187,6 +187,46 @@ def test_attend_factory(
|
||||
|
||||
assert torch.allclose(flex_out, out, atol = 1e-6)
|
||||
|
||||
def test_action_with_world_model():
|
||||
from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel
|
||||
|
||||
tokenizer = VideoTokenizer(
|
||||
512,
|
||||
dim_latent = 32,
|
||||
patch_size = 32,
|
||||
encoder_depth = 1,
|
||||
decoder_depth = 1,
|
||||
attn_heads = 8,
|
||||
image_height = 256,
|
||||
image_width = 256,
|
||||
attn_kwargs = dict(
|
||||
query_heads = 16
|
||||
)
|
||||
)
|
||||
|
||||
dynamics = DynamicsWorldModel(
|
||||
512,
|
||||
num_agents = 1,
|
||||
video_tokenizer = tokenizer,
|
||||
dim_latent = 32,
|
||||
num_discrete_actions = 4
|
||||
)
|
||||
|
||||
rewards = torch.randn(1, 4)
|
||||
discrete_actions = torch.randint(0, 4, (1, 4, 1))
|
||||
|
||||
generated_video, _, generated_rewards, (discrete_actions, continuous_actions) = dynamics.generate(
|
||||
10,
|
||||
return_rewards_per_frame = True,
|
||||
return_agent_actions = True
|
||||
)
|
||||
|
||||
assert generated_video.shape == (1, 3, 10, 256, 256)
|
||||
assert generated_rewards.shape == (1, 10)
|
||||
|
||||
assert discrete_actions.shape == (1, 10, 1)
|
||||
assert continuous_actions is None
|
||||
|
||||
def test_action_embedder():
|
||||
from dreamer4.dreamer4 import ActionEmbedder
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user