working on off-policy test. other parts of dqn_replay is runnable, but performance not tested.

This commit is contained in:
haoshengzou 2018-03-09 15:07:14 +08:00
parent e68dcd3c64
commit 92894d3853
10 changed files with 142 additions and 28 deletions

View File

@ -3,6 +3,8 @@ import tensorflow as tf
import gym import gym
import numpy as np import numpy as np
import time import time
import logging
logging.basicConfig(level=logging.INFO)
# our lib imports here! It's ok to append path in examples # our lib imports here! It's ok to append path in examples
import sys import sys
@ -12,8 +14,9 @@ import tianshou.data.advantage_estimation as advantage_estimation
import tianshou.core.policy.dqn as policy import tianshou.core.policy.dqn as policy
import tianshou.core.value_function.action_value as value_function import tianshou.core.value_function.action_value as value_function
from tianshou.data.replay_buffer.vanilla import VanillaReplayBuffer from tianshou.data.data_buffer.vanilla import VanillaReplayBuffer
from tianshou.data.data_collector import DataCollector from tianshou.data.data_collector import DataCollector
from tianshou.data.tester import test_policy_in_env
if __name__ == '__main__': if __name__ == '__main__':
@ -33,7 +36,7 @@ if __name__ == '__main__':
return None, action_values # no policy head return None, action_values # no policy head
### 2. build policy, loss, optimizer ### 2. build policy, loss, optimizer
dqn = value_function.DQN(my_network, observation_placeholder=observation_ph, weight_update=200) dqn = value_function.DQN(my_network, observation_placeholder=observation_ph, weight_update=800)
pi = policy.DQN(dqn) pi = policy.DQN(dqn)
dqn_loss = losses.qlearning(dqn) dqn_loss = losses.qlearning(dqn)
@ -43,7 +46,7 @@ if __name__ == '__main__':
train_op = optimizer.minimize(total_loss, var_list=dqn.trainable_variables) train_op = optimizer.minimize(total_loss, var_list=dqn.trainable_variables)
### 3. define data collection ### 3. define data collection
replay_buffer = VanillaReplayBuffer(capacity=1e5, nstep=1) replay_buffer = VanillaReplayBuffer(capacity=2e4, nstep=1)
process_functions = [advantage_estimation.nstep_q_return(1, dqn)] process_functions = [advantage_estimation.nstep_q_return(1, dqn)]
managed_networks = [dqn] managed_networks = [dqn]
@ -58,11 +61,11 @@ if __name__ == '__main__':
### 4. start training ### 4. start training
# hyper-parameters # hyper-parameters
batch_size = 256 batch_size = 128
replay_buffer_warmup = 1000 replay_buffer_warmup = 1000
epsilon_decay_interval = 200 epsilon_decay_interval = 500
epsilon = 0.3 epsilon = 0.6
test_interval = 1000 test_interval = 5000
seed = 0 seed = 0
np.random.seed(seed) np.random.seed(seed)
@ -74,11 +77,11 @@ if __name__ == '__main__':
sess.run(tf.global_variables_initializer()) sess.run(tf.global_variables_initializer())
# assign actor to pi_old # assign actor to pi_old
pi.sync_weights() # TODO: automate this for policies with target network pi.sync_weights() # TODO: rethink and redesign target network interface
start_time = time.time() start_time = time.time()
pi.set_epsilon_train(epsilon) pi.set_epsilon_train(epsilon)
data_collector.collect(num_timesteps=replay_buffer_warmup) # warm-up data_collector.collect(num_timesteps=replay_buffer_warmup) # TODO: uniform random warm-up
for i in range(int(1e8)): # number of training steps for i in range(int(1e8)): # number of training steps
# anneal epsilon step-wise # anneal epsilon step-wise
if (i + 1) % epsilon_decay_interval == 0 and epsilon > 0.1: if (i + 1) % epsilon_decay_interval == 0 and epsilon > 0.1:
@ -86,7 +89,7 @@ if __name__ == '__main__':
pi.set_epsilon_train(epsilon) pi.set_epsilon_train(epsilon)
# collect data # collect data
data_collector.collect() data_collector.collect(num_timesteps=4)
# update network # update network
feed_dict = data_collector.next_batch(batch_size) feed_dict = data_collector.next_batch(batch_size)
@ -95,7 +98,7 @@ if __name__ == '__main__':
# test every 1000 training steps # test every 1000 training steps
# tester could share some code with batch! # tester could share some code with batch!
if i % test_interval == 0: if i % test_interval == 0:
print('Elapsed time: {:.1f} min'.format((time.time() - start_time) / 60)) print('Step {}, elapsed time: {:.1f} min'.format(i, (time.time() - start_time) / 60))
# epsilon 0.05 as in nature paper # epsilon 0.05 as in nature paper
pi.set_epsilon_test(0.05) pi.set_epsilon_test(0.05)
#test(env, pi) # go for act_test of pi, not act test_policy_in_env(pi, env, num_timesteps=1000)

View File

@ -11,15 +11,18 @@ import argparse
import sys import sys
sys.path.append('..') sys.path.append('..')
from tianshou.core import losses from tianshou.core import losses
from tianshou.data.batch import Batch
import tianshou.data.advantage_estimation as advantage_estimation import tianshou.data.advantage_estimation as advantage_estimation
import tianshou.core.policy.stochastic as policy # TODO: fix imports as zhusuan so that only need to import to policy import tianshou.core.policy.stochastic as policy
from tianshou.data.data_buffer.vanilla import VanillaReplayBuffer
from tianshou.data.data_collector import DataCollector
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--render", action="store_true", default=False) parser.add_argument("--render", action="store_true", default=False)
args = parser.parse_args() args = parser.parse_args()
env = gym.make('CartPole-v0') env = gym.make('CartPole-v0')
observation_dim = env.observation_space.shape observation_dim = env.observation_space.shape
action_dim = env.action_space.n action_dim = env.action_space.n
@ -59,7 +62,7 @@ if __name__ == '__main__':
train_op = optimizer.minimize(total_loss, var_list=pi.trainable_variables) train_op = optimizer.minimize(total_loss, var_list=pi.trainable_variables)
### 3. define data collection ### 3. define data collection
training_data = Batch(env, pi, [advantage_estimation.full_return], [pi], render=args.render) data_collector = Batch(env, pi, [advantage_estimation.full_return], [pi], render=args.render)
### 4. start training ### 4. start training
config = tf.ConfigProto() config = tf.ConfigProto()
@ -73,15 +76,15 @@ if __name__ == '__main__':
start_time = time.time() start_time = time.time()
for i in range(100): for i in range(100):
# collect data # collect data
training_data.collect(num_episodes=50) data_collector.collect(num_episodes=50)
# print current return # print current return
print('Epoch {}:'.format(i)) print('Epoch {}:'.format(i))
training_data.statistics() data_collector.statistics()
# update network # update network
for _ in range(num_batches): for _ in range(num_batches):
feed_dict = training_data.next_batch(batch_size) feed_dict = data_collector.next_batch(batch_size)
sess.run(train_op, feed_dict=feed_dict) sess.run(train_op, feed_dict=feed_dict)
# assigning actor to pi_old # assigning actor to pi_old

View File

@ -1,13 +1,13 @@
class ReplayBufferBase(object): class DataBufferBase(object):
""" """
base class for replay buffer. base class for data buffer, including replay buffer as in DQN and batched dataset as in on-policy algos
""" """
def add(self, frame): def add(self, frame):
raise NotImplementedError() raise NotImplementedError()
def remove(self): def clear(self):
raise NotImplementedError() raise NotImplementedError()
def sample(self, batch_size): def sample(self, batch_size):

View File

@ -0,0 +1,24 @@
from .base import DataBufferBase
class BatchSet(DataBufferBase):
"""
class for batched dataset as used in on-policy algos
"""
def __init__(self):
self.data = [[]]
self.index = [[]]
self.candidate_index = 0
self.size = 0 # number of valid data points (not frames)
self.index_lengths = [0] # for sampling
def add(self, frame):
self.data[-1].append(frame)
def clear(self):
pass
def sample(self, batch_size):
pass

View File

@ -0,0 +1,12 @@
from .base import DataBufferBase
class ReplayBufferBase(DataBufferBase):
"""
base class for replay buffer.
"""
def remove(self):
"""
when size exceeds capacity, removes extra data points
:return:
"""
raise NotImplementedError()

View File

@ -1,13 +1,14 @@
import logging import logging
import numpy as np import numpy as np
from .base import ReplayBufferBase from .replay_buffer_base import ReplayBufferBase
STATE = 0 STATE = 0
ACTION = 1 ACTION = 1
REWARD = 2 REWARD = 2
DONE = 3 DONE = 3
# TODO: valid data points could be less than `nstep` timesteps
class VanillaReplayBuffer(ReplayBufferBase): class VanillaReplayBuffer(ReplayBufferBase):
""" """
vanilla replay buffer as used in (Mnih, et al., 2015). vanilla replay buffer as used in (Mnih, et al., 2015).

View File

@ -3,7 +3,8 @@ import logging
import itertools import itertools
import sys import sys
from .replay_buffer.base import ReplayBufferBase from .data_buffer.replay_buffer_base import ReplayBufferBase
from .data_buffer.batch_set import BatchSet
class DataCollector(object): class DataCollector(object):
""" """
@ -31,10 +32,13 @@ class DataCollector(object):
self.current_observation = self.env.reset() self.current_observation = self.env.reset()
def collect(self, num_timesteps=1, num_episodes=0, my_feed_dict={}): def collect(self, num_timesteps=1, num_episodes=0, my_feed_dict={}, auto_clear=True):
assert sum([num_timesteps > 0, num_episodes > 0]) == 1,\ assert sum([num_timesteps > 0, num_episodes > 0]) == 1,\
"One and only one collection number specification permitted!" "One and only one collection number specification permitted!"
if isinstance(self.data_buffer, BatchSet) and auto_clear:
self.data_buffer.clear()
if num_timesteps > 0: if num_timesteps > 0:
num_timesteps_ = int(num_timesteps) num_timesteps_ = int(num_timesteps)
for _ in range(num_timesteps_): for _ in range(num_timesteps_):

View File

@ -1,6 +1,6 @@
import numpy as np import numpy as np
from replay_buffer.vanilla import VanillaReplayBuffer from data_buffer.vanilla import VanillaReplayBuffer
capacity = 12 capacity = 12
nstep = 3 nstep = 3

View File

@ -1,8 +1,75 @@
from __future__ import absolute_import from __future__ import absolute_import
import gym
import logging
import numpy as np
def test_policy_in_env(policy, env, num_timesteps=0, num_episodes=0, discount_factor=0.99):
assert sum([num_episodes > 0, num_timesteps > 0]) == 1, \
'One and only one collection number specification permitted!'
def test_policy_in_env(policy, env):
# make another env as the original is for training data collection # make another env as the original is for training data collection
env_ = env env_id = env.spec.id
env_ = gym.make(env_id)
pass # test policy
returns = []
undiscounted_returns = []
current_return = 0.
current_undiscounted_return = 0.
if num_episodes > 0:
returns = [0.] * num_episodes
undiscounted_returns = [0.] * num_episodes
for i in range(num_episodes):
current_return = 0.
current_undiscounted_return = 0.
current_discount = 1.
observation = env_.reset()
done = False
while not done:
action = policy.act_test(observation)
observation, reward, done, _ = env_.step(action)
current_return += reward * current_discount
current_undiscounted_return += reward
current_discount *= discount_factor
returns[i] = current_return
undiscounted_returns[i] = current_undiscounted_return
# run for fix number of timesteps, only the first episode and finished episodes
# matters when calcuting average return
if num_timesteps > 0:
current_discount = 1.
observation = env_.reset()
for _ in range(num_timesteps):
action = policy.act_test(observation)
observation, reward, done, _ = env_.step(action)
current_return += reward * current_discount
current_undiscounted_return += reward
current_discount *= discount_factor
if done:
returns.append(current_return)
undiscounted_returns.append(current_undiscounted_return)
current_return = 0.
current_undiscounted_return = 0.
current_discount = 1.
observation = env_.reset()
# log
if returns: # has at least one finished episode
mean_return = np.mean(returns)
mean_undiscounted_return = np.mean(undiscounted_returns)
else: # the first episode is too long to finish
logging.warning('The first test episode is still not finished after {} timesteps. '
'Logging its return anyway.'.format(num_timesteps))
mean_return = current_return
mean_undiscounted_return = current_undiscounted_return
logging.info('Mean return: {}'.format(mean_return))
logging.info('Mean undiscounted return: {}'.format(mean_undiscounted_return))
# clear scene
env_.close()