finish ddpg. now ppo, actor-critic, dqn works. ddpg is not working, check!
This commit is contained in:
parent
a86354834c
commit
52e6b09768
@ -6,29 +6,33 @@ import gym
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import time
|
import time
|
||||||
import argparse
|
import argparse
|
||||||
|
import logging
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
# our lib imports here! It's ok to append path in examples
|
# our lib imports here! It's ok to append path in examples
|
||||||
import sys
|
import sys
|
||||||
sys.path.append('..')
|
sys.path.append('..')
|
||||||
from tianshou.core import losses
|
from tianshou.core import losses
|
||||||
from tianshou.data.batch import Batch
|
|
||||||
import tianshou.data.advantage_estimation as advantage_estimation
|
import tianshou.data.advantage_estimation as advantage_estimation
|
||||||
import tianshou.core.policy as policy
|
import tianshou.core.policy as policy
|
||||||
import tianshou.core.value_function.action_value as value_function
|
import tianshou.core.value_function.action_value as value_function
|
||||||
import tianshou.core.opt as opt
|
import tianshou.core.opt as opt
|
||||||
|
|
||||||
|
from tianshou.data.data_buffer.vanilla import VanillaReplayBuffer
|
||||||
|
from tianshou.data.data_collector import DataCollector
|
||||||
|
from tianshou.data.tester import test_policy_in_env
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--render", action="store_true", default=False)
|
parser.add_argument("--render", action="store_true", default=False)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
env = gym.make('Pendulum-v0')
|
|
||||||
|
env = gym.make('MountainCarContinuous-v0')
|
||||||
observation_dim = env.observation_space.shape
|
observation_dim = env.observation_space.shape
|
||||||
action_dim = env.action_space.shape
|
action_dim = env.action_space.shape
|
||||||
|
|
||||||
clip_param = 0.2
|
batch_size = 32
|
||||||
num_batches = 10
|
|
||||||
batch_size = 512
|
|
||||||
|
|
||||||
seed = 0
|
seed = 0
|
||||||
np.random.seed(seed)
|
np.random.seed(seed)
|
||||||
@ -59,13 +63,22 @@ if __name__ == '__main__':
|
|||||||
critic_optimizer = tf.train.AdamOptimizer(1e-3)
|
critic_optimizer = tf.train.AdamOptimizer(1e-3)
|
||||||
critic_train_op = critic_optimizer.minimize(critic_loss, var_list=critic.trainable_variables)
|
critic_train_op = critic_optimizer.minimize(critic_loss, var_list=critic.trainable_variables)
|
||||||
|
|
||||||
dpg_grads = opt.DPG(actor, critic) # not sure if it's correct
|
dpg_grads = opt.DPG(actor, critic) # check which action to use in dpg
|
||||||
actor_optimizer = tf.train.AdamOptimizer(1e-4)
|
actor_optimizer = tf.train.AdamOptimizer(1e-4)
|
||||||
actor_train_op = actor_optimizer.apply_gradients(dpg_grads)
|
actor_train_op = actor_optimizer.apply_gradients(dpg_grads)
|
||||||
|
|
||||||
### 3. define data collection
|
### 3. define data collection
|
||||||
data_collector = Batch(env, actor, [advantage_estimation.ddpg_return(actor, critic)], [actor, critic],
|
data_buffer = VanillaReplayBuffer(capacity=2e4, nstep=1)
|
||||||
render = args.render)
|
|
||||||
|
process_functions = [advantage_estimation.ddpg_return(actor, critic)]
|
||||||
|
|
||||||
|
data_collector = DataCollector(
|
||||||
|
env=env,
|
||||||
|
policy=actor,
|
||||||
|
data_buffer=data_buffer,
|
||||||
|
process_functions=process_functions,
|
||||||
|
managed_networks=[actor, critic]
|
||||||
|
)
|
||||||
|
|
||||||
### 4. start training
|
### 4. start training
|
||||||
config = tf.ConfigProto()
|
config = tf.ConfigProto()
|
||||||
@ -74,22 +87,25 @@ if __name__ == '__main__':
|
|||||||
sess.run(tf.global_variables_initializer())
|
sess.run(tf.global_variables_initializer())
|
||||||
|
|
||||||
# assign actor to pi_old
|
# assign actor to pi_old
|
||||||
actor.sync_weights() # TODO: automate this for policies with target network
|
actor.sync_weights()
|
||||||
critic.sync_weights()
|
critic.sync_weights()
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
data_collector.collect(num_timesteps=1e3) # warm-up
|
data_collector.collect(num_timesteps=1e3) # warm-up
|
||||||
for i in range(int(1e8)):
|
for i in range(int(1e8)):
|
||||||
# collect data
|
# collect data
|
||||||
data_collector.collect()
|
data_collector.collect(num_timesteps=1)
|
||||||
|
|
||||||
# update network
|
# update network
|
||||||
for _ in range(num_batches):
|
|
||||||
feed_dict = data_collector.next_batch(batch_size)
|
feed_dict = data_collector.next_batch(batch_size)
|
||||||
sess.run([actor_train_op, critic_train_op], feed_dict=feed_dict)
|
sess.run(critic_train_op, feed_dict=feed_dict)
|
||||||
|
sess.run(actor_train_op, feed_dict=feed_dict)
|
||||||
|
|
||||||
print('Elapsed time: {:.1f} min'.format((time.time() - start_time) / 60))
|
# update target networks
|
||||||
|
actor.sync_weights()
|
||||||
|
critic.sync_weights()
|
||||||
|
|
||||||
# test every 1000 training steps
|
# test every 1000 training steps
|
||||||
if i % 1000 == 0:
|
if i % 1000 == 0:
|
||||||
test(env, actor)
|
print('Step {}, elapsed time: {:.1f} min'.format(i, (time.time() - start_time) / 60))
|
||||||
|
test_policy_in_env(actor, env, num_timesteps=100)
|
||||||
|
@ -19,6 +19,13 @@ class PolicyBase(object):
|
|||||||
def act(self, observation, my_feed_dict):
|
def act(self, observation, my_feed_dict):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""
|
||||||
|
for temporal correlated random process exploration, as in DDPG
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class StochasticPolicy(PolicyBase):
|
class StochasticPolicy(PolicyBase):
|
||||||
"""
|
"""
|
||||||
|
@ -1,11 +1,12 @@
|
|||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from .base import PolicyBase
|
from .base import PolicyBase
|
||||||
|
from ..random import OrnsteinUhlenbeckProcess
|
||||||
|
|
||||||
class Deterministic(PolicyBase):
|
class Deterministic(PolicyBase):
|
||||||
"""
|
"""
|
||||||
deterministic policy as used in deterministic policy gradient methods
|
deterministic policy as used in deterministic policy gradient methods
|
||||||
"""
|
"""
|
||||||
def __init__(self, policy_callable, observation_placeholder, weight_update=1):
|
def __init__(self, policy_callable, observation_placeholder, weight_update=1, random_process=None):
|
||||||
self._observation_placeholder = observation_placeholder
|
self._observation_placeholder = observation_placeholder
|
||||||
self.managed_placeholders = {'observation': observation_placeholder}
|
self.managed_placeholders = {'observation': observation_placeholder}
|
||||||
self.weight_update = weight_update
|
self.weight_update = weight_update
|
||||||
@ -49,6 +50,9 @@ class Deterministic(PolicyBase):
|
|||||||
import math
|
import math
|
||||||
self.weight_update = math.ceil(weight_update)
|
self.weight_update = math.ceil(weight_update)
|
||||||
|
|
||||||
|
self.random_process = random_process or OrnsteinUhlenbeckProcess(
|
||||||
|
theta=0.15, sigma=0.2, size=self.action.shape.as_list()[-1])
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def action_shape(self):
|
def action_shape(self):
|
||||||
return self.action.shape.as_list()[1:]
|
return self.action.shape.as_list()[1:]
|
||||||
@ -62,6 +66,21 @@ class Deterministic(PolicyBase):
|
|||||||
feed_dict.update(my_feed_dict)
|
feed_dict.update(my_feed_dict)
|
||||||
sampled_action = sess.run(self.action, feed_dict=feed_dict)
|
sampled_action = sess.run(self.action, feed_dict=feed_dict)
|
||||||
|
|
||||||
|
sampled_action = sampled_action[0] + self.random_process.sample()
|
||||||
|
|
||||||
|
return sampled_action
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.random_process.reset_states()
|
||||||
|
|
||||||
|
def act_test(self, observation, my_feed_dict={}):
|
||||||
|
sess = tf.get_default_session()
|
||||||
|
# observation[None] adds one dimension at the beginning
|
||||||
|
|
||||||
|
feed_dict = {self._observation_placeholder: observation[None]}
|
||||||
|
feed_dict.update(my_feed_dict)
|
||||||
|
sampled_action = sess.run(self.action, feed_dict=feed_dict)
|
||||||
|
|
||||||
sampled_action = sampled_action[0]
|
sampled_action = sampled_action[0]
|
||||||
|
|
||||||
return sampled_action
|
return sampled_action
|
||||||
|
64
tianshou/core/random.py
Normal file
64
tianshou/core/random.py
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
"""
|
||||||
|
adapted from keras-rl
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import division
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class RandomProcess(object):
|
||||||
|
def reset_states(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class AnnealedGaussianProcess(RandomProcess):
|
||||||
|
def __init__(self, mu, sigma, sigma_min, n_steps_annealing):
|
||||||
|
self.mu = mu
|
||||||
|
self.sigma = sigma
|
||||||
|
self.n_steps = 0
|
||||||
|
|
||||||
|
if sigma_min is not None:
|
||||||
|
self.m = -float(sigma - sigma_min) / float(n_steps_annealing)
|
||||||
|
self.c = sigma
|
||||||
|
self.sigma_min = sigma_min
|
||||||
|
else:
|
||||||
|
self.m = 0.
|
||||||
|
self.c = sigma
|
||||||
|
self.sigma_min = sigma
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_sigma(self):
|
||||||
|
sigma = max(self.sigma_min, self.m * float(self.n_steps) + self.c)
|
||||||
|
return sigma
|
||||||
|
|
||||||
|
|
||||||
|
class GaussianWhiteNoiseProcess(AnnealedGaussianProcess):
|
||||||
|
def __init__(self, mu=0., sigma=1., sigma_min=None, n_steps_annealing=1000, size=1):
|
||||||
|
super(GaussianWhiteNoiseProcess, self).__init__(mu=mu, sigma=sigma, sigma_min=sigma_min, n_steps_annealing=n_steps_annealing)
|
||||||
|
self.size = size
|
||||||
|
|
||||||
|
def sample(self):
|
||||||
|
sample = np.random.normal(self.mu, self.current_sigma, self.size)
|
||||||
|
self.n_steps += 1
|
||||||
|
return sample
|
||||||
|
|
||||||
|
# Based on http://math.stackexchange.com/questions/1287634/implementing-ornstein-uhlenbeck-in-matlab
|
||||||
|
class OrnsteinUhlenbeckProcess(AnnealedGaussianProcess):
|
||||||
|
def __init__(self, theta, mu=0., sigma=1., dt=1e-2, x0=None, size=1, sigma_min=None, n_steps_annealing=1000):
|
||||||
|
super(OrnsteinUhlenbeckProcess, self).__init__(
|
||||||
|
mu=mu, sigma=sigma, sigma_min=sigma_min, n_steps_annealing=n_steps_annealing)
|
||||||
|
self.theta = theta
|
||||||
|
self.mu = mu
|
||||||
|
self.dt = dt
|
||||||
|
self.x0 = x0
|
||||||
|
self.size = size
|
||||||
|
self.reset_states()
|
||||||
|
|
||||||
|
def sample(self):
|
||||||
|
x = self.x_prev + self.theta * (self.mu - self.x_prev) * self.dt + self.current_sigma * np.sqrt(self.dt) * np.random.normal(size=self.size)
|
||||||
|
self.x_prev = x
|
||||||
|
self.n_steps += 1
|
||||||
|
return x
|
||||||
|
|
||||||
|
def reset_states(self):
|
||||||
|
self.x_prev = self.x0 if self.x0 is not None else np.zeros(self.size)
|
@ -127,19 +127,50 @@ class ddpg_return:
|
|||||||
"""
|
"""
|
||||||
compute the return as in DDPG. this seems to have to be special
|
compute the return as in DDPG. this seems to have to be special
|
||||||
"""
|
"""
|
||||||
def __init__(self, actor, critic, use_target_network=True):
|
def __init__(self, actor, critic, use_target_network=True, discount_factor=0.99):
|
||||||
self.actor = actor
|
self.actor = actor
|
||||||
self.critic = critic
|
self.critic = critic
|
||||||
self.use_target_network = use_target_network
|
self.use_target_network = use_target_network
|
||||||
|
self.discount_factor = discount_factor
|
||||||
|
|
||||||
def __call__(self, buffer, index=None):
|
def __call__(self, buffer, indexes=None):
|
||||||
"""
|
"""
|
||||||
:param buffer: buffer with property index and data. index determines the current content in `buffer`.
|
:param buffer: buffer with property index and data. index determines the current content in `buffer`.
|
||||||
:param index: (sampled) index to be computed. Defaults to all the data in `buffer`. Not necessarily in order within
|
:param index: (sampled) index to be computed. Defaults to all the data in `buffer`. Not necessarily in order within
|
||||||
each episode.
|
each episode.
|
||||||
:return: dict with key 'return' and value the computed returns corresponding to `index`.
|
:return: dict with key 'return' and value the computed returns corresponding to `index`.
|
||||||
"""
|
"""
|
||||||
pass
|
indexes = indexes or buffer.index
|
||||||
|
episodes = buffer.data
|
||||||
|
returns = []
|
||||||
|
|
||||||
|
for i_episode in range(len(indexes)):
|
||||||
|
index_this = indexes[i_episode]
|
||||||
|
if index_this:
|
||||||
|
episode = episodes[i_episode]
|
||||||
|
returns_this = []
|
||||||
|
|
||||||
|
for i in index_this:
|
||||||
|
return_ = episode[i][REWARD]
|
||||||
|
if not episode[i][DONE]:
|
||||||
|
if self.use_target_network:
|
||||||
|
state = episode[i + 1][STATE][None]
|
||||||
|
action = self.actor.eval_action_old(state)
|
||||||
|
q_value = self.critic.eval_value_old(state, action)
|
||||||
|
return_ += self.discount_factor * q_value
|
||||||
|
else:
|
||||||
|
state = episode[i + 1][STATE][None]
|
||||||
|
action = self.actor.eval_action(state)
|
||||||
|
q_value = self.critic.eval_value(state, action)
|
||||||
|
return_ += self.discount_factor * q_value
|
||||||
|
|
||||||
|
returns_this.append(return_)
|
||||||
|
|
||||||
|
returns.append(returns_this)
|
||||||
|
else:
|
||||||
|
returns.append([])
|
||||||
|
|
||||||
|
return {'return': returns}
|
||||||
|
|
||||||
|
|
||||||
class nstep_q_return:
|
class nstep_q_return:
|
||||||
|
@ -48,6 +48,7 @@ class DataCollector(object):
|
|||||||
|
|
||||||
if done:
|
if done:
|
||||||
self.current_observation = self.env.reset()
|
self.current_observation = self.env.reset()
|
||||||
|
self.policy.reset()
|
||||||
else:
|
else:
|
||||||
self.current_observation = next_observation
|
self.current_observation = next_observation
|
||||||
|
|
||||||
@ -61,6 +62,7 @@ class DataCollector(object):
|
|||||||
next_observation, reward, done, _ = self.env.step(action)
|
next_observation, reward, done, _ = self.env.step(action)
|
||||||
self.data_buffer.add((observation, action, reward, done))
|
self.data_buffer.add((observation, action, reward, done))
|
||||||
observation = next_observation
|
observation = next_observation
|
||||||
|
self.current_observation = self.env.reset()
|
||||||
|
|
||||||
if self.process_mode == 'full':
|
if self.process_mode == 'full':
|
||||||
for processor in self.process_functions:
|
for processor in self.process_functions:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user