145 lines
4.0 KiB
Python
145 lines
4.0 KiB
Python
|
import os
|
||
|
import torch
|
||
|
|
||
|
from experiment_launcher import single_experiment_yaml, run_experiment
|
||
|
from mpd import trainer
|
||
|
from mpd.models import UNET_DIM_MULTS, TemporalUnet
|
||
|
from mpd.trainer import get_dataset, get_model, get_loss, get_summary
|
||
|
from mpd.trainer.trainer import get_num_epochs
|
||
|
from torch_robotics.torch_utils.seed import fix_random_seed
|
||
|
from torch_robotics.torch_utils.torch_utils import get_torch_device
|
||
|
|
||
|
os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE"
|
||
|
|
||
|
|
||
|
@single_experiment_yaml
|
||
|
def experiment(
|
||
|
########################################################################
|
||
|
# Dataset
|
||
|
dataset_subdir: str = 'EnvSimple2D-RobotPointMass',
|
||
|
# dataset_subdir: str = 'EnvSpheres3D-RobotPanda',
|
||
|
include_velocity: bool = True,
|
||
|
|
||
|
########################################################################
|
||
|
# Diffusion Model
|
||
|
diffusion_model_class: str = 'GaussianDiffusionModel',
|
||
|
variance_schedule: str = 'exponential', # cosine
|
||
|
n_diffusion_steps: int = 25,
|
||
|
predict_epsilon: bool = True,
|
||
|
|
||
|
# Unet
|
||
|
unet_input_dim: int = 32,
|
||
|
unet_dim_mults_option: int = 1,
|
||
|
|
||
|
########################################################################
|
||
|
# Loss
|
||
|
loss_class: str = 'GaussianDiffusionLoss',
|
||
|
|
||
|
# Training parameters
|
||
|
batch_size: int = 32,
|
||
|
lr: float = 1e-4,
|
||
|
num_train_steps: int = 500000,
|
||
|
|
||
|
use_ema: bool = True,
|
||
|
use_amp: bool = False,
|
||
|
|
||
|
# Summary parameters
|
||
|
steps_til_summary: int = 10,
|
||
|
summary_class: str = 'SummaryTrajectoryGeneration',
|
||
|
|
||
|
steps_til_ckpt: int = 50000,
|
||
|
|
||
|
########################################################################
|
||
|
device: str = 'cuda',
|
||
|
|
||
|
debug: bool = True,
|
||
|
|
||
|
########################################################################
|
||
|
# MANDATORY
|
||
|
seed: int = 0,
|
||
|
results_dir: str = 'logs',
|
||
|
|
||
|
########################################################################
|
||
|
# WandB
|
||
|
wandb_mode: str = 'disabled', # "online", "offline" or "disabled"
|
||
|
wandb_entity: str = 'scoreplan',
|
||
|
wandb_project: str = 'test_train',
|
||
|
**kwargs
|
||
|
):
|
||
|
fix_random_seed(seed)
|
||
|
|
||
|
device = get_torch_device(device=device)
|
||
|
tensor_args = {'device': device, 'dtype': torch.float32}
|
||
|
|
||
|
# Dataset
|
||
|
train_subset, train_dataloader, val_subset, val_dataloader = get_dataset(
|
||
|
dataset_class='TrajectoryDataset',
|
||
|
include_velocity=include_velocity,
|
||
|
dataset_subdir=dataset_subdir,
|
||
|
batch_size=batch_size,
|
||
|
results_dir=results_dir,
|
||
|
save_indices=True,
|
||
|
tensor_args=tensor_args
|
||
|
)
|
||
|
|
||
|
dataset = train_subset.dataset
|
||
|
|
||
|
# Model
|
||
|
diffusion_configs = dict(
|
||
|
variance_schedule=variance_schedule,
|
||
|
n_diffusion_steps=n_diffusion_steps,
|
||
|
predict_epsilon=predict_epsilon,
|
||
|
)
|
||
|
|
||
|
unet_configs = dict(
|
||
|
state_dim=dataset.state_dim,
|
||
|
n_support_points=dataset.n_support_points,
|
||
|
unet_input_dim=unet_input_dim,
|
||
|
dim_mults=UNET_DIM_MULTS[unet_dim_mults_option],
|
||
|
)
|
||
|
|
||
|
model = get_model(
|
||
|
model_class=diffusion_model_class,
|
||
|
model=TemporalUnet(**unet_configs),
|
||
|
tensor_args=tensor_args,
|
||
|
**diffusion_configs,
|
||
|
**unet_configs
|
||
|
)
|
||
|
|
||
|
# Loss
|
||
|
loss_fn = val_loss_fn = get_loss(
|
||
|
loss_class=loss_class
|
||
|
)
|
||
|
|
||
|
# Summary
|
||
|
summary_fn = get_summary(
|
||
|
summary_class=summary_class,
|
||
|
)
|
||
|
|
||
|
# Train
|
||
|
trainer.train(
|
||
|
model=model,
|
||
|
train_dataloader=train_dataloader,
|
||
|
train_subset=train_subset,
|
||
|
val_dataloader=val_dataloader,
|
||
|
val_subset=train_subset,
|
||
|
epochs=get_num_epochs(num_train_steps, batch_size, len(dataset)),
|
||
|
model_dir=results_dir,
|
||
|
summary_fn=summary_fn,
|
||
|
lr=lr,
|
||
|
loss_fn=loss_fn,
|
||
|
val_loss_fn=val_loss_fn,
|
||
|
steps_til_summary=steps_til_summary,
|
||
|
steps_til_checkpoint=steps_til_ckpt,
|
||
|
clip_grad=True,
|
||
|
use_ema=use_ema,
|
||
|
use_amp=use_amp,
|
||
|
debug=debug,
|
||
|
tensor_args=tensor_args
|
||
|
)
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
# Leave unchanged
|
||
|
run_experiment(experiment)
|