take a gradient step with behavioral clone trainer, make sure it works with and without actions and rewards

This commit is contained in:
lucidrains 2025-10-21 10:20:08 -07:00
parent 283d59d75a
commit 2fc3b17149
3 changed files with 167 additions and 3 deletions

View File

@ -38,6 +38,7 @@ class VideoTokenizerTrainer(Module):
optim_klass = MuonAdamAtan2,
batch_size = 16,
learning_rate = 3e-4,
max_grad_norm = None,
num_train_steps = 10_000,
weight_decay = 0.,
accelerate_kwargs: dict = dict(),
@ -45,6 +46,8 @@ class VideoTokenizerTrainer(Module):
cpu = False,
):
super().__init__()
batch_size = min(batch_size, len(dataset))
self.accelerator = Accelerator(
cpu = cpu,
**accelerate_kwargs
@ -73,6 +76,8 @@ class VideoTokenizerTrainer(Module):
self.optim = optim
self.max_grad_norm = max_grad_norm
self.num_train_steps = num_train_steps
self.batch_size = batch_size
@ -104,6 +109,98 @@ class VideoTokenizerTrainer(Module):
loss = self.model(video)
self.accelerator.backward(loss)
if exists(self.max_grad_norm):
self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
self.optim.step()
self.optim.zero_grad()
self.print('training complete')
# dynamics world model
class BehaviorCloneTrainer(Module):
def __init__(
self,
model: DynamicsWorldModel,
dataset: Dataset,
optim_klass = MuonAdamAtan2,
batch_size = 16,
learning_rate = 3e-4,
max_grad_norm = None,
num_train_steps = 10_000,
weight_decay = 0.,
accelerate_kwargs: dict = dict(),
optim_kwargs: dict = dict(),
cpu = False,
):
super().__init__()
batch_size = min(batch_size, len(dataset))
self.accelerator = Accelerator(
cpu = cpu,
**accelerate_kwargs
)
self.model = model
self.dataset = dataset
self.train_dataloader = DataLoader(dataset, batch_size = batch_size, drop_last = True, shuffle = True)
optim_kwargs = dict(
lr = learning_rate,
weight_decay = weight_decay
)
if optim_klass is MuonAdamAtan2:
optim = MuonAdamAtan2(
model.muon_parameters(),
model.parameters(),
**optim_kwargs
)
else:
optim = optim_klass(
model.parameters(),
**optim_kwargs
)
self.optim = optim
self.max_grad_norm = max_grad_norm
self.num_train_steps = num_train_steps
self.batch_size = batch_size
(
self.model,
self.train_dataloader,
self.optim
) = self.accelerator.prepare(
self.model,
self.train_dataloader,
self.optim
)
@property
def device(self):
return self.accelerator.device
def print(self, *args, **kwargs):
return self.accelerator.print(*args, **kwargs)
def forward(
self
):
iter_train_dl = cycle(self.train_dataloader)
for _ in range(self.num_train_steps):
batch_data = next(iter_train_dl)
loss = self.model(**batch_data)
self.accelerator.backward(loss)
if exists(self.max_grad_norm):
self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
self.optim.step()
self.optim.zero_grad()

View File

@ -1,6 +1,6 @@
[project]
name = "dreamer4"
version = "0.0.56"
version = "0.0.57"
description = "Dreamer 4"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

View File

@ -457,8 +457,8 @@ def test_tokenizer_trainer():
tokenizer = VideoTokenizer(
16,
encoder_depth = 1,
decoder_depth = 1,
encoder_depth = 4,
decoder_depth = 4,
dim_latent = 16,
patch_size = 32,
attn_dim_head = 16,
@ -470,6 +470,73 @@ def test_tokenizer_trainer():
dataset = dataset,
num_train_steps = 1,
batch_size = 1,
cpu = True,
max_grad_norm = 0.5
)
trainer()
@param('with_actions', (True, False))
@param('with_rewards', (True, False))
def test_bc_trainer(
with_actions,
with_rewards
):
from dreamer4.trainers import BehaviorCloneTrainer
from dreamer4.dreamer4 import DynamicsWorldModel, VideoTokenizer
from torch.utils.data import Dataset
class MockDataset(Dataset):
def __len__(self):
return 2
def __getitem__(self, idx):
state = torch.randn(3, 2, 64, 64)
pkg = dict(video = state)
if with_actions:
pkg.update(discrete_actions = torch.randint(0, 4, (2, 1)))
if with_rewards:
pkg.update(rewards = torch.randn(2,))
return pkg
dataset = MockDataset()
tokenizer = VideoTokenizer(
16,
encoder_depth = 4,
decoder_depth = 4,
dim_latent = 16,
patch_size = 32,
attn_dim_head = 16,
num_latent_tokens = 1
)
model = DynamicsWorldModel(
video_tokenizer = tokenizer,
dim = 16,
dim_latent = 16,
max_steps = 64,
num_tasks = 4,
num_latent_tokens = 1,
depth = 4,
num_spatial_tokens = 1,
pred_orig_latent = True,
num_discrete_actions = 4,
attn_dim_head = 16,
prob_no_shortcut_train = 0.1,
num_residual_streams = 1
)
trainer = BehaviorCloneTrainer(
model,
dataset = dataset,
batch_size = 1,
num_train_steps = 1,
cpu = True
)