add option to show the running result of cartpole
This commit is contained in:
parent
764f7fb5f1
commit
f3aee448e0
@ -5,6 +5,7 @@ import tensorflow as tf
|
|||||||
import gym
|
import gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import time
|
import time
|
||||||
|
import argparse
|
||||||
|
|
||||||
# our lib imports here! It's ok to append path in examples
|
# our lib imports here! It's ok to append path in examples
|
||||||
import sys
|
import sys
|
||||||
@ -16,6 +17,9 @@ import tianshou.core.policy.stochastic as policy # TODO: fix imports as zhusuan
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--render", action="store_true", default=False)
|
||||||
|
args = parser.parse_args()
|
||||||
env = gym.make('CartPole-v0')
|
env = gym.make('CartPole-v0')
|
||||||
observation_dim = env.observation_space.shape
|
observation_dim = env.observation_space.shape
|
||||||
action_dim = env.action_space.n
|
action_dim = env.action_space.n
|
||||||
@ -55,7 +59,7 @@ if __name__ == '__main__':
|
|||||||
train_op = optimizer.minimize(total_loss, var_list=pi.trainable_variables)
|
train_op = optimizer.minimize(total_loss, var_list=pi.trainable_variables)
|
||||||
|
|
||||||
### 3. define data collection
|
### 3. define data collection
|
||||||
training_data = Batch(env, pi, [advantage_estimation.full_return], [pi])
|
training_data = Batch(env, pi, [advantage_estimation.full_return], [pi], render = args.render)
|
||||||
|
|
||||||
### 4. start training
|
### 4. start training
|
||||||
config = tf.ConfigProto()
|
config = tf.ConfigProto()
|
||||||
|
@ -9,7 +9,7 @@ class Batch(object):
|
|||||||
class for batch datasets. Collect multiple observations (actions, rewards, etc.) on-policy.
|
class for batch datasets. Collect multiple observations (actions, rewards, etc.) on-policy.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, env, pi, reward_processors, networks): # how to name the function?
|
def __init__(self, env, pi, reward_processors, networks, render=False): # how to name the function?
|
||||||
"""
|
"""
|
||||||
constructor
|
constructor
|
||||||
:param env:
|
:param env:
|
||||||
@ -24,6 +24,7 @@ class Batch(object):
|
|||||||
|
|
||||||
self.reward_processors = reward_processors
|
self.reward_processors = reward_processors
|
||||||
self.networks = networks
|
self.networks = networks
|
||||||
|
self.render = render
|
||||||
|
|
||||||
self.required_placeholders = {}
|
self.required_placeholders = {}
|
||||||
for net in self.networks:
|
for net in self.networks:
|
||||||
@ -108,6 +109,8 @@ class Batch(object):
|
|||||||
ac = self._pi.act(ob, my_feed_dict)
|
ac = self._pi.act(ob, my_feed_dict)
|
||||||
actions.append(ac)
|
actions.append(ac)
|
||||||
|
|
||||||
|
if self.render:
|
||||||
|
self._env.render()
|
||||||
ob, reward, done, _ = self._env.step(ac)
|
ob, reward, done, _ = self._env.step(ac)
|
||||||
rewards.append(reward)
|
rewards.append(reward)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user