diff --git a/examples/ppo_cartpole.py b/examples/ppo_cartpole.py index e88a379..074331d 100755 --- a/examples/ppo_cartpole.py +++ b/examples/ppo_cartpole.py @@ -5,6 +5,7 @@ import tensorflow as tf import gym import numpy as np import time +import argparse # our lib imports here! It's ok to append path in examples import sys @@ -16,6 +17,9 @@ import tianshou.core.policy.stochastic as policy # TODO: fix imports as zhusuan if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("--render", action="store_true", default=False) + args = parser.parse_args() env = gym.make('CartPole-v0') observation_dim = env.observation_space.shape action_dim = env.action_space.n @@ -55,7 +59,7 @@ if __name__ == '__main__': train_op = optimizer.minimize(total_loss, var_list=pi.trainable_variables) ### 3. define data collection - training_data = Batch(env, pi, [advantage_estimation.full_return], [pi]) + training_data = Batch(env, pi, [advantage_estimation.full_return], [pi], render = args.render) ### 4. start training config = tf.ConfigProto() diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 01ca78d..9c7405d 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -9,7 +9,7 @@ class Batch(object): class for batch datasets. Collect multiple observations (actions, rewards, etc.) on-policy. """ - def __init__(self, env, pi, reward_processors, networks): # how to name the function? + def __init__(self, env, pi, reward_processors, networks, render=False): # how to name the function? """ constructor :param env: @@ -24,6 +24,7 @@ class Batch(object): self.reward_processors = reward_processors self.networks = networks + self.render = render self.required_placeholders = {} for net in self.networks: @@ -108,6 +109,8 @@ class Batch(object): ac = self._pi.act(ob, my_feed_dict) actions.append(ac) + if self.render: + self._env.render() ob, reward, done, _ = self._env.step(ac) rewards.append(reward)