actor critic also works. fix some bugs in nstep_q_return. dqn still trains slow.

This commit is contained in:
haoshengzou 2018-03-11 15:07:41 +08:00
parent 498b55c051
commit a86354834c
4 changed files with 79 additions and 32 deletions

View File

@ -5,18 +5,20 @@ import tensorflow as tf
import time
import numpy as np
import gym
import logging
logging.basicConfig(level=logging.INFO)
# our lib imports here! It's ok to append path in examples
import sys
sys.path.append('..')
from tianshou.core import losses
from tianshou.data.batch import Batch
import tianshou.data.advantage_estimation as advantage_estimation
import tianshou.core.policy.stochastic as policy # TODO: fix imports as zhusuan so that only need to import to policy
import tianshou.core.policy.stochastic as policy
import tianshou.core.value_function.state_value as value_function
from tianshou.data.data_buffer.batch_set import BatchSet
from tianshou.data.data_collector import DataCollector
# for tutorial purpose, placeholders are explicitly appended with '_ph' suffix
if __name__ == '__main__':
env = gym.make('CartPole-v0')
@ -25,9 +27,9 @@ if __name__ == '__main__':
clip_param = 0.2
num_batches = 10
batch_size = 128
batch_size = 512
seed = 10
seed = 0
np.random.seed(seed)
tf.set_random_seed(seed)
@ -36,13 +38,13 @@ if __name__ == '__main__':
def my_network():
# placeholders defined in this function would be very difficult to manage
net = tf.layers.dense(observation_ph, 32, activation=tf.nn.tanh)
net = tf.layers.dense(net, 32, activation=tf.nn.tanh)
net = tf.layers.dense(observation_ph, 64, activation=tf.nn.tanh)
net = tf.layers.dense(net, 64, activation=tf.nn.tanh)
action_logtis = tf.layers.dense(net, action_dim, activation=None)
action_logits = tf.layers.dense(net, action_dim, activation=None)
value = tf.layers.dense(net, 1, activation=None)
return action_logtis, value
return action_logits, value
# TODO: overriding seems not able to handle shared layers, unless a new class `SharedPolicyValue`
# maybe the most desired thing is to freely build policy and value function from any tensor?
# but for now, only the outputs of the network matters
@ -53,7 +55,7 @@ if __name__ == '__main__':
actor_loss = losses.REINFORCE(actor)
critic_loss = losses.value_mse(critic)
total_loss = actor_loss + critic_loss
total_loss = actor_loss + 1e-2 * critic_loss
optimizer = tf.train.AdamOptimizer(1e-4)
@ -63,10 +65,15 @@ if __name__ == '__main__':
train_op = optimizer.minimize(total_loss, var_list=var_list)
### 3. define data collection
data_collector = Batch(env, actor,
[advantage_estimation.gae_lambda(1, critic), advantage_estimation.nstep_return(1, critic)],
[actor, critic])
# TODO: refactor this, data_collector should be just the top-level abstraction
data_buffer = BatchSet()
data_collector = DataCollector(
env=env,
policy=actor,
data_buffer=data_buffer,
process_functions=[advantage_estimation.nstep_return(n=3, value_function=critic, return_advantage=True)],
managed_networks=[actor, critic],
)
### 4. start training
config = tf.ConfigProto()
@ -75,13 +82,13 @@ if __name__ == '__main__':
sess.run(tf.global_variables_initializer())
start_time = time.time()
for i in range(100):
for i in range(int(1e6)):
# collect data
data_collector.collect(num_episodes=20)
data_collector.collect(num_episodes=50)
# print current return
print('Epoch {}:'.format(i))
data_collector.statistics()
data_buffer.statistics()
# update network
for _ in range(num_batches):

View File

@ -6,7 +6,8 @@ ACTION = 1
REWARD = 2
DONE = 3
# modified for new interfaces
# TODO: add discount_factor... maybe make it to be a global config?
def full_return(buffer, indexes=None):
"""
naively compute full return
@ -67,18 +68,59 @@ class nstep_return:
"""
compute the n-step return from n-step rewards and bootstrapped value function
"""
def __init__(self, n, value_function):
def __init__(self, n, value_function, return_advantage=False, discount_factor=0.99):
self.n = n
self.value_function = value_function
self.return_advantage = return_advantage
self.discount_factor = discount_factor
def __call__(self, buffer, index=None):
def __call__(self, buffer, indexes=None):
"""
:param buffer: buffer with property index and data. index determines the current content in `buffer`.
:param index: (sampled) index to be computed. Defaults to all the data in `buffer`. Not necessarily in order within
:param indexes: (sampled) index to be computed. Defaults to all the data in `buffer`. Not necessarily in order within
each episode.
:return: dict with key 'return' and value the computed returns corresponding to `index`.
"""
pass
indexes = indexes or buffer.index
episodes = buffer.data
returns = []
advantages = []
for i_episode in range(len(indexes)):
index_this = indexes[i_episode]
if index_this:
episode = episodes[i_episode]
returns_this = []
advantages_this = []
for i in index_this:
current_discount_factor = 1.
last_frame_index = i
return_ = 0.
for last_frame_index in range(i, min(len(episode), i + self.n)):
return_ += current_discount_factor * episode[last_frame_index][REWARD]
current_discount_factor *= self.discount_factor
if episode[last_frame_index][DONE]:
break
if not episode[last_frame_index][DONE]:
state = episode[last_frame_index + 1][STATE]
v_sT = self.value_function.eval_value(state[None])
return_ += current_discount_factor * v_sT
returns_this.append(return_)
if self.return_advantage:
v_s0 = self.value_function.eval_value(episode[i][STATE][None])
advantages_this.append(return_ - v_s0)
returns.append(returns_this)
advantages.append(advantages_this)
else:
returns.append([])
advantages.append([])
if self.return_advantage:
return {'return': returns, 'advantage':advantages}
else:
return {'return': returns}
class ddpg_return:
@ -128,18 +170,16 @@ class nstep_q_return:
episode_q = []
for i in index:
current_discount_factor = 1
current_discount_factor = 1.
last_frame_index = i
target_q = episode[i][REWARD]
for lfi in range(i, min(len(episode), i + self.n + 1)):
if episode[lfi][DONE]:
break
target_q += current_discount_factor * episode[lfi][REWARD]
target_q = 0.
for last_frame_index in range(i, min(len(episode), i + self.n)):
target_q += current_discount_factor * episode[last_frame_index][REWARD]
current_discount_factor *= self.discount_factor
last_frame_index = lfi
if last_frame_index > i:
state = episode[last_frame_index][STATE]
if episode[last_frame_index][DONE]:
break
if not episode[last_frame_index][DONE]: # not done will definitely have one frame later
state = episode[last_frame_index + 1][STATE]
if self.use_target_network:
# [None] adds one dimension to the beginning
qpredict = self.action_value.eval_value_all_actions_old(state[None])