actor critic also works. fix some bugs in nstep_q_return. dqn still trains slow.

This commit is contained in:
haoshengzou 2018-03-11 15:07:41 +08:00
parent 498b55c051
commit a86354834c
4 changed files with 79 additions and 32 deletions

View File

@ -5,18 +5,20 @@ import tensorflow as tf
import time import time
import numpy as np import numpy as np
import gym import gym
import logging
logging.basicConfig(level=logging.INFO)
# our lib imports here! It's ok to append path in examples # our lib imports here! It's ok to append path in examples
import sys import sys
sys.path.append('..') sys.path.append('..')
from tianshou.core import losses from tianshou.core import losses
from tianshou.data.batch import Batch
import tianshou.data.advantage_estimation as advantage_estimation import tianshou.data.advantage_estimation as advantage_estimation
import tianshou.core.policy.stochastic as policy # TODO: fix imports as zhusuan so that only need to import to policy import tianshou.core.policy.stochastic as policy
import tianshou.core.value_function.state_value as value_function import tianshou.core.value_function.state_value as value_function
from tianshou.data.data_buffer.batch_set import BatchSet
from tianshou.data.data_collector import DataCollector
# for tutorial purpose, placeholders are explicitly appended with '_ph' suffix
if __name__ == '__main__': if __name__ == '__main__':
env = gym.make('CartPole-v0') env = gym.make('CartPole-v0')
@ -25,9 +27,9 @@ if __name__ == '__main__':
clip_param = 0.2 clip_param = 0.2
num_batches = 10 num_batches = 10
batch_size = 128 batch_size = 512
seed = 10 seed = 0
np.random.seed(seed) np.random.seed(seed)
tf.set_random_seed(seed) tf.set_random_seed(seed)
@ -36,13 +38,13 @@ if __name__ == '__main__':
def my_network(): def my_network():
# placeholders defined in this function would be very difficult to manage # placeholders defined in this function would be very difficult to manage
net = tf.layers.dense(observation_ph, 32, activation=tf.nn.tanh) net = tf.layers.dense(observation_ph, 64, activation=tf.nn.tanh)
net = tf.layers.dense(net, 32, activation=tf.nn.tanh) net = tf.layers.dense(net, 64, activation=tf.nn.tanh)
action_logtis = tf.layers.dense(net, action_dim, activation=None) action_logits = tf.layers.dense(net, action_dim, activation=None)
value = tf.layers.dense(net, 1, activation=None) value = tf.layers.dense(net, 1, activation=None)
return action_logtis, value return action_logits, value
# TODO: overriding seems not able to handle shared layers, unless a new class `SharedPolicyValue` # TODO: overriding seems not able to handle shared layers, unless a new class `SharedPolicyValue`
# maybe the most desired thing is to freely build policy and value function from any tensor? # maybe the most desired thing is to freely build policy and value function from any tensor?
# but for now, only the outputs of the network matters # but for now, only the outputs of the network matters
@ -53,7 +55,7 @@ if __name__ == '__main__':
actor_loss = losses.REINFORCE(actor) actor_loss = losses.REINFORCE(actor)
critic_loss = losses.value_mse(critic) critic_loss = losses.value_mse(critic)
total_loss = actor_loss + critic_loss total_loss = actor_loss + 1e-2 * critic_loss
optimizer = tf.train.AdamOptimizer(1e-4) optimizer = tf.train.AdamOptimizer(1e-4)
@ -63,10 +65,15 @@ if __name__ == '__main__':
train_op = optimizer.minimize(total_loss, var_list=var_list) train_op = optimizer.minimize(total_loss, var_list=var_list)
### 3. define data collection ### 3. define data collection
data_collector = Batch(env, actor, data_buffer = BatchSet()
[advantage_estimation.gae_lambda(1, critic), advantage_estimation.nstep_return(1, critic)],
[actor, critic]) data_collector = DataCollector(
# TODO: refactor this, data_collector should be just the top-level abstraction env=env,
policy=actor,
data_buffer=data_buffer,
process_functions=[advantage_estimation.nstep_return(n=3, value_function=critic, return_advantage=True)],
managed_networks=[actor, critic],
)
### 4. start training ### 4. start training
config = tf.ConfigProto() config = tf.ConfigProto()
@ -75,13 +82,13 @@ if __name__ == '__main__':
sess.run(tf.global_variables_initializer()) sess.run(tf.global_variables_initializer())
start_time = time.time() start_time = time.time()
for i in range(100): for i in range(int(1e6)):
# collect data # collect data
data_collector.collect(num_episodes=20) data_collector.collect(num_episodes=50)
# print current return # print current return
print('Epoch {}:'.format(i)) print('Epoch {}:'.format(i))
data_collector.statistics() data_buffer.statistics()
# update network # update network
for _ in range(num_batches): for _ in range(num_batches):

View File

@ -6,7 +6,8 @@ ACTION = 1
REWARD = 2 REWARD = 2
DONE = 3 DONE = 3
# modified for new interfaces
# TODO: add discount_factor... maybe make it to be a global config?
def full_return(buffer, indexes=None): def full_return(buffer, indexes=None):
""" """
naively compute full return naively compute full return
@ -67,18 +68,59 @@ class nstep_return:
""" """
compute the n-step return from n-step rewards and bootstrapped value function compute the n-step return from n-step rewards and bootstrapped value function
""" """
def __init__(self, n, value_function): def __init__(self, n, value_function, return_advantage=False, discount_factor=0.99):
self.n = n self.n = n
self.value_function = value_function self.value_function = value_function
self.return_advantage = return_advantage
self.discount_factor = discount_factor
def __call__(self, buffer, index=None): def __call__(self, buffer, indexes=None):
""" """
:param buffer: buffer with property index and data. index determines the current content in `buffer`. :param buffer: buffer with property index and data. index determines the current content in `buffer`.
:param index: (sampled) index to be computed. Defaults to all the data in `buffer`. Not necessarily in order within :param indexes: (sampled) index to be computed. Defaults to all the data in `buffer`. Not necessarily in order within
each episode. each episode.
:return: dict with key 'return' and value the computed returns corresponding to `index`. :return: dict with key 'return' and value the computed returns corresponding to `index`.
""" """
pass indexes = indexes or buffer.index
episodes = buffer.data
returns = []
advantages = []
for i_episode in range(len(indexes)):
index_this = indexes[i_episode]
if index_this:
episode = episodes[i_episode]
returns_this = []
advantages_this = []
for i in index_this:
current_discount_factor = 1.
last_frame_index = i
return_ = 0.
for last_frame_index in range(i, min(len(episode), i + self.n)):
return_ += current_discount_factor * episode[last_frame_index][REWARD]
current_discount_factor *= self.discount_factor
if episode[last_frame_index][DONE]:
break
if not episode[last_frame_index][DONE]:
state = episode[last_frame_index + 1][STATE]
v_sT = self.value_function.eval_value(state[None])
return_ += current_discount_factor * v_sT
returns_this.append(return_)
if self.return_advantage:
v_s0 = self.value_function.eval_value(episode[i][STATE][None])
advantages_this.append(return_ - v_s0)
returns.append(returns_this)
advantages.append(advantages_this)
else:
returns.append([])
advantages.append([])
if self.return_advantage:
return {'return': returns, 'advantage':advantages}
else:
return {'return': returns}
class ddpg_return: class ddpg_return:
@ -128,18 +170,16 @@ class nstep_q_return:
episode_q = [] episode_q = []
for i in index: for i in index:
current_discount_factor = 1 current_discount_factor = 1.
last_frame_index = i last_frame_index = i
target_q = episode[i][REWARD] target_q = 0.
for lfi in range(i, min(len(episode), i + self.n + 1)): for last_frame_index in range(i, min(len(episode), i + self.n)):
if episode[lfi][DONE]: target_q += current_discount_factor * episode[last_frame_index][REWARD]
break
target_q += current_discount_factor * episode[lfi][REWARD]
current_discount_factor *= self.discount_factor current_discount_factor *= self.discount_factor
last_frame_index = lfi if episode[last_frame_index][DONE]:
if last_frame_index > i: break
state = episode[last_frame_index][STATE] if not episode[last_frame_index][DONE]: # not done will definitely have one frame later
state = episode[last_frame_index + 1][STATE]
if self.use_target_network: if self.use_target_network:
# [None] adds one dimension to the beginning # [None] adds one dimension to the beginning
qpredict = self.action_value.eval_value_all_actions_old(state[None]) qpredict = self.action_value.eval_value_all_actions_old(state[None])