“Shengjiewang-Jason” 1367bca203 first commit
2024-06-07 16:02:01 +08:00

192 lines
6.9 KiB
Python

# Copyright (c) EVAR Lab, IIIS, Tsinghua University.
#
# This source code is licensed under the GNU License, Version 3.0
# found in the LICENSE file in the root directory of this source tree.
import os
import sys
sys.path.append(os.getcwd())
import time
import torch
import ray
import copy
import cv2
import hydra
import multiprocessing
import numpy as np
import imageio
from PIL import Image, ImageDraw
from pathlib import Path
from tqdm.auto import tqdm
from omegaconf import OmegaConf
from torch.cuda.amp import autocast as autocast
import torch.nn.functional as F
from ez import mcts
from ez import agents
from ez.envs import make_envs
from ez.utils.format import formalize_obs_lst, DiscreteSupport, prepare_obs_lst, symexp, profile
from ez.mcts.cy_mcts import Gumbel_MCTS
from ez.utils.distribution import SquashedNormal, TruncatedNormal
@hydra.main(config_path="./config", config_name='config', version_base='1.1')
def main(config):
if config.exp_config is not None:
exp_config = OmegaConf.load(config.exp_config)
config = OmegaConf.merge(config, exp_config)
# update config
agent = agents.names[config.agent_name](config)
num_gpus = torch.cuda.device_count()
num_cpus = multiprocessing.cpu_count()
ray.init(num_gpus=num_gpus, num_cpus=num_cpus,
object_store_memory= 150 * 1024 * 1024 * 1024 if config.env.image_based else 100 * 1024 * 1024 * 1024)
# prepare model
model = agent.build_model()
if os.path.exists(config.eval.model_path):
weights = torch.load(config.eval.model_path)
model.load_state_dict(weights)
print('resume model from: ', config.eval.model_path)
if int(torch.__version__[0]) == 2:
model = torch.compile(model)
n_episodes = 1
save_path = Path(config.eval.save_path)
eval(agent, model, n_episodes, save_path, config,
max_steps=27000,
use_pb=True, verbose=config.eval.verbose)
@torch.no_grad()
def eval(agent, model, n_episodes, save_path, config, max_steps=None, use_pb=False, verbose=0):
model.cuda()
model.eval()
# prepare logs
if save_path is not None:
video_path = save_path / 'recordings'
video_path.mkdir(parents=True, exist_ok=True)
else:
video_path = None
dones = np.array([False for _ in range(n_episodes)])
if use_pb:
pb = tqdm(np.arange(max_steps), leave=True)
ep_ori_rewards = np.zeros(n_episodes)
# make env
if max_steps is not None:
config.env.max_episode_steps = max_steps
envs = make_envs(config.env.env, config.env.game, n_episodes, config.env.base_seed, save_path=video_path,
episodic_life=False, **config.env)
# initialization
stack_obs_windows, game_trajs = agent.init_envs(envs, max_steps)
# set infinity trajectory size
[traj.set_inf_len() for traj in game_trajs]
# begin to evaluate
step = 0
frames = [[] for _ in range(n_episodes)]
rewards = [[] for _ in range(n_episodes)]
while not dones.all():
# debug
if verbose:
import ipdb
ipdb.set_trace()
# stack obs
current_stacked_obs = formalize_obs_lst(stack_obs_windows, image_based=config.env.image_based)
# obtain the statistics at current steps
with torch.no_grad():
with autocast():
states, values, policies = model.initial_inference(current_stacked_obs)
values = values.detach().cpu().numpy().flatten()
# tree search for policies
tree = mcts.names[config.mcts.language](
# num_actions=config.env.action_space_size if config.env.env == 'Atari' else config.mcts.num_top_actions,
num_actions=config.env.action_space_size if config.env.env == 'Atari' else config.mcts.num_sampled_actions,
discount=config.rl.discount,
env=config.env.env,
**config.mcts, # pass mcts related params
**config.model, # pass the value and reward support params
)
if config.env.env == 'Atari':
if config.mcts.use_gumbel:
r_values, r_policies, best_actions, _ = tree.search(model, n_episodes, states, values, policies,
use_gumble_noise=False, verbose=verbose)
else:
r_values, r_policies, best_actions, _ = tree.search_ori_mcts(model, n_episodes, states, values, policies,
use_noise=False)
else:
r_values, r_policies, best_actions, _, _, _ = tree.search_continuous(
model, n_episodes, states, values, policies,
use_gumble_noise=False, verbose=verbose, add_noise=False
)
# step action in environments
for i in range(n_episodes):
if dones[i]:
continue
action = best_actions[i]
obs, reward, done, info = envs[i].step(action)
frames[i].append(obs if config.env.image_based else envs[i].render(mode='rgb_array'))
# rewards[i].append(reward)
rewards[i].append(info['raw_reward'])
dones[i] = done
# save data to trajectory buffer
game_trajs[i].store_search_results(values[i], r_values[i], r_policies[i])
game_trajs[i].append(action, obs, reward)
if config.env.env == 'Atari':
game_trajs[i].snapshot_lst.append(envs[i].ale.cloneState())
else:
game_trajs[i].snapshot_lst.append(envs[i].physics.get_state())
del stack_obs_windows[i][0]
stack_obs_windows[i].append(obs)
# log
ep_ori_rewards[i] += info['raw_reward']
step += 1
if use_pb:
pb.set_description('{} In step {}, take action {}, scores: {}(max: {}, min: {}) currently.'
''.format(config.env.game, step, best_actions,
ep_ori_rewards.mean(), ep_ori_rewards.max(), ep_ori_rewards.min()))
pb.update(1)
[env.close() for env in envs]
for i in range(n_episodes):
writer = imageio.get_writer(video_path / f'epi_{i}_{max_steps}.mp4')
rewards[i][0] = sum(rewards[i])
j = 0
for frame, reward in zip(frames[i], rewards[i]):
frame = Image.fromarray(frame)
draw = ImageDraw.Draw(frame)
if config.env.game == 'hopper_hop':
draw.text((5, 5), f'mu={game_trajs[i].action_lst[j][0]:.2f},{game_trajs[i].action_lst[j][1]:.2f}')
draw.text((5, 20), f'{game_trajs[i].action_lst[j][2]:.2f},{game_trajs[i].action_lst[j][3]:.2f}')
draw.text((5, 35), f'r={reward:.2f}')
frame = np.array(frame)
writer.append_data(frame)
j += 1
writer.close()
return ep_ori_rewards
if __name__=='__main__':
main()