“Shengjiewang-Jason” 1367bca203 first commit
2024-06-07 16:02:01 +08:00

925 lines
40 KiB
Python

# Copyright (c) EVAR Lab, IIIS, Tsinghua University.
#
# This source code is licensed under the GNU License, Version 3.0
# found in the LICENSE file in the root directory of this source tree.
import copy
import os
import time
# import SMOS
import ray
import torch
import wandb
import logging
import random
import numpy as np
import torch.optim as optim
import torch.distributed as dist
import torch.nn.functional as F
from pathlib import Path
from tqdm.auto import tqdm
from torch.nn import L1Loss
from torch.cuda.amp import autocast as autocast
from torch.cuda.amp import GradScaler as GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from ez.utils.format import get_ddp_model_weights, DiscreteSupport, symexp
from ez.utils.loss import kl_loss, cosine_similarity_loss, continuous_loss, symlog_loss, Value_loss
from ez.data.trajectory import GameTrajectory
from ez.data.augmentation import Transforms
def DDP_setup(**kwargs):
# set master nod
os.environ['MASTER_ADDR'] = kwargs.get('address')
# os.environ['MASTER_PORT'] = kwargs.get('port')
# initialize the process group
try:
dist.init_process_group('nccl', rank=kwargs.get('rank'), world_size=kwargs.get('world_size') * kwargs.get('training_size'))
except:
dist.init_process_group('gloo', rank=kwargs.get('rank'), world_size=kwargs.get('world_size') * kwargs.get('training_size'))
print(f'DDP backend={dist.get_backend()}')
class Agent:
def __init__(self, config):
self.config = config
self.transforms = None
self.obs_shape = None
self.input_shape = None
self.action_space_size = None
self._update = False
self.use_ddp = True if config.ddp.world_size > 1 or config.ddp.training_size > 1 else False
def update_config(self):
raise NotImplementedError
def train(self, rank, replay_buffer, storage, batch_storage, logger):
assert self._update
# update image augmentation transform
self.update_augmentation_transform()
# save path
model_path = Path(self.config.save_path) / 'models'
model_path.mkdir(parents=True, exist_ok=True)
is_main_process = (rank == 0)
if is_main_process:
train_logger = logging.getLogger('Train')
eval_logger = logging.getLogger('Eval')
train_logger.info('config: {}'.format(self.config))
train_logger.info('save model in: {}'.format(model_path))
# prepare model
model = self.build_model().cuda()
target_model = self.build_model().cuda()
# load model
load_path = self.config.train.load_model_path
if os.path.exists(load_path):
if is_main_process:
train_logger.info('resume model from path: {}'.format(load_path))
weights = torch.load(load_path)
storage.set_weights.remote(weights, 'self_play')
storage.set_weights.remote(weights, 'reanalyze')
storage.set_weights.remote(weights, 'latest')
model.load_state_dict(weights)
target_model.load_state_dict(weights)
# DDP
if self.use_ddp:
model = DDP(model, device_ids=[rank])
if int(torch.__version__[0]) == 2:
model = torch.compile(model)
target_model = torch.compile(target_model)
model.train()
target_model.eval()
# optimizer
if self.config.optimizer.type == 'SGD':
optimizer = optim.SGD(model.parameters(),
lr=self.config.optimizer.lr,
weight_decay=self.config.optimizer.weight_decay,
momentum=self.config.optimizer.momentum)
elif self.config.optimizer.type == 'Adam':
optimizer = optim.Adam(model.parameters(),
lr=self.config.optimizer.lr,
weight_decay=self.config.optimizer.weight_decay)
elif self.config.optimizer.type == 'AdamW':
optimizer = optim.AdamW(model.parameters(),
lr=self.config.optimizer.lr,
weight_decay=self.config.optimizer.weight_decay)
else:
raise NotImplementedError
if self.config.optimizer.lr_decay_type == 'cosine':
max_steps = self.config.train.training_steps - int(self.config.train.training_steps * self.config.optimizer.lr_warm_up)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_steps * 3, eta_min=0)
elif self.config.optimizer.lr_decay_type == 'full_cosine':
max_steps = self.config.train.training_steps - int(self.config.train.training_steps * self.config.optimizer.lr_warm_up)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_steps // 2, eta_min=0)
else:
scheduler = None
scaler = GradScaler()
# wait until collecting enough data to start
while not (ray.get(replay_buffer.get_transition_num.remote()) >= self.config.train.start_transitions):
time.sleep(1)
pass
print('[Train] Begin training...')
# set signals for other workers
if is_main_process:
storage.set_start_signal.remote()
step_count = 0
# Note: the interval of the current model and the target model is between x and 2x. (x = target_model_interval)
# recent_weights is the param of the target model
recent_weights = self.get_weights(model)
# some logs
total_time = 0
total_steps = self.config.train.training_steps + self.config.train.offline_training_steps
if is_main_process:
pb = tqdm(np.arange(total_steps), leave=True)
# while loop
self_play_reteurn = 0.
traj_num, transition_num = 0, 0
eval_score, eval_best_score = 0., 0.
prev_eval_counter = -1
eval_counter = 0
while not self.is_finished(step_count):
start_time = time.time()
# obtain a batch
batch = batch_storage.pop()
end_time1 = time.time()
if batch is None:
time.sleep(0.3)
# print('batch is None')
continue
# adjust learning rate
if is_main_process:
storage.increase_counter.remote()
lr = self.adjust_lr(optimizer, step_count, scheduler)
if is_main_process and step_count % 30 == 0:
latest_weights = self.get_weights(model)
storage.set_weights.remote(latest_weights, 'latest')
# update model for self-play
if is_main_process and step_count % self.config.train.self_play_update_interval == 0:
weights = self.get_weights(model)
storage.set_weights.remote(weights, 'self_play')
# update model for reanalyzing
if is_main_process and step_count % self.config.train.reanalyze_update_interval == 0:
storage.set_weights.remote(recent_weights, 'reanalyze')
target_model.set_weights(recent_weights)
target_model.cuda()
target_model.eval()
recent_weights = self.get_weights(model)
if step_count % self.config.train.eval_interval == 0:
if eval_counter == prev_eval_counter:
time.sleep(1)
continue
scalers, log_data = self.update_weights(model, batch, optimizer, replay_buffer, scaler, step_count, target_model=target_model)
scaler = scalers[0]
loss_data, other_scalar, other_distribution = log_data
# TODO: maybe this barrier can be removed
if self.config.ddp.training_size > 1 or self.config.ddp.world_size > 1:
dist.barrier()
# save models
if is_main_process and step_count % self.config.train.save_ckpt_interval == 0:
cur_model_path = model_path / 'model_{}.p'.format(step_count)
torch.save(self.get_weights(model), cur_model_path)
end_time = time.time()
total_time += end_time - start_time
step_count += 1
avg_time = total_time / step_count
log_scalars = {}
log_distribution = {}
pb_interval = 50
if is_main_process and step_count % pb_interval == 0:
left_steps = (self.config.train.training_steps + self.config.train.offline_training_steps - step_count)
left_time = (left_steps * avg_time) / 3600
batch_queue_size = batch_storage.get_len()
train_log_str = '[Train] {}, step {}/{}, {:.3f}h left. lr={:.3f}, avg time={:.3f}s, batchQ={}, '\
'self-play return={:.3f}, collect {}/{:.3f}k, eval score={:.3f}/{:.3f}. '\
'Loss: reward={:.3f}, value={:.3f}, policy={:.3f}, ' \
'consistency={:.3f}, entropy={:.3f}'\
''.format(self.config.env.game, step_count, total_steps, left_time, lr, avg_time,
batch_queue_size, self_play_reteurn, traj_num, transition_num / 1000,
eval_score, eval_best_score, loss_data['loss/value_prefix'],
loss_data['loss/value'], loss_data['loss/policy'],
loss_data['loss/consistency'], loss_data['loss/entropy'])
# print(f'target policy={batch[-1][-1][0, 0]}')
pb.set_description(train_log_str)
pb.update(pb_interval)
log_scalars.update({
'train/step_per_second (s)': end_time - start_time,
'train/total time (h)': total_time / 3600,
'train/avg time (s)': avg_time,
'train/lr': lr,
'train/queue size': batch_queue_size
})
if is_main_process and step_count % self.config.log.log_interval == 0:
# train_logger.info(train_log_str)
# self-play statistics
eval_scalar, remote_scalar, remote_distribution = ray.get(storage.get_log.remote())
log_scalars.update(remote_scalar)
log_distribution.update(remote_distribution)
if remote_scalar.get('self_play/episode_return'):
self_play_reteurn = remote_scalar.get('self_play/episode_return')
if len(eval_scalar) > 0:
eval_score = eval_scalar['eval/mean_score']
min_score, max_score = eval_scalar['eval/min_score'], eval_scalar['eval/max_score']
eval_counter, eval_best_score = ray.get([storage.get_eval_counter.remote(), storage.get_best_score.remote()])
eval_log_str = 'Eval {} at at step {}, score = {:.3f}(min: {:.3f}, max: {:.3f}), ' \
'best score over past evaluation = {:.3f}' \
''.format(self.config.env.game, eval_counter, eval_score, min_score, max_score,
eval_best_score)
eval_logger.info(eval_log_str)
# TODO: fix the counter issue
# logger.log(eval_scalar, eval_counter)
logger.log(eval_scalar, step_count)
print('[Eval] ', eval_log_str)
# replay statistics
traj_num, transition_num, total_priorities = ray.get([
replay_buffer.get_traj_num.remote(), replay_buffer.get_transition_num.remote(), replay_buffer.get_priorities.remote()
])
log_scalars.update({
'buffer/total_episode_num': traj_num,
'buffer/total_transition_num': transition_num
})
log_distribution.update({
'dist/priorities_in_buffer': total_priorities,
})
log_distribution.update(other_distribution)
self.log_hist(logger, log_distribution, step_count)
if step_count % 20000 == 0 and self.config.train.periodic_reset:
print('-------------------------reset network------------------------------')
model = self.periodic_reset_model(model)
# training statistics
log_scalars.update(loss_data)
log_scalars.update(other_scalar)
if step_count > 500 and step_count % 1000 == 0:
logger.log(log_scalars, step_count)
traj_num, transition_num, total_priorities = ray.get([
replay_buffer.get_traj_num.remote(), replay_buffer.get_transition_num.remote(),
replay_buffer.get_priorities.remote()
])
log_distribution.update({
'dist/priorities_in_buffer': total_priorities,
})
log_distribution.update(other_distribution)
if is_main_process:
final_weights = self.get_weights(model)
storage.set_weights.remote(final_weights, 'self_play')
else:
final_weights = None
return final_weights, model
def reset_network(self, network):
for layer in network.children():
if hasattr(layer, 'reset_parameters'):
layer.reset_parameters()
def periodic_reset_model(self, model):
if self.config.env.image_based:
# reset prediction_backbone
self.reset_network(model.value_policy_model.resblocks)
# # reset policy
self.reset_network(model.value_policy_model.conv1x1_policy)
self.reset_network(model.value_policy_model.bn_policy)
self.reset_network(model.value_policy_model.fc_policy)
#
# # reset value
self.reset_network(model.value_policy_model.conv1x1_values)
self.reset_network(model.value_policy_model.bn_values)
self.reset_network(model.value_policy_model.fc_values)
else:
self.reset_network(model.value_policy_model.val_resblock)
self.reset_network(model.value_policy_model.pi_resblock)
self.reset_network(model.value_policy_model.val_ln)
self.reset_network(model.value_policy_model.pi_ln)
self.reset_network(model.value_policy_model.val_net)
self.reset_network(model.value_policy_model.pi_net)
return model
# @profile
def update_weights(self, model, batch, optimizer, replay_buffer, scaler, step_count, target_model=None):
target_model.eval()
# init
batch_size = self.config.train.batch_size
image_channel = self.config.env.obs_shape[0] if self.config.env.image_based else self.config.env.obs_shape
unroll_steps = self.config.rl.unroll_steps
n_stack = self.config.env.n_stack
gradient_scale = 1. / unroll_steps
reward_hidden = self.init_reward_hidden(batch_size)
loss_data = {}
other_scalar = {}
other_distribution = {}
# obtain the batch data
inputs_batch, targets_batch = batch
obs_batch_ori, action_batch, mask_batch, indices, weights_lst, make_time, prior_lst = inputs_batch
target_value_prefixes, target_values, target_actions, target_policies, target_best_actions, \
top_value_masks, mismatch_masks, search_values = targets_batch
target_value_prefixes = target_value_prefixes[:, :unroll_steps]
# obs_batch_raw: [s_{t - stack} ... s_{t} ... s_{t + unroll}]
if self.config.env.image_based:
obs_batch_raw = torch.from_numpy(obs_batch_ori).cuda().float() / 255.
else:
obs_batch_raw = torch.from_numpy(obs_batch_ori).cuda().float()
obs_batch = obs_batch_raw[:, 0: n_stack * image_channel] # obs_batch: current observation
obs_target_batch = obs_batch_raw[:, image_channel:] # obs_target_batch: observation of next steps
# if self.config.train.use_decorrelation:
# obs_batch_all = copy.deepcopy(obs_batch)
# for step_i in range(1, unroll_steps + 1):
# obs_batch_all = torch.cat((obs_batch_all, obs_batch_raw[:, step_i * image_channel: (step_i + n_stack) * image_channel]), dim=0)
# augmentation
obs_batch = self.transform(obs_batch)
obs_target_batch = self.transform(obs_target_batch)
# if self.config.train.use_decorrelation:
# obs_batch_aug1 = self.transform(obs_batch_all)
# obs_batch_aug2 = self.transform(obs_batch_all)
# others to gpu
if self.config.env.env in ['DMC', 'Gym']:
action_batch = torch.from_numpy(action_batch).float().cuda()
else:
action_batch = torch.from_numpy(action_batch).cuda().unsqueeze(-1).long()
mask_batch = torch.from_numpy(mask_batch).cuda().float()
weights = torch.from_numpy(weights_lst).cuda().float()
max_value_target = np.array([target_values, search_values]).max(0)
target_value_prefixes = torch.from_numpy(target_value_prefixes).cuda().float()
target_values = torch.from_numpy(target_values).cuda().float()
target_actions = torch.from_numpy(target_actions).cuda().float()
target_policies = torch.from_numpy(target_policies).cuda().float()
target_best_actions = torch.from_numpy(target_best_actions).cuda().float()
top_value_masks = torch.from_numpy(top_value_masks).cuda().float()
mismatch_masks = torch.from_numpy(mismatch_masks).cuda().float()
search_values = torch.from_numpy(search_values).cuda().float()
max_value_target = torch.from_numpy(max_value_target).cuda().float()
# transform value and reward to support
target_value_prefixes_support = DiscreteSupport.scalar_to_vector(target_value_prefixes, **self.config.model.reward_support)
with autocast():
states, values, policies = model.initial_inference(obs_batch, training=True)
if self.config.model.value_support.type == 'symlog':
scaled_value = symexp(values).min(0)[0]
else:
scaled_value = DiscreteSupport.vector_to_scalar(values, **self.config.model.value_support).min(0)[0]
if self.config.env.env in ['DMC', 'Gym']:
scaled_value = scaled_value.clip(0, 1e5)
# loss of first step
# multi options (Value Loss)
if self.config.train.value_target == 'sarsa':
this_target_values = target_values
elif self.config.train.value_target == 'search':
this_target_values = search_values
elif self.config.train.value_target == 'mixed':
if step_count < self.config.train.start_use_mix_training_steps:
this_target_values = target_values
else:
this_target_values = target_values * top_value_masks.unsqueeze(1).repeat(1, unroll_steps + 1) \
+ search_values * (1 - top_value_masks).unsqueeze(1).repeat(1, unroll_steps + 1)
elif self.config.train.value_target == 'max':
this_target_values = max_value_target
else:
raise NotImplementedError
# update priority
fresh_priority = L1Loss(reduction='none')(scaled_value.squeeze(-1), this_target_values[:, 0]).detach().cpu().numpy()
fresh_priority += self.config.priority.min_prior
replay_buffer.update_priorities.remote(indices, fresh_priority, make_time)
value_loss = torch.zeros(batch_size).cuda()
value_loss += Value_loss(values, this_target_values[:, 0], self.config)
prev_values = values.clone()
if self.config.env.env in ['DMC', 'Gym']:
policy_loss, entropy_loss = continuous_loss(
policies, target_actions[:, 0], target_policies[:, 0],
target_best_actions[:, 0],
distribution_type=self.config.model.policy_distribution
)
mu = policies[:, :policies.shape[-1] // 2].detach().cpu().numpy().flatten()
sigma = policies[:, policies.shape[-1] // 2:].detach().cpu().numpy().flatten()
other_distribution.update({
'dist/policy_mu': mu,
'dist/policy_sigma': sigma,
})
else:
policy_loss = kl_loss(policies, target_policies[:, 0])
entropy_loss = torch.zeros(batch_size).cuda()
value_prefix_loss = torch.zeros(batch_size).cuda()
consistency_loss = torch.zeros(batch_size).cuda()
policy_entropy_loss = torch.zeros(batch_size).cuda()
policy_entropy_loss -= entropy_loss
prev_value_prefixes = torch.zeros_like(policy_loss)
# unroll k steps recurrently
with autocast():
for step_i in range(unroll_steps):
mask = mask_batch[:, step_i]
states, value_prefixes, values, policies, reward_hidden = model.recurrent_inference(states, action_batch[:, step_i], reward_hidden, training=True)
beg_index = image_channel * step_i
end_index = image_channel * (step_i + n_stack)
# consistency loss
gt_next_states = model.do_representation(obs_target_batch[:, beg_index:end_index])
# projection for consistency
dynamic_states_proj = model.do_projection(states, with_grad=True)
gt_states_proj = model.do_projection(gt_next_states, with_grad=False)
consistency_loss += cosine_similarity_loss(dynamic_states_proj, gt_states_proj) * mask
# reward, value, policy loss
if self.config.model.reward_support.type == 'symlog':
value_prefix_loss += symlog_loss(value_prefixes, target_value_prefixes[:, step_i]) * mask
else:
value_prefix_loss += kl_loss(value_prefixes, target_value_prefixes_support[:, step_i]) * mask
value_loss += Value_loss(values, this_target_values[:, step_i + 1], self.config) * mask
if self.config.env.env in ['DMC', 'Gym']:
policy_loss_i, entropy_loss_i = continuous_loss(
policies, target_actions[:, step_i + 1], target_policies[:, step_i + 1],
target_best_actions[:, step_i + 1],
mask=mask,
distribution_type=self.config.model.policy_distribution
)
policy_loss += policy_loss_i
policy_entropy_loss -= entropy_loss_i
else:
policy_loss_i = kl_loss(policies, target_policies[:, step_i + 1]) * mask
policy_loss += policy_loss_i
# set half gradient due to two branches of states
states.register_hook(lambda grad: grad * 0.5)
# reset reward hidden
if self.config.model.value_prefix and (step_i + 1) % self.config.model.lstm_horizon_len == 0:
reward_hidden = self.init_reward_hidden(batch_size)
# total loss
loss = (value_prefix_loss * self.config.train.reward_loss_coeff
+ value_loss * self.config.train.value_loss_coeff
+ policy_loss * self.config.train.policy_loss_coeff
+ consistency_loss * self.config.train.consistency_coeff)
if self.config.env.env in ['DMC', 'Gym']:
loss += policy_entropy_loss * self.config.train.entropy_coeff
weighted_loss = (weights * loss).mean()
if weighted_loss.isnan():
import ipdb
ipdb.set_trace()
print('loss nan')
# backward
parameters = model.parameters()
with autocast():
weighted_loss.register_hook(lambda grad: grad * gradient_scale)
optimizer.zero_grad()
scaler.scale(weighted_loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(parameters, self.config.train.max_grad_norm)
scaler.step(optimizer)
scaler.update()
if self.config.model.noisy_net:
model.value_policy_model.reset_noise()
target_model.value_policy_model.reset_noise()
# log
loss_data.update({
'loss/total': loss.mean().item(), 'loss/weighted': weighted_loss.mean().item(),
'loss/consistency': consistency_loss.mean().item(), 'loss/value_prefix': value_prefix_loss.mean().item(),
'loss/value': value_loss.mean().item(), 'loss/policy': policy_loss.mean().item(),
'loss/entropy': policy_entropy_loss.mean().item(),
})
other_scalar.update({
'other_log/target_value_prefix_max': target_value_prefixes.detach().cpu().numpy().max(),
'other_log/target_value_prefix_min': target_value_prefixes.detach().cpu().numpy().min(),
'other_log/target_value_prefix_mean': target_value_prefixes.detach().cpu().numpy().mean(),
'other_log/target_value_mean': target_values.detach().cpu().numpy().mean(),
'other_log/target_value_max': target_values.detach().cpu().numpy().max(),
'other_log/target_value_min': target_values.detach().cpu().numpy().min(),
'other_log/mismatch_num': batch_size * (unroll_steps + 1) - mismatch_masks.sum().detach().cpu().numpy()
})
other_distribution.update({
'dist/recent_priority': fresh_priority,
'dist/weights': weights.detach().cpu().numpy().flatten(),
'dist/prior_in_batch': prior_lst,
'dist/indices': indices.flatten(),
'dist/mask': mask.detach().cpu().numpy().flatten(),
'dist/target_policy': target_policies.detach().cpu().numpy().flatten(),
})
scalers = [scaler]
return scalers, (loss_data, other_scalar, other_distribution)
def get_weights(self, model):
if self.use_ddp:
return get_ddp_model_weights(model)
else:
return model.get_weights()
def adjust_lr(self, optimizer, step_count, scheduler):
lr_warm_step = int(self.config.train.training_steps * self.config.optimizer.lr_warm_up)
optimize_config = self.config.optimizer
# adjust learning rate, step lr every lr_decay_steps
if step_count < lr_warm_step:
lr = optimize_config.lr * step_count / lr_warm_step
for param_group in optimizer.param_groups:
param_group['lr'] = lr
else:
if self.config.optimizer.lr_decay_type == 'cosine':
if scheduler is not None:
scheduler.step()
lr = scheduler.get_last_lr()[0] # return a list
for param_group in optimizer.param_groups:
param_group['lr'] = lr
else:
lr = optimize_config.lr * optimize_config.lr_decay_rate ** (
(step_count - lr_warm_step) // optimize_config.lr_decay_steps)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
return lr
def log_hist(self, logger, distribution_dict, step_count):
for key, hist in distribution_dict.items():
table = wandb.Histogram(hist, num_bins=200)
logger.log({key: table}, step_count)
def transform(self, observation):
if self.transforms is not None:
return self.transforms(observation)
else:
return observation
def build_model(self):
raise NotImplementedError
def update_augmentation_transform(self):
if self.config.augmentation and self.config.env.image_based:
self.transforms = Transforms(self.config.augmentation, image_shape=(self.obs_shape[1], self.obs_shape[2]))
def get_temperature(self, trained_steps):
if self.config.train.change_temperature:
total_steps = self.config.train.training_steps + self.config.train.offline_training_steps
# if self.config.env.env == 'Atari':
if trained_steps < 0.5 * total_steps: # prev 0.5
return 1.0
elif trained_steps < 0.75 * total_steps: # prev 0.75
return 0.5
else:
return 0.25
else:
return 1.0
def init_env(self, env, max_steps):
assert self._update
obs = env.reset()
traj = self.new_game(max_steps)
stacked_obs = [obs for _ in range(self.config.env.n_stack)]
traj.init(stacked_obs)
return stacked_obs, traj
def init_envs(self, envs, max_steps=None):
assert self._update
stacked_obs_lst, game_trajs = [], []
# initialization for envs, stack [n - 1 zero obs, current obs] for the n-stack obs
for env in envs:
stacked_obs, traj = self.init_env(env, max_steps)
stacked_obs_lst.append(stacked_obs)
game_trajs.append(traj)
return stacked_obs_lst, game_trajs
def init_reward_hidden(self, batch_size):
if self.config.model.value_prefix:
reward_hidden = (torch.zeros(1, batch_size, self.config.model.lstm_hidden_size).cuda(),
torch.zeros(1, batch_size, self.config.model.lstm_hidden_size).cuda())
else:
reward_hidden = None
return reward_hidden
def new_game(self, max_steps):
assert self._update
traj = GameTrajectory(**self.config.env, **self.config.rl, **self.config.model, trajectory_size=max_steps)
if max_steps is None:
traj.set_inf_len()
return traj
def is_finished(self, trained_steps):
if trained_steps >= self.config.train.training_steps + self.config.train.offline_training_steps:
time.sleep(1)
return True
else:
return False
@ray.remote(num_gpus=0.55)
def train_ddp(agent, rank, replay_buffer, storage, batch_storage, logger):
print(f'training_rank={rank}')
if rank == 0:
wandb_name = agent.config.env.game + '-' + agent.config.wandb.tag
logger = wandb.init(
name=wandb_name,
project=agent.config.wandb.project,
# config=config,
)
assert agent._update
# update image augmentation transform
agent.update_augmentation_transform()
# save path
model_path = Path(agent.config.save_path) / 'models'
model_path.mkdir(parents=True, exist_ok=True)
is_main_process = (rank == 0)
if is_main_process:
train_logger = logging.getLogger('Train')
eval_logger = logging.getLogger('Eval')
train_logger.info('config: {}'.format(agent.config))
train_logger.info('save model in: {}'.format(model_path))
# prepare model
model = agent.build_model().cuda()
target_model = agent.build_model().cuda()
# load model
load_path = agent.config.train.load_model_path
if os.path.exists(load_path):
if is_main_process:
train_logger.info('resume model from path: {}'.format(load_path))
weights = torch.load(load_path)
storage.set_weights.remote(weights, 'self_play')
storage.set_weights.remote(weights, 'reanalyze')
storage.set_weights.remote(weights, 'latest')
model.load_state_dict(weights)
target_model.load_state_dict(weights)
# DDP
if agent.use_ddp:
DDP_setup(rank=rank, world_size=agent.config.ddp.world_size, training_size=agent.config.ddp.training_size, address='127.0.0.1')
model = DDP(model, device_ids=[rank])
if int(torch.__version__[0]) == 2:
model = torch.compile(model)
target_model = torch.compile(target_model)
model.train()
target_model.eval()
# optimizer
if agent.config.optimizer.type == 'SGD':
optimizer = optim.SGD(model.parameters(),
lr=agent.config.optimizer.lr,
weight_decay=agent.config.optimizer.weight_decay,
momentum=agent.config.optimizer.momentum)
elif agent.config.optimizer.type == 'Adam':
optimizer = optim.Adam(model.parameters(),
lr=agent.config.optimizer.lr,
weight_decay=agent.config.optimizer.weight_decay)
elif agent.config.optimizer.type == 'AdamW':
optimizer = optim.AdamW(model.parameters(),
lr=agent.config.optimizer.lr,
weight_decay=agent.config.optimizer.weight_decay)
else:
raise NotImplementedError
if agent.config.optimizer.lr_decay_type == 'cosine':
max_steps = agent.config.train.training_steps - int(agent.config.train.training_steps * agent.config.optimizer.lr_warm_up)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_steps * 3, eta_min=0)
elif agent.config.optimizer.lr_decay_type == 'full_cosine':
max_steps = agent.config.train.training_steps - int(agent.config.train.training_steps * agent.config.optimizer.lr_warm_up)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_steps // 2, eta_min=0)
else:
scheduler = None
scaler = GradScaler()
# wait until collecting enough data to start
while not (ray.get(replay_buffer.get_transition_num.remote()) >= agent.config.train.start_transitions):
time.sleep(1)
pass
print('[Train] Begin training...')
# set signals for other workers
if is_main_process:
storage.set_start_signal.remote()
step_count = 0
# Note: the interval of the current model and the target model is between x and 2x. (x = target_model_interval)
# recent_weights is the param of the target model
recent_weights = agent.get_weights(model)
# some logs
total_time = 0
total_steps = agent.config.train.training_steps + agent.config.train.offline_training_steps
if is_main_process:
pb = tqdm(np.arange(total_steps), leave=True)
# while loop
self_play_reteurn = 0.
traj_num, transition_num = 0, 0
eval_score, eval_best_score = 0., 0.
while not agent.is_finished(step_count):
start_time = time.time()
# obtain a batch
batch = batch_storage.pop()
end_time1 = time.time()
if batch is None:
time.sleep(0.3)
# print('batch is None')
continue
# adjust learning rate
if is_main_process:
storage.increase_counter.remote()
lr = agent.adjust_lr(optimizer, step_count, scheduler)
if is_main_process and step_count % 30 == 0:
latest_weights = agent.get_weights(model)
ray.get(storage.set_weights.remote(latest_weights, 'latest'))
# update model for agent-play
if is_main_process and step_count % agent.config.train.self_play_update_interval == 0:
weights = agent.get_weights(model)
storage.set_weights.remote(weights, 'self_play')
# update model for reanalyzing
if is_main_process and step_count % agent.config.train.reanalyze_update_interval == 0:
storage.set_weights.remote(recent_weights, 'reanalyze')
target_model.set_weights(recent_weights)
target_model.cuda()
target_model.eval()
recent_weights = agent.get_weights(model)
scalers, log_data = agent.update_weights(model.module, batch, optimizer, replay_buffer, scaler, step_count, target_model=target_model)
scaler = scalers[0]
loss_data, other_scalar, other_distribution = log_data
# TODO: maybe this barrier can be removed
if agent.config.ddp.training_size > 1 or agent.config.ddp.world_size > 1:
dist.barrier()
# save models
if is_main_process and step_count % agent.config.train.save_ckpt_interval == 0:
cur_model_path = model_path / 'model_{}.p'.format(step_count)
torch.save(agent.get_weights(model), cur_model_path)
end_time = time.time()
total_time += end_time - start_time
step_count += 1
avg_time = total_time / step_count
log_scalars = {}
log_distribution = {}
pb_interval = 50
if is_main_process and step_count % pb_interval == 0:
left_steps = (agent.config.train.training_steps + agent.config.train.offline_training_steps - step_count)
left_time = (left_steps * avg_time) / 3600
batch_queue_size = batch_storage.get_len()
train_log_str = '[Train] {}, step {}/{}, {:.3f}h left. lr={:.3f}, avg time={:.3f}s, batchQ={}, '\
'agent-play return={:.3f}, collect {}/{:.3f}k, eval score={:.3f}/{:.3f}. '\
'Loss: reward={:.3f}, value={:.3f}, policy={:.3f}, ' \
'consistency={:.3f}, entropy={:.3f}'\
''.format(agent.config.env.game, step_count, total_steps, left_time, lr, avg_time,
batch_queue_size, self_play_reteurn, traj_num, transition_num / 1000,
eval_score, eval_best_score, loss_data['loss/value_prefix'],
loss_data['loss/value'], loss_data['loss/policy'],
loss_data['loss/consistency'], loss_data['loss/entropy'])
# print(f'target policy={batch[-1][-1][0, 0]}')
pb.set_description(train_log_str)
pb.update(pb_interval)
log_scalars.update({
'train/step_per_second (s)': end_time - start_time,
'train/total time (h)': total_time / 3600,
'train/avg time (s)': avg_time,
'train/lr': lr,
'train/queue size': batch_queue_size
})
if is_main_process and step_count % agent.config.log.log_interval == 0:
# train_logger.info(train_log_str)
# agent-play statistics
eval_scalar, remote_scalar, remote_distribution = ray.get(storage.get_log.remote())
log_scalars.update(remote_scalar)
log_distribution.update(remote_distribution)
if remote_scalar.get('self_play/episode_return'):
self_play_reteurn = remote_scalar.get('self_play/episode_return')
if len(eval_scalar) > 0:
# TODO: fix the counter issue
# logger.log(eval_scalar, eval_counter)
logger.log(eval_scalar, step_count)
eval_score = eval_scalar['eval/mean_score']
min_score, max_score = eval_scalar['eval/min_score'], eval_scalar['eval/max_score']
eval_counter, eval_best_score = ray.get([storage.get_eval_counter.remote(), storage.get_best_score.remote()])
eval_log_str = 'Eval {} at at step {}, score = {:.3f}(min: {:.3f}, max: {:.3f}), ' \
'best score over past evaluation = {:.3f}' \
''.format(agent.config.env.game, eval_counter, eval_score, min_score, max_score,
eval_best_score)
eval_logger.info(eval_log_str)
print('[Eval] ', eval_log_str)
# replay statistics
traj_num, transition_num, total_priorities = ray.get([
replay_buffer.get_traj_num.remote(), replay_buffer.get_transition_num.remote(), replay_buffer.get_priorities.remote()
])
log_scalars.update({
'buffer/total_episode_num': traj_num,
'buffer/total_transition_num': transition_num
})
log_distribution.update({
'dist/priorities_in_buffer': total_priorities,
})
log_distribution.update(other_distribution)
agent.log_hist(logger, log_distribution, step_count)
if step_count % 20000 == 0 and agent.config.train.periodic_reset:
print('-------------------------reset network------------------------------')
model = agent.periodic_reset_model(model)
# training statistics
log_scalars.update(loss_data)
log_scalars.update(other_scalar)
if is_main_process and step_count > 100 and step_count % 1000 == 0:
logger.log(log_scalars, step_count)
traj_num, transition_num, total_priorities = ray.get([
replay_buffer.get_traj_num.remote(), replay_buffer.get_transition_num.remote(),
replay_buffer.get_priorities.remote()
])
log_distribution.update({
'dist/priorities_in_buffer': total_priorities,
})
log_distribution.update(other_distribution)
final_weights = agent.get_weights(model)
storage.set_weights.remote(final_weights, 'self_play')
return final_weights, model