move atari wrapper to examples and publish v0.2.4 (#124)

* move atari wrapper to examples

* consistency

* change drqn seed since it is quite unstable in current seed

* minor fix

* 0.2.4
This commit is contained in:
n+e 2020-07-10 17:20:39 +08:00 committed by GitHub
parent ff99662fe6
commit 47e8e2686c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 70 additions and 52 deletions

View File

@ -2,6 +2,10 @@ import cv2
import gym
import numpy as np
from gym.spaces.box import Box
from tianshou.data import Batch
SIZE = 84
FRAME = 4
def create_atari_environment(name=None, sticky_actions=True,
@ -14,6 +18,27 @@ def create_atari_environment(name=None, sticky_actions=True,
return env
def preprocess_fn(obs=None, act=None, rew=None, done=None,
obs_next=None, info=None, policy=None):
if obs_next is not None:
obs_next = np.reshape(obs_next, (-1, *obs_next.shape[2:]))
obs_next = np.moveaxis(obs_next, 0, -1)
obs_next = cv2.resize(obs_next, (SIZE, SIZE))
obs_next = np.asanyarray(obs_next, dtype=np.uint8)
obs_next = np.reshape(obs_next, (-1, FRAME, SIZE, SIZE))
obs_next = np.moveaxis(obs_next, 1, -1)
elif obs is not None:
obs = np.reshape(obs, (-1, *obs.shape[2:]))
obs = np.moveaxis(obs, 0, -1)
obs = cv2.resize(obs, (SIZE, SIZE))
obs = np.asanyarray(obs, dtype=np.uint8)
obs = np.reshape(obs, (-1, FRAME, SIZE, SIZE))
obs = np.moveaxis(obs, 1, -1)
return Batch(obs=obs, act=act, rew=rew, done=done,
obs_next=obs_next, info=info)
class preprocessing(object):
def __init__(self, env, frame_skip=4, terminal_on_life_loss=False,
size=84, max_episode_steps=2000):
@ -35,7 +60,8 @@ class preprocessing(object):
@property
def observation_space(self):
return Box(low=0, high=255, shape=(self.size, self.size, 4),
return Box(low=0, high=255,
shape=(self.size, self.size, self.frame_skip),
dtype=np.uint8)
def action_space(self):
@ -57,8 +83,8 @@ class preprocessing(object):
self._grayscale_obs(self.screen_buffer[0])
self.screen_buffer[1].fill(0)
return np.stack([
self._pool_and_resize() for _ in range(self.frame_skip)], axis=-1)
return np.array([self._pool_and_resize()
for _ in range(self.frame_skip)])
def render(self, mode='human'):
return self.env.render(mode)
@ -85,19 +111,15 @@ class preprocessing(object):
self._grayscale_obs(self.screen_buffer[t_])
observation.append(self._pool_and_resize())
while len(observation) > 0 and len(observation) < self.frame_skip:
if len(observation) == 0:
observation = [self._pool_and_resize()
for _ in range(self.frame_skip)]
while len(observation) > 0 and \
len(observation) < self.frame_skip:
observation.append(observation[-1])
if len(observation) > 0:
observation = np.stack(observation, axis=-1)
else:
observation = np.stack([
self._pool_and_resize() for _ in range(self.frame_skip)],
axis=-1)
if self.count >= self.max_episode_steps:
terminal = True
else:
terminal = False
return observation, total_reward, (terminal or is_terminal), info
terminal = self.count >= self.max_episode_steps
return np.array(observation), total_reward, \
(terminal or is_terminal), info
def _grayscale_obs(self, output):
self.env.ale.getScreenGrayscale(output)
@ -108,9 +130,4 @@ class preprocessing(object):
np.maximum(self.screen_buffer[0], self.screen_buffer[1],
out=self.screen_buffer[0])
transformed_image = cv2.resize(self.screen_buffer[0],
(self.size, self.size),
interpolation=cv2.INTER_AREA)
int_image = np.asarray(transformed_image, dtype=np.uint8)
# return np.expand_dims(int_image, axis=2)
return int_image
return self.screen_buffer[0]

View File

@ -8,11 +8,11 @@ from tianshou.policy import A2CPolicy
from tianshou.env import SubprocVectorEnv
from tianshou.trainer import onpolicy_trainer
from tianshou.data import Collector, ReplayBuffer
from tianshou.env.atari import create_atari_environment
from tianshou.utils.net.discrete import Actor, Critic
from tianshou.utils.net.common import Net
from atari import create_atari_environment, preprocess_fn
def get_args():
parser = argparse.ArgumentParser()
@ -45,20 +45,17 @@ def get_args():
def test_a2c(args=get_args()):
env = create_atari_environment(
args.task, max_episode_steps=args.max_episode_steps)
env = create_atari_environment(args.task)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.env.action_space.shape or env.env.action_space.n
# train_envs = gym.make(args.task)
train_envs = SubprocVectorEnv(
[lambda: create_atari_environment(
args.task, max_episode_steps=args.max_episode_steps)
for _ in range(args.training_num)])
[lambda: create_atari_environment(args.task)
for _ in range(args.training_num)])
# test_envs = gym.make(args.task)
test_envs = SubprocVectorEnv(
[lambda: create_atari_environment(
args.task, max_episode_steps=args.max_episode_steps)
for _ in range(args.test_num)])
[lambda: create_atari_environment(args.task)
for _ in range(args.test_num)])
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
@ -76,8 +73,9 @@ def test_a2c(args=get_args()):
ent_coef=args.ent_coef, max_grad_norm=args.max_grad_norm)
# collector
train_collector = Collector(
policy, train_envs, ReplayBuffer(args.buffer_size))
test_collector = Collector(policy, test_envs)
policy, train_envs, ReplayBuffer(args.buffer_size),
preprocess_fn=preprocess_fn)
test_collector = Collector(policy, test_envs, preprocess_fn=preprocess_fn)
# log
writer = SummaryWriter(args.logdir + '/' + 'a2c')
@ -99,7 +97,7 @@ def test_a2c(args=get_args()):
pprint.pprint(result)
# Let's watch its performance!
env = create_atari_environment(args.task)
collector = Collector(policy, env)
collector = Collector(policy, env, preprocess_fn=preprocess_fn)
result = collector.collect(n_episode=1, render=args.render)
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
collector.close()

View File

@ -9,7 +9,8 @@ from tianshou.env import SubprocVectorEnv
from tianshou.utils.net.discrete import DQN
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer
from tianshou.env.atari import create_atari_environment
from atari import create_atari_environment, preprocess_fn
def get_args():
@ -49,8 +50,8 @@ def test_dqn(args=get_args()):
for _ in range(args.training_num)])
# test_envs = gym.make(args.task)
test_envs = SubprocVectorEnv([
lambda: create_atari_environment(
args.task) for _ in range(args.test_num)])
lambda: create_atari_environment(args.task)
for _ in range(args.test_num)])
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
@ -68,8 +69,9 @@ def test_dqn(args=get_args()):
target_update_freq=args.target_update_freq)
# collector
train_collector = Collector(
policy, train_envs, ReplayBuffer(args.buffer_size))
test_collector = Collector(policy, test_envs)
policy, train_envs, ReplayBuffer(args.buffer_size),
preprocess_fn=preprocess_fn)
test_collector = Collector(policy, test_envs, preprocess_fn=preprocess_fn)
# policy.set_eps(1)
train_collector.collect(n_step=args.batch_size * 4)
print(len(train_collector.buffer))
@ -101,7 +103,7 @@ def test_dqn(args=get_args()):
pprint.pprint(result)
# Let's watch its performance!
env = create_atari_environment(args.task)
collector = Collector(policy, env)
collector = Collector(policy, env, preprocess_fn=preprocess_fn)
result = collector.collect(n_episode=1, render=args.render)
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
collector.close()

View File

@ -8,10 +8,11 @@ from tianshou.policy import PPOPolicy
from tianshou.env import SubprocVectorEnv
from tianshou.trainer import onpolicy_trainer
from tianshou.data import Collector, ReplayBuffer
from tianshou.env.atari import create_atari_environment
from tianshou.utils.net.discrete import Actor, Critic
from tianshou.utils.net.common import Net
from atari import create_atari_environment, preprocess_fn
def get_args():
parser = argparse.ArgumentParser()
@ -44,17 +45,16 @@ def get_args():
def test_ppo(args=get_args()):
env = create_atari_environment(
args.task, max_episode_steps=args.max_episode_steps)
env = create_atari_environment(args.task)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space().shape or env.action_space().n
# train_envs = gym.make(args.task)
train_envs = SubprocVectorEnv([lambda: create_atari_environment(
args.task, max_episode_steps=args.max_episode_steps)
train_envs = SubprocVectorEnv([
lambda: create_atari_environment(args.task)
for _ in range(args.training_num)])
# test_envs = gym.make(args.task)
test_envs = SubprocVectorEnv([lambda: create_atari_environment(
args.task, max_episode_steps=args.max_episode_steps)
test_envs = SubprocVectorEnv([
lambda: create_atari_environment(args.task)
for _ in range(args.test_num)])
# seed
np.random.seed(args.seed)
@ -77,8 +77,9 @@ def test_ppo(args=get_args()):
action_range=None)
# collector
train_collector = Collector(
policy, train_envs, ReplayBuffer(args.buffer_size))
test_collector = Collector(policy, test_envs)
policy, train_envs, ReplayBuffer(args.buffer_size),
preprocess_fn=preprocess_fn)
test_collector = Collector(policy, test_envs, preprocess_fn=preprocess_fn)
# log
writer = SummaryWriter(args.logdir + '/' + 'ppo')
@ -100,7 +101,7 @@ def test_ppo(args=get_args()):
pprint.pprint(result)
# Let's watch its performance!
env = create_atari_environment(args.task)
collector = Collector(policy, env)
collector = Collector(policy, env, preprocess_fn=preprocess_fn)
result = collector.collect(n_step=2000, render=args.render)
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
collector.close()

View File

@ -16,7 +16,7 @@ from tianshou.utils.net.common import Recurrent
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='CartPole-v0')
parser.add_argument('--seed', type=int, default=1626)
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--eps-test', type=float, default=0.05)
parser.add_argument('--eps-train', type=float, default=0.1)
parser.add_argument('--buffer-size', type=int, default=20000)

View File

@ -1,7 +1,7 @@
from tianshou import data, env, utils, policy, trainer, \
exploration
__version__ = '0.2.3'
__version__ = '0.2.4'
__all__ = [
'env',
'data',