actor critic also works. fix some bugs in nstep_q_return. dqn still trains slow.
This commit is contained in:
parent
498b55c051
commit
a86354834c
@ -5,18 +5,20 @@ import tensorflow as tf
|
|||||||
import time
|
import time
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import gym
|
import gym
|
||||||
|
import logging
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
# our lib imports here! It's ok to append path in examples
|
# our lib imports here! It's ok to append path in examples
|
||||||
import sys
|
import sys
|
||||||
sys.path.append('..')
|
sys.path.append('..')
|
||||||
from tianshou.core import losses
|
from tianshou.core import losses
|
||||||
from tianshou.data.batch import Batch
|
|
||||||
import tianshou.data.advantage_estimation as advantage_estimation
|
import tianshou.data.advantage_estimation as advantage_estimation
|
||||||
import tianshou.core.policy.stochastic as policy # TODO: fix imports as zhusuan so that only need to import to policy
|
import tianshou.core.policy.stochastic as policy
|
||||||
import tianshou.core.value_function.state_value as value_function
|
import tianshou.core.value_function.state_value as value_function
|
||||||
|
|
||||||
|
from tianshou.data.data_buffer.batch_set import BatchSet
|
||||||
|
from tianshou.data.data_collector import DataCollector
|
||||||
|
|
||||||
# for tutorial purpose, placeholders are explicitly appended with '_ph' suffix
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
env = gym.make('CartPole-v0')
|
env = gym.make('CartPole-v0')
|
||||||
@ -25,9 +27,9 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
clip_param = 0.2
|
clip_param = 0.2
|
||||||
num_batches = 10
|
num_batches = 10
|
||||||
batch_size = 128
|
batch_size = 512
|
||||||
|
|
||||||
seed = 10
|
seed = 0
|
||||||
np.random.seed(seed)
|
np.random.seed(seed)
|
||||||
tf.set_random_seed(seed)
|
tf.set_random_seed(seed)
|
||||||
|
|
||||||
@ -36,13 +38,13 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
def my_network():
|
def my_network():
|
||||||
# placeholders defined in this function would be very difficult to manage
|
# placeholders defined in this function would be very difficult to manage
|
||||||
net = tf.layers.dense(observation_ph, 32, activation=tf.nn.tanh)
|
net = tf.layers.dense(observation_ph, 64, activation=tf.nn.tanh)
|
||||||
net = tf.layers.dense(net, 32, activation=tf.nn.tanh)
|
net = tf.layers.dense(net, 64, activation=tf.nn.tanh)
|
||||||
|
|
||||||
action_logtis = tf.layers.dense(net, action_dim, activation=None)
|
action_logits = tf.layers.dense(net, action_dim, activation=None)
|
||||||
value = tf.layers.dense(net, 1, activation=None)
|
value = tf.layers.dense(net, 1, activation=None)
|
||||||
|
|
||||||
return action_logtis, value
|
return action_logits, value
|
||||||
# TODO: overriding seems not able to handle shared layers, unless a new class `SharedPolicyValue`
|
# TODO: overriding seems not able to handle shared layers, unless a new class `SharedPolicyValue`
|
||||||
# maybe the most desired thing is to freely build policy and value function from any tensor?
|
# maybe the most desired thing is to freely build policy and value function from any tensor?
|
||||||
# but for now, only the outputs of the network matters
|
# but for now, only the outputs of the network matters
|
||||||
@ -53,7 +55,7 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
actor_loss = losses.REINFORCE(actor)
|
actor_loss = losses.REINFORCE(actor)
|
||||||
critic_loss = losses.value_mse(critic)
|
critic_loss = losses.value_mse(critic)
|
||||||
total_loss = actor_loss + critic_loss
|
total_loss = actor_loss + 1e-2 * critic_loss
|
||||||
|
|
||||||
optimizer = tf.train.AdamOptimizer(1e-4)
|
optimizer = tf.train.AdamOptimizer(1e-4)
|
||||||
|
|
||||||
@ -63,10 +65,15 @@ if __name__ == '__main__':
|
|||||||
train_op = optimizer.minimize(total_loss, var_list=var_list)
|
train_op = optimizer.minimize(total_loss, var_list=var_list)
|
||||||
|
|
||||||
### 3. define data collection
|
### 3. define data collection
|
||||||
data_collector = Batch(env, actor,
|
data_buffer = BatchSet()
|
||||||
[advantage_estimation.gae_lambda(1, critic), advantage_estimation.nstep_return(1, critic)],
|
|
||||||
[actor, critic])
|
data_collector = DataCollector(
|
||||||
# TODO: refactor this, data_collector should be just the top-level abstraction
|
env=env,
|
||||||
|
policy=actor,
|
||||||
|
data_buffer=data_buffer,
|
||||||
|
process_functions=[advantage_estimation.nstep_return(n=3, value_function=critic, return_advantage=True)],
|
||||||
|
managed_networks=[actor, critic],
|
||||||
|
)
|
||||||
|
|
||||||
### 4. start training
|
### 4. start training
|
||||||
config = tf.ConfigProto()
|
config = tf.ConfigProto()
|
||||||
@ -75,13 +82,13 @@ if __name__ == '__main__':
|
|||||||
sess.run(tf.global_variables_initializer())
|
sess.run(tf.global_variables_initializer())
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
for i in range(100):
|
for i in range(int(1e6)):
|
||||||
# collect data
|
# collect data
|
||||||
data_collector.collect(num_episodes=20)
|
data_collector.collect(num_episodes=50)
|
||||||
|
|
||||||
# print current return
|
# print current return
|
||||||
print('Epoch {}:'.format(i))
|
print('Epoch {}:'.format(i))
|
||||||
data_collector.statistics()
|
data_buffer.statistics()
|
||||||
|
|
||||||
# update network
|
# update network
|
||||||
for _ in range(num_batches):
|
for _ in range(num_batches):
|
@ -6,7 +6,8 @@ ACTION = 1
|
|||||||
REWARD = 2
|
REWARD = 2
|
||||||
DONE = 3
|
DONE = 3
|
||||||
|
|
||||||
# modified for new interfaces
|
|
||||||
|
# TODO: add discount_factor... maybe make it to be a global config?
|
||||||
def full_return(buffer, indexes=None):
|
def full_return(buffer, indexes=None):
|
||||||
"""
|
"""
|
||||||
naively compute full return
|
naively compute full return
|
||||||
@ -67,18 +68,59 @@ class nstep_return:
|
|||||||
"""
|
"""
|
||||||
compute the n-step return from n-step rewards and bootstrapped value function
|
compute the n-step return from n-step rewards and bootstrapped value function
|
||||||
"""
|
"""
|
||||||
def __init__(self, n, value_function):
|
def __init__(self, n, value_function, return_advantage=False, discount_factor=0.99):
|
||||||
self.n = n
|
self.n = n
|
||||||
self.value_function = value_function
|
self.value_function = value_function
|
||||||
|
self.return_advantage = return_advantage
|
||||||
|
self.discount_factor = discount_factor
|
||||||
|
|
||||||
def __call__(self, buffer, index=None):
|
def __call__(self, buffer, indexes=None):
|
||||||
"""
|
"""
|
||||||
:param buffer: buffer with property index and data. index determines the current content in `buffer`.
|
:param buffer: buffer with property index and data. index determines the current content in `buffer`.
|
||||||
:param index: (sampled) index to be computed. Defaults to all the data in `buffer`. Not necessarily in order within
|
:param indexes: (sampled) index to be computed. Defaults to all the data in `buffer`. Not necessarily in order within
|
||||||
each episode.
|
each episode.
|
||||||
:return: dict with key 'return' and value the computed returns corresponding to `index`.
|
:return: dict with key 'return' and value the computed returns corresponding to `index`.
|
||||||
"""
|
"""
|
||||||
pass
|
indexes = indexes or buffer.index
|
||||||
|
episodes = buffer.data
|
||||||
|
returns = []
|
||||||
|
advantages = []
|
||||||
|
|
||||||
|
for i_episode in range(len(indexes)):
|
||||||
|
index_this = indexes[i_episode]
|
||||||
|
if index_this:
|
||||||
|
episode = episodes[i_episode]
|
||||||
|
returns_this = []
|
||||||
|
advantages_this = []
|
||||||
|
|
||||||
|
for i in index_this:
|
||||||
|
current_discount_factor = 1.
|
||||||
|
last_frame_index = i
|
||||||
|
return_ = 0.
|
||||||
|
for last_frame_index in range(i, min(len(episode), i + self.n)):
|
||||||
|
return_ += current_discount_factor * episode[last_frame_index][REWARD]
|
||||||
|
current_discount_factor *= self.discount_factor
|
||||||
|
if episode[last_frame_index][DONE]:
|
||||||
|
break
|
||||||
|
if not episode[last_frame_index][DONE]:
|
||||||
|
state = episode[last_frame_index + 1][STATE]
|
||||||
|
v_sT = self.value_function.eval_value(state[None])
|
||||||
|
return_ += current_discount_factor * v_sT
|
||||||
|
returns_this.append(return_)
|
||||||
|
if self.return_advantage:
|
||||||
|
v_s0 = self.value_function.eval_value(episode[i][STATE][None])
|
||||||
|
advantages_this.append(return_ - v_s0)
|
||||||
|
|
||||||
|
returns.append(returns_this)
|
||||||
|
advantages.append(advantages_this)
|
||||||
|
else:
|
||||||
|
returns.append([])
|
||||||
|
advantages.append([])
|
||||||
|
|
||||||
|
if self.return_advantage:
|
||||||
|
return {'return': returns, 'advantage':advantages}
|
||||||
|
else:
|
||||||
|
return {'return': returns}
|
||||||
|
|
||||||
|
|
||||||
class ddpg_return:
|
class ddpg_return:
|
||||||
@ -128,18 +170,16 @@ class nstep_q_return:
|
|||||||
episode_q = []
|
episode_q = []
|
||||||
|
|
||||||
for i in index:
|
for i in index:
|
||||||
current_discount_factor = 1
|
current_discount_factor = 1.
|
||||||
last_frame_index = i
|
last_frame_index = i
|
||||||
target_q = episode[i][REWARD]
|
target_q = 0.
|
||||||
for lfi in range(i, min(len(episode), i + self.n + 1)):
|
for last_frame_index in range(i, min(len(episode), i + self.n)):
|
||||||
if episode[lfi][DONE]:
|
target_q += current_discount_factor * episode[last_frame_index][REWARD]
|
||||||
break
|
|
||||||
target_q += current_discount_factor * episode[lfi][REWARD]
|
|
||||||
current_discount_factor *= self.discount_factor
|
current_discount_factor *= self.discount_factor
|
||||||
last_frame_index = lfi
|
if episode[last_frame_index][DONE]:
|
||||||
if last_frame_index > i:
|
break
|
||||||
state = episode[last_frame_index][STATE]
|
if not episode[last_frame_index][DONE]: # not done will definitely have one frame later
|
||||||
|
state = episode[last_frame_index + 1][STATE]
|
||||||
if self.use_target_network:
|
if self.use_target_network:
|
||||||
# [None] adds one dimension to the beginning
|
# [None] adds one dimension to the beginning
|
||||||
qpredict = self.action_value.eval_value_all_actions_old(state[None])
|
qpredict = self.action_value.eval_value_all_actions_old(state[None])
|
||||||
|
Loading…
x
Reference in New Issue
Block a user