EfficientZeroV2/ez/worker/data_worker.py
“Shengjiewang-Jason” 1367bca203 first commit
2024-06-07 16:02:01 +08:00

284 lines
13 KiB
Python

# Copyright (c) EVAR Lab, IIIS, Tsinghua University.
#
# This source code is licensed under the GNU License, Version 3.0
# found in the LICENSE file in the root directory of this source tree.
import copy
import os
import time
import ray
import torch
import numpy as np
from torch.nn import L1Loss
from pathlib import Path
from torch.cuda.amp import autocast as autocast
from ez.worker.base import Worker
from ez import mcts
from ez.envs import make_envs, make_env
from ez.utils.format import formalize_obs_lst, DiscreteSupport, allocate_gpu, prepare_obs_lst, symexp
from ez.mcts.cy_mcts import Gumbel_MCTS
# @ray.remote(num_gpus=0.05)
@ray.remote(num_gpus=0.05)
class DataWorker(Worker):
def __init__(self, rank, agent, replay_buffer, storage, config):
super().__init__(rank, agent, replay_buffer, storage, config)
self.model_update_interval = config.train.self_play_update_interval
self.traj_pool = []
self.pool_size = 1
# time.sleep(10000)
@torch.no_grad()
def run(self):
config = self.config
# create the model for self-play data collection
self.model = self.agent.build_model()
self.model.cuda()
if int(torch.__version__[0]) == 2:
self.model = torch.compile(self.model)
self.model.eval()
self.resume_model()
# make env
num_envs = config.data.num_envs
save_path = Path(config.save_path)
if config.data.save_video:
video_path = save_path / 'self_play_videos'
else:
video_path = None
cur_seed = config.env.base_seed
envs = make_envs(config.env.env, config.env.game, num_envs, cur_seed + self.rank * num_envs,
save_path=video_path, episodic_life=config.env.episodic, **config.env) # prev episodic_life=True
# initialization
trained_steps = 0 # current training steps
collected_transitions = ray.get(self.replay_buffer.get_transition_num.remote()) # total transitions collected
start_training = False # is training
max_transitions = config.data.total_transitions // config.actors.data_worker # max transitions to collect in this worker
dones = [False for _ in range(num_envs)]
traj_len = [0 for _ in range(num_envs)]
stack_obs_windows, game_trajs = self.agent.init_envs(envs, max_steps=self.config.data.trajectory_size)
prev_game_trajs = [None for _ in range(num_envs)] # previous game trajectories (split a full game trajectory into several sub trajectories)
# log data
episode_return = [0. for _ in range(num_envs)]
# while loop for collecting data
while not self.is_finished(trained_steps):
trained_steps = ray.get(self.storage.get_counter.remote())
if not start_training:
start_training = ray.get(self.storage.get_start_signal.remote())
# get the fresh model weights
self.get_recent_model(trained_steps, 'self_play')
if collected_transitions > max_transitions:
time.sleep(10)
continue
# self-play is faster than training speed or finished
if start_training and (collected_transitions / max_transitions) > (trained_steps / self.config.train.training_steps):
time.sleep(1)
continue
if self.config.ray.single_process:
trained_steps = ray.get(self.storage.get_counter.remote())
if start_training and trained_steps <= prev_train_steps:
time.sleep(0.1)
continue
prev_train_steps = trained_steps
print(f'selfplay[{self.rank}] rollouts at step {trained_steps}, collected transitions {collected_transitions}')
# print('self-playing')
# temperature
temperature = self.agent.get_temperature(trained_steps=trained_steps) #* np.ones((num_envs, 1))
# stack obs
current_stacked_obs = formalize_obs_lst(stack_obs_windows, image_based=config.env.image_based)
# obtain the statistics at current steps
with autocast():
states, values, policies = self.model.initial_inference(current_stacked_obs)
# process outputs
values = values.detach().cpu().numpy().flatten()
if collected_transitions % 200 == 0 and self.config.model.noisy_net and self.rank == 0:
print('*******************************')
print(f'w_ep={self.model.value_policy_model.pi_net[0].weight_epsilon.mean()}')
print(f'w_mu={self.model.value_policy_model.pi_net[0].weight_mu.mean()}')
print(f'w_si={self.model.value_policy_model.pi_net[0].weight_sigma.mean()}')
print(f'b_ep={self.model.value_policy_model.pi_net[0].bias_epsilon.mean()}')
print(f'b_mu={self.model.value_policy_model.pi_net[0].bias_mu.mean()}')
print(f'b_si={self.model.value_policy_model.pi_net[0].bias_sigma.mean()}')
# tree search for policies
tree = mcts.names[config.mcts.language](
# num_actions=config.env.action_space_size if config.env.env == 'Atari' else config.mcts.num_top_actions,
num_actions=config.env.action_space_size if config.env.env == 'Atari' else config.mcts.num_sampled_actions,
discount=config.rl.discount,
env=config.env.env,
**config.mcts, # pass mcts related params
**config.model, # pass the value and reward support params
)
if self.config.env.env == 'Atari':
if self.config.mcts.use_gumbel:
r_values, r_policies, best_actions, _ = tree.search(self.model, num_envs, states, values, policies,
# use_gumble_noise=False, # for test search
temperature=temperature)
else:
r_values, r_policies, best_actions, _ = tree.search_ori_mcts(self.model, num_envs, states, values, policies,
use_noise=True, temperature=temperature)
else:
r_values, r_policies, best_actions, sampled_actions, best_indexes, mcts_info = tree.search_continuous(
self.model, num_envs, states, values, policies, temperature=temperature,
# use_gumble_noise=True,
input_noises=None
)
# step action in environments
for i in range(num_envs):
action = best_actions[i]
obs, reward, done, info = envs[i].step(action)
dones[i] = done
traj_len[i] += 1
episode_return[i] += info['raw_reward']
# save data to trajectory buffer
game_trajs[i].store_search_results(values[i], r_values[i], r_policies[i])
game_trajs[i].append(action, obs, reward)
# game_trajs[i].raw_obs_lst.append(obs)
if self.config.env.env == 'Atari':
game_trajs[i].snapshot_lst.append([])
else:
game_trajs[i].snapshot_lst.append([])
# fresh stack windows
del stack_obs_windows[i][0]
stack_obs_windows[i].append(obs)
# if current trajectory is full; we will save the previous trajectory
if game_trajs[i].is_full():
if prev_game_trajs[i] is not None:
self.save_previous_trajectory(i, prev_game_trajs, game_trajs,
# padding=not dones[i]
)
prev_game_trajs[i] = game_trajs[i]
# new trajectory
game_trajs[i] = self.agent.new_game(max_steps=self.config.data.trajectory_size)
game_trajs[i].init(stack_obs_windows[i])
traj_len[i] = 0
# reset an env if done
if dones[i]:
# save the previous trajectory
if prev_game_trajs[i] is not None:
self.save_previous_trajectory(i, prev_game_trajs, game_trajs,
# padding=False
)
if len(game_trajs[i]) > 0:
# save current trajectory
game_trajs[i].pad_over([], [], [], [], [])
game_trajs[i].save_to_memory()
self.put_trajs(game_trajs[i])
# log
self.storage.add_log_scalar.remote({
'self_play/episode_len': traj_len[i],
'self_play/episode_return': episode_return[i],
'self_play/temperature': temperature
})
# reset the finished env and new a env
if self.config.env.env == 'DMC':
envs[i] = make_env(config.env.env, config.env.game, num_envs, cur_seed + self.rank * num_envs,
save_path=video_path, episodic_life=config.env.episodic, **config.env)
stacked_obs, traj = self.agent.init_env(envs[i], max_steps=self.config.data.trajectory_size)
stack_obs_windows[i] = stacked_obs
game_trajs[i] = traj
prev_game_trajs[i] = None
traj_len[i] = 0
episode_return[i] = 0
collected_transitions += 1
def save_previous_trajectory(self, idx, prev_game_trajs, game_trajs, padding=True):
"""put the previous game trajectory into the pool if the current trajectory is full
Parameters
----------
idx: int
index of the traj to handle
prev_game_trajs: list
list of the previous game trajectories
game_trajs: list
list of the current game trajectories
"""
if padding:
# pad over last block trajectory
if self.config.model.value_target == 'bootstrapped':
gap_step = self.config.env.n_stack + self.config.rl.td_steps
else:
extra = max(0, min(int(1 / (1 - self.config.rl.td_lambda)), self.config.model.GAE_max_steps) - self.config.rl.unroll_steps - 1)
gap_step = self.config.env.n_stack + 1 + extra + 1
beg_index = self.config.env.n_stack
end_index = beg_index + self.config.rl.unroll_steps
pad_obs_lst = game_trajs[idx].obs_lst[beg_index:end_index]
pad_policy_lst = game_trajs[idx].policy_lst[0:self.config.rl.unroll_steps]
pad_reward_lst = game_trajs[idx].reward_lst[0:gap_step - 1]
pad_pred_values_lst = game_trajs[idx].pred_value_lst[0:gap_step]
pad_search_values_lst = game_trajs[idx].search_value_lst[0:gap_step]
# pad over and save
prev_game_trajs[idx].pad_over(pad_obs_lst, pad_reward_lst, pad_pred_values_lst, pad_search_values_lst,
pad_policy_lst)
prev_game_trajs[idx].save_to_memory()
self.put_trajs(prev_game_trajs[idx])
# reset last block
prev_game_trajs[idx] = None
def put_trajs(self, traj):
if self.config.priority.use_priority:
traj_len = len(traj)
pred_values = torch.from_numpy(np.array(traj.pred_value_lst)).cuda().float()
# search_values = torch.from_numpy(np.array(traj.search_value_lst)).cuda().float()
if self.config.model.value_target == 'bootstrapped':
target_values = torch.from_numpy(np.asarray(traj.get_bootstrapped_value())).cuda().float()
elif self.config.model.value_target == 'GAE':
target_values = torch.from_numpy(np.asarray(traj.get_gae_value())).cuda().float()
else:
raise NotImplementedError
priorities = L1Loss(reduction='none')(pred_values[:traj_len], target_values[:traj_len]).detach().cpu().numpy() + self.config.priority.min_prior
else:
priorities = None
self.traj_pool.append(traj)
# save the game histories and clear the pool
if len(self.traj_pool) >= self.pool_size:
self.replay_buffer.save_pools.remote(self.traj_pool, priorities)
del self.traj_pool[:]
# ======================================================================================================================
# data worker for self-play
# ======================================================================================================================
def start_data_worker(rank, agent, replay_buffer, storage, config):
"""
Start a data worker. Call this method remotely.
"""
data_worker = DataWorker.remote(rank, agent, replay_buffer, storage, config)
data_worker.run.remote()
print(f'[Data worker] Start data worker {rank} at process {os.getpid()}.')