From e2805925102281d7e99acdd78cb2292a1da71bbc Mon Sep 17 00:00:00 2001 From: lucidrains Date: Tue, 21 Oct 2025 10:54:47 -0700 Subject: [PATCH] tweak bc trainer --- dreamer4/trainers.py | 9 ++++++++- pyproject.toml | 2 +- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/dreamer4/trainers.py b/dreamer4/trainers.py index 53d572d..ee2c019 100644 --- a/dreamer4/trainers.py +++ b/dreamer4/trainers.py @@ -195,7 +195,14 @@ class BehaviorCloneTrainer(Module): for _ in range(self.num_train_steps): batch_data = next(iter_train_dl) - loss = self.model(**batch_data) + # just assume raw video dynamics training if batch_data is a tensor + # else kwargs for video, actions, rewards + + if is_tensor(batch_data): + loss = self.model(batch_data) + else: + loss = self.model(**batch_data) + self.accelerator.backward(loss) if exists(self.max_grad_norm): diff --git a/pyproject.toml b/pyproject.toml index 5641a50..c753616 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.0.57" +version = "0.0.58" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }