import copy from math import ceil import numpy as np import os import time import torch import wandb from collections import defaultdict from tqdm.autonotebook import tqdm from torch_robotics.torch_utils.torch_timer import TimerCUDA from torch_robotics.torch_utils.torch_utils import dict_to_device, DEFAULT_TENSOR_ARGS, to_numpy def get_num_epochs(num_train_steps, batch_size, dataset_len): return ceil(num_train_steps * batch_size / dataset_len) def save_models_to_disk(models_prefix_l, epoch, total_steps, checkpoints_dir=None): for model, prefix in models_prefix_l: if model is not None: save_model_to_disk(model, epoch, total_steps, checkpoints_dir, prefix=f'{prefix}_') for submodule_key, submodule_value in model.submodules.items(): save_model_to_disk(submodule_value, epoch, total_steps, checkpoints_dir, prefix=f'{prefix}_{submodule_key}_') def save_model_to_disk(model, epoch, total_steps, checkpoints_dir=None, prefix='model_'): # If the model is frozen we do not save it again, since the parameters did not change if hasattr(model, 'is_frozen') and model.is_frozen: return torch.save(model.state_dict(), os.path.join(checkpoints_dir, f'{prefix}current_state_dict.pth')) torch.save(model.state_dict(), os.path.join(checkpoints_dir, f'{prefix}epoch_{epoch:04d}_iter_{total_steps:06d}_state_dict.pth')) torch.save(model, os.path.join(checkpoints_dir, f'{prefix}current.pth')) torch.save(model, os.path.join(checkpoints_dir, f'{prefix}epoch_{epoch:04d}_iter_{total_steps:06d}.pth')) def save_losses_to_disk(train_losses, val_losses, checkpoints_dir=None): np.save(os.path.join(checkpoints_dir, f'train_losses.npy'), train_losses) np.save(os.path.join(checkpoints_dir, f'val_losses.npy'), val_losses) class EarlyStopper: # https://stackoverflow.com/questions/71998978/early-stopping-in-pytorch def __init__(self, patience=10, min_delta=0): self.patience = patience # use -1 to deactivate it self.min_delta = min_delta self.counter = 0 self.min_validation_loss = torch.inf def early_stop(self, validation_loss): if self.patience == -1: return if validation_loss < self.min_validation_loss: self.min_validation_loss = validation_loss self.counter = 0 elif validation_loss > (self.min_validation_loss + self.min_delta): self.counter += 1 if self.counter >= self.patience: return True return False class EMA: """ https://github.com/jannerm/diffuser (empirical) exponential moving average parameters """ def __init__(self, beta=0.995): super().__init__() self.beta = beta def update_model_average(self, ema_model, current_model): for ema_params, current_params in zip(ema_model.parameters(), current_model.parameters()): old_weight, up_weight = ema_params.data, current_params.data ema_params.data = self.update_average(old_weight, up_weight) def update_average(self, old, new): if old is None: return new return old * self.beta + (1 - self.beta) * new def do_summary( summary_fn, train_steps_current, model, batch_dict, loss_info, datasubset, **kwargs ): if summary_fn is None: return with torch.no_grad(): # set model to evaluation mode model.eval() summary_fn(train_steps_current, model, batch_dict=batch_dict, loss_info=loss_info, datasubset=datasubset, **kwargs ) # set model to training mode model.train() def train(model=None, train_dataloader=None, epochs=None, lr=None, steps_til_summary=None, model_dir=None, loss_fn=None, train_subset=None, summary_fn=None, steps_til_checkpoint=None, val_dataloader=None, val_subset=None, clip_grad=False, clip_grad_max_norm=1.0, val_loss_fn=None, optimizers=None, steps_per_validation=10, max_steps=None, use_ema: bool = True, ema_decay: float = 0.995, step_start_ema: int = 1000, update_ema_every: int = 10, use_amp=False, early_stopper_patience=-1, debug=False, tensor_args=DEFAULT_TENSOR_ARGS, **kwargs ): print(f'\n------- TRAINING STARTED -------\n') ema_model = None if use_ema: # Exponential moving average model ema = EMA(beta=ema_decay) ema_model = copy.deepcopy(model) # Model optimizers if optimizers is None: optimizers = [torch.optim.Adam(lr=lr, params=model.parameters())] # Automatic Mixed Precision scaler = torch.cuda.amp.GradScaler(enabled=use_amp) if val_dataloader is not None: assert val_loss_fn is not None, "If validation set is passed, have to pass a validation loss_fn!" ## Build saving directories os.makedirs(model_dir, exist_ok=True) summaries_dir = os.path.join(model_dir, 'summaries') os.makedirs(summaries_dir, exist_ok=True) checkpoints_dir = os.path.join(model_dir, 'checkpoints') os.makedirs(checkpoints_dir, exist_ok=True) # Early stopping early_stopper = EarlyStopper(patience=early_stopper_patience, min_delta=0) stop_training = False train_steps_current = 0 # save models before training save_models_to_disk([(model, 'model'), (ema_model, 'ema_model')], 0, 0, checkpoints_dir) with tqdm(total=len(train_dataloader) * epochs, mininterval=1 if debug else 60) as pbar: train_losses_l = [] validation_losses_l = [] for epoch in range(epochs): model.train() # set model to training mode for step, train_batch_dict in enumerate(train_dataloader): #################################################################################################### # TRAINING LOSS #################################################################################################### with TimerCUDA() as t_training_loss: train_batch_dict = dict_to_device(train_batch_dict, tensor_args['device']) # Compute losses with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=use_amp): train_losses, train_losses_info = loss_fn(model, train_batch_dict, train_subset.dataset) train_loss_batch = 0. train_losses_log = {} for loss_name, loss in train_losses.items(): single_loss = loss.mean() train_loss_batch += single_loss train_losses_log[loss_name] = to_numpy(single_loss).item() #################################################################################################### # SUMMARY if train_steps_current % steps_til_summary == 0: # TRAINING print(f"\n-----------------------------------------") print(f"train_steps_current: {train_steps_current}") print(f"t_training_loss: {t_training_loss.elapsed:.4f} sec") print(f"Total training loss {train_loss_batch:.4f}") print(f"Training losses {train_losses}") train_losses_l.append((train_steps_current, train_losses_log)) with TimerCUDA() as t_training_summary: do_summary( summary_fn, train_steps_current, ema_model if ema_model is not None else model, train_batch_dict, train_losses_info, train_subset, prefix='TRAINING ', debug=debug, tensor_args=tensor_args ) print(f"t_training_summary: {t_training_summary.elapsed:.4f} sec") ################################################################################################ # VALIDATION LOSS and SUMMARY validation_losses_log = {} if val_dataloader is not None: with TimerCUDA() as t_validation_loss: print("Running validation...") val_losses = defaultdict(list) total_val_loss = 0. for step_val, batch_dict_val in enumerate(val_dataloader): batch_dict_val = dict_to_device(batch_dict_val, tensor_args['device']) val_loss, val_loss_info = loss_fn( model, batch_dict_val, val_subset.dataset, step=train_steps_current) for name, value in val_loss.items(): single_loss = to_numpy(value) val_losses[name].append(single_loss) total_val_loss += np.mean(single_loss).item() if step_val == steps_per_validation: break validation_losses = {} for loss_name, loss in val_losses.items(): single_loss = np.mean(loss).item() validation_losses[f'VALIDATION {loss_name}'] = single_loss print("... finished validation.") print(f"t_validation_loss: {t_validation_loss.elapsed:.4f} sec") print(f"Validation losses {validation_losses}") validation_losses_log = validation_losses validation_losses_l.append((train_steps_current, validation_losses_log)) # The validation summary is done only on one batch of the validation data with TimerCUDA() as t_validation_summary: do_summary( summary_fn, train_steps_current, ema_model if ema_model is not None else model, batch_dict_val, val_loss_info, val_subset, prefix='VALIDATION ', debug=debug, tensor_args=tensor_args ) print(f"t_valididation_summary: {t_validation_summary.elapsed:.4f} sec") wandb.log({**train_losses_log, **validation_losses_log}, step=train_steps_current) #################################################################################################### # Early stopping if early_stopper.early_stop(total_val_loss): print(f'Early stopped training at {train_steps_current} steps.') stop_training = True #################################################################################################### # OPTIMIZE TRAIN LOSS BATCH with TimerCUDA() as t_training_optimization: for optim in optimizers: optim.zero_grad() scaler.scale(train_loss_batch).backward() if clip_grad: for optim in optimizers: scaler.unscale_(optim) torch.nn.utils.clip_grad_norm_( model.parameters(), max_norm=clip_grad_max_norm if isinstance(clip_grad, bool) else clip_grad ) for optim in optimizers: scaler.step(optim) scaler.update() if ema_model is not None: if train_steps_current % update_ema_every == 0: # update ema if train_steps_current < step_start_ema: # reset parameters ema ema_model.load_state_dict(model.state_dict()) ema.update_model_average(ema_model, model) if train_steps_current % steps_til_summary == 0: print(f"t_training_optimization: {t_training_optimization.elapsed:.4f} sec") #################################################################################################### # SAVING #################################################################################################### pbar.update(1) train_steps_current += 1 if (steps_til_checkpoint is not None) and (train_steps_current % steps_til_checkpoint == 0): save_models_to_disk([(model, 'model'), (ema_model, 'ema_model')], epoch, train_steps_current, checkpoints_dir) save_losses_to_disk(train_losses_l, validation_losses_l, checkpoints_dir) if stop_training or (max_steps is not None and train_steps_current == max_steps): break if max_steps is not None and train_steps_current == max_steps: break # Update ema model at the end of training if ema_model is not None: # update ema if train_steps_current < step_start_ema: # reset parameters ema ema_model.load_state_dict(model.state_dict()) ema.update_model_average(ema_model, model) # Save model at end of training save_models_to_disk([(model, 'model'), (ema_model, 'ema_model')], epoch, train_steps_current, checkpoints_dir) save_losses_to_disk(train_losses_l, validation_losses_l, checkpoints_dir) print(f'\n------- TRAINING FINISHED -------')