EfficientZeroV2/ez/train.py
“Shengjiewang-Jason” 1367bca203 first commit
2024-06-07 16:02:01 +08:00

134 lines
4.2 KiB
Python

# Copyright (c) EVAR Lab, IIIS, Tsinghua University.
#
# This source code is licensed under the GNU License, Version 3.0
# found in the LICENSE file in the root directory of this source tree.
import os
import time
os.environ["RAY_OBJECT_STORE_ALLOW_SLOW_STORAGE"] = "1"
import ray
import wandb
import hydra
import torch
import multiprocessing
import sys
sys.path.append(os.getcwd())
import numpy as np
import torch.distributed as dist
import torch.multiprocessing as mp
from pathlib import Path
from omegaconf import OmegaConf
from ez import agents
from ez.utils.format import set_seed, init_logger
from ez.worker import start_workers, join_workers
from ez.eval import eval
@hydra.main(config_path='./config', config_name='config', version_base='1.1')
def main(config):
if config.exp_config is not None:
exp_config = OmegaConf.load(config.exp_config)
config = OmegaConf.merge(config, exp_config)
if config.ray.single_process:
config.train.self_play_update_interval = 1
config.train.reanalyze_update_interval = 1
config.actors.data_worker = 1
config.actors.batch_worker = 1
config.data.num_envs = 1
if config.ddp.world_size > 1:
mp.spawn(start_ddp_trainer, args=(config,), nprocs=config.ddp.world_size)
else:
start_ddp_trainer(0, config)
def start_ddp_trainer(rank, config):
assert rank >= 0
print(f'start {rank} train worker...')
agent = agents.names[config.agent_name](config) # update config
manager = None
num_gpus = torch.cuda.device_count()
num_cpus = multiprocessing.cpu_count()
ray.init(num_gpus=num_gpus, num_cpus=num_cpus, object_store_memory=150 * 1024 * 1024 * 1024 if config.env.image_based else 100 * 1024 * 1024 * 1024)
set_seed(config.env.base_seed + rank >= 0) # set seed
# set log
if rank == 0:
# wandb logger
if config.ddp.training_size == 1:
wandb_name = config.env.game + '-' + config.wandb.tag
print(f'wandb_name={wandb_name}')
logger = wandb.init(
name=wandb_name,
project=config.wandb.project,
# config=config,
)
else:
logger = None
# file logger
log_path = os.path.join(config.save_path, 'logs')
os.makedirs(log_path, exist_ok=True)
init_logger(log_path)
else:
logger = None
# train
final_weights = train(rank, agent, manager, logger, config)
# final evaluation
if rank == 0:
model = agent.build_model()
model.set_weights(final_weights)
save_path = Path(config.save_path) / 'recordings' / 'final'
scores = eval(agent, model, config.train.eval_n_episode, save_path, config)
print('final score: ', np.mean(scores))
def train(rank, agent, manager, logger, config):
# launch for the main process
if rank == 0:
workers, server_lst = start_workers(agent, manager, config)
else:
workers, server_lst = None, None
# train
storage_server, replay_buffer_server, watchdog_server, batch_storage = server_lst
if config.ddp.training_size == 1:
final_weights, final_model = agent.train(rank, replay_buffer_server, storage_server, batch_storage, logger)
else:
from ez.agents.base import train_ddp
time.sleep(1)
train_workers = [
train_ddp.remote(
agent, rank * config.ddp.training_size + rank_i,
replay_buffer_server, storage_server, batch_storage, logger
) for rank_i in range(config.ddp.training_size)
]
time.sleep(1)
final_weights, final_model = ray.get(train_workers)
epi_scores = eval(agent, final_model, 10, Path(config.save_path) / 'evaluation' / 'final', config,
max_steps=27000, use_pb=False, verbose=config.eval.verbose)
print(f'final_mean_score={epi_scores.mean():.3f}')
# join process
if rank == 0:
print(f'[main process] master worker finished')
time.sleep(1)
join_workers(workers, server_lst)
# return
dist.destroy_process_group()
return final_weights
if __name__ == '__main__':
main()