diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index c60ae67..4685539 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -71,6 +71,8 @@ except ImportError: LinearNoBias = partial(Linear, bias = False) +VideoTokenizerIntermediates = namedtuple('VideoTokenizerIntermediates', ('losses', 'recon')) + TokenizerLosses = namedtuple('TokenizerLosses', ('recon', 'lpips', 'time_decorr', 'space_decorr')) WorldModelLosses = namedtuple('WorldModelLosses', ('flow', 'rewards', 'discrete_actions', 'continuous_actions', 'state_pred')) @@ -1940,11 +1942,23 @@ class VideoTokenizer(Module): def forward( self, - video, # (b c t h w) + video_or_image, # (b c t h w) | (b c h w) return_latents = False, mask_patches = None, - return_all_losses = False + return_intermediates = False, ): + + # handle image pretraining + + is_image = video_or_image.ndim == 4 + + if is_image: + video = rearrange(video_or_image, 'b c h w -> b c 1 h w') + else: + video = video_or_image + + # shapes + batch, _, time, height, width = video.shape patch_size, device = self.patch_size, video.device @@ -2024,12 +2038,21 @@ class VideoTokenizer(Module): space_decorr_loss * self.decorr_aux_loss_weight ) - if not return_all_losses: + if not return_intermediates: return total_loss - losses = (recon_loss, lpips_loss, decorr_loss) + losses = TokenizerLosses(recon_loss, lpips_loss, time_decorr_loss, space_decorr_loss) - return total_loss, TokenizerLosses(*losses) + out = losses + + # handle returning of reconstructed, and image pretraining + + if is_image: + recon_video = rearrange(recon_video, 'b c 1 h w -> b c h w') + + out = (losses, recon_video) + + return total_loss, out # dynamics model, axial space-time transformer diff --git a/pyproject.toml b/pyproject.toml index 5dca4f0..7dd874a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.1.23" +version = "0.1.24" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } diff --git a/tests/test_dreamer.py b/tests/test_dreamer.py index 06c2bec..c0b1579 100644 --- a/tests/test_dreamer.py +++ b/tests/test_dreamer.py @@ -810,3 +810,22 @@ def test_epo(): fitness = torch.randn(16,) dynamics.evolve_(fitness) + +def test_images_to_video_tokenizer(): + import torch + from dreamer4 import VideoTokenizer, DynamicsWorldModel, AxialSpaceTimeTransformer + + tokenizer = VideoTokenizer( + dim = 512, + dim_latent = 32, + patch_size = 32, + image_height = 256, + image_width = 256, + encoder_add_decor_aux_loss = True + ) + + images = torch.randn(2, 3, 256, 256) + loss, (losses, recon_images) = tokenizer(images, return_intermediates = True) + loss.backward() + + assert images.shape == recon_images.shape