From 983cd36074fa73b2b92a0ef2dfc0d1facdab6cd5 Mon Sep 17 00:00:00 2001 From: haoshengzou Date: Mon, 15 Jan 2018 00:03:06 +0800 Subject: [PATCH] finished all ppo examples. Training is remarkably slower than the version before Jan 13. More strangely, in the gym example there's almost no improvement... but this problem comes behind design. I'll first write actor-critic. --- examples/dqn_example.py | 3 + examples/ppo_cartpole.py | 34 ++++---- examples/ppo_cartpole_alternative.py | 112 +++++++++++++++++++++++++++ examples/ppo_cartpole_gym.py | 83 ++++++++------------ examples/ppo_example.py | 91 ---------------------- tianshou/core/losses.py | 2 +- tianshou/core/policy/stochastic.py | 111 +++++++++++++++++++++----- tianshou/data/batch.py | 5 +- 8 files changed, 254 insertions(+), 187 deletions(-) create mode 100755 examples/ppo_cartpole_alternative.py delete mode 100755 examples/ppo_example.py diff --git a/examples/dqn_example.py b/examples/dqn_example.py index cf20d66..4b97ea8 100644 --- a/examples/dqn_example.py +++ b/examples/dqn_example.py @@ -11,6 +11,9 @@ from tianshou.data.replay_buffer.utils import get_replay_buffer import tianshou.core.policy.dqn as policy +# THIS EXAMPLE IS NOT FINISHED YET!!! + + def policy_net(observation, action_dim): """ Constructs the policy network. NOT NEEDED IN THE LIBRARY! this is pure tf diff --git a/examples/ppo_cartpole.py b/examples/ppo_cartpole.py index ae7de41..05d317a 100755 --- a/examples/ppo_cartpole.py +++ b/examples/ppo_cartpole.py @@ -33,55 +33,49 @@ if __name__ == '__main__': tf.set_random_seed(seed) ### 1. build network with pure tf - observation_ph = tf.placeholder(tf.float32, shape=(None,) + observation_dim) # network input + observation_ph = tf.placeholder(tf.float32, shape=(None,) + observation_dim) def my_policy(): net = tf.layers.dense(observation_ph, 32, activation=tf.nn.tanh) net = tf.layers.dense(net, 32, activation=tf.nn.tanh) + action_mean = tf.layers.dense(net, action_dim, activation=None) action_logstd = tf.get_variable('action_logstd', shape=(action_dim, )) - # value = tf.layers.dense(net, 1, activation=None) return action_mean, action_logstd, None # None value head - # TODO: current implementation of passing function or overriding function has to return a value head - # to allow network sharing between policy and value networks. This makes 'policy' and 'value_function' - # imbalanced semantically (though they are naturally imbalanced since 'policy' is required to interact - # with the environment and 'value_function' is not). I have an idea to solve this imbalance, which is - # not based on passing function or overriding function. + # TODO: current implementation of passing function or overriding function has to return a value head + # to allow network sharing between policy and value networks. This makes 'policy' and 'value_function' + # imbalanced semantically (though they are naturally imbalanced since 'policy' is required to interact + # with the environment and 'value_function' is not). I have an idea to solve this imbalance, which is + # not based on passing function or overriding function. - ### 2. build policy, losses, optimizers + ### 2. build policy, loss, optimizer pi = policy.Normal(my_policy, observation_placeholder=observation_ph, weight_update=0) - # action = tf.placeholder(dtype=tf.float32, shape=(None, action_dim)) # batch of integer actions - # advantage = tf.placeholder(dtype=tf.float32, shape=(None,)) # advantage values used in the Gradients - - ppo_loss_clip = losses.ppo_clip(pi, clip_param) # TongzhengRen: losses.vpg ... management of placeholders and feed_dict + ppo_loss_clip = losses.ppo_clip(pi, clip_param) total_loss = ppo_loss_clip optimizer = tf.train.AdamOptimizer(1e-4) train_op = optimizer.minimize(total_loss, var_list=pi.trainable_variables) ### 3. define data collection - training_data = Batch(env, pi, advantage_estimation.full_return) # YouQiaoben: finish and polish Batch, advantage_estimation.gae_lambda as in PPO paper - # ShihongSong: Replay(), see dqn_example.py - # maybe a dict to manage the elements to be collected + training_data = Batch(env, pi, advantage_estimation.full_return) ### 4. start training - # init = tf.global_variables_initializer() config = tf.ConfigProto() config.gpu_options.allow_growth = True with tf.Session(config=config) as sess: sess.run(tf.global_variables_initializer()) - # sync pi and pi_old + # assign pi to pi_old pi.sync_weights() # TODO: automate this for policies with target network start_time = time.time() - for i in range(100): # until some stopping criterion met... + for i in range(100): # collect data - training_data.collect(num_episodes=20) # YouQiaoben, ShihongSong + training_data.collect(num_episodes=20) # print current return print('Epoch {}:'.format(i)) @@ -89,7 +83,7 @@ if __name__ == '__main__': # update network for _ in range(num_batches): - feed_dict = training_data.next_batch(batch_size) # YouQiaoben, ShihongSong + feed_dict = training_data.next_batch(batch_size) sess.run(train_op, feed_dict=feed_dict) # assigning pi to pi_old diff --git a/examples/ppo_cartpole_alternative.py b/examples/ppo_cartpole_alternative.py new file mode 100755 index 0000000..b76e634 --- /dev/null +++ b/examples/ppo_cartpole_alternative.py @@ -0,0 +1,112 @@ +#!/usr/bin/env python +from __future__ import absolute_import + +import tensorflow as tf +import time +import numpy as np + +# our lib imports here! It's ok to append path in examples +import sys +sys.path.append('..') +from tianshou.core import losses +from tianshou.data.batch import Batch +import tianshou.data.advantage_estimation as advantage_estimation +import tianshou.core.policy.stochastic as policy # TODO: fix imports as zhusuan so that only need to import to policy + +from rllab.envs.box2d.cartpole_env import CartpoleEnv +from rllab.envs.normalized_env import normalize + + +# for tutorial purpose, placeholders are explicitly appended with '_ph' suffix +# this example with batch_norm and dropout almost surely cannot improve. it just shows how to use those +# layers and another way of writing networks. + +class MyPolicy(object): + def __init__(self, observation_ph, is_training_ph, keep_prob_ph, action_dim): + self.observation_ph = observation_ph + self.is_training_ph = is_training_ph + self.keep_prob_ph = keep_prob_ph + self.action_dim = action_dim + + def __call__(self): + net = tf.layers.dense(self.observation_ph, 32, activation=None) + net = tf.layers.batch_normalization(net, training=self.is_training_ph) + net = tf.nn.relu(net) + net = tf.nn.dropout(net, keep_prob=self.keep_prob_ph) + + net = tf.layers.dense(net, 32, activation=tf.nn.relu) + net = tf.layers.dropout(net, rate=1 - self.keep_prob_ph) + action_mean = tf.layers.dense(net, action_dim, activation=None) + action_logstd = tf.get_variable('action_logstd', shape=(self.action_dim,), dtype=tf.float32) + + return action_mean, action_logstd, None + + +if __name__ == '__main__': + env = normalize(CartpoleEnv()) + observation_dim = env.observation_space.shape + action_dim = env.action_space.flat_dim + + clip_param = 0.2 + num_batches = 10 + batch_size = 128 + + seed = 10 + np.random.seed(seed) + tf.set_random_seed(seed) + + ### 1. build network with pure tf + observation_ph = tf.placeholder(tf.float32, shape=(None,) + observation_dim) + is_training_ph = tf.placeholder(tf.bool, shape=()) + keep_prob_ph = tf.placeholder(tf.float32, shape=()) + + my_policy = MyPolicy(observation_ph, is_training_ph, keep_prob_ph, action_dim) + + ### 2. build policy, loss, optimizer + pi = policy.Normal(my_policy, observation_placeholder=observation_ph, weight_update=0) + + ppo_loss_clip = losses.ppo_clip(pi, clip_param) + + total_loss = ppo_loss_clip + optimizer = tf.train.AdamOptimizer(1e-4) + train_op = optimizer.minimize(total_loss, var_list=pi.trainable_variables) + + ### 3. define data collection + training_data = Batch(env, pi, advantage_estimation.full_return) + + ### 4. start training + feed_dict_train = {is_training_ph: True, keep_prob_ph: 0.8} + feed_dict_test = {is_training_ph: False, keep_prob_ph: 1} + + config = tf.ConfigProto() + config.gpu_options.allow_growth = True + with tf.Session(config=config) as sess: + sess.run(tf.global_variables_initializer()) + + # assign pi to pi_old + pi.sync_weights() # TODO: automate this for policies with target network + + start_time = time.time() + for i in range(100): + # collect data + training_data.collect(num_episodes=20, my_feed_dict=feed_dict_train) + + # print current return + print('Epoch {}:'.format(i)) + training_data.statistics() + + # update network + for _ in range(num_batches): + feed_dict = training_data.next_batch(batch_size) + feed_dict.update(feed_dict_train) + sess.run(train_op, feed_dict=feed_dict) + + # assigning pi to pi_old + pi.update_weights() + + # approximate test mode + training_data.collect(num_episodes=10, my_feed_dict=feed_dict_test) + print('After training:') + training_data.statistics() + + print('Elapsed time: {:.1f} min'.format((time.time() - start_time) / 60)) \ No newline at end of file diff --git a/examples/ppo_cartpole_gym.py b/examples/ppo_cartpole_gym.py index 35ac275..42d1c13 100755 --- a/examples/ppo_cartpole_gym.py +++ b/examples/ppo_cartpole_gym.py @@ -14,28 +14,8 @@ from tianshou.data.batch import Batch import tianshou.data.advantage_estimation as advantage_estimation import tianshou.core.policy.stochastic as policy # TODO: fix imports as zhusuan so that only need to import to policy -from rllab.envs.box2d.cartpole_env import CartpoleEnv -from rllab.envs.normalized_env import normalize - -def policy_net(observation, action_dim, scope=None): - """ - Constructs the policy network. NOT NEEDED IN THE LIBRARY! this is pure tf - - :param observation: Placeholder for the observation. A tensor of shape (bs, x, y, channels) - :param action_dim: int. The number of actions. - :param scope: str. Specifying the scope of the variables. - """ - # with tf.variable_scope(scope): - net = tf.layers.dense(observation, 32, activation=tf.nn.tanh) - net = tf.layers.dense(net, 32, activation=tf.nn.tanh) - - act_logits = tf.layers.dense(net, action_dim, activation=None) - - return act_logits - - -if __name__ == '__main__': # a clean version with only policy net, no value net +if __name__ == '__main__': env = gym.make('CartPole-v0') observation_dim = env.observation_space.shape action_dim = env.action_space.n @@ -44,51 +24,52 @@ if __name__ == '__main__': # a clean version with only policy net, no value net num_batches = 10 batch_size = 512 - seed = 10 + seed = 5 np.random.seed(seed) tf.set_random_seed(seed) - # 1. build network with pure tf - observation = tf.placeholder(tf.float32, shape=(None,) + observation_dim) # network input + ### 1. build network with pure tf + observation_ph = tf.placeholder(tf.float32, shape=(None,) + observation_dim) - with tf.variable_scope('pi'): - action_logits = policy_net(observation, action_dim, 'pi') - train_var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) # TODO: better management of TRAINABLE_VARIABLES - with tf.variable_scope('pi_old'): - action_logits_old = policy_net(observation, action_dim, 'pi_old') - pi_old_var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'pi_old') + def my_policy(): + net = tf.layers.dense(observation_ph, 64, activation=tf.nn.tanh) + net = tf.layers.dense(net, 64, activation=tf.nn.tanh) - # 2. build losses, optimizers - pi = policy.OnehotCategorical(action_logits, observation_placeholder=observation) # YongRen: policy.Gaussian (could reference the policy in TRPO paper, my code is adapted from zhusuan.distributions) policy.DQN etc. - # for continuous action space, you may need to change an environment to run - pi_old = policy.OnehotCategorical(action_logits_old, observation_placeholder=observation) + action_logits = tf.layers.dense(net, action_dim, activation=None) - action = tf.placeholder(dtype=tf.int32, shape=(None,)) # batch of integer actions - advantage = tf.placeholder(dtype=tf.float32, shape=(None,)) # advantage values used in the Gradients + return action_logits, None # None value head - ppo_loss_clip = losses.ppo_clip(action, advantage, clip_param, pi, pi_old) # TongzhengRen: losses.vpg ... management of placeholders and feed_dict + # TODO: current implementation of passing function or overriding function has to return a value head + # to allow network sharing between policy and value networks. This makes 'policy' and 'value_function' + # imbalanced semantically (though they are naturally imbalanced since 'policy' is required to interact + # with the environment and 'value_function' is not). I have an idea to solve this imbalance, which is + # not based on passing function or overriding function. + + ### 2. build policy, loss, optimizer + pi = policy.OnehotCategorical(my_policy, observation_placeholder=observation_ph, weight_update=0) + + ppo_loss_clip = losses.ppo_clip(pi, clip_param) total_loss = ppo_loss_clip optimizer = tf.train.AdamOptimizer(1e-4) - train_op = optimizer.minimize(total_loss, var_list=train_var_list) + train_op = optimizer.minimize(total_loss, var_list=pi.trainable_variables) - # 3. define data collection - training_data = Batch(env, pi, advantage_estimation.full_return) # YouQiaoben: finish and polish Batch, advantage_estimation.gae_lambda as in PPO paper - # ShihongSong: Replay(), see dqn_example.py - # maybe a dict to manage the elements to be collected + ### 3. define data collection + training_data = Batch(env, pi, advantage_estimation.full_return) - # 4. start training + ### 4. start training config = tf.ConfigProto() config.gpu_options.allow_growth = True with tf.Session(config=config) as sess: sess.run(tf.global_variables_initializer()) - # sync pi and pi_old - sess.run([tf.assign(theta_old, theta) for (theta_old, theta) in zip(pi_old_var_list, train_var_list)]) + + # assign pi to pi_old + pi.sync_weights() # TODO: automate this for policies with target network start_time = time.time() - for i in range(100): # until some stopping criterion met... + for i in range(100): # collect data - training_data.collect(num_episodes=50) # YouQiaoben, ShihongSong + training_data.collect(num_episodes=50) # print current return print('Epoch {}:'.format(i)) @@ -96,12 +77,10 @@ if __name__ == '__main__': # a clean version with only policy net, no value net # update network for _ in range(num_batches): - data = training_data.next_batch(batch_size) # YouQiaoben, ShihongSong - # TODO: auto managing of the placeholders? or add this to params of data.Batch - sess.run(train_op, feed_dict={observation: data['observations'], action: data['actions'], - advantage: data['returns']}) + feed_dict = training_data.next_batch(batch_size) + sess.run(train_op, feed_dict=feed_dict) # assigning pi to pi_old - sess.run([tf.assign(theta_old, theta) for (theta_old, theta) in zip(pi_old_var_list, train_var_list)]) + pi.update_weights() print('Elapsed time: {:.1f} min'.format((time.time() - start_time) / 60)) \ No newline at end of file diff --git a/examples/ppo_example.py b/examples/ppo_example.py deleted file mode 100755 index 985c8f2..0000000 --- a/examples/ppo_example.py +++ /dev/null @@ -1,91 +0,0 @@ -#!/usr/bin/env python -from __future__ import absolute_import - -import tensorflow as tf -import gym - -# our lib imports here! -import sys -sys.path.append('..') -from tianshou.core import losses -from tianshou.data.batch import Batch -import tianshou.data.advantage_estimation as advantage_estimation -import tianshou.core.policy.stochastic as policy # TODO: fix imports as zhusuan so that only need to import to policy - - -def policy_net(observation, action_dim, scope=None): - """ - Constructs the policy network. NOT NEEDED IN THE LIBRARY! this is pure tf - - :param observation: Placeholder for the observation. A tensor of shape (bs, x, y, channels) - :param action_dim: int. The number of actions. - :param scope: str. Specifying the scope of the variables. - """ - # with tf.variable_scope(scope): - net = tf.layers.conv2d(observation, 16, 8, 4, 'valid', activation=tf.nn.relu) - net = tf.layers.conv2d(net, 32, 4, 2, 'valid', activation=tf.nn.relu) - net = tf.layers.flatten(net) - net = tf.layers.dense(net, 256, activation=tf.nn.relu) - - act_logits = tf.layers.dense(net, action_dim) - - return act_logits - - -if __name__ == '__main__': # a clean version with only policy net, no value net - env = gym.make('PongNoFrameskip-v4') - observation_dim = env.observation_space.shape - action_dim = env.action_space.n - - clip_param = 0.2 - num_batches = 2 - - # 1. build network with pure tf - observation = tf.placeholder(tf.float32, shape=(None,) + observation_dim) # network input - - with tf.variable_scope('pi'): - action_logits = policy_net(observation, action_dim, 'pi') - train_var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) # TODO: better management of TRAINABLE_VARIABLES - with tf.variable_scope('pi_old'): - action_logits_old = policy_net(observation, action_dim, 'pi_old') - - # 2. build losses, optimizers - pi = policy.OnehotCategorical(action_logits, observation_placeholder=observation) # YongRen: policy.Gaussian (could reference the policy in TRPO paper, my code is adapted from zhusuan.distributions) policy.DQN etc. - # for continuous action space, you may need to change an environment to run - pi_old = policy.OnehotCategorical(action_logits_old, observation_placeholder=observation) - - action = tf.placeholder(dtype=tf.int32, shape=[None]) # batch of integer actions - advantage = tf.placeholder(dtype=tf.float32, shape=[None]) # advantage values used in the Gradients - - ppo_loss_clip = losses.ppo_clip(action, advantage, clip_param, pi, pi_old) # TongzhengRen: losses.vpg ... management of placeholders and feed_dict - - total_loss = ppo_loss_clip - optimizer = tf.train.AdamOptimizer(1e-3) - train_op = optimizer.minimize(total_loss, var_list=train_var_list) - - # 3. define data collection - training_data = Batch(env, pi, advantage_estimation.full_return) # YouQiaoben: finish and polish Batch, advantage_estimation.gae_lambda as in PPO paper - # ShihongSong: Replay(), see dqn_example.py - # maybe a dict to manage the elements to be collected - - # 4. start training - with tf.Session() as sess: - sess.run(tf.global_variables_initializer()) - - minibatch_count = 0 - collection_count = 0 - while True: # until some stopping criterion met... - # collect data - training_data.collect(num_episodes=2) # YouQiaoben, ShihongSong - collection_count += 1 - print('Collected {} times.'.format(collection_count)) - - # update network - for _ in range(num_batches): - data = training_data.next_batch(64) # YouQiaoben, ShihongSong - # TODO: auto managing of the placeholders? or add this to params of data.Batch - sess.run(train_op, feed_dict={observation: data['observations'], action: data['actions'], advantage: data['returns']}) - minibatch_count += 1 - print('Trained {} minibatches.'.format(minibatch_count)) - - # TODO: assigning pi to pi_old is not implemented yet \ No newline at end of file diff --git a/tianshou/core/losses.py b/tianshou/core/losses.py index e4f56ce..e1c48d4 100644 --- a/tianshou/core/losses.py +++ b/tianshou/core/losses.py @@ -11,7 +11,7 @@ def ppo_clip(policy, clip_param): :param policy: current `policy` to be optimized :param pi_old: old `policy` for computing the ppo loss as in Eqn. (7) in the paper """ - action_ph = tf.placeholder(policy.act_dtype, shape=(None, policy.action_dim), name='ppo_clip_loss/action_placeholder') + action_ph = tf.placeholder(policy.act_dtype, shape=(None,) + policy.action_shape, name='ppo_clip_loss/action_placeholder') advantage_ph = tf.placeholder(tf.float32, shape=(None,), name='ppo_clip_loss/advantage_placeholder') policy.managed_placeholders['action'] = action_ph policy.managed_placeholders['processed_reward'] = advantage_ph diff --git a/tianshou/core/policy/stochastic.py b/tianshou/core/policy/stochastic.py index d3ab8e7..33ee36a 100644 --- a/tianshou/core/policy/stochastic.py +++ b/tianshou/core/policy/stochastic.py @@ -9,7 +9,8 @@ import tensorflow as tf from .base import StochasticPolicy - +# TODO: the following, especially the target network construction should be refactored to be more neat +# even if policy_callable don't return a distribution class class OnehotCategorical(StochasticPolicy): """ The class of one-hot Categorical distribution. @@ -33,19 +34,62 @@ class OnehotCategorical(StochasticPolicy): `[i, j, ..., k, :]` is a one-hot vector of the selected category. """ - def __init__(self, logits, observation_placeholder, dtype=None, group_ndims=0, **kwargs): - self._logits = tf.convert_to_tensor(logits) - self._action = tf.multinomial(self.logits, num_samples=1) + def __init__(self, + policy_callable, + observation_placeholder, + weight_update=1, + group_ndims=1, + **kwargs): + self.managed_placeholders = {'observation': observation_placeholder} + self.weight_update = weight_update + self.interaction_count = -1 # defaults to -1. only useful if weight_update > 1. - if dtype is None: - dtype = tf.int32 - # assert_same_float_and_int_dtype([], dtype) + with tf.variable_scope('network'): + logits, value_head = policy_callable() + self._logits = tf.convert_to_tensor(logits, dtype=tf.float32) + self._action = tf.multinomial(self.logits, num_samples=1) + # TODO: self._action should be exactly the action tensor to run that directly gives action_dim - tf.assert_rank(self._logits, rank=2) # TODO: flexible policy output rank? + if value_head is not None: + pass + + self.trainable_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='network') + + if self.weight_update == 1: + self.weight_update_ops = None + self.sync_weights_ops = None + else: # then we need to build another tf graph as target network + with tf.variable_scope('net_old'): + logits, value_head = policy_callable() + self._logits_old = tf.convert_to_tensor(logits, dtype=tf.float32) + + if value_head is not None: # useful in DDPG + pass + + network_weights = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='network') + network_old_weights = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='net_old') + # TODO: use a scope that the user will almost surely not use. so get_collection will return + # the correct weights and old_weights, since it filters by regular expression + # or we write a util to parse the variable names and use only the topmost scope + + assert len(network_weights) == len(network_old_weights) + self.sync_weights_ops = [tf.assign(variable_old, variable) + for (variable_old, variable) in zip(network_old_weights, network_weights)] + + if weight_update == 0: + self.weight_update_ops = self.sync_weights_ops + elif 0 < weight_update < 1: # as in DDPG + pass + else: + self.interaction_count = 0 # as in DQN + import math + self.weight_update = math.ceil(weight_update) + + tf.assert_rank(self._logits, rank=2) # TODO: flexible policy output rank, e.g. RNN self._n_categories = self._logits.get_shape()[-1].value super(OnehotCategorical, self).__init__( - act_dtype=dtype, + act_dtype=tf.int32, param_dtype=self._logits.dtype, is_continuous=False, observation_placeholder=observation_placeholder, @@ -62,12 +106,18 @@ class OnehotCategorical(StochasticPolicy): """The number of categories in the distribution.""" return self._n_categories - def _act(self, observation): + @property + def action_shape(self): + return () + + def _act(self, observation, my_feed_dict): # TODO: this may be ugly. also maybe huge problem when parallel sess = tf.get_default_session() # observation[None] adds one dimension at the beginning - sampled_action = sess.run(self._action, - feed_dict={self._observation_placeholder: observation[None]}) + + feed_dict = {self._observation_placeholder: observation[None]} + feed_dict.update(my_feed_dict) + sampled_action = sess.run(self._action, feed_dict=feed_dict) sampled_action = sampled_action[0, 0] @@ -76,10 +126,30 @@ class OnehotCategorical(StochasticPolicy): def _log_prob(self, sampled_action): return -tf.nn.sparse_softmax_cross_entropy_with_logits(labels=sampled_action, logits=self.logits) - def _prob(self, sampled_action): return tf.exp(self._log_prob(sampled_action)) + def log_prob_old(self, sampled_action): + return -tf.nn.sparse_softmax_cross_entropy_with_logits(labels=sampled_action, logits=self._logits_old) + + def update_weights(self): + """ + updates the weights of policy_old. + :return: + """ + if self.weight_update_ops is not None: + sess = tf.get_default_session() + sess.run(self.weight_update_ops) + + def sync_weights(self): + """ + sync the weights of network_old. Direct copy the weights of network. + :return: + """ + if self.sync_weights_ops is not None: + sess = tf.get_default_session() + sess.run(self.sync_weights_ops) + OnehotDiscrete = OnehotCategorical @@ -111,7 +181,7 @@ class Normal(StochasticPolicy): shape = tf.broadcast_dynamic_shape(tf.shape(self._mean), tf.shape(self._std)) self._action = tf.random_normal(tf.concat([[1], shape], 0), dtype = tf.float32) * self._std + self._mean - # TODO: self._action should be exactly the action tensor to run, without [0, 0] in self._act + # TODO: self._action should be exactly the action tensor to run that directly gives action_dim if value_head is not None: pass @@ -131,7 +201,10 @@ class Normal(StochasticPolicy): pass network_weights = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='network') - self.network_old_weights = network_old_weights = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='net_old') + network_old_weights = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='net_old') + # TODO: use a scope that the user will almost surely not use. so get_collection will return + # the correct weights and old_weights, since it filters by regular expression + assert len(network_weights) == len(network_old_weights) self.sync_weights_ops = [tf.assign(variable_old, variable) for (variable_old, variable) in zip(network_old_weights, network_weights)] @@ -168,12 +241,8 @@ class Normal(StochasticPolicy): return self._logstd @property - def action(self): - return self._action - - @property - def action_dim(self): - return self.mean.shape.as_list()[-1] + def action_shape(self): + return tuple(self._mean.shape.as_list[1:]) def _act(self, observation, my_feed_dict): # TODO: getting session like this maybe ugly. also maybe huge problem when parallel diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 2e372ea..24d0af7 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -14,7 +14,7 @@ class Batch(object): self._advantage_estimation_function = advantage_estimation_function self._is_first_collect = True - def collect(self, num_timesteps=0, num_episodes=0, + def collect(self, num_timesteps=0, num_episodes=0, my_feed_dict={}, apply_function=True): # specify how many data to collect here, or fix it in __init__() assert sum( [num_timesteps > 0, num_episodes > 0]) == 1, "One and only one collection number specification permitted!" @@ -87,7 +87,7 @@ class Batch(object): episode_start_flags.append(True) while True: - ac = self._pi.act(ob) + ac = self._pi.act(ob, my_feed_dict) actions.append(ac) ob, reward, done, _ = self._env.step(ac) @@ -139,6 +139,7 @@ class Batch(object): feed_dict[self._pi.managed_placeholders['observation']] = current_batch['observations'] feed_dict[self._pi.managed_placeholders['action']] = current_batch['actions'] feed_dict[self._pi.managed_placeholders['processed_reward']] = current_batch['returns'] + # TODO: should use the keys in pi.managed_placeholders to find values in self.data and self.raw_data return feed_dict