diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index e83eb97..2fab2bf 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -70,12 +70,14 @@ except ImportError: LinearNoBias = partial(Linear, bias = False) -TokenizerLosses = namedtuple('TokenizerLosses', ('recon', 'lpips', 'time_decorr')) +TokenizerLosses = namedtuple('TokenizerLosses', ('recon', 'lpips', 'time_decorr', 'space_decorr')) WorldModelLosses = namedtuple('WorldModelLosses', ('flow', 'rewards', 'discrete_actions', 'continuous_actions')) AttentionIntermediates = namedtuple('AttentionIntermediates', ('next_kv_cache', 'normed_inputs')) +TransformerIntermediates = namedtuple('TransformerIntermediates', ('next_kv_cache', 'normed_time_inputs', 'normed_space_inputs')) + MaybeTensor = Tensor | None @dataclass @@ -1540,6 +1542,7 @@ class AxialSpaceTimeTransformer(Module): # normed attention inputs normed_time_attn_inputs = [] + normed_space_attn_inputs = [] # attention @@ -1582,8 +1585,9 @@ class AxialSpaceTimeTransformer(Module): # save time attention inputs for decorr - if layer_is_time: - normed_time_attn_inputs.append(attn_intermediates.normed_inputs) + space_or_time_inputs = normed_time_attn_inputs if layer_is_time else normed_space_attn_inputs + + space_or_time_inputs.append(attn_intermediates.normed_inputs) tokens = self.reduce_streams(tokens) @@ -1596,9 +1600,10 @@ class AxialSpaceTimeTransformer(Module): if not return_intermediates: return out - intermediates = ( + intermediates = TransformerIntermediates( stack(time_attn_kv_caches), - stack(normed_time_attn_inputs) + stack(normed_time_attn_inputs), + stack(normed_space_attn_inputs) ) return out, intermediates @@ -1851,7 +1856,7 @@ class VideoTokenizer(Module): # encoder attention - tokens, (_, time_attn_normed_inputs) = self.encoder_transformer(tokens, return_intermediates = True) + tokens, (_, time_attn_normed_inputs, space_attn_normed_inputs) = self.encoder_transformer(tokens, return_intermediates = True) # latent bottleneck @@ -1873,17 +1878,19 @@ class VideoTokenizer(Module): if self.has_lpips_loss: lpips_loss = self.lpips(video, recon_video) - decorr_loss = self.zero + time_decorr_loss = space_decorr_loss = self.zero if self.encoder_add_decor_aux_loss: - decorr_loss = self.decorr_loss(time_attn_normed_inputs) + time_decorr_loss = self.decorr_loss(time_attn_normed_inputs) + space_decorr_loss = self.decorr_loss(space_attn_normed_inputs) # losses total_loss = ( recon_loss + lpips_loss * self.lpips_loss_weight + - decorr_loss * self.decorr_aux_loss_weight + time_decorr_loss * self.decorr_aux_loss_weight + + space_decorr_loss * self.decorr_aux_loss_weight ) if not return_all_losses: @@ -3370,7 +3377,7 @@ class DynamicsWorldModel(Module): # attention - tokens, (next_time_kv_cache, _) = self.transformer(tokens, kv_cache = time_kv_cache, return_intermediates = True) + tokens, (next_time_kv_cache, *_) = self.transformer(tokens, kv_cache = time_kv_cache, return_intermediates = True) # unpack diff --git a/pyproject.toml b/pyproject.toml index fc3047b..4e1beec 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.1.8" +version = "0.1.10" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }