complete behavior cloning for one agent
This commit is contained in:
parent
02558d1f08
commit
b2725d9b6e
@ -8,6 +8,7 @@ from functools import partial
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from torch.distributions import Normal
|
||||||
from torch.nn import Module, ModuleList, Embedding, Parameter, Sequential, Linear, RMSNorm, Identity
|
from torch.nn import Module, ModuleList, Embedding, Parameter, Sequential, Linear, RMSNorm, Identity
|
||||||
from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange
|
from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange
|
||||||
|
|
||||||
@ -58,6 +59,8 @@ LinearNoBias = partial(Linear, bias = False)
|
|||||||
|
|
||||||
TokenizerLosses = namedtuple('TokenizerLosses', ('recon', 'lpips'))
|
TokenizerLosses = namedtuple('TokenizerLosses', ('recon', 'lpips'))
|
||||||
|
|
||||||
|
WorldModelLosses = namedtuple('WorldModelLosses', ('flow', 'reward', 'behavior_clone'))
|
||||||
|
|
||||||
# helpers
|
# helpers
|
||||||
|
|
||||||
def exists(v):
|
def exists(v):
|
||||||
@ -83,6 +86,9 @@ def is_power_two(num):
|
|||||||
|
|
||||||
# tensor helpers
|
# tensor helpers
|
||||||
|
|
||||||
|
def log(t, eps = 1e-20):
|
||||||
|
return t.clamp(min = eps).log()
|
||||||
|
|
||||||
def pack_one(t, pattern):
|
def pack_one(t, pattern):
|
||||||
packed, packed_shape = pack([t], pattern)
|
packed, packed_shape = pack([t], pattern)
|
||||||
|
|
||||||
@ -372,7 +378,7 @@ class ActionEmbedder(Module):
|
|||||||
|
|
||||||
def unembed(
|
def unembed(
|
||||||
self,
|
self,
|
||||||
embeds, # (... d)
|
embeds, # (... d)
|
||||||
discrete_action_types = None, # (na)
|
discrete_action_types = None, # (na)
|
||||||
continuous_action_types = None, # (na)
|
continuous_action_types = None, # (na)
|
||||||
return_split_discrete = False
|
return_split_discrete = False
|
||||||
@ -421,6 +427,46 @@ class ActionEmbedder(Module):
|
|||||||
|
|
||||||
return discrete_action_logits, continuous_action_mean_log_var
|
return discrete_action_logits, continuous_action_mean_log_var
|
||||||
|
|
||||||
|
def log_probs(
|
||||||
|
self,
|
||||||
|
embeds, # (... d)
|
||||||
|
discrete_targets = None, # (... na)
|
||||||
|
continuous_targets = None, # (... na)
|
||||||
|
discrete_action_types = None, # (na)
|
||||||
|
continuous_action_types = None, # (na)
|
||||||
|
):
|
||||||
|
discrete_action_logits, continuous_action_mean_log_var = self.unembed(embeds, discrete_action_types = discrete_action_types, continuous_action_types = continuous_action_types, return_split_discrete = True)
|
||||||
|
|
||||||
|
# discrete
|
||||||
|
|
||||||
|
discrete_log_probs = None
|
||||||
|
|
||||||
|
if exists(discrete_targets):
|
||||||
|
discrete_log_probs = []
|
||||||
|
|
||||||
|
for one_discrete_action_logit, one_discrete_target in zip(discrete_action_logits, discrete_targets.unbind(dim = -1)):
|
||||||
|
|
||||||
|
one_discrete_log_probs = one_discrete_action_logit.log_softmax(dim = -1)
|
||||||
|
one_discrete_target = rearrange(one_discrete_target, '... -> ... 1')
|
||||||
|
|
||||||
|
log_prob = one_discrete_log_probs.gather(-1, one_discrete_target)
|
||||||
|
discrete_log_probs.append(log_prob)
|
||||||
|
|
||||||
|
discrete_log_probs = cat(discrete_log_probs, dim = -1)
|
||||||
|
|
||||||
|
# continuous
|
||||||
|
|
||||||
|
continuous_log_probs = None
|
||||||
|
|
||||||
|
if exists(continuous_targets):
|
||||||
|
mean, log_var = continuous_action_mean_log_var.unbind(dim = -1)
|
||||||
|
std = (0.5 * log_var).exp()
|
||||||
|
|
||||||
|
distr = Normal(mean, std)
|
||||||
|
continuous_log_probs = distr.log_prob(continuous_targets)
|
||||||
|
|
||||||
|
return discrete_log_probs, continuous_log_probs
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
@ -1325,6 +1371,7 @@ class DynamicsWorldModel(Module):
|
|||||||
reward_loss_weight = 0.1,
|
reward_loss_weight = 0.1,
|
||||||
value_head_mlp_depth = 3,
|
value_head_mlp_depth = 3,
|
||||||
policy_head_mlp_depth = 3,
|
policy_head_mlp_depth = 3,
|
||||||
|
behavior_clone_weight = 0.1,
|
||||||
num_latent_genes = 0, # for carrying out evolution within the dreams https://web3.arxiv.org/abs/2503.19037
|
num_latent_genes = 0, # for carrying out evolution within the dreams https://web3.arxiv.org/abs/2503.19037
|
||||||
num_residual_streams = 1
|
num_residual_streams = 1
|
||||||
):
|
):
|
||||||
@ -1441,6 +1488,8 @@ class DynamicsWorldModel(Module):
|
|||||||
unembed_dim = dim * 4
|
unembed_dim = dim * 4
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.behavior_clone_weight = behavior_clone_weight
|
||||||
|
|
||||||
# each agent token will have the reward embedding of the previous time step - but could eventually just give reward its own token
|
# each agent token will have the reward embedding of the previous time step - but could eventually just give reward its own token
|
||||||
|
|
||||||
self.add_reward_embed_to_agent_token = add_reward_embed_to_agent_token
|
self.add_reward_embed_to_agent_token = add_reward_embed_to_agent_token
|
||||||
@ -1671,7 +1720,8 @@ class DynamicsWorldModel(Module):
|
|||||||
return_pred_only = False,
|
return_pred_only = False,
|
||||||
latent_is_noised = False,
|
latent_is_noised = False,
|
||||||
return_all_losses = False,
|
return_all_losses = False,
|
||||||
return_agent_tokens = False
|
return_agent_tokens = False,
|
||||||
|
add_autoregressive_action_loss = False
|
||||||
):
|
):
|
||||||
# handle video or latents
|
# handle video or latents
|
||||||
|
|
||||||
@ -2023,17 +2073,48 @@ class DynamicsWorldModel(Module):
|
|||||||
reward_pred = self.to_reward_pred(encoded_agent_tokens)
|
reward_pred = self.to_reward_pred(encoded_agent_tokens)
|
||||||
reward_loss = F.cross_entropy(reward_pred, two_hot_encoding)
|
reward_loss = F.cross_entropy(reward_pred, two_hot_encoding)
|
||||||
|
|
||||||
|
# maybe autoregressive action loss
|
||||||
|
|
||||||
|
behavior_clone_loss = self.zero
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.num_agents == 1 and
|
||||||
|
add_autoregressive_action_loss and
|
||||||
|
(exists(discrete_actions) or exists(continuous_actions))
|
||||||
|
):
|
||||||
|
assert self.action_embedder.has_actions
|
||||||
|
|
||||||
|
# only for 1 agent
|
||||||
|
|
||||||
|
agent_tokens = rearrange(agent_tokens, 'b t 1 d -> b t d')
|
||||||
|
policy_embed = self.policy_head(agent_tokens[:, :-1])
|
||||||
|
|
||||||
|
discrete_log_probs, continuous_log_probs = self.action_embedder.log_probs(
|
||||||
|
policy_embed,
|
||||||
|
discrete_targets = discrete_actions[:, 1:] if exists(discrete_actions) else None,
|
||||||
|
continuous_targets = continuous_actions[:, 1:] if exists(continuous_actions) else None
|
||||||
|
)
|
||||||
|
|
||||||
|
if exists(discrete_log_probs):
|
||||||
|
behavior_clone_loss = behavior_clone_loss + discrete_log_probs.sum(dim = -1).mean()
|
||||||
|
|
||||||
|
if exists(continuous_log_probs):
|
||||||
|
behavior_clone_loss = behavior_clone_loss + continuous_log_probs.sum(dim = -1).mean()
|
||||||
|
|
||||||
# gather losses
|
# gather losses
|
||||||
|
|
||||||
total_loss = (
|
total_loss = (
|
||||||
flow_loss +
|
flow_loss +
|
||||||
reward_loss * self.reward_loss_weight
|
reward_loss * self.reward_loss_weight +
|
||||||
|
behavior_clone_loss * self.behavior_clone_weight
|
||||||
)
|
)
|
||||||
|
|
||||||
if not return_all_losses:
|
if not return_all_losses:
|
||||||
return total_loss
|
return total_loss
|
||||||
|
|
||||||
return total_loss, (flow_loss, reward_loss)
|
losses = WorldModelLosses(flow_loss, reward_loss, behavior_clone_loss)
|
||||||
|
|
||||||
|
return total_loss, losses
|
||||||
|
|
||||||
# dreamer
|
# dreamer
|
||||||
|
|
||||||
|
|||||||
@ -95,7 +95,8 @@ def test_e2e(
|
|||||||
tasks = tasks,
|
tasks = tasks,
|
||||||
signal_levels = signal_levels,
|
signal_levels = signal_levels,
|
||||||
step_sizes_log2 = step_sizes_log2,
|
step_sizes_log2 = step_sizes_log2,
|
||||||
discrete_actions = actions
|
discrete_actions = actions,
|
||||||
|
add_autoregressive_action_loss = True
|
||||||
)
|
)
|
||||||
|
|
||||||
assert flow_loss.numel() == 1
|
assert flow_loss.numel() == 1
|
||||||
@ -286,3 +287,17 @@ def test_action_embedder():
|
|||||||
|
|
||||||
assert discrete_logits.shape == (2, 3, 4)
|
assert discrete_logits.shape == (2, 3, 4)
|
||||||
assert continuous_mean_log_var.shape == (2, 3, 1, 2)
|
assert continuous_mean_log_var.shape == (2, 3, 1, 2)
|
||||||
|
|
||||||
|
# log probs
|
||||||
|
|
||||||
|
assert discrete_logits.shape == (2, 3, 4)
|
||||||
|
assert continuous_mean_log_var.shape == (2, 3, 1, 2)
|
||||||
|
|
||||||
|
discrete_log_probs, continuous_log_probs = embedder.log_probs(
|
||||||
|
action_embed,
|
||||||
|
discrete_targets = discrete_actions,
|
||||||
|
continuous_targets = continuous_actions
|
||||||
|
)
|
||||||
|
|
||||||
|
assert discrete_log_probs.shape == (2, 3, 2)
|
||||||
|
assert continuous_log_probs.shape == (2, 3, 2)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user