start carving into the reinforcement learning portion, starting with reward prediction head (single for now)
This commit is contained in:
parent
77724049e2
commit
7180a8cf43
@ -29,6 +29,7 @@ from accelerate import Accelerator
|
|||||||
# h - attention heads
|
# h - attention heads
|
||||||
# d - feature dimension
|
# d - feature dimension
|
||||||
# f - frequencies (rotary)
|
# f - frequencies (rotary)
|
||||||
|
# l - logit / predicted bins
|
||||||
# p - positions (3 for spacetime in this work)
|
# p - positions (3 for spacetime in this work)
|
||||||
# t - time
|
# t - time
|
||||||
# g - groups of query heads to key heads (gqa)
|
# g - groups of query heads to key heads (gqa)
|
||||||
@ -176,26 +177,49 @@ class SymExpTwoHot(Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
range = (-20., 20.),
|
range = (-20., 20.),
|
||||||
bins = 255
|
num_bins = 255,
|
||||||
|
learned_embedding = False,
|
||||||
|
dim_embed = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
min_value, max_value = range
|
min_value, max_value = range
|
||||||
values = torch.linspace(min_value, max_value, bins)
|
values = torch.linspace(min_value, max_value, num_bins)
|
||||||
values = values.sign() * (torch.exp(values.abs()) - 1.)
|
values = values.sign() * (torch.exp(values.abs()) - 1.)
|
||||||
|
|
||||||
self.num_bins = bins
|
self.num_bins = num_bins
|
||||||
self.register_buffer('bin_values', values)
|
self.register_buffer('bin_values', values)
|
||||||
|
|
||||||
|
# take care of a reward embedding
|
||||||
|
# for an improvisation where agent tokens can also see the past rewards - it makes sense that this information should not be thrown out, a la Decision Transformer
|
||||||
|
|
||||||
|
self.learned_embedding = learned_embedding
|
||||||
|
|
||||||
|
if learned_embedding:
|
||||||
|
assert exists(dim_embed)
|
||||||
|
self.bin_embeds = nn.Embedding(num_bins, dim_embed)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def device(self):
|
def device(self):
|
||||||
return self.bin_values.device
|
return self.bin_values.device
|
||||||
|
|
||||||
def logits_to_scalar_value(
|
def embed(
|
||||||
self,
|
self,
|
||||||
logits # (... l)
|
two_hot_encoding,
|
||||||
):
|
):
|
||||||
return einsum(logits, self.bin_values, '... l, l -> ...')
|
assert self.learned_embedding, f'can only embed if `learned_embedding` is True'
|
||||||
|
|
||||||
|
weights, bin_indices = two_hot_encoding.topk(k = 2, dim = -1)
|
||||||
|
|
||||||
|
two_embeds = self.bin_embeds(bin_indices)
|
||||||
|
|
||||||
|
return einsum(two_embeds, weights, '... two d, ... two -> ... d')
|
||||||
|
|
||||||
|
def bins_to_scalar_value(
|
||||||
|
self,
|
||||||
|
two_hot_encoding # (... l)
|
||||||
|
):
|
||||||
|
return einsum(two_hot_encoding, self.bin_values, '... l, l -> ...')
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -996,6 +1020,7 @@ class DynamicsModel(Module):
|
|||||||
num_spatial_tokens = 2, # latents projected to greater number of spatial tokens
|
num_spatial_tokens = 2, # latents projected to greater number of spatial tokens
|
||||||
num_latent_tokens = None,
|
num_latent_tokens = None,
|
||||||
num_tasks = 0,
|
num_tasks = 0,
|
||||||
|
reward_encoder_kwargs: dict = dict(),
|
||||||
depth = 4,
|
depth = 4,
|
||||||
pred_orig_latent = True, # directly predicting the original x0 data yield better results, rather than velocity (x-space vs v-space)
|
pred_orig_latent = True, # directly predicting the original x0 data yield better results, rather than velocity (x-space vs v-space)
|
||||||
time_block_every = 4, # every 4th block is time
|
time_block_every = 4, # every 4th block is time
|
||||||
@ -1007,7 +1032,9 @@ class DynamicsModel(Module):
|
|||||||
ff_kwargs: dict = dict(),
|
ff_kwargs: dict = dict(),
|
||||||
loss_weight_fn: Callable = ramp_weight,
|
loss_weight_fn: Callable = ramp_weight,
|
||||||
num_future_predictions = 8, # they do multi-token prediction of 8 steps forward
|
num_future_predictions = 8, # they do multi-token prediction of 8 steps forward
|
||||||
prob_no_shortcut_train = None # probability of no shortcut training, defaults to 1 / num_step_sizes
|
prob_no_shortcut_train = None, # probability of no shortcut training, defaults to 1 / num_step_sizes
|
||||||
|
add_reward_embed_to_agent_token = False,
|
||||||
|
reward_loss_weight = 0.1
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -1024,6 +1051,8 @@ class DynamicsModel(Module):
|
|||||||
# spatial
|
# spatial
|
||||||
|
|
||||||
self.num_latent_tokens = num_latent_tokens
|
self.num_latent_tokens = num_latent_tokens
|
||||||
|
self.dim_latent = dim_latent
|
||||||
|
self.latent_shape = (num_latent_tokens, dim_latent)
|
||||||
|
|
||||||
if num_spatial_tokens >= num_latent_tokens:
|
if num_spatial_tokens >= num_latent_tokens:
|
||||||
assert divisible_by(num_spatial_tokens, num_latent_tokens)
|
assert divisible_by(num_spatial_tokens, num_latent_tokens)
|
||||||
@ -1090,6 +1119,23 @@ class DynamicsModel(Module):
|
|||||||
self.num_tasks = num_tasks
|
self.num_tasks = num_tasks
|
||||||
self.task_embed = nn.Embedding(num_tasks, dim)
|
self.task_embed = nn.Embedding(num_tasks, dim)
|
||||||
|
|
||||||
|
# each agent token will have the reward embedding of the previous time step - but could eventually just give reward its own token
|
||||||
|
|
||||||
|
self.add_reward_embed_to_agent_token = add_reward_embed_to_agent_token
|
||||||
|
|
||||||
|
self.reward_encoder = SymExpTwoHot(
|
||||||
|
**reward_encoder_kwargs,
|
||||||
|
dim_embed = dim,
|
||||||
|
learned_embedding = add_reward_embed_to_agent_token
|
||||||
|
)
|
||||||
|
|
||||||
|
self.to_reward_pred = Sequential(
|
||||||
|
RMSNorm(dim),
|
||||||
|
LinearNoBias(dim, self.reward_encoder.num_bins)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.reward_loss_weight = reward_loss_weight
|
||||||
|
|
||||||
# attention
|
# attention
|
||||||
|
|
||||||
self.attn_softclamp_value = attn_softclamp_value
|
self.attn_softclamp_value = attn_softclamp_value
|
||||||
@ -1122,6 +1168,10 @@ class DynamicsModel(Module):
|
|||||||
self.layers = ModuleList(layers)
|
self.layers = ModuleList(layers)
|
||||||
self.is_time = is_time
|
self.is_time = is_time
|
||||||
|
|
||||||
|
# zero
|
||||||
|
|
||||||
|
self.register_buffer('zero', tensor(0.), persistent = False)
|
||||||
|
|
||||||
def parameter(self):
|
def parameter(self):
|
||||||
params = super().parameters()
|
params = super().parameters()
|
||||||
|
|
||||||
@ -1138,7 +1188,9 @@ class DynamicsModel(Module):
|
|||||||
signal_levels = None, # (b t)
|
signal_levels = None, # (b t)
|
||||||
step_sizes_log2 = None, # (b)
|
step_sizes_log2 = None, # (b)
|
||||||
tasks = None, # (b)
|
tasks = None, # (b)
|
||||||
return_pred_only = False
|
rewards = None, # (b t)
|
||||||
|
return_pred_only = False,
|
||||||
|
return_all_losses = False
|
||||||
):
|
):
|
||||||
# handle video or latents
|
# handle video or latents
|
||||||
|
|
||||||
@ -1152,7 +1204,7 @@ class DynamicsModel(Module):
|
|||||||
if latents.ndim == 3:
|
if latents.ndim == 3:
|
||||||
latents = rearrange(latents, 'b t d -> b t 1 d') # 1 latent edge case
|
latents = rearrange(latents, 'b t d -> b t 1 d') # 1 latent edge case
|
||||||
|
|
||||||
assert latents.shape[-2] == self.num_latent_tokens
|
assert latents.shape[-2:] == self.latent_shape
|
||||||
|
|
||||||
# variables
|
# variables
|
||||||
|
|
||||||
@ -1198,7 +1250,7 @@ class DynamicsModel(Module):
|
|||||||
|
|
||||||
noised_latents = noise.lerp(latents, times)
|
noised_latents = noise.lerp(latents, times)
|
||||||
|
|
||||||
# reinforcementnet learning related
|
# reinforcement learning related
|
||||||
|
|
||||||
agent_tokens = repeat(self.action_learned_embed, 'd -> b d', b = batch)
|
agent_tokens = repeat(self.action_learned_embed, 'd -> b d', b = batch)
|
||||||
|
|
||||||
@ -1208,9 +1260,24 @@ class DynamicsModel(Module):
|
|||||||
task_embeds = self.task_embed(tasks)
|
task_embeds = self.task_embed(tasks)
|
||||||
agent_tokens = agent_tokens + task_embeds
|
agent_tokens = agent_tokens + task_embeds
|
||||||
|
|
||||||
|
# handle agent tokens w/ actions and task embeds
|
||||||
|
|
||||||
|
agent_tokens = repeat(agent_tokens, 'b d -> b t d', t = time)
|
||||||
|
|
||||||
|
# maybe add a reward embedding to agent tokens
|
||||||
|
|
||||||
|
if exists(rewards):
|
||||||
|
two_hot_encoding = self.reward_encoder(rewards)
|
||||||
|
|
||||||
|
if self.add_reward_embed_to_agent_token:
|
||||||
|
reward_embeds = self.reward_encoder.embed(two_hot_encoding)
|
||||||
|
reward_embeds = pad_at_dim(reward_embeds, (1, -1), dim = -2, value = 0.) # shift as each agent token predicts the next reward
|
||||||
|
|
||||||
|
agent_tokens = agent_tokens + reward_embeds
|
||||||
|
|
||||||
# main function, needs to be defined as such for shortcut training - additional calls for consistency loss
|
# main function, needs to be defined as such for shortcut training - additional calls for consistency loss
|
||||||
|
|
||||||
def get_prediction(noised_latents, signal_levels, step_sizes_log2, agent_tokens):
|
def get_prediction(noised_latents, signal_levels, step_sizes_log2, agent_tokens, return_agent_tokens = False):
|
||||||
# latents to spatial tokens
|
# latents to spatial tokens
|
||||||
|
|
||||||
space_tokens = self.latents_to_spatial_tokens(noised_latents)
|
space_tokens = self.latents_to_spatial_tokens(noised_latents)
|
||||||
@ -1234,10 +1301,6 @@ class DynamicsModel(Module):
|
|||||||
flow_token = cat((signal_embed, step_size_embed), dim = -1)
|
flow_token = cat((signal_embed, step_size_embed), dim = -1)
|
||||||
flow_token = rearrange(flow_token, 'b t d -> b t d')
|
flow_token = rearrange(flow_token, 'b t d -> b t d')
|
||||||
|
|
||||||
# handle agent tokens w/ actions and task embeds
|
|
||||||
|
|
||||||
agent_tokens = repeat(agent_tokens, 'b d -> b t d', t = time)
|
|
||||||
|
|
||||||
# pack to tokens for attending
|
# pack to tokens for attending
|
||||||
|
|
||||||
tokens, packed_tokens_shape = pack([flow_token, space_tokens, registers, agent_tokens], 'b t * d')
|
tokens, packed_tokens_shape = pack([flow_token, space_tokens, registers, agent_tokens], 'b t * d')
|
||||||
@ -1297,11 +1360,14 @@ class DynamicsModel(Module):
|
|||||||
|
|
||||||
pred = self.to_latent_pred(space_tokens)
|
pred = self.to_latent_pred(space_tokens)
|
||||||
|
|
||||||
|
if not return_agent_tokens:
|
||||||
return pred
|
return pred
|
||||||
|
|
||||||
|
return pred, agent_tokens
|
||||||
|
|
||||||
# forward the network
|
# forward the network
|
||||||
|
|
||||||
pred = get_prediction(noised_latents, signal_levels, step_sizes_log2, agent_tokens)
|
pred, encoded_agent_tokens = get_prediction(noised_latents, signal_levels, step_sizes_log2, agent_tokens, return_agent_tokens = True)
|
||||||
|
|
||||||
if return_pred_only:
|
if return_pred_only:
|
||||||
return pred
|
return pred
|
||||||
@ -1374,17 +1440,37 @@ class DynamicsModel(Module):
|
|||||||
|
|
||||||
# mse loss
|
# mse loss
|
||||||
|
|
||||||
losses = F.mse_loss(pred, pred_target, reduction = 'none')
|
flow_losses = F.mse_loss(pred, pred_target, reduction = 'none')
|
||||||
|
|
||||||
losses = losses * maybe_shortcut_loss_weight # handle the (1-t)^2 in eq(7)
|
flow_losses = flow_losses * maybe_shortcut_loss_weight # handle the (1-t)^2 in eq(7)
|
||||||
|
|
||||||
# loss weighting with their ramp function
|
# loss weighting with their ramp function
|
||||||
|
|
||||||
if exists(self.loss_weight_fn):
|
if exists(self.loss_weight_fn):
|
||||||
loss_weight = self.loss_weight_fn(times)
|
loss_weight = self.loss_weight_fn(times)
|
||||||
losses = losses * loss_weight
|
flow_losses = flow_losses * loss_weight
|
||||||
|
|
||||||
return losses.mean()
|
flow_loss = flow_losses.mean()
|
||||||
|
|
||||||
|
# now take care of the agent token losses
|
||||||
|
|
||||||
|
reward_loss = self.zero
|
||||||
|
|
||||||
|
if exists(rewards):
|
||||||
|
reward_pred = self.to_reward_pred(encoded_agent_tokens)
|
||||||
|
reward_loss = F.cross_entropy(reward_pred, two_hot_encoding)
|
||||||
|
|
||||||
|
# gather losses
|
||||||
|
|
||||||
|
total_loss = (
|
||||||
|
flow_loss +
|
||||||
|
reward_loss * self.reward_loss_weight
|
||||||
|
)
|
||||||
|
|
||||||
|
if not return_all_losses:
|
||||||
|
return total_loss
|
||||||
|
|
||||||
|
return total_loss, (flow_loss, reward_loss)
|
||||||
|
|
||||||
# dreamer
|
# dreamer
|
||||||
|
|
||||||
@ -1393,6 +1479,6 @@ class Dreamer(Module):
|
|||||||
self,
|
self,
|
||||||
video_tokenizer: VideoTokenizer,
|
video_tokenizer: VideoTokenizer,
|
||||||
dynamics_model: DynamicsModel,
|
dynamics_model: DynamicsModel,
|
||||||
discount_factor = 0.9995
|
discount_factor = 0.997
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|||||||
@ -8,20 +8,27 @@ import torch
|
|||||||
@param('prob_no_shortcut_train', (None, 0., 1.))
|
@param('prob_no_shortcut_train', (None, 0., 1.))
|
||||||
@param('add_task_embeds', (False, True))
|
@param('add_task_embeds', (False, True))
|
||||||
@param('num_spatial_tokens', (2, 8))
|
@param('num_spatial_tokens', (2, 8))
|
||||||
|
@param('signal_and_step_passed_in', (False, True))
|
||||||
|
@param('add_reward_embed_to_agent_token', (False, True))
|
||||||
def test_e2e(
|
def test_e2e(
|
||||||
pred_orig_latent,
|
pred_orig_latent,
|
||||||
grouped_query_attn,
|
grouped_query_attn,
|
||||||
dynamics_with_video_input,
|
dynamics_with_video_input,
|
||||||
prob_no_shortcut_train,
|
prob_no_shortcut_train,
|
||||||
add_task_embeds,
|
add_task_embeds,
|
||||||
num_spatial_tokens
|
num_spatial_tokens,
|
||||||
|
signal_and_step_passed_in,
|
||||||
|
add_reward_embed_to_agent_token
|
||||||
):
|
):
|
||||||
from dreamer4.dreamer4 import VideoTokenizer, DynamicsModel
|
from dreamer4.dreamer4 import VideoTokenizer, DynamicsModel
|
||||||
|
|
||||||
tokenizer = VideoTokenizer(
|
tokenizer = VideoTokenizer(
|
||||||
512,
|
16,
|
||||||
dim_latent = 32,
|
encoder_depth = 1,
|
||||||
|
decoder_depth = 1,
|
||||||
|
dim_latent = 16,
|
||||||
patch_size = 32,
|
patch_size = 32,
|
||||||
|
attn_dim_head = 16,
|
||||||
num_latent_tokens = 4
|
num_latent_tokens = 4
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -31,26 +38,32 @@ def test_e2e(
|
|||||||
assert loss.numel() == 1
|
assert loss.numel() == 1
|
||||||
|
|
||||||
latents = tokenizer(video, return_latents = True)
|
latents = tokenizer(video, return_latents = True)
|
||||||
assert latents.shape[-1] == 32
|
assert latents.shape[-1] == 16
|
||||||
|
|
||||||
query_heads, heads = (16, 4) if grouped_query_attn else (8, 8)
|
query_heads, heads = (16, 4) if grouped_query_attn else (8, 8)
|
||||||
|
|
||||||
dynamics = DynamicsModel(
|
dynamics = DynamicsModel(
|
||||||
512,
|
dim = 16,
|
||||||
video_tokenizer = tokenizer,
|
video_tokenizer = tokenizer,
|
||||||
dim_latent = 32,
|
dim_latent = 16,
|
||||||
max_steps = 64,
|
max_steps = 64,
|
||||||
num_tasks = 4,
|
num_tasks = 4,
|
||||||
num_latent_tokens = 4,
|
num_latent_tokens = 4,
|
||||||
|
depth = 4,
|
||||||
num_spatial_tokens = num_spatial_tokens,
|
num_spatial_tokens = num_spatial_tokens,
|
||||||
pred_orig_latent = pred_orig_latent,
|
pred_orig_latent = pred_orig_latent,
|
||||||
|
attn_dim_head = 16,
|
||||||
attn_kwargs = dict(
|
attn_kwargs = dict(
|
||||||
heads = heads,
|
heads = heads,
|
||||||
query_heads = query_heads
|
query_heads = query_heads,
|
||||||
),
|
),
|
||||||
prob_no_shortcut_train = prob_no_shortcut_train
|
prob_no_shortcut_train = prob_no_shortcut_train,
|
||||||
|
add_reward_embed_to_agent_token = add_reward_embed_to_agent_token
|
||||||
)
|
)
|
||||||
|
|
||||||
|
signal_levels = step_sizes_log2 = None
|
||||||
|
|
||||||
|
if signal_and_step_passed_in:
|
||||||
signal_levels = torch.randint(0, 500, (2, 4))
|
signal_levels = torch.randint(0, 500, (2, 4))
|
||||||
step_sizes_log2 = torch.randint(1, 6, (2,))
|
step_sizes_log2 = torch.randint(1, 6, (2,))
|
||||||
|
|
||||||
@ -72,18 +85,37 @@ def test_e2e(
|
|||||||
|
|
||||||
assert flow_loss.numel() == 1
|
assert flow_loss.numel() == 1
|
||||||
|
|
||||||
|
# rl
|
||||||
|
|
||||||
|
rewards = torch.randn((2, 4)) * 100.
|
||||||
|
|
||||||
|
flow_loss = dynamics(
|
||||||
|
**dynamics_input,
|
||||||
|
tasks = tasks,
|
||||||
|
rewards = rewards
|
||||||
|
)
|
||||||
|
|
||||||
def test_symexp_two_hot():
|
def test_symexp_two_hot():
|
||||||
import torch
|
import torch
|
||||||
from dreamer4.dreamer4 import SymExpTwoHot
|
from dreamer4.dreamer4 import SymExpTwoHot
|
||||||
|
|
||||||
two_hot_encoder = SymExpTwoHot((-3., 3.), 20)
|
two_hot_encoder = SymExpTwoHot(
|
||||||
|
(-3., 3.),
|
||||||
|
num_bins = 20,
|
||||||
|
learned_embedding = True,
|
||||||
|
dim_embed = 512
|
||||||
|
)
|
||||||
|
|
||||||
values = torch.randn((10))
|
values = torch.randn((10))
|
||||||
|
|
||||||
encoded = two_hot_encoder(values)
|
two_hot_encoded = two_hot_encoder(values)
|
||||||
recon_values = two_hot_encoder.logits_to_scalar_value(encoded)
|
recon_values = two_hot_encoder.bins_to_scalar_value(two_hot_encoded)
|
||||||
|
|
||||||
assert torch.allclose(recon_values, values, atol = 1e-6)
|
assert torch.allclose(recon_values, values, atol = 1e-6)
|
||||||
|
|
||||||
|
reward_embeds = two_hot_encoder.embed(two_hot_encoded)
|
||||||
|
assert reward_embeds.shape == (10, 512)
|
||||||
|
|
||||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason = 'no cuda')
|
@pytest.mark.skipif(not torch.cuda.is_available(), reason = 'no cuda')
|
||||||
@param('causal', (False, True))
|
@param('causal', (False, True))
|
||||||
@param('softclamp_value', (50., None))
|
@param('softclamp_value', (50., None))
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user