2017-12-13 20:47:45 +08:00
#!/usr/bin/env python
import tensorflow as tf
import gym
# our lib imports here!
import sys
sys . path . append ( ' .. ' )
import tianshou . core . losses as losses
2017-12-17 12:52:00 +08:00
from tianshou . data . replay_buffer . utils import get_replay_buffer
2017-12-23 17:25:16 +08:00
import tianshou . core . policy . dqn as policy
2017-12-13 20:47:45 +08:00
def policy_net ( observation , action_dim ) :
"""
Constructs the policy network . NOT NEEDED IN THE LIBRARY ! this is pure tf
: param observation : Placeholder for the observation . A tensor of shape ( bs , x , y , channels )
: param action_dim : int . The number of actions .
: param scope : str . Specifying the scope of the variables .
"""
net = tf . layers . conv2d ( observation , 16 , 8 , 4 , ' valid ' , activation = tf . nn . relu )
net = tf . layers . conv2d ( net , 32 , 4 , 2 , ' valid ' , activation = tf . nn . relu )
net = tf . layers . flatten ( net )
net = tf . layers . dense ( net , 256 , activation = tf . nn . relu )
q_values = tf . layers . dense ( net , action_dim )
return q_values
if __name__ == ' __main__ ' :
env = gym . make ( ' PongNoFrameskip-v4 ' )
observation_dim = env . observation_space . shape
action_dim = env . action_space . n
# 1. build network with pure tf
2017-12-17 13:28:21 +08:00
# TODO:
# pass the observation variable to the replay buffer or find a more reasonable way to help replay buffer
# access this observation variable.
2017-12-17 12:52:00 +08:00
observation = tf . placeholder ( tf . float32 , shape = ( None , ) + observation_dim , name = " dqn_observation " ) # network input
2017-12-23 17:25:16 +08:00
action = tf . placeholder ( dtype = tf . int32 , shape = ( None , ) ) # batch of integer actions
2017-12-13 20:47:45 +08:00
with tf . variable_scope ( ' q_net ' ) :
q_values = policy_net ( observation , action_dim )
with tf . variable_scope ( ' target_net ' ) :
q_values_target = policy_net ( observation , action_dim )
# 2. build losses, optimizers
2017-12-23 17:25:16 +08:00
q_net = policy . DQNRefactor ( q_values , observation_placeholder = observation , action_placeholder = action ) # YongRen: policy.DQN
target_net = policy . DQNRefactor ( q_values_target , observation_placeholder = observation , action_placeholder = action )
2017-12-13 20:47:45 +08:00
target = tf . placeholder ( dtype = tf . float32 , shape = [ None ] ) # target value for DQN
2017-12-14 19:46:38 +08:00
dqn_loss = losses . dqn_loss ( action , target , q_net ) # TongzhengRen
2017-12-17 12:52:00 +08:00
global_step = tf . Variable ( 0 , name = ' global_step ' , trainable = False )
train_var_list = tf . get_collection (
tf . GraphKeys . TRAINABLE_VARIABLES ) # TODO: better management of TRAINABLE_VARIABLES
2017-12-13 20:47:45 +08:00
total_loss = dqn_loss
optimizer = tf . train . AdamOptimizer ( 1e-3 )
2017-12-17 12:52:00 +08:00
train_op = optimizer . minimize ( total_loss , var_list = train_var_list , global_step = tf . train . get_global_step ( ) )
2017-12-13 20:47:45 +08:00
# 3. define data collection
2017-12-17 13:28:21 +08:00
# configuration should be given as parameters, different replay buffer has different parameters.
2017-12-17 12:52:00 +08:00
replay_memory = get_replay_buffer ( ' rank_based ' , env , q_values , q_net , target_net ,
{ ' size ' : 1000 , ' batch_size ' : 64 , ' learn_start ' : 20 } )
2017-12-14 19:46:38 +08:00
# ShihongSong: Replay(env, q_net, advantage_estimation.qlearning_target(target_network)), use your ReplayMemory, interact as follows. Simplify your advantage_estimation.dqn to run before YongRen's DQN
2017-12-13 20:47:45 +08:00
# maybe a dict to manage the elements to be collected
# 4. start training
with tf . Session ( ) as sess :
sess . run ( tf . global_variables_initializer ( ) )
minibatch_count = 0
collection_count = 0
2017-12-17 13:28:21 +08:00
# need to first collect some then sample, collect_freq must be larger than batch_size
2017-12-17 12:52:00 +08:00
collect_freq = 100
2017-12-13 20:47:45 +08:00
while True : # until some stopping criterion met...
# collect data
2017-12-17 12:52:00 +08:00
for i in range ( 0 , collect_freq ) :
replay_memory . collect ( ) # ShihongSong
collection_count + = 1
print ( ' Collected {} times. ' . format ( collection_count ) )
2017-12-13 20:47:45 +08:00
# update network
2017-12-17 12:52:00 +08:00
data = replay_memory . next_batch ( 10 ) # YouQiaoben, ShihongSong
2017-12-13 20:47:45 +08:00
# TODO: auto managing of the placeholders? or add this to params of data.Batch
sess . run ( train_op , feed_dict = { observation : data [ ' observations ' ] , action : data [ ' actions ' ] , target : data [ ' target ' ] } )
minibatch_count + = 1
print ( ' Trained {} minibatches. ' . format ( minibatch_count ) )
# TODO: assigning pi to pi_old is not implemented yet