move atari wrapper to examples and publish v0.2.4 (#124)
* move atari wrapper to examples * consistency * change drqn seed since it is quite unstable in current seed * minor fix * 0.2.4
This commit is contained in:
parent
ff99662fe6
commit
47e8e2686c
@ -2,6 +2,10 @@ import cv2
|
|||||||
import gym
|
import gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from gym.spaces.box import Box
|
from gym.spaces.box import Box
|
||||||
|
from tianshou.data import Batch
|
||||||
|
|
||||||
|
SIZE = 84
|
||||||
|
FRAME = 4
|
||||||
|
|
||||||
|
|
||||||
def create_atari_environment(name=None, sticky_actions=True,
|
def create_atari_environment(name=None, sticky_actions=True,
|
||||||
@ -14,6 +18,27 @@ def create_atari_environment(name=None, sticky_actions=True,
|
|||||||
return env
|
return env
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_fn(obs=None, act=None, rew=None, done=None,
|
||||||
|
obs_next=None, info=None, policy=None):
|
||||||
|
if obs_next is not None:
|
||||||
|
obs_next = np.reshape(obs_next, (-1, *obs_next.shape[2:]))
|
||||||
|
obs_next = np.moveaxis(obs_next, 0, -1)
|
||||||
|
obs_next = cv2.resize(obs_next, (SIZE, SIZE))
|
||||||
|
obs_next = np.asanyarray(obs_next, dtype=np.uint8)
|
||||||
|
obs_next = np.reshape(obs_next, (-1, FRAME, SIZE, SIZE))
|
||||||
|
obs_next = np.moveaxis(obs_next, 1, -1)
|
||||||
|
elif obs is not None:
|
||||||
|
obs = np.reshape(obs, (-1, *obs.shape[2:]))
|
||||||
|
obs = np.moveaxis(obs, 0, -1)
|
||||||
|
obs = cv2.resize(obs, (SIZE, SIZE))
|
||||||
|
obs = np.asanyarray(obs, dtype=np.uint8)
|
||||||
|
obs = np.reshape(obs, (-1, FRAME, SIZE, SIZE))
|
||||||
|
obs = np.moveaxis(obs, 1, -1)
|
||||||
|
|
||||||
|
return Batch(obs=obs, act=act, rew=rew, done=done,
|
||||||
|
obs_next=obs_next, info=info)
|
||||||
|
|
||||||
|
|
||||||
class preprocessing(object):
|
class preprocessing(object):
|
||||||
def __init__(self, env, frame_skip=4, terminal_on_life_loss=False,
|
def __init__(self, env, frame_skip=4, terminal_on_life_loss=False,
|
||||||
size=84, max_episode_steps=2000):
|
size=84, max_episode_steps=2000):
|
||||||
@ -35,7 +60,8 @@ class preprocessing(object):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def observation_space(self):
|
def observation_space(self):
|
||||||
return Box(low=0, high=255, shape=(self.size, self.size, 4),
|
return Box(low=0, high=255,
|
||||||
|
shape=(self.size, self.size, self.frame_skip),
|
||||||
dtype=np.uint8)
|
dtype=np.uint8)
|
||||||
|
|
||||||
def action_space(self):
|
def action_space(self):
|
||||||
@ -57,8 +83,8 @@ class preprocessing(object):
|
|||||||
self._grayscale_obs(self.screen_buffer[0])
|
self._grayscale_obs(self.screen_buffer[0])
|
||||||
self.screen_buffer[1].fill(0)
|
self.screen_buffer[1].fill(0)
|
||||||
|
|
||||||
return np.stack([
|
return np.array([self._pool_and_resize()
|
||||||
self._pool_and_resize() for _ in range(self.frame_skip)], axis=-1)
|
for _ in range(self.frame_skip)])
|
||||||
|
|
||||||
def render(self, mode='human'):
|
def render(self, mode='human'):
|
||||||
return self.env.render(mode)
|
return self.env.render(mode)
|
||||||
@ -85,19 +111,15 @@ class preprocessing(object):
|
|||||||
self._grayscale_obs(self.screen_buffer[t_])
|
self._grayscale_obs(self.screen_buffer[t_])
|
||||||
|
|
||||||
observation.append(self._pool_and_resize())
|
observation.append(self._pool_and_resize())
|
||||||
while len(observation) > 0 and len(observation) < self.frame_skip:
|
if len(observation) == 0:
|
||||||
|
observation = [self._pool_and_resize()
|
||||||
|
for _ in range(self.frame_skip)]
|
||||||
|
while len(observation) > 0 and \
|
||||||
|
len(observation) < self.frame_skip:
|
||||||
observation.append(observation[-1])
|
observation.append(observation[-1])
|
||||||
if len(observation) > 0:
|
terminal = self.count >= self.max_episode_steps
|
||||||
observation = np.stack(observation, axis=-1)
|
return np.array(observation), total_reward, \
|
||||||
else:
|
(terminal or is_terminal), info
|
||||||
observation = np.stack([
|
|
||||||
self._pool_and_resize() for _ in range(self.frame_skip)],
|
|
||||||
axis=-1)
|
|
||||||
if self.count >= self.max_episode_steps:
|
|
||||||
terminal = True
|
|
||||||
else:
|
|
||||||
terminal = False
|
|
||||||
return observation, total_reward, (terminal or is_terminal), info
|
|
||||||
|
|
||||||
def _grayscale_obs(self, output):
|
def _grayscale_obs(self, output):
|
||||||
self.env.ale.getScreenGrayscale(output)
|
self.env.ale.getScreenGrayscale(output)
|
||||||
@ -108,9 +130,4 @@ class preprocessing(object):
|
|||||||
np.maximum(self.screen_buffer[0], self.screen_buffer[1],
|
np.maximum(self.screen_buffer[0], self.screen_buffer[1],
|
||||||
out=self.screen_buffer[0])
|
out=self.screen_buffer[0])
|
||||||
|
|
||||||
transformed_image = cv2.resize(self.screen_buffer[0],
|
return self.screen_buffer[0]
|
||||||
(self.size, self.size),
|
|
||||||
interpolation=cv2.INTER_AREA)
|
|
||||||
int_image = np.asarray(transformed_image, dtype=np.uint8)
|
|
||||||
# return np.expand_dims(int_image, axis=2)
|
|
||||||
return int_image
|
|
@ -8,11 +8,11 @@ from tianshou.policy import A2CPolicy
|
|||||||
from tianshou.env import SubprocVectorEnv
|
from tianshou.env import SubprocVectorEnv
|
||||||
from tianshou.trainer import onpolicy_trainer
|
from tianshou.trainer import onpolicy_trainer
|
||||||
from tianshou.data import Collector, ReplayBuffer
|
from tianshou.data import Collector, ReplayBuffer
|
||||||
from tianshou.env.atari import create_atari_environment
|
|
||||||
|
|
||||||
from tianshou.utils.net.discrete import Actor, Critic
|
from tianshou.utils.net.discrete import Actor, Critic
|
||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.common import Net
|
||||||
|
|
||||||
|
from atari import create_atari_environment, preprocess_fn
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
@ -45,19 +45,16 @@ def get_args():
|
|||||||
|
|
||||||
|
|
||||||
def test_a2c(args=get_args()):
|
def test_a2c(args=get_args()):
|
||||||
env = create_atari_environment(
|
env = create_atari_environment(args.task)
|
||||||
args.task, max_episode_steps=args.max_episode_steps)
|
|
||||||
args.state_shape = env.observation_space.shape or env.observation_space.n
|
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||||
args.action_shape = env.env.action_space.shape or env.env.action_space.n
|
args.action_shape = env.env.action_space.shape or env.env.action_space.n
|
||||||
# train_envs = gym.make(args.task)
|
# train_envs = gym.make(args.task)
|
||||||
train_envs = SubprocVectorEnv(
|
train_envs = SubprocVectorEnv(
|
||||||
[lambda: create_atari_environment(
|
[lambda: create_atari_environment(args.task)
|
||||||
args.task, max_episode_steps=args.max_episode_steps)
|
|
||||||
for _ in range(args.training_num)])
|
for _ in range(args.training_num)])
|
||||||
# test_envs = gym.make(args.task)
|
# test_envs = gym.make(args.task)
|
||||||
test_envs = SubprocVectorEnv(
|
test_envs = SubprocVectorEnv(
|
||||||
[lambda: create_atari_environment(
|
[lambda: create_atari_environment(args.task)
|
||||||
args.task, max_episode_steps=args.max_episode_steps)
|
|
||||||
for _ in range(args.test_num)])
|
for _ in range(args.test_num)])
|
||||||
# seed
|
# seed
|
||||||
np.random.seed(args.seed)
|
np.random.seed(args.seed)
|
||||||
@ -76,8 +73,9 @@ def test_a2c(args=get_args()):
|
|||||||
ent_coef=args.ent_coef, max_grad_norm=args.max_grad_norm)
|
ent_coef=args.ent_coef, max_grad_norm=args.max_grad_norm)
|
||||||
# collector
|
# collector
|
||||||
train_collector = Collector(
|
train_collector = Collector(
|
||||||
policy, train_envs, ReplayBuffer(args.buffer_size))
|
policy, train_envs, ReplayBuffer(args.buffer_size),
|
||||||
test_collector = Collector(policy, test_envs)
|
preprocess_fn=preprocess_fn)
|
||||||
|
test_collector = Collector(policy, test_envs, preprocess_fn=preprocess_fn)
|
||||||
# log
|
# log
|
||||||
writer = SummaryWriter(args.logdir + '/' + 'a2c')
|
writer = SummaryWriter(args.logdir + '/' + 'a2c')
|
||||||
|
|
||||||
@ -99,7 +97,7 @@ def test_a2c(args=get_args()):
|
|||||||
pprint.pprint(result)
|
pprint.pprint(result)
|
||||||
# Let's watch its performance!
|
# Let's watch its performance!
|
||||||
env = create_atari_environment(args.task)
|
env = create_atari_environment(args.task)
|
||||||
collector = Collector(policy, env)
|
collector = Collector(policy, env, preprocess_fn=preprocess_fn)
|
||||||
result = collector.collect(n_episode=1, render=args.render)
|
result = collector.collect(n_episode=1, render=args.render)
|
||||||
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
|
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
|
||||||
collector.close()
|
collector.close()
|
||||||
|
@ -9,7 +9,8 @@ from tianshou.env import SubprocVectorEnv
|
|||||||
from tianshou.utils.net.discrete import DQN
|
from tianshou.utils.net.discrete import DQN
|
||||||
from tianshou.trainer import offpolicy_trainer
|
from tianshou.trainer import offpolicy_trainer
|
||||||
from tianshou.data import Collector, ReplayBuffer
|
from tianshou.data import Collector, ReplayBuffer
|
||||||
from tianshou.env.atari import create_atari_environment
|
|
||||||
|
from atari import create_atari_environment, preprocess_fn
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
@ -49,8 +50,8 @@ def test_dqn(args=get_args()):
|
|||||||
for _ in range(args.training_num)])
|
for _ in range(args.training_num)])
|
||||||
# test_envs = gym.make(args.task)
|
# test_envs = gym.make(args.task)
|
||||||
test_envs = SubprocVectorEnv([
|
test_envs = SubprocVectorEnv([
|
||||||
lambda: create_atari_environment(
|
lambda: create_atari_environment(args.task)
|
||||||
args.task) for _ in range(args.test_num)])
|
for _ in range(args.test_num)])
|
||||||
# seed
|
# seed
|
||||||
np.random.seed(args.seed)
|
np.random.seed(args.seed)
|
||||||
torch.manual_seed(args.seed)
|
torch.manual_seed(args.seed)
|
||||||
@ -68,8 +69,9 @@ def test_dqn(args=get_args()):
|
|||||||
target_update_freq=args.target_update_freq)
|
target_update_freq=args.target_update_freq)
|
||||||
# collector
|
# collector
|
||||||
train_collector = Collector(
|
train_collector = Collector(
|
||||||
policy, train_envs, ReplayBuffer(args.buffer_size))
|
policy, train_envs, ReplayBuffer(args.buffer_size),
|
||||||
test_collector = Collector(policy, test_envs)
|
preprocess_fn=preprocess_fn)
|
||||||
|
test_collector = Collector(policy, test_envs, preprocess_fn=preprocess_fn)
|
||||||
# policy.set_eps(1)
|
# policy.set_eps(1)
|
||||||
train_collector.collect(n_step=args.batch_size * 4)
|
train_collector.collect(n_step=args.batch_size * 4)
|
||||||
print(len(train_collector.buffer))
|
print(len(train_collector.buffer))
|
||||||
@ -101,7 +103,7 @@ def test_dqn(args=get_args()):
|
|||||||
pprint.pprint(result)
|
pprint.pprint(result)
|
||||||
# Let's watch its performance!
|
# Let's watch its performance!
|
||||||
env = create_atari_environment(args.task)
|
env = create_atari_environment(args.task)
|
||||||
collector = Collector(policy, env)
|
collector = Collector(policy, env, preprocess_fn=preprocess_fn)
|
||||||
result = collector.collect(n_episode=1, render=args.render)
|
result = collector.collect(n_episode=1, render=args.render)
|
||||||
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
|
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
|
||||||
collector.close()
|
collector.close()
|
||||||
|
@ -8,10 +8,11 @@ from tianshou.policy import PPOPolicy
|
|||||||
from tianshou.env import SubprocVectorEnv
|
from tianshou.env import SubprocVectorEnv
|
||||||
from tianshou.trainer import onpolicy_trainer
|
from tianshou.trainer import onpolicy_trainer
|
||||||
from tianshou.data import Collector, ReplayBuffer
|
from tianshou.data import Collector, ReplayBuffer
|
||||||
from tianshou.env.atari import create_atari_environment
|
|
||||||
from tianshou.utils.net.discrete import Actor, Critic
|
from tianshou.utils.net.discrete import Actor, Critic
|
||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.common import Net
|
||||||
|
|
||||||
|
from atari import create_atari_environment, preprocess_fn
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
@ -44,17 +45,16 @@ def get_args():
|
|||||||
|
|
||||||
|
|
||||||
def test_ppo(args=get_args()):
|
def test_ppo(args=get_args()):
|
||||||
env = create_atari_environment(
|
env = create_atari_environment(args.task)
|
||||||
args.task, max_episode_steps=args.max_episode_steps)
|
|
||||||
args.state_shape = env.observation_space.shape or env.observation_space.n
|
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||||
args.action_shape = env.action_space().shape or env.action_space().n
|
args.action_shape = env.action_space().shape or env.action_space().n
|
||||||
# train_envs = gym.make(args.task)
|
# train_envs = gym.make(args.task)
|
||||||
train_envs = SubprocVectorEnv([lambda: create_atari_environment(
|
train_envs = SubprocVectorEnv([
|
||||||
args.task, max_episode_steps=args.max_episode_steps)
|
lambda: create_atari_environment(args.task)
|
||||||
for _ in range(args.training_num)])
|
for _ in range(args.training_num)])
|
||||||
# test_envs = gym.make(args.task)
|
# test_envs = gym.make(args.task)
|
||||||
test_envs = SubprocVectorEnv([lambda: create_atari_environment(
|
test_envs = SubprocVectorEnv([
|
||||||
args.task, max_episode_steps=args.max_episode_steps)
|
lambda: create_atari_environment(args.task)
|
||||||
for _ in range(args.test_num)])
|
for _ in range(args.test_num)])
|
||||||
# seed
|
# seed
|
||||||
np.random.seed(args.seed)
|
np.random.seed(args.seed)
|
||||||
@ -77,8 +77,9 @@ def test_ppo(args=get_args()):
|
|||||||
action_range=None)
|
action_range=None)
|
||||||
# collector
|
# collector
|
||||||
train_collector = Collector(
|
train_collector = Collector(
|
||||||
policy, train_envs, ReplayBuffer(args.buffer_size))
|
policy, train_envs, ReplayBuffer(args.buffer_size),
|
||||||
test_collector = Collector(policy, test_envs)
|
preprocess_fn=preprocess_fn)
|
||||||
|
test_collector = Collector(policy, test_envs, preprocess_fn=preprocess_fn)
|
||||||
# log
|
# log
|
||||||
writer = SummaryWriter(args.logdir + '/' + 'ppo')
|
writer = SummaryWriter(args.logdir + '/' + 'ppo')
|
||||||
|
|
||||||
@ -100,7 +101,7 @@ def test_ppo(args=get_args()):
|
|||||||
pprint.pprint(result)
|
pprint.pprint(result)
|
||||||
# Let's watch its performance!
|
# Let's watch its performance!
|
||||||
env = create_atari_environment(args.task)
|
env = create_atari_environment(args.task)
|
||||||
collector = Collector(policy, env)
|
collector = Collector(policy, env, preprocess_fn=preprocess_fn)
|
||||||
result = collector.collect(n_step=2000, render=args.render)
|
result = collector.collect(n_step=2000, render=args.render)
|
||||||
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
|
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
|
||||||
collector.close()
|
collector.close()
|
||||||
|
@ -16,7 +16,7 @@ from tianshou.utils.net.common import Recurrent
|
|||||||
def get_args():
|
def get_args():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--task', type=str, default='CartPole-v0')
|
parser.add_argument('--task', type=str, default='CartPole-v0')
|
||||||
parser.add_argument('--seed', type=int, default=1626)
|
parser.add_argument('--seed', type=int, default=0)
|
||||||
parser.add_argument('--eps-test', type=float, default=0.05)
|
parser.add_argument('--eps-test', type=float, default=0.05)
|
||||||
parser.add_argument('--eps-train', type=float, default=0.1)
|
parser.add_argument('--eps-train', type=float, default=0.1)
|
||||||
parser.add_argument('--buffer-size', type=int, default=20000)
|
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from tianshou import data, env, utils, policy, trainer, \
|
from tianshou import data, env, utils, policy, trainer, \
|
||||||
exploration
|
exploration
|
||||||
|
|
||||||
__version__ = '0.2.3'
|
__version__ = '0.2.4'
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'env',
|
'env',
|
||||||
'data',
|
'data',
|
||||||
|
Loading…
x
Reference in New Issue
Block a user