towards ddpg

This commit is contained in:
haoshengzou 2018-03-28 18:47:41 +08:00
parent 52e6b09768
commit 75e7f14051
5 changed files with 40 additions and 16 deletions

View File

@ -28,13 +28,13 @@ if __name__ == '__main__':
parser.add_argument("--render", action="store_true", default=False)
args = parser.parse_args()
env = gym.make('MountainCarContinuous-v0')
env = gym.make('Pendulum-v0')
observation_dim = env.observation_space.shape
action_dim = env.action_space.shape
batch_size = 32
seed = 0
seed = 123
np.random.seed(seed)
tf.set_random_seed(seed)
@ -43,13 +43,15 @@ if __name__ == '__main__':
action_ph = tf.placeholder(tf.float32, shape=(None,) + action_dim)
def my_network():
net = tf.layers.dense(observation_ph, 32, activation=tf.nn.relu)
net = tf.layers.dense(net, 32, activation=tf.nn.relu)
net = tf.layers.dense(observation_ph, 16, activation=tf.nn.relu)
net = tf.layers.dense(net, 16, activation=tf.nn.relu)
net = tf.layers.dense(net, 16, activation=tf.nn.relu)
action = tf.layers.dense(net, action_dim[0], activation=None)
action_value_input = tf.concat([observation_ph, action_ph], axis=1)
net = tf.layers.dense(action_value_input, 32, activation=tf.nn.relu)
net = tf.layers.dense(net, 32, activation=tf.nn.relu)
net = tf.layers.dense(net, 32, activation=tf.nn.relu)
action_value = tf.layers.dense(net, 1, activation=None)
return action, action_value
@ -61,14 +63,20 @@ if __name__ == '__main__':
critic_loss = losses.value_mse(critic)
critic_optimizer = tf.train.AdamOptimizer(1e-3)
critic_train_op = critic_optimizer.minimize(critic_loss, var_list=critic.trainable_variables)
# clip by norm
critic_grads, vars = zip(*critic_optimizer.compute_gradients(critic_loss, var_list=critic.trainable_variables))
critic_grads, _ = tf.clip_by_global_norm(critic_grads, 1.0)
critic_train_op = critic_optimizer.apply_gradients(zip(critic_grads, vars))
dpg_grads = opt.DPG(actor, critic) # check which action to use in dpg
actor_optimizer = tf.train.AdamOptimizer(1e-4)
actor_train_op = actor_optimizer.apply_gradients(dpg_grads)
dpg_grads_vars = opt.DPG(actor, critic) # check which action to use in dpg
# clip by norm
dpg_grads, vars = zip(*dpg_grads_vars)
dpg_grads, _ = tf.clip_by_global_norm(dpg_grads, 1.0)
actor_optimizer = tf.train.AdamOptimizer(1e-3)
actor_train_op = actor_optimizer.apply_gradients(zip(dpg_grads, vars))
### 3. define data collection
data_buffer = VanillaReplayBuffer(capacity=2e4, nstep=1)
data_buffer = VanillaReplayBuffer(capacity=100000, nstep=1)
process_functions = [advantage_estimation.ddpg_return(actor, critic)]
@ -91,10 +99,10 @@ if __name__ == '__main__':
critic.sync_weights()
start_time = time.time()
data_collector.collect(num_timesteps=1e3) # warm-up
data_collector.collect(num_timesteps=100) # warm-up
for i in range(int(1e8)):
# collect data
data_collector.collect(num_timesteps=1)
data_collector.collect(num_timesteps=1, episode_cutoff=200)
# update network
feed_dict = data_collector.next_batch(batch_size)
@ -108,4 +116,4 @@ if __name__ == '__main__':
# test every 1000 training steps
if i % 1000 == 0:
print('Step {}, elapsed time: {:.1f} min'.format(i, (time.time() - start_time) / 60))
test_policy_in_env(actor, env, num_timesteps=100)
test_policy_in_env(actor, env, num_episodes=5, episode_cutoff=200)

View File

@ -61,7 +61,7 @@ if __name__ == '__main__':
### 4. start training
# hyper-parameters
batch_size = 128
batch_size = 32
replay_buffer_warmup = 1000
epsilon_decay_interval = 500
epsilon = 0.6

View File

@ -51,7 +51,7 @@ class Deterministic(PolicyBase):
self.weight_update = math.ceil(weight_update)
self.random_process = random_process or OrnsteinUhlenbeckProcess(
theta=0.15, sigma=0.2, size=self.action.shape.as_list()[-1])
theta=0.15, sigma=0.3, size=self.action.shape.as_list()[-1])
@property
def action_shape(self):

View File

@ -31,8 +31,9 @@ class DataCollector(object):
self.process_mode = 'full'
self.current_observation = self.env.reset()
self.step_count_this_episode = 0
def collect(self, num_timesteps=0, num_episodes=0, my_feed_dict={}, auto_clear=True):
def collect(self, num_timesteps=0, num_episodes=0, my_feed_dict={}, auto_clear=True, episode_cutoff=None):
assert sum([num_timesteps > 0, num_episodes > 0]) == 1,\
"One and only one collection number specification permitted!"
@ -44,11 +45,15 @@ class DataCollector(object):
for _ in range(num_timesteps_):
action = self.policy.act(self.current_observation, my_feed_dict=my_feed_dict)
next_observation, reward, done, _ = self.env.step(action)
self.step_count_this_episode += 1
if episode_cutoff and self.step_count_this_episode >= episode_cutoff:
done = True
self.data_buffer.add((self.current_observation, action, reward, done))
if done:
self.current_observation = self.env.reset()
self.policy.reset()
self.step_count_this_episode = 0
else:
self.current_observation = next_observation
@ -57,11 +62,18 @@ class DataCollector(object):
for _ in range(num_episodes_):
observation = self.env.reset()
done = False
step_count = 0
while not done:
action = self.policy.act(observation, my_feed_dict=my_feed_dict)
next_observation, reward, done, _ = self.env.step(action)
step_count += 1
if episode_cutoff and step_count >= episode_cutoff:
done = True
self.data_buffer.add((observation, action, reward, done))
observation = next_observation
self.current_observation = self.env.reset()
if self.process_mode == 'full':

View File

@ -5,7 +5,7 @@ import logging
import numpy as np
def test_policy_in_env(policy, env, num_timesteps=0, num_episodes=0, discount_factor=0.99):
def test_policy_in_env(policy, env, num_timesteps=0, num_episodes=0, discount_factor=0.99, episode_cutoff=None):
assert sum([num_episodes > 0, num_timesteps > 0]) == 1, \
'One and only one collection number specification permitted!'
@ -29,12 +29,16 @@ def test_policy_in_env(policy, env, num_timesteps=0, num_episodes=0, discount_fa
current_discount = 1.
observation = env_.reset()
done = False
step_count = 0
while not done:
action = policy.act_test(observation)
observation, reward, done, _ = env_.step(action)
current_return += reward * current_discount
current_undiscounted_return += reward
current_discount *= discount_factor
step_count += 1
if episode_cutoff and step_count >= episode_cutoff:
break
returns[i] = current_return
undiscounted_returns[i] = current_undiscounted_return