diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 56f7fed..067fd66 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -13,6 +13,7 @@ import torchvision from torchvision.models import VGG16_Weights from x_mlps_pytorch import create_mlp +from x_mlps_pytorch.ensemble import Ensemble from accelerate import Accelerator @@ -802,6 +803,7 @@ class DynamicsModel(Module): ), ff_kwargs: dict = dict(), loss_weight_fn: Callable = ramp_weight, + num_future_predictions = 8 # they do multi-token prediction of 8 steps forward ): super().__init__() diff --git a/pyproject.toml b/pyproject.toml index f489271..71814b4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,7 +33,7 @@ dependencies = [ "hl-gauss-pytorch", "torch>=2.4", "torchvision", - "x-mlps-pytorch" + "x-mlps-pytorch>=0.0.29" ] [project.urls]