start carving into the reinforcement learning portion, starting with reward prediction head (single for now)

This commit is contained in:
lucidrains 2025-10-06 11:17:25 -07:00
parent 77724049e2
commit 7180a8cf43
2 changed files with 154 additions and 36 deletions

View File

@ -29,6 +29,7 @@ from accelerate import Accelerator
# h - attention heads
# d - feature dimension
# f - frequencies (rotary)
# l - logit / predicted bins
# p - positions (3 for spacetime in this work)
# t - time
# g - groups of query heads to key heads (gqa)
@ -176,26 +177,49 @@ class SymExpTwoHot(Module):
def __init__(
self,
range = (-20., 20.),
bins = 255
num_bins = 255,
learned_embedding = False,
dim_embed = None,
):
super().__init__()
min_value, max_value = range
values = torch.linspace(min_value, max_value, bins)
values = torch.linspace(min_value, max_value, num_bins)
values = values.sign() * (torch.exp(values.abs()) - 1.)
self.num_bins = bins
self.num_bins = num_bins
self.register_buffer('bin_values', values)
# take care of a reward embedding
# for an improvisation where agent tokens can also see the past rewards - it makes sense that this information should not be thrown out, a la Decision Transformer
self.learned_embedding = learned_embedding
if learned_embedding:
assert exists(dim_embed)
self.bin_embeds = nn.Embedding(num_bins, dim_embed)
@property
def device(self):
return self.bin_values.device
def logits_to_scalar_value(
def embed(
self,
logits # (... l)
two_hot_encoding,
):
return einsum(logits, self.bin_values, '... l, l -> ...')
assert self.learned_embedding, f'can only embed if `learned_embedding` is True'
weights, bin_indices = two_hot_encoding.topk(k = 2, dim = -1)
two_embeds = self.bin_embeds(bin_indices)
return einsum(two_embeds, weights, '... two d, ... two -> ... d')
def bins_to_scalar_value(
self,
two_hot_encoding # (... l)
):
return einsum(two_hot_encoding, self.bin_values, '... l, l -> ...')
def forward(
self,
@ -996,6 +1020,7 @@ class DynamicsModel(Module):
num_spatial_tokens = 2, # latents projected to greater number of spatial tokens
num_latent_tokens = None,
num_tasks = 0,
reward_encoder_kwargs: dict = dict(),
depth = 4,
pred_orig_latent = True, # directly predicting the original x0 data yield better results, rather than velocity (x-space vs v-space)
time_block_every = 4, # every 4th block is time
@ -1006,8 +1031,10 @@ class DynamicsModel(Module):
attn_softclamp_value = 50.,
ff_kwargs: dict = dict(),
loss_weight_fn: Callable = ramp_weight,
num_future_predictions = 8, # they do multi-token prediction of 8 steps forward
prob_no_shortcut_train = None # probability of no shortcut training, defaults to 1 / num_step_sizes
num_future_predictions = 8, # they do multi-token prediction of 8 steps forward
prob_no_shortcut_train = None, # probability of no shortcut training, defaults to 1 / num_step_sizes
add_reward_embed_to_agent_token = False,
reward_loss_weight = 0.1
):
super().__init__()
@ -1024,6 +1051,8 @@ class DynamicsModel(Module):
# spatial
self.num_latent_tokens = num_latent_tokens
self.dim_latent = dim_latent
self.latent_shape = (num_latent_tokens, dim_latent)
if num_spatial_tokens >= num_latent_tokens:
assert divisible_by(num_spatial_tokens, num_latent_tokens)
@ -1090,6 +1119,23 @@ class DynamicsModel(Module):
self.num_tasks = num_tasks
self.task_embed = nn.Embedding(num_tasks, dim)
# each agent token will have the reward embedding of the previous time step - but could eventually just give reward its own token
self.add_reward_embed_to_agent_token = add_reward_embed_to_agent_token
self.reward_encoder = SymExpTwoHot(
**reward_encoder_kwargs,
dim_embed = dim,
learned_embedding = add_reward_embed_to_agent_token
)
self.to_reward_pred = Sequential(
RMSNorm(dim),
LinearNoBias(dim, self.reward_encoder.num_bins)
)
self.reward_loss_weight = reward_loss_weight
# attention
self.attn_softclamp_value = attn_softclamp_value
@ -1122,6 +1168,10 @@ class DynamicsModel(Module):
self.layers = ModuleList(layers)
self.is_time = is_time
# zero
self.register_buffer('zero', tensor(0.), persistent = False)
def parameter(self):
params = super().parameters()
@ -1138,7 +1188,9 @@ class DynamicsModel(Module):
signal_levels = None, # (b t)
step_sizes_log2 = None, # (b)
tasks = None, # (b)
return_pred_only = False
rewards = None, # (b t)
return_pred_only = False,
return_all_losses = False
):
# handle video or latents
@ -1152,7 +1204,7 @@ class DynamicsModel(Module):
if latents.ndim == 3:
latents = rearrange(latents, 'b t d -> b t 1 d') # 1 latent edge case
assert latents.shape[-2] == self.num_latent_tokens
assert latents.shape[-2:] == self.latent_shape
# variables
@ -1198,7 +1250,7 @@ class DynamicsModel(Module):
noised_latents = noise.lerp(latents, times)
# reinforcementnet learning related
# reinforcement learning related
agent_tokens = repeat(self.action_learned_embed, 'd -> b d', b = batch)
@ -1208,9 +1260,24 @@ class DynamicsModel(Module):
task_embeds = self.task_embed(tasks)
agent_tokens = agent_tokens + task_embeds
# handle agent tokens w/ actions and task embeds
agent_tokens = repeat(agent_tokens, 'b d -> b t d', t = time)
# maybe add a reward embedding to agent tokens
if exists(rewards):
two_hot_encoding = self.reward_encoder(rewards)
if self.add_reward_embed_to_agent_token:
reward_embeds = self.reward_encoder.embed(two_hot_encoding)
reward_embeds = pad_at_dim(reward_embeds, (1, -1), dim = -2, value = 0.) # shift as each agent token predicts the next reward
agent_tokens = agent_tokens + reward_embeds
# main function, needs to be defined as such for shortcut training - additional calls for consistency loss
def get_prediction(noised_latents, signal_levels, step_sizes_log2, agent_tokens):
def get_prediction(noised_latents, signal_levels, step_sizes_log2, agent_tokens, return_agent_tokens = False):
# latents to spatial tokens
space_tokens = self.latents_to_spatial_tokens(noised_latents)
@ -1234,10 +1301,6 @@ class DynamicsModel(Module):
flow_token = cat((signal_embed, step_size_embed), dim = -1)
flow_token = rearrange(flow_token, 'b t d -> b t d')
# handle agent tokens w/ actions and task embeds
agent_tokens = repeat(agent_tokens, 'b d -> b t d', t = time)
# pack to tokens for attending
tokens, packed_tokens_shape = pack([flow_token, space_tokens, registers, agent_tokens], 'b t * d')
@ -1297,11 +1360,14 @@ class DynamicsModel(Module):
pred = self.to_latent_pred(space_tokens)
return pred
if not return_agent_tokens:
return pred
return pred, agent_tokens
# forward the network
pred = get_prediction(noised_latents, signal_levels, step_sizes_log2, agent_tokens)
pred, encoded_agent_tokens = get_prediction(noised_latents, signal_levels, step_sizes_log2, agent_tokens, return_agent_tokens = True)
if return_pred_only:
return pred
@ -1374,17 +1440,37 @@ class DynamicsModel(Module):
# mse loss
losses = F.mse_loss(pred, pred_target, reduction = 'none')
flow_losses = F.mse_loss(pred, pred_target, reduction = 'none')
losses = losses * maybe_shortcut_loss_weight # handle the (1-t)^2 in eq(7)
flow_losses = flow_losses * maybe_shortcut_loss_weight # handle the (1-t)^2 in eq(7)
# loss weighting with their ramp function
if exists(self.loss_weight_fn):
loss_weight = self.loss_weight_fn(times)
losses = losses * loss_weight
flow_losses = flow_losses * loss_weight
return losses.mean()
flow_loss = flow_losses.mean()
# now take care of the agent token losses
reward_loss = self.zero
if exists(rewards):
reward_pred = self.to_reward_pred(encoded_agent_tokens)
reward_loss = F.cross_entropy(reward_pred, two_hot_encoding)
# gather losses
total_loss = (
flow_loss +
reward_loss * self.reward_loss_weight
)
if not return_all_losses:
return total_loss
return total_loss, (flow_loss, reward_loss)
# dreamer
@ -1393,6 +1479,6 @@ class Dreamer(Module):
self,
video_tokenizer: VideoTokenizer,
dynamics_model: DynamicsModel,
discount_factor = 0.9995
discount_factor = 0.997
):
super().__init__()

View File

@ -8,20 +8,27 @@ import torch
@param('prob_no_shortcut_train', (None, 0., 1.))
@param('add_task_embeds', (False, True))
@param('num_spatial_tokens', (2, 8))
@param('signal_and_step_passed_in', (False, True))
@param('add_reward_embed_to_agent_token', (False, True))
def test_e2e(
pred_orig_latent,
grouped_query_attn,
dynamics_with_video_input,
prob_no_shortcut_train,
add_task_embeds,
num_spatial_tokens
num_spatial_tokens,
signal_and_step_passed_in,
add_reward_embed_to_agent_token
):
from dreamer4.dreamer4 import VideoTokenizer, DynamicsModel
tokenizer = VideoTokenizer(
512,
dim_latent = 32,
16,
encoder_depth = 1,
decoder_depth = 1,
dim_latent = 16,
patch_size = 32,
attn_dim_head = 16,
num_latent_tokens = 4
)
@ -31,28 +38,34 @@ def test_e2e(
assert loss.numel() == 1
latents = tokenizer(video, return_latents = True)
assert latents.shape[-1] == 32
assert latents.shape[-1] == 16
query_heads, heads = (16, 4) if grouped_query_attn else (8, 8)
dynamics = DynamicsModel(
512,
dim = 16,
video_tokenizer = tokenizer,
dim_latent = 32,
dim_latent = 16,
max_steps = 64,
num_tasks = 4,
num_latent_tokens = 4,
depth = 4,
num_spatial_tokens = num_spatial_tokens,
pred_orig_latent = pred_orig_latent,
attn_dim_head = 16,
attn_kwargs = dict(
heads = heads,
query_heads = query_heads
query_heads = query_heads,
),
prob_no_shortcut_train = prob_no_shortcut_train
prob_no_shortcut_train = prob_no_shortcut_train,
add_reward_embed_to_agent_token = add_reward_embed_to_agent_token
)
signal_levels = torch.randint(0, 500, (2, 4))
step_sizes_log2 = torch.randint(1, 6, (2,))
signal_levels = step_sizes_log2 = None
if signal_and_step_passed_in:
signal_levels = torch.randint(0, 500, (2, 4))
step_sizes_log2 = torch.randint(1, 6, (2,))
if dynamics_with_video_input:
dynamics_input = dict(video = video)
@ -72,18 +85,37 @@ def test_e2e(
assert flow_loss.numel() == 1
# rl
rewards = torch.randn((2, 4)) * 100.
flow_loss = dynamics(
**dynamics_input,
tasks = tasks,
rewards = rewards
)
def test_symexp_two_hot():
import torch
from dreamer4.dreamer4 import SymExpTwoHot
two_hot_encoder = SymExpTwoHot((-3., 3.), 20)
two_hot_encoder = SymExpTwoHot(
(-3., 3.),
num_bins = 20,
learned_embedding = True,
dim_embed = 512
)
values = torch.randn((10))
encoded = two_hot_encoder(values)
recon_values = two_hot_encoder.logits_to_scalar_value(encoded)
two_hot_encoded = two_hot_encoder(values)
recon_values = two_hot_encoder.bins_to_scalar_value(two_hot_encoded)
assert torch.allclose(recon_values, values, atol = 1e-6)
reward_embeds = two_hot_encoder.embed(two_hot_encoded)
assert reward_embeds.shape == (10, 512)
@pytest.mark.skipif(not torch.cuda.is_available(), reason = 'no cuda')
@param('causal', (False, True))
@param('softclamp_value', (50., None))