Michael Panchenko 600f4bbd55
Python 3.9, black + ruff formatting (#921)
Preparation for #914 and #920

Changes formatting to ruff and black. Remove python 3.8

## Additional Changes

- Removed flake8 dependencies
- Adjusted pre-commit. Now CI and Make use pre-commit, reducing the
duplication of linting calls
- Removed check-docstyle option (ruff is doing that)
- Merged format and lint. In CI the format-lint step fails if any
changes are done, so it fulfills the lint functionality.

---------

Co-authored-by: Jiayi Weng <jiayi@openai.com>
2023-08-25 14:40:56 -07:00

49 lines
1.6 KiB
Python

import d4rl
import gymnasium as gym
import h5py
import numpy as np
from tianshou.data import ReplayBuffer
from tianshou.utils import RunningMeanStd
def load_buffer_d4rl(expert_data_task: str) -> ReplayBuffer:
dataset = d4rl.qlearning_dataset(gym.make(expert_data_task))
return ReplayBuffer.from_data(
obs=dataset["observations"],
act=dataset["actions"],
rew=dataset["rewards"],
done=dataset["terminals"],
obs_next=dataset["next_observations"],
terminated=dataset["terminals"],
truncated=np.zeros(len(dataset["terminals"])),
)
def load_buffer(buffer_path: str) -> ReplayBuffer:
with h5py.File(buffer_path, "r") as dataset:
return ReplayBuffer.from_data(
obs=dataset["observations"],
act=dataset["actions"],
rew=dataset["rewards"],
done=dataset["terminals"],
obs_next=dataset["next_observations"],
terminated=dataset["terminals"],
truncated=np.zeros(len(dataset["terminals"])),
)
def normalize_all_obs_in_replay_buffer(
replay_buffer: ReplayBuffer,
) -> tuple[ReplayBuffer, RunningMeanStd]:
# compute obs mean and var
obs_rms = RunningMeanStd()
obs_rms.update(replay_buffer.obs)
_eps = np.finfo(np.float32).eps.item()
# normalize obs
replay_buffer._meta["obs"] = (replay_buffer.obs - obs_rms.mean) / np.sqrt(obs_rms.var + _eps)
replay_buffer._meta["obs_next"] = (replay_buffer.obs_next - obs_rms.mean) / np.sqrt(
obs_rms.var + _eps,
)
return replay_buffer, obs_rms