complete behavior cloning for one agent

This commit is contained in:
lucidrains 2025-10-11 09:24:49 -07:00
parent 02558d1f08
commit b2725d9b6e
2 changed files with 101 additions and 5 deletions

View File

@ -8,6 +8,7 @@ from functools import partial
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch.distributions import Normal
from torch.nn import Module, ModuleList, Embedding, Parameter, Sequential, Linear, RMSNorm, Identity from torch.nn import Module, ModuleList, Embedding, Parameter, Sequential, Linear, RMSNorm, Identity
from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange
@ -58,6 +59,8 @@ LinearNoBias = partial(Linear, bias = False)
TokenizerLosses = namedtuple('TokenizerLosses', ('recon', 'lpips')) TokenizerLosses = namedtuple('TokenizerLosses', ('recon', 'lpips'))
WorldModelLosses = namedtuple('WorldModelLosses', ('flow', 'reward', 'behavior_clone'))
# helpers # helpers
def exists(v): def exists(v):
@ -83,6 +86,9 @@ def is_power_two(num):
# tensor helpers # tensor helpers
def log(t, eps = 1e-20):
return t.clamp(min = eps).log()
def pack_one(t, pattern): def pack_one(t, pattern):
packed, packed_shape = pack([t], pattern) packed, packed_shape = pack([t], pattern)
@ -372,7 +378,7 @@ class ActionEmbedder(Module):
def unembed( def unembed(
self, self,
embeds, # (... d) embeds, # (... d)
discrete_action_types = None, # (na) discrete_action_types = None, # (na)
continuous_action_types = None, # (na) continuous_action_types = None, # (na)
return_split_discrete = False return_split_discrete = False
@ -421,6 +427,46 @@ class ActionEmbedder(Module):
return discrete_action_logits, continuous_action_mean_log_var return discrete_action_logits, continuous_action_mean_log_var
def log_probs(
self,
embeds, # (... d)
discrete_targets = None, # (... na)
continuous_targets = None, # (... na)
discrete_action_types = None, # (na)
continuous_action_types = None, # (na)
):
discrete_action_logits, continuous_action_mean_log_var = self.unembed(embeds, discrete_action_types = discrete_action_types, continuous_action_types = continuous_action_types, return_split_discrete = True)
# discrete
discrete_log_probs = None
if exists(discrete_targets):
discrete_log_probs = []
for one_discrete_action_logit, one_discrete_target in zip(discrete_action_logits, discrete_targets.unbind(dim = -1)):
one_discrete_log_probs = one_discrete_action_logit.log_softmax(dim = -1)
one_discrete_target = rearrange(one_discrete_target, '... -> ... 1')
log_prob = one_discrete_log_probs.gather(-1, one_discrete_target)
discrete_log_probs.append(log_prob)
discrete_log_probs = cat(discrete_log_probs, dim = -1)
# continuous
continuous_log_probs = None
if exists(continuous_targets):
mean, log_var = continuous_action_mean_log_var.unbind(dim = -1)
std = (0.5 * log_var).exp()
distr = Normal(mean, std)
continuous_log_probs = distr.log_prob(continuous_targets)
return discrete_log_probs, continuous_log_probs
def forward( def forward(
self, self,
*, *,
@ -1325,6 +1371,7 @@ class DynamicsWorldModel(Module):
reward_loss_weight = 0.1, reward_loss_weight = 0.1,
value_head_mlp_depth = 3, value_head_mlp_depth = 3,
policy_head_mlp_depth = 3, policy_head_mlp_depth = 3,
behavior_clone_weight = 0.1,
num_latent_genes = 0, # for carrying out evolution within the dreams https://web3.arxiv.org/abs/2503.19037 num_latent_genes = 0, # for carrying out evolution within the dreams https://web3.arxiv.org/abs/2503.19037
num_residual_streams = 1 num_residual_streams = 1
): ):
@ -1441,6 +1488,8 @@ class DynamicsWorldModel(Module):
unembed_dim = dim * 4 unembed_dim = dim * 4
) )
self.behavior_clone_weight = behavior_clone_weight
# each agent token will have the reward embedding of the previous time step - but could eventually just give reward its own token # each agent token will have the reward embedding of the previous time step - but could eventually just give reward its own token
self.add_reward_embed_to_agent_token = add_reward_embed_to_agent_token self.add_reward_embed_to_agent_token = add_reward_embed_to_agent_token
@ -1671,7 +1720,8 @@ class DynamicsWorldModel(Module):
return_pred_only = False, return_pred_only = False,
latent_is_noised = False, latent_is_noised = False,
return_all_losses = False, return_all_losses = False,
return_agent_tokens = False return_agent_tokens = False,
add_autoregressive_action_loss = False
): ):
# handle video or latents # handle video or latents
@ -2023,17 +2073,48 @@ class DynamicsWorldModel(Module):
reward_pred = self.to_reward_pred(encoded_agent_tokens) reward_pred = self.to_reward_pred(encoded_agent_tokens)
reward_loss = F.cross_entropy(reward_pred, two_hot_encoding) reward_loss = F.cross_entropy(reward_pred, two_hot_encoding)
# maybe autoregressive action loss
behavior_clone_loss = self.zero
if (
self.num_agents == 1 and
add_autoregressive_action_loss and
(exists(discrete_actions) or exists(continuous_actions))
):
assert self.action_embedder.has_actions
# only for 1 agent
agent_tokens = rearrange(agent_tokens, 'b t 1 d -> b t d')
policy_embed = self.policy_head(agent_tokens[:, :-1])
discrete_log_probs, continuous_log_probs = self.action_embedder.log_probs(
policy_embed,
discrete_targets = discrete_actions[:, 1:] if exists(discrete_actions) else None,
continuous_targets = continuous_actions[:, 1:] if exists(continuous_actions) else None
)
if exists(discrete_log_probs):
behavior_clone_loss = behavior_clone_loss + discrete_log_probs.sum(dim = -1).mean()
if exists(continuous_log_probs):
behavior_clone_loss = behavior_clone_loss + continuous_log_probs.sum(dim = -1).mean()
# gather losses # gather losses
total_loss = ( total_loss = (
flow_loss + flow_loss +
reward_loss * self.reward_loss_weight reward_loss * self.reward_loss_weight +
behavior_clone_loss * self.behavior_clone_weight
) )
if not return_all_losses: if not return_all_losses:
return total_loss return total_loss
return total_loss, (flow_loss, reward_loss) losses = WorldModelLosses(flow_loss, reward_loss, behavior_clone_loss)
return total_loss, losses
# dreamer # dreamer

View File

@ -95,7 +95,8 @@ def test_e2e(
tasks = tasks, tasks = tasks,
signal_levels = signal_levels, signal_levels = signal_levels,
step_sizes_log2 = step_sizes_log2, step_sizes_log2 = step_sizes_log2,
discrete_actions = actions discrete_actions = actions,
add_autoregressive_action_loss = True
) )
assert flow_loss.numel() == 1 assert flow_loss.numel() == 1
@ -286,3 +287,17 @@ def test_action_embedder():
assert discrete_logits.shape == (2, 3, 4) assert discrete_logits.shape == (2, 3, 4)
assert continuous_mean_log_var.shape == (2, 3, 1, 2) assert continuous_mean_log_var.shape == (2, 3, 1, 2)
# log probs
assert discrete_logits.shape == (2, 3, 4)
assert continuous_mean_log_var.shape == (2, 3, 1, 2)
discrete_log_probs, continuous_log_probs = embedder.log_probs(
action_embed,
discrete_targets = discrete_actions,
continuous_targets = continuous_actions
)
assert discrete_log_probs.shape == (2, 3, 2)
assert continuous_log_probs.shape == (2, 3, 2)