tweak bc trainer
This commit is contained in:
parent
2fc3b17149
commit
40da985c6b
@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
from torch import is_tensor
|
||||
from torch.nn import Module
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
|
||||
@ -195,7 +196,14 @@ class BehaviorCloneTrainer(Module):
|
||||
for _ in range(self.num_train_steps):
|
||||
batch_data = next(iter_train_dl)
|
||||
|
||||
loss = self.model(**batch_data)
|
||||
# just assume raw video dynamics training if batch_data is a tensor
|
||||
# else kwargs for video, actions, rewards
|
||||
|
||||
if is_tensor(batch_data):
|
||||
loss = self.model(batch_data)
|
||||
else:
|
||||
loss = self.model(**batch_data)
|
||||
|
||||
self.accelerator.backward(loss)
|
||||
|
||||
if exists(self.max_grad_norm):
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "dreamer4"
|
||||
version = "0.0.57"
|
||||
version = "0.0.59"
|
||||
description = "Dreamer 4"
|
||||
authors = [
|
||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user