925 lines
40 KiB
Python
925 lines
40 KiB
Python
# Copyright (c) EVAR Lab, IIIS, Tsinghua University.
|
|
#
|
|
# This source code is licensed under the GNU License, Version 3.0
|
|
# found in the LICENSE file in the root directory of this source tree.
|
|
|
|
import copy
|
|
import os
|
|
import time
|
|
# import SMOS
|
|
import ray
|
|
import torch
|
|
import wandb
|
|
import logging
|
|
import random
|
|
import numpy as np
|
|
import torch.optim as optim
|
|
import torch.distributed as dist
|
|
import torch.nn.functional as F
|
|
|
|
from pathlib import Path
|
|
from tqdm.auto import tqdm
|
|
from torch.nn import L1Loss
|
|
from torch.cuda.amp import autocast as autocast
|
|
from torch.cuda.amp import GradScaler as GradScaler
|
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
|
|
from ez.utils.format import get_ddp_model_weights, DiscreteSupport, symexp
|
|
from ez.utils.loss import kl_loss, cosine_similarity_loss, continuous_loss, symlog_loss, Value_loss
|
|
from ez.data.trajectory import GameTrajectory
|
|
from ez.data.augmentation import Transforms
|
|
|
|
def DDP_setup(**kwargs):
|
|
# set master nod
|
|
os.environ['MASTER_ADDR'] = kwargs.get('address')
|
|
# os.environ['MASTER_PORT'] = kwargs.get('port')
|
|
|
|
# initialize the process group
|
|
try:
|
|
dist.init_process_group('nccl', rank=kwargs.get('rank'), world_size=kwargs.get('world_size') * kwargs.get('training_size'))
|
|
except:
|
|
dist.init_process_group('gloo', rank=kwargs.get('rank'), world_size=kwargs.get('world_size') * kwargs.get('training_size'))
|
|
|
|
print(f'DDP backend={dist.get_backend()}')
|
|
|
|
class Agent:
|
|
def __init__(self, config):
|
|
self.config = config
|
|
self.transforms = None
|
|
self.obs_shape = None
|
|
self.input_shape = None
|
|
self.action_space_size = None
|
|
self._update = False
|
|
self.use_ddp = True if config.ddp.world_size > 1 or config.ddp.training_size > 1 else False
|
|
|
|
def update_config(self):
|
|
raise NotImplementedError
|
|
|
|
def train(self, rank, replay_buffer, storage, batch_storage, logger):
|
|
assert self._update
|
|
# update image augmentation transform
|
|
self.update_augmentation_transform()
|
|
|
|
# save path
|
|
model_path = Path(self.config.save_path) / 'models'
|
|
model_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
is_main_process = (rank == 0)
|
|
if is_main_process:
|
|
train_logger = logging.getLogger('Train')
|
|
eval_logger = logging.getLogger('Eval')
|
|
|
|
train_logger.info('config: {}'.format(self.config))
|
|
train_logger.info('save model in: {}'.format(model_path))
|
|
|
|
# prepare model
|
|
model = self.build_model().cuda()
|
|
target_model = self.build_model().cuda()
|
|
# load model
|
|
load_path = self.config.train.load_model_path
|
|
if os.path.exists(load_path):
|
|
if is_main_process:
|
|
train_logger.info('resume model from path: {}'.format(load_path))
|
|
weights = torch.load(load_path)
|
|
storage.set_weights.remote(weights, 'self_play')
|
|
storage.set_weights.remote(weights, 'reanalyze')
|
|
storage.set_weights.remote(weights, 'latest')
|
|
model.load_state_dict(weights)
|
|
target_model.load_state_dict(weights)
|
|
|
|
# DDP
|
|
if self.use_ddp:
|
|
model = DDP(model, device_ids=[rank])
|
|
|
|
if int(torch.__version__[0]) == 2:
|
|
model = torch.compile(model)
|
|
target_model = torch.compile(target_model)
|
|
model.train()
|
|
target_model.eval()
|
|
|
|
# optimizer
|
|
if self.config.optimizer.type == 'SGD':
|
|
optimizer = optim.SGD(model.parameters(),
|
|
lr=self.config.optimizer.lr,
|
|
weight_decay=self.config.optimizer.weight_decay,
|
|
momentum=self.config.optimizer.momentum)
|
|
elif self.config.optimizer.type == 'Adam':
|
|
optimizer = optim.Adam(model.parameters(),
|
|
lr=self.config.optimizer.lr,
|
|
weight_decay=self.config.optimizer.weight_decay)
|
|
elif self.config.optimizer.type == 'AdamW':
|
|
optimizer = optim.AdamW(model.parameters(),
|
|
lr=self.config.optimizer.lr,
|
|
weight_decay=self.config.optimizer.weight_decay)
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
if self.config.optimizer.lr_decay_type == 'cosine':
|
|
max_steps = self.config.train.training_steps - int(self.config.train.training_steps * self.config.optimizer.lr_warm_up)
|
|
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_steps * 3, eta_min=0)
|
|
elif self.config.optimizer.lr_decay_type == 'full_cosine':
|
|
max_steps = self.config.train.training_steps - int(self.config.train.training_steps * self.config.optimizer.lr_warm_up)
|
|
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_steps // 2, eta_min=0)
|
|
else:
|
|
scheduler = None
|
|
|
|
scaler = GradScaler()
|
|
|
|
# wait until collecting enough data to start
|
|
while not (ray.get(replay_buffer.get_transition_num.remote()) >= self.config.train.start_transitions):
|
|
time.sleep(1)
|
|
pass
|
|
print('[Train] Begin training...')
|
|
|
|
# set signals for other workers
|
|
if is_main_process:
|
|
storage.set_start_signal.remote()
|
|
step_count = 0
|
|
|
|
# Note: the interval of the current model and the target model is between x and 2x. (x = target_model_interval)
|
|
# recent_weights is the param of the target model
|
|
recent_weights = self.get_weights(model)
|
|
|
|
# some logs
|
|
total_time = 0
|
|
total_steps = self.config.train.training_steps + self.config.train.offline_training_steps
|
|
if is_main_process:
|
|
pb = tqdm(np.arange(total_steps), leave=True)
|
|
|
|
# while loop
|
|
self_play_reteurn = 0.
|
|
traj_num, transition_num = 0, 0
|
|
eval_score, eval_best_score = 0., 0.
|
|
prev_eval_counter = -1
|
|
eval_counter = 0
|
|
|
|
while not self.is_finished(step_count):
|
|
start_time = time.time()
|
|
|
|
# obtain a batch
|
|
batch = batch_storage.pop()
|
|
end_time1 = time.time()
|
|
if batch is None:
|
|
time.sleep(0.3)
|
|
# print('batch is None')
|
|
continue
|
|
|
|
# adjust learning rate
|
|
if is_main_process:
|
|
storage.increase_counter.remote()
|
|
lr = self.adjust_lr(optimizer, step_count, scheduler)
|
|
|
|
if is_main_process and step_count % 30 == 0:
|
|
latest_weights = self.get_weights(model)
|
|
storage.set_weights.remote(latest_weights, 'latest')
|
|
|
|
# update model for self-play
|
|
if is_main_process and step_count % self.config.train.self_play_update_interval == 0:
|
|
weights = self.get_weights(model)
|
|
storage.set_weights.remote(weights, 'self_play')
|
|
|
|
# update model for reanalyzing
|
|
if is_main_process and step_count % self.config.train.reanalyze_update_interval == 0:
|
|
storage.set_weights.remote(recent_weights, 'reanalyze')
|
|
target_model.set_weights(recent_weights)
|
|
target_model.cuda()
|
|
target_model.eval()
|
|
recent_weights = self.get_weights(model)
|
|
|
|
if step_count % self.config.train.eval_interval == 0:
|
|
if eval_counter == prev_eval_counter:
|
|
time.sleep(1)
|
|
continue
|
|
|
|
scalers, log_data = self.update_weights(model, batch, optimizer, replay_buffer, scaler, step_count, target_model=target_model)
|
|
scaler = scalers[0]
|
|
|
|
loss_data, other_scalar, other_distribution = log_data
|
|
|
|
# TODO: maybe this barrier can be removed
|
|
if self.config.ddp.training_size > 1 or self.config.ddp.world_size > 1:
|
|
dist.barrier()
|
|
|
|
# save models
|
|
if is_main_process and step_count % self.config.train.save_ckpt_interval == 0:
|
|
cur_model_path = model_path / 'model_{}.p'.format(step_count)
|
|
torch.save(self.get_weights(model), cur_model_path)
|
|
|
|
end_time = time.time()
|
|
total_time += end_time - start_time
|
|
|
|
step_count += 1
|
|
avg_time = total_time / step_count
|
|
log_scalars = {}
|
|
log_distribution = {}
|
|
|
|
pb_interval = 50
|
|
if is_main_process and step_count % pb_interval == 0:
|
|
left_steps = (self.config.train.training_steps + self.config.train.offline_training_steps - step_count)
|
|
left_time = (left_steps * avg_time) / 3600
|
|
batch_queue_size = batch_storage.get_len()
|
|
train_log_str = '[Train] {}, step {}/{}, {:.3f}h left. lr={:.3f}, avg time={:.3f}s, batchQ={}, '\
|
|
'self-play return={:.3f}, collect {}/{:.3f}k, eval score={:.3f}/{:.3f}. '\
|
|
'Loss: reward={:.3f}, value={:.3f}, policy={:.3f}, ' \
|
|
'consistency={:.3f}, entropy={:.3f}'\
|
|
''.format(self.config.env.game, step_count, total_steps, left_time, lr, avg_time,
|
|
batch_queue_size, self_play_reteurn, traj_num, transition_num / 1000,
|
|
eval_score, eval_best_score, loss_data['loss/value_prefix'],
|
|
loss_data['loss/value'], loss_data['loss/policy'],
|
|
loss_data['loss/consistency'], loss_data['loss/entropy'])
|
|
# print(f'target policy={batch[-1][-1][0, 0]}')
|
|
pb.set_description(train_log_str)
|
|
|
|
pb.update(pb_interval)
|
|
|
|
log_scalars.update({
|
|
'train/step_per_second (s)': end_time - start_time,
|
|
'train/total time (h)': total_time / 3600,
|
|
'train/avg time (s)': avg_time,
|
|
'train/lr': lr,
|
|
'train/queue size': batch_queue_size
|
|
})
|
|
|
|
if is_main_process and step_count % self.config.log.log_interval == 0:
|
|
# train_logger.info(train_log_str)
|
|
# self-play statistics
|
|
eval_scalar, remote_scalar, remote_distribution = ray.get(storage.get_log.remote())
|
|
log_scalars.update(remote_scalar)
|
|
log_distribution.update(remote_distribution)
|
|
|
|
if remote_scalar.get('self_play/episode_return'):
|
|
self_play_reteurn = remote_scalar.get('self_play/episode_return')
|
|
if len(eval_scalar) > 0:
|
|
eval_score = eval_scalar['eval/mean_score']
|
|
min_score, max_score = eval_scalar['eval/min_score'], eval_scalar['eval/max_score']
|
|
eval_counter, eval_best_score = ray.get([storage.get_eval_counter.remote(), storage.get_best_score.remote()])
|
|
|
|
eval_log_str = 'Eval {} at at step {}, score = {:.3f}(min: {:.3f}, max: {:.3f}), ' \
|
|
'best score over past evaluation = {:.3f}' \
|
|
''.format(self.config.env.game, eval_counter, eval_score, min_score, max_score,
|
|
eval_best_score)
|
|
eval_logger.info(eval_log_str)
|
|
# TODO: fix the counter issue
|
|
# logger.log(eval_scalar, eval_counter)
|
|
logger.log(eval_scalar, step_count)
|
|
print('[Eval] ', eval_log_str)
|
|
|
|
# replay statistics
|
|
traj_num, transition_num, total_priorities = ray.get([
|
|
replay_buffer.get_traj_num.remote(), replay_buffer.get_transition_num.remote(), replay_buffer.get_priorities.remote()
|
|
])
|
|
log_scalars.update({
|
|
'buffer/total_episode_num': traj_num,
|
|
'buffer/total_transition_num': transition_num
|
|
})
|
|
log_distribution.update({
|
|
'dist/priorities_in_buffer': total_priorities,
|
|
})
|
|
log_distribution.update(other_distribution)
|
|
self.log_hist(logger, log_distribution, step_count)
|
|
|
|
if step_count % 20000 == 0 and self.config.train.periodic_reset:
|
|
print('-------------------------reset network------------------------------')
|
|
model = self.periodic_reset_model(model)
|
|
|
|
# training statistics
|
|
log_scalars.update(loss_data)
|
|
log_scalars.update(other_scalar)
|
|
if step_count > 500 and step_count % 1000 == 0:
|
|
logger.log(log_scalars, step_count)
|
|
|
|
traj_num, transition_num, total_priorities = ray.get([
|
|
replay_buffer.get_traj_num.remote(), replay_buffer.get_transition_num.remote(),
|
|
replay_buffer.get_priorities.remote()
|
|
])
|
|
log_distribution.update({
|
|
'dist/priorities_in_buffer': total_priorities,
|
|
})
|
|
log_distribution.update(other_distribution)
|
|
|
|
|
|
if is_main_process:
|
|
final_weights = self.get_weights(model)
|
|
storage.set_weights.remote(final_weights, 'self_play')
|
|
else:
|
|
final_weights = None
|
|
|
|
return final_weights, model
|
|
|
|
def reset_network(self, network):
|
|
for layer in network.children():
|
|
if hasattr(layer, 'reset_parameters'):
|
|
layer.reset_parameters()
|
|
|
|
def periodic_reset_model(self, model):
|
|
if self.config.env.image_based:
|
|
# reset prediction_backbone
|
|
self.reset_network(model.value_policy_model.resblocks)
|
|
|
|
# # reset policy
|
|
self.reset_network(model.value_policy_model.conv1x1_policy)
|
|
self.reset_network(model.value_policy_model.bn_policy)
|
|
self.reset_network(model.value_policy_model.fc_policy)
|
|
#
|
|
# # reset value
|
|
self.reset_network(model.value_policy_model.conv1x1_values)
|
|
self.reset_network(model.value_policy_model.bn_values)
|
|
self.reset_network(model.value_policy_model.fc_values)
|
|
|
|
else:
|
|
self.reset_network(model.value_policy_model.val_resblock)
|
|
self.reset_network(model.value_policy_model.pi_resblock)
|
|
self.reset_network(model.value_policy_model.val_ln)
|
|
self.reset_network(model.value_policy_model.pi_ln)
|
|
self.reset_network(model.value_policy_model.val_net)
|
|
self.reset_network(model.value_policy_model.pi_net)
|
|
|
|
return model
|
|
|
|
# @profile
|
|
def update_weights(self, model, batch, optimizer, replay_buffer, scaler, step_count, target_model=None):
|
|
target_model.eval()
|
|
# init
|
|
batch_size = self.config.train.batch_size
|
|
image_channel = self.config.env.obs_shape[0] if self.config.env.image_based else self.config.env.obs_shape
|
|
unroll_steps = self.config.rl.unroll_steps
|
|
n_stack = self.config.env.n_stack
|
|
gradient_scale = 1. / unroll_steps
|
|
reward_hidden = self.init_reward_hidden(batch_size)
|
|
loss_data = {}
|
|
other_scalar = {}
|
|
other_distribution = {}
|
|
|
|
# obtain the batch data
|
|
inputs_batch, targets_batch = batch
|
|
obs_batch_ori, action_batch, mask_batch, indices, weights_lst, make_time, prior_lst = inputs_batch
|
|
target_value_prefixes, target_values, target_actions, target_policies, target_best_actions, \
|
|
top_value_masks, mismatch_masks, search_values = targets_batch
|
|
target_value_prefixes = target_value_prefixes[:, :unroll_steps]
|
|
|
|
# obs_batch_raw: [s_{t - stack} ... s_{t} ... s_{t + unroll}]
|
|
if self.config.env.image_based:
|
|
obs_batch_raw = torch.from_numpy(obs_batch_ori).cuda().float() / 255.
|
|
else:
|
|
obs_batch_raw = torch.from_numpy(obs_batch_ori).cuda().float()
|
|
|
|
obs_batch = obs_batch_raw[:, 0: n_stack * image_channel] # obs_batch: current observation
|
|
obs_target_batch = obs_batch_raw[:, image_channel:] # obs_target_batch: observation of next steps
|
|
# if self.config.train.use_decorrelation:
|
|
# obs_batch_all = copy.deepcopy(obs_batch)
|
|
# for step_i in range(1, unroll_steps + 1):
|
|
# obs_batch_all = torch.cat((obs_batch_all, obs_batch_raw[:, step_i * image_channel: (step_i + n_stack) * image_channel]), dim=0)
|
|
|
|
# augmentation
|
|
obs_batch = self.transform(obs_batch)
|
|
obs_target_batch = self.transform(obs_target_batch)
|
|
# if self.config.train.use_decorrelation:
|
|
# obs_batch_aug1 = self.transform(obs_batch_all)
|
|
# obs_batch_aug2 = self.transform(obs_batch_all)
|
|
|
|
# others to gpu
|
|
if self.config.env.env in ['DMC', 'Gym']:
|
|
action_batch = torch.from_numpy(action_batch).float().cuda()
|
|
else:
|
|
action_batch = torch.from_numpy(action_batch).cuda().unsqueeze(-1).long()
|
|
mask_batch = torch.from_numpy(mask_batch).cuda().float()
|
|
weights = torch.from_numpy(weights_lst).cuda().float()
|
|
|
|
max_value_target = np.array([target_values, search_values]).max(0)
|
|
|
|
target_value_prefixes = torch.from_numpy(target_value_prefixes).cuda().float()
|
|
target_values = torch.from_numpy(target_values).cuda().float()
|
|
target_actions = torch.from_numpy(target_actions).cuda().float()
|
|
target_policies = torch.from_numpy(target_policies).cuda().float()
|
|
target_best_actions = torch.from_numpy(target_best_actions).cuda().float()
|
|
top_value_masks = torch.from_numpy(top_value_masks).cuda().float()
|
|
mismatch_masks = torch.from_numpy(mismatch_masks).cuda().float()
|
|
search_values = torch.from_numpy(search_values).cuda().float()
|
|
max_value_target = torch.from_numpy(max_value_target).cuda().float()
|
|
|
|
# transform value and reward to support
|
|
target_value_prefixes_support = DiscreteSupport.scalar_to_vector(target_value_prefixes, **self.config.model.reward_support)
|
|
|
|
with autocast():
|
|
states, values, policies = model.initial_inference(obs_batch, training=True)
|
|
|
|
if self.config.model.value_support.type == 'symlog':
|
|
scaled_value = symexp(values).min(0)[0]
|
|
else:
|
|
scaled_value = DiscreteSupport.vector_to_scalar(values, **self.config.model.value_support).min(0)[0]
|
|
if self.config.env.env in ['DMC', 'Gym']:
|
|
scaled_value = scaled_value.clip(0, 1e5)
|
|
|
|
# loss of first step
|
|
# multi options (Value Loss)
|
|
if self.config.train.value_target == 'sarsa':
|
|
this_target_values = target_values
|
|
elif self.config.train.value_target == 'search':
|
|
this_target_values = search_values
|
|
elif self.config.train.value_target == 'mixed':
|
|
if step_count < self.config.train.start_use_mix_training_steps:
|
|
this_target_values = target_values
|
|
else:
|
|
this_target_values = target_values * top_value_masks.unsqueeze(1).repeat(1, unroll_steps + 1) \
|
|
+ search_values * (1 - top_value_masks).unsqueeze(1).repeat(1, unroll_steps + 1)
|
|
elif self.config.train.value_target == 'max':
|
|
this_target_values = max_value_target
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
# update priority
|
|
fresh_priority = L1Loss(reduction='none')(scaled_value.squeeze(-1), this_target_values[:, 0]).detach().cpu().numpy()
|
|
fresh_priority += self.config.priority.min_prior
|
|
replay_buffer.update_priorities.remote(indices, fresh_priority, make_time)
|
|
|
|
value_loss = torch.zeros(batch_size).cuda()
|
|
value_loss += Value_loss(values, this_target_values[:, 0], self.config)
|
|
prev_values = values.clone()
|
|
|
|
if self.config.env.env in ['DMC', 'Gym']:
|
|
policy_loss, entropy_loss = continuous_loss(
|
|
policies, target_actions[:, 0], target_policies[:, 0],
|
|
target_best_actions[:, 0],
|
|
distribution_type=self.config.model.policy_distribution
|
|
)
|
|
mu = policies[:, :policies.shape[-1] // 2].detach().cpu().numpy().flatten()
|
|
sigma = policies[:, policies.shape[-1] // 2:].detach().cpu().numpy().flatten()
|
|
other_distribution.update({
|
|
'dist/policy_mu': mu,
|
|
'dist/policy_sigma': sigma,
|
|
})
|
|
|
|
else:
|
|
policy_loss = kl_loss(policies, target_policies[:, 0])
|
|
entropy_loss = torch.zeros(batch_size).cuda()
|
|
|
|
value_prefix_loss = torch.zeros(batch_size).cuda()
|
|
consistency_loss = torch.zeros(batch_size).cuda()
|
|
policy_entropy_loss = torch.zeros(batch_size).cuda()
|
|
policy_entropy_loss -= entropy_loss
|
|
|
|
prev_value_prefixes = torch.zeros_like(policy_loss)
|
|
# unroll k steps recurrently
|
|
with autocast():
|
|
for step_i in range(unroll_steps):
|
|
mask = mask_batch[:, step_i]
|
|
states, value_prefixes, values, policies, reward_hidden = model.recurrent_inference(states, action_batch[:, step_i], reward_hidden, training=True)
|
|
|
|
beg_index = image_channel * step_i
|
|
end_index = image_channel * (step_i + n_stack)
|
|
|
|
# consistency loss
|
|
gt_next_states = model.do_representation(obs_target_batch[:, beg_index:end_index])
|
|
|
|
# projection for consistency
|
|
dynamic_states_proj = model.do_projection(states, with_grad=True)
|
|
gt_states_proj = model.do_projection(gt_next_states, with_grad=False)
|
|
consistency_loss += cosine_similarity_loss(dynamic_states_proj, gt_states_proj) * mask
|
|
|
|
# reward, value, policy loss
|
|
if self.config.model.reward_support.type == 'symlog':
|
|
value_prefix_loss += symlog_loss(value_prefixes, target_value_prefixes[:, step_i]) * mask
|
|
else:
|
|
value_prefix_loss += kl_loss(value_prefixes, target_value_prefixes_support[:, step_i]) * mask
|
|
|
|
value_loss += Value_loss(values, this_target_values[:, step_i + 1], self.config) * mask
|
|
|
|
if self.config.env.env in ['DMC', 'Gym']:
|
|
policy_loss_i, entropy_loss_i = continuous_loss(
|
|
policies, target_actions[:, step_i + 1], target_policies[:, step_i + 1],
|
|
target_best_actions[:, step_i + 1],
|
|
mask=mask,
|
|
distribution_type=self.config.model.policy_distribution
|
|
)
|
|
policy_loss += policy_loss_i
|
|
policy_entropy_loss -= entropy_loss_i
|
|
else:
|
|
policy_loss_i = kl_loss(policies, target_policies[:, step_i + 1]) * mask
|
|
policy_loss += policy_loss_i
|
|
|
|
# set half gradient due to two branches of states
|
|
states.register_hook(lambda grad: grad * 0.5)
|
|
|
|
# reset reward hidden
|
|
if self.config.model.value_prefix and (step_i + 1) % self.config.model.lstm_horizon_len == 0:
|
|
reward_hidden = self.init_reward_hidden(batch_size)
|
|
|
|
# total loss
|
|
loss = (value_prefix_loss * self.config.train.reward_loss_coeff
|
|
+ value_loss * self.config.train.value_loss_coeff
|
|
+ policy_loss * self.config.train.policy_loss_coeff
|
|
+ consistency_loss * self.config.train.consistency_coeff)
|
|
|
|
if self.config.env.env in ['DMC', 'Gym']:
|
|
loss += policy_entropy_loss * self.config.train.entropy_coeff
|
|
|
|
weighted_loss = (weights * loss).mean()
|
|
|
|
if weighted_loss.isnan():
|
|
import ipdb
|
|
ipdb.set_trace()
|
|
print('loss nan')
|
|
|
|
# backward
|
|
parameters = model.parameters()
|
|
with autocast():
|
|
weighted_loss.register_hook(lambda grad: grad * gradient_scale)
|
|
optimizer.zero_grad()
|
|
scaler.scale(weighted_loss).backward()
|
|
scaler.unscale_(optimizer)
|
|
torch.nn.utils.clip_grad_norm_(parameters, self.config.train.max_grad_norm)
|
|
scaler.step(optimizer)
|
|
scaler.update()
|
|
|
|
if self.config.model.noisy_net:
|
|
model.value_policy_model.reset_noise()
|
|
target_model.value_policy_model.reset_noise()
|
|
|
|
# log
|
|
loss_data.update({
|
|
'loss/total': loss.mean().item(), 'loss/weighted': weighted_loss.mean().item(),
|
|
'loss/consistency': consistency_loss.mean().item(), 'loss/value_prefix': value_prefix_loss.mean().item(),
|
|
'loss/value': value_loss.mean().item(), 'loss/policy': policy_loss.mean().item(),
|
|
'loss/entropy': policy_entropy_loss.mean().item(),
|
|
|
|
})
|
|
|
|
other_scalar.update({
|
|
'other_log/target_value_prefix_max': target_value_prefixes.detach().cpu().numpy().max(),
|
|
'other_log/target_value_prefix_min': target_value_prefixes.detach().cpu().numpy().min(),
|
|
'other_log/target_value_prefix_mean': target_value_prefixes.detach().cpu().numpy().mean(),
|
|
'other_log/target_value_mean': target_values.detach().cpu().numpy().mean(),
|
|
'other_log/target_value_max': target_values.detach().cpu().numpy().max(),
|
|
'other_log/target_value_min': target_values.detach().cpu().numpy().min(),
|
|
'other_log/mismatch_num': batch_size * (unroll_steps + 1) - mismatch_masks.sum().detach().cpu().numpy()
|
|
})
|
|
|
|
other_distribution.update({
|
|
'dist/recent_priority': fresh_priority,
|
|
'dist/weights': weights.detach().cpu().numpy().flatten(),
|
|
'dist/prior_in_batch': prior_lst,
|
|
'dist/indices': indices.flatten(),
|
|
'dist/mask': mask.detach().cpu().numpy().flatten(),
|
|
'dist/target_policy': target_policies.detach().cpu().numpy().flatten(),
|
|
})
|
|
|
|
scalers = [scaler]
|
|
return scalers, (loss_data, other_scalar, other_distribution)
|
|
|
|
def get_weights(self, model):
|
|
if self.use_ddp:
|
|
return get_ddp_model_weights(model)
|
|
else:
|
|
return model.get_weights()
|
|
|
|
def adjust_lr(self, optimizer, step_count, scheduler):
|
|
lr_warm_step = int(self.config.train.training_steps * self.config.optimizer.lr_warm_up)
|
|
optimize_config = self.config.optimizer
|
|
|
|
# adjust learning rate, step lr every lr_decay_steps
|
|
if step_count < lr_warm_step:
|
|
lr = optimize_config.lr * step_count / lr_warm_step
|
|
for param_group in optimizer.param_groups:
|
|
param_group['lr'] = lr
|
|
else:
|
|
if self.config.optimizer.lr_decay_type == 'cosine':
|
|
if scheduler is not None:
|
|
scheduler.step()
|
|
lr = scheduler.get_last_lr()[0] # return a list
|
|
for param_group in optimizer.param_groups:
|
|
param_group['lr'] = lr
|
|
else:
|
|
lr = optimize_config.lr * optimize_config.lr_decay_rate ** (
|
|
(step_count - lr_warm_step) // optimize_config.lr_decay_steps)
|
|
for param_group in optimizer.param_groups:
|
|
param_group['lr'] = lr
|
|
|
|
return lr
|
|
|
|
def log_hist(self, logger, distribution_dict, step_count):
|
|
for key, hist in distribution_dict.items():
|
|
table = wandb.Histogram(hist, num_bins=200)
|
|
logger.log({key: table}, step_count)
|
|
|
|
def transform(self, observation):
|
|
if self.transforms is not None:
|
|
return self.transforms(observation)
|
|
else:
|
|
return observation
|
|
|
|
def build_model(self):
|
|
raise NotImplementedError
|
|
|
|
def update_augmentation_transform(self):
|
|
if self.config.augmentation and self.config.env.image_based:
|
|
self.transforms = Transforms(self.config.augmentation, image_shape=(self.obs_shape[1], self.obs_shape[2]))
|
|
|
|
def get_temperature(self, trained_steps):
|
|
if self.config.train.change_temperature:
|
|
total_steps = self.config.train.training_steps + self.config.train.offline_training_steps
|
|
# if self.config.env.env == 'Atari':
|
|
if trained_steps < 0.5 * total_steps: # prev 0.5
|
|
return 1.0
|
|
elif trained_steps < 0.75 * total_steps: # prev 0.75
|
|
return 0.5
|
|
else:
|
|
return 0.25
|
|
else:
|
|
return 1.0
|
|
|
|
def init_env(self, env, max_steps):
|
|
assert self._update
|
|
obs = env.reset()
|
|
|
|
traj = self.new_game(max_steps)
|
|
stacked_obs = [obs for _ in range(self.config.env.n_stack)]
|
|
traj.init(stacked_obs)
|
|
|
|
return stacked_obs, traj
|
|
|
|
def init_envs(self, envs, max_steps=None):
|
|
assert self._update
|
|
|
|
stacked_obs_lst, game_trajs = [], []
|
|
# initialization for envs, stack [n - 1 zero obs, current obs] for the n-stack obs
|
|
for env in envs:
|
|
stacked_obs, traj = self.init_env(env, max_steps)
|
|
stacked_obs_lst.append(stacked_obs)
|
|
game_trajs.append(traj)
|
|
return stacked_obs_lst, game_trajs
|
|
|
|
def init_reward_hidden(self, batch_size):
|
|
if self.config.model.value_prefix:
|
|
reward_hidden = (torch.zeros(1, batch_size, self.config.model.lstm_hidden_size).cuda(),
|
|
torch.zeros(1, batch_size, self.config.model.lstm_hidden_size).cuda())
|
|
else:
|
|
reward_hidden = None
|
|
return reward_hidden
|
|
|
|
def new_game(self, max_steps):
|
|
assert self._update
|
|
|
|
traj = GameTrajectory(**self.config.env, **self.config.rl, **self.config.model, trajectory_size=max_steps)
|
|
if max_steps is None:
|
|
traj.set_inf_len()
|
|
return traj
|
|
|
|
def is_finished(self, trained_steps):
|
|
if trained_steps >= self.config.train.training_steps + self.config.train.offline_training_steps:
|
|
time.sleep(1)
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
|
|
@ray.remote(num_gpus=0.55)
|
|
def train_ddp(agent, rank, replay_buffer, storage, batch_storage, logger):
|
|
print(f'training_rank={rank}')
|
|
if rank == 0:
|
|
wandb_name = agent.config.env.game + '-' + agent.config.wandb.tag
|
|
logger = wandb.init(
|
|
name=wandb_name,
|
|
project=agent.config.wandb.project,
|
|
# config=config,
|
|
)
|
|
assert agent._update
|
|
# update image augmentation transform
|
|
agent.update_augmentation_transform()
|
|
|
|
# save path
|
|
model_path = Path(agent.config.save_path) / 'models'
|
|
model_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
is_main_process = (rank == 0)
|
|
if is_main_process:
|
|
train_logger = logging.getLogger('Train')
|
|
eval_logger = logging.getLogger('Eval')
|
|
|
|
train_logger.info('config: {}'.format(agent.config))
|
|
train_logger.info('save model in: {}'.format(model_path))
|
|
|
|
# prepare model
|
|
model = agent.build_model().cuda()
|
|
target_model = agent.build_model().cuda()
|
|
# load model
|
|
load_path = agent.config.train.load_model_path
|
|
if os.path.exists(load_path):
|
|
if is_main_process:
|
|
train_logger.info('resume model from path: {}'.format(load_path))
|
|
weights = torch.load(load_path)
|
|
storage.set_weights.remote(weights, 'self_play')
|
|
storage.set_weights.remote(weights, 'reanalyze')
|
|
storage.set_weights.remote(weights, 'latest')
|
|
model.load_state_dict(weights)
|
|
target_model.load_state_dict(weights)
|
|
|
|
# DDP
|
|
if agent.use_ddp:
|
|
DDP_setup(rank=rank, world_size=agent.config.ddp.world_size, training_size=agent.config.ddp.training_size, address='127.0.0.1')
|
|
model = DDP(model, device_ids=[rank])
|
|
|
|
if int(torch.__version__[0]) == 2:
|
|
model = torch.compile(model)
|
|
target_model = torch.compile(target_model)
|
|
model.train()
|
|
target_model.eval()
|
|
|
|
# optimizer
|
|
if agent.config.optimizer.type == 'SGD':
|
|
optimizer = optim.SGD(model.parameters(),
|
|
lr=agent.config.optimizer.lr,
|
|
weight_decay=agent.config.optimizer.weight_decay,
|
|
momentum=agent.config.optimizer.momentum)
|
|
elif agent.config.optimizer.type == 'Adam':
|
|
optimizer = optim.Adam(model.parameters(),
|
|
lr=agent.config.optimizer.lr,
|
|
weight_decay=agent.config.optimizer.weight_decay)
|
|
elif agent.config.optimizer.type == 'AdamW':
|
|
optimizer = optim.AdamW(model.parameters(),
|
|
lr=agent.config.optimizer.lr,
|
|
weight_decay=agent.config.optimizer.weight_decay)
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
if agent.config.optimizer.lr_decay_type == 'cosine':
|
|
max_steps = agent.config.train.training_steps - int(agent.config.train.training_steps * agent.config.optimizer.lr_warm_up)
|
|
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_steps * 3, eta_min=0)
|
|
elif agent.config.optimizer.lr_decay_type == 'full_cosine':
|
|
max_steps = agent.config.train.training_steps - int(agent.config.train.training_steps * agent.config.optimizer.lr_warm_up)
|
|
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_steps // 2, eta_min=0)
|
|
else:
|
|
scheduler = None
|
|
|
|
scaler = GradScaler()
|
|
|
|
# wait until collecting enough data to start
|
|
while not (ray.get(replay_buffer.get_transition_num.remote()) >= agent.config.train.start_transitions):
|
|
time.sleep(1)
|
|
pass
|
|
print('[Train] Begin training...')
|
|
|
|
# set signals for other workers
|
|
if is_main_process:
|
|
storage.set_start_signal.remote()
|
|
step_count = 0
|
|
|
|
# Note: the interval of the current model and the target model is between x and 2x. (x = target_model_interval)
|
|
# recent_weights is the param of the target model
|
|
recent_weights = agent.get_weights(model)
|
|
|
|
# some logs
|
|
total_time = 0
|
|
total_steps = agent.config.train.training_steps + agent.config.train.offline_training_steps
|
|
if is_main_process:
|
|
pb = tqdm(np.arange(total_steps), leave=True)
|
|
|
|
# while loop
|
|
self_play_reteurn = 0.
|
|
traj_num, transition_num = 0, 0
|
|
eval_score, eval_best_score = 0., 0.
|
|
while not agent.is_finished(step_count):
|
|
start_time = time.time()
|
|
|
|
# obtain a batch
|
|
batch = batch_storage.pop()
|
|
end_time1 = time.time()
|
|
if batch is None:
|
|
time.sleep(0.3)
|
|
# print('batch is None')
|
|
continue
|
|
|
|
# adjust learning rate
|
|
if is_main_process:
|
|
storage.increase_counter.remote()
|
|
lr = agent.adjust_lr(optimizer, step_count, scheduler)
|
|
|
|
if is_main_process and step_count % 30 == 0:
|
|
latest_weights = agent.get_weights(model)
|
|
ray.get(storage.set_weights.remote(latest_weights, 'latest'))
|
|
|
|
# update model for agent-play
|
|
if is_main_process and step_count % agent.config.train.self_play_update_interval == 0:
|
|
weights = agent.get_weights(model)
|
|
storage.set_weights.remote(weights, 'self_play')
|
|
|
|
# update model for reanalyzing
|
|
if is_main_process and step_count % agent.config.train.reanalyze_update_interval == 0:
|
|
storage.set_weights.remote(recent_weights, 'reanalyze')
|
|
target_model.set_weights(recent_weights)
|
|
target_model.cuda()
|
|
target_model.eval()
|
|
recent_weights = agent.get_weights(model)
|
|
|
|
|
|
scalers, log_data = agent.update_weights(model.module, batch, optimizer, replay_buffer, scaler, step_count, target_model=target_model)
|
|
scaler = scalers[0]
|
|
|
|
loss_data, other_scalar, other_distribution = log_data
|
|
|
|
# TODO: maybe this barrier can be removed
|
|
if agent.config.ddp.training_size > 1 or agent.config.ddp.world_size > 1:
|
|
dist.barrier()
|
|
|
|
# save models
|
|
if is_main_process and step_count % agent.config.train.save_ckpt_interval == 0:
|
|
cur_model_path = model_path / 'model_{}.p'.format(step_count)
|
|
torch.save(agent.get_weights(model), cur_model_path)
|
|
|
|
end_time = time.time()
|
|
total_time += end_time - start_time
|
|
|
|
step_count += 1
|
|
avg_time = total_time / step_count
|
|
log_scalars = {}
|
|
log_distribution = {}
|
|
|
|
pb_interval = 50
|
|
if is_main_process and step_count % pb_interval == 0:
|
|
left_steps = (agent.config.train.training_steps + agent.config.train.offline_training_steps - step_count)
|
|
left_time = (left_steps * avg_time) / 3600
|
|
batch_queue_size = batch_storage.get_len()
|
|
train_log_str = '[Train] {}, step {}/{}, {:.3f}h left. lr={:.3f}, avg time={:.3f}s, batchQ={}, '\
|
|
'agent-play return={:.3f}, collect {}/{:.3f}k, eval score={:.3f}/{:.3f}. '\
|
|
'Loss: reward={:.3f}, value={:.3f}, policy={:.3f}, ' \
|
|
'consistency={:.3f}, entropy={:.3f}'\
|
|
''.format(agent.config.env.game, step_count, total_steps, left_time, lr, avg_time,
|
|
batch_queue_size, self_play_reteurn, traj_num, transition_num / 1000,
|
|
eval_score, eval_best_score, loss_data['loss/value_prefix'],
|
|
loss_data['loss/value'], loss_data['loss/policy'],
|
|
loss_data['loss/consistency'], loss_data['loss/entropy'])
|
|
# print(f'target policy={batch[-1][-1][0, 0]}')
|
|
pb.set_description(train_log_str)
|
|
|
|
pb.update(pb_interval)
|
|
|
|
log_scalars.update({
|
|
'train/step_per_second (s)': end_time - start_time,
|
|
'train/total time (h)': total_time / 3600,
|
|
'train/avg time (s)': avg_time,
|
|
'train/lr': lr,
|
|
'train/queue size': batch_queue_size
|
|
})
|
|
|
|
if is_main_process and step_count % agent.config.log.log_interval == 0:
|
|
# train_logger.info(train_log_str)
|
|
# agent-play statistics
|
|
eval_scalar, remote_scalar, remote_distribution = ray.get(storage.get_log.remote())
|
|
log_scalars.update(remote_scalar)
|
|
log_distribution.update(remote_distribution)
|
|
|
|
if remote_scalar.get('self_play/episode_return'):
|
|
self_play_reteurn = remote_scalar.get('self_play/episode_return')
|
|
if len(eval_scalar) > 0:
|
|
# TODO: fix the counter issue
|
|
# logger.log(eval_scalar, eval_counter)
|
|
logger.log(eval_scalar, step_count)
|
|
|
|
eval_score = eval_scalar['eval/mean_score']
|
|
min_score, max_score = eval_scalar['eval/min_score'], eval_scalar['eval/max_score']
|
|
eval_counter, eval_best_score = ray.get([storage.get_eval_counter.remote(), storage.get_best_score.remote()])
|
|
|
|
eval_log_str = 'Eval {} at at step {}, score = {:.3f}(min: {:.3f}, max: {:.3f}), ' \
|
|
'best score over past evaluation = {:.3f}' \
|
|
''.format(agent.config.env.game, eval_counter, eval_score, min_score, max_score,
|
|
eval_best_score)
|
|
eval_logger.info(eval_log_str)
|
|
print('[Eval] ', eval_log_str)
|
|
|
|
# replay statistics
|
|
traj_num, transition_num, total_priorities = ray.get([
|
|
replay_buffer.get_traj_num.remote(), replay_buffer.get_transition_num.remote(), replay_buffer.get_priorities.remote()
|
|
])
|
|
log_scalars.update({
|
|
'buffer/total_episode_num': traj_num,
|
|
'buffer/total_transition_num': transition_num
|
|
})
|
|
log_distribution.update({
|
|
'dist/priorities_in_buffer': total_priorities,
|
|
})
|
|
log_distribution.update(other_distribution)
|
|
agent.log_hist(logger, log_distribution, step_count)
|
|
|
|
if step_count % 20000 == 0 and agent.config.train.periodic_reset:
|
|
print('-------------------------reset network------------------------------')
|
|
model = agent.periodic_reset_model(model)
|
|
|
|
# training statistics
|
|
log_scalars.update(loss_data)
|
|
log_scalars.update(other_scalar)
|
|
if is_main_process and step_count > 100 and step_count % 1000 == 0:
|
|
logger.log(log_scalars, step_count)
|
|
|
|
traj_num, transition_num, total_priorities = ray.get([
|
|
replay_buffer.get_traj_num.remote(), replay_buffer.get_transition_num.remote(),
|
|
replay_buffer.get_priorities.remote()
|
|
])
|
|
log_distribution.update({
|
|
'dist/priorities_in_buffer': total_priorities,
|
|
})
|
|
log_distribution.update(other_distribution)
|
|
|
|
|
|
final_weights = agent.get_weights(model)
|
|
storage.set_weights.remote(final_weights, 'self_play')
|
|
|
|
return final_weights, model |