134 lines
4.2 KiB
Python
134 lines
4.2 KiB
Python
# Copyright (c) EVAR Lab, IIIS, Tsinghua University.
|
|
#
|
|
# This source code is licensed under the GNU License, Version 3.0
|
|
# found in the LICENSE file in the root directory of this source tree.
|
|
|
|
import os
|
|
import time
|
|
os.environ["RAY_OBJECT_STORE_ALLOW_SLOW_STORAGE"] = "1"
|
|
import ray
|
|
import wandb
|
|
import hydra
|
|
import torch
|
|
import multiprocessing
|
|
|
|
import sys
|
|
sys.path.append(os.getcwd())
|
|
|
|
import numpy as np
|
|
import torch.distributed as dist
|
|
import torch.multiprocessing as mp
|
|
|
|
from pathlib import Path
|
|
from omegaconf import OmegaConf
|
|
|
|
from ez import agents
|
|
from ez.utils.format import set_seed, init_logger
|
|
from ez.worker import start_workers, join_workers
|
|
from ez.eval import eval
|
|
|
|
|
|
@hydra.main(config_path='./config', config_name='config', version_base='1.1')
|
|
def main(config):
|
|
if config.exp_config is not None:
|
|
exp_config = OmegaConf.load(config.exp_config)
|
|
config = OmegaConf.merge(config, exp_config)
|
|
|
|
if config.ray.single_process:
|
|
config.train.self_play_update_interval = 1
|
|
config.train.reanalyze_update_interval = 1
|
|
config.actors.data_worker = 1
|
|
config.actors.batch_worker = 1
|
|
config.data.num_envs = 1
|
|
|
|
if config.ddp.world_size > 1:
|
|
mp.spawn(start_ddp_trainer, args=(config,), nprocs=config.ddp.world_size)
|
|
else:
|
|
start_ddp_trainer(0, config)
|
|
|
|
|
|
def start_ddp_trainer(rank, config):
|
|
assert rank >= 0
|
|
print(f'start {rank} train worker...')
|
|
agent = agents.names[config.agent_name](config) # update config
|
|
manager = None
|
|
num_gpus = torch.cuda.device_count()
|
|
num_cpus = multiprocessing.cpu_count()
|
|
ray.init(num_gpus=num_gpus, num_cpus=num_cpus, object_store_memory=150 * 1024 * 1024 * 1024 if config.env.image_based else 100 * 1024 * 1024 * 1024)
|
|
set_seed(config.env.base_seed + rank >= 0) # set seed
|
|
# set log
|
|
|
|
if rank == 0:
|
|
# wandb logger
|
|
if config.ddp.training_size == 1:
|
|
wandb_name = config.env.game + '-' + config.wandb.tag
|
|
print(f'wandb_name={wandb_name}')
|
|
logger = wandb.init(
|
|
name=wandb_name,
|
|
project=config.wandb.project,
|
|
# config=config,
|
|
)
|
|
else:
|
|
logger = None
|
|
# file logger
|
|
log_path = os.path.join(config.save_path, 'logs')
|
|
os.makedirs(log_path, exist_ok=True)
|
|
init_logger(log_path)
|
|
else:
|
|
logger = None
|
|
|
|
# train
|
|
final_weights = train(rank, agent, manager, logger, config)
|
|
|
|
# final evaluation
|
|
if rank == 0:
|
|
model = agent.build_model()
|
|
model.set_weights(final_weights)
|
|
save_path = Path(config.save_path) / 'recordings' / 'final'
|
|
|
|
scores = eval(agent, model, config.train.eval_n_episode, save_path, config)
|
|
print('final score: ', np.mean(scores))
|
|
|
|
|
|
def train(rank, agent, manager, logger, config):
|
|
# launch for the main process
|
|
if rank == 0:
|
|
workers, server_lst = start_workers(agent, manager, config)
|
|
else:
|
|
workers, server_lst = None, None
|
|
|
|
# train
|
|
storage_server, replay_buffer_server, watchdog_server, batch_storage = server_lst
|
|
|
|
if config.ddp.training_size == 1:
|
|
final_weights, final_model = agent.train(rank, replay_buffer_server, storage_server, batch_storage, logger)
|
|
else:
|
|
from ez.agents.base import train_ddp
|
|
time.sleep(1)
|
|
train_workers = [
|
|
train_ddp.remote(
|
|
agent, rank * config.ddp.training_size + rank_i,
|
|
replay_buffer_server, storage_server, batch_storage, logger
|
|
) for rank_i in range(config.ddp.training_size)
|
|
]
|
|
time.sleep(1)
|
|
final_weights, final_model = ray.get(train_workers)
|
|
|
|
epi_scores = eval(agent, final_model, 10, Path(config.save_path) / 'evaluation' / 'final', config,
|
|
max_steps=27000, use_pb=False, verbose=config.eval.verbose)
|
|
print(f'final_mean_score={epi_scores.mean():.3f}')
|
|
|
|
# join process
|
|
if rank == 0:
|
|
print(f'[main process] master worker finished')
|
|
time.sleep(1)
|
|
join_workers(workers, server_lst)
|
|
|
|
# return
|
|
dist.destroy_process_group()
|
|
return final_weights
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|