start carving into the reinforcement learning portion, starting with reward prediction head (single for now)
This commit is contained in:
parent
77724049e2
commit
7180a8cf43
@ -29,6 +29,7 @@ from accelerate import Accelerator
|
||||
# h - attention heads
|
||||
# d - feature dimension
|
||||
# f - frequencies (rotary)
|
||||
# l - logit / predicted bins
|
||||
# p - positions (3 for spacetime in this work)
|
||||
# t - time
|
||||
# g - groups of query heads to key heads (gqa)
|
||||
@ -176,26 +177,49 @@ class SymExpTwoHot(Module):
|
||||
def __init__(
|
||||
self,
|
||||
range = (-20., 20.),
|
||||
bins = 255
|
||||
num_bins = 255,
|
||||
learned_embedding = False,
|
||||
dim_embed = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
min_value, max_value = range
|
||||
values = torch.linspace(min_value, max_value, bins)
|
||||
values = torch.linspace(min_value, max_value, num_bins)
|
||||
values = values.sign() * (torch.exp(values.abs()) - 1.)
|
||||
|
||||
self.num_bins = bins
|
||||
self.num_bins = num_bins
|
||||
self.register_buffer('bin_values', values)
|
||||
|
||||
# take care of a reward embedding
|
||||
# for an improvisation where agent tokens can also see the past rewards - it makes sense that this information should not be thrown out, a la Decision Transformer
|
||||
|
||||
self.learned_embedding = learned_embedding
|
||||
|
||||
if learned_embedding:
|
||||
assert exists(dim_embed)
|
||||
self.bin_embeds = nn.Embedding(num_bins, dim_embed)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self.bin_values.device
|
||||
|
||||
def logits_to_scalar_value(
|
||||
def embed(
|
||||
self,
|
||||
logits # (... l)
|
||||
two_hot_encoding,
|
||||
):
|
||||
return einsum(logits, self.bin_values, '... l, l -> ...')
|
||||
assert self.learned_embedding, f'can only embed if `learned_embedding` is True'
|
||||
|
||||
weights, bin_indices = two_hot_encoding.topk(k = 2, dim = -1)
|
||||
|
||||
two_embeds = self.bin_embeds(bin_indices)
|
||||
|
||||
return einsum(two_embeds, weights, '... two d, ... two -> ... d')
|
||||
|
||||
def bins_to_scalar_value(
|
||||
self,
|
||||
two_hot_encoding # (... l)
|
||||
):
|
||||
return einsum(two_hot_encoding, self.bin_values, '... l, l -> ...')
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -996,6 +1020,7 @@ class DynamicsModel(Module):
|
||||
num_spatial_tokens = 2, # latents projected to greater number of spatial tokens
|
||||
num_latent_tokens = None,
|
||||
num_tasks = 0,
|
||||
reward_encoder_kwargs: dict = dict(),
|
||||
depth = 4,
|
||||
pred_orig_latent = True, # directly predicting the original x0 data yield better results, rather than velocity (x-space vs v-space)
|
||||
time_block_every = 4, # every 4th block is time
|
||||
@ -1006,8 +1031,10 @@ class DynamicsModel(Module):
|
||||
attn_softclamp_value = 50.,
|
||||
ff_kwargs: dict = dict(),
|
||||
loss_weight_fn: Callable = ramp_weight,
|
||||
num_future_predictions = 8, # they do multi-token prediction of 8 steps forward
|
||||
prob_no_shortcut_train = None # probability of no shortcut training, defaults to 1 / num_step_sizes
|
||||
num_future_predictions = 8, # they do multi-token prediction of 8 steps forward
|
||||
prob_no_shortcut_train = None, # probability of no shortcut training, defaults to 1 / num_step_sizes
|
||||
add_reward_embed_to_agent_token = False,
|
||||
reward_loss_weight = 0.1
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -1024,6 +1051,8 @@ class DynamicsModel(Module):
|
||||
# spatial
|
||||
|
||||
self.num_latent_tokens = num_latent_tokens
|
||||
self.dim_latent = dim_latent
|
||||
self.latent_shape = (num_latent_tokens, dim_latent)
|
||||
|
||||
if num_spatial_tokens >= num_latent_tokens:
|
||||
assert divisible_by(num_spatial_tokens, num_latent_tokens)
|
||||
@ -1090,6 +1119,23 @@ class DynamicsModel(Module):
|
||||
self.num_tasks = num_tasks
|
||||
self.task_embed = nn.Embedding(num_tasks, dim)
|
||||
|
||||
# each agent token will have the reward embedding of the previous time step - but could eventually just give reward its own token
|
||||
|
||||
self.add_reward_embed_to_agent_token = add_reward_embed_to_agent_token
|
||||
|
||||
self.reward_encoder = SymExpTwoHot(
|
||||
**reward_encoder_kwargs,
|
||||
dim_embed = dim,
|
||||
learned_embedding = add_reward_embed_to_agent_token
|
||||
)
|
||||
|
||||
self.to_reward_pred = Sequential(
|
||||
RMSNorm(dim),
|
||||
LinearNoBias(dim, self.reward_encoder.num_bins)
|
||||
)
|
||||
|
||||
self.reward_loss_weight = reward_loss_weight
|
||||
|
||||
# attention
|
||||
|
||||
self.attn_softclamp_value = attn_softclamp_value
|
||||
@ -1122,6 +1168,10 @@ class DynamicsModel(Module):
|
||||
self.layers = ModuleList(layers)
|
||||
self.is_time = is_time
|
||||
|
||||
# zero
|
||||
|
||||
self.register_buffer('zero', tensor(0.), persistent = False)
|
||||
|
||||
def parameter(self):
|
||||
params = super().parameters()
|
||||
|
||||
@ -1138,7 +1188,9 @@ class DynamicsModel(Module):
|
||||
signal_levels = None, # (b t)
|
||||
step_sizes_log2 = None, # (b)
|
||||
tasks = None, # (b)
|
||||
return_pred_only = False
|
||||
rewards = None, # (b t)
|
||||
return_pred_only = False,
|
||||
return_all_losses = False
|
||||
):
|
||||
# handle video or latents
|
||||
|
||||
@ -1152,7 +1204,7 @@ class DynamicsModel(Module):
|
||||
if latents.ndim == 3:
|
||||
latents = rearrange(latents, 'b t d -> b t 1 d') # 1 latent edge case
|
||||
|
||||
assert latents.shape[-2] == self.num_latent_tokens
|
||||
assert latents.shape[-2:] == self.latent_shape
|
||||
|
||||
# variables
|
||||
|
||||
@ -1198,7 +1250,7 @@ class DynamicsModel(Module):
|
||||
|
||||
noised_latents = noise.lerp(latents, times)
|
||||
|
||||
# reinforcementnet learning related
|
||||
# reinforcement learning related
|
||||
|
||||
agent_tokens = repeat(self.action_learned_embed, 'd -> b d', b = batch)
|
||||
|
||||
@ -1208,9 +1260,24 @@ class DynamicsModel(Module):
|
||||
task_embeds = self.task_embed(tasks)
|
||||
agent_tokens = agent_tokens + task_embeds
|
||||
|
||||
# handle agent tokens w/ actions and task embeds
|
||||
|
||||
agent_tokens = repeat(agent_tokens, 'b d -> b t d', t = time)
|
||||
|
||||
# maybe add a reward embedding to agent tokens
|
||||
|
||||
if exists(rewards):
|
||||
two_hot_encoding = self.reward_encoder(rewards)
|
||||
|
||||
if self.add_reward_embed_to_agent_token:
|
||||
reward_embeds = self.reward_encoder.embed(two_hot_encoding)
|
||||
reward_embeds = pad_at_dim(reward_embeds, (1, -1), dim = -2, value = 0.) # shift as each agent token predicts the next reward
|
||||
|
||||
agent_tokens = agent_tokens + reward_embeds
|
||||
|
||||
# main function, needs to be defined as such for shortcut training - additional calls for consistency loss
|
||||
|
||||
def get_prediction(noised_latents, signal_levels, step_sizes_log2, agent_tokens):
|
||||
def get_prediction(noised_latents, signal_levels, step_sizes_log2, agent_tokens, return_agent_tokens = False):
|
||||
# latents to spatial tokens
|
||||
|
||||
space_tokens = self.latents_to_spatial_tokens(noised_latents)
|
||||
@ -1234,10 +1301,6 @@ class DynamicsModel(Module):
|
||||
flow_token = cat((signal_embed, step_size_embed), dim = -1)
|
||||
flow_token = rearrange(flow_token, 'b t d -> b t d')
|
||||
|
||||
# handle agent tokens w/ actions and task embeds
|
||||
|
||||
agent_tokens = repeat(agent_tokens, 'b d -> b t d', t = time)
|
||||
|
||||
# pack to tokens for attending
|
||||
|
||||
tokens, packed_tokens_shape = pack([flow_token, space_tokens, registers, agent_tokens], 'b t * d')
|
||||
@ -1297,11 +1360,14 @@ class DynamicsModel(Module):
|
||||
|
||||
pred = self.to_latent_pred(space_tokens)
|
||||
|
||||
return pred
|
||||
if not return_agent_tokens:
|
||||
return pred
|
||||
|
||||
return pred, agent_tokens
|
||||
|
||||
# forward the network
|
||||
|
||||
pred = get_prediction(noised_latents, signal_levels, step_sizes_log2, agent_tokens)
|
||||
pred, encoded_agent_tokens = get_prediction(noised_latents, signal_levels, step_sizes_log2, agent_tokens, return_agent_tokens = True)
|
||||
|
||||
if return_pred_only:
|
||||
return pred
|
||||
@ -1374,17 +1440,37 @@ class DynamicsModel(Module):
|
||||
|
||||
# mse loss
|
||||
|
||||
losses = F.mse_loss(pred, pred_target, reduction = 'none')
|
||||
flow_losses = F.mse_loss(pred, pred_target, reduction = 'none')
|
||||
|
||||
losses = losses * maybe_shortcut_loss_weight # handle the (1-t)^2 in eq(7)
|
||||
flow_losses = flow_losses * maybe_shortcut_loss_weight # handle the (1-t)^2 in eq(7)
|
||||
|
||||
# loss weighting with their ramp function
|
||||
|
||||
if exists(self.loss_weight_fn):
|
||||
loss_weight = self.loss_weight_fn(times)
|
||||
losses = losses * loss_weight
|
||||
flow_losses = flow_losses * loss_weight
|
||||
|
||||
return losses.mean()
|
||||
flow_loss = flow_losses.mean()
|
||||
|
||||
# now take care of the agent token losses
|
||||
|
||||
reward_loss = self.zero
|
||||
|
||||
if exists(rewards):
|
||||
reward_pred = self.to_reward_pred(encoded_agent_tokens)
|
||||
reward_loss = F.cross_entropy(reward_pred, two_hot_encoding)
|
||||
|
||||
# gather losses
|
||||
|
||||
total_loss = (
|
||||
flow_loss +
|
||||
reward_loss * self.reward_loss_weight
|
||||
)
|
||||
|
||||
if not return_all_losses:
|
||||
return total_loss
|
||||
|
||||
return total_loss, (flow_loss, reward_loss)
|
||||
|
||||
# dreamer
|
||||
|
||||
@ -1393,6 +1479,6 @@ class Dreamer(Module):
|
||||
self,
|
||||
video_tokenizer: VideoTokenizer,
|
||||
dynamics_model: DynamicsModel,
|
||||
discount_factor = 0.9995
|
||||
discount_factor = 0.997
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -8,20 +8,27 @@ import torch
|
||||
@param('prob_no_shortcut_train', (None, 0., 1.))
|
||||
@param('add_task_embeds', (False, True))
|
||||
@param('num_spatial_tokens', (2, 8))
|
||||
@param('signal_and_step_passed_in', (False, True))
|
||||
@param('add_reward_embed_to_agent_token', (False, True))
|
||||
def test_e2e(
|
||||
pred_orig_latent,
|
||||
grouped_query_attn,
|
||||
dynamics_with_video_input,
|
||||
prob_no_shortcut_train,
|
||||
add_task_embeds,
|
||||
num_spatial_tokens
|
||||
num_spatial_tokens,
|
||||
signal_and_step_passed_in,
|
||||
add_reward_embed_to_agent_token
|
||||
):
|
||||
from dreamer4.dreamer4 import VideoTokenizer, DynamicsModel
|
||||
|
||||
tokenizer = VideoTokenizer(
|
||||
512,
|
||||
dim_latent = 32,
|
||||
16,
|
||||
encoder_depth = 1,
|
||||
decoder_depth = 1,
|
||||
dim_latent = 16,
|
||||
patch_size = 32,
|
||||
attn_dim_head = 16,
|
||||
num_latent_tokens = 4
|
||||
)
|
||||
|
||||
@ -31,28 +38,34 @@ def test_e2e(
|
||||
assert loss.numel() == 1
|
||||
|
||||
latents = tokenizer(video, return_latents = True)
|
||||
assert latents.shape[-1] == 32
|
||||
assert latents.shape[-1] == 16
|
||||
|
||||
query_heads, heads = (16, 4) if grouped_query_attn else (8, 8)
|
||||
|
||||
dynamics = DynamicsModel(
|
||||
512,
|
||||
dim = 16,
|
||||
video_tokenizer = tokenizer,
|
||||
dim_latent = 32,
|
||||
dim_latent = 16,
|
||||
max_steps = 64,
|
||||
num_tasks = 4,
|
||||
num_latent_tokens = 4,
|
||||
depth = 4,
|
||||
num_spatial_tokens = num_spatial_tokens,
|
||||
pred_orig_latent = pred_orig_latent,
|
||||
attn_dim_head = 16,
|
||||
attn_kwargs = dict(
|
||||
heads = heads,
|
||||
query_heads = query_heads
|
||||
query_heads = query_heads,
|
||||
),
|
||||
prob_no_shortcut_train = prob_no_shortcut_train
|
||||
prob_no_shortcut_train = prob_no_shortcut_train,
|
||||
add_reward_embed_to_agent_token = add_reward_embed_to_agent_token
|
||||
)
|
||||
|
||||
signal_levels = torch.randint(0, 500, (2, 4))
|
||||
step_sizes_log2 = torch.randint(1, 6, (2,))
|
||||
signal_levels = step_sizes_log2 = None
|
||||
|
||||
if signal_and_step_passed_in:
|
||||
signal_levels = torch.randint(0, 500, (2, 4))
|
||||
step_sizes_log2 = torch.randint(1, 6, (2,))
|
||||
|
||||
if dynamics_with_video_input:
|
||||
dynamics_input = dict(video = video)
|
||||
@ -72,18 +85,37 @@ def test_e2e(
|
||||
|
||||
assert flow_loss.numel() == 1
|
||||
|
||||
# rl
|
||||
|
||||
rewards = torch.randn((2, 4)) * 100.
|
||||
|
||||
flow_loss = dynamics(
|
||||
**dynamics_input,
|
||||
tasks = tasks,
|
||||
rewards = rewards
|
||||
)
|
||||
|
||||
def test_symexp_two_hot():
|
||||
import torch
|
||||
from dreamer4.dreamer4 import SymExpTwoHot
|
||||
|
||||
two_hot_encoder = SymExpTwoHot((-3., 3.), 20)
|
||||
two_hot_encoder = SymExpTwoHot(
|
||||
(-3., 3.),
|
||||
num_bins = 20,
|
||||
learned_embedding = True,
|
||||
dim_embed = 512
|
||||
)
|
||||
|
||||
values = torch.randn((10))
|
||||
|
||||
encoded = two_hot_encoder(values)
|
||||
recon_values = two_hot_encoder.logits_to_scalar_value(encoded)
|
||||
two_hot_encoded = two_hot_encoder(values)
|
||||
recon_values = two_hot_encoder.bins_to_scalar_value(two_hot_encoded)
|
||||
|
||||
assert torch.allclose(recon_values, values, atol = 1e-6)
|
||||
|
||||
reward_embeds = two_hot_encoder.embed(two_hot_encoded)
|
||||
assert reward_embeds.shape == (10, 512)
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason = 'no cuda')
|
||||
@param('causal', (False, True))
|
||||
@param('softclamp_value', (50., None))
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user