Preparation for #914 and #920 Changes formatting to ruff and black. Remove python 3.8 ## Additional Changes - Removed flake8 dependencies - Adjusted pre-commit. Now CI and Make use pre-commit, reducing the duplication of linting calls - Removed check-docstyle option (ruff is doing that) - Merged format and lint. In CI the format-lint step fails if any changes are done, so it fulfills the lint functionality. --------- Co-authored-by: Jiayi Weng <jiayi@openai.com>
		
			
				
	
	
		
			292 lines
		
	
	
		
			8.7 KiB
		
	
	
	
		
			Python
		
	
	
		
			Executable File
		
	
	
	
	
			
		
		
	
	
			292 lines
		
	
	
		
			8.7 KiB
		
	
	
	
		
			Python
		
	
	
		
			Executable File
		
	
	
	
	
#!/usr/bin/env python3
 | 
						|
#
 | 
						|
# Adapted from
 | 
						|
# https://github.com/deepmind/deepmind-research/blob/master/rl_unplugged/atari.py
 | 
						|
#
 | 
						|
"""Convert Atari RL Unplugged datasets to HDF5 format.
 | 
						|
 | 
						|
Examples in the dataset represent SARSA transitions stored during a
 | 
						|
DQN training run as described in https://arxiv.org/pdf/1907.04543.
 | 
						|
 | 
						|
For every training run we have recorded all 50 million transitions corresponding
 | 
						|
to 200 million environment steps (4x factor because of frame skipping). There
 | 
						|
are 5 separate datasets for each of the 45 games.
 | 
						|
 | 
						|
Every transition in the dataset is a tuple containing the following features:
 | 
						|
 | 
						|
* o_t: Observation at time t. Observations have been processed using the
 | 
						|
    canonical Atari frame processing, including 4x frame stacking. The shape
 | 
						|
    of a single observation is [84, 84, 4].
 | 
						|
* a_t: Action taken at time t.
 | 
						|
* r_t: Reward after a_t.
 | 
						|
* d_t: Discount after a_t.
 | 
						|
* o_tp1: Observation at time t+1.
 | 
						|
* a_tp1: Action at time t+1.
 | 
						|
* extras:
 | 
						|
  * episode_id: Episode identifier.
 | 
						|
  * episode_return: Total episode return computed using per-step [-1, 1]
 | 
						|
      clipping.
 | 
						|
"""
 | 
						|
import os
 | 
						|
from argparse import ArgumentParser
 | 
						|
 | 
						|
import h5py
 | 
						|
import numpy as np
 | 
						|
import requests
 | 
						|
import tensorflow as tf
 | 
						|
from tqdm import tqdm
 | 
						|
 | 
						|
from tianshou.data import Batch
 | 
						|
 | 
						|
tf.config.set_visible_devices([], "GPU")
 | 
						|
 | 
						|
# 9 tuning games.
 | 
						|
TUNING_SUITE = [
 | 
						|
    "BeamRider",
 | 
						|
    "DemonAttack",
 | 
						|
    "DoubleDunk",
 | 
						|
    "IceHockey",
 | 
						|
    "MsPacman",
 | 
						|
    "Pooyan",
 | 
						|
    "RoadRunner",
 | 
						|
    "Robotank",
 | 
						|
    "Zaxxon",
 | 
						|
]
 | 
						|
 | 
						|
# 36 testing games.
 | 
						|
TESTING_SUITE = [
 | 
						|
    "Alien",
 | 
						|
    "Amidar",
 | 
						|
    "Assault",
 | 
						|
    "Asterix",
 | 
						|
    "Atlantis",
 | 
						|
    "BankHeist",
 | 
						|
    "BattleZone",
 | 
						|
    "Boxing",
 | 
						|
    "Breakout",
 | 
						|
    "Carnival",
 | 
						|
    "Centipede",
 | 
						|
    "ChopperCommand",
 | 
						|
    "CrazyClimber",
 | 
						|
    "Enduro",
 | 
						|
    "FishingDerby",
 | 
						|
    "Freeway",
 | 
						|
    "Frostbite",
 | 
						|
    "Gopher",
 | 
						|
    "Gravitar",
 | 
						|
    "Hero",
 | 
						|
    "Jamesbond",
 | 
						|
    "Kangaroo",
 | 
						|
    "Krull",
 | 
						|
    "KungFuMaster",
 | 
						|
    "NameThisGame",
 | 
						|
    "Phoenix",
 | 
						|
    "Pong",
 | 
						|
    "Qbert",
 | 
						|
    "Riverraid",
 | 
						|
    "Seaquest",
 | 
						|
    "SpaceInvaders",
 | 
						|
    "StarGunner",
 | 
						|
    "TimePilot",
 | 
						|
    "UpNDown",
 | 
						|
    "VideoPinball",
 | 
						|
    "WizardOfWor",
 | 
						|
    "YarsRevenge",
 | 
						|
]
 | 
						|
 | 
						|
# Total of 45 games.
 | 
						|
ALL_GAMES = TUNING_SUITE + TESTING_SUITE
 | 
						|
URL_PREFIX = "http://storage.googleapis.com/rl_unplugged/atari"
 | 
						|
 | 
						|
 | 
						|
def _filename(run_id: int, shard_id: int, total_num_shards: int = 100) -> str:
 | 
						|
    return f"run_{run_id}-{shard_id:05d}-of-{total_num_shards:05d}"
 | 
						|
 | 
						|
 | 
						|
def _decode_frames(pngs: tf.Tensor) -> tf.Tensor:
 | 
						|
    """Decode PNGs.
 | 
						|
 | 
						|
    :param pngs: String Tensor of size (4,) containing PNG encoded images.
 | 
						|
 | 
						|
    :returns: Tensor of size (4, 84, 84) containing decoded images.
 | 
						|
    """
 | 
						|
    # Statically unroll png decoding
 | 
						|
    frames = [tf.image.decode_png(pngs[i], channels=1) for i in range(4)]
 | 
						|
    # NOTE: to match tianshou's convention for framestacking
 | 
						|
    frames = tf.squeeze(tf.stack(frames, axis=0))
 | 
						|
    frames.set_shape((4, 84, 84))
 | 
						|
    return frames
 | 
						|
 | 
						|
 | 
						|
def _make_tianshou_batch(
 | 
						|
    o_t: tf.Tensor,
 | 
						|
    a_t: tf.Tensor,
 | 
						|
    r_t: tf.Tensor,
 | 
						|
    d_t: tf.Tensor,
 | 
						|
    o_tp1: tf.Tensor,
 | 
						|
    a_tp1: tf.Tensor,
 | 
						|
) -> Batch:
 | 
						|
    """Create Tianshou batch with offline data.
 | 
						|
 | 
						|
    :param o_t: Observation at time t.
 | 
						|
    :param a_t: Action at time t.
 | 
						|
    :param r_t: Reward at time t.
 | 
						|
    :param d_t: Discount at time t.
 | 
						|
    :param o_tp1: Observation at time t+1.
 | 
						|
    :param a_tp1: Action at time t+1.
 | 
						|
 | 
						|
    :returns: A tianshou.data.Batch object.
 | 
						|
    """
 | 
						|
    return Batch(
 | 
						|
        obs=o_t.numpy(),
 | 
						|
        act=a_t.numpy(),
 | 
						|
        rew=r_t.numpy(),
 | 
						|
        done=1 - d_t.numpy(),
 | 
						|
        obs_next=o_tp1.numpy(),
 | 
						|
    )
 | 
						|
 | 
						|
 | 
						|
def _tf_example_to_tianshou_batch(tf_example: tf.train.Example) -> Batch:
 | 
						|
    """Create a tianshou Batch replay sample from a TF example."""
 | 
						|
    # Parse tf.Example.
 | 
						|
    feature_description = {
 | 
						|
        "o_t": tf.io.FixedLenFeature([4], tf.string),
 | 
						|
        "o_tp1": tf.io.FixedLenFeature([4], tf.string),
 | 
						|
        "a_t": tf.io.FixedLenFeature([], tf.int64),
 | 
						|
        "a_tp1": tf.io.FixedLenFeature([], tf.int64),
 | 
						|
        "r_t": tf.io.FixedLenFeature([], tf.float32),
 | 
						|
        "d_t": tf.io.FixedLenFeature([], tf.float32),
 | 
						|
        "episode_id": tf.io.FixedLenFeature([], tf.int64),
 | 
						|
        "episode_return": tf.io.FixedLenFeature([], tf.float32),
 | 
						|
    }
 | 
						|
    data = tf.io.parse_single_example(tf_example, feature_description)
 | 
						|
 | 
						|
    # Process data.
 | 
						|
    o_t = _decode_frames(data["o_t"])
 | 
						|
    o_tp1 = _decode_frames(data["o_tp1"])
 | 
						|
    a_t = tf.cast(data["a_t"], tf.int32)
 | 
						|
    a_tp1 = tf.cast(data["a_tp1"], tf.int32)
 | 
						|
 | 
						|
    # Build tianshou Batch replay sample.
 | 
						|
    return _make_tianshou_batch(o_t, a_t, data["r_t"], data["d_t"], o_tp1, a_tp1)
 | 
						|
 | 
						|
 | 
						|
# Adapted From https://gist.github.com/yanqd0/c13ed29e29432e3cf3e7c38467f42f51
 | 
						|
def download(url: str, fname: str, chunk_size=1024):
 | 
						|
    resp = requests.get(url, stream=True)
 | 
						|
    total = int(resp.headers.get("content-length", 0))
 | 
						|
    if os.path.exists(fname):
 | 
						|
        print(f"Found cached file at {fname}.")
 | 
						|
        return
 | 
						|
    with open(fname, "wb") as ofile, tqdm(
 | 
						|
        desc=fname,
 | 
						|
        total=total,
 | 
						|
        unit="iB",
 | 
						|
        unit_scale=True,
 | 
						|
        unit_divisor=1024,
 | 
						|
    ) as bar:
 | 
						|
        for data in resp.iter_content(chunk_size=chunk_size):
 | 
						|
            size = ofile.write(data)
 | 
						|
            bar.update(size)
 | 
						|
 | 
						|
 | 
						|
def process_shard(url: str, fname: str, ofname: str, maxsize: int = 500000) -> None:
 | 
						|
    download(url, fname)
 | 
						|
    obs = np.ndarray((maxsize, 4, 84, 84), dtype="uint8")
 | 
						|
    act = np.ndarray((maxsize,), dtype="int64")
 | 
						|
    rew = np.ndarray((maxsize,), dtype="float32")
 | 
						|
    done = np.ndarray((maxsize,), dtype="bool")
 | 
						|
    obs_next = np.ndarray((maxsize, 4, 84, 84), dtype="uint8")
 | 
						|
    i = 0
 | 
						|
    file_ds = tf.data.TFRecordDataset(fname, compression_type="GZIP")
 | 
						|
    for example in file_ds:
 | 
						|
        if i >= maxsize:
 | 
						|
            break
 | 
						|
        batch = _tf_example_to_tianshou_batch(example)
 | 
						|
        obs[i], act[i], rew[i], done[i], obs_next[i] = (
 | 
						|
            batch.obs,
 | 
						|
            batch.act,
 | 
						|
            batch.rew,
 | 
						|
            batch.done,
 | 
						|
            batch.obs_next,
 | 
						|
        )
 | 
						|
        i += 1
 | 
						|
        if i % 1000 == 0:
 | 
						|
            print(f"...{i}", end="", flush=True)
 | 
						|
    print("\nDataset size:", i)
 | 
						|
    # Following D4RL dataset naming conventions
 | 
						|
    with h5py.File(ofname, "w") as f:
 | 
						|
        f.create_dataset("observations", data=obs, compression="gzip")
 | 
						|
        f.create_dataset("actions", data=act, compression="gzip")
 | 
						|
        f.create_dataset("rewards", data=rew, compression="gzip")
 | 
						|
        f.create_dataset("terminals", data=done, compression="gzip")
 | 
						|
        f.create_dataset("next_observations", data=obs_next, compression="gzip")
 | 
						|
 | 
						|
 | 
						|
def process_dataset(
 | 
						|
    task: str,
 | 
						|
    download_path: str,
 | 
						|
    dst_path: str,
 | 
						|
    run_id: int = 1,
 | 
						|
    shard_id: int = 0,
 | 
						|
    total_num_shards: int = 100,
 | 
						|
) -> None:
 | 
						|
    fn = f"{task}/{_filename(run_id, shard_id, total_num_shards=total_num_shards)}"
 | 
						|
    url = f"{URL_PREFIX}/{fn}"
 | 
						|
    filepath = f"{download_path}/{fn}"
 | 
						|
    ofname = f"{dst_path}/{fn}.hdf5"
 | 
						|
    process_shard(url, filepath, ofname)
 | 
						|
 | 
						|
 | 
						|
def main(args):
 | 
						|
    if args.task not in ALL_GAMES:
 | 
						|
        raise KeyError(f"`{args.task}` is not in the list of games.")
 | 
						|
    fn = _filename(args.run_id, args.shard_id, total_num_shards=args.total_num_shards)
 | 
						|
    dataset_path = os.path.join(args.dataset_dir, args.task, f"{fn}.hdf5")
 | 
						|
    if os.path.exists(dataset_path):
 | 
						|
        raise OSError(f"Found existing dataset at {dataset_path}. Will not overwrite.")
 | 
						|
    args.cache_dir = os.environ.get("RLU_CACHE_DIR", args.cache_dir)
 | 
						|
    args.dataset_dir = os.environ.get("RLU_DATASET_DIR", args.dataset_dir)
 | 
						|
    cache_path = os.path.join(args.cache_dir, args.task)
 | 
						|
    os.makedirs(cache_path, exist_ok=True)
 | 
						|
    dst_path = os.path.join(args.dataset_dir, args.task)
 | 
						|
    os.makedirs(dst_path, exist_ok=True)
 | 
						|
    process_dataset(
 | 
						|
        args.task,
 | 
						|
        args.cache_dir,
 | 
						|
        args.dataset_dir,
 | 
						|
        run_id=args.run_id,
 | 
						|
        shard_id=args.shard_id,
 | 
						|
        total_num_shards=args.total_num_shards,
 | 
						|
    )
 | 
						|
 | 
						|
 | 
						|
if __name__ == "__main__":
 | 
						|
    parser = ArgumentParser(usage=__doc__)
 | 
						|
    parser.add_argument("--task", required=True, help="Name of the Atari game.")
 | 
						|
    parser.add_argument(
 | 
						|
        "--run-id",
 | 
						|
        type=int,
 | 
						|
        default=1,
 | 
						|
        help="Run id to download and convert. Value in [1..5].",
 | 
						|
    )
 | 
						|
    parser.add_argument(
 | 
						|
        "--shard-id",
 | 
						|
        type=int,
 | 
						|
        default=0,
 | 
						|
        help="Shard id to download and convert. Value in [0..99].",
 | 
						|
    )
 | 
						|
    parser.add_argument("--total-num-shards", type=int, default=100, help="Total number of shards.")
 | 
						|
    parser.add_argument(
 | 
						|
        "--dataset-dir",
 | 
						|
        default=os.path.expanduser("~/.rl_unplugged/datasets"),
 | 
						|
        help="Directory for converted hdf5 files.",
 | 
						|
    )
 | 
						|
    parser.add_argument(
 | 
						|
        "--cache-dir",
 | 
						|
        default=os.path.expanduser("~/.rl_unplugged/cache"),
 | 
						|
        help="Directory for downloaded original datasets.",
 | 
						|
    )
 | 
						|
    args = parser.parse_args()
 | 
						|
    main(args)
 |