diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 4112820..c6f66c0 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -1220,6 +1220,22 @@ class DynamicsModel(Module): self.register_buffer('zero', tensor(0.), persistent = False) + @property + def device(self): + return self.zero.device + + def get_times_from_signal_level( + self, + signal_levels, + align_dims_left_to = None + ): + times = signal_levels.float() / self.max_steps + + if not exists(align_dims_left_to): + return times + + return align_dims_left(times, align_dims_left_to) + def parameter(self): params = super().parameters() @@ -1228,25 +1244,87 @@ class DynamicsModel(Module): return list(set(params) - set(self.video_tokenizer.parameters())) + @torch.no_grad() def generate( self, - num_frames, + time_steps, num_steps = 4, + batch_size = 1, image_height = None, - image_width = None + image_width = None, + return_decoded_video = None ): # (b t n d) | (b c t h w) - assert log(num_steps).is_integer(), f'number of steps must be a power of 2' + assert log2(num_steps).is_integer(), f'number of steps {num_steps} must be a power of 2' + assert 0 < num_steps <= self.max_steps, f'number of steps {num_steps} must be between 0 and {self.max_steps}' - raise NotImplementedError + latent_shape = self.latent_shape + + # derive step size + + step_size = self.max_steps // num_steps + step_size_log2 = tensor(log2(step_size), dtype = torch.long, device = self.device) + + # denoising + # teacher forcing to start with + + latents = torch.empty((batch_size, 0, *latent_shape), device = self.device) + + while latents.shape[1] < time_steps: + + noised_latent = torch.randn((batch_size, 1, *latent_shape), device = self.device) + + for step in range(num_steps): + signal_level = tensor(step * step_size, device = self.device) + + noised_latent_with_context, pack_context_shape = pack((latents, noised_latent), 'b * n d') + + pred = self.forward( + latents = noised_latent_with_context, + signal_levels = signal_level, + step_sizes_log2 = step_size_log2, + return_pred_only = True + ) + + _, pred = unpack(pred, pack_context_shape, 'b * n d') + + # derive flow, based on whether in x-space or not + + if self.pred_orig_latent: + times = self.get_times_from_signal_level(signal_level, noised_latent) + flow = (pred - noised_latent) / (1. - times) + else: + flow = pred + + # denoise + + noised_latent += flow * (step_size / self.max_steps) + + latents = cat((latents, noised_latent), dim = 1) + + # returning video + + has_tokenizer = exists(self.video_tokenizer) + return_decoded_video = default(return_decoded_video, has_tokenizer) + + if not return_decoded_video: + return denoised_latents + + generated_video = self.video_tokenizer.decode( + latents, + height = image_height, + width = image_width + ) + + return generated_video def forward( self, *, video = None, latents = None, # (b t n d) | (b t d) - signal_levels = None, # (b t) - step_sizes_log2 = None, # (b) + signal_levels = None, # () | (b) | (b t) + step_sizes_log2 = None, # () | (b) tasks = None, # (b) rewards = None, # (b t) return_pred_only = False, @@ -1270,6 +1348,18 @@ class DynamicsModel(Module): batch, time, device = *latents.shape[:2], latents.device + # shape related + + if exists(signal_levels) and signal_levels.ndim == 0: + signal_levels = repeat(signal_levels, '-> b', b = batch) + + if exists(step_sizes_log2): + if step_sizes_log2.ndim == 0: + step_sizes_log2 = repeat(step_sizes_log2, '-> b', b = batch) + + if step_sizes_log2.ndim == 1: + step_sizes_log2 = repeat(step_sizes_log2, 'b -> b t', t = time) + # flow related assert not (exists(signal_levels) ^ exists(step_sizes_log2)) @@ -1300,11 +1390,7 @@ class DynamicsModel(Module): # times is from 0 to 1 - def get_times_from_signal_level(signal_levels): - times = signal_levels.float() / self.max_steps - return align_dims_left(times, latents) - - times = get_times_from_signal_level(signal_levels) + times = self.get_times_from_signal_level(signal_levels, latents) # noise from 0 as noise to 1 as data @@ -1468,7 +1554,7 @@ class DynamicsModel(Module): if is_v_space_pred: first_step_pred_flow = first_step_pred else: - first_times = get_times_from_signal_level(signal_levels) + first_times = self.get_times_from_signal_level(signal_levels, noised_latents) first_step_pred_flow = (first_step_pred - noised_latents) / (1. - first_times) # take a half step @@ -1485,7 +1571,7 @@ class DynamicsModel(Module): if is_v_space_pred: second_step_pred_flow = second_step_pred else: - second_times = get_times_from_signal_level(signal_levels_plus_half_step) + second_times = self.get_times_from_signal_level(signal_levels_plus_half_step, denoised_latent) second_step_pred_flow = (second_step_pred - denoised_latent) / (1. - second_times) # pred target is sg(b' + b'') / 2 diff --git a/tests/test_dreamer.py b/tests/test_dreamer.py index 888885c..05299b3 100644 --- a/tests/test_dreamer.py +++ b/tests/test_dreamer.py @@ -88,6 +88,17 @@ def test_e2e( assert flow_loss.numel() == 1 + # generating + + generated_video = dynamics.generate( + time_steps = 10, + image_height = 128, + image_width = 128, + batch_size = 2 + ) + + assert generated_video.shape == (2, 3, 10, 128, 128) + # rl rewards = torch.randn((2, 4)) * 100.