finish ddpg. now ppo, actor-critic, dqn works. ddpg is not working, check!

This commit is contained in:
haoshengzou 2018-03-11 17:47:42 +08:00
parent a86354834c
commit 52e6b09768
6 changed files with 158 additions and 19 deletions

View File

@ -6,29 +6,33 @@ import gym
import numpy as np import numpy as np
import time import time
import argparse import argparse
import logging
logging.basicConfig(level=logging.INFO)
# our lib imports here! It's ok to append path in examples # our lib imports here! It's ok to append path in examples
import sys import sys
sys.path.append('..') sys.path.append('..')
from tianshou.core import losses from tianshou.core import losses
from tianshou.data.batch import Batch
import tianshou.data.advantage_estimation as advantage_estimation import tianshou.data.advantage_estimation as advantage_estimation
import tianshou.core.policy as policy import tianshou.core.policy as policy
import tianshou.core.value_function.action_value as value_function import tianshou.core.value_function.action_value as value_function
import tianshou.core.opt as opt import tianshou.core.opt as opt
from tianshou.data.data_buffer.vanilla import VanillaReplayBuffer
from tianshou.data.data_collector import DataCollector
from tianshou.data.tester import test_policy_in_env
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--render", action="store_true", default=False) parser.add_argument("--render", action="store_true", default=False)
args = parser.parse_args() args = parser.parse_args()
env = gym.make('Pendulum-v0')
env = gym.make('MountainCarContinuous-v0')
observation_dim = env.observation_space.shape observation_dim = env.observation_space.shape
action_dim = env.action_space.shape action_dim = env.action_space.shape
clip_param = 0.2 batch_size = 32
num_batches = 10
batch_size = 512
seed = 0 seed = 0
np.random.seed(seed) np.random.seed(seed)
@ -59,13 +63,22 @@ if __name__ == '__main__':
critic_optimizer = tf.train.AdamOptimizer(1e-3) critic_optimizer = tf.train.AdamOptimizer(1e-3)
critic_train_op = critic_optimizer.minimize(critic_loss, var_list=critic.trainable_variables) critic_train_op = critic_optimizer.minimize(critic_loss, var_list=critic.trainable_variables)
dpg_grads = opt.DPG(actor, critic) # not sure if it's correct dpg_grads = opt.DPG(actor, critic) # check which action to use in dpg
actor_optimizer = tf.train.AdamOptimizer(1e-4) actor_optimizer = tf.train.AdamOptimizer(1e-4)
actor_train_op = actor_optimizer.apply_gradients(dpg_grads) actor_train_op = actor_optimizer.apply_gradients(dpg_grads)
### 3. define data collection ### 3. define data collection
data_collector = Batch(env, actor, [advantage_estimation.ddpg_return(actor, critic)], [actor, critic], data_buffer = VanillaReplayBuffer(capacity=2e4, nstep=1)
render = args.render)
process_functions = [advantage_estimation.ddpg_return(actor, critic)]
data_collector = DataCollector(
env=env,
policy=actor,
data_buffer=data_buffer,
process_functions=process_functions,
managed_networks=[actor, critic]
)
### 4. start training ### 4. start training
config = tf.ConfigProto() config = tf.ConfigProto()
@ -74,22 +87,25 @@ if __name__ == '__main__':
sess.run(tf.global_variables_initializer()) sess.run(tf.global_variables_initializer())
# assign actor to pi_old # assign actor to pi_old
actor.sync_weights() # TODO: automate this for policies with target network actor.sync_weights()
critic.sync_weights() critic.sync_weights()
start_time = time.time() start_time = time.time()
data_collector.collect(num_timesteps=1e3) # warm-up data_collector.collect(num_timesteps=1e3) # warm-up
for i in range(int(1e8)): for i in range(int(1e8)):
# collect data # collect data
data_collector.collect() data_collector.collect(num_timesteps=1)
# update network # update network
for _ in range(num_batches): feed_dict = data_collector.next_batch(batch_size)
feed_dict = data_collector.next_batch(batch_size) sess.run(critic_train_op, feed_dict=feed_dict)
sess.run([actor_train_op, critic_train_op], feed_dict=feed_dict) sess.run(actor_train_op, feed_dict=feed_dict)
print('Elapsed time: {:.1f} min'.format((time.time() - start_time) / 60)) # update target networks
actor.sync_weights()
critic.sync_weights()
# test every 1000 training steps # test every 1000 training steps
if i % 1000 == 0: if i % 1000 == 0:
test(env, actor) print('Step {}, elapsed time: {:.1f} min'.format(i, (time.time() - start_time) / 60))
test_policy_in_env(actor, env, num_timesteps=100)

View File

@ -19,6 +19,13 @@ class PolicyBase(object):
def act(self, observation, my_feed_dict): def act(self, observation, my_feed_dict):
raise NotImplementedError() raise NotImplementedError()
def reset(self):
"""
for temporal correlated random process exploration, as in DDPG
:return:
"""
pass
class StochasticPolicy(PolicyBase): class StochasticPolicy(PolicyBase):
""" """

View File

@ -1,11 +1,12 @@
import tensorflow as tf import tensorflow as tf
from .base import PolicyBase from .base import PolicyBase
from ..random import OrnsteinUhlenbeckProcess
class Deterministic(PolicyBase): class Deterministic(PolicyBase):
""" """
deterministic policy as used in deterministic policy gradient methods deterministic policy as used in deterministic policy gradient methods
""" """
def __init__(self, policy_callable, observation_placeholder, weight_update=1): def __init__(self, policy_callable, observation_placeholder, weight_update=1, random_process=None):
self._observation_placeholder = observation_placeholder self._observation_placeholder = observation_placeholder
self.managed_placeholders = {'observation': observation_placeholder} self.managed_placeholders = {'observation': observation_placeholder}
self.weight_update = weight_update self.weight_update = weight_update
@ -49,6 +50,9 @@ class Deterministic(PolicyBase):
import math import math
self.weight_update = math.ceil(weight_update) self.weight_update = math.ceil(weight_update)
self.random_process = random_process or OrnsteinUhlenbeckProcess(
theta=0.15, sigma=0.2, size=self.action.shape.as_list()[-1])
@property @property
def action_shape(self): def action_shape(self):
return self.action.shape.as_list()[1:] return self.action.shape.as_list()[1:]
@ -62,6 +66,21 @@ class Deterministic(PolicyBase):
feed_dict.update(my_feed_dict) feed_dict.update(my_feed_dict)
sampled_action = sess.run(self.action, feed_dict=feed_dict) sampled_action = sess.run(self.action, feed_dict=feed_dict)
sampled_action = sampled_action[0] + self.random_process.sample()
return sampled_action
def reset(self):
self.random_process.reset_states()
def act_test(self, observation, my_feed_dict={}):
sess = tf.get_default_session()
# observation[None] adds one dimension at the beginning
feed_dict = {self._observation_placeholder: observation[None]}
feed_dict.update(my_feed_dict)
sampled_action = sess.run(self.action, feed_dict=feed_dict)
sampled_action = sampled_action[0] sampled_action = sampled_action[0]
return sampled_action return sampled_action

64
tianshou/core/random.py Normal file
View File

@ -0,0 +1,64 @@
"""
adapted from keras-rl
"""
from __future__ import division
import numpy as np
class RandomProcess(object):
def reset_states(self):
pass
class AnnealedGaussianProcess(RandomProcess):
def __init__(self, mu, sigma, sigma_min, n_steps_annealing):
self.mu = mu
self.sigma = sigma
self.n_steps = 0
if sigma_min is not None:
self.m = -float(sigma - sigma_min) / float(n_steps_annealing)
self.c = sigma
self.sigma_min = sigma_min
else:
self.m = 0.
self.c = sigma
self.sigma_min = sigma
@property
def current_sigma(self):
sigma = max(self.sigma_min, self.m * float(self.n_steps) + self.c)
return sigma
class GaussianWhiteNoiseProcess(AnnealedGaussianProcess):
def __init__(self, mu=0., sigma=1., sigma_min=None, n_steps_annealing=1000, size=1):
super(GaussianWhiteNoiseProcess, self).__init__(mu=mu, sigma=sigma, sigma_min=sigma_min, n_steps_annealing=n_steps_annealing)
self.size = size
def sample(self):
sample = np.random.normal(self.mu, self.current_sigma, self.size)
self.n_steps += 1
return sample
# Based on http://math.stackexchange.com/questions/1287634/implementing-ornstein-uhlenbeck-in-matlab
class OrnsteinUhlenbeckProcess(AnnealedGaussianProcess):
def __init__(self, theta, mu=0., sigma=1., dt=1e-2, x0=None, size=1, sigma_min=None, n_steps_annealing=1000):
super(OrnsteinUhlenbeckProcess, self).__init__(
mu=mu, sigma=sigma, sigma_min=sigma_min, n_steps_annealing=n_steps_annealing)
self.theta = theta
self.mu = mu
self.dt = dt
self.x0 = x0
self.size = size
self.reset_states()
def sample(self):
x = self.x_prev + self.theta * (self.mu - self.x_prev) * self.dt + self.current_sigma * np.sqrt(self.dt) * np.random.normal(size=self.size)
self.x_prev = x
self.n_steps += 1
return x
def reset_states(self):
self.x_prev = self.x0 if self.x0 is not None else np.zeros(self.size)

View File

@ -127,19 +127,50 @@ class ddpg_return:
""" """
compute the return as in DDPG. this seems to have to be special compute the return as in DDPG. this seems to have to be special
""" """
def __init__(self, actor, critic, use_target_network=True): def __init__(self, actor, critic, use_target_network=True, discount_factor=0.99):
self.actor = actor self.actor = actor
self.critic = critic self.critic = critic
self.use_target_network = use_target_network self.use_target_network = use_target_network
self.discount_factor = discount_factor
def __call__(self, buffer, index=None): def __call__(self, buffer, indexes=None):
""" """
:param buffer: buffer with property index and data. index determines the current content in `buffer`. :param buffer: buffer with property index and data. index determines the current content in `buffer`.
:param index: (sampled) index to be computed. Defaults to all the data in `buffer`. Not necessarily in order within :param index: (sampled) index to be computed. Defaults to all the data in `buffer`. Not necessarily in order within
each episode. each episode.
:return: dict with key 'return' and value the computed returns corresponding to `index`. :return: dict with key 'return' and value the computed returns corresponding to `index`.
""" """
pass indexes = indexes or buffer.index
episodes = buffer.data
returns = []
for i_episode in range(len(indexes)):
index_this = indexes[i_episode]
if index_this:
episode = episodes[i_episode]
returns_this = []
for i in index_this:
return_ = episode[i][REWARD]
if not episode[i][DONE]:
if self.use_target_network:
state = episode[i + 1][STATE][None]
action = self.actor.eval_action_old(state)
q_value = self.critic.eval_value_old(state, action)
return_ += self.discount_factor * q_value
else:
state = episode[i + 1][STATE][None]
action = self.actor.eval_action(state)
q_value = self.critic.eval_value(state, action)
return_ += self.discount_factor * q_value
returns_this.append(return_)
returns.append(returns_this)
else:
returns.append([])
return {'return': returns}
class nstep_q_return: class nstep_q_return:

View File

@ -48,6 +48,7 @@ class DataCollector(object):
if done: if done:
self.current_observation = self.env.reset() self.current_observation = self.env.reset()
self.policy.reset()
else: else:
self.current_observation = next_observation self.current_observation = next_observation
@ -61,6 +62,7 @@ class DataCollector(object):
next_observation, reward, done, _ = self.env.step(action) next_observation, reward, done, _ = self.env.step(action)
self.data_buffer.add((observation, action, reward, done)) self.data_buffer.add((observation, action, reward, done))
observation = next_observation observation = next_observation
self.current_observation = self.env.reset()
if self.process_mode == 'full': if self.process_mode == 'full':
for processor in self.process_functions: for processor in self.process_functions: