use a newly built module from x-mlps for multi token prediction
This commit is contained in:
parent
0a26e0f92f
commit
0f4783f23c
@ -13,6 +13,7 @@ import torchvision
|
||||
from torchvision.models import VGG16_Weights
|
||||
|
||||
from x_mlps_pytorch import create_mlp
|
||||
from x_mlps_pytorch.ensemble import Ensemble
|
||||
|
||||
from accelerate import Accelerator
|
||||
|
||||
@ -802,6 +803,7 @@ class DynamicsModel(Module):
|
||||
),
|
||||
ff_kwargs: dict = dict(),
|
||||
loss_weight_fn: Callable = ramp_weight,
|
||||
num_future_predictions = 8 # they do multi-token prediction of 8 steps forward
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
||||
@ -33,7 +33,7 @@ dependencies = [
|
||||
"hl-gauss-pytorch",
|
||||
"torch>=2.4",
|
||||
"torchvision",
|
||||
"x-mlps-pytorch"
|
||||
"x-mlps-pytorch>=0.0.29"
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user