67 lines
2.6 KiB
Python
67 lines
2.6 KiB
Python
from .replay_buffer.base import ReplayBufferBase
|
|
|
|
class DataCollector(object):
|
|
"""
|
|
a utility class to manage the interaction between buffer and advantage_estimation
|
|
"""
|
|
def __init__(self, env, policy, data_buffer, process_functions, managed_networks):
|
|
self.env = env
|
|
self.policy = policy
|
|
self.data_buffer = data_buffer
|
|
self.process_functions = process_functions
|
|
self.managed_networks = managed_networks
|
|
|
|
self.required_placeholders = {}
|
|
for net in self.managed_networks:
|
|
self.required_placeholders.update(net.managed_placeholders)
|
|
self.require_advantage = 'advantage' in self.required_placeholders.keys()
|
|
|
|
if isinstance(self.data_buffer, ReplayBufferBase): # process when sampling minibatch
|
|
self.process_mode = 'minibatch'
|
|
else:
|
|
self.process_mode = 'batch'
|
|
|
|
self.current_observation = self.env.reset()
|
|
|
|
def collect(self, num_timesteps=1, num_episodes=0, exploration=None, my_feed_dict={}):
|
|
assert sum([num_timesteps > 0, num_episodes > 0]) == 1,\
|
|
"One and only one collection number specification permitted!"
|
|
|
|
if num_timesteps > 0:
|
|
for _ in range(num_timesteps):
|
|
action_vanilla = self.policy.act(self.current_observation, my_feed_dict=my_feed_dict)
|
|
if exploration:
|
|
action = exploration(action_vanilla)
|
|
else:
|
|
action = action_vanilla
|
|
|
|
next_observation, reward, done, _ = self.env.step(action)
|
|
self.data_buffer.add((self.current_observation, action, reward, done))
|
|
self.current_observation = next_observation
|
|
|
|
if num_episodes > 0:
|
|
for _ in range(num_episodes):
|
|
observation = self.env.reset()
|
|
done = False
|
|
while not done:
|
|
action_vanilla = self.policy.act(observation, my_feed_dict=my_feed_dict)
|
|
if exploration:
|
|
action = exploration(action_vanilla)
|
|
else:
|
|
action = action_vanilla
|
|
|
|
next_observation, reward, done, _ = self.env.step(action)
|
|
self.data_buffer.add((observation, action, reward, done))
|
|
observation = next_observation
|
|
|
|
def next_batch(self, batch_size):
|
|
sampled_index = self.data_buffer.sample(batch_size)
|
|
if self.process_mode == 'minibatch':
|
|
pass
|
|
|
|
# flatten rank-2 list to numpy array, construct feed_dict
|
|
|
|
return
|
|
|
|
def statistics(self):
|
|
pass |