complete behavior cloning for one agent
This commit is contained in:
parent
02558d1f08
commit
b2725d9b6e
@ -8,6 +8,7 @@ from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.distributions import Normal
|
||||
from torch.nn import Module, ModuleList, Embedding, Parameter, Sequential, Linear, RMSNorm, Identity
|
||||
from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange
|
||||
|
||||
@ -58,6 +59,8 @@ LinearNoBias = partial(Linear, bias = False)
|
||||
|
||||
TokenizerLosses = namedtuple('TokenizerLosses', ('recon', 'lpips'))
|
||||
|
||||
WorldModelLosses = namedtuple('WorldModelLosses', ('flow', 'reward', 'behavior_clone'))
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(v):
|
||||
@ -83,6 +86,9 @@ def is_power_two(num):
|
||||
|
||||
# tensor helpers
|
||||
|
||||
def log(t, eps = 1e-20):
|
||||
return t.clamp(min = eps).log()
|
||||
|
||||
def pack_one(t, pattern):
|
||||
packed, packed_shape = pack([t], pattern)
|
||||
|
||||
@ -372,7 +378,7 @@ class ActionEmbedder(Module):
|
||||
|
||||
def unembed(
|
||||
self,
|
||||
embeds, # (... d)
|
||||
embeds, # (... d)
|
||||
discrete_action_types = None, # (na)
|
||||
continuous_action_types = None, # (na)
|
||||
return_split_discrete = False
|
||||
@ -421,6 +427,46 @@ class ActionEmbedder(Module):
|
||||
|
||||
return discrete_action_logits, continuous_action_mean_log_var
|
||||
|
||||
def log_probs(
|
||||
self,
|
||||
embeds, # (... d)
|
||||
discrete_targets = None, # (... na)
|
||||
continuous_targets = None, # (... na)
|
||||
discrete_action_types = None, # (na)
|
||||
continuous_action_types = None, # (na)
|
||||
):
|
||||
discrete_action_logits, continuous_action_mean_log_var = self.unembed(embeds, discrete_action_types = discrete_action_types, continuous_action_types = continuous_action_types, return_split_discrete = True)
|
||||
|
||||
# discrete
|
||||
|
||||
discrete_log_probs = None
|
||||
|
||||
if exists(discrete_targets):
|
||||
discrete_log_probs = []
|
||||
|
||||
for one_discrete_action_logit, one_discrete_target in zip(discrete_action_logits, discrete_targets.unbind(dim = -1)):
|
||||
|
||||
one_discrete_log_probs = one_discrete_action_logit.log_softmax(dim = -1)
|
||||
one_discrete_target = rearrange(one_discrete_target, '... -> ... 1')
|
||||
|
||||
log_prob = one_discrete_log_probs.gather(-1, one_discrete_target)
|
||||
discrete_log_probs.append(log_prob)
|
||||
|
||||
discrete_log_probs = cat(discrete_log_probs, dim = -1)
|
||||
|
||||
# continuous
|
||||
|
||||
continuous_log_probs = None
|
||||
|
||||
if exists(continuous_targets):
|
||||
mean, log_var = continuous_action_mean_log_var.unbind(dim = -1)
|
||||
std = (0.5 * log_var).exp()
|
||||
|
||||
distr = Normal(mean, std)
|
||||
continuous_log_probs = distr.log_prob(continuous_targets)
|
||||
|
||||
return discrete_log_probs, continuous_log_probs
|
||||
|
||||
def forward(
|
||||
self,
|
||||
*,
|
||||
@ -1325,6 +1371,7 @@ class DynamicsWorldModel(Module):
|
||||
reward_loss_weight = 0.1,
|
||||
value_head_mlp_depth = 3,
|
||||
policy_head_mlp_depth = 3,
|
||||
behavior_clone_weight = 0.1,
|
||||
num_latent_genes = 0, # for carrying out evolution within the dreams https://web3.arxiv.org/abs/2503.19037
|
||||
num_residual_streams = 1
|
||||
):
|
||||
@ -1441,6 +1488,8 @@ class DynamicsWorldModel(Module):
|
||||
unembed_dim = dim * 4
|
||||
)
|
||||
|
||||
self.behavior_clone_weight = behavior_clone_weight
|
||||
|
||||
# each agent token will have the reward embedding of the previous time step - but could eventually just give reward its own token
|
||||
|
||||
self.add_reward_embed_to_agent_token = add_reward_embed_to_agent_token
|
||||
@ -1671,7 +1720,8 @@ class DynamicsWorldModel(Module):
|
||||
return_pred_only = False,
|
||||
latent_is_noised = False,
|
||||
return_all_losses = False,
|
||||
return_agent_tokens = False
|
||||
return_agent_tokens = False,
|
||||
add_autoregressive_action_loss = False
|
||||
):
|
||||
# handle video or latents
|
||||
|
||||
@ -2023,17 +2073,48 @@ class DynamicsWorldModel(Module):
|
||||
reward_pred = self.to_reward_pred(encoded_agent_tokens)
|
||||
reward_loss = F.cross_entropy(reward_pred, two_hot_encoding)
|
||||
|
||||
# maybe autoregressive action loss
|
||||
|
||||
behavior_clone_loss = self.zero
|
||||
|
||||
if (
|
||||
self.num_agents == 1 and
|
||||
add_autoregressive_action_loss and
|
||||
(exists(discrete_actions) or exists(continuous_actions))
|
||||
):
|
||||
assert self.action_embedder.has_actions
|
||||
|
||||
# only for 1 agent
|
||||
|
||||
agent_tokens = rearrange(agent_tokens, 'b t 1 d -> b t d')
|
||||
policy_embed = self.policy_head(agent_tokens[:, :-1])
|
||||
|
||||
discrete_log_probs, continuous_log_probs = self.action_embedder.log_probs(
|
||||
policy_embed,
|
||||
discrete_targets = discrete_actions[:, 1:] if exists(discrete_actions) else None,
|
||||
continuous_targets = continuous_actions[:, 1:] if exists(continuous_actions) else None
|
||||
)
|
||||
|
||||
if exists(discrete_log_probs):
|
||||
behavior_clone_loss = behavior_clone_loss + discrete_log_probs.sum(dim = -1).mean()
|
||||
|
||||
if exists(continuous_log_probs):
|
||||
behavior_clone_loss = behavior_clone_loss + continuous_log_probs.sum(dim = -1).mean()
|
||||
|
||||
# gather losses
|
||||
|
||||
total_loss = (
|
||||
flow_loss +
|
||||
reward_loss * self.reward_loss_weight
|
||||
reward_loss * self.reward_loss_weight +
|
||||
behavior_clone_loss * self.behavior_clone_weight
|
||||
)
|
||||
|
||||
if not return_all_losses:
|
||||
return total_loss
|
||||
|
||||
return total_loss, (flow_loss, reward_loss)
|
||||
losses = WorldModelLosses(flow_loss, reward_loss, behavior_clone_loss)
|
||||
|
||||
return total_loss, losses
|
||||
|
||||
# dreamer
|
||||
|
||||
|
||||
@ -95,7 +95,8 @@ def test_e2e(
|
||||
tasks = tasks,
|
||||
signal_levels = signal_levels,
|
||||
step_sizes_log2 = step_sizes_log2,
|
||||
discrete_actions = actions
|
||||
discrete_actions = actions,
|
||||
add_autoregressive_action_loss = True
|
||||
)
|
||||
|
||||
assert flow_loss.numel() == 1
|
||||
@ -286,3 +287,17 @@ def test_action_embedder():
|
||||
|
||||
assert discrete_logits.shape == (2, 3, 4)
|
||||
assert continuous_mean_log_var.shape == (2, 3, 1, 2)
|
||||
|
||||
# log probs
|
||||
|
||||
assert discrete_logits.shape == (2, 3, 4)
|
||||
assert continuous_mean_log_var.shape == (2, 3, 1, 2)
|
||||
|
||||
discrete_log_probs, continuous_log_probs = embedder.log_probs(
|
||||
action_embed,
|
||||
discrete_targets = discrete_actions,
|
||||
continuous_targets = continuous_actions
|
||||
)
|
||||
|
||||
assert discrete_log_probs.shape == (2, 3, 2)
|
||||
assert continuous_log_probs.shape == (2, 3, 2)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user