ppo with batch also works! now ppo improves steadily, dqn not so stable.
This commit is contained in:
parent
6eb69c7867
commit
498b55c051
@ -6,6 +6,8 @@ import gym
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import time
|
import time
|
||||||
import argparse
|
import argparse
|
||||||
|
import logging
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
# our lib imports here! It's ok to append path in examples
|
# our lib imports here! It's ok to append path in examples
|
||||||
import sys
|
import sys
|
||||||
@ -14,7 +16,7 @@ from tianshou.core import losses
|
|||||||
import tianshou.data.advantage_estimation as advantage_estimation
|
import tianshou.data.advantage_estimation as advantage_estimation
|
||||||
import tianshou.core.policy.stochastic as policy
|
import tianshou.core.policy.stochastic as policy
|
||||||
|
|
||||||
from tianshou.data.data_buffer.vanilla import VanillaReplayBuffer
|
from tianshou.data.data_buffer.batch_set import BatchSet
|
||||||
from tianshou.data.data_collector import DataCollector
|
from tianshou.data.data_collector import DataCollector
|
||||||
|
|
||||||
|
|
||||||
@ -62,7 +64,15 @@ if __name__ == '__main__':
|
|||||||
train_op = optimizer.minimize(total_loss, var_list=pi.trainable_variables)
|
train_op = optimizer.minimize(total_loss, var_list=pi.trainable_variables)
|
||||||
|
|
||||||
### 3. define data collection
|
### 3. define data collection
|
||||||
data_collector = Batch(env, pi, [advantage_estimation.full_return], [pi], render=args.render)
|
data_buffer = BatchSet()
|
||||||
|
|
||||||
|
data_collector = DataCollector(
|
||||||
|
env=env,
|
||||||
|
policy=pi,
|
||||||
|
data_buffer=data_buffer,
|
||||||
|
process_functions=[advantage_estimation.full_return],
|
||||||
|
managed_networks=[pi],
|
||||||
|
)
|
||||||
|
|
||||||
### 4. start training
|
### 4. start training
|
||||||
config = tf.ConfigProto()
|
config = tf.ConfigProto()
|
||||||
@ -80,7 +90,7 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
# print current return
|
# print current return
|
||||||
print('Epoch {}:'.format(i))
|
print('Epoch {}:'.format(i))
|
||||||
data_collector.statistics()
|
data_buffer.statistics()
|
||||||
|
|
||||||
# update network
|
# update network
|
||||||
for _ in range(num_batches):
|
for _ in range(num_batches):
|
||||||
|
@ -11,4 +11,16 @@ class DataBufferBase(object):
|
|||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def sample(self, batch_size):
|
def sample(self, batch_size):
|
||||||
raise NotImplementedError()
|
prob_episode = np.array(self.index_lengths) * 1. / self.size
|
||||||
|
num_episodes = len(self.index)
|
||||||
|
sampled_index = [[] for _ in range(num_episodes)]
|
||||||
|
|
||||||
|
for _ in range(batch_size):
|
||||||
|
# sample which episode
|
||||||
|
sampled_episode_i = int(np.random.choice(num_episodes, p=prob_episode))
|
||||||
|
|
||||||
|
# sample which data point within the sampled episode
|
||||||
|
sampled_frame_i = int(np.random.randint(self.index_lengths[sampled_episode_i]))
|
||||||
|
sampled_index[sampled_episode_i].append(sampled_frame_i)
|
||||||
|
|
||||||
|
return sampled_index
|
||||||
|
@ -1,11 +1,21 @@
|
|||||||
|
import gc
|
||||||
|
import numpy as np
|
||||||
|
import logging
|
||||||
|
|
||||||
from .base import DataBufferBase
|
from .base import DataBufferBase
|
||||||
|
|
||||||
|
STATE = 0
|
||||||
|
ACTION = 1
|
||||||
|
REWARD = 2
|
||||||
|
DONE = 3
|
||||||
|
|
||||||
class BatchSet(DataBufferBase):
|
class BatchSet(DataBufferBase):
|
||||||
"""
|
"""
|
||||||
class for batched dataset as used in on-policy algos
|
class for batched dataset as used in on-policy algos
|
||||||
"""
|
"""
|
||||||
def __init__(self):
|
def __init__(self, nstep=None):
|
||||||
|
self.nstep = nstep or 1 # RL has to look ahead at least one timestep
|
||||||
|
|
||||||
self.data = [[]]
|
self.data = [[]]
|
||||||
self.index = [[]]
|
self.index = [[]]
|
||||||
self.candidate_index = 0
|
self.candidate_index = 0
|
||||||
@ -17,8 +27,80 @@ class BatchSet(DataBufferBase):
|
|||||||
def add(self, frame):
|
def add(self, frame):
|
||||||
self.data[-1].append(frame)
|
self.data[-1].append(frame)
|
||||||
|
|
||||||
|
has_enough_frames = len(self.data[-1]) > self.nstep
|
||||||
|
if frame[DONE]: # episode terminates, all trailing frames become valid data points
|
||||||
|
trailing_index = list(range(self.candidate_index, len(self.data[-1])))
|
||||||
|
self.index[-1] += trailing_index
|
||||||
|
self.size += len(trailing_index)
|
||||||
|
self.index_lengths[-1] += len(trailing_index)
|
||||||
|
|
||||||
|
# prepare for the next episode
|
||||||
|
self.data.append([])
|
||||||
|
self.index.append([])
|
||||||
|
self.candidate_index = 0
|
||||||
|
|
||||||
|
self.index_lengths.append(0)
|
||||||
|
|
||||||
|
elif has_enough_frames: # add one valid data point
|
||||||
|
self.index[-1].append(self.candidate_index)
|
||||||
|
self.candidate_index += 1
|
||||||
|
self.size += 1
|
||||||
|
self.index_lengths[-1] += 1
|
||||||
|
|
||||||
def clear(self):
|
def clear(self):
|
||||||
pass
|
del self.data
|
||||||
|
del self.index
|
||||||
|
del self.index_lengths
|
||||||
|
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
self.data = [[]]
|
||||||
|
self.index = [[]]
|
||||||
|
self.candidate_index = 0
|
||||||
|
self.size = 0
|
||||||
|
self.index_lengths = [0]
|
||||||
|
|
||||||
def sample(self, batch_size):
|
def sample(self, batch_size):
|
||||||
pass
|
# TODO: move unified properties and methods to base. but this depends on how to deal with nstep
|
||||||
|
|
||||||
|
prob_episode = np.array(self.index_lengths) * 1. / self.size
|
||||||
|
num_episodes = len(self.index)
|
||||||
|
sampled_index = [[] for _ in range(num_episodes)]
|
||||||
|
|
||||||
|
for _ in range(batch_size):
|
||||||
|
# sample which episode
|
||||||
|
sampled_episode_i = int(np.random.choice(num_episodes, p=prob_episode))
|
||||||
|
|
||||||
|
# sample which data point within the sampled episode
|
||||||
|
sampled_frame_i = int(np.random.randint(self.index_lengths[sampled_episode_i]))
|
||||||
|
sampled_index[sampled_episode_i].append(sampled_frame_i)
|
||||||
|
|
||||||
|
return sampled_index
|
||||||
|
|
||||||
|
def statistics(self, discount_factor=0.99):
|
||||||
|
returns = []
|
||||||
|
undiscounted_returns = []
|
||||||
|
|
||||||
|
if len(self.data) == 1:
|
||||||
|
data = self.data
|
||||||
|
logging.warning('The first episode in BatchSet is still not finished. '
|
||||||
|
'Logging its return anyway.')
|
||||||
|
else:
|
||||||
|
data = self.data[:-1]
|
||||||
|
|
||||||
|
for episode in data:
|
||||||
|
current_return = 0.
|
||||||
|
current_undiscounted_return = 0.
|
||||||
|
current_discount = 1.
|
||||||
|
for frame in episode:
|
||||||
|
current_return += frame[REWARD] * current_discount
|
||||||
|
current_undiscounted_return += frame[REWARD]
|
||||||
|
current_discount *= discount_factor
|
||||||
|
returns.append(current_return)
|
||||||
|
undiscounted_returns.append(current_undiscounted_return)
|
||||||
|
|
||||||
|
mean_return = np.mean(returns)
|
||||||
|
mean_undiscounted_return = np.mean(undiscounted_returns)
|
||||||
|
|
||||||
|
logging.info('Mean return: {}'.format(mean_return))
|
||||||
|
logging.info('Mean undiscounted return: {}'.format(mean_undiscounted_return))
|
||||||
|
@ -8,7 +8,7 @@ ACTION = 1
|
|||||||
REWARD = 2
|
REWARD = 2
|
||||||
DONE = 3
|
DONE = 3
|
||||||
|
|
||||||
# TODO: valid data points could be less than `nstep` timesteps
|
# TODO: valid data points could be less than `nstep` timesteps. Check priority replay paper!
|
||||||
class VanillaReplayBuffer(ReplayBufferBase):
|
class VanillaReplayBuffer(ReplayBufferBase):
|
||||||
"""
|
"""
|
||||||
vanilla replay buffer as used in (Mnih, et al., 2015).
|
vanilla replay buffer as used in (Mnih, et al., 2015).
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import logging
|
import logging
|
||||||
import itertools
|
import itertools
|
||||||
import sys
|
|
||||||
|
|
||||||
from .data_buffer.replay_buffer_base import ReplayBufferBase
|
from .data_buffer.replay_buffer_base import ReplayBufferBase
|
||||||
from .data_buffer.batch_set import BatchSet
|
from .data_buffer.batch_set import BatchSet
|
||||||
|
from .utils import internal_key_match
|
||||||
|
|
||||||
class DataCollector(object):
|
class DataCollector(object):
|
||||||
"""
|
"""
|
||||||
@ -32,7 +32,7 @@ class DataCollector(object):
|
|||||||
|
|
||||||
self.current_observation = self.env.reset()
|
self.current_observation = self.env.reset()
|
||||||
|
|
||||||
def collect(self, num_timesteps=1, num_episodes=0, my_feed_dict={}, auto_clear=True):
|
def collect(self, num_timesteps=0, num_episodes=0, my_feed_dict={}, auto_clear=True):
|
||||||
assert sum([num_timesteps > 0, num_episodes > 0]) == 1,\
|
assert sum([num_timesteps > 0, num_episodes > 0]) == 1,\
|
||||||
"One and only one collection number specification permitted!"
|
"One and only one collection number specification permitted!"
|
||||||
|
|
||||||
@ -76,26 +76,34 @@ class DataCollector(object):
|
|||||||
feed_dict = {}
|
feed_dict = {}
|
||||||
frame_key_map = {'observation': 0, 'action': 1, 'reward': 2, 'done_flag': 3}
|
frame_key_map = {'observation': 0, 'action': 1, 'reward': 2, 'done_flag': 3}
|
||||||
for key, placeholder in self.required_placeholders.items():
|
for key, placeholder in self.required_placeholders.items():
|
||||||
if key in frame_key_map.keys(): # access raw_data
|
# check raw_data first
|
||||||
frame_index = frame_key_map[key]
|
found, matched_key = internal_key_match(key, frame_key_map.keys())
|
||||||
|
if found:
|
||||||
|
frame_index = frame_key_map[matched_key]
|
||||||
flattened = []
|
flattened = []
|
||||||
for index_episode, data_episode in zip(sampled_index, self.data_buffer.data):
|
for index_episode, data_episode in zip(sampled_index, self.data_buffer.data):
|
||||||
for i in index_episode:
|
for i in index_episode:
|
||||||
flattened.append(data_episode[i][frame_index])
|
flattened.append(data_episode[i][frame_index])
|
||||||
feed_dict[placeholder] = np.array(flattened)
|
feed_dict[placeholder] = np.array(flattened)
|
||||||
elif key in self.data_batch.keys(): # access processed minibatch data
|
|
||||||
flattened = list(itertools.chain.from_iterable(self.data_batch[key]))
|
|
||||||
feed_dict[placeholder] = np.array(flattened)
|
|
||||||
elif key in self.data.keys(): # access processed full data
|
|
||||||
flattened = [0.] * batch_size # float
|
|
||||||
i_in_batch = 0
|
|
||||||
for index_episode, data_episode in zip(sampled_index, self.data[key]):
|
|
||||||
for i in index_episode:
|
|
||||||
flattened[i_in_batch] = data_episode[i]
|
|
||||||
i_in_batch += 1
|
|
||||||
feed_dict[placeholder] = np.array(flattened)
|
|
||||||
else:
|
else:
|
||||||
raise TypeError('Placeholder {} has no value to feed!'.format(str(placeholder.name)))
|
# then check processed minibatch data
|
||||||
|
found, matched_key = internal_key_match(key, self.data_batch.keys())
|
||||||
|
if found:
|
||||||
|
flattened = list(itertools.chain.from_iterable(self.data_batch[matched_key]))
|
||||||
|
feed_dict[placeholder] = np.array(flattened)
|
||||||
|
else:
|
||||||
|
# finally check processed full data
|
||||||
|
found, matched_key = internal_key_match(key, self.data.keys())
|
||||||
|
if found:
|
||||||
|
flattened = [0.] * batch_size # float
|
||||||
|
i_in_batch = 0
|
||||||
|
for index_episode, data_episode in zip(sampled_index, self.data[matched_key]):
|
||||||
|
for i in index_episode:
|
||||||
|
flattened[i_in_batch] = data_episode[i]
|
||||||
|
i_in_batch += 1
|
||||||
|
feed_dict[placeholder] = np.array(flattened)
|
||||||
|
else:
|
||||||
|
raise TypeError('Placeholder {} has no value to feed!'.format(str(placeholder.name)))
|
||||||
|
|
||||||
auto_standardize = (standardize_advantage is None) and self.require_advantage
|
auto_standardize = (standardize_advantage is None) and self.require_advantage
|
||||||
if standardize_advantage or auto_standardize:
|
if standardize_advantage or auto_standardize:
|
||||||
@ -108,6 +116,3 @@ class DataCollector(object):
|
|||||||
feed_dict[self.required_placeholders['advantage']] = (advantage_value - advantage_mean) / advantage_std
|
feed_dict[self.required_placeholders['advantage']] = (advantage_value - advantage_mean) / advantage_std
|
||||||
|
|
||||||
return feed_dict
|
return feed_dict
|
||||||
|
|
||||||
def statistics(self):
|
|
||||||
pass
|
|
Loading…
x
Reference in New Issue
Block a user