diff --git a/examples/ddpg_example.py b/examples/ddpg_example.py index ca13423..74d3cdd 100644 --- a/examples/ddpg_example.py +++ b/examples/ddpg_example.py @@ -28,13 +28,13 @@ if __name__ == '__main__': parser.add_argument("--render", action="store_true", default=False) args = parser.parse_args() - env = gym.make('MountainCarContinuous-v0') + env = gym.make('Pendulum-v0') observation_dim = env.observation_space.shape action_dim = env.action_space.shape batch_size = 32 - seed = 0 + seed = 123 np.random.seed(seed) tf.set_random_seed(seed) @@ -43,13 +43,15 @@ if __name__ == '__main__': action_ph = tf.placeholder(tf.float32, shape=(None,) + action_dim) def my_network(): - net = tf.layers.dense(observation_ph, 32, activation=tf.nn.relu) - net = tf.layers.dense(net, 32, activation=tf.nn.relu) + net = tf.layers.dense(observation_ph, 16, activation=tf.nn.relu) + net = tf.layers.dense(net, 16, activation=tf.nn.relu) + net = tf.layers.dense(net, 16, activation=tf.nn.relu) action = tf.layers.dense(net, action_dim[0], activation=None) action_value_input = tf.concat([observation_ph, action_ph], axis=1) net = tf.layers.dense(action_value_input, 32, activation=tf.nn.relu) net = tf.layers.dense(net, 32, activation=tf.nn.relu) + net = tf.layers.dense(net, 32, activation=tf.nn.relu) action_value = tf.layers.dense(net, 1, activation=None) return action, action_value @@ -61,14 +63,20 @@ if __name__ == '__main__': critic_loss = losses.value_mse(critic) critic_optimizer = tf.train.AdamOptimizer(1e-3) - critic_train_op = critic_optimizer.minimize(critic_loss, var_list=critic.trainable_variables) + # clip by norm + critic_grads, vars = zip(*critic_optimizer.compute_gradients(critic_loss, var_list=critic.trainable_variables)) + critic_grads, _ = tf.clip_by_global_norm(critic_grads, 1.0) + critic_train_op = critic_optimizer.apply_gradients(zip(critic_grads, vars)) - dpg_grads = opt.DPG(actor, critic) # check which action to use in dpg - actor_optimizer = tf.train.AdamOptimizer(1e-4) - actor_train_op = actor_optimizer.apply_gradients(dpg_grads) + dpg_grads_vars = opt.DPG(actor, critic) # check which action to use in dpg + # clip by norm + dpg_grads, vars = zip(*dpg_grads_vars) + dpg_grads, _ = tf.clip_by_global_norm(dpg_grads, 1.0) + actor_optimizer = tf.train.AdamOptimizer(1e-3) + actor_train_op = actor_optimizer.apply_gradients(zip(dpg_grads, vars)) ### 3. define data collection - data_buffer = VanillaReplayBuffer(capacity=2e4, nstep=1) + data_buffer = VanillaReplayBuffer(capacity=100000, nstep=1) process_functions = [advantage_estimation.ddpg_return(actor, critic)] @@ -91,10 +99,10 @@ if __name__ == '__main__': critic.sync_weights() start_time = time.time() - data_collector.collect(num_timesteps=1e3) # warm-up + data_collector.collect(num_timesteps=100) # warm-up for i in range(int(1e8)): # collect data - data_collector.collect(num_timesteps=1) + data_collector.collect(num_timesteps=1, episode_cutoff=200) # update network feed_dict = data_collector.next_batch(batch_size) @@ -108,4 +116,4 @@ if __name__ == '__main__': # test every 1000 training steps if i % 1000 == 0: print('Step {}, elapsed time: {:.1f} min'.format(i, (time.time() - start_time) / 60)) - test_policy_in_env(actor, env, num_timesteps=100) + test_policy_in_env(actor, env, num_episodes=5, episode_cutoff=200) diff --git a/examples/dqn.py b/examples/dqn.py index 9fb8b4f..883d483 100644 --- a/examples/dqn.py +++ b/examples/dqn.py @@ -61,7 +61,7 @@ if __name__ == '__main__': ### 4. start training # hyper-parameters - batch_size = 128 + batch_size = 32 replay_buffer_warmup = 1000 epsilon_decay_interval = 500 epsilon = 0.6 diff --git a/tianshou/core/policy/deterministic.py b/tianshou/core/policy/deterministic.py index 77391aa..b020259 100644 --- a/tianshou/core/policy/deterministic.py +++ b/tianshou/core/policy/deterministic.py @@ -51,7 +51,7 @@ class Deterministic(PolicyBase): self.weight_update = math.ceil(weight_update) self.random_process = random_process or OrnsteinUhlenbeckProcess( - theta=0.15, sigma=0.2, size=self.action.shape.as_list()[-1]) + theta=0.15, sigma=0.3, size=self.action.shape.as_list()[-1]) @property def action_shape(self): diff --git a/tianshou/data/data_collector.py b/tianshou/data/data_collector.py index acd4a13..42efd8e 100644 --- a/tianshou/data/data_collector.py +++ b/tianshou/data/data_collector.py @@ -31,8 +31,9 @@ class DataCollector(object): self.process_mode = 'full' self.current_observation = self.env.reset() + self.step_count_this_episode = 0 - def collect(self, num_timesteps=0, num_episodes=0, my_feed_dict={}, auto_clear=True): + def collect(self, num_timesteps=0, num_episodes=0, my_feed_dict={}, auto_clear=True, episode_cutoff=None): assert sum([num_timesteps > 0, num_episodes > 0]) == 1,\ "One and only one collection number specification permitted!" @@ -44,11 +45,15 @@ class DataCollector(object): for _ in range(num_timesteps_): action = self.policy.act(self.current_observation, my_feed_dict=my_feed_dict) next_observation, reward, done, _ = self.env.step(action) + self.step_count_this_episode += 1 + if episode_cutoff and self.step_count_this_episode >= episode_cutoff: + done = True self.data_buffer.add((self.current_observation, action, reward, done)) if done: self.current_observation = self.env.reset() self.policy.reset() + self.step_count_this_episode = 0 else: self.current_observation = next_observation @@ -57,11 +62,18 @@ class DataCollector(object): for _ in range(num_episodes_): observation = self.env.reset() done = False + step_count = 0 while not done: action = self.policy.act(observation, my_feed_dict=my_feed_dict) next_observation, reward, done, _ = self.env.step(action) + step_count += 1 + + if episode_cutoff and step_count >= episode_cutoff: + done = True + self.data_buffer.add((observation, action, reward, done)) observation = next_observation + self.current_observation = self.env.reset() if self.process_mode == 'full': diff --git a/tianshou/data/tester.py b/tianshou/data/tester.py index 7f55ab3..8983b7c 100644 --- a/tianshou/data/tester.py +++ b/tianshou/data/tester.py @@ -5,7 +5,7 @@ import logging import numpy as np -def test_policy_in_env(policy, env, num_timesteps=0, num_episodes=0, discount_factor=0.99): +def test_policy_in_env(policy, env, num_timesteps=0, num_episodes=0, discount_factor=0.99, episode_cutoff=None): assert sum([num_episodes > 0, num_timesteps > 0]) == 1, \ 'One and only one collection number specification permitted!' @@ -29,12 +29,16 @@ def test_policy_in_env(policy, env, num_timesteps=0, num_episodes=0, discount_fa current_discount = 1. observation = env_.reset() done = False + step_count = 0 while not done: action = policy.act_test(observation) observation, reward, done, _ = env_.step(action) current_return += reward * current_discount current_undiscounted_return += reward current_discount *= discount_factor + step_count += 1 + if episode_cutoff and step_count >= episode_cutoff: + break returns[i] = current_return undiscounted_returns[i] = current_undiscounted_return