add option to show the running result of cartpole
This commit is contained in:
parent
764f7fb5f1
commit
f3aee448e0
@ -5,6 +5,7 @@ import tensorflow as tf
|
||||
import gym
|
||||
import numpy as np
|
||||
import time
|
||||
import argparse
|
||||
|
||||
# our lib imports here! It's ok to append path in examples
|
||||
import sys
|
||||
@ -16,6 +17,9 @@ import tianshou.core.policy.stochastic as policy # TODO: fix imports as zhusuan
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--render", action="store_true", default=False)
|
||||
args = parser.parse_args()
|
||||
env = gym.make('CartPole-v0')
|
||||
observation_dim = env.observation_space.shape
|
||||
action_dim = env.action_space.n
|
||||
@ -55,7 +59,7 @@ if __name__ == '__main__':
|
||||
train_op = optimizer.minimize(total_loss, var_list=pi.trainable_variables)
|
||||
|
||||
### 3. define data collection
|
||||
training_data = Batch(env, pi, [advantage_estimation.full_return], [pi])
|
||||
training_data = Batch(env, pi, [advantage_estimation.full_return], [pi], render = args.render)
|
||||
|
||||
### 4. start training
|
||||
config = tf.ConfigProto()
|
||||
|
@ -9,7 +9,7 @@ class Batch(object):
|
||||
class for batch datasets. Collect multiple observations (actions, rewards, etc.) on-policy.
|
||||
"""
|
||||
|
||||
def __init__(self, env, pi, reward_processors, networks): # how to name the function?
|
||||
def __init__(self, env, pi, reward_processors, networks, render=False): # how to name the function?
|
||||
"""
|
||||
constructor
|
||||
:param env:
|
||||
@ -24,6 +24,7 @@ class Batch(object):
|
||||
|
||||
self.reward_processors = reward_processors
|
||||
self.networks = networks
|
||||
self.render = render
|
||||
|
||||
self.required_placeholders = {}
|
||||
for net in self.networks:
|
||||
@ -108,6 +109,8 @@ class Batch(object):
|
||||
ac = self._pi.act(ob, my_feed_dict)
|
||||
actions.append(ac)
|
||||
|
||||
if self.render:
|
||||
self._env.render()
|
||||
ob, reward, done, _ = self._env.step(ac)
|
||||
rewards.append(reward)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user