task conditioning when dreaming
This commit is contained in:
parent
22e13c45fc
commit
83cfd2cd1b
@ -21,7 +21,6 @@ from torch.optim import Optimizer
|
||||
from adam_atan2_pytorch import MuonAdamAtan2
|
||||
|
||||
from x_mlps_pytorch.normed_mlp import create_mlp
|
||||
from x_mlps_pytorch.ensemble import Ensemble
|
||||
|
||||
from hyper_connections import get_init_and_expand_reduce_stream_functions
|
||||
|
||||
@ -1816,6 +1815,7 @@ class DynamicsWorldModel(Module):
|
||||
num_steps = 4,
|
||||
batch_size = 1,
|
||||
agent_index = 0,
|
||||
tasks: int | Tensor | None = None,
|
||||
image_height = None,
|
||||
image_width = None,
|
||||
return_decoded_video = None,
|
||||
@ -1829,9 +1829,18 @@ class DynamicsWorldModel(Module):
|
||||
was_training = self.training
|
||||
self.eval()
|
||||
|
||||
# validation
|
||||
|
||||
assert log2(num_steps).is_integer(), f'number of steps {num_steps} must be a power of 2'
|
||||
assert 0 < num_steps <= self.max_steps, f'number of steps {num_steps} must be between 0 and {self.max_steps}'
|
||||
|
||||
if isinstance(tasks, int):
|
||||
tasks = torch.full((batch_size,), tasks, device = self.device)
|
||||
|
||||
assert tasks.shape[0] == batch_size
|
||||
|
||||
# get state latent shape
|
||||
|
||||
latent_shape = self.latent_shape
|
||||
|
||||
# derive step size
|
||||
@ -1884,6 +1893,7 @@ class DynamicsWorldModel(Module):
|
||||
signal_levels = signal_levels_with_context,
|
||||
step_sizes = step_size,
|
||||
rewards = decoded_rewards,
|
||||
tasks = tasks,
|
||||
discrete_actions = decoded_discrete_actions,
|
||||
continuous_actions = decoded_continuous_actions,
|
||||
latent_is_noised = True,
|
||||
@ -2349,6 +2359,7 @@ class DynamicsWorldModel(Module):
|
||||
encoded_agent_tokens = reduce(encoded_agent_tokens, 'b t g d -> b t d', 'mean')
|
||||
|
||||
reward_pred = self.to_reward_pred(encoded_agent_tokens)
|
||||
|
||||
reward_loss = F.cross_entropy(reward_pred, two_hot_encoding)
|
||||
|
||||
# maybe autoregressive action loss
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "dreamer4"
|
||||
version = "0.0.32"
|
||||
version = "0.0.33"
|
||||
description = "Dreamer 4"
|
||||
authors = [
|
||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user