task conditioning when dreaming

This commit is contained in:
lucidrains 2025-10-18 07:47:13 -07:00
parent 22e13c45fc
commit 83cfd2cd1b
2 changed files with 13 additions and 2 deletions

View File

@ -21,7 +21,6 @@ from torch.optim import Optimizer
from adam_atan2_pytorch import MuonAdamAtan2
from x_mlps_pytorch.normed_mlp import create_mlp
from x_mlps_pytorch.ensemble import Ensemble
from hyper_connections import get_init_and_expand_reduce_stream_functions
@ -1816,6 +1815,7 @@ class DynamicsWorldModel(Module):
num_steps = 4,
batch_size = 1,
agent_index = 0,
tasks: int | Tensor | None = None,
image_height = None,
image_width = None,
return_decoded_video = None,
@ -1829,9 +1829,18 @@ class DynamicsWorldModel(Module):
was_training = self.training
self.eval()
# validation
assert log2(num_steps).is_integer(), f'number of steps {num_steps} must be a power of 2'
assert 0 < num_steps <= self.max_steps, f'number of steps {num_steps} must be between 0 and {self.max_steps}'
if isinstance(tasks, int):
tasks = torch.full((batch_size,), tasks, device = self.device)
assert tasks.shape[0] == batch_size
# get state latent shape
latent_shape = self.latent_shape
# derive step size
@ -1884,6 +1893,7 @@ class DynamicsWorldModel(Module):
signal_levels = signal_levels_with_context,
step_sizes = step_size,
rewards = decoded_rewards,
tasks = tasks,
discrete_actions = decoded_discrete_actions,
continuous_actions = decoded_continuous_actions,
latent_is_noised = True,
@ -2349,6 +2359,7 @@ class DynamicsWorldModel(Module):
encoded_agent_tokens = reduce(encoded_agent_tokens, 'b t g d -> b t d', 'mean')
reward_pred = self.to_reward_pred(encoded_agent_tokens)
reward_loss = F.cross_entropy(reward_pred, two_hot_encoding)
# maybe autoregressive action loss

View File

@ -1,6 +1,6 @@
[project]
name = "dreamer4"
version = "0.0.32"
version = "0.0.33"
description = "Dreamer 4"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }