284 lines
13 KiB
Python
284 lines
13 KiB
Python
# Copyright (c) EVAR Lab, IIIS, Tsinghua University.
|
|
#
|
|
# This source code is licensed under the GNU License, Version 3.0
|
|
# found in the LICENSE file in the root directory of this source tree.
|
|
|
|
import copy
|
|
import os
|
|
import time
|
|
import ray
|
|
import torch
|
|
import numpy as np
|
|
|
|
from torch.nn import L1Loss
|
|
from pathlib import Path
|
|
from torch.cuda.amp import autocast as autocast
|
|
|
|
from ez.worker.base import Worker
|
|
from ez import mcts
|
|
from ez.envs import make_envs, make_env
|
|
from ez.utils.format import formalize_obs_lst, DiscreteSupport, allocate_gpu, prepare_obs_lst, symexp
|
|
from ez.mcts.cy_mcts import Gumbel_MCTS
|
|
|
|
# @ray.remote(num_gpus=0.05)
|
|
@ray.remote(num_gpus=0.05)
|
|
class DataWorker(Worker):
|
|
def __init__(self, rank, agent, replay_buffer, storage, config):
|
|
super().__init__(rank, agent, replay_buffer, storage, config)
|
|
|
|
self.model_update_interval = config.train.self_play_update_interval
|
|
self.traj_pool = []
|
|
self.pool_size = 1
|
|
|
|
# time.sleep(10000)
|
|
|
|
@torch.no_grad()
|
|
def run(self):
|
|
config = self.config
|
|
|
|
# create the model for self-play data collection
|
|
self.model = self.agent.build_model()
|
|
self.model.cuda()
|
|
if int(torch.__version__[0]) == 2:
|
|
self.model = torch.compile(self.model)
|
|
self.model.eval()
|
|
self.resume_model()
|
|
|
|
# make env
|
|
num_envs = config.data.num_envs
|
|
save_path = Path(config.save_path)
|
|
if config.data.save_video:
|
|
video_path = save_path / 'self_play_videos'
|
|
else:
|
|
video_path = None
|
|
cur_seed = config.env.base_seed
|
|
|
|
envs = make_envs(config.env.env, config.env.game, num_envs, cur_seed + self.rank * num_envs,
|
|
save_path=video_path, episodic_life=config.env.episodic, **config.env) # prev episodic_life=True
|
|
|
|
# initialization
|
|
trained_steps = 0 # current training steps
|
|
collected_transitions = ray.get(self.replay_buffer.get_transition_num.remote()) # total transitions collected
|
|
start_training = False # is training
|
|
max_transitions = config.data.total_transitions // config.actors.data_worker # max transitions to collect in this worker
|
|
dones = [False for _ in range(num_envs)]
|
|
traj_len = [0 for _ in range(num_envs)]
|
|
|
|
stack_obs_windows, game_trajs = self.agent.init_envs(envs, max_steps=self.config.data.trajectory_size)
|
|
prev_game_trajs = [None for _ in range(num_envs)] # previous game trajectories (split a full game trajectory into several sub trajectories)
|
|
|
|
# log data
|
|
episode_return = [0. for _ in range(num_envs)]
|
|
|
|
# while loop for collecting data
|
|
while not self.is_finished(trained_steps):
|
|
trained_steps = ray.get(self.storage.get_counter.remote())
|
|
if not start_training:
|
|
start_training = ray.get(self.storage.get_start_signal.remote())
|
|
|
|
# get the fresh model weights
|
|
self.get_recent_model(trained_steps, 'self_play')
|
|
|
|
if collected_transitions > max_transitions:
|
|
time.sleep(10)
|
|
continue
|
|
|
|
# self-play is faster than training speed or finished
|
|
if start_training and (collected_transitions / max_transitions) > (trained_steps / self.config.train.training_steps):
|
|
time.sleep(1)
|
|
continue
|
|
|
|
if self.config.ray.single_process:
|
|
trained_steps = ray.get(self.storage.get_counter.remote())
|
|
if start_training and trained_steps <= prev_train_steps:
|
|
time.sleep(0.1)
|
|
continue
|
|
prev_train_steps = trained_steps
|
|
print(f'selfplay[{self.rank}] rollouts at step {trained_steps}, collected transitions {collected_transitions}')
|
|
|
|
# print('self-playing')
|
|
# temperature
|
|
temperature = self.agent.get_temperature(trained_steps=trained_steps) #* np.ones((num_envs, 1))
|
|
|
|
# stack obs
|
|
current_stacked_obs = formalize_obs_lst(stack_obs_windows, image_based=config.env.image_based)
|
|
# obtain the statistics at current steps
|
|
with autocast():
|
|
states, values, policies = self.model.initial_inference(current_stacked_obs)
|
|
|
|
# process outputs
|
|
values = values.detach().cpu().numpy().flatten()
|
|
|
|
if collected_transitions % 200 == 0 and self.config.model.noisy_net and self.rank == 0:
|
|
print('*******************************')
|
|
print(f'w_ep={self.model.value_policy_model.pi_net[0].weight_epsilon.mean()}')
|
|
print(f'w_mu={self.model.value_policy_model.pi_net[0].weight_mu.mean()}')
|
|
print(f'w_si={self.model.value_policy_model.pi_net[0].weight_sigma.mean()}')
|
|
print(f'b_ep={self.model.value_policy_model.pi_net[0].bias_epsilon.mean()}')
|
|
print(f'b_mu={self.model.value_policy_model.pi_net[0].bias_mu.mean()}')
|
|
print(f'b_si={self.model.value_policy_model.pi_net[0].bias_sigma.mean()}')
|
|
|
|
# tree search for policies
|
|
tree = mcts.names[config.mcts.language](
|
|
# num_actions=config.env.action_space_size if config.env.env == 'Atari' else config.mcts.num_top_actions,
|
|
num_actions=config.env.action_space_size if config.env.env == 'Atari' else config.mcts.num_sampled_actions,
|
|
discount=config.rl.discount,
|
|
env=config.env.env,
|
|
**config.mcts, # pass mcts related params
|
|
**config.model, # pass the value and reward support params
|
|
)
|
|
if self.config.env.env == 'Atari':
|
|
if self.config.mcts.use_gumbel:
|
|
r_values, r_policies, best_actions, _ = tree.search(self.model, num_envs, states, values, policies,
|
|
# use_gumble_noise=False, # for test search
|
|
temperature=temperature)
|
|
else:
|
|
r_values, r_policies, best_actions, _ = tree.search_ori_mcts(self.model, num_envs, states, values, policies,
|
|
use_noise=True, temperature=temperature)
|
|
else:
|
|
r_values, r_policies, best_actions, sampled_actions, best_indexes, mcts_info = tree.search_continuous(
|
|
self.model, num_envs, states, values, policies, temperature=temperature,
|
|
# use_gumble_noise=True,
|
|
input_noises=None
|
|
)
|
|
|
|
# step action in environments
|
|
for i in range(num_envs):
|
|
action = best_actions[i]
|
|
obs, reward, done, info = envs[i].step(action)
|
|
dones[i] = done
|
|
traj_len[i] += 1
|
|
episode_return[i] += info['raw_reward']
|
|
|
|
# save data to trajectory buffer
|
|
game_trajs[i].store_search_results(values[i], r_values[i], r_policies[i])
|
|
game_trajs[i].append(action, obs, reward)
|
|
# game_trajs[i].raw_obs_lst.append(obs)
|
|
if self.config.env.env == 'Atari':
|
|
game_trajs[i].snapshot_lst.append([])
|
|
else:
|
|
game_trajs[i].snapshot_lst.append([])
|
|
|
|
# fresh stack windows
|
|
del stack_obs_windows[i][0]
|
|
stack_obs_windows[i].append(obs)
|
|
|
|
# if current trajectory is full; we will save the previous trajectory
|
|
if game_trajs[i].is_full():
|
|
if prev_game_trajs[i] is not None:
|
|
self.save_previous_trajectory(i, prev_game_trajs, game_trajs,
|
|
# padding=not dones[i]
|
|
)
|
|
|
|
prev_game_trajs[i] = game_trajs[i]
|
|
|
|
# new trajectory
|
|
game_trajs[i] = self.agent.new_game(max_steps=self.config.data.trajectory_size)
|
|
game_trajs[i].init(stack_obs_windows[i])
|
|
|
|
traj_len[i] = 0
|
|
|
|
# reset an env if done
|
|
if dones[i]:
|
|
# save the previous trajectory
|
|
if prev_game_trajs[i] is not None:
|
|
self.save_previous_trajectory(i, prev_game_trajs, game_trajs,
|
|
# padding=False
|
|
)
|
|
|
|
if len(game_trajs[i]) > 0:
|
|
# save current trajectory
|
|
game_trajs[i].pad_over([], [], [], [], [])
|
|
game_trajs[i].save_to_memory()
|
|
self.put_trajs(game_trajs[i])
|
|
|
|
# log
|
|
self.storage.add_log_scalar.remote({
|
|
'self_play/episode_len': traj_len[i],
|
|
'self_play/episode_return': episode_return[i],
|
|
'self_play/temperature': temperature
|
|
})
|
|
|
|
# reset the finished env and new a env
|
|
if self.config.env.env == 'DMC':
|
|
envs[i] = make_env(config.env.env, config.env.game, num_envs, cur_seed + self.rank * num_envs,
|
|
save_path=video_path, episodic_life=config.env.episodic, **config.env)
|
|
stacked_obs, traj = self.agent.init_env(envs[i], max_steps=self.config.data.trajectory_size)
|
|
stack_obs_windows[i] = stacked_obs
|
|
game_trajs[i] = traj
|
|
prev_game_trajs[i] = None
|
|
|
|
traj_len[i] = 0
|
|
episode_return[i] = 0
|
|
collected_transitions += 1
|
|
|
|
|
|
def save_previous_trajectory(self, idx, prev_game_trajs, game_trajs, padding=True):
|
|
"""put the previous game trajectory into the pool if the current trajectory is full
|
|
Parameters
|
|
----------
|
|
idx: int
|
|
index of the traj to handle
|
|
prev_game_trajs: list
|
|
list of the previous game trajectories
|
|
game_trajs: list
|
|
list of the current game trajectories
|
|
"""
|
|
if padding:
|
|
# pad over last block trajectory
|
|
if self.config.model.value_target == 'bootstrapped':
|
|
gap_step = self.config.env.n_stack + self.config.rl.td_steps
|
|
else:
|
|
extra = max(0, min(int(1 / (1 - self.config.rl.td_lambda)), self.config.model.GAE_max_steps) - self.config.rl.unroll_steps - 1)
|
|
gap_step = self.config.env.n_stack + 1 + extra + 1
|
|
|
|
beg_index = self.config.env.n_stack
|
|
end_index = beg_index + self.config.rl.unroll_steps
|
|
|
|
pad_obs_lst = game_trajs[idx].obs_lst[beg_index:end_index]
|
|
|
|
pad_policy_lst = game_trajs[idx].policy_lst[0:self.config.rl.unroll_steps]
|
|
pad_reward_lst = game_trajs[idx].reward_lst[0:gap_step - 1]
|
|
pad_pred_values_lst = game_trajs[idx].pred_value_lst[0:gap_step]
|
|
pad_search_values_lst = game_trajs[idx].search_value_lst[0:gap_step]
|
|
|
|
# pad over and save
|
|
prev_game_trajs[idx].pad_over(pad_obs_lst, pad_reward_lst, pad_pred_values_lst, pad_search_values_lst,
|
|
pad_policy_lst)
|
|
prev_game_trajs[idx].save_to_memory()
|
|
self.put_trajs(prev_game_trajs[idx])
|
|
|
|
# reset last block
|
|
prev_game_trajs[idx] = None
|
|
|
|
def put_trajs(self, traj):
|
|
if self.config.priority.use_priority:
|
|
traj_len = len(traj)
|
|
pred_values = torch.from_numpy(np.array(traj.pred_value_lst)).cuda().float()
|
|
# search_values = torch.from_numpy(np.array(traj.search_value_lst)).cuda().float()
|
|
if self.config.model.value_target == 'bootstrapped':
|
|
target_values = torch.from_numpy(np.asarray(traj.get_bootstrapped_value())).cuda().float()
|
|
elif self.config.model.value_target == 'GAE':
|
|
target_values = torch.from_numpy(np.asarray(traj.get_gae_value())).cuda().float()
|
|
else:
|
|
raise NotImplementedError
|
|
priorities = L1Loss(reduction='none')(pred_values[:traj_len], target_values[:traj_len]).detach().cpu().numpy() + self.config.priority.min_prior
|
|
else:
|
|
priorities = None
|
|
self.traj_pool.append(traj)
|
|
# save the game histories and clear the pool
|
|
if len(self.traj_pool) >= self.pool_size:
|
|
self.replay_buffer.save_pools.remote(self.traj_pool, priorities)
|
|
del self.traj_pool[:]
|
|
|
|
# ======================================================================================================================
|
|
# data worker for self-play
|
|
# ======================================================================================================================
|
|
def start_data_worker(rank, agent, replay_buffer, storage, config):
|
|
"""
|
|
Start a data worker. Call this method remotely.
|
|
"""
|
|
data_worker = DataWorker.remote(rank, agent, replay_buffer, storage, config)
|
|
data_worker.run.remote()
|
|
print(f'[Data worker] Start data worker {rank} at process {os.getpid()}.')
|