take a gradient step with video tokenizer trainer
This commit is contained in:
parent
15876d34cf
commit
ea13d4fcab
@ -1,5 +1,8 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.nn import Module
|
from torch.nn import Module
|
||||||
|
from torch.utils.data import Dataset, DataLoader
|
||||||
|
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
|
|
||||||
@ -7,13 +10,100 @@ from adam_atan2_pytorch import MuonAdamAtan2
|
|||||||
|
|
||||||
from dreamer4.dreamer4 import (
|
from dreamer4.dreamer4 import (
|
||||||
VideoTokenizer,
|
VideoTokenizer,
|
||||||
DynamicsModel
|
DynamicsWorldModel
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# helpers
|
||||||
|
|
||||||
|
def exists(v):
|
||||||
|
return v is not None
|
||||||
|
|
||||||
|
def default(v, d):
|
||||||
|
return v if exists(v) else d
|
||||||
|
|
||||||
|
def cycle(dl):
|
||||||
|
while True:
|
||||||
|
for batch in dl:
|
||||||
|
yield batch
|
||||||
|
|
||||||
|
# trainers
|
||||||
|
|
||||||
class VideoTokenizerTrainer(Module):
|
class VideoTokenizerTrainer(Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: VideoTokenizer
|
model: VideoTokenizer,
|
||||||
|
dataset: Dataset,
|
||||||
|
optim_klass = MuonAdamAtan2,
|
||||||
|
batch_size = 16,
|
||||||
|
learning_rate = 3e-4,
|
||||||
|
num_train_steps = 10_000,
|
||||||
|
weight_decay = 0.,
|
||||||
|
accelerate_kwargs: dict = dict(),
|
||||||
|
optim_kwargs: dict = dict(),
|
||||||
|
cpu = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
raise NotImplementedError
|
self.accelerator = Accelerator(
|
||||||
|
cpu = cpu,
|
||||||
|
**accelerate_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
self.model = model
|
||||||
|
self.dataset = dataset
|
||||||
|
self.train_dataloader = DataLoader(dataset, batch_size = batch_size, drop_last = True, shuffle = True)
|
||||||
|
|
||||||
|
optim_kwargs = dict(
|
||||||
|
lr = learning_rate,
|
||||||
|
weight_decay = weight_decay
|
||||||
|
)
|
||||||
|
|
||||||
|
if optim_klass is MuonAdamAtan2:
|
||||||
|
optim = MuonAdamAtan2(
|
||||||
|
model.muon_parameters(),
|
||||||
|
model.parameters(),
|
||||||
|
**optim_kwargs
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
optim = optim_klass(
|
||||||
|
model.parameters(),
|
||||||
|
**optim_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
self.optim = optim
|
||||||
|
|
||||||
|
self.num_train_steps = num_train_steps
|
||||||
|
self.batch_size = batch_size
|
||||||
|
|
||||||
|
(
|
||||||
|
self.model,
|
||||||
|
self.train_dataloader,
|
||||||
|
self.optim
|
||||||
|
) = self.accelerator.prepare(
|
||||||
|
self.model,
|
||||||
|
self.train_dataloader,
|
||||||
|
self.optim
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self):
|
||||||
|
return self.accelerator.device
|
||||||
|
|
||||||
|
def print(self, *args, **kwargs):
|
||||||
|
return self.accelerator.print(*args, **kwargs)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self
|
||||||
|
):
|
||||||
|
|
||||||
|
iter_train_dl = cycle(self.train_dataloader)
|
||||||
|
|
||||||
|
for _ in range(self.num_train_steps):
|
||||||
|
video = next(iter_train_dl)
|
||||||
|
|
||||||
|
loss = self.model(video)
|
||||||
|
self.accelerator.backward(loss)
|
||||||
|
|
||||||
|
self.optim.step()
|
||||||
|
self.optim.zero_grad()
|
||||||
|
|
||||||
|
self.print('training complete')
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "dreamer4"
|
name = "dreamer4"
|
||||||
version = "0.0.53"
|
version = "0.0.54"
|
||||||
description = "Dreamer 4"
|
description = "Dreamer 4"
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||||
|
|||||||
@ -439,3 +439,37 @@ def test_loss_normalizer():
|
|||||||
normed_losses = loss_normalizer(losses)
|
normed_losses = loss_normalizer(losses)
|
||||||
|
|
||||||
assert (normed_losses == 1.).all()
|
assert (normed_losses == 1.).all()
|
||||||
|
|
||||||
|
def test_tokenizer_trainer():
|
||||||
|
from dreamer4.trainers import VideoTokenizerTrainer
|
||||||
|
from dreamer4.dreamer4 import VideoTokenizer
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
|
class MockDataset(Dataset):
|
||||||
|
def __len__(self):
|
||||||
|
return 4
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
return torch.randn(3, 16, 256, 256)
|
||||||
|
|
||||||
|
dataset = MockDataset()
|
||||||
|
|
||||||
|
tokenizer = VideoTokenizer(
|
||||||
|
16,
|
||||||
|
encoder_depth = 1,
|
||||||
|
decoder_depth = 1,
|
||||||
|
dim_latent = 16,
|
||||||
|
patch_size = 32,
|
||||||
|
attn_dim_head = 16,
|
||||||
|
num_latent_tokens = 4
|
||||||
|
)
|
||||||
|
|
||||||
|
trainer = VideoTokenizerTrainer(
|
||||||
|
tokenizer,
|
||||||
|
dataset = dataset,
|
||||||
|
num_train_steps = 1,
|
||||||
|
batch_size = 2,
|
||||||
|
cpu = True
|
||||||
|
)
|
||||||
|
|
||||||
|
trainer()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user