use a newly built module from x-mlps for multi token prediction

This commit is contained in:
lucidrains 2025-10-04 07:56:56 -07:00
parent 0a26e0f92f
commit 0f4783f23c
2 changed files with 3 additions and 1 deletions

View File

@ -13,6 +13,7 @@ import torchvision
from torchvision.models import VGG16_Weights
from x_mlps_pytorch import create_mlp
from x_mlps_pytorch.ensemble import Ensemble
from accelerate import Accelerator
@ -802,6 +803,7 @@ class DynamicsModel(Module):
),
ff_kwargs: dict = dict(),
loss_weight_fn: Callable = ramp_weight,
num_future_predictions = 8 # they do multi-token prediction of 8 steps forward
):
super().__init__()

View File

@ -33,7 +33,7 @@ dependencies = [
"hl-gauss-pytorch",
"torch>=2.4",
"torchvision",
"x-mlps-pytorch"
"x-mlps-pytorch>=0.0.29"
]
[project.urls]