52 lines
1.2 KiB
Python
Raw Normal View History

2017-12-10 14:53:57 +08:00
class ReplayBuffer(object):
def __init__(self, env, policy, qnet, target_qnet, conf):
"""
2017-12-10 14:53:57 +08:00
Initialize a replay buffer with parameters in conf.
"""
pass
2017-12-10 14:53:57 +08:00
def add(self, data, priority):
"""
2017-12-10 14:53:57 +08:00
Add a data with priority = priority to replay buffer.
"""
pass
2017-12-10 14:53:57 +08:00
def collect(self):
"""
Collect data from current environment and policy.
"""
pass
def next_batch(self, batch_size):
"""
get batch of data from the replay buffer.
"""
pass
def update_priority(self, indices, priorities):
"""
2017-12-10 14:53:57 +08:00
Update the data's priority whose indices = indices.
For proportional replay buffer, the priority is the priority.
For rank based replay buffer, the priorities parameter will be the delta used to update the priority.
"""
pass
2017-12-10 14:53:57 +08:00
def reset_alpha(self, alpha):
"""
2017-12-10 14:53:57 +08:00
This function only works for proportional replay buffer.
This function resets alpha.
"""
pass
2017-12-10 14:53:57 +08:00
def sample(self, conf):
"""
2017-12-10 14:53:57 +08:00
Sample from replay buffer with parameters in conf.
"""
pass
2017-12-10 14:53:57 +08:00
def rebalance(self):
"""
2017-12-10 14:53:57 +08:00
This is for rank based priority replay buffer, which is used to rebalance the sum tree of the priority queue.
"""
pass