53 lines
1.7 KiB
Python
Raw Normal View History

from typing import Tuple
import d4rl
import gym
import h5py
import numpy as np
from tianshou.data import ReplayBuffer
from tianshou.utils import RunningMeanStd
def load_buffer_d4rl(expert_data_task: str) -> ReplayBuffer:
dataset = d4rl.qlearning_dataset(gym.make(expert_data_task))
replay_buffer = ReplayBuffer.from_data(
obs=dataset["observations"],
act=dataset["actions"],
rew=dataset["rewards"],
done=dataset["terminals"],
obs_next=dataset["next_observations"],
terminated=dataset["terminals"],
truncated=np.zeros(len(dataset["terminals"]))
)
return replay_buffer
def load_buffer(buffer_path: str) -> ReplayBuffer:
with h5py.File(buffer_path, "r") as dataset:
buffer = ReplayBuffer.from_data(
obs=dataset["observations"],
act=dataset["actions"],
rew=dataset["rewards"],
done=dataset["terminals"],
obs_next=dataset["next_observations"],
terminated=dataset["terminals"],
truncated=np.zeros(len(dataset["terminals"]))
)
return buffer
def normalize_all_obs_in_replay_buffer(
replay_buffer: ReplayBuffer
) -> Tuple[ReplayBuffer, RunningMeanStd]:
# compute obs mean and var
obs_rms = RunningMeanStd()
obs_rms.update(replay_buffer.obs)
_eps = np.finfo(np.float32).eps.item()
# normalize obs
replay_buffer._meta["obs"] = (replay_buffer.obs -
obs_rms.mean) / np.sqrt(obs_rms.var + _eps)
replay_buffer._meta["obs_next"] = (replay_buffer.obs_next -
obs_rms.mean) / np.sqrt(obs_rms.var + _eps)
return replay_buffer, obs_rms