diff --git a/dreamer4/trainers.py b/dreamer4/trainers.py index 53d572d..8b59f98 100644 --- a/dreamer4/trainers.py +++ b/dreamer4/trainers.py @@ -1,6 +1,7 @@ from __future__ import annotations import torch +from torch import is_tensor from torch.nn import Module from torch.utils.data import Dataset, DataLoader @@ -195,7 +196,14 @@ class BehaviorCloneTrainer(Module): for _ in range(self.num_train_steps): batch_data = next(iter_train_dl) - loss = self.model(**batch_data) + # just assume raw video dynamics training if batch_data is a tensor + # else kwargs for video, actions, rewards + + if is_tensor(batch_data): + loss = self.model(batch_data) + else: + loss = self.model(**batch_data) + self.accelerator.backward(loss) if exists(self.max_grad_norm): diff --git a/pyproject.toml b/pyproject.toml index 5641a50..cb11e5b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.0.57" +version = "0.0.59" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }