move atari wrapper to examples and publish v0.2.4 (#124)
* move atari wrapper to examples * consistency * change drqn seed since it is quite unstable in current seed * minor fix * 0.2.4
This commit is contained in:
parent
ff99662fe6
commit
47e8e2686c
@ -2,6 +2,10 @@ import cv2
|
||||
import gym
|
||||
import numpy as np
|
||||
from gym.spaces.box import Box
|
||||
from tianshou.data import Batch
|
||||
|
||||
SIZE = 84
|
||||
FRAME = 4
|
||||
|
||||
|
||||
def create_atari_environment(name=None, sticky_actions=True,
|
||||
@ -14,6 +18,27 @@ def create_atari_environment(name=None, sticky_actions=True,
|
||||
return env
|
||||
|
||||
|
||||
def preprocess_fn(obs=None, act=None, rew=None, done=None,
|
||||
obs_next=None, info=None, policy=None):
|
||||
if obs_next is not None:
|
||||
obs_next = np.reshape(obs_next, (-1, *obs_next.shape[2:]))
|
||||
obs_next = np.moveaxis(obs_next, 0, -1)
|
||||
obs_next = cv2.resize(obs_next, (SIZE, SIZE))
|
||||
obs_next = np.asanyarray(obs_next, dtype=np.uint8)
|
||||
obs_next = np.reshape(obs_next, (-1, FRAME, SIZE, SIZE))
|
||||
obs_next = np.moveaxis(obs_next, 1, -1)
|
||||
elif obs is not None:
|
||||
obs = np.reshape(obs, (-1, *obs.shape[2:]))
|
||||
obs = np.moveaxis(obs, 0, -1)
|
||||
obs = cv2.resize(obs, (SIZE, SIZE))
|
||||
obs = np.asanyarray(obs, dtype=np.uint8)
|
||||
obs = np.reshape(obs, (-1, FRAME, SIZE, SIZE))
|
||||
obs = np.moveaxis(obs, 1, -1)
|
||||
|
||||
return Batch(obs=obs, act=act, rew=rew, done=done,
|
||||
obs_next=obs_next, info=info)
|
||||
|
||||
|
||||
class preprocessing(object):
|
||||
def __init__(self, env, frame_skip=4, terminal_on_life_loss=False,
|
||||
size=84, max_episode_steps=2000):
|
||||
@ -35,7 +60,8 @@ class preprocessing(object):
|
||||
|
||||
@property
|
||||
def observation_space(self):
|
||||
return Box(low=0, high=255, shape=(self.size, self.size, 4),
|
||||
return Box(low=0, high=255,
|
||||
shape=(self.size, self.size, self.frame_skip),
|
||||
dtype=np.uint8)
|
||||
|
||||
def action_space(self):
|
||||
@ -57,8 +83,8 @@ class preprocessing(object):
|
||||
self._grayscale_obs(self.screen_buffer[0])
|
||||
self.screen_buffer[1].fill(0)
|
||||
|
||||
return np.stack([
|
||||
self._pool_and_resize() for _ in range(self.frame_skip)], axis=-1)
|
||||
return np.array([self._pool_and_resize()
|
||||
for _ in range(self.frame_skip)])
|
||||
|
||||
def render(self, mode='human'):
|
||||
return self.env.render(mode)
|
||||
@ -85,19 +111,15 @@ class preprocessing(object):
|
||||
self._grayscale_obs(self.screen_buffer[t_])
|
||||
|
||||
observation.append(self._pool_and_resize())
|
||||
while len(observation) > 0 and len(observation) < self.frame_skip:
|
||||
if len(observation) == 0:
|
||||
observation = [self._pool_and_resize()
|
||||
for _ in range(self.frame_skip)]
|
||||
while len(observation) > 0 and \
|
||||
len(observation) < self.frame_skip:
|
||||
observation.append(observation[-1])
|
||||
if len(observation) > 0:
|
||||
observation = np.stack(observation, axis=-1)
|
||||
else:
|
||||
observation = np.stack([
|
||||
self._pool_and_resize() for _ in range(self.frame_skip)],
|
||||
axis=-1)
|
||||
if self.count >= self.max_episode_steps:
|
||||
terminal = True
|
||||
else:
|
||||
terminal = False
|
||||
return observation, total_reward, (terminal or is_terminal), info
|
||||
terminal = self.count >= self.max_episode_steps
|
||||
return np.array(observation), total_reward, \
|
||||
(terminal or is_terminal), info
|
||||
|
||||
def _grayscale_obs(self, output):
|
||||
self.env.ale.getScreenGrayscale(output)
|
||||
@ -108,9 +130,4 @@ class preprocessing(object):
|
||||
np.maximum(self.screen_buffer[0], self.screen_buffer[1],
|
||||
out=self.screen_buffer[0])
|
||||
|
||||
transformed_image = cv2.resize(self.screen_buffer[0],
|
||||
(self.size, self.size),
|
||||
interpolation=cv2.INTER_AREA)
|
||||
int_image = np.asarray(transformed_image, dtype=np.uint8)
|
||||
# return np.expand_dims(int_image, axis=2)
|
||||
return int_image
|
||||
return self.screen_buffer[0]
|
@ -8,11 +8,11 @@ from tianshou.policy import A2CPolicy
|
||||
from tianshou.env import SubprocVectorEnv
|
||||
from tianshou.trainer import onpolicy_trainer
|
||||
from tianshou.data import Collector, ReplayBuffer
|
||||
from tianshou.env.atari import create_atari_environment
|
||||
|
||||
from tianshou.utils.net.discrete import Actor, Critic
|
||||
from tianshou.utils.net.common import Net
|
||||
|
||||
from atari import create_atari_environment, preprocess_fn
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
@ -45,20 +45,17 @@ def get_args():
|
||||
|
||||
|
||||
def test_a2c(args=get_args()):
|
||||
env = create_atari_environment(
|
||||
args.task, max_episode_steps=args.max_episode_steps)
|
||||
env = create_atari_environment(args.task)
|
||||
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||
args.action_shape = env.env.action_space.shape or env.env.action_space.n
|
||||
# train_envs = gym.make(args.task)
|
||||
train_envs = SubprocVectorEnv(
|
||||
[lambda: create_atari_environment(
|
||||
args.task, max_episode_steps=args.max_episode_steps)
|
||||
for _ in range(args.training_num)])
|
||||
[lambda: create_atari_environment(args.task)
|
||||
for _ in range(args.training_num)])
|
||||
# test_envs = gym.make(args.task)
|
||||
test_envs = SubprocVectorEnv(
|
||||
[lambda: create_atari_environment(
|
||||
args.task, max_episode_steps=args.max_episode_steps)
|
||||
for _ in range(args.test_num)])
|
||||
[lambda: create_atari_environment(args.task)
|
||||
for _ in range(args.test_num)])
|
||||
# seed
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
@ -76,8 +73,9 @@ def test_a2c(args=get_args()):
|
||||
ent_coef=args.ent_coef, max_grad_norm=args.max_grad_norm)
|
||||
# collector
|
||||
train_collector = Collector(
|
||||
policy, train_envs, ReplayBuffer(args.buffer_size))
|
||||
test_collector = Collector(policy, test_envs)
|
||||
policy, train_envs, ReplayBuffer(args.buffer_size),
|
||||
preprocess_fn=preprocess_fn)
|
||||
test_collector = Collector(policy, test_envs, preprocess_fn=preprocess_fn)
|
||||
# log
|
||||
writer = SummaryWriter(args.logdir + '/' + 'a2c')
|
||||
|
||||
@ -99,7 +97,7 @@ def test_a2c(args=get_args()):
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
env = create_atari_environment(args.task)
|
||||
collector = Collector(policy, env)
|
||||
collector = Collector(policy, env, preprocess_fn=preprocess_fn)
|
||||
result = collector.collect(n_episode=1, render=args.render)
|
||||
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
|
||||
collector.close()
|
||||
|
@ -9,7 +9,8 @@ from tianshou.env import SubprocVectorEnv
|
||||
from tianshou.utils.net.discrete import DQN
|
||||
from tianshou.trainer import offpolicy_trainer
|
||||
from tianshou.data import Collector, ReplayBuffer
|
||||
from tianshou.env.atari import create_atari_environment
|
||||
|
||||
from atari import create_atari_environment, preprocess_fn
|
||||
|
||||
|
||||
def get_args():
|
||||
@ -49,8 +50,8 @@ def test_dqn(args=get_args()):
|
||||
for _ in range(args.training_num)])
|
||||
# test_envs = gym.make(args.task)
|
||||
test_envs = SubprocVectorEnv([
|
||||
lambda: create_atari_environment(
|
||||
args.task) for _ in range(args.test_num)])
|
||||
lambda: create_atari_environment(args.task)
|
||||
for _ in range(args.test_num)])
|
||||
# seed
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
@ -68,8 +69,9 @@ def test_dqn(args=get_args()):
|
||||
target_update_freq=args.target_update_freq)
|
||||
# collector
|
||||
train_collector = Collector(
|
||||
policy, train_envs, ReplayBuffer(args.buffer_size))
|
||||
test_collector = Collector(policy, test_envs)
|
||||
policy, train_envs, ReplayBuffer(args.buffer_size),
|
||||
preprocess_fn=preprocess_fn)
|
||||
test_collector = Collector(policy, test_envs, preprocess_fn=preprocess_fn)
|
||||
# policy.set_eps(1)
|
||||
train_collector.collect(n_step=args.batch_size * 4)
|
||||
print(len(train_collector.buffer))
|
||||
@ -101,7 +103,7 @@ def test_dqn(args=get_args()):
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
env = create_atari_environment(args.task)
|
||||
collector = Collector(policy, env)
|
||||
collector = Collector(policy, env, preprocess_fn=preprocess_fn)
|
||||
result = collector.collect(n_episode=1, render=args.render)
|
||||
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
|
||||
collector.close()
|
||||
|
@ -8,10 +8,11 @@ from tianshou.policy import PPOPolicy
|
||||
from tianshou.env import SubprocVectorEnv
|
||||
from tianshou.trainer import onpolicy_trainer
|
||||
from tianshou.data import Collector, ReplayBuffer
|
||||
from tianshou.env.atari import create_atari_environment
|
||||
from tianshou.utils.net.discrete import Actor, Critic
|
||||
from tianshou.utils.net.common import Net
|
||||
|
||||
from atari import create_atari_environment, preprocess_fn
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
@ -44,17 +45,16 @@ def get_args():
|
||||
|
||||
|
||||
def test_ppo(args=get_args()):
|
||||
env = create_atari_environment(
|
||||
args.task, max_episode_steps=args.max_episode_steps)
|
||||
env = create_atari_environment(args.task)
|
||||
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||
args.action_shape = env.action_space().shape or env.action_space().n
|
||||
# train_envs = gym.make(args.task)
|
||||
train_envs = SubprocVectorEnv([lambda: create_atari_environment(
|
||||
args.task, max_episode_steps=args.max_episode_steps)
|
||||
train_envs = SubprocVectorEnv([
|
||||
lambda: create_atari_environment(args.task)
|
||||
for _ in range(args.training_num)])
|
||||
# test_envs = gym.make(args.task)
|
||||
test_envs = SubprocVectorEnv([lambda: create_atari_environment(
|
||||
args.task, max_episode_steps=args.max_episode_steps)
|
||||
test_envs = SubprocVectorEnv([
|
||||
lambda: create_atari_environment(args.task)
|
||||
for _ in range(args.test_num)])
|
||||
# seed
|
||||
np.random.seed(args.seed)
|
||||
@ -77,8 +77,9 @@ def test_ppo(args=get_args()):
|
||||
action_range=None)
|
||||
# collector
|
||||
train_collector = Collector(
|
||||
policy, train_envs, ReplayBuffer(args.buffer_size))
|
||||
test_collector = Collector(policy, test_envs)
|
||||
policy, train_envs, ReplayBuffer(args.buffer_size),
|
||||
preprocess_fn=preprocess_fn)
|
||||
test_collector = Collector(policy, test_envs, preprocess_fn=preprocess_fn)
|
||||
# log
|
||||
writer = SummaryWriter(args.logdir + '/' + 'ppo')
|
||||
|
||||
@ -100,7 +101,7 @@ def test_ppo(args=get_args()):
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
env = create_atari_environment(args.task)
|
||||
collector = Collector(policy, env)
|
||||
collector = Collector(policy, env, preprocess_fn=preprocess_fn)
|
||||
result = collector.collect(n_step=2000, render=args.render)
|
||||
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
|
||||
collector.close()
|
||||
|
@ -16,7 +16,7 @@ from tianshou.utils.net.common import Recurrent
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--task', type=str, default='CartPole-v0')
|
||||
parser.add_argument('--seed', type=int, default=1626)
|
||||
parser.add_argument('--seed', type=int, default=0)
|
||||
parser.add_argument('--eps-test', type=float, default=0.05)
|
||||
parser.add_argument('--eps-train', type=float, default=0.1)
|
||||
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||
|
@ -1,7 +1,7 @@
|
||||
from tianshou import data, env, utils, policy, trainer, \
|
||||
exploration
|
||||
|
||||
__version__ = '0.2.3'
|
||||
__version__ = '0.2.4'
|
||||
__all__ = [
|
||||
'env',
|
||||
'data',
|
||||
|
Loading…
x
Reference in New Issue
Block a user