diff --git a/examples/offline/atari_bcq.py b/examples/offline/atari_bcq.py index c37313d..82775b5 100644 --- a/examples/offline/atari_bcq.py +++ b/examples/offline/atari_bcq.py @@ -12,7 +12,8 @@ from torch.utils.tensorboard import SummaryWriter from examples.atari.atari_network import DQN from examples.atari.atari_wrapper import make_atari_env -from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer +from examples.offline.utils import load_buffer +from tianshou.data import Collector, VectorReplayBuffer from tianshou.policy import DiscreteBCQPolicy from tianshou.trainer import offline_trainer from tianshou.utils import TensorboardLogger, WandbLogger @@ -118,18 +119,19 @@ def test_discrete_bcq(args=get_args()): policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # buffer - assert os.path.exists(args.load_buffer_name), \ - "Please run atari_dqn.py first to get expert's data buffer." - if args.load_buffer_name.endswith(".pkl"): - buffer = pickle.load(open(args.load_buffer_name, "rb")) - elif args.load_buffer_name.endswith(".hdf5"): - if args.buffer_from_rl_unplugged: - buffer = ReplayBuffer.load_hdf5(args.load_buffer_name) - else: - buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name) + if args.buffer_from_rl_unplugged: + buffer = load_buffer(args.load_buffer_name) else: - print(f"Unknown buffer format: {args.load_buffer_name}") - exit(0) + assert os.path.exists(args.load_buffer_name), \ + "Please run atari_dqn.py first to get expert's data buffer." + if args.load_buffer_name.endswith(".pkl"): + buffer = pickle.load(open(args.load_buffer_name, "rb")) + elif args.load_buffer_name.endswith(".hdf5"): + buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name) + else: + print(f"Unknown buffer format: {args.load_buffer_name}") + exit(0) + print("Replay buffer size:", len(buffer), flush=True) # collector test_collector = Collector(policy, test_envs, exploration_noise=True) diff --git a/examples/offline/atari_cql.py b/examples/offline/atari_cql.py index c8300b3..794ebc8 100644 --- a/examples/offline/atari_cql.py +++ b/examples/offline/atari_cql.py @@ -12,6 +12,7 @@ from torch.utils.tensorboard import SummaryWriter from examples.atari.atari_network import QRDQN from examples.atari.atari_wrapper import make_atari_env +from examples.offline.utils import load_buffer from tianshou.data import Collector, VectorReplayBuffer from tianshou.policy import DiscreteCQLPolicy from tianshou.trainer import offline_trainer @@ -57,6 +58,9 @@ def get_args(): parser.add_argument( "--load-buffer-name", type=str, default="./expert_DQN_PongNoFrameskip-v4.hdf5" ) + parser.add_argument( + "--buffer-from-rl-unplugged", action="store_true", default=False + ) parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu" ) @@ -100,15 +104,19 @@ def test_discrete_cql(args=get_args()): policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # buffer - assert os.path.exists(args.load_buffer_name), \ - "Please run atari_qrdqn.py first to get expert's data buffer." - if args.load_buffer_name.endswith(".pkl"): - buffer = pickle.load(open(args.load_buffer_name, "rb")) - elif args.load_buffer_name.endswith(".hdf5"): - buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name) + if args.buffer_from_rl_unplugged: + buffer = load_buffer(args.load_buffer_name) else: - print(f"Unknown buffer format: {args.load_buffer_name}") - exit(0) + assert os.path.exists(args.load_buffer_name), \ + "Please run atari_dqn.py first to get expert's data buffer." + if args.load_buffer_name.endswith(".pkl"): + buffer = pickle.load(open(args.load_buffer_name, "rb")) + elif args.load_buffer_name.endswith(".hdf5"): + buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name) + else: + print(f"Unknown buffer format: {args.load_buffer_name}") + exit(0) + print("Replay buffer size:", len(buffer), flush=True) # collector test_collector = Collector(policy, test_envs, exploration_noise=True) diff --git a/examples/offline/atari_crr.py b/examples/offline/atari_crr.py index 49ef9c3..b9af9c4 100644 --- a/examples/offline/atari_crr.py +++ b/examples/offline/atari_crr.py @@ -12,6 +12,7 @@ from torch.utils.tensorboard import SummaryWriter from examples.atari.atari_network import DQN from examples.atari.atari_wrapper import make_atari_env +from examples.offline.utils import load_buffer from tianshou.data import Collector, VectorReplayBuffer from tianshou.policy import DiscreteCRRPolicy from tianshou.trainer import offline_trainer @@ -59,6 +60,9 @@ def get_args(): parser.add_argument( "--load-buffer-name", type=str, default="./expert_DQN_PongNoFrameskip-v4.hdf5" ) + parser.add_argument( + "--buffer-from-rl-unplugged", action="store_true", default=False + ) parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu" ) @@ -120,15 +124,19 @@ def test_discrete_crr(args=get_args()): policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # buffer - assert os.path.exists(args.load_buffer_name), \ - "Please run atari_qrdqn.py first to get expert's data buffer." - if args.load_buffer_name.endswith(".pkl"): - buffer = pickle.load(open(args.load_buffer_name, "rb")) - elif args.load_buffer_name.endswith(".hdf5"): - buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name) + if args.buffer_from_rl_unplugged: + buffer = load_buffer(args.load_buffer_name) else: - print(f"Unknown buffer format: {args.load_buffer_name}") - exit(0) + assert os.path.exists(args.load_buffer_name), \ + "Please run atari_dqn.py first to get expert's data buffer." + if args.load_buffer_name.endswith(".pkl"): + buffer = pickle.load(open(args.load_buffer_name, "rb")) + elif args.load_buffer_name.endswith(".hdf5"): + buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name) + else: + print(f"Unknown buffer format: {args.load_buffer_name}") + exit(0) + print("Replay buffer size:", len(buffer), flush=True) # collector test_collector = Collector(policy, test_envs, exploration_noise=True) diff --git a/examples/offline/atari_il.py b/examples/offline/atari_il.py index 348f992..1d17c6d 100644 --- a/examples/offline/atari_il.py +++ b/examples/offline/atari_il.py @@ -12,6 +12,7 @@ from torch.utils.tensorboard import SummaryWriter from examples.atari.atari_network import DQN from examples.atari.atari_wrapper import make_atari_env +from examples.offline.utils import load_buffer from tianshou.data import Collector, VectorReplayBuffer from tianshou.policy import ImitationPolicy from tianshou.trainer import offline_trainer @@ -50,6 +51,9 @@ def get_args(): parser.add_argument( "--load-buffer-name", type=str, default="./expert_DQN_PongNoFrameskip-v4.hdf5" ) + parser.add_argument( + "--buffer-from-rl-unplugged", action="store_true", default=False + ) parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu" ) @@ -85,15 +89,19 @@ def test_il(args=get_args()): policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # buffer - assert os.path.exists(args.load_buffer_name), \ - "Please run atari_qrdqn.py first to get expert's data buffer." - if args.load_buffer_name.endswith('.pkl'): - buffer = pickle.load(open(args.load_buffer_name, "rb")) - elif args.load_buffer_name.endswith('.hdf5'): - buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name) + if args.buffer_from_rl_unplugged: + buffer = load_buffer(args.load_buffer_name) else: - print(f"Unknown buffer format: {args.load_buffer_name}") - exit(0) + assert os.path.exists(args.load_buffer_name), \ + "Please run atari_dqn.py first to get expert's data buffer." + if args.load_buffer_name.endswith(".pkl"): + buffer = pickle.load(open(args.load_buffer_name, "rb")) + elif args.load_buffer_name.endswith(".hdf5"): + buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name) + else: + print(f"Unknown buffer format: {args.load_buffer_name}") + exit(0) + print("Replay buffer size:", len(buffer), flush=True) # collector test_collector = Collector(policy, test_envs, exploration_noise=True) diff --git a/examples/offline/convert_rl_unplugged_atari.py b/examples/offline/convert_rl_unplugged_atari.py index d60ddb4..9fc8c25 100755 --- a/examples/offline/convert_rl_unplugged_atari.py +++ b/examples/offline/convert_rl_unplugged_atari.py @@ -3,7 +3,7 @@ # Adapted from # https://github.com/deepmind/deepmind-research/blob/master/rl_unplugged/atari.py # -"""Convert Atari RL Unplugged datasets to Tianshou replay buffers. +"""Convert Atari RL Unplugged datasets to HDF5 format. Examples in the dataset represent SARSA transitions stored during a DQN training run as described in https://arxiv.org/pdf/1907.04543. @@ -30,11 +30,13 @@ Every transition in the dataset is a tuple containing the following features: import os from argparse import ArgumentParser +import h5py +import numpy as np import requests import tensorflow as tf from tqdm import tqdm -from tianshou.data import Batch, ReplayBuffer +from tianshou.data import Batch tf.config.set_visible_devices([], 'GPU') @@ -108,7 +110,7 @@ def _decode_frames(pngs: tf.Tensor) -> tf.Tensor: pngs: String Tensor of size (4,) containing PNG encoded images. Returns: - 4 84x84 grayscale images packed in a (84, 84, 4) uint8 Tensor. + 4 84x84 grayscale images packed in a (4, 84, 84) uint8 Tensor. """ # Statically unroll png decoding frames = [tf.image.decode_png(pngs[i], channels=1) for i in range(4)] @@ -195,17 +197,30 @@ def download(url: str, fname: str, chunk_size=1024): def process_shard(url: str, fname: str, ofname: str) -> None: download(url, fname) + maxsize = 500000 + obs = np.ndarray((maxsize, 4, 84, 84), dtype="uint8") + act = np.ndarray((maxsize, ), dtype="int64") + rew = np.ndarray((maxsize, ), dtype="float32") + done = np.ndarray((maxsize, ), dtype="bool") + obs_next = np.ndarray((maxsize, 4, 84, 84), dtype="uint8") + i = 0 file_ds = tf.data.TFRecordDataset(fname, compression_type="GZIP") - buffer = ReplayBuffer(500000) - cnt = 0 for example in file_ds: batch = _tf_example_to_tianshou_batch(example) - buffer.add(batch) - cnt += 1 - if cnt % 1000 == 0: - print(f"...{cnt}", end="", flush=True) - print("\nReplayBuffer size:", len(buffer)) - buffer.save_hdf5(ofname, compression="gzip") + obs[i], act[i], rew[i], done[i], obs_next[i] = ( + batch.obs, batch.act, batch.rew, batch.done, batch.obs_next + ) + i += 1 + if i % 1000 == 0: + print(f"...{i}", end="", flush=True) + print("\nDataset size:", i) + # Following D4RL dataset naming conventions + with h5py.File(ofname, "w") as f: + f.create_dataset("observations", data=obs, compression="gzip") + f.create_dataset("actions", data=act, compression="gzip") + f.create_dataset("rewards", data=rew, compression="gzip") + f.create_dataset("terminals", data=done, compression="gzip") + f.create_dataset("next_observations", data=obs_next, compression="gzip") def process_dataset( @@ -227,19 +242,19 @@ def main(args): if args.task not in ALL_GAMES: raise KeyError(f"`{args.task}` is not in the list of games.") fn = _filename(args.run_id, args.shard_id, total_num_shards=args.total_num_shards) - buffer_path = os.path.join(args.buffer_dir, args.task, f"{fn}.hdf5") - if os.path.exists(buffer_path): - raise IOError(f"Found existing buffer at {buffer_path}. Will not overwrite.") + dataset_path = os.path.join(args.dataset_dir, args.task, f"{fn}.hdf5") + if os.path.exists(dataset_path): + raise IOError(f"Found existing dataset at {dataset_path}. Will not overwrite.") + args.cache_dir = os.environ.get("RLU_CACHE_DIR", args.cache_dir) args.dataset_dir = os.environ.get("RLU_DATASET_DIR", args.dataset_dir) - args.buffer_dir = os.environ.get("RLU_BUFFER_DIR", args.buffer_dir) - dataset_path = os.path.join(args.dataset_dir, args.task) - os.makedirs(dataset_path, exist_ok=True) - dst_path = os.path.join(args.buffer_dir, args.task) + cache_path = os.path.join(args.cache_dir, args.task) + os.makedirs(cache_path, exist_ok=True) + dst_path = os.path.join(args.dataset_dir, args.task) os.makedirs(dst_path, exist_ok=True) process_dataset( args.task, + args.cache_dir, args.dataset_dir, - args.buffer_dir, run_id=args.run_id, shard_id=args.shard_id, total_num_shards=args.total_num_shards @@ -267,12 +282,12 @@ if __name__ == "__main__": parser.add_argument( "--dataset-dir", default=os.path.expanduser("~/.rl_unplugged/datasets"), - help="Directory for downloaded original datasets.", + help="Directory for converted hdf5 files.", ) parser.add_argument( - "--buffer-dir", - default=os.path.expanduser("~/.rl_unplugged/buffers"), - help="Directory for converted replay buffers.", + "--cache-dir", + default=os.path.expanduser("~/.rl_unplugged/cache"), + help="Directory for downloaded original datasets.", ) args = parser.parse_args() main(args) diff --git a/examples/offline/d4rl_bcq.py b/examples/offline/d4rl_bcq.py index 434763e..eecf03f 100644 --- a/examples/offline/d4rl_bcq.py +++ b/examples/offline/d4rl_bcq.py @@ -5,13 +5,13 @@ import datetime import os import pprint -import d4rl import gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Batch, Collector, ReplayBuffer +from examples.offline.utils import load_buffer_d4rl +from tianshou.data import Collector from tianshou.env import SubprocVectorEnv from tianshou.policy import BCQPolicy from tianshou.trainer import offline_trainer @@ -211,23 +211,7 @@ def test_bcq(): collector.collect(n_episode=1, render=1 / 35) if not args.watch: - dataset = d4rl.qlearning_dataset(gym.make(args.expert_data_task)) - dataset_size = dataset["rewards"].size - - print("dataset_size", dataset_size) - replay_buffer = ReplayBuffer(dataset_size) - - for i in range(dataset_size): - replay_buffer.add( - Batch( - obs=dataset["observations"][i], - act=dataset["actions"][i], - rew=dataset["rewards"][i], - done=dataset["terminals"][i], - obs_next=dataset["next_observations"][i], - ) - ) - print("dataset loaded") + replay_buffer = load_buffer_d4rl(args.expert_data_task) # trainer result = offline_trainer( policy, diff --git a/examples/offline/d4rl_cql.py b/examples/offline/d4rl_cql.py index 2b7f5ff..af8b787 100644 --- a/examples/offline/d4rl_cql.py +++ b/examples/offline/d4rl_cql.py @@ -5,13 +5,13 @@ import datetime import os import pprint -import d4rl import gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Batch, Collector, ReplayBuffer +from examples.offline.utils import load_buffer_d4rl +from tianshou.data import Collector from tianshou.env import SubprocVectorEnv from tianshou.policy import CQLPolicy from tianshou.trainer import offline_trainer @@ -206,23 +206,7 @@ def test_cql(): collector.collect(n_episode=1, render=1 / 35) if not args.watch: - dataset = d4rl.qlearning_dataset(gym.make(args.expert_data_task)) - dataset_size = dataset["rewards"].size - - print("dataset_size", dataset_size) - replay_buffer = ReplayBuffer(dataset_size) - - for i in range(dataset_size): - replay_buffer.add( - Batch( - obs=dataset["observations"][i], - act=dataset["actions"][i], - rew=dataset["rewards"][i], - done=dataset["terminals"][i], - obs_next=dataset["next_observations"][i], - ) - ) - print("dataset loaded") + replay_buffer = load_buffer_d4rl(args.expert_data_task) # trainer result = offline_trainer( policy, diff --git a/examples/offline/d4rl_il.py b/examples/offline/d4rl_il.py index 710441a..54dde85 100644 --- a/examples/offline/d4rl_il.py +++ b/examples/offline/d4rl_il.py @@ -5,13 +5,13 @@ import datetime import os import pprint -import d4rl import gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Batch, Collector, ReplayBuffer +from examples.offline.utils import load_buffer_d4rl +from tianshou.data import Collector from tianshou.env import SubprocVectorEnv from tianshou.policy import ImitationPolicy from tianshou.trainer import offline_trainer @@ -148,23 +148,7 @@ def test_il(): collector.collect(n_episode=1, render=1 / 35) if not args.watch: - dataset = d4rl.qlearning_dataset(gym.make(args.expert_data_task)) - dataset_size = dataset["rewards"].size - - print("dataset_size", dataset_size) - replay_buffer = ReplayBuffer(dataset_size) - - for i in range(dataset_size): - replay_buffer.add( - Batch( - obs=dataset["observations"][i], - act=dataset["actions"][i], - rew=dataset["rewards"][i], - done=dataset["terminals"][i], - obs_next=dataset["next_observations"][i], - ) - ) - print("dataset loaded") + replay_buffer = load_buffer_d4rl(args.expert_data_task) # trainer result = offline_trainer( policy, diff --git a/examples/offline/utils.py b/examples/offline/utils.py new file mode 100644 index 0000000..757baf6 --- /dev/null +++ b/examples/offline/utils.py @@ -0,0 +1,29 @@ +import d4rl +import gym +import h5py + +from tianshou.data import ReplayBuffer + + +def load_buffer_d4rl(expert_data_task: str) -> ReplayBuffer: + dataset = d4rl.qlearning_dataset(gym.make(expert_data_task)) + replay_buffer = ReplayBuffer.from_data( + obs=dataset["observations"], + act=dataset["actions"], + rew=dataset["rewards"], + done=dataset["terminals"], + obs_next=dataset["next_observations"] + ) + return replay_buffer + + +def load_buffer(buffer_path: str) -> ReplayBuffer: + with h5py.File(buffer_path, "r") as dataset: + buffer = ReplayBuffer.from_data( + obs=dataset["observations"], + act=dataset["actions"], + rew=dataset["rewards"], + done=dataset["terminals"], + obs_next=dataset["next_observations"] + ) + return buffer diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index 83ef975..edbf539 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -1019,6 +1019,31 @@ def test_multibuf_hdf5(): os.remove(path) +def test_from_data(): + obs_data = np.ndarray((10, 3, 3), dtype="uint8") + for i in range(10): + obs_data[i] = i * np.ones((3, 3), dtype="uint8") + obs_next_data = np.zeros_like(obs_data) + obs_next_data[:-1] = obs_data[1:] + f, path = tempfile.mkstemp(suffix='.hdf5') + os.close(f) + with h5py.File(path, "w") as f: + obs = f.create_dataset("obs", data=obs_data) + act = f.create_dataset("act", data=np.arange(10, dtype="int32")) + rew = f.create_dataset("rew", data=np.arange(10, dtype="float32")) + done = f.create_dataset("done", data=np.zeros(10, dtype="bool")) + obs_next = f.create_dataset("obs_next", data=obs_next_data) + buf = ReplayBuffer.from_data(obs, act, rew, done, obs_next) + assert len(buf) == 10 + batch = buf[3] + assert np.array_equal(batch.obs, 3 * np.ones((3, 3), dtype="uint8")) + assert batch.act == 3 + assert batch.rew == 3.0 + assert not batch.done + assert np.array_equal(batch.obs_next, 4 * np.ones((3, 3), dtype="uint8")) + os.remove(path) + + if __name__ == '__main__': test_replaybuffer() test_ignore_obs_next() @@ -1032,3 +1057,4 @@ if __name__ == '__main__': test_cachedbuffer() test_multibuf_stack() test_multibuf_hdf5() + test_from_data() diff --git a/tianshou/data/buffer/base.py b/tianshou/data/buffer/base.py index c9aafc7..9ef99cd 100644 --- a/tianshou/data/buffer/base.py +++ b/tianshou/data/buffer/base.py @@ -99,6 +99,22 @@ class ReplayBuffer: buf.__setstate__(from_hdf5(f, device=device)) # type: ignore return buf + @classmethod + def from_data( + cls, obs: h5py.Dataset, act: h5py.Dataset, rew: h5py.Dataset, + done: h5py.Dataset, obs_next: h5py.Dataset + ) -> "ReplayBuffer": + size = len(obs) + assert all(len(dset) == size for dset in [obs, act, rew, done, obs_next]), \ + "Lengths of all hdf5 datasets need to be equal." + buf = cls(size) + if size == 0: + return buf + batch = Batch(obs=obs, act=act, rew=rew, done=done, obs_next=obs_next) + buf.set_batch(batch) + buf._size = size + return buf + def reset(self, keep_statistics: bool = False) -> None: """Clear all the data in replay buffer and episode statistics.""" self.last_index = np.array([0])