336 lines
14 KiB
Python
336 lines
14 KiB
Python
import copy
|
|
from math import ceil
|
|
|
|
import numpy as np
|
|
import os
|
|
import time
|
|
import torch
|
|
import wandb
|
|
from collections import defaultdict
|
|
from tqdm.autonotebook import tqdm
|
|
|
|
from torch_robotics.torch_utils.torch_timer import TimerCUDA
|
|
from torch_robotics.torch_utils.torch_utils import dict_to_device, DEFAULT_TENSOR_ARGS, to_numpy
|
|
|
|
|
|
def get_num_epochs(num_train_steps, batch_size, dataset_len):
|
|
return ceil(num_train_steps * batch_size / dataset_len)
|
|
|
|
|
|
def save_models_to_disk(models_prefix_l, epoch, total_steps, checkpoints_dir=None):
|
|
for model, prefix in models_prefix_l:
|
|
if model is not None:
|
|
save_model_to_disk(model, epoch, total_steps, checkpoints_dir, prefix=f'{prefix}_')
|
|
for submodule_key, submodule_value in model.submodules.items():
|
|
save_model_to_disk(submodule_value, epoch, total_steps, checkpoints_dir,
|
|
prefix=f'{prefix}_{submodule_key}_')
|
|
|
|
|
|
def save_model_to_disk(model, epoch, total_steps, checkpoints_dir=None, prefix='model_'):
|
|
# If the model is frozen we do not save it again, since the parameters did not change
|
|
if hasattr(model, 'is_frozen') and model.is_frozen:
|
|
return
|
|
|
|
torch.save(model.state_dict(), os.path.join(checkpoints_dir, f'{prefix}current_state_dict.pth'))
|
|
torch.save(model.state_dict(), os.path.join(checkpoints_dir, f'{prefix}epoch_{epoch:04d}_iter_{total_steps:06d}_state_dict.pth'))
|
|
torch.save(model, os.path.join(checkpoints_dir, f'{prefix}current.pth'))
|
|
torch.save(model, os.path.join(checkpoints_dir, f'{prefix}epoch_{epoch:04d}_iter_{total_steps:06d}.pth'))
|
|
|
|
|
|
def save_losses_to_disk(train_losses, val_losses, checkpoints_dir=None):
|
|
np.save(os.path.join(checkpoints_dir, f'train_losses.npy'), train_losses)
|
|
np.save(os.path.join(checkpoints_dir, f'val_losses.npy'), val_losses)
|
|
|
|
|
|
class EarlyStopper:
|
|
# https://stackoverflow.com/questions/71998978/early-stopping-in-pytorch
|
|
|
|
def __init__(self, patience=10, min_delta=0):
|
|
self.patience = patience # use -1 to deactivate it
|
|
self.min_delta = min_delta
|
|
self.counter = 0
|
|
self.min_validation_loss = torch.inf
|
|
|
|
def early_stop(self, validation_loss):
|
|
if self.patience == -1:
|
|
return
|
|
if validation_loss < self.min_validation_loss:
|
|
self.min_validation_loss = validation_loss
|
|
self.counter = 0
|
|
elif validation_loss > (self.min_validation_loss + self.min_delta):
|
|
self.counter += 1
|
|
if self.counter >= self.patience:
|
|
return True
|
|
return False
|
|
|
|
|
|
class EMA:
|
|
"""
|
|
https://github.com/jannerm/diffuser
|
|
(empirical) exponential moving average parameters
|
|
"""
|
|
|
|
def __init__(self, beta=0.995):
|
|
super().__init__()
|
|
self.beta = beta
|
|
|
|
def update_model_average(self, ema_model, current_model):
|
|
for ema_params, current_params in zip(ema_model.parameters(), current_model.parameters()):
|
|
old_weight, up_weight = ema_params.data, current_params.data
|
|
ema_params.data = self.update_average(old_weight, up_weight)
|
|
|
|
def update_average(self, old, new):
|
|
if old is None:
|
|
return new
|
|
return old * self.beta + (1 - self.beta) * new
|
|
|
|
|
|
def do_summary(
|
|
summary_fn,
|
|
train_steps_current,
|
|
model,
|
|
batch_dict,
|
|
loss_info,
|
|
datasubset,
|
|
**kwargs
|
|
):
|
|
if summary_fn is None:
|
|
return
|
|
|
|
with torch.no_grad():
|
|
# set model to evaluation mode
|
|
model.eval()
|
|
|
|
summary_fn(train_steps_current,
|
|
model,
|
|
batch_dict=batch_dict,
|
|
loss_info=loss_info,
|
|
datasubset=datasubset,
|
|
**kwargs
|
|
)
|
|
|
|
# set model to training mode
|
|
model.train()
|
|
|
|
|
|
def train(model=None, train_dataloader=None, epochs=None, lr=None, steps_til_summary=None, model_dir=None, loss_fn=None,
|
|
train_subset=None,
|
|
summary_fn=None, steps_til_checkpoint=None,
|
|
val_dataloader=None, val_subset=None,
|
|
clip_grad=False,
|
|
clip_grad_max_norm=1.0,
|
|
val_loss_fn=None,
|
|
optimizers=None, steps_per_validation=10, max_steps=None,
|
|
use_ema: bool = True,
|
|
ema_decay: float = 0.995, step_start_ema: int = 1000, update_ema_every: int = 10,
|
|
use_amp=False,
|
|
early_stopper_patience=-1,
|
|
debug=False,
|
|
tensor_args=DEFAULT_TENSOR_ARGS,
|
|
**kwargs
|
|
):
|
|
|
|
print(f'\n------- TRAINING STARTED -------\n')
|
|
|
|
ema_model = None
|
|
if use_ema:
|
|
# Exponential moving average model
|
|
ema = EMA(beta=ema_decay)
|
|
ema_model = copy.deepcopy(model)
|
|
|
|
# Model optimizers
|
|
if optimizers is None:
|
|
optimizers = [torch.optim.Adam(lr=lr, params=model.parameters())]
|
|
|
|
# Automatic Mixed Precision
|
|
scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
|
|
|
|
if val_dataloader is not None:
|
|
assert val_loss_fn is not None, "If validation set is passed, have to pass a validation loss_fn!"
|
|
|
|
## Build saving directories
|
|
os.makedirs(model_dir, exist_ok=True)
|
|
|
|
summaries_dir = os.path.join(model_dir, 'summaries')
|
|
os.makedirs(summaries_dir, exist_ok=True)
|
|
|
|
checkpoints_dir = os.path.join(model_dir, 'checkpoints')
|
|
os.makedirs(checkpoints_dir, exist_ok=True)
|
|
|
|
# Early stopping
|
|
early_stopper = EarlyStopper(patience=early_stopper_patience, min_delta=0)
|
|
|
|
stop_training = False
|
|
train_steps_current = 0
|
|
|
|
# save models before training
|
|
save_models_to_disk([(model, 'model'), (ema_model, 'ema_model')], 0, 0, checkpoints_dir)
|
|
|
|
with tqdm(total=len(train_dataloader) * epochs, mininterval=1 if debug else 60) as pbar:
|
|
train_losses_l = []
|
|
validation_losses_l = []
|
|
for epoch in range(epochs):
|
|
model.train() # set model to training mode
|
|
for step, train_batch_dict in enumerate(train_dataloader):
|
|
####################################################################################################
|
|
# TRAINING LOSS
|
|
####################################################################################################
|
|
with TimerCUDA() as t_training_loss:
|
|
train_batch_dict = dict_to_device(train_batch_dict, tensor_args['device'])
|
|
|
|
# Compute losses
|
|
with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=use_amp):
|
|
train_losses, train_losses_info = loss_fn(model, train_batch_dict, train_subset.dataset)
|
|
|
|
train_loss_batch = 0.
|
|
train_losses_log = {}
|
|
for loss_name, loss in train_losses.items():
|
|
single_loss = loss.mean()
|
|
train_loss_batch += single_loss
|
|
train_losses_log[loss_name] = to_numpy(single_loss).item()
|
|
|
|
####################################################################################################
|
|
# SUMMARY
|
|
if train_steps_current % steps_til_summary == 0:
|
|
# TRAINING
|
|
print(f"\n-----------------------------------------")
|
|
print(f"train_steps_current: {train_steps_current}")
|
|
print(f"t_training_loss: {t_training_loss.elapsed:.4f} sec")
|
|
print(f"Total training loss {train_loss_batch:.4f}")
|
|
print(f"Training losses {train_losses}")
|
|
|
|
train_losses_l.append((train_steps_current, train_losses_log))
|
|
|
|
with TimerCUDA() as t_training_summary:
|
|
do_summary(
|
|
summary_fn,
|
|
train_steps_current,
|
|
ema_model if ema_model is not None else model,
|
|
train_batch_dict,
|
|
train_losses_info,
|
|
train_subset,
|
|
prefix='TRAINING ',
|
|
debug=debug,
|
|
tensor_args=tensor_args
|
|
)
|
|
print(f"t_training_summary: {t_training_summary.elapsed:.4f} sec")
|
|
|
|
################################################################################################
|
|
# VALIDATION LOSS and SUMMARY
|
|
validation_losses_log = {}
|
|
if val_dataloader is not None:
|
|
with TimerCUDA() as t_validation_loss:
|
|
print("Running validation...")
|
|
val_losses = defaultdict(list)
|
|
total_val_loss = 0.
|
|
for step_val, batch_dict_val in enumerate(val_dataloader):
|
|
batch_dict_val = dict_to_device(batch_dict_val, tensor_args['device'])
|
|
val_loss, val_loss_info = loss_fn(
|
|
model, batch_dict_val, val_subset.dataset, step=train_steps_current)
|
|
for name, value in val_loss.items():
|
|
single_loss = to_numpy(value)
|
|
val_losses[name].append(single_loss)
|
|
total_val_loss += np.mean(single_loss).item()
|
|
|
|
if step_val == steps_per_validation:
|
|
break
|
|
|
|
validation_losses = {}
|
|
for loss_name, loss in val_losses.items():
|
|
single_loss = np.mean(loss).item()
|
|
validation_losses[f'VALIDATION {loss_name}'] = single_loss
|
|
print("... finished validation.")
|
|
|
|
print(f"t_validation_loss: {t_validation_loss.elapsed:.4f} sec")
|
|
print(f"Validation losses {validation_losses}")
|
|
|
|
validation_losses_log = validation_losses
|
|
validation_losses_l.append((train_steps_current, validation_losses_log))
|
|
|
|
# The validation summary is done only on one batch of the validation data
|
|
with TimerCUDA() as t_validation_summary:
|
|
do_summary(
|
|
summary_fn,
|
|
train_steps_current,
|
|
ema_model if ema_model is not None else model,
|
|
batch_dict_val,
|
|
val_loss_info,
|
|
val_subset,
|
|
prefix='VALIDATION ',
|
|
debug=debug,
|
|
tensor_args=tensor_args
|
|
)
|
|
print(f"t_valididation_summary: {t_validation_summary.elapsed:.4f} sec")
|
|
|
|
wandb.log({**train_losses_log, **validation_losses_log}, step=train_steps_current)
|
|
|
|
####################################################################################################
|
|
# Early stopping
|
|
if early_stopper.early_stop(total_val_loss):
|
|
print(f'Early stopped training at {train_steps_current} steps.')
|
|
stop_training = True
|
|
|
|
####################################################################################################
|
|
# OPTIMIZE TRAIN LOSS BATCH
|
|
with TimerCUDA() as t_training_optimization:
|
|
for optim in optimizers:
|
|
optim.zero_grad()
|
|
|
|
scaler.scale(train_loss_batch).backward()
|
|
|
|
if clip_grad:
|
|
for optim in optimizers:
|
|
scaler.unscale_(optim)
|
|
torch.nn.utils.clip_grad_norm_(
|
|
model.parameters(),
|
|
max_norm=clip_grad_max_norm if isinstance(clip_grad, bool) else clip_grad
|
|
)
|
|
|
|
for optim in optimizers:
|
|
scaler.step(optim)
|
|
|
|
scaler.update()
|
|
|
|
if ema_model is not None:
|
|
if train_steps_current % update_ema_every == 0:
|
|
# update ema
|
|
if train_steps_current < step_start_ema:
|
|
# reset parameters ema
|
|
ema_model.load_state_dict(model.state_dict())
|
|
ema.update_model_average(ema_model, model)
|
|
|
|
if train_steps_current % steps_til_summary == 0:
|
|
print(f"t_training_optimization: {t_training_optimization.elapsed:.4f} sec")
|
|
|
|
####################################################################################################
|
|
# SAVING
|
|
####################################################################################################
|
|
pbar.update(1)
|
|
train_steps_current += 1
|
|
|
|
if (steps_til_checkpoint is not None) and (train_steps_current % steps_til_checkpoint == 0):
|
|
save_models_to_disk([(model, 'model'), (ema_model, 'ema_model')],
|
|
epoch, train_steps_current, checkpoints_dir)
|
|
save_losses_to_disk(train_losses_l, validation_losses_l, checkpoints_dir)
|
|
|
|
if stop_training or (max_steps is not None and train_steps_current == max_steps):
|
|
break
|
|
|
|
if max_steps is not None and train_steps_current == max_steps:
|
|
break
|
|
|
|
# Update ema model at the end of training
|
|
if ema_model is not None:
|
|
# update ema
|
|
if train_steps_current < step_start_ema:
|
|
# reset parameters ema
|
|
ema_model.load_state_dict(model.state_dict())
|
|
ema.update_model_average(ema_model, model)
|
|
|
|
# Save model at end of training
|
|
save_models_to_disk([(model, 'model'), (ema_model, 'ema_model')],
|
|
epoch, train_steps_current, checkpoints_dir)
|
|
save_losses_to_disk(train_losses_l, validation_losses_l, checkpoints_dir)
|
|
|
|
print(f'\n------- TRAINING FINISHED -------')
|