able to roll out actions from one agent within the dreams of a world model
This commit is contained in:
parent
d74f09f0b3
commit
2d20d0a6c1
@ -63,6 +63,8 @@ TokenizerLosses = namedtuple('TokenizerLosses', ('recon', 'lpips'))
|
|||||||
|
|
||||||
WorldModelLosses = namedtuple('WorldModelLosses', ('flow', 'reward', 'behavior_clone'))
|
WorldModelLosses = namedtuple('WorldModelLosses', ('flow', 'reward', 'behavior_clone'))
|
||||||
|
|
||||||
|
WorldModelGenerations = namedtuple('WorldModelGenerations', ('video', 'latents', 'rewards', 'actions'))
|
||||||
|
|
||||||
# helpers
|
# helpers
|
||||||
|
|
||||||
def exists(v):
|
def exists(v):
|
||||||
@ -74,6 +76,9 @@ def default(v, d):
|
|||||||
def first(arr):
|
def first(arr):
|
||||||
return arr[0]
|
return arr[0]
|
||||||
|
|
||||||
|
def has_at_least_one(*bools):
|
||||||
|
return sum([*map(int, bools)]) > 0
|
||||||
|
|
||||||
def ensure_tuple(t):
|
def ensure_tuple(t):
|
||||||
return (t,) if not isinstance(t, tuple) else t
|
return (t,) if not isinstance(t, tuple) else t
|
||||||
|
|
||||||
@ -94,6 +99,16 @@ def is_empty(t):
|
|||||||
def log(t, eps = 1e-20):
|
def log(t, eps = 1e-20):
|
||||||
return t.clamp(min = eps).log()
|
return t.clamp(min = eps).log()
|
||||||
|
|
||||||
|
def safe_cat(*tensors, dim):
|
||||||
|
tensors = [*filter(exists, tensors)]
|
||||||
|
|
||||||
|
if len(tensors) == 0:
|
||||||
|
return None
|
||||||
|
elif len(tensors) == 1:
|
||||||
|
return tensors[0]
|
||||||
|
|
||||||
|
return cat(tensors, dim = dim)
|
||||||
|
|
||||||
def gumbel_noise(t):
|
def gumbel_noise(t):
|
||||||
noise = torch.rand_like(t)
|
noise = torch.rand_like(t)
|
||||||
return -log(-log(noise))
|
return -log(-log(noise))
|
||||||
@ -1630,7 +1645,8 @@ class DynamicsWorldModel(Module):
|
|||||||
image_width = None,
|
image_width = None,
|
||||||
return_decoded_video = None,
|
return_decoded_video = None,
|
||||||
context_signal_noise = 0.1, # they do a noising of the past, this was from an old diffusion world modeling paper from EPFL iirc
|
context_signal_noise = 0.1, # they do a noising of the past, this was from an old diffusion world modeling paper from EPFL iirc
|
||||||
return_rewards_per_frame = False
|
return_rewards_per_frame = False,
|
||||||
|
return_agent_actions = False
|
||||||
|
|
||||||
): # (b t n d) | (b c t h w)
|
): # (b t n d) | (b c t h w)
|
||||||
|
|
||||||
@ -1653,6 +1669,14 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
past_context_noise = latents.clone()
|
past_context_noise = latents.clone()
|
||||||
|
|
||||||
|
# maybe return actions
|
||||||
|
|
||||||
|
if return_agent_actions:
|
||||||
|
assert self.action_embedder.has_actions
|
||||||
|
|
||||||
|
decoded_discrete_actions = None
|
||||||
|
decoded_continuous_actions = None
|
||||||
|
|
||||||
# maybe return rewards
|
# maybe return rewards
|
||||||
|
|
||||||
if return_rewards_per_frame:
|
if return_rewards_per_frame:
|
||||||
@ -1679,6 +1703,8 @@ class DynamicsWorldModel(Module):
|
|||||||
signal_levels = signal_levels_with_context,
|
signal_levels = signal_levels_with_context,
|
||||||
step_sizes = step_size,
|
step_sizes = step_size,
|
||||||
rewards = decoded_rewards,
|
rewards = decoded_rewards,
|
||||||
|
discrete_actions = decoded_discrete_actions,
|
||||||
|
continuous_actions = decoded_continuous_actions,
|
||||||
latent_is_noised = True,
|
latent_is_noised = True,
|
||||||
return_pred_only = True,
|
return_pred_only = True,
|
||||||
return_agent_tokens = True
|
return_agent_tokens = True
|
||||||
@ -1710,6 +1736,18 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
decoded_rewards = cat((decoded_rewards, pred_reward), dim = 1)
|
decoded_rewards = cat((decoded_rewards, pred_reward), dim = 1)
|
||||||
|
|
||||||
|
# decode the agent actions if needed
|
||||||
|
|
||||||
|
if return_agent_actions:
|
||||||
|
one_agent_embed = agent_embed[:, -1:, agent_index]
|
||||||
|
|
||||||
|
policy_embed = self.policy_head(one_agent_embed)
|
||||||
|
|
||||||
|
sampled_discrete_actions, sampled_continuous_actions = self.action_embedder.sample(policy_embed)
|
||||||
|
|
||||||
|
decoded_discrete_actions = safe_cat(decoded_discrete_actions, sampled_discrete_actions, dim = 1)
|
||||||
|
decoded_continuous_actions = safe_cat(decoded_continuous_actions, sampled_continuous_actions, dim = 1)
|
||||||
|
|
||||||
# concat the denoised latent
|
# concat the denoised latent
|
||||||
|
|
||||||
latents = cat((latents, denoised_latent), dim = 1)
|
latents = cat((latents, denoised_latent), dim = 1)
|
||||||
@ -1739,10 +1777,15 @@ class DynamicsWorldModel(Module):
|
|||||||
width = image_width
|
width = image_width
|
||||||
)
|
)
|
||||||
|
|
||||||
if not return_rewards_per_frame:
|
if not has_at_least_one(return_rewards_per_frame, return_agent_actions):
|
||||||
return generated_video
|
return generated_video
|
||||||
|
|
||||||
return generated_video, decoded_rewards
|
return WorldModelGenerations(
|
||||||
|
video = generated_video,
|
||||||
|
latents = latents,
|
||||||
|
rewards = decoded_rewards if return_rewards_per_frame else None,
|
||||||
|
actions = (decoded_discrete_actions, decoded_continuous_actions) if return_agent_actions else None
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "dreamer4"
|
name = "dreamer4"
|
||||||
version = "0.0.24"
|
version = "0.0.25"
|
||||||
description = "Dreamer 4"
|
description = "Dreamer 4"
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||||
|
|||||||
@ -187,6 +187,46 @@ def test_attend_factory(
|
|||||||
|
|
||||||
assert torch.allclose(flex_out, out, atol = 1e-6)
|
assert torch.allclose(flex_out, out, atol = 1e-6)
|
||||||
|
|
||||||
|
def test_action_with_world_model():
|
||||||
|
from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel
|
||||||
|
|
||||||
|
tokenizer = VideoTokenizer(
|
||||||
|
512,
|
||||||
|
dim_latent = 32,
|
||||||
|
patch_size = 32,
|
||||||
|
encoder_depth = 1,
|
||||||
|
decoder_depth = 1,
|
||||||
|
attn_heads = 8,
|
||||||
|
image_height = 256,
|
||||||
|
image_width = 256,
|
||||||
|
attn_kwargs = dict(
|
||||||
|
query_heads = 16
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
dynamics = DynamicsWorldModel(
|
||||||
|
512,
|
||||||
|
num_agents = 1,
|
||||||
|
video_tokenizer = tokenizer,
|
||||||
|
dim_latent = 32,
|
||||||
|
num_discrete_actions = 4
|
||||||
|
)
|
||||||
|
|
||||||
|
rewards = torch.randn(1, 4)
|
||||||
|
discrete_actions = torch.randint(0, 4, (1, 4, 1))
|
||||||
|
|
||||||
|
generated_video, _, generated_rewards, (discrete_actions, continuous_actions) = dynamics.generate(
|
||||||
|
10,
|
||||||
|
return_rewards_per_frame = True,
|
||||||
|
return_agent_actions = True
|
||||||
|
)
|
||||||
|
|
||||||
|
assert generated_video.shape == (1, 3, 10, 256, 256)
|
||||||
|
assert generated_rewards.shape == (1, 10)
|
||||||
|
|
||||||
|
assert discrete_actions.shape == (1, 10, 1)
|
||||||
|
assert continuous_actions is None
|
||||||
|
|
||||||
def test_action_embedder():
|
def test_action_embedder():
|
||||||
from dreamer4.dreamer4 import ActionEmbedder
|
from dreamer4.dreamer4 import ActionEmbedder
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user