Tianshou/examples/offline/convert_rl_unplugged_atari.py
Daniel Plop 8a0629ded6
Fix mypy issues in tests and examples (#1077)
Closes #952 

- `SamplingConfig` supports `batch_size=None`. #1077
- tests and examples are covered by `mypy`. #1077
- `NetBase` is more used, stricter typing by making it generic. #1077
- `utils.net.common.Recurrent` now receives and returns a
`RecurrentStateBatch` instead of a dict. #1077

---------

Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
2024-04-03 18:07:51 +02:00

293 lines
8.8 KiB
Python
Executable File

#!/usr/bin/env python3
#
# Adapted from
# https://github.com/deepmind/deepmind-research/blob/master/rl_unplugged/atari.py
#
"""Convert Atari RL Unplugged datasets to HDF5 format.
Examples in the dataset represent SARSA transitions stored during a
DQN training run as described in https://arxiv.org/pdf/1907.04543.
For every training run we have recorded all 50 million transitions corresponding
to 200 million environment steps (4x factor because of frame skipping). There
are 5 separate datasets for each of the 45 games.
Every transition in the dataset is a tuple containing the following features:
* o_t: Observation at time t. Observations have been processed using the
canonical Atari frame processing, including 4x frame stacking. The shape
of a single observation is [84, 84, 4].
* a_t: Action taken at time t.
* r_t: Reward after a_t.
* d_t: Discount after a_t.
* o_tp1: Observation at time t+1.
* a_tp1: Action at time t+1.
* extras:
* episode_id: Episode identifier.
* episode_return: Total episode return computed using per-step [-1, 1]
clipping.
"""
import os
from argparse import ArgumentParser, Namespace
import h5py
import numpy as np
import numpy.typing as npt
import requests
import tensorflow as tf
from tqdm import tqdm
from tianshou.data import Batch
tf.config.set_visible_devices([], "GPU")
# 9 tuning games.
TUNING_SUITE = [
"BeamRider",
"DemonAttack",
"DoubleDunk",
"IceHockey",
"MsPacman",
"Pooyan",
"RoadRunner",
"Robotank",
"Zaxxon",
]
# 36 testing games.
TESTING_SUITE = [
"Alien",
"Amidar",
"Assault",
"Asterix",
"Atlantis",
"BankHeist",
"BattleZone",
"Boxing",
"Breakout",
"Carnival",
"Centipede",
"ChopperCommand",
"CrazyClimber",
"Enduro",
"FishingDerby",
"Freeway",
"Frostbite",
"Gopher",
"Gravitar",
"Hero",
"Jamesbond",
"Kangaroo",
"Krull",
"KungFuMaster",
"NameThisGame",
"Phoenix",
"Pong",
"Qbert",
"Riverraid",
"Seaquest",
"SpaceInvaders",
"StarGunner",
"TimePilot",
"UpNDown",
"VideoPinball",
"WizardOfWor",
"YarsRevenge",
]
# Total of 45 games.
ALL_GAMES = TUNING_SUITE + TESTING_SUITE
URL_PREFIX = "http://storage.googleapis.com/rl_unplugged/atari"
def _filename(run_id: int, shard_id: int, total_num_shards: int = 100) -> str:
return f"run_{run_id}-{shard_id:05d}-of-{total_num_shards:05d}"
def _decode_frames(pngs: tf.Tensor) -> tf.Tensor:
"""Decode PNGs.
:param pngs: String Tensor of size (4,) containing PNG encoded images.
:returns: Tensor of size (4, 84, 84) containing decoded images.
"""
# Statically unroll png decoding
frames = [tf.image.decode_png(pngs[i], channels=1) for i in range(4)]
# NOTE: to match tianshou's convention for framestacking
frames = tf.squeeze(tf.stack(frames, axis=0))
frames.set_shape((4, 84, 84))
return frames
def _make_tianshou_batch(
o_t: tf.Tensor,
a_t: tf.Tensor,
r_t: tf.Tensor,
d_t: tf.Tensor,
o_tp1: tf.Tensor,
a_tp1: tf.Tensor,
) -> Batch:
"""Create Tianshou batch with offline data.
:param o_t: Observation at time t.
:param a_t: Action at time t.
:param r_t: Reward at time t.
:param d_t: Discount at time t.
:param o_tp1: Observation at time t+1.
:param a_tp1: Action at time t+1.
:returns: A tianshou.data.Batch object.
"""
return Batch(
obs=o_t.numpy(),
act=a_t.numpy(),
rew=r_t.numpy(),
done=1 - d_t.numpy(),
obs_next=o_tp1.numpy(),
)
def _tf_example_to_tianshou_batch(tf_example: tf.train.Example) -> Batch:
"""Create a tianshou Batch replay sample from a TF example."""
# Parse tf.Example.
feature_description = {
"o_t": tf.io.FixedLenFeature([4], tf.string),
"o_tp1": tf.io.FixedLenFeature([4], tf.string),
"a_t": tf.io.FixedLenFeature([], tf.int64),
"a_tp1": tf.io.FixedLenFeature([], tf.int64),
"r_t": tf.io.FixedLenFeature([], tf.float32),
"d_t": tf.io.FixedLenFeature([], tf.float32),
"episode_id": tf.io.FixedLenFeature([], tf.int64),
"episode_return": tf.io.FixedLenFeature([], tf.float32),
}
data = tf.io.parse_single_example(tf_example, feature_description)
# Process data.
o_t = _decode_frames(data["o_t"])
o_tp1 = _decode_frames(data["o_tp1"])
a_t = tf.cast(data["a_t"], tf.int32)
a_tp1 = tf.cast(data["a_tp1"], tf.int32)
# Build tianshou Batch replay sample.
return _make_tianshou_batch(o_t, a_t, data["r_t"], data["d_t"], o_tp1, a_tp1)
# Adapted From https://gist.github.com/yanqd0/c13ed29e29432e3cf3e7c38467f42f51
def download(url: str, fname: str, chunk_size: int | None = 1024) -> None:
resp = requests.get(url, stream=True)
total = int(resp.headers.get("content-length", 0))
if os.path.exists(fname):
print(f"Found cached file at {fname}.")
return
with open(fname, "wb") as ofile, tqdm(
desc=fname,
total=total,
unit="iB",
unit_scale=True,
unit_divisor=1024,
) as bar:
for data in resp.iter_content(chunk_size=chunk_size):
size = ofile.write(data)
bar.update(size)
def process_shard(url: str, fname: str, ofname: str, maxsize: int = 500000) -> None:
download(url, fname)
obs: npt.NDArray[np.uint8] = np.ndarray((maxsize, 4, 84, 84), dtype="uint8")
act: npt.NDArray[np.int64] = np.ndarray((maxsize,), dtype="int64")
rew: npt.NDArray[np.float32] = np.ndarray((maxsize,), dtype="float32")
done: npt.NDArray[np.bool_] = np.ndarray((maxsize,), dtype="bool")
obs_next: npt.NDArray[np.uint8] = np.ndarray((maxsize, 4, 84, 84), dtype="uint8")
i = 0
file_ds = tf.data.TFRecordDataset(fname, compression_type="GZIP")
for example in file_ds:
if i >= maxsize:
break
batch = _tf_example_to_tianshou_batch(example)
obs[i], act[i], rew[i], done[i], obs_next[i] = (
batch.obs,
batch.act,
batch.rew,
batch.done,
batch.obs_next,
)
i += 1
if i % 1000 == 0:
print(f"...{i}", end="", flush=True)
print("\nDataset size:", i)
# Following D4RL dataset naming conventions
with h5py.File(ofname, "w") as f:
f.create_dataset("observations", data=obs, compression="gzip")
f.create_dataset("actions", data=act, compression="gzip")
f.create_dataset("rewards", data=rew, compression="gzip")
f.create_dataset("terminals", data=done, compression="gzip")
f.create_dataset("next_observations", data=obs_next, compression="gzip")
def process_dataset(
task: str,
download_path: str,
dst_path: str,
run_id: int = 1,
shard_id: int = 0,
total_num_shards: int = 100,
) -> None:
fn = f"{task}/{_filename(run_id, shard_id, total_num_shards=total_num_shards)}"
url = f"{URL_PREFIX}/{fn}"
filepath = f"{download_path}/{fn}"
ofname = f"{dst_path}/{fn}.hdf5"
process_shard(url, filepath, ofname)
def main(args: Namespace) -> None:
if args.task not in ALL_GAMES:
raise KeyError(f"`{args.task}` is not in the list of games.")
fn = _filename(args.run_id, args.shard_id, total_num_shards=args.total_num_shards)
dataset_path = os.path.join(args.dataset_dir, args.task, f"{fn}.hdf5")
if os.path.exists(dataset_path):
raise OSError(f"Found existing dataset at {dataset_path}. Will not overwrite.")
args.cache_dir = os.environ.get("RLU_CACHE_DIR", args.cache_dir)
args.dataset_dir = os.environ.get("RLU_DATASET_DIR", args.dataset_dir)
cache_path = os.path.join(args.cache_dir, args.task)
os.makedirs(cache_path, exist_ok=True)
dst_path = os.path.join(args.dataset_dir, args.task)
os.makedirs(dst_path, exist_ok=True)
process_dataset(
args.task,
args.cache_dir,
args.dataset_dir,
run_id=args.run_id,
shard_id=args.shard_id,
total_num_shards=args.total_num_shards,
)
if __name__ == "__main__":
parser = ArgumentParser(usage=__doc__)
parser.add_argument("--task", required=True, help="Name of the Atari game.")
parser.add_argument(
"--run-id",
type=int,
default=1,
help="Run id to download and convert. Value in [1..5].",
)
parser.add_argument(
"--shard-id",
type=int,
default=0,
help="Shard id to download and convert. Value in [0..99].",
)
parser.add_argument("--total-num-shards", type=int, default=100, help="Total number of shards.")
parser.add_argument(
"--dataset-dir",
default=os.path.expanduser("~/.rl_unplugged/datasets"),
help="Directory for converted hdf5 files.",
)
parser.add_argument(
"--cache-dir",
default=os.path.expanduser("~/.rl_unplugged/cache"),
help="Directory for downloaded original datasets.",
)
args = parser.parse_args()
main(args)