| 
									
										
										
										
											2022-04-29 04:33:28 -07:00
										 |  |  | #!/usr/bin/env python3 | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | # Adapted from | 
					
						
							|  |  |  | # https://github.com/deepmind/deepmind-research/blob/master/rl_unplugged/atari.py | 
					
						
							|  |  |  | # | 
					
						
							| 
									
										
										
										
											2022-05-03 13:37:52 -07:00
										 |  |  | """Convert Atari RL Unplugged datasets to HDF5 format.
 | 
					
						
							| 
									
										
										
										
											2022-04-29 04:33:28 -07:00
										 |  |  | 
 | 
					
						
							|  |  |  | Examples in the dataset represent SARSA transitions stored during a | 
					
						
							|  |  |  | DQN training run as described in https://arxiv.org/pdf/1907.04543. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | For every training run we have recorded all 50 million transitions corresponding | 
					
						
							|  |  |  | to 200 million environment steps (4x factor because of frame skipping). There | 
					
						
							|  |  |  | are 5 separate datasets for each of the 45 games. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | Every transition in the dataset is a tuple containing the following features: | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | * o_t: Observation at time t. Observations have been processed using the | 
					
						
							|  |  |  |     canonical Atari frame processing, including 4x frame stacking. The shape | 
					
						
							|  |  |  |     of a single observation is [84, 84, 4]. | 
					
						
							|  |  |  | * a_t: Action taken at time t. | 
					
						
							|  |  |  | * r_t: Reward after a_t. | 
					
						
							|  |  |  | * d_t: Discount after a_t. | 
					
						
							|  |  |  | * o_tp1: Observation at time t+1. | 
					
						
							|  |  |  | * a_tp1: Action at time t+1. | 
					
						
							|  |  |  | * extras: | 
					
						
							|  |  |  |   * episode_id: Episode identifier. | 
					
						
							|  |  |  |   * episode_return: Total episode return computed using per-step [-1, 1] | 
					
						
							|  |  |  |       clipping. | 
					
						
							|  |  |  | """
 | 
					
						
							|  |  |  | import os | 
					
						
							|  |  |  | from argparse import ArgumentParser | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-05-03 13:37:52 -07:00
										 |  |  | import h5py | 
					
						
							|  |  |  | import numpy as np | 
					
						
							| 
									
										
										
										
											2022-04-29 04:33:28 -07:00
										 |  |  | import requests | 
					
						
							|  |  |  | import tensorflow as tf | 
					
						
							|  |  |  | from tqdm import tqdm | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-05-03 13:37:52 -07:00
										 |  |  | from tianshou.data import Batch | 
					
						
							| 
									
										
										
										
											2022-04-29 04:33:28 -07:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  | tf.config.set_visible_devices([], "GPU") | 
					
						
							| 
									
										
										
										
											2022-04-29 04:33:28 -07:00
										 |  |  | 
 | 
					
						
							|  |  |  | # 9 tuning games. | 
					
						
							|  |  |  | TUNING_SUITE = [ | 
					
						
							|  |  |  |     "BeamRider", | 
					
						
							|  |  |  |     "DemonAttack", | 
					
						
							|  |  |  |     "DoubleDunk", | 
					
						
							|  |  |  |     "IceHockey", | 
					
						
							|  |  |  |     "MsPacman", | 
					
						
							|  |  |  |     "Pooyan", | 
					
						
							|  |  |  |     "RoadRunner", | 
					
						
							|  |  |  |     "Robotank", | 
					
						
							|  |  |  |     "Zaxxon", | 
					
						
							|  |  |  | ] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # 36 testing games. | 
					
						
							|  |  |  | TESTING_SUITE = [ | 
					
						
							|  |  |  |     "Alien", | 
					
						
							|  |  |  |     "Amidar", | 
					
						
							|  |  |  |     "Assault", | 
					
						
							|  |  |  |     "Asterix", | 
					
						
							|  |  |  |     "Atlantis", | 
					
						
							|  |  |  |     "BankHeist", | 
					
						
							|  |  |  |     "BattleZone", | 
					
						
							|  |  |  |     "Boxing", | 
					
						
							|  |  |  |     "Breakout", | 
					
						
							|  |  |  |     "Carnival", | 
					
						
							|  |  |  |     "Centipede", | 
					
						
							|  |  |  |     "ChopperCommand", | 
					
						
							|  |  |  |     "CrazyClimber", | 
					
						
							|  |  |  |     "Enduro", | 
					
						
							|  |  |  |     "FishingDerby", | 
					
						
							|  |  |  |     "Freeway", | 
					
						
							|  |  |  |     "Frostbite", | 
					
						
							|  |  |  |     "Gopher", | 
					
						
							|  |  |  |     "Gravitar", | 
					
						
							|  |  |  |     "Hero", | 
					
						
							|  |  |  |     "Jamesbond", | 
					
						
							|  |  |  |     "Kangaroo", | 
					
						
							|  |  |  |     "Krull", | 
					
						
							|  |  |  |     "KungFuMaster", | 
					
						
							|  |  |  |     "NameThisGame", | 
					
						
							|  |  |  |     "Phoenix", | 
					
						
							|  |  |  |     "Pong", | 
					
						
							|  |  |  |     "Qbert", | 
					
						
							|  |  |  |     "Riverraid", | 
					
						
							|  |  |  |     "Seaquest", | 
					
						
							|  |  |  |     "SpaceInvaders", | 
					
						
							|  |  |  |     "StarGunner", | 
					
						
							|  |  |  |     "TimePilot", | 
					
						
							|  |  |  |     "UpNDown", | 
					
						
							|  |  |  |     "VideoPinball", | 
					
						
							|  |  |  |     "WizardOfWor", | 
					
						
							|  |  |  |     "YarsRevenge", | 
					
						
							|  |  |  | ] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # Total of 45 games. | 
					
						
							|  |  |  | ALL_GAMES = TUNING_SUITE + TESTING_SUITE | 
					
						
							|  |  |  | URL_PREFIX = "http://storage.googleapis.com/rl_unplugged/atari" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def _filename(run_id: int, shard_id: int, total_num_shards: int = 100) -> str: | 
					
						
							|  |  |  |     return f"run_{run_id}-{shard_id:05d}-of-{total_num_shards:05d}" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def _decode_frames(pngs: tf.Tensor) -> tf.Tensor: | 
					
						
							|  |  |  |     """Decode PNGs.
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |     :param pngs: String Tensor of size (4,) containing PNG encoded images. | 
					
						
							| 
									
										
										
										
											2022-04-29 04:33:28 -07:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |     :returns: Tensor of size (4, 84, 84) containing decoded images. | 
					
						
							| 
									
										
										
										
											2022-04-29 04:33:28 -07:00
										 |  |  |     """
 | 
					
						
							|  |  |  |     # Statically unroll png decoding | 
					
						
							|  |  |  |     frames = [tf.image.decode_png(pngs[i], channels=1) for i in range(4)] | 
					
						
							|  |  |  |     # NOTE: to match tianshou's convention for framestacking | 
					
						
							|  |  |  |     frames = tf.squeeze(tf.stack(frames, axis=0)) | 
					
						
							|  |  |  |     frames.set_shape((4, 84, 84)) | 
					
						
							|  |  |  |     return frames | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def _make_tianshou_batch( | 
					
						
							|  |  |  |     o_t: tf.Tensor, | 
					
						
							|  |  |  |     a_t: tf.Tensor, | 
					
						
							|  |  |  |     r_t: tf.Tensor, | 
					
						
							|  |  |  |     d_t: tf.Tensor, | 
					
						
							|  |  |  |     o_tp1: tf.Tensor, | 
					
						
							|  |  |  |     a_tp1: tf.Tensor, | 
					
						
							|  |  |  | ) -> Batch: | 
					
						
							|  |  |  |     """Create Tianshou batch with offline data.
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |     :param o_t: Observation at time t. | 
					
						
							|  |  |  |     :param a_t: Action at time t. | 
					
						
							|  |  |  |     :param r_t: Reward at time t. | 
					
						
							|  |  |  |     :param d_t: Discount at time t. | 
					
						
							|  |  |  |     :param o_tp1: Observation at time t+1. | 
					
						
							|  |  |  |     :param a_tp1: Action at time t+1. | 
					
						
							| 
									
										
										
										
											2022-04-29 04:33:28 -07:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |     :returns: A tianshou.data.Batch object. | 
					
						
							| 
									
										
										
										
											2022-04-29 04:33:28 -07:00
										 |  |  |     """
 | 
					
						
							|  |  |  |     return Batch( | 
					
						
							|  |  |  |         obs=o_t.numpy(), | 
					
						
							|  |  |  |         act=a_t.numpy(), | 
					
						
							|  |  |  |         rew=r_t.numpy(), | 
					
						
							|  |  |  |         done=1 - d_t.numpy(), | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |         obs_next=o_tp1.numpy(), | 
					
						
							| 
									
										
										
										
											2022-04-29 04:33:28 -07:00
										 |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def _tf_example_to_tianshou_batch(tf_example: tf.train.Example) -> Batch: | 
					
						
							|  |  |  |     """Create a tianshou Batch replay sample from a TF example.""" | 
					
						
							|  |  |  |     # Parse tf.Example. | 
					
						
							|  |  |  |     feature_description = { | 
					
						
							|  |  |  |         "o_t": tf.io.FixedLenFeature([4], tf.string), | 
					
						
							|  |  |  |         "o_tp1": tf.io.FixedLenFeature([4], tf.string), | 
					
						
							|  |  |  |         "a_t": tf.io.FixedLenFeature([], tf.int64), | 
					
						
							|  |  |  |         "a_tp1": tf.io.FixedLenFeature([], tf.int64), | 
					
						
							|  |  |  |         "r_t": tf.io.FixedLenFeature([], tf.float32), | 
					
						
							|  |  |  |         "d_t": tf.io.FixedLenFeature([], tf.float32), | 
					
						
							|  |  |  |         "episode_id": tf.io.FixedLenFeature([], tf.int64), | 
					
						
							|  |  |  |         "episode_return": tf.io.FixedLenFeature([], tf.float32), | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     data = tf.io.parse_single_example(tf_example, feature_description) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Process data. | 
					
						
							|  |  |  |     o_t = _decode_frames(data["o_t"]) | 
					
						
							|  |  |  |     o_tp1 = _decode_frames(data["o_tp1"]) | 
					
						
							|  |  |  |     a_t = tf.cast(data["a_t"], tf.int32) | 
					
						
							|  |  |  |     a_tp1 = tf.cast(data["a_tp1"], tf.int32) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Build tianshou Batch replay sample. | 
					
						
							|  |  |  |     return _make_tianshou_batch(o_t, a_t, data["r_t"], data["d_t"], o_tp1, a_tp1) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # Adapted From https://gist.github.com/yanqd0/c13ed29e29432e3cf3e7c38467f42f51 | 
					
						
							|  |  |  | def download(url: str, fname: str, chunk_size=1024): | 
					
						
							|  |  |  |     resp = requests.get(url, stream=True) | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |     total = int(resp.headers.get("content-length", 0)) | 
					
						
							| 
									
										
										
										
											2022-04-29 04:33:28 -07:00
										 |  |  |     if os.path.exists(fname): | 
					
						
							|  |  |  |         print(f"Found cached file at {fname}.") | 
					
						
							|  |  |  |         return | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |     with open(fname, "wb") as ofile, tqdm( | 
					
						
							| 
									
										
										
										
											2022-04-29 04:33:28 -07:00
										 |  |  |         desc=fname, | 
					
						
							|  |  |  |         total=total, | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |         unit="iB", | 
					
						
							| 
									
										
										
										
											2022-04-29 04:33:28 -07:00
										 |  |  |         unit_scale=True, | 
					
						
							|  |  |  |         unit_divisor=1024, | 
					
						
							|  |  |  |     ) as bar: | 
					
						
							|  |  |  |         for data in resp.iter_content(chunk_size=chunk_size): | 
					
						
							|  |  |  |             size = ofile.write(data) | 
					
						
							|  |  |  |             bar.update(size) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-11-03 16:12:33 -07:00
										 |  |  | def process_shard(url: str, fname: str, ofname: str, maxsize: int = 500000) -> None: | 
					
						
							| 
									
										
										
										
											2022-04-29 04:33:28 -07:00
										 |  |  |     download(url, fname) | 
					
						
							| 
									
										
										
										
											2022-05-03 13:37:52 -07:00
										 |  |  |     obs = np.ndarray((maxsize, 4, 84, 84), dtype="uint8") | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |     act = np.ndarray((maxsize,), dtype="int64") | 
					
						
							|  |  |  |     rew = np.ndarray((maxsize,), dtype="float32") | 
					
						
							|  |  |  |     done = np.ndarray((maxsize,), dtype="bool") | 
					
						
							| 
									
										
										
										
											2022-05-03 13:37:52 -07:00
										 |  |  |     obs_next = np.ndarray((maxsize, 4, 84, 84), dtype="uint8") | 
					
						
							|  |  |  |     i = 0 | 
					
						
							| 
									
										
										
										
											2022-04-29 04:33:28 -07:00
										 |  |  |     file_ds = tf.data.TFRecordDataset(fname, compression_type="GZIP") | 
					
						
							|  |  |  |     for example in file_ds: | 
					
						
							| 
									
										
										
										
											2022-11-03 16:12:33 -07:00
										 |  |  |         if i >= maxsize: | 
					
						
							|  |  |  |             break | 
					
						
							| 
									
										
										
										
											2022-04-29 04:33:28 -07:00
										 |  |  |         batch = _tf_example_to_tianshou_batch(example) | 
					
						
							| 
									
										
										
										
											2022-05-03 13:37:52 -07:00
										 |  |  |         obs[i], act[i], rew[i], done[i], obs_next[i] = ( | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |             batch.obs, | 
					
						
							|  |  |  |             batch.act, | 
					
						
							|  |  |  |             batch.rew, | 
					
						
							|  |  |  |             batch.done, | 
					
						
							|  |  |  |             batch.obs_next, | 
					
						
							| 
									
										
										
										
											2022-05-03 13:37:52 -07:00
										 |  |  |         ) | 
					
						
							|  |  |  |         i += 1 | 
					
						
							|  |  |  |         if i % 1000 == 0: | 
					
						
							|  |  |  |             print(f"...{i}", end="", flush=True) | 
					
						
							|  |  |  |     print("\nDataset size:", i) | 
					
						
							|  |  |  |     # Following D4RL dataset naming conventions | 
					
						
							|  |  |  |     with h5py.File(ofname, "w") as f: | 
					
						
							|  |  |  |         f.create_dataset("observations", data=obs, compression="gzip") | 
					
						
							|  |  |  |         f.create_dataset("actions", data=act, compression="gzip") | 
					
						
							|  |  |  |         f.create_dataset("rewards", data=rew, compression="gzip") | 
					
						
							|  |  |  |         f.create_dataset("terminals", data=done, compression="gzip") | 
					
						
							|  |  |  |         f.create_dataset("next_observations", data=obs_next, compression="gzip") | 
					
						
							| 
									
										
										
										
											2022-04-29 04:33:28 -07:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def process_dataset( | 
					
						
							|  |  |  |     task: str, | 
					
						
							|  |  |  |     download_path: str, | 
					
						
							|  |  |  |     dst_path: str, | 
					
						
							|  |  |  |     run_id: int = 1, | 
					
						
							|  |  |  |     shard_id: int = 0, | 
					
						
							|  |  |  |     total_num_shards: int = 100, | 
					
						
							|  |  |  | ) -> None: | 
					
						
							|  |  |  |     fn = f"{task}/{_filename(run_id, shard_id, total_num_shards=total_num_shards)}" | 
					
						
							|  |  |  |     url = f"{URL_PREFIX}/{fn}" | 
					
						
							|  |  |  |     filepath = f"{download_path}/{fn}" | 
					
						
							|  |  |  |     ofname = f"{dst_path}/{fn}.hdf5" | 
					
						
							|  |  |  |     process_shard(url, filepath, ofname) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-02-06 14:24:30 +01:00
										 |  |  | def main(args) -> None: | 
					
						
							| 
									
										
										
										
											2022-04-29 04:33:28 -07:00
										 |  |  |     if args.task not in ALL_GAMES: | 
					
						
							|  |  |  |         raise KeyError(f"`{args.task}` is not in the list of games.") | 
					
						
							|  |  |  |     fn = _filename(args.run_id, args.shard_id, total_num_shards=args.total_num_shards) | 
					
						
							| 
									
										
										
										
											2022-05-03 13:37:52 -07:00
										 |  |  |     dataset_path = os.path.join(args.dataset_dir, args.task, f"{fn}.hdf5") | 
					
						
							|  |  |  |     if os.path.exists(dataset_path): | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |         raise OSError(f"Found existing dataset at {dataset_path}. Will not overwrite.") | 
					
						
							| 
									
										
										
										
											2022-05-03 13:37:52 -07:00
										 |  |  |     args.cache_dir = os.environ.get("RLU_CACHE_DIR", args.cache_dir) | 
					
						
							| 
									
										
										
										
											2022-04-29 04:33:28 -07:00
										 |  |  |     args.dataset_dir = os.environ.get("RLU_DATASET_DIR", args.dataset_dir) | 
					
						
							| 
									
										
										
										
											2022-05-03 13:37:52 -07:00
										 |  |  |     cache_path = os.path.join(args.cache_dir, args.task) | 
					
						
							|  |  |  |     os.makedirs(cache_path, exist_ok=True) | 
					
						
							|  |  |  |     dst_path = os.path.join(args.dataset_dir, args.task) | 
					
						
							| 
									
										
										
										
											2022-04-29 04:33:28 -07:00
										 |  |  |     os.makedirs(dst_path, exist_ok=True) | 
					
						
							|  |  |  |     process_dataset( | 
					
						
							|  |  |  |         args.task, | 
					
						
							| 
									
										
										
										
											2022-05-03 13:37:52 -07:00
										 |  |  |         args.cache_dir, | 
					
						
							| 
									
										
										
										
											2022-04-29 04:33:28 -07:00
										 |  |  |         args.dataset_dir, | 
					
						
							|  |  |  |         run_id=args.run_id, | 
					
						
							|  |  |  |         shard_id=args.shard_id, | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |         total_num_shards=args.total_num_shards, | 
					
						
							| 
									
										
										
										
											2022-04-29 04:33:28 -07:00
										 |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | if __name__ == "__main__": | 
					
						
							|  |  |  |     parser = ArgumentParser(usage=__doc__) | 
					
						
							|  |  |  |     parser.add_argument("--task", required=True, help="Name of the Atari game.") | 
					
						
							|  |  |  |     parser.add_argument( | 
					
						
							|  |  |  |         "--run-id", | 
					
						
							|  |  |  |         type=int, | 
					
						
							|  |  |  |         default=1, | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |         help="Run id to download and convert. Value in [1..5].", | 
					
						
							| 
									
										
										
										
											2022-04-29 04:33:28 -07:00
										 |  |  |     ) | 
					
						
							|  |  |  |     parser.add_argument( | 
					
						
							|  |  |  |         "--shard-id", | 
					
						
							|  |  |  |         type=int, | 
					
						
							|  |  |  |         default=0, | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |         help="Shard id to download and convert. Value in [0..99].", | 
					
						
							| 
									
										
										
										
											2022-04-29 04:33:28 -07:00
										 |  |  |     ) | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |     parser.add_argument("--total-num-shards", type=int, default=100, help="Total number of shards.") | 
					
						
							| 
									
										
										
										
											2022-04-29 04:33:28 -07:00
										 |  |  |     parser.add_argument( | 
					
						
							|  |  |  |         "--dataset-dir", | 
					
						
							|  |  |  |         default=os.path.expanduser("~/.rl_unplugged/datasets"), | 
					
						
							| 
									
										
										
										
											2022-05-03 13:37:52 -07:00
										 |  |  |         help="Directory for converted hdf5 files.", | 
					
						
							| 
									
										
										
										
											2022-04-29 04:33:28 -07:00
										 |  |  |     ) | 
					
						
							|  |  |  |     parser.add_argument( | 
					
						
							| 
									
										
										
										
											2022-05-03 13:37:52 -07:00
										 |  |  |         "--cache-dir", | 
					
						
							|  |  |  |         default=os.path.expanduser("~/.rl_unplugged/cache"), | 
					
						
							|  |  |  |         help="Directory for downloaded original datasets.", | 
					
						
							| 
									
										
										
										
											2022-04-29 04:33:28 -07:00
										 |  |  |     ) | 
					
						
							|  |  |  |     args = parser.parse_args() | 
					
						
							|  |  |  |     main(args) |