task conditioning when dreaming
This commit is contained in:
parent
22e13c45fc
commit
83cfd2cd1b
@ -21,7 +21,6 @@ from torch.optim import Optimizer
|
|||||||
from adam_atan2_pytorch import MuonAdamAtan2
|
from adam_atan2_pytorch import MuonAdamAtan2
|
||||||
|
|
||||||
from x_mlps_pytorch.normed_mlp import create_mlp
|
from x_mlps_pytorch.normed_mlp import create_mlp
|
||||||
from x_mlps_pytorch.ensemble import Ensemble
|
|
||||||
|
|
||||||
from hyper_connections import get_init_and_expand_reduce_stream_functions
|
from hyper_connections import get_init_and_expand_reduce_stream_functions
|
||||||
|
|
||||||
@ -1816,6 +1815,7 @@ class DynamicsWorldModel(Module):
|
|||||||
num_steps = 4,
|
num_steps = 4,
|
||||||
batch_size = 1,
|
batch_size = 1,
|
||||||
agent_index = 0,
|
agent_index = 0,
|
||||||
|
tasks: int | Tensor | None = None,
|
||||||
image_height = None,
|
image_height = None,
|
||||||
image_width = None,
|
image_width = None,
|
||||||
return_decoded_video = None,
|
return_decoded_video = None,
|
||||||
@ -1829,9 +1829,18 @@ class DynamicsWorldModel(Module):
|
|||||||
was_training = self.training
|
was_training = self.training
|
||||||
self.eval()
|
self.eval()
|
||||||
|
|
||||||
|
# validation
|
||||||
|
|
||||||
assert log2(num_steps).is_integer(), f'number of steps {num_steps} must be a power of 2'
|
assert log2(num_steps).is_integer(), f'number of steps {num_steps} must be a power of 2'
|
||||||
assert 0 < num_steps <= self.max_steps, f'number of steps {num_steps} must be between 0 and {self.max_steps}'
|
assert 0 < num_steps <= self.max_steps, f'number of steps {num_steps} must be between 0 and {self.max_steps}'
|
||||||
|
|
||||||
|
if isinstance(tasks, int):
|
||||||
|
tasks = torch.full((batch_size,), tasks, device = self.device)
|
||||||
|
|
||||||
|
assert tasks.shape[0] == batch_size
|
||||||
|
|
||||||
|
# get state latent shape
|
||||||
|
|
||||||
latent_shape = self.latent_shape
|
latent_shape = self.latent_shape
|
||||||
|
|
||||||
# derive step size
|
# derive step size
|
||||||
@ -1884,6 +1893,7 @@ class DynamicsWorldModel(Module):
|
|||||||
signal_levels = signal_levels_with_context,
|
signal_levels = signal_levels_with_context,
|
||||||
step_sizes = step_size,
|
step_sizes = step_size,
|
||||||
rewards = decoded_rewards,
|
rewards = decoded_rewards,
|
||||||
|
tasks = tasks,
|
||||||
discrete_actions = decoded_discrete_actions,
|
discrete_actions = decoded_discrete_actions,
|
||||||
continuous_actions = decoded_continuous_actions,
|
continuous_actions = decoded_continuous_actions,
|
||||||
latent_is_noised = True,
|
latent_is_noised = True,
|
||||||
@ -2349,6 +2359,7 @@ class DynamicsWorldModel(Module):
|
|||||||
encoded_agent_tokens = reduce(encoded_agent_tokens, 'b t g d -> b t d', 'mean')
|
encoded_agent_tokens = reduce(encoded_agent_tokens, 'b t g d -> b t d', 'mean')
|
||||||
|
|
||||||
reward_pred = self.to_reward_pred(encoded_agent_tokens)
|
reward_pred = self.to_reward_pred(encoded_agent_tokens)
|
||||||
|
|
||||||
reward_loss = F.cross_entropy(reward_pred, two_hot_encoding)
|
reward_loss = F.cross_entropy(reward_pred, two_hot_encoding)
|
||||||
|
|
||||||
# maybe autoregressive action loss
|
# maybe autoregressive action loss
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "dreamer4"
|
name = "dreamer4"
|
||||||
version = "0.0.32"
|
version = "0.0.33"
|
||||||
description = "Dreamer 4"
|
description = "Dreamer 4"
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user