From ea13d4fcabb91504cbf9071791155a3e59bab4e8 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Tue, 21 Oct 2025 08:52:22 -0700 Subject: [PATCH] take a gradient step with video tokenizer trainer --- dreamer4/trainers.py | 96 +++++++++++++++++++++++++++++++++++++++++-- pyproject.toml | 2 +- tests/test_dreamer.py | 34 +++++++++++++++ 3 files changed, 128 insertions(+), 4 deletions(-) diff --git a/dreamer4/trainers.py b/dreamer4/trainers.py index 755da23..b578f44 100644 --- a/dreamer4/trainers.py +++ b/dreamer4/trainers.py @@ -1,5 +1,8 @@ +from __future__ import annotations + import torch from torch.nn import Module +from torch.utils.data import Dataset, DataLoader from accelerate import Accelerator @@ -7,13 +10,100 @@ from adam_atan2_pytorch import MuonAdamAtan2 from dreamer4.dreamer4 import ( VideoTokenizer, - DynamicsModel + DynamicsWorldModel ) +# helpers + +def exists(v): + return v is not None + +def default(v, d): + return v if exists(v) else d + +def cycle(dl): + while True: + for batch in dl: + yield batch + +# trainers + class VideoTokenizerTrainer(Module): def __init__( self, - model: VideoTokenizer + model: VideoTokenizer, + dataset: Dataset, + optim_klass = MuonAdamAtan2, + batch_size = 16, + learning_rate = 3e-4, + num_train_steps = 10_000, + weight_decay = 0., + accelerate_kwargs: dict = dict(), + optim_kwargs: dict = dict(), + cpu = False, ): super().__init__() - raise NotImplementedError + self.accelerator = Accelerator( + cpu = cpu, + **accelerate_kwargs + ) + + self.model = model + self.dataset = dataset + self.train_dataloader = DataLoader(dataset, batch_size = batch_size, drop_last = True, shuffle = True) + + optim_kwargs = dict( + lr = learning_rate, + weight_decay = weight_decay + ) + + if optim_klass is MuonAdamAtan2: + optim = MuonAdamAtan2( + model.muon_parameters(), + model.parameters(), + **optim_kwargs + ) + else: + optim = optim_klass( + model.parameters(), + **optim_kwargs + ) + + self.optim = optim + + self.num_train_steps = num_train_steps + self.batch_size = batch_size + + ( + self.model, + self.train_dataloader, + self.optim + ) = self.accelerator.prepare( + self.model, + self.train_dataloader, + self.optim + ) + + @property + def device(self): + return self.accelerator.device + + def print(self, *args, **kwargs): + return self.accelerator.print(*args, **kwargs) + + def forward( + self + ): + + iter_train_dl = cycle(self.train_dataloader) + + for _ in range(self.num_train_steps): + video = next(iter_train_dl) + + loss = self.model(video) + self.accelerator.backward(loss) + + self.optim.step() + self.optim.zero_grad() + + self.print('training complete') diff --git a/pyproject.toml b/pyproject.toml index 307a4b4..3965716 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.0.53" +version = "0.0.54" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } diff --git a/tests/test_dreamer.py b/tests/test_dreamer.py index 44b74bf..0498123 100644 --- a/tests/test_dreamer.py +++ b/tests/test_dreamer.py @@ -439,3 +439,37 @@ def test_loss_normalizer(): normed_losses = loss_normalizer(losses) assert (normed_losses == 1.).all() + +def test_tokenizer_trainer(): + from dreamer4.trainers import VideoTokenizerTrainer + from dreamer4.dreamer4 import VideoTokenizer + from torch.utils.data import Dataset + + class MockDataset(Dataset): + def __len__(self): + return 4 + + def __getitem__(self, idx): + return torch.randn(3, 16, 256, 256) + + dataset = MockDataset() + + tokenizer = VideoTokenizer( + 16, + encoder_depth = 1, + decoder_depth = 1, + dim_latent = 16, + patch_size = 32, + attn_dim_head = 16, + num_latent_tokens = 4 + ) + + trainer = VideoTokenizerTrainer( + tokenizer, + dataset = dataset, + num_train_steps = 1, + batch_size = 2, + cpu = True + ) + + trainer()