ppo with batch also works! now ppo improves steadily, dqn not so stable.

This commit is contained in:
haoshengzou 2018-03-10 17:30:11 +08:00
parent 6eb69c7867
commit 498b55c051
5 changed files with 136 additions and 27 deletions

View File

@ -6,6 +6,8 @@ import gym
import numpy as np
import time
import argparse
import logging
logging.basicConfig(level=logging.INFO)
# our lib imports here! It's ok to append path in examples
import sys
@ -14,7 +16,7 @@ from tianshou.core import losses
import tianshou.data.advantage_estimation as advantage_estimation
import tianshou.core.policy.stochastic as policy
from tianshou.data.data_buffer.vanilla import VanillaReplayBuffer
from tianshou.data.data_buffer.batch_set import BatchSet
from tianshou.data.data_collector import DataCollector
@ -62,7 +64,15 @@ if __name__ == '__main__':
train_op = optimizer.minimize(total_loss, var_list=pi.trainable_variables)
### 3. define data collection
data_collector = Batch(env, pi, [advantage_estimation.full_return], [pi], render=args.render)
data_buffer = BatchSet()
data_collector = DataCollector(
env=env,
policy=pi,
data_buffer=data_buffer,
process_functions=[advantage_estimation.full_return],
managed_networks=[pi],
)
### 4. start training
config = tf.ConfigProto()
@ -80,7 +90,7 @@ if __name__ == '__main__':
# print current return
print('Epoch {}:'.format(i))
data_collector.statistics()
data_buffer.statistics()
# update network
for _ in range(num_batches):

View File

@ -11,4 +11,16 @@ class DataBufferBase(object):
raise NotImplementedError()
def sample(self, batch_size):
raise NotImplementedError()
prob_episode = np.array(self.index_lengths) * 1. / self.size
num_episodes = len(self.index)
sampled_index = [[] for _ in range(num_episodes)]
for _ in range(batch_size):
# sample which episode
sampled_episode_i = int(np.random.choice(num_episodes, p=prob_episode))
# sample which data point within the sampled episode
sampled_frame_i = int(np.random.randint(self.index_lengths[sampled_episode_i]))
sampled_index[sampled_episode_i].append(sampled_frame_i)
return sampled_index

View File

@ -1,11 +1,21 @@
import gc
import numpy as np
import logging
from .base import DataBufferBase
STATE = 0
ACTION = 1
REWARD = 2
DONE = 3
class BatchSet(DataBufferBase):
"""
class for batched dataset as used in on-policy algos
"""
def __init__(self):
def __init__(self, nstep=None):
self.nstep = nstep or 1 # RL has to look ahead at least one timestep
self.data = [[]]
self.index = [[]]
self.candidate_index = 0
@ -17,8 +27,80 @@ class BatchSet(DataBufferBase):
def add(self, frame):
self.data[-1].append(frame)
has_enough_frames = len(self.data[-1]) > self.nstep
if frame[DONE]: # episode terminates, all trailing frames become valid data points
trailing_index = list(range(self.candidate_index, len(self.data[-1])))
self.index[-1] += trailing_index
self.size += len(trailing_index)
self.index_lengths[-1] += len(trailing_index)
# prepare for the next episode
self.data.append([])
self.index.append([])
self.candidate_index = 0
self.index_lengths.append(0)
elif has_enough_frames: # add one valid data point
self.index[-1].append(self.candidate_index)
self.candidate_index += 1
self.size += 1
self.index_lengths[-1] += 1
def clear(self):
pass
del self.data
del self.index
del self.index_lengths
gc.collect()
self.data = [[]]
self.index = [[]]
self.candidate_index = 0
self.size = 0
self.index_lengths = [0]
def sample(self, batch_size):
pass
# TODO: move unified properties and methods to base. but this depends on how to deal with nstep
prob_episode = np.array(self.index_lengths) * 1. / self.size
num_episodes = len(self.index)
sampled_index = [[] for _ in range(num_episodes)]
for _ in range(batch_size):
# sample which episode
sampled_episode_i = int(np.random.choice(num_episodes, p=prob_episode))
# sample which data point within the sampled episode
sampled_frame_i = int(np.random.randint(self.index_lengths[sampled_episode_i]))
sampled_index[sampled_episode_i].append(sampled_frame_i)
return sampled_index
def statistics(self, discount_factor=0.99):
returns = []
undiscounted_returns = []
if len(self.data) == 1:
data = self.data
logging.warning('The first episode in BatchSet is still not finished. '
'Logging its return anyway.')
else:
data = self.data[:-1]
for episode in data:
current_return = 0.
current_undiscounted_return = 0.
current_discount = 1.
for frame in episode:
current_return += frame[REWARD] * current_discount
current_undiscounted_return += frame[REWARD]
current_discount *= discount_factor
returns.append(current_return)
undiscounted_returns.append(current_undiscounted_return)
mean_return = np.mean(returns)
mean_undiscounted_return = np.mean(undiscounted_returns)
logging.info('Mean return: {}'.format(mean_return))
logging.info('Mean undiscounted return: {}'.format(mean_undiscounted_return))

View File

@ -8,7 +8,7 @@ ACTION = 1
REWARD = 2
DONE = 3
# TODO: valid data points could be less than `nstep` timesteps
# TODO: valid data points could be less than `nstep` timesteps. Check priority replay paper!
class VanillaReplayBuffer(ReplayBufferBase):
"""
vanilla replay buffer as used in (Mnih, et al., 2015).

View File

@ -1,10 +1,10 @@
import numpy as np
import logging
import itertools
import sys
from .data_buffer.replay_buffer_base import ReplayBufferBase
from .data_buffer.batch_set import BatchSet
from .utils import internal_key_match
class DataCollector(object):
"""
@ -32,7 +32,7 @@ class DataCollector(object):
self.current_observation = self.env.reset()
def collect(self, num_timesteps=1, num_episodes=0, my_feed_dict={}, auto_clear=True):
def collect(self, num_timesteps=0, num_episodes=0, my_feed_dict={}, auto_clear=True):
assert sum([num_timesteps > 0, num_episodes > 0]) == 1,\
"One and only one collection number specification permitted!"
@ -76,26 +76,34 @@ class DataCollector(object):
feed_dict = {}
frame_key_map = {'observation': 0, 'action': 1, 'reward': 2, 'done_flag': 3}
for key, placeholder in self.required_placeholders.items():
if key in frame_key_map.keys(): # access raw_data
frame_index = frame_key_map[key]
# check raw_data first
found, matched_key = internal_key_match(key, frame_key_map.keys())
if found:
frame_index = frame_key_map[matched_key]
flattened = []
for index_episode, data_episode in zip(sampled_index, self.data_buffer.data):
for i in index_episode:
flattened.append(data_episode[i][frame_index])
feed_dict[placeholder] = np.array(flattened)
elif key in self.data_batch.keys(): # access processed minibatch data
flattened = list(itertools.chain.from_iterable(self.data_batch[key]))
feed_dict[placeholder] = np.array(flattened)
elif key in self.data.keys(): # access processed full data
flattened = [0.] * batch_size # float
i_in_batch = 0
for index_episode, data_episode in zip(sampled_index, self.data[key]):
for i in index_episode:
flattened[i_in_batch] = data_episode[i]
i_in_batch += 1
feed_dict[placeholder] = np.array(flattened)
else:
raise TypeError('Placeholder {} has no value to feed!'.format(str(placeholder.name)))
# then check processed minibatch data
found, matched_key = internal_key_match(key, self.data_batch.keys())
if found:
flattened = list(itertools.chain.from_iterable(self.data_batch[matched_key]))
feed_dict[placeholder] = np.array(flattened)
else:
# finally check processed full data
found, matched_key = internal_key_match(key, self.data.keys())
if found:
flattened = [0.] * batch_size # float
i_in_batch = 0
for index_episode, data_episode in zip(sampled_index, self.data[matched_key]):
for i in index_episode:
flattened[i_in_batch] = data_episode[i]
i_in_batch += 1
feed_dict[placeholder] = np.array(flattened)
else:
raise TypeError('Placeholder {} has no value to feed!'.format(str(placeholder.name)))
auto_standardize = (standardize_advantage is None) and self.require_advantage
if standardize_advantage or auto_standardize:
@ -108,6 +116,3 @@ class DataCollector(object):
feed_dict[self.required_placeholders['advantage']] = (advantage_value - advantage_mean) / advantage_std
return feed_dict
def statistics(self):
pass