Tianshou/examples/offline/convert_rl_unplugged_atari.py
Michael Panchenko 600f4bbd55
Python 3.9, black + ruff formatting (#921)
Preparation for #914 and #920

Changes formatting to ruff and black. Remove python 3.8

## Additional Changes

- Removed flake8 dependencies
- Adjusted pre-commit. Now CI and Make use pre-commit, reducing the
duplication of linting calls
- Removed check-docstyle option (ruff is doing that)
- Merged format and lint. In CI the format-lint step fails if any
changes are done, so it fulfills the lint functionality.

---------

Co-authored-by: Jiayi Weng <jiayi@openai.com>
2023-08-25 14:40:56 -07:00

292 lines
8.7 KiB
Python
Executable File

#!/usr/bin/env python3
#
# Adapted from
# https://github.com/deepmind/deepmind-research/blob/master/rl_unplugged/atari.py
#
"""Convert Atari RL Unplugged datasets to HDF5 format.
Examples in the dataset represent SARSA transitions stored during a
DQN training run as described in https://arxiv.org/pdf/1907.04543.
For every training run we have recorded all 50 million transitions corresponding
to 200 million environment steps (4x factor because of frame skipping). There
are 5 separate datasets for each of the 45 games.
Every transition in the dataset is a tuple containing the following features:
* o_t: Observation at time t. Observations have been processed using the
canonical Atari frame processing, including 4x frame stacking. The shape
of a single observation is [84, 84, 4].
* a_t: Action taken at time t.
* r_t: Reward after a_t.
* d_t: Discount after a_t.
* o_tp1: Observation at time t+1.
* a_tp1: Action at time t+1.
* extras:
* episode_id: Episode identifier.
* episode_return: Total episode return computed using per-step [-1, 1]
clipping.
"""
import os
from argparse import ArgumentParser
import h5py
import numpy as np
import requests
import tensorflow as tf
from tqdm import tqdm
from tianshou.data import Batch
tf.config.set_visible_devices([], "GPU")
# 9 tuning games.
TUNING_SUITE = [
"BeamRider",
"DemonAttack",
"DoubleDunk",
"IceHockey",
"MsPacman",
"Pooyan",
"RoadRunner",
"Robotank",
"Zaxxon",
]
# 36 testing games.
TESTING_SUITE = [
"Alien",
"Amidar",
"Assault",
"Asterix",
"Atlantis",
"BankHeist",
"BattleZone",
"Boxing",
"Breakout",
"Carnival",
"Centipede",
"ChopperCommand",
"CrazyClimber",
"Enduro",
"FishingDerby",
"Freeway",
"Frostbite",
"Gopher",
"Gravitar",
"Hero",
"Jamesbond",
"Kangaroo",
"Krull",
"KungFuMaster",
"NameThisGame",
"Phoenix",
"Pong",
"Qbert",
"Riverraid",
"Seaquest",
"SpaceInvaders",
"StarGunner",
"TimePilot",
"UpNDown",
"VideoPinball",
"WizardOfWor",
"YarsRevenge",
]
# Total of 45 games.
ALL_GAMES = TUNING_SUITE + TESTING_SUITE
URL_PREFIX = "http://storage.googleapis.com/rl_unplugged/atari"
def _filename(run_id: int, shard_id: int, total_num_shards: int = 100) -> str:
return f"run_{run_id}-{shard_id:05d}-of-{total_num_shards:05d}"
def _decode_frames(pngs: tf.Tensor) -> tf.Tensor:
"""Decode PNGs.
:param pngs: String Tensor of size (4,) containing PNG encoded images.
:returns: Tensor of size (4, 84, 84) containing decoded images.
"""
# Statically unroll png decoding
frames = [tf.image.decode_png(pngs[i], channels=1) for i in range(4)]
# NOTE: to match tianshou's convention for framestacking
frames = tf.squeeze(tf.stack(frames, axis=0))
frames.set_shape((4, 84, 84))
return frames
def _make_tianshou_batch(
o_t: tf.Tensor,
a_t: tf.Tensor,
r_t: tf.Tensor,
d_t: tf.Tensor,
o_tp1: tf.Tensor,
a_tp1: tf.Tensor,
) -> Batch:
"""Create Tianshou batch with offline data.
:param o_t: Observation at time t.
:param a_t: Action at time t.
:param r_t: Reward at time t.
:param d_t: Discount at time t.
:param o_tp1: Observation at time t+1.
:param a_tp1: Action at time t+1.
:returns: A tianshou.data.Batch object.
"""
return Batch(
obs=o_t.numpy(),
act=a_t.numpy(),
rew=r_t.numpy(),
done=1 - d_t.numpy(),
obs_next=o_tp1.numpy(),
)
def _tf_example_to_tianshou_batch(tf_example: tf.train.Example) -> Batch:
"""Create a tianshou Batch replay sample from a TF example."""
# Parse tf.Example.
feature_description = {
"o_t": tf.io.FixedLenFeature([4], tf.string),
"o_tp1": tf.io.FixedLenFeature([4], tf.string),
"a_t": tf.io.FixedLenFeature([], tf.int64),
"a_tp1": tf.io.FixedLenFeature([], tf.int64),
"r_t": tf.io.FixedLenFeature([], tf.float32),
"d_t": tf.io.FixedLenFeature([], tf.float32),
"episode_id": tf.io.FixedLenFeature([], tf.int64),
"episode_return": tf.io.FixedLenFeature([], tf.float32),
}
data = tf.io.parse_single_example(tf_example, feature_description)
# Process data.
o_t = _decode_frames(data["o_t"])
o_tp1 = _decode_frames(data["o_tp1"])
a_t = tf.cast(data["a_t"], tf.int32)
a_tp1 = tf.cast(data["a_tp1"], tf.int32)
# Build tianshou Batch replay sample.
return _make_tianshou_batch(o_t, a_t, data["r_t"], data["d_t"], o_tp1, a_tp1)
# Adapted From https://gist.github.com/yanqd0/c13ed29e29432e3cf3e7c38467f42f51
def download(url: str, fname: str, chunk_size=1024):
resp = requests.get(url, stream=True)
total = int(resp.headers.get("content-length", 0))
if os.path.exists(fname):
print(f"Found cached file at {fname}.")
return
with open(fname, "wb") as ofile, tqdm(
desc=fname,
total=total,
unit="iB",
unit_scale=True,
unit_divisor=1024,
) as bar:
for data in resp.iter_content(chunk_size=chunk_size):
size = ofile.write(data)
bar.update(size)
def process_shard(url: str, fname: str, ofname: str, maxsize: int = 500000) -> None:
download(url, fname)
obs = np.ndarray((maxsize, 4, 84, 84), dtype="uint8")
act = np.ndarray((maxsize,), dtype="int64")
rew = np.ndarray((maxsize,), dtype="float32")
done = np.ndarray((maxsize,), dtype="bool")
obs_next = np.ndarray((maxsize, 4, 84, 84), dtype="uint8")
i = 0
file_ds = tf.data.TFRecordDataset(fname, compression_type="GZIP")
for example in file_ds:
if i >= maxsize:
break
batch = _tf_example_to_tianshou_batch(example)
obs[i], act[i], rew[i], done[i], obs_next[i] = (
batch.obs,
batch.act,
batch.rew,
batch.done,
batch.obs_next,
)
i += 1
if i % 1000 == 0:
print(f"...{i}", end="", flush=True)
print("\nDataset size:", i)
# Following D4RL dataset naming conventions
with h5py.File(ofname, "w") as f:
f.create_dataset("observations", data=obs, compression="gzip")
f.create_dataset("actions", data=act, compression="gzip")
f.create_dataset("rewards", data=rew, compression="gzip")
f.create_dataset("terminals", data=done, compression="gzip")
f.create_dataset("next_observations", data=obs_next, compression="gzip")
def process_dataset(
task: str,
download_path: str,
dst_path: str,
run_id: int = 1,
shard_id: int = 0,
total_num_shards: int = 100,
) -> None:
fn = f"{task}/{_filename(run_id, shard_id, total_num_shards=total_num_shards)}"
url = f"{URL_PREFIX}/{fn}"
filepath = f"{download_path}/{fn}"
ofname = f"{dst_path}/{fn}.hdf5"
process_shard(url, filepath, ofname)
def main(args):
if args.task not in ALL_GAMES:
raise KeyError(f"`{args.task}` is not in the list of games.")
fn = _filename(args.run_id, args.shard_id, total_num_shards=args.total_num_shards)
dataset_path = os.path.join(args.dataset_dir, args.task, f"{fn}.hdf5")
if os.path.exists(dataset_path):
raise OSError(f"Found existing dataset at {dataset_path}. Will not overwrite.")
args.cache_dir = os.environ.get("RLU_CACHE_DIR", args.cache_dir)
args.dataset_dir = os.environ.get("RLU_DATASET_DIR", args.dataset_dir)
cache_path = os.path.join(args.cache_dir, args.task)
os.makedirs(cache_path, exist_ok=True)
dst_path = os.path.join(args.dataset_dir, args.task)
os.makedirs(dst_path, exist_ok=True)
process_dataset(
args.task,
args.cache_dir,
args.dataset_dir,
run_id=args.run_id,
shard_id=args.shard_id,
total_num_shards=args.total_num_shards,
)
if __name__ == "__main__":
parser = ArgumentParser(usage=__doc__)
parser.add_argument("--task", required=True, help="Name of the Atari game.")
parser.add_argument(
"--run-id",
type=int,
default=1,
help="Run id to download and convert. Value in [1..5].",
)
parser.add_argument(
"--shard-id",
type=int,
default=0,
help="Shard id to download and convert. Value in [0..99].",
)
parser.add_argument("--total-num-shards", type=int, default=100, help="Total number of shards.")
parser.add_argument(
"--dataset-dir",
default=os.path.expanduser("~/.rl_unplugged/datasets"),
help="Directory for converted hdf5 files.",
)
parser.add_argument(
"--cache-dir",
default=os.path.expanduser("~/.rl_unplugged/cache"),
help="Directory for downloaded original datasets.",
)
args = parser.parse_args()
main(args)