tweak bc trainer
This commit is contained in:
parent
2fc3b17149
commit
40da985c6b
@ -1,6 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from torch import is_tensor
|
||||||
from torch.nn import Module
|
from torch.nn import Module
|
||||||
from torch.utils.data import Dataset, DataLoader
|
from torch.utils.data import Dataset, DataLoader
|
||||||
|
|
||||||
@ -195,7 +196,14 @@ class BehaviorCloneTrainer(Module):
|
|||||||
for _ in range(self.num_train_steps):
|
for _ in range(self.num_train_steps):
|
||||||
batch_data = next(iter_train_dl)
|
batch_data = next(iter_train_dl)
|
||||||
|
|
||||||
|
# just assume raw video dynamics training if batch_data is a tensor
|
||||||
|
# else kwargs for video, actions, rewards
|
||||||
|
|
||||||
|
if is_tensor(batch_data):
|
||||||
|
loss = self.model(batch_data)
|
||||||
|
else:
|
||||||
loss = self.model(**batch_data)
|
loss = self.model(**batch_data)
|
||||||
|
|
||||||
self.accelerator.backward(loss)
|
self.accelerator.backward(loss)
|
||||||
|
|
||||||
if exists(self.max_grad_norm):
|
if exists(self.max_grad_norm):
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "dreamer4"
|
name = "dreamer4"
|
||||||
version = "0.0.57"
|
version = "0.0.59"
|
||||||
description = "Dreamer 4"
|
description = "Dreamer 4"
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user