95 lines
4.0 KiB
Python
95 lines
4.0 KiB
Python
# Copyright (c) EVAR Lab, IIIS, Tsinghua University.
|
|
#
|
|
# This source code is licensed under the GNU License, Version 3.0
|
|
# found in the LICENSE file in the root directory of this source tree.
|
|
|
|
import os
|
|
import time
|
|
import ray
|
|
import torch
|
|
from omegaconf import OmegaConf
|
|
import torch.distributed as dist
|
|
import torch.multiprocessing as mp
|
|
|
|
from ez.worker.watchdog_worker import start_watchdog_server
|
|
from ez.data.global_storage import GlobalStorage
|
|
from ez.data.replay_buffer import ReplayBuffer
|
|
from ez.worker.data_worker import start_data_worker
|
|
from ez.worker.batch_worker import start_batch_worker, start_batch_worker_cpu, start_batch_worker_gpu
|
|
from ez.worker.eval_worker import start_eval_worker
|
|
from ez.utils.format import RayQueue, PreQueue
|
|
|
|
|
|
def start_workers(agent, manager, config):
|
|
# ==================================================================================================================
|
|
# start server
|
|
# ==================================================================================================================
|
|
|
|
# global storage server
|
|
storage_server = GlobalStorage.remote(agent.build_model(), agent.build_model(), agent.build_model())
|
|
print('[main process] Global storage server has been started from main process.')
|
|
|
|
# batch queue
|
|
batch_storage = RayQueue(15, 20)
|
|
print('[main process] Batch storage has been initialized.')
|
|
|
|
# replay buffer server
|
|
replay_buffer_server = ReplayBuffer.remote(batch_size=config.train.batch_size,
|
|
buffer_size=config.data.buffer_size,
|
|
top_transitions=config.data.top_transitions,
|
|
use_priority=config.priority.use_priority,
|
|
env=config.env.env,
|
|
total_transitions=config.data.total_transitions)
|
|
print('[main process] Replay buffer server has been started from main process.')
|
|
|
|
# watchdog server
|
|
watchdog_server = start_watchdog_server(manager)
|
|
print('[main process] Watchdog server has been started from main process.')
|
|
|
|
# ==================================================================================================================
|
|
# start worker
|
|
# ==================================================================================================================
|
|
|
|
# data workers
|
|
data_workers = [start_data_worker(rank, agent, replay_buffer_server, storage_server, config)
|
|
for rank in range(0, config.actors.data_worker)]
|
|
print('[main process] Data workers have all been launched.')
|
|
|
|
# batch worker
|
|
batch_workers = [start_batch_worker(rank, agent, replay_buffer_server, storage_server, batch_storage, config)
|
|
for rank in range(0, config.actors.batch_worker)]
|
|
print('[main process] Batch workers have all been launched.')
|
|
|
|
# eval worker
|
|
eval_worker = [start_eval_worker(agent, replay_buffer_server, storage_server, config)]
|
|
|
|
if int(torch.__version__[0]) == 2:
|
|
print(f'[main process] torch version is {torch.__version__}, enabled torch_compile.')
|
|
|
|
# trainer (in current process)
|
|
worker_lst = [data_workers, batch_workers, eval_worker]
|
|
server_lst = [storage_server, replay_buffer_server, watchdog_server, batch_storage]
|
|
|
|
return worker_lst, server_lst
|
|
|
|
|
|
def join_workers(worker_lst, server_lst):
|
|
data_workers, batch_workers, eval_worker = worker_lst
|
|
storage_server, replay_buffer_server, watchdog_server, smos_server = server_lst
|
|
|
|
# wait for all workers to finish
|
|
for data_worker in data_workers:
|
|
data_worker.join()
|
|
for batch_worker in batch_workers:
|
|
batch_worker.join()
|
|
eval_worker.join()
|
|
print(f'[main process] All workers have stopped.')
|
|
|
|
# stop servers
|
|
storage_server.terminate()
|
|
replay_buffer_server.terminate()
|
|
watchdog_server.terminate()
|
|
smos_server.stop()
|
|
print(f'[main process] All servers have stopped.')
|
|
|