tweak bc trainer

This commit is contained in:
lucidrains 2025-10-21 10:55:24 -07:00
parent 2fc3b17149
commit 40da985c6b
2 changed files with 10 additions and 2 deletions

View File

@ -1,6 +1,7 @@
from __future__ import annotations
import torch
from torch import is_tensor
from torch.nn import Module
from torch.utils.data import Dataset, DataLoader
@ -195,7 +196,14 @@ class BehaviorCloneTrainer(Module):
for _ in range(self.num_train_steps):
batch_data = next(iter_train_dl)
loss = self.model(**batch_data)
# just assume raw video dynamics training if batch_data is a tensor
# else kwargs for video, actions, rewards
if is_tensor(batch_data):
loss = self.model(batch_data)
else:
loss = self.model(**batch_data)
self.accelerator.backward(loss)
if exists(self.max_grad_norm):

View File

@ -1,6 +1,6 @@
[project]
name = "dreamer4"
version = "0.0.57"
version = "0.0.59"
description = "Dreamer 4"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }