complete behavior cloning for one agent

This commit is contained in:
lucidrains 2025-10-11 09:24:49 -07:00
parent 02558d1f08
commit b2725d9b6e
2 changed files with 101 additions and 5 deletions

View File

@ -8,6 +8,7 @@ from functools import partial
import torch
import torch.nn.functional as F
from torch.distributions import Normal
from torch.nn import Module, ModuleList, Embedding, Parameter, Sequential, Linear, RMSNorm, Identity
from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange
@ -58,6 +59,8 @@ LinearNoBias = partial(Linear, bias = False)
TokenizerLosses = namedtuple('TokenizerLosses', ('recon', 'lpips'))
WorldModelLosses = namedtuple('WorldModelLosses', ('flow', 'reward', 'behavior_clone'))
# helpers
def exists(v):
@ -83,6 +86,9 @@ def is_power_two(num):
# tensor helpers
def log(t, eps = 1e-20):
return t.clamp(min = eps).log()
def pack_one(t, pattern):
packed, packed_shape = pack([t], pattern)
@ -372,7 +378,7 @@ class ActionEmbedder(Module):
def unembed(
self,
embeds, # (... d)
embeds, # (... d)
discrete_action_types = None, # (na)
continuous_action_types = None, # (na)
return_split_discrete = False
@ -421,6 +427,46 @@ class ActionEmbedder(Module):
return discrete_action_logits, continuous_action_mean_log_var
def log_probs(
self,
embeds, # (... d)
discrete_targets = None, # (... na)
continuous_targets = None, # (... na)
discrete_action_types = None, # (na)
continuous_action_types = None, # (na)
):
discrete_action_logits, continuous_action_mean_log_var = self.unembed(embeds, discrete_action_types = discrete_action_types, continuous_action_types = continuous_action_types, return_split_discrete = True)
# discrete
discrete_log_probs = None
if exists(discrete_targets):
discrete_log_probs = []
for one_discrete_action_logit, one_discrete_target in zip(discrete_action_logits, discrete_targets.unbind(dim = -1)):
one_discrete_log_probs = one_discrete_action_logit.log_softmax(dim = -1)
one_discrete_target = rearrange(one_discrete_target, '... -> ... 1')
log_prob = one_discrete_log_probs.gather(-1, one_discrete_target)
discrete_log_probs.append(log_prob)
discrete_log_probs = cat(discrete_log_probs, dim = -1)
# continuous
continuous_log_probs = None
if exists(continuous_targets):
mean, log_var = continuous_action_mean_log_var.unbind(dim = -1)
std = (0.5 * log_var).exp()
distr = Normal(mean, std)
continuous_log_probs = distr.log_prob(continuous_targets)
return discrete_log_probs, continuous_log_probs
def forward(
self,
*,
@ -1325,6 +1371,7 @@ class DynamicsWorldModel(Module):
reward_loss_weight = 0.1,
value_head_mlp_depth = 3,
policy_head_mlp_depth = 3,
behavior_clone_weight = 0.1,
num_latent_genes = 0, # for carrying out evolution within the dreams https://web3.arxiv.org/abs/2503.19037
num_residual_streams = 1
):
@ -1441,6 +1488,8 @@ class DynamicsWorldModel(Module):
unembed_dim = dim * 4
)
self.behavior_clone_weight = behavior_clone_weight
# each agent token will have the reward embedding of the previous time step - but could eventually just give reward its own token
self.add_reward_embed_to_agent_token = add_reward_embed_to_agent_token
@ -1671,7 +1720,8 @@ class DynamicsWorldModel(Module):
return_pred_only = False,
latent_is_noised = False,
return_all_losses = False,
return_agent_tokens = False
return_agent_tokens = False,
add_autoregressive_action_loss = False
):
# handle video or latents
@ -2023,17 +2073,48 @@ class DynamicsWorldModel(Module):
reward_pred = self.to_reward_pred(encoded_agent_tokens)
reward_loss = F.cross_entropy(reward_pred, two_hot_encoding)
# maybe autoregressive action loss
behavior_clone_loss = self.zero
if (
self.num_agents == 1 and
add_autoregressive_action_loss and
(exists(discrete_actions) or exists(continuous_actions))
):
assert self.action_embedder.has_actions
# only for 1 agent
agent_tokens = rearrange(agent_tokens, 'b t 1 d -> b t d')
policy_embed = self.policy_head(agent_tokens[:, :-1])
discrete_log_probs, continuous_log_probs = self.action_embedder.log_probs(
policy_embed,
discrete_targets = discrete_actions[:, 1:] if exists(discrete_actions) else None,
continuous_targets = continuous_actions[:, 1:] if exists(continuous_actions) else None
)
if exists(discrete_log_probs):
behavior_clone_loss = behavior_clone_loss + discrete_log_probs.sum(dim = -1).mean()
if exists(continuous_log_probs):
behavior_clone_loss = behavior_clone_loss + continuous_log_probs.sum(dim = -1).mean()
# gather losses
total_loss = (
flow_loss +
reward_loss * self.reward_loss_weight
reward_loss * self.reward_loss_weight +
behavior_clone_loss * self.behavior_clone_weight
)
if not return_all_losses:
return total_loss
return total_loss, (flow_loss, reward_loss)
losses = WorldModelLosses(flow_loss, reward_loss, behavior_clone_loss)
return total_loss, losses
# dreamer

View File

@ -95,7 +95,8 @@ def test_e2e(
tasks = tasks,
signal_levels = signal_levels,
step_sizes_log2 = step_sizes_log2,
discrete_actions = actions
discrete_actions = actions,
add_autoregressive_action_loss = True
)
assert flow_loss.numel() == 1
@ -286,3 +287,17 @@ def test_action_embedder():
assert discrete_logits.shape == (2, 3, 4)
assert continuous_mean_log_var.shape == (2, 3, 1, 2)
# log probs
assert discrete_logits.shape == (2, 3, 4)
assert continuous_mean_log_var.shape == (2, 3, 1, 2)
discrete_log_probs, continuous_log_probs = embedder.log_probs(
action_embed,
discrete_targets = discrete_actions,
continuous_targets = continuous_actions
)
assert discrete_log_probs.shape == (2, 3, 2)
assert continuous_log_probs.shape == (2, 3, 2)