From 47e8e2686cde0cb54bf1bbdcefa2046b0c5e0392 Mon Sep 17 00:00:00 2001 From: n+e <463003665@qq.com> Date: Fri, 10 Jul 2020 17:20:39 +0800 Subject: [PATCH] move atari wrapper to examples and publish v0.2.4 (#124) * move atari wrapper to examples * consistency * change drqn seed since it is quite unstable in current seed * minor fix * 0.2.4 --- {tianshou/env => examples}/atari.py | 59 +++++++++++++++++++---------- examples/pong_a2c.py | 24 ++++++------ examples/pong_dqn.py | 14 ++++--- examples/pong_ppo.py | 21 +++++----- test/discrete/test_drqn.py | 2 +- tianshou/__init__.py | 2 +- 6 files changed, 70 insertions(+), 52 deletions(-) rename {tianshou/env => examples}/atari.py (63%) diff --git a/tianshou/env/atari.py b/examples/atari.py similarity index 63% rename from tianshou/env/atari.py rename to examples/atari.py index 9904f08..7960efc 100644 --- a/tianshou/env/atari.py +++ b/examples/atari.py @@ -2,6 +2,10 @@ import cv2 import gym import numpy as np from gym.spaces.box import Box +from tianshou.data import Batch + +SIZE = 84 +FRAME = 4 def create_atari_environment(name=None, sticky_actions=True, @@ -14,6 +18,27 @@ def create_atari_environment(name=None, sticky_actions=True, return env +def preprocess_fn(obs=None, act=None, rew=None, done=None, + obs_next=None, info=None, policy=None): + if obs_next is not None: + obs_next = np.reshape(obs_next, (-1, *obs_next.shape[2:])) + obs_next = np.moveaxis(obs_next, 0, -1) + obs_next = cv2.resize(obs_next, (SIZE, SIZE)) + obs_next = np.asanyarray(obs_next, dtype=np.uint8) + obs_next = np.reshape(obs_next, (-1, FRAME, SIZE, SIZE)) + obs_next = np.moveaxis(obs_next, 1, -1) + elif obs is not None: + obs = np.reshape(obs, (-1, *obs.shape[2:])) + obs = np.moveaxis(obs, 0, -1) + obs = cv2.resize(obs, (SIZE, SIZE)) + obs = np.asanyarray(obs, dtype=np.uint8) + obs = np.reshape(obs, (-1, FRAME, SIZE, SIZE)) + obs = np.moveaxis(obs, 1, -1) + + return Batch(obs=obs, act=act, rew=rew, done=done, + obs_next=obs_next, info=info) + + class preprocessing(object): def __init__(self, env, frame_skip=4, terminal_on_life_loss=False, size=84, max_episode_steps=2000): @@ -35,7 +60,8 @@ class preprocessing(object): @property def observation_space(self): - return Box(low=0, high=255, shape=(self.size, self.size, 4), + return Box(low=0, high=255, + shape=(self.size, self.size, self.frame_skip), dtype=np.uint8) def action_space(self): @@ -57,8 +83,8 @@ class preprocessing(object): self._grayscale_obs(self.screen_buffer[0]) self.screen_buffer[1].fill(0) - return np.stack([ - self._pool_and_resize() for _ in range(self.frame_skip)], axis=-1) + return np.array([self._pool_and_resize() + for _ in range(self.frame_skip)]) def render(self, mode='human'): return self.env.render(mode) @@ -85,19 +111,15 @@ class preprocessing(object): self._grayscale_obs(self.screen_buffer[t_]) observation.append(self._pool_and_resize()) - while len(observation) > 0 and len(observation) < self.frame_skip: + if len(observation) == 0: + observation = [self._pool_and_resize() + for _ in range(self.frame_skip)] + while len(observation) > 0 and \ + len(observation) < self.frame_skip: observation.append(observation[-1]) - if len(observation) > 0: - observation = np.stack(observation, axis=-1) - else: - observation = np.stack([ - self._pool_and_resize() for _ in range(self.frame_skip)], - axis=-1) - if self.count >= self.max_episode_steps: - terminal = True - else: - terminal = False - return observation, total_reward, (terminal or is_terminal), info + terminal = self.count >= self.max_episode_steps + return np.array(observation), total_reward, \ + (terminal or is_terminal), info def _grayscale_obs(self, output): self.env.ale.getScreenGrayscale(output) @@ -108,9 +130,4 @@ class preprocessing(object): np.maximum(self.screen_buffer[0], self.screen_buffer[1], out=self.screen_buffer[0]) - transformed_image = cv2.resize(self.screen_buffer[0], - (self.size, self.size), - interpolation=cv2.INTER_AREA) - int_image = np.asarray(transformed_image, dtype=np.uint8) - # return np.expand_dims(int_image, axis=2) - return int_image + return self.screen_buffer[0] diff --git a/examples/pong_a2c.py b/examples/pong_a2c.py index 544153e..7261a05 100644 --- a/examples/pong_a2c.py +++ b/examples/pong_a2c.py @@ -8,11 +8,11 @@ from tianshou.policy import A2CPolicy from tianshou.env import SubprocVectorEnv from tianshou.trainer import onpolicy_trainer from tianshou.data import Collector, ReplayBuffer -from tianshou.env.atari import create_atari_environment - from tianshou.utils.net.discrete import Actor, Critic from tianshou.utils.net.common import Net +from atari import create_atari_environment, preprocess_fn + def get_args(): parser = argparse.ArgumentParser() @@ -45,20 +45,17 @@ def get_args(): def test_a2c(args=get_args()): - env = create_atari_environment( - args.task, max_episode_steps=args.max_episode_steps) + env = create_atari_environment(args.task) args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.env.action_space.shape or env.env.action_space.n # train_envs = gym.make(args.task) train_envs = SubprocVectorEnv( - [lambda: create_atari_environment( - args.task, max_episode_steps=args.max_episode_steps) - for _ in range(args.training_num)]) + [lambda: create_atari_environment(args.task) + for _ in range(args.training_num)]) # test_envs = gym.make(args.task) test_envs = SubprocVectorEnv( - [lambda: create_atari_environment( - args.task, max_episode_steps=args.max_episode_steps) - for _ in range(args.test_num)]) + [lambda: create_atari_environment(args.task) + for _ in range(args.test_num)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) @@ -76,8 +73,9 @@ def test_a2c(args=get_args()): ent_coef=args.ent_coef, max_grad_norm=args.max_grad_norm) # collector train_collector = Collector( - policy, train_envs, ReplayBuffer(args.buffer_size)) - test_collector = Collector(policy, test_envs) + policy, train_envs, ReplayBuffer(args.buffer_size), + preprocess_fn=preprocess_fn) + test_collector = Collector(policy, test_envs, preprocess_fn=preprocess_fn) # log writer = SummaryWriter(args.logdir + '/' + 'a2c') @@ -99,7 +97,7 @@ def test_a2c(args=get_args()): pprint.pprint(result) # Let's watch its performance! env = create_atari_environment(args.task) - collector = Collector(policy, env) + collector = Collector(policy, env, preprocess_fn=preprocess_fn) result = collector.collect(n_episode=1, render=args.render) print(f'Final reward: {result["rew"]}, length: {result["len"]}') collector.close() diff --git a/examples/pong_dqn.py b/examples/pong_dqn.py index 85a7e6e..e99c04e 100644 --- a/examples/pong_dqn.py +++ b/examples/pong_dqn.py @@ -9,7 +9,8 @@ from tianshou.env import SubprocVectorEnv from tianshou.utils.net.discrete import DQN from tianshou.trainer import offpolicy_trainer from tianshou.data import Collector, ReplayBuffer -from tianshou.env.atari import create_atari_environment + +from atari import create_atari_environment, preprocess_fn def get_args(): @@ -49,8 +50,8 @@ def test_dqn(args=get_args()): for _ in range(args.training_num)]) # test_envs = gym.make(args.task) test_envs = SubprocVectorEnv([ - lambda: create_atari_environment( - args.task) for _ in range(args.test_num)]) + lambda: create_atari_environment(args.task) + for _ in range(args.test_num)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) @@ -68,8 +69,9 @@ def test_dqn(args=get_args()): target_update_freq=args.target_update_freq) # collector train_collector = Collector( - policy, train_envs, ReplayBuffer(args.buffer_size)) - test_collector = Collector(policy, test_envs) + policy, train_envs, ReplayBuffer(args.buffer_size), + preprocess_fn=preprocess_fn) + test_collector = Collector(policy, test_envs, preprocess_fn=preprocess_fn) # policy.set_eps(1) train_collector.collect(n_step=args.batch_size * 4) print(len(train_collector.buffer)) @@ -101,7 +103,7 @@ def test_dqn(args=get_args()): pprint.pprint(result) # Let's watch its performance! env = create_atari_environment(args.task) - collector = Collector(policy, env) + collector = Collector(policy, env, preprocess_fn=preprocess_fn) result = collector.collect(n_episode=1, render=args.render) print(f'Final reward: {result["rew"]}, length: {result["len"]}') collector.close() diff --git a/examples/pong_ppo.py b/examples/pong_ppo.py index e475344..083c0df 100644 --- a/examples/pong_ppo.py +++ b/examples/pong_ppo.py @@ -8,10 +8,11 @@ from tianshou.policy import PPOPolicy from tianshou.env import SubprocVectorEnv from tianshou.trainer import onpolicy_trainer from tianshou.data import Collector, ReplayBuffer -from tianshou.env.atari import create_atari_environment from tianshou.utils.net.discrete import Actor, Critic from tianshou.utils.net.common import Net +from atari import create_atari_environment, preprocess_fn + def get_args(): parser = argparse.ArgumentParser() @@ -44,17 +45,16 @@ def get_args(): def test_ppo(args=get_args()): - env = create_atari_environment( - args.task, max_episode_steps=args.max_episode_steps) + env = create_atari_environment(args.task) args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space().shape or env.action_space().n # train_envs = gym.make(args.task) - train_envs = SubprocVectorEnv([lambda: create_atari_environment( - args.task, max_episode_steps=args.max_episode_steps) + train_envs = SubprocVectorEnv([ + lambda: create_atari_environment(args.task) for _ in range(args.training_num)]) # test_envs = gym.make(args.task) - test_envs = SubprocVectorEnv([lambda: create_atari_environment( - args.task, max_episode_steps=args.max_episode_steps) + test_envs = SubprocVectorEnv([ + lambda: create_atari_environment(args.task) for _ in range(args.test_num)]) # seed np.random.seed(args.seed) @@ -77,8 +77,9 @@ def test_ppo(args=get_args()): action_range=None) # collector train_collector = Collector( - policy, train_envs, ReplayBuffer(args.buffer_size)) - test_collector = Collector(policy, test_envs) + policy, train_envs, ReplayBuffer(args.buffer_size), + preprocess_fn=preprocess_fn) + test_collector = Collector(policy, test_envs, preprocess_fn=preprocess_fn) # log writer = SummaryWriter(args.logdir + '/' + 'ppo') @@ -100,7 +101,7 @@ def test_ppo(args=get_args()): pprint.pprint(result) # Let's watch its performance! env = create_atari_environment(args.task) - collector = Collector(policy, env) + collector = Collector(policy, env, preprocess_fn=preprocess_fn) result = collector.collect(n_step=2000, render=args.render) print(f'Final reward: {result["rew"]}, length: {result["len"]}') collector.close() diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index f0f34ed..42ee534 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -16,7 +16,7 @@ from tianshou.utils.net.common import Recurrent def get_args(): parser = argparse.ArgumentParser() parser.add_argument('--task', type=str, default='CartPole-v0') - parser.add_argument('--seed', type=int, default=1626) + parser.add_argument('--seed', type=int, default=0) parser.add_argument('--eps-test', type=float, default=0.05) parser.add_argument('--eps-train', type=float, default=0.1) parser.add_argument('--buffer-size', type=int, default=20000) diff --git a/tianshou/__init__.py b/tianshou/__init__.py index 8637643..71d1f36 100644 --- a/tianshou/__init__.py +++ b/tianshou/__init__.py @@ -1,7 +1,7 @@ from tianshou import data, env, utils, policy, trainer, \ exploration -__version__ = '0.2.3' +__version__ = '0.2.4' __all__ = [ 'env', 'data',