take a gradient step with video tokenizer trainer

This commit is contained in:
lucidrains 2025-10-21 08:52:22 -07:00
parent 15876d34cf
commit ea13d4fcab
3 changed files with 128 additions and 4 deletions

View File

@ -1,5 +1,8 @@
from __future__ import annotations
import torch
from torch.nn import Module
from torch.utils.data import Dataset, DataLoader
from accelerate import Accelerator
@ -7,13 +10,100 @@ from adam_atan2_pytorch import MuonAdamAtan2
from dreamer4.dreamer4 import (
VideoTokenizer,
DynamicsModel
DynamicsWorldModel
)
# helpers
def exists(v):
return v is not None
def default(v, d):
return v if exists(v) else d
def cycle(dl):
while True:
for batch in dl:
yield batch
# trainers
class VideoTokenizerTrainer(Module):
def __init__(
self,
model: VideoTokenizer
model: VideoTokenizer,
dataset: Dataset,
optim_klass = MuonAdamAtan2,
batch_size = 16,
learning_rate = 3e-4,
num_train_steps = 10_000,
weight_decay = 0.,
accelerate_kwargs: dict = dict(),
optim_kwargs: dict = dict(),
cpu = False,
):
super().__init__()
raise NotImplementedError
self.accelerator = Accelerator(
cpu = cpu,
**accelerate_kwargs
)
self.model = model
self.dataset = dataset
self.train_dataloader = DataLoader(dataset, batch_size = batch_size, drop_last = True, shuffle = True)
optim_kwargs = dict(
lr = learning_rate,
weight_decay = weight_decay
)
if optim_klass is MuonAdamAtan2:
optim = MuonAdamAtan2(
model.muon_parameters(),
model.parameters(),
**optim_kwargs
)
else:
optim = optim_klass(
model.parameters(),
**optim_kwargs
)
self.optim = optim
self.num_train_steps = num_train_steps
self.batch_size = batch_size
(
self.model,
self.train_dataloader,
self.optim
) = self.accelerator.prepare(
self.model,
self.train_dataloader,
self.optim
)
@property
def device(self):
return self.accelerator.device
def print(self, *args, **kwargs):
return self.accelerator.print(*args, **kwargs)
def forward(
self
):
iter_train_dl = cycle(self.train_dataloader)
for _ in range(self.num_train_steps):
video = next(iter_train_dl)
loss = self.model(video)
self.accelerator.backward(loss)
self.optim.step()
self.optim.zero_grad()
self.print('training complete')

View File

@ -1,6 +1,6 @@
[project]
name = "dreamer4"
version = "0.0.53"
version = "0.0.54"
description = "Dreamer 4"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

View File

@ -439,3 +439,37 @@ def test_loss_normalizer():
normed_losses = loss_normalizer(losses)
assert (normed_losses == 1.).all()
def test_tokenizer_trainer():
from dreamer4.trainers import VideoTokenizerTrainer
from dreamer4.dreamer4 import VideoTokenizer
from torch.utils.data import Dataset
class MockDataset(Dataset):
def __len__(self):
return 4
def __getitem__(self, idx):
return torch.randn(3, 16, 256, 256)
dataset = MockDataset()
tokenizer = VideoTokenizer(
16,
encoder_depth = 1,
decoder_depth = 1,
dim_latent = 16,
patch_size = 32,
attn_dim_head = 16,
num_latent_tokens = 4
)
trainer = VideoTokenizerTrainer(
tokenizer,
dataset = dataset,
num_train_steps = 1,
batch_size = 2,
cpu = True
)
trainer()