working on off-policy test. other parts of dqn_replay is runnable, but performance not tested.
This commit is contained in:
parent
e68dcd3c64
commit
92894d3853
@ -3,6 +3,8 @@ import tensorflow as tf
|
|||||||
import gym
|
import gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import time
|
import time
|
||||||
|
import logging
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
# our lib imports here! It's ok to append path in examples
|
# our lib imports here! It's ok to append path in examples
|
||||||
import sys
|
import sys
|
||||||
@ -12,8 +14,9 @@ import tianshou.data.advantage_estimation as advantage_estimation
|
|||||||
import tianshou.core.policy.dqn as policy
|
import tianshou.core.policy.dqn as policy
|
||||||
import tianshou.core.value_function.action_value as value_function
|
import tianshou.core.value_function.action_value as value_function
|
||||||
|
|
||||||
from tianshou.data.replay_buffer.vanilla import VanillaReplayBuffer
|
from tianshou.data.data_buffer.vanilla import VanillaReplayBuffer
|
||||||
from tianshou.data.data_collector import DataCollector
|
from tianshou.data.data_collector import DataCollector
|
||||||
|
from tianshou.data.tester import test_policy_in_env
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
@ -33,7 +36,7 @@ if __name__ == '__main__':
|
|||||||
return None, action_values # no policy head
|
return None, action_values # no policy head
|
||||||
|
|
||||||
### 2. build policy, loss, optimizer
|
### 2. build policy, loss, optimizer
|
||||||
dqn = value_function.DQN(my_network, observation_placeholder=observation_ph, weight_update=200)
|
dqn = value_function.DQN(my_network, observation_placeholder=observation_ph, weight_update=800)
|
||||||
pi = policy.DQN(dqn)
|
pi = policy.DQN(dqn)
|
||||||
|
|
||||||
dqn_loss = losses.qlearning(dqn)
|
dqn_loss = losses.qlearning(dqn)
|
||||||
@ -43,7 +46,7 @@ if __name__ == '__main__':
|
|||||||
train_op = optimizer.minimize(total_loss, var_list=dqn.trainable_variables)
|
train_op = optimizer.minimize(total_loss, var_list=dqn.trainable_variables)
|
||||||
|
|
||||||
### 3. define data collection
|
### 3. define data collection
|
||||||
replay_buffer = VanillaReplayBuffer(capacity=1e5, nstep=1)
|
replay_buffer = VanillaReplayBuffer(capacity=2e4, nstep=1)
|
||||||
|
|
||||||
process_functions = [advantage_estimation.nstep_q_return(1, dqn)]
|
process_functions = [advantage_estimation.nstep_q_return(1, dqn)]
|
||||||
managed_networks = [dqn]
|
managed_networks = [dqn]
|
||||||
@ -58,11 +61,11 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
### 4. start training
|
### 4. start training
|
||||||
# hyper-parameters
|
# hyper-parameters
|
||||||
batch_size = 256
|
batch_size = 128
|
||||||
replay_buffer_warmup = 1000
|
replay_buffer_warmup = 1000
|
||||||
epsilon_decay_interval = 200
|
epsilon_decay_interval = 500
|
||||||
epsilon = 0.3
|
epsilon = 0.6
|
||||||
test_interval = 1000
|
test_interval = 5000
|
||||||
|
|
||||||
seed = 0
|
seed = 0
|
||||||
np.random.seed(seed)
|
np.random.seed(seed)
|
||||||
@ -74,11 +77,11 @@ if __name__ == '__main__':
|
|||||||
sess.run(tf.global_variables_initializer())
|
sess.run(tf.global_variables_initializer())
|
||||||
|
|
||||||
# assign actor to pi_old
|
# assign actor to pi_old
|
||||||
pi.sync_weights() # TODO: automate this for policies with target network
|
pi.sync_weights() # TODO: rethink and redesign target network interface
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
pi.set_epsilon_train(epsilon)
|
pi.set_epsilon_train(epsilon)
|
||||||
data_collector.collect(num_timesteps=replay_buffer_warmup) # warm-up
|
data_collector.collect(num_timesteps=replay_buffer_warmup) # TODO: uniform random warm-up
|
||||||
for i in range(int(1e8)): # number of training steps
|
for i in range(int(1e8)): # number of training steps
|
||||||
# anneal epsilon step-wise
|
# anneal epsilon step-wise
|
||||||
if (i + 1) % epsilon_decay_interval == 0 and epsilon > 0.1:
|
if (i + 1) % epsilon_decay_interval == 0 and epsilon > 0.1:
|
||||||
@ -86,7 +89,7 @@ if __name__ == '__main__':
|
|||||||
pi.set_epsilon_train(epsilon)
|
pi.set_epsilon_train(epsilon)
|
||||||
|
|
||||||
# collect data
|
# collect data
|
||||||
data_collector.collect()
|
data_collector.collect(num_timesteps=4)
|
||||||
|
|
||||||
# update network
|
# update network
|
||||||
feed_dict = data_collector.next_batch(batch_size)
|
feed_dict = data_collector.next_batch(batch_size)
|
||||||
@ -95,7 +98,7 @@ if __name__ == '__main__':
|
|||||||
# test every 1000 training steps
|
# test every 1000 training steps
|
||||||
# tester could share some code with batch!
|
# tester could share some code with batch!
|
||||||
if i % test_interval == 0:
|
if i % test_interval == 0:
|
||||||
print('Elapsed time: {:.1f} min'.format((time.time() - start_time) / 60))
|
print('Step {}, elapsed time: {:.1f} min'.format(i, (time.time() - start_time) / 60))
|
||||||
# epsilon 0.05 as in nature paper
|
# epsilon 0.05 as in nature paper
|
||||||
pi.set_epsilon_test(0.05)
|
pi.set_epsilon_test(0.05)
|
||||||
#test(env, pi) # go for act_test of pi, not act
|
test_policy_in_env(pi, env, num_timesteps=1000)
|
||||||
|
@ -11,15 +11,18 @@ import argparse
|
|||||||
import sys
|
import sys
|
||||||
sys.path.append('..')
|
sys.path.append('..')
|
||||||
from tianshou.core import losses
|
from tianshou.core import losses
|
||||||
from tianshou.data.batch import Batch
|
|
||||||
import tianshou.data.advantage_estimation as advantage_estimation
|
import tianshou.data.advantage_estimation as advantage_estimation
|
||||||
import tianshou.core.policy.stochastic as policy # TODO: fix imports as zhusuan so that only need to import to policy
|
import tianshou.core.policy.stochastic as policy
|
||||||
|
|
||||||
|
from tianshou.data.data_buffer.vanilla import VanillaReplayBuffer
|
||||||
|
from tianshou.data.data_collector import DataCollector
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--render", action="store_true", default=False)
|
parser.add_argument("--render", action="store_true", default=False)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
env = gym.make('CartPole-v0')
|
env = gym.make('CartPole-v0')
|
||||||
observation_dim = env.observation_space.shape
|
observation_dim = env.observation_space.shape
|
||||||
action_dim = env.action_space.n
|
action_dim = env.action_space.n
|
||||||
@ -59,7 +62,7 @@ if __name__ == '__main__':
|
|||||||
train_op = optimizer.minimize(total_loss, var_list=pi.trainable_variables)
|
train_op = optimizer.minimize(total_loss, var_list=pi.trainable_variables)
|
||||||
|
|
||||||
### 3. define data collection
|
### 3. define data collection
|
||||||
training_data = Batch(env, pi, [advantage_estimation.full_return], [pi], render=args.render)
|
data_collector = Batch(env, pi, [advantage_estimation.full_return], [pi], render=args.render)
|
||||||
|
|
||||||
### 4. start training
|
### 4. start training
|
||||||
config = tf.ConfigProto()
|
config = tf.ConfigProto()
|
||||||
@ -73,15 +76,15 @@ if __name__ == '__main__':
|
|||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
for i in range(100):
|
for i in range(100):
|
||||||
# collect data
|
# collect data
|
||||||
training_data.collect(num_episodes=50)
|
data_collector.collect(num_episodes=50)
|
||||||
|
|
||||||
# print current return
|
# print current return
|
||||||
print('Epoch {}:'.format(i))
|
print('Epoch {}:'.format(i))
|
||||||
training_data.statistics()
|
data_collector.statistics()
|
||||||
|
|
||||||
# update network
|
# update network
|
||||||
for _ in range(num_batches):
|
for _ in range(num_batches):
|
||||||
feed_dict = training_data.next_batch(batch_size)
|
feed_dict = data_collector.next_batch(batch_size)
|
||||||
sess.run(train_op, feed_dict=feed_dict)
|
sess.run(train_op, feed_dict=feed_dict)
|
||||||
|
|
||||||
# assigning actor to pi_old
|
# assigning actor to pi_old
|
||||||
|
@ -1,13 +1,13 @@
|
|||||||
|
|
||||||
|
|
||||||
class ReplayBufferBase(object):
|
class DataBufferBase(object):
|
||||||
"""
|
"""
|
||||||
base class for replay buffer.
|
base class for data buffer, including replay buffer as in DQN and batched dataset as in on-policy algos
|
||||||
"""
|
"""
|
||||||
def add(self, frame):
|
def add(self, frame):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def remove(self):
|
def clear(self):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def sample(self, batch_size):
|
def sample(self, batch_size):
|
24
tianshou/data/data_buffer/batch_set.py
Normal file
24
tianshou/data/data_buffer/batch_set.py
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
from .base import DataBufferBase
|
||||||
|
|
||||||
|
|
||||||
|
class BatchSet(DataBufferBase):
|
||||||
|
"""
|
||||||
|
class for batched dataset as used in on-policy algos
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
self.data = [[]]
|
||||||
|
self.index = [[]]
|
||||||
|
self.candidate_index = 0
|
||||||
|
|
||||||
|
self.size = 0 # number of valid data points (not frames)
|
||||||
|
|
||||||
|
self.index_lengths = [0] # for sampling
|
||||||
|
|
||||||
|
def add(self, frame):
|
||||||
|
self.data[-1].append(frame)
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def sample(self, batch_size):
|
||||||
|
pass
|
12
tianshou/data/data_buffer/replay_buffer_base.py
Normal file
12
tianshou/data/data_buffer/replay_buffer_base.py
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
from .base import DataBufferBase
|
||||||
|
|
||||||
|
class ReplayBufferBase(DataBufferBase):
|
||||||
|
"""
|
||||||
|
base class for replay buffer.
|
||||||
|
"""
|
||||||
|
def remove(self):
|
||||||
|
"""
|
||||||
|
when size exceeds capacity, removes extra data points
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
@ -1,13 +1,14 @@
|
|||||||
import logging
|
import logging
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from .base import ReplayBufferBase
|
from .replay_buffer_base import ReplayBufferBase
|
||||||
|
|
||||||
STATE = 0
|
STATE = 0
|
||||||
ACTION = 1
|
ACTION = 1
|
||||||
REWARD = 2
|
REWARD = 2
|
||||||
DONE = 3
|
DONE = 3
|
||||||
|
|
||||||
|
# TODO: valid data points could be less than `nstep` timesteps
|
||||||
class VanillaReplayBuffer(ReplayBufferBase):
|
class VanillaReplayBuffer(ReplayBufferBase):
|
||||||
"""
|
"""
|
||||||
vanilla replay buffer as used in (Mnih, et al., 2015).
|
vanilla replay buffer as used in (Mnih, et al., 2015).
|
@ -3,7 +3,8 @@ import logging
|
|||||||
import itertools
|
import itertools
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
from .replay_buffer.base import ReplayBufferBase
|
from .data_buffer.replay_buffer_base import ReplayBufferBase
|
||||||
|
from .data_buffer.batch_set import BatchSet
|
||||||
|
|
||||||
class DataCollector(object):
|
class DataCollector(object):
|
||||||
"""
|
"""
|
||||||
@ -31,10 +32,13 @@ class DataCollector(object):
|
|||||||
|
|
||||||
self.current_observation = self.env.reset()
|
self.current_observation = self.env.reset()
|
||||||
|
|
||||||
def collect(self, num_timesteps=1, num_episodes=0, my_feed_dict={}):
|
def collect(self, num_timesteps=1, num_episodes=0, my_feed_dict={}, auto_clear=True):
|
||||||
assert sum([num_timesteps > 0, num_episodes > 0]) == 1,\
|
assert sum([num_timesteps > 0, num_episodes > 0]) == 1,\
|
||||||
"One and only one collection number specification permitted!"
|
"One and only one collection number specification permitted!"
|
||||||
|
|
||||||
|
if isinstance(self.data_buffer, BatchSet) and auto_clear:
|
||||||
|
self.data_buffer.clear()
|
||||||
|
|
||||||
if num_timesteps > 0:
|
if num_timesteps > 0:
|
||||||
num_timesteps_ = int(num_timesteps)
|
num_timesteps_ = int(num_timesteps)
|
||||||
for _ in range(num_timesteps_):
|
for _ in range(num_timesteps_):
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from replay_buffer.vanilla import VanillaReplayBuffer
|
from data_buffer.vanilla import VanillaReplayBuffer
|
||||||
|
|
||||||
capacity = 12
|
capacity = 12
|
||||||
nstep = 3
|
nstep = 3
|
||||||
|
@ -1,8 +1,75 @@
|
|||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
|
|
||||||
|
import gym
|
||||||
|
import logging
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def test_policy_in_env(policy, env, num_timesteps=0, num_episodes=0, discount_factor=0.99):
|
||||||
|
|
||||||
|
assert sum([num_episodes > 0, num_timesteps > 0]) == 1, \
|
||||||
|
'One and only one collection number specification permitted!'
|
||||||
|
|
||||||
def test_policy_in_env(policy, env):
|
|
||||||
# make another env as the original is for training data collection
|
# make another env as the original is for training data collection
|
||||||
env_ = env
|
env_id = env.spec.id
|
||||||
|
env_ = gym.make(env_id)
|
||||||
|
|
||||||
pass
|
# test policy
|
||||||
|
returns = []
|
||||||
|
undiscounted_returns = []
|
||||||
|
current_return = 0.
|
||||||
|
current_undiscounted_return = 0.
|
||||||
|
|
||||||
|
if num_episodes > 0:
|
||||||
|
returns = [0.] * num_episodes
|
||||||
|
undiscounted_returns = [0.] * num_episodes
|
||||||
|
for i in range(num_episodes):
|
||||||
|
current_return = 0.
|
||||||
|
current_undiscounted_return = 0.
|
||||||
|
current_discount = 1.
|
||||||
|
observation = env_.reset()
|
||||||
|
done = False
|
||||||
|
while not done:
|
||||||
|
action = policy.act_test(observation)
|
||||||
|
observation, reward, done, _ = env_.step(action)
|
||||||
|
current_return += reward * current_discount
|
||||||
|
current_undiscounted_return += reward
|
||||||
|
current_discount *= discount_factor
|
||||||
|
|
||||||
|
returns[i] = current_return
|
||||||
|
undiscounted_returns[i] = current_undiscounted_return
|
||||||
|
|
||||||
|
# run for fix number of timesteps, only the first episode and finished episodes
|
||||||
|
# matters when calcuting average return
|
||||||
|
if num_timesteps > 0:
|
||||||
|
current_discount = 1.
|
||||||
|
observation = env_.reset()
|
||||||
|
for _ in range(num_timesteps):
|
||||||
|
action = policy.act_test(observation)
|
||||||
|
observation, reward, done, _ = env_.step(action)
|
||||||
|
current_return += reward * current_discount
|
||||||
|
current_undiscounted_return += reward
|
||||||
|
current_discount *= discount_factor
|
||||||
|
|
||||||
|
if done:
|
||||||
|
returns.append(current_return)
|
||||||
|
undiscounted_returns.append(current_undiscounted_return)
|
||||||
|
current_return = 0.
|
||||||
|
current_undiscounted_return = 0.
|
||||||
|
current_discount = 1.
|
||||||
|
observation = env_.reset()
|
||||||
|
|
||||||
|
# log
|
||||||
|
if returns: # has at least one finished episode
|
||||||
|
mean_return = np.mean(returns)
|
||||||
|
mean_undiscounted_return = np.mean(undiscounted_returns)
|
||||||
|
else: # the first episode is too long to finish
|
||||||
|
logging.warning('The first test episode is still not finished after {} timesteps. '
|
||||||
|
'Logging its return anyway.'.format(num_timesteps))
|
||||||
|
mean_return = current_return
|
||||||
|
mean_undiscounted_return = current_undiscounted_return
|
||||||
|
logging.info('Mean return: {}'.format(mean_return))
|
||||||
|
logging.info('Mean undiscounted return: {}'.format(mean_undiscounted_return))
|
||||||
|
|
||||||
|
# clear scene
|
||||||
|
env_.close()
|
Loading…
x
Reference in New Issue
Block a user