take a gradient step with video tokenizer trainer
This commit is contained in:
parent
15876d34cf
commit
ea13d4fcab
@ -1,5 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
from torch.nn import Module
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
|
||||
from accelerate import Accelerator
|
||||
|
||||
@ -7,13 +10,100 @@ from adam_atan2_pytorch import MuonAdamAtan2
|
||||
|
||||
from dreamer4.dreamer4 import (
|
||||
VideoTokenizer,
|
||||
DynamicsModel
|
||||
DynamicsWorldModel
|
||||
)
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(v):
|
||||
return v is not None
|
||||
|
||||
def default(v, d):
|
||||
return v if exists(v) else d
|
||||
|
||||
def cycle(dl):
|
||||
while True:
|
||||
for batch in dl:
|
||||
yield batch
|
||||
|
||||
# trainers
|
||||
|
||||
class VideoTokenizerTrainer(Module):
|
||||
def __init__(
|
||||
self,
|
||||
model: VideoTokenizer
|
||||
model: VideoTokenizer,
|
||||
dataset: Dataset,
|
||||
optim_klass = MuonAdamAtan2,
|
||||
batch_size = 16,
|
||||
learning_rate = 3e-4,
|
||||
num_train_steps = 10_000,
|
||||
weight_decay = 0.,
|
||||
accelerate_kwargs: dict = dict(),
|
||||
optim_kwargs: dict = dict(),
|
||||
cpu = False,
|
||||
):
|
||||
super().__init__()
|
||||
raise NotImplementedError
|
||||
self.accelerator = Accelerator(
|
||||
cpu = cpu,
|
||||
**accelerate_kwargs
|
||||
)
|
||||
|
||||
self.model = model
|
||||
self.dataset = dataset
|
||||
self.train_dataloader = DataLoader(dataset, batch_size = batch_size, drop_last = True, shuffle = True)
|
||||
|
||||
optim_kwargs = dict(
|
||||
lr = learning_rate,
|
||||
weight_decay = weight_decay
|
||||
)
|
||||
|
||||
if optim_klass is MuonAdamAtan2:
|
||||
optim = MuonAdamAtan2(
|
||||
model.muon_parameters(),
|
||||
model.parameters(),
|
||||
**optim_kwargs
|
||||
)
|
||||
else:
|
||||
optim = optim_klass(
|
||||
model.parameters(),
|
||||
**optim_kwargs
|
||||
)
|
||||
|
||||
self.optim = optim
|
||||
|
||||
self.num_train_steps = num_train_steps
|
||||
self.batch_size = batch_size
|
||||
|
||||
(
|
||||
self.model,
|
||||
self.train_dataloader,
|
||||
self.optim
|
||||
) = self.accelerator.prepare(
|
||||
self.model,
|
||||
self.train_dataloader,
|
||||
self.optim
|
||||
)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self.accelerator.device
|
||||
|
||||
def print(self, *args, **kwargs):
|
||||
return self.accelerator.print(*args, **kwargs)
|
||||
|
||||
def forward(
|
||||
self
|
||||
):
|
||||
|
||||
iter_train_dl = cycle(self.train_dataloader)
|
||||
|
||||
for _ in range(self.num_train_steps):
|
||||
video = next(iter_train_dl)
|
||||
|
||||
loss = self.model(video)
|
||||
self.accelerator.backward(loss)
|
||||
|
||||
self.optim.step()
|
||||
self.optim.zero_grad()
|
||||
|
||||
self.print('training complete')
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "dreamer4"
|
||||
version = "0.0.53"
|
||||
version = "0.0.54"
|
||||
description = "Dreamer 4"
|
||||
authors = [
|
||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||
|
||||
@ -439,3 +439,37 @@ def test_loss_normalizer():
|
||||
normed_losses = loss_normalizer(losses)
|
||||
|
||||
assert (normed_losses == 1.).all()
|
||||
|
||||
def test_tokenizer_trainer():
|
||||
from dreamer4.trainers import VideoTokenizerTrainer
|
||||
from dreamer4.dreamer4 import VideoTokenizer
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
class MockDataset(Dataset):
|
||||
def __len__(self):
|
||||
return 4
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return torch.randn(3, 16, 256, 256)
|
||||
|
||||
dataset = MockDataset()
|
||||
|
||||
tokenizer = VideoTokenizer(
|
||||
16,
|
||||
encoder_depth = 1,
|
||||
decoder_depth = 1,
|
||||
dim_latent = 16,
|
||||
patch_size = 32,
|
||||
attn_dim_head = 16,
|
||||
num_latent_tokens = 4
|
||||
)
|
||||
|
||||
trainer = VideoTokenizerTrainer(
|
||||
tokenizer,
|
||||
dataset = dataset,
|
||||
num_train_steps = 1,
|
||||
batch_size = 2,
|
||||
cpu = True
|
||||
)
|
||||
|
||||
trainer()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user