add option to show the running result of cartpole

This commit is contained in:
Dong Yan 2018-02-24 10:53:39 +08:00
parent 764f7fb5f1
commit f3aee448e0
2 changed files with 9 additions and 2 deletions

View File

@ -5,6 +5,7 @@ import tensorflow as tf
import gym import gym
import numpy as np import numpy as np
import time import time
import argparse
# our lib imports here! It's ok to append path in examples # our lib imports here! It's ok to append path in examples
import sys import sys
@ -16,6 +17,9 @@ import tianshou.core.policy.stochastic as policy # TODO: fix imports as zhusuan
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--render", action="store_true", default=False)
args = parser.parse_args()
env = gym.make('CartPole-v0') env = gym.make('CartPole-v0')
observation_dim = env.observation_space.shape observation_dim = env.observation_space.shape
action_dim = env.action_space.n action_dim = env.action_space.n
@ -55,7 +59,7 @@ if __name__ == '__main__':
train_op = optimizer.minimize(total_loss, var_list=pi.trainable_variables) train_op = optimizer.minimize(total_loss, var_list=pi.trainable_variables)
### 3. define data collection ### 3. define data collection
training_data = Batch(env, pi, [advantage_estimation.full_return], [pi]) training_data = Batch(env, pi, [advantage_estimation.full_return], [pi], render = args.render)
### 4. start training ### 4. start training
config = tf.ConfigProto() config = tf.ConfigProto()

View File

@ -9,7 +9,7 @@ class Batch(object):
class for batch datasets. Collect multiple observations (actions, rewards, etc.) on-policy. class for batch datasets. Collect multiple observations (actions, rewards, etc.) on-policy.
""" """
def __init__(self, env, pi, reward_processors, networks): # how to name the function? def __init__(self, env, pi, reward_processors, networks, render=False): # how to name the function?
""" """
constructor constructor
:param env: :param env:
@ -24,6 +24,7 @@ class Batch(object):
self.reward_processors = reward_processors self.reward_processors = reward_processors
self.networks = networks self.networks = networks
self.render = render
self.required_placeholders = {} self.required_placeholders = {}
for net in self.networks: for net in self.networks:
@ -108,6 +109,8 @@ class Batch(object):
ac = self._pi.act(ob, my_feed_dict) ac = self._pi.act(ob, my_feed_dict)
actions.append(ac) actions.append(ac)
if self.render:
self._env.render()
ob, reward, done, _ = self._env.step(ac) ob, reward, done, _ = self._env.step(ac)
rewards.append(reward) rewards.append(reward)