take a gradient step with behavioral clone trainer, make sure it works with and without actions and rewards
This commit is contained in:
parent
283d59d75a
commit
2fc3b17149
@ -38,6 +38,7 @@ class VideoTokenizerTrainer(Module):
|
|||||||
optim_klass = MuonAdamAtan2,
|
optim_klass = MuonAdamAtan2,
|
||||||
batch_size = 16,
|
batch_size = 16,
|
||||||
learning_rate = 3e-4,
|
learning_rate = 3e-4,
|
||||||
|
max_grad_norm = None,
|
||||||
num_train_steps = 10_000,
|
num_train_steps = 10_000,
|
||||||
weight_decay = 0.,
|
weight_decay = 0.,
|
||||||
accelerate_kwargs: dict = dict(),
|
accelerate_kwargs: dict = dict(),
|
||||||
@ -45,6 +46,8 @@ class VideoTokenizerTrainer(Module):
|
|||||||
cpu = False,
|
cpu = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
batch_size = min(batch_size, len(dataset))
|
||||||
|
|
||||||
self.accelerator = Accelerator(
|
self.accelerator = Accelerator(
|
||||||
cpu = cpu,
|
cpu = cpu,
|
||||||
**accelerate_kwargs
|
**accelerate_kwargs
|
||||||
@ -73,6 +76,8 @@ class VideoTokenizerTrainer(Module):
|
|||||||
|
|
||||||
self.optim = optim
|
self.optim = optim
|
||||||
|
|
||||||
|
self.max_grad_norm = max_grad_norm
|
||||||
|
|
||||||
self.num_train_steps = num_train_steps
|
self.num_train_steps = num_train_steps
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
|
|
||||||
@ -104,6 +109,98 @@ class VideoTokenizerTrainer(Module):
|
|||||||
loss = self.model(video)
|
loss = self.model(video)
|
||||||
self.accelerator.backward(loss)
|
self.accelerator.backward(loss)
|
||||||
|
|
||||||
|
if exists(self.max_grad_norm):
|
||||||
|
self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
|
||||||
|
|
||||||
|
self.optim.step()
|
||||||
|
self.optim.zero_grad()
|
||||||
|
|
||||||
|
self.print('training complete')
|
||||||
|
|
||||||
|
# dynamics world model
|
||||||
|
|
||||||
|
class BehaviorCloneTrainer(Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: DynamicsWorldModel,
|
||||||
|
dataset: Dataset,
|
||||||
|
optim_klass = MuonAdamAtan2,
|
||||||
|
batch_size = 16,
|
||||||
|
learning_rate = 3e-4,
|
||||||
|
max_grad_norm = None,
|
||||||
|
num_train_steps = 10_000,
|
||||||
|
weight_decay = 0.,
|
||||||
|
accelerate_kwargs: dict = dict(),
|
||||||
|
optim_kwargs: dict = dict(),
|
||||||
|
cpu = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
batch_size = min(batch_size, len(dataset))
|
||||||
|
|
||||||
|
self.accelerator = Accelerator(
|
||||||
|
cpu = cpu,
|
||||||
|
**accelerate_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
self.model = model
|
||||||
|
self.dataset = dataset
|
||||||
|
self.train_dataloader = DataLoader(dataset, batch_size = batch_size, drop_last = True, shuffle = True)
|
||||||
|
|
||||||
|
optim_kwargs = dict(
|
||||||
|
lr = learning_rate,
|
||||||
|
weight_decay = weight_decay
|
||||||
|
)
|
||||||
|
|
||||||
|
if optim_klass is MuonAdamAtan2:
|
||||||
|
optim = MuonAdamAtan2(
|
||||||
|
model.muon_parameters(),
|
||||||
|
model.parameters(),
|
||||||
|
**optim_kwargs
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
optim = optim_klass(
|
||||||
|
model.parameters(),
|
||||||
|
**optim_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
self.optim = optim
|
||||||
|
|
||||||
|
self.max_grad_norm = max_grad_norm
|
||||||
|
|
||||||
|
self.num_train_steps = num_train_steps
|
||||||
|
self.batch_size = batch_size
|
||||||
|
|
||||||
|
(
|
||||||
|
self.model,
|
||||||
|
self.train_dataloader,
|
||||||
|
self.optim
|
||||||
|
) = self.accelerator.prepare(
|
||||||
|
self.model,
|
||||||
|
self.train_dataloader,
|
||||||
|
self.optim
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self):
|
||||||
|
return self.accelerator.device
|
||||||
|
|
||||||
|
def print(self, *args, **kwargs):
|
||||||
|
return self.accelerator.print(*args, **kwargs)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self
|
||||||
|
):
|
||||||
|
iter_train_dl = cycle(self.train_dataloader)
|
||||||
|
|
||||||
|
for _ in range(self.num_train_steps):
|
||||||
|
batch_data = next(iter_train_dl)
|
||||||
|
|
||||||
|
loss = self.model(**batch_data)
|
||||||
|
self.accelerator.backward(loss)
|
||||||
|
|
||||||
|
if exists(self.max_grad_norm):
|
||||||
|
self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
|
||||||
|
|
||||||
self.optim.step()
|
self.optim.step()
|
||||||
self.optim.zero_grad()
|
self.optim.zero_grad()
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "dreamer4"
|
name = "dreamer4"
|
||||||
version = "0.0.56"
|
version = "0.0.57"
|
||||||
description = "Dreamer 4"
|
description = "Dreamer 4"
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||||
|
|||||||
@ -457,8 +457,8 @@ def test_tokenizer_trainer():
|
|||||||
|
|
||||||
tokenizer = VideoTokenizer(
|
tokenizer = VideoTokenizer(
|
||||||
16,
|
16,
|
||||||
encoder_depth = 1,
|
encoder_depth = 4,
|
||||||
decoder_depth = 1,
|
decoder_depth = 4,
|
||||||
dim_latent = 16,
|
dim_latent = 16,
|
||||||
patch_size = 32,
|
patch_size = 32,
|
||||||
attn_dim_head = 16,
|
attn_dim_head = 16,
|
||||||
@ -470,6 +470,73 @@ def test_tokenizer_trainer():
|
|||||||
dataset = dataset,
|
dataset = dataset,
|
||||||
num_train_steps = 1,
|
num_train_steps = 1,
|
||||||
batch_size = 1,
|
batch_size = 1,
|
||||||
|
cpu = True,
|
||||||
|
max_grad_norm = 0.5
|
||||||
|
)
|
||||||
|
|
||||||
|
trainer()
|
||||||
|
|
||||||
|
@param('with_actions', (True, False))
|
||||||
|
@param('with_rewards', (True, False))
|
||||||
|
def test_bc_trainer(
|
||||||
|
with_actions,
|
||||||
|
with_rewards
|
||||||
|
):
|
||||||
|
from dreamer4.trainers import BehaviorCloneTrainer
|
||||||
|
from dreamer4.dreamer4 import DynamicsWorldModel, VideoTokenizer
|
||||||
|
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
|
class MockDataset(Dataset):
|
||||||
|
def __len__(self):
|
||||||
|
return 2
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
state = torch.randn(3, 2, 64, 64)
|
||||||
|
|
||||||
|
pkg = dict(video = state)
|
||||||
|
|
||||||
|
if with_actions:
|
||||||
|
pkg.update(discrete_actions = torch.randint(0, 4, (2, 1)))
|
||||||
|
|
||||||
|
if with_rewards:
|
||||||
|
pkg.update(rewards = torch.randn(2,))
|
||||||
|
|
||||||
|
return pkg
|
||||||
|
|
||||||
|
dataset = MockDataset()
|
||||||
|
|
||||||
|
tokenizer = VideoTokenizer(
|
||||||
|
16,
|
||||||
|
encoder_depth = 4,
|
||||||
|
decoder_depth = 4,
|
||||||
|
dim_latent = 16,
|
||||||
|
patch_size = 32,
|
||||||
|
attn_dim_head = 16,
|
||||||
|
num_latent_tokens = 1
|
||||||
|
)
|
||||||
|
|
||||||
|
model = DynamicsWorldModel(
|
||||||
|
video_tokenizer = tokenizer,
|
||||||
|
dim = 16,
|
||||||
|
dim_latent = 16,
|
||||||
|
max_steps = 64,
|
||||||
|
num_tasks = 4,
|
||||||
|
num_latent_tokens = 1,
|
||||||
|
depth = 4,
|
||||||
|
num_spatial_tokens = 1,
|
||||||
|
pred_orig_latent = True,
|
||||||
|
num_discrete_actions = 4,
|
||||||
|
attn_dim_head = 16,
|
||||||
|
prob_no_shortcut_train = 0.1,
|
||||||
|
num_residual_streams = 1
|
||||||
|
)
|
||||||
|
|
||||||
|
trainer = BehaviorCloneTrainer(
|
||||||
|
model,
|
||||||
|
dataset = dataset,
|
||||||
|
batch_size = 1,
|
||||||
|
num_train_steps = 1,
|
||||||
cpu = True
|
cpu = True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user