tweak bc trainer
This commit is contained in:
parent
2fc3b17149
commit
e280592510
@ -195,7 +195,14 @@ class BehaviorCloneTrainer(Module):
|
|||||||
for _ in range(self.num_train_steps):
|
for _ in range(self.num_train_steps):
|
||||||
batch_data = next(iter_train_dl)
|
batch_data = next(iter_train_dl)
|
||||||
|
|
||||||
loss = self.model(**batch_data)
|
# just assume raw video dynamics training if batch_data is a tensor
|
||||||
|
# else kwargs for video, actions, rewards
|
||||||
|
|
||||||
|
if is_tensor(batch_data):
|
||||||
|
loss = self.model(batch_data)
|
||||||
|
else:
|
||||||
|
loss = self.model(**batch_data)
|
||||||
|
|
||||||
self.accelerator.backward(loss)
|
self.accelerator.backward(loss)
|
||||||
|
|
||||||
if exists(self.max_grad_norm):
|
if exists(self.max_grad_norm):
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "dreamer4"
|
name = "dreamer4"
|
||||||
version = "0.0.57"
|
version = "0.0.58"
|
||||||
description = "Dreamer 4"
|
description = "Dreamer 4"
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user