diff --git a/examples/dqn_example.py b/examples/dqn_example.py index cf20d66..4b97ea8 100644 --- a/examples/dqn_example.py +++ b/examples/dqn_example.py @@ -11,6 +11,9 @@ from tianshou.data.replay_buffer.utils import get_replay_buffer import tianshou.core.policy.dqn as policy +# THIS EXAMPLE IS NOT FINISHED YET!!! + + def policy_net(observation, action_dim): """ Constructs the policy network. NOT NEEDED IN THE LIBRARY! this is pure tf diff --git a/examples/ppo_cartpole.py b/examples/ppo_cartpole.py index ae7de41..05d317a 100755 --- a/examples/ppo_cartpole.py +++ b/examples/ppo_cartpole.py @@ -33,55 +33,49 @@ if __name__ == '__main__': tf.set_random_seed(seed) ### 1. build network with pure tf - observation_ph = tf.placeholder(tf.float32, shape=(None,) + observation_dim) # network input + observation_ph = tf.placeholder(tf.float32, shape=(None,) + observation_dim) def my_policy(): net = tf.layers.dense(observation_ph, 32, activation=tf.nn.tanh) net = tf.layers.dense(net, 32, activation=tf.nn.tanh) + action_mean = tf.layers.dense(net, action_dim, activation=None) action_logstd = tf.get_variable('action_logstd', shape=(action_dim, )) - # value = tf.layers.dense(net, 1, activation=None) return action_mean, action_logstd, None # None value head - # TODO: current implementation of passing function or overriding function has to return a value head - # to allow network sharing between policy and value networks. This makes 'policy' and 'value_function' - # imbalanced semantically (though they are naturally imbalanced since 'policy' is required to interact - # with the environment and 'value_function' is not). I have an idea to solve this imbalance, which is - # not based on passing function or overriding function. + # TODO: current implementation of passing function or overriding function has to return a value head + # to allow network sharing between policy and value networks. This makes 'policy' and 'value_function' + # imbalanced semantically (though they are naturally imbalanced since 'policy' is required to interact + # with the environment and 'value_function' is not). I have an idea to solve this imbalance, which is + # not based on passing function or overriding function. - ### 2. build policy, losses, optimizers + ### 2. build policy, loss, optimizer pi = policy.Normal(my_policy, observation_placeholder=observation_ph, weight_update=0) - # action = tf.placeholder(dtype=tf.float32, shape=(None, action_dim)) # batch of integer actions - # advantage = tf.placeholder(dtype=tf.float32, shape=(None,)) # advantage values used in the Gradients - - ppo_loss_clip = losses.ppo_clip(pi, clip_param) # TongzhengRen: losses.vpg ... management of placeholders and feed_dict + ppo_loss_clip = losses.ppo_clip(pi, clip_param) total_loss = ppo_loss_clip optimizer = tf.train.AdamOptimizer(1e-4) train_op = optimizer.minimize(total_loss, var_list=pi.trainable_variables) ### 3. define data collection - training_data = Batch(env, pi, advantage_estimation.full_return) # YouQiaoben: finish and polish Batch, advantage_estimation.gae_lambda as in PPO paper - # ShihongSong: Replay(), see dqn_example.py - # maybe a dict to manage the elements to be collected + training_data = Batch(env, pi, advantage_estimation.full_return) ### 4. start training - # init = tf.global_variables_initializer() config = tf.ConfigProto() config.gpu_options.allow_growth = True with tf.Session(config=config) as sess: sess.run(tf.global_variables_initializer()) - # sync pi and pi_old + # assign pi to pi_old pi.sync_weights() # TODO: automate this for policies with target network start_time = time.time() - for i in range(100): # until some stopping criterion met... + for i in range(100): # collect data - training_data.collect(num_episodes=20) # YouQiaoben, ShihongSong + training_data.collect(num_episodes=20) # print current return print('Epoch {}:'.format(i)) @@ -89,7 +83,7 @@ if __name__ == '__main__': # update network for _ in range(num_batches): - feed_dict = training_data.next_batch(batch_size) # YouQiaoben, ShihongSong + feed_dict = training_data.next_batch(batch_size) sess.run(train_op, feed_dict=feed_dict) # assigning pi to pi_old diff --git a/examples/ppo_cartpole_alternative.py b/examples/ppo_cartpole_alternative.py new file mode 100755 index 0000000..b76e634 --- /dev/null +++ b/examples/ppo_cartpole_alternative.py @@ -0,0 +1,112 @@ +#!/usr/bin/env python +from __future__ import absolute_import + +import tensorflow as tf +import time +import numpy as np + +# our lib imports here! It's ok to append path in examples +import sys +sys.path.append('..') +from tianshou.core import losses +from tianshou.data.batch import Batch +import tianshou.data.advantage_estimation as advantage_estimation +import tianshou.core.policy.stochastic as policy # TODO: fix imports as zhusuan so that only need to import to policy + +from rllab.envs.box2d.cartpole_env import CartpoleEnv +from rllab.envs.normalized_env import normalize + + +# for tutorial purpose, placeholders are explicitly appended with '_ph' suffix +# this example with batch_norm and dropout almost surely cannot improve. it just shows how to use those +# layers and another way of writing networks. + +class MyPolicy(object): + def __init__(self, observation_ph, is_training_ph, keep_prob_ph, action_dim): + self.observation_ph = observation_ph + self.is_training_ph = is_training_ph + self.keep_prob_ph = keep_prob_ph + self.action_dim = action_dim + + def __call__(self): + net = tf.layers.dense(self.observation_ph, 32, activation=None) + net = tf.layers.batch_normalization(net, training=self.is_training_ph) + net = tf.nn.relu(net) + net = tf.nn.dropout(net, keep_prob=self.keep_prob_ph) + + net = tf.layers.dense(net, 32, activation=tf.nn.relu) + net = tf.layers.dropout(net, rate=1 - self.keep_prob_ph) + action_mean = tf.layers.dense(net, action_dim, activation=None) + action_logstd = tf.get_variable('action_logstd', shape=(self.action_dim,), dtype=tf.float32) + + return action_mean, action_logstd, None + + +if __name__ == '__main__': + env = normalize(CartpoleEnv()) + observation_dim = env.observation_space.shape + action_dim = env.action_space.flat_dim + + clip_param = 0.2 + num_batches = 10 + batch_size = 128 + + seed = 10 + np.random.seed(seed) + tf.set_random_seed(seed) + + ### 1. build network with pure tf + observation_ph = tf.placeholder(tf.float32, shape=(None,) + observation_dim) + is_training_ph = tf.placeholder(tf.bool, shape=()) + keep_prob_ph = tf.placeholder(tf.float32, shape=()) + + my_policy = MyPolicy(observation_ph, is_training_ph, keep_prob_ph, action_dim) + + ### 2. build policy, loss, optimizer + pi = policy.Normal(my_policy, observation_placeholder=observation_ph, weight_update=0) + + ppo_loss_clip = losses.ppo_clip(pi, clip_param) + + total_loss = ppo_loss_clip + optimizer = tf.train.AdamOptimizer(1e-4) + train_op = optimizer.minimize(total_loss, var_list=pi.trainable_variables) + + ### 3. define data collection + training_data = Batch(env, pi, advantage_estimation.full_return) + + ### 4. start training + feed_dict_train = {is_training_ph: True, keep_prob_ph: 0.8} + feed_dict_test = {is_training_ph: False, keep_prob_ph: 1} + + config = tf.ConfigProto() + config.gpu_options.allow_growth = True + with tf.Session(config=config) as sess: + sess.run(tf.global_variables_initializer()) + + # assign pi to pi_old + pi.sync_weights() # TODO: automate this for policies with target network + + start_time = time.time() + for i in range(100): + # collect data + training_data.collect(num_episodes=20, my_feed_dict=feed_dict_train) + + # print current return + print('Epoch {}:'.format(i)) + training_data.statistics() + + # update network + for _ in range(num_batches): + feed_dict = training_data.next_batch(batch_size) + feed_dict.update(feed_dict_train) + sess.run(train_op, feed_dict=feed_dict) + + # assigning pi to pi_old + pi.update_weights() + + # approximate test mode + training_data.collect(num_episodes=10, my_feed_dict=feed_dict_test) + print('After training:') + training_data.statistics() + + print('Elapsed time: {:.1f} min'.format((time.time() - start_time) / 60)) \ No newline at end of file diff --git a/examples/ppo_cartpole_gym.py b/examples/ppo_cartpole_gym.py index 35ac275..42d1c13 100755 --- a/examples/ppo_cartpole_gym.py +++ b/examples/ppo_cartpole_gym.py @@ -14,28 +14,8 @@ from tianshou.data.batch import Batch import tianshou.data.advantage_estimation as advantage_estimation import tianshou.core.policy.stochastic as policy # TODO: fix imports as zhusuan so that only need to import to policy -from rllab.envs.box2d.cartpole_env import CartpoleEnv -from rllab.envs.normalized_env import normalize - -def policy_net(observation, action_dim, scope=None): - """ - Constructs the policy network. NOT NEEDED IN THE LIBRARY! this is pure tf - - :param observation: Placeholder for the observation. A tensor of shape (bs, x, y, channels) - :param action_dim: int. The number of actions. - :param scope: str. Specifying the scope of the variables. - """ - # with tf.variable_scope(scope): - net = tf.layers.dense(observation, 32, activation=tf.nn.tanh) - net = tf.layers.dense(net, 32, activation=tf.nn.tanh) - - act_logits = tf.layers.dense(net, action_dim, activation=None) - - return act_logits - - -if __name__ == '__main__': # a clean version with only policy net, no value net +if __name__ == '__main__': env = gym.make('CartPole-v0') observation_dim = env.observation_space.shape action_dim = env.action_space.n @@ -44,51 +24,52 @@ if __name__ == '__main__': # a clean version with only policy net, no value net num_batches = 10 batch_size = 512 - seed = 10 + seed = 5 np.random.seed(seed) tf.set_random_seed(seed) - # 1. build network with pure tf - observation = tf.placeholder(tf.float32, shape=(None,) + observation_dim) # network input + ### 1. build network with pure tf + observation_ph = tf.placeholder(tf.float32, shape=(None,) + observation_dim) - with tf.variable_scope('pi'): - action_logits = policy_net(observation, action_dim, 'pi') - train_var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) # TODO: better management of TRAINABLE_VARIABLES - with tf.variable_scope('pi_old'): - action_logits_old = policy_net(observation, action_dim, 'pi_old') - pi_old_var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'pi_old') + def my_policy(): + net = tf.layers.dense(observation_ph, 64, activation=tf.nn.tanh) + net = tf.layers.dense(net, 64, activation=tf.nn.tanh) - # 2. build losses, optimizers - pi = policy.OnehotCategorical(action_logits, observation_placeholder=observation) # YongRen: policy.Gaussian (could reference the policy in TRPO paper, my code is adapted from zhusuan.distributions) policy.DQN etc. - # for continuous action space, you may need to change an environment to run - pi_old = policy.OnehotCategorical(action_logits_old, observation_placeholder=observation) + action_logits = tf.layers.dense(net, action_dim, activation=None) - action = tf.placeholder(dtype=tf.int32, shape=(None,)) # batch of integer actions - advantage = tf.placeholder(dtype=tf.float32, shape=(None,)) # advantage values used in the Gradients + return action_logits, None # None value head - ppo_loss_clip = losses.ppo_clip(action, advantage, clip_param, pi, pi_old) # TongzhengRen: losses.vpg ... management of placeholders and feed_dict + # TODO: current implementation of passing function or overriding function has to return a value head + # to allow network sharing between policy and value networks. This makes 'policy' and 'value_function' + # imbalanced semantically (though they are naturally imbalanced since 'policy' is required to interact + # with the environment and 'value_function' is not). I have an idea to solve this imbalance, which is + # not based on passing function or overriding function. + + ### 2. build policy, loss, optimizer + pi = policy.OnehotCategorical(my_policy, observation_placeholder=observation_ph, weight_update=0) + + ppo_loss_clip = losses.ppo_clip(pi, clip_param) total_loss = ppo_loss_clip optimizer = tf.train.AdamOptimizer(1e-4) - train_op = optimizer.minimize(total_loss, var_list=train_var_list) + train_op = optimizer.minimize(total_loss, var_list=pi.trainable_variables) - # 3. define data collection - training_data = Batch(env, pi, advantage_estimation.full_return) # YouQiaoben: finish and polish Batch, advantage_estimation.gae_lambda as in PPO paper - # ShihongSong: Replay(), see dqn_example.py - # maybe a dict to manage the elements to be collected + ### 3. define data collection + training_data = Batch(env, pi, advantage_estimation.full_return) - # 4. start training + ### 4. start training config = tf.ConfigProto() config.gpu_options.allow_growth = True with tf.Session(config=config) as sess: sess.run(tf.global_variables_initializer()) - # sync pi and pi_old - sess.run([tf.assign(theta_old, theta) for (theta_old, theta) in zip(pi_old_var_list, train_var_list)]) + + # assign pi to pi_old + pi.sync_weights() # TODO: automate this for policies with target network start_time = time.time() - for i in range(100): # until some stopping criterion met... + for i in range(100): # collect data - training_data.collect(num_episodes=50) # YouQiaoben, ShihongSong + training_data.collect(num_episodes=50) # print current return print('Epoch {}:'.format(i)) @@ -96,12 +77,10 @@ if __name__ == '__main__': # a clean version with only policy net, no value net # update network for _ in range(num_batches): - data = training_data.next_batch(batch_size) # YouQiaoben, ShihongSong - # TODO: auto managing of the placeholders? or add this to params of data.Batch - sess.run(train_op, feed_dict={observation: data['observations'], action: data['actions'], - advantage: data['returns']}) + feed_dict = training_data.next_batch(batch_size) + sess.run(train_op, feed_dict=feed_dict) # assigning pi to pi_old - sess.run([tf.assign(theta_old, theta) for (theta_old, theta) in zip(pi_old_var_list, train_var_list)]) + pi.update_weights() print('Elapsed time: {:.1f} min'.format((time.time() - start_time) / 60)) \ No newline at end of file diff --git a/examples/ppo_example.py b/examples/ppo_example.py deleted file mode 100755 index 985c8f2..0000000 --- a/examples/ppo_example.py +++ /dev/null @@ -1,91 +0,0 @@ -#!/usr/bin/env python -from __future__ import absolute_import - -import tensorflow as tf -import gym - -# our lib imports here! -import sys -sys.path.append('..') -from tianshou.core import losses -from tianshou.data.batch import Batch -import tianshou.data.advantage_estimation as advantage_estimation -import tianshou.core.policy.stochastic as policy # TODO: fix imports as zhusuan so that only need to import to policy - - -def policy_net(observation, action_dim, scope=None): - """ - Constructs the policy network. NOT NEEDED IN THE LIBRARY! this is pure tf - - :param observation: Placeholder for the observation. A tensor of shape (bs, x, y, channels) - :param action_dim: int. The number of actions. - :param scope: str. Specifying the scope of the variables. - """ - # with tf.variable_scope(scope): - net = tf.layers.conv2d(observation, 16, 8, 4, 'valid', activation=tf.nn.relu) - net = tf.layers.conv2d(net, 32, 4, 2, 'valid', activation=tf.nn.relu) - net = tf.layers.flatten(net) - net = tf.layers.dense(net, 256, activation=tf.nn.relu) - - act_logits = tf.layers.dense(net, action_dim) - - return act_logits - - -if __name__ == '__main__': # a clean version with only policy net, no value net - env = gym.make('PongNoFrameskip-v4') - observation_dim = env.observation_space.shape - action_dim = env.action_space.n - - clip_param = 0.2 - num_batches = 2 - - # 1. build network with pure tf - observation = tf.placeholder(tf.float32, shape=(None,) + observation_dim) # network input - - with tf.variable_scope('pi'): - action_logits = policy_net(observation, action_dim, 'pi') - train_var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) # TODO: better management of TRAINABLE_VARIABLES - with tf.variable_scope('pi_old'): - action_logits_old = policy_net(observation, action_dim, 'pi_old') - - # 2. build losses, optimizers - pi = policy.OnehotCategorical(action_logits, observation_placeholder=observation) # YongRen: policy.Gaussian (could reference the policy in TRPO paper, my code is adapted from zhusuan.distributions) policy.DQN etc. - # for continuous action space, you may need to change an environment to run - pi_old = policy.OnehotCategorical(action_logits_old, observation_placeholder=observation) - - action = tf.placeholder(dtype=tf.int32, shape=[None]) # batch of integer actions - advantage = tf.placeholder(dtype=tf.float32, shape=[None]) # advantage values used in the Gradients - - ppo_loss_clip = losses.ppo_clip(action, advantage, clip_param, pi, pi_old) # TongzhengRen: losses.vpg ... management of placeholders and feed_dict - - total_loss = ppo_loss_clip - optimizer = tf.train.AdamOptimizer(1e-3) - train_op = optimizer.minimize(total_loss, var_list=train_var_list) - - # 3. define data collection - training_data = Batch(env, pi, advantage_estimation.full_return) # YouQiaoben: finish and polish Batch, advantage_estimation.gae_lambda as in PPO paper - # ShihongSong: Replay(), see dqn_example.py - # maybe a dict to manage the elements to be collected - - # 4. start training - with tf.Session() as sess: - sess.run(tf.global_variables_initializer()) - - minibatch_count = 0 - collection_count = 0 - while True: # until some stopping criterion met... - # collect data - training_data.collect(num_episodes=2) # YouQiaoben, ShihongSong - collection_count += 1 - print('Collected {} times.'.format(collection_count)) - - # update network - for _ in range(num_batches): - data = training_data.next_batch(64) # YouQiaoben, ShihongSong - # TODO: auto managing of the placeholders? or add this to params of data.Batch - sess.run(train_op, feed_dict={observation: data['observations'], action: data['actions'], advantage: data['returns']}) - minibatch_count += 1 - print('Trained {} minibatches.'.format(minibatch_count)) - - # TODO: assigning pi to pi_old is not implemented yet \ No newline at end of file diff --git a/tianshou/core/losses.py b/tianshou/core/losses.py index e4f56ce..e1c48d4 100644 --- a/tianshou/core/losses.py +++ b/tianshou/core/losses.py @@ -11,7 +11,7 @@ def ppo_clip(policy, clip_param): :param policy: current `policy` to be optimized :param pi_old: old `policy` for computing the ppo loss as in Eqn. (7) in the paper """ - action_ph = tf.placeholder(policy.act_dtype, shape=(None, policy.action_dim), name='ppo_clip_loss/action_placeholder') + action_ph = tf.placeholder(policy.act_dtype, shape=(None,) + policy.action_shape, name='ppo_clip_loss/action_placeholder') advantage_ph = tf.placeholder(tf.float32, shape=(None,), name='ppo_clip_loss/advantage_placeholder') policy.managed_placeholders['action'] = action_ph policy.managed_placeholders['processed_reward'] = advantage_ph diff --git a/tianshou/core/policy/stochastic.py b/tianshou/core/policy/stochastic.py index d3ab8e7..33ee36a 100644 --- a/tianshou/core/policy/stochastic.py +++ b/tianshou/core/policy/stochastic.py @@ -9,7 +9,8 @@ import tensorflow as tf from .base import StochasticPolicy - +# TODO: the following, especially the target network construction should be refactored to be more neat +# even if policy_callable don't return a distribution class class OnehotCategorical(StochasticPolicy): """ The class of one-hot Categorical distribution. @@ -33,19 +34,62 @@ class OnehotCategorical(StochasticPolicy): `[i, j, ..., k, :]` is a one-hot vector of the selected category. """ - def __init__(self, logits, observation_placeholder, dtype=None, group_ndims=0, **kwargs): - self._logits = tf.convert_to_tensor(logits) - self._action = tf.multinomial(self.logits, num_samples=1) + def __init__(self, + policy_callable, + observation_placeholder, + weight_update=1, + group_ndims=1, + **kwargs): + self.managed_placeholders = {'observation': observation_placeholder} + self.weight_update = weight_update + self.interaction_count = -1 # defaults to -1. only useful if weight_update > 1. - if dtype is None: - dtype = tf.int32 - # assert_same_float_and_int_dtype([], dtype) + with tf.variable_scope('network'): + logits, value_head = policy_callable() + self._logits = tf.convert_to_tensor(logits, dtype=tf.float32) + self._action = tf.multinomial(self.logits, num_samples=1) + # TODO: self._action should be exactly the action tensor to run that directly gives action_dim - tf.assert_rank(self._logits, rank=2) # TODO: flexible policy output rank? + if value_head is not None: + pass + + self.trainable_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='network') + + if self.weight_update == 1: + self.weight_update_ops = None + self.sync_weights_ops = None + else: # then we need to build another tf graph as target network + with tf.variable_scope('net_old'): + logits, value_head = policy_callable() + self._logits_old = tf.convert_to_tensor(logits, dtype=tf.float32) + + if value_head is not None: # useful in DDPG + pass + + network_weights = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='network') + network_old_weights = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='net_old') + # TODO: use a scope that the user will almost surely not use. so get_collection will return + # the correct weights and old_weights, since it filters by regular expression + # or we write a util to parse the variable names and use only the topmost scope + + assert len(network_weights) == len(network_old_weights) + self.sync_weights_ops = [tf.assign(variable_old, variable) + for (variable_old, variable) in zip(network_old_weights, network_weights)] + + if weight_update == 0: + self.weight_update_ops = self.sync_weights_ops + elif 0 < weight_update < 1: # as in DDPG + pass + else: + self.interaction_count = 0 # as in DQN + import math + self.weight_update = math.ceil(weight_update) + + tf.assert_rank(self._logits, rank=2) # TODO: flexible policy output rank, e.g. RNN self._n_categories = self._logits.get_shape()[-1].value super(OnehotCategorical, self).__init__( - act_dtype=dtype, + act_dtype=tf.int32, param_dtype=self._logits.dtype, is_continuous=False, observation_placeholder=observation_placeholder, @@ -62,12 +106,18 @@ class OnehotCategorical(StochasticPolicy): """The number of categories in the distribution.""" return self._n_categories - def _act(self, observation): + @property + def action_shape(self): + return () + + def _act(self, observation, my_feed_dict): # TODO: this may be ugly. also maybe huge problem when parallel sess = tf.get_default_session() # observation[None] adds one dimension at the beginning - sampled_action = sess.run(self._action, - feed_dict={self._observation_placeholder: observation[None]}) + + feed_dict = {self._observation_placeholder: observation[None]} + feed_dict.update(my_feed_dict) + sampled_action = sess.run(self._action, feed_dict=feed_dict) sampled_action = sampled_action[0, 0] @@ -76,10 +126,30 @@ class OnehotCategorical(StochasticPolicy): def _log_prob(self, sampled_action): return -tf.nn.sparse_softmax_cross_entropy_with_logits(labels=sampled_action, logits=self.logits) - def _prob(self, sampled_action): return tf.exp(self._log_prob(sampled_action)) + def log_prob_old(self, sampled_action): + return -tf.nn.sparse_softmax_cross_entropy_with_logits(labels=sampled_action, logits=self._logits_old) + + def update_weights(self): + """ + updates the weights of policy_old. + :return: + """ + if self.weight_update_ops is not None: + sess = tf.get_default_session() + sess.run(self.weight_update_ops) + + def sync_weights(self): + """ + sync the weights of network_old. Direct copy the weights of network. + :return: + """ + if self.sync_weights_ops is not None: + sess = tf.get_default_session() + sess.run(self.sync_weights_ops) + OnehotDiscrete = OnehotCategorical @@ -111,7 +181,7 @@ class Normal(StochasticPolicy): shape = tf.broadcast_dynamic_shape(tf.shape(self._mean), tf.shape(self._std)) self._action = tf.random_normal(tf.concat([[1], shape], 0), dtype = tf.float32) * self._std + self._mean - # TODO: self._action should be exactly the action tensor to run, without [0, 0] in self._act + # TODO: self._action should be exactly the action tensor to run that directly gives action_dim if value_head is not None: pass @@ -131,7 +201,10 @@ class Normal(StochasticPolicy): pass network_weights = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='network') - self.network_old_weights = network_old_weights = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='net_old') + network_old_weights = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='net_old') + # TODO: use a scope that the user will almost surely not use. so get_collection will return + # the correct weights and old_weights, since it filters by regular expression + assert len(network_weights) == len(network_old_weights) self.sync_weights_ops = [tf.assign(variable_old, variable) for (variable_old, variable) in zip(network_old_weights, network_weights)] @@ -168,12 +241,8 @@ class Normal(StochasticPolicy): return self._logstd @property - def action(self): - return self._action - - @property - def action_dim(self): - return self.mean.shape.as_list()[-1] + def action_shape(self): + return tuple(self._mean.shape.as_list[1:]) def _act(self, observation, my_feed_dict): # TODO: getting session like this maybe ugly. also maybe huge problem when parallel diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 2e372ea..24d0af7 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -14,7 +14,7 @@ class Batch(object): self._advantage_estimation_function = advantage_estimation_function self._is_first_collect = True - def collect(self, num_timesteps=0, num_episodes=0, + def collect(self, num_timesteps=0, num_episodes=0, my_feed_dict={}, apply_function=True): # specify how many data to collect here, or fix it in __init__() assert sum( [num_timesteps > 0, num_episodes > 0]) == 1, "One and only one collection number specification permitted!" @@ -87,7 +87,7 @@ class Batch(object): episode_start_flags.append(True) while True: - ac = self._pi.act(ob) + ac = self._pi.act(ob, my_feed_dict) actions.append(ac) ob, reward, done, _ = self._env.step(ac) @@ -139,6 +139,7 @@ class Batch(object): feed_dict[self._pi.managed_placeholders['observation']] = current_batch['observations'] feed_dict[self._pi.managed_placeholders['action']] = current_batch['actions'] feed_dict[self._pi.managed_placeholders['processed_reward']] = current_batch['returns'] + # TODO: should use the keys in pi.managed_placeholders to find values in self.data and self.raw_data return feed_dict