Changes: - Disclaimer in README - Replaced all occurences of Gym with Gymnasium - Removed code that is now dead since we no longer need to support the old step API - Updated type hints to only allow new step API - Increased required version of envpool to support Gymnasium - Increased required version of PettingZoo to support Gymnasium - Updated `PettingZooEnv` to only use the new step API, removed hack to also support old API - I had to add some `# type: ignore` comments, due to new type hinting in Gymnasium. I'm not that familiar with type hinting but I believe that the issue is on the Gymnasium side and we are looking into it. - Had to update `MyTestEnv` to support `options` kwarg - Skip NNI tests because they still use OpenAI Gym - Also allow `PettingZooEnv` in vector environment - Updated doc page about ReplayBuffer to also talk about terminated and truncated flags. Still need to do: - Update the Jupyter notebooks in docs - Check the entire code base for more dead code (from compatibility stuff) - Check the reset functions of all environments/wrappers in code base to make sure they use the `options` kwarg - Someone might want to check test_env_finite.py - Is it okay to allow `PettingZooEnv` in vector environments? Might need to update docs?
53 lines
1.7 KiB
Python
53 lines
1.7 KiB
Python
from typing import Tuple
|
|
|
|
import d4rl
|
|
import gymnasium as gym
|
|
import h5py
|
|
import numpy as np
|
|
|
|
from tianshou.data import ReplayBuffer
|
|
from tianshou.utils import RunningMeanStd
|
|
|
|
|
|
def load_buffer_d4rl(expert_data_task: str) -> ReplayBuffer:
|
|
dataset = d4rl.qlearning_dataset(gym.make(expert_data_task))
|
|
replay_buffer = ReplayBuffer.from_data(
|
|
obs=dataset["observations"],
|
|
act=dataset["actions"],
|
|
rew=dataset["rewards"],
|
|
done=dataset["terminals"],
|
|
obs_next=dataset["next_observations"],
|
|
terminated=dataset["terminals"],
|
|
truncated=np.zeros(len(dataset["terminals"]))
|
|
)
|
|
return replay_buffer
|
|
|
|
|
|
def load_buffer(buffer_path: str) -> ReplayBuffer:
|
|
with h5py.File(buffer_path, "r") as dataset:
|
|
buffer = ReplayBuffer.from_data(
|
|
obs=dataset["observations"],
|
|
act=dataset["actions"],
|
|
rew=dataset["rewards"],
|
|
done=dataset["terminals"],
|
|
obs_next=dataset["next_observations"],
|
|
terminated=dataset["terminals"],
|
|
truncated=np.zeros(len(dataset["terminals"]))
|
|
)
|
|
return buffer
|
|
|
|
|
|
def normalize_all_obs_in_replay_buffer(
|
|
replay_buffer: ReplayBuffer
|
|
) -> Tuple[ReplayBuffer, RunningMeanStd]:
|
|
# compute obs mean and var
|
|
obs_rms = RunningMeanStd()
|
|
obs_rms.update(replay_buffer.obs)
|
|
_eps = np.finfo(np.float32).eps.item()
|
|
# normalize obs
|
|
replay_buffer._meta["obs"] = (replay_buffer.obs -
|
|
obs_rms.mean) / np.sqrt(obs_rms.var + _eps)
|
|
replay_buffer._meta["obs_next"] = (replay_buffer.obs_next -
|
|
obs_rms.mean) / np.sqrt(obs_rms.var + _eps)
|
|
return replay_buffer, obs_rms
|