From 2fc3b17149cbc17802de6679c3b9e1e3b07dd307 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Tue, 21 Oct 2025 10:20:08 -0700 Subject: [PATCH] take a gradient step with behavioral clone trainer, make sure it works with and without actions and rewards --- dreamer4/trainers.py | 97 +++++++++++++++++++++++++++++++++++++++++++ pyproject.toml | 2 +- tests/test_dreamer.py | 71 ++++++++++++++++++++++++++++++- 3 files changed, 167 insertions(+), 3 deletions(-) diff --git a/dreamer4/trainers.py b/dreamer4/trainers.py index ce72f92..53d572d 100644 --- a/dreamer4/trainers.py +++ b/dreamer4/trainers.py @@ -38,6 +38,7 @@ class VideoTokenizerTrainer(Module): optim_klass = MuonAdamAtan2, batch_size = 16, learning_rate = 3e-4, + max_grad_norm = None, num_train_steps = 10_000, weight_decay = 0., accelerate_kwargs: dict = dict(), @@ -45,6 +46,8 @@ class VideoTokenizerTrainer(Module): cpu = False, ): super().__init__() + batch_size = min(batch_size, len(dataset)) + self.accelerator = Accelerator( cpu = cpu, **accelerate_kwargs @@ -73,6 +76,8 @@ class VideoTokenizerTrainer(Module): self.optim = optim + self.max_grad_norm = max_grad_norm + self.num_train_steps = num_train_steps self.batch_size = batch_size @@ -104,6 +109,98 @@ class VideoTokenizerTrainer(Module): loss = self.model(video) self.accelerator.backward(loss) + if exists(self.max_grad_norm): + self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm) + + self.optim.step() + self.optim.zero_grad() + + self.print('training complete') + +# dynamics world model + +class BehaviorCloneTrainer(Module): + def __init__( + self, + model: DynamicsWorldModel, + dataset: Dataset, + optim_klass = MuonAdamAtan2, + batch_size = 16, + learning_rate = 3e-4, + max_grad_norm = None, + num_train_steps = 10_000, + weight_decay = 0., + accelerate_kwargs: dict = dict(), + optim_kwargs: dict = dict(), + cpu = False, + ): + super().__init__() + batch_size = min(batch_size, len(dataset)) + + self.accelerator = Accelerator( + cpu = cpu, + **accelerate_kwargs + ) + + self.model = model + self.dataset = dataset + self.train_dataloader = DataLoader(dataset, batch_size = batch_size, drop_last = True, shuffle = True) + + optim_kwargs = dict( + lr = learning_rate, + weight_decay = weight_decay + ) + + if optim_klass is MuonAdamAtan2: + optim = MuonAdamAtan2( + model.muon_parameters(), + model.parameters(), + **optim_kwargs + ) + else: + optim = optim_klass( + model.parameters(), + **optim_kwargs + ) + + self.optim = optim + + self.max_grad_norm = max_grad_norm + + self.num_train_steps = num_train_steps + self.batch_size = batch_size + + ( + self.model, + self.train_dataloader, + self.optim + ) = self.accelerator.prepare( + self.model, + self.train_dataloader, + self.optim + ) + + @property + def device(self): + return self.accelerator.device + + def print(self, *args, **kwargs): + return self.accelerator.print(*args, **kwargs) + + def forward( + self + ): + iter_train_dl = cycle(self.train_dataloader) + + for _ in range(self.num_train_steps): + batch_data = next(iter_train_dl) + + loss = self.model(**batch_data) + self.accelerator.backward(loss) + + if exists(self.max_grad_norm): + self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm) + self.optim.step() self.optim.zero_grad() diff --git a/pyproject.toml b/pyproject.toml index bab94eb..5641a50 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.0.56" +version = "0.0.57" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } diff --git a/tests/test_dreamer.py b/tests/test_dreamer.py index 94a395d..3acf1e5 100644 --- a/tests/test_dreamer.py +++ b/tests/test_dreamer.py @@ -457,8 +457,8 @@ def test_tokenizer_trainer(): tokenizer = VideoTokenizer( 16, - encoder_depth = 1, - decoder_depth = 1, + encoder_depth = 4, + decoder_depth = 4, dim_latent = 16, patch_size = 32, attn_dim_head = 16, @@ -470,6 +470,73 @@ def test_tokenizer_trainer(): dataset = dataset, num_train_steps = 1, batch_size = 1, + cpu = True, + max_grad_norm = 0.5 + ) + + trainer() + +@param('with_actions', (True, False)) +@param('with_rewards', (True, False)) +def test_bc_trainer( + with_actions, + with_rewards +): + from dreamer4.trainers import BehaviorCloneTrainer + from dreamer4.dreamer4 import DynamicsWorldModel, VideoTokenizer + + from torch.utils.data import Dataset + + class MockDataset(Dataset): + def __len__(self): + return 2 + + def __getitem__(self, idx): + state = torch.randn(3, 2, 64, 64) + + pkg = dict(video = state) + + if with_actions: + pkg.update(discrete_actions = torch.randint(0, 4, (2, 1))) + + if with_rewards: + pkg.update(rewards = torch.randn(2,)) + + return pkg + + dataset = MockDataset() + + tokenizer = VideoTokenizer( + 16, + encoder_depth = 4, + decoder_depth = 4, + dim_latent = 16, + patch_size = 32, + attn_dim_head = 16, + num_latent_tokens = 1 + ) + + model = DynamicsWorldModel( + video_tokenizer = tokenizer, + dim = 16, + dim_latent = 16, + max_steps = 64, + num_tasks = 4, + num_latent_tokens = 1, + depth = 4, + num_spatial_tokens = 1, + pred_orig_latent = True, + num_discrete_actions = 4, + attn_dim_head = 16, + prob_no_shortcut_train = 0.1, + num_residual_streams = 1 + ) + + trainer = BehaviorCloneTrainer( + model, + dataset = dataset, + batch_size = 1, + num_train_steps = 1, cpu = True )