diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 64c62f0..f0dc400 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -35,6 +35,7 @@ from assoc_scan import AssocScan # d - feature dimension # f - frequencies (rotary) # l - logit / predicted bins +# y - layers of transformer # p - positions (3 for spacetime in this work) # t - time # na - action dimension (number of discrete and continuous actions) @@ -1266,8 +1267,12 @@ class AxialSpaceTimeTransformer(Module): def forward( self, - tokens # (b t s d) - ): + tokens, # (b t s d) + kv_cache: Tensor | None = None, # (y 2 b h t d) + return_kv_cache = False + + ): # (b t s d) | (y 2 b h t d) + batch, time, space_seq_len, _, device = *tokens.shape, tokens.device assert tokens.ndim == 4 @@ -1286,6 +1291,13 @@ class AxialSpaceTimeTransformer(Module): rotary_pos_emb = self.time_rotary(time) + # prepare cache + + time_attn_kv_caches = [] + + kv_cache = default(kv_cache, (None,)) + iter_kv_cache = iter(kv_cache) + # attention tokens = self.expand_streams(tokens) @@ -1302,17 +1314,27 @@ class AxialSpaceTimeTransformer(Module): # attention layer - tokens = attn(tokens, rotary_pos_emb = layer_rotary_pos_emb, attend_fn = attend_fn) + tokens + tokens, kv_cache = attn(tokens, rotary_pos_emb = layer_rotary_pos_emb, attend_fn = attend_fn, kv_cache = next(iter_kv_cache, None), return_kv_cache = True) tokens = post_attn_rearrange(tokens) # feedforward layer - tokens = ff(tokens) + tokens + tokens = ff(tokens) + + # save kv cache if is time layer + + if layer_is_time: + time_attn_kv_caches.append(kv_cache) tokens = self.reduce_streams(tokens) - return self.final_norm(tokens) + out = self.final_norm(tokens) + + if not return_kv_cache: + return out + + return out, stack(time_attn_kv_caches) # video tokenizer @@ -2347,7 +2369,7 @@ class DynamicsWorldModel(Module): # main function, needs to be defined as such for shortcut training - additional calls for consistency loss - def get_prediction(noised_latents, signal_levels, step_sizes_log2, action_tokens, reward_tokens, agent_tokens, return_agent_tokens = False): + def get_prediction(noised_latents, signal_levels, step_sizes_log2, action_tokens, reward_tokens, agent_tokens, return_agent_tokens = False, return_time_kv_cache = False): # latents to spatial tokens space_tokens = self.latents_to_spatial_tokens(noised_latents) @@ -2385,7 +2407,7 @@ class DynamicsWorldModel(Module): # attention - tokens = self.transformer(tokens) + tokens, time_kv_cache = self.transformer(tokens, return_kv_cache = True) # unpack diff --git a/pyproject.toml b/pyproject.toml index 38418a4..77dd976 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.0.43" +version = "0.0.44" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }