use a newly built module from x-mlps for multi token prediction
This commit is contained in:
parent
0a26e0f92f
commit
0f4783f23c
@ -13,6 +13,7 @@ import torchvision
|
|||||||
from torchvision.models import VGG16_Weights
|
from torchvision.models import VGG16_Weights
|
||||||
|
|
||||||
from x_mlps_pytorch import create_mlp
|
from x_mlps_pytorch import create_mlp
|
||||||
|
from x_mlps_pytorch.ensemble import Ensemble
|
||||||
|
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
|
|
||||||
@ -802,6 +803,7 @@ class DynamicsModel(Module):
|
|||||||
),
|
),
|
||||||
ff_kwargs: dict = dict(),
|
ff_kwargs: dict = dict(),
|
||||||
loss_weight_fn: Callable = ramp_weight,
|
loss_weight_fn: Callable = ramp_weight,
|
||||||
|
num_future_predictions = 8 # they do multi-token prediction of 8 steps forward
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
|||||||
@ -33,7 +33,7 @@ dependencies = [
|
|||||||
"hl-gauss-pytorch",
|
"hl-gauss-pytorch",
|
||||||
"torch>=2.4",
|
"torch>=2.4",
|
||||||
"torchvision",
|
"torchvision",
|
||||||
"x-mlps-pytorch"
|
"x-mlps-pytorch>=0.0.29"
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.urls]
|
[project.urls]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user