From 0f4783f23cd12863b2532cb19ad2af79b3253055 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Sat, 4 Oct 2025 07:56:56 -0700 Subject: [PATCH] use a newly built module from x-mlps for multi token prediction --- dreamer4/dreamer4.py | 2 ++ pyproject.toml | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 56f7fed..067fd66 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -13,6 +13,7 @@ import torchvision from torchvision.models import VGG16_Weights from x_mlps_pytorch import create_mlp +from x_mlps_pytorch.ensemble import Ensemble from accelerate import Accelerator @@ -802,6 +803,7 @@ class DynamicsModel(Module): ), ff_kwargs: dict = dict(), loss_weight_fn: Callable = ramp_weight, + num_future_predictions = 8 # they do multi-token prediction of 8 steps forward ): super().__init__() diff --git a/pyproject.toml b/pyproject.toml index f489271..71814b4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,7 +33,7 @@ dependencies = [ "hl-gauss-pytorch", "torch>=2.4", "torchvision", - "x-mlps-pytorch" + "x-mlps-pytorch>=0.0.29" ] [project.urls]