21 lines
722 B
Python
Raw Normal View History

2017-12-10 14:53:57 +08:00
import sys
from tianshou.data.replay_buffer.naive import NaiveExperience
from tianshou.data.replay_buffer.proportional import PropotionalExperience
from tianshou.data.replay_buffer.rank_based import RankBasedExperience
def get_replay_buffer(name, env, policy, qnet, target_qnet, conf):
"""
Get replay buffer according to the given name.
"""
if name == 'rank_based':
return RankBasedExperience(env, policy, qnet, target_qnet, conf)
elif name == 'proportional':
return PropotionalExperience(env, policy, qnet, target_qnet, conf)
elif name == 'naive':
return NaiveExperience(env, policy, qnet, target_qnet, conf)
else:
sys.stderr.write('no such replay buffer')