From 52e6b09768a36cf1886858d10c6511830aa76de5 Mon Sep 17 00:00:00 2001 From: haoshengzou Date: Sun, 11 Mar 2018 17:47:42 +0800 Subject: [PATCH] finish ddpg. now ppo, actor-critic, dqn works. ddpg is not working, check! --- examples/ddpg_example.py | 46 ++++++++++++------- tianshou/core/policy/base.py | 7 +++ tianshou/core/policy/deterministic.py | 21 ++++++++- tianshou/core/random.py | 64 +++++++++++++++++++++++++++ tianshou/data/advantage_estimation.py | 37 ++++++++++++++-- tianshou/data/data_collector.py | 2 + 6 files changed, 158 insertions(+), 19 deletions(-) create mode 100644 tianshou/core/random.py diff --git a/examples/ddpg_example.py b/examples/ddpg_example.py index faa8dcd..ca13423 100644 --- a/examples/ddpg_example.py +++ b/examples/ddpg_example.py @@ -6,29 +6,33 @@ import gym import numpy as np import time import argparse +import logging +logging.basicConfig(level=logging.INFO) # our lib imports here! It's ok to append path in examples import sys sys.path.append('..') from tianshou.core import losses -from tianshou.data.batch import Batch import tianshou.data.advantage_estimation as advantage_estimation import tianshou.core.policy as policy import tianshou.core.value_function.action_value as value_function import tianshou.core.opt as opt +from tianshou.data.data_buffer.vanilla import VanillaReplayBuffer +from tianshou.data.data_collector import DataCollector +from tianshou.data.tester import test_policy_in_env + if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument("--render", action="store_true", default=False) args = parser.parse_args() - env = gym.make('Pendulum-v0') + + env = gym.make('MountainCarContinuous-v0') observation_dim = env.observation_space.shape action_dim = env.action_space.shape - clip_param = 0.2 - num_batches = 10 - batch_size = 512 + batch_size = 32 seed = 0 np.random.seed(seed) @@ -59,13 +63,22 @@ if __name__ == '__main__': critic_optimizer = tf.train.AdamOptimizer(1e-3) critic_train_op = critic_optimizer.minimize(critic_loss, var_list=critic.trainable_variables) - dpg_grads = opt.DPG(actor, critic) # not sure if it's correct + dpg_grads = opt.DPG(actor, critic) # check which action to use in dpg actor_optimizer = tf.train.AdamOptimizer(1e-4) actor_train_op = actor_optimizer.apply_gradients(dpg_grads) ### 3. define data collection - data_collector = Batch(env, actor, [advantage_estimation.ddpg_return(actor, critic)], [actor, critic], - render = args.render) + data_buffer = VanillaReplayBuffer(capacity=2e4, nstep=1) + + process_functions = [advantage_estimation.ddpg_return(actor, critic)] + + data_collector = DataCollector( + env=env, + policy=actor, + data_buffer=data_buffer, + process_functions=process_functions, + managed_networks=[actor, critic] + ) ### 4. start training config = tf.ConfigProto() @@ -74,22 +87,25 @@ if __name__ == '__main__': sess.run(tf.global_variables_initializer()) # assign actor to pi_old - actor.sync_weights() # TODO: automate this for policies with target network + actor.sync_weights() critic.sync_weights() start_time = time.time() data_collector.collect(num_timesteps=1e3) # warm-up for i in range(int(1e8)): # collect data - data_collector.collect() + data_collector.collect(num_timesteps=1) # update network - for _ in range(num_batches): - feed_dict = data_collector.next_batch(batch_size) - sess.run([actor_train_op, critic_train_op], feed_dict=feed_dict) + feed_dict = data_collector.next_batch(batch_size) + sess.run(critic_train_op, feed_dict=feed_dict) + sess.run(actor_train_op, feed_dict=feed_dict) - print('Elapsed time: {:.1f} min'.format((time.time() - start_time) / 60)) + # update target networks + actor.sync_weights() + critic.sync_weights() # test every 1000 training steps if i % 1000 == 0: - test(env, actor) + print('Step {}, elapsed time: {:.1f} min'.format(i, (time.time() - start_time) / 60)) + test_policy_in_env(actor, env, num_timesteps=100) diff --git a/tianshou/core/policy/base.py b/tianshou/core/policy/base.py index 6a060ce..ebada64 100644 --- a/tianshou/core/policy/base.py +++ b/tianshou/core/policy/base.py @@ -19,6 +19,13 @@ class PolicyBase(object): def act(self, observation, my_feed_dict): raise NotImplementedError() + def reset(self): + """ + for temporal correlated random process exploration, as in DDPG + :return: + """ + pass + class StochasticPolicy(PolicyBase): """ diff --git a/tianshou/core/policy/deterministic.py b/tianshou/core/policy/deterministic.py index 7cd7a1c..77391aa 100644 --- a/tianshou/core/policy/deterministic.py +++ b/tianshou/core/policy/deterministic.py @@ -1,11 +1,12 @@ import tensorflow as tf from .base import PolicyBase +from ..random import OrnsteinUhlenbeckProcess class Deterministic(PolicyBase): """ deterministic policy as used in deterministic policy gradient methods """ - def __init__(self, policy_callable, observation_placeholder, weight_update=1): + def __init__(self, policy_callable, observation_placeholder, weight_update=1, random_process=None): self._observation_placeholder = observation_placeholder self.managed_placeholders = {'observation': observation_placeholder} self.weight_update = weight_update @@ -49,6 +50,9 @@ class Deterministic(PolicyBase): import math self.weight_update = math.ceil(weight_update) + self.random_process = random_process or OrnsteinUhlenbeckProcess( + theta=0.15, sigma=0.2, size=self.action.shape.as_list()[-1]) + @property def action_shape(self): return self.action.shape.as_list()[1:] @@ -62,6 +66,21 @@ class Deterministic(PolicyBase): feed_dict.update(my_feed_dict) sampled_action = sess.run(self.action, feed_dict=feed_dict) + sampled_action = sampled_action[0] + self.random_process.sample() + + return sampled_action + + def reset(self): + self.random_process.reset_states() + + def act_test(self, observation, my_feed_dict={}): + sess = tf.get_default_session() + # observation[None] adds one dimension at the beginning + + feed_dict = {self._observation_placeholder: observation[None]} + feed_dict.update(my_feed_dict) + sampled_action = sess.run(self.action, feed_dict=feed_dict) + sampled_action = sampled_action[0] return sampled_action diff --git a/tianshou/core/random.py b/tianshou/core/random.py new file mode 100644 index 0000000..fe6e5c7 --- /dev/null +++ b/tianshou/core/random.py @@ -0,0 +1,64 @@ +""" +adapted from keras-rl +""" + +from __future__ import division +import numpy as np + + +class RandomProcess(object): + def reset_states(self): + pass + + +class AnnealedGaussianProcess(RandomProcess): + def __init__(self, mu, sigma, sigma_min, n_steps_annealing): + self.mu = mu + self.sigma = sigma + self.n_steps = 0 + + if sigma_min is not None: + self.m = -float(sigma - sigma_min) / float(n_steps_annealing) + self.c = sigma + self.sigma_min = sigma_min + else: + self.m = 0. + self.c = sigma + self.sigma_min = sigma + + @property + def current_sigma(self): + sigma = max(self.sigma_min, self.m * float(self.n_steps) + self.c) + return sigma + + +class GaussianWhiteNoiseProcess(AnnealedGaussianProcess): + def __init__(self, mu=0., sigma=1., sigma_min=None, n_steps_annealing=1000, size=1): + super(GaussianWhiteNoiseProcess, self).__init__(mu=mu, sigma=sigma, sigma_min=sigma_min, n_steps_annealing=n_steps_annealing) + self.size = size + + def sample(self): + sample = np.random.normal(self.mu, self.current_sigma, self.size) + self.n_steps += 1 + return sample + +# Based on http://math.stackexchange.com/questions/1287634/implementing-ornstein-uhlenbeck-in-matlab +class OrnsteinUhlenbeckProcess(AnnealedGaussianProcess): + def __init__(self, theta, mu=0., sigma=1., dt=1e-2, x0=None, size=1, sigma_min=None, n_steps_annealing=1000): + super(OrnsteinUhlenbeckProcess, self).__init__( + mu=mu, sigma=sigma, sigma_min=sigma_min, n_steps_annealing=n_steps_annealing) + self.theta = theta + self.mu = mu + self.dt = dt + self.x0 = x0 + self.size = size + self.reset_states() + + def sample(self): + x = self.x_prev + self.theta * (self.mu - self.x_prev) * self.dt + self.current_sigma * np.sqrt(self.dt) * np.random.normal(size=self.size) + self.x_prev = x + self.n_steps += 1 + return x + + def reset_states(self): + self.x_prev = self.x0 if self.x0 is not None else np.zeros(self.size) diff --git a/tianshou/data/advantage_estimation.py b/tianshou/data/advantage_estimation.py index 621684e..010d68c 100644 --- a/tianshou/data/advantage_estimation.py +++ b/tianshou/data/advantage_estimation.py @@ -127,19 +127,50 @@ class ddpg_return: """ compute the return as in DDPG. this seems to have to be special """ - def __init__(self, actor, critic, use_target_network=True): + def __init__(self, actor, critic, use_target_network=True, discount_factor=0.99): self.actor = actor self.critic = critic self.use_target_network = use_target_network + self.discount_factor = discount_factor - def __call__(self, buffer, index=None): + def __call__(self, buffer, indexes=None): """ :param buffer: buffer with property index and data. index determines the current content in `buffer`. :param index: (sampled) index to be computed. Defaults to all the data in `buffer`. Not necessarily in order within each episode. :return: dict with key 'return' and value the computed returns corresponding to `index`. """ - pass + indexes = indexes or buffer.index + episodes = buffer.data + returns = [] + + for i_episode in range(len(indexes)): + index_this = indexes[i_episode] + if index_this: + episode = episodes[i_episode] + returns_this = [] + + for i in index_this: + return_ = episode[i][REWARD] + if not episode[i][DONE]: + if self.use_target_network: + state = episode[i + 1][STATE][None] + action = self.actor.eval_action_old(state) + q_value = self.critic.eval_value_old(state, action) + return_ += self.discount_factor * q_value + else: + state = episode[i + 1][STATE][None] + action = self.actor.eval_action(state) + q_value = self.critic.eval_value(state, action) + return_ += self.discount_factor * q_value + + returns_this.append(return_) + + returns.append(returns_this) + else: + returns.append([]) + + return {'return': returns} class nstep_q_return: diff --git a/tianshou/data/data_collector.py b/tianshou/data/data_collector.py index c887b78..acd4a13 100644 --- a/tianshou/data/data_collector.py +++ b/tianshou/data/data_collector.py @@ -48,6 +48,7 @@ class DataCollector(object): if done: self.current_observation = self.env.reset() + self.policy.reset() else: self.current_observation = next_observation @@ -61,6 +62,7 @@ class DataCollector(object): next_observation, reward, done, _ = self.env.step(action) self.data_buffer.add((observation, action, reward, done)) observation = next_observation + self.current_observation = self.env.reset() if self.process_mode == 'full': for processor in self.process_functions: