30 lines
833 B
Python
30 lines
833 B
Python
import d4rl
|
|
import gym
|
|
import h5py
|
|
|
|
from tianshou.data import ReplayBuffer
|
|
|
|
|
|
def load_buffer_d4rl(expert_data_task: str) -> ReplayBuffer:
|
|
dataset = d4rl.qlearning_dataset(gym.make(expert_data_task))
|
|
replay_buffer = ReplayBuffer.from_data(
|
|
obs=dataset["observations"],
|
|
act=dataset["actions"],
|
|
rew=dataset["rewards"],
|
|
done=dataset["terminals"],
|
|
obs_next=dataset["next_observations"]
|
|
)
|
|
return replay_buffer
|
|
|
|
|
|
def load_buffer(buffer_path: str) -> ReplayBuffer:
|
|
with h5py.File(buffer_path, "r") as dataset:
|
|
buffer = ReplayBuffer.from_data(
|
|
obs=dataset["observations"],
|
|
act=dataset["actions"],
|
|
rew=dataset["rewards"],
|
|
done=dataset["terminals"],
|
|
obs_next=dataset["next_observations"]
|
|
)
|
|
return buffer
|