Preparation for #914 and #920 Changes formatting to ruff and black. Remove python 3.8 ## Additional Changes - Removed flake8 dependencies - Adjusted pre-commit. Now CI and Make use pre-commit, reducing the duplication of linting calls - Removed check-docstyle option (ruff is doing that) - Merged format and lint. In CI the format-lint step fails if any changes are done, so it fulfills the lint functionality. --------- Co-authored-by: Jiayi Weng <jiayi@openai.com>
		
			
				
	
	
		
			49 lines
		
	
	
		
			1.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			49 lines
		
	
	
		
			1.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import d4rl
 | |
| import gymnasium as gym
 | |
| import h5py
 | |
| import numpy as np
 | |
| 
 | |
| from tianshou.data import ReplayBuffer
 | |
| from tianshou.utils import RunningMeanStd
 | |
| 
 | |
| 
 | |
| def load_buffer_d4rl(expert_data_task: str) -> ReplayBuffer:
 | |
|     dataset = d4rl.qlearning_dataset(gym.make(expert_data_task))
 | |
|     return ReplayBuffer.from_data(
 | |
|         obs=dataset["observations"],
 | |
|         act=dataset["actions"],
 | |
|         rew=dataset["rewards"],
 | |
|         done=dataset["terminals"],
 | |
|         obs_next=dataset["next_observations"],
 | |
|         terminated=dataset["terminals"],
 | |
|         truncated=np.zeros(len(dataset["terminals"])),
 | |
|     )
 | |
| 
 | |
| 
 | |
| def load_buffer(buffer_path: str) -> ReplayBuffer:
 | |
|     with h5py.File(buffer_path, "r") as dataset:
 | |
|         return ReplayBuffer.from_data(
 | |
|             obs=dataset["observations"],
 | |
|             act=dataset["actions"],
 | |
|             rew=dataset["rewards"],
 | |
|             done=dataset["terminals"],
 | |
|             obs_next=dataset["next_observations"],
 | |
|             terminated=dataset["terminals"],
 | |
|             truncated=np.zeros(len(dataset["terminals"])),
 | |
|         )
 | |
| 
 | |
| 
 | |
| def normalize_all_obs_in_replay_buffer(
 | |
|     replay_buffer: ReplayBuffer,
 | |
| ) -> tuple[ReplayBuffer, RunningMeanStd]:
 | |
|     # compute obs mean and var
 | |
|     obs_rms = RunningMeanStd()
 | |
|     obs_rms.update(replay_buffer.obs)
 | |
|     _eps = np.finfo(np.float32).eps.item()
 | |
|     # normalize obs
 | |
|     replay_buffer._meta["obs"] = (replay_buffer.obs - obs_rms.mean) / np.sqrt(obs_rms.var + _eps)
 | |
|     replay_buffer._meta["obs_next"] = (replay_buffer.obs_next - obs_rms.mean) / np.sqrt(
 | |
|         obs_rms.var + _eps,
 | |
|     )
 | |
|     return replay_buffer, obs_rms
 |