diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index fff3b4f..de0425c 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -29,6 +29,7 @@ from accelerate import Accelerator # h - attention heads # d - feature dimension # f - frequencies (rotary) +# l - logit / predicted bins # p - positions (3 for spacetime in this work) # t - time # g - groups of query heads to key heads (gqa) @@ -176,26 +177,49 @@ class SymExpTwoHot(Module): def __init__( self, range = (-20., 20.), - bins = 255 + num_bins = 255, + learned_embedding = False, + dim_embed = None, ): super().__init__() min_value, max_value = range - values = torch.linspace(min_value, max_value, bins) + values = torch.linspace(min_value, max_value, num_bins) values = values.sign() * (torch.exp(values.abs()) - 1.) - self.num_bins = bins + self.num_bins = num_bins self.register_buffer('bin_values', values) + # take care of a reward embedding + # for an improvisation where agent tokens can also see the past rewards - it makes sense that this information should not be thrown out, a la Decision Transformer + + self.learned_embedding = learned_embedding + + if learned_embedding: + assert exists(dim_embed) + self.bin_embeds = nn.Embedding(num_bins, dim_embed) + @property def device(self): return self.bin_values.device - def logits_to_scalar_value( + def embed( self, - logits # (... l) + two_hot_encoding, ): - return einsum(logits, self.bin_values, '... l, l -> ...') + assert self.learned_embedding, f'can only embed if `learned_embedding` is True' + + weights, bin_indices = two_hot_encoding.topk(k = 2, dim = -1) + + two_embeds = self.bin_embeds(bin_indices) + + return einsum(two_embeds, weights, '... two d, ... two -> ... d') + + def bins_to_scalar_value( + self, + two_hot_encoding # (... l) + ): + return einsum(two_hot_encoding, self.bin_values, '... l, l -> ...') def forward( self, @@ -996,6 +1020,7 @@ class DynamicsModel(Module): num_spatial_tokens = 2, # latents projected to greater number of spatial tokens num_latent_tokens = None, num_tasks = 0, + reward_encoder_kwargs: dict = dict(), depth = 4, pred_orig_latent = True, # directly predicting the original x0 data yield better results, rather than velocity (x-space vs v-space) time_block_every = 4, # every 4th block is time @@ -1006,8 +1031,10 @@ class DynamicsModel(Module): attn_softclamp_value = 50., ff_kwargs: dict = dict(), loss_weight_fn: Callable = ramp_weight, - num_future_predictions = 8, # they do multi-token prediction of 8 steps forward - prob_no_shortcut_train = None # probability of no shortcut training, defaults to 1 / num_step_sizes + num_future_predictions = 8, # they do multi-token prediction of 8 steps forward + prob_no_shortcut_train = None, # probability of no shortcut training, defaults to 1 / num_step_sizes + add_reward_embed_to_agent_token = False, + reward_loss_weight = 0.1 ): super().__init__() @@ -1024,6 +1051,8 @@ class DynamicsModel(Module): # spatial self.num_latent_tokens = num_latent_tokens + self.dim_latent = dim_latent + self.latent_shape = (num_latent_tokens, dim_latent) if num_spatial_tokens >= num_latent_tokens: assert divisible_by(num_spatial_tokens, num_latent_tokens) @@ -1090,6 +1119,23 @@ class DynamicsModel(Module): self.num_tasks = num_tasks self.task_embed = nn.Embedding(num_tasks, dim) + # each agent token will have the reward embedding of the previous time step - but could eventually just give reward its own token + + self.add_reward_embed_to_agent_token = add_reward_embed_to_agent_token + + self.reward_encoder = SymExpTwoHot( + **reward_encoder_kwargs, + dim_embed = dim, + learned_embedding = add_reward_embed_to_agent_token + ) + + self.to_reward_pred = Sequential( + RMSNorm(dim), + LinearNoBias(dim, self.reward_encoder.num_bins) + ) + + self.reward_loss_weight = reward_loss_weight + # attention self.attn_softclamp_value = attn_softclamp_value @@ -1122,6 +1168,10 @@ class DynamicsModel(Module): self.layers = ModuleList(layers) self.is_time = is_time + # zero + + self.register_buffer('zero', tensor(0.), persistent = False) + def parameter(self): params = super().parameters() @@ -1138,7 +1188,9 @@ class DynamicsModel(Module): signal_levels = None, # (b t) step_sizes_log2 = None, # (b) tasks = None, # (b) - return_pred_only = False + rewards = None, # (b t) + return_pred_only = False, + return_all_losses = False ): # handle video or latents @@ -1152,7 +1204,7 @@ class DynamicsModel(Module): if latents.ndim == 3: latents = rearrange(latents, 'b t d -> b t 1 d') # 1 latent edge case - assert latents.shape[-2] == self.num_latent_tokens + assert latents.shape[-2:] == self.latent_shape # variables @@ -1198,7 +1250,7 @@ class DynamicsModel(Module): noised_latents = noise.lerp(latents, times) - # reinforcementnet learning related + # reinforcement learning related agent_tokens = repeat(self.action_learned_embed, 'd -> b d', b = batch) @@ -1208,9 +1260,24 @@ class DynamicsModel(Module): task_embeds = self.task_embed(tasks) agent_tokens = agent_tokens + task_embeds + # handle agent tokens w/ actions and task embeds + + agent_tokens = repeat(agent_tokens, 'b d -> b t d', t = time) + + # maybe add a reward embedding to agent tokens + + if exists(rewards): + two_hot_encoding = self.reward_encoder(rewards) + + if self.add_reward_embed_to_agent_token: + reward_embeds = self.reward_encoder.embed(two_hot_encoding) + reward_embeds = pad_at_dim(reward_embeds, (1, -1), dim = -2, value = 0.) # shift as each agent token predicts the next reward + + agent_tokens = agent_tokens + reward_embeds + # main function, needs to be defined as such for shortcut training - additional calls for consistency loss - def get_prediction(noised_latents, signal_levels, step_sizes_log2, agent_tokens): + def get_prediction(noised_latents, signal_levels, step_sizes_log2, agent_tokens, return_agent_tokens = False): # latents to spatial tokens space_tokens = self.latents_to_spatial_tokens(noised_latents) @@ -1234,10 +1301,6 @@ class DynamicsModel(Module): flow_token = cat((signal_embed, step_size_embed), dim = -1) flow_token = rearrange(flow_token, 'b t d -> b t d') - # handle agent tokens w/ actions and task embeds - - agent_tokens = repeat(agent_tokens, 'b d -> b t d', t = time) - # pack to tokens for attending tokens, packed_tokens_shape = pack([flow_token, space_tokens, registers, agent_tokens], 'b t * d') @@ -1297,11 +1360,14 @@ class DynamicsModel(Module): pred = self.to_latent_pred(space_tokens) - return pred + if not return_agent_tokens: + return pred + + return pred, agent_tokens # forward the network - pred = get_prediction(noised_latents, signal_levels, step_sizes_log2, agent_tokens) + pred, encoded_agent_tokens = get_prediction(noised_latents, signal_levels, step_sizes_log2, agent_tokens, return_agent_tokens = True) if return_pred_only: return pred @@ -1374,17 +1440,37 @@ class DynamicsModel(Module): # mse loss - losses = F.mse_loss(pred, pred_target, reduction = 'none') + flow_losses = F.mse_loss(pred, pred_target, reduction = 'none') - losses = losses * maybe_shortcut_loss_weight # handle the (1-t)^2 in eq(7) + flow_losses = flow_losses * maybe_shortcut_loss_weight # handle the (1-t)^2 in eq(7) # loss weighting with their ramp function if exists(self.loss_weight_fn): loss_weight = self.loss_weight_fn(times) - losses = losses * loss_weight + flow_losses = flow_losses * loss_weight - return losses.mean() + flow_loss = flow_losses.mean() + + # now take care of the agent token losses + + reward_loss = self.zero + + if exists(rewards): + reward_pred = self.to_reward_pred(encoded_agent_tokens) + reward_loss = F.cross_entropy(reward_pred, two_hot_encoding) + + # gather losses + + total_loss = ( + flow_loss + + reward_loss * self.reward_loss_weight + ) + + if not return_all_losses: + return total_loss + + return total_loss, (flow_loss, reward_loss) # dreamer @@ -1393,6 +1479,6 @@ class Dreamer(Module): self, video_tokenizer: VideoTokenizer, dynamics_model: DynamicsModel, - discount_factor = 0.9995 + discount_factor = 0.997 ): super().__init__() diff --git a/tests/test_dreamer.py b/tests/test_dreamer.py index d11cc53..696af1c 100644 --- a/tests/test_dreamer.py +++ b/tests/test_dreamer.py @@ -8,20 +8,27 @@ import torch @param('prob_no_shortcut_train', (None, 0., 1.)) @param('add_task_embeds', (False, True)) @param('num_spatial_tokens', (2, 8)) +@param('signal_and_step_passed_in', (False, True)) +@param('add_reward_embed_to_agent_token', (False, True)) def test_e2e( pred_orig_latent, grouped_query_attn, dynamics_with_video_input, prob_no_shortcut_train, add_task_embeds, - num_spatial_tokens + num_spatial_tokens, + signal_and_step_passed_in, + add_reward_embed_to_agent_token ): from dreamer4.dreamer4 import VideoTokenizer, DynamicsModel tokenizer = VideoTokenizer( - 512, - dim_latent = 32, + 16, + encoder_depth = 1, + decoder_depth = 1, + dim_latent = 16, patch_size = 32, + attn_dim_head = 16, num_latent_tokens = 4 ) @@ -31,28 +38,34 @@ def test_e2e( assert loss.numel() == 1 latents = tokenizer(video, return_latents = True) - assert latents.shape[-1] == 32 + assert latents.shape[-1] == 16 query_heads, heads = (16, 4) if grouped_query_attn else (8, 8) dynamics = DynamicsModel( - 512, + dim = 16, video_tokenizer = tokenizer, - dim_latent = 32, + dim_latent = 16, max_steps = 64, num_tasks = 4, num_latent_tokens = 4, + depth = 4, num_spatial_tokens = num_spatial_tokens, pred_orig_latent = pred_orig_latent, + attn_dim_head = 16, attn_kwargs = dict( heads = heads, - query_heads = query_heads + query_heads = query_heads, ), - prob_no_shortcut_train = prob_no_shortcut_train + prob_no_shortcut_train = prob_no_shortcut_train, + add_reward_embed_to_agent_token = add_reward_embed_to_agent_token ) - signal_levels = torch.randint(0, 500, (2, 4)) - step_sizes_log2 = torch.randint(1, 6, (2,)) + signal_levels = step_sizes_log2 = None + + if signal_and_step_passed_in: + signal_levels = torch.randint(0, 500, (2, 4)) + step_sizes_log2 = torch.randint(1, 6, (2,)) if dynamics_with_video_input: dynamics_input = dict(video = video) @@ -72,18 +85,37 @@ def test_e2e( assert flow_loss.numel() == 1 + # rl + + rewards = torch.randn((2, 4)) * 100. + + flow_loss = dynamics( + **dynamics_input, + tasks = tasks, + rewards = rewards + ) + def test_symexp_two_hot(): import torch from dreamer4.dreamer4 import SymExpTwoHot - two_hot_encoder = SymExpTwoHot((-3., 3.), 20) + two_hot_encoder = SymExpTwoHot( + (-3., 3.), + num_bins = 20, + learned_embedding = True, + dim_embed = 512 + ) + values = torch.randn((10)) - encoded = two_hot_encoder(values) - recon_values = two_hot_encoder.logits_to_scalar_value(encoded) + two_hot_encoded = two_hot_encoder(values) + recon_values = two_hot_encoder.bins_to_scalar_value(two_hot_encoded) assert torch.allclose(recon_values, values, atol = 1e-6) + reward_embeds = two_hot_encoder.embed(two_hot_encoded) + assert reward_embeds.shape == (10, 512) + @pytest.mark.skipif(not torch.cuda.is_available(), reason = 'no cuda') @param('causal', (False, True)) @param('softclamp_value', (50., None))