Improves typing in examples and tests, towards mypy passing there. Introduces the SpaceInfo utility
		
			
				
	
	
		
			292 lines
		
	
	
		
			8.7 KiB
		
	
	
	
		
			Python
		
	
	
		
			Executable File
		
	
	
	
	
			
		
		
	
	
			292 lines
		
	
	
		
			8.7 KiB
		
	
	
	
		
			Python
		
	
	
		
			Executable File
		
	
	
	
	
| #!/usr/bin/env python3
 | |
| #
 | |
| # Adapted from
 | |
| # https://github.com/deepmind/deepmind-research/blob/master/rl_unplugged/atari.py
 | |
| #
 | |
| """Convert Atari RL Unplugged datasets to HDF5 format.
 | |
| 
 | |
| Examples in the dataset represent SARSA transitions stored during a
 | |
| DQN training run as described in https://arxiv.org/pdf/1907.04543.
 | |
| 
 | |
| For every training run we have recorded all 50 million transitions corresponding
 | |
| to 200 million environment steps (4x factor because of frame skipping). There
 | |
| are 5 separate datasets for each of the 45 games.
 | |
| 
 | |
| Every transition in the dataset is a tuple containing the following features:
 | |
| 
 | |
| * o_t: Observation at time t. Observations have been processed using the
 | |
|     canonical Atari frame processing, including 4x frame stacking. The shape
 | |
|     of a single observation is [84, 84, 4].
 | |
| * a_t: Action taken at time t.
 | |
| * r_t: Reward after a_t.
 | |
| * d_t: Discount after a_t.
 | |
| * o_tp1: Observation at time t+1.
 | |
| * a_tp1: Action at time t+1.
 | |
| * extras:
 | |
|   * episode_id: Episode identifier.
 | |
|   * episode_return: Total episode return computed using per-step [-1, 1]
 | |
|       clipping.
 | |
| """
 | |
| import os
 | |
| from argparse import ArgumentParser
 | |
| 
 | |
| import h5py
 | |
| import numpy as np
 | |
| import requests
 | |
| import tensorflow as tf
 | |
| from tqdm import tqdm
 | |
| 
 | |
| from tianshou.data import Batch
 | |
| 
 | |
| tf.config.set_visible_devices([], "GPU")
 | |
| 
 | |
| # 9 tuning games.
 | |
| TUNING_SUITE = [
 | |
|     "BeamRider",
 | |
|     "DemonAttack",
 | |
|     "DoubleDunk",
 | |
|     "IceHockey",
 | |
|     "MsPacman",
 | |
|     "Pooyan",
 | |
|     "RoadRunner",
 | |
|     "Robotank",
 | |
|     "Zaxxon",
 | |
| ]
 | |
| 
 | |
| # 36 testing games.
 | |
| TESTING_SUITE = [
 | |
|     "Alien",
 | |
|     "Amidar",
 | |
|     "Assault",
 | |
|     "Asterix",
 | |
|     "Atlantis",
 | |
|     "BankHeist",
 | |
|     "BattleZone",
 | |
|     "Boxing",
 | |
|     "Breakout",
 | |
|     "Carnival",
 | |
|     "Centipede",
 | |
|     "ChopperCommand",
 | |
|     "CrazyClimber",
 | |
|     "Enduro",
 | |
|     "FishingDerby",
 | |
|     "Freeway",
 | |
|     "Frostbite",
 | |
|     "Gopher",
 | |
|     "Gravitar",
 | |
|     "Hero",
 | |
|     "Jamesbond",
 | |
|     "Kangaroo",
 | |
|     "Krull",
 | |
|     "KungFuMaster",
 | |
|     "NameThisGame",
 | |
|     "Phoenix",
 | |
|     "Pong",
 | |
|     "Qbert",
 | |
|     "Riverraid",
 | |
|     "Seaquest",
 | |
|     "SpaceInvaders",
 | |
|     "StarGunner",
 | |
|     "TimePilot",
 | |
|     "UpNDown",
 | |
|     "VideoPinball",
 | |
|     "WizardOfWor",
 | |
|     "YarsRevenge",
 | |
| ]
 | |
| 
 | |
| # Total of 45 games.
 | |
| ALL_GAMES = TUNING_SUITE + TESTING_SUITE
 | |
| URL_PREFIX = "http://storage.googleapis.com/rl_unplugged/atari"
 | |
| 
 | |
| 
 | |
| def _filename(run_id: int, shard_id: int, total_num_shards: int = 100) -> str:
 | |
|     return f"run_{run_id}-{shard_id:05d}-of-{total_num_shards:05d}"
 | |
| 
 | |
| 
 | |
| def _decode_frames(pngs: tf.Tensor) -> tf.Tensor:
 | |
|     """Decode PNGs.
 | |
| 
 | |
|     :param pngs: String Tensor of size (4,) containing PNG encoded images.
 | |
| 
 | |
|     :returns: Tensor of size (4, 84, 84) containing decoded images.
 | |
|     """
 | |
|     # Statically unroll png decoding
 | |
|     frames = [tf.image.decode_png(pngs[i], channels=1) for i in range(4)]
 | |
|     # NOTE: to match tianshou's convention for framestacking
 | |
|     frames = tf.squeeze(tf.stack(frames, axis=0))
 | |
|     frames.set_shape((4, 84, 84))
 | |
|     return frames
 | |
| 
 | |
| 
 | |
| def _make_tianshou_batch(
 | |
|     o_t: tf.Tensor,
 | |
|     a_t: tf.Tensor,
 | |
|     r_t: tf.Tensor,
 | |
|     d_t: tf.Tensor,
 | |
|     o_tp1: tf.Tensor,
 | |
|     a_tp1: tf.Tensor,
 | |
| ) -> Batch:
 | |
|     """Create Tianshou batch with offline data.
 | |
| 
 | |
|     :param o_t: Observation at time t.
 | |
|     :param a_t: Action at time t.
 | |
|     :param r_t: Reward at time t.
 | |
|     :param d_t: Discount at time t.
 | |
|     :param o_tp1: Observation at time t+1.
 | |
|     :param a_tp1: Action at time t+1.
 | |
| 
 | |
|     :returns: A tianshou.data.Batch object.
 | |
|     """
 | |
|     return Batch(
 | |
|         obs=o_t.numpy(),
 | |
|         act=a_t.numpy(),
 | |
|         rew=r_t.numpy(),
 | |
|         done=1 - d_t.numpy(),
 | |
|         obs_next=o_tp1.numpy(),
 | |
|     )
 | |
| 
 | |
| 
 | |
| def _tf_example_to_tianshou_batch(tf_example: tf.train.Example) -> Batch:
 | |
|     """Create a tianshou Batch replay sample from a TF example."""
 | |
|     # Parse tf.Example.
 | |
|     feature_description = {
 | |
|         "o_t": tf.io.FixedLenFeature([4], tf.string),
 | |
|         "o_tp1": tf.io.FixedLenFeature([4], tf.string),
 | |
|         "a_t": tf.io.FixedLenFeature([], tf.int64),
 | |
|         "a_tp1": tf.io.FixedLenFeature([], tf.int64),
 | |
|         "r_t": tf.io.FixedLenFeature([], tf.float32),
 | |
|         "d_t": tf.io.FixedLenFeature([], tf.float32),
 | |
|         "episode_id": tf.io.FixedLenFeature([], tf.int64),
 | |
|         "episode_return": tf.io.FixedLenFeature([], tf.float32),
 | |
|     }
 | |
|     data = tf.io.parse_single_example(tf_example, feature_description)
 | |
| 
 | |
|     # Process data.
 | |
|     o_t = _decode_frames(data["o_t"])
 | |
|     o_tp1 = _decode_frames(data["o_tp1"])
 | |
|     a_t = tf.cast(data["a_t"], tf.int32)
 | |
|     a_tp1 = tf.cast(data["a_tp1"], tf.int32)
 | |
| 
 | |
|     # Build tianshou Batch replay sample.
 | |
|     return _make_tianshou_batch(o_t, a_t, data["r_t"], data["d_t"], o_tp1, a_tp1)
 | |
| 
 | |
| 
 | |
| # Adapted From https://gist.github.com/yanqd0/c13ed29e29432e3cf3e7c38467f42f51
 | |
| def download(url: str, fname: str, chunk_size=1024):
 | |
|     resp = requests.get(url, stream=True)
 | |
|     total = int(resp.headers.get("content-length", 0))
 | |
|     if os.path.exists(fname):
 | |
|         print(f"Found cached file at {fname}.")
 | |
|         return
 | |
|     with open(fname, "wb") as ofile, tqdm(
 | |
|         desc=fname,
 | |
|         total=total,
 | |
|         unit="iB",
 | |
|         unit_scale=True,
 | |
|         unit_divisor=1024,
 | |
|     ) as bar:
 | |
|         for data in resp.iter_content(chunk_size=chunk_size):
 | |
|             size = ofile.write(data)
 | |
|             bar.update(size)
 | |
| 
 | |
| 
 | |
| def process_shard(url: str, fname: str, ofname: str, maxsize: int = 500000) -> None:
 | |
|     download(url, fname)
 | |
|     obs = np.ndarray((maxsize, 4, 84, 84), dtype="uint8")
 | |
|     act = np.ndarray((maxsize,), dtype="int64")
 | |
|     rew = np.ndarray((maxsize,), dtype="float32")
 | |
|     done = np.ndarray((maxsize,), dtype="bool")
 | |
|     obs_next = np.ndarray((maxsize, 4, 84, 84), dtype="uint8")
 | |
|     i = 0
 | |
|     file_ds = tf.data.TFRecordDataset(fname, compression_type="GZIP")
 | |
|     for example in file_ds:
 | |
|         if i >= maxsize:
 | |
|             break
 | |
|         batch = _tf_example_to_tianshou_batch(example)
 | |
|         obs[i], act[i], rew[i], done[i], obs_next[i] = (
 | |
|             batch.obs,
 | |
|             batch.act,
 | |
|             batch.rew,
 | |
|             batch.done,
 | |
|             batch.obs_next,
 | |
|         )
 | |
|         i += 1
 | |
|         if i % 1000 == 0:
 | |
|             print(f"...{i}", end="", flush=True)
 | |
|     print("\nDataset size:", i)
 | |
|     # Following D4RL dataset naming conventions
 | |
|     with h5py.File(ofname, "w") as f:
 | |
|         f.create_dataset("observations", data=obs, compression="gzip")
 | |
|         f.create_dataset("actions", data=act, compression="gzip")
 | |
|         f.create_dataset("rewards", data=rew, compression="gzip")
 | |
|         f.create_dataset("terminals", data=done, compression="gzip")
 | |
|         f.create_dataset("next_observations", data=obs_next, compression="gzip")
 | |
| 
 | |
| 
 | |
| def process_dataset(
 | |
|     task: str,
 | |
|     download_path: str,
 | |
|     dst_path: str,
 | |
|     run_id: int = 1,
 | |
|     shard_id: int = 0,
 | |
|     total_num_shards: int = 100,
 | |
| ) -> None:
 | |
|     fn = f"{task}/{_filename(run_id, shard_id, total_num_shards=total_num_shards)}"
 | |
|     url = f"{URL_PREFIX}/{fn}"
 | |
|     filepath = f"{download_path}/{fn}"
 | |
|     ofname = f"{dst_path}/{fn}.hdf5"
 | |
|     process_shard(url, filepath, ofname)
 | |
| 
 | |
| 
 | |
| def main(args) -> None:
 | |
|     if args.task not in ALL_GAMES:
 | |
|         raise KeyError(f"`{args.task}` is not in the list of games.")
 | |
|     fn = _filename(args.run_id, args.shard_id, total_num_shards=args.total_num_shards)
 | |
|     dataset_path = os.path.join(args.dataset_dir, args.task, f"{fn}.hdf5")
 | |
|     if os.path.exists(dataset_path):
 | |
|         raise OSError(f"Found existing dataset at {dataset_path}. Will not overwrite.")
 | |
|     args.cache_dir = os.environ.get("RLU_CACHE_DIR", args.cache_dir)
 | |
|     args.dataset_dir = os.environ.get("RLU_DATASET_DIR", args.dataset_dir)
 | |
|     cache_path = os.path.join(args.cache_dir, args.task)
 | |
|     os.makedirs(cache_path, exist_ok=True)
 | |
|     dst_path = os.path.join(args.dataset_dir, args.task)
 | |
|     os.makedirs(dst_path, exist_ok=True)
 | |
|     process_dataset(
 | |
|         args.task,
 | |
|         args.cache_dir,
 | |
|         args.dataset_dir,
 | |
|         run_id=args.run_id,
 | |
|         shard_id=args.shard_id,
 | |
|         total_num_shards=args.total_num_shards,
 | |
|     )
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     parser = ArgumentParser(usage=__doc__)
 | |
|     parser.add_argument("--task", required=True, help="Name of the Atari game.")
 | |
|     parser.add_argument(
 | |
|         "--run-id",
 | |
|         type=int,
 | |
|         default=1,
 | |
|         help="Run id to download and convert. Value in [1..5].",
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         "--shard-id",
 | |
|         type=int,
 | |
|         default=0,
 | |
|         help="Shard id to download and convert. Value in [0..99].",
 | |
|     )
 | |
|     parser.add_argument("--total-num-shards", type=int, default=100, help="Total number of shards.")
 | |
|     parser.add_argument(
 | |
|         "--dataset-dir",
 | |
|         default=os.path.expanduser("~/.rl_unplugged/datasets"),
 | |
|         help="Directory for converted hdf5 files.",
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         "--cache-dir",
 | |
|         default=os.path.expanduser("~/.rl_unplugged/cache"),
 | |
|         help="Directory for downloaded original datasets.",
 | |
|     )
 | |
|     args = parser.parse_args()
 | |
|     main(args)
 |