2017-12-10 14:53:57 +08:00
|
|
|
import sys
|
|
|
|
|
2017-12-17 12:52:00 +08:00
|
|
|
from tianshou.data.replay_buffer.naive import NaiveExperience
|
|
|
|
from tianshou.data.replay_buffer.proportional import PropotionalExperience
|
|
|
|
from tianshou.data.replay_buffer.rank_based import RankBasedExperience
|
|
|
|
|
|
|
|
|
|
|
|
def get_replay_buffer(name, env, policy, qnet, target_qnet, conf):
|
|
|
|
"""
|
|
|
|
Get replay buffer according to the given name.
|
|
|
|
"""
|
|
|
|
|
|
|
|
if name == 'rank_based':
|
|
|
|
return RankBasedExperience(env, policy, qnet, target_qnet, conf)
|
|
|
|
elif name == 'proportional':
|
|
|
|
return PropotionalExperience(env, policy, qnet, target_qnet, conf)
|
|
|
|
elif name == 'naive':
|
|
|
|
return NaiveExperience(env, policy, qnet, target_qnet, conf)
|
|
|
|
else:
|
|
|
|
sys.stderr.write('no such replay buffer')
|