ppo with batch also works! now ppo improves steadily, dqn not so stable.
This commit is contained in:
parent
6eb69c7867
commit
498b55c051
@ -6,6 +6,8 @@ import gym
|
||||
import numpy as np
|
||||
import time
|
||||
import argparse
|
||||
import logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
# our lib imports here! It's ok to append path in examples
|
||||
import sys
|
||||
@ -14,7 +16,7 @@ from tianshou.core import losses
|
||||
import tianshou.data.advantage_estimation as advantage_estimation
|
||||
import tianshou.core.policy.stochastic as policy
|
||||
|
||||
from tianshou.data.data_buffer.vanilla import VanillaReplayBuffer
|
||||
from tianshou.data.data_buffer.batch_set import BatchSet
|
||||
from tianshou.data.data_collector import DataCollector
|
||||
|
||||
|
||||
@ -62,7 +64,15 @@ if __name__ == '__main__':
|
||||
train_op = optimizer.minimize(total_loss, var_list=pi.trainable_variables)
|
||||
|
||||
### 3. define data collection
|
||||
data_collector = Batch(env, pi, [advantage_estimation.full_return], [pi], render=args.render)
|
||||
data_buffer = BatchSet()
|
||||
|
||||
data_collector = DataCollector(
|
||||
env=env,
|
||||
policy=pi,
|
||||
data_buffer=data_buffer,
|
||||
process_functions=[advantage_estimation.full_return],
|
||||
managed_networks=[pi],
|
||||
)
|
||||
|
||||
### 4. start training
|
||||
config = tf.ConfigProto()
|
||||
@ -80,7 +90,7 @@ if __name__ == '__main__':
|
||||
|
||||
# print current return
|
||||
print('Epoch {}:'.format(i))
|
||||
data_collector.statistics()
|
||||
data_buffer.statistics()
|
||||
|
||||
# update network
|
||||
for _ in range(num_batches):
|
||||
|
@ -11,4 +11,16 @@ class DataBufferBase(object):
|
||||
raise NotImplementedError()
|
||||
|
||||
def sample(self, batch_size):
|
||||
raise NotImplementedError()
|
||||
prob_episode = np.array(self.index_lengths) * 1. / self.size
|
||||
num_episodes = len(self.index)
|
||||
sampled_index = [[] for _ in range(num_episodes)]
|
||||
|
||||
for _ in range(batch_size):
|
||||
# sample which episode
|
||||
sampled_episode_i = int(np.random.choice(num_episodes, p=prob_episode))
|
||||
|
||||
# sample which data point within the sampled episode
|
||||
sampled_frame_i = int(np.random.randint(self.index_lengths[sampled_episode_i]))
|
||||
sampled_index[sampled_episode_i].append(sampled_frame_i)
|
||||
|
||||
return sampled_index
|
||||
|
@ -1,11 +1,21 @@
|
||||
import gc
|
||||
import numpy as np
|
||||
import logging
|
||||
|
||||
from .base import DataBufferBase
|
||||
|
||||
STATE = 0
|
||||
ACTION = 1
|
||||
REWARD = 2
|
||||
DONE = 3
|
||||
|
||||
class BatchSet(DataBufferBase):
|
||||
"""
|
||||
class for batched dataset as used in on-policy algos
|
||||
"""
|
||||
def __init__(self):
|
||||
def __init__(self, nstep=None):
|
||||
self.nstep = nstep or 1 # RL has to look ahead at least one timestep
|
||||
|
||||
self.data = [[]]
|
||||
self.index = [[]]
|
||||
self.candidate_index = 0
|
||||
@ -17,8 +27,80 @@ class BatchSet(DataBufferBase):
|
||||
def add(self, frame):
|
||||
self.data[-1].append(frame)
|
||||
|
||||
has_enough_frames = len(self.data[-1]) > self.nstep
|
||||
if frame[DONE]: # episode terminates, all trailing frames become valid data points
|
||||
trailing_index = list(range(self.candidate_index, len(self.data[-1])))
|
||||
self.index[-1] += trailing_index
|
||||
self.size += len(trailing_index)
|
||||
self.index_lengths[-1] += len(trailing_index)
|
||||
|
||||
# prepare for the next episode
|
||||
self.data.append([])
|
||||
self.index.append([])
|
||||
self.candidate_index = 0
|
||||
|
||||
self.index_lengths.append(0)
|
||||
|
||||
elif has_enough_frames: # add one valid data point
|
||||
self.index[-1].append(self.candidate_index)
|
||||
self.candidate_index += 1
|
||||
self.size += 1
|
||||
self.index_lengths[-1] += 1
|
||||
|
||||
def clear(self):
|
||||
pass
|
||||
del self.data
|
||||
del self.index
|
||||
del self.index_lengths
|
||||
|
||||
gc.collect()
|
||||
|
||||
self.data = [[]]
|
||||
self.index = [[]]
|
||||
self.candidate_index = 0
|
||||
self.size = 0
|
||||
self.index_lengths = [0]
|
||||
|
||||
def sample(self, batch_size):
|
||||
pass
|
||||
# TODO: move unified properties and methods to base. but this depends on how to deal with nstep
|
||||
|
||||
prob_episode = np.array(self.index_lengths) * 1. / self.size
|
||||
num_episodes = len(self.index)
|
||||
sampled_index = [[] for _ in range(num_episodes)]
|
||||
|
||||
for _ in range(batch_size):
|
||||
# sample which episode
|
||||
sampled_episode_i = int(np.random.choice(num_episodes, p=prob_episode))
|
||||
|
||||
# sample which data point within the sampled episode
|
||||
sampled_frame_i = int(np.random.randint(self.index_lengths[sampled_episode_i]))
|
||||
sampled_index[sampled_episode_i].append(sampled_frame_i)
|
||||
|
||||
return sampled_index
|
||||
|
||||
def statistics(self, discount_factor=0.99):
|
||||
returns = []
|
||||
undiscounted_returns = []
|
||||
|
||||
if len(self.data) == 1:
|
||||
data = self.data
|
||||
logging.warning('The first episode in BatchSet is still not finished. '
|
||||
'Logging its return anyway.')
|
||||
else:
|
||||
data = self.data[:-1]
|
||||
|
||||
for episode in data:
|
||||
current_return = 0.
|
||||
current_undiscounted_return = 0.
|
||||
current_discount = 1.
|
||||
for frame in episode:
|
||||
current_return += frame[REWARD] * current_discount
|
||||
current_undiscounted_return += frame[REWARD]
|
||||
current_discount *= discount_factor
|
||||
returns.append(current_return)
|
||||
undiscounted_returns.append(current_undiscounted_return)
|
||||
|
||||
mean_return = np.mean(returns)
|
||||
mean_undiscounted_return = np.mean(undiscounted_returns)
|
||||
|
||||
logging.info('Mean return: {}'.format(mean_return))
|
||||
logging.info('Mean undiscounted return: {}'.format(mean_undiscounted_return))
|
||||
|
@ -8,7 +8,7 @@ ACTION = 1
|
||||
REWARD = 2
|
||||
DONE = 3
|
||||
|
||||
# TODO: valid data points could be less than `nstep` timesteps
|
||||
# TODO: valid data points could be less than `nstep` timesteps. Check priority replay paper!
|
||||
class VanillaReplayBuffer(ReplayBufferBase):
|
||||
"""
|
||||
vanilla replay buffer as used in (Mnih, et al., 2015).
|
||||
|
@ -1,10 +1,10 @@
|
||||
import numpy as np
|
||||
import logging
|
||||
import itertools
|
||||
import sys
|
||||
|
||||
from .data_buffer.replay_buffer_base import ReplayBufferBase
|
||||
from .data_buffer.batch_set import BatchSet
|
||||
from .utils import internal_key_match
|
||||
|
||||
class DataCollector(object):
|
||||
"""
|
||||
@ -32,7 +32,7 @@ class DataCollector(object):
|
||||
|
||||
self.current_observation = self.env.reset()
|
||||
|
||||
def collect(self, num_timesteps=1, num_episodes=0, my_feed_dict={}, auto_clear=True):
|
||||
def collect(self, num_timesteps=0, num_episodes=0, my_feed_dict={}, auto_clear=True):
|
||||
assert sum([num_timesteps > 0, num_episodes > 0]) == 1,\
|
||||
"One and only one collection number specification permitted!"
|
||||
|
||||
@ -76,20 +76,28 @@ class DataCollector(object):
|
||||
feed_dict = {}
|
||||
frame_key_map = {'observation': 0, 'action': 1, 'reward': 2, 'done_flag': 3}
|
||||
for key, placeholder in self.required_placeholders.items():
|
||||
if key in frame_key_map.keys(): # access raw_data
|
||||
frame_index = frame_key_map[key]
|
||||
# check raw_data first
|
||||
found, matched_key = internal_key_match(key, frame_key_map.keys())
|
||||
if found:
|
||||
frame_index = frame_key_map[matched_key]
|
||||
flattened = []
|
||||
for index_episode, data_episode in zip(sampled_index, self.data_buffer.data):
|
||||
for i in index_episode:
|
||||
flattened.append(data_episode[i][frame_index])
|
||||
feed_dict[placeholder] = np.array(flattened)
|
||||
elif key in self.data_batch.keys(): # access processed minibatch data
|
||||
flattened = list(itertools.chain.from_iterable(self.data_batch[key]))
|
||||
else:
|
||||
# then check processed minibatch data
|
||||
found, matched_key = internal_key_match(key, self.data_batch.keys())
|
||||
if found:
|
||||
flattened = list(itertools.chain.from_iterable(self.data_batch[matched_key]))
|
||||
feed_dict[placeholder] = np.array(flattened)
|
||||
elif key in self.data.keys(): # access processed full data
|
||||
else:
|
||||
# finally check processed full data
|
||||
found, matched_key = internal_key_match(key, self.data.keys())
|
||||
if found:
|
||||
flattened = [0.] * batch_size # float
|
||||
i_in_batch = 0
|
||||
for index_episode, data_episode in zip(sampled_index, self.data[key]):
|
||||
for index_episode, data_episode in zip(sampled_index, self.data[matched_key]):
|
||||
for i in index_episode:
|
||||
flattened[i_in_batch] = data_episode[i]
|
||||
i_in_batch += 1
|
||||
@ -108,6 +116,3 @@ class DataCollector(object):
|
||||
feed_dict[self.required_placeholders['advantage']] = (advantage_value - advantage_mean) / advantage_std
|
||||
|
||||
return feed_dict
|
||||
|
||||
def statistics(self):
|
||||
pass
|
Loading…
x
Reference in New Issue
Block a user