tweak bc trainer
This commit is contained in:
parent
2fc3b17149
commit
e280592510
@ -195,7 +195,14 @@ class BehaviorCloneTrainer(Module):
|
||||
for _ in range(self.num_train_steps):
|
||||
batch_data = next(iter_train_dl)
|
||||
|
||||
loss = self.model(**batch_data)
|
||||
# just assume raw video dynamics training if batch_data is a tensor
|
||||
# else kwargs for video, actions, rewards
|
||||
|
||||
if is_tensor(batch_data):
|
||||
loss = self.model(batch_data)
|
||||
else:
|
||||
loss = self.model(**batch_data)
|
||||
|
||||
self.accelerator.backward(loss)
|
||||
|
||||
if exists(self.max_grad_norm):
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "dreamer4"
|
||||
version = "0.0.57"
|
||||
version = "0.0.58"
|
||||
description = "Dreamer 4"
|
||||
authors = [
|
||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user