finish ddpg. now ppo, actor-critic, dqn works. ddpg is not working, check!

This commit is contained in:
haoshengzou 2018-03-11 17:47:42 +08:00
parent a86354834c
commit 52e6b09768
6 changed files with 158 additions and 19 deletions

View File

@ -6,29 +6,33 @@ import gym
import numpy as np
import time
import argparse
import logging
logging.basicConfig(level=logging.INFO)
# our lib imports here! It's ok to append path in examples
import sys
sys.path.append('..')
from tianshou.core import losses
from tianshou.data.batch import Batch
import tianshou.data.advantage_estimation as advantage_estimation
import tianshou.core.policy as policy
import tianshou.core.value_function.action_value as value_function
import tianshou.core.opt as opt
from tianshou.data.data_buffer.vanilla import VanillaReplayBuffer
from tianshou.data.data_collector import DataCollector
from tianshou.data.tester import test_policy_in_env
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--render", action="store_true", default=False)
args = parser.parse_args()
env = gym.make('Pendulum-v0')
env = gym.make('MountainCarContinuous-v0')
observation_dim = env.observation_space.shape
action_dim = env.action_space.shape
clip_param = 0.2
num_batches = 10
batch_size = 512
batch_size = 32
seed = 0
np.random.seed(seed)
@ -59,13 +63,22 @@ if __name__ == '__main__':
critic_optimizer = tf.train.AdamOptimizer(1e-3)
critic_train_op = critic_optimizer.minimize(critic_loss, var_list=critic.trainable_variables)
dpg_grads = opt.DPG(actor, critic) # not sure if it's correct
dpg_grads = opt.DPG(actor, critic) # check which action to use in dpg
actor_optimizer = tf.train.AdamOptimizer(1e-4)
actor_train_op = actor_optimizer.apply_gradients(dpg_grads)
### 3. define data collection
data_collector = Batch(env, actor, [advantage_estimation.ddpg_return(actor, critic)], [actor, critic],
render = args.render)
data_buffer = VanillaReplayBuffer(capacity=2e4, nstep=1)
process_functions = [advantage_estimation.ddpg_return(actor, critic)]
data_collector = DataCollector(
env=env,
policy=actor,
data_buffer=data_buffer,
process_functions=process_functions,
managed_networks=[actor, critic]
)
### 4. start training
config = tf.ConfigProto()
@ -74,22 +87,25 @@ if __name__ == '__main__':
sess.run(tf.global_variables_initializer())
# assign actor to pi_old
actor.sync_weights() # TODO: automate this for policies with target network
actor.sync_weights()
critic.sync_weights()
start_time = time.time()
data_collector.collect(num_timesteps=1e3) # warm-up
for i in range(int(1e8)):
# collect data
data_collector.collect()
data_collector.collect(num_timesteps=1)
# update network
for _ in range(num_batches):
feed_dict = data_collector.next_batch(batch_size)
sess.run([actor_train_op, critic_train_op], feed_dict=feed_dict)
feed_dict = data_collector.next_batch(batch_size)
sess.run(critic_train_op, feed_dict=feed_dict)
sess.run(actor_train_op, feed_dict=feed_dict)
print('Elapsed time: {:.1f} min'.format((time.time() - start_time) / 60))
# update target networks
actor.sync_weights()
critic.sync_weights()
# test every 1000 training steps
if i % 1000 == 0:
test(env, actor)
print('Step {}, elapsed time: {:.1f} min'.format(i, (time.time() - start_time) / 60))
test_policy_in_env(actor, env, num_timesteps=100)

View File

@ -19,6 +19,13 @@ class PolicyBase(object):
def act(self, observation, my_feed_dict):
raise NotImplementedError()
def reset(self):
"""
for temporal correlated random process exploration, as in DDPG
:return:
"""
pass
class StochasticPolicy(PolicyBase):
"""

View File

@ -1,11 +1,12 @@
import tensorflow as tf
from .base import PolicyBase
from ..random import OrnsteinUhlenbeckProcess
class Deterministic(PolicyBase):
"""
deterministic policy as used in deterministic policy gradient methods
"""
def __init__(self, policy_callable, observation_placeholder, weight_update=1):
def __init__(self, policy_callable, observation_placeholder, weight_update=1, random_process=None):
self._observation_placeholder = observation_placeholder
self.managed_placeholders = {'observation': observation_placeholder}
self.weight_update = weight_update
@ -49,6 +50,9 @@ class Deterministic(PolicyBase):
import math
self.weight_update = math.ceil(weight_update)
self.random_process = random_process or OrnsteinUhlenbeckProcess(
theta=0.15, sigma=0.2, size=self.action.shape.as_list()[-1])
@property
def action_shape(self):
return self.action.shape.as_list()[1:]
@ -62,6 +66,21 @@ class Deterministic(PolicyBase):
feed_dict.update(my_feed_dict)
sampled_action = sess.run(self.action, feed_dict=feed_dict)
sampled_action = sampled_action[0] + self.random_process.sample()
return sampled_action
def reset(self):
self.random_process.reset_states()
def act_test(self, observation, my_feed_dict={}):
sess = tf.get_default_session()
# observation[None] adds one dimension at the beginning
feed_dict = {self._observation_placeholder: observation[None]}
feed_dict.update(my_feed_dict)
sampled_action = sess.run(self.action, feed_dict=feed_dict)
sampled_action = sampled_action[0]
return sampled_action

64
tianshou/core/random.py Normal file
View File

@ -0,0 +1,64 @@
"""
adapted from keras-rl
"""
from __future__ import division
import numpy as np
class RandomProcess(object):
def reset_states(self):
pass
class AnnealedGaussianProcess(RandomProcess):
def __init__(self, mu, sigma, sigma_min, n_steps_annealing):
self.mu = mu
self.sigma = sigma
self.n_steps = 0
if sigma_min is not None:
self.m = -float(sigma - sigma_min) / float(n_steps_annealing)
self.c = sigma
self.sigma_min = sigma_min
else:
self.m = 0.
self.c = sigma
self.sigma_min = sigma
@property
def current_sigma(self):
sigma = max(self.sigma_min, self.m * float(self.n_steps) + self.c)
return sigma
class GaussianWhiteNoiseProcess(AnnealedGaussianProcess):
def __init__(self, mu=0., sigma=1., sigma_min=None, n_steps_annealing=1000, size=1):
super(GaussianWhiteNoiseProcess, self).__init__(mu=mu, sigma=sigma, sigma_min=sigma_min, n_steps_annealing=n_steps_annealing)
self.size = size
def sample(self):
sample = np.random.normal(self.mu, self.current_sigma, self.size)
self.n_steps += 1
return sample
# Based on http://math.stackexchange.com/questions/1287634/implementing-ornstein-uhlenbeck-in-matlab
class OrnsteinUhlenbeckProcess(AnnealedGaussianProcess):
def __init__(self, theta, mu=0., sigma=1., dt=1e-2, x0=None, size=1, sigma_min=None, n_steps_annealing=1000):
super(OrnsteinUhlenbeckProcess, self).__init__(
mu=mu, sigma=sigma, sigma_min=sigma_min, n_steps_annealing=n_steps_annealing)
self.theta = theta
self.mu = mu
self.dt = dt
self.x0 = x0
self.size = size
self.reset_states()
def sample(self):
x = self.x_prev + self.theta * (self.mu - self.x_prev) * self.dt + self.current_sigma * np.sqrt(self.dt) * np.random.normal(size=self.size)
self.x_prev = x
self.n_steps += 1
return x
def reset_states(self):
self.x_prev = self.x0 if self.x0 is not None else np.zeros(self.size)

View File

@ -127,19 +127,50 @@ class ddpg_return:
"""
compute the return as in DDPG. this seems to have to be special
"""
def __init__(self, actor, critic, use_target_network=True):
def __init__(self, actor, critic, use_target_network=True, discount_factor=0.99):
self.actor = actor
self.critic = critic
self.use_target_network = use_target_network
self.discount_factor = discount_factor
def __call__(self, buffer, index=None):
def __call__(self, buffer, indexes=None):
"""
:param buffer: buffer with property index and data. index determines the current content in `buffer`.
:param index: (sampled) index to be computed. Defaults to all the data in `buffer`. Not necessarily in order within
each episode.
:return: dict with key 'return' and value the computed returns corresponding to `index`.
"""
pass
indexes = indexes or buffer.index
episodes = buffer.data
returns = []
for i_episode in range(len(indexes)):
index_this = indexes[i_episode]
if index_this:
episode = episodes[i_episode]
returns_this = []
for i in index_this:
return_ = episode[i][REWARD]
if not episode[i][DONE]:
if self.use_target_network:
state = episode[i + 1][STATE][None]
action = self.actor.eval_action_old(state)
q_value = self.critic.eval_value_old(state, action)
return_ += self.discount_factor * q_value
else:
state = episode[i + 1][STATE][None]
action = self.actor.eval_action(state)
q_value = self.critic.eval_value(state, action)
return_ += self.discount_factor * q_value
returns_this.append(return_)
returns.append(returns_this)
else:
returns.append([])
return {'return': returns}
class nstep_q_return:

View File

@ -48,6 +48,7 @@ class DataCollector(object):
if done:
self.current_observation = self.env.reset()
self.policy.reset()
else:
self.current_observation = next_observation
@ -61,6 +62,7 @@ class DataCollector(object):
next_observation, reward, done, _ = self.env.step(action)
self.data_buffer.add((observation, action, reward, done))
observation = next_observation
self.current_observation = self.env.reset()
if self.process_mode == 'full':
for processor in self.process_functions: