finished all ppo examples. Training is remarkably slower than the version before Jan 13. More strangely, in the gym example there's almost no improvement... but this problem comes behind design. I'll first write actor-critic.

This commit is contained in:
haoshengzou 2018-01-15 00:03:06 +08:00
parent fed3bf2a12
commit 983cd36074
8 changed files with 254 additions and 187 deletions

View File

@ -11,6 +11,9 @@ from tianshou.data.replay_buffer.utils import get_replay_buffer
import tianshou.core.policy.dqn as policy
# THIS EXAMPLE IS NOT FINISHED YET!!!
def policy_net(observation, action_dim):
"""
Constructs the policy network. NOT NEEDED IN THE LIBRARY! this is pure tf

View File

@ -33,55 +33,49 @@ if __name__ == '__main__':
tf.set_random_seed(seed)
### 1. build network with pure tf
observation_ph = tf.placeholder(tf.float32, shape=(None,) + observation_dim) # network input
observation_ph = tf.placeholder(tf.float32, shape=(None,) + observation_dim)
def my_policy():
net = tf.layers.dense(observation_ph, 32, activation=tf.nn.tanh)
net = tf.layers.dense(net, 32, activation=tf.nn.tanh)
action_mean = tf.layers.dense(net, action_dim, activation=None)
action_logstd = tf.get_variable('action_logstd', shape=(action_dim, ))
# value = tf.layers.dense(net, 1, activation=None)
return action_mean, action_logstd, None # None value head
# TODO: current implementation of passing function or overriding function has to return a value head
# to allow network sharing between policy and value networks. This makes 'policy' and 'value_function'
# imbalanced semantically (though they are naturally imbalanced since 'policy' is required to interact
# with the environment and 'value_function' is not). I have an idea to solve this imbalance, which is
# not based on passing function or overriding function.
# TODO: current implementation of passing function or overriding function has to return a value head
# to allow network sharing between policy and value networks. This makes 'policy' and 'value_function'
# imbalanced semantically (though they are naturally imbalanced since 'policy' is required to interact
# with the environment and 'value_function' is not). I have an idea to solve this imbalance, which is
# not based on passing function or overriding function.
### 2. build policy, losses, optimizers
### 2. build policy, loss, optimizer
pi = policy.Normal(my_policy, observation_placeholder=observation_ph, weight_update=0)
# action = tf.placeholder(dtype=tf.float32, shape=(None, action_dim)) # batch of integer actions
# advantage = tf.placeholder(dtype=tf.float32, shape=(None,)) # advantage values used in the Gradients
ppo_loss_clip = losses.ppo_clip(pi, clip_param) # TongzhengRen: losses.vpg ... management of placeholders and feed_dict
ppo_loss_clip = losses.ppo_clip(pi, clip_param)
total_loss = ppo_loss_clip
optimizer = tf.train.AdamOptimizer(1e-4)
train_op = optimizer.minimize(total_loss, var_list=pi.trainable_variables)
### 3. define data collection
training_data = Batch(env, pi, advantage_estimation.full_return) # YouQiaoben: finish and polish Batch, advantage_estimation.gae_lambda as in PPO paper
# ShihongSong: Replay(), see dqn_example.py
# maybe a dict to manage the elements to be collected
training_data = Batch(env, pi, advantage_estimation.full_return)
### 4. start training
# init = tf.global_variables_initializer()
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
with tf.Session(config=config) as sess:
sess.run(tf.global_variables_initializer())
# sync pi and pi_old
# assign pi to pi_old
pi.sync_weights() # TODO: automate this for policies with target network
start_time = time.time()
for i in range(100): # until some stopping criterion met...
for i in range(100):
# collect data
training_data.collect(num_episodes=20) # YouQiaoben, ShihongSong
training_data.collect(num_episodes=20)
# print current return
print('Epoch {}:'.format(i))
@ -89,7 +83,7 @@ if __name__ == '__main__':
# update network
for _ in range(num_batches):
feed_dict = training_data.next_batch(batch_size) # YouQiaoben, ShihongSong
feed_dict = training_data.next_batch(batch_size)
sess.run(train_op, feed_dict=feed_dict)
# assigning pi to pi_old

View File

@ -0,0 +1,112 @@
#!/usr/bin/env python
from __future__ import absolute_import
import tensorflow as tf
import time
import numpy as np
# our lib imports here! It's ok to append path in examples
import sys
sys.path.append('..')
from tianshou.core import losses
from tianshou.data.batch import Batch
import tianshou.data.advantage_estimation as advantage_estimation
import tianshou.core.policy.stochastic as policy # TODO: fix imports as zhusuan so that only need to import to policy
from rllab.envs.box2d.cartpole_env import CartpoleEnv
from rllab.envs.normalized_env import normalize
# for tutorial purpose, placeholders are explicitly appended with '_ph' suffix
# this example with batch_norm and dropout almost surely cannot improve. it just shows how to use those
# layers and another way of writing networks.
class MyPolicy(object):
def __init__(self, observation_ph, is_training_ph, keep_prob_ph, action_dim):
self.observation_ph = observation_ph
self.is_training_ph = is_training_ph
self.keep_prob_ph = keep_prob_ph
self.action_dim = action_dim
def __call__(self):
net = tf.layers.dense(self.observation_ph, 32, activation=None)
net = tf.layers.batch_normalization(net, training=self.is_training_ph)
net = tf.nn.relu(net)
net = tf.nn.dropout(net, keep_prob=self.keep_prob_ph)
net = tf.layers.dense(net, 32, activation=tf.nn.relu)
net = tf.layers.dropout(net, rate=1 - self.keep_prob_ph)
action_mean = tf.layers.dense(net, action_dim, activation=None)
action_logstd = tf.get_variable('action_logstd', shape=(self.action_dim,), dtype=tf.float32)
return action_mean, action_logstd, None
if __name__ == '__main__':
env = normalize(CartpoleEnv())
observation_dim = env.observation_space.shape
action_dim = env.action_space.flat_dim
clip_param = 0.2
num_batches = 10
batch_size = 128
seed = 10
np.random.seed(seed)
tf.set_random_seed(seed)
### 1. build network with pure tf
observation_ph = tf.placeholder(tf.float32, shape=(None,) + observation_dim)
is_training_ph = tf.placeholder(tf.bool, shape=())
keep_prob_ph = tf.placeholder(tf.float32, shape=())
my_policy = MyPolicy(observation_ph, is_training_ph, keep_prob_ph, action_dim)
### 2. build policy, loss, optimizer
pi = policy.Normal(my_policy, observation_placeholder=observation_ph, weight_update=0)
ppo_loss_clip = losses.ppo_clip(pi, clip_param)
total_loss = ppo_loss_clip
optimizer = tf.train.AdamOptimizer(1e-4)
train_op = optimizer.minimize(total_loss, var_list=pi.trainable_variables)
### 3. define data collection
training_data = Batch(env, pi, advantage_estimation.full_return)
### 4. start training
feed_dict_train = {is_training_ph: True, keep_prob_ph: 0.8}
feed_dict_test = {is_training_ph: False, keep_prob_ph: 1}
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
with tf.Session(config=config) as sess:
sess.run(tf.global_variables_initializer())
# assign pi to pi_old
pi.sync_weights() # TODO: automate this for policies with target network
start_time = time.time()
for i in range(100):
# collect data
training_data.collect(num_episodes=20, my_feed_dict=feed_dict_train)
# print current return
print('Epoch {}:'.format(i))
training_data.statistics()
# update network
for _ in range(num_batches):
feed_dict = training_data.next_batch(batch_size)
feed_dict.update(feed_dict_train)
sess.run(train_op, feed_dict=feed_dict)
# assigning pi to pi_old
pi.update_weights()
# approximate test mode
training_data.collect(num_episodes=10, my_feed_dict=feed_dict_test)
print('After training:')
training_data.statistics()
print('Elapsed time: {:.1f} min'.format((time.time() - start_time) / 60))

View File

@ -14,28 +14,8 @@ from tianshou.data.batch import Batch
import tianshou.data.advantage_estimation as advantage_estimation
import tianshou.core.policy.stochastic as policy # TODO: fix imports as zhusuan so that only need to import to policy
from rllab.envs.box2d.cartpole_env import CartpoleEnv
from rllab.envs.normalized_env import normalize
def policy_net(observation, action_dim, scope=None):
"""
Constructs the policy network. NOT NEEDED IN THE LIBRARY! this is pure tf
:param observation: Placeholder for the observation. A tensor of shape (bs, x, y, channels)
:param action_dim: int. The number of actions.
:param scope: str. Specifying the scope of the variables.
"""
# with tf.variable_scope(scope):
net = tf.layers.dense(observation, 32, activation=tf.nn.tanh)
net = tf.layers.dense(net, 32, activation=tf.nn.tanh)
act_logits = tf.layers.dense(net, action_dim, activation=None)
return act_logits
if __name__ == '__main__': # a clean version with only policy net, no value net
if __name__ == '__main__':
env = gym.make('CartPole-v0')
observation_dim = env.observation_space.shape
action_dim = env.action_space.n
@ -44,51 +24,52 @@ if __name__ == '__main__': # a clean version with only policy net, no value net
num_batches = 10
batch_size = 512
seed = 10
seed = 5
np.random.seed(seed)
tf.set_random_seed(seed)
# 1. build network with pure tf
observation = tf.placeholder(tf.float32, shape=(None,) + observation_dim) # network input
### 1. build network with pure tf
observation_ph = tf.placeholder(tf.float32, shape=(None,) + observation_dim)
with tf.variable_scope('pi'):
action_logits = policy_net(observation, action_dim, 'pi')
train_var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) # TODO: better management of TRAINABLE_VARIABLES
with tf.variable_scope('pi_old'):
action_logits_old = policy_net(observation, action_dim, 'pi_old')
pi_old_var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'pi_old')
def my_policy():
net = tf.layers.dense(observation_ph, 64, activation=tf.nn.tanh)
net = tf.layers.dense(net, 64, activation=tf.nn.tanh)
# 2. build losses, optimizers
pi = policy.OnehotCategorical(action_logits, observation_placeholder=observation) # YongRen: policy.Gaussian (could reference the policy in TRPO paper, my code is adapted from zhusuan.distributions) policy.DQN etc.
# for continuous action space, you may need to change an environment to run
pi_old = policy.OnehotCategorical(action_logits_old, observation_placeholder=observation)
action_logits = tf.layers.dense(net, action_dim, activation=None)
action = tf.placeholder(dtype=tf.int32, shape=(None,)) # batch of integer actions
advantage = tf.placeholder(dtype=tf.float32, shape=(None,)) # advantage values used in the Gradients
return action_logits, None # None value head
ppo_loss_clip = losses.ppo_clip(action, advantage, clip_param, pi, pi_old) # TongzhengRen: losses.vpg ... management of placeholders and feed_dict
# TODO: current implementation of passing function or overriding function has to return a value head
# to allow network sharing between policy and value networks. This makes 'policy' and 'value_function'
# imbalanced semantically (though they are naturally imbalanced since 'policy' is required to interact
# with the environment and 'value_function' is not). I have an idea to solve this imbalance, which is
# not based on passing function or overriding function.
### 2. build policy, loss, optimizer
pi = policy.OnehotCategorical(my_policy, observation_placeholder=observation_ph, weight_update=0)
ppo_loss_clip = losses.ppo_clip(pi, clip_param)
total_loss = ppo_loss_clip
optimizer = tf.train.AdamOptimizer(1e-4)
train_op = optimizer.minimize(total_loss, var_list=train_var_list)
train_op = optimizer.minimize(total_loss, var_list=pi.trainable_variables)
# 3. define data collection
training_data = Batch(env, pi, advantage_estimation.full_return) # YouQiaoben: finish and polish Batch, advantage_estimation.gae_lambda as in PPO paper
# ShihongSong: Replay(), see dqn_example.py
# maybe a dict to manage the elements to be collected
### 3. define data collection
training_data = Batch(env, pi, advantage_estimation.full_return)
# 4. start training
### 4. start training
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
with tf.Session(config=config) as sess:
sess.run(tf.global_variables_initializer())
# sync pi and pi_old
sess.run([tf.assign(theta_old, theta) for (theta_old, theta) in zip(pi_old_var_list, train_var_list)])
# assign pi to pi_old
pi.sync_weights() # TODO: automate this for policies with target network
start_time = time.time()
for i in range(100): # until some stopping criterion met...
for i in range(100):
# collect data
training_data.collect(num_episodes=50) # YouQiaoben, ShihongSong
training_data.collect(num_episodes=50)
# print current return
print('Epoch {}:'.format(i))
@ -96,12 +77,10 @@ if __name__ == '__main__': # a clean version with only policy net, no value net
# update network
for _ in range(num_batches):
data = training_data.next_batch(batch_size) # YouQiaoben, ShihongSong
# TODO: auto managing of the placeholders? or add this to params of data.Batch
sess.run(train_op, feed_dict={observation: data['observations'], action: data['actions'],
advantage: data['returns']})
feed_dict = training_data.next_batch(batch_size)
sess.run(train_op, feed_dict=feed_dict)
# assigning pi to pi_old
sess.run([tf.assign(theta_old, theta) for (theta_old, theta) in zip(pi_old_var_list, train_var_list)])
pi.update_weights()
print('Elapsed time: {:.1f} min'.format((time.time() - start_time) / 60))

View File

@ -1,91 +0,0 @@
#!/usr/bin/env python
from __future__ import absolute_import
import tensorflow as tf
import gym
# our lib imports here!
import sys
sys.path.append('..')
from tianshou.core import losses
from tianshou.data.batch import Batch
import tianshou.data.advantage_estimation as advantage_estimation
import tianshou.core.policy.stochastic as policy # TODO: fix imports as zhusuan so that only need to import to policy
def policy_net(observation, action_dim, scope=None):
"""
Constructs the policy network. NOT NEEDED IN THE LIBRARY! this is pure tf
:param observation: Placeholder for the observation. A tensor of shape (bs, x, y, channels)
:param action_dim: int. The number of actions.
:param scope: str. Specifying the scope of the variables.
"""
# with tf.variable_scope(scope):
net = tf.layers.conv2d(observation, 16, 8, 4, 'valid', activation=tf.nn.relu)
net = tf.layers.conv2d(net, 32, 4, 2, 'valid', activation=tf.nn.relu)
net = tf.layers.flatten(net)
net = tf.layers.dense(net, 256, activation=tf.nn.relu)
act_logits = tf.layers.dense(net, action_dim)
return act_logits
if __name__ == '__main__': # a clean version with only policy net, no value net
env = gym.make('PongNoFrameskip-v4')
observation_dim = env.observation_space.shape
action_dim = env.action_space.n
clip_param = 0.2
num_batches = 2
# 1. build network with pure tf
observation = tf.placeholder(tf.float32, shape=(None,) + observation_dim) # network input
with tf.variable_scope('pi'):
action_logits = policy_net(observation, action_dim, 'pi')
train_var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) # TODO: better management of TRAINABLE_VARIABLES
with tf.variable_scope('pi_old'):
action_logits_old = policy_net(observation, action_dim, 'pi_old')
# 2. build losses, optimizers
pi = policy.OnehotCategorical(action_logits, observation_placeholder=observation) # YongRen: policy.Gaussian (could reference the policy in TRPO paper, my code is adapted from zhusuan.distributions) policy.DQN etc.
# for continuous action space, you may need to change an environment to run
pi_old = policy.OnehotCategorical(action_logits_old, observation_placeholder=observation)
action = tf.placeholder(dtype=tf.int32, shape=[None]) # batch of integer actions
advantage = tf.placeholder(dtype=tf.float32, shape=[None]) # advantage values used in the Gradients
ppo_loss_clip = losses.ppo_clip(action, advantage, clip_param, pi, pi_old) # TongzhengRen: losses.vpg ... management of placeholders and feed_dict
total_loss = ppo_loss_clip
optimizer = tf.train.AdamOptimizer(1e-3)
train_op = optimizer.minimize(total_loss, var_list=train_var_list)
# 3. define data collection
training_data = Batch(env, pi, advantage_estimation.full_return) # YouQiaoben: finish and polish Batch, advantage_estimation.gae_lambda as in PPO paper
# ShihongSong: Replay(), see dqn_example.py
# maybe a dict to manage the elements to be collected
# 4. start training
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
minibatch_count = 0
collection_count = 0
while True: # until some stopping criterion met...
# collect data
training_data.collect(num_episodes=2) # YouQiaoben, ShihongSong
collection_count += 1
print('Collected {} times.'.format(collection_count))
# update network
for _ in range(num_batches):
data = training_data.next_batch(64) # YouQiaoben, ShihongSong
# TODO: auto managing of the placeholders? or add this to params of data.Batch
sess.run(train_op, feed_dict={observation: data['observations'], action: data['actions'], advantage: data['returns']})
minibatch_count += 1
print('Trained {} minibatches.'.format(minibatch_count))
# TODO: assigning pi to pi_old is not implemented yet

View File

@ -11,7 +11,7 @@ def ppo_clip(policy, clip_param):
:param policy: current `policy` to be optimized
:param pi_old: old `policy` for computing the ppo loss as in Eqn. (7) in the paper
"""
action_ph = tf.placeholder(policy.act_dtype, shape=(None, policy.action_dim), name='ppo_clip_loss/action_placeholder')
action_ph = tf.placeholder(policy.act_dtype, shape=(None,) + policy.action_shape, name='ppo_clip_loss/action_placeholder')
advantage_ph = tf.placeholder(tf.float32, shape=(None,), name='ppo_clip_loss/advantage_placeholder')
policy.managed_placeholders['action'] = action_ph
policy.managed_placeholders['processed_reward'] = advantage_ph

View File

@ -9,7 +9,8 @@ import tensorflow as tf
from .base import StochasticPolicy
# TODO: the following, especially the target network construction should be refactored to be more neat
# even if policy_callable don't return a distribution class
class OnehotCategorical(StochasticPolicy):
"""
The class of one-hot Categorical distribution.
@ -33,19 +34,62 @@ class OnehotCategorical(StochasticPolicy):
`[i, j, ..., k, :]` is a one-hot vector of the selected category.
"""
def __init__(self, logits, observation_placeholder, dtype=None, group_ndims=0, **kwargs):
self._logits = tf.convert_to_tensor(logits)
self._action = tf.multinomial(self.logits, num_samples=1)
def __init__(self,
policy_callable,
observation_placeholder,
weight_update=1,
group_ndims=1,
**kwargs):
self.managed_placeholders = {'observation': observation_placeholder}
self.weight_update = weight_update
self.interaction_count = -1 # defaults to -1. only useful if weight_update > 1.
if dtype is None:
dtype = tf.int32
# assert_same_float_and_int_dtype([], dtype)
with tf.variable_scope('network'):
logits, value_head = policy_callable()
self._logits = tf.convert_to_tensor(logits, dtype=tf.float32)
self._action = tf.multinomial(self.logits, num_samples=1)
# TODO: self._action should be exactly the action tensor to run that directly gives action_dim
tf.assert_rank(self._logits, rank=2) # TODO: flexible policy output rank?
if value_head is not None:
pass
self.trainable_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='network')
if self.weight_update == 1:
self.weight_update_ops = None
self.sync_weights_ops = None
else: # then we need to build another tf graph as target network
with tf.variable_scope('net_old'):
logits, value_head = policy_callable()
self._logits_old = tf.convert_to_tensor(logits, dtype=tf.float32)
if value_head is not None: # useful in DDPG
pass
network_weights = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='network')
network_old_weights = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='net_old')
# TODO: use a scope that the user will almost surely not use. so get_collection will return
# the correct weights and old_weights, since it filters by regular expression
# or we write a util to parse the variable names and use only the topmost scope
assert len(network_weights) == len(network_old_weights)
self.sync_weights_ops = [tf.assign(variable_old, variable)
for (variable_old, variable) in zip(network_old_weights, network_weights)]
if weight_update == 0:
self.weight_update_ops = self.sync_weights_ops
elif 0 < weight_update < 1: # as in DDPG
pass
else:
self.interaction_count = 0 # as in DQN
import math
self.weight_update = math.ceil(weight_update)
tf.assert_rank(self._logits, rank=2) # TODO: flexible policy output rank, e.g. RNN
self._n_categories = self._logits.get_shape()[-1].value
super(OnehotCategorical, self).__init__(
act_dtype=dtype,
act_dtype=tf.int32,
param_dtype=self._logits.dtype,
is_continuous=False,
observation_placeholder=observation_placeholder,
@ -62,12 +106,18 @@ class OnehotCategorical(StochasticPolicy):
"""The number of categories in the distribution."""
return self._n_categories
def _act(self, observation):
@property
def action_shape(self):
return ()
def _act(self, observation, my_feed_dict):
# TODO: this may be ugly. also maybe huge problem when parallel
sess = tf.get_default_session()
# observation[None] adds one dimension at the beginning
sampled_action = sess.run(self._action,
feed_dict={self._observation_placeholder: observation[None]})
feed_dict = {self._observation_placeholder: observation[None]}
feed_dict.update(my_feed_dict)
sampled_action = sess.run(self._action, feed_dict=feed_dict)
sampled_action = sampled_action[0, 0]
@ -76,10 +126,30 @@ class OnehotCategorical(StochasticPolicy):
def _log_prob(self, sampled_action):
return -tf.nn.sparse_softmax_cross_entropy_with_logits(labels=sampled_action, logits=self.logits)
def _prob(self, sampled_action):
return tf.exp(self._log_prob(sampled_action))
def log_prob_old(self, sampled_action):
return -tf.nn.sparse_softmax_cross_entropy_with_logits(labels=sampled_action, logits=self._logits_old)
def update_weights(self):
"""
updates the weights of policy_old.
:return:
"""
if self.weight_update_ops is not None:
sess = tf.get_default_session()
sess.run(self.weight_update_ops)
def sync_weights(self):
"""
sync the weights of network_old. Direct copy the weights of network.
:return:
"""
if self.sync_weights_ops is not None:
sess = tf.get_default_session()
sess.run(self.sync_weights_ops)
OnehotDiscrete = OnehotCategorical
@ -111,7 +181,7 @@ class Normal(StochasticPolicy):
shape = tf.broadcast_dynamic_shape(tf.shape(self._mean), tf.shape(self._std))
self._action = tf.random_normal(tf.concat([[1], shape], 0), dtype = tf.float32) * self._std + self._mean
# TODO: self._action should be exactly the action tensor to run, without [0, 0] in self._act
# TODO: self._action should be exactly the action tensor to run that directly gives action_dim
if value_head is not None:
pass
@ -131,7 +201,10 @@ class Normal(StochasticPolicy):
pass
network_weights = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='network')
self.network_old_weights = network_old_weights = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='net_old')
network_old_weights = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='net_old')
# TODO: use a scope that the user will almost surely not use. so get_collection will return
# the correct weights and old_weights, since it filters by regular expression
assert len(network_weights) == len(network_old_weights)
self.sync_weights_ops = [tf.assign(variable_old, variable)
for (variable_old, variable) in zip(network_old_weights, network_weights)]
@ -168,12 +241,8 @@ class Normal(StochasticPolicy):
return self._logstd
@property
def action(self):
return self._action
@property
def action_dim(self):
return self.mean.shape.as_list()[-1]
def action_shape(self):
return tuple(self._mean.shape.as_list[1:])
def _act(self, observation, my_feed_dict):
# TODO: getting session like this maybe ugly. also maybe huge problem when parallel

View File

@ -14,7 +14,7 @@ class Batch(object):
self._advantage_estimation_function = advantage_estimation_function
self._is_first_collect = True
def collect(self, num_timesteps=0, num_episodes=0,
def collect(self, num_timesteps=0, num_episodes=0, my_feed_dict={},
apply_function=True): # specify how many data to collect here, or fix it in __init__()
assert sum(
[num_timesteps > 0, num_episodes > 0]) == 1, "One and only one collection number specification permitted!"
@ -87,7 +87,7 @@ class Batch(object):
episode_start_flags.append(True)
while True:
ac = self._pi.act(ob)
ac = self._pi.act(ob, my_feed_dict)
actions.append(ac)
ob, reward, done, _ = self._env.step(ac)
@ -139,6 +139,7 @@ class Batch(object):
feed_dict[self._pi.managed_placeholders['observation']] = current_batch['observations']
feed_dict[self._pi.managed_placeholders['action']] = current_batch['actions']
feed_dict[self._pi.managed_placeholders['processed_reward']] = current_batch['returns']
# TODO: should use the keys in pi.managed_placeholders to find values in self.data and self.raw_data
return feed_dict